* Quick review of the symmetric power attention, with theoretical flops improvement and real training curves.
* The challanges of an efficient implementation:
* No sympow function callable from pytorch. But doing a good job is much harder than implementing the embedding function and binding it to pytorch.
* A few dims relevant to the problem: t, D are big. c, d are small. We want to avoid every materializing an object in main memory that is big dim x big dim
* The 3 Virtual matmul kernels needed
* Interpolation between the symmetric power and the tensor power
* We also include learnable gating
* Current benchmarks. We should be able to squeeze a lot more perf
* MFU
* h100
::: {.hidden}
$$
\newcommand{\R}{\mathbb{R}}
\newcommand{\Z}{\mathbb{Z}}
\newcommand{\N}{\mathbb{N}}
\newcommand{\sft}{\text{softmax}}
\newcommand{\List}{\text{List}}
\newcommand{\Seq}{\text{Seq}}
\newcommand{\SeqT}{\text{SeqT}}
\newcommand{\CSeqT}{\text{CSeqT}}
\newcommand{\Dist}{\text{Dist}}
\newcommand{\SM}{\text{SM}}
\newcommand{\Fn}{\text{Fn}}
\newcommand{\Tok}{\text{Tok}}
\newcommand{\Aij}{ A_{[i,j]}}
\newcommand{\ten}{\small\text{tensor}}
\newcommand{\sym}{\small\text{symmetric}}
$$
:::
# Power Attention Kernels: Alpha Release
[We have released an open-source implementation of GPU kernels which implement symmetric power attention.](TODO) This is still an early alpha, and these kernel will continue to recieve love over the next several months as we improve the performance, augment the feature set, and polish the interface and maintainability.
[Symmetric power attention](todo) is a linear-cost alternative to standard softmax attention that has 2 main advantages:
1. It has a **far more favorable per-FLOP learning efficiency** than softmax attention.

2. It allows for **much longer effective context for both training and inference**.
3.
Mathematically, it is straightforward, if we start with a standard transformer:
$$
Y_i = \sum_{j=1}^i A_{ij} V_j \qquad A_{ij} = e^{Q_i^T K_j}
$$
Simply replace $e^{(\cdot)}$ in the attention equation with $(\cdot)^p$, where $p$ is some integer power.
$$
e^{QK^T}V \to (QK^T)^pV
$$
Unlike traditional attention, which requires retaining a state which grows in size as the context length grows (i.e. the KV cache), symmetric power attention can be unrolled using a fixed-size state. (The size of the state can be controlled by adjusting the degree $p$.) When context lengths are large, this translates into a massive reduction in training FLOPs, and both FLOPs and memory at inference time.
What's more, transformer architectures built around symmetric power attention have learning ability competitive with standard softmax-attention transformers, in terms of loss-per-update. Alongside the FLOP reduction, this translates into better overall performance. Optimal learning requires selecting the state size in accordance with the compute budget: large states should be used only when ample compute is available. (This is a mirror of parameter scaling laws.)
(TODO image)
In the plot above, we see that state-size-optimal symmetric power transformers dominate classic transformers at long contexts, in terms of loss-per-FLOP. However, an implementation which translates these theoretical benefits into wall-clock speedups presents a further challenge. In this post, we explain the design principles behind the efficient GPU kernels in [our open-source repository](TODO).
## 1. Symmetric Power Transformers
We begin with a high level overview of symmetric power transformers. The inputs to the layer are sequences of $Q_i, K_i, V_i \in \R^d$ of queries, keys, and values, where $i$ ranges from $1$ to the sequence length $t$. The outputs are a sequence $Y_i\in \R^d$. In the *attention formulation*, the formula for the output vectors is:
$$
Y_i = \sum_{j=1}^i A_{ij} V_j \qquad A_{ij} = \frac{B_{ij}}{\sum_{k=1}^i B_{ik}} \qquad B_{ij} = (Q_i^T K_j)^p
\qquad \text{(sympow)}
$$
We refer to $A_{ij}$ as the attention scores and $B_{ij}$ as the preattention stores (mirroring the preactivation/activation lanugage often use to desrive the hidden values of an MLP before and after the nonlinearity).
::: {.column-margin}
It is important that the power $p$ is even because that guarantees the denominator is positive, which makes $A_{i1}, \cdots, A_{ii}$ a valid probability distribution. In turn, this makes the outputs $Y_i$ a convex combinatoin of $V_1, \cdots, V_i$.
:::
The exact same outputs $Y_i$ can be computed via a *recurrent formulation*. Doing so invovles an embedding function $\phi^p : \R^d \to \R^D$. The vector $\phi^p(k)$ contains the same information as ${k\otimes \cdots \otimes k}$, repeatedly taking tensor product $p$ times. But it does so much more efficiently because it removes a lot of symmetry in the tensor product. Thus $D << d^p$. Using this embedding function, we can write the recurrent equations:
$$
Y_{i} = \frac{S_i \phi^p(Q_i)}{Z_i \phi^p(Q_i)} \qquad Z_i = Z_{i-1} + \phi^p(K_i)^T \qquad S_i = S_{i-1} + V_i \phi^p(K_i)^T
$$
where $Z_0$ and $S_0$ are $\mathcal 0$ vectors in their respective spaces.
Since $S_i \in \R^{d \times D}$ and $Z_i \in \R^{D}$, the size of the state is $D(d+1)$.
These two forms give rise to a variety of algorithms for training linear transformers, with differing computational properties. [Read our earlier article on linear transformers for a detailed explanation.](https://manifestai.com/articles/linear-transformers-are-faster/)
## The 3 core operations of power attention
**Attention** $Y = A V$ where $A = (QK^T)^p \odot M$. Note that the attentin matrix is $t \times t$ while
**Update state**
**Query state**
Note how in all 3 cases we are dealing with an operation that is trying to compute an object that has one large and one small dimension. But, in the process of doing so it needs to compute an intermediate object
| | Attention | Update State | Query State |
|-| -------- | -------- | -------- |
|output | $t\times d$ | $D\times d$ | $t\times d$ |
|large intermediate | $t\times t$ | $t\times D$ | $t\times D$ |
## On the embedding function: tensor product, symmetric power and their interpolation
* tensor product is simple, maps well into gpu code. Wastes huge amount of dimensions
* symmetric power doesn't waste a single dimension but it's really tricky to efficiently break up the work o fcomputing it into code for each thread of the GPU.
* Break up the whole tensor product into cube of tensor products. There are symmetrices at the qute lebel (multiple cubes that contain the exact same elements). This allows us to interpolate between the two cases. As long as we can distribute the work of computing....
* For query state, each CTA work son a cube. 16x16 for the D feate and 16 for the batch dimension. The main loop of the kernel
## Temporal vs Feature Normalization
The attention rows being a distribution (positive nmbers summing to 1) causes issues when $p$ isn't even.
Layer norm vs ball norm.
## Upcoming features and optimizations
* generic p
* smaller cube size. 8 or 4 should be possible
* h100