此文只是記錄一下在解[Triton Puzzles][1]的過程。由於Triton的說明文件並不多,所以藉由做中學來熟悉一些基本概念,以下用到的一些命名習慣(naming convention),基本上會類似於官方的[tutorial][2],建議先將官方的tutorial看過一遍會比較有感覺。Triton做為橋接PyTorch跟Cuda的MLIR (Multi-Level Intermediate Representation),希望能達到類似於Cuda那樣的效能發揮,但又有PyTorch那樣易於撰寫的特性(相較於Cuda而言),裡面的一些觀念需要你對Cuda有一些基本的了解會比較好[^3^][3] [^4^][4] [^5^][5],尤其是平行化的運算思維。有任何較好的寫法(比如puzzle 10),或任何的錯誤,歡迎指教,謝謝! # Puzzle 1: Constant Add Add a constant to a vector. Uses one program id axis. Block size `B0` is always the same as vector `x` with length `N0`. $$ z_i = 10 + x_i\; for\; i = 1\; ...\; N_0$$ ```triton def add_spec(x: Float32[Tensor, "32"]) -> Float32[Tensor, "32"]: "This is the spec that you should implement. Uses typing to define sizes." return x + 10. @triton.jit def add_kernel(x_ptr, z_ptr, N0, B0: tl.constexpr): x_range = tl.arange(0, B0) x = tl.load(x_ptr + x_range) # Finish me! x = x + 10 # vector add x_mask = x_range < N0 z_ptrs = z_ptr + x_range tl.store(z_ptrs, x, x_mask) return test(add_kernel, add_spec, nelem={"N0": 32}, viz=True) ``` # Puzzle 2: Constant Add Block Add a constant to a vector. Uses one program block axis (no for loops yet). Block size `B0` is now smaller than the shape vector `x` which is `N0`. $$ z_i = 10 + x_i\; for\; i=1\; ...\; N_0 $$ ```tritoon def add2_spec(x: Float32[Tensor, "200"]) -> Float32[Tensor, "200"]: return x + 10. @triton.jit def add_mask2_kernel(x_ptr, z_ptr, N0, B0: tl.constexpr): pid_0 = tl.program_id(0) offs = pid_0 * B0 + tl.arange(0, B0) mask = offs < N0 x_ptrs = x_ptr + offs x = tl.load(x_ptrs, mask) x = x + 10 z_ptrs = z_ptr + offs tl.store(z_ptr + offs, x, mask) return test(add_mask2_kernel, add2_spec, nelem={"N0": 200}) ``` # Puzzle 3: Outer Vector Add Add two vectors. Uses one program block axis. Block size `B0` is always the same as vector `x` length `N0`. Block size `B1` is always the same as vector `y` length `N1`. $$ z_{j,i} = x_i + y_j\; for\; i=1\; ...\; B_0\; ,\; j =1\; ...\; B_1$$ ```triton def add2_spec(x: Float32[Tensor, "200"]) -> Float32[Tensor, "200"]: return x + 10. @triton.jit def add_mask2_kernel(x_ptr, z_ptr, N0, B0: tl.constexpr): pid_0 = tl.program_id(0) offs = pid_0 * B0 + tl.arange(0, B0) mask = offs < N0 x_ptrs = x_ptr + offs x = tl.load(x_ptrs, mask) x = x + 10 z_ptrs = z_ptr + offs tl.store(z_ptr + offs, x, mask) return test(add_mask2_kernel, add2_spec, nelem={"N0": 200}) ``` # Puzzle 4: Outer Vector Add Block Add a row vector to a column vector. Uses two program block axes. Block size `B0` is always less than the vector `x` length `N0`. Block size `B1` is always less than vector `y` length `N1`. $$ z_{j,i} = x_i + y_j\; for\; i = 1\; ...\; N_0, j=1\; ...\; N_1$$ ```triton def add_vec_spec(x: Float32[Tensor, "32"], y: Float32[Tensor, "32"]) -> Float32[Tensor, "32 32"]: return x[None, :] + y[:, None] @triton.jit def add_vec_kernel(x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr): offs_x = tl.arange(0, B0) offs_y = tl.arange(0, B1) x_mask = offs_x < N0 y_mask = offs_y < N1 x_ptrs = x_ptr + offs_x y_ptrs = y_ptr + offs_y x = tl.load(x_ptrs, x_mask) y = tl.load(y_ptrs, y_mask) z = x[None, :] + y[:, None] offs_z = offs_y[:, None] * N0 + offs_x[None, :] z_mask = (offs_y[:, None] < N1) & (offs_x[None, :] < N0) tl.store(z_ptr + offs_z, z, z_mask) return test(add_vec_kernel, add_vec_spec, nelem={"N0": 32, "N1": 32}) ``` # Puzzle 5: Fused Outer Multiplication Multiply a row vector to a column vector and take a relu. Uses two program block axes. Block size `B0` is always less than the vector `x` length `N0`. Block size `B1` is always less than vector `y` length `N1`. $$ z_{j,i} = relu(x_i\,\times\,y_j) \; for i = 1\; ...\; N_0, j = 1\; ...\; N_1 $$ ```triton def mul_relu_block_spec(x: Float32[Tensor, "100"], y: Float32[Tensor, "90"]) -> Float32[Tensor, "90 100"]: return torch.relu(x[None, :] * y[:, None]) @triton.jit def mul_relu_block_kernel(x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr): pid_0 = tl.program_id(0) pid_1 = tl.program_id(1) offs_x = pid_0 * B0 + tl.arange(0, B0) offs_y = pid_1 * B1 + tl.arange(0, B1) x_mask = offs_x < N0 y_mask = offs_y < N1 x_ptrs = x_ptr + offs_x y_ptrs = y_ptr + offs_y x = tl.load(x_ptrs, x_mask) y = tl.load(y_ptrs, y_mask) z = x[None, :] * y[:, None] z = tl.where(z > 0, z, 0) offs_z = offs_y[:, None] * N0 + offs_x[None, :] z_mask = (offs_y[:, None] < N1) & (offs_x[None, :] < N0) z_ptrs = z_ptr + offs_z tl.store(z_ptrs, z, z_mask) return test(mul_relu_block_kernel, mul_relu_block_spec, nelem={"N0": 100, "N1": 90}) ``` # Puzzle 6: Fused Outer Multiplication - Backwards Backwards of a function that multiplies a matrix with a row vector and take a relu. Uses two program blocks. Block size `B0` is always less than the vector `x` length `N0`. Block size `B1` is always less than vector `y` length `N1`. Chain rule backward `dz` is of shape `N1` by N0 $$ f(x,y) = relu(x_i \times y_j) for i = 1\; ...\; N_0, j = 1\; ...\; N_1$$ $$ dx_{i,j} = f_x'(x, y)_{i,j} \times dz_{i,j} $$ ```triton def mul_relu_block_back_spec(x: Float32[Tensor, "90 100"], y: Float32[Tensor, "90"], dz: Float32[Tensor, "90 100"]) -> Float32[Tensor, "90 100"]: x = x.clone() y = y.clone() x = x.requires_grad_(True) y = y.requires_grad_(True) z = torch.relu(x * y[:, None]) z.backward(dz) dx = x.grad return dx @triton.jit def mul_relu_block_back_kernel(x_ptr, y_ptr, dz_ptr, dx_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr): pid_0 = tl.program_id(0) pid_1 = tl.program_id(1) offs_col = pid_0 * B0 + tl.arange(0, B0) offs_row = pid_1 * B1 + tl.arange(0, B1) offs_x = offs_col[None, :] + offs_row[:, None] * N0 offs_y = offs_row x_mask = (offs_col[None, :] < N0) & (offs_row[:, None] < N1) y_mask = offs_y < N1 x_ptrs = x_ptr + offs_x y_ptrs = y_ptr + offs_y x = tl.load(x_ptrs, x_mask) y = tl.load(y_ptrs, y_mask) z = x * y[:, None] z = tl.where(z > 0, z, 0) offs_dz = offs_row[:, None] * N0 + offs_col[None, :] dz_mask = (offs_row[:, None] < N1) & (offs_col[None, :] < N0) dz_ptrs = dz_ptr + offs_dz dz = tl.load(dz_ptrs, dz_mask) dz = tl.where(z > 0, dz, 0) # gradient only flow through positive part for relu dx = y[:, None] * dz # jacobian vector product dx_mask = dz_mask offs_dx = offs_row[:, None] * N0 + offs_col[None, :] tl.store(dx_ptr + offs_dx, dx, dx_mask) return test(mul_relu_block_back_kernel, mul_relu_block_back_spec, nelem={"N0": 100, "N1": 90}) ``` # Puzzle 7: Long Sum Sum of a batch of numbers. Uses one program blocks. Block size `B0` represents a range of batches of `x` of length `N0`. Each element is of length `T`. Process it `B1` < `T` elements at a time. $$ z_i = \sum_{j}^{T} x_{i,j} \; ,\; for\; i = 1\; ...\; N_0 $$ Hint: You will need a for loop for this problem. These work and look the same as in Python. ```triton def sum_spec(x: Float32[Tensor, "4 200"]) -> Float32[Tensor, "4"]: return x.sum(1) @triton.jit def sum_kernel(x_ptr, z_ptr, N0, N1, T, B0: tl.constexpr, B1: tl.constexpr): pid_0 = tl.program_id(0) accum = tl.zeros((B0,), tl.float32) offs_row = pid_0 * B0 + tl.arange(0, B0) for i in range((T + B1 - 1) / B1): offs_col = i * B1 + tl.arange(0, B1) offs_block = offs_row[:, None] * T + offs_col[None, :] block_mask = (offs_row[:, None] < N0) & (offs_col[None, :] < T) block_ptrs = x_ptr + offs_block block = tl.load(block_ptrs, block_mask) accum += tl.sum(block, axis=1) offs_z = offs_row z_mask = offs_z < N0 z_ptrs = z_ptr + offs_z z = tl.store(z_ptr + offs_z, accum, z_mask) return test(sum_kernel, sum_spec, B={"B0": 1, "B1": 32}, nelem={"N0": 4, "N1": 32, "T": 200}) ``` # Puzzle 8: Long Softmax Softmax of a batch of logits. Uses one program block axis. Block size `B0` represents the batch of `x` of length `N0`. Block logit length `T`. Process it `B1` < `T` elements at a time. $$ z_{i,j} = softmax(x_{i,1}\; ...\; x_{i,T})\; for i = 1\; ...\; N_0 $$ Note softmax needs to be computed in numerically stable form as in Python. In addition in Triton they recommend not using `exp` but instead using `exp2`. You need the identity $$ exp(x) = 2^{log_{2}(e)x} $$ Advanced: there one way to do this with 3 loops. You can also do it with 2 loops if you are clever. Hint: you will find this identity useful: $$ exp(x_i - m) = exp(x_i - m/2 - m/2) = exp(x_i - m/2)\;/exp\;(m/2)$$ ```triton def softmax_spec(x: Float32[Tensor, "4 200"]) -> Float32[Tensor, "4 200"]: x_max = x.max(1, keepdim=True)[0] x = x - x_max x_exp = x.exp() return x_exp / x_exp.sum(1, keepdim=True) @triton.jit def softmax_kernel(x_ptr, z_ptr, N0, N1, T, B0: tl.constexpr, B1: tl.constexpr): pid_0 = tl.program_id(0) log2_e = 1.44269504 # reference: # - https://courses.cs.washington.edu/courses/cse599m/23sp/notes/flashattn.pdf # 2-loop solution m_last = tl.zeros((B0,), tl.float32) - float('inf') # init -inf m = tl.zeros((B0,), tl.float32) - float('inf') # init -inf d = tl.zeros((B0,), tl.float32) num_iters = (T + B1 - 1) / B1 row_offs = pid_0 * B0 + tl.arange(0, B0) for i in range(num_iters): col_offs = i * B1 + tl.arange(0, B1) offs_x = row_offs[:, None] * T + col_offs[None, :] x_mask = (row_offs[:, None] < N0) & (col_offs[None, :] < T) x_ptrs = x_ptr + offs_x x = tl.load(x_ptrs, x_mask) # get current max, rescale denominator m = tl.maximum(tl.max(x, axis=1), m_last) d = d * tl.exp(m_last - m) + tl.sum(tl.exp(x - m[:, None]), axis=1) m_last = m for i in range(num_iters): col_offs = i * B1 + tl.arange(0, B1) offs_x = row_offs[:, None] * T + col_offs[None, :] x_mask = (row_offs[:, None] < N0) & (col_offs[None, :] < T) x_ptrs = x_ptr + offs_x x = tl.load(x_ptrs, x_mask) soft = tl.div_rn(tl.exp(x - m), d) offs_z = offs_x z_mask = x_mask z_ptrs = z_ptr + offs_z tl.store(z_ptrs, soft, z_mask) return ''' # 3-loop solution m = tl.zeros((B0,), tl.float32) - float('inf') d = tl.zeros((B0,), tl.float32) num_iters = (T + B1 - 1) / B1 row_offs = pid_0 * B0 + tl.arange(0, B0) for i in range(num_iters): col_offs = i * B1 + tl.arange(0, B1) offs_x = row_offs[:, None] * T + col_offs[None, :] x_mask = (row_offs[:, None] < N0) & (col_offs[None, :] < T) x_ptrs = x_ptr + offs_x x = tl.load(x_ptrs, x_mask) m = tl.maximum(tl.max(x, axis=1), m) for i in range(num_iters): col_offs = i * B1 + tl.arange(0, B1) offs_x = row_offs[:, None] * T + col_offs[None, :] x_mask = (row_offs[:, None] < N0) & (col_offs[None, :] < T) x_ptrs = x_ptr + offs_x x = tl.load(x_ptrs, x_mask) d += tl.sum(tl.exp(x - m), axis=1) for i in range(num_iters): col_offs = i * B1 + tl.arange(0, B1) offs_x = row_offs[:, None] * T + col_offs[None, :] x_mask = (row_offs[:, None] < N0) & (col_offs[None, :] < T) x_ptrs = x_ptr + offs_x x = tl.load(x_ptrs, x_mask) soft = tl.div_rn(tl.exp(x - m), d) offs_z = offs_x z_mask = x_mask z_ptrs = z_ptr + offs_z z = tl.store(z_ptrs, soft, z_mask) return ''' test(softmax_kernel, softmax_spec, B={"B0": 1, "B1":32}, nelem={"N0": 4, "N1": 32, "T": 200}) ``` # Puzzle 9: Simple FlashAttention A scalar version of FlashAttention. Uses zero programs. Block size `B0` represents `k` of length `N0`. Block size `B0` represents `q` of length `N0`. Block size `B0` represents `v` of length `N0`. Sequence length is `T`. Process it `B1` < `T` elements at a time. $$ z_i = \sum_{j}softmax(q_ik_1\; ...\; q_ik_T)_jv_j\quad for\; i = 1\; ...\; N_0$$ This can be done in 1 loop using a similar trick from the last puzzle. ```triton def flashatt_spec(q: Float32[Tensor, "200"], k: Float32[Tensor, "200"], v: Float32[Tensor, "200"]) -> Float32[Tensor, "200"]: x = q[:, None] * k[None, :] x_max = x.max(1, keepdim=True)[0] x = x - x_max x_exp = x.exp() soft = x_exp / x_exp.sum(1, keepdim=True) return (v[None, :] * soft).sum(1) @triton.jit def flashatt_kernel(q_ptr, k_ptr, v_ptr, z_ptr, N0, T, B0: tl.constexpr): B0 = triton.next_power_of_2(B0) offs_word = tl.arange(0, B0) q_mask = offs_word < N0 k_mask = offs_word < N0 v_mask = offs_word < N0 q_ptrs = q_ptr + offs_word k_ptrs = k_ptr + offs_word v_ptrs = v_ptr + offs_word q = tl.load(q_ptrs, q_mask, other=0) k = tl.load(k_ptrs, k_mask, other=0) v = tl.load(v_ptrs, v_mask, other=0) # self attention by softmax(qk)v qk = q[:, None] * k[None, :] qk_max = qk.max(axis=1, keep_dims=True) qk = qk - qk_max # safe softmax qk_exp = tl.exp(qk) qk_soft = qk_exp / qk_exp.sum(1, keep_dims=True) z = (v[None, :] * qk_soft).sum(1) z_mask = offs_word < N0 z_ptrs = z_ptr + offs_word tl.store(z_ptrs, z, z_mask) return test(flashatt_kernel, flashatt_spec, B={"B0":200}, nelem={"N0": 200, "T": 200}) ``` # Puzzle 10: Two Dimensional Convolution A batched 2D convolution. Uses one program id axis. Block size `B0` represent the batches to process out of `N0`. Image `x` is size is `H` by `W` with only 1 channel, and kernel `k` is size `KH` by `KW`. $$ z_{i,j,k} = \sum_{oj, ok} k_{oj,ok} \times x_{i,\:j+oj,\:k+ok}\; for i = 1; ...\; N_0 $$ ```triton def conv2d_spec(x: Float32[Tensor, "4 8 8"], k: Float32[Tensor, "4 4"]) -> Float32[Tensor, "4 8 8"]: z = torch.zeros(4, 8, 8) x = torch.nn.functional.pad(x, (0, 4, 0, 4, 0, 0), value=0.0) print(x.shape, k.shape) for i in range(8): for j in range(8): z[:, i, j] = (k[None, :, :] * x[:, i: i+4, j: j + 4]).sum(1).sum(1) return z @triton.jit def conv2d_kernel(x_ptr, k_ptr, z_ptr, N0, H, W, KH: tl.constexpr, KW: tl.constexpr, B0: tl.constexpr): pid_0 = tl.program_id(0) batch_offset = pid_0 * (B0 * H * W) + tl.arange(0, B0) row_offs_k = tl.arange(0, KH) col_offs_k = tl.arange(0, KW) offs_kernel = col_offs_k[None, :] + row_offs_k[:, None] * KW kernel_mask = (col_offs_k[None, :] < KW) & (row_offs_k[:, None] < KH) kernel_ptrs = k_ptr + offs_kernel kernel = tl.load(kernel_ptrs, kernel_mask) for i in tl.range(H): for j in tl.range(W): row_offs_block = i + tl.arange(0, KH) col_offs_block = j + tl.arange(0, KW) block_mask = (row_offs_block[:, None] < H) & (col_offs_block[None, :] < W) block_ptrs = x_ptr + batch_offset + row_offs_block[:, None] * W + col_offs_block[None, :] block = tl.load(block_ptrs, block_mask, other=0) z = (block * kernel).sum() z_ptrs = z_ptr + batch_offset + i * W + j tl.store(z_ptrs, z) return test(conv2d_kernel, conv2d_spec, B={"B0": 1}, nelem={"N0": 4, "H": 8, "W": 8, "KH": 4, "KW": 4}) ``` # Puzzle 11: Matrix Multiplication A blocked matrix multiplication. Uses three program id axes. Block size `B2` represent the batches to process out of `N2`. Block size `B0` represent the rows of `x` to process out of `N0`. Block size `B1` represent the cols of `y` to process out of `N1`. The middle shape is `MID`. $$ z_{i,j,k} = \sum_{l}x_{i,j,l} \times y_{i,l,k}\quad for\; i =\; 1\; ...\; N_2, j = 1\; ...\ N_0, k = 1\; ...\; N_1 $$ You are allowed to use `tl.dot` which computes a smaller mat mul. Hint: the main trick is that you can split a matmul into smaller parts. $$ z_{i,j,k} = \sum_{l}^{L/2}x_{i,j,l} \times y_{i,l,k} + \sum_{L/2}^{L}x_{i,j,l} \times y_{i,l,k} $$ ```triton def dot_spec(x: Float32[Tensor, "4 32 32"], y: Float32[Tensor, "4 32 32"]) -> Float32[Tensor, "4 32 32"]: return x @ y @triton.jit def dot_kernel(x_ptr, y_ptr, z_ptr, N0, N1, N2, MID, B0: tl.constexpr, B1: tl.constexpr, B2: tl.constexpr, B_MID: tl.constexpr): pid_0 = tl.program_id(0) pid_1 = tl.program_id(1) pid_2 = tl.program_id(2) accum = tl.zeros((B2, B0, B1), tl.float32) # x multiply y in 2 parts for i in range(tl.cdiv(MID, B_MID)): row_offs_x = pid_0 * B0 + tl.arange(0, B0) col_offs_x = i * B_MID + tl.arange(0, B_MID) row_offs_y = i * B_MID + tl.arange(0, B_MID) col_offs_y = pid_1 * B1 + tl.arange(0, B1) batch_offsets = pid_2 * (B2 * N0 * N1) + tl.arange(0, B2) * N0 * N1 x_mask = (col_offs_x[None, None, :] < MID) & (row_offs_x[None, :, None] < N0) y_mask = (col_offs_y[None, None, :] < N1) & (row_offs_y[None, :, None] < MID) x_ptrs = x_ptr + batch_offsets[:, None, None] + col_offs_x[None, None, :] + row_offs_x[None, :, None] * MID y_ptrs = y_ptr + batch_offsets[:, None, None] + col_offs_y[None, None, :] + row_offs_y[None, :, None] * N1 x = tl.load(x_ptrs, x_mask) y = tl.load(y_ptrs, y_mask) accum = tl.dot(x, y, accum) row_offs_z = pid_0 * B0 + tl.arange(0, B0) col_offs_z = pid_1 * B1 + tl.arange(0, B1) z_mask = (col_offs_z[None, None, :] < N0) & (row_offs_z[None, :, None] < N1) z_ptrs = z_ptr + batch_offsets[:, None, None] + col_offs_z[None, None, :] + row_offs_z[None, :, None] * N0 tl.store(z_ptrs, accum, z_mask) return ''' accum = tl.zeros((B0, B1), tl.float32) for i in range(tl.cdiv(MID, B_MID)): row_offs_x = pid_0 * B0 + tl.arange(0, B0) col_offs_x = i * B_MID + tl.arange(0, B_MID) row_offs_y = i * B_MID + tl.arange(0, B_MID) col_offs_y = pid_1 * B1 + tl.arange(0, B1) batch_offsets = pid_2 * (B2 * N0 * N1) + tl.arange(0, B2) x_mask = (col_offs_x[None, :] < MID) & (row_offs_x[:, None] < N0) y_mask = (col_offs_y[None, :] < N1) & (row_offs_y[:, None] < MID) x_ptrs = x_ptr + batch_offsets + col_offs_x[None, :] + row_offs_x[:, None] * MID y_ptrs = y_ptr + batch_offsets + col_offs_y[None, :] + row_offs_y[:, None] * N1 x = tl.load(x_ptrs, x_mask) y = tl.load(y_ptrs, y_mask) accum = tl.dot(x, y, accum) row_offs_z = pid_0 * B0 + tl.arange(0, B0) col_offs_z = pid_1 * B1 + tl.arange(0, B1) z_mask = (col_offs_z[None, :] < N0) & (row_offs_z[:, None] < N1) z_ptrs = z_ptr + batch_offsets + col_offs_z[None, :] + row_offs_z[:, None] * N0 tl.store(z_ptrs, accum, z_mask) return ''' test(dot_kernel, dot_spec, B={"B0": 16, "B1": 16, "B2": 1, "B_MID": 16}, nelem={"N0": 32, "N1": 32, "N2": 4, "MID": 32}) ``` # Puzzle 12: Quantized Matrix Mult When doing matrix multiplication with quantized neural networks a common strategy is to store the weight matrix in lower precision, with a shift and scale term. For this problem our 1weight1 will be stored in 4 bits. We can store 1FPINT1 of these in a 32 bit integer. In addition for every `group` weights in order we will store 1 `scale` float value and 1 `shift` 4 bit value. We store these for the column of weight. The `activations` are stored separately in standard floats. Mathematically it looks like. $$ z_{j,k} = \sum_{l}sc_{j, \frac{l}{g}}(w_{j,l}\; -\; sh_{j,\frac{l}{g}}) \times y_{l,k} \quad for\; i = 1\; ...\; N_2, j = 1\; ...\; N_0, k = 1\; ...\; N_1$$ However, it is a bit more complex since we need to also extract the 4-bit values into floats to begin. ```triton FPINT = 32 // 4 GROUP = 8 def quant_dot_spec(scale : Float32[Tensor, "32 8"], offset : Int32[Tensor, "32"], weight: Int32[Tensor, "32 8"], activation: Float32[Tensor, "64 32"]) -> Float32[Tensor, "32 32"]: offset = offset.view(32, 1) def extract(x): over = torch.arange(8) * 4 mask = 2**4 - 1 return (x[..., None] >> over) & mask scale = scale[..., None].expand(-1, 8, GROUP).contiguous().view(-1, 64) offset = extract(offset)[..., None].expand(-1, 1, 8, GROUP).contiguous().view(-1, 64) return ( scale * (extract(weight).view(-1, 64) - offset)) @ activation @triton.jit def quant_dot_kernel(scale_ptr, offset_ptr, weight_ptr, activation_ptr, z_ptr, N0, N1, MID, B0: tl.constexpr, B1: tl.constexpr, B_MID: tl.constexpr): pid_0 = tl.program_id(0) pid_1 = tl.program_id(1) row_offs_weight = pid_0 * B0 + tl.arange(0, B0) col_offs_weight = tl.arange(0, B_MID // FPINT) # weight sub-matrix shape is (16, 8) weight_mask = (col_offs_weight[None, :] < MID // FPINT) & (row_offs_weight[:, None] < N0) weight_ptrs = weight_ptr + col_offs_weight[None, :] + row_offs_weight[:, None] * (MID // FPINT) weight = tl.load(weight_ptrs, weight_mask) # activation sub-matrix shape is (64, 16) row_offs_act = tl.arange(0, B_MID) col_offs_act = pid_1 * B1 + tl.arange(0, B1) act_mask = (col_offs_act[None, :] < N1) & (row_offs_act[:, None] < MID) act_ptrs = activation_ptr + col_offs_act[None, :] + row_offs_act[:, None] * N1 activation = tl.load(act_ptrs, act_mask) # offset shape is (16, ) offs_offset = pid_0 * B0 + tl.arange(0, B0) offset_mask = offs_offset < N0 offset_ptrs = offset_ptr + offs_offset offset = tl.load(offset_ptrs, offset_mask) # scale shape is (16, 8), copy layout from weight sub-matrix row_offs_scale = row_offs_weight col_offs_scale = col_offs_weight scale_mask = weight_mask scale_ptrs = scale_ptr + col_offs_scale[None, :] + row_offs_scale[:, None] * (MID // FPINT) scale = tl.load(scale_ptrs, scale_mask) # Since each weight is 4 bits only, we could have 8 weight in one 32bits. integer. # We could use shifting & masking to filter out 4 bits weight. def extract(x): over = tl.arange(0, 8) * 4 mask = 2**4 - 1 return (x.expand_dims(-1) >> over) & mask scale = scale.expand_dims(-1).broadcast_to((B0, 8, GROUP)).reshape((B0, 64)) offset = extract(offset).expand_dims(-1).broadcast_to((B0, 8, GROUP)).reshape(B0, 64) z = tl.dot((scale * (extract(weight).reshape(B0, 64) - offset)), activation) row_offs_z = pid_0 * B0 + tl.arange(0, B0) col_offs_z = pid_1 * B1 + tl.arange(0, B1) z_mask = (col_offs_z[None, :] < N1) & (row_offs_z[:, None] < N0) z_ptrs = z_ptr + col_offs_z[None, :] + row_offs_z[: , None] * N1 tl.store(z_ptrs, z, z_mask) return test(quant_dot_kernel, quant_dot_spec, B={"B0": 16, "B1": 16, "B_MID": 64}, nelem={"N0": 32, "N1": 32, "MID": 64}) ``` Reference: 1. [Triton Puzzles][1] 2. [Triton Tutorial][2] 3. [An Even Easier Introduction to CUDA][3] 4. [CUDA kernel launch and thread hierarchy][4] 5. [Parallel Computing Using Cuda-C][5] [1]: https://github.com/srush/Triton-Puzzles?tab=readme-ov-file [2]: https://triton-lang.org/main/getting-started/tutorials/index.html [3]: https://developer.nvidia.com/blog/even-easier-introduction-cuda/ [4]: https://medium.com/@linyan98765/cuda-kernel-launch-and-thread-hierarchy-315f4aa3b355 [5]: https://github.com/CisMine/Parallel-Computing-Cuda-C/tree/main