Masked Contrastive Representation Learning for Reinforcement Learning

  • 2020-10-15 02:00:10
  • Jinhua Zhu, Yingce Xia, Lijun Wu, Jiajun Deng, Wengang Zhou, Tao Qin, Houqiang Li
  • 5

Abstract

Improving sample efficiency is a key research problem in reinforcementlearning (RL), and CURL, which uses contrastive learning to extract high-levelfeatures from raw pixels of individual video frames, is an efficientalgorithm~\citep{srinivas2020curl}. We observe that consecutive video frames ina game are highly correlated but CURL deals with them independently. To furtherimprove data efficiency, we propose a new algorithm, masked contrastiverepresentation learning for RL, that takes the correlation among consecutiveinputs into consideration. In addition to the CNN encoder and the policynetwork in CURL, our method introduces an auxiliary Transformer module toleverage the correlations among video frames. During training, we randomly maskthe features of several frames, and use the CNN encoder and Transformer toreconstruct them based on the context frames. The CNN encoder and Transformerare jointly trained via contrastive learning where the reconstructed featuresshould be similar to the ground-truth ones while dissimilar to others. Duringinference, the CNN encoder and the policy network are used to take actions, andthe Transformer module is discarded. Our method achieves consistentimprovements over CURL on $14$ out of $16$ environments from DMControl suiteand $21$ out of $26$ environments from Atari 2600 Games. The code is availableat https://github.com/teslacool/m-curl.

 

Quick Read (beta)

loading the full paper ...