이번 포스팅은 MATLAB
을 이용한 강화학습(Reinforcement Learning)
실습입니다. 제가 예전에 RLCode 팀(Reinforcement Learning Code Team)
의 코드를 천천히 봤는데, 거기서도 이론과 실제 구현하는 두 수준의 간극이 차이가 난다고 얘기했습니다. 저도 마찬가지인데요. 그래서 MATLAB, Python, Keras
를 적절한 때에 사용하여 실습할 예정입니다.
1. Summary
본격적인 실습 전에 이전의 두 포스팅(강화학습 개요, MDP)에 대해서 간단하게 정리하고 넘어가겠습니다.
1.1. Reinforcement Learning
먼저 강화학습을 통해 얻는 최종 산출물은 최적 정책
함수(Optimal Policy)
Function입니다. 최적의 정책, 최적의 가치 함수들이 핵심이죠. 그리고 강화학습에서는 Agent
와 환경(Environmet)
간의 상호작용이 존재하고 Agent
는 보상을 통해 어떤 행동(Action)
을 할 지 결정하죠. 우리가 자전거를 타고, 연필을 쓰고, 흥분해서 하는 주먹질에도 어떤 사전지식 없이 할 수 있습니다. 다만, 숙달된 것의 차이만 있을 뿐입니다.
1.2. The Problem people face
풀고자하는 문제들을 순차적으로 행동을 정의한다면, MDP(Markov Decision Process)
로서 정의할 수 있습니다. MDP
에는 상태(State)
, 보상(Reward)
, 행동(Action)
, 정책(Policy)
이 포함되어 있죠. 기존에 강화학습
은 한정된 상태(State)
에 대해 어려움이 있었습니다. 하지만 인공지능(Artificial Intelligence)
의 신경망(Neural Network)
을 통해서 방대한 상태
에 대해서 학습이 가능합니다. 어지간한 문제에 대해서 마음껏 다뤄볼 수 있겠죠. 이에 대해 자세히 설명하기 위해서는 다이나믹 프로그래밍
이 필요하고 이는 다음에 설명하겠습니다. 그 전에, Grid World
를 익혀보죠.
2. Train Reinforcement Learning Agent in Basic Grid World
David Silver
교수님의 강화학습(Reinforcement Learning)
강의의 Planning by Dynamic programming
에서도 Grid World
를 이용한 예시를 들었고, 전통적인 라이트 게임류(미로찾기, 테트리스, 아타리, 틀린그림찾기 등)에서 직, 간접적으로 Grid World
를 사용하고 있습니다. 그래서 이번 강화학습 실습이 여러분들께도 도움이 됐으면 좋겠습니다. 먼저 이번 포스팅의 소스는 MATLAB 도움말 이며, 이를 한글로 번역했습니다.
이번 실습은 Q-Learning 그리고 SARSA agents
를 학습하여 Grid world
의 환경을 어떻게 해결하는 지 보여줍니다. Q-Learning Agents와 SARSA Agents를 알고싶다면, 링크를 클릭하여 주세요. (번역된 내용을 밑에 따로 적어놨습니다.) 또한 아래는 MATLAB
에서 제공하는 Agents
입니다.
Agent | Type | Action sauce |
---|---|---|
Q-Learning Agents (Q) | Value-Based | Discrete |
Deep Q-Network Agents(DQN) | Value-Based | Discrete |
SARSA Agents | Value-Based | Discrete |
Policy Gradient Agents (PG) | Policy-Based | Discrete or continuous |
Actor-Critic Agents (AC) | Actor-Critic | Discrete or continuous |
Proximal Policy Optimization Agents (PPO) | Actor-Critic | Discrete or continuous |
Deep Deterministic Policy Gradient Agents (DDPG) | Actor-Critic | Continuous |
Twin-Delayed Deep Deterministic Policy Gradient Agents (TD3) | Actor-Critic | Continuous |
Soft Actor-Critic Agents (SAC) | Actor-Critic | Continuous |
Grid World Environment
는 아래 조건들을 따르고 있습니다.
- 5x5 로 구성된 맵과 동,서,남,북으로 움질일 수 있다.
Agent
는 2행 1열에서 출발한다.Agent
가 끝지점(terminal state)
에 도달하면 +10의 보상을 받는다.- [2,4]는 순간이동 구간
(special jump)
이라 [4,4]로 바로 이동할 수 있으며 +5의 보상을 받는다. - 검정으로된 구간은 장애물 구간으로 막혀있다. [3,3] [3,4] [3,5], [4,3]
- 그 외에 모든 행동들은 -1의 패널티가 주어진다.
2.1. Create Grid World Environment
1
env = rlPredefinedEnv("BasicGridWorld");
MATLAB
에는 사전에 정의된 환경들이 있습니다. 그 중에서 BasicGridWorld
를 꺼낼 건데요. 그 외의 환경은 아래 표와 같습니다.
MATLAB Environment
'BasicGridWorld'
'CartPole-Discrete'
'CartPole-Continuous'
'DoubleIntegrator-Discrete'
'DoubleIntegrator-Continuous'
'SimplePendulumWithImage-Discrete'
'SimplePendulumWithImage-Continuous'
'WaterFallGridWorld-Stochastic'
'WaterFallGridWorld-Deterministic'
Simulink Environment
'SimplePendulumModel-Discrete'
'SimplePendulumModel-Continuous'
'CartPoleSimscapeModel-Discrete'
'CartPoleSimscapeModel-Continuous'
1
env.ResetFcn = @() 2;
이번 실습에서는 초기 상태가 [2,1]에서 늘 시작해야합니다. 그래서 Initial agent state
의 값을 리턴해주는 함수를 만들었는데요. 이 함수는 매 에피소드, 시뮬레이션 마다 처음 시작할때 호출됩니다. 완전 처음 생성되자마자 [1,1]로 지정된 상태를 그저 2로 옮겨주는 역할일 뿐입니다.
1
rng(0)
랜덤 값을 추출할 때, seed
를 먼저 설정합니다. rng(0)
인 경우에는 seed
가 0인 상태로 메르센 트위스트 알고리즘을 이용하여 난수생성기를 설정합니다.
2.2. Create Q-Learning agent
환경을 만들었다면, Agent
를 생성할 차례입니다. 그 전에 Q-Learning algorithm
을 간단하게 봅시다.
먼저 Q-Learning algorithm
은 모델을 사용하지 않는 방식(model-free)
입니다. Q
라는 의미에서 알 수 있게, Q-Learning agent
는 가치 기반 강화학습 agent
입니다. 벨만 방정식
에서 보는 것처럼, 현재가치 + 미래가치를 계산하여 앞으로의 행동을 결정합니다. 이러한 특성에 의해서 Online, Off-Policy
를 가리지 않죠. (On-Policy : 스스로 시행착오를 겪음, Off-Policy: 다른 policy의 시행착오를 학습함)
Q-Learning agent
는 아래의 observation 과 action space
를 따르며 환경(environment)
속에서 학습합니다.
Observation Space | Action Space |
---|---|
Continuous or discrete (연속 혹은 이산) | Discrete (이산) |
Critic
표현식(미래가치를 정해주는)은 아래를 따릅니다.
Critic | Actor |
---|---|
Q-value function critic Q(S,A), which you create usingrlQValueRepresentation | Q agents do not use an actor. |
좀 더 자세한 설명은 다음 포스팅에 쓸 듯합니다. 혹시 그 전에, 한글로 번역될 수 있으니 미리 링크를 남겨놓겠습니다.
1
2
3
qTable = rlTable(getObservationInfo(env),getActionInfo(env));
qRepresentation = rlQValueRepresentation(qTable,getObservationInfo(env),getActionInfo(env));
qRepresentation.Options.LearnRate = 1;
Q-Learning agent
를 만들때 첫번째는 Q table
을 가져오는 것입니다. Q table
에는 Grid World
환경에 대한 observation, action spec
이 들어가 있고, Q representation
은 Q table
을 넣은 다음 Learning Rate(학습률)
은 1로 설정합니다.
1
2
3
agentOpts = rlQAgentOptions;
agentOpts.EpsilonGreedyExploration.Epsilon = .04;
qAgent = rlQAgent(qRepresentation,agentOpts);
그 다음엔 Epsilon Greedy exploration
을 Learning option
으로 설정합니다.
2.3. Train Q-Learning agent
1
2
3
4
5
6
trainOpts = rlTrainingOptions;
trainOpts.MaxStepsPerEpisode = 50;
trainOpts.MaxEpisodes= 200;
trainOpts.StopTrainingCriteria = "AverageReward";
trainOpts.StopTrainingValue = 11;
trainOpts.ScoreAveragingWindowLength = 30;
Train option
을 설정해줍니다.
- 최대 에피소드는 200개에 한 에피소드에는 50번의 스텝으로 제한합니다.
- 종료조건은 연속으로 30회 이상 평균 누적 보상(10점 넘어서)을 받으면 안정적이라 판단되고 종료합니다.
1
2
3
4
5
6
7
8
9
doTraining = true;
if doTraining
% Train the agent.
trainingStats = train(qAgent,env,trainOpts);
else
% Load the pretrained agent for the example.
load('basicGWQAgent.mat','qAgent')
end
본격적으로 학습하는 코드입니다. 몇 분정도 소요된다는데, PC 사양마다 다름을 알고 계시고 시간을 절약하기 위해 학습이 필요없는 경우엔 doTraining
을 false
로 설정하시면 됩니다.
2.4. Validate Q-Learning agent
1
2
3
plot(env)
env.Model.Viewer.ShowTrace = true;
env.Model.Viewer.clearTrace;
결론적으로 학습이 전부가 아니라 검증을 해야합니다. 그러기 위해서는 Train
된 환경에서 agent
를 시뮬레이션합니다.
시뮬레이션을 실행하기 전에 시각화를 구성합니다.
1
sim(qAgent,env)
MATLAB
, Simulation
하면 기가막힌 툴이죠? 역시나 sim
이라는 function
이 존재합니다.
2.5. Create SARSA agent
이번에는 SARSA Agent
를 학습할 차례입니다. Q-Learning agent
처럼 SARSA agent
도 model-free
형태의 알고리즘입니다. 그러니 흐름이 비슷하겠죠?
1
2
3
agentOpts = rlSARSAAgentOptions;
agentOpts.EpsilonGreedyExploration.Epsilon = 0.04;
sarsaAgent = rlSARSAAgent(qRepresentation,agentOpts);
심지어 Epsilon Greedy Exploration
을 쓰는 것도 같습니다.
1
2
3
4
5
6
7
8
9
doTraining = false;
if doTraining
% Train the agent.
trainingStats = train(sarsaAgent,env,trainOpts);
else
% Load the pretrained agent for the example.
load('basicGWSarsaAgent.mat','sarsaAgent')
end
학습하구요.
2.6. Validate SARSA agent
1
2
3
4
plot(env)
env.Model.Viewer.ShowTrace = true;
env.Model.Viewer.clearTrace;
sim(sarsaAgent,env)
시각화하고 시뮬레이션하며 끝납니다.