<style> img { display: block; margin-left: auto; margin-right: auto; } </style> > [Paper link](https://arxiv.org/abs/2308.02223) | [Code link](https://github.com/wangclnlp/DeepSpeed-Chat-Extension/tree/main/examples/esrl) | AAAI 2024 :::success **Thoughts** Reinforcement Learning (RL) for sequence generation models can be computationally expensive. This study proposes an efficient RL method (ESRL) using two-stage and dynamic sampling approaches. ::: ## Abstract Applying Reinforcement Learning (RL) to sequence generation models allows the optimization of long-term rewards, such as BLEU scores and human feedback. However, this often demands extensive sampling of action sequences, which is computationally challenging due to the large action space and long sequences typical in machine translation. In this study, they introduce two-stage sampling and dynamic sampling methods to enhance sampling efficiency in training these models. ## Background Reinforcement Learning (RL) in training sequence generation models has gained attention, but its application in NLP presents challenges due to a **large action space** and **long sequences**. Below is an illustration of the traditional RL loss calculation. ![image](https://hackmd.io/_uploads/H1ih0Pi5R.png) To address these challenges, they explore strategies that reduce the computational load during exploration in RL for sequence generation models. ## Method In sequence generation models, given an input $x$, the model generates a sequence of $N$ tokens $y = \{ y_1, \dots, y_N \}$. During training, the model learns the probability: $$ p_\theta(y \mid x) = \prod_{t=1}^N p_\theta (y_t \mid y_{<t}, x) $$ At the inference stage, they generate tokens sequentially according to probability $p_\theta$. The RL loss for this training instance: $$ \mathcal{L}_{\mathrm{RL}} = \sum_{\hat{y} \in \Omega(x)} p_\theta(\hat{y} \mid x) r(\hat{y}) $$ where $\Omega(x)$ is the output space which comprises all possible candidate target sequences for input $x$. --- This study, Efficient Sampling-based RL (ESRL) can explore efficiently. 1. They use a **two-stage sampling** framework to implement the exploration. 2. And then they propose a **dynamic sampling** approach that can reduce redundant sampling by considering the capability of a model. ![image](https://hackmd.io/_uploads/ryTpCPicA.png) ### Two-stage Sampling To address the excessive computational graph storage requirements from the sampling process, they use a two-stage framework. In the first stage, they sample candidate sequences using an autoregressive model. In the second stage, they calculate the probabilities of these sampled sequences. ### Dynamic Sampling This study purposes a dynamic sampling to further improve the efficiency of RL training. They first estimate the model's capability and then adjust the sampling size and temperature based on this estimation, allowing for efficient and adequate sampling. ### Optimization This study replaces the standard policy method with the fusion of MRT and REINFORCE in computing the loss. ## Experiments In their experiment, they use three different tasks and using a standard **Transformer** base model 1. Machine Translation 2. Abstractive Summarization 3. RLHF > Machine Translation ![image](https://hackmd.io/_uploads/SyZJyOjcC.png) > Abstractive Summarization ![image](https://hackmd.io/_uploads/SJ7D2Os9A.png)