Lowering PyTorch's Memory Consumption for Selective Differentiation

  • 2024-08-21 07:21:52
  • Samarth Bhatia, Felix Dangel
  • 0

Abstract

Memory is a limiting resource for many deep learning tasks. Beside the neuralnetwork weights, one main memory consumer is the computation graph built up byautomatic differentiation (AD) for backpropagation. We observe that PyTorch'scurrent AD implementation neglects information about parameterdifferentiability when storing the computation graph. This information isuseful though to reduce memory whenever gradients are requested for a parametersubset, as is the case in many modern fine-tuning tasks. Specifically, inputsto layers that act linearly in their parameters (dense, convolution, ornormalization layers) can be discarded whenever the parameters are marked asnon-differentiable. We provide a drop-in, differentiability-agnosticimplementation of such layers and demonstrate its ability to reduce memorywithout affecting run time.

 

Quick Read (beta)

loading the full paper ...