These are my notes on the following paper: https://arxiv.org/abs/2104.04874 which analyses the inductive bias of SGD relative to full-batch GD using Taylor expansions. In order to make this paper easier to understand, I present the key ideas using modern machine learning notation.
The generalisation gap, for the purposes of this paper, is the difference between a model's test loss and training loss. In this work we look at how this generalisation gap changes in a single step of gradient descent on the training data, starting from a fixed parameter . Since the training and test data are random, the change in change in generalisation gap is also random, we will be interested in the average of this change, when we average over realisations of the training and test data.
Let and be a random training and test set, drawn i.i.d. from the same underlying distribution over some data .
let be the empirical loss evaluated on a set of data .
Let be the gradient evaluated on a set of data .
Let's say we take a gradient step on . We can approximate how the training error changes via Taylor expansion of :
Similarly, we can express the change in the test loss as:
To move forward, we will want to average the above quantities over random realization of the training and test dataset. Let's consider a dataset of size , , where each datapoint is drawn independently from the same distribution . Let's say we have two index sets , and corresponding subsets and . We now consider the following expectation:
Where we used the fact that the expectation of the product of independent random variables is the product of their expetations on the second term in the second line, and the following identity that for for vector valued which we prove below using the cyclic property of trace:
Using this, we can now express the expected change in training and test errors, under the assumption that :
Thus, the training error is reduced more than the test error, in expectation. This is what the authors mean when they write that 'GD develops a “bias” toward overfitting', because in expectation, one gradient step increases the gap between training and test losses.
This is the first main finding of the paper, equation (1), just written differently:
where is the covariance matrix of gradients on a single datapoint.
This also provides a nice, intuitive explanation for how the local total variance of gradients is related to generalisation. Ideally, if the variance minibatch gradients is negligible - or small, it means that running gradient descent is a 'safe way' to learn, because in this case the gradient one would get on the training set is the same as one would get on the test set.
A key limitation of this analysis is that it assummes is chosen independently of . So long as we never recycle data in subsequent training steps this would be true in each step of gradient descent. In practice, however, we cycle through training data multiple times. If in our equations is some intermediate state, and is itself found by running gradient descent on data that overlapped with the training data we currently evaluate, the above analysis is flawed (because it considers fixed and not a function of ). In this situation it can no longer be guaranteed that the distribution of gradients is going to be the same on the training and test set and the whole argument fails.
The second finding of this paper contrasts stochastic and full-batch gradient descent.