# 淺談CUTLASS / CuTe的Swizzling Functor swizzle最主要的目的在於做GEMM(General Matrix Multiply)時,為了解決存取share memory的memory bank conflict的問題,我們需要去編排執行緒(thread)跟數值(value)對應到矩陣之間的關係 所以如果能用下圖這種方式來存取share memory,就可以減少memory bound ![image](https://hackmd.io/_uploads/rJoSj2HWJe.png) swizzle通常會搭配一個Layout去做函數合成(composition),主要是為了要對原本Layout所產生的索引值(index)去做變動(索引值並沒有不見,只是變成我們了想要的樣子),我們可以看看位在Cute裡關於swizzle及Layout的函數合成原始碼(source code) [https://github.com/NVIDIA/cutlass/blob/main/include/cute/swizzle_layout.hpp][1] ![image](https://hackmd.io/_uploads/SkyRP-SZ1g.png) 這裡的composition接受了三個參數swizzle functor、offset及layout,我們再往下,可以發現此composition會產生一個ComposedLayout型別 [https://github.com/NVIDIA/cutlass/blob/main/include/cute/layout_composed.hpp][2] ![image](https://hackmd.io/_uploads/HJFNF0r-ke.png) 當有了些型別,之後再做邏輯座標(logical coordinate)轉換成索引值(index)時,會先透過之前的layout參數做轉換,再加上offset(=0),最後再呼叫swizzle functor ![image](https://hackmd.io/_uploads/ry3OoCHZyx.png) **所以整個的流程重點還是在於當我們得到索引值(index)時,swizzle functor對於這個索引值做了什麼樣的修改** 我們可以從Swizzle這個型別開始看起 [https://github.com/NVIDIA/cutlass/blob/main/include/cute/swizzle.hpp][3] ![image](https://hackmd.io/_uploads/H11N-kUZ1l.png) ![image](https://hackmd.io/_uploads/ryT_by8Wke.png) Swizzle有三個重要的參數B, M, S,此三個參數的意義如下 B: 表示我們在做XOR時,有2^B^交換的模式 M: 在此交換模式下,最小單一元素應包含2^M^個數字(這些數字會是連續的,不受swizzle的影響),這些數字也稱做基底元素([BaseElements][5]) S: 每一交換模式應該套用幾個最小單一元素(=2^|S|^)。通常來說|S|>=B,若|S|>B,則此交換模式會重複2^|S|-B^次,若|S|=B,則只會套用一次 通常看到這裡,應該只會覺得霧煞煞,我們就試著舉一些例子,並搭配原始碼(source code)來看。 **為了簡化我們的問題,以下我們所用到的例子,都會有一個假設,也就是我們使用的Layout是[compact layout][6]** 假設我們有一個Layout,它的shape(8,8),stride(8, 1) ```cuda auto smem_atom = Layout<Shape<_8,_8>,Stride<_8,_1>>{}; print_layout(smem_atom); ``` 它的output如下,這是一個0 ~ 63的索引值(index),並且是列為主(row-major) ``` (_8,_8):(_8,_1) 0 1 2 3 4 5 6 7 +----+----+----+----+----+----+----+----+ 0 | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | +----+----+----+----+----+----+----+----+ 1 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | +----+----+----+----+----+----+----+----+ 2 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | +----+----+----+----+----+----+----+----+ 3 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | +----+----+----+----+----+----+----+----+ 4 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | +----+----+----+----+----+----+----+----+ 5 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | +----+----+----+----+----+----+----+----+ 6 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | +----+----+----+----+----+----+----+----+ 7 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | +----+----+----+----+----+----+----+----+ ``` 如果合成了一個swizzle<1, 1, 1>之後,結果會變成 ```cuda auto smem_atom = composition(Swizzle<1, 1, 1>{}, Layout<Shape<_8,_8>,Stride<_8,_1>>{}); print_layout(smem_atom); ``` ``` Sw<1,1,1> o _0 o (_8,_8):(_8,_1) 0 1 2 3 4 5 6 7 +----+----+----+----+----+----+----+----+ 0 | 0 | 1 | 2 | 3 | 6 | 7 | 4 | 5 | +----+----+----+----+----+----+----+----+ 1 | 8 | 9 | 10 | 11 | 14 | 15 | 12 | 13 | +----+----+----+----+----+----+----+----+ 2 | 16 | 17 | 18 | 19 | 22 | 23 | 20 | 21 | +----+----+----+----+----+----+----+----+ 3 | 24 | 25 | 26 | 27 | 30 | 31 | 28 | 29 | +----+----+----+----+----+----+----+----+ 4 | 32 | 33 | 34 | 35 | 38 | 39 | 36 | 37 | +----+----+----+----+----+----+----+----+ 5 | 40 | 41 | 42 | 43 | 46 | 47 | 44 | 45 | +----+----+----+----+----+----+----+----+ 6 | 48 | 49 | 50 | 51 | 54 | 55 | 52 | 53 | +----+----+----+----+----+----+----+----+ 7 | 56 | 57 | 58 | 59 | 62 | 63 | 60 | 61 | +----+----+----+----+----+----+----+----+ ``` 我們可以看到,最小的單一元素是包含2個數字(=2^M^=2^1^=2),比如原本是4, 5, 6, 7,如今變成了6, 7, 4, 5,我們可以把4,5及6,7分別看成是最小的單一元素。如果我們將M代入2,那結果會是如何呢? ```cuda auto smem_atom = composition(Swizzle<1, 2, 1>{}, Layout<Shape<_8,_8>,Stride<_8,_1>>{}); print_layout(smem_atom); ``` 我們可以看到原本的8, 9, 10, 11, 12, 13, 14, 15,如今變成12, 13, 14, 15, 8, 9, 10, 11,最小單一元素已變成8, 9, 10, 11及12, 13, 14, 15這2組,而每一組所包含的數字為4(=2^M^=2^2^=4) ``` Sw<1,2,1> o _0 o (_8,_8):(_8,_1) 0 1 2 3 4 5 6 7 +----+----+----+----+----+----+----+----+ 0 | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | +----+----+----+----+----+----+----+----+ 1 | 12 | 13 | 14 | 15 | 8 | 9 | 10 | 11 | +----+----+----+----+----+----+----+----+ 2 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | +----+----+----+----+----+----+----+----+ 3 | 28 | 29 | 30 | 31 | 24 | 25 | 26 | 27 | +----+----+----+----+----+----+----+----+ 4 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | +----+----+----+----+----+----+----+----+ 5 | 44 | 45 | 46 | 47 | 40 | 41 | 42 | 43 | +----+----+----+----+----+----+----+----+ 6 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | +----+----+----+----+----+----+----+----+ 7 | 60 | 61 | 62 | 63 | 56 | 57 | 58 | 59 | +----+----+----+----+----+----+----+----+ ``` 接著我們以swizzle<2, 1, 2>來探討B參數的意思 ```cuda auto smem_atom = composition(Swizzle<2, 1, 2>{}, Layout<Shape<_8,_8>,Stride<_8,_1>>{}); print_layout(smem_atom); ``` 如果我們以原本的64個索引值(index),每2個數字為一組,只標示每組數字在在每一列中的行值(column),我們可以得出這樣圖 ![image](https://hackmd.io/_uploads/HkOuxG8W1g.png) 如果我們接著套用swizzle(2,1,2),那原本在每一列的行值(column)會變成這樣。感覺有點熟悉,很像我們一開始提的,希望存取每個行時,能分散開來 ![image](https://hackmd.io/_uploads/rJy9YXIWke.png) 那為什麼會有這樣的結果呢,如果我們仔細觀察原始碼(source code),我們可以看到B參數會決定bit_msk的大小,在此例中bit_msk = (1 << 2) - 1 = 0b11 ```cuda static constexpr int num_bits = BBits; using bit_msk = cute::constant<int, (1 << num_bits) - 1>; ``` 之後我們從Layout輸出的索引值時,會利用yyy_msk及zzz_msk去對索引值做修改 yyy_msk = 0b11 << (1 + max(0, 2)) = 0b11000 zzz_msk = 0b11 << (1 - min(0, 2)) = 0b110 ```cuda static constexpr int num_base = MBase; static constexpr int num_shft = SShift; using yyy_msk = cute::constant<int, bit_msk{} << (num_base + max(0,num_shft))>; using zzz_msk = cute::constant<int, bit_msk{} << (num_base - min(0,num_shft))>; ``` 當我們在執行以下動作時,是拿高位元的2個bits(bit4 & bit3),跟低位元的2個bits(bit2 & bit1)做XOR ```cuda apply(Offset const& offset) { return offset ^ shiftr(offset & yyy_msk{}, msk_sft{}); // ZZZ ^= YYY } ``` 2個bits的數值跟2個bits的數值去做XOR會有4種模式出現。這跟上面的結果是相符的 | | 0b00 | 0b01 | 0b10 | 0b11 | |--------|:------:|:------:|:------:|:------:| | 0b00 | 0 | 1 | 2 | 3 | | 0b01 | 1 | 0 | 3 | 2 | | 0b10 | 2 | 3 | 0 | 1 | | 0b11 | 3 | 2 | 1 | 0 | 我們己經知道B參數的意義了,那如果我們套用swizzle<3, 1, 2>,那照理說應該會有2^B^(=2^3^=8)種交換模式 ```cuda auto smem_atom = composition(Swizzle<3, 1, 2>{}, Layout<Shape<_8,_8>,Stride<_8,_1>>{}); print_layout(smem_atom); ``` 結果出現了一個error ``` error: static assertion failed with "abs(SShift) must be more than BBits." static_assert(abs(num_shft) >= num_bits, "abs(SShift) must be more than BBits."); ``` 這邊出錯的原因是原始碼(source code)中有一個檢查是長這樣 `static_assert(abs(num_shft) >= num_bits, "abs(SShift) must be more than BBits.");` 我們可以計算一下,在此例中的yyy_msk及zzz_msk yyy_msk = 0b111 << (1 + max(0, 2)) = 0b111000 zzz_msk = 0b111 << (1 - min(0, 2)) = 0b001110 我們可以看到2個mask在bit3有重疊,這會造成一些問題,所以當我們在設定B及|S|參數時,需要S>=B 我們重新以swizzle<3, 1, 3>再試一次,此時的輸出為 ![image](https://hackmd.io/_uploads/rkgspVL-yg.png) 從這個例子中,我們可以開始探討S參數的功用,此參數是指每一個交換模式下,需要幾個最小單一元素,由於S=3,我們可以得出需要2^|S|^(=2^3^=8)個最小單一元素,讓我們重新標記一下上面的圖,可以得到以下輸出,可以看出每2列,才能完成一次交換模式,在64個索引值(index)下,我們目前只有4種交換模式,照道理說應該有8種(=2^B^=2^3^=8),這是因為我們每一個交換模式需要2^|S|^個最小單一元素,所以總共需要2^|S|^*2^M^個數字(即2^3^*2^1^=16),但因為我們目前的索引值(index)只有64個,所以我們目前只會有4(=64/16)個交換模式 ![image](https://hackmd.io/_uploads/r1CufrIZ1x.png) 我們可以試著產生128個索引值(以增加行(column)的方式),來看看結果如何 ```cuda auto smem_atom = Layout<Shape<_8,_16>,Stride<_16,_1>>{}; print_layout(smem_atom); ``` ![image](https://hackmd.io/_uploads/BkN4rSLZkg.png) ```cuda auto smem_atom = composition(Swizzle<3, 1, 3>{}, Layout<Shape<_8,_16>,Stride<_16,_1>>{}); print_layout(smem_atom); ``` 一如我們的預期,總共有8種的交換模式,每個模式需要8個最小單位元素,每個最小單位元素包含2個數字 ![image](https://hackmd.io/_uploads/Sk57ESv-yx.png) 再來我們來看一下,當S參數大於B參數會發生什麼狀況 代入swizzle<3, 1, 4>,我們可以看到,同一種交換模式,重複了一次 ![image](https://hackmd.io/_uploads/Bk0yFUU-1l.png) 我們之前有提到,若|S|>B,則交換模式會套用2^|S|-B^次,才會再進行下一個模式,這個由yyy_msk及zzz_msk不難看出,以這個例子來說 yyy_msk = 0b111 << (1 + max(0, 4)) = 0b11100000 zzz_msk = 0b111 << (1 - min(0, 4)) = 0b00001110 yyy_msk在bit7, bit6, bit5為1,而zzz_msk在bit3, bit2, bit1為1,而二者之間還有一個bit4,所以索引值(index)是0b0xxxx或是0b1xxxx,在套用這2個mask做完計算後會是相同的結果 我們來看一下當B參數為0的情況,依我們之前的定義2^0^=1,也是說我們只有一種交換模式,也就是我們原本索引值(index)的排列順序,如果我們去看原始碼(source code),也可以得到相同的結論 ![image](https://hackmd.io/_uploads/BkfAs8I-kx.png) 最後我們來看當S為負值時會發生什麼事,我們這裡舉一個multi-stride的範例(也就是索引值的排列,並非列為主(row major),也非行為主(column major)) ```cuda auto smem_atom = Layout<Shape<Shape<_2,_2>, Shape<_2,_2>>, Stride<Stride<_1,_4>, Stride<_2,_8>>>{}; print_layout(smem_atom); ``` ``` ((_2,_2),(_2,_2)):((_1,_4),(_2,_8)) 0 1 2 3 +----+----+----+----+ 0 | 0 | 2 | 8 | 10 | +----+----+----+----+ 1 | 1 | 3 | 9 | 11 | +----+----+----+----+ 2 | 4 | 6 | 12 | 14 | +----+----+----+----+ 3 | 5 | 7 | 13 | 15 | +----+----+----+----+ ``` 我們先以swizzle<1, 0, 1>來試試看 ```cuda auto smem_atom = composition(Swizzle<1, 0, 1>{}, Layout<Shape<Shape<_2,_2>, Shape<_2,_2>>, Stride<Stride<_1,_4>, Stride<_2,_8>>>{}); print_layout(smem_atom); ``` 我們可以看到,在每一個2x2的tile中,右邊的那一行的索引值做了交換 ``` Sw<1,0,1> o _0 o ((_2,_2),(_2,_2)):((_1,_4),(_2,_8)) 0 1 2 3 +----+----+----+----+ 0 | 0 | 3 | 8 | 11 | +----+----+----+----+ 1 | 1 | 2 | 9 | 10 | +----+----+----+----+ 2 | 4 | 7 | 12 | 15 | +----+----+----+----+ 3 | 5 | 6 | 13 | 14 | +----+----+----+----+ ``` 緊接著我們嘗試Swizzle<1, 0, -1> ```cuda auto smem_atom = composition(Swizzle<1, 0, -1>{}, Layout<Shape<Shape<_2,_2>, Shape<_2,_2>>, Stride<Stride<_1,_4>, Stride<_2,_8>>>{}); print_layout(smem_atom); ``` 我們可以看到這次反而是每個tile中下面那一列的索引值做了交換 ``` Sw<1,0,-1> o _0 o ((_2,_2),(_2,_2)):((_1,_4),(_2,_8)) 0 1 2 3 +----+----+----+----+ 0 | 0 | 2 | 8 | 10 | +----+----+----+----+ 1 | 3 | 1 | 11 | 9 | +----+----+----+----+ 2 | 4 | 6 | 12 | 14 | +----+----+----+----+ 3 | 7 | 5 | 15 | 13 | +----+----+----+----+ ``` 當S為負值時,我們是以索引值的特定低位元(yyy_msk)去跟特定高位元(zzz_msk)去做XOR,這個從原始碼(source code)可以很容易的看出 S > 0, offset = 0bxxxxxxxxxxxxxxxYYYxxxxxxxZZZxxxx // ZZZ ^ = YYY S < 0, offset = 0bxxxxxxxxxxxxxxxZZZxxxxxxxYYYxxxx // ZZZ ^= YYY 在多維的情況,索引值交換的情形就不像一維那樣直覺,這個方面需要多練習,才可以根據不同的需求,做出不同的編排 參考資料: 1. [https://github.com/NVIDIA/cutlass/blob/main/include/cute/swizzle_layout.hpp][1] 2. [https://github.com/NVIDIA/cutlass/blob/main/include/cute/layout_composed.hpp][2] 3. [https://github.com/NVIDIA/cutlass/blob/main/include/cute/swizzle.hpp][3] 4. [GTC 2020: Developing CUDA kernels to push Tensor Cores to the Absolute Limit on NVIDIA A100][4] 5. 6. [compact layout][6] 7. [[QST] Why have a Swizzle with BBits = 0? #1704][7] 8. [[QST]How to create and use TiledMMA and ThrMMA in cute/atom/mma_atom.hpp #1028][8] 9. [cute 之 Swizzle][9] [1]: https://github.com/NVIDIA/cutlass/blob/main/include/cute/swizzle_layout.hpp [2]: https://github.com/NVIDIA/cutlass/blob/main/include/cute/layout_composed.hpp [3]: https://github.com/NVIDIA/cutlass/blob/main/include/cute/swizzle.hpp [4]: https://developer.download.nvidia.com/video/gputechconf/gtc/2020/presentations/s21745-developing-cuda-kernels-to-push-tensor-cores-to-the-absolute-limit-on-nvidia-a100.pdf [5]:https://github.com/NVIDIA/cutlass/issues/1876 [6]: https://github.com/NVIDIA/cutlass/issues/944 [7]: https://github.com/NVIDIA/cutlass/issues/1704 [8]: https://github.com/NVIDIA/cutlass/issues/1028 [9]: https://zhuanlan.zhihu.com/p/671419093