Install VLLM Spark 1. Install uv ```bash curl -LsSf https://astral.sh/uv/install.sh | sh ``` 2. Create environment ```bash sudo apt install python3-dev uv venv .vllm --python 3.12 source .vllm/bin/activate ``` 3. Install vllm ```bash uv pip install -U vllm --torch-backend=auto --extra-index-url https://wheels.vllm.ai/nightly/cu130 uv pip install --prerelease=allow --force-reinstall triton --index-url https://download.pytorch.org/whl/test/cu132 ``` 4. Export variables ```bash export TORCH_CUDA_ARCH_LIST=12.1a ``` 5. Clean memory ```bash sudo sysctl -w vm.drop_caches=3 ``` # Wiring `flash-attn-4-sm120` into vLLM on SM120 / SM121 How to make vLLM use the [SecondNatureComputing/flash-attn-4-sm120](https://huggingface.co/SecondNatureComputing/flash-attn-4-sm120) Hugging Face kernel for the FA4 path on consumer Blackwell GPUs (RTX 5090, RTX PRO 6000 Blackwell, DGX Spark GB10 / SM121a). Tested on: - DGX Spark (GB10), compute capability **12.1** - vLLM installed in a venv at `~/Projects/vllm/.vllm` - Python 3.12 > **Heads up before you start.** On SM120/121 this kernel is not a free speedup. The HF README's own benchmarks show FA4 is **4–10% slower** than vLLM's bundled FA2 on realistic Qwen3 prefill shapes, and **~1.7× slower** on very short sequences (S = 128). Use this when you specifically need FA4-only features (paged KV with FA4, `score_mod`, block sparse, dropout in attention). For pure throughput, stay on FA2 and don't read further. --- ## 1. Why a shim is needed vLLM in CUDA does **not** import the upstream `flash_attn` package. It uses its own bundled fork: ```python # vllm/v1/attention/backends/fa_utils.py if current_platform.is_cuda(): from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata ``` `vllm.vllm_flash_attn` calls compiled C++ extensions (`_vllm_fa2_C`, `_vllm_fa3_C`) and its own CuTe DSL implementation under `vllm/vllm_flash_attn/cute/` for FA4. So aliasing `sys.modules['flash_attn']` to the HF kernel **does nothing** for the main attention path. vLLM already ships SM120 CuTe kernels (`vllm/vllm_flash_attn/cute/flash_fwd_sm120.py`), but in the version most people have installed they are **gated**: ```python # vllm/vllm_flash_attn/cute/interface.py assert page_table is None, "Paged KV not supported on SM 12.0 in this PR" assert not is_split_kv, "SplitKV not supported on SM 12.0 in this PR" ``` vLLM serves with paged KV always, so the bundled FA4 SM120 path can't run during serving. Net effect: on SM12x, vLLM falls back to FA2. The HF kernel bundles two PRs that vLLM's bundled version is missing: - **#2348** — SM120 kernel-level paged KV cache support - **#2336** — SM120 split-KV (FlashDecoding) So the only way to use FA4-with-paged-KV on SM12x today is to redirect vLLM's FA4 call site (`vllm.vllm_flash_attn.cute.interface._flash_attn_fwd`) to the HF kernel. That's what this shim does. --- ## 2. Prerequisites ```bash # Inside your vLLM venv uv pip install -U "kernels>=0.4" "nvidia-cutlass-dsl>=4.4.1" einops apache-tvm-ffi ``` Confirm hardware: ```bash nvidia-smi --query-gpu=name,compute_cap --format=csv # Expect compute_cap >= 12.0 (e.g. 12.1 on DGX Spark GB10) ``` CUDA Toolkit must be **12.8 or newer** (FA4 baseline). --- ## 3. Pre-download the kernel ```bash python -c "from kernels import get_kernel; get_kernel('SecondNatureComputing/flash-attn-4-sm120')" ``` Sanity check: ```bash python - <<'PY' import torch from kernels import get_kernel fa4 = get_kernel("SecondNatureComputing/flash-attn-4-sm120") q = k = v = torch.randn(1, 1024, 16, 128, device="cuda", dtype=torch.bfloat16) out, _ = fa4.flash_attn_func(q, k, v, causal=True) print("OK", out.shape, out.dtype) PY ``` --- ## 4. Install the shim The shim does three things, lazily, when the relevant vLLM modules first load: 1. Patches `vllm.vllm_flash_attn.flash_attn_interface._is_fa4_supported` to accept SM 9.x / 10.x / 11.x / 12.x. 2. Wraps `vllm.vllm_flash_attn.cute.interface._flash_attn_fwd` so that on SM12x the call dispatches to the HF kernel's `_flash_attn_fwd` (which has paged KV). 3. Wraps `vllm.v1.attention.backends.fa_utils.get_flash_attn_version` so that on SM12x with `head_dim ≤ 128` it returns `4` (instead of the default `2`). The wiring is done via a `sys.meta_path` finder that runs the patches the moment those vLLM modules finish loading — before any other module captures `get_flash_attn_version` by name. > **Note on `sitecustomize.py`.** Ubuntu ships `/usr/lib/python3.12/sitecustomize.py` which takes precedence over a venv-local one, so we use a `.pth` file (`zzz_*.pth`) instead. `.pth` files with `import …` lines are processed by `site.py` at interpreter startup, regardless of OS-level `sitecustomize`. ### 4.1 Create the shim file ```bash SP="$VIRTUAL_ENV/lib/python3.12/site-packages" # Remove any earlier experimental shim rm -f "$SP/zzz_fa4_sm120_shim.pth" "$SP/fa4_sm120_shim.py" "$SP/sitecustomize.py" cat > "$SP/fa4_sm120_shim.py" <<'PY' """ Force vLLM on SM12x (RTX 5090 / RTX PRO 6000 / DGX Spark) to use the SecondNatureComputing/flash-attn-4-sm120 HF kernel for the FA4 path. Disable by setting env var FA4_SM120_SHIM=0 before launching. """ from __future__ import annotations import os import sys import warnings from importlib.abc import Loader, MetaPathFinder if os.environ.get("FA4_SM120_SHIM", "1") != "0": _HF_KERNEL = None def _hf_kernel(): global _HF_KERNEL if _HF_KERNEL is None: from kernels import get_kernel _HF_KERNEL = get_kernel("SecondNatureComputing/flash-attn-4-sm120") return _HF_KERNEL _PATCHED: set[str] = set() def _is_sm12x() -> bool: try: import torch if not torch.cuda.is_available(): return False major, _ = torch.cuda.get_device_capability() return major == 12 except Exception: return False def _patch_fa_iface(mod): def _is_fa4_supported_patched(): if not getattr(mod, "FA4_AVAILABLE", False): return False, getattr(mod, "FA4_UNAVAILABLE_REASON", "FA4 unavailable") try: import torch major, _ = torch.cuda.get_device_capability() except Exception: return False, "no CUDA device" if major in (9, 10, 11, 12): return True, None return False, f"FA4 not supported on capability {major}.x" mod._is_fa4_supported = _is_fa4_supported_patched def _patch_cute_iface(mod): orig = mod._flash_attn_fwd def _dispatch(*args, **kwargs): if _is_sm12x(): return _hf_kernel().interface._flash_attn_fwd(*args, **kwargs) return orig(*args, **kwargs) mod._flash_attn_fwd = _dispatch def _patch_fa_utils(mod): import functools orig = mod.get_flash_attn_version @functools.wraps(orig) def patched(requires_alibi: bool = False, head_size: int | None = None, head_size_v: int | None = None, has_sinks: bool = False): if not _is_sm12x(): return orig(requires_alibi=requires_alibi, head_size=head_size, head_size_v=head_size_v, has_sinks=has_sinks) if requires_alibi: return 2 if head_size is not None and head_size > 128: return 2 try: from vllm.vllm_flash_attn.flash_attn_interface import is_fa_version_supported if is_fa_version_supported(4): return 4 except Exception: pass return 2 mod.get_flash_attn_version = patched _DISPATCH = { "vllm.vllm_flash_attn.flash_attn_interface": _patch_fa_iface, "vllm.vllm_flash_attn.cute.interface": _patch_cute_iface, "vllm.v1.attention.backends.fa_utils": _patch_fa_utils, } def _try_patch(name: str): if name in _PATCHED or name not in _DISPATCH: return mod = sys.modules.get(name) if mod is None: return try: _DISPATCH[name](mod) _PATCHED.add(name) except Exception as e: warnings.warn(f"[fa4_sm120_shim] failed to patch {name}: {e!r}") class _WrappedLoader(Loader): def __init__(self, real, name): self._real = real self._name = name def create_module(self, spec): if hasattr(self._real, "create_module"): return self._real.create_module(spec) return None def exec_module(self, module): self._real.exec_module(module) _try_patch(self._name) class _PatchingFinder(MetaPathFinder): def find_spec(self, name, path=None, target=None): if name not in _DISPATCH or name in _PATCHED: return None for finder in list(sys.meta_path): if finder is self or not hasattr(finder, "find_spec"): continue spec = finder.find_spec(name, path, target) if spec is not None and spec.loader is not None: spec.loader = _WrappedLoader(spec.loader, name) return spec return None if not getattr(sys, "_fa4_sm120_shim_installed", False): sys.meta_path.insert(0, _PatchingFinder()) try: sys._fa4_sm120_shim_installed = True # type: ignore[attr-defined] except Exception: pass for _n in list(_DISPATCH): _try_patch(_n) PY ``` ### 4.2 Auto-load via a `.pth` file ```bash cat > "$SP/zzz_fa4_sm120_shim.pth" <<'PTH' import fa4_sm120_shim PTH ``` `.pth` files starting with `import …` are executed by `site.py` at interpreter startup. The `zzz_` prefix sorts our file last so it runs after `torch`, `kernels`, etc. --- ## 5. Verify ### 5.1 Hook is installed ```bash python - <<'PY' import sys import fa4_sm120_shim print("hook installed:", any("PatchingFinder" in type(f).__name__ for f in sys.meta_path)) PY ``` Expected: `hook installed: True`. ### 5.2 Patches applied to vLLM ```bash python - <<'PY' import vllm.vllm_flash_attn.flash_attn_interface as fai import vllm.vllm_flash_attn.cute.interface as cute import vllm.v1.attention.backends.fa_utils as fau print("FA4_AVAILABLE: ", fai.FA4_AVAILABLE) print("_is_fa4_supported(): ", fai._is_fa4_supported()) print("get_flash_attn_version(head_size=128):", fau.get_flash_attn_version(head_size=128)) print("get_flash_attn_version(head_size=256):", fau.get_flash_attn_version(head_size=256)) print("_flash_attn_fwd qualname: ", cute._flash_attn_fwd.__qualname__) PY ``` Expected: ``` FA4_AVAILABLE: True _is_fa4_supported(): (True, None) get_flash_attn_version(head_size=128): 4 get_flash_attn_version(head_size=256): 2 _flash_attn_fwd qualname: _patch_cute_iface.<locals>._dispatch ``` If `_flash_attn_fwd qualname` shows the original (something like `cute.interface._flash_attn_fwd`), the hook didn't run for that module — see "Troubleshooting". ### 5.3 End-to-end FA4 fwd through vLLM's call site This is closer to what vLLM actually invokes: ```bash python - <<'PY' import torch from vllm.vllm_flash_attn.cute.interface import _flash_attn_fwd B, S, Hq, Hkv, D = 1, 1024, 16, 8, 128 q = torch.randn(B*S, Hq, D, device="cuda", dtype=torch.bfloat16) k = torch.randn(B*S, Hkv, D, device="cuda", dtype=torch.bfloat16) v = torch.randn(B*S, Hkv, D, device="cuda", dtype=torch.bfloat16) cu = torch.tensor([0, B*S], device="cuda", dtype=torch.int32) out, lse = _flash_attn_fwd( q, k, v, cu_seqlens_q=cu, cu_seqlens_k=cu, max_seqlen_q=B*S, max_seqlen_k=B*S, softmax_scale=D**-0.5, causal=True, return_lse=True, ) print("OK", out.shape, out.dtype, "lse:", None if lse is None else lse.shape) PY ``` If this completes without error, the HF kernel is being driven through vLLM's FA4 entry point. --- ## 6. Launch vLLM Run with debug logging the first time so you can verify the routing: ```bash VLLM_LOGGING_LEVEL=INFO vllm serve Qwen/Qwen3.5-27B \ --speculative-config '{"method": "dflash", "model": "z-lab/Qwen3.5-27B-DFlash", "num_speculative_tokens": 15}' \ --attention-backend flash_attn \ --gpu-memory-utilization 0.85 \ --max-model-len 65536 \ --load-format fastsafetensors \ --max-num-batched-tokens 32768 2>&1 | tee vllm.log ``` In another terminal: ```bash grep -iE "fa version|flash_attn|attention backend|sm12|fa4" vllm.log | head -50 ``` Look for log lines mentioning `fa_version=4` or the chosen attention backend. --- ## 7. Disable the shim Without removing files: ```bash FA4_SM120_SHIM=0 vllm serve ... ``` Permanently: ```bash SP="$VIRTUAL_ENV/lib/python3.12/site-packages" rm -f "$SP/fa4_sm120_shim.py" "$SP/zzz_fa4_sm120_shim.pth" ``` --- ## 8. Troubleshooting ### `_flash_attn_fwd qualname` still shows the original Most likely the module was imported in some other Python process or by some import path that loaded it before the finder was active. Make sure: - You created `zzz_fa4_sm120_shim.pth` (the `zzz_` prefix matters for ordering). - You don't have a competing `.pth` that imports vLLM modules earlier. - The shim file is in the *venv's* site-packages, not the system one. Sanity: ```bash python -c "import sys; print('\n'.join(p for p in sys.path if 'site-packages' in p))" ls "$VIRTUAL_ENV/lib/python3.12/site-packages/" | grep fa4_sm120 ``` ### `assert is_fa_version_supported(4)` fires somewhere We already patch `_is_fa4_supported`, but if some code path imported `is_fa_version_supported` before the patch ran and captured it by name, it might still see the unpatched version. Reproduce with the verifier in §5.2 and pinpoint which module — open an issue or extend `_DISPATCH` to also patch the offending module. ### Tensor-shape / stride asserts inside the HF kernel Possible signature drift between vLLM's call site and the HF kernel's `_flash_attn_fwd`. Compare: - vLLM call: `vllm/vllm_flash_attn/flash_attn_interface.py` around `elif fa_version == 4:` - HF `_flash_attn_fwd`: `~/.cache/huggingface/hub/models--SecondNatureComputing--flash-attn-4-sm120/snapshots/<hash>/build/torch-cuda/interface.py` If a kwarg vLLM passes isn't accepted by the HF kernel, drop it inside `_dispatch` before forwarding. As of HF kernel v0.1.0 all kwargs vLLM passes (`cu_seqlens_q/k`, `seqused_k`, `max_seqlen_q/k`, `page_table`, `softmax_scale`, `causal`, `softcap`, `window_size_left/right`, `num_splits`, `return_lse`, `out`, `learnable_sink`) are present. ### `head_dim > 128` model The SM120 kernel cannot fit `head_dim > 128` in 99 KB SMEM. The shim's `get_flash_attn_version` already returns 2 in that case, so vLLM falls back to FA2. If the model has `head_dim == 256`, **don't bother with this shim** — neither this kernel nor vLLM's bundled FA4 will run, and vLLM's FA2 already handles it. Check head_dim: ```bash python - <<'PY' from huggingface_hub import hf_hub_download import json p = hf_hub_download("Qwen/Qwen3.5-27B", "config.json") c = json.load(open(p)) print("head_dim:", c.get("head_dim") or (c["hidden_size"] // c["num_attention_heads"])) PY ``` ### `Paged KV not supported on SM 12.0 in this PR` This assert is in vLLM's *bundled* `cute/interface.py`. If you see it, the dispatch didn't reach our wrapper. Re-run the §5.2 verifier; if `_flash_attn_fwd qualname` doesn't say `_dispatch`, the patch is not in effect. --- ## 9. How it works (under the hood) ``` .vllm/lib/python3.12/site-packages/ ├── zzz_fa4_sm120_shim.pth ← processed by site.py at startup; runs `import fa4_sm120_shim` ├── fa4_sm120_shim.py ← installs sys.meta_path finder, patches three vLLM modules │ └── vllm/ ├── v1/attention/backends/ │ └── fa_utils.py ← get_flash_attn_version() — patched: returns 4 on SM12x └── vllm_flash_attn/ ├── flash_attn_interface.py ← _is_fa4_supported() — patched: accepts SM 9-12.x │ └── flash_attn_varlen_func — dispatches on fa_version, calls _flash_attn_fwd if 4 └── cute/ └── interface.py ← _flash_attn_fwd() — patched: dispatches to HF kernel on SM12x ~/.cache/huggingface/hub/models--SecondNatureComputing--flash-attn-4-sm120/ └── snapshots/<hash>/build/torch-cuda/ └── interface.py ← real _flash_attn_fwd with paged KV (PR #2348) ``` Flow at request time: 1. Some vLLM backend (`vllm/v1/attention/backends/flash_attn.py`) calls `get_flash_attn_version(head_size=128)`. 2. **Patched** version returns `4` (instead of `2`). 3. The backend calls `flash_attn_varlen_func(..., fa_version=4)`. 4. Inside, the FA4 branch does `from vllm.vllm_flash_attn.cute.interface import _flash_attn_fwd`. 5. **Patched** `_flash_attn_fwd` is a `_dispatch` wrapper. On SM12x it forwards to `hf_kernel.interface._flash_attn_fwd`, otherwise it calls the original. 6. The HF kernel runs, returns `(out, softmax_lse)`. --- ## 11. Caveats - The shim assumes vLLM's internal API names (`_flash_attn_fwd`, `is_fa_version_supported`, `get_flash_attn_version`). If you upgrade vLLM and these change, the shim's verifier in §5.2 will tell you immediately. - On SM12x, the HF kernel's `interface.py` clamps `num_splits` to 1 (no SplitKV). Decode workloads use a single split anyway, but if you've tuned vLLM with `num_splits > 1` it'll be silently ignored. - Dropout in this kernel falls back to smaller tiles to avoid register spills. Throughput cost is small but real. - Backward pass on SM12x has its own restrictions (no block sparse, no `score_mod`, no `mask_mod`, no deterministic). Inference is forward-only so this doesn't affect serving. --- ## License The HF kernel is BSD-3-Clause (inherited from Dao-AILab/flash-attention). The shim above is provided as-is, no warranty.