In our notation, the above equation becomes
\begin{align*}
\text{Attention} \colon \mathbb{R}^{\key} \times \mathbb{R}^{\seq \times \key} \times \mathbb{R}^{\seq \times\val} &\rightarrow \mathbb{R}^{\val} \\
\text{Attention}(Q,K,V) = \nfun{\seq}{softmax} \left( \frac{Q \ndot{\key} K}{\sqrt{|\key|}} \right) \ndot{\seq} V.
\end{align*}
The $\key$ axis corresponds to the coordinates the keys used for queries. The $\val$ axis corresponds to coordinates of the value vectors. The $\seq$ axis corresponds to the positions of tokens (and so its length is the number of tokens). This notation makes it clear what are the types of each input, and how they are acted upon. The tensor $Q$ is a query, which is a vector over the $\key$ axis. The tensor $K$ is the map of positions to keys, and so has the $\seq$ and $\key$ axes. You can think of it as a matrix, but the reader does not need to remember whether $\seq$ corresponds to the columns and $\key$ to the rows or vice versa. The tensor $V$ maps every position in $\seq$ to a value vector over the axis $\val$, and so has the $\seq$ and $\val$ axes. The description of the functions makes it clear on which axis any operation acts upon. For example, to parse the expression $Q \ndot{\key} K$ it doesn't matter if $\key$ correpsonds to rows or columns in $K$, since the dot product is taken over the $\key$ axis which is shared between $K$ and $Q$.
Our notation also makes it easy to ``broadcast'' or ``vectorize'' a function. For example, if instead of being a vector with $\key$ axis, $Q$ is a tensor with $\key$, $\seq$ and $\batch$ axes (corresponding to token in a sequence and to index in a minibatch) then these axes will be ``propagated'' to the output, with the function acting on each element of them independently. That is, in this case $\text{Attention}(Q,K,V)$ will be the tensor $A$ with axes $\val$, $\seq$ and $\batch$, such that for every index $s$ of $\seq$ and $b$ of $\batch$, the corresponding element of $A$ is obtained by applying $\text{Attention}$ to the corresponding restriction of $Q$ (together with $K$ and $V$). Similarly, we can also add $\heads$ axis to the inputs for multiple attention heads.
$\sum_{i \in \mathsf{foo}} A_i \left(\sum_{j \in \mathsf{bar}} B_j C_j \right)_i =
\sum_{j \in \mathsf{bar}} \left( \sum _{i \in \mathsf{foo}} A_iB_i \right)_j C_j = \sum_{i \in \mathsf{foo}, j \in \mathsf{bar}} A_i B_{i,j} C_j$
$\sum_{i \in \mathsf{ax}} A_{\mathsf{ax}(i)} \left(B \otimes C \right)_{\mathsf{ax}(i)} =
\sum_{i \in \mathsf{ax}} (A \otimes B)_{\mathsf{ax}(i)} C_{\mathsf{ax}(i)} = \sum_{i \in \mathsf{ax}} A_{\mathsf{ax}(i)}B_{\mathsf{ax}(i)}C_{\mathsf{ax}(i)}$
\sum_{j \in \mathsf{ax}} \left( \sum _{i \in \mathsf{foo}} A_iB_i \right)_j C_j = \sum_{i \in \mathsf{foo}, j \in \mathsf{bar}} A_i B_{i,j} C_j$
$\sum_{i \in \mathsf{foo}} A_{\mathsf{foo}(i)} \left(\sum_{j \in \mathsf{bar}} B_{\mathsf{bar}(j)} C_{\mathsf{bar}(j)} \right)_{\mathsf{foo}(i)} =
\sum_{j \in \mathsf{bar}} \left( \sum _{i \in \mathsf{foo}} A_{\mathsf{foo}(i)}B_{\mathsf{foo}(i)} \right)_{\mathsf{bar}(j)} C_{\mathsf{bar}(j)} = \sum_{i \in \mathsf{foo}, j \in \mathsf{bar}} A_{\mathsf{foo}(i)} B_{\mathsf{foo}(i),{\mathsf{bar}(j)}} C_{\mathsf{bar}(j)}$
# Tensor notation - V2
Here is a suggestion for notation. It largely tracks the notation used in the current document, except that axes are identified not by mere "names". I find it useful to think while defining a notion how these will look in code as well. BTW I think that aside from correctness issues, using named tensor might also help with efficiency since it gives more of a chance to optimize the particular way tensors are stored.
We define a __axis__ to be a set of pairs of the form $\{ (\text{name}, x) | x\in X \}$ where $X$ some ordered set. We say that $\text{name}$ is the name of the axis. We will typically denote the axis with a certain name by using this name in bold font, and so for example $\mathbf{channel}$ might correspond to the set $\{ (\text{channel},1), (\text{channel},2),(\text{channel},3) \}$.
For notation we might write $\mathbf{ax} \underset{ax}{=} X$, where $X$ is the index se to mean that
$\mathbf{ax} = \{ (\text{ax},x) | x\in X \}$. For example we might write $\mathbf{width} \underset{ax}{=} [n]$.
In the common case $X$ will be the set $[n]$ for some integer $n$; this will be default choice.
An element $(dim,x)$ of an axis $\mathbf{ax}$ is called an _index_.
We can think of an axis as a _type_: the difference between $(\text{channel},1)$ and the plain integer $1$ is that the former has the type $\mathbf{channel}$ and so changing from (for example) $(\text{channel},1)$ to $(\text{width},1)$ requires explicit casting.
__Comment:__ It may seem strange that we use $\mathbf{channel}$ both as the name of the _type_ and as a variable name to identify an element of this type. This is a little like "Hungarian" programming notation where the name of the variable is prefixed by its type. However, here in many cases we won't need more than a single variable for an axis, and so the prefix is the entire name. We will use _annotation_ if we want more than one variable of the same type. Note that also in math, variable names such as $i,j,n$ or $f,g$ or $\epsilon,\delta$ are used indicate the type of the object they denote.
__Python:__ While I proposed to conceptually view an axis as a type, I think it would actually make sense in Python to think of it as an object (potentially a singleton element of a class). I do not think we want an axis as simply a string, since different packages might use a string "width" in different ways (for example maybe use different coordinate systems. Code might look something like
```python
import axes
width = axes.width
rgb = axes.register("rgb",["red","green","blue"])
```
I imagine there would be standard constant axes that everyone can use, but also a way to add new axes.
It would also be possible to register casting transformation that specify how we transform an index of one axis into another. For example translate from a coordinate system where $(0,0)$ is top left corner to one when its bottom left, or transfer a pair $(i,j) \in (\mathbf{width},\mathbf{height})$ into $j \in \mathbf{layer}$ for flattening.
an axis $\mathbf{ax'}$ is a _restriction_ of an axis $\mathbf{ax}$ if $\mathbf{ax'} \subseteq \mathbf{ax}$.
In other words, if $\mathbf{ax}$ has the form $\{ (\text{name},x) | x \in X \}$ then $\mathbf{ax'}$ has the form $\{ (\text{name},x) | x \in X' \}$ for $X' \subseteq X$.
We denote the name of an axis $\mathbf{ax}$ by $name(\mathbf{ax})$.
The _support_ of an axis $\mathbf{ax}$, denoted by $Supp(\mathbf{ax})$ is the set $X$ above.
That is, $\mathbf{ax}$ is the set $\{ (name(\mathbf{ax}), x) | x\in Supp(\mathbf{ax}) \}$.
We define $|\mathbf{ax}| = |Supp(\mathbf{ax})|$.
For an axis $\mathbf{ax}$ with support $X$, if $S \subseteq X$ then we use the notation $\mathbf{ax}[S]$ for the restricted axis $\{ (name(\mathbf{ax}),x) | x\in S \}$. We use the notation $\mathbf{ax}[..n]$ (I prefer this to $\mathbf{ax}[n]$ - I want to keep $\mathbf{ax}[n]$ for the $n$-th index of this axis, or maybe for casting the number $n$ into the axis) for the axis $\mathbf{ax}[\{ x_1,x_2,\ldots,x_n \}]$ where these are the first $n$ elements of $Supp(\mathbf{ax})$.
__Python:__ In Python we will write something like
```python
width_input = width[:64]
print(len(width_input)==64) # prints True
```
A _shape_ $\mathcal{S}$ is an unordered set of axes.
__Annotation__ We sometimes want to use two different names for the same axis. We will use primes, subscripts or superscripts in such a case.
Mathematically it means that we can distinguish between the two copies when we need to, and treat them uniformly when we don't.
For example we could write $\mathbf{channel}_{in}$ and $\mathbf{channel}_{out}$. The notation will be uniform regardless of whether $\mathbf{channel}_{in}$ and $\mathbf{channel}_{out}$ happen to have the same length or not. If a tensor $T$ has only one axis of type channel then we can write something like $T_{\mathbf{channel}=17}$. If it has both we need to write $T_{\mathbf{channel}_{in}=17}$. In Python we might write this as `channel_in = channel.annotate('in')` or something like that.
A _tensor_ $T$ of shape $\mathcal{S} = \{ \mathbf{ax}_1 ,\ldots , \mathbf{ax}_d \}$ over the reals is a map that on input an $d$-sized set of indices $\{ i_1 , \ldots, i_d \}$ with $i_j \in \mathbf{ax}_j$ outputs a real number, which we denote by $T(i_1,\ldots,i_j)$ or $T_{i_1,\ldots,i_j}$.
We denote the shape of tensor $T$ by $shape(T)$.
As an example, if we set $\mathbf{width} \underset{ax}{=} [64]$, $\mathbf{height} \underset{ax}{=} [64]$, $\mathbf{channel} \underset{ax}{=} \{ "R","G", "B" \}$, to define a tensor of shape $\{ \mathbf{width}, \mathbf{height}, \mathbf{channel} \}$ we can write:
$A \in \mathbb{R}^{\mathbf{width} , \mathbf{height} , \mathbf{channel}}$
(This notation is ordered to be familiar and consistent with standard tensors, but our tensors will always be symmetric in the sense that $A_{i,j,k} = A_{k,i,j}=\ldots$ )i_
If all axes are integrals, then we can define them implicitly through their restrictions and just write
$A \in \mathbb{R}^{\mathbf{width}[..64] \times \mathbf{height}[..64] \times \mathbf{channel}[..3]}$
In this case we assume (for example) that $\mathbf{width}$ is an axis with support $[n]$ where $n$ is set to be large enough to allow all the restrictions used in the paper.
In python we might write this as
```python
A = namedtensors.zeroes(size=[width[:64], height[:64],channel[:3]])
# the size is unordered, can be a list or also a set
print(A.shape)
# prints { width[:64], height[:64], channel[:3] }
```
We can also convert a standard tensor `T` to a named one.
```python
A = namedtensor.tensor(T,[width, height])
# first axis becomes width and second height.
# lengths are inherited from T and so do not need to be specified.
```
### Indexing
If $A$ has order $d$ (i.e. $|Shape(A)|=d$) and $i \in \mathbf{ax}$ for $\mathbf{ax} \in Shape(A)$ then $A(i)$ (can also be written as $A_i$ ) is the "slice" which maps every $i_1,\ldots,i_{d-1}$ in the axes in $Shape(A) \setminus \{ \mathbf{ax} \}$ to $A(i,i_1,\ldots,i_{d-1})$
If $x$ is a member of $Supp(\mathbf{ax})$ (for example $x$ is an integer if this is an integral axis) and $name = name(\mathbf{ax})$ then the slice $A_{(name,x)}$ is also denoted by $A_{\mathbf{ax}[x]}$ or $A_{\mathbf{ax}=x}$.
In Python we might write something like:
```python
print(A[width[10],height[3]]) # print A_{width=10,height=3}
# print all elements
for w in width(A):
for h in height(A):
print(A[w,h])
```
### Partial indexing
If $S=\{ \mathbf{ax}_1,\ldots,\mathbf{ax}_\ell \}$ is a strict subset of $shape(A)$ and $i_1,\ldots,i_\ell$ are indices in $\mathbf{ax}_1,\ldots,\mathbf{ax}_\ell$ respectively then $A' = A[i_1,\ldots,i_\ell]$ will be the tensor of shape $shape(A) \setminus S$ consisting of the corresponding "slice" of $A$.
(That is, for every indices $j_1,\ldots,j_{d-\ell}$ in the axes $shape(A) \setminus S$, $A'(j_1,\ldots,i_{d-\ell}) = A(i_1,\ldots,i_{\ell},j_1,\ldots,j_{d-\ell})$.)
In Python, if `A.shape == [width,height]` we can write `A[width[10]]` to get the corresponding vector of shape `{height}` or `A[height[5]]` to get the corresponding vector of shape `{width}`.
### Missing indices
I suggest that if $\mathbf{ax}$ is missing from a tensor then it is treated as if it is axis one. For example, the following code will print either all coordinates of a batch of images or all coordinates of a single image, depending on whether or not A has the $\mathbf{batch}$ axis
```python
# print all elements
for b in batch(A):
for w in width(A):
for h in height(A):
print(A[w,h,b])
```
## Operators
An __operator__ takes as input one or more tensors and outputs as input a tensor. For example:
$CONV: \mathbb{R}^{\mathbf{width},\mathbf{height}} \times \mathbb{R}^{\mathbf{width},\mathbf{height}} \rightarrow \mathbb{R}^{\mathbf{width},\mathbf{height}}$
is the two axisal convolution defined as following: $CONV(W,X)=Y$ of shape $\{ \mathbf{width}(X),\mathbf{height}(X) \}$ such that for every $(i,j) \in (\mathbf{width},\mathbf{height})$,
$Y_{x,y} = \sum_{(a,b) \in (\mathbf{width}(W),\mathbf{height}(W)} X_{x+a,y+b} W_{a,b}$
Another way to write it is that
$Y_{x,y} = X_{\mathbf{width}[x-\ell:x+\ell],\mathbf{height}[y-\ell,y+\ell]} \cdot W$ where $|\mathbf{width}(W)|=|\mathbf{height}(W)|=2\ell+1$.
(I am thinking of the axes of $W$ here as indexed by integers $\{ -\ell, -\ell+1, \ldots, 0, 1,\ldots, \ell \}$. We assume here that any "out of bound" value defaults to zero.)
In Python we might write this as follows:
```python
@signature(X=[width,height],W=[width,height], output=[width,height])
def CONV(X,W):
Y = namedtensor.tensor({(x,y) : sum((a,b) in zip(width(W),height(W)) X[x+a,y+b]*W[a,b]} for (x,y) in zip(width(X),height(X))))
```
If we wanted to add the $\mathbf{channel}$ axis then by adding it to the signature above to both `X` and `W`, the product `X[x+a,y+b]*W[a,b]` will be interpreted
as a dot product of these two $\{ \mathbf{channel} \}$ shaped vectors.
We will also have a way in Python to extend existing functions and modules that were written for unnamed tensors.
Maybe something like
```python
g = named(X=[width,height], output = [layer] ,f)
```
when `f` is such a function that takes `X` as a tensor input. In this case width will be mapped to the first axis and height to the second one.
Should also have such way to automatically transform pytorch modules of the form `nn.Module` to ones that use named tensors as their weights, inputs, and outputs.
As usual, the signature of an operator does not specify all the information about the relation of its input and output relation. For example, we can have the "halving" operator $HALVE$ that drops the bottom half of an image.
$HALVE:\mathbb{R}^{\mathbf{width},\mathbf{height}} \rightarrow \mathbb{R}^{\mathbf{width},\mathbf{height}}$
The abstract types of the input and output tensors are the same but we can ass the constraint that $|width(HALVE(X))|=\lfloor |width(X)/2| \rfloor$ separately.
In Python we might write this as an assertion.
We might also want to somehow be able to specify meta operators such as the identity that can take a tensor of any type . (Though I think that by our conventions on dangling tensors the operator $ID:\mathbb{R}^{\emptyset} \rightarrow \mathbb{R}^\emptyset$ defined as $ID(x)=x$ would actually do that.)
### Dangling and aligned axes
We use the following conventions for an operator $F$ taking one or more tensors $X_1,X_2,\ldots,X_\ell$ when these tensors contain axes that do not appear in the signature of $F$:
* If an axis $\mathbf{ax}$ appears only in one of the $X_i$'s and not in any other, then it is called __dangling__. In such a case, we execute $F$ in parallel for every slice of the form $X_i(\mathbf{ax}=i)$ and the output tensor will contain the axis $\mathbf{ax}$ with the same support as that of this $X_i$ with the corresponding slice having the output of $F$ on that slice.
* If an axis $\mathbf{ax}$ that does not appear in the signature appears in more than one of the $X_i$'s, then the two appearances must have the same support (unless we explain how aligning them can happen). In such a case these axes are __aligned__. In such cases, for every $i$ in the support of this axis, we run $F$ on the corresponding slices of all tensors in which the axis appears.
For example, consider the following operator $F:\mathbb{R}^{\mathbf{layer}}\times \mathbb{R}^{\mathbf{layer}} \rightarrow \mathbb{R}$ defined as $F(X,W) = ReLU(X \cdot W)$. If we execute this operator on $X\in \mathbb{R}^{\mathbf{layer},\mathbf{batch}}$ and $W \in \mathbb{R}^{\mathbf{layer},\mathbf{output}}$ then the output will be a tensor in $\mathbb{R}^{\mathbf{batch},\mathbf{output}}$.
## Summations / contractions
We can sum, take means, and do all sort of contraction operations on certain axes
In python we might write something like
```python
sum(X[batch])
```
or
```python
X.sum(axis=batch)
```
In math we we'll write $\sum_{\mathbf{batch}} X$ or $\sum_{b \in \mathbf{batch}} X_b$
We can also use dot product across different axes, maybe something like
```python=
X[cols] * Y[rows]
```
which we can write as a dot product with $\mathbf{cols}|\mathbf{row}$ under it.
$
X \underset{\mathbf{cols}|\mathbf{row}}{\cdot} Y
$
## Groupings / flattening
We can convert two axes $\mathbf{ax1},\mathbf{ax2}$ into a single axis which we can simply denote by $\{ \mathbf{ax1},\mathbf{ax2}\}$.
We can use these groupings to flatten or group together axes:
$v = A_{\{ \mathbf{width}, \mathbf{height} \}}$
will flatten a vector $A \in \mathbb{R}^{\mathbf{width},\mathbf{height}}$ into a one axisal vector of the shape $\{ \{ \mathbf{width},\mathbf{height} \} \}$.
If $A$ had extra axes then $v$ will have too.
I am not sure if we want to have a notation for flattinging all axes but we could perhaps write this as
$v = A_{\{ shape(A) \}}$
## Broadcasting
The convention of ignoring missed axes will automatically allow some broadcasting. Conceptually, we might think of all tensors as having all of the axes, with them having only one index in axes they are missing.
## Concatenating and taking parts of tensors
To be continued. Maybe the notation will be $X \| Y$ and they are concatenated along all axes that exist in both, or we can also concatenate across certain matching axes and use annotations for others.
Taking a half of a tensor could be something like
$B = A_{\mathbf{height}[..\tfrac{|\mathbf{height}|}{2}]}$