原範例程式可參考[Cutlass Github][1] ```cuda= /*************************************************************************************************** * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #include <cstdlib> #include <cstdio> #include <cassert> #include <thrust/host_vector.h> #include <thrust/device_vector.h> #include <cute/tensor.hpp> #include "cutlass/util/print_error.hpp" #include "cutlass/util/GPU_Clock.hpp" #include "cutlass/util/helper_cuda.hpp" template <class ProblemShape, class CtaTiler, class TA, class AStride, class ASmemLayout, class AThreadLayout, class TB, class BStride, class BSmemLayout, class BThreadLayout, class TC, class CStride, class CSmemLayout, class CThreadLayout, class Alpha, class Beta> __global__ static __launch_bounds__(decltype(size(CThreadLayout{}))::value) void gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler, TA const* A, AStride dA, ASmemLayout sA_layout, AThreadLayout tA, TB const* B, BStride dB, BSmemLayout sB_layout, BThreadLayout tB, TC * C, CStride dC, CSmemLayout , CThreadLayout tC, Alpha alpha, Beta beta) { using namespace cute; // Preconditions CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); // (M, N, K) CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); // (BLK_M, BLK_N, BLK_K) static_assert(is_static<AThreadLayout>::value); static_assert(is_static<BThreadLayout>::value); static_assert(is_static<CThreadLayout>::value); CUTE_STATIC_ASSERT_V(size(tA) == size(tB)); // NumThreads CUTE_STATIC_ASSERT_V(size(tC) == size(tA)); // NumThreads CUTE_STATIC_ASSERT_V(size<0>(cta_tiler) % size<0>(tA) == Int<0>{}); // BLK_M / THR_M CUTE_STATIC_ASSERT_V(size<2>(cta_tiler) % size<1>(tA) == Int<0>{}); // BLK_K / THR_K CUTE_STATIC_ASSERT_V(size<1>(cta_tiler) % size<0>(tB) == Int<0>{}); // BLK_N / THR_N CUTE_STATIC_ASSERT_V(size<2>(cta_tiler) % size<1>(tB) == Int<0>{}); // BLK_K / THR_K CUTE_STATIC_ASSERT_V(size<0>(cta_tiler) % size<0>(tC) == Int<0>{}); // BLK_M / THR_M CUTE_STATIC_ASSERT_V(size<1>(cta_tiler) % size<1>(tC) == Int<0>{}); // BLK_N / THR_N static_assert(is_static<ASmemLayout>::value); static_assert(is_static<BSmemLayout>::value); static_assert(is_static<CSmemLayout>::value); CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M CUTE_STATIC_ASSERT_V(size<0>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K CUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler)); // BLK_K CUTE_STATIC_ASSERT_V(congruent(select<0,2>(shape_MNK), dA)); // dA strides for shape MK CUTE_STATIC_ASSERT_V(congruent(select<1,2>(shape_MNK), dB)); // dB strides for shape NK CUTE_STATIC_ASSERT_V(congruent(select<0,1>(shape_MNK), dC)); // dC strides for shape MN // // Full and Tiled Tensors // // Represent the full tensors Tensor mA = make_tensor(make_gmem_ptr(A), select<0,2>(shape_MNK), dA); // (M,K) Tensor mB = make_tensor(make_gmem_ptr(B), select<1,2>(shape_MNK), dB); // (N,K) Tensor mC = make_tensor(make_gmem_ptr(C), select<0,1>(shape_MNK), dC); // (M,N) // Get the appropriate blocks for this thread block auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k) Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N) // Shared memory buffers __shared__ TA smemA[cosize_v<ASmemLayout>]; __shared__ TB smemB[cosize_v<BSmemLayout>]; Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout); // (BLK_M,BLK_K) Tensor sB = make_tensor(make_smem_ptr(smemB), sB_layout); // (BLK_N,BLK_K) // // Partition the copying of A and B tiles across the threads // // TUTORIAL: Example of simple raked partitioning of ThreadLayouts tA|tB over data A|B tiles Tensor tAgA = local_partition(gA, tA, threadIdx.x); // (THR_M,THR_K,k) Tensor tAsA = local_partition(sA, tA, threadIdx.x); // (THR_M,THR_K) Tensor tBgB = local_partition(gB, tB, threadIdx.x); // (THR_N,THR_K,k) Tensor tBsB = local_partition(sB, tB, threadIdx.x); // (THR_N,THR_K) CUTE_STATIC_ASSERT_V(size<0>(tAgA) == size<0>(tAsA)); // THR_M CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tAsA)); // THR_K CUTE_STATIC_ASSERT_V(size<0>(tBgB) == size<0>(tBsB)); // THR_N CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBsB)); // THR_K // // Define A/B partitioning and C accumulators // // TUTORIAL: Example of partitioning via projections of a ThreadLayout tC // Partition sA (M,K) by the rows of tC Tensor tCsA = local_partition(sA, tC, threadIdx.x, Step<_1, X>{}); // (THR_M,BLK_K) // Partition sB (N,K) by the cols of tC Tensor tCsB = local_partition(sB, tC, threadIdx.x, Step< X,_1>{}); // (THR_N,BLK_K) // Partition gC (M,N) by the tile of tC Tensor tCgC = local_partition(gC, tC, threadIdx.x, Step<_1,_1>{}); // (THR_M,THR_N) // Allocate the accumulators -- same shape/layout as the partitioned data Tensor tCrC = make_tensor_like(tCgC); // (THR_M,THR_N) CUTE_STATIC_ASSERT_V(size<0>(tCrC) == size<0>(tCgC)); // THR_M CUTE_STATIC_ASSERT_V(size<0>(tCrC) == size<0>(tCsA)); // THR_M CUTE_STATIC_ASSERT_V(size<1>(tCrC) == size<1>(tCgC)); // THR_N CUTE_STATIC_ASSERT_V(size<1>(tCrC) == size<0>(tCsB)); // THR_N CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCsB)); // BLK_K // Clear the accumulators clear(tCrC); #if 0 if(thread0()) { print(" mA : "); print( mA); print("\n"); print(" gA : "); print( gA); print("\n"); print(" sA : "); print( sA); print("\n"); print("tAgA : "); print(tAgA); print("\n"); print("tAsA : "); print(tAsA); print("\n"); } #endif #if 0 if(thread0()) { print(" mB : "); print( mB); print("\n"); print(" gB : "); print( gB); print("\n"); print(" sB : "); print( sB); print("\n"); print("tBgB : "); print(tBgB); print("\n"); print("tBsB : "); print(tBsB); print("\n"); } #endif #if 0 if(thread0()) { print(" mC : "); print( mC); print("\n"); print(" gC : "); print( gC); print("\n"); print("tCsA : "); print(tCsA); print("\n"); print("tCsB : "); print(tCsB); print("\n"); print("tCgC : "); print(tCgC); print("\n"); print("tCrC : "); print(tCrC); print("\n"); } #endif #if 1 // TUTORIAL: Example of a simple mainloop that read tiles of data into shared memory, // and then computes on those tiles. // copy(.) operates on the global and shared memory via the tA|tB partitioning // gemm(.) operates on the shared and register memory via the tC partitioning auto K_TILE_MAX = size<2>(tAgA); for (int k_tile = 0; k_tile < K_TILE_MAX; ++k_tile) { // Copy gmem to smem with tA|tB thread-partitioned tensors copy(tAgA(_,_,k_tile), tAsA); // A (THR_M,THR_K) -> (THR_M,THR_K) copy(tBgB(_,_,k_tile), tBsB); // B (THR_N,THR_K) -> (THR_N,THR_K) // TUTORIAL: The above call to copy(tAgA(_,_,k_tile), tAsA) is equivalent to // Tensor tAgAk = tAgA(_,_,k_tile); // CUTE_UNROLL // for (int i = 0; i < size(tAsA); ++i) { // tAsA(i) = tAgAk(i); // } cp_async_fence(); // Label the end of (potential) cp.async instructions cp_async_wait<0>(); // Sync on all (potential) cp.async instructions __syncthreads(); // Wait for all threads to write to smem // Compute gemm on tC thread-partitioned smem gemm(tCsA, tCsB, tCrC); // (THR_M,THR_N) += (THR_M,BLK_K) * (THR_N,BLK_K) // TUTORIAL: The above call to gemm(tCsA, tCsB, tCrC) is equivalent to // CUTE_UNROLL // for (int k = 0; k < size<1>(tCsA); ++k) { // CUTE_UNROLL // for (int m = 0; m < size<0>(tCrC); ++m) { // CUTE_UNROLL // for (int n = 0; n < size<1>(tCrC); ++n) { // tCrC(m,n) += tCsA(m,k) * tCsB(n,k); // } // } // } __syncthreads(); // Wait for all threads to read from smem } #endif // // Epilogue // axpby(alpha, tCrC, beta, tCgC); // TUTORIAL: The above call to axpby(alpha, tCrC, beta, tCgC) is equivalent to // CUTE_UNROLL // for (int i = 0; i < size(tCsA); ++i) { // tCgC(i) = alpha * tCrC(i) + beta * tCgC(i); // } } // Setup params for an NT GEMM // Use m-major smem sA, n-major smem sB, and mn-major threads tA|tB template <class TA, class TB, class TC, class Alpha, class Beta> void gemm_nt(int m, int n, int k, Alpha alpha, TA const* A, int ldA, TB const* B, int ldB, Beta beta, TC * C, int ldC, cudaStream_t stream = 0) { using namespace cute; // Define shapes (dynamic) auto M = int(m); auto N = int(n); auto K = int(k); auto prob_shape = make_shape(M, N, K); // (M, N, K) // Define NT strides (mixed) auto dA = make_stride(Int<1>{}, ldA); // (dM, dK) auto dB = make_stride(Int<1>{}, ldB); // (dN, dK) auto dC = make_stride(Int<1>{}, ldC); // (dM, dN) // Define CTA tile sizes (static) auto bM = Int<128>{}; auto bN = Int<128>{}; auto bK = Int< 8>{}; auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K) // Define the smem layouts (static) auto sA = make_layout(make_shape(bM, bK)); // (m,k) -> smem_idx; m-major auto sB = make_layout(make_shape(bN, bK)); // (n,k) -> smem_idx; n-major auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx; m-major // Define the thread layouts (static) auto tA = make_layout(make_shape(Int<32>{}, Int< 8>{})); // (m,k) -> thr_idx auto tB = make_layout(make_shape(Int<32>{}, Int< 8>{})); // (n,k) -> thr_idx auto tC = make_layout(make_shape(Int<16>{}, Int<16>{})); // (m,n) -> thr_idx dim3 dimBlock(size(tC)); dim3 dimGrid(size(ceil_div(M, bM)), size(ceil_div(N, bN))); gemm_device<<<dimGrid, dimBlock, 0, stream>>> (prob_shape, cta_tiler, A, dA, sA, tA, B, dB, sB, tB, C, dC, sC, tC, alpha, beta); } // Setup params for a TN GEMM // Use padded m-major smem sA, padded n-major smem sB, and k-major threads tA|tB template <class TA, class TB, class TC, class Alpha, class Beta> void gemm_tn(int m, int n, int k, Alpha alpha, TA const* A, int ldA, TB const* B, int ldB, Beta beta, TC * C, int ldC, cudaStream_t stream = 0) { using namespace cute; // Define shapes (dynamic) auto M = int(m); auto N = int(n); auto K = int(k); auto prob_shape = make_shape(M, N, K); // (M, N, K) // Define TN strides (mixed) auto dA = make_stride(ldA, Int<1>{}); // (dM, dK) auto dB = make_stride(ldB, Int<1>{}); // (dN, dK) auto dC = make_stride(Int<1>{}, ldC); // (dM, dN) // Define CTA tile sizes (static) auto bM = Int<128>{}; auto bN = Int<128>{}; auto bK = Int< 8>{}; auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K) // Define the smem layouts (static) auto sA = make_layout(make_shape(bM,bK), LayoutRight{}); // (m,k) -> smem_idx; k-major auto sB = make_layout(make_shape(bN,bK), LayoutRight{}); // (n,k) -> smem_idx; k-major auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx; m-major // Define the thread layouts (static) auto tA = make_layout(make_shape(Int<32>{}, Int< 8>{}), LayoutRight{}); // (m,k) -> thr_idx; k-major auto tB = make_layout(make_shape(Int<32>{}, Int< 8>{}), LayoutRight{}); // (n,k) -> thr_idx; k-major auto tC = make_layout(make_shape(Int<16>{}, Int<16>{})); // (m,n) -> thr_idx; m-major dim3 dimBlock(size(tC)); dim3 dimGrid(size(ceil_div(M, bM)), size(ceil_div(N, bN))); gemm_device<<<dimGrid, dimBlock, 0, stream>>> (prob_shape, cta_tiler, A, dA, sA, tA, B, dB, sB, tB, C, dC, sC, tC, alpha, beta); } template <class TA, class TB, class TC, class Alpha, class Beta> void gemm(char transA, char transB, int m, int n, int k, Alpha alpha, TA const* A, int ldA, TB const* B, int ldB, Beta beta, TC * C, int ldC, cudaStream_t stream = 0) { if (transA == 'N' && transB == 'T') { return gemm_nt(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream); } else if (transA == 'T' && transB == 'N') { return gemm_tn(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream); } assert(false && "Not implemented"); } int main(int argc, char** argv) { int m = 5120; if (argc >= 2) sscanf(argv[1], "%d", &m); int n = 5120; if (argc >= 3) sscanf(argv[2], "%d", &n); int k = 4096; if (argc >= 4) sscanf(argv[3], "%d", &k); char transA = 'N'; if (argc >= 5) sscanf(argv[4], "%c", &transA); char transB = 'T'; if (argc >= 6) sscanf(argv[5], "%c", &transB); using TA = float; using TB = float; using TC = float; using TI = float; TI alpha = 1.0; TI beta = 0.0; std::cout << "M = " << m << std::endl; std::cout << "N = " << n << std::endl; std::cout << "K = " << k << std::endl; std::cout << "C = A^" << transA << " B^" << transB << std::endl; cute::device_init(0); thrust::host_vector<TA> h_A(m*k); thrust::host_vector<TB> h_B(n*k); thrust::host_vector<TC> h_C(m*n); for (int j = 0; j < m*k; ++j) h_A[j] = static_cast<TA>( 2*(rand() / double(RAND_MAX)) - 1 ); for (int j = 0; j < n*k; ++j) h_B[j] = static_cast<TB>( 2*(rand() / double(RAND_MAX)) - 1 ); for (int j = 0; j < m*n; ++j) h_C[j] = static_cast<TC>(-1); thrust::device_vector<TA> d_A = h_A; thrust::device_vector<TB> d_B = h_B; thrust::device_vector<TC> d_C = h_C; double gflops = (2.0*m*n*k) * 1e-9; const int timing_iterations = 100; GPU_Clock timer; int ldA = 0, ldB = 0, ldC = m; if (transA == 'N') { ldA = m; } else if (transA == 'T') { ldA = k; } else { assert(false); } if (transB == 'N') { ldB = k; } else if (transB == 'T') { ldB = n; } else { assert(false); } // Run once d_C = h_C; gemm(transA, transB, m, n, k, alpha, d_A.data().get(), ldA, d_B.data().get(), ldB, beta, d_C.data().get(), ldC); CUTE_CHECK_LAST(); thrust::host_vector<TC> cute_result = d_C; // Timing iterations timer.start(); for (int i = 0; i < timing_iterations; ++i) { gemm(transA, transB, m, n, k, alpha, d_A.data().get(), ldA, d_B.data().get(), ldB, beta, d_C.data().get(), ldC); } double cute_time = timer.seconds() / timing_iterations; CUTE_CHECK_LAST(); printf("CUTE_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cute_time, cute_time*1000); return 0; } ``` 我們可以先將debug message打開,將相關的shape及tiling印出來,再搭配程式碼來解讀。另外,本文預設都會是以thread 0的角度,來分析指派給thread 0的任務,若有其它thread參與其中,會再另行標出 ```cuda=156 #if 1 if(thread0()) { print(" mA : "); print( mA); print("\n"); print(" gA : "); print( gA); print("\n"); print(" sA : "); print( sA); print("\n"); print("tAgA : "); print(tAgA); print("\n"); print("tAsA : "); print(tAsA); print("\n"); } #endif #if 1 if(thread0()) { print(" mB : "); print( mB); print("\n"); print(" gB : "); print( gB); print("\n"); print(" sB : "); print( sB); print("\n"); print("tBgB : "); print(tBgB); print("\n"); print("tBsB : "); print(tBsB); print("\n"); } #endif #if 1 if(thread0()) { print(" mC : "); print( mC); print("\n"); print(" gC : "); print( gC); print("\n"); print("tCsA : "); print(tCsA); print("\n"); print("tCsB : "); print(tCsB); print("\n"); print("tCgC : "); print(tCgC); print("\n"); print("tCrC : "); print(tCrC); print("\n"); } #endif ``` 我們並沒有更改預設矩陣M, N, K的值,所以MxNxK = 5120x5120x4096,此外,A矩陣是非轉置格式(Not transposed, M-major, column-major),而B矩陣是轉置格式(Transposed, N-major, row-major)。至於mA, gA, sA, tAgA, tAsA的意義我們會留待後面再解釋 > M = 5120 N = 5120 K = 4096 C = A^N B^T Using device 0: Tesla T4 (SM75, 40 SMs) mA : gmem_ptr[32b](0x7e5772000000) o (5120,4096):(_1,5120) gA : gmem_ptr[32b](0x7e5772000000) o (_128,_8,512):(_1,5120,40960) sA : smem_ptr[32b](0x7e57b3000000) o (_128,_8):(_1,_128) tAgA : gmem_ptr[32b](0x7e5772000000) o (_4,_1,512):(_32,_0,40960) tAsA : smem_ptr[32b](0x7e57b3000000) o (_4,_1):(_32,_0) mB : gmem_ptr[32b](0x7e576c000000) o (5120,4096):(_1,5120) gB : gmem_ptr[32b](0x7e576c000000) o (_128,_8,512):(_1,5120,40960) sB : smem_ptr[32b](0x7e57b3001000) o (_128,_8):(_1,_128) tBgB : gmem_ptr[32b](0x7e576c000000) o (_4,_1,512):(_32,_0,40960) tBsB : smem_ptr[32b](0x7e57b3001000) o (_4,_1):(_32,_0) mC : gmem_ptr[32b](0x7e5764000000) o (5120,5120):(_1,5120) gC : gmem_ptr[32b](0x7e5764000000) o (_128,_128):(_1,5120) tCsA : smem_ptr[32b](0x7e57b3000000) o (_8,_8):(_16,_128) tCsB : smem_ptr[32b](0x7e57b3001000) o (_8,_8):(_16,_128) tCgC : gmem_ptr[32b](0x7e5764000000) o (_8,_8):(_16,81920) tCrC : ptr[32b](0x7e57b1fffbe0) o (_8,_8):(_1,_8) 我們先從mA, mB, mC看起,它們是三塊GMEM(Global Memory) ```cuda=97 // Represent the full tensors Tensor mA = make_tensor(make_gmem_ptr(A), select<0,2>(shape_MNK), dA); // (M,K) Tensor mB = make_tensor(make_gmem_ptr(B), select<1,2>(shape_MNK), dB); // (N,K) Tensor mC = make_tensor(make_gmem_ptr(C), select<0,1>(shape_MNK), dC); // (M,N) ``` ,維度如下圖所示 ![image](https://hackmd.io/_uploads/SkuRGHBBJx.png) 我們要做的是C = A*B + C這樣的矩陣乘法累加,並利用Cuda/Tensor Core來分配工作,達成平行運算的目的 宏觀的來說,整個過程會類似下圖,我們會切割出一塊塊的小block(一般稱做tile,就像貼磁磚一樣),將它從全域記憶體(global memory)搬到共享記憶體(shared memory)中,當我們需要做MMA(Matrix Multiply Accumulate)時,再去針對不同的執行緒(thread)去設定共享記憶體(share memory)跟暫存器檔案(register file)的對應關係 ![image](https://hackmd.io/_uploads/HySGjQHr1g.png) 而gA, gB, gC則是依照我們先前傳入的cta_tiler(Cooperative Thread Array指的就是thread block)對mA, mB, mC去做分割 ```cuda=102 // Get the appropriate blocks for this thread block auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k) Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N) ``` cta_tiler的定義則在之前已經先定義好了,這邊習慣上我們會用BLK_M, BLK_N, BLK_K來做註解,各代表在A, B, C矩陣中,block的大小 ```cuda=272 // Define CTA tile sizes (static) auto bM = Int<128>{}; auto bN = Int<128>{}; auto bK = Int< 8>{}; auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K) ``` ![image](https://hackmd.io/_uploads/S1rjZ5BB1e.png) 我們再回頭看一下cta_coord這個變數,這個變數是方便我們之後在做block分割時,好依照每個thread id去分配相對應的block,裡頭的參數有_,它類似於python的slice概念,表示這個維度下的資料我全都要。之後,我們可以呼叫local_tile函數來做block分割,可以看一下此函數裡的參數,mA是我們想要分割的矩陣A,cta_tiler則是想要分割的block大小,cta_coord則是依據thread id,來決定每個thread要分配到分割完後的那一個block,最後一個參數Step<_1, X, _1>{},它有點類似flag的作用,會決定我們在套用cta_tiler及cta_coord時,要不要把某個維度考慮進去,比如此例中,我們只考慮M和K,因為我們現在要分割的是A矩陣,所以只需考慮M及K維度,如果是B矩陣的話,因為只考慮N及K維度,所以會傳入Step<X, _1, _1>{}。由於我們之前的cta_coord在k維度是搭配cta_tiler的作用下,我們最後會在A矩陣中的到一整列的row block (即gA),而在B矩陣中,則會得到一整行的column block (即gB) ![image](https://hackmd.io/_uploads/SJ7zFEUS1e.png) ![image](https://hackmd.io/_uploads/BJu2Nr8Hye.png) 會得到這樣的結果,其實是很直覺的,仔細觀察C矩陣,我們若要得到C矩陣上某元素的值,勢必要從相對應的A矩陣的列(row),乘上相對應B矩陣的行(column),所以這樣的矩陣分割是很合理的 還記得我們之前印出跟mA, gA, mB, gB相關的debug message嗎 >mA : gmem_ptr[32b](0x7e5772000000) o (5120,4096):(_1,5120) gA : gmem_ptr[32b](0x7e5772000000) o (_128,_8,512):(_1,5120,40960) mB : gmem_ptr[32b](0x7e576c000000) o (5120,4096):(_1,5120) gB : gmem_ptr[32b](0x7e576c000000) o (_128,_8,512):(_1,5120,40960) 從它的shape及stride,大致可以了解它的memory layout,我們大概畫一下圖即可了解 ![image](https://hackmd.io/_uploads/ryooKtUBke.png) ![image](https://hackmd.io/_uploads/ryUTKFUB1x.png) ![image](https://hackmd.io/_uploads/Hy4eqtLrJg.png) ![image](https://hackmd.io/_uploads/r1L1sYUHye.png) 接下來,我們可以觀察到sA, sB的大小是(BLK_M, BLK_K)及(BLK_N, BLK_K),在這個例子中分別是(128, 8)及(128, 8),這2塊記憶體的功用在於我們會從gA、gB複製一小塊BLOCK到sA、sB,之後再從這2塊記憶體讀取出一小塊fragment到register file做運算 ```cuda=108 // Shared memory buffers __shared__ TA smemA[cosize_v<ASmemLayout>]; __shared__ TB smemB[cosize_v<BSmemLayout>]; Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout); // (BLK_M,BLK_K) Tensor sB = make_tensor(make_smem_ptr(smemB), sB_layout); // (BLK_N,BLK_K) ``` 同樣的,我們也可以畫出它們的layout ![image](https://hackmd.io/_uploads/By1bu9ISke.png) 緊接著,我們來看最關鍵的一部分,也是如何將sA, sB分配給不同的執行緒(thread),我們可以先從tA, tB, tC看起,它們的thread layout分別是32\*8, 32\*8, 16\*16,這邊的共通點是,它們都有256個執行緒,只是有不同的layout ```cuda=283 // Define the thread layouts (static) auto tA = make_layout(make_shape(Int<32>{}, Int< 8>{}), LayoutRight{}); // (m,k) -> thr_idx; k-major auto tB = make_layout(make_shape(Int<32>{}, Int< 8>{}), LayoutRight{}); // (n,k) -> thr_idx; k-major auto tC = make_layout(make_shape(Int<16>{}, Int<16>{})); // (m,n) -> thr_idx; m-major ``` ![image](https://hackmd.io/_uploads/B1q7DRIHyg.png) 之後要跑下面這段程式碼時,就會將前面thread layout套用到之前討論到的gA, sA, gB, sB,做thread partition的動作 ```cuda= Tensor tAgA = local_partition(gA, tA, threadIdx.x); // (THR_M,THR_K,k) Tensor tAsA = local_partition(sA, tA, threadIdx.x); // (THR_M,THR_K) Tensor tBgB = local_partition(gB, tB, threadIdx.x); // (THR_N,THR_K,k) Tensor tBsB = local_partition(sB, tB, threadIdx.x); // (THR_N,THR_K) ``` 我們可以看到location_partition的參數一共有3個,分別是gA, tA及threadIdx.x,我們可以想像利用tA這樣的thread tile對gA做分割時,會類似下圖這樣的情形 ![image](https://hackmd.io/_uploads/HJYbUswH1e.png) ![image](https://hackmd.io/_uploads/rJJQUiwS1e.png) 這動作像是在貼磁磚,會將整個block給鋪滿,由於gA的layout是(128, 8, 512),tA是(32, 8),我們可以預期利用tA去對gA分割時,它的layout會是(4, 1, 512),通常我們用tAgA這樣的命名習慣(naming convention)來表示分割後的結果 ![image](https://hackmd.io/_uploads/SJyvIsDHye.png) ![image](https://hackmd.io/_uploads/SkI_8sDSke.png) 我們還有一個東西忘了講,local_partition的參數中,還有threadIdx.x,表示在每個thread tile中,要對應到那一個thread id,本例中,在一個(32,8)的thread tile裡,thread 0會對應到左上角 ![image](https://hackmd.io/_uploads/ByuqM2wHyg.png) 所以真正的tAgA應該會長得像下面這樣子 ![image](https://hackmd.io/_uploads/BkFZ66PHJe.png) 如果是thread 1的話,會長這樣子 ![image](https://hackmd.io/_uploads/rJaPATvSkx.png) 我們可以再回去看一下關於tAgA, tAgB的debug message,跟我們畫出來的圖是相符的,在每個block中有4個值,而沿著k dimension有512個block >tAgA : gmem_ptr[32b](0x7e5772000000) o (_4,_1,512):(_32,_0,40960) >tBgB : gmem_ptr[32b](0x7e576c000000) o (_4,_1,512):(_32,_0,40960) 同樣的當套用tA到sA時,也會有一樣的thread partition ![image](https://hackmd.io/_uploads/B169yCDHJg.png) 終於到了最後一步,我們都還沒有提到關於矩陣C的分割方式,我們可以注意到像tC這樣的partition pattern主要會套用至sA, sB及rC,rC是做為accumulate之前的結果,所分配的一塊記憶體,還記得我們的矩陣運算是C = A*B + C嗎, ```cuda=137 // Partition sA (M,K) by the rows of tC Tensor tCsA = local_partition(sA, tC, threadIdx.x, Step<_1, X>{}); // (THR_M,BLK_K) // Partition sB (N,K) by the cols of tC Tensor tCsB = local_partition(sB, tC, threadIdx.x, Step< X,_1>{}); // (THR_N,BLK_K) // Partition gC (M,N) by the tile of tC Tensor tCgC = local_partition(gC, tC, threadIdx.x, Step<_1,_1>{}); // (THR_M,THR_N) // Allocate the accumulators -- same shape/layout as the partitioned data Tensor tCrC = make_tensor_like(tCgC); // (THR_M,THR_N) ``` 這樣的local_partition所傳入的參數,大致上跟之前是差不多的,唯一的不同是,我們依據要分割的是sA, sB, gC,而會傳入不同的Step參數,以tCsA來說,我們只考慮M維度,而tCsB只考慮N維度,這是因為我們若想求得某個C矩陣上元素的值,我們只要有相對應列(from A)跟行(from B)即可 ![image](https://hackmd.io/_uploads/B1ufLCwHJg.png) gC是我們最後要輸出的一塊記憶體,我們可以宣告另一塊記憶體rC,它的配置和gC相同,只是位在RMEM 當我們的thead partition都搞定之後,我們就可以讓每個thread沿著維度K,將相對應的block從GMEM載入到SMEM,並呼叫GEMM(General Matrix Multiply),做矩陣乘法累加 ```cuda=196 for (int k_tile = 0; k_tile < K_TILE_MAX; ++k_tile) { // Copy gmem to smem with tA|tB thread-partitioned tensors copy(tAgA(_,_,k_tile), tAsA); // A (THR_M,THR_K) -> (THR_M,THR_K) copy(tBgB(_,_,k_tile), tBsB); // B (THR_N,THR_K) -> (THR_N,THR_K) // TUTORIAL: The above call to copy(tAgA(_,_,k_tile), tAsA) is equivalent to // Tensor tAgAk = tAgA(_,_,k_tile); // CUTE_UNROLL // for (int i = 0; i < size(tAsA); ++i) { // tAsA(i) = tAgAk(i); // } cp_async_fence(); // Label the end of (potential) cp.async instructions cp_async_wait<0>(); // Sync on all (potential) cp.async instructions __syncthreads(); // Wait for all threads to write to smem // Compute gemm on tC thread-partitioned smem gemm(tCsA, tCsB, tCrC); // (THR_M,THR_N) += (THR_M,BLK_K) * (THR_N,BLK_K) // TUTORIAL: The above call to gemm(tCsA, tCsB, tCrC) is equivalent to // CUTE_UNROLL // for (int k = 0; k < size<1>(tCsA); ++k) { // CUTE_UNROLL // for (int m = 0; m < size<0>(tCrC); ++m) { // CUTE_UNROLL // for (int n = 0; n < size<1>(tCrC); ++n) { // tCrC(m,n) += tCsA(m,k) * tCsB(n,k); // } // } // } __syncthreads(); // Wait for all threads to read from smem } ``` 參考資料: 1. [sgemm_1.cu][1] 2. [CUTLASS: Fast Linear Algebra in CUDA C++][2] 3. [CuTe dense matrix-matrix multiply tutorial][3] [1]:https://github.com/NVIDIA/cutlass/blob/main/examples/cute/tutorial/sgemm_1.cu [2]:https://developer.nvidia.com/blog/cutlass-linear-algebra-cuda/ [3]:https://github.com/NVIDIA/cutlass/blob/main/media/docs/cute/0x_gemm_tutorial.md