Posts MATLAB, MDP에서 강화학습 실습
Post
Cancel

MATLAB, MDP에서 강화학습 실습

Preview Image

이번 포스팅은 MATLAB 을 이용한 강화학습(Reinforcement Learning) 실습입니다. 제가 예전에 RLCode 팀(Reinforcement Learning Code Team) 의 코드를 천천히 봤는데, 거기서도 이론과 실제 구현하는 두 수준의 간극이 차이가 난다고 얘기했습니다. 저도 마찬가지인데요. 그래서 MATLAB, Python, Keras 를 적절한 때에 사용하여 실습할 예정입니다.

1. Summary

본격적인 실습 전에 이전의 두 포스팅(강화학습 개요, MDP)에 대해서 간단하게 정리하고 넘어가겠습니다.

1.1. Reinforcement Learning

먼저 강화학습을 통해 얻는 최종 산출물은 최적 정책 함수(Optimal Policy) Function입니다. 최적의 정책, 최적의 가치 함수들이 핵심이죠. 그리고 강화학습에서는 Agent환경(Environmet) 간의 상호작용이 존재하고 Agent 는 보상을 통해 어떤 행동(Action)을 할 지 결정하죠. 우리가 자전거를 타고, 연필을 쓰고, 흥분해서 하는 주먹질에도 어떤 사전지식 없이 할 수 있습니다. 다만, 숙달된 것의 차이만 있을 뿐입니다.

자전거와 주먹

1.2. The Problem people face

풀고자하는 문제들을 순차적으로 행동을 정의한다면, MDP(Markov Decision Process) 로서 정의할 수 있습니다. MDP 에는 상태(State), 보상(Reward), 행동(Action), 정책(Policy) 이 포함되어 있죠. 기존에 강화학습 은 한정된 상태(State) 에 대해 어려움이 있었습니다. 하지만 인공지능(Artificial Intelligence)신경망(Neural Network) 을 통해서 방대한 상태 에 대해서 학습이 가능합니다. 어지간한 문제에 대해서 마음껏 다뤄볼 수 있겠죠. 이에 대해 자세히 설명하기 위해서는 다이나믹 프로그래밍 이 필요하고 이는 다음에 설명하겠습니다. 그 전에, Grid World 를 익혀보죠.

2. Train Reinforcement Learning Agent in Basic Grid World

Grid World

David Silver 교수님의 강화학습(Reinforcement Learning) 강의의 Planning by Dynamic programming 에서도 Grid World 를 이용한 예시를 들었고, 전통적인 라이트 게임류(미로찾기, 테트리스, 아타리, 틀린그림찾기 등)에서 직, 간접적으로 Grid World 를 사용하고 있습니다. 그래서 이번 강화학습 실습이 여러분들께도 도움이 됐으면 좋겠습니다. 먼저 이번 포스팅의 소스는 MATLAB 도움말 이며, 이를 한글로 번역했습니다.

이번 실습은 Q-Learning 그리고 SARSA agents 를 학습하여 Grid world 의 환경을 어떻게 해결하는 지 보여줍니다. Q-Learning AgentsSARSA Agents를 알고싶다면, 링크를 클릭하여 주세요. (번역된 내용을 밑에 따로 적어놨습니다.) 또한 아래는 MATLAB 에서 제공하는 Agents 입니다.

AgentTypeAction sauce
Q-Learning Agents (Q)Value-BasedDiscrete
Deep Q-Network Agents(DQN)Value-BasedDiscrete
SARSA AgentsValue-BasedDiscrete
Policy Gradient Agents (PG)Policy-BasedDiscrete or continuous
Actor-Critic Agents (AC)Actor-CriticDiscrete or continuous
Proximal Policy Optimization Agents (PPO)Actor-CriticDiscrete or continuous
Deep Deterministic Policy Gradient Agents (DDPG)Actor-CriticContinuous
Twin-Delayed Deep Deterministic Policy Gradient Agents (TD3)Actor-CriticContinuous
Soft Actor-Critic Agents (SAC)Actor-CriticContinuous

Grid World

Grid World Environment 는 아래 조건들을 따르고 있습니다.


  • 5x5 로 구성된 맵과 동,서,남,북으로 움질일 수 있다.
  • Agent 는 2행 1열에서 출발한다.
  • Agent 가 끝지점(terminal state) 에 도달하면 +10의 보상을 받는다.
  • [2,4]는 순간이동 구간(special jump)이라 [4,4]로 바로 이동할 수 있으며 +5의 보상을 받는다.
  • 검정으로된 구간은 장애물 구간으로 막혀있다. [3,3] [3,4] [3,5], [4,3]
  • 그 외에 모든 행동들은 -1의 패널티가 주어진다.

2.1. Create Grid World Environment

1
env = rlPredefinedEnv("BasicGridWorld");

MATLAB 에는 사전에 정의된 환경들이 있습니다. 그 중에서 BasicGridWorld 를 꺼낼 건데요. 그 외의 환경은 아래 표와 같습니다.


MATLAB Environment

  • 'BasicGridWorld'
  • 'CartPole-Discrete'
  • 'CartPole-Continuous'
  • 'DoubleIntegrator-Discrete'
  • 'DoubleIntegrator-Continuous'
  • 'SimplePendulumWithImage-Discrete'
  • 'SimplePendulumWithImage-Continuous'
  • 'WaterFallGridWorld-Stochastic'
  • 'WaterFallGridWorld-Deterministic'

Simulink Environment

  • 'SimplePendulumModel-Discrete'
  • 'SimplePendulumModel-Continuous'
  • 'CartPoleSimscapeModel-Discrete'
  • 'CartPoleSimscapeModel-Continuous'

1
env.ResetFcn = @() 2;

이번 실습에서는 초기 상태가 [2,1]에서 늘 시작해야합니다. 그래서 Initial agent state 의 값을 리턴해주는 함수를 만들었는데요. 이 함수는 매 에피소드, 시뮬레이션 마다 처음 시작할때 호출됩니다. 완전 처음 생성되자마자 [1,1]로 지정된 상태를 그저 2로 옮겨주는 역할일 뿐입니다.

1
rng(0)

랜덤 값을 추출할 때, seed 를 먼저 설정합니다. rng(0) 인 경우에는 seed 가 0인 상태로 메르센 트위스트 알고리즘을 이용하여 난수생성기를 설정합니다.

2.2. Create Q-Learning agent

환경을 만들었다면, Agent 를 생성할 차례입니다. 그 전에 Q-Learning algorithm 을 간단하게 봅시다.

먼저 Q-Learning algorithm 은 모델을 사용하지 않는 방식(model-free)입니다. Q 라는 의미에서 알 수 있게, Q-Learning agent 는 가치 기반 강화학습 agent 입니다. 벨만 방정식 에서 보는 것처럼, 현재가치 + 미래가치를 계산하여 앞으로의 행동을 결정합니다. 이러한 특성에 의해서 Online, Off-Policy 를 가리지 않죠. (On-Policy : 스스로 시행착오를 겪음, Off-Policy: 다른 policy의 시행착오를 학습함)

Q-Learning agent 는 아래의 observation 과 action space 를 따르며 환경(environment) 속에서 학습합니다.

Observation SpaceAction Space
Continuous or discrete (연속 혹은 이산)Discrete (이산)

Critic 표현식(미래가치를 정해주는)은 아래를 따릅니다.

CriticActor
Q-value function critic Q(S,A), which you create usingrlQValueRepresentationQ agents do not use an actor.

좀 더 자세한 설명은 다음 포스팅에 쓸 듯합니다. 혹시 그 전에, 한글로 번역될 수 있으니 미리 링크를 남겨놓겠습니다.

1
2
3
qTable = rlTable(getObservationInfo(env),getActionInfo(env));
qRepresentation = rlQValueRepresentation(qTable,getObservationInfo(env),getActionInfo(env));
qRepresentation.Options.LearnRate = 1;

Q-Learning agent 를 만들때 첫번째는 Q table 을 가져오는 것입니다. Q table 에는 Grid World 환경에 대한 observation, action spec 이 들어가 있고, Q representationQ table 을 넣은 다음 Learning Rate(학습률) 은 1로 설정합니다.

1
2
3
agentOpts = rlQAgentOptions;
agentOpts.EpsilonGreedyExploration.Epsilon = .04;
qAgent = rlQAgent(qRepresentation,agentOpts);

그 다음엔 Epsilon Greedy explorationLearning option 으로 설정합니다.

2.3. Train Q-Learning agent

1
2
3
4
5
6
trainOpts = rlTrainingOptions;
trainOpts.MaxStepsPerEpisode = 50;
trainOpts.MaxEpisodes= 200;
trainOpts.StopTrainingCriteria = "AverageReward";
trainOpts.StopTrainingValue = 11;
trainOpts.ScoreAveragingWindowLength = 30;

Train option 을 설정해줍니다.

  • 최대 에피소드는 200개에 한 에피소드에는 50번의 스텝으로 제한합니다.
  • 종료조건은 연속으로 30회 이상 평균 누적 보상(10점 넘어서)을 받으면 안정적이라 판단되고 종료합니다.
1
2
3
4
5
6
7
8
9
doTraining = true;

if doTraining
    % Train the agent.
    trainingStats = train(qAgent,env,trainOpts);
else
    % Load the pretrained agent for the example.
    load('basicGWQAgent.mat','qAgent')
end

본격적으로 학습하는 코드입니다. 몇 분정도 소요된다는데, PC 사양마다 다름을 알고 계시고 시간을 절약하기 위해 학습이 필요없는 경우엔 doTrainingfalse 로 설정하시면 됩니다.

Grid World

2.4. Validate Q-Learning agent

1
2
3
plot(env)
env.Model.Viewer.ShowTrace = true;
env.Model.Viewer.clearTrace;

결론적으로 학습이 전부가 아니라 검증을 해야합니다. 그러기 위해서는 Train 된 환경에서 agent 를 시뮬레이션합니다.

시뮬레이션을 실행하기 전에 시각화를 구성합니다.

1
sim(qAgent,env)

MATLAB, Simulation 하면 기가막힌 툴이죠? 역시나 sim 이라는 function 이 존재합니다.

Grid World

2.5. Create SARSA agent

이번에는 SARSA Agent 를 학습할 차례입니다. Q-Learning agent 처럼 SARSA agentmodel-free 형태의 알고리즘입니다. 그러니 흐름이 비슷하겠죠?

1
2
3
agentOpts = rlSARSAAgentOptions;
agentOpts.EpsilonGreedyExploration.Epsilon = 0.04;
sarsaAgent = rlSARSAAgent(qRepresentation,agentOpts);

심지어 Epsilon Greedy Exploration 을 쓰는 것도 같습니다.

1
2
3
4
5
6
7
8
9
doTraining = false;

if doTraining
    % Train the agent.
    trainingStats = train(sarsaAgent,env,trainOpts);
else
    % Load the pretrained agent for the example.
    load('basicGWSarsaAgent.mat','sarsaAgent')
end

학습하구요.

2.6. Validate SARSA agent

1
2
3
4
plot(env)
env.Model.Viewer.ShowTrace = true;
env.Model.Viewer.clearTrace;
sim(sarsaAgent,env)

시각화하고 시뮬레이션하며 끝납니다.

This post is licensed under CC BY 4.0 by the author.