# Stochastic Optimization ### Ferenc Huszár (fh277) DeepNN Lecture 4 --- ## Empirical Risk Minimization via gradient descent $$ \mathbf{w}_{t+1} = \mathbf{w}_t - \eta \nabla_\mathbf{w} \hat{L}(\mathbf{w_t}, \mathcal{D}) $$ Calculating the gradient: * takes time to cycle through whole dataset * limited memory on GPU * is wasteful: $\hat{L}$ is a sum, CLT applies --- ## Stochastic gradient descent $$ \mathbf{w}_{t+1} = \mathbf{w}_t - \eta \nabla_\mathbf{w} \hat{L}(\mathbf{w_t}, \mathcal{D}_t) $$ where $\mathcal{D}_t$ is a random subset (minibatch) of $\mathcal{D}$. Also known as minibatch-SGD. --- ## Does it converge? Unbiased gradient estimator: $$ \mathbb{E}[\hat{L}(\mathbf{w}, \mathcal{D}_t)] = \hat{L}(\mathbf{w}, \mathcal{D}) $$ * empirical risk does not increase in expectation * $\hat{L}(\mathbf{w}_t)$ is a supermartingale * Doob's martingale convergence theorem: a.s. convergence. --- ## Does it behave the same way? ![](https://i.imgur.com/xRYHk0m.png =1200x) --- ## Improving SGD: Two key ideas * idea 1: momentum * **problem:** * high variance of gradients due to stochasticity * oscillation in narrow valley situation * **solution**: maintain running average of gradients https://distill.pub/2017/momentum/ --- ## Improving SGD: two key ideas * idea 2: adaptive stepsizes * **problem**: * parameters have different magnitude gradients * some parameters tolerate high learning rates, others don't * **solution**: normalize by running average of gradient magnitudes --- ## Adam: combines the two ideas ![](https://i.imgur.com/MpiCllk.png) --- ## How good is Adam? optimization vs. generalisation --- ## How good is Adam? ![](https://i.imgur.com/0yelxmm.png) --- ## How good is Adam? ![](https://i.imgur.com/5rAyMYc.png) --- ## Revisiting the cartoon example ![](https://i.imgur.com/xRYHk0m.png =1200x) --- ## Can we describe SGD's behaviour? ![](https://i.imgur.com/AyrSiFZ.png) --- ## Analysis of mean iterate ![](https://i.imgur.com/9j85UIv.png) ([Smith et al, 2021](https://arxiv.org/abs/2101.12176)) "On the Origin of Implicit Regularization in Stochastic Gradient Descent" --- ## Implicit regularization in SGD $$ \mathbf{w}_{t+1} = \mathbf{w}_t - \eta \nabla_\mathbf{w} \hat{L}(\mathbf{w_t}, \mathcal{D}_t) $$ mean iterate in SGD: $$ \mu_t = \mathbb{E}[\mathbf{w}_t] $$ --- ## Implicit regularization in SGD ([Smith et al, 2021](https://arxiv.org/abs/2101.12176)): mean iterate approximated as continuous gradient flow: $$ \small \dot{\mu}(t) = -\eta \nabla_\mathbf{w}\tilde{L}_{SGD}(\mu(t), \mathcal{D}) $$ where $$ \small \tilde{L}_{SGD}(\mathbf{w}, \mathcal{D}) = \tilde{L}_{GD}(\mathbf{w}, \mathcal{D}) + \frac{\eta}{4}\mathbb{E}\|\nabla_\mathbf{w}\hat{L}(\mathbf{w}, \mathcal{D_t}) - \nabla_\mathbf{w}\hat{L}(\mathbf{w}, \mathcal{D})\|^2 $$ --- ## Implicit regularization in SGD ([Smith et al, 2021](https://arxiv.org/abs/2101.12176)): mean iterate approximated as continuous gradient flow: $$ \small \dot{\mu}(t) = -\eta \nabla_\mathbf{w}\tilde{L}_{SGD}(\mu(t), \mathcal{D}) $$ where $$ \small \tilde{L}_{SGD}(\mathbf{w}, \mathcal{D}) = \tilde{L}_{GD}(\mathbf{w}, \mathcal{D}) + \frac{\eta}{4}\underbrace{\mathbb{E}\|\nabla_\mathbf{w}\hat{L}(\mathbf{w}, \mathcal{D_t}) - \nabla_\mathbf{w}\hat{L}(\mathbf{w}, \mathcal{D})\|^2}_{\text{variance of gradients}} $$ --- ## Revisiting cartoon example ![](https://i.imgur.com/xRYHk0m.png =1200x) --- ## Is Stochastic Training Necessary? ![](https://i.imgur.com/1WThiku.png) --- ## Is Stochastic Training Necessary? ![](https://i.imgur.com/dHJKsKS.png) * reg $\approx$ flatness of minimum * bs32 $\approx$ variance of gradients size 32 batches --- ## SGD summary * gradient noise is a feature not bug * SGD avoids regions with high gradient noise * this may help with generalization * improved SGD, like Adam, may not always help * an optimization algorithm can be "too good"
{"metaMigratedAt":"2023-06-15T19:20:45.570Z","metaMigratedFrom":"YAML","title":"DeepNN Lecture 4 Slides","breaks":true,"description":"Lecture slides on stochastic gradient descent, its variants like Adam, and some thoughts on generalisation","contributors":"[{\"id\":\"e558be3b-4a2d-4524-8a66-38ec9fea8715\",\"add\":10965,\"del\":6806}]"}
    766 views