# Notes on "Selective Replay Enhances Learning in Online Continual Analogical Reasoning" ###### tags: `continual learning` `analogical reasoning` ### Author [Rishika Bhagwatkar](https://https://github.com/rishika2110) ## Introduction * Analogical encoding has been shown to facilitate forward knowledge transfer and backword transfer for memory retrieval. * They integrate both regularization and replay continual learning mechanisms into neural networks for analogical reasoning to establish baseline results. * They train and test on Relational and Analogical Visual rEasoNing (RAVEN) dataset. It contains objects in structured pattern for Raven’s Progressive Matrices (RPMs) problem, hence forcing the models to perform both structural and analogical reasoning. ## Related Work ### Neural Networks for RPMs Currently, Rel-Base and Rel-AIR have shown the best results on RAVEN and compete with each other on the Procedurally Generated Matrices (PGM) dataset. PGM dataset was the first largescale RPM-based dataset containing enough problems to successfully train deep neural networks. * Rel-Base: An object encoder network is used for processing image panels individually. Then, these encodings are fed to a sequence encoder to extract relationships among them before being scored. * Rel-AIR: Attend-Infer-Repeat (AIR) module is trained to extract objects from images. These objects are encoded and paired with additional position and scale information before being processed by a sequence encoder. ### Continual Learning in Neural Networks * Methods adapted for mitigating catastrophic forgetting are: * Regularisation schemes: Apply constraint weight updates with GD * Network expansion techniques: Adding new parameters for learning new parameters * Replay mechanisms: Storing representations of seen data to mix with the new data * Several regularisation mechanisms seek to directly preserve important network parameters over time. * Variants of distillation seek to preserve outputs at various locations in the network. * For image classification, replay variants are the state-of-the-art approach for mitigating catastrophic forgetting. The most relevant task on which continual learning is performed is VQA. However, the input is in natural language and the model needs to mitigate several biases present in the dataset. ## Continual Learning Models and Baselines 1. Base Initialization Phase: Rel-Base is trained offline on the first task. 2. Then, each of the learner starts learning the remaining tasks one by one. Testing is done on the completion of every task. 3. There are 2 kinds of models: batch models and incremental models Methods used to enable continual learning on Rel-Base: * Fine-Tune - Both settings * Distillation - Batch setting * Elastic Weight Consolidation - Batch setting * Partial Replay - Streaming setting * Cumulative Replay - Batch setting * Offline - Batch setting Partial Replay model is played in 2 phases: Base initialisation phase (trained offline) and streaming phase. All of the base initialisation data is stored in the replay buffer. Then, the streaming phase starts with model learning on a mixture of the current sample ($S_i, y_i$) and $r$ labelled samples stored in replay buffer, selected on the basis of probabilities calculated by: \begin{equation} p_i = \frac{v_i}{\sum_{v_j \in \mathcal{B}} v_j} \end{equation} where $v_i$ is the value associated with choosing sample $S_i$ from buffer $\mathcal{B}$ for replay. Following seven selective replay policies are studied: * Uniform Random: Randomly select examples with uniform probability from the memory buffer * Minimum Logit Distance: Samples are scored according to their distance to a decision boundary \begin{equation} s_i = \sum_{j=1}^{K} |{\phi(\mathbf{S_i})}_j \space \mathbf{y}_j|\end{equation} * Minimum Confidence: Samples are selected based on network confidence \begin{equation}s_i = \sum_{j = 1}^{K}softmax (\phi(\mathbf{S_i}))_{j} \space \mathbf{y}_j = P (C = y_i | \mathbf{S_i}) \end{equation} $y \in C$ is a one-hot encoding of the label $y_i$. * Minimum Margin: \begin{equation} s_i = P (C = y_i | \mathbf{S_i}) - max_{y', y \neq y' } P(C=y'| \mathbf{S_i}) \end{equation} * Maximum Loss: \begin{equation} s_i = -\sum_{j = 1}^{K} \mathbf{y}_j \log P (j = y_i | \mathbf{S_i}) \end{equation} * Maximum Time Since Last Replay: Samples which have not been replayed for a while might be forgotten, hence it's essential to replay them. So samples are chosen based on the time they were last seen by the model. * Minimum Replays: Samples which are replayed for very times might not be learnt by the model. Every sample is initialised with replay count as number of epochs in base initialisation. After the $r$ samples are chosen, the network updates on this batch of $r + 1$ samples for a single iteration and the associated $s_i$ values of the $r + 1$ samples are subsequently updated. ## Metrics An $R$ matrix is defined, $R \in \mathbb{R}^{T \space \mathbb{x} \space T}$, where $R_{i ,j}$ denotes test accuracy of model on task $t_j$ after task $t_i$. $T = 7$ is the total number of tasks (Center, Out-InCenter, Left-Right, Up-Down, 2x2Grid, 3x3Grid, Out-InGridg). Accuracy of online model at time $i$ is given by $\gamma_{i} = \frac{1}{T}\sum_{j=1}^{T}R_{i,j}$ Continual learner's performance is determined by $\Omega = \frac{1}{T}\sum_{i=1}^{T}\frac{\gamma_{i}}{\gamma_{offline, i}}$ $\gamma_{offline, i}$ is the oflline model accuracy Average accuracy is given by $A = \frac{2}{T(T+1)}\sum_{i\geq j}^{T}R_{i,j}$ Forward Transfer is given by $FWT = \frac{2}{T(T-1)}\sum_{i=2}^{T}\sum_{j=2}^{i-1}R_{i,j} - R_{j,j}$ Backward Transfer is given by $BWT = \frac{2}{T(T-1)}\sum_{i<j}^{T}R_{i,j}$ $BWT$ is negative implies model suffered from catastrophic forgetting ## Conclusion 1. Replay methods had the best global performance and backward/forward knowledge transfer. 2. All other sample selection methods performed better than uniform random policy. 3. Replaying samples based on Min Replays and Max Loss strategies yielded the best overall results. 4. Baseline models and metrics for continual analogical learning were presented.