# Stochastic Optimization
### Ferenc Huszár (fh277)
DeepNN Lecture 4
---
## Empirical Risk Minimization via gradient descent
$$
\mathbf{w}_{t+1} = \mathbf{w}_t - \eta \nabla_\mathbf{w} \hat{L}(\mathbf{w_t}, \mathcal{D})
$$
Calculating the gradient:
* takes time to cycle through whole dataset
* limited memory on GPU
* is wasteful: $\hat{L}$ is a sum, CLT applies
---
## Stochastic gradient descent
$$
\mathbf{w}_{t+1} = \mathbf{w}_t - \eta \nabla_\mathbf{w} \hat{L}(\mathbf{w_t}, \mathcal{D}_t)
$$
where $\mathcal{D}_t$ is a random subset (minibatch) of $\mathcal{D}$.
Also known as minibatch-SGD.
---
## Does it converge?
Unbiased gradient estimator:
$$
\mathbb{E}[\hat{L}(\mathbf{w}, \mathcal{D}_t)] = \hat{L}(\mathbf{w}, \mathcal{D})
$$
* empirical risk does not increase in expectation
* $\hat{L}(\mathbf{w}_t)$ is a supermartingale
* Doob's martingale convergence theorem: a.s. convergence.
---
## Does it behave the same way?
![](https://i.imgur.com/xRYHk0m.png =1200x)
---
## Improving SGD: Two key ideas
* idea 1: momentum
* **problem:**
* high variance of gradients due to stochasticity
* oscillation in narrow valley situation
* **solution**: maintain running average of gradients
https://distill.pub/2017/momentum/
---
## Improving SGD: two key ideas
* idea 2: adaptive stepsizes
* **problem**:
* parameters have different magnitude gradients
* some parameters tolerate high learning rates, others don't
* **solution**: normalize by running average of gradient magnitudes
---
## Adam: combines the two ideas
![](https://i.imgur.com/MpiCllk.png)
---
## How good is Adam?
optimization vs. generalisation
---
## How good is Adam?
![](https://i.imgur.com/0yelxmm.png)
---
## How good is Adam?
![](https://i.imgur.com/5rAyMYc.png)
---
## Revisiting the cartoon example
![](https://i.imgur.com/xRYHk0m.png =1200x)
---
## Can we describe SGD's behaviour?
![](https://i.imgur.com/AyrSiFZ.png)
---
## Analysis of mean iterate
![](https://i.imgur.com/9j85UIv.png)
([Smith et al, 2021](https://arxiv.org/abs/2101.12176)) "On the Origin of Implicit Regularization in Stochastic Gradient Descent"
---
## Implicit regularization in SGD
$$
\mathbf{w}_{t+1} = \mathbf{w}_t - \eta \nabla_\mathbf{w} \hat{L}(\mathbf{w_t}, \mathcal{D}_t)
$$
mean iterate in SGD:
$$
\mu_t = \mathbb{E}[\mathbf{w}_t]
$$
---
## Implicit regularization in SGD
([Smith et al, 2021](https://arxiv.org/abs/2101.12176)): mean iterate approximated as continuous gradient flow:
$$
\small
\dot{\mu}(t) = -\eta \nabla_\mathbf{w}\tilde{L}_{SGD}(\mu(t), \mathcal{D})
$$
where
$$
\small
\tilde{L}_{SGD}(\mathbf{w}, \mathcal{D}) = \tilde{L}_{GD}(\mathbf{w}, \mathcal{D}) + \frac{\eta}{4}\mathbb{E}\|\nabla_\mathbf{w}\hat{L}(\mathbf{w}, \mathcal{D_t}) - \nabla_\mathbf{w}\hat{L}(\mathbf{w}, \mathcal{D})\|^2
$$
---
## Implicit regularization in SGD
([Smith et al, 2021](https://arxiv.org/abs/2101.12176)): mean iterate approximated as continuous gradient flow:
$$
\small
\dot{\mu}(t) = -\eta \nabla_\mathbf{w}\tilde{L}_{SGD}(\mu(t), \mathcal{D})
$$
where
$$
\small
\tilde{L}_{SGD}(\mathbf{w}, \mathcal{D}) = \tilde{L}_{GD}(\mathbf{w}, \mathcal{D}) + \frac{\eta}{4}\underbrace{\mathbb{E}\|\nabla_\mathbf{w}\hat{L}(\mathbf{w}, \mathcal{D_t}) - \nabla_\mathbf{w}\hat{L}(\mathbf{w}, \mathcal{D})\|^2}_{\text{variance of gradients}}
$$
---
## Revisiting cartoon example
![](https://i.imgur.com/xRYHk0m.png =1200x)
---
## Is Stochastic Training Necessary?
![](https://i.imgur.com/1WThiku.png)
---
## Is Stochastic Training Necessary?
![](https://i.imgur.com/dHJKsKS.png)
* reg $\approx$ flatness of minimum
* bs32 $\approx$ variance of gradients size 32 batches
---
## SGD summary
* gradient noise is a feature not bug
* SGD avoids regions with high gradient noise
* this may help with generalization
* improved SGD, like Adam, may not always help
* an optimization algorithm can be "too good"
{"metaMigratedAt":"2023-06-15T19:20:45.570Z","metaMigratedFrom":"YAML","title":"DeepNN Lecture 4 Slides","breaks":true,"description":"Lecture slides on stochastic gradient descent, its variants like Adam, and some thoughts on generalisation","contributors":"[{\"id\":\"e558be3b-4a2d-4524-8a66-38ec9fea8715\",\"add\":10965,\"del\":6806}]"}