Consider replacing softmax with a non-linear function in a single head attention as suggested by a recent theoretical study of transformers [1].
**Emprical observation for non-linear attentions.**
For three different non-linear activation functions, we observed the same sparsity pattern as those of Theorem 1 after optimization of the training loss for linear models. Note that the data distribution is isotropic and the same as the settings of Theorem 1.



**Theoretical analysis for ReLU attention**
Following the theoretical settings of [1], we introduce a non-linear attention for which the softmax is relpaced by ReLU as
$Attn_{Q}(Z) = P ZM g(
Z^\top Q Z)$
where $g$ is the ReLU activation function, and $P = [0,\dots, 0, 1]^{\otimes 2} \in \mathbb{R}^{(d+1)\times (d+1)}$ and $Q = \begin{bmatrix} A & 0\\
0 & 0\end{bmatrix}$.
The in-contex loss is defined as
$$f(A) = \mathbb{E} \left( \sum_{i} \langle x^{(i)}, w^* \rangle g(x^\top A x^{(i)} )- \langle x, w^* \rangle \right)^2 $$
where we use the compact notation $x = x^{(n+1)}$. Note that we skip constant $1/N$ for the simplicity of notations.
**Lemma.** Suppose that the data and $w^*$ ditribution is Gaussian and isotropic. Then there is a constant $c$ such that $c I = \arg\min_A f(A)$.
**Remarks.**
This is a rushed theoretical result derived in the limited rebuttal time. We will do more extensive proof-checking and also add the following remarks on the limitations:
- We fixed the parameter $P$ in the attention and only optimized $Q$. This is only due to the limited time for rebuttal. We believe that it is possible to study the optimal joint $P$ and $Q$
- While we only analyzed ReLU, the same proof techinque may extend to more non-linear functions as observed in the experiments
- We only analyzed isotropic inputs, but we beleive that the same proof techinque will work for non-isotropic Gaussian inputs.
- While non-linear attention considered here is different from softmax, the above analysis shows that the structure of global minimum in Theorem 1 extends to a non-linear attention.
**Proof Idea.** The proof idea is inspired by the elegant observation in [2] that by scaling least-squares solution one can recover the solution of generalized linear models for Guassian data. This observation is based on a subtle application of Stein's charaterization of Gaussian distributions. Here, we will use Stein's Lemma and leverage the symmetricies of Gaussian distribution and properties of ReLU to prove ReLU only influence the scaling of global optimum.
**Proof.**
Since $w^*$ is isotropic Gaussian, we can take the expectation over $w^*$ to get
$$ f(Z) = \mathbb{E} \left\| \sum_{i} g(x^\top A x^{(i)} )x^{(i)}- x \right\|^2 $$
Expanding $f$, we get
$$ f(Z) = \mathbb{E} \left[ \sum_{i,j}g(x^\top A x^{(i)} ) g(x^\top A x^{(j)} )\langle x^{(i)}, x^{(j)} \rangle - 2 \sum_{i} g(x^\top A x^{(i)} ) \langle x^{(i)}, x \rangle \right] + \text{const.}$$
Notably, thanks to Stein's lemma, it holds that [2; Equation 5]
$$ \mathbb{E}_{x \sim N(0,I)} [ x h(x^\top v) ] = \mathbf{E}_{x \sim N(0,I)} [h'(x^\top v) ] v$$
Since Stein's idenity is a consequence of integration by parts, it holds for all absolutely continous functions $h$, hence for $g$ despite its non-smoothness. Therefore,
$$\mathbb{E}_{x^{(i)},x^{(j)},x} \left[ g(x^\top A x^{(i)}) g(x^\top A x^{(j)}) \langle x^{(i)}, x^{(j)} \rangle \right] = \mathbb{E}_{x^{(j)},x} [\mathbb{E}_{x^{(i)}} [g'(x^\top A x^{(i)})] g(x^\top A x^{(j)}) x^\top A x^{(j)}] $$
holds for $g'$ the derivate of ReLU with respect to its inputs. Using the invariance of distribution of $x^{(i)}$ to the sign filip, we have
$$\mathbb{E}_{x^{(i)}} [g'(x^\top A x^{(i)})] = \mathbb{E}_{x^{(i)}} [g'(-x^\top A x^{(i)})]$$
On the other hand, $g'(-a)+ g'(a) =1$. This equation with the above equality concludes $\mathbb{E}_{x^{(i)}} [g'(x^\top A x^{(i)})] = 1/2$. Thus,
$$\mathbb{E}_{x^{(i)},x^{(j)},x} \left[ g(x^\top A x^{(i)}) g(x^\top A x^{(j)}) \langle x^{(i)}, x^{(j)} \rangle \right] = \frac{1}{2}\mathbb{E}_{x^{(j)},x} [ g(x^\top A x^{(j)}) x^\top A x^{(j)} ]$$
Similarly, we can use Stein's lemma for $x^{(j)}$ ($i\neq j$) to get
$$\frac{1}{2}\mathbb{E}_{x^{(j)},x} [ g(x^\top A x^{(j)}) x^\top A x^{(j)} ] = \frac{1}{4}\mathbb{E}_{x} [ \langle x^\top A A^\top x \rangle] = \frac{1}{4}\| A \|_F^2$$
Using Stien's lemma, we get
$$ \mathbb{E}_{x^{(i)},x} [ g(x^\top A x^{(i)} ) \langle x^{(i)}, x \rangle] = \frac{1}{2} \mathbb{E}_{x} \left[ x^\top A x \right] = \text{Tr}(A)/2 $$
where Tr is the trace. Using the fact that $g^2(a)+g^2(-a) = a^2$, we get
$$\mathbb{E}_{x^{(i)},x} \left[ g^2(x^\top A x^{(i)}) \|x^{(i)} \|^2 \right] = \frac{1}{2} \mathbb{E} [(x^\top A x^{(i)})^2 \| x^{(i)}\|^2 ] $$
Without non-linearity, it is easy to compute the expectation as
$$\frac{1}{2} \mathbb{E} [(x^\top A x^{(i)})^2 \| x^{(i)}\|^2 ] = \mathbb{E} \left[ (x^{(i)})^\top A^\top A x^{(i)} \| x^{(i)}\|^2 \right] = (d+2) \| A \|_F^2$$
Putting all together, we have the following experession for the objective function
$$f(A) = \alpha \| A \|_F^2 - \beta \text{Tr}(A) + constant$$
The above function is convex in $A$. Setting the gradient to zero proves the optimal $A$ is a proper scaling of the idenity matrix. $\square$
**References**
[1] Zhao, H., Panigrahi, A., Ge, R., & Arora, S. (2023). Do Transformers Parse while Predicting the Masked Word?. arXiv preprint arXiv:2303.08117.
[2] Erdogdu, M. A., Dicker, L. H., & Bayati, M. (2016). Scaled least squares estimator for glms in large-scale problems. Advances in Neural Information Processing Systems 2016.
** posted on openreview **
**Restated results**
Following the theoretical settings of [1], we introduce a non-linear attention for which the softmax is relpaced by ReLU as $Attn_{Q}(Z) = P ZM g(
Z^\top Q Z)$
where $g$ is the ReLU activation function, and $P = [0,\dots, 0, 1]^{\otimes 2} \in \mathbb{R}^{(d+1)\times (d+1)}$ and $Q = \begin{bmatrix} A & 0\\
0 & 0\end{bmatrix}$.
The in-contex loss is defined as
$$f(A) = \mathbb{E} \left( \sum_{i} \langle x^{(i)}, w^* \rangle g(x^\top A x^{(i)} )- \langle x, w^* \rangle \right)^2 $$
where we use the compact notation $x = x^{(n+1)}$. Note that we skip constant $1/N$ for the simplicity of notations.
**Lemma.** Suppose that the data and $w^*$ ditribution is Gaussian and isotropic. Then there is a constant $c$ such that $c I = \arg\min_A f(A)$.
**Proof Idea.** The proof idea is inspired by the elegant observation in [2] that by scaling least-squares solution one can recover the solution of generalized linear models for Guassian data. This observation is based on a subtle application of Stein's charaterization of Gaussian distributions. Here, we will use Stein's Lemma and leverage the symmetricies of Gaussian distribution and properties of ReLU to prove ReLU only influence the scaling of global optimum.
**Proof.**
Since $w^* $ is isotropic Gaussian, we can take the expectation over $w^* $ to get
$$ f(Z) = \mathbb{E} \left\\| \sum_{i} g(x^\\top A x^{(i)} )x^{(i)}- x \right\\|^2 $$
Expanding $f$, we get
$$ f(Z) = \mathbb{E} \left[ \sum_{i,j}g(x^\top A x^{(i)} ) g(x^\top A x^{(j)} )\langle x^{(i)}, x^{(j)} \rangle - 2 \sum_{i} g(x^\top A x^{(i)} ) \langle x^{(i)}, x \rangle \right] + \text{constant independent from } A $$
Notably, thanks to Stein's lemma, it holds that [2; Equation 5]
$$ \mathbb{E}\_{x \sim N(0,I)} [ x h(x v) ] = \mathbf{E}\_{x \sim N(0,I)} [h'(x v) ] v $$
Since Stein's idenity is a consequence of integration by parts, it holds for all absolutely continous functions $h$, hence for $g$ despite its non-smoothness. Therefore,
$$\mathbb{E}\_{x^{(i)},x^{(j)},x} \left[ g(x^\top A x^{(i)}) g(x^\top A x^{(j)}) \langle x^{(i)}, x^{(j)} \rangle \right] = \mathbb{E}\_{x^{(j)},x} [\mathbb{E}\_{x^{(i)}} [g'(x^\top A x^{(i)})] g(x^\top A x^{(j)}) x^\top A x^{(j)}] $$
holds for $g'$ the derivate of ReLU with respect to its inputs. Using the invariance of distribution of $x^{(i)}$ to the sign filip, we have
$$\mathbb{E}\_{x^{(i)}} [g'(x^\top A x^{(i)})] = \mathbb{E}\_{x^{(i)}} [g'(-x^\top A x^{(i)})]$$
On the other hand, $g'(-a)+ g'(a) =1$. This equation with the above equality concludes $\mathbb{E}\_{x^{(i)}} [g'(x^\top A x^{(i)})] = 1/2$. Thus,
$$\mathbb{E}\_{x^{(i)},x^{(j)},x} \left[ g(x^\top A x^{(i)}) g(x^\top A x^{(j)}) \langle x^{(i)}, x^{(j)} \rangle \right] = \frac{1}{2}\mathbb{E}\_{x^{(j)},x} [ g(x^\top A x^{(j)}) x^\top A x^{(j)} ]$$
Similarly, we can use Stein's lemma for $x^{(j)}$ ($i\neq j$) to get
$$\frac{1}{2}\mathbb{E}\_{x^{(j)},x} [ g(x^\top A x^{(j)}) x^\top A x^{(j)} ] = \frac{1}{4}\mathbb{E}\_{x} [ \langle x^\top A A^\top x \rangle] = \frac{1}{4}\\| A \\|_F^2$$
Using Stien's lemma, we get
$$ \mathbb{E}\_{x^{(i)},x} [ g(x^\top A x^{(i)} ) \langle x^{(i)}, x \rangle] = \frac{1}{2} \mathbb{E}\_{x} \left[ x^\top A x \right] = \text{Tr}(A)/2 $$
where Tr is the trace. Using the fact that $g^2(a)+g^2(-a) = a^2$, we get
$$\mathbb{E}\_{x^{(i)},x} \left[ g^2(x^\top A x^{(i)}) \\|x^{(i)} \\|^2 \right] = \frac{1}{2} \mathbb{E} [(x^\top A x^{(i)})^2 \\| x^{(i)}\\|^2 ] $$
Without non-linearity, it is easy to compute the expectation as
$$\frac{1}{2} \mathbb{E} [(x^\top A x^{(i)})^2 \\| x^{(i)}\\|^2 ] = \mathbb{E} \left[ (x^{(i)})^\top A^\top A x^{(i)} \\| x^{(i)}\\|^2 \right] = (d+2) \\| A \\|_F^2$$
Putting all together, we have the following experession for the objective function
$$f(A) = \alpha \\| A \\|_F^2 - \beta \text{Tr}(A) + constant$$
The above function is convex in $A$. Setting the gradient to zero proves the optimal $A$ is a proper scaling of the idenity matrix. $\square$
----
**References**
[1] Zhao, H., Panigrahi, A., Ge, R., & Arora, S. (2023). Do Transformers Parse while Predicting the Masked Word?. arXiv:2303.08117.
[2] Erdogdu, M. A., Dicker, L. H., & Bayati, M. (2016). Scaled least squares estimator for glms in large-scale problems. Advances in Neural Information Processing Systems 2016.