---
# System prepended metadata

title: vLLM Qwen3.6 DFlash + FA Cute DSL

---

Install VLLM Spark

1. Install uv
```bash 
curl -LsSf https://astral.sh/uv/install.sh | sh
```
2. Create environment

```bash
sudo apt install python3-dev
uv venv .vllm --python 3.12
source .vllm/bin/activate
```

3. Install vllm

```bash 
uv pip install -U vllm --torch-backend=auto --extra-index-url https://wheels.vllm.ai/nightly/cu130
uv pip install --prerelease=allow --force-reinstall triton --index-url https://download.pytorch.org/whl/test/cu132

```

4. Export variables
```bash
export TORCH_CUDA_ARCH_LIST=12.1a
```

5. Clean memory
```bash
sudo sysctl -w vm.drop_caches=3
```


# Wiring `flash-attn-4-sm120` into vLLM on SM120 / SM121

How to make vLLM use the [SecondNatureComputing/flash-attn-4-sm120](https://huggingface.co/SecondNatureComputing/flash-attn-4-sm120) Hugging Face kernel for the FA4 path on consumer Blackwell GPUs (RTX 5090, RTX PRO 6000 Blackwell, DGX Spark GB10 / SM121a).

Tested on:

- DGX Spark (GB10), compute capability **12.1**
- vLLM installed in a venv at `~/Projects/vllm/.vllm`
- Python 3.12

> **Heads up before you start.** On SM120/121 this kernel is not a free speedup. The HF README's own benchmarks show FA4 is **4–10% slower** than vLLM's bundled FA2 on realistic Qwen3 prefill shapes, and **~1.7× slower** on very short sequences (S = 128). Use this when you specifically need FA4-only features (paged KV with FA4, `score_mod`, block sparse, dropout in attention). For pure throughput, stay on FA2 and don't read further.

---

## 1. Why a shim is needed

vLLM in CUDA does **not** import the upstream `flash_attn` package. It uses its own bundled fork:

```python
# vllm/v1/attention/backends/fa_utils.py
if current_platform.is_cuda():
    from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata
```

`vllm.vllm_flash_attn` calls compiled C++ extensions (`_vllm_fa2_C`, `_vllm_fa3_C`) and its own CuTe DSL implementation under `vllm/vllm_flash_attn/cute/` for FA4. So aliasing `sys.modules['flash_attn']` to the HF kernel **does nothing** for the main attention path.

vLLM already ships SM120 CuTe kernels (`vllm/vllm_flash_attn/cute/flash_fwd_sm120.py`), but in the version most people have installed they are **gated**:

```python
# vllm/vllm_flash_attn/cute/interface.py
assert page_table is None,    "Paged KV not supported on SM 12.0 in this PR"
assert not is_split_kv,       "SplitKV not supported on SM 12.0 in this PR"
```

vLLM serves with paged KV always, so the bundled FA4 SM120 path can't run during serving. Net effect: on SM12x, vLLM falls back to FA2.

The HF kernel bundles two PRs that vLLM's bundled version is missing:

- **#2348** — SM120 kernel-level paged KV cache support
- **#2336** — SM120 split-KV (FlashDecoding)

So the only way to use FA4-with-paged-KV on SM12x today is to redirect vLLM's FA4 call site (`vllm.vllm_flash_attn.cute.interface._flash_attn_fwd`) to the HF kernel. That's what this shim does.

---

## 2. Prerequisites

```bash
# Inside your vLLM venv
uv pip install -U "kernels>=0.4" "nvidia-cutlass-dsl>=4.4.1" einops apache-tvm-ffi
```

Confirm hardware:

```bash
nvidia-smi --query-gpu=name,compute_cap --format=csv
# Expect compute_cap >= 12.0  (e.g. 12.1 on DGX Spark GB10)
```

CUDA Toolkit must be **12.8 or newer** (FA4 baseline).

---

## 3. Pre-download the kernel

```bash
python -c "from kernels import get_kernel; get_kernel('SecondNatureComputing/flash-attn-4-sm120')"
```

Sanity check:

```bash
python - <<'PY'
import torch
from kernels import get_kernel
fa4 = get_kernel("SecondNatureComputing/flash-attn-4-sm120")
q = k = v = torch.randn(1, 1024, 16, 128, device="cuda", dtype=torch.bfloat16)
out, _ = fa4.flash_attn_func(q, k, v, causal=True)
print("OK", out.shape, out.dtype)
PY
```

---

## 4. Install the shim

The shim does three things, lazily, when the relevant vLLM modules first load:

1. Patches `vllm.vllm_flash_attn.flash_attn_interface._is_fa4_supported` to accept SM 9.x / 10.x / 11.x / 12.x.
2. Wraps `vllm.vllm_flash_attn.cute.interface._flash_attn_fwd` so that on SM12x the call dispatches to the HF kernel's `_flash_attn_fwd` (which has paged KV).
3. Wraps `vllm.v1.attention.backends.fa_utils.get_flash_attn_version` so that on SM12x with `head_dim ≤ 128` it returns `4` (instead of the default `2`).

The wiring is done via a `sys.meta_path` finder that runs the patches the moment those vLLM modules finish loading — before any other module captures `get_flash_attn_version` by name.

> **Note on `sitecustomize.py`.** Ubuntu ships `/usr/lib/python3.12/sitecustomize.py` which takes precedence over a venv-local one, so we use a `.pth` file (`zzz_*.pth`) instead. `.pth` files with `import …` lines are processed by `site.py` at interpreter startup, regardless of OS-level `sitecustomize`.

### 4.1 Create the shim file

```bash
SP="$VIRTUAL_ENV/lib/python3.12/site-packages"
# Remove any earlier experimental shim
rm -f "$SP/zzz_fa4_sm120_shim.pth" "$SP/fa4_sm120_shim.py" "$SP/sitecustomize.py"

cat > "$SP/fa4_sm120_shim.py" <<'PY'
"""
Force vLLM on SM12x (RTX 5090 / RTX PRO 6000 / DGX Spark) to use the
SecondNatureComputing/flash-attn-4-sm120 HF kernel for the FA4 path.

Disable by setting env var FA4_SM120_SHIM=0 before launching.
"""
from __future__ import annotations

import os
import sys
import warnings
from importlib.abc import Loader, MetaPathFinder

if os.environ.get("FA4_SM120_SHIM", "1") != "0":

    _HF_KERNEL = None

    def _hf_kernel():
        global _HF_KERNEL
        if _HF_KERNEL is None:
            from kernels import get_kernel
            _HF_KERNEL = get_kernel("SecondNatureComputing/flash-attn-4-sm120")
        return _HF_KERNEL

    _PATCHED: set[str] = set()

    def _is_sm12x() -> bool:
        try:
            import torch
            if not torch.cuda.is_available():
                return False
            major, _ = torch.cuda.get_device_capability()
            return major == 12
        except Exception:
            return False

    def _patch_fa_iface(mod):
        def _is_fa4_supported_patched():
            if not getattr(mod, "FA4_AVAILABLE", False):
                return False, getattr(mod, "FA4_UNAVAILABLE_REASON", "FA4 unavailable")
            try:
                import torch
                major, _ = torch.cuda.get_device_capability()
            except Exception:
                return False, "no CUDA device"
            if major in (9, 10, 11, 12):
                return True, None
            return False, f"FA4 not supported on capability {major}.x"
        mod._is_fa4_supported = _is_fa4_supported_patched

    def _patch_cute_iface(mod):
        orig = mod._flash_attn_fwd

        def _dispatch(*args, **kwargs):
            if _is_sm12x():
                return _hf_kernel().interface._flash_attn_fwd(*args, **kwargs)
            return orig(*args, **kwargs)

        mod._flash_attn_fwd = _dispatch

    def _patch_fa_utils(mod):
        import functools
        orig = mod.get_flash_attn_version

        @functools.wraps(orig)
        def patched(requires_alibi: bool = False,
                    head_size: int | None = None,
                    head_size_v: int | None = None,
                    has_sinks: bool = False):
            if not _is_sm12x():
                return orig(requires_alibi=requires_alibi,
                            head_size=head_size,
                            head_size_v=head_size_v,
                            has_sinks=has_sinks)
            if requires_alibi:
                return 2
            if head_size is not None and head_size > 128:
                return 2
            try:
                from vllm.vllm_flash_attn.flash_attn_interface import is_fa_version_supported
                if is_fa_version_supported(4):
                    return 4
            except Exception:
                pass
            return 2

        mod.get_flash_attn_version = patched

    _DISPATCH = {
        "vllm.vllm_flash_attn.flash_attn_interface": _patch_fa_iface,
        "vllm.vllm_flash_attn.cute.interface":      _patch_cute_iface,
        "vllm.v1.attention.backends.fa_utils":      _patch_fa_utils,
    }

    def _try_patch(name: str):
        if name in _PATCHED or name not in _DISPATCH:
            return
        mod = sys.modules.get(name)
        if mod is None:
            return
        try:
            _DISPATCH[name](mod)
            _PATCHED.add(name)
        except Exception as e:
            warnings.warn(f"[fa4_sm120_shim] failed to patch {name}: {e!r}")

    class _WrappedLoader(Loader):
        def __init__(self, real, name):
            self._real = real
            self._name = name

        def create_module(self, spec):
            if hasattr(self._real, "create_module"):
                return self._real.create_module(spec)
            return None

        def exec_module(self, module):
            self._real.exec_module(module)
            _try_patch(self._name)

    class _PatchingFinder(MetaPathFinder):
        def find_spec(self, name, path=None, target=None):
            if name not in _DISPATCH or name in _PATCHED:
                return None
            for finder in list(sys.meta_path):
                if finder is self or not hasattr(finder, "find_spec"):
                    continue
                spec = finder.find_spec(name, path, target)
                if spec is not None and spec.loader is not None:
                    spec.loader = _WrappedLoader(spec.loader, name)
                    return spec
            return None

    if not getattr(sys, "_fa4_sm120_shim_installed", False):
        sys.meta_path.insert(0, _PatchingFinder())
        try:
            sys._fa4_sm120_shim_installed = True  # type: ignore[attr-defined]
        except Exception:
            pass

    for _n in list(_DISPATCH):
        _try_patch(_n)
PY
```

### 4.2 Auto-load via a `.pth` file

```bash
cat > "$SP/zzz_fa4_sm120_shim.pth" <<'PTH'
import fa4_sm120_shim
PTH
```

`.pth` files starting with `import …` are executed by `site.py` at interpreter startup. The `zzz_` prefix sorts our file last so it runs after `torch`, `kernels`, etc.

---

## 5. Verify

### 5.1 Hook is installed

```bash
python - <<'PY'
import sys
import fa4_sm120_shim
print("hook installed:", any("PatchingFinder" in type(f).__name__ for f in sys.meta_path))
PY
```

Expected: `hook installed: True`.

### 5.2 Patches applied to vLLM

```bash
python - <<'PY'
import vllm.vllm_flash_attn.flash_attn_interface as fai
import vllm.vllm_flash_attn.cute.interface as cute
import vllm.v1.attention.backends.fa_utils as fau

print("FA4_AVAILABLE:                       ", fai.FA4_AVAILABLE)
print("_is_fa4_supported():                  ", fai._is_fa4_supported())
print("get_flash_attn_version(head_size=128):", fau.get_flash_attn_version(head_size=128))
print("get_flash_attn_version(head_size=256):", fau.get_flash_attn_version(head_size=256))
print("_flash_attn_fwd qualname:             ", cute._flash_attn_fwd.__qualname__)
PY
```

Expected:

```
FA4_AVAILABLE:                        True
_is_fa4_supported():                  (True, None)
get_flash_attn_version(head_size=128): 4
get_flash_attn_version(head_size=256): 2
_flash_attn_fwd qualname:              _patch_cute_iface.<locals>._dispatch
```

If `_flash_attn_fwd qualname` shows the original (something like `cute.interface._flash_attn_fwd`), the hook didn't run for that module — see "Troubleshooting".

### 5.3 End-to-end FA4 fwd through vLLM's call site

This is closer to what vLLM actually invokes:

```bash
python - <<'PY'
import torch
from vllm.vllm_flash_attn.cute.interface import _flash_attn_fwd

B, S, Hq, Hkv, D = 1, 1024, 16, 8, 128
q = torch.randn(B*S, Hq,  D, device="cuda", dtype=torch.bfloat16)
k = torch.randn(B*S, Hkv, D, device="cuda", dtype=torch.bfloat16)
v = torch.randn(B*S, Hkv, D, device="cuda", dtype=torch.bfloat16)
cu = torch.tensor([0, B*S], device="cuda", dtype=torch.int32)

out, lse = _flash_attn_fwd(
    q, k, v,
    cu_seqlens_q=cu, cu_seqlens_k=cu,
    max_seqlen_q=B*S, max_seqlen_k=B*S,
    softmax_scale=D**-0.5, causal=True, return_lse=True,
)
print("OK", out.shape, out.dtype, "lse:", None if lse is None else lse.shape)
PY
```

If this completes without error, the HF kernel is being driven through vLLM's FA4 entry point.

---

## 6. Launch vLLM

Run with debug logging the first time so you can verify the routing:

```bash
VLLM_LOGGING_LEVEL=INFO vllm serve Qwen/Qwen3.5-27B \
  --speculative-config '{"method": "dflash", "model": "z-lab/Qwen3.5-27B-DFlash", "num_speculative_tokens": 15}' \
  --attention-backend flash_attn \
  --gpu-memory-utilization 0.85 \
  --max-model-len 65536 \
  --load-format fastsafetensors \
  --max-num-batched-tokens 32768 2>&1 | tee vllm.log
```

In another terminal:

```bash
grep -iE "fa version|flash_attn|attention backend|sm12|fa4" vllm.log | head -50
```

Look for log lines mentioning `fa_version=4` or the chosen attention backend.

---

## 7. Disable the shim

Without removing files:

```bash
FA4_SM120_SHIM=0 vllm serve ...
```

Permanently:

```bash
SP="$VIRTUAL_ENV/lib/python3.12/site-packages"
rm -f "$SP/fa4_sm120_shim.py" "$SP/zzz_fa4_sm120_shim.pth"
```

---

## 8. Troubleshooting

### `_flash_attn_fwd qualname` still shows the original

Most likely the module was imported in some other Python process or by some import path that loaded it before the finder was active. Make sure:

- You created `zzz_fa4_sm120_shim.pth` (the `zzz_` prefix matters for ordering).
- You don't have a competing `.pth` that imports vLLM modules earlier.
- The shim file is in the *venv's* site-packages, not the system one.

Sanity:

```bash
python -c "import sys; print('\n'.join(p for p in sys.path if 'site-packages' in p))"
ls "$VIRTUAL_ENV/lib/python3.12/site-packages/" | grep fa4_sm120
```

### `assert is_fa_version_supported(4)` fires somewhere

We already patch `_is_fa4_supported`, but if some code path imported `is_fa_version_supported` before the patch ran and captured it by name, it might still see the unpatched version. Reproduce with the verifier in §5.2 and pinpoint which module — open an issue or extend `_DISPATCH` to also patch the offending module.

### Tensor-shape / stride asserts inside the HF kernel

Possible signature drift between vLLM's call site and the HF kernel's `_flash_attn_fwd`. Compare:

- vLLM call: `vllm/vllm_flash_attn/flash_attn_interface.py` around `elif fa_version == 4:`
- HF `_flash_attn_fwd`: `~/.cache/huggingface/hub/models--SecondNatureComputing--flash-attn-4-sm120/snapshots/<hash>/build/torch-cuda/interface.py`

If a kwarg vLLM passes isn't accepted by the HF kernel, drop it inside `_dispatch` before forwarding. As of HF kernel v0.1.0 all kwargs vLLM passes (`cu_seqlens_q/k`, `seqused_k`, `max_seqlen_q/k`, `page_table`, `softmax_scale`, `causal`, `softcap`, `window_size_left/right`, `num_splits`, `return_lse`, `out`, `learnable_sink`) are present.

### `head_dim > 128` model

The SM120 kernel cannot fit `head_dim > 128` in 99 KB SMEM. The shim's `get_flash_attn_version` already returns 2 in that case, so vLLM falls back to FA2. If the model has `head_dim == 256`, **don't bother with this shim** — neither this kernel nor vLLM's bundled FA4 will run, and vLLM's FA2 already handles it.

Check head_dim:

```bash
python - <<'PY'
from huggingface_hub import hf_hub_download
import json
p = hf_hub_download("Qwen/Qwen3.5-27B", "config.json")
c = json.load(open(p))
print("head_dim:", c.get("head_dim") or (c["hidden_size"] // c["num_attention_heads"]))
PY
```

### `Paged KV not supported on SM 12.0 in this PR`

This assert is in vLLM's *bundled* `cute/interface.py`. If you see it, the dispatch didn't reach our wrapper. Re-run the §5.2 verifier; if `_flash_attn_fwd qualname` doesn't say `_dispatch`, the patch is not in effect.

---

## 9. How it works (under the hood)

```
.vllm/lib/python3.12/site-packages/
├── zzz_fa4_sm120_shim.pth          ← processed by site.py at startup; runs `import fa4_sm120_shim`
├── fa4_sm120_shim.py               ← installs sys.meta_path finder, patches three vLLM modules
│
└── vllm/
    ├── v1/attention/backends/
    │   └── fa_utils.py             ← get_flash_attn_version()  — patched: returns 4 on SM12x
    └── vllm_flash_attn/
        ├── flash_attn_interface.py ← _is_fa4_supported()       — patched: accepts SM 9-12.x
        │   └── flash_attn_varlen_func — dispatches on fa_version, calls _flash_attn_fwd if 4
        └── cute/
            └── interface.py        ← _flash_attn_fwd()         — patched: dispatches to HF kernel on SM12x

~/.cache/huggingface/hub/models--SecondNatureComputing--flash-attn-4-sm120/
└── snapshots/<hash>/build/torch-cuda/
    └── interface.py                ← real _flash_attn_fwd with paged KV (PR #2348)
```

Flow at request time:

1. Some vLLM backend (`vllm/v1/attention/backends/flash_attn.py`) calls `get_flash_attn_version(head_size=128)`.
2. **Patched** version returns `4` (instead of `2`).
3. The backend calls `flash_attn_varlen_func(..., fa_version=4)`.
4. Inside, the FA4 branch does `from vllm.vllm_flash_attn.cute.interface import _flash_attn_fwd`.
5. **Patched** `_flash_attn_fwd` is a `_dispatch` wrapper. On SM12x it forwards to `hf_kernel.interface._flash_attn_fwd`, otherwise it calls the original.
6. The HF kernel runs, returns `(out, softmax_lse)`.

---

## 11. Caveats

- The shim assumes vLLM's internal API names (`_flash_attn_fwd`, `is_fa_version_supported`, `get_flash_attn_version`). If you upgrade vLLM and these change, the shim's verifier in §5.2 will tell you immediately.
- On SM12x, the HF kernel's `interface.py` clamps `num_splits` to 1 (no SplitKV). Decode workloads use a single split anyway, but if you've tuned vLLM with `num_splits > 1` it'll be silently ignored.
- Dropout in this kernel falls back to smaller tiles to avoid register spills. Throughput cost is small but real.
- Backward pass on SM12x has its own restrictions (no block sparse, no `score_mod`, no `mask_mod`, no deterministic). Inference is forward-only so this doesn't affect serving.

---

## License

The HF kernel is BSD-3-Clause (inherited from Dao-AILab/flash-attention).
The shim above is provided as-is, no warranty.