first commit

This commit is contained in:
2026-06-14 23:49:03 +08:00
commit 3f95e2939d
35 changed files with 6764 additions and 0 deletions

21
LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2026 Qwen, Alibaba
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

206
README.md Normal file
View File

@@ -0,0 +1,206 @@
> [!IMPORTANT]
> This repository is an experimental SM70/SM75 fork of [QwenLM/FlashQLA](https://github.com/QwenLM/FlashQLA).
>
> It is not an official FlashQLA release and does not replace the upstream Hopper/SM90 implementation.
# FlashQLA-SM70-SM75
Experimental forward-inference support for Qwen-style Gated DeltaNet on SM70/SM75-class NVIDIA GPUs.
This fork keeps the upstream Hopper/SM90 TileLang path intact and adds an explicit legacy backend entry point for Volta/Turing inference devices. The current runtime validation target is RTX 2080 Ti / SM75. SM70 currently has compile coverage, but V100-class runtime validation is still required before making performance claims.
## Changes in This Fork
- Adds `flash_qla.ops.gated_delta_rule.legacy.chunk_gated_delta_rule_fwd_legacy`.
- Adds a lazy-built CUDA extension for a forward-only SM70/SM75-class Gated DeltaNet backend.
- Keeps the upstream Hopper/SM90 TileLang path unchanged.
- Keeps the legacy path explicit instead of silently replacing the upstream high-level API.
- Adds CUDA correctness tests for the supported legacy path.
- Documents the supported scope, validation status, and benchmark caveats separately from upstream Hopper results.
## Supported Scope
Supported:
- forward inference only
- SM70/SM75-class CUDA devices as the intended legacy target family
- scalar-gate Gated DeltaNet
- Qwen-style grouped-query head mapping
- primary optimized shape: `D=128`
- explicit legacy API entry point
Not supported:
- backward kernels or training
- automatic dispatch from the upstream high-level API
- generic support for all pre-Hopper NVIDIA GPUs
- runtime performance claims for SM70 before V100-class validation
- SM80/SM86/SM89 support claims
- automatic default dispatch for non-Hopper devices
## Current Validation
Runtime validation was performed on RTX 2080 Ti / SM75.
Standalone kernel timing for a Qwen-like shape:
- `B=1, T=512, Hq=16, Hv=32, D=128`
- control recurrent path: about `1.126 ms`
- optimized legacy path on SM75: about `0.520-0.533 ms`
- GDN-stage speedup: about `2.1x`
GGUF runtime profiling on SM75:
- default fused GDN: `406.656 ms`
- legacy fast path: `195.105 ms`
- GDN-stage speedup: about `2.08x`
Whole-request impact under the same server parameters:
- prefill: `+7.17%`
- decode: `+0.61%`
- wall time: `-3.49%`
SM70 status:
- compile check passes
- runtime validation is pending
- V100-class benchmarking is needed before claiming SM70 performance
Fork wrapper status:
- Python syntax check passes
- CUDA tests are included under `tests/test_legacy_sm_gdn.py`
- CUDA PyTorch runtime validation still requires a CUDA-enabled PyTorch environment
## Positioning
This fork is meant to make the SM70/SM75 experiment reproducible and reviewable. It should be treated as an upstreamable experimental branch, not as a separate long-term replacement for FlashQLA.
---
The original upstream README follows below.
<p align="center">
<img src="https://qianwen-res.oss-cn-beijing.aliyuncs.com/flashqla/flashqla.png" width="1000"/>
<p>
<p align="center">|&nbsp&nbsp 📜 <a href="https://qwen.ai/blog?id=flashqla">Blog</a>&nbsp&nbsp |</p>
## Introduction
FlashQLA is a high-performance linear attention kernel library built on [TileLang](https://github.com/tile-ai/tilelang). FlashQLA applies **reasonable operator fusion and performance optimization** to the forward and backward passes of GDN Chunked Prefill, achieving **2-3× forward speedup** and **2× backward speedup** over the FLA Triton kernel across multiple scenarios on NVIDIA Hopper. The efficiency gains are particularly pronounced in pretraining scenarios and edge-side agentic inference.
Key features:
1.**Gate-driven automatic intra-card context parallelism**. By exploiting the exponential decay property of the GDN gate, FlashQLA automatically enables intra-card CP under TP, long-sequence, and small-head-count settings, improving GPU SM utilization.
2.**Hardware-friendly algebraic reformulation**. We reformulate the forward and backward flows of GDN Chunked Prefill to a certain extent, effectively reducing Tensor Core, CUDA Core, and SFU overhead without sacrificing numerical precision.
3.**TileLang fused warp-specialized kernels**. Rather than following the step-by-step decomposition into independent kernels, nor fusing the entire computation flow into a single kernel, we take CP and backward requirements into account, use TileLang to build several key fused kernels, and manually implement warpgroup specialization to overlap data movement, Tensor Core computation, and CUDA Core computation.
## Requirements
- SM90 or above
- CUDA 12.8 or above
- PyTorch 2.8 or above
## Installation
```bash
git clone https://github.com/QwenLM/FlashQLA.git
cd FlashQLA
pip install -v .
```
## Usage
### High-level API
```python
import torch
from flash_qla import chunk_gated_delta_rule
o, final_state = chunk_gated_delta_rule(
q=q, # [B, T, H_q, K]
k=k, # [B, T, H_q, K]
v=v, # [B, T, H_v, V]
g=g, # [B, T, H_v]
beta=beta, # [B, T, H_v]
scale=scale,
initial_state=initial_state, # optional, [B, H_v, K, V]
output_final_state=True,
cu_seqlens=cu_seqlens, # optional, for variable-length sequences
)
```
### Low-level API
For separate forward and backward calls:
```python
from flash_qla import chunk_gated_delta_rule_fwd, chunk_gated_delta_rule_bwd
# Forward
g, A, o, h, final_state = chunk_gated_delta_rule_fwd(
q, k, v, g, beta, scale=scale, initial_state=h0, cu_seqlens=cu_seqlens
)
# Backward
dq, dk, dv, db, dg, dh0 = chunk_gated_delta_rule_bwd(
q, k, v, g, beta, A, do, dht=dht, scale=scale, initial_state=h0, cu_seqlens=cu_seqlens
)
```
## Tests
```bash
# require flash linear attention for comparison
pip install flash_linear_attention==0.5.0
cd tests
python test_gdr.py --set develop
python test_gdr.py --set varlen --num-heads 32
python test_gdr.py --set profile --num-heads 32
python test_gdr.py --set product --ref-dtype float32 --num-heads 32
```
## Benchmark
We benchmarked FlashQLA against the FLA Triton and FlashInfer baseline (FLA 0.5.0, Triton 3.5.1, FlashInfer 0.6.9, TileLang 0.1.8) on the head configurations used by the Qwen3.5 / Qwen3.6 family h_k,v \in {64, 48, 32, 24, 16, 8}, corresponding to TP1 through TP8.
<p align="center">
<img src="https://qianwen-res.oss-cn-beijing.aliyuncs.com/flashqla/fwd_bwd_latency_comparison.png" width="1000"/>
<p>
Specifically, the forward (FWD) benchmarks measure single-kernel latency for different models and TP settings under varying batch lengths, while the backward (BWD) benchmarks examine the relationship between total token count within a batch and latency during a single update step.
More detail in [benchmark_results_H200.txt](./benchmark/benchmark_results_H200.txt).
```bash
# require flash linear attention and flashinfer for comparison
pip install flash_linear_attention==0.5.0 flashinfer-python==0.6.9
cd benchmark
python bench_gated_delta_rule.py
```
## Acknowledge
FlashQLA is inspired by [Flash Linear Attention](https://github.com/fla-org/flash-linear-attention), [TileLang](https://github.com/tile-ai/tilelang) and [FlashInfer](https://github.com/flashinfer-ai/flashinfer/) projects.
## License
FlashQLA is released under the MIT License.
## Citation
```bibtex
@misc{flashqla2025,
title={FlashQLA: Flash Qwen Linear Attention},
author={Zhang, Chengruidong and Lin, Xi and Jiang, Huiqiang and Wang, Zekun and Li, Xiao and Cao, Yizhong and Zhuang, Bohan and Men, Rui and Zhang, Jianwei and Zheng, Bo and Lin, Junyang and Liu, Dayiheng and Zhou, Jingren},
year={2026},
publisher={GitHub},
howpublished={\url{https://github.com/QwenLM/FlashQLA}},
}
```

View File

@@ -0,0 +1,585 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
#
# Benchmark Script for FlashQLA
import argparse
import math
import gc
from dataclasses import dataclass
from typing import List, Tuple, Optional, Dict, Any
import torch
import torch.nn.functional as F
import tilelang
# Kernel Imports
from fla.ops.gated_delta_rule.chunk import (
chunk_gated_delta_rule_fwd as fla_fwd,
chunk_gated_delta_rule_bwd as fla_bwd,
)
from flash_qla import (
chunk_gated_delta_rule_fwd as qla_fwd,
chunk_gated_delta_rule_bwd as qla_bwd,
)
from flash_qla.utils import l2norm
try:
from flashinfer.gdn_prefill import chunk_gated_delta_rule as fi_fwd
HAS_FI = True
except ImportError:
HAS_FI = False
HEAD_DIM = 128
BWD_SPLIT_SIZE = 8
@dataclass
class ModelConfig:
label: str
h_qk: int
h_v: int
@dataclass
class SeqLenConfig:
label: str
seqlens: List[int]
def generate_rand_seqlens(batch_size, num_tokens):
bars = (
torch.sort(
torch.randperm(num_tokens - 1, device="cuda", dtype=torch.int32)[
: batch_size - 1
]
).values
+ 1
)
cu_seqlens = torch.nn.functional.pad(bars, (1, 1))
cu_seqlens[-1] = num_tokens
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
return seqlens.tolist()
FWD_MODEL_CONFIGS = [
ModelConfig("397B/122B TP8", 2, 8),
ModelConfig("397B/122B TP4", 4, 16),
ModelConfig("397B/122B TP2", 8, 32),
ModelConfig("397B/122B TP1", 16, 64),
ModelConfig("35B/9B/4B TP1", 16, 32),
ModelConfig("27B TP2", 8, 24),
ModelConfig("27B TP1", 16, 48),
ModelConfig("2B/0.8B TP1", 16, 16),
ModelConfig("Sym h32", 32, 32),
]
FWD_SEQLEN_CONFIGS = [
SeqLenConfig("1x32768", [32768]),
SeqLenConfig("1x16384", [16384]),
SeqLenConfig("1x8192", [8192]),
SeqLenConfig("1x4096", [4096]),
SeqLenConfig("1x2048", [2048]),
SeqLenConfig("28672+4096", [28672, 4096]),
SeqLenConfig("24576+8192", [24576, 8192]),
SeqLenConfig("16384+16384", [16384, 16384]),
SeqLenConfig("8192+24576", [8192, 24576]),
SeqLenConfig("4096+28672", [4096, 28672]),
SeqLenConfig("12288+4096", [12288, 4096]),
SeqLenConfig("6144+2048", [6144, 2048]),
SeqLenConfig("4096+4096", [4096, 4096]),
SeqLenConfig("2048+6144", [2048, 6144]),
SeqLenConfig("1024+7168", [1024, 7168]),
SeqLenConfig("8192x4", [8192] * 4),
SeqLenConfig("4096x8", [4096] * 8),
SeqLenConfig("2048x4", [2048] * 4),
SeqLenConfig("1024x8", [1024] * 8),
]
BWD_MODEL_CONFIGS = [
ModelConfig("32", 32, 32),
ModelConfig("48", 48, 48),
ModelConfig("64", 64, 64),
]
BWD_SEQLEN_CONFIGS = [
SeqLenConfig("16k", generate_rand_seqlens(8, 16384)),
SeqLenConfig("32k", generate_rand_seqlens(8, 32768)),
SeqLenConfig("64k", generate_rand_seqlens(8, 65536)),
SeqLenConfig("128k", generate_rand_seqlens(8, 131072)),
SeqLenConfig("256k", generate_rand_seqlens(8, 262144)),
]
def cleanup_cuda():
try:
if torch.cuda.is_available():
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()
except Exception:
pass
def get_lib_versions() -> Dict[str, str]:
"""Collects version strings for relevant libraries."""
versions = {}
# Torch
try:
versions["torch"] = torch.__version__
except Exception:
versions["torch"] = "N/A"
# Flash Linear Attention (FLA)
try:
import fla
if hasattr(fla, "__version__"):
versions["fla"] = fla.__version__
else:
versions["fla"] = "Installed (ver unknown)"
except ImportError:
versions["fla"] = "Not Installed"
# FlashInfer
try:
import flashinfer
if hasattr(flashinfer, "__version__"):
versions["flashinfer"] = flashinfer.__version__
elif hasattr(flashinfer, "version"):
versions["flashinfer"] = str(flashinfer.version)
else:
versions["flashinfer"] = "Installed (ver unknown)"
except ImportError:
versions["flashinfer"] = "Not Installed"
# TileLang
try:
import tilelang
if hasattr(tilelang, "__version__"):
versions["tilelang"] = tilelang.__version__
elif hasattr(tilelang, "version"):
versions["tilelang"] = str(tilelang.version)
else:
versions["tilelang"] = "Installed (ver unknown)"
except ImportError:
versions["tilelang"] = "Not Installed"
return versions
def prepare_tensors(
seqlens: List[int], h_qk: int, h_v: int, head_dim: int = HEAD_DIM
) -> Optional[Dict[str, Any]]:
device = "cuda"
num_seqs = len(seqlens)
total_tokens = sum(seqlens)
scale = head_dim ** (-0.5)
offsets = [0]
for s in seqlens:
offsets.append(offsets[-1] + s)
cu_seqlens = torch.tensor(offsets, dtype=torch.int32, device=device)
try:
q = l2norm(
torch.randn(
1, total_tokens, h_qk, head_dim, device=device, dtype=torch.bfloat16
)
)
k = l2norm(
torch.randn(
1, total_tokens, h_qk, head_dim, device=device, dtype=torch.bfloat16
)
)
v = torch.randn(
1, total_tokens, h_v, head_dim, device=device, dtype=torch.bfloat16
)
g = (
F.logsigmoid(
torch.randn(1, total_tokens, h_v, device=device, dtype=torch.float32)
)
/ 16
)
beta = torch.randn(
1, total_tokens, h_v, device=device, dtype=torch.float32
).sigmoid()
h0 = torch.randn(
num_seqs, h_v, head_dim, head_dim, device=device, dtype=torch.float32
)
do = torch.randn_like(v)
dht = (
torch.randn(
num_seqs, h_v, head_dim, head_dim, device=device, dtype=torch.float32
)
/ 8
)
except RuntimeError as e:
if "out of memory" in str(e).lower():
return None
raise e
swa_ratio = 0.75
swa_mask = torch.zeros(h_v, dtype=torch.bool, device=device)
swa_mask[: math.ceil(swa_ratio * h_v)] = True
swa_mask = swa_mask[torch.randperm(h_v, device=device)]
g[:, :, ~swa_mask] = 0.0
return {
"device": device,
"num_seqs": num_seqs,
"total_tokens": total_tokens,
"scale": scale,
"cu_seqlens": cu_seqlens,
"q": q,
"k": k,
"v": v,
"g": g,
"beta": beta,
"h0": h0,
"do": do,
"dht": dht,
}
def bench_fwd(
seqlens: List[int],
h_qk: int,
h_v: int,
head_dim: int = HEAD_DIM,
warmup: int = 10,
repeats: int = 5,
) -> Tuple[float, float, float]:
"""
Run Forward Pass Benchmark.
Returns: (qla_mean_ms, fi_mean_ms, fla_mean_ms)
"""
cleanup_cuda()
data = prepare_tensors(seqlens, h_qk, h_v, head_dim)
if data is None:
return float("nan"), float("nan"), float("nan")
q, k, v, g, beta = data["q"], data["k"], data["v"], data["g"], data["beta"]
h0, scale, cu_seqlens = data["h0"], data["scale"], data["cu_seqlens"]
results = {}
def call_qla_fwd():
qla_fwd(
q,
k,
v,
g,
beta,
scale=scale,
initial_state=h0,
output_final_state=True,
output_h=False,
cu_seqlens=cu_seqlens,
auto_cp=True,
)
try:
mean = tilelang.profiler.do_bench(call_qla_fwd, warmup=warmup, rep=repeats)
results["flash_qla"] = mean
except RuntimeError as e:
print(f"\n[WARN] FlashQLA Fwd failed: {e}")
cleanup_cuda()
results["flash_qla"] = float("nan")
if HAS_FI:
def call_fi_fwd():
fi_fwd(
q=q.view(-1, h_qk, head_dim),
k=k.view(-1, h_qk, head_dim),
v=v.view(-1, h_v, head_dim),
g=g.view(-1, h_v),
beta=beta.view(-1, h_v),
scale=scale,
initial_state=h0,
cu_seqlens=cu_seqlens,
output_final_state=True,
)
try:
mean = tilelang.profiler.do_bench(call_fi_fwd, warmup=warmup, rep=repeats)
results["fi"] = mean
except RuntimeError as e:
print(f"\n[WARN] FI Fwd failed: {e}")
cleanup_cuda()
results["fi"] = float("nan")
else:
results["fi"] = float("nan")
def call_fla_fwd():
fla_fwd(
q,
k,
v,
g,
beta,
scale=scale,
initial_state=h0,
output_final_state=True,
cu_seqlens=cu_seqlens,
)
try:
mean = tilelang.profiler.do_bench(call_fla_fwd, warmup=warmup, rep=repeats)
results["fla"] = mean
except RuntimeError as e:
print(f"\n[WARN] FLA Fwd failed: {e}")
cleanup_cuda()
results["fla"] = float("nan")
try:
torch.cuda.synchronize()
except Exception:
pass
return (
results.get("flash_qla", float("nan")),
results.get("fi", float("nan")),
results.get("fla", float("nan")),
)
def bench_bwd(
seqlens: List[int],
h_qk: int,
h_v: int,
head_dim: int = HEAD_DIM,
warmup: int = 10,
repeats: int = 100,
) -> Tuple[float, float]:
"""
Run Backward Pass Benchmark.
Returns: (qla_mean_ms, fla_mean_ms)
"""
unified_h = h_qk
cleanup_cuda()
data = prepare_tensors(seqlens, unified_h, unified_h, head_dim)
if data is None:
return float("nan"), float("nan")
q, k, v, g, beta = data["q"], data["k"], data["v"], data["g"], data["beta"]
h0, scale, cu_seqlens = data["h0"], data["scale"], data["cu_seqlens"]
do, dht = data["do"], data["dht"]
g_cumsum = None
A = None
# Pre-run FWD to get intermediates
try:
result = qla_fwd(
q,
k,
v,
g,
beta,
scale=scale,
initial_state=h0,
output_final_state=True,
output_h=False,
cu_seqlens=cu_seqlens,
auto_cp=True,
)
if isinstance(result, tuple) and len(result) >= 2:
g_cumsum, A = result[0], result[1]
else:
raise RuntimeError("FlashQLA FWD did not return expected intermediates")
except RuntimeError as e:
print(f"[FWD Error] Failed at seqlens={seqlens}, heads={h_qk}. Error: {e}")
cleanup_cuda()
return float("nan"), float("nan")
results = {}
def call_qla_bwd():
return qla_bwd(
q,
k,
v,
g_cumsum,
beta,
A,
do,
dht,
scale=scale,
initial_state=h0,
cu_seqlens=cu_seqlens,
)
try:
mean = tilelang.profiler.do_bench(call_qla_bwd, warmup=warmup, rep=repeats)
results["flash_qla"] = mean
except RuntimeError as e:
print(f"\n[WARN] FlashQLA Bwd failed: {e}")
cleanup_cuda()
results["flash_qla"] = float("nan")
def call_fla_bwd():
return fla_bwd(q, k, v, g_cumsum, beta, A, scale, h0, do, dht, cu_seqlens)
try:
mean = tilelang.profiler.do_bench(call_fla_bwd, warmup=warmup, rep=repeats)
results["fla"] = mean
except RuntimeError as e:
print(f"\n[WARN] FLA Bwd failed: {e}")
cleanup_cuda()
results["fla"] = float("nan")
return results.get("flash_qla", float("nan")), results.get("fla", float("nan"))
FWD_HDR = (
f"{'Model Config':<16} {'Seqlens':<17} {'h_qk':>5} {'h_v':>5} "
f"{'flash_qla [fwd]':>10} {'FI [fwd]':>10} {'FLA [fwd]':>10} "
f"{'vs FLA':>7} {'vs FI':>7}"
)
BWD_HDR = (
f"{'Heads':<8} {'SeqLen':<15} "
f"{'flash_qla [bwd]':>10} {'FLA [bwd]':>10} {'Speedup':>8}"
)
def fmt_time(ms: float) -> str:
if math.isnan(ms):
return " N/A "
return f"{ms:>8.3f}ms"
def fmt_ratio(base: float, other: float) -> str:
if math.isnan(base) or math.isnan(other) or base == 0:
return " N/A "
return f"{other / base:>6.2f}x"
def main():
parser = argparse.ArgumentParser(description="Benchmark FlashQLA Gated Delta Rule")
parser.add_argument("--warmup", type=int, default=10)
parser.add_argument("--repeats", type=int, default=100)
parser.add_argument("--mode", choices=["fwd", "bwd", "all"], default="all")
parser.add_argument("--skip-fi", action="store_true")
args = parser.parse_args()
if not torch.cuda.is_available():
print("CUDA not available.")
return
global HAS_FI
if args.skip_fi:
HAS_FI = False
gpu_name = torch.cuda.get_device_properties(0).name
print(f"GPU: {gpu_name}")
print("Models: Qwen3.5 family (397B, 122B, 35B, 27B, 9B, 4B, 2B, 0.8B), d=128")
print(f"Config: Warmup={args.warmup}, Repeats={args.repeats}")
libs = get_lib_versions()
print("Library Versions:")
ver_str = " | ".join([f"{k}: {v}" for k, v in libs.items()])
print(f" {ver_str}")
print("=" * 110)
# Forward
if args.mode in ("fwd", "all"):
print("\n>>> FORWARD BENCHMARKS")
print(FWD_HDR)
print("-" * len(FWD_HDR))
prev_model = None
for cfg in FWD_MODEL_CONFIGS:
if prev_model is not None and cfg.label != prev_model:
print()
prev_model = cfg.label
for sl_cfg in FWD_SEQLEN_CONFIGS:
try:
qla_ms, fi_ms, fla_ms = bench_fwd(
sl_cfg.seqlens,
cfg.h_qk,
cfg.h_v,
warmup=args.warmup,
repeats=args.repeats,
)
if math.isnan(qla_ms) and math.isnan(fla_ms):
cleanup_cuda()
ratio_fla = fmt_ratio(qla_ms, fla_ms)
ratio_fi = fmt_ratio(qla_ms, fi_ms)
print(
f"{cfg.label:<16} {sl_cfg.label:<17} {cfg.h_qk:>5} {cfg.h_v:>5} "
f"{fmt_time(qla_ms)} {fmt_time(fi_ms)} {fmt_time(fla_ms)} "
f"{ratio_fla} {ratio_fi}",
flush=True,
)
except Exception as e:
print(
f"\n[ERROR] Forward Case Failed: {cfg.label} / {sl_cfg.label}"
)
print(f"Exception: {e}")
cleanup_cuda()
continue
cleanup_cuda()
# Backward
if args.mode in ("bwd", "all"):
print("\n" + "=" * 110)
print(f"\n>>> BACKWARD BENCHMARKS (Split seq into {BWD_SPLIT_SIZE} sequences)")
print(BWD_HDR)
print("-" * len(BWD_HDR))
prev_model = None
for cfg in BWD_MODEL_CONFIGS:
if prev_model is not None and cfg.label != prev_model:
print()
prev_model = cfg.label
for sl_cfg in BWD_SEQLEN_CONFIGS:
try:
qla_bwd_ms, fla_bwd_ms = bench_bwd(
sl_cfg.seqlens,
cfg.h_qk,
cfg.h_v,
warmup=args.warmup,
repeats=args.repeats,
)
if math.isnan(qla_bwd_ms) and math.isnan(fla_bwd_ms):
speedup = " Skip "
else:
speedup = fmt_ratio(qla_bwd_ms, fla_bwd_ms)
print(
f"{cfg.label:<8} {sl_cfg.label:<15} "
f"{fmt_time(qla_bwd_ms)} {fmt_time(fla_bwd_ms)} {speedup} ",
flush=True,
)
except Exception as e:
print(
f"\n[CRITICAL ERROR] Case Crashed: Heads={cfg.label}, TotalSeqLen={sl_cfg.label}"
)
print(f"Exception: {e}")
cleanup_cuda()
continue
cleanup_cuda()
print("\nBenchmark Finished.")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,214 @@
GPU: NVIDIA H200
Models: Qwen3.5 family (397B, 122B, 35B, 27B, 9B, 4B, 2B, 0.8B), d=128
Config: Warmup=10, Repeats=100
Library Versions:
torch: 2.9.1+cu128 | fla: 0.5.0 | flashinfer: 0.6.9 | tilelang: 0.1.8
==============================================================================================================
>>> FORWARD BENCHMARKS
Model Config Seqlens h_qk h_v flash_qla [fwd] FI [fwd] FLA [fwd] vs FLA vs FI
------------------------------------------------------------------------------------------------------------
397B/122B TP8 1x32768 2 8 0.310ms 1.657ms 0.910ms 2.93x 5.34x
397B/122B TP8 1x16384 2 8 0.184ms 0.832ms 0.465ms 2.53x 4.52x
397B/122B TP8 1x8192 2 8 0.122ms 0.420ms 0.578ms 4.74x 3.45x
397B/122B TP8 1x4096 2 8 0.107ms 0.215ms 0.277ms 2.59x 2.01x
397B/122B TP8 1x2048 2 8 0.106ms 0.114ms 0.261ms 2.46x 1.07x
397B/122B TP8 28672+4096 2 8 0.304ms 1.449ms 0.841ms 2.77x 4.77x
397B/122B TP8 24576+8192 2 8 0.300ms 1.243ms 0.766ms 2.55x 4.15x
397B/122B TP8 16384+16384 2 8 0.292ms 0.831ms 0.623ms 2.13x 2.84x
397B/122B TP8 8192+24576 2 8 0.298ms 1.241ms 0.766ms 2.57x 4.16x
397B/122B TP8 4096+28672 2 8 0.304ms 1.447ms 0.841ms 2.76x 4.75x
397B/122B TP8 12288+4096 2 8 0.180ms 0.626ms 0.391ms 2.17x 3.48x
397B/122B TP8 6144+2048 2 8 0.118ms 0.317ms 0.263ms 2.23x 2.70x
397B/122B TP8 4096+4096 2 8 0.115ms 0.215ms 0.262ms 2.27x 1.86x
397B/122B TP8 2048+6144 2 8 0.118ms 0.317ms 0.276ms 2.35x 2.70x
397B/122B TP8 1024+7168 2 8 0.120ms 0.368ms 0.263ms 2.19x 3.06x
397B/122B TP8 8192x4 2 8 0.274ms 0.401ms 0.477ms 1.74x 1.46x
397B/122B TP8 4096x8 2 8 0.203ms 0.205ms 0.481ms 2.37x 1.01x
397B/122B TP8 2048x4 2 8 0.109ms 0.110ms 0.268ms 2.46x 1.01x
397B/122B TP8 1024x8 2 8 0.065ms 0.063ms 0.273ms 4.17x 0.95x
397B/122B TP4 1x32768 4 16 0.485ms 1.655ms 1.250ms 2.58x 3.41x
397B/122B TP4 1x16384 4 16 0.292ms 0.832ms 0.630ms 2.16x 2.85x
397B/122B TP4 1x8192 4 16 0.178ms 0.420ms 0.325ms 1.83x 2.37x
397B/122B TP4 1x4096 4 16 0.115ms 0.215ms 0.290ms 2.51x 1.86x
397B/122B TP4 1x2048 4 16 0.108ms 0.114ms 0.275ms 2.54x 1.05x
397B/122B TP4 28672+4096 4 16 0.477ms 1.440ms 1.178ms 2.47x 3.02x
397B/122B TP4 24576+8192 4 16 0.474ms 1.222ms 1.100ms 2.32x 2.58x
397B/122B TP4 16384+16384 4 16 0.469ms 0.789ms 0.947ms 2.02x 1.68x
397B/122B TP4 8192+24576 4 16 0.470ms 1.192ms 1.099ms 2.34x 2.53x
397B/122B TP4 4096+28672 4 16 0.474ms 1.394ms 1.175ms 2.48x 2.94x
397B/122B TP4 12288+4096 4 16 0.281ms 0.616ms 0.549ms 1.95x 2.19x
397B/122B TP4 6144+2048 4 16 0.175ms 0.312ms 0.292ms 1.67x 1.78x
397B/122B TP4 4096+4096 4 16 0.173ms 0.207ms 0.263ms 1.52x 1.19x
397B/122B TP4 2048+6144 4 16 0.175ms 0.306ms 0.294ms 1.67x 1.75x
397B/122B TP4 1024+7168 4 16 0.177ms 0.359ms 0.307ms 1.74x 2.03x
397B/122B TP4 8192x4 4 16 0.388ms 0.400ms 0.949ms 2.45x 1.03x
397B/122B TP4 4096x8 4 16 0.311ms 0.210ms 0.954ms 3.07x 0.68x
397B/122B TP4 2048x4 4 16 0.114ms 0.110ms 0.264ms 2.32x 0.97x
397B/122B TP4 1024x8 4 16 0.097ms 0.066ms 0.280ms 2.88x 0.68x
397B/122B TP2 1x32768 8 32 0.866ms 1.567ms 1.894ms 2.19x 1.81x
397B/122B TP2 1x16384 8 32 0.468ms 0.789ms 0.926ms 1.98x 1.69x
397B/122B TP2 1x8192 8 32 0.275ms 0.401ms 0.479ms 1.74x 1.46x
397B/122B TP2 1x4096 8 32 0.175ms 0.207ms 0.273ms 1.56x 1.18x
397B/122B TP2 1x2048 8 32 0.116ms 0.110ms 0.308ms 2.66x 0.96x
397B/122B TP2 28672+4096 8 32 0.841ms 1.374ms 1.893ms 2.25x 1.63x
397B/122B TP2 24576+8192 8 32 0.846ms 1.178ms 1.891ms 2.24x 1.39x
397B/122B TP2 16384+16384 8 32 0.752ms 0.789ms 1.897ms 2.52x 1.05x
397B/122B TP2 8192+24576 8 32 0.842ms 1.182ms 1.893ms 2.25x 1.40x
397B/122B TP2 4096+28672 8 32 0.843ms 1.378ms 1.899ms 2.25x 1.63x
397B/122B TP2 12288+4096 8 32 0.468ms 0.597ms 0.928ms 1.98x 1.28x
397B/122B TP2 6144+2048 8 32 0.271ms 0.305ms 0.506ms 1.87x 1.13x
397B/122B TP2 4096+4096 8 32 0.206ms 0.206ms 0.482ms 2.34x 1.00x
397B/122B TP2 2048+6144 8 32 0.269ms 0.305ms 0.483ms 1.79x 1.13x
397B/122B TP2 1024+7168 8 32 0.277ms 0.354ms 0.480ms 1.73x 1.27x
397B/122B TP2 8192x4 8 32 0.596ms 0.405ms 1.894ms 3.18x 0.68x
397B/122B TP2 4096x8 8 32 0.601ms 0.409ms 1.904ms 3.17x 0.68x
397B/122B TP2 2048x4 8 32 0.172ms 0.115ms 0.486ms 2.83x 0.67x
397B/122B TP2 1024x8 8 32 0.180ms 0.124ms 0.504ms 2.80x 0.69x
397B/122B TP1 1x32768 16 64 1.494ms 1.563ms 3.336ms 2.23x 1.05x
397B/122B TP1 1x16384 16 64 0.762ms 0.790ms 1.608ms 2.11x 1.04x
397B/122B TP1 1x8192 16 64 0.393ms 0.397ms 0.815ms 2.07x 1.01x
397B/122B TP1 1x4096 16 64 0.209ms 0.206ms 0.430ms 2.06x 0.99x
397B/122B TP1 1x2048 16 64 0.117ms 0.111ms 0.276ms 2.37x 0.95x
397B/122B TP1 28672+4096 16 64 1.719ms 1.374ms 3.341ms 1.94x 0.80x
397B/122B TP1 24576+8192 16 64 1.530ms 1.179ms 3.341ms 2.18x 0.77x
397B/122B TP1 16384+16384 16 64 1.168ms 0.804ms 3.367ms 2.88x 0.69x
397B/122B TP1 8192+24576 16 64 1.539ms 1.232ms 3.352ms 2.18x 0.80x
397B/122B TP1 4096+28672 16 64 1.723ms 1.453ms 3.353ms 1.95x 0.84x
397B/122B TP1 12288+4096 16 64 0.781ms 0.596ms 1.606ms 2.06x 0.76x
397B/122B TP1 6144+2048 16 64 0.406ms 0.305ms 0.822ms 2.02x 0.75x
397B/122B TP1 4096+4096 16 64 0.318ms 0.209ms 0.823ms 2.59x 0.66x
397B/122B TP1 2048+6144 16 64 0.406ms 0.317ms 0.825ms 2.03x 0.78x
397B/122B TP1 1024+7168 16 64 0.454ms 0.371ms 0.825ms 1.82x 0.82x
397B/122B TP1 8192x4 16 64 1.170ms 0.800ms 3.361ms 2.87x 0.68x
397B/122B TP1 4096x8 16 64 1.191ms 0.819ms 3.411ms 2.86x 0.69x
397B/122B TP1 2048x4 16 64 0.324ms 0.219ms 0.837ms 2.58x 0.67x
397B/122B TP1 1024x8 16 64 0.344ms 0.239ms 0.868ms 2.52x 0.70x
35B/9B/4B TP1 1x32768 16 32 0.871ms 1.567ms 1.908ms 2.19x 1.80x
35B/9B/4B TP1 1x16384 16 32 0.472ms 0.790ms 0.945ms 2.00x 1.67x
35B/9B/4B TP1 1x8192 16 32 0.278ms 0.402ms 0.484ms 1.74x 1.44x
35B/9B/4B TP1 1x4096 16 32 0.176ms 0.207ms 0.266ms 1.51x 1.17x
35B/9B/4B TP1 1x2048 16 32 0.118ms 0.111ms 0.265ms 2.24x 0.94x
35B/9B/4B TP1 28672+4096 16 32 0.860ms 1.374ms 1.904ms 2.21x 1.60x
35B/9B/4B TP1 24576+8192 16 32 0.857ms 1.180ms 1.904ms 2.22x 1.38x
35B/9B/4B TP1 16384+16384 16 32 0.762ms 0.796ms 1.919ms 2.52x 1.05x
35B/9B/4B TP1 8192+24576 16 32 0.854ms 1.184ms 1.909ms 2.24x 1.39x
35B/9B/4B TP1 4096+28672 16 32 0.853ms 1.379ms 1.911ms 2.24x 1.62x
35B/9B/4B TP1 12288+4096 16 32 0.474ms 0.597ms 0.946ms 2.00x 1.26x
35B/9B/4B TP1 6144+2048 16 32 0.270ms 0.304ms 0.488ms 1.81x 1.13x
35B/9B/4B TP1 4096+4096 16 32 0.208ms 0.206ms 0.489ms 2.35x 0.99x
35B/9B/4B TP1 2048+6144 16 32 0.272ms 0.304ms 0.487ms 1.79x 1.12x
35B/9B/4B TP1 1024+7168 16 32 0.279ms 0.354ms 0.488ms 1.75x 1.27x
35B/9B/4B TP1 8192x4 16 32 0.596ms 0.406ms 1.903ms 3.19x 0.68x
35B/9B/4B TP1 4096x8 16 32 0.607ms 0.412ms 1.918ms 3.16x 0.68x
35B/9B/4B TP1 2048x4 16 32 0.171ms 0.116ms 0.495ms 2.89x 0.68x
35B/9B/4B TP1 1024x8 16 32 0.184ms 0.125ms 0.513ms 2.78x 0.68x
27B TP2 1x32768 8 24 0.657ms 1.631ms 1.576ms 2.40x 2.48x
27B TP2 1x16384 8 24 0.397ms 0.820ms 0.778ms 1.96x 2.06x
27B TP2 1x8192 8 24 0.262ms 0.416ms 0.402ms 1.53x 1.58x
27B TP2 1x4096 8 24 0.170ms 0.213ms 0.267ms 1.57x 1.25x
27B TP2 1x2048 8 24 0.111ms 0.113ms 0.266ms 2.40x 1.02x
27B TP2 28672+4096 8 24 0.655ms 1.422ms 1.495ms 2.28x 2.17x
27B TP2 24576+8192 8 24 0.651ms 1.211ms 1.425ms 2.19x 1.86x
27B TP2 16384+16384 8 24 0.657ms 0.791ms 1.579ms 2.40x 1.20x
27B TP2 8192+24576 8 24 0.656ms 1.185ms 1.576ms 2.40x 1.81x
27B TP2 4096+28672 8 24 0.655ms 1.382ms 1.577ms 2.41x 2.11x
27B TP2 12288+4096 8 24 0.392ms 0.611ms 0.707ms 1.80x 1.56x
27B TP2 6144+2048 8 24 0.258ms 0.310ms 0.372ms 1.44x 1.20x
27B TP2 4096+4096 8 24 0.189ms 0.206ms 0.406ms 2.15x 1.09x
27B TP2 2048+6144 8 24 0.261ms 0.303ms 0.406ms 1.56x 1.16x
27B TP2 1024+7168 8 24 0.261ms 0.352ms 0.407ms 1.56x 1.35x
27B TP2 8192x4 8 24 0.545ms 0.401ms 1.419ms 2.61x 0.74x
27B TP2 4096x8 8 24 0.546ms 0.410ms 1.430ms 2.62x 0.75x
27B TP2 2048x4 8 24 0.155ms 0.112ms 0.374ms 2.41x 0.73x
27B TP2 1024x8 8 24 0.164ms 0.120ms 0.383ms 2.34x 0.74x
27B TP1 1x32768 16 48 1.243ms 1.573ms 2.728ms 2.19x 1.27x
27B TP1 1x16384 16 48 0.664ms 0.792ms 1.327ms 2.00x 1.19x
27B TP1 1x8192 16 48 0.402ms 0.401ms 0.677ms 1.69x 1.00x
27B TP1 1x4096 16 48 0.193ms 0.206ms 0.355ms 1.84x 1.07x
27B TP1 1x2048 16 48 0.107ms 0.110ms 0.264ms 2.47x 1.03x
27B TP1 28672+4096 16 48 1.231ms 1.371ms 2.618ms 2.13x 1.11x
27B TP1 24576+8192 16 48 1.434ms 1.185ms 2.541ms 1.77x 0.83x
27B TP1 16384+16384 16 48 1.074ms 0.787ms 2.744ms 2.56x 0.73x
27B TP1 8192+24576 16 48 1.435ms 1.186ms 2.741ms 1.91x 0.83x
27B TP1 4096+28672 16 48 1.230ms 1.371ms 2.731ms 2.22x 1.11x
27B TP1 12288+4096 16 48 0.733ms 0.598ms 1.230ms 1.68x 0.82x
27B TP1 6144+2048 16 48 0.379ms 0.305ms 0.638ms 1.68x 0.80x
27B TP1 4096+4096 16 48 0.290ms 0.207ms 0.682ms 2.35x 0.71x
27B TP1 2048+6144 16 48 0.379ms 0.306ms 0.689ms 1.82x 0.81x
27B TP1 1024+7168 16 48 0.425ms 0.355ms 0.688ms 1.62x 0.84x
27B TP1 8192x4 16 48 1.068ms 0.804ms 2.530ms 2.37x 0.75x
27B TP1 4096x8 16 48 0.899ms 0.611ms 2.571ms 2.86x 0.68x
27B TP1 2048x4 16 48 0.296ms 0.218ms 0.645ms 2.18x 0.74x
27B TP1 1024x8 16 48 0.268ms 0.180ms 0.659ms 2.46x 0.67x
2B/0.8B TP1 1x32768 16 16 0.492ms 1.654ms 1.288ms 2.62x 3.36x
2B/0.8B TP1 1x16384 16 16 0.289ms 0.831ms 0.655ms 2.27x 2.88x
2B/0.8B TP1 1x8192 16 16 0.181ms 0.421ms 0.341ms 1.88x 2.32x
2B/0.8B TP1 1x4096 16 16 0.119ms 0.215ms 0.294ms 2.47x 1.81x
2B/0.8B TP1 1x2048 16 16 0.109ms 0.114ms 0.282ms 2.58x 1.05x
2B/0.8B TP1 28672+4096 16 16 0.484ms 1.440ms 1.215ms 2.51x 2.98x
2B/0.8B TP1 24576+8192 16 16 0.478ms 1.222ms 1.140ms 2.38x 2.55x
2B/0.8B TP1 16384+16384 16 16 0.474ms 0.789ms 0.982ms 2.07x 1.67x
2B/0.8B TP1 8192+24576 16 16 0.478ms 1.191ms 1.138ms 2.38x 2.49x
2B/0.8B TP1 4096+28672 16 16 0.481ms 1.392ms 1.217ms 2.53x 2.89x
2B/0.8B TP1 12288+4096 16 16 0.287ms 0.618ms 0.581ms 2.02x 2.15x
2B/0.8B TP1 6144+2048 16 16 0.179ms 0.312ms 0.306ms 1.71x 1.74x
2B/0.8B TP1 4096+4096 16 16 0.177ms 0.207ms 0.292ms 1.65x 1.17x
2B/0.8B TP1 2048+6144 16 16 0.180ms 0.306ms 0.304ms 1.69x 1.70x
2B/0.8B TP1 1024+7168 16 16 0.181ms 0.358ms 0.324ms 1.79x 1.98x
2B/0.8B TP1 8192x4 16 16 0.389ms 0.399ms 0.984ms 2.53x 1.03x
2B/0.8B TP1 4096x8 16 16 0.318ms 0.214ms 0.992ms 3.12x 0.67x
2B/0.8B TP1 2048x4 16 16 0.118ms 0.111ms 0.283ms 2.41x 0.95x
2B/0.8B TP1 1024x8 16 16 0.100ms 0.069ms 0.321ms 3.21x 0.69x
Sym h32 1x32768 32 32 0.874ms 1.567ms 1.961ms 2.24x 1.79x
Sym h32 1x16384 32 32 0.474ms 0.790ms 0.981ms 2.07x 1.66x
Sym h32 1x8192 32 32 0.281ms 0.402ms 0.508ms 1.81x 1.43x
Sym h32 1x4096 32 32 0.177ms 0.208ms 0.271ms 1.52x 1.17x
Sym h32 1x2048 32 32 0.117ms 0.111ms 0.296ms 2.54x 0.96x
Sym h32 28672+4096 32 32 0.861ms 1.376ms 1.964ms 2.28x 1.60x
Sym h32 24576+8192 32 32 0.857ms 1.181ms 1.964ms 2.29x 1.38x
Sym h32 16384+16384 32 32 0.771ms 0.808ms 1.974ms 2.56x 1.05x
Sym h32 8192+24576 32 32 0.855ms 1.183ms 1.966ms 2.30x 1.38x
Sym h32 4096+28672 32 32 0.857ms 1.379ms 1.965ms 2.29x 1.61x
Sym h32 12288+4096 32 32 0.475ms 0.597ms 0.985ms 2.07x 1.26x
Sym h32 6144+2048 32 32 0.270ms 0.305ms 0.512ms 1.90x 1.13x
Sym h32 4096+4096 32 32 0.215ms 0.210ms 0.513ms 2.38x 0.97x
Sym h32 2048+6144 32 32 0.271ms 0.304ms 0.510ms 1.88x 1.12x
Sym h32 1024+7168 32 32 0.281ms 0.353ms 0.510ms 1.81x 1.26x
Sym h32 8192x4 32 32 0.600ms 0.419ms 1.981ms 3.30x 0.70x
Sym h32 4096x8 32 32 0.620ms 0.426ms 1.994ms 3.22x 0.69x
Sym h32 2048x4 32 32 0.176ms 0.119ms 0.519ms 2.95x 0.67x
Sym h32 1024x8 32 32 0.183ms 0.128ms 0.535ms 2.93x 0.70x
==============================================================================================================
>>> BACKWARD BENCHMARKS (Split seq into 8 sequences)
Heads SeqLen flash_qla [bwd] FLA [bwd] Speedup
---------------------------------------------------------------
32 16k 1.262ms 2.433ms 1.93x
32 32k 2.743ms 4.904ms 1.79x
32 64k 6.554ms 9.247ms 1.41x
32 128k 9.489ms 18.195ms 1.92x
32 256k 20.939ms 38.613ms 1.84x
48 16k 1.775ms 3.549ms 2.00x
48 32k 3.715ms 7.197ms 1.94x
48 64k 7.565ms 13.864ms 1.83x
48 128k 13.777ms 28.175ms 2.05x
48 256k 31.321ms 56.344ms 1.80x
64 16k 2.213ms 4.654ms 2.10x
64 32k 4.513ms 9.237ms 2.05x
64 64k 8.024ms 18.265ms 2.28x
64 128k 15.299ms 36.572ms 2.39x
64 256k 34.310ms 74.392ms 2.17x
Benchmark Finished.

16
flash_qla/__init__.py Normal file
View File

@@ -0,0 +1,16 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
__version__ = "0.1.0"
from flash_qla.ops.gated_delta_rule.chunk import (
chunk_gated_delta_rule_fwd,
chunk_gated_delta_rule_bwd,
chunk_gated_delta_rule,
)
__all__ = [
"chunk_gated_delta_rule_fwd",
"chunk_gated_delta_rule_bwd",
"chunk_gated_delta_rule",
]

View File

@@ -0,0 +1,7 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
from .gated_delta_rule import chunk_gated_delta_rule
__all__ = ["chunk_gated_delta_rule"]

View File

@@ -0,0 +1,7 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
from .chunk import chunk_gated_delta_rule
__all__ = ["chunk_gated_delta_rule"]

View File

@@ -0,0 +1,237 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
import torch
import tilelang
from flash_qla.utils import l2norm
from flash_qla.ops.utils import chunk_local_cumsum, group_reduce_vector
if tilelang.contrib.nvcc.get_target_compute_version() == "9.0":
from .hopper import fused_gdr_fwd, fused_gdr_bwd, fused_gdr_h, kkt_solve
else:
raise ValueError("FlashQLA now support sm90 only.")
from .cp_context import intra_card_cp_preprocess
def chunk_gated_delta_rule_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float | None = None,
initial_state: torch.Tensor | None = None,
cu_seqlens: torch.LongTensor | None = None,
output_final_state: bool = True,
output_h: bool = False,
auto_cp: bool = True,
):
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
A = kkt_solve(
k=k,
b=beta,
cu_seqlens=cu_seqlens,
)
if auto_cp:
initial_state, cu_seqlens, cp_seq_map, raw_cu_seqlens = (
intra_card_cp_preprocess(
k=k,
v=v,
a=A,
g=g,
b=beta,
raw_h0=initial_state,
raw_cu_seqlens=cu_seqlens,
)
)
else:
cp_seq_map = None
raw_cu_seqlens = None
o, h, final_state = fused_gdr_fwd(
q=q,
k=k,
v=v,
a=A,
g=g,
b=beta,
scale=scale,
initial_state=initial_state,
output_final_state=output_final_state,
output_h=output_h,
output_o=True,
cu_seqlens=cu_seqlens,
cp_seq_map=cp_seq_map,
raw_cu_seqlens=raw_cu_seqlens,
)
return g, A, o, h, final_state
def chunk_gated_delta_rule_bwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
A: torch.Tensor,
do: torch.Tensor,
dht: torch.Tensor | None = None,
scale: float | None = None,
initial_state: torch.Tensor | None = None,
cu_seqlens: torch.LongTensor | None = None,
):
h, _, _ = fused_gdr_h(
k=k,
v=v,
a=A,
g=g,
b=beta,
initial_state=initial_state,
output_final_state=False,
output_h=True,
cu_seqlens=cu_seqlens,
)
dq, dk, dv, dg, db, dh0 = fused_gdr_bwd(
q=q,
k=k,
v=v,
a=A,
g=g,
b=beta,
do=do,
dht=dht,
h=h,
scale=scale,
cu_seqlens=cu_seqlens,
)
Hg, H = k.shape[-2], v.shape[-2]
if Hg < H:
dq = group_reduce_vector(dq, Hg)
dk = group_reduce_vector(dk, Hg)
assert dg.dtype == torch.float32, "dg should be fp32"
dg = chunk_local_cumsum(dg, chunk_size=64, reverse=True, cu_seqlens=cu_seqlens)
return dq, dk, dv, db, dg, dh0
class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
@staticmethod
@torch.amp.custom_fwd(device_type="cuda")
def forward(
ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float | None = None,
initial_state: torch.Tensor | None = None,
output_final_state: bool = False,
cu_seqlens: torch.LongTensor | None = None,
):
q_orig = q
k_orig = k
g, A, o, _, final_state = chunk_gated_delta_rule_fwd(
q=q,
k=k,
v=v,
g=g,
beta=beta,
scale=scale,
initial_state=initial_state,
output_final_state=output_final_state,
output_h=False,
cu_seqlens=cu_seqlens,
)
ctx.save_for_backward(q_orig, k_orig, v, g, beta, A, initial_state, cu_seqlens)
ctx.scale = scale
return o.to(q.dtype), final_state
@staticmethod
@torch.amp.custom_bwd(device_type="cuda")
def backward(ctx, do: torch.Tensor, dht: torch.Tensor):
q_orig, k_orig, v, g, beta, A, initial_state, cu_seqlens = ctx.saved_tensors
dq, dk, dv, db, dg, dh0 = chunk_gated_delta_rule_bwd(
q=q_orig,
k=k_orig,
v=v,
g=g,
beta=beta,
A=A,
do=do,
dht=dht,
scale=ctx.scale,
initial_state=initial_state,
cu_seqlens=cu_seqlens,
)
return (
dq.to(q_orig),
dk.to(k_orig),
dv.to(v),
dg.to(g),
db.to(beta),
None,
dh0,
None,
None,
)
@torch.compiler.disable
def chunk_gated_delta_rule(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
use_qk_l2norm_in_kernel: bool = False,
cu_seqlens: torch.LongTensor | None = None,
head_first: bool = False,
):
assert q.dtype == k.dtype == v.dtype
assert q.dtype != torch.float32, (
"ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16 or float16."
)
assert not head_first, "head_first=True is not supported."
assert v.shape[2] % k.shape[2] == 0, (
"num_qk_heads must be divisible to num_v_heads."
)
if cu_seqlens is not None:
if q.shape[0] != 1:
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing."
)
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
raise ValueError(
f"The number of initial states is expected to be equal to the number of input sequences, "
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
)
if scale is None:
scale = k.shape[-1] ** -0.5
if use_qk_l2norm_in_kernel:
q = l2norm(q)
k = l2norm(k)
o, final_state = ChunkGatedDeltaRuleFunction.apply(
q,
k,
v,
g,
beta,
scale,
initial_state,
output_final_state,
cu_seqlens,
)
return o, final_state

View File

@@ -0,0 +1,163 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
import math
import torch
import tilelang
from flash_qla.utils import tensor_cache
if tilelang.contrib.nvcc.get_target_compute_version() == "9.0":
from .hopper import get_warmup_chunks, fused_gdr_h, correct_initial_states
else:
raise ValueError("FlashQLA now support sm90 only.")
MULTI_PROCESSOR_COUNT = torch.cuda.get_device_properties().multi_processor_count
@tensor_cache
def _create_cu_seqlens(
batch_size: int,
num_tokens: int,
device_idx: int,
):
return (
torch.arange((batch_size + 1), dtype=torch.int32, device=f"cuda:{device_idx}")
* num_tokens
)
@tensor_cache
def _calc_cp_seqs(
raw_cu_seqlens: torch.LongTensor,
chunk_size: int,
num_v_heads: int,
):
# TODO: tilelang kernel
device = raw_cu_seqlens.device
seqlen_dtype = raw_cu_seqlens.dtype
raw_cu_seqlens = raw_cu_seqlens.tolist()
raw_batch_size = len(raw_cu_seqlens) - 1
seqlens = [raw_cu_seqlens[i + 1] - raw_cu_seqlens[i] for i in range(raw_batch_size)]
num_chunks = [tilelang.cdiv(x, chunk_size) for x in seqlens]
# autocp
H = num_v_heads
# Latency model: T = a·L_cp + b·(B·H·Lc/P) / L_cp + c
# Minimizing T yields the theoretical optimum: L_cp* ∝ √(B·H·Lc / P), where P = MULTI_PROCESSOR_COUNT, L_cp = max_local_chunks
# Scaled by empirical factor (3) and aligned to the nearest power of 2 for optimal SM scheduling & memory alignment.
max_local_chunks = 2 ** round(
math.log2(math.sqrt(H * sum(num_chunks) / MULTI_PROCESSOR_COUNT) * 3)
)
# Set min to 4 to ensure multi-stage pipelining in fused_gdr;
max_local_chunks = max(max_local_chunks, 4)
use_cp = False
cp_cu_seqlens = []
ht_mask = []
seq_map_c2r = []
seq_map_r2c = [0]
max_local_tokens = max_local_chunks * chunk_size
for i, c in enumerate(num_chunks):
s = raw_cu_seqlens[i]
e = raw_cu_seqlens[i + 1]
if c > max_local_chunks:
while s < e:
cp_cu_seqlens.append(s)
ht_mask.append(False)
seq_map_c2r.append(i)
s += max_local_tokens
ht_mask[-1] = True
else:
cp_cu_seqlens.append(s)
ht_mask.append(True)
seq_map_c2r.append(i)
seq_map_r2c.append(len(cp_cu_seqlens))
cp_cu_seqlens.append(raw_cu_seqlens[-1])
# Disable CP when B * H naturally saturates SM occupancy.
# For varlen inputs, use `total_chunks / max_seq_chunks` as effective B,
# since CP helps accelerate highly uneven sequence lengths.
Be = sum(num_chunks) / max(num_chunks)
use_cp = Be * H <= 40 or (Be * H <= 56 and max(num_chunks) >= 128)
if use_cp:
cp_cu_seqlens = torch.tensor(
cp_cu_seqlens, dtype=seqlen_dtype, device=device, requires_grad=False
)
seq_map_c2r = torch.tensor(seq_map_c2r, dtype=seqlen_dtype, device=device)
seq_map_r2c = torch.tensor(
seq_map_r2c, dtype=seqlen_dtype, device=device, requires_grad=False
)
ht_mask = torch.tensor(
ht_mask, dtype=torch.bool, device=device, requires_grad=False
)
else:
cp_cu_seqlens, seq_map_r2c, ht_mask = None, None, None
return use_cp, cp_cu_seqlens, seq_map_r2c, seq_map_c2r, ht_mask
def intra_card_cp_preprocess(
k: torch.Tensor,
v: torch.Tensor,
a: torch.Tensor,
g: torch.Tensor,
b: torch.Tensor,
raw_h0: torch.Tensor,
raw_cu_seqlens: torch.Tensor,
warmup_threshold: float = -10.0,
):
batch_size, num_tokens, num_k_heads, k_head_dim = k.shape
_, _, num_v_heads, v_head_dim = v.shape
chunk_size = a.shape[-1]
device = k.device
if batch_size > 1:
return raw_h0, raw_cu_seqlens, None, None
if raw_cu_seqlens is None:
raw_cu_seqlens = _create_cu_seqlens(batch_size, num_tokens, device.index)
use_cp, cp_cu_seqlens, seq_map_r2c, seq_map_c2r, ht_mask = _calc_cp_seqs(
raw_cu_seqlens,
chunk_size,
num_v_heads,
)
if not use_cp:
return raw_h0, raw_cu_seqlens, None, None
num_warmup_chunks, fallback_mask = get_warmup_chunks(
g=g,
cu_seqlens=cp_cu_seqlens,
ht_mask=ht_mask,
chunk_size=chunk_size,
threshold=warmup_threshold,
) # [cp_batch_size, num_v_heads]
_, ht, mt = fused_gdr_h(
k=k,
v=v,
a=a,
g=g,
b=b,
initial_state=None,
output_final_state=True,
output_h=False,
cu_seqlens=cp_cu_seqlens,
num_warmup_chunks=num_warmup_chunks,
) # [cp_batch_size, num_v_heads, k_head_dim, v_head_dim]
cp_h0 = correct_initial_states(
raw_h0=raw_h0,
ht_buffer=ht,
mt_buffer=mt,
fallback_mask=fallback_mask,
seq_map_r2c=seq_map_r2c,
)
return cp_h0, cp_cu_seqlens, seq_map_c2r, raw_cu_seqlens

View File

@@ -0,0 +1,18 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
from .fused_fwd import fused_gdr_fwd
from .fused_bwd import fused_gdr_bwd
from .prepare_h import fused_gdr_h
from .kkt_solve import kkt_solve
from .cp_fwd import get_warmup_chunks, correct_initial_states
__all__ = [
"fused_gdr_fwd",
"fused_gdr_bwd",
"fused_gdr_h",
"kkt_solve",
"get_warmup_chunks",
"correct_initial_states",
]

View File

@@ -0,0 +1,309 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
import torch
import tilelang
import tilelang.language as T
@tilelang.jit()
def tilelang_get_warmup_chunks(
num_heads,
chunk_size,
threshold,
accum_dtype,
g_dtype,
mask_dtype,
seqlen_dtype,
):
batch_size = T.dynamic("batch_size")
num_tokens = T.dynamic("num_tokens")
num_threads = tilelang.cdiv(num_heads, 32) * 32
@T.prim_func
def tilelang_get_warmup_chunks_kernel(
g: T.Tensor([1, num_tokens, num_heads], dtype=g_dtype),
ht_mask: T.Tensor([batch_size], dtype=mask_dtype),
cu_seqlens: T.Tensor([batch_size + 1], dtype=seqlen_dtype),
num_warmup_chunks: T.Tensor([batch_size, num_heads], dtype=seqlen_dtype),
fallback_mask: T.Tensor([batch_size, num_heads], dtype=mask_dtype),
):
with T.Kernel(batch_size, threads=num_threads) as (bb,):
if ht_mask[bb]:
for i_h in T.Parallel(num_heads):
num_warmup_chunks[bb, i_h] = 0
else:
seq_start_idx = T.alloc_var("int32")
seq_end_idx = T.alloc_var("int32")
num_iters = T.alloc_var("int32")
seq_start_idx = cu_seqlens[bb]
seq_end_idx = cu_seqlens[bb + 1]
num_iters = (seq_end_idx - seq_start_idx) // chunk_size
g_fragment = T.alloc_fragment((num_heads), dtype=accum_dtype)
g_cumsum = T.alloc_fragment((num_heads), dtype=accum_dtype)
n_fragment = T.alloc_fragment((num_heads), dtype=seqlen_dtype)
f_fragment = T.alloc_fragment((num_heads), dtype=mask_dtype)
T.clear(g_cumsum)
T.fill(n_fragment, num_iters)
T.fill(f_fragment, True)
for i_s in T.serial(num_iters):
for i_h in T.Parallel(num_heads):
g_fragment[i_h] = g[0, seq_end_idx - i_s * chunk_size - 1, i_h]
for i_h in T.Parallel(num_heads):
g_cumsum[i_h] += g_fragment[i_h]
for i_h in T.Parallel(num_heads):
if g_cumsum[i_h] < threshold and n_fragment[i_h] == num_iters:
n_fragment[i_h] = i_s + 1
f_fragment[i_h] = False
for i_h in T.Parallel(num_heads):
num_warmup_chunks[bb, i_h] = n_fragment[i_h]
for i_h in T.Parallel(num_heads):
fallback_mask[bb, i_h] = f_fragment[i_h]
return tilelang_get_warmup_chunks_kernel
def get_warmup_chunks(
g: torch.Tensor, # [1, num_total_tokens, num_v_heads]
cu_seqlens: torch.Tensor, # [cp_real_batch_size + 1]
ht_mask: torch.Tensor, # [cp_real_batch_size]
chunk_size: int = 64,
threshold: float = -10.0,
):
batch_size, num_tokens, num_heads = g.shape
real_batch_size = ht_mask.shape[0]
assert cu_seqlens.shape[0] == real_batch_size + 1
assert batch_size == 1
assert chunk_size == 64
tilelang_get_warmup_chunks_kernel = tilelang_get_warmup_chunks(
num_heads=num_heads,
chunk_size=chunk_size,
threshold=threshold,
accum_dtype="float32",
g_dtype=g.dtype,
mask_dtype=ht_mask.dtype,
seqlen_dtype=cu_seqlens.dtype,
)
num_warmup_chunks = torch.empty(
[real_batch_size, num_heads], dtype=cu_seqlens.dtype, device=cu_seqlens.device
)
fallback_mask = torch.empty(
[real_batch_size, num_heads], dtype=ht_mask.dtype, device=cu_seqlens.device
)
tilelang_get_warmup_chunks_kernel(
g, ht_mask, cu_seqlens, num_warmup_chunks, fallback_mask
)
return num_warmup_chunks, fallback_mask
@tilelang.jit()
def tilelang_correct_h0(
H,
DK,
DV,
res_dtype,
accum_dtype,
buffer_dtype,
seqlen_dtype,
mask_dtype,
use_raw_h0,
block_DV: int = 32,
):
cp_batch_size = T.dynamic("cp_batch_size")
raw_batch_size = T.dynamic("raw_batch_size")
@T.macro
def kernel_body(
bb,
bh,
bv,
seq_start_idx,
seq_end_idx,
num_iters,
ht_buffer,
mt_buffer,
fallback_mask,
seq_map_r2c,
cp_h0,
h_fragment,
):
h_shared = T.alloc_shared((DK, block_DV), dtype=buffer_dtype)
hd_shared = T.alloc_shared((DK, block_DV), dtype=buffer_dtype)
m_shared = T.alloc_shared((DK, DK), dtype=buffer_dtype)
T.copy(
h_fragment,
cp_h0[seq_start_idx, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV],
)
for i_s in T.Pipelined(num_iters - 1, num_stages=2):
if fallback_mask[seq_start_idx + i_s, bh]:
T.copy(h_fragment, hd_shared)
T.copy(
ht_buffer[
seq_start_idx + i_s, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV
],
h_shared,
)
T.copy(h_shared, h_fragment)
if fallback_mask[seq_start_idx + i_s, bh]:
T.copy(mt_buffer[seq_start_idx + i_s, bh, 0:DK, 0:DK], m_shared)
T.gemm(m_shared, hd_shared, h_fragment, clear_accum=False)
T.copy(
h_fragment,
cp_h0[
seq_start_idx + i_s + 1,
bh,
0:DK,
bv * block_DV : (bv + 1) * block_DV,
],
)
if use_raw_h0:
@T.prim_func
def tilelang_correct_h0_kernel(
raw_h0: T.Tensor([raw_batch_size, H, DK, DV], dtype=res_dtype),
ht_buffer: T.Tensor([cp_batch_size, H, DK, DV], dtype=buffer_dtype),
mt_buffer: T.Tensor([cp_batch_size, H, DK, DK], dtype=buffer_dtype),
fallback_mask: T.Tensor([cp_batch_size, H], dtype=mask_dtype),
seq_map_r2c: T.Tensor([raw_batch_size + 1], dtype=seqlen_dtype),
cp_h0: T.Tensor([cp_batch_size, H, DK, DV], dtype=res_dtype),
):
with T.Kernel(
T.ceildiv(DV, block_DV) * H * raw_batch_size, threads=128
) as (bbhv,):
bbh, bv = (
bbhv // T.ceildiv(DV, block_DV),
bbhv % T.ceildiv(DV, block_DV),
)
bb, bh = bbh // H, bbh % H
seq_start_idx = seq_map_r2c[bb]
seq_end_idx = seq_map_r2c[bb + 1]
num_iters = seq_end_idx - seq_start_idx
h_fragment = T.alloc_fragment((DK, block_DV), dtype=accum_dtype)
T.copy(
raw_h0[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV],
h_fragment,
)
kernel_body(
bb,
bh,
bv,
seq_start_idx,
seq_end_idx,
num_iters,
ht_buffer,
mt_buffer,
fallback_mask,
seq_map_r2c,
cp_h0,
h_fragment,
)
else:
@T.prim_func
def tilelang_correct_h0_kernel(
ht_buffer: T.Tensor([cp_batch_size, H, DK, DV], dtype=buffer_dtype),
mt_buffer: T.Tensor([cp_batch_size, H, DK, DK], dtype=buffer_dtype),
fallback_mask: T.Tensor([cp_batch_size, H], dtype=mask_dtype),
seq_map_r2c: T.Tensor([raw_batch_size + 1], dtype=seqlen_dtype),
cp_h0: T.Tensor([cp_batch_size, H, DK, DV], dtype=res_dtype),
):
with T.Kernel(
T.ceildiv(DV, block_DV) * H * raw_batch_size, threads=128
) as (bbhv,):
bbh, bv = (
bbhv // T.ceildiv(DV, block_DV),
bbhv % T.ceildiv(DV, block_DV),
)
bb, bh = bbh // H, bbh % H
seq_start_idx = seq_map_r2c[bb]
seq_end_idx = seq_map_r2c[bb + 1]
num_iters = seq_end_idx - seq_start_idx
h_fragment = T.alloc_fragment((DK, block_DV), dtype=accum_dtype)
T.clear(h_fragment)
kernel_body(
bb,
bh,
bv,
seq_start_idx,
seq_end_idx,
num_iters,
ht_buffer,
mt_buffer,
fallback_mask,
seq_map_r2c,
cp_h0,
h_fragment,
)
return tilelang_correct_h0_kernel
def correct_initial_states(
raw_h0: torch.Tensor
| None, # [raw_batch_size, num_v_heads, k_head_dim, v_head_dim]
ht_buffer: torch.Tensor, # [cp_batch_size, num_v_heads, k_head_dim, v_head_dim]
mt_buffer: torch.Tensor, # [cp_batch_size, num_v_heads, k_head_dim, k_head_dim]
fallback_mask: torch.Tensor, # [cp_batch_size, num_v_heads]
seq_map_r2c: torch.Tensor, # [raw_batch_size + 1]
):
cp_batch_size = fallback_mask.shape[0]
_, num_heads, k_head_dim, v_head_dim = ht_buffer.shape
assert k_head_dim == v_head_dim == 128
if raw_h0 is None:
res_dtype = torch.float32
use_raw_h0 = False
else:
res_dtype = raw_h0.dtype
use_raw_h0 = True
tilelang_correct_h0_kernel = tilelang_correct_h0(
H=num_heads,
DK=k_head_dim,
DV=v_head_dim,
res_dtype=res_dtype,
accum_dtype="float32",
buffer_dtype=ht_buffer.dtype,
seqlen_dtype=seq_map_r2c.dtype,
mask_dtype=fallback_mask.dtype,
use_raw_h0=use_raw_h0,
)
cp_h0 = torch.empty(
(cp_batch_size, num_heads, k_head_dim, v_head_dim),
dtype=res_dtype,
device=ht_buffer.device,
)
if use_raw_h0:
tilelang_correct_h0_kernel(
raw_h0,
ht_buffer,
mt_buffer,
fallback_mask,
seq_map_r2c,
cp_h0,
)
else:
tilelang_correct_h0_kernel(
ht_buffer,
mt_buffer,
fallback_mask,
seq_map_r2c,
cp_h0,
)
return cp_h0

View File

@@ -0,0 +1,985 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
import torch
import tilelang
import tilelang.language as T
from flash_qla.utils import prepare_chunk_offsets
@tilelang.jit(
# out_idx=[-5, -4, -3, -2, -1],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
tilelang.PassConfigKey.TL_DISABLE_DATA_RACE_CHECK: True,
},
)
def tilelang_fused_chunk_gdr_bwd(
H,
Hg,
DK,
DV,
chunk_size,
scale,
accum_dtype,
qkva_dtype,
g_dtype,
b_dtype,
h_dtype,
o_dtype,
seqlen_dtype,
is_varlen,
use_dht,
):
batch_size = T.dynamic("batch_size")
num_tokens = T.dynamic("num_tokens")
num_chunks = T.dynamic("num_chunks")
block_S = chunk_size
if is_varlen:
q_shape = (1, num_tokens, Hg, DK)
k_shape = (1, num_tokens, Hg, DK)
v_shape = (1, num_tokens, H, DV)
o_shape = (1, num_tokens, H, DV)
a_shape = (1, num_tokens, H, chunk_size)
g_shape = (1, num_tokens, H)
b_shape = (1, num_tokens, H)
h_shape = (1, num_chunks, H, DK, DV)
else:
q_shape = (batch_size, num_tokens, Hg, DK)
k_shape = (batch_size, num_tokens, Hg, DK)
v_shape = (batch_size, num_tokens, H, DV)
o_shape = (batch_size, num_tokens, H, DV)
a_shape = (batch_size, num_tokens, H, chunk_size)
g_shape = (batch_size, num_tokens, H)
b_shape = (batch_size, num_tokens, H)
h_shape = (batch_size, num_chunks, H, DK, DV)
h0_shape = (batch_size, H, DK, DV)
ht_shape = (batch_size, H, DK, DV)
@T.prim_func
def tilelang_fused_chunk_gdr_bwd_kernel(
do: T.Tensor(o_shape, dtype=o_dtype),
dht: T.Tensor(ht_shape, dtype=accum_dtype),
q: T.Tensor(q_shape, dtype=qkva_dtype),
k: T.Tensor(k_shape, dtype=qkva_dtype),
v: T.Tensor(v_shape, dtype=qkva_dtype),
a: T.Tensor(a_shape, dtype=qkva_dtype),
g: T.Tensor(g_shape, dtype=g_dtype),
b: T.Tensor(b_shape, dtype=b_dtype),
h: T.Tensor(h_shape, dtype=h_dtype),
cu_seqlens: T.Tensor([batch_size + 1], dtype=seqlen_dtype),
chunk_offsets: T.Tensor([batch_size + 1], dtype=seqlen_dtype),
dq: T.Tensor(v_shape, dtype=qkva_dtype),
dk: T.Tensor(v_shape, dtype=qkva_dtype),
dv: T.Tensor(v_shape, dtype=qkva_dtype),
dg: T.Tensor(g_shape, dtype=g_dtype),
db: T.Tensor(b_shape, dtype=b_dtype),
dh0: T.Tensor(h0_shape, dtype=accum_dtype),
):
with T.Kernel(batch_size * H, threads=512) as (bbh,):
bb, bh = bbh // H, bbh % H
bhg = bh // (H // Hg)
batch_idx = T.alloc_var("int32")
seq_start_idx = T.alloc_var("int32")
seq_end_idx = T.alloc_var("int32")
chunk_start_idx = T.alloc_var("int32")
batch_idx = 0 if is_varlen else bb
seq_start_idx = cu_seqlens[bb] if is_varlen else 0
seq_end_idx = cu_seqlens[bb + 1] if is_varlen else num_tokens
chunk_start_idx = chunk_offsets[bb] if is_varlen else 0
num_iters = T.alloc_var("int32")
num_iters = T.ceildiv(seq_end_idx - seq_start_idx, block_S)
# 2+2+2+2 + 1 + 4 = 13 units
do_shared = T.alloc_shared((block_S, DV), dtype=o_dtype)
q_shared = T.alloc_shared((block_S, DK), dtype=qkva_dtype)
k_shared = T.alloc_shared((block_S, DK), dtype=qkva_dtype)
v_shared = T.alloc_shared((block_S, DV), dtype=qkva_dtype)
a_shared = T.alloc_shared((block_S, block_S), dtype=qkva_dtype)
h_shared = T.alloc_shared((DK, DV), dtype=h_dtype)
g_shared = T.alloc_shared((block_S), dtype=accum_dtype, scope="shared")
g_exp_shared = T.alloc_shared((block_S), dtype=accum_dtype, scope="shared")
g_rev_exp_shared = T.alloc_shared(
(block_S), dtype=accum_dtype, scope="shared"
)
b_shared = T.alloc_shared((block_S), dtype=accum_dtype, scope="shared")
# 2 units
dqkv_shared = T.alloc_shared((block_S, DK), dtype=qkva_dtype)
dg_shared = T.alloc_shared((block_S), dtype=accum_dtype, scope="shared")
db_shared = T.alloc_shared((block_S), dtype=accum_dtype, scope="shared")
# 1+1 + 2+2+2 + 4 = 12 units
tmp_shared_1_1 = T.alloc_shared((block_S, block_S), dtype=qkva_dtype)
tmp_shared_1_2 = T.alloc_shared((block_S, block_S), dtype=qkva_dtype)
tmp_shared_1_3 = T.alloc_shared((block_S, block_S), dtype=qkva_dtype)
tmp_shared_2_1 = T.alloc_shared((block_S, DK), dtype=qkva_dtype)
tmp_shared_2_2 = T.alloc_shared((block_S, DK), dtype=qkva_dtype)
tmp_shared_2_3 = T.alloc_shared((block_S, DK), dtype=qkva_dtype)
tmp_shared_4_1 = T.alloc_shared((DK, DV), dtype=qkva_dtype)
# CONSUMER_K
dk_fragment = T.alloc_fragment((block_S, DK), dtype=accum_dtype)
dv_fragment = T.alloc_fragment((block_S, DK), dtype=accum_dtype)
odot_fragment_1 = T.alloc_fragment((block_S, DK), dtype=accum_dtype)
dg_fragment_1 = T.alloc_fragment((block_S), dtype=accum_dtype)
dg_last_local_1 = T.alloc_fragment((1), dtype=accum_dtype)
# CONSUMER_A
mask_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
p_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
a_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
dp_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
da_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
u_fragment = T.alloc_fragment((block_S, DK), dtype=accum_dtype)
dq_fragment = T.alloc_fragment((block_S, DK), dtype=accum_dtype)
db_fragment = T.alloc_fragment((block_S), dtype=accum_dtype)
odot_fragment_2 = T.alloc_fragment((block_S, DK), dtype=accum_dtype)
dg_fragment_2 = T.alloc_fragment((block_S), dtype=accum_dtype)
# CONSUMER_S
dh_fragment = T.alloc_fragment((DK, DV), dtype=accum_dtype)
_odot_fragment_3 = T.alloc_fragment((DK, DV), dtype=accum_dtype)
reduce_fragment = T.alloc_fragment((128, 2), dtype=accum_dtype)
dg_last_local_3 = T.alloc_fragment((1), dtype=accum_dtype)
g_last_local_3 = T.alloc_local((1), dtype=accum_dtype)
# 16 stages
bar_00 = T.alloc_barrier(arrive_count=448)
bar_01 = T.alloc_barrier(arrive_count=384)
bar_02 = T.alloc_barrier(arrive_count=288)
bar_03 = T.alloc_barrier(arrive_count=256)
bar_04 = T.alloc_barrier(arrive_count=416)
bar_05 = T.alloc_barrier(arrive_count=288)
bar_06 = T.alloc_barrier(arrive_count=256)
bar_07 = T.alloc_barrier(arrive_count=256)
bar_08 = T.alloc_barrier(arrive_count=384)
bar_09 = T.alloc_barrier(arrive_count=256)
bar_10 = T.alloc_barrier(arrive_count=288)
bar_11 = T.alloc_barrier(arrive_count=256)
bar_12 = T.alloc_barrier(arrive_count=128)
bar_13 = T.alloc_barrier(arrive_count=256)
bar_14 = T.alloc_barrier(arrive_count=256)
bar_15 = T.alloc_barrier(arrive_count=256)
T.annotate_layout(
{
do_shared: tilelang.layout.make_swizzled_layout(do_shared),
q_shared: tilelang.layout.make_swizzled_layout(q_shared),
k_shared: tilelang.layout.make_swizzled_layout(k_shared),
v_shared: tilelang.layout.make_swizzled_layout(v_shared),
a_shared: tilelang.layout.make_swizzled_layout(a_shared),
h_shared: tilelang.layout.make_swizzled_layout(h_shared),
dqkv_shared: tilelang.layout.make_swizzled_layout(dqkv_shared),
tmp_shared_1_1: tilelang.layout.make_swizzled_layout(
tmp_shared_1_1
),
tmp_shared_1_2: tilelang.layout.make_swizzled_layout(
tmp_shared_1_2
),
tmp_shared_1_3: tilelang.layout.make_swizzled_layout(
tmp_shared_1_3
),
tmp_shared_2_1: tilelang.layout.make_swizzled_layout(
tmp_shared_2_1
),
tmp_shared_2_2: tilelang.layout.make_swizzled_layout(
tmp_shared_2_2
),
tmp_shared_2_3: tilelang.layout.make_swizzled_layout(
tmp_shared_2_3
),
tmp_shared_4_1: tilelang.layout.make_swizzled_layout(
tmp_shared_4_1
),
}
)
# T.use_swizzle(10)
tx = T.get_thread_binding()
PRODUCER_NREG = 24
CONSUMER_K_NREG = 144
CONSUMER_A_NREG = 176
CONSUMER_S_NREG = 160
# Prefetch the last chunk of data
T.copy(
h[batch_idx, chunk_start_idx + num_iters - 1, bh, 0:DK, 0:DV], h_shared
)
for j_s, j_k in T.Parallel(block_S, DK):
if seq_start_idx + (num_iters - 1) * block_S + j_s < seq_end_idx:
q_shared[j_s, j_k] = q[
batch_idx,
seq_start_idx + (num_iters - 1) * block_S + j_s,
bhg,
j_k,
]
else:
q_shared[j_s, j_k] = 0
for j_s, j_k in T.Parallel(block_S, DK):
if seq_start_idx + (num_iters - 1) * block_S + j_s < seq_end_idx:
k_shared[j_s, j_k] = k[
batch_idx,
seq_start_idx + (num_iters - 1) * block_S + j_s,
bhg,
j_k,
]
else:
k_shared[j_s, j_k] = 0
for j_s, j_v in T.Parallel(block_S, DV):
if seq_start_idx + (num_iters - 1) * block_S + j_s < seq_end_idx:
v_shared[j_s, j_v] = v[
batch_idx,
seq_start_idx + (num_iters - 1) * block_S + j_s,
bh,
j_v,
]
else:
v_shared[j_s, j_v] = 0
for j_s, j_t in T.Parallel(block_S, block_S):
if seq_start_idx + (num_iters - 1) * block_S + j_s < seq_end_idx:
a_shared[j_s, j_t] = a[
batch_idx,
seq_start_idx + (num_iters - 1) * block_S + j_s,
bh,
j_t,
]
else:
a_shared[j_s, j_t] = 0
for j_s, j_v in T.Parallel(block_S, DV):
if seq_start_idx + (num_iters - 1) * block_S + j_s < seq_end_idx:
do_shared[j_s, j_v] = do[
batch_idx,
seq_start_idx + (num_iters - 1) * block_S + j_s,
bh,
j_v,
]
else:
do_shared[j_s, j_v] = 0
for j_s in T.Parallel(block_S):
if seq_start_idx + (num_iters - 1) * block_S + j_s < seq_end_idx:
g_shared[j_s] = g[
batch_idx, seq_start_idx + (num_iters - 1) * block_S + j_s, bh
]
else:
g_shared[j_s] = g[batch_idx, seq_end_idx - 1, bh]
for j_s in T.Parallel(block_S):
if seq_start_idx + (num_iters - 1) * block_S + j_s < seq_end_idx:
b_shared[j_s] = b[
batch_idx, seq_start_idx + (num_iters - 1) * block_S + j_s, bh
]
else:
b_shared[j_s] = 0
if tx < 128:
T.set_max_nreg(CONSUMER_S_NREG, 1)
if use_dht:
T.copy(dht[bb, bh, 0:DK, 0:DV], dh_fragment)
else:
T.clear(dh_fragment)
T.copy(dh_fragment, tmp_shared_4_1)
for i_s in T.serial(num_iters):
T.barrier_arrive(bar_00)
# 00
T.barrier_wait(bar_00, (i_s + 0) % 2)
for j_s in T.Parallel(block_S):
g_exp_shared[j_s] = T.exp2(g_shared[j_s] * 1.442695)
g_rev_exp_shared[j_s] = T.exp2(
(g_shared[block_S - 1] - g_shared[j_s]) * 1.442695
)
T.barrier_arrive(bar_01)
# 01, 02, 03
T.barrier_wait(bar_01, (i_s + 0) % 2)
g_last_local_3[0] = g_exp_shared[block_S - 1]
# dS0 = g_last * dSt
for j_k, j_v in T.Parallel(DK, DV):
dh_fragment[j_k, j_v] *= g_last_local_3[0]
T.barrier_arrive(bar_04)
# 04, 05, 06, 07
T.barrier_wait(bar_04, (i_s + 0) % 2)
# dg_last += sum(dS0 * S0)
T.clear(reduce_fragment)
for j_k, j_v in T.Parallel(DK, DV):
reduce_fragment[
j_k % 64 // 16 * 32 + j_k % 8 * 4 + j_v % 8 // 2, j_v % 2
] += dh_fragment[j_k, j_v] * h_shared[j_k, j_v]
T.barrier_arrive(bar_08)
T.barrier_wait(bar_08, (i_s + 0) % 2)
T.barrier_wait(bar_09, (i_s + 0) % 2)
# 10
T.barrier_wait(bar_10, (i_s + 0) % 2)
T.reduce_sum(
T.reshape(reduce_fragment, (128 * 2,)),
dg_last_local_3,
dim=0,
clear=True,
)
dg_shared[block_S - 1] += dg_last_local_3[0]
T.barrier_arrive(bar_11)
# 11
T.barrier_wait(bar_11, (i_s + 0) % 2)
# dS0 += K^T @ dVg
T.gemm_v1(
tmp_shared_2_2,
tmp_shared_2_3,
dh_fragment,
transpose_A=True,
clear_accum=False,
)
T.barrier_arrive(bar_12)
T.barrier_wait(bar_12, (i_s + 0) % 2)
# 13
T.barrier_wait(bar_13, (i_s + 0) % 2)
# dOg = s * g * dO
for j_s, j_v in T.Parallel(block_S, DV):
tmp_shared_2_3[j_s, j_v] = (
scale * do_shared[j_s, j_v] * g_exp_shared[j_s]
)
T.barrier_arrive(bar_14)
# 14
T.barrier_wait(bar_14, (i_s + 0) % 2)
# dS0 += Q^T @ dOg
T.gemm_v1(
tmp_shared_2_1,
tmp_shared_2_3,
dh_fragment,
transpose_A=True,
clear_accum=False,
)
T.barrier_arrive(bar_15)
# 15
T.barrier_wait(bar_15, (i_s + 0) % 2)
# S4[1] = dS0
T.copy(dh_fragment, tmp_shared_4_1)
if use_dht:
T.copy(dh_fragment, dh0[bb, bh, 0:DK, 0:DV])
elif tx < 256:
T.set_max_nreg(CONSUMER_K_NREG, 1)
for i_s in T.serial(num_iters):
T.barrier_arrive(bar_00)
# 16 == 00
T.barrier_wait(bar_00, (i_s + 0) % 2)
# S2[S] dK
if i_s > 0:
T.copy(dk_fragment, dqkv_shared)
T.barrier_arrive(bar_01)
# 01
T.barrier_wait(bar_01, (i_s + 0) % 2)
# dV' = K @ dSt
T.gemm_v1(k_shared, tmp_shared_4_1, dv_fragment, clear_accum=True)
# dV' = g_last/g * dV'
for j_s, j_v in T.Parallel(block_S, DV):
dv_fragment[j_s, j_v] *= g_rev_exp_shared[j_s]
T.barrier_arrive(bar_02)
# 02
T.barrier_wait(bar_02, (i_s + 0) % 2)
# dV' += Pg^T @ dO
T.gemm_v1(
tmp_shared_1_1,
do_shared,
dv_fragment,
transpose_A=True,
clear_accum=False,
)
T.barrier_arrive(bar_03)
# 03
T.barrier_wait(bar_03, (i_s + 0) % 2)
# S2[1] dV'
T.copy(dv_fragment, tmp_shared_2_1)
T.barrier_arrive(bar_04)
# 04
T.barrier_wait(bar_04, (i_s + 0) % 2)
# dV = Ag^T @ dV'
T.gemm_v1(
tmp_shared_1_2,
tmp_shared_2_1,
dv_fragment,
transpose_A=True,
clear_accum=True,
)
# S2[S] dV
T.copy(dv_fragment, dqkv_shared)
T.barrier_arrive(bar_05)
# 05
T.barrier_wait(bar_05, (i_s + 0) % 2)
# dVg = -g * dV
for j_s, j_v in T.Parallel(block_S, DV):
dv_fragment[j_s, j_v] = (
-dv_fragment[j_s, j_v] * g_exp_shared[j_s]
)
# dg += sum(dVg * U)
T.copy(tmp_shared_2_3, odot_fragment_1)
for j_s, j_v in T.Parallel(block_S, DV):
odot_fragment_1[j_s, j_v] *= dv_fragment[j_s, j_v]
T.reduce_sum(odot_fragment_1, dg_fragment_1, dim=1, clear=True)
T.copy(dg_fragment_1, dg_shared)
# S2[3] dVg
T.copy(dv_fragment, tmp_shared_2_3)
T.barrier_arrive(bar_06)
# 06
T.barrier_wait(bar_06, (i_s + 0) % 2)
# S2[2] K
T.copy(k_shared, odot_fragment_1)
T.copy(odot_fragment_1, tmp_shared_2_2)
T.barrier_arrive(bar_07)
# 07
T.barrier_wait(bar_07, (i_s + 0) % 2)
# dK = V' @ dSt^T
T.gemm_v1(
tmp_shared_2_1,
tmp_shared_4_1,
dk_fragment,
transpose_B=True,
clear_accum=True,
)
T.barrier_arrive(bar_08)
# 08
T.barrier_wait(bar_08, (i_s + 0) % 2)
# dK = g_last/g * dK
for j_s, j_k in T.Parallel(block_S, DK):
dk_fragment[j_s, j_k] *= g_rev_exp_shared[j_s]
# dg -= sum(K * dK)
for j_s, j_k in T.Parallel(block_S, DK):
odot_fragment_1[j_s, j_k] *= -dk_fragment[j_s, j_k]
T.reduce_sum(odot_fragment_1, dg_fragment_1, dim=1, clear=True)
for j_s in T.Parallel(block_S):
dg_shared[j_s] += dg_fragment_1[j_s]
# dg_last += sum(K * dK)
T.reduce_sum(dg_fragment_1, dg_last_local_1, dim=0, clear=True)
# Sg[S] dg
dg_shared[block_S - 1] -= dg_last_local_1[0]
T.barrier_arrive(bar_09)
# 09
T.barrier_wait(bar_09, (i_s + 0) % 2)
# dK += dVg @ S0^T
T.gemm_v1(
tmp_shared_2_3,
h_shared,
dk_fragment,
transpose_B=True,
clear_accum=False,
)
T.barrier_arrive(bar_10)
T.barrier_wait(bar_10, (i_s + 0) % 2)
# 12
T.barrier_wait(bar_12, (i_s + 0) % 2)
# dK += dP^T @ Q
T.gemm_v1(
tmp_shared_1_1,
tmp_shared_2_1,
dk_fragment,
transpose_A=True,
clear_accum=False,
)
T.barrier_arrive(bar_13)
T.barrier_wait(bar_13, (i_s + 0) % 2)
# 15
T.barrier_wait(bar_15, (i_s + 0) % 2)
# dK += dAs @ K
T.gemm_v1(
tmp_shared_1_2, tmp_shared_2_2, dk_fragment, clear_accum=False
)
for j_s, j_k in T.Parallel(block_S, DK):
if seq_start_idx + j_s < seq_end_idx:
dk[batch_idx, seq_start_idx + j_s, bh, j_k] = dk_fragment[
j_s, j_k
]
elif tx < 384:
T.set_max_nreg(CONSUMER_A_NREG, 1)
for i_s in T.serial(num_iters):
T.barrier_arrive(bar_00)
# 00
T.barrier_wait(bar_00, (i_s + 0) % 2)
# P = Q @ K^T
T.gemm_v1(
q_shared,
k_shared,
p_fragment,
transpose_B=True,
clear_accum=True,
)
T.barrier_arrive(bar_01)
# 01
T.barrier_wait(bar_01, (i_s + 0) % 2)
# G = Lower(diag(g) @ I @ diag(1/g))
for j_s, j_t in T.Parallel(block_S, block_S):
mask_fragment[j_s, j_t] = g_shared[j_s] - g_shared[j_t]
for j_s, j_t in T.Parallel(block_S, block_S):
if j_s >= j_t:
mask_fragment[j_s, j_t] = T.exp2(
mask_fragment[j_s, j_t] * 1.442695
)
else:
mask_fragment[j_s, j_t] = 0
# Pg = s * P * G
for j_s, j_t in T.Parallel(block_S, block_S):
p_fragment[j_s, j_t] *= mask_fragment[j_s, j_t]
for j_s, j_t in T.Parallel(block_S, block_S):
p_fragment[j_s, j_t] *= scale
# S1[1] Pg
T.copy(p_fragment, tmp_shared_1_1)
T.barrier_arrive(bar_02)
# 02
T.barrier_wait(bar_02, (i_s + 0) % 2)
# Ab = Ar * b
T.copy(a_shared, a_fragment)
for j_s, j_t in T.Parallel(block_S, block_S):
a_fragment[j_s, j_t] *= b_shared[j_t]
# Ag = G * Ab
for j_s, j_t in T.Parallel(block_S, block_S):
a_fragment[j_s, j_t] *= mask_fragment[j_s, j_t]
# S1[2] Ag
T.copy(a_fragment, tmp_shared_1_2)
T.barrier_arrive(bar_03)
# 03
T.barrier_wait(bar_03, (i_s + 0) % 2)
# U = K @ S0
T.gemm_v1(k_shared, h_shared, u_fragment, clear_accum=True)
T.barrier_arrive(bar_04)
# 04
T.barrier_wait(bar_04, (i_s + 0) % 2)
# S2[3] U
T.copy(u_fragment, tmp_shared_2_3)
# W = V - g * U
for j_s, j_v in T.Parallel(block_S, DV):
u_fragment[j_s, j_v] *= -g_exp_shared[j_s]
for j_s, j_v in T.Parallel(block_S, DV):
u_fragment[j_s, j_v] += v_shared[j_s, j_v]
# S2[2] W
T.copy(u_fragment, tmp_shared_2_2)
T.barrier_arrive(bar_05)
# 05
T.barrier_wait(bar_05, (i_s + 0) % 2)
# dAg = dV' @ W^T
T.gemm_v1(
tmp_shared_2_1,
tmp_shared_2_2,
da_fragment,
transpose_B=True,
clear_accum=True,
)
# V' = Ag @ W
T.gemm_v1(
tmp_shared_1_2, tmp_shared_2_2, u_fragment, clear_accum=True
)
# S2[1] V'
T.copy(u_fragment, tmp_shared_2_1)
T.barrier_arrive(bar_06)
# 06
T.barrier_wait(bar_06, (i_s + 0) % 2)
# dPg = dO @ V'^T
T.gemm_v1(
do_shared,
tmp_shared_2_1,
dp_fragment,
transpose_B=True,
clear_accum=True,
)
T.barrier_arrive(bar_07)
# 07
T.barrier_wait(bar_07, (i_s + 0) % 2)
# dAb = G * dAg
for j_s, j_t in T.Parallel(block_S, block_S):
da_fragment[j_s, j_t] *= mask_fragment[j_s, j_t]
# dg += sum((dPg * P) - (dPg * P)^T)
T.copy(tmp_shared_1_1, p_fragment)
for j_s, j_t in T.Parallel(block_S, block_S):
p_fragment[j_s, j_t] *= dp_fragment[j_s, j_t]
T.copy(p_fragment, tmp_shared_1_1)
for j_s, j_t in T.Parallel(block_S, block_S):
p_fragment[j_s, j_t] -= tmp_shared_1_1[j_t, j_s]
T.reduce_sum(p_fragment, dg_fragment_2, dim=1, clear=True)
# dP = s * G * dPg
for j_s, j_t in T.Parallel(block_S, block_S):
dp_fragment[j_s, j_t] *= mask_fragment[j_s, j_t]
for j_s, j_t in T.Parallel(block_S, block_S):
dp_fragment[j_s, j_t] *= scale
# S1[1] dP
T.copy(dp_fragment, tmp_shared_1_1)
T.barrier_arrive(bar_08)
# 08
T.barrier_wait(bar_08, (i_s + 0) % 2)
# dQ = dO @ S0^T
T.gemm_v1(
do_shared,
h_shared,
dq_fragment,
transpose_B=True,
clear_accum=True,
)
T.barrier_arrive(bar_09)
# 09
T.barrier_wait(bar_09, (i_s + 0) % 2)
# dQ = s * g * dQ
for j_s, j_k in T.Parallel(block_S, DK):
dq_fragment[j_s, j_k] *= g_exp_shared[j_s]
for j_s, j_k in T.Parallel(block_S, DK):
dq_fragment[j_s, j_k] *= scale
# S2[1] Q
T.copy(q_shared, odot_fragment_2)
# dg += sum(Q * dQ)
T.copy(odot_fragment_2, tmp_shared_2_1)
for j_s, j_k in T.Parallel(block_S, DK):
odot_fragment_2[j_s, j_k] *= dq_fragment[j_s, j_k]
T.reduce_sum(odot_fragment_2, dg_fragment_2, dim=1, clear=False)
T.barrier_arrive(bar_10)
# 10
T.barrier_wait(bar_10, (i_s + 0) % 2)
# dQ += dP @ K
T.gemm_v1(
tmp_shared_1_1, tmp_shared_2_2, dq_fragment, clear_accum=False
)
# S2[S] dQ
T.copy(dq_fragment, dqkv_shared)
T.barrier_arrive(bar_11)
# 11, 12
T.barrier_wait(bar_11, (i_s + 0) % 2)
# dAb * Ar
T.copy(a_shared, a_fragment)
for j_s, j_t in T.Parallel(block_S, block_S):
a_fragment[j_s, j_t] *= da_fragment[j_s, j_t]
T.copy(a_fragment, tmp_shared_1_3)
# dAb * Ab [ = G * dAg * Ab ]
for j_s, j_t in T.Parallel(block_S, block_S):
a_fragment[j_s, j_t] *= b_shared[j_t]
# dg += sum((dAb * Ab) - (dAb * Ab)^T)
T.copy(a_fragment, tmp_shared_1_2)
for j_s, j_t in T.Parallel(block_S, block_S):
a_fragment[j_s, j_t] -= tmp_shared_1_2[j_t, j_s]
T.reduce_sum(a_fragment, dg_fragment_2, dim=1, clear=False)
# Sg[S] dg
for j_s in T.Parallel(block_S):
dg_shared[j_s] += dg_fragment_2[j_s]
# db = sum((dAb * Ar)^T)
for j_s, j_t in T.Parallel(block_S, block_S):
a_fragment[j_s, j_t] = tmp_shared_1_3[j_t, j_s]
T.reduce_sum(a_fragment, db_fragment, dim=1, clear=True)
# dAr = dAb * b
for j_s, j_t in T.Parallel(block_S, block_S):
da_fragment[j_s, j_t] *= b_shared[j_t]
# S1[2] dAr
T.copy(da_fragment, tmp_shared_1_2)
T.barrier_arrive(bar_13)
# 13
T.barrier_wait(bar_13, (i_s + 0) % 2)
# dA = -Ar^T @ dAr @ Ar^T
T.gemm_v1(
a_shared,
tmp_shared_1_2,
da_fragment,
transpose_A=True,
clear_accum=True,
)
T.copy(da_fragment, tmp_shared_1_2)
T.gemm_v1(
tmp_shared_1_2,
a_shared,
da_fragment,
transpose_B=True,
clear_accum=True,
)
# At = K @ K^T
T.gemm_v1(
tmp_shared_2_2,
tmp_shared_2_2,
a_fragment,
transpose_B=True,
clear_accum=True,
)
T.barrier_arrive(bar_14)
# 14
T.barrier_wait(bar_14, (i_s + 0) % 2)
for j_s, j_t in T.Parallel(block_S, block_S):
if j_s <= j_t:
da_fragment[j_s, j_t] = 0
else:
da_fragment[j_s, j_t] = -da_fragment[j_s, j_t]
# db += sum(dA * At)
for j_s, j_t in T.Parallel(block_S, block_S):
a_fragment[j_s, j_t] *= da_fragment[j_s, j_t]
T.reduce_sum(a_fragment, db_fragment, dim=1, clear=False)
T.copy(db_fragment, db_shared)
# dAt = b * dA
for j_s, j_t in T.Parallel(block_S, block_S):
da_fragment[j_s, j_t] *= b_shared[j_s]
# dAs = dAt + dAt^T
T.copy(da_fragment, tmp_shared_1_2)
for j_s, j_t in T.Parallel(block_S, block_S):
da_fragment[j_s, j_t] += tmp_shared_1_2[j_t, j_s]
# S1[1] dAs
T.copy(da_fragment, tmp_shared_1_2)
T.barrier_arrive(bar_15)
T.barrier_wait(bar_15, (i_s + 0) % 2)
else:
T.set_max_nreg(PRODUCER_NREG, 0)
if tx < 384 + 32:
for i_s in T.serial(num_iters - 1):
chunk_idx = num_iters - i_s - 2
left = seq_start_idx + chunk_idx * block_S
right = left + block_S
T.barrier_arrive(bar_00)
T.barrier_wait(bar_00, (i_s + 0) % 2)
T.barrier_wait(bar_03, (i_s + 0) % 2)
for j_s in T.Parallel(block_S):
g_shared[j_s] = g[batch_idx, left + j_s, bh]
T.barrier_wait(bar_05, (i_s + 0) % 2)
T.copy(v[batch_idx, left:right, bh, 0:DV], v_shared)
T.barrier_wait(bar_07, (i_s + 0) % 2)
T.copy(k[batch_idx, left:right, bhg, 0:DK], k_shared)
T.barrier_wait(bar_10, (i_s + 0) % 2)
T.copy(q[batch_idx, left:right, bhg, 0:DK], q_shared)
if num_iters > 0:
T.barrier_arrive(bar_00)
elif tx < 384 + 64:
for i_s in T.serial(num_iters):
left = seq_start_idx + (num_iters - i_s - 1) * block_S
right = left + block_S
T.barrier_arrive(bar_00)
T.barrier_wait(bar_00, (i_s + 0) % 2)
T.barrier_wait(bar_01, (i_s + 0) % 2)
if i_s == 1:
for j_s, j_k in T.Parallel(block_S, DK):
if left + block_S + j_s < seq_end_idx:
dk[batch_idx, left + block_S + j_s, bh, j_k] = (
dqkv_shared[j_s, j_k]
)
elif i_s > 1:
T.copy(
dqkv_shared,
dk[
batch_idx,
left + block_S : right + block_S,
bh,
0:DK,
],
)
T.barrier_arrive(bar_04)
T.barrier_wait(bar_04, (i_s + 0) % 2)
T.barrier_wait(bar_05, (i_s + 0) % 2)
if i_s == 0:
for j_s, j_v in T.Parallel(block_S, DV):
if left + j_s < seq_end_idx:
dv[batch_idx, left + j_s, bh, j_v] = dqkv_shared[
j_s, j_v
]
else:
T.copy(dqkv_shared, dv[batch_idx, left:right, bh, 0:DV])
T.barrier_arrive(bar_10)
T.barrier_wait(bar_10, (i_s + 0) % 2)
T.barrier_wait(bar_11, (i_s + 0) % 2)
if i_s == 0:
for j_s, j_k in T.Parallel(block_S, DK):
if left + j_s < seq_end_idx:
dq[batch_idx, left + j_s, bh, j_k] = dqkv_shared[
j_s, j_k
]
else:
T.copy(dqkv_shared, dq[batch_idx, left:right, bh, 0:DK])
elif tx < 384 + 96:
for i_s in T.serial(num_iters - 1):
chunk_idx = num_iters - i_s - 2
left = seq_start_idx + chunk_idx * block_S
right = left + block_S
T.barrier_arrive(bar_02)
T.barrier_wait(bar_02, (i_s + 0) % 2)
T.barrier_wait(bar_10, (i_s + 0) % 2)
T.copy(
h[batch_idx, chunk_start_idx + chunk_idx, bh, 0:DK, 0:DV],
h_shared,
)
T.barrier_wait(bar_14, (i_s + 0) % 2)
T.copy(a[batch_idx, left:right, bh, 0:block_S], a_shared)
T.copy(do[batch_idx, left:right, bh, 0:DV], do_shared)
T.barrier_wait(bar_15, (i_s + 0) % 2)
for j_s in T.Parallel(block_S):
b_shared[j_s] = b[batch_idx, left + j_s, bh]
if num_iters > 0:
T.barrier_wait(bar_00, (num_iters - 1) % 2)
T.barrier_arrive(bar_02)
else:
for i_s in T.serial(num_iters):
left = seq_start_idx + (num_iters - i_s - 1) * block_S
T.barrier_arrive(bar_05)
T.barrier_wait(bar_05, (i_s + 0) % 2)
T.barrier_wait(bar_15, (i_s + 0) % 2)
if i_s == 0:
for j_s in T.Parallel(block_S):
if left + j_s < seq_end_idx:
dg[batch_idx, left + j_s, bh] = dg_shared[j_s]
if (seq_end_idx - seq_start_idx) % block_S > 0:
dg[batch_idx, seq_end_idx - 1, bh] += dg_shared[
block_S - 1
]
else:
for j_s in T.Parallel(block_S):
dg[batch_idx, left + j_s, bh] = dg_shared[j_s]
if i_s == 0:
for j_s in T.Parallel(block_S):
if left + j_s < seq_end_idx:
db[batch_idx, left + j_s, bh] = db_shared[j_s]
else:
for j_s in T.Parallel(block_S):
db[batch_idx, left + j_s, bh] = db_shared[j_s]
return tilelang_fused_chunk_gdr_bwd_kernel
def fused_gdr_bwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
a: torch.Tensor,
g: torch.Tensor,
b: torch.Tensor,
do: torch.Tensor,
dht: torch.Tensor,
h: torch.Tensor,
scale: float | None = None,
cu_seqlens: torch.LongTensor | None = None,
chunk_size: int = 64,
):
batch_size, num_tokens, Hg, K = k.shape
_, _, H, V = v.shape
scale = scale or K ** (-0.5)
assert K == V == 128
assert chunk_size == 64
if cu_seqlens is None:
real_batch_size = batch_size
cu_seqlens = torch.empty((batch_size + 1), dtype=torch.int32, device=k.device)
chunk_offsets = torch.empty(
(batch_size + 1), dtype=torch.int32, device=k.device
)
is_varlen = False
else:
real_batch_size = len(cu_seqlens) - 1
chunk_offsets, _ = prepare_chunk_offsets(cu_seqlens, chunk_size)
chunk_offsets = chunk_offsets.to(cu_seqlens.dtype)
is_varlen = True
use_dht = dht is not None
if dht is None:
dht = torch.empty(
(real_batch_size, H, K, V), dtype=torch.float32, device=k.device
)
dq = torch.empty_like(v)
dk = torch.empty_like(v)
dv = torch.empty_like(v)
dg = torch.empty_like(g)
db = torch.empty_like(b)
dh0 = torch.empty_like(dht)
tilelang_fused_chunk_gdr_bwd_kernel = tilelang_fused_chunk_gdr_bwd(
H,
Hg,
K,
V,
chunk_size,
scale,
qkva_dtype=q.dtype,
g_dtype=g.dtype,
b_dtype=b.dtype,
h_dtype=h.dtype,
o_dtype=do.dtype,
seqlen_dtype=cu_seqlens.dtype,
accum_dtype="float32",
is_varlen=is_varlen,
use_dht=use_dht,
)
tilelang_fused_chunk_gdr_bwd_kernel(
do,
dht,
q,
k,
v,
a,
g,
b,
h,
cu_seqlens,
chunk_offsets,
dq,
dk,
dv,
dg,
db,
dh0,
)
if not use_dht:
dh0 = None
return dq, dk, dv, dg, db, dh0

View File

@@ -0,0 +1,658 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
import torch
import tilelang
import tilelang.language as T
from flash_qla.utils import prepare_chunk_offsets
MULTI_PROCESSOR_COUNT = torch.cuda.get_device_properties().multi_processor_count
TARGET_NUM_CTAS = int(MULTI_PROCESSOR_COUNT * 0.7)
@tilelang.jit(
# out_idx=[-3, -2, -1],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
# tilelang.PassConfigKey.TL_DISABLE_THREAD_STORAGE_SYNC: True,
},
)
def tilelang_fused_chunk_gdr_fwd(
H,
Hg,
DK,
DV,
chunk_size,
scale,
accum_dtype,
qkva_dtype,
g_dtype,
b_dtype,
h0_dtype,
ht_dtype,
h_dtype,
o_dtype,
seqlen_dtype,
use_initial_state,
store_final_state,
store_h,
store_o,
is_varlen,
is_cp,
block_DV=128,
):
batch_size = T.dynamic("batch_size")
num_tokens = T.dynamic("num_tokens")
num_chunks = T.dynamic("num_chunks")
raw_batch_size = T.dynamic("raw_batch_size")
block_S = chunk_size
if is_varlen:
q_shape = (1, num_tokens, Hg, DK)
k_shape = (1, num_tokens, Hg, DK)
v_shape = (1, num_tokens, H, DV)
o_shape = (1, num_tokens, H, DV)
a_shape = (1, num_tokens, H, chunk_size)
g_shape = (1, num_tokens, H)
b_shape = (1, num_tokens, H)
h_shape = (1, num_chunks, H, DK, DV)
else:
q_shape = (batch_size, num_tokens, Hg, DK)
k_shape = (batch_size, num_tokens, Hg, DK)
v_shape = (batch_size, num_tokens, H, DV)
o_shape = (batch_size, num_tokens, H, DV)
a_shape = (batch_size, num_tokens, H, chunk_size)
g_shape = (batch_size, num_tokens, H)
b_shape = (batch_size, num_tokens, H)
h_shape = (batch_size, num_chunks, H, DK, DV)
h0_shape = (batch_size, H, DK, DV)
ht_shape = (raw_batch_size, H, DK, DV)
@T.prim_func
def tilelang_fused_chunk_gdr_fwd_kernel(
q: T.Tensor(q_shape, dtype=qkva_dtype),
k: T.Tensor(k_shape, dtype=qkva_dtype),
v: T.Tensor(v_shape, dtype=qkva_dtype),
a: T.Tensor(a_shape, dtype=qkva_dtype),
g: T.Tensor(g_shape, dtype=g_dtype),
b: T.Tensor(b_shape, dtype=b_dtype),
h0: T.Tensor(h0_shape, dtype=h0_dtype),
cu_seqlens: T.Tensor([batch_size + 1], dtype=seqlen_dtype),
chunk_offsets: T.Tensor([batch_size + 1], dtype=seqlen_dtype),
cp_seq_map: T.Tensor([batch_size], dtype=seqlen_dtype),
raw_cu_seqlens: T.Tensor([raw_batch_size + 1], dtype=seqlen_dtype),
o: T.Tensor(o_shape, dtype=o_dtype),
h: T.Tensor(h_shape, dtype=h_dtype),
ht: T.Tensor(ht_shape, dtype=ht_dtype),
):
with T.Kernel(T.ceildiv(DV, block_DV) * batch_size * H, threads=512) as (bbhv,):
bbh, bv = bbhv // T.ceildiv(DV, block_DV), bbhv % T.ceildiv(DV, block_DV)
bb, bh = bbh // H, bbh % H
bhg = bh // (H // Hg)
batch_idx = T.alloc_var("int32")
seq_start_idx = T.alloc_var("int32")
seq_end_idx = T.alloc_var("int32")
seq_split_idx = T.alloc_var("int32")
chunk_start_idx = T.alloc_var("int32")
chunk_split_idx = T.alloc_var("int32")
batch_idx = 0 if is_varlen else bb
seq_start_idx = cu_seqlens[bb] if is_varlen else 0
seq_end_idx = cu_seqlens[bb + 1] if is_varlen else num_tokens
chunk_start_idx = chunk_offsets[bb] if is_varlen else 0
raw_batch_idx = T.alloc_var("int32")
raw_seq_end_idx = T.alloc_var("int32")
need_store_final_state = T.alloc_var("bool")
raw_batch_idx = cp_seq_map[bb] if is_cp else bb
raw_seq_end_idx = (
raw_cu_seqlens[raw_batch_idx + 1] if is_cp else seq_end_idx
)
need_store_final_state = store_final_state & (
raw_seq_end_idx == seq_end_idx
)
num_iters = T.alloc_var("int32")
num_unmasked_iters = T.alloc_var("int32")
num_iters = T.ceildiv(seq_end_idx - seq_start_idx, block_S)
num_unmasked_iters = (seq_end_idx - seq_start_idx) // block_S
q_shared = T.alloc_shared((2, block_S, DK), dtype=qkva_dtype)
k_shared = T.alloc_shared((2, block_S, DK), dtype=qkva_dtype)
v_shared = T.alloc_shared((2, block_S, block_DV), dtype=qkva_dtype)
a_shared = T.alloc_shared((2, block_S, block_S), dtype=qkva_dtype)
g_shared = T.alloc_shared((2, block_S), dtype=accum_dtype, scope="shared")
b_shared = T.alloc_shared((2, block_S), dtype=accum_dtype, scope="shared")
o_shared = T.alloc_shared((block_S, block_DV), dtype=o_dtype)
h_shared = T.alloc_shared((DK, block_DV), dtype=qkva_dtype)
vd_shared = T.alloc_shared((block_S, block_DV), dtype=qkva_dtype)
vn_shared = T.alloc_shared((block_S, block_DV), dtype=qkva_dtype)
p_shared = T.alloc_shared((block_S, block_S), dtype=qkva_dtype)
g_exp_shared = T.alloc_shared((block_S), dtype=accum_dtype, scope="shared")
g_rev_exp_shared = T.alloc_shared(
(block_S), dtype=accum_dtype, scope="shared"
)
h_fragment = T.alloc_fragment((DK, block_DV), dtype=accum_dtype)
o_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
v_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
u_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
p_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
a_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
g_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
g_last_local = T.alloc_local((1), dtype=accum_dtype)
data_is_ready = T.alloc_barrier(arrive_count=[96] * 2)
data_is_free = T.alloc_barrier(arrive_count=[384] * 2)
bar_o = T.alloc_barrier(arrive_count=128)
bar_0 = T.alloc_barrier(arrive_count=416)
bar_1 = T.alloc_barrier(arrive_count=256)
_bar_2 = T.alloc_barrier(arrive_count=128)
bar_3 = T.alloc_barrier(arrive_count=128)
bar_4 = T.alloc_barrier(arrive_count=128)
bar_5 = T.alloc_barrier(arrive_count=416)
T.use_swizzle(10)
tx = T.get_thread_binding()
PRODUCER_NREG = 32
CONSUMER_V_NREG = 128
CONSUMER_S_NREG = 160
CONSUMER_O_NREG = 128
if tx < 128:
T.set_max_nreg(CONSUMER_S_NREG, 1)
# Initialize S
if use_initial_state:
T.copy(
h0[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV],
h_fragment,
)
else:
T.clear(h_fragment)
# Main Loop
for i_s in T.serial(num_iters):
# [STAGE 0]
T.barrier_wait(data_is_ready[i_s % 2], (i_s // 2 + 0) % 2)
T.barrier_arrive(bar_0)
# [STAGE 0] 0
T.barrier_wait(bar_0, i_s % 2)
# S4[S] S
T.copy(h_fragment, h_shared)
T.barrier_arrive(bar_1)
# [STAGE 0] 2, 3, 4
T.barrier_wait(bar_1, i_s % 2)
# S = g_last * S
g_last_local[0] = g_exp_shared[block_S - 1]
for j_k, j_v in T.Parallel(DK, block_DV):
h_fragment[j_k, j_v] *= g_last_local[0]
T.barrier_arrive(bar_5)
# [STAGE 0] 5
T.barrier_wait(bar_5, i_s % 2)
# S += K^T @ V'
T.gemm_v1(
k_shared[i_s % 2, :, :],
vn_shared,
h_fragment,
transpose_A=True,
clear_accum=False,
)
T.barrier_arrive(data_is_free[i_s % 2])
# Store final S
if need_store_final_state:
T.copy(
h_fragment,
ht[
raw_batch_idx, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV
],
)
elif tx < 256:
T.set_max_nreg(CONSUMER_V_NREG, 1)
# Main Loop
for i_s in T.serial(num_iters):
# [STAGE 0]
T.barrier_wait(data_is_ready[i_s % 2], (i_s // 2 + 0) % 2)
T.barrier_arrive(bar_0)
# [STAGE 0] 0
T.barrier_wait(bar_0, i_s % 2)
# Precompute g, g_last/g
for j_s in T.Parallel(block_S):
g_exp_shared[j_s] = T.exp2(g_shared[i_s % 2, j_s] * 1.442695)
for j_s in T.Parallel(block_S):
g_rev_exp_shared[j_s] = T.if_then_else(
seq_start_idx + i_s * block_S + j_s < seq_end_idx,
T.exp2(
(
g_shared[i_s % 2, block_S - 1]
- g_shared[i_s % 2, j_s]
)
* 1.442695
),
0.0,
)
T.barrier_arrive(bar_1)
# [STAGE 0] 1
T.barrier_wait(bar_1, i_s % 2)
# U = K @ S
T.gemm_v1(
k_shared[i_s % 2, :, :], h_shared, u_fragment, clear_accum=True
)
# [STAGE 0] 2
# W = V - g * U
for j_s, j_v in T.Parallel(block_S, block_DV):
u_fragment[j_s, j_v] *= -g_exp_shared[j_s]
for j_s, j_v in T.Parallel(block_S, block_DV):
u_fragment[j_s, j_v] += v_shared[i_s % 2, j_s, j_v]
# S2[V] W
for j_s, j_v in T.Parallel(block_S, block_DV):
v_shared[i_s % 2, j_s, j_v] = u_fragment[j_s, j_v]
# [STAGE 0] 3
T.barrier_wait(bar_3, i_s % 2)
# Vd = Ag @ W
T.gemm_v1(
a_shared[i_s % 2, :, :],
v_shared[i_s % 2, :, :],
v_fragment,
clear_accum=True,
)
# S2[2] Vd
T.copy(v_fragment, vd_shared)
T.barrier_arrive(bar_4)
# [STAGE 0] 4
# V' = g_last/g Vd
for j_s, j_v in T.Parallel(block_S, block_DV):
v_fragment[j_s, j_v] *= g_rev_exp_shared[j_s]
# S2[1] V'
T.copy(v_fragment, vn_shared)
T.barrier_arrive(bar_5)
T.barrier_wait(bar_5, i_s % 2)
T.barrier_arrive(data_is_free[i_s % 2])
elif tx < 384:
T.set_max_nreg(CONSUMER_O_NREG, 1)
# Main Loop
for i_s in T.serial(num_iters):
# [STAGE 0]
T.barrier_wait(data_is_ready[i_s % 2], (i_s // 2 + 0) % 2)
T.barrier_arrive(bar_0)
# [STAGE 0] 0
T.barrier_wait(bar_0, i_s % 2)
# P = Q K^T
T.gemm_v1(
q_shared[i_s % 2, :, :],
k_shared[i_s % 2, :, :],
p_fragment,
transpose_B=True,
clear_accum=True,
)
# [STAGE 0] 1
# G = Lower(diag(g) @ I @ diag(1/g))
for j_s, j_t in T.Parallel(block_S, block_S):
g_fragment[j_s, j_t] = (
g_shared[i_s % 2, j_s] - g_shared[i_s % 2, j_t]
)
for j_s, j_t in T.Parallel(block_S, block_S):
if j_s >= j_t:
g_fragment[j_s, j_t] = T.exp2(
g_fragment[j_s, j_t] * 1.442695
)
else:
g_fragment[j_s, j_t] = 0
# Ag = G * Ar * b
for j_s, j_t in T.Parallel(block_S, block_S):
a_fragment[j_s, j_t] = a_shared[i_s % 2, j_s, j_t]
for j_s, j_t in T.Parallel(block_S, block_S):
a_fragment[j_s, j_t] *= g_fragment[j_s, j_t]
for j_s, j_t in T.Parallel(block_S, block_S):
a_fragment[j_s, j_t] *= b_shared[i_s % 2, j_t]
for j_s, j_t in T.Parallel(block_S, block_S):
a_shared[i_s % 2, j_s, j_t] = a_fragment[j_s, j_t]
# [STAGE 0] 2
T.barrier_wait(bar_1, i_s % 2)
# O = Q @ S
T.gemm_v1(
q_shared[i_s % 2, :, :], h_shared, o_fragment, clear_accum=True
)
# [STAGE 0] 3
# Pg = s * G * P
for j_s, j_t in T.Parallel(block_S, block_S):
p_fragment[j_s, j_t] *= scale * g_fragment[j_s, j_t]
# S1[1] Pg
T.copy(p_fragment, p_shared)
T.barrier_arrive(bar_3)
# O = s * g * O
for j_s, j_k in T.Parallel(block_S, DK):
o_fragment[j_s, j_k] *= scale * g_exp_shared[j_s]
# [STAGE 0] 4
T.barrier_wait(bar_4, i_s % 2)
# O += Pg @ Vd
T.gemm_v1(p_shared, vd_shared, o_fragment, clear_accum=False)
T.barrier_arrive(bar_5)
# [STAGE 0] 5
T.barrier_wait(bar_5, i_s % 2)
# S2[S] O
T.copy(o_fragment, o_shared)
T.barrier_arrive(data_is_free[i_s % 2])
T.barrier_arrive(bar_o)
else:
T.set_max_nreg(PRODUCER_NREG, 0)
if tx < 384 + 32:
for i_s in T.serial(num_iters):
T.barrier_wait(data_is_free[i_s % 2], (i_s // 2 + 1) % 2)
left = seq_start_idx + i_s * block_S
right = left + block_S
# Load Q
T.copy(
q[batch_idx, left:right, bhg, 0:DK], q_shared[i_s % 2, :, :]
)
# Load K
T.copy(
k[batch_idx, left:right, bhg, 0:DK], k_shared[i_s % 2, :, :]
)
T.barrier_arrive(data_is_ready[i_s % 2])
elif tx < 384 + 64:
for i_s in T.serial(num_iters):
T.barrier_wait(data_is_free[i_s % 2], (i_s // 2 + 1) % 2)
left = seq_start_idx + i_s * block_S
right = left + block_S
# Load V
T.copy(
v[
batch_idx,
left:right,
bh,
bv * block_DV : (bv + 1) * block_DV,
],
v_shared[i_s % 2, :, :],
)
# Load beta
if right <= seq_end_idx:
for j_s in T.Parallel(block_S):
b_shared[i_s % 2, j_s] = b[batch_idx, left + j_s, bh]
else:
for j_s in T.Parallel(block_S):
if left + j_s < seq_end_idx:
b_shared[i_s % 2, j_s] = b[
batch_idx, left + j_s, bh
]
else:
b_shared[i_s % 2, j_s] = 0
T.barrier_arrive(data_is_ready[i_s % 2])
elif tx < 384 + 96:
for i_s in T.serial(num_iters):
T.barrier_wait(data_is_free[i_s % 2], (i_s // 2 + 1) % 2)
left = seq_start_idx + i_s * block_S
right = left + block_S
# Load A
T.copy(
a[batch_idx, left:right, bh, 0:block_S],
a_shared[i_s % 2, :, :],
)
# Load gamma
if right <= seq_end_idx:
for j_s in T.Parallel(block_S):
g_shared[i_s % 2, j_s] = g[batch_idx, left + j_s, bh]
else:
for j_s in T.Parallel(block_S):
if left + j_s < seq_end_idx:
g_shared[i_s % 2, j_s] = g[
batch_idx, left + j_s, bh
]
else:
g_shared[i_s % 2, j_s] = g[
batch_idx, seq_end_idx - 1, bh
]
T.barrier_arrive(data_is_ready[i_s % 2])
else:
for i_s in T.serial(num_unmasked_iters):
right = seq_start_idx + i_s * block_S
left = right - block_S
T.barrier_arrive(bar_0)
T.barrier_wait(bar_0, i_s % 2)
# Store O
if i_s > 0 and store_o:
T.copy(
o_shared,
o[
batch_idx,
left:right,
bh,
bv * block_DV : (bv + 1) * block_DV,
],
)
T.barrier_arrive(bar_5)
T.barrier_wait(bar_1, i_s % 2)
# Store S
if store_h:
T.copy(
h_shared,
h[
batch_idx,
chunk_start_idx + i_s,
bh,
0:DK,
bv * block_DV : (bv + 1) * block_DV,
],
)
if num_unmasked_iters < num_iters:
seq_split_idx = seq_start_idx + num_unmasked_iters * block_S
chunk_split_idx = chunk_start_idx + num_unmasked_iters
T.barrier_arrive(bar_0)
T.barrier_wait(bar_0, num_unmasked_iters % 2)
# Store O
if num_unmasked_iters > 0 and store_o:
T.copy(
o_shared,
o[
batch_idx,
seq_split_idx - block_S : seq_split_idx,
bh,
bv * block_DV : (bv + 1) * block_DV,
],
)
T.barrier_arrive(bar_5)
T.barrier_wait(bar_1, num_unmasked_iters % 2)
# Store S
if store_h:
T.copy(
h_shared,
h[
batch_idx,
chunk_split_idx,
bh,
0:DK,
bv * block_DV : (bv + 1) * block_DV,
],
)
seq_split_idx = seq_start_idx + (num_iters - 1) * block_S
# Store O
T.barrier_wait(bar_o, 0)
if store_o:
for j_s, j_v in T.Parallel(block_S, block_DV):
with T.If(seq_split_idx + j_s < seq_end_idx):
with T.Then():
o[
batch_idx,
seq_split_idx + j_s,
bh,
bv * block_DV + j_v,
] = o_shared[j_s, j_v]
return tilelang_fused_chunk_gdr_fwd_kernel
def fused_gdr_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
a: torch.Tensor,
g: torch.Tensor,
b: torch.Tensor,
scale: float | None = None,
initial_state: torch.Tensor | None = None,
output_final_state: bool = True,
output_h: bool = False,
output_o: bool = True,
cu_seqlens: torch.LongTensor | None = None,
cp_seq_map: torch.LongTensor | None = None,
raw_cu_seqlens: torch.LongTensor | None = None,
chunk_size: int = 64,
):
batch_size, num_tokens, Hg, K = k.shape
_, _, H, V = v.shape
scale = scale or K ** (-0.5)
assert K == V == 128
assert chunk_size == 64
if cu_seqlens is None:
real_batch_size = batch_size
num_chunks = tilelang.cdiv(num_tokens, chunk_size) if output_h else 0
cu_seqlens = torch.empty((batch_size + 1), dtype=torch.int32, device=k.device)
chunk_offsets = torch.empty(
(batch_size + 1), dtype=torch.int32, device=k.device
)
seqlen_dtype = torch.int32
is_varlen = False
else:
real_batch_size = len(cu_seqlens) - 1
chunk_offsets, num_chunks = prepare_chunk_offsets(cu_seqlens, chunk_size)
chunk_offsets = chunk_offsets.to(cu_seqlens.dtype)
num_chunks = num_chunks if output_h else 0
seqlen_dtype = cu_seqlens.dtype
is_varlen = True
if cp_seq_map is None:
cp_seq_map = torch.empty(
(real_batch_size,), dtype=seqlen_dtype, device=k.device
)
is_cp = False
else:
is_cp = True
use_initial_state = initial_state is not None
if initial_state is None:
initial_state = torch.empty(
(real_batch_size, H, K, V), dtype=torch.float32, device=k.device
)
h = torch.empty((batch_size, num_chunks, H, K, V), dtype=k.dtype, device=k.device)
if raw_cu_seqlens is None:
raw_cu_seqlens = torch.empty(
(real_batch_size + 1,), dtype=seqlen_dtype, device=k.device
)
final_state = torch.empty(
(real_batch_size, H, K, V), dtype=torch.float32, device=k.device
)
else:
final_state = torch.empty(
(raw_cu_seqlens.shape[0] - 1, H, K, V), dtype=torch.float32, device=k.device
)
o = torch.empty_like(v)
grid_size = real_batch_size * H
if grid_size >= TARGET_NUM_CTAS:
block_DV = 128
elif grid_size * 2 >= TARGET_NUM_CTAS:
block_DV = 64
else:
block_DV = 32
tilelang_fused_chunk_gdr_fwd_kernel = tilelang_fused_chunk_gdr_fwd(
H,
Hg,
K,
V,
chunk_size,
scale,
qkva_dtype=q.dtype,
g_dtype=g.dtype,
b_dtype=b.dtype,
h0_dtype=initial_state.dtype,
ht_dtype=final_state.dtype,
h_dtype=h.dtype,
o_dtype=o.dtype,
seqlen_dtype=seqlen_dtype,
accum_dtype="float32",
use_initial_state=use_initial_state,
store_final_state=output_final_state,
store_h=output_h,
store_o=output_o,
is_varlen=is_varlen,
is_cp=is_cp,
block_DV=block_DV,
)
tilelang_fused_chunk_gdr_fwd_kernel(
q,
k,
v,
a,
g,
b,
initial_state,
cu_seqlens,
chunk_offsets,
cp_seq_map,
raw_cu_seqlens,
o,
h,
final_state,
)
if not output_final_state:
final_state = None
if not output_h:
h = None
if not output_o:
o = None
return o, h, final_state

View File

@@ -0,0 +1,345 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
from typing import Optional
import torch
import tilelang
import tilelang.language as T
from flash_qla.utils import prepare_chunk_indices
@tilelang.jit(
# out_idx=[-1],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
# tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
# tilelang.PassConfigKey.TL_ENABLE_ASYNC_COPY: True,
},
)
def tilelang_kkt_solve(
H,
Hg,
DK,
chunk_size,
accum_dtype,
qkva_dtype,
b_dtype,
seqlen_dtype,
is_varlen,
):
data_batch_size = T.dynamic("data_batch_size")
real_batch_size = T.dynamic("real_batch_size")
num_tokens = T.dynamic("num_tokens")
num_chunks = T.dynamic("num_chunks")
block_S = chunk_size
k_shape = (data_batch_size, num_tokens, Hg, DK)
a_shape = (data_batch_size, num_tokens, H, chunk_size)
b_shape = (data_batch_size, num_tokens, H)
@T.macro
def kernel_body(
bb,
bc,
bh,
bhg,
batch_idx,
chunk_idx,
seq_start_idx,
seq_end_idx,
k,
b,
a,
):
left = seq_start_idx + chunk_idx * block_S
right = left + block_S
k_shared = T.alloc_shared((block_S, DK), dtype=qkva_dtype)
b_shared = T.alloc_shared((block_S), dtype=accum_dtype, scope="shared")
a64_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
a16i_row = T.alloc_fragment((4, 16), dtype=accum_dtype)
a16i_sum = T.alloc_fragment((4, 16), dtype=accum_dtype)
a16i_shared = T.alloc_shared((4, 17, 16), dtype=accum_dtype)
a16o_shared = T.alloc_shared((2, 17, 16), dtype=accum_dtype)
a16o_fragment = T.alloc_fragment((2, 16, 16), dtype=accum_dtype)
a32i_fragment = T.alloc_fragment((2, 32, 32), dtype=accum_dtype)
a32i0_shared = T.alloc_shared((32, 32), dtype=accum_dtype)
a32i1_shared = T.alloc_shared((32, 32), dtype=accum_dtype)
a32o_shared = T.alloc_shared((32, 32), dtype=accum_dtype)
a32o_fragment = T.alloc_fragment((32, 32), dtype=accum_dtype)
a64_shared = T.alloc_shared((block_S, block_S), dtype=qkva_dtype)
T.annotate_layout(
{
a16i_shared: tilelang.layout.make_linear_layout(a16i_shared),
a16o_shared: tilelang.layout.make_linear_layout(a16o_shared),
}
)
k_is_ready = T.alloc_barrier(arrive_count=32)
a_is_ready = T.alloc_barrier(arrive_count=128)
tx = T.get_thread_binding()
PRODUCER_NREG = 24
CONSUMER_NREG = 64
if tx < 128:
T.set_max_nreg(CONSUMER_NREG, 1)
# Load b
if right <= seq_end_idx:
for j_s in T.Parallel(block_S):
b_shared[j_s] = b[bb, left + j_s, bh]
else:
for j_s in T.Parallel(block_S):
if left + j_s < seq_end_idx:
b_shared[j_s] = b[bb, left + j_s, bh]
else:
b_shared[j_s] = 0
T.barrier_wait(k_is_ready, 0)
# A = K @ K^T
T.gemm_v1(
k_shared, k_shared, a64_fragment, transpose_B=True, clear_accum=True
)
# A = b * A
for j_s, j_t in T.Parallel(block_S, block_S):
a64_fragment[j_s, j_t] *= b_shared[j_s]
# A = I + StrictLower(A)
for j_s, j_t in T.Parallel(block_S, block_S):
if j_s < j_t:
a64_fragment[j_s, j_t] = 0
elif j_s == j_t:
a64_fragment[j_s, j_t] = 1
# Prepare inversion input
for j_s, j_t in T.Parallel(block_S, block_S):
if j_s >= 32 and j_t < 32:
a32o_shared[j_s - 32, j_t] = -a64_fragment[j_s, j_t]
elif (j_s // 16) == (j_t // 16) + 1:
a16o_shared[j_s // 32, j_s % 16, j_t % 16] = -a64_fragment[j_s, j_t]
elif (j_s // 16) == (j_t // 16):
a16i_shared[j_s // 16, j_s % 16, j_t % 16] = a64_fragment[j_s, j_t]
# Diagonal 4x16x16
T.clear(a16i_row)
for k_s in T.unroll(1, 16):
for j_s, k_t in T.Parallel(4, 16):
if k_t < k_s:
a16i_row[j_s, k_t] = a16i_shared[j_s, k_s, k_t]
T.clear(a16i_sum)
for k_r in T.unroll(k_s):
for j_s, k_t in T.Parallel(4, 16):
a16i_sum[j_s, k_t] -= (
a16i_shared[j_s, k_r, k_t] * a16i_row[j_s, k_r]
)
for j_s, k_t in T.Parallel(4, 16):
if k_t < k_s:
a16i_shared[j_s, k_s, k_t] = a16i_sum[j_s, k_t]
# First level 2x16x16
T.clear(a16o_fragment)
for k_r in T.unroll(16):
for j_s, k_s, k_t in T.Parallel(2, 16, 16):
a16o_fragment[j_s, k_s, k_t] += (
a16i_shared[j_s * 2 + 1, k_s, k_r] * a16o_shared[j_s, k_r, k_t]
)
for j_s, k_s, k_t in T.Parallel(2, 16, 16):
a16o_shared[j_s, k_t, k_s] = a16o_fragment[j_s, k_s, k_t]
T.clear(a16o_fragment)
for k_r in T.unroll(16):
for j_s, k_s, k_t in T.Parallel(2, 16, 16):
a16o_fragment[j_s, k_s, k_t] += (
a16o_shared[j_s, k_r, k_s] * a16i_shared[j_s * 2, k_r, k_t]
)
T.copy(a16o_fragment, a16o_shared[:, 0:16, 0:16])
# Second level 1x32x32
for j_s, k_s, k_t in T.Parallel(2, 32, 32):
if k_s < 16 and k_t >= 16:
a32i_fragment[j_s, k_s, k_t] = 0
for j_s, k_s, k_t in T.Parallel(2, 32, 32):
if k_s >= 16 and k_t < 16:
a32i_fragment[j_s, k_s, k_t] = a16o_shared[j_s, k_s - 16, k_t]
for j_s, k_s, k_t in T.Parallel(2, 32, 32):
if k_s // 16 == k_t // 16:
a32i_fragment[j_s, k_s, k_t] = a16i_shared[
j_s * 2 + k_s // 16, k_s % 16, k_t % 16
]
for j_s, k_s, k_t in T.Parallel(2, 32, 32):
if j_s == 0:
a32i0_shared[k_s, k_t] = a32i_fragment[j_s, k_s, k_t]
else:
a32i1_shared[k_s, k_t] = a32i_fragment[j_s, k_s, k_t]
T.gemm_v1(a32i1_shared, a32o_shared, a32o_fragment, clear_accum=True)
T.copy(a32o_fragment, a32o_shared)
T.gemm_v1(a32o_shared, a32i0_shared, a32o_fragment, clear_accum=True)
# Combine inversion output
for j_s, k_s, k_t in T.Parallel(2, 32, 32):
a64_shared[j_s * 32 + k_s, j_s * 32 + k_t] = a32i_fragment[
j_s, k_s, k_t
]
for k_s, k_t in T.Parallel(32, 32):
a64_shared[32 + k_s, k_t] = a32o_fragment[k_s, k_t]
for k_s, k_t in T.Parallel(32, 32):
a64_shared[k_s, 32 + k_t] = 0
T.barrier_arrive(a_is_ready)
else:
T.set_max_nreg(PRODUCER_NREG, 0)
if tx < 128 + 32:
# Load K
T.copy(k[bb, left:right, bhg, 0:DK], k_shared)
T.barrier_arrive(k_is_ready)
elif tx < 128 + 64:
T.barrier_wait(a_is_ready, 0)
# Save A (unmasked)
if right <= seq_end_idx:
T.copy(a64_shared, a[bb, left:right, bh, 0:block_S])
else:
T.barrier_wait(a_is_ready, 0)
# Save A (masked)
if right > seq_end_idx:
for j_s, j_t in T.Parallel(block_S, block_S):
if left + j_s < seq_end_idx:
a[bb, left + j_s, bh, j_t] = a64_shared[j_s, j_t]
if is_varlen:
@T.prim_func
def tilelang_kkt_solve_kernel(
k: T.Tensor(k_shape, dtype=qkva_dtype),
b: T.Tensor(b_shape, dtype=b_dtype),
cu_seqlens: T.Tensor([real_batch_size + 1], dtype=seqlen_dtype),
chunk_indices: T.Tensor([num_chunks, 2], dtype=seqlen_dtype),
a: T.Tensor(a_shape, dtype=qkva_dtype),
):
with T.Kernel(num_chunks * H, threads=256) as (bch,):
bc, bh = bch // H, bch % H
bhg = bh // (H // Hg)
batch_idx = T.alloc_var("int32")
chunk_idx = T.alloc_var("int32")
seq_start_idx = T.alloc_var("int32")
seq_end_idx = T.alloc_var("int32")
bb = 0
batch_idx = chunk_indices[bc, 0]
chunk_idx = chunk_indices[bc, 1]
seq_start_idx = cu_seqlens[batch_idx]
seq_end_idx = cu_seqlens[batch_idx + 1]
kernel_body(
bb,
bc,
bh,
bhg,
batch_idx,
chunk_idx,
seq_start_idx,
seq_end_idx,
k,
b,
a,
)
else:
@T.prim_func
def tilelang_kkt_solve_kernel(
k: T.Tensor(k_shape, dtype=qkva_dtype),
b: T.Tensor(b_shape, dtype=b_dtype),
a: T.Tensor(a_shape, dtype=qkva_dtype),
num_chunks: T.int32,
):
with T.Kernel(num_chunks * H, threads=256) as (bch,):
bc, bh = bch // H, bch % H
bhg = bh // (H // Hg)
batch_idx = T.alloc_var("int32")
chunk_idx = T.alloc_var("int32")
seq_start_idx = T.alloc_var("int32")
seq_end_idx = T.alloc_var("int32")
bb = bc % data_batch_size
batch_idx = bb
chunk_idx = bc // data_batch_size
seq_start_idx = 0
seq_end_idx = num_tokens
kernel_body(
bb,
bc,
bh,
bhg,
batch_idx,
chunk_idx,
seq_start_idx,
seq_end_idx,
k,
b,
a,
)
return tilelang_kkt_solve_kernel
def kkt_solve(
k: torch.Tensor,
b: torch.Tensor,
chunk_size: int = 64,
cu_seqlens: Optional[torch.LongTensor] = None,
):
batch_size, num_tokens, Hg, K = k.shape
_, _, H = b.shape
assert K == 128
assert chunk_size == 64
if cu_seqlens is None:
num_chunks = batch_size * tilelang.cdiv(num_tokens, chunk_size)
seqlen_dtype = "int32"
is_varlen = False
else:
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
seqlen_dtype = cu_seqlens.dtype
is_varlen = True
a = torch.empty(
(batch_size, num_tokens, H, chunk_size), dtype=k.dtype, device=k.device
)
tilelang_kkt_solve_kernel = tilelang_kkt_solve(
H,
Hg,
K,
chunk_size,
qkva_dtype=k.dtype,
b_dtype=b.dtype,
seqlen_dtype=seqlen_dtype,
accum_dtype="float32",
is_varlen=is_varlen,
)
if is_varlen:
tilelang_kkt_solve_kernel(k, b, cu_seqlens, chunk_indices, a)
else:
tilelang_kkt_solve_kernel(k, b, a, num_chunks)
return a

View File

@@ -0,0 +1,558 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
import torch
import tilelang
import tilelang.language as T
from flash_qla.utils import prepare_chunk_offsets
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def tilelang_prepare_h(
H,
Hg,
DK,
DV,
chunk_size,
accum_dtype,
qkva_dtype,
g_dtype,
b_dtype,
h0_dtype,
ht_dtype,
h_dtype,
seqlen_dtype,
use_initial_state,
store_final_state,
store_h,
is_varlen,
is_cp,
num_stages=2,
):
batch_size = T.dynamic("batch_size")
num_tokens = T.dynamic("num_tokens")
num_chunks = T.dynamic("num_chunks")
block_S = chunk_size
if is_varlen:
k_shape = (1, num_tokens, Hg, DK)
v_shape = (1, num_tokens, H, DV)
a_shape = (1, num_tokens, H, chunk_size)
g_shape = (1, num_tokens, H)
b_shape = (1, num_tokens, H)
h_shape = (1, num_chunks, H, DK, DV)
else:
k_shape = (batch_size, num_tokens, Hg, DK)
v_shape = (batch_size, num_tokens, H, DV)
a_shape = (batch_size, num_tokens, H, chunk_size)
g_shape = (batch_size, num_tokens, H)
b_shape = (batch_size, num_tokens, H)
h_shape = (batch_size, num_chunks, H, DK, DV)
h0_shape = (batch_size, H, DK, DV)
ht_shape = (batch_size, H, DK, DV)
m_shape = (batch_size, H, DK, DK)
@T.prim_func
def tilelang_prepare_h_kernel(
k: T.Tensor(k_shape, dtype=qkva_dtype),
v: T.Tensor(v_shape, dtype=qkva_dtype),
a: T.Tensor(a_shape, dtype=qkva_dtype),
g: T.Tensor(g_shape, dtype=g_dtype),
b: T.Tensor(b_shape, dtype=b_dtype),
h0: T.Tensor(h0_shape, dtype=h0_dtype),
cu_seqlens: T.Tensor([batch_size + 1], dtype=seqlen_dtype),
chunk_offsets: T.Tensor([batch_size + 1], dtype=seqlen_dtype),
num_warmup_chunks: T.Tensor([batch_size, H], dtype=seqlen_dtype),
h: T.Tensor(h_shape, dtype=h_dtype),
ht: T.Tensor(ht_shape, dtype=ht_dtype),
mt: T.Tensor(m_shape, dtype=ht_dtype),
):
with T.Kernel(batch_size * H, threads=512) as (bbh,):
bb, bh = bbh // H, bbh % H
bhg = bh // (H // Hg)
batch_idx = T.alloc_var("int32")
seq_start_idx = T.alloc_var("int32")
seq_end_idx = T.alloc_var("int32")
_seq_split_idx = T.alloc_var("int32")
chunk_start_idx = T.alloc_var("int32")
_chunk_split_idx = T.alloc_var("int32")
batch_idx = 0 if is_varlen else bb
seq_start_idx = cu_seqlens[bb] if is_varlen else 0
seq_end_idx = cu_seqlens[bb + 1] if is_varlen else num_tokens
chunk_start_idx = chunk_offsets[bb] if is_varlen else 0
num_iters = T.alloc_var("int32")
num_iters = (
num_warmup_chunks[bb, bh]
if is_cp
else T.ceildiv(seq_end_idx - seq_start_idx, block_S)
)
calc_mt = T.alloc_var("bool")
calc_mt = is_cp and num_iters >= T.ceildiv(
seq_end_idx - seq_start_idx, block_S
)
seq_start_idx = (
seq_end_idx - num_iters * block_S if is_cp else seq_start_idx
)
k_shared = T.alloc_shared((num_stages, block_S, DK), dtype=qkva_dtype)
v_shared = T.alloc_shared((num_stages, block_S, DV), dtype=qkva_dtype)
a_shared = T.alloc_shared((num_stages, block_S, block_S), dtype=qkva_dtype)
g_shared = T.alloc_shared(
(num_stages, block_S), dtype=accum_dtype, scope="shared"
)
b_shared = T.alloc_shared(
(num_stages, block_S), dtype=accum_dtype, scope="shared"
)
h_shared = T.alloc_shared((DK, DV), dtype=qkva_dtype)
x_shared = T.alloc_shared((block_S, DK), dtype=qkva_dtype)
y_shared = T.alloc_shared((block_S, DV), dtype=qkva_dtype)
m_shared_L = T.alloc_shared((DK, DK // 2), dtype=qkva_dtype)
m_shared_R = T.alloc_shared((DK, DK // 2), dtype=qkva_dtype)
z_shared_L = T.alloc_shared((block_S, DK // 2), dtype=qkva_dtype)
z_shared_R = T.alloc_shared((block_S, DK // 2), dtype=qkva_dtype)
g_rev_exp_shared = T.alloc_shared(
(block_S), dtype=accum_dtype, scope="shared"
)
h_fragment = T.alloc_fragment((DK, DV), dtype=accum_dtype)
x_fragment = T.alloc_fragment((block_S, DK), dtype=accum_dtype)
y_fragment = T.alloc_fragment((block_S, DV), dtype=accum_dtype)
m_fragment_L = T.alloc_fragment((DK, DK // 2), dtype=accum_dtype)
m_fragment_R = T.alloc_fragment((DK, DK // 2), dtype=accum_dtype)
z_fragment_L = T.alloc_fragment((block_S, DK // 2), dtype=accum_dtype)
z_fragment_R = T.alloc_fragment((block_S, DK // 2), dtype=accum_dtype)
g_last_local_S = T.alloc_local((1), dtype=accum_dtype)
g_last_local_X = T.alloc_local((1), dtype=accum_dtype)
g_last_local_Y = T.alloc_local((1), dtype=accum_dtype)
g_prod_X = T.alloc_fragment((1), dtype=accum_dtype)
g_prod_Y = T.alloc_fragment((1), dtype=accum_dtype)
data_is_ready = T.alloc_barrier(arrive_count=[96] * num_stages)
data_is_free = T.alloc_barrier(arrive_count=[384] * num_stages)
bar_0 = T.alloc_barrier(arrive_count=416)
bar_1 = T.alloc_barrier(arrive_count=256)
bar_2 = T.alloc_barrier(arrive_count=384)
bar_3 = T.alloc_barrier(arrive_count=128)
T.use_swizzle(10)
tx = T.get_thread_binding()
PRODUCER_NREG = 24
CONSUMER_S_NREG = 168
CONSUMER_X_NREG = 160
CONSUMER_Y_NREG = 160
if tx < 128:
T.set_max_nreg(CONSUMER_S_NREG, 1)
# Initialize S
if use_initial_state:
T.copy(h0[bb, bh, 0:DK, 0:DV], h_fragment)
else:
T.clear(h_fragment)
# Main Loop
for i_s in T.serial(num_iters):
# [STAGE = i_s % num_stages]
T.barrier_wait(
data_is_ready[i_s % num_stages], (i_s // num_stages + 0) % 2
)
T.barrier_arrive(bar_0)
# [STAGE = i_s % num_stages] 0
T.barrier_wait(bar_0, i_s % 2)
# S4[1] S
T.copy(h_fragment, h_shared)
T.barrier_arrive(bar_1)
# [STAGE = i_s % num_stages] 1
T.barrier_wait(bar_1, i_s % 2)
# S = g_last * S
g_last_local_S[0] = T.exp2(
g_shared[i_s % num_stages, block_S - 1] * 1.442695
)
for j_k, j_v in T.Parallel(DK, DV):
h_fragment[j_k, j_v] *= g_last_local_S[0]
T.barrier_arrive(bar_2)
# [STAGE = i_s % num_stages] 2
T.barrier_wait(bar_2, i_s % 2)
# S += X^T @ Y
T.gemm_v1(
x_shared,
y_shared,
h_fragment,
transpose_A=True,
clear_accum=False,
)
T.barrier_arrive(bar_3)
T.barrier_arrive(data_is_free[i_s % num_stages])
# Store final S
if store_final_state:
T.copy(h_fragment, ht[bb, bh, 0:DK, 0:DV])
elif tx < 256:
T.set_max_nreg(CONSUMER_X_NREG, 1)
if calc_mt:
for j_k, j_v in T.Parallel(DK, DK // 2):
if j_k == j_v + DK // 2:
m_fragment_R[j_k, j_v] = 1
else:
m_fragment_R[j_k, j_v] = 0
g_prod_X[0] = 0
# Main Loop
for i_s in T.serial(num_iters):
# [STAGE = i_s % num_stages]
T.barrier_wait(
data_is_ready[i_s % num_stages], (i_s // num_stages + 0) % 2
)
T.barrier_arrive(bar_0)
# [STAGE = i_s % num_stages] 0
T.barrier_wait(bar_0, i_s % 2)
# X = A^T @ K
T.gemm_v1(
a_shared[i_s % num_stages, :, :],
k_shared[i_s % num_stages, :, :],
x_fragment,
transpose_A=True,
clear_accum=True,
)
# [STAGE = i_s % num_stages] 1
# X = - b * X
for j_s, j_k in T.Parallel(block_S, DK):
x_fragment[j_s, j_k] *= -b_shared[i_s % num_stages, j_s]
# S2[1] X
T.copy(x_fragment, x_shared)
T.barrier_arrive(bar_2)
if calc_mt:
# [STAGE = i_s % num_stages] 2
g_prod_X[0] += g_shared[i_s % num_stages, block_S - 1]
# S4[2] M
T.copy(m_fragment_R, m_shared_R)
# [STAGE = i_s % num_stages] 3
T.barrier_wait(bar_3, i_s % 2)
# Z = K @ M
T.gemm_v1(
k_shared[i_s % num_stages, :, :],
m_shared_R,
z_fragment_R,
clear_accum=True,
)
# S4[2] Z
T.copy(z_fragment_R, z_shared_R)
# M += X^T @ Z
T.gemm_v1(
x_shared,
z_shared_R,
m_fragment_R,
transpose_A=True,
clear_accum=False,
)
T.barrier_arrive(data_is_free[i_s % num_stages])
if calc_mt:
g_last_local_X[0] = T.exp2(g_prod_X[0] * 1.442695)
for j_k, j_v in T.Parallel(DK, DK // 2):
m_fragment_R[j_k, j_v] *= g_last_local_X[0]
T.copy(m_fragment_R, mt[bb, bh, 0:DK, DK // 2 :])
elif tx < 384:
T.set_max_nreg(CONSUMER_Y_NREG, 1)
if calc_mt:
for j_k, j_v in T.Parallel(DK, DK // 2):
if j_k == j_v:
m_fragment_L[j_k, j_v] = 1
else:
m_fragment_L[j_k, j_v] = 0
g_prod_Y[0] = 0
# Main Loop
for i_s in T.serial(num_iters):
# [STAGE = i_s % num_stages]
T.barrier_wait(
data_is_ready[i_s % num_stages], (i_s // num_stages + 0) % 2
)
T.barrier_arrive(bar_0)
# [STAGE = i_s % num_stages] 0
T.barrier_wait(bar_0, i_s % 2)
# Precompute g_last/g
g_last_local_Y[0] = g_shared[i_s % num_stages, block_S - 1]
for j_s in T.Parallel(block_S):
g_rev_exp_shared[j_s] = T.exp2(
(g_last_local_Y[0] - g_shared[i_s % num_stages, j_s])
* 1.442695
)
g_last_local_Y[0] = T.exp2(g_last_local_Y[0] * 1.442695)
T.barrier_arrive(bar_1)
# [STAGE = i_s % num_stages] 1
T.barrier_wait(bar_1, i_s % 2)
# U = K @ S
T.gemm_v1(
k_shared[i_s % num_stages, :, :],
h_shared,
y_fragment,
clear_accum=True,
)
# Y = g_last * U - g_last/g * V
for j_s, j_v in T.Parallel(block_S, DV):
y_fragment[j_s, j_v] *= g_last_local_Y[0]
for j_s, j_v in T.Parallel(block_S, DV):
y_fragment[j_s, j_v] -= (
v_shared[i_s % num_stages, j_s, j_v] * g_rev_exp_shared[j_s]
)
# S2[2] Y
T.copy(y_fragment, y_shared)
T.barrier_arrive(bar_2)
if calc_mt:
# [STAGE = i_s % num_stages] 2
g_prod_Y[0] += g_shared[i_s % num_stages, block_S - 1]
# S4[2] M
T.copy(m_fragment_L, m_shared_L)
# [STAGE = i_s % num_stages] 3
T.barrier_wait(bar_3, i_s % 2)
# Z = K @ M
T.gemm_v1(
k_shared[i_s % num_stages, :, :],
m_shared_L,
z_fragment_L,
clear_accum=True,
)
# S4[2] Z
T.copy(z_fragment_L, z_shared_L)
# M += X^T @ Z
T.gemm_v1(
x_shared,
z_shared_L,
m_fragment_L,
transpose_A=True,
clear_accum=False,
)
T.barrier_arrive(data_is_free[i_s % num_stages])
if calc_mt:
g_last_local_Y[0] = T.exp2(g_prod_Y[0] * 1.442695)
for j_k, j_v in T.Parallel(DK, DK // 2):
m_fragment_L[j_k, j_v] *= g_last_local_Y[0]
T.copy(m_fragment_L, mt[bb, bh, 0:DK, : DK // 2])
else:
T.set_max_nreg(PRODUCER_NREG, 0)
if tx < 384 + 32:
for i_s in T.serial(num_iters):
T.barrier_wait(
data_is_free[i_s % num_stages], (i_s // num_stages + 1) % 2
)
left = seq_start_idx + i_s * block_S
right = left + block_S
# Load K
T.copy(
k[batch_idx, left:right, bhg, 0:DK],
k_shared[i_s % num_stages, :, :],
)
T.barrier_arrive(data_is_ready[i_s % num_stages])
elif tx < 384 + 64:
for i_s in T.serial(num_iters):
T.barrier_wait(
data_is_free[i_s % num_stages], (i_s // num_stages + 1) % 2
)
left = seq_start_idx + i_s * block_S
right = left + block_S
# Load V
T.copy(
v[batch_idx, left:right, bh, 0:DV],
v_shared[i_s % num_stages, :, :],
)
# Load A TODO: Mask A for the last chunk
T.copy(
a[batch_idx, left:right, bh, 0:block_S],
a_shared[i_s % num_stages, :, :],
)
T.barrier_arrive(data_is_ready[i_s % num_stages])
elif tx < 384 + 96:
for i_s in T.serial(num_iters):
T.barrier_wait(
data_is_free[i_s % num_stages], (i_s // num_stages + 1) % 2
)
left = seq_start_idx + i_s * block_S
right = left + block_S
# Load gamma
if right <= seq_end_idx:
for j_s in T.Parallel(block_S):
g_shared[i_s % num_stages, j_s] = g[
batch_idx, left + j_s, bh
]
else:
for j_s in T.Parallel(block_S):
if left + j_s < seq_end_idx:
g_shared[i_s % num_stages, j_s] = g[
batch_idx, left + j_s, bh
]
else:
g_shared[i_s % num_stages, j_s] = g[
batch_idx, seq_end_idx - 1, bh
]
# Load beta
if right <= seq_end_idx:
for j_s in T.Parallel(block_S):
b_shared[i_s % num_stages, j_s] = b[
batch_idx, left + j_s, bh
]
else:
for j_s in T.Parallel(block_S):
if left + j_s < seq_end_idx:
b_shared[i_s % num_stages, j_s] = b[
batch_idx, left + j_s, bh
]
else:
b_shared[i_s % num_stages, j_s] = 0
T.barrier_arrive(data_is_ready[i_s % num_stages])
else:
for i_s in T.serial(num_iters):
T.barrier_arrive(bar_0)
T.barrier_wait(bar_0, i_s % 2)
T.barrier_wait(bar_1, i_s % 2)
# Store S
if store_h:
T.copy(
h_shared,
h[batch_idx, chunk_start_idx + i_s, bh, 0:DK, 0:DV],
)
return tilelang_prepare_h_kernel
def fused_gdr_h(
k: torch.Tensor,
v: torch.Tensor,
a: torch.Tensor,
g: torch.Tensor,
b: torch.Tensor,
initial_state: torch.Tensor | None = None,
output_final_state: bool = True,
output_h: bool = True,
chunk_size: int = 64,
cu_seqlens: torch.LongTensor | None = None,
num_warmup_chunks: torch.LongTensor | None = None,
):
batch_size, num_tokens, Hg, K = k.shape
_, _, H, V = v.shape
assert K == V == 128
assert chunk_size == 64
if cu_seqlens is None:
assert num_warmup_chunks is None
real_batch_size = batch_size
num_chunks = tilelang.cdiv(num_tokens, chunk_size) if output_h else 0
cu_seqlens = torch.empty((batch_size + 1), dtype=torch.int32, device=k.device)
chunk_offsets = torch.empty(
(batch_size + 1), dtype=torch.int32, device=k.device
)
is_varlen = False
is_cp = False
else:
real_batch_size = len(cu_seqlens) - 1
chunk_offsets, num_chunks = prepare_chunk_offsets(cu_seqlens, chunk_size)
chunk_offsets = chunk_offsets.to(cu_seqlens.dtype)
num_chunks = num_chunks if output_h else 0
is_varlen = True
if num_warmup_chunks is None:
num_warmup_chunks = torch.empty(
(real_batch_size, H), dtype=cu_seqlens.dtype, device=k.device
)
is_cp = False
else:
is_cp = True
use_initial_state = initial_state is not None
if initial_state is None:
initial_state = torch.empty(
(real_batch_size, H, K, V), dtype=torch.float32, device=k.device
)
h = torch.empty((batch_size, num_chunks, H, K, V), dtype=k.dtype, device=k.device)
ht_dtype = k.dtype if is_cp else torch.float32
final_state = torch.empty(
(real_batch_size, H, K, V), dtype=ht_dtype, device=k.device
)
final_correction = torch.empty(
(real_batch_size, H, K, K), dtype=ht_dtype, device=k.device
)
tilelang_prepare_h_kernel = tilelang_prepare_h(
H,
Hg,
K,
V,
chunk_size,
qkva_dtype=k.dtype,
g_dtype=g.dtype,
b_dtype=b.dtype,
h0_dtype=initial_state.dtype,
ht_dtype=final_state.dtype,
h_dtype=h.dtype,
seqlen_dtype=cu_seqlens.dtype,
accum_dtype="float32",
use_initial_state=use_initial_state,
store_final_state=output_final_state,
store_h=output_h,
is_varlen=is_varlen,
is_cp=is_cp,
)
tilelang_prepare_h_kernel(
k,
v,
a,
g,
b,
initial_state,
cu_seqlens,
chunk_offsets,
num_warmup_chunks,
h,
final_state,
final_correction,
)
if not output_final_state:
final_state = None
final_correction = None
if not output_h:
h = None
return h, final_state, final_correction

View File

@@ -0,0 +1,6 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
from .sm_legacy import chunk_gated_delta_rule_fwd_legacy
__all__ = ["chunk_gated_delta_rule_fwd_legacy"]

View File

@@ -0,0 +1,348 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cstdint>
#include <stdexcept>
#include <string>
#include <vector>
namespace {
void check_cuda(cudaError_t status, const char* context) {
if (status != cudaSuccess) {
throw std::runtime_error(std::string(context) + ": " +
cudaGetErrorString(status));
}
}
__device__ __forceinline__ float subgroup_sum_lane0(float value,
int width) {
constexpr unsigned mask = 0xffffffffU;
for (int offset = width / 2; offset > 0; offset >>= 1) {
value += __shfl_down_sync(mask, value, offset, width);
}
return value;
}
__device__ __forceinline__ float subgroup_broadcast_lane0(float value,
int width) {
return __shfl_sync(0xffffffffU, value, 0, width);
}
template <int D, int COLS, int WIDTH>
__global__ void gdn_forward_kernel(const float* __restrict__ q,
const float* __restrict__ k,
const float* __restrict__ v,
const float* __restrict__ gate,
const float* __restrict__ beta,
const float* __restrict__ initial_state,
float* __restrict__ output,
float* __restrict__ final_state,
int batch,
int tokens,
int q_heads,
int v_heads,
float scale) {
static_assert(D % (COLS * (32 / WIDTH)) == 0);
constexpr int subgroups_per_warp = 32 / WIDTH;
constexpr int rows_per_lane = (D + WIDTH - 1) / WIDTH;
const int hv = blockIdx.x;
const int b = blockIdx.y;
const int subgroup = threadIdx.x / WIDTH;
const int lane = threadIdx.x % WIDTH;
const int group_base =
(blockIdx.z * blockDim.y + threadIdx.y) * subgroups_per_warp + subgroup;
const int col_base = group_base * COLS;
const int hq = hv / (v_heads / q_heads);
float state_shard[COLS][rows_per_lane];
#pragma unroll
for (int c = 0; c < COLS; ++c) {
const int col = col_base + c;
#pragma unroll
for (int r = 0; r < rows_per_lane; ++r) {
const int row = r * WIDTH + lane;
float value = 0.0F;
if (row < D) {
const auto state_index =
(((static_cast<int64_t>(b) * v_heads + hv) * D + col) * D) + row;
value = initial_state == nullptr ? 0.0F : initial_state[state_index];
}
state_shard[c][r] = value;
}
}
for (int t = 0; t < tokens; ++t) {
const auto gate_index =
((static_cast<int64_t>(b) * tokens + t) * v_heads + hv);
float gate_value = 0.0F;
float beta_value = 0.0F;
if (threadIdx.x == 0) {
gate_value = __expf(gate[gate_index]);
beta_value = beta[gate_index];
}
gate_value = __shfl_sync(0xffffffffU, gate_value, 0);
beta_value = __shfl_sync(0xffffffffU, beta_value, 0);
float k_reg[rows_per_lane];
float q_reg[rows_per_lane];
float kv_partial[COLS];
#pragma unroll
for (int c = 0; c < COLS; ++c) {
kv_partial[c] = 0.0F;
}
#pragma unroll
for (int r = 0; r < rows_per_lane; ++r) {
const int row = r * WIDTH + lane;
float q_value = 0.0F;
float k_value = 0.0F;
if (row < D) {
const auto qk_index =
(((static_cast<int64_t>(b) * tokens + t) * q_heads + hq) * D) + row;
q_value = q[qk_index];
k_value = k[qk_index];
}
q_reg[r] = q_value;
k_reg[r] = k_value;
#pragma unroll
for (int c = 0; c < COLS; ++c) {
kv_partial[c] += state_shard[c][r] * k_value;
}
}
float delta[COLS];
#pragma unroll
for (int c = 0; c < COLS; ++c) {
const float kv_col = subgroup_sum_lane0(kv_partial[c], WIDTH);
float delta_value = 0.0F;
if (lane == 0) {
const auto v_index =
(((static_cast<int64_t>(b) * tokens + t) * v_heads + hv) * D) +
col_base + c;
delta_value = (v[v_index] - gate_value * kv_col) * beta_value;
}
delta[c] = subgroup_broadcast_lane0(delta_value, WIDTH);
}
float attn_partial[COLS];
#pragma unroll
for (int c = 0; c < COLS; ++c) {
attn_partial[c] = 0.0F;
}
#pragma unroll
for (int r = 0; r < rows_per_lane; ++r) {
#pragma unroll
for (int c = 0; c < COLS; ++c) {
const float new_state =
fmaf(k_reg[r], delta[c], gate_value * state_shard[c][r]);
state_shard[c][r] = new_state;
attn_partial[c] += new_state * q_reg[r];
}
}
#pragma unroll
for (int c = 0; c < COLS; ++c) {
attn_partial[c] = subgroup_sum_lane0(attn_partial[c], WIDTH);
}
if (lane == 0) {
const auto out_base =
(((static_cast<int64_t>(b) * tokens + t) * v_heads + hv) * D);
#pragma unroll
for (int c = 0; c < COLS; ++c) {
output[out_base + col_base + c] = attn_partial[c] * scale;
}
}
}
#pragma unroll
for (int c = 0; c < COLS; ++c) {
const int col = col_base + c;
#pragma unroll
for (int r = 0; r < rows_per_lane; ++r) {
const int row = r * WIDTH + lane;
if (row < D) {
const auto state_index =
(((static_cast<int64_t>(b) * v_heads + hv) * D + col) * D) + row;
final_state[state_index] = state_shard[c][r];
}
}
}
}
template <int D>
void launch_gdn_forward(const float* q,
const float* k,
const float* v,
const float* gate,
const float* beta,
const float* initial_state,
float* output,
float* final_state,
int batch,
int tokens,
int q_heads,
int v_heads,
float scale,
cudaStream_t stream) {
constexpr int cols = D == 128 ? 4 : 1;
constexpr int width = D == 128 ? 16 : 32;
constexpr int groups_per_warp = 32 / width;
constexpr int column_groups_per_block = 8;
const dim3 block(32, column_groups_per_block);
const int groups = D / cols;
const int z = (groups + column_groups_per_block * groups_per_warp - 1) /
(column_groups_per_block * groups_per_warp);
const dim3 grid(v_heads, batch, z);
gdn_forward_kernel<D, cols, width>
<<<grid, block, 0, stream>>>(q,
k,
v,
gate,
beta,
initial_state,
output,
final_state,
batch,
tokens,
q_heads,
v_heads,
scale);
}
void validate_tensor(const torch::Tensor& tensor,
const char* name,
int64_t dims) {
TORCH_CHECK(tensor.is_cuda(), name, " must be a CUDA tensor");
TORCH_CHECK(tensor.scalar_type() == torch::kFloat32,
name,
" must be float32");
TORCH_CHECK(tensor.is_contiguous(), name, " must be contiguous");
TORCH_CHECK(tensor.dim() == dims, name, " has wrong rank");
}
} // namespace
std::vector<torch::Tensor> gdn_forward(torch::Tensor q,
torch::Tensor k,
torch::Tensor v,
torch::Tensor gate,
torch::Tensor beta,
c10::optional<torch::Tensor> initial_state,
double scale) {
validate_tensor(q, "q", 4);
validate_tensor(k, "k", 4);
validate_tensor(v, "v", 4);
validate_tensor(gate, "gate", 3);
validate_tensor(beta, "beta", 3);
TORCH_CHECK(q.sizes() == k.sizes(), "q and k must have the same shape");
const int batch = static_cast<int>(q.size(0));
const int tokens = static_cast<int>(q.size(1));
const int q_heads = static_cast<int>(q.size(2));
const int dim = static_cast<int>(q.size(3));
const int v_heads = static_cast<int>(v.size(2));
TORCH_CHECK(v.size(0) == batch && v.size(1) == tokens && v.size(3) == dim,
"v must have shape [B, T, Hv, D] matching q/k");
TORCH_CHECK(gate.size(0) == batch && gate.size(1) == tokens &&
gate.size(2) == v_heads,
"gate must have shape [B, T, Hv]");
TORCH_CHECK(beta.sizes() == gate.sizes(),
"beta must have the same shape as gate");
TORCH_CHECK(v_heads % q_heads == 0, "Hv must be divisible by Hq");
TORCH_CHECK(dim == 16 || dim == 32 || dim == 64 || dim == 128,
"D must be one of 16, 32, 64, or 128");
const float* initial_ptr = nullptr;
if (initial_state.has_value() && initial_state.value().defined()) {
const auto& h0 = initial_state.value();
validate_tensor(h0, "initial_state", 4);
TORCH_CHECK(h0.size(0) == batch && h0.size(1) == v_heads &&
h0.size(2) == dim && h0.size(3) == dim,
"initial_state must have shape [B, Hv, D, D]");
initial_ptr = h0.data_ptr<float>();
}
auto output = torch::empty_like(v);
auto final_state = torch::empty({batch, v_heads, dim, dim}, q.options());
const auto stream = at::cuda::getCurrentCUDAStream(q.device().index()).stream();
switch (dim) {
case 16:
launch_gdn_forward<16>(q.data_ptr<float>(),
k.data_ptr<float>(),
v.data_ptr<float>(),
gate.data_ptr<float>(),
beta.data_ptr<float>(),
initial_ptr,
output.data_ptr<float>(),
final_state.data_ptr<float>(),
batch,
tokens,
q_heads,
v_heads,
static_cast<float>(scale),
stream);
break;
case 32:
launch_gdn_forward<32>(q.data_ptr<float>(),
k.data_ptr<float>(),
v.data_ptr<float>(),
gate.data_ptr<float>(),
beta.data_ptr<float>(),
initial_ptr,
output.data_ptr<float>(),
final_state.data_ptr<float>(),
batch,
tokens,
q_heads,
v_heads,
static_cast<float>(scale),
stream);
break;
case 64:
launch_gdn_forward<64>(q.data_ptr<float>(),
k.data_ptr<float>(),
v.data_ptr<float>(),
gate.data_ptr<float>(),
beta.data_ptr<float>(),
initial_ptr,
output.data_ptr<float>(),
final_state.data_ptr<float>(),
batch,
tokens,
q_heads,
v_heads,
static_cast<float>(scale),
stream);
break;
case 128:
launch_gdn_forward<128>(q.data_ptr<float>(),
k.data_ptr<float>(),
v.data_ptr<float>(),
gate.data_ptr<float>(),
beta.data_ptr<float>(),
initial_ptr,
output.data_ptr<float>(),
final_state.data_ptr<float>(),
batch,
tokens,
q_heads,
v_heads,
static_cast<float>(scale),
stream);
break;
}
check_cuda(cudaGetLastError(), "gdn_forward launch");
return {output, final_state};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("gdn_forward", &gdn_forward, "SM70/SM75 legacy GDN forward");
}

View File

@@ -0,0 +1,104 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
from __future__ import annotations
import os
from pathlib import Path
import torch
from torch.utils.cpp_extension import load
_EXT = None
def _load_ext():
global _EXT
if _EXT is not None:
return _EXT
if not torch.cuda.is_available():
raise RuntimeError("SM70/SM75 legacy GDN backend requires CUDA")
os.environ.setdefault("TORCH_CUDA_ARCH_LIST", "7.0;7.5")
src = Path(__file__).with_name("csrc") / "gdn_forward.cu"
_EXT = load(
name="flash_qla_legacy_gdn",
sources=[str(src)],
extra_cuda_cflags=["-O3"],
extra_cflags=["-O3"],
verbose=bool(int(os.environ.get("FLASH_QLA_LEGACY_VERBOSE_BUILD", "0"))),
)
return _EXT
def _check_inputs(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
initial_state: torch.Tensor | None,
) -> None:
tensors = [q, k, v, g, beta]
if initial_state is not None:
tensors.append(initial_state)
if any(not tensor.is_cuda for tensor in tensors):
raise ValueError("legacy GDN tensors must be CUDA tensors")
# if any(tensor.dtype != torch.float32 for tensor in tensors):
# raise ValueError("legacy GDN backend currently supports float32 tensors only")
if any(not tensor.is_contiguous() for tensor in tensors):
raise ValueError("legacy GDN tensors must be contiguous")
if q.ndim != 4 or k.ndim != 4 or v.ndim != 4:
raise ValueError("q, k, and v must have shape [B, T, H, D]")
if g.ndim != 3 or beta.ndim != 3:
raise ValueError("g and beta must have shape [B, T, Hv]")
if q.shape != k.shape:
raise ValueError("q and k must have the same shape")
batch, tokens, q_heads, dim = q.shape
if v.shape[0] != batch or v.shape[1] != tokens or v.shape[3] != dim:
raise ValueError("v must have shape [B, T, Hv, D] matching q/k")
if g.shape != beta.shape or g.shape != v.shape[:3]:
raise ValueError("g and beta must have shape [B, T, Hv]")
if v.shape[2] % q_heads != 0:
raise ValueError("Hv must be divisible by Hq")
if dim not in (16, 32, 64, 128):
raise ValueError("legacy GDN backend supports D in {16, 32, 64, 128}")
if initial_state is not None and initial_state.shape != (batch, v.shape[2], dim, dim):
raise ValueError("initial_state must have shape [B, Hv, D, D]")
def chunk_gated_delta_rule_fwd_legacy(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float | None = None,
initial_state: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Run the experimental SM70/SM75 forward-only GDN backend.
This legacy backend is intentionally explicit. It does not replace the
Hopper/SM90 TileLang path and currently supports only contiguous float32
tensors for inference-oriented forward execution.
Shapes:
q, k: [B, T, Hq, D]
v: [B, T, Hv, D]
g, beta: [B, T, Hv]
initial_state: optional [B, Hv, D, D]
Returns:
output: [B, T, Hv, D]
final_state: [B, Hv, D, D]
"""
_check_inputs(q, k, v, g, beta, initial_state)
if scale is None:
scale = q.shape[-1] ** -0.5
ext = _load_ext()
return ext.gdn_forward(q, k, v, g, beta, initial_state, float(scale))

View File

@@ -0,0 +1,11 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
from .cumsum import chunk_local_cumsum
from .group_reduce import group_reduce_vector
__all__ = [
"chunk_local_cumsum",
"group_reduce_vector",
]

View File

@@ -0,0 +1,165 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
import torch
import tilelang
import tilelang.language as T
from flash_qla.utils import prepare_chunk_indices
@tilelang.jit(
# out_idx=[-1],
)
def tilelang_chunk_local_cumsum(
H,
chunk_size,
accum_dtype,
g_dtype,
seqlen_dtype,
is_varlen,
reverse,
):
data_batch_size = T.dynamic("data_batch_size")
real_batch_size = T.dynamic("real_batch_size")
num_tokens = T.dynamic("num_tokens")
num_chunks = T.dynamic("num_chunks")
block_S = chunk_size
g_shape = (data_batch_size, num_tokens, H)
@T.macro
def kernel_body(
bb,
bc,
batch_idx,
chunk_idx,
seq_start_idx,
seq_end_idx,
g_raw,
g_cumsum,
):
left = seq_start_idx + chunk_idx * block_S
right = left + block_S
g_fragment = T.alloc_fragment((H, block_S), dtype=accum_dtype)
gT_fragment = T.alloc_fragment((block_S, H), dtype=g_dtype)
gT_shared = T.alloc_shared((block_S, H + 1), dtype=g_dtype)
if right <= seq_end_idx:
T.copy(g_raw[bb, left:right, 0:H], gT_fragment)
else:
for j, i in T.Parallel(block_S, H):
if left + j < seq_end_idx:
gT_fragment[j, i] = g_raw[bb, left + j, i]
else:
gT_fragment[j, i] = 0
T.copy(gT_fragment, gT_shared[:, :H])
for i, j in T.Parallel(H, block_S):
g_fragment[i, j] = gT_shared[j, i]
T.cumsum(g_fragment, dim=1, reverse=reverse)
for i, j in T.Parallel(H, block_S):
gT_shared[j, i] = g_fragment[i, j]
T.copy(gT_shared[:, :H], gT_fragment)
if right <= seq_end_idx:
T.copy(gT_fragment, g_cumsum[bb, left:right, 0:H])
else:
for j, i in T.Parallel(block_S, H):
if left + j < seq_end_idx:
g_cumsum[bb, left + j, i] = gT_fragment[j, i]
if is_varlen:
@T.prim_func
def tilelang_chunk_local_cumsum_kernel(
g_raw: T.Tensor(g_shape, dtype=g_dtype),
cu_seqlens: T.Tensor([real_batch_size + 1], dtype=seqlen_dtype),
chunk_indices: T.Tensor([num_chunks, 2], dtype=seqlen_dtype),
g_cumsum: T.Tensor(g_shape, dtype=g_dtype),
):
with T.Kernel(num_chunks, threads=128) as (bc,):
bb = 0
batch_idx = chunk_indices[bc, 0]
chunk_idx = chunk_indices[bc, 1]
seq_start_idx = cu_seqlens[batch_idx]
seq_end_idx = cu_seqlens[batch_idx + 1]
kernel_body(
bb,
bc,
batch_idx,
chunk_idx,
seq_start_idx,
seq_end_idx,
g_raw,
g_cumsum,
)
else:
@T.prim_func
def tilelang_chunk_local_cumsum_kernel(
g_raw: T.Tensor(g_shape, dtype=g_dtype),
g_cumsum: T.Tensor(g_shape, dtype=g_dtype),
num_chunks: T.int32,
):
with T.Kernel(num_chunks, threads=128) as (bc,):
bb = bc % data_batch_size
batch_idx = bb
chunk_idx = bc // data_batch_size
seq_start_idx = 0
seq_end_idx = num_tokens
kernel_body(
bb,
bc,
batch_idx,
chunk_idx,
seq_start_idx,
seq_end_idx,
g_raw,
g_cumsum,
)
return tilelang_chunk_local_cumsum_kernel
def chunk_local_cumsum(
g: torch.Tensor,
chunk_size: int = 64,
cu_seqlens: torch.LongTensor | None = None,
reverse: bool = False,
):
batch_size, num_tokens, H = g.shape
assert g.stride(-1) == 1
if cu_seqlens is None:
num_chunks = batch_size * tilelang.cdiv(num_tokens, chunk_size)
seqlen_dtype = "int32"
is_varlen = False
else:
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
seqlen_dtype = cu_seqlens.dtype
is_varlen = True
g_cumsum = torch.empty_like(g)
tilelang_chunk_local_cumsum_kernel = tilelang_chunk_local_cumsum(
H,
chunk_size,
g_dtype=g.dtype,
seqlen_dtype=seqlen_dtype,
accum_dtype="float32",
is_varlen=is_varlen,
reverse=reverse,
)
if is_varlen:
tilelang_chunk_local_cumsum_kernel(g, cu_seqlens, chunk_indices, g_cumsum)
else:
tilelang_chunk_local_cumsum_kernel(g, g_cumsum, num_chunks)
return g_cumsum

View File

@@ -0,0 +1,79 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
import torch
import tilelang
import tilelang.language as T
@tilelang.jit(
# out_idx=[-1],
)
def tilelang_group_reduce_vector(
H,
Hg,
DK,
accum_dtype,
qkva_dtype,
block_size: int = 16,
):
batch_size = T.dynamic("batch_size")
num_tokens = T.dynamic("num_tokens")
group_size = H // Hg
buffer_shape = (batch_size, num_tokens, H, DK)
dqk_shape = (batch_size, num_tokens, Hg, DK)
@T.prim_func
def tilelang_group_reduce_vector_kernel(
buffer: T.Tensor(buffer_shape, dtype=qkva_dtype),
result: T.Tensor(dqk_shape, dtype=qkva_dtype),
):
with T.Kernel(
tilelang.cdiv(num_tokens, block_size), Hg, batch_size, threads=128
) as (bt, bhg, bb):
buffer_fragment = T.alloc_fragment((block_size, DK), dtype=accum_dtype)
result_fragment = T.alloc_fragment((block_size, DK), dtype=accum_dtype)
T.clear(result_fragment)
for i in T.serial(group_size):
T.copy(
buffer[
bb,
bt * block_size : (bt + 1) * block_size,
bhg * group_size + i,
0:DK,
],
buffer_fragment,
)
for j, k in T.Parallel(block_size, DK):
result_fragment[j, k] += buffer_fragment[j, k]
T.copy(
result_fragment,
result[bb, bt * block_size : (bt + 1) * block_size, bhg, 0:DK],
)
return tilelang_group_reduce_vector_kernel
def group_reduce_vector(
buffer: torch.Tensor,
Hg: int,
):
batch_size, num_tokens, H, K = buffer.shape
result = torch.empty(
(batch_size, num_tokens, Hg, K), dtype=buffer.dtype, device=buffer.device
)
tilelang_group_reduce_vector_kernel = tilelang_group_reduce_vector(
H,
Hg,
K,
qkva_dtype=buffer.dtype,
accum_dtype="float32",
)
tilelang_group_reduce_vector_kernel(buffer, result)
return result

View File

@@ -0,0 +1,20 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
from .profiler import profile
from .pack import pad_and_reshape, pack, unpack, fill_last_chunk_of_g
from .math import l2norm
from .index import prepare_chunk_indices, prepare_chunk_offsets, tensor_cache
__all__ = [
"profile",
"pad_and_reshape",
"pack",
"unpack",
"fill_last_chunk_of_g",
"l2norm",
"prepare_chunk_indices",
"prepare_chunk_offsets",
"tensor_cache",
]

138
flash_qla/utils/index.py Normal file
View File

@@ -0,0 +1,138 @@
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
import functools
from typing import Any
from collections import OrderedDict
from collections.abc import Callable
import torch
import tilelang
import tilelang.language as T
def tensor_cache(
fn: Callable[..., torch.Tensor],
) -> Callable[..., torch.Tensor]:
"""
A decorator that caches the most recent results of a function with tensor inputs.
This decorator will store the output of the decorated function for the most recent set of input tensors.
The cache is limited to a fixed size (default is 256). When the cache is full, the oldest entry will be removed.
Args:
fn (Callable[..., torch.Tensor]):
The function to be decorated. It should take tensor inputs and return tensor outputs.
Returns:
Callable[..., torch.Tensor]:
A wrapped version of the input function with single-entry caching.
"""
cache: "OrderedDict[tuple[tuple[int, ...], tuple[tuple[str, int], ...]], tuple[tuple[Any, ...], dict[str, Any], Any]]" = OrderedDict()
cache_size = 256
def get_id(x: Any):
if (type(x) is int) or (type(x) is float) or (type(x) is str):
return x
else:
return id(x)
def make_identity_key(
args: tuple[Any, ...], kwargs: dict[str, Any]
) -> tuple[tuple[int, ...], tuple[tuple[str, int], ...]]:
args_key = tuple(get_id(a) for a in args)
kwargs_key = tuple(sorted((k, get_id(v)) for k, v in kwargs.items()))
return args_key, kwargs_key
@functools.wraps(fn)
def wrapper(*args: Any, **kwargs: Any) -> Any:
nonlocal cache, cache_size
key = make_identity_key(args, kwargs)
if key in cache:
cache.move_to_end(key, last=True)
_, _, cached_result = cache[key]
return cached_result
result = fn(*args, **kwargs)
cache[key] = (args, kwargs, result)
cache.move_to_end(key, last=True)
if len(cache) > cache_size:
cache.popitem(last=False)
return result
return wrapper
@tensor_cache
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
return torch.diff(cu_seqlens)
@tensor_cache
def prepare_chunk_indices(
cu_seqlens: torch.LongTensor,
chunk_size: int,
) -> torch.LongTensor:
# TODO: tilelang kernel
indices = torch.cat(
[
torch.arange(n)
for n in tilelang.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()
]
)
return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens)
@tilelang.jit()
def tilelang_prepare_chunk_offsets(
chunk_size,
block_size,
dtype,
):
batch_size_plus_1 = T.dynamic("batch_size_plus_1")
num_threads = min(max(block_size, 32), 128)
@T.prim_func
def tilelang_prepare_chunk_offsets_kernel(
cu_seqlens: T.Tensor([batch_size_plus_1], dtype=dtype),
chunk_offsets: T.Tensor([batch_size_plus_1], dtype=dtype),
):
with T.Kernel(1, threads=num_threads) as (bb,):
_batch_size = T.alloc_var("int32")
_batch_size = batch_size_plus_1 - 1
seqlen_start_fragment = T.alloc_fragment((block_size), dtype=dtype)
seqlen_end_fragment = T.alloc_fragment((block_size), dtype=dtype)
chunk_offset_fragment = T.alloc_fragment((block_size), dtype=dtype)
T.copy(cu_seqlens[: batch_size_plus_1 - 1], seqlen_start_fragment)
T.copy(cu_seqlens[1:], seqlen_end_fragment)
for i in T.Parallel(block_size):
chunk_offset_fragment[i] = (
seqlen_end_fragment[i] - seqlen_start_fragment[i]
)
chunk_offset_fragment[i] = (
chunk_offset_fragment[i] + chunk_size - 1
) // chunk_size
T.cumsum(src=chunk_offset_fragment, dim=0)
chunk_offsets[0] = 0
T.copy(chunk_offset_fragment, chunk_offsets[1:])
return tilelang_prepare_chunk_offsets_kernel
@tensor_cache
def prepare_chunk_offsets(
cu_seqlens: torch.LongTensor,
chunk_size: int,
) -> torch.LongTensor:
chunk_offsets = torch.empty_like(cu_seqlens)
tilelang_prepare_chunk_offsets_kernel = tilelang_prepare_chunk_offsets(
chunk_size=chunk_size,
block_size=tilelang.next_power_of_2(cu_seqlens.shape[0] - 1),
dtype=cu_seqlens.dtype,
)
tilelang_prepare_chunk_offsets_kernel(cu_seqlens, chunk_offsets)
return chunk_offsets, chunk_offsets[-1].item()

21
flash_qla/utils/math.py Normal file
View File

@@ -0,0 +1,21 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
import torch
@torch.compile
def l2norm_compiled(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6):
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
return (x * inv_norm).to(x.dtype)
def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6):
assert dim == -1
assert x.stride(-1) == 1
raw_shape = x.shape
x = x.view((-1, raw_shape[-1]))
torch._dynamo.mark_dynamic(x, 0)
y = l2norm_compiled(x, dim, eps)
y = y.view(raw_shape)
return y

83
flash_qla/utils/pack.py Normal file
View File

@@ -0,0 +1,83 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
import torch
def unpack(
x: torch.Tensor, # [B, T, H]
cu_seqlens: torch.Tensor,
):
assert x.shape[0] == 1
assert len(cu_seqlens.shape) == 1
max_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
batch_size = cu_seqlens.shape[0] - 1
y = torch.zeros((batch_size, max_len, *x.shape[2:]), dtype=x.dtype, device=x.device)
for i in range(batch_size):
start = cu_seqlens[i].item()
end = cu_seqlens[i + 1].item()
y[i, : end - start] = x[0, start:end]
return y
def pack(
x: torch.Tensor, # [B, T, H]
cu_seqlens: torch.Tensor,
):
assert len(cu_seqlens.shape) == 1
sum_len = cu_seqlens[-1].item()
batch_size = cu_seqlens.shape[0] - 1
y = torch.empty((1, sum_len, *x.shape[2:]), dtype=x.dtype, device=x.device)
for i in range(batch_size):
start = cu_seqlens[i].item()
end = cu_seqlens[i + 1].item()
y[0, start:end] = x[i, : end - start]
return y
def pad_and_reshape(
x: torch.Tensor,
dim: int,
chunk_size: int = 64,
):
sequence_length = x.shape[dim]
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
zeros = [
0,
] * (2 * (len(x.shape) - 1 - dim))
padded = torch.nn.functional.pad(x, (*zeros, 0, pad_size))
return padded.reshape((*x.shape[:dim], -1, chunk_size, *x.shape[dim + 1 :]))
def fill_last_chunk_of_g(
g: torch.Tensor,
num_tokens: int,
cu_seqlens: torch.Tensor,
chunk_size: int = 64,
reverse: bool = False,
):
if cu_seqlens is None:
last_chunk_size = num_tokens % chunk_size
if last_chunk_size > 0:
if reverse:
g[:, -1, last_chunk_size - 1] += g[:, -1, -1]
else:
g[:, -1, last_chunk_size:] = g[
:, -1, last_chunk_size - 1 : last_chunk_size
]
else:
for i in range(cu_seqlens.shape[0] - 1):
start = cu_seqlens[i].item()
end = cu_seqlens[i + 1].item()
last_chunk_idx = (end - start) // chunk_size
last_chunk_size = (end - start) % chunk_size
if last_chunk_size > 0:
if reverse:
g[i, last_chunk_idx, last_chunk_size - 1] += g[
i, last_chunk_idx, -1
]
else:
g[i, last_chunk_idx, last_chunk_size:] = g[
i, last_chunk_idx, last_chunk_size - 1 : last_chunk_size
]
return g

View File

@@ -0,0 +1,25 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
import torch
import tilelang
def profile(func, inputs, wait: int = 50, warmup: int = 50, rep: int = 100):
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=rep),
# on_trace_ready=torch.profiler.tensorboard_trace_handler('./tb'),
) as prof:
for idx in range(wait + warmup + rep):
func(*inputs)
prof.step()
# print(prof.key_averages().table(sort_by="cpu_time", row_limit=10))
result = {x.key: x.device_time * 1e-3 for x in prof.key_averages()}
result["total"] = tilelang.profiler.do_bench(
lambda: func(*inputs), warmup=warmup, rep=rep
)
return result

31
setup.py Normal file
View File

@@ -0,0 +1,31 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
import os
import subprocess
from setuptools import setup, find_packages
this_dir = os.path.dirname(os.path.abspath(__file__))
rev = os.getenv("QLA_VERSION_SUFFIX", "")
if not rev:
try:
cmd = ["git", "rev-parse", "--short", "HEAD"]
rev = "+" + subprocess.check_output(cmd, cwd=this_dir).decode("ascii").rstrip()
except Exception:
rev = ""
setup(
name="flash_qla",
version="0.1.0" + rev,
description="FlashQLA: Fused TileLang kernels for Linear Attention",
packages=find_packages(),
license="MIT",
python_requires=">=3.10",
install_requires=[
"torch>=2.8",
"tilelang==0.1.8",
"apache-tvm-ffi==0.1.9",
],
zip_safe=False,
)

700
tests/ref_gdr.py Normal file
View File

@@ -0,0 +1,700 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
import torch
from flash_qla.utils import (
pad_and_reshape,
pack,
unpack,
fill_last_chunk_of_g,
prepare_chunk_offsets,
)
def torch_cumsum(
x: torch.Tensor, # [B, T, H]
cu_seqlens: torch.Tensor = None,
chunk_size: int = 64,
reverse: bool = False,
):
if cu_seqlens is not None:
x = unpack(x, cu_seqlens)
batch_size, num_tokens, num_heads = x.shape
x = pad_and_reshape(x, dim=1, chunk_size=chunk_size)
if reverse:
x = torch.flip(x, dims=(2,))
x = x.cumsum(dim=2)
x = torch.flip(x, dims=(2,))
else:
x = x.cumsum(dim=2)
x = x.reshape(batch_size, -1, num_heads)
x = x[:, :num_tokens]
if cu_seqlens is not None:
x = pack(x, cu_seqlens)
return x
def torch_kkt_fwd(
k: torch.Tensor, # [B, T, Hk, K]
g: torch.Tensor, # [B, T, Hv]
beta: torch.Tensor, # [B, T, Hv]
cu_seqlens: torch.Tensor = None,
chunk_size: int = 64,
):
if cu_seqlens is not None:
k = unpack(k, cu_seqlens)
g = unpack(g, cu_seqlens)
beta = unpack(beta, cu_seqlens)
batch_size, num_tokens, num_k_heads, head_dim = k.shape
num_v_heads = g.shape[-1]
if num_k_heads != num_v_heads:
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, H, K]
g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, H]
beta = pad_and_reshape(beta, dim=1, chunk_size=chunk_size) # [B, N, C, H]
mask = torch.triu(
torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device)
)
decay_mask = torch.exp(g[:, :, :, None, :] - g[:, :, None, :, :])
decay_mask = decay_mask.masked_fill(mask[None, None, :, :, None], 0.0)
# decay_mask = torch.where(mask[None, None, :, :, None], decay_mask, 0.0)
attn = torch.einsum(
"bnchk, bndhk -> bnchd", k * beta.unsqueeze(-1), k
) * decay_mask.swapaxes(-2, -1) # [B, N, C, H, D]
attn = attn.reshape(batch_size, -1, num_v_heads, chunk_size)[:, :num_tokens]
if cu_seqlens is not None:
attn = pack(attn, cu_seqlens)
return attn
def torch_solve(
x: torch.Tensor, # [B, T, H, D]
cu_seqlens: torch.Tensor = None,
):
if cu_seqlens is not None:
x = unpack(x, cu_seqlens)
batch_size, num_tokens, num_heads, chunk_size = x.shape
x = -pad_and_reshape(x, dim=1, chunk_size=chunk_size).swapaxes(
2, 3
) # [B, N, H, C, D]
for i in range(1, chunk_size):
row = x[..., i, :i].clone()
sub = x[..., :i, :i].clone()
x[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
x += torch.eye(chunk_size, dtype=x.dtype, device=x.device)
x = x.swapaxes(2, 3).reshape((batch_size, -1, num_heads, chunk_size))[
:, :num_tokens
]
if cu_seqlens is not None:
x = pack(x, cu_seqlens)
return x
def torch_w_u_fwd(
k: torch.Tensor, # [B, T, Hk, K]
v: torch.Tensor, # [B, T, Hv, V]
g: torch.Tensor, # [B, T, Hv]
beta: torch.Tensor, # [B, T, Hv]
A: torch.Tensor, # [B, T, Hv, D]
cu_seqlens: torch.Tensor = None,
):
if cu_seqlens is not None:
k = unpack(k, cu_seqlens)
v = unpack(v, cu_seqlens)
A = unpack(A, cu_seqlens)
beta = unpack(beta, cu_seqlens)
g = unpack(g, cu_seqlens)
batch_size, num_tokens, _, chunk_size = A.shape
_, _, num_k_heads, head_dim_k = k.shape
_, _, num_v_heads, head_dim_v = v.shape
if num_k_heads != num_v_heads:
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
k_beta = pad_and_reshape(
k * beta.unsqueeze(-1) * g.exp().unsqueeze(-1), dim=1, chunk_size=chunk_size
) # [B, N, C, Hv, K]
v_beta = pad_and_reshape(
v * beta.unsqueeze(-1), dim=1, chunk_size=chunk_size
) # [B, N, C, Hv, V]
A = pad_and_reshape(A, dim=1)
w = torch.einsum("bnchd, bndhk -> bnchk", A, k_beta).reshape(
(batch_size, -1, num_v_heads, head_dim_k)
)[:, :num_tokens]
u = torch.einsum("bnchd, bndhk -> bnchk", A, v_beta).reshape(
(batch_size, -1, num_v_heads, head_dim_v)
)[:, :num_tokens]
if cu_seqlens is not None:
w = pack(w, cu_seqlens)
u = pack(u, cu_seqlens)
return w, u
def torch_chunk_gdr_fwd(
k: torch.Tensor, # [B, T, Hk, K]
w: torch.Tensor, # [B, T, Hv, K]
u: torch.Tensor, # [B, T, Hv, V]
g: torch.Tensor, # [B, T, Hv]
initial_state: torch.Tensor = None, # [B, Hv, K, V]
cu_seqlens: torch.Tensor = None,
chunk_size: int = 64,
):
if cu_seqlens is not None:
k = unpack(k, cu_seqlens)
w = unpack(w, cu_seqlens)
u = unpack(u, cu_seqlens)
g = unpack(g, cu_seqlens)
batch_size, num_tokens, num_k_heads, head_dim_k = k.shape
_, _, num_v_heads, head_dim_v = u.shape
if num_k_heads != num_v_heads:
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
w = pad_and_reshape(w, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
u = pad_and_reshape(u, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V]
g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, Hv]
g = fill_last_chunk_of_g(g, num_tokens, cu_seqlens, chunk_size=chunk_size)
if initial_state is None:
last_state = torch.zeros(
(batch_size, num_v_heads, head_dim_k, head_dim_v),
dtype=g.dtype,
device=g.device,
)
else:
last_state = initial_state.to(g.dtype, copy=True)
h, vn = [], []
for i in range(k.shape[1]):
h.append(last_state)
v_new = u[:, i] - torch.einsum("bchk, bhkv -> bchv", w[:, i], last_state)
vn.append(v_new)
last_state = last_state * g[:, i, -1, :, None, None].exp()
last_state = last_state + torch.einsum(
"bchk, bchv -> bhkv",
k[:, i] * (g[:, i, -1:, :, None] - g[:, i, :, :, None]).exp(),
v_new,
)
h = torch.stack(h, dim=1).contiguous()
vn = (
torch.stack(vn, dim=1)
.reshape((batch_size, -1, num_v_heads, head_dim_v))[:, :num_tokens]
.contiguous()
)
if cu_seqlens is not None:
vn = pack(vn, cu_seqlens)
h = pack(h, prepare_chunk_offsets(cu_seqlens, chunk_size))
return h, vn, last_state
def torch_chunk_o_fwd(
q: torch.Tensor, # [B, T, Hk, K]
k: torch.Tensor, # [B, T, Hk, K]
v: torch.Tensor, # [B, T, Hv, K]
h: torch.Tensor, # [B, N, Hv, K, V]
g: torch.Tensor, # [B, T, Hv]
cu_seqlens: torch.Tensor = None,
scale: float = None,
chunk_size: int = 64,
):
if cu_seqlens is not None:
q = unpack(q, cu_seqlens)
k = unpack(k, cu_seqlens)
v = unpack(v, cu_seqlens)
g = unpack(g, cu_seqlens)
h = unpack(h, prepare_chunk_offsets(cu_seqlens, chunk_size))
batch_size, num_tokens, num_k_heads, head_dim_k = k.shape
_, _, num_v_heads, head_dim_v = v.shape
if num_k_heads != num_v_heads:
q = q.repeat_interleave(num_v_heads // num_k_heads, dim=2)
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
scale = scale or head_dim_k ** (-0.5)
q = pad_and_reshape(q, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
v = pad_and_reshape(v, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, Hv]
q = q * scale
mask = torch.triu(
torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device),
diagonal=1,
)
decay_mask = torch.exp(g[:, :, :, None, :] - g[:, :, None, :, :])
decay_mask = decay_mask.masked_fill(
mask[None, None, :, :, None], 0.0
) # [B, N, C, D, Hv]
attn = torch.einsum("bnchk, bndhk -> bncdh", q, k) * decay_mask
attn_inter = torch.einsum("bnchk, bnhkv -> bnchv", q * g.exp().unsqueeze(-1), h)
o = attn_inter + torch.einsum("bncdh, bndhv -> bnchv", attn, v)
o = o.reshape((batch_size, -1, num_v_heads, head_dim_v))[:, :num_tokens]
if cu_seqlens is not None:
o = pack(o, cu_seqlens)
return o
def torch_chunk_dv_bwd(
q: torch.Tensor, # [B, T, Hk, K]
k: torch.Tensor, # [B, T, Hk, K]
g: torch.Tensor, # [B, T, Hv]
do: torch.Tensor, # [B, T, Hv, V]
cu_seqlens: torch.Tensor = None,
scale: float = None,
chunk_size: int = 64,
):
if cu_seqlens is not None:
q = unpack(q, cu_seqlens)
k = unpack(k, cu_seqlens)
g = unpack(g, cu_seqlens)
do = unpack(do, cu_seqlens)
batch_size, num_tokens, num_k_heads, head_dim_k = k.shape
_, _, num_v_heads, head_dim_v = do.shape
if num_k_heads != num_v_heads:
q = q.repeat_interleave(num_v_heads // num_k_heads, dim=2)
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
scale = scale or head_dim_k ** (-0.5)
q = pad_and_reshape(q, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, Hv]
do = pad_and_reshape(do, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V]
q = q * scale
mask = torch.triu(
torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device),
diagonal=1,
)
decay_mask = torch.exp(g[:, :, :, None, :] - g[:, :, None, :, :])
decay_mask = decay_mask.masked_fill(
mask[None, None, :, :, None], 0.0
) # [B, N, C, D, Hv]
attn = torch.einsum("bnchk, bndhk -> bncdh", q, k) * decay_mask
dv = torch.einsum("bncdh, bnchv -> bndhv", attn, do)
dv = dv.reshape((batch_size, -1, num_v_heads, head_dim_v))[:, :num_tokens]
if cu_seqlens is not None:
dv = pack(dv, cu_seqlens)
return dv
def torch_chunk_gdr_bwd(
q: torch.Tensor, # [B, T, Hk, K]
k: torch.Tensor, # [B, T, Hk, K]
w: torch.Tensor, # [B, T, Hv, K]
g: torch.Tensor, # [B, T, Hv]
do: torch.Tensor, # [B, T, Hv, V]
dv: torch.Tensor, # [B, T, Hv, V]
h0: torch.Tensor = None, # [B, Hv, K, V]
dht: torch.Tensor = None, # [B, Hv, K, V]
cu_seqlens: torch.Tensor = None,
scale: float = None,
chunk_size: int = 64,
):
if cu_seqlens is not None:
q = unpack(q, cu_seqlens)
k = unpack(k, cu_seqlens)
w = unpack(w, cu_seqlens)
g = unpack(g, cu_seqlens)
do = unpack(do, cu_seqlens)
dv = unpack(dv, cu_seqlens)
batch_size, num_tokens, num_k_heads, head_dim_k = k.shape
_, _, num_v_heads, head_dim_v = do.shape
if num_k_heads != num_v_heads:
q = q.repeat_interleave(num_v_heads // num_k_heads, dim=2)
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
scale = scale or head_dim_k ** (-0.5)
q = pad_and_reshape(q, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
w = pad_and_reshape(w, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, Hv]
do = pad_and_reshape(do, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V]
dv = pad_and_reshape(dv, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V]
g = fill_last_chunk_of_g(g, num_tokens, cu_seqlens, chunk_size=chunk_size)
q = q * scale
if dht is None:
dstate = torch.zeros(
(batch_size, num_v_heads, head_dim_k, head_dim_v),
dtype=g.dtype,
device=g.device,
)
else:
dstate = dht.to(g.dtype, copy=True)
dstate_inter = torch.einsum("bnchk, bnchv -> bnhkv", q * g.exp().unsqueeze(-1), do)
dh = []
for i in reversed(range(k.shape[1])):
dh.insert(0, dstate)
dv[:, i] += torch.einsum(
"bchk, bhkv -> bchv",
k[:, i] * (g[:, i, -1:, :, None] - g[:, i, :, :, None]).exp(),
dstate,
)
dstate = dstate * g[:, i, -1, :, None, None].exp()
dstate = (
dstate
+ dstate_inter[:, i]
- torch.einsum("bchk, bchv -> bhkv", w[:, i], dv[:, i])
)
dh = torch.stack(dh, dim=1).contiguous()
dh0 = None if h0 is None else dstate
dv = dv.reshape((batch_size, -1, num_v_heads, head_dim_v))[:, :num_tokens]
if cu_seqlens is not None:
dv = pack(dv, cu_seqlens)
dh = pack(dh, prepare_chunk_offsets(cu_seqlens, chunk_size))
return dh, dh0, dv
def torch_chunk_dqkwg_bwd(
q: torch.Tensor, # [B, T, Hk, K]
k: torch.Tensor, # [B, T, Hk, K]
v: torch.Tensor, # [B, T, Hv, V]
w: torch.Tensor, # [B, T, Hv, K]
g: torch.Tensor, # [B, T, Hv]
h: torch.Tensor, # [B, N, Hv, K, V]
dv: torch.Tensor, # [B, T, Hv, V]
do: torch.Tensor, # [B, T, Hv, V]
dh: torch.Tensor, # [B, N, Hv, K, V]
cu_seqlens: torch.Tensor = None,
scale: float = None,
chunk_size: int = 64,
):
if cu_seqlens is not None:
q = unpack(q, cu_seqlens)
k = unpack(k, cu_seqlens)
v = unpack(v, cu_seqlens)
w = unpack(w, cu_seqlens)
g = unpack(g, cu_seqlens)
do = unpack(do, cu_seqlens)
dv = unpack(dv, cu_seqlens)
h = unpack(h, prepare_chunk_offsets(cu_seqlens, chunk_size))
dh = unpack(dh, prepare_chunk_offsets(cu_seqlens, chunk_size))
batch_size, num_tokens, num_k_heads, head_dim_k = k.shape
_, _, num_v_heads, head_dim_v = do.shape
if num_k_heads != num_v_heads:
q = q.repeat_interleave(num_v_heads // num_k_heads, dim=2)
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
scale = scale or head_dim_k ** (-0.5)
q = pad_and_reshape(q, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
v = pad_and_reshape(v, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V]
w = pad_and_reshape(w, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, Hv]
do = pad_and_reshape(do, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V]
dv = pad_and_reshape(dv, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V]
g = fill_last_chunk_of_g(g, num_tokens, cu_seqlens, chunk_size=chunk_size)
mask = torch.triu(
torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device),
diagonal=1,
)
decay_mask = torch.exp(g[:, :, :, None, :] - g[:, :, None, :, :])
decay_mask = decay_mask.masked_fill(
mask[None, None, :, :, None], 0.0
) # [B, N, C, D, Hv]
dg_last = (h * dh).sum(dim=-1).sum(dim=-1) # [B, N, Hv]
ds = torch.einsum("bnchv, bndhv -> bncdh", do, v)
dq = torch.einsum("bnchv, bnhkv -> bnchk", do, h)
dk = torch.einsum("bnchv, bnhkv -> bnchk", v, dh)
dw = -torch.einsum("bnchv, bnhkv -> bnchk", dv, h)
g_last = g[:, :, -1]
dg_last *= g_last.exp()
dq = dq * g.unsqueeze(-1).exp() * scale
dg = (q * dq).sum(dim=-1) # [B, N, C, Hv]
dk = dk * (g_last.unsqueeze(-2) - g).unsqueeze(-1).exp()
dg -= (k * dk).sum(dim=-1)
dg_last += (k * dk).sum(dim=-1).sum(dim=-2)
ds *= decay_mask * scale
ds2 = ds * torch.einsum("bnchk, bndhk -> bncdh", q, k)
dg += ds2.sum(dim=-2)
dg -= ds2.sum(dim=-3)
dq += torch.einsum("bncdh, bndhk -> bnchk", ds, k)
dk += torch.einsum("bncdh, bnchk -> bndhk", ds, q)
dg[:, :, -1] += dg_last
dg = fill_last_chunk_of_g(
dg, num_tokens, cu_seqlens, chunk_size=chunk_size, reverse=True
)
dq = dq.reshape((batch_size, -1, num_v_heads, head_dim_k))[:, :num_tokens]
dk = dk.reshape((batch_size, -1, num_v_heads, head_dim_k))[:, :num_tokens]
dw = dw.reshape((batch_size, -1, num_v_heads, head_dim_k))[:, :num_tokens]
dg = dg.reshape((batch_size, -1, num_v_heads))[:, :num_tokens]
if cu_seqlens is not None:
dq = pack(dq, cu_seqlens)
dk = pack(dk, cu_seqlens)
dw = pack(dw, cu_seqlens)
dg = pack(dg, cu_seqlens)
return dq, dk, dw, dg
def torch_chunk_wy_bwd(
k: torch.Tensor, # [B, T, Hk, K]
v: torch.Tensor, # [B, T, Hv, V]
beta: torch.Tensor, # [B, T, Hv]
A: torch.Tensor, # [B, T, Hv, D]
g: torch.Tensor, # [B, T, Hv]
dw: torch.Tensor, # [B, T, Hv, K]
du: torch.Tensor, # [B, T, Hv, V]
dk1: torch.Tensor, # [B, T, Hv, K]
dg1: torch.Tensor, # [B, T, Hv]
cu_seqlens: torch.Tensor = None,
):
if cu_seqlens is not None:
k = unpack(k, cu_seqlens)
v = unpack(v, cu_seqlens)
beta = unpack(beta, cu_seqlens)
A = unpack(A, cu_seqlens)
g = unpack(g, cu_seqlens)
dw = unpack(dw, cu_seqlens)
du = unpack(du, cu_seqlens)
dk1 = unpack(dk1, cu_seqlens)
dg1 = unpack(dg1, cu_seqlens)
batch_size, num_tokens, num_k_heads, head_dim_k = k.shape
_, _, num_v_heads, head_dim_v = v.shape
chunk_size = A.shape[-1]
if num_k_heads != num_v_heads:
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
v = pad_and_reshape(v, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V]
beta = pad_and_reshape(beta, dim=1, chunk_size=chunk_size) # [B, N, C, Hv]
A = pad_and_reshape(A, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, D]
g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, Hv]
dw = pad_and_reshape(dw, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
du = pad_and_reshape(du, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V]
dk1 = pad_and_reshape(dk1, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
dg1 = pad_and_reshape(dg1, dim=1, chunk_size=chunk_size) # [B, N, C, Hv]
dA = torch.einsum("bnchk, bndhk -> bnchd", dw, k * (beta * g.exp()).unsqueeze(-1))
dk_beta_g = torch.einsum("bnchd, bnchk -> bndhk", A, dw)
dk = dk_beta_g * (beta * g.exp()).unsqueeze(-1)
db = (dk_beta_g * k * g.exp().unsqueeze(-1)).sum(dim=-1)
dg = (dk_beta_g * k * (g.exp() * beta).unsqueeze(-1)).sum(dim=-1)
dA += torch.einsum("bnchv, bndhv -> bnchd", du, v * beta.unsqueeze(-1))
dv_beta = torch.einsum("bnchd, bnchv -> bndhv", A, du)
dv = dv_beta * beta.unsqueeze(-1)
db += (dv_beta * v).sum(dim=-1)
mask = torch.triu(
torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device)
)
decay_mask = torch.exp(g[:, :, :, None, :] - g[:, :, None, :, :])
decay_mask = decay_mask.masked_fill(mask[None, None, :, :, None], 0.0).swapaxes(
-2, -1
)
dA = dA.masked_fill(mask[None, None, :, None, :], 0.0)
dA = torch.einsum("bndhc, bndhe -> bnche", A, dA)
dA = torch.einsum("bnchd, bnehd -> bnche", dA, A)
dA = -dA * decay_mask
A = torch.einsum("bnchk, bndhk -> bnchd", k * beta.unsqueeze(-1), k)
dk_beta = torch.einsum("bnchd, bndhk -> bnchk", dA, k)
db += (dk_beta * k).sum(dim=-1)
dk += torch.einsum("bnchd, bnchk -> bndhk", dA, k * beta.unsqueeze(-1))
dk += dk_beta * beta.unsqueeze(-1)
dk += dk1
dg += (dA * A).sum(dim=-1) - (dA * A).sum(dim=-3).swapaxes(-1, -2)
dg += dg1
# TODO: NOTE: GVA
dk = dk.reshape((batch_size, -1, num_v_heads, head_dim_k))[:, :num_tokens]
dv = dv.reshape((batch_size, -1, num_v_heads, head_dim_k))[:, :num_tokens]
db = db.reshape((batch_size, -1, num_v_heads))[:, :num_tokens]
dg = dg.reshape((batch_size, -1, num_v_heads))[:, :num_tokens]
if cu_seqlens is not None:
dk = pack(dk, cu_seqlens)
dv = pack(dv, cu_seqlens)
db = pack(db, cu_seqlens)
dg = pack(dg, cu_seqlens)
return dk, dv, db, dg
def chunk_gated_delta_rule_fwd(
q: torch.Tensor, # [B, T, Hk, K]
k: torch.Tensor, # [B, T, Hk, K]
v: torch.Tensor, # [B, T, Hv, K]
g: torch.Tensor, # [B, T, Hv]
beta: torch.Tensor, # [B, T, Hv]
cu_seqlens: torch.Tensor = None,
initial_state: torch.Tensor = None,
scale: float = None,
chunk_size: int = 64,
):
scale = scale or q.shape[-1] ** (-0.5)
g = torch_cumsum(
x=g,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
)
A = torch_kkt_fwd(
k=k,
g=g,
beta=beta,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
)
A = torch_solve(
x=A,
cu_seqlens=cu_seqlens,
)
w, u = torch_w_u_fwd(
k=k,
v=v,
beta=beta,
A=A,
g=g,
cu_seqlens=cu_seqlens,
)
h, vn, final_state = torch_chunk_gdr_fwd(
k=k,
w=w,
u=u,
g=g,
initial_state=initial_state,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
)
o = torch_chunk_o_fwd(
q=q,
k=k,
v=vn,
h=h,
g=g,
cu_seqlens=cu_seqlens,
scale=scale,
chunk_size=chunk_size,
)
return g, o, A, h, final_state
def chunk_gated_delta_rule_bwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
A: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
do: torch.Tensor,
dht: torch.Tensor,
cu_seqlens: torch.Tensor = None,
chunk_size: int = 64,
):
w, u = torch_w_u_fwd(
k=k,
v=v,
beta=beta,
A=A,
g=g,
cu_seqlens=cu_seqlens,
)
h, vn, _ = torch_chunk_gdr_fwd(
k=k,
w=w,
u=u,
g=g,
initial_state=initial_state,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
)
dv = torch_chunk_dv_bwd(
q=q,
k=k,
g=g,
do=do,
scale=scale,
cu_seqlens=cu_seqlens,
)
dh, dh0, dv = torch_chunk_gdr_bwd(
q=q,
k=k,
w=w,
g=g,
h0=initial_state,
dht=dht,
do=do,
dv=dv,
scale=scale,
cu_seqlens=cu_seqlens,
)
dq, dk1, dw, dg1 = torch_chunk_dqkwg_bwd(
q=q,
k=k,
v=vn,
w=w,
g=g,
h=h,
dv=dv,
do=do,
dh=dh,
scale=scale,
cu_seqlens=cu_seqlens,
)
dk, dv, db, dg = torch_chunk_wy_bwd(
k=k,
v=v,
beta=beta,
g=g,
A=A,
dw=dw,
du=dv,
dk1=dk1,
dg1=dg1,
cu_seqlens=cu_seqlens,
)
Hg, H = k.shape[-2], v.shape[-2]
if Hg < H:
B, T, _, K = dq.shape
dq = torch.sum(dq.reshape(B, T, Hg, -1, K), dim=3)
dk = torch.sum(dk.reshape(B, T, Hg, -1, K), dim=3)
dg = torch_cumsum(dg, chunk_size=64, reverse=True, cu_seqlens=cu_seqlens)
return dq, dk, dv, db, dg, dh0

View File

@@ -0,0 +1,2 @@
batch_size,num_tokens,varlen
1,32768,False
1 batch_size num_tokens varlen
2 1 32768 False

View File

@@ -0,0 +1,9 @@
batch_size,num_tokens,varlen,cu_seqlens
1,16384,True,0-268-1139-1179-1212-1476-2792-3202-3611-3726-3820-4096-4882-6417-8130-8192-8328-9426-10473-11002-11754-12288-14085-15370-16384
1,16384,True,0-3393-4096-4153-5636-5853-6318-6777-8192-8320-8931-9163-9494-10040-10113-10363-10561-11061-11388-11634-12288-14545-16288-16384
1,16384,True,0-4096-6111-6485-6589-7118-8192-9056-10448-12288-14032-14525-15884-16012-16384
1,16384,True,0-177-4096-8192-12288-12805-13171-13298-16055-16384
1,16384,True,0-308-1128-1678-4096-4748-8192-8506-9657-10252-12113-12288-16384
1,16384,True,0-4096-6893-7665-8192-12288-16384
1,16384,True,0-410-841-1135-2126-2512-4096-4682-5022-5375-6259-6335-6580-6648-7308-8192-10450-12058-12288-14215-15280-15701-16384
1,16384,True,0-2048-4096-6144-8192-10240-12288-14336-16384
1 batch_size num_tokens varlen cu_seqlens
2 1 16384 True 0-268-1139-1179-1212-1476-2792-3202-3611-3726-3820-4096-4882-6417-8130-8192-8328-9426-10473-11002-11754-12288-14085-15370-16384
3 1 16384 True 0-3393-4096-4153-5636-5853-6318-6777-8192-8320-8931-9163-9494-10040-10113-10363-10561-11061-11388-11634-12288-14545-16288-16384
4 1 16384 True 0-4096-6111-6485-6589-7118-8192-9056-10448-12288-14032-14525-15884-16012-16384
5 1 16384 True 0-177-4096-8192-12288-12805-13171-13298-16055-16384
6 1 16384 True 0-308-1128-1678-4096-4748-8192-8506-9657-10252-12113-12288-16384
7 1 16384 True 0-4096-6893-7665-8192-12288-16384
8 1 16384 True 0-410-841-1135-2126-2512-4096-4682-5022-5375-6259-6335-6580-6648-7308-8192-10450-12058-12288-14215-15280-15701-16384
9 1 16384 True 0-2048-4096-6144-8192-10240-12288-14336-16384

View File

@@ -0,0 +1,5 @@
batch_size,num_tokens,varlen
1,4096,False
1,8192,False
1,16384,False
1,32768,False
1 batch_size num_tokens varlen
2 1 4096 False
3 1 8192 False
4 1 16384 False
5 1 32768 False

View File

@@ -0,0 +1,7 @@
batch_size,num_tokens,varlen
11,33,False
7,4321,False
3,16789,True
5,8192,True
10,1024,True
20,512,True
1 batch_size num_tokens varlen
2 11 33 False
3 7 4321 False
4 3 16789 True
5 5 8192 True
6 10 1024 True
7 20 512 True

View File

@@ -0,0 +1,71 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
"""Static checks for autograd ``Function`` signatures.
PyTorch validates that ``Function.backward`` returns exactly as many gradients
as ``Function.forward`` received non-``ctx`` inputs; mismatches raise at
``.backward()`` time. ``tests/test_gdr.py`` invokes ``chunk_gated_delta_rule_fwd``
and ``chunk_gated_delta_rule_bwd`` directly, bypassing the autograd path, so
drift between the forward signature and the backward return tuple goes
uncaught by the existing suite.
These tests parse the source files with ``ast`` instead of importing the
modules so they run on CPU-only / non-Hopper machines.
"""
import ast
from pathlib import Path
REPO_ROOT = Path(__file__).resolve().parents[1]
CHUNK_INIT = "flash_qla/ops/gated_delta_rule/chunk/__init__.py"
def _parse(rel_path: str) -> ast.Module:
return ast.parse((REPO_ROOT / rel_path).read_text(encoding="utf-8"))
def _get_class(module: ast.Module, name: str) -> ast.ClassDef:
for node in module.body:
if isinstance(node, ast.ClassDef) and node.name == name:
return node
raise AssertionError(f"class {name!r} not found")
def _get_method(cls: ast.ClassDef, name: str) -> ast.FunctionDef:
for node in cls.body:
if isinstance(node, ast.FunctionDef) and node.name == name:
return node
raise AssertionError(f"method {name!r} not found on {cls.name}")
def test_chunk_gated_delta_rule_grad_count_matches_forward_inputs():
"""``backward`` must return one gradient per non-``ctx`` input of ``forward``."""
module = _parse(CHUNK_INIT)
cls = _get_class(module, "ChunkGatedDeltaRuleFunction")
fwd = _get_method(cls, "forward")
fwd_args = fwd.args.args
assert fwd_args and fwd_args[0].arg == "ctx", (
"forward must take `ctx` as its first argument"
)
n_inputs = len(fwd_args) - 1 # exclude ctx
bwd = _get_method(cls, "backward")
returns = [n for n in ast.walk(bwd) if isinstance(n, ast.Return)]
assert len(returns) == 1, f"expected one Return in backward, got {len(returns)}"
assert isinstance(returns[0].value, ast.Tuple), (
"backward must return a tuple literal"
)
n_grads = len(returns[0].value.elts)
assert n_inputs == n_grads, (
f"backward returns {n_grads} gradients but forward takes {n_inputs} non-ctx "
f"inputs; PyTorch will raise a count-mismatch error at .backward() time."
)
if __name__ == "__main__":
test_chunk_gated_delta_rule_grad_count_matches_forward_inputs()
print("OK")

553
tests/test_gdr.py Normal file
View File

@@ -0,0 +1,553 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
import argparse
import math
import torch
import pandas as pd
# Requires flash-linear-attention==0.5.0
from fla.ops.gated_delta_rule.chunk import (
chunk_gated_delta_rule_fwd as chunk_gated_delta_rule_fwd_fla,
)
from fla.ops.gated_delta_rule.chunk import (
chunk_gated_delta_rule_bwd as chunk_gated_delta_rule_bwd_fla,
)
from flash_qla import chunk_gated_delta_rule_fwd as chunk_gated_delta_rule_fwd_qla
from flash_qla import chunk_gated_delta_rule_bwd as chunk_gated_delta_rule_bwd_qla
from flash_qla.utils import l2norm, pack, profile
from ref_gdr import chunk_gated_delta_rule_fwd as chunk_gated_delta_rule_fwd_ref
from ref_gdr import chunk_gated_delta_rule_bwd as chunk_gated_delta_rule_bwd_ref
def test_gated_delta_rule(
batch_size: int,
num_tokens: int,
num_k_heads: int,
num_v_heads: int,
head_dim_k: int,
head_dim_v: int,
varlen: bool = False,
cu_seqlens: list[int] | None = None,
use_h0: bool = False,
chunk_size: int = 64,
data_dtype: str = "bfloat16",
ref_dtype: str = "float32",
device: torch.device = "cuda",
random_seed: int = 42,
check_accuracy: bool = True,
show_speedup: bool = True,
auto_cp: bool = True,
swa_ratio: float = 0.75,
skip_bwd: bool = False,
):
data_dtype = getattr(torch, data_dtype)
ref_dtype = getattr(torch, ref_dtype)
torch.manual_seed(random_seed)
q = l2norm(
torch.randn(
(batch_size, num_tokens, num_k_heads, head_dim_k),
device=device,
dtype=data_dtype,
)
)
k = l2norm(
torch.randn(
(batch_size, num_tokens, num_k_heads, head_dim_k),
device=device,
dtype=data_dtype,
)
)
v = torch.randn(
(batch_size, num_tokens, num_v_heads, head_dim_v),
device=device,
dtype=data_dtype,
)
g = (
torch.nn.functional.logsigmoid(
torch.randn(
(batch_size, num_tokens, num_v_heads),
device=device,
dtype=torch.float32,
)
)
/ 16
)
beta = torch.randn(
(batch_size, num_tokens, num_v_heads), device=device, dtype=torch.float32
).sigmoid()
h0 = (
torch.randn(
(batch_size, num_v_heads, head_dim_k, head_dim_v),
device=device,
dtype=torch.float32,
)
if use_h0
else None
)
do = torch.randn_like(v)
dht = (
torch.randn(
(batch_size, num_v_heads, head_dim_k, head_dim_v),
device=device,
dtype=torch.float32,
)
/ 8
if use_h0
else None
)
scale = head_dim_k ** (-0.5)
print(
f"Shape: B={batch_size} Hk={num_k_heads} Hv={num_v_heads} T={num_tokens} VarLen={varlen}"
)
swa_mask = torch.zeros((num_v_heads), dtype=torch.bool, device=device)
swa_mask[: math.ceil(swa_ratio * num_v_heads)] = 1
swa_mask = swa_mask[torch.randperm(num_v_heads, device=device)]
g[:, :, ~swa_mask] = 0.0
print(f"SWA Mask: {swa_mask.to(torch.int32, copy=True).tolist()}")
if varlen:
if cu_seqlens is None:
cu_seqlens = torch.randint(
1, num_tokens, (batch_size,), device=device, dtype=torch.int32
)
cu_seqlens = torch.nn.functional.pad(
torch.cumsum(cu_seqlens, dim=-1), (1, 0)
)
q = pack(q, cu_seqlens)
k = pack(k, cu_seqlens)
v = pack(v, cu_seqlens)
g = pack(g, cu_seqlens)
beta = pack(beta, cu_seqlens)
do = pack(do, cu_seqlens)
else:
assert batch_size == 1
assert cu_seqlens[0] == 0
assert cu_seqlens[-1] == num_tokens
cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32)
if use_h0:
real_batch_size = cu_seqlens.shape[0] - 1
h0 = torch.randn(
(real_batch_size, num_v_heads, head_dim_k, head_dim_v),
device=device,
dtype=torch.float32,
)
dht = (
torch.randn(
(real_batch_size, num_v_heads, head_dim_k, head_dim_v),
device=device,
dtype=torch.float32,
)
/ 8
)
assert (cu_seqlens[1:] - cu_seqlens[:-1]).min() > 0
else:
cu_seqlens = None
g_ref, o_ref, A_ref, h_ref, s_ref = chunk_gated_delta_rule_fwd_ref(
q=q.to(ref_dtype, copy=True),
k=k.to(ref_dtype, copy=True),
v=v.to(ref_dtype, copy=True),
g=g.to(ref_dtype, copy=True),
beta=beta.to(ref_dtype, copy=True),
scale=scale,
initial_state=h0,
cu_seqlens=cu_seqlens,
)
g_fla, o_fla, A_fla, s_fla, _, _ = chunk_gated_delta_rule_fwd_fla(
q=q,
k=k,
v=v,
g=g,
beta=beta,
scale=scale,
initial_state=h0,
output_final_state=True,
cu_seqlens=cu_seqlens,
)
g_qla, A_qla, o_qla, h_qla, s_qla = chunk_gated_delta_rule_fwd_qla(
q=q,
k=k,
v=v,
g=g,
beta=beta,
scale=scale,
initial_state=h0,
cu_seqlens=cu_seqlens,
output_final_state=True,
output_h=True,
auto_cp=auto_cp,
)
if check_accuracy:
print(
f"h_qla: {(h_qla - h_ref).abs().max().item():.4f} / {h_ref.abs().max().item():.4f}"
)
print(
f"s_fla: {(s_fla - s_ref).abs().max().item():.4f} / {s_ref.abs().max().item():.4f}"
)
print(
f"s_qla: {(s_qla - s_ref).abs().max().item():.4f} / {s_ref.abs().max().item():.4f}"
)
print(
f"o_fla: {(o_fla - o_ref).abs().max().item():.4f} / {o_ref.abs().max().item():.4f}"
)
print(
f"o_qla: {(o_qla - o_ref).abs().max().item():.4f} / {o_ref.abs().max().item():.4f}"
)
for _ in range(1000):
g_qla, A_qla, o_qla, h_qla, s_qla = chunk_gated_delta_rule_fwd_qla(
q,
k,
v,
g,
beta,
scale,
h0,
cu_seqlens,
True,
False,
auto_cp,
)
try:
if h0 is not None:
assert (
s_qla - s_ref
).abs().max().item() <= s_ref.abs().max().item() * 0.02
assert (
o_qla - o_ref
).abs().max().item() <= o_ref.abs().max().item() * 0.02
except AssertionError as e:
print("********** ERROR **********")
if h0 is not None:
print(
f"s_qla: {(s_qla - s_ref).abs().max().item():.4f} / {s_ref.abs().max().item():.4f}"
)
print(
f"o_qla: {(o_qla - o_ref).abs().max().item():.4f} / {o_ref.abs().max().item():.4f}"
)
print("********** ERROR **********")
raise e
if show_speedup:
prof_fla = profile(
chunk_gated_delta_rule_fwd_fla,
[q, k, v, g, beta, scale, h0, True, cu_seqlens],
)
prof_qla = profile(
chunk_gated_delta_rule_fwd_qla,
[q, k, v, g, beta, scale, h0, cu_seqlens, True, False, auto_cp],
)
result_fla = {
"[fwd] csum": prof_fla["chunk_local_cumsum_scalar_kernel"],
"[fwd] solve": prof_fla["chunk_gated_delta_rule_fwd_kkt_solve_kernel"],
"[fwd] wu": prof_fla["recompute_w_u_fwd_kernel"],
"[fwd] gdr": prof_fla["chunk_gated_delta_rule_fwd_kernel_h_blockdim64"],
"[fwd] o": prof_fla["chunk_fwd_kernel_o"],
}
result_qla = {
"[fwd] csum": prof_qla["tilelang_chunk_local_cumsum_kernel_kernel"],
"[fwd] solve": prof_qla["tilelang_kkt_solve_kernel_kernel"],
"[fwd] gdr": prof_qla["tilelang_fused_chunk_gdr_fwd_kernel_kernel"],
}
if "tilelang_get_warmup_chunks_kernel_kernel" in prof_qla.keys():
result_fla["[fwd] cp-w"] = None
result_fla["[fwd] cp-h"] = None
result_fla["[fwd] cp-c"] = None
result_qla["[fwd] cp-w"] = prof_qla[
"tilelang_get_warmup_chunks_kernel_kernel"
]
result_qla["[fwd] cp-h"] = prof_qla["tilelang_prepare_h_kernel_kernel"]
result_qla["[fwd] cp-c"] = prof_qla["tilelang_correct_h0_kernel_kernel"]
result_fla["total"] = prof_fla["total"]
result_qla["total"] = prof_qla["total"]
results = {
"fla": result_fla,
"flash_qla": result_qla,
}
df = pd.DataFrame(results)
print(df.round(3))
speedup = results["fla"]["total"] / results["flash_qla"]["total"]
print(f"Speed up: {speedup:.2f}x")
if skip_bwd:
return
dq_ref, dk_ref, dv_ref, db_ref, dg_ref, dh0_ref = chunk_gated_delta_rule_bwd_ref(
q.to(ref_dtype, copy=True),
k.to(ref_dtype, copy=True),
v.to(ref_dtype, copy=True),
g_ref,
beta.to(ref_dtype, copy=True),
A_ref.to(ref_dtype, copy=True),
scale,
h0,
do.to(ref_dtype, copy=True),
dht,
cu_seqlens,
)
dq_fla, dk_fla, dv_fla, db_fla, dg_fla, dh0_fla, _, _ = (
chunk_gated_delta_rule_bwd_fla(
q,
k,
v,
g_fla,
beta,
A_fla,
scale,
h0,
do,
dht,
cu_seqlens,
)
)
dq_qla, dk_qla, dv_qla, db_qla, dg_qla, dh0_qla = chunk_gated_delta_rule_bwd_qla(
q,
k,
v,
g_qla,
beta,
A_qla,
do,
dht,
scale,
h0,
cu_seqlens,
)
if check_accuracy:
print(
f"dq_fla: {(dq_fla - dq_ref).abs().max().item():.4f} / {dq_ref.abs().max().item():.4f}"
)
print(
f"dq_qla: {(dq_qla - dq_ref).abs().max().item():.4f} / {dq_ref.abs().max().item():.4f}"
)
print(
f"dk_fla: {(dk_fla - dk_ref).abs().max().item():.4f} / {dk_ref.abs().max().item():.4f}"
)
print(
f"dk_qla: {(dk_qla - dk_ref).abs().max().item():.4f} / {dk_ref.abs().max().item():.4f}"
)
print(
f"dv_fla: {(dv_fla - dv_ref).abs().max().item():.4f} / {dv_ref.abs().max().item():.4f}"
)
print(
f"dv_qla: {(dv_qla - dv_ref).abs().max().item():.4f} / {dv_ref.abs().max().item():.4f}"
)
if dht is not None:
print(
f"dh0_fla: {(dh0_fla - dh0_ref).abs().max().item():.4f} / {dh0_ref.abs().max().item():.4f}"
)
print(
f"dh0_qla: {(dh0_qla - dh0_ref).abs().max().item():.4f} / {dh0_ref.abs().max().item():.4f}"
)
print(
f"db_fla: {(db_fla - db_ref).abs().max().item():.4f} / {db_ref.abs().max().item():.4f}"
)
print(
f"db_qla: {(db_qla - db_ref).abs().max().item():.4f} / {db_ref.abs().max().item():.4f}"
)
print(
f"dg_fla: {(dg_fla - dg_ref).abs().max().item():.4f} / {dg_ref.abs().max().item():.4f}"
)
print(
f"dg_qla: {(dg_qla - dg_ref).abs().max().item():.4f} / {dg_ref.abs().max().item():.4f}"
)
for _ in range(1000):
dq_qla, dk_qla, dv_qla, db_qla, dg_qla, dh0_qla = (
chunk_gated_delta_rule_bwd_qla(
q,
k,
v,
g_qla,
beta,
A_qla,
do,
dht,
scale,
h0,
cu_seqlens,
)
)
try:
assert (
dq_qla - dq_ref
).abs().max().item() <= dq_ref.abs().max().item() * 0.02
assert (
dk_qla - dk_ref
).abs().max().item() <= dk_ref.abs().max().item() * 0.02
assert (
dv_qla - dv_ref
).abs().max().item() <= dv_ref.abs().max().item() * 0.02
assert (
dg_qla - dg_ref
).abs().max().item() <= dg_ref.abs().max().item() * 0.02
assert (
db_qla - db_ref
).abs().max().item() <= db_ref.abs().max().item() * 0.02
if dht is not None:
assert (
dh0_qla - dh0_ref
).abs().max().item() <= dh0_ref.abs().max().item() * 0.02
except AssertionError as e:
print("********** ERROR **********")
print(
f"dq_qla: {(dq_qla - dq_ref).abs().max().item():.4f} / {dq_ref.abs().max().item():.4f}"
)
print(
f"dk_qla: {(dk_qla - dk_ref).abs().max().item():.4f} / {dk_ref.abs().max().item():.4f}"
)
print(
f"dv_qla: {(dv_qla - dv_ref).abs().max().item():.4f} / {dv_ref.abs().max().item():.4f}"
)
if dht is not None:
print(
f"dh0_qla: {(dh0_qla - dh0_ref).abs().max().item():.4f} / {dh0_ref.abs().max().item():.4f}"
)
print(
f"db_qla: {(db_qla - db_ref).abs().max().item():.4f} / {db_ref.abs().max().item():.4f}"
)
print(
f"dg_qla: {(dg_qla - dg_ref).abs().max().item():.4f} / {dg_ref.abs().max().item():.4f}"
)
print("********** ERROR **********")
raise e
if show_speedup:
prof_fla = profile(
chunk_gated_delta_rule_bwd_fla,
[q, k, v, g_fla, beta, A_fla, scale, h0, do, dht, cu_seqlens],
)
prof_qla = profile(
chunk_gated_delta_rule_bwd_qla,
[q, k, v, g_qla, beta, A_qla, do, dht, scale, h0, cu_seqlens],
)
result_fla = {
"[bwd] csum": prof_fla["chunk_local_cumsum_scalar_kernel"],
"[bwd] recom": prof_fla["recompute_w_u_fwd_kernel"]
+ prof_fla["chunk_gated_delta_rule_fwd_kernel_h_blockdim64"],
"[bwd] dv": prof_fla["chunk_bwd_kernel_dv_local"],
"[bwd] gdr": prof_fla["chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64"],
"[bwd] dqkwg": prof_fla["kernel_kernel"],
"[bwd] wy": prof_fla["prepare_wy_repr_bwd_kernel"],
}
result_qla = {
"[bwd] csum": prof_qla["tilelang_chunk_local_cumsum_kernel_kernel"],
"[bwd] recom": prof_qla["tilelang_prepare_h_kernel_kernel"],
"[bwd] gdr": prof_qla["tilelang_fused_chunk_gdr_bwd_kernel_kernel"],
}
if num_k_heads < num_v_heads:
result_fla["[bwd] reduc"] = prof_fla["compress_heads_kernel"]
result_qla["[bwd] reduc"] = (
prof_qla["tilelang_group_reduce_vector_kernel_kernel"] * 2
)
result_fla["total"] = prof_fla["total"]
result_qla["total"] = prof_qla["total"]
results = {
"fla": result_fla,
"flash_qla": result_qla,
}
df = pd.DataFrame(results)
print(df.round(3))
speedup = results["fla"]["total"] / results["flash_qla"]["total"]
print(f"Speed up: {speedup:2.2f}x")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Test Gated Delta Rule")
parser.add_argument(
"--set",
type=str,
default="develop",
help="Preset name (loads from settings/{set}.csv)",
)
parser.add_argument(
"--seqlen", "--num-tokens", type=int, default=16384, help="Sequence Length"
)
parser.add_argument(
"--nkh",
"--num-k-heads",
type=int,
default=0,
help="Number of K heads (num_k_heads)",
)
parser.add_argument(
"--nvh",
"--num-heads",
"--num-v-heads",
type=int,
default=64,
help="Number of V heads (num_v_heads)",
)
parser.add_argument(
"--no-h0",
action="store_true",
help="Disable initial state and gradient of final state",
)
parser.add_argument("--skip-bwd", action="store_true", help="Test forward only")
parser.add_argument(
"--no-cp",
"--disable-auto-cp",
action="store_true",
help="Disable auto intra-card CP",
)
parser.add_argument(
"--swa-ratio", type=float, default=0.75, help="Ratio of sliding-window heads"
)
parser.add_argument(
"--data-dtype",
type=str,
default="bfloat16",
help="Data type for input and output",
)
parser.add_argument(
"--ref-dtype", type=str, default="float64", help="Data type for reference"
)
parser.add_argument("--hide-acc", action="store_true", help="Do not print accuracy")
parser.add_argument("--hide-lat", action="store_true", help="Do not print latency")
parser.add_argument(
"--seed", "--random-seed", type=int, default=42, help="Random seed"
)
args = parser.parse_args()
if args.nkh <= 0:
args.nkh = args.nvh
metadata = {
"head_dim_k": 128, # MUST BE 128
"head_dim_v": 128, # MUST BE 128
"chunk_size": 64, # MUST BE 64
"num_tokens": args.seqlen,
"num_k_heads": args.nkh,
"num_v_heads": args.nvh,
"use_h0": not args.no_h0,
"data_dtype": args.data_dtype,
"ref_dtype": args.ref_dtype,
"check_accuracy": not args.hide_acc,
"show_speedup": not args.hide_lat,
"skip_bwd": args.skip_bwd,
"auto_cp": not args.no_cp,
"swa_ratio": args.swa_ratio,
"random_seed": args.seed,
"device": "cuda",
}
import os
script_dir = os.path.dirname(os.path.abspath(__file__))
preset = pd.read_csv(os.path.join(script_dir, "settings", f"{args.set}.csv"))
for i, row in preset.iterrows():
print("-" * 64)
torch.cuda.empty_cache()
data = row.to_dict()
if "cu_seqlens" in data.keys():
data["cu_seqlens"] = list(map(int, data["cu_seqlens"].split("-")))
metadata.update(data)
test_gated_delta_rule(**metadata)
print("-" * 64)

View File

@@ -0,0 +1,57 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
import math
import pytest
import torch
from flash_qla.ops.gated_delta_rule.legacy import chunk_gated_delta_rule_fwd_legacy
def _reference(q, k, v, g, beta, scale=None, initial_state=None):
batch, tokens, q_heads, dim = q.shape
v_heads = v.shape[2]
scale = scale if scale is not None else dim**-0.5
state = (
initial_state.clone()
if initial_state is not None
else torch.zeros(batch, v_heads, dim, dim, device=q.device, dtype=q.dtype)
)
output = torch.empty_like(v)
for b in range(batch):
for hv in range(v_heads):
hq = hv // (v_heads // q_heads)
for t in range(tokens):
gate = torch.exp(g[b, t, hv])
delta = (v[b, t, hv] - gate * (state[b, hv].transpose(0, 1) @ k[b, t, hq])) * beta[b, t, hv]
state[b, hv] = gate * state[b, hv] + torch.outer(k[b, t, hq], delta)
output[b, t, hv] = scale * (state[b, hv].transpose(0, 1) @ q[b, t, hq])
return output, state
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required")
@pytest.mark.parametrize("dim", [16, 32, 64, 128])
def test_legacy_sm_gdn_matches_reference(dim):
torch.manual_seed(1000 + dim)
q = torch.randn(1, 5, 2, dim, device="cuda", dtype=torch.float32).contiguous() * 0.05
k = torch.randn_like(q).contiguous() * 0.05
v = torch.randn(1, 5, 4, dim, device="cuda", dtype=torch.float32).contiguous() * 0.1
g = torch.randn(1, 5, 4, device="cuda", dtype=torch.float32).contiguous() * 0.02 - 0.04
beta = torch.rand(1, 5, 4, device="cuda", dtype=torch.float32).contiguous()
h0 = torch.randn(1, 4, dim, dim, device="cuda", dtype=torch.float32).contiguous() * 0.01
scale = 1.0 / math.sqrt(dim)
out_ref, state_ref = _reference(q, k, v, g, beta, scale, h0)
out, state = chunk_gated_delta_rule_fwd_legacy(q, k, v, g, beta, scale, h0)
torch.cuda.synchronize()
torch.testing.assert_close(out, out_ref, atol=2e-4, rtol=2e-4)
torch.testing.assert_close(state, state_ref, atol=1e-3, rtol=1e-3)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required")
def test_legacy_sm_gdn_rejects_unsupported_dtype():
q = torch.randn(1, 1, 1, 16, device="cuda", dtype=torch.float16)
with pytest.raises(ValueError, match="float32"):
chunk_gated_delta_rule_fwd_legacy(q, q, q, torch.randn(1, 1, 1, device="cuda"), torch.randn(1, 1, 1, device="cuda"))