Abstract
Contrastive loss is a powerful approach for representation learning, wherelarger batch sizes enhance performance by providing more negative samples tobetter distinguish between similar and dissimilar data. However, scaling batchsizes is constrained by the quadratic growth in GPU memory consumption,primarily due to the full instantiation of the similarity matrix. To addressthis, we propose a tile-based computation strategy that partitions thecontrastive loss calculation into arbitrary small blocks, avoiding fullmaterialization of the similarity matrix. Furthermore, we introduce amulti-level tiling strategy to leverage the hierarchical structure ofdistributed systems, employing ring-based communication at the GPU level tooptimize synchronization and fused kernels at the CUDA core level to reduce I/Ooverhead. Experimental results show that the proposed method scales batch sizesto unprecedented levels. For instance, it enables contrastive training of aCLIP-ViT-L/14 model with a batch size of 4M or 12M using 8 or 32 A800 80GBwithout sacrificing any accuracy. Compared to SOTA memory-efficient solutions,it achieves a two-order-of-magnitude reduction in memory while maintainingcomparable speed. The code will be made publicly available.