# Prefix tuning (transposed)
Dimensions
- $D$ = hidden dimension
- $d$ = dimensions per head
- $h$ = number of heads
- $t$ = num tokens (usually query tokens, often=1)
- $s$ = num past tokens
- $m$ = num prefix tokens
Variables
* $x: [t,D]$ hidden states entering attention
* $\mathbf{Q}, \mathbf{K}, \mathbf{V} , \mathbf{O}: [D,d] =$ QKVO weights
* $P_k, P_v: [m,d]$ = soft prefixes
---
Normal attention is expressed as
$$
\text{S}(x\mathbf{Q}\mathbf{K}^Tx^T)x\mathbf{V}\mathbf{O}^T
$$
for each head, where S is the softmax operator.
Adding the soft prefixes, we get
$$
\text{S}\left(
x\mathbf{Q}
\begin{bmatrix}
\mathbf{K}^Tx^T & P_k^T
\end{bmatrix}
\right)
\begin{bmatrix}
x\mathbf{V} \\ P_v
\end{bmatrix}
\mathbf{O}^T
$$
or
$$
\text{S}\left(
\begin{bmatrix}
x\mathbf{Q} \mathbf{K}^Tx^T & x\mathbf{Q}P_k^T
\end{bmatrix}
\right)
\begin{bmatrix}
x\mathbf{V}\mathbf{O}^T \\ P_v\mathbf{O}^T
\end{bmatrix}
$$
(Note that while we could pre-multiply $\mathbf{V}\mathbf{O}^T$ and $\mathbf{Q}\mathbf{K}^T$ to save computation, we do not.
Each of the component matrices are $[D,d]$, whereas the products would be $[D,D]$ (per head!), which is quite large. $\mathbf{V}\mathbf{O}^T$ and $\mathbf{Q}\mathbf{K}^T$ serve as low-rank matrices.)
Analagously, if $x$ is the hidden state for a single token, and $z$ is the previous hidden states, we can write single token query attention as:
$$
\text{S}\left(
\begin{bmatrix}
x\mathbf{Q}\mathbf{K}^Tx^T
& x\mathbf{Q}\mathbf{K}^Tz^T
\end{bmatrix}
\right)
\begin{bmatrix}
x \mathbf{V}\mathbf{O}^T
\\ z \mathbf{V} \mathbf{O}^T
\end{bmatrix}
$$
We thus see the parallel between past-key/values and soft prefixes:
$$
z\mathbf{V} \rightarrow P_v \\
z\mathbf{K} \rightarrow P_k
$$
Except $z\mathbf{V}: [s, d]$ and $P_v: [m, d]$ (and likewise for K).
Our goal is to find a way to map for $s$ vectors to $m$ vectors. Can we try a learned low-rank projection? Putting the softmax aside temporarily, we have
$$
\begin{align}
\begin{bmatrix}
x\mathbf{Q}\mathbf{K}^Tx^T
& x\mathbf{Q}\mathbf{K}^Tz^T
\end{bmatrix}
\begin{bmatrix}
x \mathbf{V}\mathbf{O}^T
\\ z \mathbf{V} \mathbf{O}^T
\end{bmatrix}
&= x\mathbf{Q}\mathbf{K}^Tx^Tx \mathbf{V}\mathbf{O}^T + x\mathbf{Q}\mathbf{K}^Tz^Tz \mathbf{V} \mathbf{O}^T
\\&=x\mathbf{Q}\left(\mathbf{K}^Tx^Tx \mathbf{V} + \mathbf{K}^Tz^Tz \mathbf{V} \right)\mathbf{O}^T
\\&\rightarrow x\mathbf{Q}\left(\mathbf{K}^Tx^Tx \mathbf{V} + P_k^TP_v \right)\mathbf{O}^T
\end{align}
$$
In the second to last step, we see $\mathbf{K}^Tx^Tx \mathbf{V}$ and $\mathbf{K}^Tz^Tz \mathbf{V}$. We want to convert the latter to $P_k^TP_v$. As proposed above, we want to project down from $s$ to $m$; or in other words, we need two matrices of shape $[m, s]$.
We propose to learn two low-rank matrices $A, B: [D, m]$. Then, $zA$ and $B^Tz^T$ are both of shape $[s, m]$.
To elaborate, remember that we wish to approximate
$$\mathbf{K}^Tz^Tz \mathbf{V}$$
We can approximate this with a low rank projection
$$\mathbf{K}^Tz^TzAB^Tz^Tz \mathbf{V}$$
As such, we have:
$$
P_v = B^Tz^Tz \mathbf{V}: [m,d] \\
P_k = A^Tz^Tz \mathbf{K}: [m,d]
$$
(Note: $\mathbf{K}^Tz^Tz \mathbf{V}$ is of dimensions $[d,d]$, but remember that this is a simplification without the softmax. With the softmax, it actually has DoF $s>d$. But hopefully this formulation is fine, since $d>m$.)
## WIP
Two ways of training:
- $\mathbf{K}^Tz^Tz \mathbf{V} \Leftrightarrow \mathbf{K}^Tz^TzAB^Tz^Tz \mathbf{V}$
- Straight LM with $\mathbf{K}^Tz^TzA$ and $B^Tz^Tz \mathbf{V}$ (big mem)
- $\text{S}(x\mathbf{Q}\mathbf{K}^Tx^TxA)B^Tx^Tx\mathbf{V}\mathbf{O}^T$
- Wait how do we even do attention masking
- Okay, we need to do this with conditional LM
- RoPE messes everything up. Post-rope?
- LLaMA-adapter