在MATLAB中使用强化学习(Reinforcement Learning)

环境简介


​ MATLAB在2019a版本中增加了Reinforcement Learning Toolbox,这使得在MATLAB中编写强化学习代码更加简单、方便。在开始之前,请确保MATLAB为2019以及更新的版本,确保已安装Deep Learning Toolbox以及Reinforcement Learning Toolbox。

​ MathWorks提供了详细的用户指南,想要了解更多内容可以参考用户指南

什么是强化学习


​ 强化学习是一种以目标为导向的算法,计算机可以通过强化学习在未知的动态环境中学习并完成任务。强化学习能让机器在没有人的参与下通过最大化累计rewards完成复杂的任务。下图为强化学习基本框架:

YYzEA.png

​ 强化学习的最终目标是在未知的环境中训练一个agent,这个agent接受来自环境的observation与reward并对环境输出action,其中的reward用来表示当前动作对于任务目标的贡献。

​ agent由两部分组成,分别为policy和reinforcement learning algorithm。

  • policy部分基于从环境中得到的observation做出action,通常来说 ,policy是一个由神经网络构成的可以进行调参的函数估计器。
  • reinforcement learning algorithm部分基于observation,action,reward对policy的参数进行调节,其目标是找到一个最优的policy最大化累计reward。

MATLAB强化学习入门


​ MATLAB提供了数个预设的环境供初学者学习,这里以PG Agent平衡倒立摆系统作为例子。

​ 下图为MATLAB中的倒立摆环境,环境中有一个可以水平移动的蓝色方块,和一个一端固定在蓝色方块上的黄杆,训练目标是通过移动蓝色方块,使得黄杆不会落下来。

YY34B.png

创建预设环境

可以使用rlPredefinedEnv调用MATLAB预设的环境。环境中包含了reset和step两个函数,这两个函数描述了环境的功能细节。


env = rlPredefinedEnv("CartPole-Discrete"); nobsInfo = getObservationInfo(env); numObservations = obsInfo.Dimension(1); actInfo = getActionInfo(env); rng(0);

创建Agent

接着需要创建PG Agent,首先需要创建Policy的神经网络结构,该网络结构决定了强化学习表现的上限,再使用rlStochasticActorRepresentation对该网络进行representation,接着按默认选项创建PGAgent即可。


actorNetwork = [ imageInputLayer([numObservations 1 1],'Normalization','none','Name','state') fullyConnectedLayer(2,'Name','fc') softmaxLayer('Name','actionProb') ]; actorOpts = rlRepresentationOptions('LearnRate',1e-2,'GradientThreshold',1); actor = rlStochasticActorRepresentation(actorNetwork,obsInfo,actInfo,'Observation',{'state'},actorOpts); agent = rlPGAgent(actor);

训练Agent

设置训练参数:选择最多训练1000个episode,每个episode最多运行200步,打开Episode Manager窗口,当平均reward达到195使停止训练,其中的平均reward只计算最多100个reward的平均。


trainOpts = rlTrainingOptions(... 'MaxEpisodes', 1000, ... 'MaxStepsPerEpisode', 200, ... 'Plots','training-progress',... 'StopTrainingCriteria','AverageReward',... 'StopTrainingValue',195,... 'ScoreAveragingWindowLength',100); plot(env); trainingStats = train(agent,env,trainOpts); save(opt.SaveAgentDirectory + "/finalAgent.mat",'agent')

训练完后可以看到如下窗口

YYrhn.png

黄色曲线为平均reward,蓝色曲线为当前reward,可见,随着训练平均reward越来越高,最后抵达195。刚开始训练时效果如下:

YYL3G.gif

训练完后效果如下:

YYjcT.gif