owned this note
                
                
                     
                     owned this note
                
                
                     
                    
                
                
                     
                    
                
                
                     
                    
                        
                            
                            Published
                        
                        
                            
                                
                                Linked with GitHub
                            
                            
                                
                                
                            
                        
                     
                
            
            
                
                    
                    
                
                
                    
                
                
                
                    
                        
                    
                    
                    
                
                
                
                    
                
            
            
         
        
        ***REF***: https://arxiv.org/pdf/2402.03300
Algorithm:

Objective:

# GRPO Quick Note
This note is for learning the GRPO quickly while use lot of simplified abstraction or intuitive descriptions.
But it works and helps understanding of GRPO (and related things about RL on probabilitics-based models. This GRPO method can work on any token-based AR model such as LLM, AR image gen or AR video gen)
The same concept should be worked on more general tasks but current implementation is specifically for probabilitics based model.
## Key Components
1. Model
    * reference model: the original model (at first)
    * proxy model: the model to train
2. Sampling
    * nucleus sampling or any random sampling method
    * we need to generate multiple output for same input so some trick like TGTS can be introduced
3. Trainer
    * PEFT method can be used, I may try to use LyCORIS here. (multiplier 1 as proxy model and multiplier 0 as ref model)
    * try to implement in pytorch lightning
4. Dataset
    * In RL procedure, there are no "direct answer" for the corresponding input, so the dataset should be seen as input-reward func pair
    * In general model, it is very possible that you need different reward function/model for different type of task, so each entry should provide the info of needed reward function. (Like prompt format for LLM-based reward function, or the expected format for format score)
5. Reward Function
    * Should be predefined and will be called during training.
    * Can be a LLM model with template, a direct scorer model or desired format or fixed answer
## HyperParameters
***We ignore some general NN training hyperparam***
* [BS] Batch Size: How many entry in each Steps
* [Iters] Iterations: The total GRPO training loops
    * In each iteration we do multiple generation and training iter, than update the reference model as we start a new GRPO training, and update the reward function if needed.
* [Steps] Steps: The total GRPO Steps
* [GRPO Iters] GRPO Iteration: The training iter for each generation
* [GS] Group Size: How many LLM generation for each entry
## Training Loop
1. RL Iteration
    1. Get BS entries from Dataset
    2. Generate GS sample with proxy model for each entry
    3. Calculate reward for each sample
        * remember to calculate group wise A (Advances) as final score
    5. Calculate log prob for each token with ref model (log_prob_ref)
    6. Calculate log prob for each token with proxy model (log_prob_p_old, detached)
    7. GRPO Loop
        1. Calculate log prob for each token with proxy model
        2. Calculate reward scale for each token
            * scale = torch.exp(log_prob - log_prob_p_old)
        4. Calculate reward for each token
            * reward = min(reward * scale, reward * clip(scale, 1-eta, 1+eta))
        5. Calculate kl divergence by log prob
        6. Calculate loss with following equation:
            * loss = - torch.exp(log_prob - log_prob.detach()) * score + beta * kl_div(log_prob, log_prob_ref)
        7. loss.backward() and optimizer step
8. Replace ref model with current proxy model
9. Update reward model if needed
    1. some reward function is fixed, which never need to be updated
## Intuition
For each sample you have corresponding score, we use torch.exp(log_prob - log_prob.detach()) to let the model to learn "how to make the probability higher for entries with higher score".
This objective doesn't have any direct desired value so it is possible that some value will get exploded and that's why we use kl divergence as regularization. (and that's why we need clipping for the scale)
Also, the scaled reward allow the model to get real-time update in the GRPO loop while have min(scaled_reward, clipped_scaled_reward) to avoid some "already well learned tokens" get emphasized too much.
The scaled reward with clipping and kl divergence means it is ok to have quite large GRPO iters. Which SHOULD speed up the training.
## Implementation with pytorch lightning
***All the code in this section should be seen as pseudo code***
If we consider a single training loop (to match the pytorch lightning's mechanism). We need to design our own IterativeDataset wrapper which repeat same batch until achieve enough GRPO iterations.
Therefore we have some special design for dataloader/dataset:
1. each "batch" which send to lightning trainer have following content
    1. batch index
    2. sub index (for grad accumulation)
    3. GPRO iteration
    4. batch content
5. The wrapper, sample a full size (global BS) batch from original dataset, than split into sub batch.
6. The wrapper/dataloader yield sub batch with following mechanism:
    * ```python=
        def iter_batch():
            idx, batch = sample(org_dataset, bs*grad_acc)
            for i in range(gpro_iter):
              for sub in range(grad_acc):
                yield idx, sub, i, batch[sub*bs:(sub+1)*bs]
      ```
Than we have the training_step() and related utilities designed as:
```python=
@torch.no_grad()
def calc_cache(self, entries):
    groups = [self.generate_groups(entry, self.group_size) for entry in entries]
    advantages = torch.concat([self.group_score(group, entry) for group, entry in zip(groups, entries)])
    log_prob_ref = torch.concat([self.log_prob(output, ref=True) for output in groups])
    return [groups, advantages, log_prob_ref]
    
def trainin_step(self, batch, idx):
    bid, subid, gproiter, entries = batch
    if bid != self.bid:
        self.bid = bid
        self.cache = [None]*self.grad_acc
        self.log_prob_old = [None]*self.grad_acc
    if self.cache[subid] is None:
        self.cache[subid] = self.calc_cache(entries)
    groups, advantages, log_prob_ref = self.cache[subid]
    log_prob = [self.log_prob(output) for output in chain(groups)]
    if self.log_prob_old[subid] is None:
        self.log_prob_old[subid] = log_prob.detach()
    log_prob_old = self.log_prob_old[subid]
    reward_scale = torch.exp(log_prob - log_prob_old)
    clipped_scale = torch.clamp(reward_scale, 1-self.eta, 1+self.eta)
    scaled_adv = reward_scale * advantages
    clipped_adv = clipped_scale * advantages
    mask = scaled_adv < clipped_adv
    result = torch.where(mask, scaled_adv, clipped_adv)
    
    kl_div = torch.exp(log_prob_ref - log_prob) - (log_prob_ref - log_prob) - 1
    objective = result - self.beta * kl_div
    return -objective
```