# Multi Term Adam Explained [TOC] # Abstract ## Problem - **Multiple losses in training**: When training neural networks, we often use more than one loss (for example: classification loss + adversarial loss). - **Balancing loss terms is hard**: - You need to manually choose how much weight to give each loss. - This takes time and effort. - It's also hard because the "best" balance might change during training. - **Adversarial terms change more**: Some losses (like adversarial ones) can behave differently as training goes on, making balancing even harder. ## Solution: MTAdam (Multi-Term Adam) - **MTAdam is a new version of Adam optimizer**: - It is specially designed to handle multiple loss terms at the same time. - **Main idea**: Balance the gradient magnitudes of different losses _for every layer_ in the network. - **How it works**: - Calculates the gradient from **each loss term separately**. - Tracks first and second moments (like in Adam) **for each parameter and each loss**. - Calculates the **gradient magnitude per layer** for each loss. - Uses these magnitudes to **adjust the balance** between losses during training. ## Benefits - **Layer-wise balancing**: Each layer can adjust its balance between losses independently. - **Dynamic adjustment**: The balancing changes over time automatically as training progresses. - **Better recovery**: Even if your initial weights for the loss terms are wrong, MTAdam can correct it during training. - **Good results**: Final model performance is as good as manually tuned training. # About it - MTAdam extends Adam and allows an effective training of an unweighted multi-term loss objective. - MTAdam can streamline the computationally demanding task of hyperparameter search, required for effectively weighting multi-term loss objectives. - **Hyperband** is utilized as baseline that can directly optimize the FID score (**Fréchet Inception Distance**. It is used to **measure how similar two sets of images are**) - Image generation techniques utilize multiple loss terms, including pixel-wise norms, perceptual losses, and one or more adversarial loss terms - Experimented on 3 different image generation methods: - **pix2pix**: generates an image in domain B based on input image in domain A - **CycleGAN**: performs the same task while training in an unsupervised manner on unmatched images from 2 domains - **SRGAN**: generates high-resolution images from low-resolution ones and is trained in supervised way - **Multi-task Learning(MTL)** means training a model to do several tasks at the same time. - In this each task has its own loss term (error to minimize) - Approaches in MTL: 1. **Hard Parameter Sharing**: - One model shares most layers (parameters) for all tasks. - Separate **task-specific heads** (output layers) are used. - Works well when tasks are **closely related** (e.g., Mask R-CNN). 2. **Soft Parameter Sharing**: - Each task has its own separate model. - A **regularization term** keeps the models parameters similar. - Encourages learning shared patterns while keeping task flexibility. ---- --- # Method ## Adam Algorithm - **Adam algorithm** optimizes one stochastic objective function $f_t(\theta)$ over the set of parameters $\theta$ where $t$ is an index of current mini-batch of samples. - Adam's task is to minimize the expected value $E_t [f_t(\theta) ]$ w.r.t $\theta$ - 2 moments are continuously updated, using moving average scheme: - $m_t$ is the first moment of gradient $\nabla_\theta f_\theta$ - $v_t$ is the second moment - Both are vectors of the same size of $\theta$ - Moving averages are computed using the mixing coefficients $\beta_1$ and $\beta_2$ for the two moments. ## MTAdam - **MTAdam** optimizes set of such terms $f^1_t(\theta),...,f^I_t(\theta)$. - $I$ represents number of loss terms. - MTAdam minimizes a **weighted average** of the $I$ terms. The weights of these mixtures are all positive but otherwise unknown. - **KEY IDEA:** Adjust the weights so that the **gradient magnitudes** of each term are **equal**. This ensures **fair learning** across tasks. - This magnitude is evaluated and balanced at every layer of the neural network. - MTAdam records such moments for each term $i=1,...,I$ seperately. - It also uses a mixing coefficient $\beta_3$ in order to maintain the moving average of gradient magnitude per each layer $l$, which is denoted by $n^i_{l,t}$. - In MTAdam, the first moment is computed based on weighted gradient, in which the parameters of each layer $l$ for every term $i$ are weighted such that their magnitude is normalized by the factor $n^i_{l,t}$ --- --- # Code And Algorithm - Executed on a sequence of MNIST Experiments - $I$ represents number of loss terms. - $T$ is number of training steps. - $m_t$ is the first moment of gradient $\nabla_\theta f_\theta$ - $v_t$ is the second moment ## Step 1 - Initialization (Line 1-2) - MTAdam algorithm initializes $I$ pairs of **first and second moment vectors**. $$ for \space i=\{1,...,I\} \space \space do \space \space m^i_0,v^i_0 \leftarrow 0,0 \tag{Line 1} $$ - exp_avg_i - $m^i$ - exp_avg_sq_i - $v^i$ - MTAdam initializes $I$ **first moments for the magnitude of the gradients, per layer**. $$ for \space l \space in \space 1,...,L \space \space do \space \space n^i_{l,0} \leftarrow 1 \tag{Line 2} $$ - p.norms - $n^i$ ```python= # State initialization if len(state) == 0: state['step'] = 1 for j, _ in enumerate(loss_array): # Exponential moving average of gradient values state['exp_avg'+str(j)] = torch.zeros_like(p.data) # Exponential moving average of squared gradient values state['exp_avg_sq'+str(j)] = torch.zeros_like(p.data) if amsgrad: # Maintains max of all exp. moving avg. of sq. grad. values state['max_exp_avg_sq'+str(j)] = torch.zeros_like(p.data) if j == 0: p.norms = [torch.ones(1).cuda()] else: p.norms.append(torch.ones(1).cuda()) ``` ## Step 2 - Iterate over the stochastic mini-batches, performing T training steps. (Similar to Adam) $$ while \space t=\{1,...,T\} \space \space do \tag{Line 3} $$ ## Step 3 - Algorithm Start - Iterate over each loss term $$ for \space i = \{1,...,I\} \space do \tag{Line 4} $$ ## Step 4 - Calculate gradients (Line 5) - Calculate the gradients of each specific loss term $i$ w.r.t $\theta$, across all layers. - For a layer index $l$, the gradient is denoted by $g^i_l$. - $g^i$ is the concatenation of all per-layer gradients. $$ g^i := (g^i_1,...,g^i_L) \leftarrow \nabla_\theta f^i_t(\theta_{t-1}) \tag{Line 5} $$ ## Step 5 (Line 6-8) - Iterate over all loss layers $$ \textbf{for } \ell \text{ in } 1 \ldots L \textbf{ do} \tag{Line 6} $$ - Normalize the magnitude of the gradients of all the loss terms, to match the magnitude of the first loss term. $$ \begin{cases} n_{\ell,t}^i \leftarrow \beta_3 \cdot n_{\ell,t-1}^i + (1 - \beta_3) \cdot \| g_{\ell}^i \|_2 \\ g_{\ell}^i \leftarrow n_{\ell,t}^1 \cdot g_{\ell}^i / n_{\ell,t}^i \end{cases} \tag{Line 7-8} $$ - Iterates over the layers, updates the moving average of the magnitude for each layer and loss term, and normalize the gradients of the current layer and loss term, by multiplying with $\frac {n^1_{{\ell,t}}}{n^i_{{\ell,t}}}$ ```python= # normalize the norm of current loss gradients to be the same as the anchor if state['step'] == 1: p.norms[loss_index] = torch.norm(p.grad) else: p.norms[loss_index] = (p.norms[loss_index]*beta3) + ((1-beta3)*torch.norm(p.grad)) if p.norms[loss_index] > 1e-10: for anchor_index in range(len(loss_array)): if p.norms[anchor_index] > 1e-10: p.grad = ranks[loss_index] * p.norms[anchor_index] * p.grad / p.norms[loss_index] break ``` - p.norms $n^i_{l,t}$ - p.grad $g^i_l$ - The normalization iterates over the loss terms, and for each gradient, the first moment of the magnitude of the gradient is updated. The gradient magnitude is then normalized by that of the first loss term. ## Step 6 - Update Moments (Lines 9-12) - MTAdam updates the first and second moments for each parameter and each loss term and computes their bias correction. - This is similar to Adam, except that the moments are calculated separately for each loss term $$ \begin{align*} m_t^i &\leftarrow \beta_1 \cdot m_{t-1}^i + (1 - \beta_1) \cdot g^i \\ v_t^i &\leftarrow \beta_2 \cdot v_{t-1}^i + (1 - \beta_2) \cdot (g^i)^2 \\ \hat{m}_t^i &= \frac{m_t^i}{1 - \beta_1^t} \\ \hat{v}_t^i &= \frac{v_t^i}{1 - \beta_2^t} \end{align*} \tag{Lines 9-12} $$ ```python= exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(p.grad, p.grad, value=1 - beta2) ``` ## Step 7 - Update Parameters $$ \theta_{t-1} \leftarrow \theta_{t-1} - \alpha \cdot \hat{m}_t^i \Big/ \left( \sqrt{ \max(\hat{v}_t^1, \ldots, \hat{v}_t^I) } + \epsilon \right) $$ ```python= exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(p.grad, p.grad, value=1 - beta2) if amsgrad: # Maintains the maximum of all 2nd moment running avg. till now torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) # Use the max. for normalizing running avg. of gradient denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) else: denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) step_size = group['lr'] / bias_correction1 if loss_index == 0 or not hasattr(p, 'exp_avg'): p.exp_avg = [exp_avg] p.denom = [denom] p.step_size = [step_size] else: p.exp_avg.append(exp_avg) p.denom.append(denom) p.step_size.append(step_size) if p.grad is not None: p.grad.detach_() p.grad.zero_() for group in self.param_groups: for p in group['params']: temp = 0 max_denom = p.denom[0] for index in range(1, len(p.exp_avg)): max_denom = torch.max(max_denom, p.denom[index]) for index in range(len(p.exp_avg)): update_step = -p.step_size[index]*(p.exp_avg[index]/max_denom) temp += update_step p.add_(temp) ``` <!-- # Steps 1. Run MTAdam on DNBP using TACC 2. Build a new optimizer based on prof idea - Implement MTSGHMC 3. Test the new optimizer on DNBP and compare performance with MTAdam/Adam etc 4. Build DNBP for Human Pose Estimation and tracking problem and use the new optimizer -->