본문 바로가기
Unity/AI

[Unity 인공지능 / ML Agent] Get Potion (추가중)

by 왹져박사 2023. 4. 11.
728x90

캐릭터(agent)가 힐링포션만을 먹도록 학습시키는 예제.

 

1. 환경 배치

 

2. 환경 세팅

1) 이전의 포션이 있다면 모두 삭제

2) 캐릭터를 (0, 0, 0)으로 초기화

3) 새로운 포션들을 랜덤으로 배치(치유 포션 + 독 포션 3세트로 설정, 이전에 배운 Apple Catch 예제의 랜덤 생성을 이용)

리스트를 이용하여 포션을 저장, 먹었다면 삭제, 초기화하면 모두 삭제

 

PotionArea

using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using TMPro;

public class PotionArea : MonoBehaviour
{
    public CharacterAgent characterAgent;

    public GameObject healingPotionPrefab;
    public GameObject poisonPotionPrefab;
    public TMP_Text cumulativeRewardText;

    private List<GameObject> potionList;

    private void Start()
    {
        this.ResetArea();
    }
    private void Update()
    {
        if (Input.GetMouseButtonDown(0))
        {
            this.ResetArea();
        }
        cumulativeRewardText.text = characterAgent.GetCumulativeReward().ToString("0.00");
    }
    public void ResetArea()
    {
        this.RemoveAllPotions();
        this.PlaceCharacter();
        this.SpawnPotion(3);
    }

    private void PlaceCharacter()
    {
        Rigidbody rigidbody = this.characterAgent.GetComponent<Rigidbody>();
        rigidbody.velocity = Vector3.zero;
        rigidbody.angularVelocity = Vector3.zero;
        this.characterAgent.transform.localPosition = new Vector3(0, 0, 0);
        this.characterAgent.transform.localRotation = Quaternion.Euler(0, 0, 0);
    }
    private void RemoveAllPotions()
    {
        if (this.potionList != null)
        {
            for(int i = 0; i < potionList.Count; i++)
            {
                if (this.potionList[i] != null) Destroy(this.potionList[i]);
            }
        }
        this.potionList = new List<GameObject>();
    }
    private void SpawnPotion(int count)
    {
        for(int i = 0; i < count; i++)
        {
            var x0 = Random.Range(-4.5f, 4.5f);
            var z0 = Random.Range(-4.5f, 4.5f);
            var x1 = Random.Range(-4.5f, 4.5f);
            var z1 = Random.Range(-4.5f, 4.5f);

            var go0 = Instantiate(this.healingPotionPrefab);
            var go1 = Instantiate(this.poisonPotionPrefab);

            go0.transform.localPosition = this.transform.position + new Vector3(x0, 0, z0);
            go1.transform.localPosition = this.transform.position + new Vector3(x1, 0, z1);

            go0.transform.SetParent(this.transform);
            go1.transform.SetParent(this.transform);

            this.potionList.Add(go0);
            this.potionList.Add(go1);
        }
    }
    public void RemovePotion(GameObject potion)
    {
        this.potionList.Remove(potion);
        Destroy(potion);
    }
    public int PotionRemaining
    {
        get { return potionList.Count; }
    }
}

 

CharacterAgent

healing potion을 먹으면 +2점, poison potion을 먹는다면 -1점, 움직임이 없다면 Maxstep만큼 -1점

Heuristic에서 이동하는 방법을 굳이 바꿔야 할 필요성을 느끼지 못해 Penguin 예제 그대로 사용하였다.

using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;

public class CharacterAgent : Agent
{
    private Rigidbody rbody;
    private PotionArea potionArea;

    private float moveSpeed = 5f;
    private float turnSpeed = 180f;

    public override void Initialize()
    {
        base.Initialize();

        this.rbody = this.GetComponent<Rigidbody>();
        this.potionArea = GetComponentInParent<PotionArea>();
    }
    public override void OnEpisodeBegin()
    {
        this.potionArea.ResetArea();
    }

    public override void CollectObservations(VectorSensor sensor)
    {
        sensor.AddObservation(transform.forward);   
    }
    public override void OnActionReceived(ActionBuffers actionBuffers)
    {
        float forwardAmount = actionBuffers.DiscreteActions[0];

        float turnAmount = 0f;
        if (actionBuffers.DiscreteActions[1] == 1f)
        {
            turnAmount = -1f;   
        }
        else if (actionBuffers.DiscreteActions[1] == 2f)
        {
            turnAmount = 1f;    
        }

        // Apply movement
        this.rbody.MovePosition(transform.position + transform.forward * forwardAmount * moveSpeed * Time.fixedDeltaTime);

        transform.Rotate(transform.up * turnAmount * turnSpeed * Time.fixedDeltaTime);

        // Apply a tiny negative reward every step to encourage action
        if (MaxStep > 0) AddReward(-1f / MaxStep);
    }
    private void OnTriggerEnter(Collider other)
    {
        if (other.transform.CompareTag("heal"))
        {
            Debug.Log("healing potion");
            this.GetHealingPotion(other.gameObject);
        }
        if (other.transform.CompareTag("poison"))
        {
            Debug.Log("poison potion");
            this.GetPoisonPotion(other.gameObject);
        }
    }
    private void GetHealingPotion(GameObject potion)
    {
        AddReward(2f);
        this.potionArea.RemovePotion(potion);
        if (potionArea.PotionRemaining <= 0) EndEpisode();
    }
    private void GetPoisonPotion(GameObject potion)
    {
        AddReward(-1f);
        this.potionArea.RemovePotion(potion);
        if (potionArea.PotionRemaining <= 0) EndEpisode();
    }
    public override void Heuristic(in ActionBuffers actionsOut)
    {
        int forwardAction = 0;
        int turnAction = 0;
        if (Input.GetKey(KeyCode.W))
        {
            // move forward
            forwardAction = 1;
        }
        if (Input.GetKey(KeyCode.A))
        {
            // turn left
            turnAction = 1;
        }
        else if (Input.GetKey(KeyCode.D))
        {
            // turn right
            turnAction = 2;
        }

        // Put the actions into the array
        actionsOut.DiscreteActions.Array[0] = forwardAction;
        actionsOut.DiscreteActions.Array[1] = turnAction;
    }

}

 

Potion.yaml

behaviors:
    Character:
        trainer_type: ppo
        hyperparameters:
            batch_size: 128
            buffer_size: 2048
            learning_rate: 0.0003
            beta: 0.01
            epsilon: 0.2
            lambd: 0.95
            num_epoch: 3
            learning_rate_schedule: linear
        network_settings:
            normalize: false
            hidden_units: 256
            num_layers: 2
            vis_encode_type: simple
        reward_signals:
            extrinsic:
                gamma: 0.99
                strength: 1.0
        keep_checkpoints: 5
        max_steps: 1000000
        time_horizon: 128
        summary_freq: 5000
        threaded: true

--force는 이전에 저장된 id가 같아도 강제로 덮어씌운다. 

+영상 녹화 추가하기

728x90