# Improving NLP models without scaling
## Motivation
- Easy to improve by scaling
- Bigger models
- More (better) data
- Scaling comes at a cost
- Curating data can get expensive and take long time
- Not all data is created equal
- Larger models comes with memory/speed tradeoffs both in training and inference
- How to improve model performance without scaling up?
- Bonus points for being able to use right away
- Take inspiration from computer vision (CV) research
- Lots of tricks to improve model performance without additional data or larger models
## Sharpness Aware Minimization (SAM)
[Sharpness-Aware Minimization for Efficiently Improving Generalization](https://arxiv.org/abs/2010.01412)
### Motivation:
- Simply minimizing commonly used loss functions (e.g., cross-entropy) on the training set is typically not sufficient to achieve satisfactory generalization. The training loss landscapes of today’s models are commonly complex and non-convex. Leading to lots of local/global minima that have different generalization capabilities
- Prior studying the relationship between the geometry of the loss landscape and model performance show that a flatter minimum tends to generalize better
- Want an optimizer that optimizes for both the loss function and the "sharpness" of the loss
- Rather than just optimizing for model parameters with low loss values, SAM focuses on optimizing for model parameters that lie in neighborhoods with uniformly low loss values.

### Algorithm



[Code implementation](https://github.com/davda54/sam/blob/295977586de6a6e38f4730adb2ae496fd75f94ec/sam.py)
```
from sam import SAM
...
model = YourModel()
base_optimizer = torch.optim.SGD # define an optimizer for the "sharpness-aware" update
optimizer = SAM(model.parameters(), base_optimizer, lr=0.1, momentum=0.9)
...
for input, output in data:
# first forward-backward pass
loss = loss_function(output, model(input)) # use this loss for any training statistics
loss.backward()
optimizer.first_step(zero_grad=True)
# second forward-backward pass
loss_function(output, model(input)).backward() # make sure to do a full forward pass
optimizer.second_step(zero_grad=True)
...
```
### Experimental Results
- Tested a variety of models for both training from scratch, and fine-tuning pretrained models
- Because each SAM update is 2 forward and backprop operations, non-SAM models are trained using both the same, as well as double the epoch count
- Authors also noted that "SAM enables increasing the number of training epochs while continuing to improve accuracy without overfitting. In contrast, the standard training procedure (without SAM) generally significantly overfits as training extends from 200 to 400 epochs"
- Bonus: SAM also acts as a regularization method for training


## SAM in NLP
[Sharpness-Aware Minimization Improves Language Model Generalization](https://arxiv.org/abs/2110.08529)
### Experimental Results

SAM also uniformly improves model performance across all sizes for IR/QA tasks as well

- SAM also improve model performance in low data regime (5% of SGLUE, CBQA data)
### Sidenote: Adaptive Sharpness-Aware Minimization (ASAM)
[ASAM: Adaptive Sharpness-Aware Minimization for Scale-Invariant Learning of Deep Neural Networks](https://arxiv.org/abs/2102.11600)
- Extension of SAM that shows improvement in several Vision and NLP benchmarks
- Not really validated on any "major" benchmarks (Imagenet, GLUE, SGLUE etc)
- *May* improve on SAM, but not well tested at the moment
### Example Code
```
class PLClassifier(pl.LightningModule):
def __init__(
self,
pytorch_model: nn.Module, # this can be any HF model
total_steps: int,
lr: float = 3e-5,
rho: float = 0.05
asam: bool = False,
is_ddp: bool = False,
grad_acc_batches: int = 1
):
super().__init__()
self.lr = lr
self.rho = rho # rho is neighborhood size for SAM
self.asam = asam # whether to use Adaptive SAM
self.is_ddp = is_ddp
self.grad_acc_batches = grad_acc_batches
self.pytorch_model = pytorch_model
self.criterion = nn.BCEWithLogitsLoss()
# IMPORTANT manually define opt step
self.automatic_optimization = False
def configure_optimizers(self):
base_optimizer = AdamW
optimizer = SAM(
self.pytorch_model.parameters(),
base_optimizer=base_optimizer,
lr=self.lr,
betas=(0.9, 0.99),
rho=self.rho
)
scheduler = OneCycleLR(
optimizer=optimizer,
max_lr=self.lr,
pct_start=0.3,
total_steps=self.total_steps
)
return [optimizer], [scheduler]
def forward(self, input_ids, attention_mask):
out = self.pytorch_model(input_ids, attention_mask=attention_mask)
return out
def training_step(self, batch, batch_idx):
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
labels = batch["labels"]
# set optimizer
opt = self.optimizers()
# first forward pass
out = self(input_ids, attention_mask)
loss = self.criterion(out, labels)
# loss backward
if self.is_ddp:
with self.trainer.model.no_sync():
self.manual_backward(loss)
else:
self.manual_backward(loss)
# optimizer step with gradient accumulation
if (batch_idx + 1) % self.grad_acc_batches == 0:
opt.first_step(zero_grad=True)
# 2nd forward pass
out2 = self(input_ids, attention_mask)
loss2 = self.criterion(out2, labels)
self.manual_backward(loss2)
opt.second_step(zero_grad=True)
# lr scheduler step
lr_sch = self.lr_schedulers()
lr_sch.step()
return {"train_loss": loss}
```
### Critiques
- While the original paper compares models trained with SAM optimizer with models trained without SAM but double the epoch count, this paper does not do that
- Extrapolating from the original paper, an argument *could* be made that SAM would still outperform even with additional training time
- All hyperparameters were kept the same across SAM/non-SAM experiments
- Is it truly an improvement, or were the hyperparameters selected to be favorable for SAM?
- Trade offs between generalization and compute time: is the ~2x increase in compute worth the gains?
### Final thoughts
- Has potential, given SAM seems to yield generalization improvements across variety of datasets and model architectures
- Additional benefit of regularization, no more worrying about overfitting?
- Need more papers/experiments to fully study the benefit of SAM
- Defintely something worth trying, fairly easy to incorporate into any model training
## Mixup
[mixup: Beyond Empirical Risk Minimization](https://arxiv.org/abs/1710.09412)
- Another paper from computer vision
- Proposes a simple data augmentation to improve model performance and robustness
$$ \tilde{x} = \lambda x_{i} + (1 - \lambda)x_{j} $$
$$ \tilde{y} = \lambda y_{i} + (1 - \lambda)y_{j} $$
Where $(x_i, y_i)$, $(x_j, y_j)$ are two exmamples drawn at random from the training data
$0 < \lambda < 1$ Is sampled from a $B(\alpha, \alpha)$ distribution each time with $0 < \alpha < 1$
Example of mixup for image classification:

## Mixup for NLP
### Motivation
- We know data augmentation improves model generalization
- Data augmentation for NLP can be pretty hard
- Typically use generative methods (back translation, text generation, etc.)
- Hard to balance text diversity and preserving semantic meaning
### Mixup for Sentice Classification
[Augmenting Data with Mixup for Sentence Classification: An Empirical Study](https://arxiv.org/abs/1905.08941)
- Looked at effects of Mixup augmentation for sentence classification using CNN and LSTMs
- Studied effects of mixup at the word embedding level, as well as sentence embedding level after encoder
### Mixup-Transformer
[Mixup-Transformer: Dynamic Data Augmentation for NLP Tasks](https://arxiv.org/abs/2010.02394)
- Introduces a way to incoporate mixup for NLP using transformers
- Instead of doing mixup on the raw input vectors, use mixup on the pooled output vectors, then use a fully connected linear layer for classification

- Key differences with previous approaches:
- The hidden representations for the mixup vectors are dynamic, trained together in the fine-tuning process
- Can dynamically choose to use mixup for a subset of training epochs
### Experimental Results
- Fine-tuning Bert-base and Bert-large on GLUE
- Did not use mixup in the first half of training to obtain good hidden representations, then use mixup for the 2nd half of training
- Yields improvement in results in both full and limite data scenarios

- Also fine-tuned Bert-large on various GLUE tasks with 10%-100% data
- Mixup consistently improves model performance across low data scenarios

### Discussion
- Typically, the value of $\lambda$ is dynamically sampled from a $~B(\alpha, \alpha)$ distribution. Experiments showed the mixup-transformer seemed invariant to the value of $\lambda$, so it was kept at 0.5
- In CV works exploring mixup, longer training s required to get improved results, but this paper shows improvements in just 3 epochs
- Likely due to the dynamic activation of mixup: using mixup for the entirety of training might result in longer time to obtain good representations
- Benefits of mixup seems to be greater when training data is limited
### Example Code
```
def mixup_data(
x: torch.tensor,
y: torch.tensor,
batch_size: int,
alpha: float = 1.,
lam: Optional[float] = None,
use_cuda: bool = True
):
'''Returns mixed data'''
if not lam:
if alpha > 0:
lam = np.random.beta(alpha, alpha)
else:
lam = 1
# randomly permute indices
index = torch.randperm(batch_size)
mixed_x = lam * x + (1 - lam) * x[index, :]
mixed_y = lam * y + (1 - lam) * y[index, :]
return mixed_x, mixed_y
class HFModel(nn.Module):
def __init__(self, n_classes: int):
super().__init__()
self.model = model
self.classifier = nn.Linear(self.model.config.hidden_size, n_classes)
def forward(self, input_ids, attention_mask, labels = None):
out = self.model(input_ids, attention_mask=attention_mask)
out = out.pooler_output
if labels:
out, labels = mixup_data(out, labels)
out = self.classifier(out.pooler_output)
return (out, labels) if labels else out
```
## Bonus content
[More NLP data augmentation strategies](https://www.kaggle.com/c/chaii-hindi-and-tamil-question-answering/discussion/287923?fbclid=IwAR2unJF_zFq0wGTU8d4h_FqcA9JnVrGieriOgVigQoTrWNXnADGYE-E0I1M)
[Exponential Moving Average for better models](https://github.com/fadel/pytorch_ema)
Label smoothing
- Mostly for discriminative tasks i.e. multi-class/multi-label classification
Mixed-precision training
- Faster training (on certain accelerators), lower memory, regularization
- May require increasing `eps` value in optimizer from default value (eg `1e-08`) for numerical stability