Statistical Context Detection for Deep Lifelong Reinforcement Learning

  • 2024-05-29 13:44:41
  • Jeffery Dick, Saptarshi Nath, Christos Peridis, Eseoghene Benjamin, Soheil Kolouri, Andrea Soltoggio
  • 0

Abstract

Context detection involves labeling segments of an online stream of data asbelonging to different tasks. Task labels are used in lifelong learningalgorithms to perform consolidation or other procedures that preventcatastrophic forgetting. Inferring task labels from online experiences remainsa challenging problem. Most approaches assume finite and low-dimensionobservation spaces or a preliminary training phase during which task labels arelearned. Moreover, changes in the transition or reward functions can bedetected only in combination with a policy, and therefore are more difficult todetect than changes in the input distribution. This paper presents an approachto learning both policies and labels in an online deep reinforcement learningsetting. The key idea is to use distance metrics, obtained via optimaltransport methods, i.e., Wasserstein distance, on suitable latent action-rewardspaces to measure distances between sets of data points from past and currentstreams. Such distances can then be used for statistical tests based on anadapted Kolmogorov-Smirnov calculation to assign labels to sequences ofexperiences. A rollback procedure is introduced to learn multiple policies byensuring that only the appropriate data is used to train the correspondingpolicy. The combination of task detection and policy deployment allows for theoptimization of lifelong reinforcement learning agents without an oracle thatprovides task labels. The approach is tested using two benchmarks and theresults show promising performance when compared with related context detectionalgorithms. The results suggest that optimal transport statistical methodsprovide an explainable and justifiable procedure for online context detectionand reward optimization in lifelong reinforcement learning.

 

Quick Read (beta)

loading the full paper ...