使用强化学习进行目标跟踪

如何应用

​ 目标跟踪中常常会遇到跟踪目标被遮挡、形变、出界、模糊等问题,但这些问题常常不会持续很久,那么解决这个问题的一种思路是调节学习速率,使模型少学习”有问题“的目标,从而使得模型不被污染,增强跟踪效果。

​ 而强化学习是训练一个Agent做决策。那么我们不妨把目标跟踪作为一个任务,何时进行模型更新交由Agent处理,确保跟踪的准确率。

​ 于是考虑将跟踪器输出的response map作为observation,将更新/不更新作为action,当跟踪框与真实框的交并比(IOU)高于0.7时记为跟踪成功,奖励1,反之惩罚1。

​ 考虑到当下跟踪器运行速度极慢,我的电脑配置不高,于是决定使用速度很快但准确率不高的MOSSE跟踪算法。由于MOSSE准确率不高,使得增强学习的效果能更加明显。

IMPLEMENTATION

建立环境

创建环境中最重要的两个函数step,reset

  • reset

    reset需要返回两个参数[InitialObservation, LoggedSignal] = CF_Reset(),其中InitialObservation为第一步的observation,在目标跟踪中初始化为二维高斯函数。

    LoggedSignal中存放需要传递的参数。

    reset函数中主要完成了两个任务:1、使用MOSSE算法根据第一帧的信息训练好模型,并将模型传递下去。2.随机选择一个图片序列用以训练。

  • step

    step根据当前的action,需要返回四个参数[NextObs,Reward,IsDone,LoggedSignals] = CF_Step(Action,LoggedSignals)

    step函数中主要完成了四个任务:

    • 根据上一帧得到的模型计算当前帧的目标位置,使用MOSSE算法计算当前帧的模型,并由Policy选择是否进行更新。
    • reward更新:IOU小于0.7的reward为-1,否则为1。
    • 判断episode是否结束:当IOU小于0.2时,判断为跟踪失败,episode结束;当跑完序列最后一帧时,episode结束。
    • 可视化跟踪结果。

定义好上面两个函数后,还需要定义observationaction这两个参数:

ObservationInfo = rlNumericSpec([163 115]); 
ObservationInfo.Name = 'CartPole States';

ActionInfo = rlFiniteSetSpec([0 1]);  
ActionInfo.Name = 'CartPole Action';

创建环境所具备的条件都聚齐了,使用rlFunctionEnv便可创建环境

env = rlFunctionEnv(ObservationInfo,ActionInfo,'CF_Step','CF_Reset');

运行上面代码后,MATLAB会自动调用validate程序进行校验,若没有报错,说明环境创建成功。

创建PG Agent

区别于上一篇文章,本次的任务会复杂许多,因此网络结构会设计得更复杂些:

actorNetwork = [
    imageInputLayer([163 115 1],'Normalization','none','Name','state')
    convolution2dLayer([5 5], 4, 'Padding', 2)
    reluLayer()
    maxPooling2dLayer(2, 'Stride', 3)
    convolution2dLayer([3 3], 8, 'Padding', 2)
    reluLayer()
    maxPooling2dLayer(1, 'Stride',1)
    fullyConnectedLayer(64)
    reluLayer()
    fullyConnectedLayer(2)
    softmaxLayer
    ];

再默认创建按actor即可。

训练Agent

当平均Reward高于200即可结束训练。

运行效果

下图为MOSSE效果

mosse

可见在目标被遮挡时,跟踪器会对遮挡的目标照单全收的学习,当遮挡面积较大时甚至可能对遮挡物进行学习,以至于最后跟踪遮挡物去了。

下图为使用了强化学习的MOSSE

mosse_RL

经过强化学习后,跟踪器能准确地更新模型,使得模型不被污染。

代码可以从此处获得。