Labeled Memory Networks for Online Model Adaptation === **Title**: Labeled Memory Networks for Online Model Adaptation **Authors**: Shiv Shankar, Sunita Sarawagi **Year**: 2018 **Link**: <https://www.aaai.org/ocs/index.php/AAAI/AAAI18/paper/viewFile/17141/16672> > Augmenting a NN with memory that can grow without growing the number of trained parameters is a recent powerful concept... We establish their potential in online adapting a batch trained NN to domain-relevant labeled data. > Introduce **LMN (Labeled Memory Network)**: a memoery augmented NNet (or MANN) for fast online model adaptation. - LMNs treat memoery as a second boosted stage following the trained network -- allows memory and network to play complementary roles (*what does this mean??*) - Better memory utilization by writing only labelled data with non-zero loss. - Better memoery organization by using the discrete class label as the primary key. - Uses RNN for determining weighting b/w memory and network. System evaluated on sequence prediction tasks, language modelling tasks. ## Introduction - Batch training is infrequent and expensive. What if we want to train online? - text auto-completion - user trajectory prediction - extension of few-shot, zero-shot learning - All example tasks above involve identification/prediction of existing class labels. - *This paper tried to extend to handling new labels.* Current approaches to "model adaptation" require a re-turing step of all or parts of the model using domain-labelled data in a separate adaptation phase. - This is one-time adaptation separate from training. - Using this in online setting is very slow. - Meta-learning on pre-trained networks tend to destry previously learned useful information. Alternative is memorizing: - NNet is augmented with memoery that can grow without increasing training parameters. - Tried to apply MANNs through NTMs and DNTMs, but they did not beat baselines. - One problem is balancing role b/w trained net and memory. ### Online Model Adaptation -- Problem Description - Online Model Adaptation kicks in after deploying a batch-trained model. - Trained model sees new prediction data, it's predicted response and the actual response -- like predicting which link user will click on and then seeing which link he actually clicks on. - The online adapter, decides how to improve the next prediction by using this limited "live" labelled data and the existing pre-trained model. - This has to be fast. The LMN solution to these problems consists of 3 parts: 1. **Primary Classification Network** (PCN) 2. **Memory Module** 3. **Combiner Network** ### PCN - Classification networks/models only. - Assumes input $x_t$ is converted to a real-vector $h_t$ before going to softmax layer. Softmax prediction for input $x_t$ for predicting class $y$ is $r_{ty}$. ### Memory Memory consists of $N$ cells. Each cell $m$ is a 3-tuple: $(l_m, v_m, \alpha_m)$ - $l_m$ is the label of the cell - $v_m$ is the hidden vector stored in $m$ - $\alpha_m$ is the weight attached to the cell How it works? - $l_m$ acts as an index to enumerate all cells with label $y$. - Memory provides score over each class label $y$ for given input $x_t$. - Pre-softmax output from PCN ($h_t$) is the embedding for $x_t$. With this, we compute kernel b/w $h_t$ and memory vector $v_t$. ![](https://i.imgur.com/lYwulpN.png) :arrow_up: in the above equations :arrow_up: - First compute $w_{tm}$, then take the dot prod with the cell vector $v_m$ to get $M_{ty}$ (this is the vector which you have read from memory). - All summation over cells with the same labels. With the "read" vector, you can compute the "memory score" $s_{ty}$ : ![](https://i.imgur.com/wdwOtOT.png) ### Combining PCN score and Memory score - Convex sum b/w the two scores; $\theta_{ty}$ is the weight parameter. - This theta is obtained from an RNN output followed by sigmoid layer. ## Training Memory module is setup with a pre-trained PCN. Full network is trained in online setting -- data provided in sequential manner. ## Adapting to True Label $y_t$ 1. Update state of combiner RNN. 2. Write to memory **if needed**. ### When to write - When diff b/w true-label prob and closest *wrong*-label prob is more than a margin. I.e. only write when model is misled. MANNs write everything and fill up the memory. ### What to write Both $v_m$ and $\alpha_m$ of the cells are updated based on similarity of $h_t$ w.r.t. the cell contents. *IF PREDICTION IS INCORRECT, WE ATTEMPT TO CREATE A NEW CELL BY REPLACING AN EXISTING ONE.* ### What cell to replace - MANNs use LRU; this is wasteful; also end up forgetting good cells. - LMN replace the cell with the smallest weight among all cells with the same label. Idea is to replace cell that is least useful for classification. (**POSSIBLE IDEA -- THIS IS ALSO JUST A HEURISTIC; WE COULD ADD ANOTHER PARAM TO DECAY ON THIS OR USE SOMETHING LIKE ATTENTION**) ## Related Work **Diffs from MANNs** - Memory is updated only if diff in loss crosses margin. - The memory reads from MANNs are fed into the PCN for prediction. Instead, in LMN we `loosely couple` (convex sum). - MANNs use LRU policy for replacing cells. LMN replaces the lowest weighted cell for the same label. ## Experiments LMN was tested on: - online sequence prediction (user trajectory) - online image classifiers (with unknown labels at test time) - language modeling Significantly outperforms other baselinse. - Note that in some instances, authors implemented baselines and did not use official code release (possible because it was unavailable or the results were on a different dataset). - Online image classification results were shockingly good. ## Thoughts I would like to spend more time to understand the links b/w the memory and a neural network -- because it seems that they have achieved just that. In some sense, they are going over the entire batch of the "online data" when they sum over all cells with the same label as $y_t$, so it's not really "online" in the "online SGD" sense. - The update/prediction is not really based on a some concise summary statistics which capture the current state. Since there are multiple cells for the same label, you have access to a lot more information. - The memory score is like a weighted version of all scores based on the similarity of the current output to the memory data. - $v_m$ for a cell is learned via backprop? Not clear -- will check code. - Also not clear how can there be multiple cells for the same label -- the Section `Adapting to true label` suggests that cells are replaced. Will check code.