# Making custom Kernels The TLDR is: You need to write new kernels for each combination of (GPU Architecture, Model Architecture, InferenceType) for it to go fast. We needed (RTX Pro 6000, Qwen3, Prefill), it didn't exist online, we wrote it. > Note that there are two InferenceTypes: Decode and Prefill. We didn't write decode kernels, i.e. autoregression. Those are generally bandwidth bound, and are 3-5x less efficient than prefill. Hence why output tokens are 3-5x more expensive than input tokens. Projects like vLLM start with the dream of coding the kernels to be super modular. TRT-LLM started with the dream of being able to deterministically create the CUDA code based on just the model architecture and GPU architecture. Honestly, neither dream really worked out - each new GPU architecture is _so_ different, and each new model architecture is _so_ different, that both projects just end up being a very large repository of open-source contributed kernels for every single specific combination. Being open-source inference engines, they (TRT, SGLang, vLLM) inevitably have the broadest support. Projects like HF Transformers simply use pytorch - i.e., very inefficient, but it supports backprop whereas inference engines do not. For our use-case, the inference engines didn't support our needs, so it was as slow as standard pytorch. ~20% GPU Util. Writing custom kernels on each of the 4 GEMMs + Attention, brought us to ~80% GPU Util. ## Details By default, each pytorch LOC reads from VRAM, executes the task, writes the result to VRAM. Consider the first half of the MLP of Qwen (The gateup part). In pytorch, its simply 4LOC, ```python gate = activations @ W_gate.T # [M, intermediate] up = activations @ W_up.T # [M, intermediate] gate_silu = F.silu(gate) # [M, intermediate] out = gate_silu * up # elementwise *, not matmul ``` Each pytorch operation does some amount of FLOPs and needs to read/write some amount of bytes of memory (Bandwidth). In general, to saturate a GPU, the desired Bandwidth:FLOPs ratio is 1:100; i.e., when you load a byte from VRAM, you want to do at least 100 assembly operations on that byte to be compute-bound. GEMMs are pretty easy to saturate, because MxK * KxN = MxN matmul does O(2MNK) FLOPs and uses O(2(MN+NK+KM)) Bandwidth. So, by big O notation, matmuls are flops-bound as long as M,N,K are large enough. Great! However, element-wise operations such as `F.silu`, and elementwise `*`, are O(MN) FLOPs and O(MN) Bandwidth. So they're massively bandwidth bound. You get ~1% GPU Util since you load a byte from VRAM, do 1 operation, then write back to VRAM. If writing a single cuda kernel to do the entire gateup, you want to "fuse" the element-wise operations into the GEMM. I.e., do the element-wise operations _before_ the GEMM writes the output to VRAM. For `F.silu`, it's pretty easy to fuse, just apply silu right before writing the output the VRAM. But for `gate * up`, that's non trivial because it's using the output of two separate GEMMs. To solve this problem (specific to the gateup of Qwen), we combine `gate = activations @ W_gate.T` and `up = activations @ W_up.T` into one bigger GEMM (since they share the same left operand, there's no wasted work here). And, crucially, we also interweave their rows so that the elements that need be multiplied will end up residing on the same Streaming Multiprocessor (The tiny but highly parallel processing units in a GPU). ```python # Preprocessing step weights_gate_up = torch.stack([gate, up], dim=1).reshape(2 * intermediate, hidden) # rows: [gate[0], up[0], gate[1], up[1], ...] ``` Now, you run ```python out = torch.empty(M, intermediate) custom_kernel(out, weights_gate_up, activations) ``` As for any GEMM, you break up the work into a bunch of tiny output tiles so that it's embarrassingly parallel. A kernel is a task-executor that runs the same code on each of the Steaming Multiprocessor (SMs). RTX Pro 6000 has 188 SMs. The kernel we write to have a work stealing loop where it looks at the big heap of work to do, takes a job, calculates the 16x16 output, then writes the result to VRAM. It does so until there's no more work to do. Our custom-fused gate-up kernel custom kernel will look like this: ```python= # Pythonic Pseudocode, but CUDA is a C++-style DSL def custom_kernel(out, W_gate_up, activations): while (tile := work.get_work()): # one fused GEMM produces both gate and up rows gu = activations @ W_gate_up[tile].T # in registers, never to VRAM gate, up = gu[:, 0::2], gu[:, 1::2] # adjacent lanes out[tile] = F.silu(gate) * up # epilogue, single VRAM write ``` ## General GEMM The above makes it sound like `custom_kernel` is like 5LOC. But, it forgets that just doing a MxK\*KxN GEMM at all is quite involved. Remember, the name of the game is to reduce memory bandwidth as much as possible. So, there are layers of tricks to "share" as much memory as possible. - Consider a Streaming Multiprocessor (SM) naively pulling a random 16x16 output tile on each work iteration. For that, it needs to read 16 length-K rows of the left operand and 16 length-K columns of the right operand, and dot product all 16x16 combinations. CUDA lets you do all 256 dot products in a single assembly command specialized for matmuls (Called `mma`). But, this requires reading 32 vectors of length K. What if instead, the SM pulled a 2x2 grid of 16x16 output tiles (i.e. it pulled 32x32 and used 4 `mma` instructions per workgroup). Then, you would get 4x as much output work done, but only need to read 32 rows of the left operand and 32 columns of the right operand. I.e., only read 64 vectors of length K. At 4x the output work done, but only 2x the bandwidth read, that's a 2:1 improvement in bandwidth:FLOPs usage. - In our kernels, we found 128x128 "megatiles" to be best (Mostly because that's the largest size that still fits in the SM's memory). When using massive tile sizes, you will want large batch size, otherwise the other SM's won't have any work to do. - Now, consider two Streaming Multiprocessors (SMs) that are physically adjacent, on the chip. Blackwell has a new feature that lets you "read" memory from your neighbor. This is complex to orchestrate, but the result is that e.g. if the two SMs read adjacent 128x128 megatiles, they can share: e.g. each reads 64 different rows of the left operand, and they read eachothers' memory to complete the MMAs. This again saves memory bandwidth, making it more compute-bound. You will also want to "pre-fetch". I.e., grab the memory for the next work item, while computing the 128x128 dot products for the current item. That way, on the next iteration of the for loop, the operands are loaded and you're ready to matmul. On Blackwell this is generally done with two workers being mutex-synchronized in a "ping-pong" alternating pattern, where one task issues a bunch of loads, and another task executing a bunch of MMAs from the last load, and they alternate back and forth, using the same SM's shared memory. An example of an end-to-end basic `C = A @ B` matmul kernel on Blackwell B200 with all of the low-level intricacies is found here: https://gau-nernst.github.io/tcgen05/ - Despite the high complexity of this article, matmul on B200 is _much_ easier than on Hopper (H100/H200). Hopper is a nightmare in comparison. - Doing matmul for RTX Pro 6000 is _completely_ different than this article, but approximately equal in complexity, but since RTX Pro 6000 doesn't have much support, there are not clean tutorials for it like there are for B200 (Lack of documentation is the main thing that makes these things difficult, as even the B200 article notes). ## Resources - https://veitner.bearblog.dev/an-applied-introduction-to-cutedsl/ - https://gau-nernst.github.io/tcgen05/ - https://hazyresearch.stanford.edu/blog/2026-02-19-tk-2 - https://hc2023.hotchips.org/assets/program/conference/day2/ML%20training/HC2023.Session5.ML_Training.Cerebras.Sean_Lie.final_v02.pdf - https://christianjmills.com/posts/cuda-mode-notes/lecture-008/#case-study-6-rewriting-algorithms-with-better-math-flash-attention - https://www.lesswrong.com/posts/9yJxKSEoidX4HEeMp/google-seemingly-solved-efficient-attention - https://www.youtube.com/watch?v=xmkSf5IS-zw CUDA Docs: - https://docs.nvidia.com/cuda/pdf/CUDA_Math_API.pdf QAT: - SageAttention: https://arxiv.org/pdf/2505.11594 - https://arxiv.org/pdf/2504.19874 - ArcQuant https://arxiv.org/pdf/2601.07475 - https://arxiv.org/html/2602.02047v1 - https://arxiv.org/pdf/2509.23202 - https://arxiv.org/pdf/2601.20088 - https://arxiv.org/pdf/2505.19115 - https://arxiv.org/pdf/2004.07320 - https://arxiv.org/pdf/1807.04629 - https://arxiv.org/pdf/2603.22370 - https://arxiv.org/pdf/2004.07320