--- title : "Addressing Distribution Shift in Online Reinforcement Learning with Offline Datasets - Notes" tags : "IvLabs, RL" --- # Addressing Distribution Shift in Online Reinforcement Learning with Offline Datasets Link to the [Research Paper](https://offline-rl-neurips.github.io/pdf/13.pdf) {%pdf https://offline-rl-neurips.github.io/pdf/13.pdf%} Strong RL agents can be made on previously-collected, static datasets (Offline RL), it is often desirable to improve such offline RL agents with further online interaction ## Introduction Offline RL may be suboptimal - The dataset they were trained on may only contain suboptimal data - The environment in which the data set is developed may be different from the environment in which dataset was generated This calls for Fine-Tuning Challenges - Offline RL algorithms based on modeling the dataset-generating policy are not amenable to fine tuning - due to difficulty of modeling the dataset-generating policy in the online setup - Conservative Q-Learning - does not require explicit behaviour modeling - amenable to fine tuning - In a non trivial task - due to distribution shift - The agent encounters out of distribution samples - loses its good initial policy - Can be attributed to bootstrapping error - error introduced when Q-function is updated with an inaccurate target value evaluated at unfamiliar states and actions Appeal of Offline RL lies in safe deployment at test time ### Contribution - Demonstration of fine-tuning a CQL - unstable training - Balanced replay scheme and an ensemble distillation scheme - Separate replay buffers for offline and online samples - modulate the sampling ratio to balance the effect - Widening the data distribution the agent sees (offline data) - Exploiting the environment feedback (online data) - Ensemble distillation - Learn Ensemble of independent CQL agents - Distill the multiple polices into a single policy - Policy is improved using the mean of Q-functions - policy updates are more robust to error in each individual Q-function ## Background ### Reinforcement Learning Off-policy RL algorithms - Train an agent with samples generated by any behaviour policy - Well suited for fine tuning a pre trained RL agent - can leverage both offline and aonline samples ### Soft Actor-Critic* - Off-policy actor-critic algorithm that learns a soft Q-function $Q_{\theta} (s,a)$ parameterized by $\theta$ and a stochastic policy $\pi_\theta$ modeled as a Gaussian with its parameters $\phi$ - SAC alternates between a soft policy evaluation and a soft policy improvement - Soft Policy evaluation - $\theta$ updated to minimize the following $\mathcal L_{\text{critic}}^{\text{SAC}}(\theta) = \Bbb E_{\tau_t\sim\mathcal B}[\mathcal L_Q(\tau_t,\theta)]$ $\mathcal L_Q^{\text{SAC}}(\tau_t,\theta) = (Q_\theta(s_t, a_t) - (r_t + \gamma \Bbb E_{\alpha_t\sim\pi_\phi}[Q_{\bar\theta}(s_{t+1}, a_{t+1}) - \alpha\text{ log }\pi_\phi(a_{t+1}|s_{t+1})]))^2$ $\tau_t = (s_t, a_t, r_t, s_{t+1})$ $\mathcal B$ - Replay Buffer $\bar \theta$ - Moving target of soft Q-function parameter $\theta$ $\alpha$ - Temperature parameter - Soft Policy improvement - $\phi$ is updated to minimize the following $\mathcal L_{\text{actor}}^{\text{SAC}}(\phi) = \Bbb E_{s_t\sim\mathcal B}[\mathcal L_\pi(s_t,\phi)]$ $\mathcal L_\pi(s_t,\phi) = \Bbb E_{\alpha_t\sim\pi_\phi}[\alpha \text{ log }\pi_\phi(a_t|s_t) - Q_\theta(s_t, a_t)]$ ### Conservative Q-Learning CQL - Offline Rl algorithm that learns a lower bound of the Q-function $Q_\theta(s,a)$ - To prevent extrapolation error-value overestimation caused by bootstrapping from out-of distribution actions CQL($\mathcal H$) - Imposes a regularization that minimizes the expected Q-value at unseen actions and maximizes the expected Q-value at seen actions ## Challenge: Distribution Shift - Distribution Shift - Offline RL agent encounters data distributed away from the offline data - when interacting with the environment - It involves an interplau between actor and critic updates with newly collected out-of-distribution samples - It occurs because there is a shift between offline and online data distribution - In case of using both online and offline data - The chance of agent seeing online samples for update becomes too low - This prevents the timely updates at unfamiliar states encountered online - In case of using online data exlusively - The agent is exposed to unseen samples only, for which Q function does not provide a reliable value estimate - bootstrapping error There is a need to balance the trade-off ## BRED: Balanced Replay with Ensemble Distillation - Addresses the distribution shift - Separate offline and online replay buffers - to select a balanced mix of samples - Advantage - Updates Q-function with a wide distribution of samples - Q-values are updated at novel, unseen states from online interaction - Multiple actor-critic models are trained together and their policy is distilled in a single policy - This distilled policy is then improved via Q-ensemble ### Balancing experiences from online and offline replay buffers At timestep t - $B\cdot(1-\rho_t^{\text{on}})$ samples are drawn from the offline replay buffer and $B\cdot(\rho_t^{\text{on}})$ samples from the offline replay buffer - B - Batch Size - $\rho_t^{\text{on}}$ - Fraction of online samples $\rho_t^{\text{on}} = \rho_0^{\text{on}} + (1-\rho_0^{\text{on}})\cdot\dfrac{\text{min}(t,t_{\text{final}})}{t_{\text{final}}}$ $t_{\text{final}}$ - Final timestep of the annealing schedule $\rho_0^{\text{on}}$ - Initial fraction of online samples - Effect - Better Q-function updates with a wide distribution of both offline and online samples - Eventually exploiting the online samples later when there are enough online sample gathered ### Ensemble of offline RL agenst for online fine-tuning - During fine-tuning - each individual Q-function may be inaccurate due to bootstrapping error from unfamiliar online samples - Consider an ensemble of N CQL agents pre-trained - Distillation of these ensemble of independent policies is done by minimizing the following before online interaction- $\mathcal L_{\text{distill}}^{pd} (\phi_{pd}) = \Bbb E_{s_t\sim\mathcal D}[||\mu_{\phi_{pd}}(s_t) - \hat\mu(s_t)||^2 + ||\sigma_{\phi_{pd}}(s_t) - \hat\sigma(s_t)||^2]$ Where, $\displaystyle\hat \mu (s_t) = \frac 1N\underset{i = 1}{\overset{N}{\sum}}\mu_{\phi_i}(s_t)$ $\displaystyle\hat \sigma^2 (s_t) = \frac 1N\underset{i = 1}{\overset{N}{\sum}}(\sigma^2_{\phi_i(s_t)} + \mu_{\phi_i}^2(s_t)) - \hat\mu(s_t)^2$ - The distilled policy is them updated by minimizing the following $\mathcal L_{\text{actor}}^{\text{pd}}(\phi_\text{pd}) = \Bbb E_{s_t\sim\mathcal B}[\mathcal L_\pi^{\text{pd}}(s_t,\phi_\text{pd})]$ $\mathcal L_\pi^\text{pd}(s_t,\phi) = \Bbb E_{\alpha_t\sim\pi_{\phi_\text{pd}}}[\alpha \text{ log }\pi_{\phi_\text{pd}}(a_t|s_t) - \dfrac 1N \underset{i = 1}{\overset{N}{\sum}} Q_{\theta_i}(s_t, a_t)]$ - A separate target Q-function $Q_{\hat\theta}$ for each $Q_\theta$ - then minimize the loss independently - to ensure diversity among the Q-functions ## Related Work Offline RL - [CQL](https://dl.acm.org/doi/pdf/10.5555/3495724.3495824) Online RL with offline datasets - [Optimality of Dataset](https://arxiv.org/pdf/2006.09359.pdf) Replay Buffer - [Hard Exploration Problem](https://arxiv.org/pdf/1707.08817.pdf) - [Continual Learning](https://papers.nips.cc/paper/2019/file/fa7cdfad1a5aaf8370ebeda47a1ff1c3-Paper.pdf) Ensemble Methods - Addressing Q-function's [overestimating bias](https://papers.nips.cc/paper/2010/file/091d584fced301b442654dd8c23b3fc9-Paper.pdf) - [Better exploration and reducing bootstrap error propagation](https://arxiv.org/pdf/2002.06487.pdf) ## Experiments ### Setups #### Tasks and Implementation MuJoCo - `halfcheetah`, `hopper`, `walker2d` - D4RL Dataset types - `random` - `medium` - `medium-replay` - `medium-expert` Offline RL agent - 1000 epochs without early stopping - N = 5 - $\rho_0^{\text{on}} \in \{0.5,0.75\}$ - $t_{\text{final}} = 125K$ - Report mean and standard deviation across four runs for 250K timesteps #### Baselines - Advantage-Weighted Actor Critic (AWAC) - actor-critic scheme for fine-tuning - policy is trained to imitate actions with high advantage estimates - Comparison of BRED to AWAC shows the benefit of exploiting the generalization ability of Q-function for policy learning - Batch-Constrained deep Q-learning (BCQ-ft) - Offline RL method - updates policy be modeling the data-generating policy using a conditional [VAE](https://papers.nips.cc/paper/2015/file/8d55a249e6baa5c06772297520da2051-Paper.pdf) - CQL - CQl trained agent, fine-tune with SAC - Exclude CQL regularization - SAC - SAC agent trained from scratch - No access to offline dataset - To show the benefit of fine-tuning a pre-trained agent in terms of sample-efficiency