Abstract
The scaling of Large Language Models (LLMs) for retrieval-based tasks,particularly in Retrieval Augmented Generation (RAG), faces significant memoryconstraints, especially when fine-tuning extensive prompt sequences. Currentopen-source libraries support full-model inference and fine-tuning acrossmultiple GPUs but fall short of accommodating the efficient parameterdistribution required for retrieved context. Addressing this gap, we introducea novel framework for PEFT-compatible fine-tuning of Llama-2 models, leveragingdistributed training. Our framework uniquely utilizes JAX's just-in-time (JIT)compilation and tensor-sharding for efficient resource management, therebyenabling accelerated fine-tuning with reduced memory requirements. Thisadvancement significantly improves the scalability and feasibility offine-tuning LLMs for complex RAG applications, even on systems with limited GPUresources. Our experiments show more than 12x improvement in runtime comparedto Hugging Face/DeepSpeed implementation with four GPUs while consuming lessthan half the VRAM per GPU.