# ENCODING RECURRENCE INTO TRANSFORMERS
###### tags: `Meeting`
ICLR 2023
<!-- *Feiqing Huang, Kexin Lu, Yuxi CAI, Zhen Qin, Yanwen Fang, Guangjian Tian, Guodong Li* -->
## Introduction
The recurrent models are well known to suffer from two drawbacks.
1. **Gradient vanishing problem**, i.e. the recurrent models have difficulty in depicting the possibly high correlation between distant inputs.
2. **Sequential nature** renders these models difficult to be trained in parallel.
The success of **Transformers** is due to the fact that the similarity (or dependence) between any two tokens (or inputs) is well taken into account, and hence they can model long range dependence effortlessly.
Moreover, contrary to the recurrent models, the self-attention mechanism in Transformers is **feed-forward** in nature, and this enables them to be computed in **parallel** on the GPU infrastructure.
The flexibility also leads to sample inefficiency in training a Transformer, i.e. much more samples will be needed to guarantee a good generalization ability.
Moreover, the chronological orders are usually ignored by Transformers since they are **time-invariant**, and some additional efforts, in the form of positional encoding, will be required to further aggregate the temporal information.

Transformers may have an improved performance if the recurrent model can be involved to handle these recurrent patterns, especially when the sample size is relatively small.
Specifically, if the recurrent and non-recurrent components are separable, one then can apply a parsimonious recurrent model on the recurrent component and a Transformer on the non-recurrent one, and the **sample efficiency can be improved** comparing to the Transformer-only baseline.
## Relationship between RNN and multihead self-attention
RNN layer can be approximated by a series of simple RNNs with scalar coefficients, which can be further represented in the form of a multihead self-attention.

### Breaking down an RNN layer
Consider an RNN layer with the input variables $\{x_t\in\mathbb{R}^{d_{in}},1\leq t\leq T\}$, and it has the form of $h_t=g(W_hh_{t-1}+W_xx_t+b)$, where $g(.)$ is the activation function, $h_t\in\mathbb{R}^d$ is the output or hidden variable with $h_0=0,b\in\mathbb{R}^d$ is the bias term, $W_h\in\mathbb{R}^{d\times d}$ and $W_X\in\mathbb{R}^{d\times d_{in}}$ are weights.
When the **activation function is linear**, i.e. $g(x)=x$, the RNN becomes
$$
h_t = W_h h_{t-1}+W_x x_t,\;\text{or equivalently} \; h_t=\sum^{t-1}_{j=0}W_h^jW_xx_{t-j} \tag 1
$$
$W_h$ has $r$ real nonzero eigenvalues $\lambda_1,...\lambda_r$, and $s$ pairs of complex nonzero eigenvalues $\lambda_{r+1},...,\lambda_{r+2s}$, where $(\lambda_{r+2k-1},\lambda_{r+2k})=(\gamma_ke^{i\theta_k},\gamma_ke^{-i\theta_k})$ for $1\leq k\leq s$, $i$ represents the imaginary unit, and $R=r+2s$.
We have Jordan decomposition in real form, $W_h=BJB^{-1}$, where $B\in\mathbb{R}^{d\times d}$ is invertible and $J\in\mathbb{R}^{d\times d}$ is a block diagonal matrix. It holds that $W_h^j=BJ^jB^{-1}$ for all $j\geq1$, and we can then break down the recurrence induced by $W_h$ into that of the $p\times p$ block matrices in $J$ with $p=1\,\text{or}\,2$.
$$
h_t^R(\lambda)=\sum_{j=1}^{t-1}\lambda^jW_x^Rx_{t-j}, \quad h_t^{C1}(\gamma,\theta)=\sum_{j=1}^{t-1}\gamma^j\cos(j\theta)W_x^{C1}x_{t-j},\\
\text{and}\quad h_t^{C2}(\gamma,\theta)=\sum_{j=1}^{t-1}\gamma^j\sin(j\theta)W_x^{C2}x_{t-j} \tag 2
$$
where the first one corresponds to the real eigenvalues, i.e. the $1\times1$ block matrices in $J$, while the last two correspond to the complex eigenvalues, i.e. the $2\times2$ block matrices in $J$. Each of the three RNNs has the recurrent weights of $\lambda$ or $(\gamma,\theta)$, and its form with a nonlinear activation function is given in the Appendix.
**Proposition 1.** Let $h_{0,t}=W_x x_t$, and then the RNN with linear activation can be equivalently rewritten into
$$
h_t=\sum_{k=1}^r h_t^R(\lambda_k)+\sum_{k=1}^s h_t^{C1}(\gamma_k, \theta_k)+\sum_{k=1}^s h_t^{C2}(\gamma_k,\theta_k) + h_{0,t}.
$$
### An equivalent MHSA representation
Consider the RNN of $\{ h_t^R\}$, and let $X=(x_1,...,x_T)^{\prime}\in\mathbb{R}^{T\times d_{in}}$ be an input matrix consisting of T tokens with dimension $d_{in}$, where the transpose of a matrix $A$ is denoted by $A^{\prime}$ throughout this paper.
They first give the value matrix $V$ by projecting $X$ with a linear transformation, i.e. $V=XW_V$ with $W_V=W_x^{R\prime}\in\mathbb{R}^{d_{in}\times d}$, and the relative **positional encoding matrix** is set to
$$
P_{mask}^R(\lambda)=
\begin{pmatrix}
0 & 0 & 0 & \cdots & 0 \\
f_1(\lambda) & 0 & 0 & \cdots & 0 \\
f_2(\lambda) & f_1(\lambda) & 0 & \cdots & 0 \\
\vdots & \vdots & \vdots & \ddots & \vdots \\
f_{T-1}(\lambda) & f_{T-2}(\lambda) & f_{T-3}(\lambda) & \cdots & 0 \\
\end{pmatrix} \tag 3
$$
where $f_t(\lambda)=\lambda^t$ for $1\leq t\leq T-1$. As a result, the first RNN at (2) can be represented into a self-attention (SA) form,
$$
(h_1^R,...,h_T^R)^{\prime} = \text{SA}^R(X)=[\text{softmax}(QK^{\prime})+P_{mask}^R(\lambda)]V ,
$$
where $Q$ and $K$ are zero query and key matrices, respectively. They call $P_{mask}^R(\lambda)$ the **recurrence encoding matrix (REM)** since it summarizes all the recurrence in $\{h_T^R\}$.
For the RNN of $\{h_T^{C1}\}$, the REM is denoted by $P_{mask}^{C1}(\gamma,\theta)$, which has the form of (3) with $f_t(\lambda)$ being replaced by $f_t(\gamma,\theta)=\gamma^t\cos(t\theta)$ for $1\leq t\leq T-1$, and the value matrix has the form of $V=XW_V$ with $W_V=W_x^{C1\prime}\in\mathbb{R}^{d_{in}\times d}$. Similarly, for the RNN of $\{h_T^{C2}\}$, the REM, $P_{mask}^{C2}(\gamma,\theta)$, has the form of (3) with $f_t(\lambda)$ being replaced by $f_t(\gamma,\theta)=\gamma^t\sin(t\theta)$ for $1\leq t\leq T-1$, and the value matrix is defined as $V=XW_V$ with $W_V=W_x^{C2\prime}\in\mathbb{R}^{d_{in}\times d}$. Thus, these two RNNs at (2) can also be represented into SA forms,
$$
(h_1^{C_i},...,h_T^{C_i})^{\prime}=\text{SA}^{C_i}(X)=[\text{softmax}(QK^{\prime})+P_{mask}^{C_i}(\gamma,\theta)]V\quad\text{with}\quad i=1\;\text{or}\;2,
$$
where query and key matrices $Q$ and $K$ are both zero.
Finally, for the remainting term in Proposition 1, $h_{0,t}$ depends on $x_t$ only, and there is no interdependence involved. Mathematically, we can represent it into a SA with the **identity relative positional encoding matrix** and zero query and key matrices.
**Proposition 2.** If the conditions of Proposition 1 hold, then the RNN with linear activation at (1) can be represented into a multihead self-attention (MHSA) with $r+2s+1$ heads, where the query and key matrices are zero, and relative positional encoding matrices are $\{P_{mask}^R(\lambda_k),1\leq k\leq r\}, \{P_{mask}^{C_1}(\gamma_k,\theta_k),P_{mask}^{C_2}(\gamma_k,\theta_k),1\leq k\leq s\}$ and an identity matrix, respectively.
## Encoding recurrence into self-attention

Propose the Self-Attention with Recurrene (RSA) module to seamlessly combine the strengths of RNNs and Transformers:
$$
\text{RSA}(X)=\{[1-\sigma(\mu)]\text{softmax}(QK^{\prime})+\sigma(\mu)P\}V
$$
$P$ is a regular or cyclical REM, and $\sigma(\mu)\in[0,1]$ is a gate with $\sigma$ being the sigmoid function and $\mu$ being the learnable gate-control parameter.
The learnalbe gate $\sigma(\mu)$ is used to measure the proportion or strength of recurrent patterns. When the sample size is relatively small, it is also possible to use REMs to approximate a part of non-recurrent patterns to obtain a better bias-variance trade-off, leading to a higher value of $\sigma(\mu)$. On the orther hand, the non-recurrent patterns can be taken care of by the flexible Transformers.
## Experiments
<!-- This section contains four sequential modeling tasks ,and for each task, they modify some popular Transformer baselines by adding the REMs to their attention weights via a gated mechanism.

They argue that time series data have the strongest recurrent sinals, followed by the regular language and finally the code or natural languages. -->
### Time Series Forecasting
Datasets:
1. the ETT datasets is comprised of seven features related to the electric power long-term deployment, where {$\text{ETTh}_1,\text{ETTh}_2$} are recorded by the hour and $\text{ETTm}_1$ is recorded by 15-minute intervals
2. the Weather dataset contains twelve climate indicators collected every 1 hour over a 4-year period.

### Regular Language Learning
Regular languages are intimately related to the linear recurrence sequences, such as Fibonacci numbers(0,1,1,2,3,5,8,...). Some works report that Transformers have difficulty in generalizing the rules of regular languages.

> [On the ability and limitations of transformers to recognize formal languages](https://arxiv.org/abs/2009.11264)

> Parity

> Tomita

### Code and Natural Language Modeling
Different from regular languages, the recurrence relationship in programming or natural languages is weaker and harder to interpret.
They first conduct a defect detection task based on the C language dataset, which is a binary classification task to evaluate whether a code is vulnerable to external attacks.

## Conclusion
- By formulating RNN into an MHSA form, they propose an RSA module to incorporate the recurrent dynamics of an RNN into a Transformer.
- The lightweight REMs are combined with the self-attention weights via a gated mechanism, maintaining a parallel and efficient computation on the GPU infrastructure.
- Experiments on four sequential learning tasks show that the proposed RSA module can boost the sample efficiency of the baseline Transformer, supported by significantly improved performance.