# Notes on Recurrent Independent Mechanisms (RIMs)
#### Author: [Sharath Chandra](https://sharathraparthy.github.io/)
## [Paper Link](https://arxiv.org/pdf/1909.10893.pdf)
This paper aims at learning modular structures to better understand the world and generalize to out-of-distribution tasks. The authors propose a new recurrent architecture where the "small/sub" recurrent units (or mechanisms in general) follow their own dynamics and sparingly interact with each other.
## Introduction
Lets examine the following example. Let us consider two researchers, Alice and Bob, who specialize in machine learning and economics. Essentially, these two people work independently on their own interest. So we can consider these two people carrying out their own research as two mechanisms which are independent and hence has a modular structure. But when there is a need for Alice to solve an economics problem, then she will collaborate with bob and get the job(project) then. This can be viewed as two independent mechanisms communicating sparsely. This is the core idea which this paper exploits.
If we have a look at the current deep learning architectures, we can see that every neuron shares information with every other neuron. For large neural networks we can sense the weight matrix over burdening itself. But most of the physical processes in the world has a modular structure and a human brain is doing a great job at maintaining this modular structure and communicating only with other modules only when necessary. Why can't we bring that structure to current deep learning systems?
This is the main motivation behind this paper and the authors design an way to modularize the RNNs into smaller RNNs which carry out their independent dynamics and sparsely interact when required. They call this new architecture "Recurrent Independent Mechanisms (RIMs)"
## RIMs with sparse interactions
![RIMs in one picture](https://i.imgur.com/9CeTN6k.png)
The idea is to divide the overall model into $k$ subsystems called RIMs, where each RIMs have their own function which are learned by observing the data during the training. Each RIM is activated when the input is relevant and this is done by using the attention mechanism.
This whole input attention process can be better understood from the programming languages perspective where if you consider each RIM as a function which accepts only specific type of arguments. When the arguments type matches, there is a binding process that happens and then the function carries out the set of steps that it is coded for. Here, the binding process happens through input attention mechanism where it would look at the candidate input object's key and evaluates if its "type" matches with what that particular RIM expect. If it matches, then that RIM gets activated.
Now more formally, at each step, top-k RIMs are selected based on their attention scores from the real input. The input $x_t$ at time $t$ is seen as a set of elements of the row matrix. Then a row full of zero's is concatenated to obtain $X = \Phi \bigoplus x_t$. The keys, values and the queries are calculated by linear transformations:
$$
K = XW^e \ \ \text{one per element} \\
V = XW^v \ \ \text{one per element} \\
Q = h_t W_k^q \ \ \text{one per RIM}
$$
Where $W^e$ is a learnable weight matrix mapping from an input to key vector, $W^v$ is a learnable weight matrix mapping from an input to value vector and W_k^q is a per-RIM weight matrix which maps RIM's hidden state to its queries. The attention score is calculated by
\begin{equation}
A_k^{in} = \text{softmax} \left(\frac{h_t W_k^q (XW^e)^T}{d_e^{1/2}} \right)X W^v
\end{equation}
Based on these scores, the top-k attention scores are selected which have least attention on null input. The intuition behind concatenating a zero row matrix is to have a threshold to differentiate between most attentive parts and least attentive parts of the input. RIMs uses similar transformer styled multi-headed attention. This essentially doesn't change the computational aspects but the final scores are averaged over the heads.
After the RIMs activation no information passes between the activated RIMs and each activated RIMs follows its default dynamics either using GRU or LSTM. For the hidden states of the RIMs which are not activated will remain unchanged ($h_{t+1, k} = h_{t, k}$). The active RIMs are updated by
\begin{equation}
\tilde{h}_{t, k} = LSTM(h_{t, k}, A_k^{(in)}; \theta_{k}) \ \ \forall k \in \mathcal{S_t}
\end{equation}
where $\mathcal{S_t}$ represents a set of activated RIMs and $A_k^{(in)$ is relevant input which is calculated by attention mechanism as discussed earlier.
Now each RIM, carries out its own dynamics and for the sparse communication step the authors again employ an attention bottleneck. During this sparse communication step, they use a separate set of weight matrices $\theta_k^{(c)} = (\tilde{\theta_k^q}, \tilde{\theta_k^e}, \tilde{\theta_k^v})$ for calculating the attention scores.
\begin{equation}
h_{t+1, k} = \text{softmax}\left(\frac{Q_{t, k}(K_{t, :})^T}{d_e^{1/2}} \right)V_{t, :} + \tilde{h}_{t, k}
\end{equation}
In order to maintain the sparsity, the paper still employs the same top-k attention.
This is the gist of the paper and I am omitting the discussion on experiment and results. For these details, check out the [paper](https://arxiv.org/pdf/1909.10893.pdf)