# 淺談CUTLASS / CuTe的Swizzling Functor
swizzle最主要的目的在於做GEMM(General Matrix Multiply)時,為了解決存取share memory的memory bank conflict的問題,我們需要去編排執行緒(thread)跟數值(value)對應到矩陣之間的關係
所以如果能用下圖這種方式來存取share memory,就可以減少memory bound

swizzle通常會搭配一個Layout去做函數合成(composition),主要是為了要對原本Layout所產生的索引值(index)去做變動(索引值並沒有不見,只是變成我們了想要的樣子),我們可以看看位在Cute裡關於swizzle及Layout的函數合成原始碼(source code) [https://github.com/NVIDIA/cutlass/blob/main/include/cute/swizzle_layout.hpp][1]

這裡的composition接受了三個參數swizzle functor、offset及layout,我們再往下,可以發現此composition會產生一個ComposedLayout型別
[https://github.com/NVIDIA/cutlass/blob/main/include/cute/layout_composed.hpp][2]

當有了些型別,之後再做邏輯座標(logical coordinate)轉換成索引值(index)時,會先透過之前的layout參數做轉換,再加上offset(=0),最後再呼叫swizzle functor

**所以整個的流程重點還是在於當我們得到索引值(index)時,swizzle functor對於這個索引值做了什麼樣的修改**
我們可以從Swizzle這個型別開始看起
[https://github.com/NVIDIA/cutlass/blob/main/include/cute/swizzle.hpp][3]


Swizzle有三個重要的參數B, M, S,此三個參數的意義如下
B: 表示我們在做XOR時,有2^B^交換的模式
M: 在此交換模式下,最小單一元素應包含2^M^個數字(這些數字會是連續的,不受swizzle的影響),這些數字也稱做基底元素([BaseElements][5])
S: 每一交換模式應該套用幾個最小單一元素(=2^|S|^)。通常來說|S|>=B,若|S|>B,則此交換模式會重複2^|S|-B^次,若|S|=B,則只會套用一次
通常看到這裡,應該只會覺得霧煞煞,我們就試著舉一些例子,並搭配原始碼(source code)來看。
**為了簡化我們的問題,以下我們所用到的例子,都會有一個假設,也就是我們使用的Layout是[compact layout][6]**
假設我們有一個Layout,它的shape(8,8),stride(8, 1)
```cuda
auto smem_atom = Layout<Shape<_8,_8>,Stride<_8,_1>>{};
print_layout(smem_atom);
```
它的output如下,這是一個0 ~ 63的索引值(index),並且是列為主(row-major)
```
(_8,_8):(_8,_1)
0 1 2 3 4 5 6 7
+----+----+----+----+----+----+----+----+
0 | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 |
+----+----+----+----+----+----+----+----+
1 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 |
+----+----+----+----+----+----+----+----+
2 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 |
+----+----+----+----+----+----+----+----+
3 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 |
+----+----+----+----+----+----+----+----+
4 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 |
+----+----+----+----+----+----+----+----+
5 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 |
+----+----+----+----+----+----+----+----+
6 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 |
+----+----+----+----+----+----+----+----+
7 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 |
+----+----+----+----+----+----+----+----+
```
如果合成了一個swizzle<1, 1, 1>之後,結果會變成
```cuda
auto smem_atom = composition(Swizzle<1, 1, 1>{}, Layout<Shape<_8,_8>,Stride<_8,_1>>{});
print_layout(smem_atom);
```
```
Sw<1,1,1> o _0 o (_8,_8):(_8,_1)
0 1 2 3 4 5 6 7
+----+----+----+----+----+----+----+----+
0 | 0 | 1 | 2 | 3 | 6 | 7 | 4 | 5 |
+----+----+----+----+----+----+----+----+
1 | 8 | 9 | 10 | 11 | 14 | 15 | 12 | 13 |
+----+----+----+----+----+----+----+----+
2 | 16 | 17 | 18 | 19 | 22 | 23 | 20 | 21 |
+----+----+----+----+----+----+----+----+
3 | 24 | 25 | 26 | 27 | 30 | 31 | 28 | 29 |
+----+----+----+----+----+----+----+----+
4 | 32 | 33 | 34 | 35 | 38 | 39 | 36 | 37 |
+----+----+----+----+----+----+----+----+
5 | 40 | 41 | 42 | 43 | 46 | 47 | 44 | 45 |
+----+----+----+----+----+----+----+----+
6 | 48 | 49 | 50 | 51 | 54 | 55 | 52 | 53 |
+----+----+----+----+----+----+----+----+
7 | 56 | 57 | 58 | 59 | 62 | 63 | 60 | 61 |
+----+----+----+----+----+----+----+----+
```
我們可以看到,最小的單一元素是包含2個數字(=2^M^=2^1^=2),比如原本是4, 5, 6, 7,如今變成了6, 7, 4, 5,我們可以把4,5及6,7分別看成是最小的單一元素。如果我們將M代入2,那結果會是如何呢?
```cuda
auto smem_atom = composition(Swizzle<1, 2, 1>{}, Layout<Shape<_8,_8>,Stride<_8,_1>>{});
print_layout(smem_atom);
```
我們可以看到原本的8, 9, 10, 11, 12, 13, 14, 15,如今變成12, 13, 14, 15, 8, 9, 10, 11,最小單一元素已變成8, 9, 10, 11及12, 13, 14, 15這2組,而每一組所包含的數字為4(=2^M^=2^2^=4)
```
Sw<1,2,1> o _0 o (_8,_8):(_8,_1)
0 1 2 3 4 5 6 7
+----+----+----+----+----+----+----+----+
0 | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 |
+----+----+----+----+----+----+----+----+
1 | 12 | 13 | 14 | 15 | 8 | 9 | 10 | 11 |
+----+----+----+----+----+----+----+----+
2 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 |
+----+----+----+----+----+----+----+----+
3 | 28 | 29 | 30 | 31 | 24 | 25 | 26 | 27 |
+----+----+----+----+----+----+----+----+
4 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 |
+----+----+----+----+----+----+----+----+
5 | 44 | 45 | 46 | 47 | 40 | 41 | 42 | 43 |
+----+----+----+----+----+----+----+----+
6 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 |
+----+----+----+----+----+----+----+----+
7 | 60 | 61 | 62 | 63 | 56 | 57 | 58 | 59 |
+----+----+----+----+----+----+----+----+
```
接著我們以swizzle<2, 1, 2>來探討B參數的意思
```cuda
auto smem_atom = composition(Swizzle<2, 1, 2>{}, Layout<Shape<_8,_8>,Stride<_8,_1>>{});
print_layout(smem_atom);
```
如果我們以原本的64個索引值(index),每2個數字為一組,只標示每組數字在在每一列中的行值(column),我們可以得出這樣圖

如果我們接著套用swizzle(2,1,2),那原本在每一列的行值(column)會變成這樣。感覺有點熟悉,很像我們一開始提的,希望存取每個行時,能分散開來

那為什麼會有這樣的結果呢,如果我們仔細觀察原始碼(source code),我們可以看到B參數會決定bit_msk的大小,在此例中bit_msk = (1 << 2) - 1 = 0b11
```cuda
static constexpr int num_bits = BBits;
using bit_msk = cute::constant<int, (1 << num_bits) - 1>;
```
之後我們從Layout輸出的索引值時,會利用yyy_msk及zzz_msk去對索引值做修改
yyy_msk = 0b11 << (1 + max(0, 2)) = 0b11000
zzz_msk = 0b11 << (1 - min(0, 2)) = 0b110
```cuda
static constexpr int num_base = MBase;
static constexpr int num_shft = SShift;
using yyy_msk = cute::constant<int, bit_msk{} << (num_base + max(0,num_shft))>;
using zzz_msk = cute::constant<int, bit_msk{} << (num_base - min(0,num_shft))>;
```
當我們在執行以下動作時,是拿高位元的2個bits(bit4 & bit3),跟低位元的2個bits(bit2 & bit1)做XOR
```cuda
apply(Offset const& offset)
{
return offset ^ shiftr(offset & yyy_msk{}, msk_sft{}); // ZZZ ^= YYY
}
```
2個bits的數值跟2個bits的數值去做XOR會有4種模式出現。這跟上面的結果是相符的
| | 0b00 | 0b01 | 0b10 | 0b11 |
|--------|:------:|:------:|:------:|:------:|
| 0b00 | 0 | 1 | 2 | 3 |
| 0b01 | 1 | 0 | 3 | 2 |
| 0b10 | 2 | 3 | 0 | 1 |
| 0b11 | 3 | 2 | 1 | 0 |
我們己經知道B參數的意義了,那如果我們套用swizzle<3, 1, 2>,那照理說應該會有2^B^(=2^3^=8)種交換模式
```cuda
auto smem_atom = composition(Swizzle<3, 1, 2>{}, Layout<Shape<_8,_8>,Stride<_8,_1>>{});
print_layout(smem_atom);
```
結果出現了一個error
```
error: static assertion failed with "abs(SShift) must be more than BBits."
static_assert(abs(num_shft) >= num_bits, "abs(SShift) must be more than BBits.");
```
這邊出錯的原因是原始碼(source code)中有一個檢查是長這樣
`static_assert(abs(num_shft) >= num_bits, "abs(SShift) must be more than BBits.");`
我們可以計算一下,在此例中的yyy_msk及zzz_msk
yyy_msk = 0b111 << (1 + max(0, 2)) = 0b111000
zzz_msk = 0b111 << (1 - min(0, 2)) = 0b001110
我們可以看到2個mask在bit3有重疊,這會造成一些問題,所以當我們在設定B及|S|參數時,需要S>=B
我們重新以swizzle<3, 1, 3>再試一次,此時的輸出為

從這個例子中,我們可以開始探討S參數的功用,此參數是指每一個交換模式下,需要幾個最小單一元素,由於S=3,我們可以得出需要2^|S|^(=2^3^=8)個最小單一元素,讓我們重新標記一下上面的圖,可以得到以下輸出,可以看出每2列,才能完成一次交換模式,在64個索引值(index)下,我們目前只有4種交換模式,照道理說應該有8種(=2^B^=2^3^=8),這是因為我們每一個交換模式需要2^|S|^個最小單一元素,所以總共需要2^|S|^*2^M^個數字(即2^3^*2^1^=16),但因為我們目前的索引值(index)只有64個,所以我們目前只會有4(=64/16)個交換模式

我們可以試著產生128個索引值(以增加行(column)的方式),來看看結果如何
```cuda
auto smem_atom = Layout<Shape<_8,_16>,Stride<_16,_1>>{};
print_layout(smem_atom);
```

```cuda
auto smem_atom = composition(Swizzle<3, 1, 3>{}, Layout<Shape<_8,_16>,Stride<_16,_1>>{});
print_layout(smem_atom);
```
一如我們的預期,總共有8種的交換模式,每個模式需要8個最小單位元素,每個最小單位元素包含2個數字

再來我們來看一下,當S參數大於B參數會發生什麼狀況
代入swizzle<3, 1, 4>,我們可以看到,同一種交換模式,重複了一次

我們之前有提到,若|S|>B,則交換模式會套用2^|S|-B^次,才會再進行下一個模式,這個由yyy_msk及zzz_msk不難看出,以這個例子來說
yyy_msk = 0b111 << (1 + max(0, 4)) = 0b11100000
zzz_msk = 0b111 << (1 - min(0, 4)) = 0b00001110
yyy_msk在bit7, bit6, bit5為1,而zzz_msk在bit3, bit2, bit1為1,而二者之間還有一個bit4,所以索引值(index)是0b0xxxx或是0b1xxxx,在套用這2個mask做完計算後會是相同的結果
我們來看一下當B參數為0的情況,依我們之前的定義2^0^=1,也是說我們只有一種交換模式,也就是我們原本索引值(index)的排列順序,如果我們去看原始碼(source code),也可以得到相同的結論

最後我們來看當S為負值時會發生什麼事,我們這裡舉一個multi-stride的範例(也就是索引值的排列,並非列為主(row major),也非行為主(column major))
```cuda
auto smem_atom = Layout<Shape<Shape<_2,_2>, Shape<_2,_2>>, Stride<Stride<_1,_4>, Stride<_2,_8>>>{};
print_layout(smem_atom);
```
```
((_2,_2),(_2,_2)):((_1,_4),(_2,_8))
0 1 2 3
+----+----+----+----+
0 | 0 | 2 | 8 | 10 |
+----+----+----+----+
1 | 1 | 3 | 9 | 11 |
+----+----+----+----+
2 | 4 | 6 | 12 | 14 |
+----+----+----+----+
3 | 5 | 7 | 13 | 15 |
+----+----+----+----+
```
我們先以swizzle<1, 0, 1>來試試看
```cuda
auto smem_atom = composition(Swizzle<1, 0, 1>{}, Layout<Shape<Shape<_2,_2>, Shape<_2,_2>>, Stride<Stride<_1,_4>, Stride<_2,_8>>>{});
print_layout(smem_atom);
```
我們可以看到,在每一個2x2的tile中,右邊的那一行的索引值做了交換
```
Sw<1,0,1> o _0 o ((_2,_2),(_2,_2)):((_1,_4),(_2,_8))
0 1 2 3
+----+----+----+----+
0 | 0 | 3 | 8 | 11 |
+----+----+----+----+
1 | 1 | 2 | 9 | 10 |
+----+----+----+----+
2 | 4 | 7 | 12 | 15 |
+----+----+----+----+
3 | 5 | 6 | 13 | 14 |
+----+----+----+----+
```
緊接著我們嘗試Swizzle<1, 0, -1>
```cuda
auto smem_atom = composition(Swizzle<1, 0, -1>{}, Layout<Shape<Shape<_2,_2>, Shape<_2,_2>>, Stride<Stride<_1,_4>, Stride<_2,_8>>>{});
print_layout(smem_atom);
```
我們可以看到這次反而是每個tile中下面那一列的索引值做了交換
```
Sw<1,0,-1> o _0 o ((_2,_2),(_2,_2)):((_1,_4),(_2,_8))
0 1 2 3
+----+----+----+----+
0 | 0 | 2 | 8 | 10 |
+----+----+----+----+
1 | 3 | 1 | 11 | 9 |
+----+----+----+----+
2 | 4 | 6 | 12 | 14 |
+----+----+----+----+
3 | 7 | 5 | 15 | 13 |
+----+----+----+----+
```
當S為負值時,我們是以索引值的特定低位元(yyy_msk)去跟特定高位元(zzz_msk)去做XOR,這個從原始碼(source code)可以很容易的看出
S > 0, offset = 0bxxxxxxxxxxxxxxxYYYxxxxxxxZZZxxxx // ZZZ ^ = YYY
S < 0, offset = 0bxxxxxxxxxxxxxxxZZZxxxxxxxYYYxxxx // ZZZ ^= YYY
在多維的情況,索引值交換的情形就不像一維那樣直覺,這個方面需要多練習,才可以根據不同的需求,做出不同的編排
參考資料:
1. [https://github.com/NVIDIA/cutlass/blob/main/include/cute/swizzle_layout.hpp][1]
2. [https://github.com/NVIDIA/cutlass/blob/main/include/cute/layout_composed.hpp][2]
3. [https://github.com/NVIDIA/cutlass/blob/main/include/cute/swizzle.hpp][3]
4. [GTC 2020: Developing CUDA kernels to push Tensor Cores to the Absolute Limit on NVIDIA A100][4]
5.
6. [compact layout][6]
7. [[QST] Why have a Swizzle with BBits = 0? #1704][7]
8. [[QST]How to create and use TiledMMA and ThrMMA in cute/atom/mma_atom.hpp #1028][8]
9. [cute 之 Swizzle][9]
[1]: https://github.com/NVIDIA/cutlass/blob/main/include/cute/swizzle_layout.hpp
[2]: https://github.com/NVIDIA/cutlass/blob/main/include/cute/layout_composed.hpp
[3]: https://github.com/NVIDIA/cutlass/blob/main/include/cute/swizzle.hpp
[4]: https://developer.download.nvidia.com/video/gputechconf/gtc/2020/presentations/s21745-developing-cuda-kernels-to-push-tensor-cores-to-the-absolute-limit-on-nvidia-a100.pdf
[5]:https://github.com/NVIDIA/cutlass/issues/1876
[6]: https://github.com/NVIDIA/cutlass/issues/944
[7]:
https://github.com/NVIDIA/cutlass/issues/1704
[8]: https://github.com/NVIDIA/cutlass/issues/1028
[9]: https://zhuanlan.zhihu.com/p/671419093