본문 바로가기
유니티/mlAgents

유니티 머신러닝 개발 ML Agents 12편, 목표 찾기 예제 개선. 6 타겟 출동검증방슥을 교체, 거리비교에서 트리거 사용하기

by NGVI 2021. 4. 12.

유니티 머신러닝 개발 ML Agents 12편, 목표 찾기 예제 개선. 6 타겟 출동검증방슥을 교체, 거리비교에서 트리거 사용하기

저번 작업에서는 관측시스템을 바꾸었죠.

수치로 포지션과 이동량등을 제공하다, mlagent가 제공하는 타겟 센서로 교체하였고,

교육이후 보여주는 퍼포먼스가 이전 수치로만 교육시키던 상황보다 눈으로 봐도 상당히 좋아진것을 볼수 있었습니다.

 

gRollerAgent.cs의 

public override void OnActionReceived(ActionBuffers actionBuffers)

함수를 보면

public override void OnActionReceived(ActionBuffers actionBuffers)
    {
        // Rewards
        float distanceToTarget = Vector3.Distance(this.transform.localPosition, Target.localPosition);

        //타겟을 찻을시 리워드점수를 주고, 에피소드를 종료시킨다.
        // Reached target
        if (distanceToTarget < 1.42f)
        {
            SetReward(1.0f);
            EndEpisode();
        }

        MoveAgent(actionBuffers.DiscreteActions);
    }

타겟과 나의 거리를 직접 측정하여 충돌을 검사합니다.

 

해당 방식을 일단 유니티가 지원하는 트리거 시스템으로 교체합니다.

 

타켓 충돌 체크방식 변경, 트리거사용하기

gTarget.cs 를 새로 생성합니다.

using System.Collections;
using System.Collections.Generic;
using UnityEngine;

public class gTarget : MonoBehaviour
{
    private void OnTriggerEnter(Collider other)
    {
        if (other.tag == "agent")
        {
            gRollerAgent ra = other.GetComponent<gRollerAgent>();
            if (null != ra)
            {
                ra.EnteredTarget();
                this.gameObject.SetActive(false);
            }
        }
    }
}

위의 내용으로 코딩합니다.

간단하게 트리거에 들어오면 에이전트인지 검사해서 에이전트에 알리는 코드입니다.

에이전트쪽도 수정해야 됩니다.

 

gRollerAgent.cs 쪽 코드도 수정합니다.

전체소스를 공개하고 아래 추가소스만 따로 설명하겠습니다.

using System.Collections;
using System.Collections.Generic;
using UnityEngine;

using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
//mlAgent 사용시 포함해야 됨

public class gRollerAgent : Agent
{
    Rigidbody rBody;
    void Start()
    {
        rBody = GetComponent<Rigidbody>();
    }

    public Transform Target;

    public GameObject viewModel = null;

    int Pointlast = 0;
    int Point = 0;


    public override void OnEpisodeBegin()
    {
        //새로운 애피소드 시작시, 다시 에이전트의 포지션의 초기화
        // If the Agent fell, zero its momentum
        if (this.transform.localPosition.y < 0) //에이전트가 floor 아래로 떨어진 경우 추가 초기화
        {
            this.rBody.angularVelocity = Vector3.zero;
            this.rBody.velocity = Vector3.zero;
            this.transform.localPosition = new Vector3(0, 0.5f, 0);
        }

        //타겟의 위치는 에피소드 시작시 랜덤하게 변경된다.
        // Move the target to a new spot
        float rx = 0;
        float rz = 0;

        rx = Random.value * 16 - 8;
        rz = Random.value * 16 - 8;

        Target.localPosition = new Vector3(rx,
                                           0.5f,
                                           rz);

        Target.gameObject.SetActive(true);

        Pointlast = 0;
        Point = 0;
    }

    /// <summary>
    /// 강화학습을 위한, 강화학습을 통한 행동이 결정되는 곳
    /// </summary>
    public float forceMultiplier = 10;
    float m_ForwardSpeed = 1.0f;

    public void EnteredTarget()
    {
        Point++;
    }

    public override void OnActionReceived(ActionBuffers actionBuffers)
    {
        if (Pointlast < Point)
        {
            SetReward(1.0f);
            EndEpisode();
        }

        MoveAgent(actionBuffers.DiscreteActions);
    }
    public void MoveAgent(ActionSegment<int> act)
    {
        var dirToGo = Vector3.zero;
        var rotateDir = Vector3.zero;

        var forwardAxis = act[0];
        var rotateAxis = act[1];

        switch (forwardAxis)
        {
            case 1:
                dirToGo = transform.forward * m_ForwardSpeed;
                break;
        }

        switch (rotateAxis)
        {
            case 1:
                rotateDir = transform.up * -1f;
                break;
            case 2:
                rotateDir = transform.up * 1f;
                break;
        }

        transform.Rotate(rotateDir, Time.deltaTime * 100f);
        rBody.AddForce(dirToGo * forceMultiplier, ForceMode.VelocityChange);
    }

    /// <summary>
    /// 해당 함수는 직접조작 혹은 규칙성있는 코딩으로 조작시키기 위한 함수
    /// </summary>
    /// <param name="actionsOut"></param>

    public override void Heuristic(in ActionBuffers actionsOut)
    {
        var discreteActionsOut = actionsOut.DiscreteActions;
        discreteActionsOut.Clear();
        //forward
        if (Input.GetKey(KeyCode.W))
        {
            discreteActionsOut[0] = 1;
        }

        //rotate
        if (Input.GetKey(KeyCode.A))
        {
            discreteActionsOut[1] = 1;
        }
        if (Input.GetKey(KeyCode.D))
        {
            discreteActionsOut[1] = 2;
        }
    }
}

전체소스 아래는 그중 변경부분만,

 

변경부분 설명

//변경 추가부분만 설명
int Pointlast = 0;
int Point = 0;
//타겟이 트리거를 알려주는 행위를 포인트로 계산함
    
public override void OnEpisodeBegin() //내부에 추가된 기능
    
Target.gameObject.SetActive(true);
Pointlast = 0;
Point = 0;
//점수 초기화와, 타겟 재 활성화
    
public void EnteredTarget()
{
    Point++;
}
//타겟 트리거에 에이전트가 들어오면 알려주는 함수
    
public override void OnActionReceived(ActionBuffers actionBuffers)
{
        if (Pointlast < Point)
        {
            SetReward(1.0f);
            EndEpisode();
        }

        MoveAgent(actionBuffers.DiscreteActions);
}
//기존 거리체크 로직이 트리거 시스템으로 교체됨

 

이제 다시 유니티에디터로 돌아옵니다.

 

에이전트 오브젝트 태그 달기

RollerAgent 에이전트의 태그를 agent로 설정해줍니다. 없으면 만드시면 됩니다.(태그를)

 

타겟에 우리가 생성한 GTarget.cs를 달아줍니다.

실행해봅니다.

 

현재까지 내용은 거리체크로직을 트리거로 교체한 내용입니다.

 

구동의 거의 똑같이 이루어집니다.

 

아니 그냥 하면 되지 왜 이렇게 번거롭게 바꾸느냐고 물으실수 있습니다.

 

일단 다켓을 여러개를 관리해 보려 하는데, 그것을 하기 위한 기반작업이라고 생각해주시면 됩니다.

 

봐주셔서 감사합니다.

댓글