Abstract
Sequential social dilemmas pose a significant challenge in the field ofmulti-agent reinforcement learning (MARL), requiring environments thataccurately reflect the tension between individual and collective interests.Previous benchmarks and environments, such as Melting Pot, provide anevaluation protocol that measures generalization to new social partners invarious test scenarios. However, running reinforcement learning algorithms intraditional environments requires substantial computational resources. In thispaper, we introduce SocialJax, a suite of sequential social dilemmaenvironments and algorithms implemented in JAX. JAX is a high-performancenumerical computing library for Python that enables significant improvements inoperational efficiency. Our experiments demonstrate that the SocialJax trainingpipeline achieves at least 50\texttimes{} speed-up in real-time performancecompared to Melting Pot RLlib baselines. Additionally, we validate theeffectiveness of baseline algorithms within SocialJax environments. Finally, weuse Schelling diagrams to verify the social dilemma properties of theseenvironments, ensuring that they accurately capture the dynamics of socialdilemmas.