# State of IREE MXFP4 GEMM Codegen ## Table of Contents - [XOR swizzling](#xor-swizzling) - [The main hacks](#the-main-hacks) - [Megaintrinsics](#megaintrinsics) - [Heuristics (copying ASM kernel heuristics)](#heuristics-copying-asm-kernel-heuristics) - [Experiment](#experiment) - [Results](#results) - [Bank conflicts](#bank-conflicts) - [Performance: experimental_codegen vs ToM](#performance-experimental_codegen-vs-tom) - [Performance: experimental_codegen vs experimental_no_swizzle](#performance-experimental_codegen-vs-experimental_no_swizzle) - [Performance: ToM vs xor_swizzle](#performance-tom-vs-xor_swizzle) - [Performance: Unpruned](#performance-unpruned) - [Summary](#summary) - [Unpruned shapes](#unpruned-shapes) - [pruned shapes](#pruned-shapes) - [Takeaways](#takeaways) --- ## Summary (2025-12-25) Currently, the performance of MXFP4 GEMMs through IREE Codegen (Top of main) is far from the performance of handwritten assembly. To be precise, the throughput of the compute kernel produced by either kernel was measured over shapes that seemed relevant to LLAMA 405B. These numbers were calculated through the latency measurements of these dispatches in tracy. These throughput measurements are tabulated below. | M | N | K/2 | K/32 | asm kernel | ToM | experimental | |-------|-------|-------|------|----------------------|----------|--------------| | 512 | 512 | 256 | 16 | 15.82662909 | 8.461856 | 12.831523 | | 1024 | 512 | 256 | 16 | 32.26387692 | 16.07350 | 26.8435456 | | 8192 | 512 | 256 | 16 | 240.73579373 | 165.6817 | 181.0694476 | |16384 | 512 | 256 | 16 | 457.88563923 | 287.0967 | 305.9093516 | |53248 | 512 | 256 | 16 | 924.38288216 | 541.8833 | 674.9664521 | | 512 | 1024 | 8192 | 512 | 199.85422843 | 75.12591 | 71.5983681 | | 1024 | 1024 | 8192 | 512 | 378.90780172 | 152.3274 | 463.5753095 | | 8192 | 1024 | 8192 | 512 | 2829.09713716 | 1097.400 | 1251.954632 | |16384 | 1024 | 8192 | 512 | 3988.36196959 | 1956.440 | 2273.5408298 | |53248 | 1024 | 8192 | 512 | 3516.95857129 | 1810.577 | 2140.1038907 | | 512 |16384 | 8192 | 512 | 3094.44258863 | 1183.913 | 769.8626043 | | 1024 |16384 | 8192 | 512 | 4214.62598810 | 2113.372 | 2332.943547 | | 8192 |16384 | 8192 | 512 | 5200.83736511 | 2366.870 | 2673.3293607 | |16384 |16384 | 8192 | 512 | 5208.63823157 | 2110.495 | 2543.9122511 | |53248 |16384 | 8192 | 512 | 4055.99681196 | 2179.195 | 2469.2284563 | | 512 |53248 | 8192 | 512 | 3432.63362024 | 1693.769 | 1112.591597 | | 1024 |53248 | 8192 | 512 | 3365.53069058 | 1799.282 | 2197.4546110 | | 8192 |53248 | 8192 | 512 | 2971.20211716 | 1674.111 | 2028.4375840 | |16384 |53248 | 8192 | 512 | 2963.43797902 | 1734.363 | 2035.0343922 | |53248 |53248 | 8192 | 512 | 3054.71551623 | 1876.655 | 2138.6656562 | | 512 |16384 |26624 |1664 | 3094.44258863 | 1183.913 | 793.1444136 | | 1024 |16384 |26624 |1664 | 4967.48886548 | 2297.846 | 1952.7315358 | | 8192 |16384 |26624 |1664 | 4065.22006480 | 2000.249 | 2362.5201934 | |16384 |16384 |26624 |1664 | 3679.96100111 | 2097.239 | 2379.9390850 | |53248 |16384 |26624 |1664 | 3974.31085006 | 2251.034 | 2622.5220145 | The average geomean throughput of ToM is 49% of the asm kernel or 51% slower. This worsens to 33% of the asm kernel when values of M are used that are not perfect powers of 2 (not shown here). Through several experiments we found that there are opportunities to improve this to 63% of the asm kernel in average geomean throughput. These improvements are: - Improvements to Virtual(Scaled)MMAAttr to allow for scales packing. - Support for XOR swizzling to reduce bank conflicts. ## State of MXFP4 GEMM in IREE (2025-12-25) ### Context: Given IREE's current codegen performance for FP4 Scaled matmul gemms is 49% of handwritten assembly, there are several opportunities that we've identified for performance gains: - [XOR swizzling](https://github.com/iree-org/iree/issues/22256) - [Mega-intrinsics](https://github.com/iree-org/iree/discussions/22820) - Improved heuristics To understand what kind of improvements these changes represent we made an experimental hacky implementation of these changes [in this PR](https://github.com/iree-org/iree/pull/22896). ### The main hacks: #### XOR swizzling To resolve the significant number of LDS bank conflicts that were occurring ([explained in this section](#bank-conflicts)) we add a swizzle hint during bufferization to alloc ops related to fp4 inputs. ```c++ if (isa<Float4E2M1FNType>(memRefType.getElementType())) { auto flatAllocType = MemRefType::get(ArrayRef<int64_t>{memRefType.getNumElements()}, memRefType.getElementType(), AffineMap(), workgroupSpace); Value flatAlloc = memref::AllocOp::create(builder, loc, flatAllocType); Value swizzled = iree_compiler::IREE::Codegen::SwizzleHintOp::create(builder, loc, flatAlloc, iree_compiler::IREE::Codegen::XORShuffleAttr::get(builder.getContext(), 256, 32, int64_t(), int64_t())); ReassociationIndices reassoc = llvm::to_vector(llvm::seq(allocType.getRank())); Value expanded = memref::ExpandShapeOp::create(builder, loc, allocType.getShape(), swizzled, {reassoc}); return expanded; } ``` Note that a more principles way of doing this is explained [here](https://github.com/iree-org/iree/issues/22256#issuecomment-3423389275). #### Packing scales The presence of single byte loads from LDS in ToM that were loaded four bytes at a time in the handwritten kernel flagged that ToM could benefit from some form of scales packing. So our very rough analogy was to change the [MMA single subgroup layout]() to do something _similar_ to packing scales: ```c++ case kScaledMMAOperandLhsScale: return {/*outer=*/{1, 1}, /*thread=*/{16, 1}, /*tstrides=*/{1, 16}, /*element=*/{1, 4}}; case kScaledMMAOperandRhsScale: return {/*outer=*/{1, 1}, /*thread=*/{1, 16}, /*tstrides=*/{16, 1}, /*element=*/{4, 1}}; ``` Note that a more robust approach to this will involve modifying ScaledMMAIntrinsic to add a notion of repeating the underlying intrinsic, which will also generalize VirtualMMAAttr. #### Heuristics (copying ASM kernel heuristics) Also there was a possibility that our heuristics that until this point had not been tuned for MXFP4 GEMMs were the problem so we manually change the seed used for tiling heuristics for scaled GEMMs, bringing it closer to the asm kernel: ```c++ if (scaled) { seeds.bestMNTileCountPerSubgroup = 64; seeds.bestKElementCountPerSubgroup = 256; } ``` This change does two things, first we set the output tile size per subgroup to be 128x128. Second, we set the number of K elements processed per subgroup to be 256 (meaning two iterations of the scaled_mfma intrinsic in the K loop). ### Experiment With the three hacks listed above, we were able to compare this experimental codegened kernel against the handwritten assembly kernel to observe what impacts these changes could have should they be implemented more robustly within IREE. This was done by performing a sweep across 4 configurations and across 50 shapes that seemed relevant to LLAMA 405B. Note, the throughput was calculated by first measuring the latency of only the compute kernel then dividing it by the arithmetic intensity of the GEMM being tested. **Configurations tested**: - 1. ToM (Top of Main): None of the hacks enabled - 2. xor_swizzle: Swizzling of fp4 inputs enabled - 4. experimental_no_swizzle: Changes to heuristics and packing scales enabled - 3. experimental_codegen: All hacks enabled The goal of choosing these configurations specifically was to thorougly test the importance of XOR swizzling. **Shapes tested**: ``` M,N,K/2,K/32 512,512,256,16 500,512,256,16 1024,512,256,16 1000,512,256,16 8192,512,256,16 8100,512,256,16 16384,512,256,16 16300,512,256,16 53248,512,256,16 53200,512,256,16 512,1024,8192,512 500,1024,8192,512 1024,1024,8192,512 1000,1024,8192,512 8192,1024,8192,512 8100,1024,8192,512 16384,1024,8192,512 16300,1024,8192,512 53248,1024,8192,512 53200,1024,8192,512 512,16384,8192,512 500,16384,8192,512 1024,16384,8192,512 1000,16384,8192,512 8192,16384,8192,512 8100,16384,8192,512 16384,16384,8192,512 16300,16384,8192,512 53248,16384,8192,512 53200,16384,8192,512 512,53248,8192,512 500,53248,8192,512 1024,53248,8192,512 1000,53248,8192,512 8192,53248,8192,512 8100,53248,8192,512 16384,53248,8192,512 16300,53248,8192,512 53248,53248,8192,512 53200,53248,8192,512 512,16384,26624,1664 500,16384,26624,1664 1024,16384,26624,1664 1000,16384,26624,1664 8192,16384,26624,1664 8100,16384,26624,1664 16384,16384,26624,1664 16300,16384,26624,1664 53248,16384,26624,1664 ``` ### Results #### Bank conflicts Through the rocprofv3 counter "SQ_LDS_BANK_CONFLICT" we were able to measure the number of cycles lost to bank conflicts in LDS. The following graph shows the bank conflicts observed at Top of Main across all the shapes measured. ![image](https://hackmd.io/_uploads/By2A4mtGWx.png) As shown above, the number of cycles lost to bank conflicts in LDS can range anywhere from tens of thousands to billions. However, after applying xor swizzling there are **no bank conflicts across all shapes.** So all of these bank conflicts disappear, which is good! #### Performance: experimental_codegen vs ToM After applying all these hacks we found a modest improvement, approximately equal to a 5% geomean improvement averaged across all shapes. ![image](https://hackmd.io/_uploads/BJSoFzFMZg.png) The performance of the hacky version suffers a lot from M dimension values that are not powers of 2. Excluding these shapes from our test results, we arrive at a 19.27% improvement over Top of Main (this is significant!). ![image](https://hackmd.io/_uploads/HkqfoMKz-l.png) #### Performance: experimental_codegen vs experimental_no_swizzle Now note that when we undo the change to swizzling from the list of applied hacks we see a decrease in performance by 9% (0.9132x). ![image](https://hackmd.io/_uploads/HJEHTzFM-x.png) #### Performance: ToM vs xor_swizzle Note also that applying the swizzle to Top of main achieves an average geomean improvement of 5.04% across this smaller subset of tests. ![image](https://hackmd.io/_uploads/Sk0ByQYGZl.png) #### Performance: Unpruned For those curious what the performance looks like when the unfriendly shapes are included: ![image](https://hackmd.io/_uploads/B1wPlmYG-x.png) ### Summary #### Unpruned shapes | Method | Geomean | Change vs ToM | |--------|---------|---------------| | ToM | 1x | +0.00% | | xor_swizzle | 0.9904x | -0.96% | | experimental_codegen | 1.0571x | +5.71% | | experimental_no_swizzle | 0.9653x | -3.47% | #### pruned shapes | Method | Geomean | Change vs ToM | |--------|---------|---------------| | ToM | 1x | +0.00% | | xor_swizzle | 1.0504x | +5.04% | | experimental_codegen | 1.1927x | +19.27% | | experimental_no_swizzle | 1.0139x | +1.01% | #### Takeaways - Over unfriendly shapes it's incredibly difficult to say that our hacks are useful (we are looking into this). - We expect that with changes to top of main the improvements observed from our hacks (an improvement of 19% over ToM for friendly shapes) can be realized in IREE. - Almost all of this gain is dependent on xor swizzling on inputs, the impact of removing bank conflicts in LDS provides a 5% improvement right now, but is required for future gains. - **swizzling is important!**