Spike-based causal inference for weight alignment

  • 2019-10-03 19:07:58
  • Jordan Guerguiev, Konrad P. Kording, Blake A. Richards
  • 20

Abstract

In artificial neural networks trained with gradient descent, the weights usedfor processing stimuli are also used during backward passes to calculategradients. For the real brain to approximate gradients, gradient informationwould have to be propagated separately, such that one set of synaptic weightsis used for processing and another set is used for backward passes. Thisproduces the so-called "weight transport problem" for biological models oflearning, where the backward weights used to calculate gradients need to mirrorthe forward weights used to process stimuli. This weight transport problem hasbeen considered so hard that popular proposals for biological learning assumethat the backward weights are simply random, as in the feedback alignmentalgorithm. However, such random weights do not appear to work well for largenetworks. Here we show how the discontinuity introduced in a spiking system canlead to a solution to this problem. The resulting algorithm is a special caseof an estimator used for causal inference in econometrics, regressiondiscontinuity design. We show empirically that this algorithm rapidly makes thebackward weights approximate the forward weights. As the backward weightsbecome correct, this improves learning performance over feedback alignment ontasks such as Fashion-MNIST and CIFAR-10. Our results demonstrate that a simplelearning rule in a spiking network can allow neurons to produce the rightbackward connections and thus solve the weight transport problem.

 

Quick Read (beta)

loading the full paper ...