## Gradient flow ODE/Diffusion SDE views of neural training dynamics + connections with sharpness of minima
### Background and thought process
Recently I've been reading a bunch of papers on gradient flow ODE/diffusion approximations of GD/SGD and almost all of them assert that the implicit biases arising from discretization (from the gradient flow ODE in the GD setting) or stochasticity (in SGD) lead to flatter minima which achieve higher test accuracy/correlate with better generalization. This led to me diving down the rabbit hole of sharpness, generalization, and optimizers that that explicitly regularize some notion of sharpness (notably [Sharpness Aware Minimization (SAM)](https://arxiv.org/abs/2010.01412), which achieved SOTA when combined with vision Transformers on a bunch of image benchmarks like CIFAR-100).
There are a bunch of measures of sharpness out there but one of the most popular seem to be leading eigenvalue of the Hessian of the loss defined as $\lambda_\text{max}(H_\theta(\mathcal{L}))$, where $\theta$ are the parameters of your network and $\mathcal{L}$ is your loss function. This can be seen as measuring the worst-case loss increase under an adversarial perturbation to the weights. There's definitely some doubt over whether this measure of sharpness is actually correlated with generalization (see [On the Maximum Hessian Eigenvalue and Generalization](https://proceedings.mlr.press/v187/kaur23a.html)), but in any case I mention it because I already have some PyTorch code written to measure the leading eigenvalue of the Hessian of the loss and use it in conjunction with Runge-Kutta integration of the gradient flow ODE instead of using SGD/Adam for training :smiley:.
This was based on code from this fun paper [Gradient Descent on Neural Networks Typically Occurs at the Edge of Stability](https://arxiv.org/abs/2103.00065) which looks at full batch gradient descent on neural network training objectives and shows that it typically operates in a regime they call the "Edge of Stability" (:heart: this name). In this regime, the maximum eigenvalue of the training loss Hessian hovers just above the numerical value 2/(step size), and the training loss behaves non-monotonically over short timescales, yet consistently decreases over long timescales. They do a lot of experiments across different architectures and data sets, and overall I like how empirically driven this work is along with other works from this group, one of which I'll talk about below.
Below I've listed some prominent works in my head related to each topic, along with some summary notes
Side note: very cool repo full of relevant papers [here!](https://github.com/zeke-xie/deep-learning-dynamics-paper-list)
#### Diffusion/ODE view of training dynamics + implicit regularization:
- [Implicit Gradient Regularization](https://arxiv.org/abs/2009.11162)
- Uses backward error analysis from the dynamics literature to create a modified loss that better models the trajectory of GD which follows the form $$\begin{align*}\frac{d\theta}{dt} &= - \nabla_\theta \tilde{\mathcal{L}}(\theta) \\ \tilde{\mathcal{L}}_{\text{GD}}(\theta) &= \mathcal{L}(\theta) + \lambda R_{\text{IG}}(\theta)\\ \lambda &\equiv \frac{\varepsilon m}{4}, \ R_{\text{IG}}(\theta) \equiv \frac{1}{m} \sum_{i=1}^{m} (\nabla_{\theta_i}\mathcal{L}(\theta))^2 \\ \tilde{\mathcal{L}}_{\text{GD}}(\theta) &= \mathcal{L}(\theta) + \frac{\varepsilon}{4} || \nabla \mathcal{L}(\theta)||^2 \end{align*}$$ where $\theta$ are the parameters of your network, $\mathcal{L}$ is your loss function, $\varepsilon$ is your learning rate, and $m$ is your network size. In the paper they call $R_{\text{IG}}$ the "implicit gradient regularizer" because it penalizes regions of the loss landscape that have large gradient values, and because it is implicit in gradient descent, rather than being explicitly added to our loss. From the modifed loss equation above, the strength of regularization is proportional to the learning rate and network size and consequently, networks with small learning rates or fewer parameters or both will have less implicit gradient regularization and worse test error.
- Implicit gradient regularization (IGR) is therefore defined as the implicit regularization behaviour originating from the use of discrete update steps in gradient descent (from the continuous gradient flow ODE). The paper goes on to predict that IGR encourages small values of $R_{\text{IG}}$ and discovers flatter minima that have higher test accuracy and are more robust to parameter perturbations. They include some empirical experiments on a mix of toy problems and MNIST and CIFAR-10 to back these up.
- The sharpness measure they use here is *"loss surface slope"* which they define from a differential geometric perspective. Need to read more on this to fully understand it, but in anycase they seem to confirm the notion that flatter minima generalize better.
- In their formulation, the amount of IGR is partially controlled by the learning rate, i.e. the higher the learning rate the more regularization. The authors suggest that *explicit* regularization maybe useful in cases where large learning rates cause numerical instability. I feel like there's a connection here with other papers that find that larger learning rates are beneficial to better generalization.
- In the conclusion the authors suggest extending the backward error analysis to Adam and RMSProp, which looks to be done [here](https://arxiv.org/abs/2309.00079) as a [submission to ICLR 2024](https://openreview.net/forum?id=ZA9XUTseA9)
- [On the Origin of Implicit Regularization in Stochastic Gradient Descent](https://arxiv.org/abs/2101.12176)
- Similar to the above paper, but instead focuses their backward error analysis on the mean evolution of SGD. The modified loss of SGD looks like $$\begin{align*}\frac{d\theta}{dt} &= - \nabla_\theta \tilde{\mathcal{L}}_{\text{SGD}}(\theta) \\ \tilde{\mathcal{L}}_{\text{SGD}} &= \mathcal{L}(\theta) + \frac{\varepsilon}{4m}\sum_{k=0}^{m-1}||\nabla\hat{\mathcal{L}}_k(\theta) ||^2 \end{align*}$$ where $\theta$ are the parameters of your network, $\mathcal{L}(\theta)$ is your loss function summed across all training examples, $\varepsilon$ is your learning rate, $m$ is your batch size, and $\hat{\mathcal{L}}_k$ is the minibatch loss of minibatch $k$. We can see the regularization term penalizes the mean squared norm of the gradient evaluated on a minibatch.
- Expanding the modified loss of SGD to include terms of the modified loss of GD, we can see $$ \tilde{\mathcal{L}}_{\text{SGD}} = \tilde{\mathcal{L}}_{\text{GD}}(\theta) + \frac{\varepsilon}{4m}\sum_{k=0}^{m-1}||\nabla\hat{\mathcal{L}}_k(\theta) - \nabla \mathcal{L}(\theta)||^2 $$ The additional regularizer term causes SGD to avoid parts of the parameter-space where the variance of gradients calculated over different minibatches is high.
- Another important note is that $\mathcal{L}(\theta)$ and $\tilde{\mathcal{L}}_{\text{GD}}(\theta)$ share the same minima, but this is not necessarily true for $\mathcal{L}(\theta)$ and $\tilde{\mathcal{L}}_{\text{SGD}}(\theta)$ for minima that have high variance gradients.
- Empirical experiments explicitly including the additional regularization term show improvements in test accuracy. They also mention and further demonstrate a well known linear scaling rule (first discussed in [Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour](https://arxiv.org/abs/1706.02677)) where different batch sizes achieve the same test accuracy if the ratio of the learning rate to the batch size ($\varepsilon/m$) is constant and $m$ is not too large.
- [A Diffusion Theory For Deep Learning Dynamics: Stochastic Gradient Descent Exponentially Favors Flat Minima](https://arxiv.org/abs/2002.03495)
- How does diffusion theory help us? It helps us model the diffusion process of probability densities of parameters instead of model parameters themselves.
- Main contribution is to take into account parameter-dependent, anisotropic gradient noise when modeling SGD for training dynamics. Their theory explicitly takes into account gradient noise, batch size, learning rate, and the Hessian to show that small learning rates and large batch training require exponentially more iterations to escape sharp minima in order to reach flat minima.
- [On the Validity of Modeling SGD with Stochastic
Differential Equations (SDEs)](https://arxiv.org/abs/2102.12470)
- [Disentangling the Mechanisms Behind Implicit Regularization in SGD](https://arxiv.org/abs/2211.15853)
- [A Bayesian Perspective on Generalization and Stochastic Gradient Descent](https://arxiv.org/abs/1710.06451)
- Fun lil Bayesian spin :smiley: tries to show that minima that generalize well have larger Bayesian evidence
#### Sharpness:
- [Sharp Minima Can Generalize For Deep Nets](https://arxiv.org/abs/1703.04933)
- Standard measures of defining sharpness are *not* parametrization invariant (BatchNorm, WeightNorm, etc. will affect the loss landscape), meaning we can make take minima with high test accuracy and make them arbitrarily sharp according to standard measures of sharpness (ex. trace or leading eigenvalue of Hessian).
- The non-identifiability of most modern NN architectures (ReLU networks) induced by symmetries allows one to alter the flatness of a minimum without affecting the function it represents.
- Core of their argument relies on the non-negative homogeneity property defined as: A given function $\phi$ is non-negative homogeneous if $$\forall(z, \alpha) \in \mathbb{R} \times \mathbb{R}^+, \ \phi(\alpha z) = \alpha \phi(z)$$ Since ReLU activation functions have this property, given two weight matrices $\theta_1$ and $\theta_2$, input data $x$, and scalar $\alpha >0$, $$\phi_{\text{ReLU}}(x \cdot (\alpha \theta_1)) \cdot \theta_2 = \phi_{\text{ReLU}}(x \cdot \theta_1) \cdot (\alpha \theta_2) $$ Thus $(\alpha \theta_1, \theta_2)$ is observationally equivalent to $(\theta_1, \alpha\theta_2)$. This phenomena also holds for convolutional layers. The rest of the paper plays with scaling the weight matrices of different layers that still result in the same function output, but with arbitrary sharpness in weight space.
- [Sharpness Aware Minimization For Efficiently Improving Generalization](https://arxiv.org/abs/2010.01412)
- Derives an PAC-Bayes upper bound on generalization error and designs an optimization procedure based on that bound to explicitly find uniform regions of low loss and low curvature by optimizing both loss and loss sharpness.
- Given a training set $\mathcal{S}$ generated from distribution $\mathcal{D}$, loss function $\mathcal{L}$, and network parameters $\theta$, for any $\rho > 0$ $$\mathcal{L}_{\mathcal{D}}(\theta) \leq \max_{||\epsilon||_2 \leq \rho} \mathcal{L}_{\mathcal{S}}(\theta + \epsilon) + h(||\theta||^2_2/\rho^2)$$ where $h: \mathbb{R}^+\rightarrow\mathbb{R}^+$ is a strictly increasing function. $\rho$ is interpreted as the neighborhood size of uniform low loss that one is seeking to find. Here the paper defines sharpness as $$ \max_{||\epsilon||_2 \leq \rho} \mathcal{L}_{\mathcal{S}}(\theta + \epsilon) - \mathcal{L}_\mathcal{S}(\theta) $$ which measures how quickly the training loss increases by moving from $\theta$ to a nearby parameter value.
- Since $h$ is monotonic, you can replace it with the standard L2 regularization term to yield the "Sharpness Aware Minimization (SAM)" problem: $$ \min_\theta \max_{||\epsilon||_2 \leq \rho} \mathcal{L}_{\mathcal{S}}(\theta + \epsilon) + \lambda ||\theta||_2^2$$ with $\lambda$ controlling the degree of regularization. This minmax problem gets solved iteratively by first approximating the inner maximization problem before substituting the maximized epsilon into the outer minimization problem to yield an adversarial parameter point from which one will compute the gradient. This two step procedure for $t=0, 1, 2 \dots$ looks like $$ \begin{align*} \epsilon_t &= \rho \frac{\nabla \mathcal{L}_\mathcal{S}(\theta_t)}{|| \nabla \mathcal{L}_\mathcal{S}(\theta_t)||_2} \\ \theta_{t+1} &= \theta_t - \alpha_t(\nabla \mathcal{L}_\mathcal{S}(\theta_t + \epsilon_t) + \lambda\theta_t) \end{align*} $$ with $\alpha_t$ being an appropriately scheduled learning rate. See the original paper for details involving a first order Taylor series approximation and a classical dual norm problem formulation to end up at this procedure. Note that solving the maximization problem for $\epsilon_t$ boils down to rescaling the gradient such that its norm is $\rho$.
- A nice figure from the paper showing graphically how the SAM parameter updates look like

- [ASAM: Adaptive Sharpness-Aware Minimization for Scale-Invariant Learning of Deep Neural Networks](https://arxiv.org/abs/2102.11600)
- Tried to alleviate issues identified in the "Sharp Minima Can Generalize For Deep Nets" paper by defining a notion of "adaptive sharpness" that is scale-invariant.
- They begin by defining normalization operators for weights $\theta$ where $\{T_\theta, \theta \in \mathbb{R}^k\}$ is a family of invertible linear operators with $k$ being the dimensionality of weight space. These operators satisfy the property $T_{A\theta}^{-1}A = T_{\theta}^{-1}$ for any invertible scaling operator $A$ that does not change the loss function and thus $T^{-1}_{\theta}$ is known as a normalization operator of $\theta$.
- Adaptive sharpness is thus defined as $$\max_{|| T_{\theta}^{-1} \epsilon||_2 \leq \rho} \mathcal{L}_{\mathcal{S}}(\theta + \epsilon) - \mathcal{L}_{\mathcal{S}}(\theta) $$ Notice that this is identical to the SAM definition with the inclusion of the normalization operator on $\epsilon$. This definition seems to have become the defacto measure of sharpness in most future works due to it being invariant to scaling operations. A popular choice for the normalization operator introduced in this paper is known as the element-wise normalization operator: $$T_{\theta} = \text{diag}(|\theta_1|, \dots, |\theta_k|)$$ where $$ \theta = [\theta_1, \dots, \theta_k] $$ I will say this definition was overall pretty unintuitive for me, and I had to work through some numerical examples before I kinda figured out how things worked.
- [A Modern Look at the Relationship between Sharpness and Generalization](https://arxiv.org/abs/2302.07011)
- Tests in large scale modern settings (not CIFAR-10 and below in size) how well various sharpness measures correlate with good test performance, overall finding that the relationship is not as cut and dry as one would like. Sharpness does not correlate well with generalization but rather with some training parameters like the learning rate that can be positively or negatively correlated with generalization depending on the setup.
### Potential future direction: Lyapunov exponents
[Gradient Flossing: Improving Gradient Descent through Dynamic Control of Jacobians](https://arxiv.org/abs/2312.17306)
- In order to alleviate the vanishing/exploding gradients problem when training RNNs, this paper introduces a method that explicitly regularizes Lyapunov exponents of the forward dynamics through backpropagation.
I still need to do some reading on Lyapunov exponents, but this seems like a cool example of how it's used for an adjacent problem. Maybe a way forward would be to look at different initialization schemes combined with the training trajectories of different optimizers and then do some analysis using Lyapunov exponents.
### Random fun arXiv paper that ties a lot of things together
[Training Recurrent Neural Networks by Diffusion](https://arxiv.org/abs/1601.04114)
### Other potential future directions
- Fisher-Rao metric stuff to make things invariant, inspiration from maybe [Fisher-Rao Metric, Geometry, and Complexity of Neural Network](https://arxiv.org/abs/1711.01530) and [Fisher SAM: Information Geometry and Sharpness Aware Minimisation](https://arxiv.org/abs/2206.04920)
- Something with Fisher Information Matrix/natural gradients, need to look at [Universal Statistics of Fisher Information in Deep Neural Networks: Mean Field Approach](https://arxiv.org/abs/1806.01316), [Natural Gradient Works Efficiently in Learning](https://direct.mit.edu/neco/article/10/2/251/6143/Natural-Gradient-Works-Efficiently-in-Learning), and [Catastrophic Fisher Explosion: Early Phase Fisher Matrix Impacts Generalization](https://arxiv.org/abs/2012.14193)