# Prefix tuning (transposed) Dimensions - $D$ = hidden dimension - $d$ = dimensions per head - $h$ = number of heads - $t$ = num tokens (usually query tokens, often=1) - $s$ = num past tokens - $m$ = num prefix tokens Variables * $x: [t,D]$ hidden states entering attention * $\mathbf{Q}, \mathbf{K}, \mathbf{V} , \mathbf{O}: [D,d] =$ QKVO weights * $P_k, P_v: [m,d]$ = soft prefixes --- Normal attention is expressed as $$ \text{S}(x\mathbf{Q}\mathbf{K}^Tx^T)x\mathbf{V}\mathbf{O}^T $$ for each head, where S is the softmax operator. Adding the soft prefixes, we get $$ \text{S}\left( x\mathbf{Q} \begin{bmatrix} \mathbf{K}^Tx^T & P_k^T \end{bmatrix} \right) \begin{bmatrix} x\mathbf{V} \\ P_v \end{bmatrix} \mathbf{O}^T $$ or $$ \text{S}\left( \begin{bmatrix} x\mathbf{Q} \mathbf{K}^Tx^T & x\mathbf{Q}P_k^T \end{bmatrix} \right) \begin{bmatrix} x\mathbf{V}\mathbf{O}^T \\ P_v\mathbf{O}^T \end{bmatrix} $$ (Note that while we could pre-multiply $\mathbf{V}\mathbf{O}^T$ and $\mathbf{Q}\mathbf{K}^T$ to save computation, we do not. Each of the component matrices are $[D,d]$, whereas the products would be $[D,D]$ (per head!), which is quite large. $\mathbf{V}\mathbf{O}^T$ and $\mathbf{Q}\mathbf{K}^T$ serve as low-rank matrices.) Analagously, if $x$ is the hidden state for a single token, and $z$ is the previous hidden states, we can write single token query attention as: $$ \text{S}\left( \begin{bmatrix} x\mathbf{Q}\mathbf{K}^Tx^T & x\mathbf{Q}\mathbf{K}^Tz^T \end{bmatrix} \right) \begin{bmatrix} x \mathbf{V}\mathbf{O}^T \\ z \mathbf{V} \mathbf{O}^T \end{bmatrix} $$ We thus see the parallel between past-key/values and soft prefixes: $$ z\mathbf{V} \rightarrow P_v \\ z\mathbf{K} \rightarrow P_k $$ Except $z\mathbf{V}: [s, d]$ and $P_v: [m, d]$ (and likewise for K). Our goal is to find a way to map for $s$ vectors to $m$ vectors. Can we try a learned low-rank projection? Putting the softmax aside temporarily, we have $$ \begin{align} \begin{bmatrix} x\mathbf{Q}\mathbf{K}^Tx^T & x\mathbf{Q}\mathbf{K}^Tz^T \end{bmatrix} \begin{bmatrix} x \mathbf{V}\mathbf{O}^T \\ z \mathbf{V} \mathbf{O}^T \end{bmatrix} &= x\mathbf{Q}\mathbf{K}^Tx^Tx \mathbf{V}\mathbf{O}^T + x\mathbf{Q}\mathbf{K}^Tz^Tz \mathbf{V} \mathbf{O}^T \\&=x\mathbf{Q}\left(\mathbf{K}^Tx^Tx \mathbf{V} + \mathbf{K}^Tz^Tz \mathbf{V} \right)\mathbf{O}^T \\&\rightarrow x\mathbf{Q}\left(\mathbf{K}^Tx^Tx \mathbf{V} + P_k^TP_v \right)\mathbf{O}^T \end{align} $$ In the second to last step, we see $\mathbf{K}^Tx^Tx \mathbf{V}$ and $\mathbf{K}^Tz^Tz \mathbf{V}$. We want to convert the latter to $P_k^TP_v$. As proposed above, we want to project down from $s$ to $m$; or in other words, we need two matrices of shape $[m, s]$. We propose to learn two low-rank matrices $A, B: [D, m]$. Then, $zA$ and $B^Tz^T$ are both of shape $[s, m]$. To elaborate, remember that we wish to approximate $$\mathbf{K}^Tz^Tz \mathbf{V}$$ We can approximate this with a low rank projection $$\mathbf{K}^Tz^TzAB^Tz^Tz \mathbf{V}$$ As such, we have: $$ P_v = B^Tz^Tz \mathbf{V}: [m,d] \\ P_k = A^Tz^Tz \mathbf{K}: [m,d] $$ (Note: $\mathbf{K}^Tz^Tz \mathbf{V}$ is of dimensions $[d,d]$, but remember that this is a simplification without the softmax. With the softmax, it actually has DoF $s>d$. But hopefully this formulation is fine, since $d>m$.) ## WIP Two ways of training: - $\mathbf{K}^Tz^Tz \mathbf{V} \Leftrightarrow \mathbf{K}^Tz^TzAB^Tz^Tz \mathbf{V}$ - Straight LM with $\mathbf{K}^Tz^TzA$ and $B^Tz^Tz \mathbf{V}$ (big mem) - $\text{S}(x\mathbf{Q}\mathbf{K}^Tx^TxA)B^Tx^Tx\mathbf{V}\mathbf{O}^T$ - Wait how do we even do attention masking - Okay, we need to do this with conditional LM - RoPE messes everything up. Post-rope? - LLaMA-adapter