Unity/AI

[Unity 인공지능 / ML Agent] Get Potion (추가중)

왹져박사 2023. 4. 11. 15:48

캐릭터(agent)가 힐링포션만을 먹도록 학습시키는 예제.

 

1. 환경 배치

 

2. 환경 세팅

1) 이전의 포션이 있다면 모두 삭제

2) 캐릭터를 (0, 0, 0)으로 초기화

3) 새로운 포션들을 랜덤으로 배치(치유 포션 + 독 포션 3세트로 설정, 이전에 배운 Apple Catch 예제의 랜덤 생성을 이용)

리스트를 이용하여 포션을 저장, 먹었다면 삭제, 초기화하면 모두 삭제

 

PotionArea

using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using TMPro;

public class PotionArea : MonoBehaviour
{
    public CharacterAgent characterAgent;

    public GameObject healingPotionPrefab;
    public GameObject poisonPotionPrefab;
    public TMP_Text cumulativeRewardText;

    private List<GameObject> potionList;

    private void Start()
    {
        this.ResetArea();
    }
    private void Update()
    {
        if (Input.GetMouseButtonDown(0))
        {
            this.ResetArea();
        }
        cumulativeRewardText.text = characterAgent.GetCumulativeReward().ToString("0.00");
    }
    public void ResetArea()
    {
        this.RemoveAllPotions();
        this.PlaceCharacter();
        this.SpawnPotion(3);
    }

    private void PlaceCharacter()
    {
        Rigidbody rigidbody = this.characterAgent.GetComponent<Rigidbody>();
        rigidbody.velocity = Vector3.zero;
        rigidbody.angularVelocity = Vector3.zero;
        this.characterAgent.transform.localPosition = new Vector3(0, 0, 0);
        this.characterAgent.transform.localRotation = Quaternion.Euler(0, 0, 0);
    }
    private void RemoveAllPotions()
    {
        if (this.potionList != null)
        {
            for(int i = 0; i < potionList.Count; i++)
            {
                if (this.potionList[i] != null) Destroy(this.potionList[i]);
            }
        }
        this.potionList = new List<GameObject>();
    }
    private void SpawnPotion(int count)
    {
        for(int i = 0; i < count; i++)
        {
            var x0 = Random.Range(-4.5f, 4.5f);
            var z0 = Random.Range(-4.5f, 4.5f);
            var x1 = Random.Range(-4.5f, 4.5f);
            var z1 = Random.Range(-4.5f, 4.5f);

            var go0 = Instantiate(this.healingPotionPrefab);
            var go1 = Instantiate(this.poisonPotionPrefab);

            go0.transform.localPosition = this.transform.position + new Vector3(x0, 0, z0);
            go1.transform.localPosition = this.transform.position + new Vector3(x1, 0, z1);

            go0.transform.SetParent(this.transform);
            go1.transform.SetParent(this.transform);

            this.potionList.Add(go0);
            this.potionList.Add(go1);
        }
    }
    public void RemovePotion(GameObject potion)
    {
        this.potionList.Remove(potion);
        Destroy(potion);
    }
    public int PotionRemaining
    {
        get { return potionList.Count; }
    }
}

 

CharacterAgent

healing potion을 먹으면 +2점, poison potion을 먹는다면 -1점, 움직임이 없다면 Maxstep만큼 -1점

Heuristic에서 이동하는 방법을 굳이 바꿔야 할 필요성을 느끼지 못해 Penguin 예제 그대로 사용하였다.

using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;

public class CharacterAgent : Agent
{
    private Rigidbody rbody;
    private PotionArea potionArea;

    private float moveSpeed = 5f;
    private float turnSpeed = 180f;

    public override void Initialize()
    {
        base.Initialize();

        this.rbody = this.GetComponent<Rigidbody>();
        this.potionArea = GetComponentInParent<PotionArea>();
    }
    public override void OnEpisodeBegin()
    {
        this.potionArea.ResetArea();
    }

    public override void CollectObservations(VectorSensor sensor)
    {
        sensor.AddObservation(transform.forward);   
    }
    public override void OnActionReceived(ActionBuffers actionBuffers)
    {
        float forwardAmount = actionBuffers.DiscreteActions[0];

        float turnAmount = 0f;
        if (actionBuffers.DiscreteActions[1] == 1f)
        {
            turnAmount = -1f;   
        }
        else if (actionBuffers.DiscreteActions[1] == 2f)
        {
            turnAmount = 1f;    
        }

        // Apply movement
        this.rbody.MovePosition(transform.position + transform.forward * forwardAmount * moveSpeed * Time.fixedDeltaTime);

        transform.Rotate(transform.up * turnAmount * turnSpeed * Time.fixedDeltaTime);

        // Apply a tiny negative reward every step to encourage action
        if (MaxStep > 0) AddReward(-1f / MaxStep);
    }
    private void OnTriggerEnter(Collider other)
    {
        if (other.transform.CompareTag("heal"))
        {
            Debug.Log("healing potion");
            this.GetHealingPotion(other.gameObject);
        }
        if (other.transform.CompareTag("poison"))
        {
            Debug.Log("poison potion");
            this.GetPoisonPotion(other.gameObject);
        }
    }
    private void GetHealingPotion(GameObject potion)
    {
        AddReward(2f);
        this.potionArea.RemovePotion(potion);
        if (potionArea.PotionRemaining <= 0) EndEpisode();
    }
    private void GetPoisonPotion(GameObject potion)
    {
        AddReward(-1f);
        this.potionArea.RemovePotion(potion);
        if (potionArea.PotionRemaining <= 0) EndEpisode();
    }
    public override void Heuristic(in ActionBuffers actionsOut)
    {
        int forwardAction = 0;
        int turnAction = 0;
        if (Input.GetKey(KeyCode.W))
        {
            // move forward
            forwardAction = 1;
        }
        if (Input.GetKey(KeyCode.A))
        {
            // turn left
            turnAction = 1;
        }
        else if (Input.GetKey(KeyCode.D))
        {
            // turn right
            turnAction = 2;
        }

        // Put the actions into the array
        actionsOut.DiscreteActions.Array[0] = forwardAction;
        actionsOut.DiscreteActions.Array[1] = turnAction;
    }

}

 

Potion.yaml

behaviors:
    Character:
        trainer_type: ppo
        hyperparameters:
            batch_size: 128
            buffer_size: 2048
            learning_rate: 0.0003
            beta: 0.01
            epsilon: 0.2
            lambd: 0.95
            num_epoch: 3
            learning_rate_schedule: linear
        network_settings:
            normalize: false
            hidden_units: 256
            num_layers: 2
            vis_encode_type: simple
        reward_signals:
            extrinsic:
                gamma: 0.99
                strength: 1.0
        keep_checkpoints: 5
        max_steps: 1000000
        time_horizon: 128
        summary_freq: 5000
        threaded: true

--force는 이전에 저장된 id가 같아도 강제로 덮어씌운다. 

+영상 녹화 추가하기