캐릭터(agent)가 힐링포션만을 먹도록 학습시키는 예제.
1. 환경 배치

2. 환경 세팅
1) 이전의 포션이 있다면 모두 삭제
2) 캐릭터를 (0, 0, 0)으로 초기화
3) 새로운 포션들을 랜덤으로 배치(치유 포션 + 독 포션 3세트로 설정, 이전에 배운 Apple Catch 예제의 랜덤 생성을 이용)
리스트를 이용하여 포션을 저장, 먹었다면 삭제, 초기화하면 모두 삭제
PotionArea
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using TMPro;
public class PotionArea : MonoBehaviour
{
public CharacterAgent characterAgent;
public GameObject healingPotionPrefab;
public GameObject poisonPotionPrefab;
public TMP_Text cumulativeRewardText;
private List<GameObject> potionList;
private void Start()
{
this.ResetArea();
}
private void Update()
{
if (Input.GetMouseButtonDown(0))
{
this.ResetArea();
}
cumulativeRewardText.text = characterAgent.GetCumulativeReward().ToString("0.00");
}
public void ResetArea()
{
this.RemoveAllPotions();
this.PlaceCharacter();
this.SpawnPotion(3);
}
private void PlaceCharacter()
{
Rigidbody rigidbody = this.characterAgent.GetComponent<Rigidbody>();
rigidbody.velocity = Vector3.zero;
rigidbody.angularVelocity = Vector3.zero;
this.characterAgent.transform.localPosition = new Vector3(0, 0, 0);
this.characterAgent.transform.localRotation = Quaternion.Euler(0, 0, 0);
}
private void RemoveAllPotions()
{
if (this.potionList != null)
{
for(int i = 0; i < potionList.Count; i++)
{
if (this.potionList[i] != null) Destroy(this.potionList[i]);
}
}
this.potionList = new List<GameObject>();
}
private void SpawnPotion(int count)
{
for(int i = 0; i < count; i++)
{
var x0 = Random.Range(-4.5f, 4.5f);
var z0 = Random.Range(-4.5f, 4.5f);
var x1 = Random.Range(-4.5f, 4.5f);
var z1 = Random.Range(-4.5f, 4.5f);
var go0 = Instantiate(this.healingPotionPrefab);
var go1 = Instantiate(this.poisonPotionPrefab);
go0.transform.localPosition = this.transform.position + new Vector3(x0, 0, z0);
go1.transform.localPosition = this.transform.position + new Vector3(x1, 0, z1);
go0.transform.SetParent(this.transform);
go1.transform.SetParent(this.transform);
this.potionList.Add(go0);
this.potionList.Add(go1);
}
}
public void RemovePotion(GameObject potion)
{
this.potionList.Remove(potion);
Destroy(potion);
}
public int PotionRemaining
{
get { return potionList.Count; }
}
}
CharacterAgent
healing potion을 먹으면 +2점, poison potion을 먹는다면 -1점, 움직임이 없다면 Maxstep만큼 -1점
Heuristic에서 이동하는 방법을 굳이 바꿔야 할 필요성을 느끼지 못해 Penguin 예제 그대로 사용하였다.
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
public class CharacterAgent : Agent
{
private Rigidbody rbody;
private PotionArea potionArea;
private float moveSpeed = 5f;
private float turnSpeed = 180f;
public override void Initialize()
{
base.Initialize();
this.rbody = this.GetComponent<Rigidbody>();
this.potionArea = GetComponentInParent<PotionArea>();
}
public override void OnEpisodeBegin()
{
this.potionArea.ResetArea();
}
public override void CollectObservations(VectorSensor sensor)
{
sensor.AddObservation(transform.forward);
}
public override void OnActionReceived(ActionBuffers actionBuffers)
{
float forwardAmount = actionBuffers.DiscreteActions[0];
float turnAmount = 0f;
if (actionBuffers.DiscreteActions[1] == 1f)
{
turnAmount = -1f;
}
else if (actionBuffers.DiscreteActions[1] == 2f)
{
turnAmount = 1f;
}
// Apply movement
this.rbody.MovePosition(transform.position + transform.forward * forwardAmount * moveSpeed * Time.fixedDeltaTime);
transform.Rotate(transform.up * turnAmount * turnSpeed * Time.fixedDeltaTime);
// Apply a tiny negative reward every step to encourage action
if (MaxStep > 0) AddReward(-1f / MaxStep);
}
private void OnTriggerEnter(Collider other)
{
if (other.transform.CompareTag("heal"))
{
Debug.Log("healing potion");
this.GetHealingPotion(other.gameObject);
}
if (other.transform.CompareTag("poison"))
{
Debug.Log("poison potion");
this.GetPoisonPotion(other.gameObject);
}
}
private void GetHealingPotion(GameObject potion)
{
AddReward(2f);
this.potionArea.RemovePotion(potion);
if (potionArea.PotionRemaining <= 0) EndEpisode();
}
private void GetPoisonPotion(GameObject potion)
{
AddReward(-1f);
this.potionArea.RemovePotion(potion);
if (potionArea.PotionRemaining <= 0) EndEpisode();
}
public override void Heuristic(in ActionBuffers actionsOut)
{
int forwardAction = 0;
int turnAction = 0;
if (Input.GetKey(KeyCode.W))
{
// move forward
forwardAction = 1;
}
if (Input.GetKey(KeyCode.A))
{
// turn left
turnAction = 1;
}
else if (Input.GetKey(KeyCode.D))
{
// turn right
turnAction = 2;
}
// Put the actions into the array
actionsOut.DiscreteActions.Array[0] = forwardAction;
actionsOut.DiscreteActions.Array[1] = turnAction;
}
}
Potion.yaml
behaviors:
Character:
trainer_type: ppo
hyperparameters:
batch_size: 128
buffer_size: 2048
learning_rate: 0.0003
beta: 0.01
epsilon: 0.2
lambd: 0.95
num_epoch: 3
learning_rate_schedule: linear
network_settings:
normalize: false
hidden_units: 256
num_layers: 2
vis_encode_type: simple
reward_signals:
extrinsic:
gamma: 0.99
strength: 1.0
keep_checkpoints: 5
max_steps: 1000000
time_horizon: 128
summary_freq: 5000
threaded: true

--force는 이전에 저장된 id가 같아도 강제로 덮어씌운다.

+영상 녹화 추가하기