Try   HackMD

Transformers using Named Tensor Notation

written by @marc_lelarge

We are using the Named Tensor Notation from David Chiang, Alexander M. Rush and Boaz Barak to describe the basic (i.e. without position encoding and autoregressive mask) encoder block of Transformers defined in Attention Is All You Need by Vaswani et al.

For a presentation of attention mechanism and transformers, see Module 12 - Attention and Transformers

Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More →

Basics about Named Tensor Notation

These notations should feel natural as it implements the elementwise operations, broadcasting, reductions and contractions natural in numpy and PyTorch.
Perhaps, we should recall that functions from vectors to vectors lift to functions on tensors that operate along one axis but leave the tensor shape unchanged. For example, the following softmax function defined by

fgsoftmaxfgax:RfgaxRfgaxfgsoftmaxfgax(X)=exp(X)fgaxexp(X)
will act on any
XRfgax×fgbx
so that
fgsoftmaxfgax(X)=YRfgax×fgbx
is such that
fgaxY=1Rfgbx
. The function
fgsoftmaxfgax
is only defined for tensors having an axis named
fgax
, this is the meaning of the line:
fgsoftmaxfgax:RfgaxRfgax
. Note in particular, that
Rfgax
is NOT the domain of definition of the function, it contains the minimal set of named axis required to apply the function.

Recall that elmentwise multiplication is denoted by

fg and the dot-product along axis
fgax
is denoted
fgfgax
. Here is one example taken from the original paper:
A=fgheightfgwidth[314159265]Rfgheight×fgwidth,

and
x=fgheight[271]Rfgheight.

Then, we have
Afgx=fgheightfgwidth[321242175797216151]Rfgheight×fgwidth,

and
Afgfgheightx=fgheightAfgx=fgwidth[6+7+22+35+68+49+5]Rfgwidth,

corresponding to the standard matrix multiplication.

Linear layers

For

WRfgax×fgbx and
bRfgbx
, we define a linear layer by:
LNW:RfgaxRfgbxLNW(X)=XfgfgaxW+b

Note that if

XRfgax×fgseq, then we have
LNW(X)Rfgbx×fgseq
.

Feedforward neural networks

Now a feedforward neural network is defined for

W1Rfgax×fghidden,
b1Rfghidden
and
W2Rfghidden×fgbx
,
b2Rfgbx
as:
FFN:RfgaxRfgbxFFN(X)=ReLU(XfgfgaxW1+b1)fgfghiddenW2+b2

Again, if

XRfgax×fgseq, then we have
FFN(X)Rfgbx×fgseq
.

Attention and SelfAttention

Attention:Rfgkey×Rfgseq×fgkey×Rfgseq×fgvalRfgvalAttention(Q,K,V)=(fgsoftmaxfgseqQfgfgkeyK|fgkey|)fgfgseqV.

This definition takes a single query

Q vector and returns a single result vector (and actually could be further reduced to a scalar values as
fgval
is not strictly necessary). To apply to a sequence, we can give
Q
a
fgseq
axis, and the function will compute an output sequence. Providing
Q
,
K
, and
V
with a
fgheads
axis lifts the function to compute multiple attention heads.

We can now define SelfAttention. Let

WQRfgchans×fgkey,
bQRfgkey
,
WKRfgchans×fgkey
,
bKRfgkey
and
WVRfgchans×fgval
,
bVRfgval
, with
|fgval|=|fgchans|
:
SelfAttention:Rfgchans×fgseqRfgval×fgseqSelfAttention(X)=Attention(LNQ(X),LNK(X),LNV(X)),

where
LNQ
(resp.
LNK
and
LNV
) are linear layers associated with
WQ
(resp.
WK
and
WV
).
Note that the names of the output are
fgseq
and
fgval
and to be able to add this to the input we need to rename
fgvalfgchans
and this is possible because
|fgval|=|fgchans|
. So that in the end, we can do
X+SelfAttention(X)fgvalfgchansRfgchans×fgseq

Normalization Layers

We can define a single generic standardization function as:

fgstandardizefgax:RfgaxRfgaxfgstandardizefgax(X)=Xfgmeanfgax(X)fgvarfgax(X)+ϵ
where
ϵ>0
is a small constant for numerical stability.

Note that

fgmeanfgax(X)=1|fgax|fgaxX,
so that if
XRfgax×fgbx
, then
fgmeanfgax(X)Rfgbx
and we are using broadcasting when we write
Xfgmeanfgax(X)Rfgax×fgbx
in the numerator above and similarly for the denominator.

Then, we can define the three kinds of normalization layers, all with type

Rfgbatch×fgchans×fglayerRfgbatch×fgchans×fglayer:
BatchNorm(X;γ,β)=fgstandardizefgbatch,fglayer(X)fgγ+βγ,βRfgchansInstanceNorm(X;γ,β)=fgstandardizefglayer(X)fgγ+βγ,βRfgchansLayerNorm(X;γ,β)=fgstandardizefglayer,fgchans(X)fgγ+βγ,βRfgchans×fglayer

Note that the shape of the output is always the same as the shape of the input for normalization layers.

For transformers, we will use a particular LayerNorm where

|fglayer|=1, so that we can simplify it as follows:
LayerNorm:RfgchansRfgchansLayerNorm(X)=fgstandardizefgchans(X)fgγ+βγ,βRfgchans

A simple Transformer block

To simplify, we omit multiple heads here. We also present the pre-LN Transformer (see On Layer Normalization in the Transformer Architecture by Xiong et al.) where LayerNorm are put before SelfAttention and Feed Forward Network.

Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More →

XRfgchans×fgseqX1=LayerNorm(X)Rfgchans×fgseqX2=X+SelfAttention(X1)fgvalfgaxRfgchans×fgseqY=X2+FFN(LayerNorm(X2))Rfgchans×fgseq

Note that we take the simple version of LayerNorm with

|fglayer|=1. The values of the axis
fgseq
are mixed only in the SelfAttention layer.

Summary

LayerNorm:RfgchansRfgchansLayerNorm(X)=fgstandardizefgchans(X)fgγ+βγ,βRfgchans
Attention:Rfgkey×Rfgseq×fgkey×Rfgseq×fgvalRfgvalAttention(Q,K,V)=(fgsoftmaxfgseqQfgfgkeyK|fgkey|)fgfgseqV.

Let
WQ,WKRfgchans×fgkey
,
bQ,bKRfgkey
, and
WVRfgchans×fgval
,
bVRfgval
, with
|fgval|=|fgchans|
:
SelfAttention:Rfgchans×fgseqRfgval×fgseqSelfAttention(X)=Attention(XfgfgchansWQ+bQ,XfgfgchansWK+bK,XfgfgchansWV+bV)

FFN:RfgchansRfgchansFFN(X)=ReLU(XfgfgchansW1+b1)fgfghiddenW2+b2

XRfgchans×fgseqX1=LayerNorm(X)Rfgchans×fgseqX2=X+SelfAttention(X1)fgvalfgaxRfgchans×fgseqY=X2+FFN(LayerNorm(X2))Rfgchans×fgseq

For a presentation of attention mechanism and transformers, see Module 12 - Attention and Transformers

tags: public dataflowr transformers