Install VLLM Spark
1. Install uv
```bash
curl -LsSf https://astral.sh/uv/install.sh | sh
```
2. Create environment
```bash
sudo apt install python3-dev
uv venv .vllm --python 3.12
source .vllm/bin/activate
```
3. Install vllm
```bash
uv pip install -U vllm --torch-backend=auto --extra-index-url https://wheels.vllm.ai/nightly/cu130
uv pip install --prerelease=allow --force-reinstall triton --index-url https://download.pytorch.org/whl/test/cu132
```
4. Export variables
```bash
export TORCH_CUDA_ARCH_LIST=12.1a
```
5. Clean memory
```bash
sudo sysctl -w vm.drop_caches=3
```
# Wiring `flash-attn-4-sm120` into vLLM on SM120 / SM121
How to make vLLM use the [SecondNatureComputing/flash-attn-4-sm120](https://huggingface.co/SecondNatureComputing/flash-attn-4-sm120) Hugging Face kernel for the FA4 path on consumer Blackwell GPUs (RTX 5090, RTX PRO 6000 Blackwell, DGX Spark GB10 / SM121a).
Tested on:
- DGX Spark (GB10), compute capability **12.1**
- vLLM installed in a venv at `~/Projects/vllm/.vllm`
- Python 3.12
> **Heads up before you start.** On SM120/121 this kernel is not a free speedup. The HF README's own benchmarks show FA4 is **4–10% slower** than vLLM's bundled FA2 on realistic Qwen3 prefill shapes, and **~1.7× slower** on very short sequences (S = 128). Use this when you specifically need FA4-only features (paged KV with FA4, `score_mod`, block sparse, dropout in attention). For pure throughput, stay on FA2 and don't read further.
---
## 1. Why a shim is needed
vLLM in CUDA does **not** import the upstream `flash_attn` package. It uses its own bundled fork:
```python
# vllm/v1/attention/backends/fa_utils.py
if current_platform.is_cuda():
from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata
```
`vllm.vllm_flash_attn` calls compiled C++ extensions (`_vllm_fa2_C`, `_vllm_fa3_C`) and its own CuTe DSL implementation under `vllm/vllm_flash_attn/cute/` for FA4. So aliasing `sys.modules['flash_attn']` to the HF kernel **does nothing** for the main attention path.
vLLM already ships SM120 CuTe kernels (`vllm/vllm_flash_attn/cute/flash_fwd_sm120.py`), but in the version most people have installed they are **gated**:
```python
# vllm/vllm_flash_attn/cute/interface.py
assert page_table is None, "Paged KV not supported on SM 12.0 in this PR"
assert not is_split_kv, "SplitKV not supported on SM 12.0 in this PR"
```
vLLM serves with paged KV always, so the bundled FA4 SM120 path can't run during serving. Net effect: on SM12x, vLLM falls back to FA2.
The HF kernel bundles two PRs that vLLM's bundled version is missing:
- **#2348** — SM120 kernel-level paged KV cache support
- **#2336** — SM120 split-KV (FlashDecoding)
So the only way to use FA4-with-paged-KV on SM12x today is to redirect vLLM's FA4 call site (`vllm.vllm_flash_attn.cute.interface._flash_attn_fwd`) to the HF kernel. That's what this shim does.
---
## 2. Prerequisites
```bash
# Inside your vLLM venv
uv pip install -U "kernels>=0.4" "nvidia-cutlass-dsl>=4.4.1" einops apache-tvm-ffi
```
Confirm hardware:
```bash
nvidia-smi --query-gpu=name,compute_cap --format=csv
# Expect compute_cap >= 12.0 (e.g. 12.1 on DGX Spark GB10)
```
CUDA Toolkit must be **12.8 or newer** (FA4 baseline).
---
## 3. Pre-download the kernel
```bash
python -c "from kernels import get_kernel; get_kernel('SecondNatureComputing/flash-attn-4-sm120')"
```
Sanity check:
```bash
python - <<'PY'
import torch
from kernels import get_kernel
fa4 = get_kernel("SecondNatureComputing/flash-attn-4-sm120")
q = k = v = torch.randn(1, 1024, 16, 128, device="cuda", dtype=torch.bfloat16)
out, _ = fa4.flash_attn_func(q, k, v, causal=True)
print("OK", out.shape, out.dtype)
PY
```
---
## 4. Install the shim
The shim does three things, lazily, when the relevant vLLM modules first load:
1. Patches `vllm.vllm_flash_attn.flash_attn_interface._is_fa4_supported` to accept SM 9.x / 10.x / 11.x / 12.x.
2. Wraps `vllm.vllm_flash_attn.cute.interface._flash_attn_fwd` so that on SM12x the call dispatches to the HF kernel's `_flash_attn_fwd` (which has paged KV).
3. Wraps `vllm.v1.attention.backends.fa_utils.get_flash_attn_version` so that on SM12x with `head_dim ≤ 128` it returns `4` (instead of the default `2`).
The wiring is done via a `sys.meta_path` finder that runs the patches the moment those vLLM modules finish loading — before any other module captures `get_flash_attn_version` by name.
> **Note on `sitecustomize.py`.** Ubuntu ships `/usr/lib/python3.12/sitecustomize.py` which takes precedence over a venv-local one, so we use a `.pth` file (`zzz_*.pth`) instead. `.pth` files with `import …` lines are processed by `site.py` at interpreter startup, regardless of OS-level `sitecustomize`.
### 4.1 Create the shim file
```bash
SP="$VIRTUAL_ENV/lib/python3.12/site-packages"
# Remove any earlier experimental shim
rm -f "$SP/zzz_fa4_sm120_shim.pth" "$SP/fa4_sm120_shim.py" "$SP/sitecustomize.py"
cat > "$SP/fa4_sm120_shim.py" <<'PY'
"""
Force vLLM on SM12x (RTX 5090 / RTX PRO 6000 / DGX Spark) to use the
SecondNatureComputing/flash-attn-4-sm120 HF kernel for the FA4 path.
Disable by setting env var FA4_SM120_SHIM=0 before launching.
"""
from __future__ import annotations
import os
import sys
import warnings
from importlib.abc import Loader, MetaPathFinder
if os.environ.get("FA4_SM120_SHIM", "1") != "0":
_HF_KERNEL = None
def _hf_kernel():
global _HF_KERNEL
if _HF_KERNEL is None:
from kernels import get_kernel
_HF_KERNEL = get_kernel("SecondNatureComputing/flash-attn-4-sm120")
return _HF_KERNEL
_PATCHED: set[str] = set()
def _is_sm12x() -> bool:
try:
import torch
if not torch.cuda.is_available():
return False
major, _ = torch.cuda.get_device_capability()
return major == 12
except Exception:
return False
def _patch_fa_iface(mod):
def _is_fa4_supported_patched():
if not getattr(mod, "FA4_AVAILABLE", False):
return False, getattr(mod, "FA4_UNAVAILABLE_REASON", "FA4 unavailable")
try:
import torch
major, _ = torch.cuda.get_device_capability()
except Exception:
return False, "no CUDA device"
if major in (9, 10, 11, 12):
return True, None
return False, f"FA4 not supported on capability {major}.x"
mod._is_fa4_supported = _is_fa4_supported_patched
def _patch_cute_iface(mod):
orig = mod._flash_attn_fwd
def _dispatch(*args, **kwargs):
if _is_sm12x():
return _hf_kernel().interface._flash_attn_fwd(*args, **kwargs)
return orig(*args, **kwargs)
mod._flash_attn_fwd = _dispatch
def _patch_fa_utils(mod):
import functools
orig = mod.get_flash_attn_version
@functools.wraps(orig)
def patched(requires_alibi: bool = False,
head_size: int | None = None,
head_size_v: int | None = None,
has_sinks: bool = False):
if not _is_sm12x():
return orig(requires_alibi=requires_alibi,
head_size=head_size,
head_size_v=head_size_v,
has_sinks=has_sinks)
if requires_alibi:
return 2
if head_size is not None and head_size > 128:
return 2
try:
from vllm.vllm_flash_attn.flash_attn_interface import is_fa_version_supported
if is_fa_version_supported(4):
return 4
except Exception:
pass
return 2
mod.get_flash_attn_version = patched
_DISPATCH = {
"vllm.vllm_flash_attn.flash_attn_interface": _patch_fa_iface,
"vllm.vllm_flash_attn.cute.interface": _patch_cute_iface,
"vllm.v1.attention.backends.fa_utils": _patch_fa_utils,
}
def _try_patch(name: str):
if name in _PATCHED or name not in _DISPATCH:
return
mod = sys.modules.get(name)
if mod is None:
return
try:
_DISPATCH[name](mod)
_PATCHED.add(name)
except Exception as e:
warnings.warn(f"[fa4_sm120_shim] failed to patch {name}: {e!r}")
class _WrappedLoader(Loader):
def __init__(self, real, name):
self._real = real
self._name = name
def create_module(self, spec):
if hasattr(self._real, "create_module"):
return self._real.create_module(spec)
return None
def exec_module(self, module):
self._real.exec_module(module)
_try_patch(self._name)
class _PatchingFinder(MetaPathFinder):
def find_spec(self, name, path=None, target=None):
if name not in _DISPATCH or name in _PATCHED:
return None
for finder in list(sys.meta_path):
if finder is self or not hasattr(finder, "find_spec"):
continue
spec = finder.find_spec(name, path, target)
if spec is not None and spec.loader is not None:
spec.loader = _WrappedLoader(spec.loader, name)
return spec
return None
if not getattr(sys, "_fa4_sm120_shim_installed", False):
sys.meta_path.insert(0, _PatchingFinder())
try:
sys._fa4_sm120_shim_installed = True # type: ignore[attr-defined]
except Exception:
pass
for _n in list(_DISPATCH):
_try_patch(_n)
PY
```
### 4.2 Auto-load via a `.pth` file
```bash
cat > "$SP/zzz_fa4_sm120_shim.pth" <<'PTH'
import fa4_sm120_shim
PTH
```
`.pth` files starting with `import …` are executed by `site.py` at interpreter startup. The `zzz_` prefix sorts our file last so it runs after `torch`, `kernels`, etc.
---
## 5. Verify
### 5.1 Hook is installed
```bash
python - <<'PY'
import sys
import fa4_sm120_shim
print("hook installed:", any("PatchingFinder" in type(f).__name__ for f in sys.meta_path))
PY
```
Expected: `hook installed: True`.
### 5.2 Patches applied to vLLM
```bash
python - <<'PY'
import vllm.vllm_flash_attn.flash_attn_interface as fai
import vllm.vllm_flash_attn.cute.interface as cute
import vllm.v1.attention.backends.fa_utils as fau
print("FA4_AVAILABLE: ", fai.FA4_AVAILABLE)
print("_is_fa4_supported(): ", fai._is_fa4_supported())
print("get_flash_attn_version(head_size=128):", fau.get_flash_attn_version(head_size=128))
print("get_flash_attn_version(head_size=256):", fau.get_flash_attn_version(head_size=256))
print("_flash_attn_fwd qualname: ", cute._flash_attn_fwd.__qualname__)
PY
```
Expected:
```
FA4_AVAILABLE: True
_is_fa4_supported(): (True, None)
get_flash_attn_version(head_size=128): 4
get_flash_attn_version(head_size=256): 2
_flash_attn_fwd qualname: _patch_cute_iface.<locals>._dispatch
```
If `_flash_attn_fwd qualname` shows the original (something like `cute.interface._flash_attn_fwd`), the hook didn't run for that module — see "Troubleshooting".
### 5.3 End-to-end FA4 fwd through vLLM's call site
This is closer to what vLLM actually invokes:
```bash
python - <<'PY'
import torch
from vllm.vllm_flash_attn.cute.interface import _flash_attn_fwd
B, S, Hq, Hkv, D = 1, 1024, 16, 8, 128
q = torch.randn(B*S, Hq, D, device="cuda", dtype=torch.bfloat16)
k = torch.randn(B*S, Hkv, D, device="cuda", dtype=torch.bfloat16)
v = torch.randn(B*S, Hkv, D, device="cuda", dtype=torch.bfloat16)
cu = torch.tensor([0, B*S], device="cuda", dtype=torch.int32)
out, lse = _flash_attn_fwd(
q, k, v,
cu_seqlens_q=cu, cu_seqlens_k=cu,
max_seqlen_q=B*S, max_seqlen_k=B*S,
softmax_scale=D**-0.5, causal=True, return_lse=True,
)
print("OK", out.shape, out.dtype, "lse:", None if lse is None else lse.shape)
PY
```
If this completes without error, the HF kernel is being driven through vLLM's FA4 entry point.
---
## 6. Launch vLLM
Run with debug logging the first time so you can verify the routing:
```bash
VLLM_LOGGING_LEVEL=INFO vllm serve Qwen/Qwen3.5-27B \
--speculative-config '{"method": "dflash", "model": "z-lab/Qwen3.5-27B-DFlash", "num_speculative_tokens": 15}' \
--attention-backend flash_attn \
--gpu-memory-utilization 0.85 \
--max-model-len 65536 \
--load-format fastsafetensors \
--max-num-batched-tokens 32768 2>&1 | tee vllm.log
```
In another terminal:
```bash
grep -iE "fa version|flash_attn|attention backend|sm12|fa4" vllm.log | head -50
```
Look for log lines mentioning `fa_version=4` or the chosen attention backend.
---
## 7. Disable the shim
Without removing files:
```bash
FA4_SM120_SHIM=0 vllm serve ...
```
Permanently:
```bash
SP="$VIRTUAL_ENV/lib/python3.12/site-packages"
rm -f "$SP/fa4_sm120_shim.py" "$SP/zzz_fa4_sm120_shim.pth"
```
---
## 8. Troubleshooting
### `_flash_attn_fwd qualname` still shows the original
Most likely the module was imported in some other Python process or by some import path that loaded it before the finder was active. Make sure:
- You created `zzz_fa4_sm120_shim.pth` (the `zzz_` prefix matters for ordering).
- You don't have a competing `.pth` that imports vLLM modules earlier.
- The shim file is in the *venv's* site-packages, not the system one.
Sanity:
```bash
python -c "import sys; print('\n'.join(p for p in sys.path if 'site-packages' in p))"
ls "$VIRTUAL_ENV/lib/python3.12/site-packages/" | grep fa4_sm120
```
### `assert is_fa_version_supported(4)` fires somewhere
We already patch `_is_fa4_supported`, but if some code path imported `is_fa_version_supported` before the patch ran and captured it by name, it might still see the unpatched version. Reproduce with the verifier in §5.2 and pinpoint which module — open an issue or extend `_DISPATCH` to also patch the offending module.
### Tensor-shape / stride asserts inside the HF kernel
Possible signature drift between vLLM's call site and the HF kernel's `_flash_attn_fwd`. Compare:
- vLLM call: `vllm/vllm_flash_attn/flash_attn_interface.py` around `elif fa_version == 4:`
- HF `_flash_attn_fwd`: `~/.cache/huggingface/hub/models--SecondNatureComputing--flash-attn-4-sm120/snapshots/<hash>/build/torch-cuda/interface.py`
If a kwarg vLLM passes isn't accepted by the HF kernel, drop it inside `_dispatch` before forwarding. As of HF kernel v0.1.0 all kwargs vLLM passes (`cu_seqlens_q/k`, `seqused_k`, `max_seqlen_q/k`, `page_table`, `softmax_scale`, `causal`, `softcap`, `window_size_left/right`, `num_splits`, `return_lse`, `out`, `learnable_sink`) are present.
### `head_dim > 128` model
The SM120 kernel cannot fit `head_dim > 128` in 99 KB SMEM. The shim's `get_flash_attn_version` already returns 2 in that case, so vLLM falls back to FA2. If the model has `head_dim == 256`, **don't bother with this shim** — neither this kernel nor vLLM's bundled FA4 will run, and vLLM's FA2 already handles it.
Check head_dim:
```bash
python - <<'PY'
from huggingface_hub import hf_hub_download
import json
p = hf_hub_download("Qwen/Qwen3.5-27B", "config.json")
c = json.load(open(p))
print("head_dim:", c.get("head_dim") or (c["hidden_size"] // c["num_attention_heads"]))
PY
```
### `Paged KV not supported on SM 12.0 in this PR`
This assert is in vLLM's *bundled* `cute/interface.py`. If you see it, the dispatch didn't reach our wrapper. Re-run the §5.2 verifier; if `_flash_attn_fwd qualname` doesn't say `_dispatch`, the patch is not in effect.
---
## 9. How it works (under the hood)
```
.vllm/lib/python3.12/site-packages/
├── zzz_fa4_sm120_shim.pth ← processed by site.py at startup; runs `import fa4_sm120_shim`
├── fa4_sm120_shim.py ← installs sys.meta_path finder, patches three vLLM modules
│
└── vllm/
├── v1/attention/backends/
│ └── fa_utils.py ← get_flash_attn_version() — patched: returns 4 on SM12x
└── vllm_flash_attn/
├── flash_attn_interface.py ← _is_fa4_supported() — patched: accepts SM 9-12.x
│ └── flash_attn_varlen_func — dispatches on fa_version, calls _flash_attn_fwd if 4
└── cute/
└── interface.py ← _flash_attn_fwd() — patched: dispatches to HF kernel on SM12x
~/.cache/huggingface/hub/models--SecondNatureComputing--flash-attn-4-sm120/
└── snapshots/<hash>/build/torch-cuda/
└── interface.py ← real _flash_attn_fwd with paged KV (PR #2348)
```
Flow at request time:
1. Some vLLM backend (`vllm/v1/attention/backends/flash_attn.py`) calls `get_flash_attn_version(head_size=128)`.
2. **Patched** version returns `4` (instead of `2`).
3. The backend calls `flash_attn_varlen_func(..., fa_version=4)`.
4. Inside, the FA4 branch does `from vllm.vllm_flash_attn.cute.interface import _flash_attn_fwd`.
5. **Patched** `_flash_attn_fwd` is a `_dispatch` wrapper. On SM12x it forwards to `hf_kernel.interface._flash_attn_fwd`, otherwise it calls the original.
6. The HF kernel runs, returns `(out, softmax_lse)`.
---
## 11. Caveats
- The shim assumes vLLM's internal API names (`_flash_attn_fwd`, `is_fa_version_supported`, `get_flash_attn_version`). If you upgrade vLLM and these change, the shim's verifier in §5.2 will tell you immediately.
- On SM12x, the HF kernel's `interface.py` clamps `num_splits` to 1 (no SplitKV). Decode workloads use a single split anyway, but if you've tuned vLLM with `num_splits > 1` it'll be silently ignored.
- Dropout in this kernel falls back to smaller tiles to avoid register spills. Throughput cost is small but real.
- Backward pass on SM12x has its own restrictions (no block sparse, no `score_mod`, no `mask_mod`, no deterministic). Inference is forward-only so this doesn't affect serving.
---
## License
The HF kernel is BSD-3-Clause (inherited from Dao-AILab/flash-attention).
The shim above is provided as-is, no warranty.