# Notes on Natural Gradient Descent Consider a probabilistic model represented by its likelihood $p(y\vert\theta)$. Maximizing this likelihood to find the most likely parameter $\theta$ is equivalent to minimizing the negative log-likelihood, i.e. the loss function $\mathcal{L}(\theta, \mathcal{D}) = -\sum_{y\in\mathcal{D}} \log p(y\vert \theta)$. Many common loss functions have interpretations as the negative log likelihood of a probabilistic model, for example, the squared loss is the log loss of a Normal distribution. To tackle this optimization we use gradient descent (GD) over the parameter $\theta$, which finds the steepest descent direction in each step around the local neighbourhood of the current value of $\theta$ in the parameter space. If we change how we parametrise our model, we also change what the meaning of steepest descent direction is. Now, if $\theta_1$ and $\theta_2$ are two parameter sets describing the same model, updating our model in the steepest descent direction around $\theta_1$ might result in a different model than if we update in the steepest direction around $\theta_2$. This is why you saw those differences in Assignment 1. Notice how the loss function can be thought of as a composition of two mappings: $$ \theta \mapsto p(y\vert \theta) \mapsto \mathcal{L}(\theta, \mathcal{D}) $$ The first map is the parameter-to-hypothesis map, which maps from parameters to the space of probability distributions. The second map, the hypothesis-to-loss map, maps from the space of distributions to a scalar by evaluating the likelihood of data. Parametrisation effects the parameter-to-hypothesis map, but hypothesis-to-loss map remains unchanged even if we change how we parametrise our models. Thus, to obtain an optimization method that is agnostic to how we parametrise our models, it would make sense to **follow the steepest descent direction in hypothesis space**, as opposed to steepest descent in the Euclidean space of parameters $\theta$. By hypothesis space we now mean the set of probability distributions $p(\cdot\vert \theta)$ realisable by the parameters. Notice also that while the parameter-to-hypothesis map can be ugly and non-convex, the hypothesis-to-loss map is often nicely behaved and convex (e.g. squared loss), therefore we can hope that an algorithm that directly optimises this map would also converge faster, and have other nice properties. But, what does steepest descent mean in hypothesis space? To talk about *steep*, we need to be able to measure distances. We need to endow the hypothesis space with a Riemannian metric. In our case hypotheses are probability distributions, we can use a metric for probability distributions: the Fisher-Rao metric. This metric is closely related to the KL-divergence between probability distributions. The Fisher-Rao metric in the distribution space parametrised by $\theta$ can be described in terms of the Fisher Information Matrix $F(\theta)$ defined as $$ F(\theta) = \mathop{\mathbb{E}}_{p(y \vert \theta)} \left[ \nabla_\theta \log p(y \vert \theta) \, \nabla_\theta \log p(y \vert \theta)^{\text{T}} \right]. $$ Notice that we interpreted $\nabla_\theta \log p(y \vert \theta)$ to be a column vector, therefore the Fisher-infromation matrix will be indeed matrix of size $p \times p$ where $p$ is the length of $\theta$. What this implies is, in the vicinity of any parameter $\theta$ we should measure distances not by the Euclidean distance $\|\theta' - \theta\|$ but instead by the Mahalanobis distance $\sqrt{(\theta' - \theta)^{\text{T}}F(\theta)(\theta' - \theta)}$. When $\theta'$ is close to $\theta$, this distance metric roughly corresponds to the KL divergence between $p(y\vert \theta)$ and $p(y\vert \theta')$. Using this new, "parametrisation-independent" notion of distance, the Natural Gradient Descent update becones the following $$ \theta_{i+1} = \theta_i - \alpha F^{-1} (\theta)\nabla _{\theta} \mathcal{L}(\theta). $$ If our learning rate is sufficiently small (and with a few technical caveats we will ignore here) this algorithm can be interpreted as a version of gradient descent that directly minimises the hypothesis-to-loss map in distribution space. ## What about supervised learning? In most applications of ML our parameters define not a single probabilistic model $p(y\vert \theta)$, but a conditional distribution $p(y \vert x; \theta)$. We can consider this as a family of probabilistic models, one for each value of $x$, each of which has its own Fisher-Rao metric and Fisher-information matrix $F_x(\theta)$: $$ F_x(\theta) = \mathop{\mathbb{E}}_{p(y \vert x; \theta)} \left[ \nabla_\theta \log p(y \vert x; \theta) \, \nabla_\theta \log p(y \vert x; \theta)^{\text{T}} \right]. $$ We can turn this into a single metric by averaging over random values of x. For example, we can average over the training data $\mathcal{D}_{train}$: $$ F_{train}(\theta) = \frac{1}{\vert\mathcal{D}_{train}\vert}\sum_{x\in \mathcal{D}_{train}}F_x(\theta) $$ Or, we can average over test or a separate held-out validation dataset. $$ F_{test}(\theta) = \frac{1}{\vert\mathcal{D}_{test}\vert}\sum_{x\in \mathcal{D}_{test}}F_x(\theta) $$ In the Assignments we have considered models that defined functions $f(x; \theta)$. Since we optimised the squared loss, one can re-interpret these functions as defining the mean of a Gaussian distribution, such that $p(y\vert x; \theta) = \mathcal{N}(\mu=f(x;\theta), \sigma_n)$. It's easy to check that for appropriate choice of standard deviation $\sigma_n$, the negative log likelihood of this model is the squared loss. To calculate the Fisher information matrix $F_x$: * try yourself without looking at the hints below * consider the following hints: plug the Gaussian pdf into the definition of $F_x$; apply chain rule of gradients; apply knowledge of linearity of expectations; apply knowledge of covariance matrices; * if it doesn't work, you can look this up on the internet ### Invertibility of $F(\theta)$ and advice on algorithms What is relatively important though is that the Fisher information matrix should ideally be invertible. If you look at the definiton of $F_x$, you will notice that it is low-rank, and this means that we will need a large number of $x$ to obtain an invertible matrix. To overcome this issue, the NGD algorithm is often modified as follows: $$ \theta_{i+1} = \theta_i - \alpha (F (\theta) + \epsilon I)^{-1}\nabla _{\theta} \mathcal{L}(\theta). $$ In practice, inverting the matrix is almost never necessary. If one wants to calculate $A^{-1}B$, the right numerical thing to do is almost always to solve the linear system $AX = B$ instead. Notice also that since $F$ is symmetric and positive semidefinite, the natural gradient can also be approximated using [conjugate gradients](https://en.wikipedia.org/wiki/Conjugate_gradient_method#:~:text=In%20mathematics%2C%20the%20conjugate%20gradient,whose%20matrix%20is%20positive%2Ddefinite). In machine learning, such an algorithm is used in [Hessian-free optimisation](https://www.cs.toronto.edu/~jmartens/docs/Deep_HessianFree.pdf). ### Other things related to Fisher information Often the likelihood function is complicated and computing the expectation is intractable, therefore in practice the Empirical Fisher is often used \begin{align} \tilde{F}(\theta)&= \frac{1}{N} \sum_{n=1}^{N} \nabla \log p(y_n \vert x_n; \theta) \, \nabla \log p(y_n \vert x_n; \theta)^{\text{T}} \, . \end{align} Notice de difference is that the expectation over the model's distribution is replaced by samples of $y$ from the data itself. While the Fisher information may use only inputs from the training or test data (to average over $x$), the empirical Fisher also uses real labels. There are other related matrices, such as the generalised Gauss-Newton matrix, which are also sometimes used in optimisation methods. Good summaries of these connections are found in [(Martens, 2020)](https://jmlr.org/papers/volume21/17-678/17-678.pdf) and [(Kunstner et al, 2019)](https://arxiv.org/abs/1905.12558). In summary, the NGD algorithm is: Repeat: 1. Do forward pass on our model and compute loss $\mathcal{L}(\theta)$ 2. Compute the gradient $\nabla _{\theta} \mathcal{L}(\theta)$ 3. Compute the Fisher Information Matrix $F_\theta$ or an approximation thereof (you may skip this step if you have a way of performing step 4 without explicitly calculating F first) 4. Compute or solve for the natural gradient $\tilde{\nabla} _{\theta} \mathcal{L}(\theta)=\alpha F^{-1} (\theta)\nabla _{\theta} \mathcal{L}(\theta)$ 5. Update the parameter: $\theta_{t+1}=\theta_t - \alpha \tilde{\nabla} _{\theta} \mathcal{L}(\theta)$, where $\alpha$ is the learning rate. Until convergence. For more detail, further reading, see this [great blog](https://agustinus.kristia.de/techblog/2018/03/14/natural-gradient/).