# SCALING FORWARD GRADIENT WITH LOCAL LOSSES ICLR 2023 [paper]([https:/](https://arxiv.org/abs/2210.03310)/) ## Introduction <!-- - Artificial neural networks were inspired by biological neurons --> - Backprop has always been considered as “**biologically implausible**” - the brain does not form symmetric backward connections or perform synchronized computations - it requires separate phases for inference and learning - the learning signals are not local, but have to be propagated backward, layer-by-layer, from the output units - Backprop is incompatible with a massive level of **model parallelism** and restricts potential hardware designs - This paper propose forward gradient learning as a biologically plausible alternative to backpropagation - Forward gradient learning suffers from **high variance**, which limits its **scalability** - The authors propose modifications to reduce the variance of the forward gradient estimator - Introduce the LocalMixer architecture to improve scalability <!-- - Approaches based on **weight perturbation** directly send the loss signal back to the weight connections - do not require any backward weights - previously proposed as a biologically plausible alternative to backprop - In the forward pass, the network adds a slight perturbation to the synaptic connections - Weight update is then multiplied by the negative change in the loss The method of weight perturbation estimates the gradients by perturbing synaptic weights, and observing the change in the objective function - Unlike BP, weight perturbation is completely "**model-free**" - it does not depend on knowing anything about the functional dependence of the objective on the network weights - The disadvantage of a completely model-free approach is the tradeoff between generality and learning speed - weight perturbation is far more widely applicable than BP, but BP is much faster when it is applicable - Existing approaches suffer from the curse of dimensionality - Variance of the estimated gradients is too high to effectively train large networks - In this paper, they revisit activity perturbation --> ## Forward Gradient Learning ### Forward-mode Automatic Differentiation ([AD](https://arxiv.org/pdf/2202.08587.pdf)) Let $f:\mathbb{R}^m \mapsto \mathbb{R}^n$. The Jacobian of $f$, $J_f$ is a matrix of size $n\times m$. Forward-mode AD computes the matrix-vector product $J_f v$, where $v\in\mathbb{R}^m$. It is defined as the directional gradient along $v$ evaluated at $x$: $$ J_fv:=\lim_{\delta\mapsto 0}\frac{f(x+\delta v)-f(x)}{\delta} $$ - Backprop, also known as **reverse-mode AD**, computes the vector-Jacobian product $vJ_f$, where $v\in\mathbb{R}^n$, which corresponds to the last term in the chain rule - Forward-mode AD only requires one forward pass, which is augmented with the derivative information ### Weight-perturbed Forward Gradient Let $w_{ij}$ be the weight connection between unit ${i}$ and ${j}$, and ${f}$ be the loss function. We can estimate the gradient by sampling a **random matrix** with iid elements $v_{ij}$ drawn from a **zero-mean unit-variance Gaussian distribution** The estimator is $$ g_w(w_{ij})=\left(\sum_{i^{\prime}j^{\prime}}\nabla w_{i^{\prime}j^{\prime}}v_{i^{\prime}j^{\prime}} \right)v_{ij} $$ - Estimator samples a random perturbation direction $v_{ij}$ and tests how it aligns with the true gradient $\nabla w_{i^{\prime}j^{\prime}}$ by using forward-mode to perform the dot product - Multiplies the scalar alignment with the perturbation direction again ### Activity-perturbed Forward Gradient - An alternative to perturbing the weights is to instead perturb the **activities** - reduce the number of perturbation dimensions per example - $x_i$ denote the activity of the i-th presynaptic neuron - $z_j$ denote the j-th postsynaptic neuron before the non-linear activation function - $u_j$ denote the perturbation of $z_j$ The activity-perturbed forward gradient estimator is $$ g_a(w_{ij})=x_i\left(\sum_{j^{\prime}}\nabla z_{j^{\prime}}u_{j^{\prime}}\right)u_j $$ where the inner product between $\nabla\mathbb{z}$ and $\mathbb{u}$ is computed by using forward-mode AD. ![](https://hackmd.io/_uploads/SJdFv7Bn3.png =50%x) ### Theoretical Properties - Activity perturbation has a factor of $p$ times smaller variance compared to weight perturbation - the number of perturbed elements is the number of output units instead of the size of the whole weight matrix ![](https://hackmd.io/_uploads/S1I6ZEr2n.png) - Both activity and weight perturbation, the variance still grows with larger networks - they will further reduce the variance by introducing **local loss functions** ## Scaling with Local Losses - Perturbation learning can suffer from a **curse of dimensionality**: the variance grows with the number of perturbation dimensions - One way to limit the number of learnable dimensions is to **divide the network into submodules, each with a separate loss function** 1. **Blockwise loss** - Divide the network into modules in depth - Each module consists of several layers - Compute a loss function at the end of each module 2. **Patchwise loss** - Apply a separate loss patchwise along spatial dimensions - In the Vision Transformer architecture, each spatial token represents a patch in the image ![](https://hackmd.io/_uploads/ry7BKVSnn.png) 4. **Groupwise loss** - To create multiple losses, they split the channels into a number of groups - Each group is attached to a loss function - **Feature aggregators** - Naively applying losses separately to the spatial and channel dimensions leads to **suboptimal performances** - Standard architectures obtain global view by performing global average pooling layer before the final classification layer <!-- We therefore explore strategies for aggregating information from other groups and spatial patches before the local loss function --> ![](https://hackmd.io/_uploads/BkRG6J822.png) - Channel groups are copied and communicated to one another, but every group except the active group itself is masked with **stop gradient**$$\mathbb{x}_{p,g}=[\text{StopGrad}(x_{p,1},...,x_{p,g-1}),x_{p,g},\text{StopGrad}(x_{p,g+1},...,x_{p,G})]$$ - Each spatial location is also copied, communicated, and masked, and then averaged locally$$\bar{\mathbb{x}}_{p,g}=\frac{1}{P}\left(\mathbb{x}_{p,g}+\sum_{p^{\prime}\neq p}\text{StopGrad}(x_{p^{\prime},g})\right)$$ - The output of feature aggregation is the same as that of the conventional global average pooling layer - The difference is that here the loss is replicated and different patch groups are activated in each loss - **Learning objectives** - They consider the supervised classification loss and the contrastive InfoNCE loss - For supervised classification, they attach a shared linear layer (shared across $p,g$) on top of the aggregated features for a cross entropy loss$$L^s_{p,g}=-\sum_kt_k\log\text{softmax}(W_l\bar{\mathbb{x}}_{p,g})_k$$ - For contrastive learning, the linear layer becomes a linear feature projector, the InfoNCE loss for contrastive learning is$$L^c_{p,g}=-\sum_n\log\frac{(W\bar{\mathbb{x}}^{(1)}_{n,p,g})^{T}\text{StopGrad}(W\bar{\mathbb{x}}^{(2)}_n)}{\sum_m(W\bar{\mathbb{x}}^{(1)}_{n,p,g})^{T}\text{StopGrad}(W\bar{\mathbb{x}}^{(2)}_m)}$$ - Perturbation-based methods require a stop gradient and otherwise the loss will not go down - This is likely because they share the perturbations on both views - Non-shared perturbations also work but are worse than stop gradient ![](https://hackmd.io/_uploads/Bk8NYl8n3.png =50%x) ## Implementation - **Network architecture** - [MLPMixer](https://arxiv.org/pdf/2105.01601.pdf)![](https://hackmd.io/_uploads/Bkb9Pu8h3.png) - LocalMixer ![](https://hackmd.io/_uploads/Hyg9lHUnh.png) - An image is divided into non-overlapping patches (i.e. tokens) - Each block consists of token and channel mixing layers - They add a linear projector/classification layer to attach a loss function at the end of each block - Before the last channel mixing layer, features are reshaped into a number of groups, and the last layer is fully connected within each feature group - **Efficient implementation of replicated losses** - Naive implementation of groups can be very inefficient in terms of both memory consumption and compute - Each spatial group actually computes the **same** aggregated feature and loss function ![](https://hackmd.io/_uploads/Bkss5SU22.png =50%x) ## Experiments They compare to a set of alternatives: Backprop, [Feedback Alignment](https://arxiv.org/pdf/1609.01596.pdf) and other global variants of Forward Gradient 1. **Backprop (BP)** - standard backprop algorithm - Backprop (**L-BP**) adds local losses as proposed - Local Greedy Backprop (**LG-BP**) adds stop gradient operators in between blocks 2. **Feedback Alignment (FA)** - standard [FA algorithm](https://arxiv.org/pdf/1411.0247.pdf) adds a set of random and fixed backward weights - Local Feedback Alignment (**L-FA**) adds local losses as proposed - Local Greedy Feedback Alignment (**LG-FA**) adds a stop gradient 3. **Forward Gradient (FG)** - Weight-perturbed forward gradient (**FG-W**) - Activity perturbation variant (**FG-A**) - Local Greedy Forward Gradient Weight-Perturbed (**LG-FG-W**) - Local Greedy Forward Gradient Activity-Perturbed (**LG-FG-A**) - **Datasets** - MNIST - CIFAR-10 - ImageNet - **Main results** ![](https://hackmd.io/_uploads/rJiw8uIh2.png) - Local forward gradient method can match the test error of backprop on MNIST and CIFAR - Error due to greediness grows as the problem gets more complex and requires more layers to cooperate - They significantly outperform the FA family on ImageNet (by 25% for supervised and 10% for contrastive) - Interestingly, local greedy FA also performs better than global feedback alignment - benefit of local learning transfers to other types of gradient approximation - **Effect of local losses** ![](https://hackmd.io/_uploads/BktCHO8nh.png) - **Effect of groups** ![](https://hackmd.io/_uploads/BkTVL_Uh2.png) - Adding more groups bring significant improvement to local perturbation learning in terms of lowering both training and test errors - Effect vanishes around 8 channels / group ## Conclusion - It is often believed that perturbation-based learning cannot scale to large and deep networks - This paper shows that a huge number of local greedy losses can help forward gradient learning scale much better - They explored blockwise, patchwise, and groupwise local losses, and a combination of all three, with a total of a quarter of a million losses in one of the larger networks - Local activity-perturbed forward gradient performs better than previous backpropfree algorithms on larger networks