commit 3f95e2939d1d1c589168876cd9c8b94e1dc32c2b Author: Hokori Date: Sun Jun 14 23:49:03 2026 +0800 first commit diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..a55b1f5 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2026 Qwen, Alibaba + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..5a4dbd1 --- /dev/null +++ b/README.md @@ -0,0 +1,206 @@ +> [!IMPORTANT] +> This repository is an experimental SM70/SM75 fork of [QwenLM/FlashQLA](https://github.com/QwenLM/FlashQLA). +> +> It is not an official FlashQLA release and does not replace the upstream Hopper/SM90 implementation. + +# FlashQLA-SM70-SM75 + +Experimental forward-inference support for Qwen-style Gated DeltaNet on SM70/SM75-class NVIDIA GPUs. + +This fork keeps the upstream Hopper/SM90 TileLang path intact and adds an explicit legacy backend entry point for Volta/Turing inference devices. The current runtime validation target is RTX 2080 Ti / SM75. SM70 currently has compile coverage, but V100-class runtime validation is still required before making performance claims. + +## Changes in This Fork + +- Adds `flash_qla.ops.gated_delta_rule.legacy.chunk_gated_delta_rule_fwd_legacy`. +- Adds a lazy-built CUDA extension for a forward-only SM70/SM75-class Gated DeltaNet backend. +- Keeps the upstream Hopper/SM90 TileLang path unchanged. +- Keeps the legacy path explicit instead of silently replacing the upstream high-level API. +- Adds CUDA correctness tests for the supported legacy path. +- Documents the supported scope, validation status, and benchmark caveats separately from upstream Hopper results. + +## Supported Scope + +Supported: + +- forward inference only +- SM70/SM75-class CUDA devices as the intended legacy target family +- scalar-gate Gated DeltaNet +- Qwen-style grouped-query head mapping +- primary optimized shape: `D=128` +- explicit legacy API entry point + +Not supported: + +- backward kernels or training +- automatic dispatch from the upstream high-level API +- generic support for all pre-Hopper NVIDIA GPUs +- runtime performance claims for SM70 before V100-class validation +- SM80/SM86/SM89 support claims +- automatic default dispatch for non-Hopper devices + +## Current Validation + +Runtime validation was performed on RTX 2080 Ti / SM75. + +Standalone kernel timing for a Qwen-like shape: + +- `B=1, T=512, Hq=16, Hv=32, D=128` +- control recurrent path: about `1.126 ms` +- optimized legacy path on SM75: about `0.520-0.533 ms` +- GDN-stage speedup: about `2.1x` + +GGUF runtime profiling on SM75: + +- default fused GDN: `406.656 ms` +- legacy fast path: `195.105 ms` +- GDN-stage speedup: about `2.08x` + +Whole-request impact under the same server parameters: + +- prefill: `+7.17%` +- decode: `+0.61%` +- wall time: `-3.49%` + +SM70 status: + +- compile check passes +- runtime validation is pending +- V100-class benchmarking is needed before claiming SM70 performance + +Fork wrapper status: + +- Python syntax check passes +- CUDA tests are included under `tests/test_legacy_sm_gdn.py` +- CUDA PyTorch runtime validation still requires a CUDA-enabled PyTorch environment + +## Positioning + +This fork is meant to make the SM70/SM75 experiment reproducible and reviewable. It should be treated as an upstreamable experimental branch, not as a separate long-term replacement for FlashQLA. + +--- + +The original upstream README follows below. + +

+ +

+ +

|   📜 Blog   |

+ +## Introduction + +FlashQLA is a high-performance linear attention kernel library built on [TileLang](https://github.com/tile-ai/tilelang). FlashQLA applies **reasonable operator fusion and performance optimization** to the forward and backward passes of GDN Chunked Prefill, achieving **2-3× forward speedup** and **2× backward speedup** over the FLA Triton kernel across multiple scenarios on NVIDIA Hopper. The efficiency gains are particularly pronounced in pretraining scenarios and edge-side agentic inference. + +Key features: + +1.**Gate-driven automatic intra-card context parallelism**. By exploiting the exponential decay property of the GDN gate, FlashQLA automatically enables intra-card CP under TP, long-sequence, and small-head-count settings, improving GPU SM utilization. + +2.**Hardware-friendly algebraic reformulation**. We reformulate the forward and backward flows of GDN Chunked Prefill to a certain extent, effectively reducing Tensor Core, CUDA Core, and SFU overhead without sacrificing numerical precision. + +3.**TileLang fused warp-specialized kernels**. Rather than following the step-by-step decomposition into independent kernels, nor fusing the entire computation flow into a single kernel, we take CP and backward requirements into account, use TileLang to build several key fused kernels, and manually implement warpgroup specialization to overlap data movement, Tensor Core computation, and CUDA Core computation. + +## Requirements + +- SM90 or above +- CUDA 12.8 or above +- PyTorch 2.8 or above + +## Installation + +```bash +git clone https://github.com/QwenLM/FlashQLA.git +cd FlashQLA +pip install -v . +``` + +## Usage + +### High-level API + +```python +import torch +from flash_qla import chunk_gated_delta_rule + +o, final_state = chunk_gated_delta_rule( + q=q, # [B, T, H_q, K] + k=k, # [B, T, H_q, K] + v=v, # [B, T, H_v, V] + g=g, # [B, T, H_v] + beta=beta, # [B, T, H_v] + scale=scale, + initial_state=initial_state, # optional, [B, H_v, K, V] + output_final_state=True, + cu_seqlens=cu_seqlens, # optional, for variable-length sequences +) +``` + +### Low-level API + +For separate forward and backward calls: + +```python +from flash_qla import chunk_gated_delta_rule_fwd, chunk_gated_delta_rule_bwd + +# Forward +g, A, o, h, final_state = chunk_gated_delta_rule_fwd( + q, k, v, g, beta, scale=scale, initial_state=h0, cu_seqlens=cu_seqlens +) + +# Backward +dq, dk, dv, db, dg, dh0 = chunk_gated_delta_rule_bwd( + q, k, v, g, beta, A, do, dht=dht, scale=scale, initial_state=h0, cu_seqlens=cu_seqlens +) +``` + +## Tests + +```bash +# require flash linear attention for comparison +pip install flash_linear_attention==0.5.0 + +cd tests +python test_gdr.py --set develop +python test_gdr.py --set varlen --num-heads 32 +python test_gdr.py --set profile --num-heads 32 +python test_gdr.py --set product --ref-dtype float32 --num-heads 32 +``` + +## Benchmark + +We benchmarked FlashQLA against the FLA Triton and FlashInfer baseline (FLA 0.5.0, Triton 3.5.1, FlashInfer 0.6.9, TileLang 0.1.8) on the head configurations used by the Qwen3.5 / Qwen3.6 family h_k,v \in {64, 48, 32, 24, 16, 8}, corresponding to TP1 through TP8. + +

+ +

+ +Specifically, the forward (FWD) benchmarks measure single-kernel latency for different models and TP settings under varying batch lengths, while the backward (BWD) benchmarks examine the relationship between total token count within a batch and latency during a single update step. + +More detail in [benchmark_results_H200.txt](./benchmark/benchmark_results_H200.txt). + +```bash +# require flash linear attention and flashinfer for comparison +pip install flash_linear_attention==0.5.0 flashinfer-python==0.6.9 + +cd benchmark +python bench_gated_delta_rule.py +``` + +## Acknowledge + +FlashQLA is inspired by [Flash Linear Attention](https://github.com/fla-org/flash-linear-attention), [TileLang](https://github.com/tile-ai/tilelang) and [FlashInfer](https://github.com/flashinfer-ai/flashinfer/) projects. + +## License + +FlashQLA is released under the MIT License. + +## Citation + +```bibtex +@misc{flashqla2025, + title={FlashQLA: Flash Qwen Linear Attention}, + author={Zhang, Chengruidong and Lin, Xi and Jiang, Huiqiang and Wang, Zekun and Li, Xiao and Cao, Yizhong and Zhuang, Bohan and Men, Rui and Zhang, Jianwei and Zheng, Bo and Lin, Junyang and Liu, Dayiheng and Zhou, Jingren}, + year={2026}, + publisher={GitHub}, + howpublished={\url{https://github.com/QwenLM/FlashQLA}}, +} +``` diff --git a/benchmark/bench_gated_delta_rule.py b/benchmark/bench_gated_delta_rule.py new file mode 100644 index 0000000..b9a8004 --- /dev/null +++ b/benchmark/bench_gated_delta_rule.py @@ -0,0 +1,585 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] +# +# Benchmark Script for FlashQLA + +import argparse +import math +import gc +from dataclasses import dataclass +from typing import List, Tuple, Optional, Dict, Any + +import torch +import torch.nn.functional as F + +import tilelang + +# Kernel Imports +from fla.ops.gated_delta_rule.chunk import ( + chunk_gated_delta_rule_fwd as fla_fwd, + chunk_gated_delta_rule_bwd as fla_bwd, +) + +from flash_qla import ( + chunk_gated_delta_rule_fwd as qla_fwd, + chunk_gated_delta_rule_bwd as qla_bwd, +) +from flash_qla.utils import l2norm + +try: + from flashinfer.gdn_prefill import chunk_gated_delta_rule as fi_fwd + + HAS_FI = True +except ImportError: + HAS_FI = False + +HEAD_DIM = 128 +BWD_SPLIT_SIZE = 8 + + +@dataclass +class ModelConfig: + label: str + h_qk: int + h_v: int + + +@dataclass +class SeqLenConfig: + label: str + seqlens: List[int] + + +def generate_rand_seqlens(batch_size, num_tokens): + bars = ( + torch.sort( + torch.randperm(num_tokens - 1, device="cuda", dtype=torch.int32)[ + : batch_size - 1 + ] + ).values + + 1 + ) + cu_seqlens = torch.nn.functional.pad(bars, (1, 1)) + cu_seqlens[-1] = num_tokens + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + return seqlens.tolist() + + +FWD_MODEL_CONFIGS = [ + ModelConfig("397B/122B TP8", 2, 8), + ModelConfig("397B/122B TP4", 4, 16), + ModelConfig("397B/122B TP2", 8, 32), + ModelConfig("397B/122B TP1", 16, 64), + ModelConfig("35B/9B/4B TP1", 16, 32), + ModelConfig("27B TP2", 8, 24), + ModelConfig("27B TP1", 16, 48), + ModelConfig("2B/0.8B TP1", 16, 16), + ModelConfig("Sym h32", 32, 32), +] + +FWD_SEQLEN_CONFIGS = [ + SeqLenConfig("1x32768", [32768]), + SeqLenConfig("1x16384", [16384]), + SeqLenConfig("1x8192", [8192]), + SeqLenConfig("1x4096", [4096]), + SeqLenConfig("1x2048", [2048]), + SeqLenConfig("28672+4096", [28672, 4096]), + SeqLenConfig("24576+8192", [24576, 8192]), + SeqLenConfig("16384+16384", [16384, 16384]), + SeqLenConfig("8192+24576", [8192, 24576]), + SeqLenConfig("4096+28672", [4096, 28672]), + SeqLenConfig("12288+4096", [12288, 4096]), + SeqLenConfig("6144+2048", [6144, 2048]), + SeqLenConfig("4096+4096", [4096, 4096]), + SeqLenConfig("2048+6144", [2048, 6144]), + SeqLenConfig("1024+7168", [1024, 7168]), + SeqLenConfig("8192x4", [8192] * 4), + SeqLenConfig("4096x8", [4096] * 8), + SeqLenConfig("2048x4", [2048] * 4), + SeqLenConfig("1024x8", [1024] * 8), +] +BWD_MODEL_CONFIGS = [ + ModelConfig("32", 32, 32), + ModelConfig("48", 48, 48), + ModelConfig("64", 64, 64), +] + +BWD_SEQLEN_CONFIGS = [ + SeqLenConfig("16k", generate_rand_seqlens(8, 16384)), + SeqLenConfig("32k", generate_rand_seqlens(8, 32768)), + SeqLenConfig("64k", generate_rand_seqlens(8, 65536)), + SeqLenConfig("128k", generate_rand_seqlens(8, 131072)), + SeqLenConfig("256k", generate_rand_seqlens(8, 262144)), +] + + +def cleanup_cuda(): + try: + if torch.cuda.is_available(): + torch.cuda.synchronize() + gc.collect() + torch.cuda.empty_cache() + except Exception: + pass + + +def get_lib_versions() -> Dict[str, str]: + """Collects version strings for relevant libraries.""" + versions = {} + + # Torch + try: + versions["torch"] = torch.__version__ + except Exception: + versions["torch"] = "N/A" + + # Flash Linear Attention (FLA) + try: + import fla + + if hasattr(fla, "__version__"): + versions["fla"] = fla.__version__ + else: + versions["fla"] = "Installed (ver unknown)" + except ImportError: + versions["fla"] = "Not Installed" + + # FlashInfer + try: + import flashinfer + + if hasattr(flashinfer, "__version__"): + versions["flashinfer"] = flashinfer.__version__ + elif hasattr(flashinfer, "version"): + versions["flashinfer"] = str(flashinfer.version) + else: + versions["flashinfer"] = "Installed (ver unknown)" + except ImportError: + versions["flashinfer"] = "Not Installed" + + # TileLang + try: + import tilelang + + if hasattr(tilelang, "__version__"): + versions["tilelang"] = tilelang.__version__ + elif hasattr(tilelang, "version"): + versions["tilelang"] = str(tilelang.version) + else: + versions["tilelang"] = "Installed (ver unknown)" + except ImportError: + versions["tilelang"] = "Not Installed" + + return versions + + +def prepare_tensors( + seqlens: List[int], h_qk: int, h_v: int, head_dim: int = HEAD_DIM +) -> Optional[Dict[str, Any]]: + device = "cuda" + num_seqs = len(seqlens) + total_tokens = sum(seqlens) + scale = head_dim ** (-0.5) + + offsets = [0] + for s in seqlens: + offsets.append(offsets[-1] + s) + cu_seqlens = torch.tensor(offsets, dtype=torch.int32, device=device) + + try: + q = l2norm( + torch.randn( + 1, total_tokens, h_qk, head_dim, device=device, dtype=torch.bfloat16 + ) + ) + k = l2norm( + torch.randn( + 1, total_tokens, h_qk, head_dim, device=device, dtype=torch.bfloat16 + ) + ) + v = torch.randn( + 1, total_tokens, h_v, head_dim, device=device, dtype=torch.bfloat16 + ) + g = ( + F.logsigmoid( + torch.randn(1, total_tokens, h_v, device=device, dtype=torch.float32) + ) + / 16 + ) + beta = torch.randn( + 1, total_tokens, h_v, device=device, dtype=torch.float32 + ).sigmoid() + h0 = torch.randn( + num_seqs, h_v, head_dim, head_dim, device=device, dtype=torch.float32 + ) + do = torch.randn_like(v) + dht = ( + torch.randn( + num_seqs, h_v, head_dim, head_dim, device=device, dtype=torch.float32 + ) + / 8 + ) + except RuntimeError as e: + if "out of memory" in str(e).lower(): + return None + raise e + + swa_ratio = 0.75 + swa_mask = torch.zeros(h_v, dtype=torch.bool, device=device) + swa_mask[: math.ceil(swa_ratio * h_v)] = True + swa_mask = swa_mask[torch.randperm(h_v, device=device)] + g[:, :, ~swa_mask] = 0.0 + + return { + "device": device, + "num_seqs": num_seqs, + "total_tokens": total_tokens, + "scale": scale, + "cu_seqlens": cu_seqlens, + "q": q, + "k": k, + "v": v, + "g": g, + "beta": beta, + "h0": h0, + "do": do, + "dht": dht, + } + + +def bench_fwd( + seqlens: List[int], + h_qk: int, + h_v: int, + head_dim: int = HEAD_DIM, + warmup: int = 10, + repeats: int = 5, +) -> Tuple[float, float, float]: + """ + Run Forward Pass Benchmark. + Returns: (qla_mean_ms, fi_mean_ms, fla_mean_ms) + """ + cleanup_cuda() + data = prepare_tensors(seqlens, h_qk, h_v, head_dim) + if data is None: + return float("nan"), float("nan"), float("nan") + + q, k, v, g, beta = data["q"], data["k"], data["v"], data["g"], data["beta"] + h0, scale, cu_seqlens = data["h0"], data["scale"], data["cu_seqlens"] + + results = {} + + def call_qla_fwd(): + qla_fwd( + q, + k, + v, + g, + beta, + scale=scale, + initial_state=h0, + output_final_state=True, + output_h=False, + cu_seqlens=cu_seqlens, + auto_cp=True, + ) + + try: + mean = tilelang.profiler.do_bench(call_qla_fwd, warmup=warmup, rep=repeats) + results["flash_qla"] = mean + except RuntimeError as e: + print(f"\n[WARN] FlashQLA Fwd failed: {e}") + cleanup_cuda() + results["flash_qla"] = float("nan") + + if HAS_FI: + + def call_fi_fwd(): + fi_fwd( + q=q.view(-1, h_qk, head_dim), + k=k.view(-1, h_qk, head_dim), + v=v.view(-1, h_v, head_dim), + g=g.view(-1, h_v), + beta=beta.view(-1, h_v), + scale=scale, + initial_state=h0, + cu_seqlens=cu_seqlens, + output_final_state=True, + ) + + try: + mean = tilelang.profiler.do_bench(call_fi_fwd, warmup=warmup, rep=repeats) + results["fi"] = mean + except RuntimeError as e: + print(f"\n[WARN] FI Fwd failed: {e}") + cleanup_cuda() + results["fi"] = float("nan") + else: + results["fi"] = float("nan") + + def call_fla_fwd(): + fla_fwd( + q, + k, + v, + g, + beta, + scale=scale, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens, + ) + + try: + mean = tilelang.profiler.do_bench(call_fla_fwd, warmup=warmup, rep=repeats) + results["fla"] = mean + except RuntimeError as e: + print(f"\n[WARN] FLA Fwd failed: {e}") + cleanup_cuda() + results["fla"] = float("nan") + + try: + torch.cuda.synchronize() + except Exception: + pass + + return ( + results.get("flash_qla", float("nan")), + results.get("fi", float("nan")), + results.get("fla", float("nan")), + ) + + +def bench_bwd( + seqlens: List[int], + h_qk: int, + h_v: int, + head_dim: int = HEAD_DIM, + warmup: int = 10, + repeats: int = 100, +) -> Tuple[float, float]: + """ + Run Backward Pass Benchmark. + Returns: (qla_mean_ms, fla_mean_ms) + """ + unified_h = h_qk + cleanup_cuda() + + data = prepare_tensors(seqlens, unified_h, unified_h, head_dim) + if data is None: + return float("nan"), float("nan") + + q, k, v, g, beta = data["q"], data["k"], data["v"], data["g"], data["beta"] + h0, scale, cu_seqlens = data["h0"], data["scale"], data["cu_seqlens"] + do, dht = data["do"], data["dht"] + + g_cumsum = None + A = None + + # Pre-run FWD to get intermediates + try: + result = qla_fwd( + q, + k, + v, + g, + beta, + scale=scale, + initial_state=h0, + output_final_state=True, + output_h=False, + cu_seqlens=cu_seqlens, + auto_cp=True, + ) + if isinstance(result, tuple) and len(result) >= 2: + g_cumsum, A = result[0], result[1] + else: + raise RuntimeError("FlashQLA FWD did not return expected intermediates") + except RuntimeError as e: + print(f"[FWD Error] Failed at seqlens={seqlens}, heads={h_qk}. Error: {e}") + cleanup_cuda() + return float("nan"), float("nan") + + results = {} + + def call_qla_bwd(): + return qla_bwd( + q, + k, + v, + g_cumsum, + beta, + A, + do, + dht, + scale=scale, + initial_state=h0, + cu_seqlens=cu_seqlens, + ) + + try: + mean = tilelang.profiler.do_bench(call_qla_bwd, warmup=warmup, rep=repeats) + results["flash_qla"] = mean + except RuntimeError as e: + print(f"\n[WARN] FlashQLA Bwd failed: {e}") + cleanup_cuda() + results["flash_qla"] = float("nan") + + def call_fla_bwd(): + return fla_bwd(q, k, v, g_cumsum, beta, A, scale, h0, do, dht, cu_seqlens) + + try: + mean = tilelang.profiler.do_bench(call_fla_bwd, warmup=warmup, rep=repeats) + results["fla"] = mean + except RuntimeError as e: + print(f"\n[WARN] FLA Bwd failed: {e}") + cleanup_cuda() + results["fla"] = float("nan") + + return results.get("flash_qla", float("nan")), results.get("fla", float("nan")) + + +FWD_HDR = ( + f"{'Model Config':<16} {'Seqlens':<17} {'h_qk':>5} {'h_v':>5} " + f"{'flash_qla [fwd]':>10} {'FI [fwd]':>10} {'FLA [fwd]':>10} " + f"{'vs FLA':>7} {'vs FI':>7}" +) + +BWD_HDR = ( + f"{'Heads':<8} {'SeqLen':<15} " + f"{'flash_qla [bwd]':>10} {'FLA [bwd]':>10} {'Speedup':>8}" +) + + +def fmt_time(ms: float) -> str: + if math.isnan(ms): + return " N/A " + return f"{ms:>8.3f}ms" + + +def fmt_ratio(base: float, other: float) -> str: + if math.isnan(base) or math.isnan(other) or base == 0: + return " N/A " + return f"{other / base:>6.2f}x" + + +def main(): + parser = argparse.ArgumentParser(description="Benchmark FlashQLA Gated Delta Rule") + parser.add_argument("--warmup", type=int, default=10) + parser.add_argument("--repeats", type=int, default=100) + parser.add_argument("--mode", choices=["fwd", "bwd", "all"], default="all") + parser.add_argument("--skip-fi", action="store_true") + args = parser.parse_args() + + if not torch.cuda.is_available(): + print("CUDA not available.") + return + + global HAS_FI + if args.skip_fi: + HAS_FI = False + + gpu_name = torch.cuda.get_device_properties(0).name + print(f"GPU: {gpu_name}") + print("Models: Qwen3.5 family (397B, 122B, 35B, 27B, 9B, 4B, 2B, 0.8B), d=128") + print(f"Config: Warmup={args.warmup}, Repeats={args.repeats}") + + libs = get_lib_versions() + print("Library Versions:") + ver_str = " | ".join([f"{k}: {v}" for k, v in libs.items()]) + print(f" {ver_str}") + + print("=" * 110) + + # Forward + if args.mode in ("fwd", "all"): + print("\n>>> FORWARD BENCHMARKS") + print(FWD_HDR) + print("-" * len(FWD_HDR)) + + prev_model = None + for cfg in FWD_MODEL_CONFIGS: + if prev_model is not None and cfg.label != prev_model: + print() + prev_model = cfg.label + + for sl_cfg in FWD_SEQLEN_CONFIGS: + try: + qla_ms, fi_ms, fla_ms = bench_fwd( + sl_cfg.seqlens, + cfg.h_qk, + cfg.h_v, + warmup=args.warmup, + repeats=args.repeats, + ) + + if math.isnan(qla_ms) and math.isnan(fla_ms): + cleanup_cuda() + + ratio_fla = fmt_ratio(qla_ms, fla_ms) + ratio_fi = fmt_ratio(qla_ms, fi_ms) + + print( + f"{cfg.label:<16} {sl_cfg.label:<17} {cfg.h_qk:>5} {cfg.h_v:>5} " + f"{fmt_time(qla_ms)} {fmt_time(fi_ms)} {fmt_time(fla_ms)} " + f"{ratio_fla} {ratio_fi}", + flush=True, + ) + except Exception as e: + print( + f"\n[ERROR] Forward Case Failed: {cfg.label} / {sl_cfg.label}" + ) + print(f"Exception: {e}") + cleanup_cuda() + continue + + cleanup_cuda() + + # Backward + if args.mode in ("bwd", "all"): + print("\n" + "=" * 110) + print(f"\n>>> BACKWARD BENCHMARKS (Split seq into {BWD_SPLIT_SIZE} sequences)") + print(BWD_HDR) + print("-" * len(BWD_HDR)) + + prev_model = None + for cfg in BWD_MODEL_CONFIGS: + if prev_model is not None and cfg.label != prev_model: + print() + prev_model = cfg.label + + for sl_cfg in BWD_SEQLEN_CONFIGS: + try: + qla_bwd_ms, fla_bwd_ms = bench_bwd( + sl_cfg.seqlens, + cfg.h_qk, + cfg.h_v, + warmup=args.warmup, + repeats=args.repeats, + ) + + if math.isnan(qla_bwd_ms) and math.isnan(fla_bwd_ms): + speedup = " Skip " + else: + speedup = fmt_ratio(qla_bwd_ms, fla_bwd_ms) + + print( + f"{cfg.label:<8} {sl_cfg.label:<15} " + f"{fmt_time(qla_bwd_ms)} {fmt_time(fla_bwd_ms)} {speedup} ", + flush=True, + ) + except Exception as e: + print( + f"\n[CRITICAL ERROR] Case Crashed: Heads={cfg.label}, TotalSeqLen={sl_cfg.label}" + ) + print(f"Exception: {e}") + cleanup_cuda() + continue + + cleanup_cuda() + + print("\nBenchmark Finished.") + + +if __name__ == "__main__": + main() diff --git a/benchmark/benchmark_results_H200.txt b/benchmark/benchmark_results_H200.txt new file mode 100644 index 0000000..4ee59f5 --- /dev/null +++ b/benchmark/benchmark_results_H200.txt @@ -0,0 +1,214 @@ +GPU: NVIDIA H200 +Models: Qwen3.5 family (397B, 122B, 35B, 27B, 9B, 4B, 2B, 0.8B), d=128 +Config: Warmup=10, Repeats=100 +Library Versions: + torch: 2.9.1+cu128 | fla: 0.5.0 | flashinfer: 0.6.9 | tilelang: 0.1.8 +============================================================================================================== + +>>> FORWARD BENCHMARKS +Model Config Seqlens h_qk h_v flash_qla [fwd] FI [fwd] FLA [fwd] vs FLA vs FI +------------------------------------------------------------------------------------------------------------ +397B/122B TP8 1x32768 2 8 0.310ms 1.657ms 0.910ms 2.93x 5.34x +397B/122B TP8 1x16384 2 8 0.184ms 0.832ms 0.465ms 2.53x 4.52x +397B/122B TP8 1x8192 2 8 0.122ms 0.420ms 0.578ms 4.74x 3.45x +397B/122B TP8 1x4096 2 8 0.107ms 0.215ms 0.277ms 2.59x 2.01x +397B/122B TP8 1x2048 2 8 0.106ms 0.114ms 0.261ms 2.46x 1.07x +397B/122B TP8 28672+4096 2 8 0.304ms 1.449ms 0.841ms 2.77x 4.77x +397B/122B TP8 24576+8192 2 8 0.300ms 1.243ms 0.766ms 2.55x 4.15x +397B/122B TP8 16384+16384 2 8 0.292ms 0.831ms 0.623ms 2.13x 2.84x +397B/122B TP8 8192+24576 2 8 0.298ms 1.241ms 0.766ms 2.57x 4.16x +397B/122B TP8 4096+28672 2 8 0.304ms 1.447ms 0.841ms 2.76x 4.75x +397B/122B TP8 12288+4096 2 8 0.180ms 0.626ms 0.391ms 2.17x 3.48x +397B/122B TP8 6144+2048 2 8 0.118ms 0.317ms 0.263ms 2.23x 2.70x +397B/122B TP8 4096+4096 2 8 0.115ms 0.215ms 0.262ms 2.27x 1.86x +397B/122B TP8 2048+6144 2 8 0.118ms 0.317ms 0.276ms 2.35x 2.70x +397B/122B TP8 1024+7168 2 8 0.120ms 0.368ms 0.263ms 2.19x 3.06x +397B/122B TP8 8192x4 2 8 0.274ms 0.401ms 0.477ms 1.74x 1.46x +397B/122B TP8 4096x8 2 8 0.203ms 0.205ms 0.481ms 2.37x 1.01x +397B/122B TP8 2048x4 2 8 0.109ms 0.110ms 0.268ms 2.46x 1.01x +397B/122B TP8 1024x8 2 8 0.065ms 0.063ms 0.273ms 4.17x 0.95x + +397B/122B TP4 1x32768 4 16 0.485ms 1.655ms 1.250ms 2.58x 3.41x +397B/122B TP4 1x16384 4 16 0.292ms 0.832ms 0.630ms 2.16x 2.85x +397B/122B TP4 1x8192 4 16 0.178ms 0.420ms 0.325ms 1.83x 2.37x +397B/122B TP4 1x4096 4 16 0.115ms 0.215ms 0.290ms 2.51x 1.86x +397B/122B TP4 1x2048 4 16 0.108ms 0.114ms 0.275ms 2.54x 1.05x +397B/122B TP4 28672+4096 4 16 0.477ms 1.440ms 1.178ms 2.47x 3.02x +397B/122B TP4 24576+8192 4 16 0.474ms 1.222ms 1.100ms 2.32x 2.58x +397B/122B TP4 16384+16384 4 16 0.469ms 0.789ms 0.947ms 2.02x 1.68x +397B/122B TP4 8192+24576 4 16 0.470ms 1.192ms 1.099ms 2.34x 2.53x +397B/122B TP4 4096+28672 4 16 0.474ms 1.394ms 1.175ms 2.48x 2.94x +397B/122B TP4 12288+4096 4 16 0.281ms 0.616ms 0.549ms 1.95x 2.19x +397B/122B TP4 6144+2048 4 16 0.175ms 0.312ms 0.292ms 1.67x 1.78x +397B/122B TP4 4096+4096 4 16 0.173ms 0.207ms 0.263ms 1.52x 1.19x +397B/122B TP4 2048+6144 4 16 0.175ms 0.306ms 0.294ms 1.67x 1.75x +397B/122B TP4 1024+7168 4 16 0.177ms 0.359ms 0.307ms 1.74x 2.03x +397B/122B TP4 8192x4 4 16 0.388ms 0.400ms 0.949ms 2.45x 1.03x +397B/122B TP4 4096x8 4 16 0.311ms 0.210ms 0.954ms 3.07x 0.68x +397B/122B TP4 2048x4 4 16 0.114ms 0.110ms 0.264ms 2.32x 0.97x +397B/122B TP4 1024x8 4 16 0.097ms 0.066ms 0.280ms 2.88x 0.68x + +397B/122B TP2 1x32768 8 32 0.866ms 1.567ms 1.894ms 2.19x 1.81x +397B/122B TP2 1x16384 8 32 0.468ms 0.789ms 0.926ms 1.98x 1.69x +397B/122B TP2 1x8192 8 32 0.275ms 0.401ms 0.479ms 1.74x 1.46x +397B/122B TP2 1x4096 8 32 0.175ms 0.207ms 0.273ms 1.56x 1.18x +397B/122B TP2 1x2048 8 32 0.116ms 0.110ms 0.308ms 2.66x 0.96x +397B/122B TP2 28672+4096 8 32 0.841ms 1.374ms 1.893ms 2.25x 1.63x +397B/122B TP2 24576+8192 8 32 0.846ms 1.178ms 1.891ms 2.24x 1.39x +397B/122B TP2 16384+16384 8 32 0.752ms 0.789ms 1.897ms 2.52x 1.05x +397B/122B TP2 8192+24576 8 32 0.842ms 1.182ms 1.893ms 2.25x 1.40x +397B/122B TP2 4096+28672 8 32 0.843ms 1.378ms 1.899ms 2.25x 1.63x +397B/122B TP2 12288+4096 8 32 0.468ms 0.597ms 0.928ms 1.98x 1.28x +397B/122B TP2 6144+2048 8 32 0.271ms 0.305ms 0.506ms 1.87x 1.13x +397B/122B TP2 4096+4096 8 32 0.206ms 0.206ms 0.482ms 2.34x 1.00x +397B/122B TP2 2048+6144 8 32 0.269ms 0.305ms 0.483ms 1.79x 1.13x +397B/122B TP2 1024+7168 8 32 0.277ms 0.354ms 0.480ms 1.73x 1.27x +397B/122B TP2 8192x4 8 32 0.596ms 0.405ms 1.894ms 3.18x 0.68x +397B/122B TP2 4096x8 8 32 0.601ms 0.409ms 1.904ms 3.17x 0.68x +397B/122B TP2 2048x4 8 32 0.172ms 0.115ms 0.486ms 2.83x 0.67x +397B/122B TP2 1024x8 8 32 0.180ms 0.124ms 0.504ms 2.80x 0.69x + +397B/122B TP1 1x32768 16 64 1.494ms 1.563ms 3.336ms 2.23x 1.05x +397B/122B TP1 1x16384 16 64 0.762ms 0.790ms 1.608ms 2.11x 1.04x +397B/122B TP1 1x8192 16 64 0.393ms 0.397ms 0.815ms 2.07x 1.01x +397B/122B TP1 1x4096 16 64 0.209ms 0.206ms 0.430ms 2.06x 0.99x +397B/122B TP1 1x2048 16 64 0.117ms 0.111ms 0.276ms 2.37x 0.95x +397B/122B TP1 28672+4096 16 64 1.719ms 1.374ms 3.341ms 1.94x 0.80x +397B/122B TP1 24576+8192 16 64 1.530ms 1.179ms 3.341ms 2.18x 0.77x +397B/122B TP1 16384+16384 16 64 1.168ms 0.804ms 3.367ms 2.88x 0.69x +397B/122B TP1 8192+24576 16 64 1.539ms 1.232ms 3.352ms 2.18x 0.80x +397B/122B TP1 4096+28672 16 64 1.723ms 1.453ms 3.353ms 1.95x 0.84x +397B/122B TP1 12288+4096 16 64 0.781ms 0.596ms 1.606ms 2.06x 0.76x +397B/122B TP1 6144+2048 16 64 0.406ms 0.305ms 0.822ms 2.02x 0.75x +397B/122B TP1 4096+4096 16 64 0.318ms 0.209ms 0.823ms 2.59x 0.66x +397B/122B TP1 2048+6144 16 64 0.406ms 0.317ms 0.825ms 2.03x 0.78x +397B/122B TP1 1024+7168 16 64 0.454ms 0.371ms 0.825ms 1.82x 0.82x +397B/122B TP1 8192x4 16 64 1.170ms 0.800ms 3.361ms 2.87x 0.68x +397B/122B TP1 4096x8 16 64 1.191ms 0.819ms 3.411ms 2.86x 0.69x +397B/122B TP1 2048x4 16 64 0.324ms 0.219ms 0.837ms 2.58x 0.67x +397B/122B TP1 1024x8 16 64 0.344ms 0.239ms 0.868ms 2.52x 0.70x + +35B/9B/4B TP1 1x32768 16 32 0.871ms 1.567ms 1.908ms 2.19x 1.80x +35B/9B/4B TP1 1x16384 16 32 0.472ms 0.790ms 0.945ms 2.00x 1.67x +35B/9B/4B TP1 1x8192 16 32 0.278ms 0.402ms 0.484ms 1.74x 1.44x +35B/9B/4B TP1 1x4096 16 32 0.176ms 0.207ms 0.266ms 1.51x 1.17x +35B/9B/4B TP1 1x2048 16 32 0.118ms 0.111ms 0.265ms 2.24x 0.94x +35B/9B/4B TP1 28672+4096 16 32 0.860ms 1.374ms 1.904ms 2.21x 1.60x +35B/9B/4B TP1 24576+8192 16 32 0.857ms 1.180ms 1.904ms 2.22x 1.38x +35B/9B/4B TP1 16384+16384 16 32 0.762ms 0.796ms 1.919ms 2.52x 1.05x +35B/9B/4B TP1 8192+24576 16 32 0.854ms 1.184ms 1.909ms 2.24x 1.39x +35B/9B/4B TP1 4096+28672 16 32 0.853ms 1.379ms 1.911ms 2.24x 1.62x +35B/9B/4B TP1 12288+4096 16 32 0.474ms 0.597ms 0.946ms 2.00x 1.26x +35B/9B/4B TP1 6144+2048 16 32 0.270ms 0.304ms 0.488ms 1.81x 1.13x +35B/9B/4B TP1 4096+4096 16 32 0.208ms 0.206ms 0.489ms 2.35x 0.99x +35B/9B/4B TP1 2048+6144 16 32 0.272ms 0.304ms 0.487ms 1.79x 1.12x +35B/9B/4B TP1 1024+7168 16 32 0.279ms 0.354ms 0.488ms 1.75x 1.27x +35B/9B/4B TP1 8192x4 16 32 0.596ms 0.406ms 1.903ms 3.19x 0.68x +35B/9B/4B TP1 4096x8 16 32 0.607ms 0.412ms 1.918ms 3.16x 0.68x +35B/9B/4B TP1 2048x4 16 32 0.171ms 0.116ms 0.495ms 2.89x 0.68x +35B/9B/4B TP1 1024x8 16 32 0.184ms 0.125ms 0.513ms 2.78x 0.68x + +27B TP2 1x32768 8 24 0.657ms 1.631ms 1.576ms 2.40x 2.48x +27B TP2 1x16384 8 24 0.397ms 0.820ms 0.778ms 1.96x 2.06x +27B TP2 1x8192 8 24 0.262ms 0.416ms 0.402ms 1.53x 1.58x +27B TP2 1x4096 8 24 0.170ms 0.213ms 0.267ms 1.57x 1.25x +27B TP2 1x2048 8 24 0.111ms 0.113ms 0.266ms 2.40x 1.02x +27B TP2 28672+4096 8 24 0.655ms 1.422ms 1.495ms 2.28x 2.17x +27B TP2 24576+8192 8 24 0.651ms 1.211ms 1.425ms 2.19x 1.86x +27B TP2 16384+16384 8 24 0.657ms 0.791ms 1.579ms 2.40x 1.20x +27B TP2 8192+24576 8 24 0.656ms 1.185ms 1.576ms 2.40x 1.81x +27B TP2 4096+28672 8 24 0.655ms 1.382ms 1.577ms 2.41x 2.11x +27B TP2 12288+4096 8 24 0.392ms 0.611ms 0.707ms 1.80x 1.56x +27B TP2 6144+2048 8 24 0.258ms 0.310ms 0.372ms 1.44x 1.20x +27B TP2 4096+4096 8 24 0.189ms 0.206ms 0.406ms 2.15x 1.09x +27B TP2 2048+6144 8 24 0.261ms 0.303ms 0.406ms 1.56x 1.16x +27B TP2 1024+7168 8 24 0.261ms 0.352ms 0.407ms 1.56x 1.35x +27B TP2 8192x4 8 24 0.545ms 0.401ms 1.419ms 2.61x 0.74x +27B TP2 4096x8 8 24 0.546ms 0.410ms 1.430ms 2.62x 0.75x +27B TP2 2048x4 8 24 0.155ms 0.112ms 0.374ms 2.41x 0.73x +27B TP2 1024x8 8 24 0.164ms 0.120ms 0.383ms 2.34x 0.74x + +27B TP1 1x32768 16 48 1.243ms 1.573ms 2.728ms 2.19x 1.27x +27B TP1 1x16384 16 48 0.664ms 0.792ms 1.327ms 2.00x 1.19x +27B TP1 1x8192 16 48 0.402ms 0.401ms 0.677ms 1.69x 1.00x +27B TP1 1x4096 16 48 0.193ms 0.206ms 0.355ms 1.84x 1.07x +27B TP1 1x2048 16 48 0.107ms 0.110ms 0.264ms 2.47x 1.03x +27B TP1 28672+4096 16 48 1.231ms 1.371ms 2.618ms 2.13x 1.11x +27B TP1 24576+8192 16 48 1.434ms 1.185ms 2.541ms 1.77x 0.83x +27B TP1 16384+16384 16 48 1.074ms 0.787ms 2.744ms 2.56x 0.73x +27B TP1 8192+24576 16 48 1.435ms 1.186ms 2.741ms 1.91x 0.83x +27B TP1 4096+28672 16 48 1.230ms 1.371ms 2.731ms 2.22x 1.11x +27B TP1 12288+4096 16 48 0.733ms 0.598ms 1.230ms 1.68x 0.82x +27B TP1 6144+2048 16 48 0.379ms 0.305ms 0.638ms 1.68x 0.80x +27B TP1 4096+4096 16 48 0.290ms 0.207ms 0.682ms 2.35x 0.71x +27B TP1 2048+6144 16 48 0.379ms 0.306ms 0.689ms 1.82x 0.81x +27B TP1 1024+7168 16 48 0.425ms 0.355ms 0.688ms 1.62x 0.84x +27B TP1 8192x4 16 48 1.068ms 0.804ms 2.530ms 2.37x 0.75x +27B TP1 4096x8 16 48 0.899ms 0.611ms 2.571ms 2.86x 0.68x +27B TP1 2048x4 16 48 0.296ms 0.218ms 0.645ms 2.18x 0.74x +27B TP1 1024x8 16 48 0.268ms 0.180ms 0.659ms 2.46x 0.67x + +2B/0.8B TP1 1x32768 16 16 0.492ms 1.654ms 1.288ms 2.62x 3.36x +2B/0.8B TP1 1x16384 16 16 0.289ms 0.831ms 0.655ms 2.27x 2.88x +2B/0.8B TP1 1x8192 16 16 0.181ms 0.421ms 0.341ms 1.88x 2.32x +2B/0.8B TP1 1x4096 16 16 0.119ms 0.215ms 0.294ms 2.47x 1.81x +2B/0.8B TP1 1x2048 16 16 0.109ms 0.114ms 0.282ms 2.58x 1.05x +2B/0.8B TP1 28672+4096 16 16 0.484ms 1.440ms 1.215ms 2.51x 2.98x +2B/0.8B TP1 24576+8192 16 16 0.478ms 1.222ms 1.140ms 2.38x 2.55x +2B/0.8B TP1 16384+16384 16 16 0.474ms 0.789ms 0.982ms 2.07x 1.67x +2B/0.8B TP1 8192+24576 16 16 0.478ms 1.191ms 1.138ms 2.38x 2.49x +2B/0.8B TP1 4096+28672 16 16 0.481ms 1.392ms 1.217ms 2.53x 2.89x +2B/0.8B TP1 12288+4096 16 16 0.287ms 0.618ms 0.581ms 2.02x 2.15x +2B/0.8B TP1 6144+2048 16 16 0.179ms 0.312ms 0.306ms 1.71x 1.74x +2B/0.8B TP1 4096+4096 16 16 0.177ms 0.207ms 0.292ms 1.65x 1.17x +2B/0.8B TP1 2048+6144 16 16 0.180ms 0.306ms 0.304ms 1.69x 1.70x +2B/0.8B TP1 1024+7168 16 16 0.181ms 0.358ms 0.324ms 1.79x 1.98x +2B/0.8B TP1 8192x4 16 16 0.389ms 0.399ms 0.984ms 2.53x 1.03x +2B/0.8B TP1 4096x8 16 16 0.318ms 0.214ms 0.992ms 3.12x 0.67x +2B/0.8B TP1 2048x4 16 16 0.118ms 0.111ms 0.283ms 2.41x 0.95x +2B/0.8B TP1 1024x8 16 16 0.100ms 0.069ms 0.321ms 3.21x 0.69x + +Sym h32 1x32768 32 32 0.874ms 1.567ms 1.961ms 2.24x 1.79x +Sym h32 1x16384 32 32 0.474ms 0.790ms 0.981ms 2.07x 1.66x +Sym h32 1x8192 32 32 0.281ms 0.402ms 0.508ms 1.81x 1.43x +Sym h32 1x4096 32 32 0.177ms 0.208ms 0.271ms 1.52x 1.17x +Sym h32 1x2048 32 32 0.117ms 0.111ms 0.296ms 2.54x 0.96x +Sym h32 28672+4096 32 32 0.861ms 1.376ms 1.964ms 2.28x 1.60x +Sym h32 24576+8192 32 32 0.857ms 1.181ms 1.964ms 2.29x 1.38x +Sym h32 16384+16384 32 32 0.771ms 0.808ms 1.974ms 2.56x 1.05x +Sym h32 8192+24576 32 32 0.855ms 1.183ms 1.966ms 2.30x 1.38x +Sym h32 4096+28672 32 32 0.857ms 1.379ms 1.965ms 2.29x 1.61x +Sym h32 12288+4096 32 32 0.475ms 0.597ms 0.985ms 2.07x 1.26x +Sym h32 6144+2048 32 32 0.270ms 0.305ms 0.512ms 1.90x 1.13x +Sym h32 4096+4096 32 32 0.215ms 0.210ms 0.513ms 2.38x 0.97x +Sym h32 2048+6144 32 32 0.271ms 0.304ms 0.510ms 1.88x 1.12x +Sym h32 1024+7168 32 32 0.281ms 0.353ms 0.510ms 1.81x 1.26x +Sym h32 8192x4 32 32 0.600ms 0.419ms 1.981ms 3.30x 0.70x +Sym h32 4096x8 32 32 0.620ms 0.426ms 1.994ms 3.22x 0.69x +Sym h32 2048x4 32 32 0.176ms 0.119ms 0.519ms 2.95x 0.67x +Sym h32 1024x8 32 32 0.183ms 0.128ms 0.535ms 2.93x 0.70x + +============================================================================================================== + +>>> BACKWARD BENCHMARKS (Split seq into 8 sequences) +Heads SeqLen flash_qla [bwd] FLA [bwd] Speedup +--------------------------------------------------------------- +32 16k 1.262ms 2.433ms 1.93x +32 32k 2.743ms 4.904ms 1.79x +32 64k 6.554ms 9.247ms 1.41x +32 128k 9.489ms 18.195ms 1.92x +32 256k 20.939ms 38.613ms 1.84x + +48 16k 1.775ms 3.549ms 2.00x +48 32k 3.715ms 7.197ms 1.94x +48 64k 7.565ms 13.864ms 1.83x +48 128k 13.777ms 28.175ms 2.05x +48 256k 31.321ms 56.344ms 1.80x + +64 16k 2.213ms 4.654ms 2.10x +64 32k 4.513ms 9.237ms 2.05x +64 64k 8.024ms 18.265ms 2.28x +64 128k 15.299ms 36.572ms 2.39x +64 256k 34.310ms 74.392ms 2.17x + +Benchmark Finished. diff --git a/flash_qla/__init__.py b/flash_qla/__init__.py new file mode 100644 index 0000000..640f684 --- /dev/null +++ b/flash_qla/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] + +__version__ = "0.1.0" + +from flash_qla.ops.gated_delta_rule.chunk import ( + chunk_gated_delta_rule_fwd, + chunk_gated_delta_rule_bwd, + chunk_gated_delta_rule, +) + +__all__ = [ + "chunk_gated_delta_rule_fwd", + "chunk_gated_delta_rule_bwd", + "chunk_gated_delta_rule", +] diff --git a/flash_qla/ops/__init__.py b/flash_qla/ops/__init__.py new file mode 100644 index 0000000..70df152 --- /dev/null +++ b/flash_qla/ops/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] + +from .gated_delta_rule import chunk_gated_delta_rule + + +__all__ = ["chunk_gated_delta_rule"] diff --git a/flash_qla/ops/gated_delta_rule/__init__.py b/flash_qla/ops/gated_delta_rule/__init__.py new file mode 100644 index 0000000..ea07eeb --- /dev/null +++ b/flash_qla/ops/gated_delta_rule/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] + +from .chunk import chunk_gated_delta_rule + + +__all__ = ["chunk_gated_delta_rule"] diff --git a/flash_qla/ops/gated_delta_rule/chunk/__init__.py b/flash_qla/ops/gated_delta_rule/chunk/__init__.py new file mode 100644 index 0000000..1db0116 --- /dev/null +++ b/flash_qla/ops/gated_delta_rule/chunk/__init__.py @@ -0,0 +1,237 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] + +import torch +import tilelang + +from flash_qla.utils import l2norm +from flash_qla.ops.utils import chunk_local_cumsum, group_reduce_vector + +if tilelang.contrib.nvcc.get_target_compute_version() == "9.0": + from .hopper import fused_gdr_fwd, fused_gdr_bwd, fused_gdr_h, kkt_solve +else: + raise ValueError("FlashQLA now support sm90 only.") +from .cp_context import intra_card_cp_preprocess + + +def chunk_gated_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float | None = None, + initial_state: torch.Tensor | None = None, + cu_seqlens: torch.LongTensor | None = None, + output_final_state: bool = True, + output_h: bool = False, + auto_cp: bool = True, +): + g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens) + A = kkt_solve( + k=k, + b=beta, + cu_seqlens=cu_seqlens, + ) + if auto_cp: + initial_state, cu_seqlens, cp_seq_map, raw_cu_seqlens = ( + intra_card_cp_preprocess( + k=k, + v=v, + a=A, + g=g, + b=beta, + raw_h0=initial_state, + raw_cu_seqlens=cu_seqlens, + ) + ) + else: + cp_seq_map = None + raw_cu_seqlens = None + o, h, final_state = fused_gdr_fwd( + q=q, + k=k, + v=v, + a=A, + g=g, + b=beta, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + output_h=output_h, + output_o=True, + cu_seqlens=cu_seqlens, + cp_seq_map=cp_seq_map, + raw_cu_seqlens=raw_cu_seqlens, + ) + return g, A, o, h, final_state + + +def chunk_gated_delta_rule_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + A: torch.Tensor, + do: torch.Tensor, + dht: torch.Tensor | None = None, + scale: float | None = None, + initial_state: torch.Tensor | None = None, + cu_seqlens: torch.LongTensor | None = None, +): + h, _, _ = fused_gdr_h( + k=k, + v=v, + a=A, + g=g, + b=beta, + initial_state=initial_state, + output_final_state=False, + output_h=True, + cu_seqlens=cu_seqlens, + ) + dq, dk, dv, dg, db, dh0 = fused_gdr_bwd( + q=q, + k=k, + v=v, + a=A, + g=g, + b=beta, + do=do, + dht=dht, + h=h, + scale=scale, + cu_seqlens=cu_seqlens, + ) + Hg, H = k.shape[-2], v.shape[-2] + if Hg < H: + dq = group_reduce_vector(dq, Hg) + dk = group_reduce_vector(dk, Hg) + assert dg.dtype == torch.float32, "dg should be fp32" + dg = chunk_local_cumsum(dg, chunk_size=64, reverse=True, cu_seqlens=cu_seqlens) + return dq, dk, dv, db, dg, dh0 + + +class ChunkGatedDeltaRuleFunction(torch.autograd.Function): + @staticmethod + @torch.amp.custom_fwd(device_type="cuda") + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float | None = None, + initial_state: torch.Tensor | None = None, + output_final_state: bool = False, + cu_seqlens: torch.LongTensor | None = None, + ): + q_orig = q + k_orig = k + + g, A, o, _, final_state = chunk_gated_delta_rule_fwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + output_h=False, + cu_seqlens=cu_seqlens, + ) + + ctx.save_for_backward(q_orig, k_orig, v, g, beta, A, initial_state, cu_seqlens) + ctx.scale = scale + return o.to(q.dtype), final_state + + @staticmethod + @torch.amp.custom_bwd(device_type="cuda") + def backward(ctx, do: torch.Tensor, dht: torch.Tensor): + q_orig, k_orig, v, g, beta, A, initial_state, cu_seqlens = ctx.saved_tensors + + dq, dk, dv, db, dg, dh0 = chunk_gated_delta_rule_bwd( + q=q_orig, + k=k_orig, + v=v, + g=g, + beta=beta, + A=A, + do=do, + dht=dht, + scale=ctx.scale, + initial_state=initial_state, + cu_seqlens=cu_seqlens, + ) + + return ( + dq.to(q_orig), + dk.to(k_orig), + dv.to(v), + dg.to(g), + db.to(beta), + None, + dh0, + None, + None, + ) + + +@torch.compiler.disable +def chunk_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + use_qk_l2norm_in_kernel: bool = False, + cu_seqlens: torch.LongTensor | None = None, + head_first: bool = False, +): + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, ( + "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16 or float16." + ) + assert not head_first, "head_first=True is not supported." + assert v.shape[2] % k.shape[2] == 0, ( + "num_qk_heads must be divisible to num_v_heads." + ) + + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) + + if scale is None: + scale = k.shape[-1] ** -0.5 + + if use_qk_l2norm_in_kernel: + q = l2norm(q) + k = l2norm(k) + + o, final_state = ChunkGatedDeltaRuleFunction.apply( + q, + k, + v, + g, + beta, + scale, + initial_state, + output_final_state, + cu_seqlens, + ) + + return o, final_state diff --git a/flash_qla/ops/gated_delta_rule/chunk/cp_context.py b/flash_qla/ops/gated_delta_rule/chunk/cp_context.py new file mode 100644 index 0000000..4d680c8 --- /dev/null +++ b/flash_qla/ops/gated_delta_rule/chunk/cp_context.py @@ -0,0 +1,163 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] + +import math + +import torch +import tilelang + +from flash_qla.utils import tensor_cache + +if tilelang.contrib.nvcc.get_target_compute_version() == "9.0": + from .hopper import get_warmup_chunks, fused_gdr_h, correct_initial_states +else: + raise ValueError("FlashQLA now support sm90 only.") + + +MULTI_PROCESSOR_COUNT = torch.cuda.get_device_properties().multi_processor_count + + +@tensor_cache +def _create_cu_seqlens( + batch_size: int, + num_tokens: int, + device_idx: int, +): + return ( + torch.arange((batch_size + 1), dtype=torch.int32, device=f"cuda:{device_idx}") + * num_tokens + ) + + +@tensor_cache +def _calc_cp_seqs( + raw_cu_seqlens: torch.LongTensor, + chunk_size: int, + num_v_heads: int, +): + # TODO: tilelang kernel + device = raw_cu_seqlens.device + seqlen_dtype = raw_cu_seqlens.dtype + raw_cu_seqlens = raw_cu_seqlens.tolist() + raw_batch_size = len(raw_cu_seqlens) - 1 + seqlens = [raw_cu_seqlens[i + 1] - raw_cu_seqlens[i] for i in range(raw_batch_size)] + num_chunks = [tilelang.cdiv(x, chunk_size) for x in seqlens] + + # autocp + H = num_v_heads + # Latency model: T = a·L_cp + b·(B·H·Lc/P) / L_cp + c + # Minimizing T yields the theoretical optimum: L_cp* ∝ √(B·H·Lc / P), where P = MULTI_PROCESSOR_COUNT, L_cp = max_local_chunks + # Scaled by empirical factor (3) and aligned to the nearest power of 2 for optimal SM scheduling & memory alignment. + + max_local_chunks = 2 ** round( + math.log2(math.sqrt(H * sum(num_chunks) / MULTI_PROCESSOR_COUNT) * 3) + ) + + # Set min to 4 to ensure multi-stage pipelining in fused_gdr; + max_local_chunks = max(max_local_chunks, 4) + + use_cp = False + cp_cu_seqlens = [] + ht_mask = [] + seq_map_c2r = [] + seq_map_r2c = [0] + max_local_tokens = max_local_chunks * chunk_size + for i, c in enumerate(num_chunks): + s = raw_cu_seqlens[i] + e = raw_cu_seqlens[i + 1] + if c > max_local_chunks: + while s < e: + cp_cu_seqlens.append(s) + ht_mask.append(False) + seq_map_c2r.append(i) + s += max_local_tokens + ht_mask[-1] = True + else: + cp_cu_seqlens.append(s) + ht_mask.append(True) + seq_map_c2r.append(i) + seq_map_r2c.append(len(cp_cu_seqlens)) + cp_cu_seqlens.append(raw_cu_seqlens[-1]) + + # Disable CP when B * H naturally saturates SM occupancy. + # For varlen inputs, use `total_chunks / max_seq_chunks` as effective B, + # since CP helps accelerate highly uneven sequence lengths. + + Be = sum(num_chunks) / max(num_chunks) + use_cp = Be * H <= 40 or (Be * H <= 56 and max(num_chunks) >= 128) + + if use_cp: + cp_cu_seqlens = torch.tensor( + cp_cu_seqlens, dtype=seqlen_dtype, device=device, requires_grad=False + ) + seq_map_c2r = torch.tensor(seq_map_c2r, dtype=seqlen_dtype, device=device) + seq_map_r2c = torch.tensor( + seq_map_r2c, dtype=seqlen_dtype, device=device, requires_grad=False + ) + ht_mask = torch.tensor( + ht_mask, dtype=torch.bool, device=device, requires_grad=False + ) + else: + cp_cu_seqlens, seq_map_r2c, ht_mask = None, None, None + + return use_cp, cp_cu_seqlens, seq_map_r2c, seq_map_c2r, ht_mask + + +def intra_card_cp_preprocess( + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + g: torch.Tensor, + b: torch.Tensor, + raw_h0: torch.Tensor, + raw_cu_seqlens: torch.Tensor, + warmup_threshold: float = -10.0, +): + batch_size, num_tokens, num_k_heads, k_head_dim = k.shape + _, _, num_v_heads, v_head_dim = v.shape + chunk_size = a.shape[-1] + device = k.device + + if batch_size > 1: + return raw_h0, raw_cu_seqlens, None, None + + if raw_cu_seqlens is None: + raw_cu_seqlens = _create_cu_seqlens(batch_size, num_tokens, device.index) + + use_cp, cp_cu_seqlens, seq_map_r2c, seq_map_c2r, ht_mask = _calc_cp_seqs( + raw_cu_seqlens, + chunk_size, + num_v_heads, + ) + + if not use_cp: + return raw_h0, raw_cu_seqlens, None, None + + num_warmup_chunks, fallback_mask = get_warmup_chunks( + g=g, + cu_seqlens=cp_cu_seqlens, + ht_mask=ht_mask, + chunk_size=chunk_size, + threshold=warmup_threshold, + ) # [cp_batch_size, num_v_heads] + _, ht, mt = fused_gdr_h( + k=k, + v=v, + a=a, + g=g, + b=b, + initial_state=None, + output_final_state=True, + output_h=False, + cu_seqlens=cp_cu_seqlens, + num_warmup_chunks=num_warmup_chunks, + ) # [cp_batch_size, num_v_heads, k_head_dim, v_head_dim] + cp_h0 = correct_initial_states( + raw_h0=raw_h0, + ht_buffer=ht, + mt_buffer=mt, + fallback_mask=fallback_mask, + seq_map_r2c=seq_map_r2c, + ) + + return cp_h0, cp_cu_seqlens, seq_map_c2r, raw_cu_seqlens diff --git a/flash_qla/ops/gated_delta_rule/chunk/hopper/__init__.py b/flash_qla/ops/gated_delta_rule/chunk/hopper/__init__.py new file mode 100644 index 0000000..574dd4e --- /dev/null +++ b/flash_qla/ops/gated_delta_rule/chunk/hopper/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] + +from .fused_fwd import fused_gdr_fwd +from .fused_bwd import fused_gdr_bwd +from .prepare_h import fused_gdr_h +from .kkt_solve import kkt_solve +from .cp_fwd import get_warmup_chunks, correct_initial_states + + +__all__ = [ + "fused_gdr_fwd", + "fused_gdr_bwd", + "fused_gdr_h", + "kkt_solve", + "get_warmup_chunks", + "correct_initial_states", +] diff --git a/flash_qla/ops/gated_delta_rule/chunk/hopper/cp_fwd.py b/flash_qla/ops/gated_delta_rule/chunk/hopper/cp_fwd.py new file mode 100644 index 0000000..29f0259 --- /dev/null +++ b/flash_qla/ops/gated_delta_rule/chunk/hopper/cp_fwd.py @@ -0,0 +1,309 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] + +import torch +import tilelang +import tilelang.language as T + + +@tilelang.jit() +def tilelang_get_warmup_chunks( + num_heads, + chunk_size, + threshold, + accum_dtype, + g_dtype, + mask_dtype, + seqlen_dtype, +): + batch_size = T.dynamic("batch_size") + num_tokens = T.dynamic("num_tokens") + num_threads = tilelang.cdiv(num_heads, 32) * 32 + + @T.prim_func + def tilelang_get_warmup_chunks_kernel( + g: T.Tensor([1, num_tokens, num_heads], dtype=g_dtype), + ht_mask: T.Tensor([batch_size], dtype=mask_dtype), + cu_seqlens: T.Tensor([batch_size + 1], dtype=seqlen_dtype), + num_warmup_chunks: T.Tensor([batch_size, num_heads], dtype=seqlen_dtype), + fallback_mask: T.Tensor([batch_size, num_heads], dtype=mask_dtype), + ): + with T.Kernel(batch_size, threads=num_threads) as (bb,): + if ht_mask[bb]: + for i_h in T.Parallel(num_heads): + num_warmup_chunks[bb, i_h] = 0 + else: + seq_start_idx = T.alloc_var("int32") + seq_end_idx = T.alloc_var("int32") + num_iters = T.alloc_var("int32") + seq_start_idx = cu_seqlens[bb] + seq_end_idx = cu_seqlens[bb + 1] + num_iters = (seq_end_idx - seq_start_idx) // chunk_size + + g_fragment = T.alloc_fragment((num_heads), dtype=accum_dtype) + g_cumsum = T.alloc_fragment((num_heads), dtype=accum_dtype) + n_fragment = T.alloc_fragment((num_heads), dtype=seqlen_dtype) + f_fragment = T.alloc_fragment((num_heads), dtype=mask_dtype) + T.clear(g_cumsum) + T.fill(n_fragment, num_iters) + T.fill(f_fragment, True) + + for i_s in T.serial(num_iters): + for i_h in T.Parallel(num_heads): + g_fragment[i_h] = g[0, seq_end_idx - i_s * chunk_size - 1, i_h] + for i_h in T.Parallel(num_heads): + g_cumsum[i_h] += g_fragment[i_h] + for i_h in T.Parallel(num_heads): + if g_cumsum[i_h] < threshold and n_fragment[i_h] == num_iters: + n_fragment[i_h] = i_s + 1 + f_fragment[i_h] = False + + for i_h in T.Parallel(num_heads): + num_warmup_chunks[bb, i_h] = n_fragment[i_h] + for i_h in T.Parallel(num_heads): + fallback_mask[bb, i_h] = f_fragment[i_h] + + return tilelang_get_warmup_chunks_kernel + + +def get_warmup_chunks( + g: torch.Tensor, # [1, num_total_tokens, num_v_heads] + cu_seqlens: torch.Tensor, # [cp_real_batch_size + 1] + ht_mask: torch.Tensor, # [cp_real_batch_size] + chunk_size: int = 64, + threshold: float = -10.0, +): + batch_size, num_tokens, num_heads = g.shape + real_batch_size = ht_mask.shape[0] + assert cu_seqlens.shape[0] == real_batch_size + 1 + assert batch_size == 1 + assert chunk_size == 64 + + tilelang_get_warmup_chunks_kernel = tilelang_get_warmup_chunks( + num_heads=num_heads, + chunk_size=chunk_size, + threshold=threshold, + accum_dtype="float32", + g_dtype=g.dtype, + mask_dtype=ht_mask.dtype, + seqlen_dtype=cu_seqlens.dtype, + ) + num_warmup_chunks = torch.empty( + [real_batch_size, num_heads], dtype=cu_seqlens.dtype, device=cu_seqlens.device + ) + fallback_mask = torch.empty( + [real_batch_size, num_heads], dtype=ht_mask.dtype, device=cu_seqlens.device + ) + tilelang_get_warmup_chunks_kernel( + g, ht_mask, cu_seqlens, num_warmup_chunks, fallback_mask + ) + + return num_warmup_chunks, fallback_mask + + +@tilelang.jit() +def tilelang_correct_h0( + H, + DK, + DV, + res_dtype, + accum_dtype, + buffer_dtype, + seqlen_dtype, + mask_dtype, + use_raw_h0, + block_DV: int = 32, +): + cp_batch_size = T.dynamic("cp_batch_size") + raw_batch_size = T.dynamic("raw_batch_size") + + @T.macro + def kernel_body( + bb, + bh, + bv, + seq_start_idx, + seq_end_idx, + num_iters, + ht_buffer, + mt_buffer, + fallback_mask, + seq_map_r2c, + cp_h0, + h_fragment, + ): + h_shared = T.alloc_shared((DK, block_DV), dtype=buffer_dtype) + hd_shared = T.alloc_shared((DK, block_DV), dtype=buffer_dtype) + m_shared = T.alloc_shared((DK, DK), dtype=buffer_dtype) + + T.copy( + h_fragment, + cp_h0[seq_start_idx, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV], + ) + + for i_s in T.Pipelined(num_iters - 1, num_stages=2): + if fallback_mask[seq_start_idx + i_s, bh]: + T.copy(h_fragment, hd_shared) + T.copy( + ht_buffer[ + seq_start_idx + i_s, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV + ], + h_shared, + ) + T.copy(h_shared, h_fragment) + if fallback_mask[seq_start_idx + i_s, bh]: + T.copy(mt_buffer[seq_start_idx + i_s, bh, 0:DK, 0:DK], m_shared) + T.gemm(m_shared, hd_shared, h_fragment, clear_accum=False) + T.copy( + h_fragment, + cp_h0[ + seq_start_idx + i_s + 1, + bh, + 0:DK, + bv * block_DV : (bv + 1) * block_DV, + ], + ) + + if use_raw_h0: + + @T.prim_func + def tilelang_correct_h0_kernel( + raw_h0: T.Tensor([raw_batch_size, H, DK, DV], dtype=res_dtype), + ht_buffer: T.Tensor([cp_batch_size, H, DK, DV], dtype=buffer_dtype), + mt_buffer: T.Tensor([cp_batch_size, H, DK, DK], dtype=buffer_dtype), + fallback_mask: T.Tensor([cp_batch_size, H], dtype=mask_dtype), + seq_map_r2c: T.Tensor([raw_batch_size + 1], dtype=seqlen_dtype), + cp_h0: T.Tensor([cp_batch_size, H, DK, DV], dtype=res_dtype), + ): + with T.Kernel( + T.ceildiv(DV, block_DV) * H * raw_batch_size, threads=128 + ) as (bbhv,): + bbh, bv = ( + bbhv // T.ceildiv(DV, block_DV), + bbhv % T.ceildiv(DV, block_DV), + ) + bb, bh = bbh // H, bbh % H + + seq_start_idx = seq_map_r2c[bb] + seq_end_idx = seq_map_r2c[bb + 1] + num_iters = seq_end_idx - seq_start_idx + + h_fragment = T.alloc_fragment((DK, block_DV), dtype=accum_dtype) + T.copy( + raw_h0[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV], + h_fragment, + ) + + kernel_body( + bb, + bh, + bv, + seq_start_idx, + seq_end_idx, + num_iters, + ht_buffer, + mt_buffer, + fallback_mask, + seq_map_r2c, + cp_h0, + h_fragment, + ) + + else: + + @T.prim_func + def tilelang_correct_h0_kernel( + ht_buffer: T.Tensor([cp_batch_size, H, DK, DV], dtype=buffer_dtype), + mt_buffer: T.Tensor([cp_batch_size, H, DK, DK], dtype=buffer_dtype), + fallback_mask: T.Tensor([cp_batch_size, H], dtype=mask_dtype), + seq_map_r2c: T.Tensor([raw_batch_size + 1], dtype=seqlen_dtype), + cp_h0: T.Tensor([cp_batch_size, H, DK, DV], dtype=res_dtype), + ): + with T.Kernel( + T.ceildiv(DV, block_DV) * H * raw_batch_size, threads=128 + ) as (bbhv,): + bbh, bv = ( + bbhv // T.ceildiv(DV, block_DV), + bbhv % T.ceildiv(DV, block_DV), + ) + bb, bh = bbh // H, bbh % H + + seq_start_idx = seq_map_r2c[bb] + seq_end_idx = seq_map_r2c[bb + 1] + num_iters = seq_end_idx - seq_start_idx + + h_fragment = T.alloc_fragment((DK, block_DV), dtype=accum_dtype) + T.clear(h_fragment) + + kernel_body( + bb, + bh, + bv, + seq_start_idx, + seq_end_idx, + num_iters, + ht_buffer, + mt_buffer, + fallback_mask, + seq_map_r2c, + cp_h0, + h_fragment, + ) + + return tilelang_correct_h0_kernel + + +def correct_initial_states( + raw_h0: torch.Tensor + | None, # [raw_batch_size, num_v_heads, k_head_dim, v_head_dim] + ht_buffer: torch.Tensor, # [cp_batch_size, num_v_heads, k_head_dim, v_head_dim] + mt_buffer: torch.Tensor, # [cp_batch_size, num_v_heads, k_head_dim, k_head_dim] + fallback_mask: torch.Tensor, # [cp_batch_size, num_v_heads] + seq_map_r2c: torch.Tensor, # [raw_batch_size + 1] +): + cp_batch_size = fallback_mask.shape[0] + _, num_heads, k_head_dim, v_head_dim = ht_buffer.shape + assert k_head_dim == v_head_dim == 128 + + if raw_h0 is None: + res_dtype = torch.float32 + use_raw_h0 = False + else: + res_dtype = raw_h0.dtype + use_raw_h0 = True + + tilelang_correct_h0_kernel = tilelang_correct_h0( + H=num_heads, + DK=k_head_dim, + DV=v_head_dim, + res_dtype=res_dtype, + accum_dtype="float32", + buffer_dtype=ht_buffer.dtype, + seqlen_dtype=seq_map_r2c.dtype, + mask_dtype=fallback_mask.dtype, + use_raw_h0=use_raw_h0, + ) + cp_h0 = torch.empty( + (cp_batch_size, num_heads, k_head_dim, v_head_dim), + dtype=res_dtype, + device=ht_buffer.device, + ) + if use_raw_h0: + tilelang_correct_h0_kernel( + raw_h0, + ht_buffer, + mt_buffer, + fallback_mask, + seq_map_r2c, + cp_h0, + ) + else: + tilelang_correct_h0_kernel( + ht_buffer, + mt_buffer, + fallback_mask, + seq_map_r2c, + cp_h0, + ) + + return cp_h0 diff --git a/flash_qla/ops/gated_delta_rule/chunk/hopper/fused_bwd.py b/flash_qla/ops/gated_delta_rule/chunk/hopper/fused_bwd.py new file mode 100644 index 0000000..e0072dd --- /dev/null +++ b/flash_qla/ops/gated_delta_rule/chunk/hopper/fused_bwd.py @@ -0,0 +1,985 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] + +import torch +import tilelang +import tilelang.language as T + +from flash_qla.utils import prepare_chunk_offsets + + +@tilelang.jit( + # out_idx=[-5, -4, -3, -2, -1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + tilelang.PassConfigKey.TL_DISABLE_DATA_RACE_CHECK: True, + }, +) +def tilelang_fused_chunk_gdr_bwd( + H, + Hg, + DK, + DV, + chunk_size, + scale, + accum_dtype, + qkva_dtype, + g_dtype, + b_dtype, + h_dtype, + o_dtype, + seqlen_dtype, + is_varlen, + use_dht, +): + batch_size = T.dynamic("batch_size") + num_tokens = T.dynamic("num_tokens") + num_chunks = T.dynamic("num_chunks") + block_S = chunk_size + + if is_varlen: + q_shape = (1, num_tokens, Hg, DK) + k_shape = (1, num_tokens, Hg, DK) + v_shape = (1, num_tokens, H, DV) + o_shape = (1, num_tokens, H, DV) + a_shape = (1, num_tokens, H, chunk_size) + g_shape = (1, num_tokens, H) + b_shape = (1, num_tokens, H) + h_shape = (1, num_chunks, H, DK, DV) + else: + q_shape = (batch_size, num_tokens, Hg, DK) + k_shape = (batch_size, num_tokens, Hg, DK) + v_shape = (batch_size, num_tokens, H, DV) + o_shape = (batch_size, num_tokens, H, DV) + a_shape = (batch_size, num_tokens, H, chunk_size) + g_shape = (batch_size, num_tokens, H) + b_shape = (batch_size, num_tokens, H) + h_shape = (batch_size, num_chunks, H, DK, DV) + h0_shape = (batch_size, H, DK, DV) + ht_shape = (batch_size, H, DK, DV) + + @T.prim_func + def tilelang_fused_chunk_gdr_bwd_kernel( + do: T.Tensor(o_shape, dtype=o_dtype), + dht: T.Tensor(ht_shape, dtype=accum_dtype), + q: T.Tensor(q_shape, dtype=qkva_dtype), + k: T.Tensor(k_shape, dtype=qkva_dtype), + v: T.Tensor(v_shape, dtype=qkva_dtype), + a: T.Tensor(a_shape, dtype=qkva_dtype), + g: T.Tensor(g_shape, dtype=g_dtype), + b: T.Tensor(b_shape, dtype=b_dtype), + h: T.Tensor(h_shape, dtype=h_dtype), + cu_seqlens: T.Tensor([batch_size + 1], dtype=seqlen_dtype), + chunk_offsets: T.Tensor([batch_size + 1], dtype=seqlen_dtype), + dq: T.Tensor(v_shape, dtype=qkva_dtype), + dk: T.Tensor(v_shape, dtype=qkva_dtype), + dv: T.Tensor(v_shape, dtype=qkva_dtype), + dg: T.Tensor(g_shape, dtype=g_dtype), + db: T.Tensor(b_shape, dtype=b_dtype), + dh0: T.Tensor(h0_shape, dtype=accum_dtype), + ): + with T.Kernel(batch_size * H, threads=512) as (bbh,): + bb, bh = bbh // H, bbh % H + bhg = bh // (H // Hg) + + batch_idx = T.alloc_var("int32") + seq_start_idx = T.alloc_var("int32") + seq_end_idx = T.alloc_var("int32") + chunk_start_idx = T.alloc_var("int32") + batch_idx = 0 if is_varlen else bb + seq_start_idx = cu_seqlens[bb] if is_varlen else 0 + seq_end_idx = cu_seqlens[bb + 1] if is_varlen else num_tokens + chunk_start_idx = chunk_offsets[bb] if is_varlen else 0 + + num_iters = T.alloc_var("int32") + num_iters = T.ceildiv(seq_end_idx - seq_start_idx, block_S) + + # 2+2+2+2 + 1 + 4 = 13 units + do_shared = T.alloc_shared((block_S, DV), dtype=o_dtype) + q_shared = T.alloc_shared((block_S, DK), dtype=qkva_dtype) + k_shared = T.alloc_shared((block_S, DK), dtype=qkva_dtype) + v_shared = T.alloc_shared((block_S, DV), dtype=qkva_dtype) + a_shared = T.alloc_shared((block_S, block_S), dtype=qkva_dtype) + h_shared = T.alloc_shared((DK, DV), dtype=h_dtype) + g_shared = T.alloc_shared((block_S), dtype=accum_dtype, scope="shared") + g_exp_shared = T.alloc_shared((block_S), dtype=accum_dtype, scope="shared") + g_rev_exp_shared = T.alloc_shared( + (block_S), dtype=accum_dtype, scope="shared" + ) + b_shared = T.alloc_shared((block_S), dtype=accum_dtype, scope="shared") + + # 2 units + dqkv_shared = T.alloc_shared((block_S, DK), dtype=qkva_dtype) + dg_shared = T.alloc_shared((block_S), dtype=accum_dtype, scope="shared") + db_shared = T.alloc_shared((block_S), dtype=accum_dtype, scope="shared") + + # 1+1 + 2+2+2 + 4 = 12 units + tmp_shared_1_1 = T.alloc_shared((block_S, block_S), dtype=qkva_dtype) + tmp_shared_1_2 = T.alloc_shared((block_S, block_S), dtype=qkva_dtype) + tmp_shared_1_3 = T.alloc_shared((block_S, block_S), dtype=qkva_dtype) + tmp_shared_2_1 = T.alloc_shared((block_S, DK), dtype=qkva_dtype) + tmp_shared_2_2 = T.alloc_shared((block_S, DK), dtype=qkva_dtype) + tmp_shared_2_3 = T.alloc_shared((block_S, DK), dtype=qkva_dtype) + tmp_shared_4_1 = T.alloc_shared((DK, DV), dtype=qkva_dtype) + + # CONSUMER_K + dk_fragment = T.alloc_fragment((block_S, DK), dtype=accum_dtype) + dv_fragment = T.alloc_fragment((block_S, DK), dtype=accum_dtype) + odot_fragment_1 = T.alloc_fragment((block_S, DK), dtype=accum_dtype) + dg_fragment_1 = T.alloc_fragment((block_S), dtype=accum_dtype) + dg_last_local_1 = T.alloc_fragment((1), dtype=accum_dtype) + + # CONSUMER_A + mask_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype) + p_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype) + a_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype) + dp_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype) + da_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype) + u_fragment = T.alloc_fragment((block_S, DK), dtype=accum_dtype) + dq_fragment = T.alloc_fragment((block_S, DK), dtype=accum_dtype) + db_fragment = T.alloc_fragment((block_S), dtype=accum_dtype) + odot_fragment_2 = T.alloc_fragment((block_S, DK), dtype=accum_dtype) + dg_fragment_2 = T.alloc_fragment((block_S), dtype=accum_dtype) + + # CONSUMER_S + dh_fragment = T.alloc_fragment((DK, DV), dtype=accum_dtype) + _odot_fragment_3 = T.alloc_fragment((DK, DV), dtype=accum_dtype) + reduce_fragment = T.alloc_fragment((128, 2), dtype=accum_dtype) + dg_last_local_3 = T.alloc_fragment((1), dtype=accum_dtype) + g_last_local_3 = T.alloc_local((1), dtype=accum_dtype) + + # 16 stages + bar_00 = T.alloc_barrier(arrive_count=448) + bar_01 = T.alloc_barrier(arrive_count=384) + bar_02 = T.alloc_barrier(arrive_count=288) + bar_03 = T.alloc_barrier(arrive_count=256) + bar_04 = T.alloc_barrier(arrive_count=416) + bar_05 = T.alloc_barrier(arrive_count=288) + bar_06 = T.alloc_barrier(arrive_count=256) + bar_07 = T.alloc_barrier(arrive_count=256) + bar_08 = T.alloc_barrier(arrive_count=384) + bar_09 = T.alloc_barrier(arrive_count=256) + bar_10 = T.alloc_barrier(arrive_count=288) + bar_11 = T.alloc_barrier(arrive_count=256) + bar_12 = T.alloc_barrier(arrive_count=128) + bar_13 = T.alloc_barrier(arrive_count=256) + bar_14 = T.alloc_barrier(arrive_count=256) + bar_15 = T.alloc_barrier(arrive_count=256) + + T.annotate_layout( + { + do_shared: tilelang.layout.make_swizzled_layout(do_shared), + q_shared: tilelang.layout.make_swizzled_layout(q_shared), + k_shared: tilelang.layout.make_swizzled_layout(k_shared), + v_shared: tilelang.layout.make_swizzled_layout(v_shared), + a_shared: tilelang.layout.make_swizzled_layout(a_shared), + h_shared: tilelang.layout.make_swizzled_layout(h_shared), + dqkv_shared: tilelang.layout.make_swizzled_layout(dqkv_shared), + tmp_shared_1_1: tilelang.layout.make_swizzled_layout( + tmp_shared_1_1 + ), + tmp_shared_1_2: tilelang.layout.make_swizzled_layout( + tmp_shared_1_2 + ), + tmp_shared_1_3: tilelang.layout.make_swizzled_layout( + tmp_shared_1_3 + ), + tmp_shared_2_1: tilelang.layout.make_swizzled_layout( + tmp_shared_2_1 + ), + tmp_shared_2_2: tilelang.layout.make_swizzled_layout( + tmp_shared_2_2 + ), + tmp_shared_2_3: tilelang.layout.make_swizzled_layout( + tmp_shared_2_3 + ), + tmp_shared_4_1: tilelang.layout.make_swizzled_layout( + tmp_shared_4_1 + ), + } + ) + + # T.use_swizzle(10) + + tx = T.get_thread_binding() + + PRODUCER_NREG = 24 + CONSUMER_K_NREG = 144 + CONSUMER_A_NREG = 176 + CONSUMER_S_NREG = 160 + + # Prefetch the last chunk of data + T.copy( + h[batch_idx, chunk_start_idx + num_iters - 1, bh, 0:DK, 0:DV], h_shared + ) + for j_s, j_k in T.Parallel(block_S, DK): + if seq_start_idx + (num_iters - 1) * block_S + j_s < seq_end_idx: + q_shared[j_s, j_k] = q[ + batch_idx, + seq_start_idx + (num_iters - 1) * block_S + j_s, + bhg, + j_k, + ] + else: + q_shared[j_s, j_k] = 0 + for j_s, j_k in T.Parallel(block_S, DK): + if seq_start_idx + (num_iters - 1) * block_S + j_s < seq_end_idx: + k_shared[j_s, j_k] = k[ + batch_idx, + seq_start_idx + (num_iters - 1) * block_S + j_s, + bhg, + j_k, + ] + else: + k_shared[j_s, j_k] = 0 + for j_s, j_v in T.Parallel(block_S, DV): + if seq_start_idx + (num_iters - 1) * block_S + j_s < seq_end_idx: + v_shared[j_s, j_v] = v[ + batch_idx, + seq_start_idx + (num_iters - 1) * block_S + j_s, + bh, + j_v, + ] + else: + v_shared[j_s, j_v] = 0 + for j_s, j_t in T.Parallel(block_S, block_S): + if seq_start_idx + (num_iters - 1) * block_S + j_s < seq_end_idx: + a_shared[j_s, j_t] = a[ + batch_idx, + seq_start_idx + (num_iters - 1) * block_S + j_s, + bh, + j_t, + ] + else: + a_shared[j_s, j_t] = 0 + for j_s, j_v in T.Parallel(block_S, DV): + if seq_start_idx + (num_iters - 1) * block_S + j_s < seq_end_idx: + do_shared[j_s, j_v] = do[ + batch_idx, + seq_start_idx + (num_iters - 1) * block_S + j_s, + bh, + j_v, + ] + else: + do_shared[j_s, j_v] = 0 + for j_s in T.Parallel(block_S): + if seq_start_idx + (num_iters - 1) * block_S + j_s < seq_end_idx: + g_shared[j_s] = g[ + batch_idx, seq_start_idx + (num_iters - 1) * block_S + j_s, bh + ] + else: + g_shared[j_s] = g[batch_idx, seq_end_idx - 1, bh] + for j_s in T.Parallel(block_S): + if seq_start_idx + (num_iters - 1) * block_S + j_s < seq_end_idx: + b_shared[j_s] = b[ + batch_idx, seq_start_idx + (num_iters - 1) * block_S + j_s, bh + ] + else: + b_shared[j_s] = 0 + + if tx < 128: + T.set_max_nreg(CONSUMER_S_NREG, 1) + + if use_dht: + T.copy(dht[bb, bh, 0:DK, 0:DV], dh_fragment) + else: + T.clear(dh_fragment) + T.copy(dh_fragment, tmp_shared_4_1) + + for i_s in T.serial(num_iters): + T.barrier_arrive(bar_00) + + # 00 + T.barrier_wait(bar_00, (i_s + 0) % 2) + for j_s in T.Parallel(block_S): + g_exp_shared[j_s] = T.exp2(g_shared[j_s] * 1.442695) + g_rev_exp_shared[j_s] = T.exp2( + (g_shared[block_S - 1] - g_shared[j_s]) * 1.442695 + ) + T.barrier_arrive(bar_01) + + # 01, 02, 03 + T.barrier_wait(bar_01, (i_s + 0) % 2) + g_last_local_3[0] = g_exp_shared[block_S - 1] + # dS0 = g_last * dSt + for j_k, j_v in T.Parallel(DK, DV): + dh_fragment[j_k, j_v] *= g_last_local_3[0] + T.barrier_arrive(bar_04) + + # 04, 05, 06, 07 + T.barrier_wait(bar_04, (i_s + 0) % 2) + # dg_last += sum(dS0 * S0) + T.clear(reduce_fragment) + for j_k, j_v in T.Parallel(DK, DV): + reduce_fragment[ + j_k % 64 // 16 * 32 + j_k % 8 * 4 + j_v % 8 // 2, j_v % 2 + ] += dh_fragment[j_k, j_v] * h_shared[j_k, j_v] + T.barrier_arrive(bar_08) + T.barrier_wait(bar_08, (i_s + 0) % 2) + T.barrier_wait(bar_09, (i_s + 0) % 2) + + # 10 + T.barrier_wait(bar_10, (i_s + 0) % 2) + T.reduce_sum( + T.reshape(reduce_fragment, (128 * 2,)), + dg_last_local_3, + dim=0, + clear=True, + ) + dg_shared[block_S - 1] += dg_last_local_3[0] + T.barrier_arrive(bar_11) + + # 11 + T.barrier_wait(bar_11, (i_s + 0) % 2) + # dS0 += K^T @ dVg + T.gemm_v1( + tmp_shared_2_2, + tmp_shared_2_3, + dh_fragment, + transpose_A=True, + clear_accum=False, + ) + T.barrier_arrive(bar_12) + T.barrier_wait(bar_12, (i_s + 0) % 2) + + # 13 + T.barrier_wait(bar_13, (i_s + 0) % 2) + # dOg = s * g * dO + for j_s, j_v in T.Parallel(block_S, DV): + tmp_shared_2_3[j_s, j_v] = ( + scale * do_shared[j_s, j_v] * g_exp_shared[j_s] + ) + T.barrier_arrive(bar_14) + + # 14 + T.barrier_wait(bar_14, (i_s + 0) % 2) + # dS0 += Q^T @ dOg + T.gemm_v1( + tmp_shared_2_1, + tmp_shared_2_3, + dh_fragment, + transpose_A=True, + clear_accum=False, + ) + T.barrier_arrive(bar_15) + + # 15 + T.barrier_wait(bar_15, (i_s + 0) % 2) + # S4[1] = dS0 + T.copy(dh_fragment, tmp_shared_4_1) + + if use_dht: + T.copy(dh_fragment, dh0[bb, bh, 0:DK, 0:DV]) + + elif tx < 256: + T.set_max_nreg(CONSUMER_K_NREG, 1) + + for i_s in T.serial(num_iters): + T.barrier_arrive(bar_00) + + # 16 == 00 + T.barrier_wait(bar_00, (i_s + 0) % 2) + # S2[S] dK + if i_s > 0: + T.copy(dk_fragment, dqkv_shared) + T.barrier_arrive(bar_01) + + # 01 + T.barrier_wait(bar_01, (i_s + 0) % 2) + # dV' = K @ dSt + T.gemm_v1(k_shared, tmp_shared_4_1, dv_fragment, clear_accum=True) + # dV' = g_last/g * dV' + for j_s, j_v in T.Parallel(block_S, DV): + dv_fragment[j_s, j_v] *= g_rev_exp_shared[j_s] + T.barrier_arrive(bar_02) + + # 02 + T.barrier_wait(bar_02, (i_s + 0) % 2) + # dV' += Pg^T @ dO + T.gemm_v1( + tmp_shared_1_1, + do_shared, + dv_fragment, + transpose_A=True, + clear_accum=False, + ) + T.barrier_arrive(bar_03) + + # 03 + T.barrier_wait(bar_03, (i_s + 0) % 2) + # S2[1] dV' + T.copy(dv_fragment, tmp_shared_2_1) + T.barrier_arrive(bar_04) + + # 04 + T.barrier_wait(bar_04, (i_s + 0) % 2) + # dV = Ag^T @ dV' + T.gemm_v1( + tmp_shared_1_2, + tmp_shared_2_1, + dv_fragment, + transpose_A=True, + clear_accum=True, + ) + # S2[S] dV + T.copy(dv_fragment, dqkv_shared) + T.barrier_arrive(bar_05) + + # 05 + T.barrier_wait(bar_05, (i_s + 0) % 2) + # dVg = -g * dV + for j_s, j_v in T.Parallel(block_S, DV): + dv_fragment[j_s, j_v] = ( + -dv_fragment[j_s, j_v] * g_exp_shared[j_s] + ) + # dg += sum(dVg * U) + T.copy(tmp_shared_2_3, odot_fragment_1) + for j_s, j_v in T.Parallel(block_S, DV): + odot_fragment_1[j_s, j_v] *= dv_fragment[j_s, j_v] + T.reduce_sum(odot_fragment_1, dg_fragment_1, dim=1, clear=True) + T.copy(dg_fragment_1, dg_shared) + # S2[3] dVg + T.copy(dv_fragment, tmp_shared_2_3) + T.barrier_arrive(bar_06) + + # 06 + T.barrier_wait(bar_06, (i_s + 0) % 2) + # S2[2] K + T.copy(k_shared, odot_fragment_1) + T.copy(odot_fragment_1, tmp_shared_2_2) + T.barrier_arrive(bar_07) + + # 07 + T.barrier_wait(bar_07, (i_s + 0) % 2) + # dK = V' @ dSt^T + T.gemm_v1( + tmp_shared_2_1, + tmp_shared_4_1, + dk_fragment, + transpose_B=True, + clear_accum=True, + ) + T.barrier_arrive(bar_08) + + # 08 + T.barrier_wait(bar_08, (i_s + 0) % 2) + # dK = g_last/g * dK + for j_s, j_k in T.Parallel(block_S, DK): + dk_fragment[j_s, j_k] *= g_rev_exp_shared[j_s] + # dg -= sum(K * dK) + for j_s, j_k in T.Parallel(block_S, DK): + odot_fragment_1[j_s, j_k] *= -dk_fragment[j_s, j_k] + T.reduce_sum(odot_fragment_1, dg_fragment_1, dim=1, clear=True) + for j_s in T.Parallel(block_S): + dg_shared[j_s] += dg_fragment_1[j_s] + # dg_last += sum(K * dK) + T.reduce_sum(dg_fragment_1, dg_last_local_1, dim=0, clear=True) + # Sg[S] dg + dg_shared[block_S - 1] -= dg_last_local_1[0] + T.barrier_arrive(bar_09) + + # 09 + T.barrier_wait(bar_09, (i_s + 0) % 2) + # dK += dVg @ S0^T + T.gemm_v1( + tmp_shared_2_3, + h_shared, + dk_fragment, + transpose_B=True, + clear_accum=False, + ) + T.barrier_arrive(bar_10) + T.barrier_wait(bar_10, (i_s + 0) % 2) + + # 12 + T.barrier_wait(bar_12, (i_s + 0) % 2) + # dK += dP^T @ Q + T.gemm_v1( + tmp_shared_1_1, + tmp_shared_2_1, + dk_fragment, + transpose_A=True, + clear_accum=False, + ) + T.barrier_arrive(bar_13) + T.barrier_wait(bar_13, (i_s + 0) % 2) + + # 15 + T.barrier_wait(bar_15, (i_s + 0) % 2) + # dK += dAs @ K + T.gemm_v1( + tmp_shared_1_2, tmp_shared_2_2, dk_fragment, clear_accum=False + ) + + for j_s, j_k in T.Parallel(block_S, DK): + if seq_start_idx + j_s < seq_end_idx: + dk[batch_idx, seq_start_idx + j_s, bh, j_k] = dk_fragment[ + j_s, j_k + ] + + elif tx < 384: + T.set_max_nreg(CONSUMER_A_NREG, 1) + + for i_s in T.serial(num_iters): + T.barrier_arrive(bar_00) + + # 00 + T.barrier_wait(bar_00, (i_s + 0) % 2) + # P = Q @ K^T + T.gemm_v1( + q_shared, + k_shared, + p_fragment, + transpose_B=True, + clear_accum=True, + ) + T.barrier_arrive(bar_01) + + # 01 + T.barrier_wait(bar_01, (i_s + 0) % 2) + # G = Lower(diag(g) @ I @ diag(1/g)) + for j_s, j_t in T.Parallel(block_S, block_S): + mask_fragment[j_s, j_t] = g_shared[j_s] - g_shared[j_t] + for j_s, j_t in T.Parallel(block_S, block_S): + if j_s >= j_t: + mask_fragment[j_s, j_t] = T.exp2( + mask_fragment[j_s, j_t] * 1.442695 + ) + else: + mask_fragment[j_s, j_t] = 0 + # Pg = s * P * G + for j_s, j_t in T.Parallel(block_S, block_S): + p_fragment[j_s, j_t] *= mask_fragment[j_s, j_t] + for j_s, j_t in T.Parallel(block_S, block_S): + p_fragment[j_s, j_t] *= scale + # S1[1] Pg + T.copy(p_fragment, tmp_shared_1_1) + T.barrier_arrive(bar_02) + + # 02 + T.barrier_wait(bar_02, (i_s + 0) % 2) + # Ab = Ar * b + T.copy(a_shared, a_fragment) + for j_s, j_t in T.Parallel(block_S, block_S): + a_fragment[j_s, j_t] *= b_shared[j_t] + # Ag = G * Ab + for j_s, j_t in T.Parallel(block_S, block_S): + a_fragment[j_s, j_t] *= mask_fragment[j_s, j_t] + # S1[2] Ag + T.copy(a_fragment, tmp_shared_1_2) + T.barrier_arrive(bar_03) + + # 03 + T.barrier_wait(bar_03, (i_s + 0) % 2) + # U = K @ S0 + T.gemm_v1(k_shared, h_shared, u_fragment, clear_accum=True) + T.barrier_arrive(bar_04) + + # 04 + T.barrier_wait(bar_04, (i_s + 0) % 2) + # S2[3] U + T.copy(u_fragment, tmp_shared_2_3) + # W = V - g * U + for j_s, j_v in T.Parallel(block_S, DV): + u_fragment[j_s, j_v] *= -g_exp_shared[j_s] + for j_s, j_v in T.Parallel(block_S, DV): + u_fragment[j_s, j_v] += v_shared[j_s, j_v] + # S2[2] W + T.copy(u_fragment, tmp_shared_2_2) + T.barrier_arrive(bar_05) + + # 05 + T.barrier_wait(bar_05, (i_s + 0) % 2) + # dAg = dV' @ W^T + T.gemm_v1( + tmp_shared_2_1, + tmp_shared_2_2, + da_fragment, + transpose_B=True, + clear_accum=True, + ) + # V' = Ag @ W + T.gemm_v1( + tmp_shared_1_2, tmp_shared_2_2, u_fragment, clear_accum=True + ) + # S2[1] V' + T.copy(u_fragment, tmp_shared_2_1) + T.barrier_arrive(bar_06) + + # 06 + T.barrier_wait(bar_06, (i_s + 0) % 2) + # dPg = dO @ V'^T + T.gemm_v1( + do_shared, + tmp_shared_2_1, + dp_fragment, + transpose_B=True, + clear_accum=True, + ) + T.barrier_arrive(bar_07) + + # 07 + T.barrier_wait(bar_07, (i_s + 0) % 2) + # dAb = G * dAg + for j_s, j_t in T.Parallel(block_S, block_S): + da_fragment[j_s, j_t] *= mask_fragment[j_s, j_t] + # dg += sum((dPg * P) - (dPg * P)^T) + T.copy(tmp_shared_1_1, p_fragment) + for j_s, j_t in T.Parallel(block_S, block_S): + p_fragment[j_s, j_t] *= dp_fragment[j_s, j_t] + T.copy(p_fragment, tmp_shared_1_1) + for j_s, j_t in T.Parallel(block_S, block_S): + p_fragment[j_s, j_t] -= tmp_shared_1_1[j_t, j_s] + T.reduce_sum(p_fragment, dg_fragment_2, dim=1, clear=True) + # dP = s * G * dPg + for j_s, j_t in T.Parallel(block_S, block_S): + dp_fragment[j_s, j_t] *= mask_fragment[j_s, j_t] + for j_s, j_t in T.Parallel(block_S, block_S): + dp_fragment[j_s, j_t] *= scale + # S1[1] dP + T.copy(dp_fragment, tmp_shared_1_1) + T.barrier_arrive(bar_08) + + # 08 + T.barrier_wait(bar_08, (i_s + 0) % 2) + # dQ = dO @ S0^T + T.gemm_v1( + do_shared, + h_shared, + dq_fragment, + transpose_B=True, + clear_accum=True, + ) + T.barrier_arrive(bar_09) + + # 09 + T.barrier_wait(bar_09, (i_s + 0) % 2) + # dQ = s * g * dQ + for j_s, j_k in T.Parallel(block_S, DK): + dq_fragment[j_s, j_k] *= g_exp_shared[j_s] + for j_s, j_k in T.Parallel(block_S, DK): + dq_fragment[j_s, j_k] *= scale + # S2[1] Q + T.copy(q_shared, odot_fragment_2) + # dg += sum(Q * dQ) + T.copy(odot_fragment_2, tmp_shared_2_1) + for j_s, j_k in T.Parallel(block_S, DK): + odot_fragment_2[j_s, j_k] *= dq_fragment[j_s, j_k] + T.reduce_sum(odot_fragment_2, dg_fragment_2, dim=1, clear=False) + T.barrier_arrive(bar_10) + + # 10 + T.barrier_wait(bar_10, (i_s + 0) % 2) + # dQ += dP @ K + T.gemm_v1( + tmp_shared_1_1, tmp_shared_2_2, dq_fragment, clear_accum=False + ) + # S2[S] dQ + T.copy(dq_fragment, dqkv_shared) + T.barrier_arrive(bar_11) + + # 11, 12 + T.barrier_wait(bar_11, (i_s + 0) % 2) + # dAb * Ar + T.copy(a_shared, a_fragment) + for j_s, j_t in T.Parallel(block_S, block_S): + a_fragment[j_s, j_t] *= da_fragment[j_s, j_t] + T.copy(a_fragment, tmp_shared_1_3) + # dAb * Ab [ = G * dAg * Ab ] + for j_s, j_t in T.Parallel(block_S, block_S): + a_fragment[j_s, j_t] *= b_shared[j_t] + # dg += sum((dAb * Ab) - (dAb * Ab)^T) + T.copy(a_fragment, tmp_shared_1_2) + for j_s, j_t in T.Parallel(block_S, block_S): + a_fragment[j_s, j_t] -= tmp_shared_1_2[j_t, j_s] + T.reduce_sum(a_fragment, dg_fragment_2, dim=1, clear=False) + # Sg[S] dg + for j_s in T.Parallel(block_S): + dg_shared[j_s] += dg_fragment_2[j_s] + # db = sum((dAb * Ar)^T) + for j_s, j_t in T.Parallel(block_S, block_S): + a_fragment[j_s, j_t] = tmp_shared_1_3[j_t, j_s] + T.reduce_sum(a_fragment, db_fragment, dim=1, clear=True) + # dAr = dAb * b + for j_s, j_t in T.Parallel(block_S, block_S): + da_fragment[j_s, j_t] *= b_shared[j_t] + # S1[2] dAr + T.copy(da_fragment, tmp_shared_1_2) + T.barrier_arrive(bar_13) + + # 13 + T.barrier_wait(bar_13, (i_s + 0) % 2) + # dA = -Ar^T @ dAr @ Ar^T + T.gemm_v1( + a_shared, + tmp_shared_1_2, + da_fragment, + transpose_A=True, + clear_accum=True, + ) + T.copy(da_fragment, tmp_shared_1_2) + T.gemm_v1( + tmp_shared_1_2, + a_shared, + da_fragment, + transpose_B=True, + clear_accum=True, + ) + # At = K @ K^T + T.gemm_v1( + tmp_shared_2_2, + tmp_shared_2_2, + a_fragment, + transpose_B=True, + clear_accum=True, + ) + T.barrier_arrive(bar_14) + + # 14 + T.barrier_wait(bar_14, (i_s + 0) % 2) + for j_s, j_t in T.Parallel(block_S, block_S): + if j_s <= j_t: + da_fragment[j_s, j_t] = 0 + else: + da_fragment[j_s, j_t] = -da_fragment[j_s, j_t] + # db += sum(dA * At) + for j_s, j_t in T.Parallel(block_S, block_S): + a_fragment[j_s, j_t] *= da_fragment[j_s, j_t] + T.reduce_sum(a_fragment, db_fragment, dim=1, clear=False) + T.copy(db_fragment, db_shared) + # dAt = b * dA + for j_s, j_t in T.Parallel(block_S, block_S): + da_fragment[j_s, j_t] *= b_shared[j_s] + # dAs = dAt + dAt^T + T.copy(da_fragment, tmp_shared_1_2) + for j_s, j_t in T.Parallel(block_S, block_S): + da_fragment[j_s, j_t] += tmp_shared_1_2[j_t, j_s] + # S1[1] dAs + T.copy(da_fragment, tmp_shared_1_2) + T.barrier_arrive(bar_15) + T.barrier_wait(bar_15, (i_s + 0) % 2) + + else: + T.set_max_nreg(PRODUCER_NREG, 0) + + if tx < 384 + 32: + for i_s in T.serial(num_iters - 1): + chunk_idx = num_iters - i_s - 2 + left = seq_start_idx + chunk_idx * block_S + right = left + block_S + + T.barrier_arrive(bar_00) + T.barrier_wait(bar_00, (i_s + 0) % 2) + + T.barrier_wait(bar_03, (i_s + 0) % 2) + for j_s in T.Parallel(block_S): + g_shared[j_s] = g[batch_idx, left + j_s, bh] + + T.barrier_wait(bar_05, (i_s + 0) % 2) + T.copy(v[batch_idx, left:right, bh, 0:DV], v_shared) + + T.barrier_wait(bar_07, (i_s + 0) % 2) + T.copy(k[batch_idx, left:right, bhg, 0:DK], k_shared) + + T.barrier_wait(bar_10, (i_s + 0) % 2) + T.copy(q[batch_idx, left:right, bhg, 0:DK], q_shared) + + if num_iters > 0: + T.barrier_arrive(bar_00) + + elif tx < 384 + 64: + for i_s in T.serial(num_iters): + left = seq_start_idx + (num_iters - i_s - 1) * block_S + right = left + block_S + + T.barrier_arrive(bar_00) + T.barrier_wait(bar_00, (i_s + 0) % 2) + + T.barrier_wait(bar_01, (i_s + 0) % 2) + if i_s == 1: + for j_s, j_k in T.Parallel(block_S, DK): + if left + block_S + j_s < seq_end_idx: + dk[batch_idx, left + block_S + j_s, bh, j_k] = ( + dqkv_shared[j_s, j_k] + ) + elif i_s > 1: + T.copy( + dqkv_shared, + dk[ + batch_idx, + left + block_S : right + block_S, + bh, + 0:DK, + ], + ) + T.barrier_arrive(bar_04) + T.barrier_wait(bar_04, (i_s + 0) % 2) + + T.barrier_wait(bar_05, (i_s + 0) % 2) + if i_s == 0: + for j_s, j_v in T.Parallel(block_S, DV): + if left + j_s < seq_end_idx: + dv[batch_idx, left + j_s, bh, j_v] = dqkv_shared[ + j_s, j_v + ] + else: + T.copy(dqkv_shared, dv[batch_idx, left:right, bh, 0:DV]) + T.barrier_arrive(bar_10) + T.barrier_wait(bar_10, (i_s + 0) % 2) + + T.barrier_wait(bar_11, (i_s + 0) % 2) + if i_s == 0: + for j_s, j_k in T.Parallel(block_S, DK): + if left + j_s < seq_end_idx: + dq[batch_idx, left + j_s, bh, j_k] = dqkv_shared[ + j_s, j_k + ] + else: + T.copy(dqkv_shared, dq[batch_idx, left:right, bh, 0:DK]) + + elif tx < 384 + 96: + for i_s in T.serial(num_iters - 1): + chunk_idx = num_iters - i_s - 2 + left = seq_start_idx + chunk_idx * block_S + right = left + block_S + + T.barrier_arrive(bar_02) + T.barrier_wait(bar_02, (i_s + 0) % 2) + + T.barrier_wait(bar_10, (i_s + 0) % 2) + T.copy( + h[batch_idx, chunk_start_idx + chunk_idx, bh, 0:DK, 0:DV], + h_shared, + ) + + T.barrier_wait(bar_14, (i_s + 0) % 2) + T.copy(a[batch_idx, left:right, bh, 0:block_S], a_shared) + + T.copy(do[batch_idx, left:right, bh, 0:DV], do_shared) + + T.barrier_wait(bar_15, (i_s + 0) % 2) + for j_s in T.Parallel(block_S): + b_shared[j_s] = b[batch_idx, left + j_s, bh] + + if num_iters > 0: + T.barrier_wait(bar_00, (num_iters - 1) % 2) + T.barrier_arrive(bar_02) + + else: + for i_s in T.serial(num_iters): + left = seq_start_idx + (num_iters - i_s - 1) * block_S + + T.barrier_arrive(bar_05) + T.barrier_wait(bar_05, (i_s + 0) % 2) + + T.barrier_wait(bar_15, (i_s + 0) % 2) + + if i_s == 0: + for j_s in T.Parallel(block_S): + if left + j_s < seq_end_idx: + dg[batch_idx, left + j_s, bh] = dg_shared[j_s] + if (seq_end_idx - seq_start_idx) % block_S > 0: + dg[batch_idx, seq_end_idx - 1, bh] += dg_shared[ + block_S - 1 + ] + else: + for j_s in T.Parallel(block_S): + dg[batch_idx, left + j_s, bh] = dg_shared[j_s] + + if i_s == 0: + for j_s in T.Parallel(block_S): + if left + j_s < seq_end_idx: + db[batch_idx, left + j_s, bh] = db_shared[j_s] + else: + for j_s in T.Parallel(block_S): + db[batch_idx, left + j_s, bh] = db_shared[j_s] + + return tilelang_fused_chunk_gdr_bwd_kernel + + +def fused_gdr_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + g: torch.Tensor, + b: torch.Tensor, + do: torch.Tensor, + dht: torch.Tensor, + h: torch.Tensor, + scale: float | None = None, + cu_seqlens: torch.LongTensor | None = None, + chunk_size: int = 64, +): + batch_size, num_tokens, Hg, K = k.shape + _, _, H, V = v.shape + scale = scale or K ** (-0.5) + assert K == V == 128 + assert chunk_size == 64 + + if cu_seqlens is None: + real_batch_size = batch_size + cu_seqlens = torch.empty((batch_size + 1), dtype=torch.int32, device=k.device) + chunk_offsets = torch.empty( + (batch_size + 1), dtype=torch.int32, device=k.device + ) + is_varlen = False + else: + real_batch_size = len(cu_seqlens) - 1 + chunk_offsets, _ = prepare_chunk_offsets(cu_seqlens, chunk_size) + chunk_offsets = chunk_offsets.to(cu_seqlens.dtype) + is_varlen = True + + use_dht = dht is not None + if dht is None: + dht = torch.empty( + (real_batch_size, H, K, V), dtype=torch.float32, device=k.device + ) + dq = torch.empty_like(v) + dk = torch.empty_like(v) + dv = torch.empty_like(v) + dg = torch.empty_like(g) + db = torch.empty_like(b) + dh0 = torch.empty_like(dht) + + tilelang_fused_chunk_gdr_bwd_kernel = tilelang_fused_chunk_gdr_bwd( + H, + Hg, + K, + V, + chunk_size, + scale, + qkva_dtype=q.dtype, + g_dtype=g.dtype, + b_dtype=b.dtype, + h_dtype=h.dtype, + o_dtype=do.dtype, + seqlen_dtype=cu_seqlens.dtype, + accum_dtype="float32", + is_varlen=is_varlen, + use_dht=use_dht, + ) + tilelang_fused_chunk_gdr_bwd_kernel( + do, + dht, + q, + k, + v, + a, + g, + b, + h, + cu_seqlens, + chunk_offsets, + dq, + dk, + dv, + dg, + db, + dh0, + ) + + if not use_dht: + dh0 = None + + return dq, dk, dv, dg, db, dh0 diff --git a/flash_qla/ops/gated_delta_rule/chunk/hopper/fused_fwd.py b/flash_qla/ops/gated_delta_rule/chunk/hopper/fused_fwd.py new file mode 100644 index 0000000..2c90e3c --- /dev/null +++ b/flash_qla/ops/gated_delta_rule/chunk/hopper/fused_fwd.py @@ -0,0 +1,658 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] + +import torch +import tilelang +import tilelang.language as T + +from flash_qla.utils import prepare_chunk_offsets + + +MULTI_PROCESSOR_COUNT = torch.cuda.get_device_properties().multi_processor_count +TARGET_NUM_CTAS = int(MULTI_PROCESSOR_COUNT * 0.7) + + +@tilelang.jit( + # out_idx=[-3, -2, -1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + # tilelang.PassConfigKey.TL_DISABLE_THREAD_STORAGE_SYNC: True, + }, +) +def tilelang_fused_chunk_gdr_fwd( + H, + Hg, + DK, + DV, + chunk_size, + scale, + accum_dtype, + qkva_dtype, + g_dtype, + b_dtype, + h0_dtype, + ht_dtype, + h_dtype, + o_dtype, + seqlen_dtype, + use_initial_state, + store_final_state, + store_h, + store_o, + is_varlen, + is_cp, + block_DV=128, +): + batch_size = T.dynamic("batch_size") + num_tokens = T.dynamic("num_tokens") + num_chunks = T.dynamic("num_chunks") + raw_batch_size = T.dynamic("raw_batch_size") + block_S = chunk_size + + if is_varlen: + q_shape = (1, num_tokens, Hg, DK) + k_shape = (1, num_tokens, Hg, DK) + v_shape = (1, num_tokens, H, DV) + o_shape = (1, num_tokens, H, DV) + a_shape = (1, num_tokens, H, chunk_size) + g_shape = (1, num_tokens, H) + b_shape = (1, num_tokens, H) + h_shape = (1, num_chunks, H, DK, DV) + else: + q_shape = (batch_size, num_tokens, Hg, DK) + k_shape = (batch_size, num_tokens, Hg, DK) + v_shape = (batch_size, num_tokens, H, DV) + o_shape = (batch_size, num_tokens, H, DV) + a_shape = (batch_size, num_tokens, H, chunk_size) + g_shape = (batch_size, num_tokens, H) + b_shape = (batch_size, num_tokens, H) + h_shape = (batch_size, num_chunks, H, DK, DV) + h0_shape = (batch_size, H, DK, DV) + ht_shape = (raw_batch_size, H, DK, DV) + + @T.prim_func + def tilelang_fused_chunk_gdr_fwd_kernel( + q: T.Tensor(q_shape, dtype=qkva_dtype), + k: T.Tensor(k_shape, dtype=qkva_dtype), + v: T.Tensor(v_shape, dtype=qkva_dtype), + a: T.Tensor(a_shape, dtype=qkva_dtype), + g: T.Tensor(g_shape, dtype=g_dtype), + b: T.Tensor(b_shape, dtype=b_dtype), + h0: T.Tensor(h0_shape, dtype=h0_dtype), + cu_seqlens: T.Tensor([batch_size + 1], dtype=seqlen_dtype), + chunk_offsets: T.Tensor([batch_size + 1], dtype=seqlen_dtype), + cp_seq_map: T.Tensor([batch_size], dtype=seqlen_dtype), + raw_cu_seqlens: T.Tensor([raw_batch_size + 1], dtype=seqlen_dtype), + o: T.Tensor(o_shape, dtype=o_dtype), + h: T.Tensor(h_shape, dtype=h_dtype), + ht: T.Tensor(ht_shape, dtype=ht_dtype), + ): + with T.Kernel(T.ceildiv(DV, block_DV) * batch_size * H, threads=512) as (bbhv,): + bbh, bv = bbhv // T.ceildiv(DV, block_DV), bbhv % T.ceildiv(DV, block_DV) + bb, bh = bbh // H, bbh % H + bhg = bh // (H // Hg) + + batch_idx = T.alloc_var("int32") + seq_start_idx = T.alloc_var("int32") + seq_end_idx = T.alloc_var("int32") + seq_split_idx = T.alloc_var("int32") + chunk_start_idx = T.alloc_var("int32") + chunk_split_idx = T.alloc_var("int32") + + batch_idx = 0 if is_varlen else bb + seq_start_idx = cu_seqlens[bb] if is_varlen else 0 + seq_end_idx = cu_seqlens[bb + 1] if is_varlen else num_tokens + chunk_start_idx = chunk_offsets[bb] if is_varlen else 0 + + raw_batch_idx = T.alloc_var("int32") + raw_seq_end_idx = T.alloc_var("int32") + need_store_final_state = T.alloc_var("bool") + raw_batch_idx = cp_seq_map[bb] if is_cp else bb + raw_seq_end_idx = ( + raw_cu_seqlens[raw_batch_idx + 1] if is_cp else seq_end_idx + ) + need_store_final_state = store_final_state & ( + raw_seq_end_idx == seq_end_idx + ) + + num_iters = T.alloc_var("int32") + num_unmasked_iters = T.alloc_var("int32") + num_iters = T.ceildiv(seq_end_idx - seq_start_idx, block_S) + num_unmasked_iters = (seq_end_idx - seq_start_idx) // block_S + + q_shared = T.alloc_shared((2, block_S, DK), dtype=qkva_dtype) + k_shared = T.alloc_shared((2, block_S, DK), dtype=qkva_dtype) + v_shared = T.alloc_shared((2, block_S, block_DV), dtype=qkva_dtype) + a_shared = T.alloc_shared((2, block_S, block_S), dtype=qkva_dtype) + g_shared = T.alloc_shared((2, block_S), dtype=accum_dtype, scope="shared") + b_shared = T.alloc_shared((2, block_S), dtype=accum_dtype, scope="shared") + + o_shared = T.alloc_shared((block_S, block_DV), dtype=o_dtype) + h_shared = T.alloc_shared((DK, block_DV), dtype=qkva_dtype) + vd_shared = T.alloc_shared((block_S, block_DV), dtype=qkva_dtype) + vn_shared = T.alloc_shared((block_S, block_DV), dtype=qkva_dtype) + p_shared = T.alloc_shared((block_S, block_S), dtype=qkva_dtype) + g_exp_shared = T.alloc_shared((block_S), dtype=accum_dtype, scope="shared") + g_rev_exp_shared = T.alloc_shared( + (block_S), dtype=accum_dtype, scope="shared" + ) + + h_fragment = T.alloc_fragment((DK, block_DV), dtype=accum_dtype) + o_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) + v_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) + u_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) + p_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype) + a_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype) + g_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype) + g_last_local = T.alloc_local((1), dtype=accum_dtype) + + data_is_ready = T.alloc_barrier(arrive_count=[96] * 2) + data_is_free = T.alloc_barrier(arrive_count=[384] * 2) + + bar_o = T.alloc_barrier(arrive_count=128) + bar_0 = T.alloc_barrier(arrive_count=416) + bar_1 = T.alloc_barrier(arrive_count=256) + _bar_2 = T.alloc_barrier(arrive_count=128) + bar_3 = T.alloc_barrier(arrive_count=128) + bar_4 = T.alloc_barrier(arrive_count=128) + bar_5 = T.alloc_barrier(arrive_count=416) + + T.use_swizzle(10) + + tx = T.get_thread_binding() + + PRODUCER_NREG = 32 + CONSUMER_V_NREG = 128 + CONSUMER_S_NREG = 160 + CONSUMER_O_NREG = 128 + + if tx < 128: + T.set_max_nreg(CONSUMER_S_NREG, 1) + + # Initialize S + if use_initial_state: + T.copy( + h0[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV], + h_fragment, + ) + else: + T.clear(h_fragment) + + # Main Loop + for i_s in T.serial(num_iters): + # [STAGE 0] + T.barrier_wait(data_is_ready[i_s % 2], (i_s // 2 + 0) % 2) + T.barrier_arrive(bar_0) + + # [STAGE 0] 0 + T.barrier_wait(bar_0, i_s % 2) + # S4[S] S + T.copy(h_fragment, h_shared) + T.barrier_arrive(bar_1) + + # [STAGE 0] 2, 3, 4 + T.barrier_wait(bar_1, i_s % 2) + # S = g_last * S + g_last_local[0] = g_exp_shared[block_S - 1] + for j_k, j_v in T.Parallel(DK, block_DV): + h_fragment[j_k, j_v] *= g_last_local[0] + T.barrier_arrive(bar_5) + + # [STAGE 0] 5 + T.barrier_wait(bar_5, i_s % 2) + # S += K^T @ V' + T.gemm_v1( + k_shared[i_s % 2, :, :], + vn_shared, + h_fragment, + transpose_A=True, + clear_accum=False, + ) + + T.barrier_arrive(data_is_free[i_s % 2]) + + # Store final S + if need_store_final_state: + T.copy( + h_fragment, + ht[ + raw_batch_idx, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV + ], + ) + + elif tx < 256: + T.set_max_nreg(CONSUMER_V_NREG, 1) + + # Main Loop + for i_s in T.serial(num_iters): + # [STAGE 0] + T.barrier_wait(data_is_ready[i_s % 2], (i_s // 2 + 0) % 2) + T.barrier_arrive(bar_0) + + # [STAGE 0] 0 + T.barrier_wait(bar_0, i_s % 2) + # Precompute g, g_last/g + for j_s in T.Parallel(block_S): + g_exp_shared[j_s] = T.exp2(g_shared[i_s % 2, j_s] * 1.442695) + for j_s in T.Parallel(block_S): + g_rev_exp_shared[j_s] = T.if_then_else( + seq_start_idx + i_s * block_S + j_s < seq_end_idx, + T.exp2( + ( + g_shared[i_s % 2, block_S - 1] + - g_shared[i_s % 2, j_s] + ) + * 1.442695 + ), + 0.0, + ) + T.barrier_arrive(bar_1) + + # [STAGE 0] 1 + T.barrier_wait(bar_1, i_s % 2) + # U = K @ S + T.gemm_v1( + k_shared[i_s % 2, :, :], h_shared, u_fragment, clear_accum=True + ) + + # [STAGE 0] 2 + # W = V - g * U + for j_s, j_v in T.Parallel(block_S, block_DV): + u_fragment[j_s, j_v] *= -g_exp_shared[j_s] + for j_s, j_v in T.Parallel(block_S, block_DV): + u_fragment[j_s, j_v] += v_shared[i_s % 2, j_s, j_v] + # S2[V] W + for j_s, j_v in T.Parallel(block_S, block_DV): + v_shared[i_s % 2, j_s, j_v] = u_fragment[j_s, j_v] + + # [STAGE 0] 3 + T.barrier_wait(bar_3, i_s % 2) + # Vd = Ag @ W + T.gemm_v1( + a_shared[i_s % 2, :, :], + v_shared[i_s % 2, :, :], + v_fragment, + clear_accum=True, + ) + # S2[2] Vd + T.copy(v_fragment, vd_shared) + T.barrier_arrive(bar_4) + + # [STAGE 0] 4 + # V' = g_last/g Vd + for j_s, j_v in T.Parallel(block_S, block_DV): + v_fragment[j_s, j_v] *= g_rev_exp_shared[j_s] + # S2[1] V' + T.copy(v_fragment, vn_shared) + T.barrier_arrive(bar_5) + + T.barrier_wait(bar_5, i_s % 2) + + T.barrier_arrive(data_is_free[i_s % 2]) + + elif tx < 384: + T.set_max_nreg(CONSUMER_O_NREG, 1) + + # Main Loop + for i_s in T.serial(num_iters): + # [STAGE 0] + T.barrier_wait(data_is_ready[i_s % 2], (i_s // 2 + 0) % 2) + T.barrier_arrive(bar_0) + + # [STAGE 0] 0 + T.barrier_wait(bar_0, i_s % 2) + # P = Q K^T + T.gemm_v1( + q_shared[i_s % 2, :, :], + k_shared[i_s % 2, :, :], + p_fragment, + transpose_B=True, + clear_accum=True, + ) + + # [STAGE 0] 1 + # G = Lower(diag(g) @ I @ diag(1/g)) + for j_s, j_t in T.Parallel(block_S, block_S): + g_fragment[j_s, j_t] = ( + g_shared[i_s % 2, j_s] - g_shared[i_s % 2, j_t] + ) + for j_s, j_t in T.Parallel(block_S, block_S): + if j_s >= j_t: + g_fragment[j_s, j_t] = T.exp2( + g_fragment[j_s, j_t] * 1.442695 + ) + else: + g_fragment[j_s, j_t] = 0 + # Ag = G * Ar * b + for j_s, j_t in T.Parallel(block_S, block_S): + a_fragment[j_s, j_t] = a_shared[i_s % 2, j_s, j_t] + for j_s, j_t in T.Parallel(block_S, block_S): + a_fragment[j_s, j_t] *= g_fragment[j_s, j_t] + for j_s, j_t in T.Parallel(block_S, block_S): + a_fragment[j_s, j_t] *= b_shared[i_s % 2, j_t] + for j_s, j_t in T.Parallel(block_S, block_S): + a_shared[i_s % 2, j_s, j_t] = a_fragment[j_s, j_t] + + # [STAGE 0] 2 + T.barrier_wait(bar_1, i_s % 2) + # O = Q @ S + T.gemm_v1( + q_shared[i_s % 2, :, :], h_shared, o_fragment, clear_accum=True + ) + + # [STAGE 0] 3 + # Pg = s * G * P + for j_s, j_t in T.Parallel(block_S, block_S): + p_fragment[j_s, j_t] *= scale * g_fragment[j_s, j_t] + # S1[1] Pg + T.copy(p_fragment, p_shared) + T.barrier_arrive(bar_3) + # O = s * g * O + for j_s, j_k in T.Parallel(block_S, DK): + o_fragment[j_s, j_k] *= scale * g_exp_shared[j_s] + + # [STAGE 0] 4 + T.barrier_wait(bar_4, i_s % 2) + # O += Pg @ Vd + T.gemm_v1(p_shared, vd_shared, o_fragment, clear_accum=False) + T.barrier_arrive(bar_5) + + # [STAGE 0] 5 + T.barrier_wait(bar_5, i_s % 2) + # S2[S] O + T.copy(o_fragment, o_shared) + + T.barrier_arrive(data_is_free[i_s % 2]) + + T.barrier_arrive(bar_o) + + else: + T.set_max_nreg(PRODUCER_NREG, 0) + + if tx < 384 + 32: + for i_s in T.serial(num_iters): + T.barrier_wait(data_is_free[i_s % 2], (i_s // 2 + 1) % 2) + left = seq_start_idx + i_s * block_S + right = left + block_S + + # Load Q + T.copy( + q[batch_idx, left:right, bhg, 0:DK], q_shared[i_s % 2, :, :] + ) + # Load K + T.copy( + k[batch_idx, left:right, bhg, 0:DK], k_shared[i_s % 2, :, :] + ) + + T.barrier_arrive(data_is_ready[i_s % 2]) + + elif tx < 384 + 64: + for i_s in T.serial(num_iters): + T.barrier_wait(data_is_free[i_s % 2], (i_s // 2 + 1) % 2) + left = seq_start_idx + i_s * block_S + right = left + block_S + + # Load V + T.copy( + v[ + batch_idx, + left:right, + bh, + bv * block_DV : (bv + 1) * block_DV, + ], + v_shared[i_s % 2, :, :], + ) + # Load beta + if right <= seq_end_idx: + for j_s in T.Parallel(block_S): + b_shared[i_s % 2, j_s] = b[batch_idx, left + j_s, bh] + else: + for j_s in T.Parallel(block_S): + if left + j_s < seq_end_idx: + b_shared[i_s % 2, j_s] = b[ + batch_idx, left + j_s, bh + ] + else: + b_shared[i_s % 2, j_s] = 0 + + T.barrier_arrive(data_is_ready[i_s % 2]) + + elif tx < 384 + 96: + for i_s in T.serial(num_iters): + T.barrier_wait(data_is_free[i_s % 2], (i_s // 2 + 1) % 2) + left = seq_start_idx + i_s * block_S + right = left + block_S + + # Load A + T.copy( + a[batch_idx, left:right, bh, 0:block_S], + a_shared[i_s % 2, :, :], + ) + # Load gamma + if right <= seq_end_idx: + for j_s in T.Parallel(block_S): + g_shared[i_s % 2, j_s] = g[batch_idx, left + j_s, bh] + else: + for j_s in T.Parallel(block_S): + if left + j_s < seq_end_idx: + g_shared[i_s % 2, j_s] = g[ + batch_idx, left + j_s, bh + ] + else: + g_shared[i_s % 2, j_s] = g[ + batch_idx, seq_end_idx - 1, bh + ] + + T.barrier_arrive(data_is_ready[i_s % 2]) + + else: + for i_s in T.serial(num_unmasked_iters): + right = seq_start_idx + i_s * block_S + left = right - block_S + + T.barrier_arrive(bar_0) + + T.barrier_wait(bar_0, i_s % 2) + # Store O + if i_s > 0 and store_o: + T.copy( + o_shared, + o[ + batch_idx, + left:right, + bh, + bv * block_DV : (bv + 1) * block_DV, + ], + ) + T.barrier_arrive(bar_5) + + T.barrier_wait(bar_1, i_s % 2) + # Store S + if store_h: + T.copy( + h_shared, + h[ + batch_idx, + chunk_start_idx + i_s, + bh, + 0:DK, + bv * block_DV : (bv + 1) * block_DV, + ], + ) + + if num_unmasked_iters < num_iters: + seq_split_idx = seq_start_idx + num_unmasked_iters * block_S + chunk_split_idx = chunk_start_idx + num_unmasked_iters + + T.barrier_arrive(bar_0) + + T.barrier_wait(bar_0, num_unmasked_iters % 2) + # Store O + if num_unmasked_iters > 0 and store_o: + T.copy( + o_shared, + o[ + batch_idx, + seq_split_idx - block_S : seq_split_idx, + bh, + bv * block_DV : (bv + 1) * block_DV, + ], + ) + T.barrier_arrive(bar_5) + + T.barrier_wait(bar_1, num_unmasked_iters % 2) + # Store S + if store_h: + T.copy( + h_shared, + h[ + batch_idx, + chunk_split_idx, + bh, + 0:DK, + bv * block_DV : (bv + 1) * block_DV, + ], + ) + + seq_split_idx = seq_start_idx + (num_iters - 1) * block_S + + # Store O + T.barrier_wait(bar_o, 0) + if store_o: + for j_s, j_v in T.Parallel(block_S, block_DV): + with T.If(seq_split_idx + j_s < seq_end_idx): + with T.Then(): + o[ + batch_idx, + seq_split_idx + j_s, + bh, + bv * block_DV + j_v, + ] = o_shared[j_s, j_v] + + return tilelang_fused_chunk_gdr_fwd_kernel + + +def fused_gdr_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + g: torch.Tensor, + b: torch.Tensor, + scale: float | None = None, + initial_state: torch.Tensor | None = None, + output_final_state: bool = True, + output_h: bool = False, + output_o: bool = True, + cu_seqlens: torch.LongTensor | None = None, + cp_seq_map: torch.LongTensor | None = None, + raw_cu_seqlens: torch.LongTensor | None = None, + chunk_size: int = 64, +): + batch_size, num_tokens, Hg, K = k.shape + _, _, H, V = v.shape + scale = scale or K ** (-0.5) + assert K == V == 128 + assert chunk_size == 64 + + if cu_seqlens is None: + real_batch_size = batch_size + num_chunks = tilelang.cdiv(num_tokens, chunk_size) if output_h else 0 + cu_seqlens = torch.empty((batch_size + 1), dtype=torch.int32, device=k.device) + chunk_offsets = torch.empty( + (batch_size + 1), dtype=torch.int32, device=k.device + ) + seqlen_dtype = torch.int32 + is_varlen = False + else: + real_batch_size = len(cu_seqlens) - 1 + chunk_offsets, num_chunks = prepare_chunk_offsets(cu_seqlens, chunk_size) + chunk_offsets = chunk_offsets.to(cu_seqlens.dtype) + num_chunks = num_chunks if output_h else 0 + seqlen_dtype = cu_seqlens.dtype + is_varlen = True + + if cp_seq_map is None: + cp_seq_map = torch.empty( + (real_batch_size,), dtype=seqlen_dtype, device=k.device + ) + is_cp = False + else: + is_cp = True + + use_initial_state = initial_state is not None + if initial_state is None: + initial_state = torch.empty( + (real_batch_size, H, K, V), dtype=torch.float32, device=k.device + ) + h = torch.empty((batch_size, num_chunks, H, K, V), dtype=k.dtype, device=k.device) + if raw_cu_seqlens is None: + raw_cu_seqlens = torch.empty( + (real_batch_size + 1,), dtype=seqlen_dtype, device=k.device + ) + final_state = torch.empty( + (real_batch_size, H, K, V), dtype=torch.float32, device=k.device + ) + else: + final_state = torch.empty( + (raw_cu_seqlens.shape[0] - 1, H, K, V), dtype=torch.float32, device=k.device + ) + o = torch.empty_like(v) + + grid_size = real_batch_size * H + if grid_size >= TARGET_NUM_CTAS: + block_DV = 128 + elif grid_size * 2 >= TARGET_NUM_CTAS: + block_DV = 64 + else: + block_DV = 32 + + tilelang_fused_chunk_gdr_fwd_kernel = tilelang_fused_chunk_gdr_fwd( + H, + Hg, + K, + V, + chunk_size, + scale, + qkva_dtype=q.dtype, + g_dtype=g.dtype, + b_dtype=b.dtype, + h0_dtype=initial_state.dtype, + ht_dtype=final_state.dtype, + h_dtype=h.dtype, + o_dtype=o.dtype, + seqlen_dtype=seqlen_dtype, + accum_dtype="float32", + use_initial_state=use_initial_state, + store_final_state=output_final_state, + store_h=output_h, + store_o=output_o, + is_varlen=is_varlen, + is_cp=is_cp, + block_DV=block_DV, + ) + tilelang_fused_chunk_gdr_fwd_kernel( + q, + k, + v, + a, + g, + b, + initial_state, + cu_seqlens, + chunk_offsets, + cp_seq_map, + raw_cu_seqlens, + o, + h, + final_state, + ) + + if not output_final_state: + final_state = None + if not output_h: + h = None + if not output_o: + o = None + + return o, h, final_state diff --git a/flash_qla/ops/gated_delta_rule/chunk/hopper/kkt_solve.py b/flash_qla/ops/gated_delta_rule/chunk/hopper/kkt_solve.py new file mode 100644 index 0000000..14eca1d --- /dev/null +++ b/flash_qla/ops/gated_delta_rule/chunk/hopper/kkt_solve.py @@ -0,0 +1,345 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] + +from typing import Optional + +import torch +import tilelang +import tilelang.language as T + +from flash_qla.utils import prepare_chunk_indices + + +@tilelang.jit( + # out_idx=[-1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + # tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + # tilelang.PassConfigKey.TL_ENABLE_ASYNC_COPY: True, + }, +) +def tilelang_kkt_solve( + H, + Hg, + DK, + chunk_size, + accum_dtype, + qkva_dtype, + b_dtype, + seqlen_dtype, + is_varlen, +): + data_batch_size = T.dynamic("data_batch_size") + real_batch_size = T.dynamic("real_batch_size") + num_tokens = T.dynamic("num_tokens") + num_chunks = T.dynamic("num_chunks") + block_S = chunk_size + + k_shape = (data_batch_size, num_tokens, Hg, DK) + a_shape = (data_batch_size, num_tokens, H, chunk_size) + b_shape = (data_batch_size, num_tokens, H) + + @T.macro + def kernel_body( + bb, + bc, + bh, + bhg, + batch_idx, + chunk_idx, + seq_start_idx, + seq_end_idx, + k, + b, + a, + ): + left = seq_start_idx + chunk_idx * block_S + right = left + block_S + + k_shared = T.alloc_shared((block_S, DK), dtype=qkva_dtype) + b_shared = T.alloc_shared((block_S), dtype=accum_dtype, scope="shared") + a64_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype) + + a16i_row = T.alloc_fragment((4, 16), dtype=accum_dtype) + a16i_sum = T.alloc_fragment((4, 16), dtype=accum_dtype) + + a16i_shared = T.alloc_shared((4, 17, 16), dtype=accum_dtype) + a16o_shared = T.alloc_shared((2, 17, 16), dtype=accum_dtype) + a16o_fragment = T.alloc_fragment((2, 16, 16), dtype=accum_dtype) + + a32i_fragment = T.alloc_fragment((2, 32, 32), dtype=accum_dtype) + a32i0_shared = T.alloc_shared((32, 32), dtype=accum_dtype) + a32i1_shared = T.alloc_shared((32, 32), dtype=accum_dtype) + a32o_shared = T.alloc_shared((32, 32), dtype=accum_dtype) + a32o_fragment = T.alloc_fragment((32, 32), dtype=accum_dtype) + + a64_shared = T.alloc_shared((block_S, block_S), dtype=qkva_dtype) + + T.annotate_layout( + { + a16i_shared: tilelang.layout.make_linear_layout(a16i_shared), + a16o_shared: tilelang.layout.make_linear_layout(a16o_shared), + } + ) + + k_is_ready = T.alloc_barrier(arrive_count=32) + a_is_ready = T.alloc_barrier(arrive_count=128) + + tx = T.get_thread_binding() + + PRODUCER_NREG = 24 + CONSUMER_NREG = 64 + + if tx < 128: + T.set_max_nreg(CONSUMER_NREG, 1) + + # Load b + if right <= seq_end_idx: + for j_s in T.Parallel(block_S): + b_shared[j_s] = b[bb, left + j_s, bh] + else: + for j_s in T.Parallel(block_S): + if left + j_s < seq_end_idx: + b_shared[j_s] = b[bb, left + j_s, bh] + else: + b_shared[j_s] = 0 + + T.barrier_wait(k_is_ready, 0) + + # A = K @ K^T + T.gemm_v1( + k_shared, k_shared, a64_fragment, transpose_B=True, clear_accum=True + ) + + # A = b * A + for j_s, j_t in T.Parallel(block_S, block_S): + a64_fragment[j_s, j_t] *= b_shared[j_s] + + # A = I + StrictLower(A) + for j_s, j_t in T.Parallel(block_S, block_S): + if j_s < j_t: + a64_fragment[j_s, j_t] = 0 + elif j_s == j_t: + a64_fragment[j_s, j_t] = 1 + + # Prepare inversion input + for j_s, j_t in T.Parallel(block_S, block_S): + if j_s >= 32 and j_t < 32: + a32o_shared[j_s - 32, j_t] = -a64_fragment[j_s, j_t] + elif (j_s // 16) == (j_t // 16) + 1: + a16o_shared[j_s // 32, j_s % 16, j_t % 16] = -a64_fragment[j_s, j_t] + elif (j_s // 16) == (j_t // 16): + a16i_shared[j_s // 16, j_s % 16, j_t % 16] = a64_fragment[j_s, j_t] + + # Diagonal 4x16x16 + T.clear(a16i_row) + for k_s in T.unroll(1, 16): + for j_s, k_t in T.Parallel(4, 16): + if k_t < k_s: + a16i_row[j_s, k_t] = a16i_shared[j_s, k_s, k_t] + T.clear(a16i_sum) + for k_r in T.unroll(k_s): + for j_s, k_t in T.Parallel(4, 16): + a16i_sum[j_s, k_t] -= ( + a16i_shared[j_s, k_r, k_t] * a16i_row[j_s, k_r] + ) + for j_s, k_t in T.Parallel(4, 16): + if k_t < k_s: + a16i_shared[j_s, k_s, k_t] = a16i_sum[j_s, k_t] + + # First level 2x16x16 + T.clear(a16o_fragment) + for k_r in T.unroll(16): + for j_s, k_s, k_t in T.Parallel(2, 16, 16): + a16o_fragment[j_s, k_s, k_t] += ( + a16i_shared[j_s * 2 + 1, k_s, k_r] * a16o_shared[j_s, k_r, k_t] + ) + for j_s, k_s, k_t in T.Parallel(2, 16, 16): + a16o_shared[j_s, k_t, k_s] = a16o_fragment[j_s, k_s, k_t] + T.clear(a16o_fragment) + for k_r in T.unroll(16): + for j_s, k_s, k_t in T.Parallel(2, 16, 16): + a16o_fragment[j_s, k_s, k_t] += ( + a16o_shared[j_s, k_r, k_s] * a16i_shared[j_s * 2, k_r, k_t] + ) + T.copy(a16o_fragment, a16o_shared[:, 0:16, 0:16]) + + # Second level 1x32x32 + for j_s, k_s, k_t in T.Parallel(2, 32, 32): + if k_s < 16 and k_t >= 16: + a32i_fragment[j_s, k_s, k_t] = 0 + for j_s, k_s, k_t in T.Parallel(2, 32, 32): + if k_s >= 16 and k_t < 16: + a32i_fragment[j_s, k_s, k_t] = a16o_shared[j_s, k_s - 16, k_t] + for j_s, k_s, k_t in T.Parallel(2, 32, 32): + if k_s // 16 == k_t // 16: + a32i_fragment[j_s, k_s, k_t] = a16i_shared[ + j_s * 2 + k_s // 16, k_s % 16, k_t % 16 + ] + for j_s, k_s, k_t in T.Parallel(2, 32, 32): + if j_s == 0: + a32i0_shared[k_s, k_t] = a32i_fragment[j_s, k_s, k_t] + else: + a32i1_shared[k_s, k_t] = a32i_fragment[j_s, k_s, k_t] + T.gemm_v1(a32i1_shared, a32o_shared, a32o_fragment, clear_accum=True) + T.copy(a32o_fragment, a32o_shared) + T.gemm_v1(a32o_shared, a32i0_shared, a32o_fragment, clear_accum=True) + + # Combine inversion output + for j_s, k_s, k_t in T.Parallel(2, 32, 32): + a64_shared[j_s * 32 + k_s, j_s * 32 + k_t] = a32i_fragment[ + j_s, k_s, k_t + ] + for k_s, k_t in T.Parallel(32, 32): + a64_shared[32 + k_s, k_t] = a32o_fragment[k_s, k_t] + for k_s, k_t in T.Parallel(32, 32): + a64_shared[k_s, 32 + k_t] = 0 + + T.barrier_arrive(a_is_ready) + + else: + T.set_max_nreg(PRODUCER_NREG, 0) + + if tx < 128 + 32: + # Load K + T.copy(k[bb, left:right, bhg, 0:DK], k_shared) + + T.barrier_arrive(k_is_ready) + + elif tx < 128 + 64: + T.barrier_wait(a_is_ready, 0) + + # Save A (unmasked) + if right <= seq_end_idx: + T.copy(a64_shared, a[bb, left:right, bh, 0:block_S]) + + else: + T.barrier_wait(a_is_ready, 0) + + # Save A (masked) + if right > seq_end_idx: + for j_s, j_t in T.Parallel(block_S, block_S): + if left + j_s < seq_end_idx: + a[bb, left + j_s, bh, j_t] = a64_shared[j_s, j_t] + + if is_varlen: + + @T.prim_func + def tilelang_kkt_solve_kernel( + k: T.Tensor(k_shape, dtype=qkva_dtype), + b: T.Tensor(b_shape, dtype=b_dtype), + cu_seqlens: T.Tensor([real_batch_size + 1], dtype=seqlen_dtype), + chunk_indices: T.Tensor([num_chunks, 2], dtype=seqlen_dtype), + a: T.Tensor(a_shape, dtype=qkva_dtype), + ): + with T.Kernel(num_chunks * H, threads=256) as (bch,): + bc, bh = bch // H, bch % H + bhg = bh // (H // Hg) + + batch_idx = T.alloc_var("int32") + chunk_idx = T.alloc_var("int32") + seq_start_idx = T.alloc_var("int32") + seq_end_idx = T.alloc_var("int32") + + bb = 0 + batch_idx = chunk_indices[bc, 0] + chunk_idx = chunk_indices[bc, 1] + seq_start_idx = cu_seqlens[batch_idx] + seq_end_idx = cu_seqlens[batch_idx + 1] + + kernel_body( + bb, + bc, + bh, + bhg, + batch_idx, + chunk_idx, + seq_start_idx, + seq_end_idx, + k, + b, + a, + ) + + else: + + @T.prim_func + def tilelang_kkt_solve_kernel( + k: T.Tensor(k_shape, dtype=qkva_dtype), + b: T.Tensor(b_shape, dtype=b_dtype), + a: T.Tensor(a_shape, dtype=qkva_dtype), + num_chunks: T.int32, + ): + with T.Kernel(num_chunks * H, threads=256) as (bch,): + bc, bh = bch // H, bch % H + bhg = bh // (H // Hg) + + batch_idx = T.alloc_var("int32") + chunk_idx = T.alloc_var("int32") + seq_start_idx = T.alloc_var("int32") + seq_end_idx = T.alloc_var("int32") + + bb = bc % data_batch_size + batch_idx = bb + chunk_idx = bc // data_batch_size + seq_start_idx = 0 + seq_end_idx = num_tokens + + kernel_body( + bb, + bc, + bh, + bhg, + batch_idx, + chunk_idx, + seq_start_idx, + seq_end_idx, + k, + b, + a, + ) + + return tilelang_kkt_solve_kernel + + +def kkt_solve( + k: torch.Tensor, + b: torch.Tensor, + chunk_size: int = 64, + cu_seqlens: Optional[torch.LongTensor] = None, +): + batch_size, num_tokens, Hg, K = k.shape + _, _, H = b.shape + assert K == 128 + assert chunk_size == 64 + + if cu_seqlens is None: + num_chunks = batch_size * tilelang.cdiv(num_tokens, chunk_size) + seqlen_dtype = "int32" + is_varlen = False + else: + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) + seqlen_dtype = cu_seqlens.dtype + is_varlen = True + + a = torch.empty( + (batch_size, num_tokens, H, chunk_size), dtype=k.dtype, device=k.device + ) + + tilelang_kkt_solve_kernel = tilelang_kkt_solve( + H, + Hg, + K, + chunk_size, + qkva_dtype=k.dtype, + b_dtype=b.dtype, + seqlen_dtype=seqlen_dtype, + accum_dtype="float32", + is_varlen=is_varlen, + ) + if is_varlen: + tilelang_kkt_solve_kernel(k, b, cu_seqlens, chunk_indices, a) + else: + tilelang_kkt_solve_kernel(k, b, a, num_chunks) + + return a diff --git a/flash_qla/ops/gated_delta_rule/chunk/hopper/prepare_h.py b/flash_qla/ops/gated_delta_rule/chunk/hopper/prepare_h.py new file mode 100644 index 0000000..d1f3828 --- /dev/null +++ b/flash_qla/ops/gated_delta_rule/chunk/hopper/prepare_h.py @@ -0,0 +1,558 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] + +import torch +import tilelang +import tilelang.language as T + +from flash_qla.utils import prepare_chunk_offsets + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def tilelang_prepare_h( + H, + Hg, + DK, + DV, + chunk_size, + accum_dtype, + qkva_dtype, + g_dtype, + b_dtype, + h0_dtype, + ht_dtype, + h_dtype, + seqlen_dtype, + use_initial_state, + store_final_state, + store_h, + is_varlen, + is_cp, + num_stages=2, +): + batch_size = T.dynamic("batch_size") + num_tokens = T.dynamic("num_tokens") + num_chunks = T.dynamic("num_chunks") + block_S = chunk_size + + if is_varlen: + k_shape = (1, num_tokens, Hg, DK) + v_shape = (1, num_tokens, H, DV) + a_shape = (1, num_tokens, H, chunk_size) + g_shape = (1, num_tokens, H) + b_shape = (1, num_tokens, H) + h_shape = (1, num_chunks, H, DK, DV) + else: + k_shape = (batch_size, num_tokens, Hg, DK) + v_shape = (batch_size, num_tokens, H, DV) + a_shape = (batch_size, num_tokens, H, chunk_size) + g_shape = (batch_size, num_tokens, H) + b_shape = (batch_size, num_tokens, H) + h_shape = (batch_size, num_chunks, H, DK, DV) + h0_shape = (batch_size, H, DK, DV) + ht_shape = (batch_size, H, DK, DV) + m_shape = (batch_size, H, DK, DK) + + @T.prim_func + def tilelang_prepare_h_kernel( + k: T.Tensor(k_shape, dtype=qkva_dtype), + v: T.Tensor(v_shape, dtype=qkva_dtype), + a: T.Tensor(a_shape, dtype=qkva_dtype), + g: T.Tensor(g_shape, dtype=g_dtype), + b: T.Tensor(b_shape, dtype=b_dtype), + h0: T.Tensor(h0_shape, dtype=h0_dtype), + cu_seqlens: T.Tensor([batch_size + 1], dtype=seqlen_dtype), + chunk_offsets: T.Tensor([batch_size + 1], dtype=seqlen_dtype), + num_warmup_chunks: T.Tensor([batch_size, H], dtype=seqlen_dtype), + h: T.Tensor(h_shape, dtype=h_dtype), + ht: T.Tensor(ht_shape, dtype=ht_dtype), + mt: T.Tensor(m_shape, dtype=ht_dtype), + ): + with T.Kernel(batch_size * H, threads=512) as (bbh,): + bb, bh = bbh // H, bbh % H + bhg = bh // (H // Hg) + + batch_idx = T.alloc_var("int32") + seq_start_idx = T.alloc_var("int32") + seq_end_idx = T.alloc_var("int32") + _seq_split_idx = T.alloc_var("int32") + chunk_start_idx = T.alloc_var("int32") + _chunk_split_idx = T.alloc_var("int32") + + batch_idx = 0 if is_varlen else bb + seq_start_idx = cu_seqlens[bb] if is_varlen else 0 + seq_end_idx = cu_seqlens[bb + 1] if is_varlen else num_tokens + chunk_start_idx = chunk_offsets[bb] if is_varlen else 0 + + num_iters = T.alloc_var("int32") + num_iters = ( + num_warmup_chunks[bb, bh] + if is_cp + else T.ceildiv(seq_end_idx - seq_start_idx, block_S) + ) + + calc_mt = T.alloc_var("bool") + calc_mt = is_cp and num_iters >= T.ceildiv( + seq_end_idx - seq_start_idx, block_S + ) + seq_start_idx = ( + seq_end_idx - num_iters * block_S if is_cp else seq_start_idx + ) + + k_shared = T.alloc_shared((num_stages, block_S, DK), dtype=qkva_dtype) + v_shared = T.alloc_shared((num_stages, block_S, DV), dtype=qkva_dtype) + a_shared = T.alloc_shared((num_stages, block_S, block_S), dtype=qkva_dtype) + g_shared = T.alloc_shared( + (num_stages, block_S), dtype=accum_dtype, scope="shared" + ) + b_shared = T.alloc_shared( + (num_stages, block_S), dtype=accum_dtype, scope="shared" + ) + h_shared = T.alloc_shared((DK, DV), dtype=qkva_dtype) + x_shared = T.alloc_shared((block_S, DK), dtype=qkva_dtype) + y_shared = T.alloc_shared((block_S, DV), dtype=qkva_dtype) + m_shared_L = T.alloc_shared((DK, DK // 2), dtype=qkva_dtype) + m_shared_R = T.alloc_shared((DK, DK // 2), dtype=qkva_dtype) + z_shared_L = T.alloc_shared((block_S, DK // 2), dtype=qkva_dtype) + z_shared_R = T.alloc_shared((block_S, DK // 2), dtype=qkva_dtype) + g_rev_exp_shared = T.alloc_shared( + (block_S), dtype=accum_dtype, scope="shared" + ) + + h_fragment = T.alloc_fragment((DK, DV), dtype=accum_dtype) + x_fragment = T.alloc_fragment((block_S, DK), dtype=accum_dtype) + y_fragment = T.alloc_fragment((block_S, DV), dtype=accum_dtype) + m_fragment_L = T.alloc_fragment((DK, DK // 2), dtype=accum_dtype) + m_fragment_R = T.alloc_fragment((DK, DK // 2), dtype=accum_dtype) + z_fragment_L = T.alloc_fragment((block_S, DK // 2), dtype=accum_dtype) + z_fragment_R = T.alloc_fragment((block_S, DK // 2), dtype=accum_dtype) + g_last_local_S = T.alloc_local((1), dtype=accum_dtype) + g_last_local_X = T.alloc_local((1), dtype=accum_dtype) + g_last_local_Y = T.alloc_local((1), dtype=accum_dtype) + g_prod_X = T.alloc_fragment((1), dtype=accum_dtype) + g_prod_Y = T.alloc_fragment((1), dtype=accum_dtype) + + data_is_ready = T.alloc_barrier(arrive_count=[96] * num_stages) + data_is_free = T.alloc_barrier(arrive_count=[384] * num_stages) + + bar_0 = T.alloc_barrier(arrive_count=416) + bar_1 = T.alloc_barrier(arrive_count=256) + bar_2 = T.alloc_barrier(arrive_count=384) + bar_3 = T.alloc_barrier(arrive_count=128) + + T.use_swizzle(10) + + tx = T.get_thread_binding() + + PRODUCER_NREG = 24 + CONSUMER_S_NREG = 168 + CONSUMER_X_NREG = 160 + CONSUMER_Y_NREG = 160 + + if tx < 128: + T.set_max_nreg(CONSUMER_S_NREG, 1) + + # Initialize S + if use_initial_state: + T.copy(h0[bb, bh, 0:DK, 0:DV], h_fragment) + else: + T.clear(h_fragment) + + # Main Loop + for i_s in T.serial(num_iters): + # [STAGE = i_s % num_stages] + T.barrier_wait( + data_is_ready[i_s % num_stages], (i_s // num_stages + 0) % 2 + ) + T.barrier_arrive(bar_0) + + # [STAGE = i_s % num_stages] 0 + T.barrier_wait(bar_0, i_s % 2) + # S4[1] S + T.copy(h_fragment, h_shared) + T.barrier_arrive(bar_1) + + # [STAGE = i_s % num_stages] 1 + T.barrier_wait(bar_1, i_s % 2) + # S = g_last * S + g_last_local_S[0] = T.exp2( + g_shared[i_s % num_stages, block_S - 1] * 1.442695 + ) + for j_k, j_v in T.Parallel(DK, DV): + h_fragment[j_k, j_v] *= g_last_local_S[0] + T.barrier_arrive(bar_2) + + # [STAGE = i_s % num_stages] 2 + T.barrier_wait(bar_2, i_s % 2) + # S += X^T @ Y + T.gemm_v1( + x_shared, + y_shared, + h_fragment, + transpose_A=True, + clear_accum=False, + ) + T.barrier_arrive(bar_3) + + T.barrier_arrive(data_is_free[i_s % num_stages]) + + # Store final S + if store_final_state: + T.copy(h_fragment, ht[bb, bh, 0:DK, 0:DV]) + + elif tx < 256: + T.set_max_nreg(CONSUMER_X_NREG, 1) + + if calc_mt: + for j_k, j_v in T.Parallel(DK, DK // 2): + if j_k == j_v + DK // 2: + m_fragment_R[j_k, j_v] = 1 + else: + m_fragment_R[j_k, j_v] = 0 + g_prod_X[0] = 0 + + # Main Loop + for i_s in T.serial(num_iters): + # [STAGE = i_s % num_stages] + T.barrier_wait( + data_is_ready[i_s % num_stages], (i_s // num_stages + 0) % 2 + ) + T.barrier_arrive(bar_0) + + # [STAGE = i_s % num_stages] 0 + T.barrier_wait(bar_0, i_s % 2) + # X = A^T @ K + T.gemm_v1( + a_shared[i_s % num_stages, :, :], + k_shared[i_s % num_stages, :, :], + x_fragment, + transpose_A=True, + clear_accum=True, + ) + + # [STAGE = i_s % num_stages] 1 + # X = - b * X + for j_s, j_k in T.Parallel(block_S, DK): + x_fragment[j_s, j_k] *= -b_shared[i_s % num_stages, j_s] + # S2[1] X + T.copy(x_fragment, x_shared) + T.barrier_arrive(bar_2) + + if calc_mt: + # [STAGE = i_s % num_stages] 2 + g_prod_X[0] += g_shared[i_s % num_stages, block_S - 1] + # S4[2] M + T.copy(m_fragment_R, m_shared_R) + + # [STAGE = i_s % num_stages] 3 + T.barrier_wait(bar_3, i_s % 2) + # Z = K @ M + T.gemm_v1( + k_shared[i_s % num_stages, :, :], + m_shared_R, + z_fragment_R, + clear_accum=True, + ) + # S4[2] Z + T.copy(z_fragment_R, z_shared_R) + # M += X^T @ Z + T.gemm_v1( + x_shared, + z_shared_R, + m_fragment_R, + transpose_A=True, + clear_accum=False, + ) + + T.barrier_arrive(data_is_free[i_s % num_stages]) + + if calc_mt: + g_last_local_X[0] = T.exp2(g_prod_X[0] * 1.442695) + for j_k, j_v in T.Parallel(DK, DK // 2): + m_fragment_R[j_k, j_v] *= g_last_local_X[0] + T.copy(m_fragment_R, mt[bb, bh, 0:DK, DK // 2 :]) + + elif tx < 384: + T.set_max_nreg(CONSUMER_Y_NREG, 1) + + if calc_mt: + for j_k, j_v in T.Parallel(DK, DK // 2): + if j_k == j_v: + m_fragment_L[j_k, j_v] = 1 + else: + m_fragment_L[j_k, j_v] = 0 + g_prod_Y[0] = 0 + + # Main Loop + for i_s in T.serial(num_iters): + # [STAGE = i_s % num_stages] + T.barrier_wait( + data_is_ready[i_s % num_stages], (i_s // num_stages + 0) % 2 + ) + T.barrier_arrive(bar_0) + + # [STAGE = i_s % num_stages] 0 + T.barrier_wait(bar_0, i_s % 2) + # Precompute g_last/g + g_last_local_Y[0] = g_shared[i_s % num_stages, block_S - 1] + for j_s in T.Parallel(block_S): + g_rev_exp_shared[j_s] = T.exp2( + (g_last_local_Y[0] - g_shared[i_s % num_stages, j_s]) + * 1.442695 + ) + g_last_local_Y[0] = T.exp2(g_last_local_Y[0] * 1.442695) + T.barrier_arrive(bar_1) + + # [STAGE = i_s % num_stages] 1 + T.barrier_wait(bar_1, i_s % 2) + # U = K @ S + T.gemm_v1( + k_shared[i_s % num_stages, :, :], + h_shared, + y_fragment, + clear_accum=True, + ) + # Y = g_last * U - g_last/g * V + for j_s, j_v in T.Parallel(block_S, DV): + y_fragment[j_s, j_v] *= g_last_local_Y[0] + for j_s, j_v in T.Parallel(block_S, DV): + y_fragment[j_s, j_v] -= ( + v_shared[i_s % num_stages, j_s, j_v] * g_rev_exp_shared[j_s] + ) + # S2[2] Y + T.copy(y_fragment, y_shared) + T.barrier_arrive(bar_2) + + if calc_mt: + # [STAGE = i_s % num_stages] 2 + g_prod_Y[0] += g_shared[i_s % num_stages, block_S - 1] + # S4[2] M + T.copy(m_fragment_L, m_shared_L) + + # [STAGE = i_s % num_stages] 3 + T.barrier_wait(bar_3, i_s % 2) + # Z = K @ M + T.gemm_v1( + k_shared[i_s % num_stages, :, :], + m_shared_L, + z_fragment_L, + clear_accum=True, + ) + # S4[2] Z + T.copy(z_fragment_L, z_shared_L) + # M += X^T @ Z + T.gemm_v1( + x_shared, + z_shared_L, + m_fragment_L, + transpose_A=True, + clear_accum=False, + ) + + T.barrier_arrive(data_is_free[i_s % num_stages]) + + if calc_mt: + g_last_local_Y[0] = T.exp2(g_prod_Y[0] * 1.442695) + for j_k, j_v in T.Parallel(DK, DK // 2): + m_fragment_L[j_k, j_v] *= g_last_local_Y[0] + T.copy(m_fragment_L, mt[bb, bh, 0:DK, : DK // 2]) + + else: + T.set_max_nreg(PRODUCER_NREG, 0) + + if tx < 384 + 32: + for i_s in T.serial(num_iters): + T.barrier_wait( + data_is_free[i_s % num_stages], (i_s // num_stages + 1) % 2 + ) + left = seq_start_idx + i_s * block_S + right = left + block_S + + # Load K + T.copy( + k[batch_idx, left:right, bhg, 0:DK], + k_shared[i_s % num_stages, :, :], + ) + + T.barrier_arrive(data_is_ready[i_s % num_stages]) + + elif tx < 384 + 64: + for i_s in T.serial(num_iters): + T.barrier_wait( + data_is_free[i_s % num_stages], (i_s // num_stages + 1) % 2 + ) + left = seq_start_idx + i_s * block_S + right = left + block_S + + # Load V + T.copy( + v[batch_idx, left:right, bh, 0:DV], + v_shared[i_s % num_stages, :, :], + ) + # Load A TODO: Mask A for the last chunk + T.copy( + a[batch_idx, left:right, bh, 0:block_S], + a_shared[i_s % num_stages, :, :], + ) + + T.barrier_arrive(data_is_ready[i_s % num_stages]) + + elif tx < 384 + 96: + for i_s in T.serial(num_iters): + T.barrier_wait( + data_is_free[i_s % num_stages], (i_s // num_stages + 1) % 2 + ) + left = seq_start_idx + i_s * block_S + right = left + block_S + + # Load gamma + if right <= seq_end_idx: + for j_s in T.Parallel(block_S): + g_shared[i_s % num_stages, j_s] = g[ + batch_idx, left + j_s, bh + ] + else: + for j_s in T.Parallel(block_S): + if left + j_s < seq_end_idx: + g_shared[i_s % num_stages, j_s] = g[ + batch_idx, left + j_s, bh + ] + else: + g_shared[i_s % num_stages, j_s] = g[ + batch_idx, seq_end_idx - 1, bh + ] + # Load beta + if right <= seq_end_idx: + for j_s in T.Parallel(block_S): + b_shared[i_s % num_stages, j_s] = b[ + batch_idx, left + j_s, bh + ] + else: + for j_s in T.Parallel(block_S): + if left + j_s < seq_end_idx: + b_shared[i_s % num_stages, j_s] = b[ + batch_idx, left + j_s, bh + ] + else: + b_shared[i_s % num_stages, j_s] = 0 + + T.barrier_arrive(data_is_ready[i_s % num_stages]) + + else: + for i_s in T.serial(num_iters): + T.barrier_arrive(bar_0) + + T.barrier_wait(bar_0, i_s % 2) + T.barrier_wait(bar_1, i_s % 2) + # Store S + if store_h: + T.copy( + h_shared, + h[batch_idx, chunk_start_idx + i_s, bh, 0:DK, 0:DV], + ) + + return tilelang_prepare_h_kernel + + +def fused_gdr_h( + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + g: torch.Tensor, + b: torch.Tensor, + initial_state: torch.Tensor | None = None, + output_final_state: bool = True, + output_h: bool = True, + chunk_size: int = 64, + cu_seqlens: torch.LongTensor | None = None, + num_warmup_chunks: torch.LongTensor | None = None, +): + batch_size, num_tokens, Hg, K = k.shape + _, _, H, V = v.shape + assert K == V == 128 + assert chunk_size == 64 + + if cu_seqlens is None: + assert num_warmup_chunks is None + real_batch_size = batch_size + num_chunks = tilelang.cdiv(num_tokens, chunk_size) if output_h else 0 + cu_seqlens = torch.empty((batch_size + 1), dtype=torch.int32, device=k.device) + chunk_offsets = torch.empty( + (batch_size + 1), dtype=torch.int32, device=k.device + ) + is_varlen = False + is_cp = False + else: + real_batch_size = len(cu_seqlens) - 1 + chunk_offsets, num_chunks = prepare_chunk_offsets(cu_seqlens, chunk_size) + chunk_offsets = chunk_offsets.to(cu_seqlens.dtype) + num_chunks = num_chunks if output_h else 0 + is_varlen = True + if num_warmup_chunks is None: + num_warmup_chunks = torch.empty( + (real_batch_size, H), dtype=cu_seqlens.dtype, device=k.device + ) + is_cp = False + else: + is_cp = True + + use_initial_state = initial_state is not None + if initial_state is None: + initial_state = torch.empty( + (real_batch_size, H, K, V), dtype=torch.float32, device=k.device + ) + h = torch.empty((batch_size, num_chunks, H, K, V), dtype=k.dtype, device=k.device) + ht_dtype = k.dtype if is_cp else torch.float32 + final_state = torch.empty( + (real_batch_size, H, K, V), dtype=ht_dtype, device=k.device + ) + final_correction = torch.empty( + (real_batch_size, H, K, K), dtype=ht_dtype, device=k.device + ) + + tilelang_prepare_h_kernel = tilelang_prepare_h( + H, + Hg, + K, + V, + chunk_size, + qkva_dtype=k.dtype, + g_dtype=g.dtype, + b_dtype=b.dtype, + h0_dtype=initial_state.dtype, + ht_dtype=final_state.dtype, + h_dtype=h.dtype, + seqlen_dtype=cu_seqlens.dtype, + accum_dtype="float32", + use_initial_state=use_initial_state, + store_final_state=output_final_state, + store_h=output_h, + is_varlen=is_varlen, + is_cp=is_cp, + ) + tilelang_prepare_h_kernel( + k, + v, + a, + g, + b, + initial_state, + cu_seqlens, + chunk_offsets, + num_warmup_chunks, + h, + final_state, + final_correction, + ) + + if not output_final_state: + final_state = None + final_correction = None + if not output_h: + h = None + + return h, final_state, final_correction diff --git a/flash_qla/ops/gated_delta_rule/legacy/__init__.py b/flash_qla/ops/gated_delta_rule/legacy/__init__.py new file mode 100644 index 0000000..fa7921f --- /dev/null +++ b/flash_qla/ops/gated_delta_rule/legacy/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] + +from .sm_legacy import chunk_gated_delta_rule_fwd_legacy + +__all__ = ["chunk_gated_delta_rule_fwd_legacy"] diff --git a/flash_qla/ops/gated_delta_rule/legacy/csrc/gdn_forward.cu b/flash_qla/ops/gated_delta_rule/legacy/csrc/gdn_forward.cu new file mode 100644 index 0000000..088f517 --- /dev/null +++ b/flash_qla/ops/gated_delta_rule/legacy/csrc/gdn_forward.cu @@ -0,0 +1,348 @@ +#include + +#include +#include + +#include +#include +#include +#include + +namespace { + +void check_cuda(cudaError_t status, const char* context) { + if (status != cudaSuccess) { + throw std::runtime_error(std::string(context) + ": " + + cudaGetErrorString(status)); + } +} + +__device__ __forceinline__ float subgroup_sum_lane0(float value, + int width) { + constexpr unsigned mask = 0xffffffffU; + for (int offset = width / 2; offset > 0; offset >>= 1) { + value += __shfl_down_sync(mask, value, offset, width); + } + return value; +} + +__device__ __forceinline__ float subgroup_broadcast_lane0(float value, + int width) { + return __shfl_sync(0xffffffffU, value, 0, width); +} + +template +__global__ void gdn_forward_kernel(const float* __restrict__ q, + const float* __restrict__ k, + const float* __restrict__ v, + const float* __restrict__ gate, + const float* __restrict__ beta, + const float* __restrict__ initial_state, + float* __restrict__ output, + float* __restrict__ final_state, + int batch, + int tokens, + int q_heads, + int v_heads, + float scale) { + static_assert(D % (COLS * (32 / WIDTH)) == 0); + constexpr int subgroups_per_warp = 32 / WIDTH; + constexpr int rows_per_lane = (D + WIDTH - 1) / WIDTH; + + const int hv = blockIdx.x; + const int b = blockIdx.y; + const int subgroup = threadIdx.x / WIDTH; + const int lane = threadIdx.x % WIDTH; + const int group_base = + (blockIdx.z * blockDim.y + threadIdx.y) * subgroups_per_warp + subgroup; + const int col_base = group_base * COLS; + const int hq = hv / (v_heads / q_heads); + + float state_shard[COLS][rows_per_lane]; + +#pragma unroll + for (int c = 0; c < COLS; ++c) { + const int col = col_base + c; +#pragma unroll + for (int r = 0; r < rows_per_lane; ++r) { + const int row = r * WIDTH + lane; + float value = 0.0F; + if (row < D) { + const auto state_index = + (((static_cast(b) * v_heads + hv) * D + col) * D) + row; + value = initial_state == nullptr ? 0.0F : initial_state[state_index]; + } + state_shard[c][r] = value; + } + } + + for (int t = 0; t < tokens; ++t) { + const auto gate_index = + ((static_cast(b) * tokens + t) * v_heads + hv); + float gate_value = 0.0F; + float beta_value = 0.0F; + if (threadIdx.x == 0) { + gate_value = __expf(gate[gate_index]); + beta_value = beta[gate_index]; + } + gate_value = __shfl_sync(0xffffffffU, gate_value, 0); + beta_value = __shfl_sync(0xffffffffU, beta_value, 0); + + float k_reg[rows_per_lane]; + float q_reg[rows_per_lane]; + float kv_partial[COLS]; +#pragma unroll + for (int c = 0; c < COLS; ++c) { + kv_partial[c] = 0.0F; + } + +#pragma unroll + for (int r = 0; r < rows_per_lane; ++r) { + const int row = r * WIDTH + lane; + float q_value = 0.0F; + float k_value = 0.0F; + if (row < D) { + const auto qk_index = + (((static_cast(b) * tokens + t) * q_heads + hq) * D) + row; + q_value = q[qk_index]; + k_value = k[qk_index]; + } + q_reg[r] = q_value; + k_reg[r] = k_value; +#pragma unroll + for (int c = 0; c < COLS; ++c) { + kv_partial[c] += state_shard[c][r] * k_value; + } + } + + float delta[COLS]; +#pragma unroll + for (int c = 0; c < COLS; ++c) { + const float kv_col = subgroup_sum_lane0(kv_partial[c], WIDTH); + float delta_value = 0.0F; + if (lane == 0) { + const auto v_index = + (((static_cast(b) * tokens + t) * v_heads + hv) * D) + + col_base + c; + delta_value = (v[v_index] - gate_value * kv_col) * beta_value; + } + delta[c] = subgroup_broadcast_lane0(delta_value, WIDTH); + } + + float attn_partial[COLS]; +#pragma unroll + for (int c = 0; c < COLS; ++c) { + attn_partial[c] = 0.0F; + } + +#pragma unroll + for (int r = 0; r < rows_per_lane; ++r) { +#pragma unroll + for (int c = 0; c < COLS; ++c) { + const float new_state = + fmaf(k_reg[r], delta[c], gate_value * state_shard[c][r]); + state_shard[c][r] = new_state; + attn_partial[c] += new_state * q_reg[r]; + } + } + +#pragma unroll + for (int c = 0; c < COLS; ++c) { + attn_partial[c] = subgroup_sum_lane0(attn_partial[c], WIDTH); + } + + if (lane == 0) { + const auto out_base = + (((static_cast(b) * tokens + t) * v_heads + hv) * D); +#pragma unroll + for (int c = 0; c < COLS; ++c) { + output[out_base + col_base + c] = attn_partial[c] * scale; + } + } + } + +#pragma unroll + for (int c = 0; c < COLS; ++c) { + const int col = col_base + c; +#pragma unroll + for (int r = 0; r < rows_per_lane; ++r) { + const int row = r * WIDTH + lane; + if (row < D) { + const auto state_index = + (((static_cast(b) * v_heads + hv) * D + col) * D) + row; + final_state[state_index] = state_shard[c][r]; + } + } + } +} + +template +void launch_gdn_forward(const float* q, + const float* k, + const float* v, + const float* gate, + const float* beta, + const float* initial_state, + float* output, + float* final_state, + int batch, + int tokens, + int q_heads, + int v_heads, + float scale, + cudaStream_t stream) { + constexpr int cols = D == 128 ? 4 : 1; + constexpr int width = D == 128 ? 16 : 32; + constexpr int groups_per_warp = 32 / width; + constexpr int column_groups_per_block = 8; + const dim3 block(32, column_groups_per_block); + const int groups = D / cols; + const int z = (groups + column_groups_per_block * groups_per_warp - 1) / + (column_groups_per_block * groups_per_warp); + const dim3 grid(v_heads, batch, z); + gdn_forward_kernel + <<>>(q, + k, + v, + gate, + beta, + initial_state, + output, + final_state, + batch, + tokens, + q_heads, + v_heads, + scale); +} + +void validate_tensor(const torch::Tensor& tensor, + const char* name, + int64_t dims) { + TORCH_CHECK(tensor.is_cuda(), name, " must be a CUDA tensor"); + TORCH_CHECK(tensor.scalar_type() == torch::kFloat32, + name, + " must be float32"); + TORCH_CHECK(tensor.is_contiguous(), name, " must be contiguous"); + TORCH_CHECK(tensor.dim() == dims, name, " has wrong rank"); +} + +} // namespace + +std::vector gdn_forward(torch::Tensor q, + torch::Tensor k, + torch::Tensor v, + torch::Tensor gate, + torch::Tensor beta, + c10::optional initial_state, + double scale) { + validate_tensor(q, "q", 4); + validate_tensor(k, "k", 4); + validate_tensor(v, "v", 4); + validate_tensor(gate, "gate", 3); + validate_tensor(beta, "beta", 3); + + TORCH_CHECK(q.sizes() == k.sizes(), "q and k must have the same shape"); + const int batch = static_cast(q.size(0)); + const int tokens = static_cast(q.size(1)); + const int q_heads = static_cast(q.size(2)); + const int dim = static_cast(q.size(3)); + const int v_heads = static_cast(v.size(2)); + TORCH_CHECK(v.size(0) == batch && v.size(1) == tokens && v.size(3) == dim, + "v must have shape [B, T, Hv, D] matching q/k"); + TORCH_CHECK(gate.size(0) == batch && gate.size(1) == tokens && + gate.size(2) == v_heads, + "gate must have shape [B, T, Hv]"); + TORCH_CHECK(beta.sizes() == gate.sizes(), + "beta must have the same shape as gate"); + TORCH_CHECK(v_heads % q_heads == 0, "Hv must be divisible by Hq"); + TORCH_CHECK(dim == 16 || dim == 32 || dim == 64 || dim == 128, + "D must be one of 16, 32, 64, or 128"); + + const float* initial_ptr = nullptr; + if (initial_state.has_value() && initial_state.value().defined()) { + const auto& h0 = initial_state.value(); + validate_tensor(h0, "initial_state", 4); + TORCH_CHECK(h0.size(0) == batch && h0.size(1) == v_heads && + h0.size(2) == dim && h0.size(3) == dim, + "initial_state must have shape [B, Hv, D, D]"); + initial_ptr = h0.data_ptr(); + } + + auto output = torch::empty_like(v); + auto final_state = torch::empty({batch, v_heads, dim, dim}, q.options()); + + const auto stream = at::cuda::getCurrentCUDAStream(q.device().index()).stream(); + switch (dim) { + case 16: + launch_gdn_forward<16>(q.data_ptr(), + k.data_ptr(), + v.data_ptr(), + gate.data_ptr(), + beta.data_ptr(), + initial_ptr, + output.data_ptr(), + final_state.data_ptr(), + batch, + tokens, + q_heads, + v_heads, + static_cast(scale), + stream); + break; + case 32: + launch_gdn_forward<32>(q.data_ptr(), + k.data_ptr(), + v.data_ptr(), + gate.data_ptr(), + beta.data_ptr(), + initial_ptr, + output.data_ptr(), + final_state.data_ptr(), + batch, + tokens, + q_heads, + v_heads, + static_cast(scale), + stream); + break; + case 64: + launch_gdn_forward<64>(q.data_ptr(), + k.data_ptr(), + v.data_ptr(), + gate.data_ptr(), + beta.data_ptr(), + initial_ptr, + output.data_ptr(), + final_state.data_ptr(), + batch, + tokens, + q_heads, + v_heads, + static_cast(scale), + stream); + break; + case 128: + launch_gdn_forward<128>(q.data_ptr(), + k.data_ptr(), + v.data_ptr(), + gate.data_ptr(), + beta.data_ptr(), + initial_ptr, + output.data_ptr(), + final_state.data_ptr(), + batch, + tokens, + q_heads, + v_heads, + static_cast(scale), + stream); + break; + } + check_cuda(cudaGetLastError(), "gdn_forward launch"); + return {output, final_state}; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("gdn_forward", &gdn_forward, "SM70/SM75 legacy GDN forward"); +} diff --git a/flash_qla/ops/gated_delta_rule/legacy/sm_legacy.py b/flash_qla/ops/gated_delta_rule/legacy/sm_legacy.py new file mode 100644 index 0000000..f49c703 --- /dev/null +++ b/flash_qla/ops/gated_delta_rule/legacy/sm_legacy.py @@ -0,0 +1,104 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] + +from __future__ import annotations + +import os +from pathlib import Path + +import torch +from torch.utils.cpp_extension import load + +_EXT = None + + +def _load_ext(): + global _EXT + if _EXT is not None: + return _EXT + + if not torch.cuda.is_available(): + raise RuntimeError("SM70/SM75 legacy GDN backend requires CUDA") + + os.environ.setdefault("TORCH_CUDA_ARCH_LIST", "7.0;7.5") + src = Path(__file__).with_name("csrc") / "gdn_forward.cu" + _EXT = load( + name="flash_qla_legacy_gdn", + sources=[str(src)], + extra_cuda_cflags=["-O3"], + extra_cflags=["-O3"], + verbose=bool(int(os.environ.get("FLASH_QLA_LEGACY_VERBOSE_BUILD", "0"))), + ) + return _EXT + + +def _check_inputs( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + initial_state: torch.Tensor | None, +) -> None: + tensors = [q, k, v, g, beta] + if initial_state is not None: + tensors.append(initial_state) + + if any(not tensor.is_cuda for tensor in tensors): + raise ValueError("legacy GDN tensors must be CUDA tensors") +# if any(tensor.dtype != torch.float32 for tensor in tensors): +# raise ValueError("legacy GDN backend currently supports float32 tensors only") + if any(not tensor.is_contiguous() for tensor in tensors): + raise ValueError("legacy GDN tensors must be contiguous") + if q.ndim != 4 or k.ndim != 4 or v.ndim != 4: + raise ValueError("q, k, and v must have shape [B, T, H, D]") + if g.ndim != 3 or beta.ndim != 3: + raise ValueError("g and beta must have shape [B, T, Hv]") + if q.shape != k.shape: + raise ValueError("q and k must have the same shape") + + batch, tokens, q_heads, dim = q.shape + if v.shape[0] != batch or v.shape[1] != tokens or v.shape[3] != dim: + raise ValueError("v must have shape [B, T, Hv, D] matching q/k") + if g.shape != beta.shape or g.shape != v.shape[:3]: + raise ValueError("g and beta must have shape [B, T, Hv]") + if v.shape[2] % q_heads != 0: + raise ValueError("Hv must be divisible by Hq") + if dim not in (16, 32, 64, 128): + raise ValueError("legacy GDN backend supports D in {16, 32, 64, 128}") + if initial_state is not None and initial_state.shape != (batch, v.shape[2], dim, dim): + raise ValueError("initial_state must have shape [B, Hv, D, D]") + + +def chunk_gated_delta_rule_fwd_legacy( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float | None = None, + initial_state: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Run the experimental SM70/SM75 forward-only GDN backend. + + This legacy backend is intentionally explicit. It does not replace the + Hopper/SM90 TileLang path and currently supports only contiguous float32 + tensors for inference-oriented forward execution. + + Shapes: + q, k: [B, T, Hq, D] + v: [B, T, Hv, D] + g, beta: [B, T, Hv] + initial_state: optional [B, Hv, D, D] + + Returns: + output: [B, T, Hv, D] + final_state: [B, Hv, D, D] + """ + + _check_inputs(q, k, v, g, beta, initial_state) + if scale is None: + scale = q.shape[-1] ** -0.5 + + ext = _load_ext() + return ext.gdn_forward(q, k, v, g, beta, initial_state, float(scale)) diff --git a/flash_qla/ops/utils/__init__.py b/flash_qla/ops/utils/__init__.py new file mode 100644 index 0000000..9ab1286 --- /dev/null +++ b/flash_qla/ops/utils/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] + +from .cumsum import chunk_local_cumsum +from .group_reduce import group_reduce_vector + + +__all__ = [ + "chunk_local_cumsum", + "group_reduce_vector", +] diff --git a/flash_qla/ops/utils/cumsum.py b/flash_qla/ops/utils/cumsum.py new file mode 100644 index 0000000..9977719 --- /dev/null +++ b/flash_qla/ops/utils/cumsum.py @@ -0,0 +1,165 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] + +import torch +import tilelang +import tilelang.language as T + +from flash_qla.utils import prepare_chunk_indices + + +@tilelang.jit( + # out_idx=[-1], +) +def tilelang_chunk_local_cumsum( + H, + chunk_size, + accum_dtype, + g_dtype, + seqlen_dtype, + is_varlen, + reverse, +): + data_batch_size = T.dynamic("data_batch_size") + real_batch_size = T.dynamic("real_batch_size") + num_tokens = T.dynamic("num_tokens") + num_chunks = T.dynamic("num_chunks") + block_S = chunk_size + + g_shape = (data_batch_size, num_tokens, H) + + @T.macro + def kernel_body( + bb, + bc, + batch_idx, + chunk_idx, + seq_start_idx, + seq_end_idx, + g_raw, + g_cumsum, + ): + left = seq_start_idx + chunk_idx * block_S + right = left + block_S + + g_fragment = T.alloc_fragment((H, block_S), dtype=accum_dtype) + gT_fragment = T.alloc_fragment((block_S, H), dtype=g_dtype) + gT_shared = T.alloc_shared((block_S, H + 1), dtype=g_dtype) + + if right <= seq_end_idx: + T.copy(g_raw[bb, left:right, 0:H], gT_fragment) + else: + for j, i in T.Parallel(block_S, H): + if left + j < seq_end_idx: + gT_fragment[j, i] = g_raw[bb, left + j, i] + else: + gT_fragment[j, i] = 0 + T.copy(gT_fragment, gT_shared[:, :H]) + + for i, j in T.Parallel(H, block_S): + g_fragment[i, j] = gT_shared[j, i] + + T.cumsum(g_fragment, dim=1, reverse=reverse) + + for i, j in T.Parallel(H, block_S): + gT_shared[j, i] = g_fragment[i, j] + + T.copy(gT_shared[:, :H], gT_fragment) + if right <= seq_end_idx: + T.copy(gT_fragment, g_cumsum[bb, left:right, 0:H]) + else: + for j, i in T.Parallel(block_S, H): + if left + j < seq_end_idx: + g_cumsum[bb, left + j, i] = gT_fragment[j, i] + + if is_varlen: + + @T.prim_func + def tilelang_chunk_local_cumsum_kernel( + g_raw: T.Tensor(g_shape, dtype=g_dtype), + cu_seqlens: T.Tensor([real_batch_size + 1], dtype=seqlen_dtype), + chunk_indices: T.Tensor([num_chunks, 2], dtype=seqlen_dtype), + g_cumsum: T.Tensor(g_shape, dtype=g_dtype), + ): + with T.Kernel(num_chunks, threads=128) as (bc,): + bb = 0 + batch_idx = chunk_indices[bc, 0] + chunk_idx = chunk_indices[bc, 1] + seq_start_idx = cu_seqlens[batch_idx] + seq_end_idx = cu_seqlens[batch_idx + 1] + + kernel_body( + bb, + bc, + batch_idx, + chunk_idx, + seq_start_idx, + seq_end_idx, + g_raw, + g_cumsum, + ) + + else: + + @T.prim_func + def tilelang_chunk_local_cumsum_kernel( + g_raw: T.Tensor(g_shape, dtype=g_dtype), + g_cumsum: T.Tensor(g_shape, dtype=g_dtype), + num_chunks: T.int32, + ): + with T.Kernel(num_chunks, threads=128) as (bc,): + bb = bc % data_batch_size + batch_idx = bb + chunk_idx = bc // data_batch_size + seq_start_idx = 0 + seq_end_idx = num_tokens + + kernel_body( + bb, + bc, + batch_idx, + chunk_idx, + seq_start_idx, + seq_end_idx, + g_raw, + g_cumsum, + ) + + return tilelang_chunk_local_cumsum_kernel + + +def chunk_local_cumsum( + g: torch.Tensor, + chunk_size: int = 64, + cu_seqlens: torch.LongTensor | None = None, + reverse: bool = False, +): + batch_size, num_tokens, H = g.shape + assert g.stride(-1) == 1 + + if cu_seqlens is None: + num_chunks = batch_size * tilelang.cdiv(num_tokens, chunk_size) + seqlen_dtype = "int32" + is_varlen = False + else: + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) + seqlen_dtype = cu_seqlens.dtype + is_varlen = True + + g_cumsum = torch.empty_like(g) + + tilelang_chunk_local_cumsum_kernel = tilelang_chunk_local_cumsum( + H, + chunk_size, + g_dtype=g.dtype, + seqlen_dtype=seqlen_dtype, + accum_dtype="float32", + is_varlen=is_varlen, + reverse=reverse, + ) + if is_varlen: + tilelang_chunk_local_cumsum_kernel(g, cu_seqlens, chunk_indices, g_cumsum) + else: + tilelang_chunk_local_cumsum_kernel(g, g_cumsum, num_chunks) + + return g_cumsum diff --git a/flash_qla/ops/utils/group_reduce.py b/flash_qla/ops/utils/group_reduce.py new file mode 100644 index 0000000..9254083 --- /dev/null +++ b/flash_qla/ops/utils/group_reduce.py @@ -0,0 +1,79 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] + +import torch +import tilelang +import tilelang.language as T + + +@tilelang.jit( + # out_idx=[-1], +) +def tilelang_group_reduce_vector( + H, + Hg, + DK, + accum_dtype, + qkva_dtype, + block_size: int = 16, +): + batch_size = T.dynamic("batch_size") + num_tokens = T.dynamic("num_tokens") + + group_size = H // Hg + + buffer_shape = (batch_size, num_tokens, H, DK) + dqk_shape = (batch_size, num_tokens, Hg, DK) + + @T.prim_func + def tilelang_group_reduce_vector_kernel( + buffer: T.Tensor(buffer_shape, dtype=qkva_dtype), + result: T.Tensor(dqk_shape, dtype=qkva_dtype), + ): + with T.Kernel( + tilelang.cdiv(num_tokens, block_size), Hg, batch_size, threads=128 + ) as (bt, bhg, bb): + buffer_fragment = T.alloc_fragment((block_size, DK), dtype=accum_dtype) + result_fragment = T.alloc_fragment((block_size, DK), dtype=accum_dtype) + + T.clear(result_fragment) + for i in T.serial(group_size): + T.copy( + buffer[ + bb, + bt * block_size : (bt + 1) * block_size, + bhg * group_size + i, + 0:DK, + ], + buffer_fragment, + ) + for j, k in T.Parallel(block_size, DK): + result_fragment[j, k] += buffer_fragment[j, k] + T.copy( + result_fragment, + result[bb, bt * block_size : (bt + 1) * block_size, bhg, 0:DK], + ) + + return tilelang_group_reduce_vector_kernel + + +def group_reduce_vector( + buffer: torch.Tensor, + Hg: int, +): + batch_size, num_tokens, H, K = buffer.shape + + result = torch.empty( + (batch_size, num_tokens, Hg, K), dtype=buffer.dtype, device=buffer.device + ) + + tilelang_group_reduce_vector_kernel = tilelang_group_reduce_vector( + H, + Hg, + K, + qkva_dtype=buffer.dtype, + accum_dtype="float32", + ) + tilelang_group_reduce_vector_kernel(buffer, result) + + return result diff --git a/flash_qla/utils/__init__.py b/flash_qla/utils/__init__.py new file mode 100644 index 0000000..540805a --- /dev/null +++ b/flash_qla/utils/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] + +from .profiler import profile +from .pack import pad_and_reshape, pack, unpack, fill_last_chunk_of_g +from .math import l2norm +from .index import prepare_chunk_indices, prepare_chunk_offsets, tensor_cache + + +__all__ = [ + "profile", + "pad_and_reshape", + "pack", + "unpack", + "fill_last_chunk_of_g", + "l2norm", + "prepare_chunk_indices", + "prepare_chunk_offsets", + "tensor_cache", +] diff --git a/flash_qla/utils/index.py b/flash_qla/utils/index.py new file mode 100644 index 0000000..ef03c32 --- /dev/null +++ b/flash_qla/utils/index.py @@ -0,0 +1,138 @@ +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import functools +from typing import Any +from collections import OrderedDict +from collections.abc import Callable + +import torch +import tilelang +import tilelang.language as T + + +def tensor_cache( + fn: Callable[..., torch.Tensor], +) -> Callable[..., torch.Tensor]: + """ + A decorator that caches the most recent results of a function with tensor inputs. + + This decorator will store the output of the decorated function for the most recent set of input tensors. + The cache is limited to a fixed size (default is 256). When the cache is full, the oldest entry will be removed. + + Args: + fn (Callable[..., torch.Tensor]): + The function to be decorated. It should take tensor inputs and return tensor outputs. + + Returns: + Callable[..., torch.Tensor]: + A wrapped version of the input function with single-entry caching. + """ + + cache: "OrderedDict[tuple[tuple[int, ...], tuple[tuple[str, int], ...]], tuple[tuple[Any, ...], dict[str, Any], Any]]" = OrderedDict() + cache_size = 256 + + def get_id(x: Any): + if (type(x) is int) or (type(x) is float) or (type(x) is str): + return x + else: + return id(x) + + def make_identity_key( + args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> tuple[tuple[int, ...], tuple[tuple[str, int], ...]]: + args_key = tuple(get_id(a) for a in args) + kwargs_key = tuple(sorted((k, get_id(v)) for k, v in kwargs.items())) + return args_key, kwargs_key + + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + nonlocal cache, cache_size + key = make_identity_key(args, kwargs) + if key in cache: + cache.move_to_end(key, last=True) + _, _, cached_result = cache[key] + return cached_result + + result = fn(*args, **kwargs) + cache[key] = (args, kwargs, result) + cache.move_to_end(key, last=True) + if len(cache) > cache_size: + cache.popitem(last=False) + return result + + return wrapper + + +@tensor_cache +def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + return torch.diff(cu_seqlens) + + +@tensor_cache +def prepare_chunk_indices( + cu_seqlens: torch.LongTensor, + chunk_size: int, +) -> torch.LongTensor: + # TODO: tilelang kernel + indices = torch.cat( + [ + torch.arange(n) + for n in tilelang.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist() + ] + ) + return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens) + + +@tilelang.jit() +def tilelang_prepare_chunk_offsets( + chunk_size, + block_size, + dtype, +): + batch_size_plus_1 = T.dynamic("batch_size_plus_1") + num_threads = min(max(block_size, 32), 128) + + @T.prim_func + def tilelang_prepare_chunk_offsets_kernel( + cu_seqlens: T.Tensor([batch_size_plus_1], dtype=dtype), + chunk_offsets: T.Tensor([batch_size_plus_1], dtype=dtype), + ): + with T.Kernel(1, threads=num_threads) as (bb,): + _batch_size = T.alloc_var("int32") + _batch_size = batch_size_plus_1 - 1 + + seqlen_start_fragment = T.alloc_fragment((block_size), dtype=dtype) + seqlen_end_fragment = T.alloc_fragment((block_size), dtype=dtype) + chunk_offset_fragment = T.alloc_fragment((block_size), dtype=dtype) + + T.copy(cu_seqlens[: batch_size_plus_1 - 1], seqlen_start_fragment) + T.copy(cu_seqlens[1:], seqlen_end_fragment) + + for i in T.Parallel(block_size): + chunk_offset_fragment[i] = ( + seqlen_end_fragment[i] - seqlen_start_fragment[i] + ) + chunk_offset_fragment[i] = ( + chunk_offset_fragment[i] + chunk_size - 1 + ) // chunk_size + T.cumsum(src=chunk_offset_fragment, dim=0) + + chunk_offsets[0] = 0 + T.copy(chunk_offset_fragment, chunk_offsets[1:]) + + return tilelang_prepare_chunk_offsets_kernel + + +@tensor_cache +def prepare_chunk_offsets( + cu_seqlens: torch.LongTensor, + chunk_size: int, +) -> torch.LongTensor: + chunk_offsets = torch.empty_like(cu_seqlens) + tilelang_prepare_chunk_offsets_kernel = tilelang_prepare_chunk_offsets( + chunk_size=chunk_size, + block_size=tilelang.next_power_of_2(cu_seqlens.shape[0] - 1), + dtype=cu_seqlens.dtype, + ) + tilelang_prepare_chunk_offsets_kernel(cu_seqlens, chunk_offsets) + return chunk_offsets, chunk_offsets[-1].item() diff --git a/flash_qla/utils/math.py b/flash_qla/utils/math.py new file mode 100644 index 0000000..b2b5b2a --- /dev/null +++ b/flash_qla/utils/math.py @@ -0,0 +1,21 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] + +import torch + + +@torch.compile +def l2norm_compiled(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6): + inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) + return (x * inv_norm).to(x.dtype) + + +def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6): + assert dim == -1 + assert x.stride(-1) == 1 + raw_shape = x.shape + x = x.view((-1, raw_shape[-1])) + torch._dynamo.mark_dynamic(x, 0) + y = l2norm_compiled(x, dim, eps) + y = y.view(raw_shape) + return y diff --git a/flash_qla/utils/pack.py b/flash_qla/utils/pack.py new file mode 100644 index 0000000..9353428 --- /dev/null +++ b/flash_qla/utils/pack.py @@ -0,0 +1,83 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] + +import torch + + +def unpack( + x: torch.Tensor, # [B, T, H] + cu_seqlens: torch.Tensor, +): + assert x.shape[0] == 1 + assert len(cu_seqlens.shape) == 1 + max_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + batch_size = cu_seqlens.shape[0] - 1 + y = torch.zeros((batch_size, max_len, *x.shape[2:]), dtype=x.dtype, device=x.device) + for i in range(batch_size): + start = cu_seqlens[i].item() + end = cu_seqlens[i + 1].item() + y[i, : end - start] = x[0, start:end] + return y + + +def pack( + x: torch.Tensor, # [B, T, H] + cu_seqlens: torch.Tensor, +): + assert len(cu_seqlens.shape) == 1 + sum_len = cu_seqlens[-1].item() + batch_size = cu_seqlens.shape[0] - 1 + y = torch.empty((1, sum_len, *x.shape[2:]), dtype=x.dtype, device=x.device) + for i in range(batch_size): + start = cu_seqlens[i].item() + end = cu_seqlens[i + 1].item() + y[0, start:end] = x[i, : end - start] + return y + + +def pad_and_reshape( + x: torch.Tensor, + dim: int, + chunk_size: int = 64, +): + sequence_length = x.shape[dim] + pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size + zeros = [ + 0, + ] * (2 * (len(x.shape) - 1 - dim)) + padded = torch.nn.functional.pad(x, (*zeros, 0, pad_size)) + return padded.reshape((*x.shape[:dim], -1, chunk_size, *x.shape[dim + 1 :])) + + +def fill_last_chunk_of_g( + g: torch.Tensor, + num_tokens: int, + cu_seqlens: torch.Tensor, + chunk_size: int = 64, + reverse: bool = False, +): + if cu_seqlens is None: + last_chunk_size = num_tokens % chunk_size + if last_chunk_size > 0: + if reverse: + g[:, -1, last_chunk_size - 1] += g[:, -1, -1] + else: + g[:, -1, last_chunk_size:] = g[ + :, -1, last_chunk_size - 1 : last_chunk_size + ] + else: + for i in range(cu_seqlens.shape[0] - 1): + start = cu_seqlens[i].item() + end = cu_seqlens[i + 1].item() + last_chunk_idx = (end - start) // chunk_size + last_chunk_size = (end - start) % chunk_size + if last_chunk_size > 0: + if reverse: + g[i, last_chunk_idx, last_chunk_size - 1] += g[ + i, last_chunk_idx, -1 + ] + else: + g[i, last_chunk_idx, last_chunk_size:] = g[ + i, last_chunk_idx, last_chunk_size - 1 : last_chunk_size + ] + return g diff --git a/flash_qla/utils/profiler.py b/flash_qla/utils/profiler.py new file mode 100644 index 0000000..2ceb8ef --- /dev/null +++ b/flash_qla/utils/profiler.py @@ -0,0 +1,25 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] + +import torch +import tilelang + + +def profile(func, inputs, wait: int = 50, warmup: int = 50, rep: int = 100): + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=rep), + # on_trace_ready=torch.profiler.tensorboard_trace_handler('./tb'), + ) as prof: + for idx in range(wait + warmup + rep): + func(*inputs) + prof.step() + # print(prof.key_averages().table(sort_by="cpu_time", row_limit=10)) + result = {x.key: x.device_time * 1e-3 for x in prof.key_averages()} + result["total"] = tilelang.profiler.do_bench( + lambda: func(*inputs), warmup=warmup, rep=rep + ) + return result diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..2c6c516 --- /dev/null +++ b/setup.py @@ -0,0 +1,31 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] + +import os +import subprocess +from setuptools import setup, find_packages + +this_dir = os.path.dirname(os.path.abspath(__file__)) + +rev = os.getenv("QLA_VERSION_SUFFIX", "") +if not rev: + try: + cmd = ["git", "rev-parse", "--short", "HEAD"] + rev = "+" + subprocess.check_output(cmd, cwd=this_dir).decode("ascii").rstrip() + except Exception: + rev = "" + +setup( + name="flash_qla", + version="0.1.0" + rev, + description="FlashQLA: Fused TileLang kernels for Linear Attention", + packages=find_packages(), + license="MIT", + python_requires=">=3.10", + install_requires=[ + "torch>=2.8", + "tilelang==0.1.8", + "apache-tvm-ffi==0.1.9", + ], + zip_safe=False, +) diff --git a/tests/ref_gdr.py b/tests/ref_gdr.py new file mode 100644 index 0000000..fde5335 --- /dev/null +++ b/tests/ref_gdr.py @@ -0,0 +1,700 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] + +import torch + +from flash_qla.utils import ( + pad_and_reshape, + pack, + unpack, + fill_last_chunk_of_g, + prepare_chunk_offsets, +) + + +def torch_cumsum( + x: torch.Tensor, # [B, T, H] + cu_seqlens: torch.Tensor = None, + chunk_size: int = 64, + reverse: bool = False, +): + if cu_seqlens is not None: + x = unpack(x, cu_seqlens) + + batch_size, num_tokens, num_heads = x.shape + + x = pad_and_reshape(x, dim=1, chunk_size=chunk_size) + + if reverse: + x = torch.flip(x, dims=(2,)) + x = x.cumsum(dim=2) + x = torch.flip(x, dims=(2,)) + else: + x = x.cumsum(dim=2) + x = x.reshape(batch_size, -1, num_heads) + x = x[:, :num_tokens] + + if cu_seqlens is not None: + x = pack(x, cu_seqlens) + return x + + +def torch_kkt_fwd( + k: torch.Tensor, # [B, T, Hk, K] + g: torch.Tensor, # [B, T, Hv] + beta: torch.Tensor, # [B, T, Hv] + cu_seqlens: torch.Tensor = None, + chunk_size: int = 64, +): + if cu_seqlens is not None: + k = unpack(k, cu_seqlens) + g = unpack(g, cu_seqlens) + beta = unpack(beta, cu_seqlens) + + batch_size, num_tokens, num_k_heads, head_dim = k.shape + num_v_heads = g.shape[-1] + + if num_k_heads != num_v_heads: + k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2) + + k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, H, K] + g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, H] + beta = pad_and_reshape(beta, dim=1, chunk_size=chunk_size) # [B, N, C, H] + + mask = torch.triu( + torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device) + ) + decay_mask = torch.exp(g[:, :, :, None, :] - g[:, :, None, :, :]) + decay_mask = decay_mask.masked_fill(mask[None, None, :, :, None], 0.0) + # decay_mask = torch.where(mask[None, None, :, :, None], decay_mask, 0.0) + attn = torch.einsum( + "bnchk, bndhk -> bnchd", k * beta.unsqueeze(-1), k + ) * decay_mask.swapaxes(-2, -1) # [B, N, C, H, D] + attn = attn.reshape(batch_size, -1, num_v_heads, chunk_size)[:, :num_tokens] + + if cu_seqlens is not None: + attn = pack(attn, cu_seqlens) + return attn + + +def torch_solve( + x: torch.Tensor, # [B, T, H, D] + cu_seqlens: torch.Tensor = None, +): + if cu_seqlens is not None: + x = unpack(x, cu_seqlens) + + batch_size, num_tokens, num_heads, chunk_size = x.shape + + x = -pad_and_reshape(x, dim=1, chunk_size=chunk_size).swapaxes( + 2, 3 + ) # [B, N, H, C, D] + + for i in range(1, chunk_size): + row = x[..., i, :i].clone() + sub = x[..., :i, :i].clone() + x[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + x += torch.eye(chunk_size, dtype=x.dtype, device=x.device) + x = x.swapaxes(2, 3).reshape((batch_size, -1, num_heads, chunk_size))[ + :, :num_tokens + ] + + if cu_seqlens is not None: + x = pack(x, cu_seqlens) + return x + + +def torch_w_u_fwd( + k: torch.Tensor, # [B, T, Hk, K] + v: torch.Tensor, # [B, T, Hv, V] + g: torch.Tensor, # [B, T, Hv] + beta: torch.Tensor, # [B, T, Hv] + A: torch.Tensor, # [B, T, Hv, D] + cu_seqlens: torch.Tensor = None, +): + if cu_seqlens is not None: + k = unpack(k, cu_seqlens) + v = unpack(v, cu_seqlens) + A = unpack(A, cu_seqlens) + beta = unpack(beta, cu_seqlens) + g = unpack(g, cu_seqlens) + + batch_size, num_tokens, _, chunk_size = A.shape + _, _, num_k_heads, head_dim_k = k.shape + _, _, num_v_heads, head_dim_v = v.shape + + if num_k_heads != num_v_heads: + k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2) + + k_beta = pad_and_reshape( + k * beta.unsqueeze(-1) * g.exp().unsqueeze(-1), dim=1, chunk_size=chunk_size + ) # [B, N, C, Hv, K] + v_beta = pad_and_reshape( + v * beta.unsqueeze(-1), dim=1, chunk_size=chunk_size + ) # [B, N, C, Hv, V] + A = pad_and_reshape(A, dim=1) + + w = torch.einsum("bnchd, bndhk -> bnchk", A, k_beta).reshape( + (batch_size, -1, num_v_heads, head_dim_k) + )[:, :num_tokens] + u = torch.einsum("bnchd, bndhk -> bnchk", A, v_beta).reshape( + (batch_size, -1, num_v_heads, head_dim_v) + )[:, :num_tokens] + + if cu_seqlens is not None: + w = pack(w, cu_seqlens) + u = pack(u, cu_seqlens) + return w, u + + +def torch_chunk_gdr_fwd( + k: torch.Tensor, # [B, T, Hk, K] + w: torch.Tensor, # [B, T, Hv, K] + u: torch.Tensor, # [B, T, Hv, V] + g: torch.Tensor, # [B, T, Hv] + initial_state: torch.Tensor = None, # [B, Hv, K, V] + cu_seqlens: torch.Tensor = None, + chunk_size: int = 64, +): + if cu_seqlens is not None: + k = unpack(k, cu_seqlens) + w = unpack(w, cu_seqlens) + u = unpack(u, cu_seqlens) + g = unpack(g, cu_seqlens) + + batch_size, num_tokens, num_k_heads, head_dim_k = k.shape + _, _, num_v_heads, head_dim_v = u.shape + + if num_k_heads != num_v_heads: + k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2) + + k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K] + w = pad_and_reshape(w, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K] + u = pad_and_reshape(u, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V] + g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, Hv] + g = fill_last_chunk_of_g(g, num_tokens, cu_seqlens, chunk_size=chunk_size) + + if initial_state is None: + last_state = torch.zeros( + (batch_size, num_v_heads, head_dim_k, head_dim_v), + dtype=g.dtype, + device=g.device, + ) + else: + last_state = initial_state.to(g.dtype, copy=True) + + h, vn = [], [] + for i in range(k.shape[1]): + h.append(last_state) + v_new = u[:, i] - torch.einsum("bchk, bhkv -> bchv", w[:, i], last_state) + vn.append(v_new) + last_state = last_state * g[:, i, -1, :, None, None].exp() + last_state = last_state + torch.einsum( + "bchk, bchv -> bhkv", + k[:, i] * (g[:, i, -1:, :, None] - g[:, i, :, :, None]).exp(), + v_new, + ) + h = torch.stack(h, dim=1).contiguous() + vn = ( + torch.stack(vn, dim=1) + .reshape((batch_size, -1, num_v_heads, head_dim_v))[:, :num_tokens] + .contiguous() + ) + + if cu_seqlens is not None: + vn = pack(vn, cu_seqlens) + h = pack(h, prepare_chunk_offsets(cu_seqlens, chunk_size)) + + return h, vn, last_state + + +def torch_chunk_o_fwd( + q: torch.Tensor, # [B, T, Hk, K] + k: torch.Tensor, # [B, T, Hk, K] + v: torch.Tensor, # [B, T, Hv, K] + h: torch.Tensor, # [B, N, Hv, K, V] + g: torch.Tensor, # [B, T, Hv] + cu_seqlens: torch.Tensor = None, + scale: float = None, + chunk_size: int = 64, +): + if cu_seqlens is not None: + q = unpack(q, cu_seqlens) + k = unpack(k, cu_seqlens) + v = unpack(v, cu_seqlens) + g = unpack(g, cu_seqlens) + h = unpack(h, prepare_chunk_offsets(cu_seqlens, chunk_size)) + + batch_size, num_tokens, num_k_heads, head_dim_k = k.shape + _, _, num_v_heads, head_dim_v = v.shape + + if num_k_heads != num_v_heads: + q = q.repeat_interleave(num_v_heads // num_k_heads, dim=2) + k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2) + + scale = scale or head_dim_k ** (-0.5) + + q = pad_and_reshape(q, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K] + k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K] + v = pad_and_reshape(v, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K] + g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, Hv] + + q = q * scale + + mask = torch.triu( + torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device), + diagonal=1, + ) + decay_mask = torch.exp(g[:, :, :, None, :] - g[:, :, None, :, :]) + decay_mask = decay_mask.masked_fill( + mask[None, None, :, :, None], 0.0 + ) # [B, N, C, D, Hv] + + attn = torch.einsum("bnchk, bndhk -> bncdh", q, k) * decay_mask + attn_inter = torch.einsum("bnchk, bnhkv -> bnchv", q * g.exp().unsqueeze(-1), h) + o = attn_inter + torch.einsum("bncdh, bndhv -> bnchv", attn, v) + + o = o.reshape((batch_size, -1, num_v_heads, head_dim_v))[:, :num_tokens] + if cu_seqlens is not None: + o = pack(o, cu_seqlens) + return o + + +def torch_chunk_dv_bwd( + q: torch.Tensor, # [B, T, Hk, K] + k: torch.Tensor, # [B, T, Hk, K] + g: torch.Tensor, # [B, T, Hv] + do: torch.Tensor, # [B, T, Hv, V] + cu_seqlens: torch.Tensor = None, + scale: float = None, + chunk_size: int = 64, +): + if cu_seqlens is not None: + q = unpack(q, cu_seqlens) + k = unpack(k, cu_seqlens) + g = unpack(g, cu_seqlens) + do = unpack(do, cu_seqlens) + + batch_size, num_tokens, num_k_heads, head_dim_k = k.shape + _, _, num_v_heads, head_dim_v = do.shape + + if num_k_heads != num_v_heads: + q = q.repeat_interleave(num_v_heads // num_k_heads, dim=2) + k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2) + + scale = scale or head_dim_k ** (-0.5) + + q = pad_and_reshape(q, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K] + k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K] + g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, Hv] + do = pad_and_reshape(do, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V] + + q = q * scale + + mask = torch.triu( + torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device), + diagonal=1, + ) + decay_mask = torch.exp(g[:, :, :, None, :] - g[:, :, None, :, :]) + decay_mask = decay_mask.masked_fill( + mask[None, None, :, :, None], 0.0 + ) # [B, N, C, D, Hv] + + attn = torch.einsum("bnchk, bndhk -> bncdh", q, k) * decay_mask + dv = torch.einsum("bncdh, bnchv -> bndhv", attn, do) + + dv = dv.reshape((batch_size, -1, num_v_heads, head_dim_v))[:, :num_tokens] + if cu_seqlens is not None: + dv = pack(dv, cu_seqlens) + return dv + + +def torch_chunk_gdr_bwd( + q: torch.Tensor, # [B, T, Hk, K] + k: torch.Tensor, # [B, T, Hk, K] + w: torch.Tensor, # [B, T, Hv, K] + g: torch.Tensor, # [B, T, Hv] + do: torch.Tensor, # [B, T, Hv, V] + dv: torch.Tensor, # [B, T, Hv, V] + h0: torch.Tensor = None, # [B, Hv, K, V] + dht: torch.Tensor = None, # [B, Hv, K, V] + cu_seqlens: torch.Tensor = None, + scale: float = None, + chunk_size: int = 64, +): + if cu_seqlens is not None: + q = unpack(q, cu_seqlens) + k = unpack(k, cu_seqlens) + w = unpack(w, cu_seqlens) + g = unpack(g, cu_seqlens) + do = unpack(do, cu_seqlens) + dv = unpack(dv, cu_seqlens) + + batch_size, num_tokens, num_k_heads, head_dim_k = k.shape + _, _, num_v_heads, head_dim_v = do.shape + + if num_k_heads != num_v_heads: + q = q.repeat_interleave(num_v_heads // num_k_heads, dim=2) + k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2) + + scale = scale or head_dim_k ** (-0.5) + + q = pad_and_reshape(q, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K] + k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K] + w = pad_and_reshape(w, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K] + g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, Hv] + do = pad_and_reshape(do, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V] + dv = pad_and_reshape(dv, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V] + g = fill_last_chunk_of_g(g, num_tokens, cu_seqlens, chunk_size=chunk_size) + + q = q * scale + + if dht is None: + dstate = torch.zeros( + (batch_size, num_v_heads, head_dim_k, head_dim_v), + dtype=g.dtype, + device=g.device, + ) + else: + dstate = dht.to(g.dtype, copy=True) + dstate_inter = torch.einsum("bnchk, bnchv -> bnhkv", q * g.exp().unsqueeze(-1), do) + + dh = [] + for i in reversed(range(k.shape[1])): + dh.insert(0, dstate) + dv[:, i] += torch.einsum( + "bchk, bhkv -> bchv", + k[:, i] * (g[:, i, -1:, :, None] - g[:, i, :, :, None]).exp(), + dstate, + ) + dstate = dstate * g[:, i, -1, :, None, None].exp() + dstate = ( + dstate + + dstate_inter[:, i] + - torch.einsum("bchk, bchv -> bhkv", w[:, i], dv[:, i]) + ) + dh = torch.stack(dh, dim=1).contiguous() + + dh0 = None if h0 is None else dstate + dv = dv.reshape((batch_size, -1, num_v_heads, head_dim_v))[:, :num_tokens] + if cu_seqlens is not None: + dv = pack(dv, cu_seqlens) + dh = pack(dh, prepare_chunk_offsets(cu_seqlens, chunk_size)) + return dh, dh0, dv + + +def torch_chunk_dqkwg_bwd( + q: torch.Tensor, # [B, T, Hk, K] + k: torch.Tensor, # [B, T, Hk, K] + v: torch.Tensor, # [B, T, Hv, V] + w: torch.Tensor, # [B, T, Hv, K] + g: torch.Tensor, # [B, T, Hv] + h: torch.Tensor, # [B, N, Hv, K, V] + dv: torch.Tensor, # [B, T, Hv, V] + do: torch.Tensor, # [B, T, Hv, V] + dh: torch.Tensor, # [B, N, Hv, K, V] + cu_seqlens: torch.Tensor = None, + scale: float = None, + chunk_size: int = 64, +): + if cu_seqlens is not None: + q = unpack(q, cu_seqlens) + k = unpack(k, cu_seqlens) + v = unpack(v, cu_seqlens) + w = unpack(w, cu_seqlens) + g = unpack(g, cu_seqlens) + do = unpack(do, cu_seqlens) + dv = unpack(dv, cu_seqlens) + h = unpack(h, prepare_chunk_offsets(cu_seqlens, chunk_size)) + dh = unpack(dh, prepare_chunk_offsets(cu_seqlens, chunk_size)) + + batch_size, num_tokens, num_k_heads, head_dim_k = k.shape + _, _, num_v_heads, head_dim_v = do.shape + + if num_k_heads != num_v_heads: + q = q.repeat_interleave(num_v_heads // num_k_heads, dim=2) + k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2) + + scale = scale or head_dim_k ** (-0.5) + + q = pad_and_reshape(q, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K] + k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K] + v = pad_and_reshape(v, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V] + w = pad_and_reshape(w, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K] + g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, Hv] + do = pad_and_reshape(do, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V] + dv = pad_and_reshape(dv, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V] + g = fill_last_chunk_of_g(g, num_tokens, cu_seqlens, chunk_size=chunk_size) + + mask = torch.triu( + torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device), + diagonal=1, + ) + decay_mask = torch.exp(g[:, :, :, None, :] - g[:, :, None, :, :]) + decay_mask = decay_mask.masked_fill( + mask[None, None, :, :, None], 0.0 + ) # [B, N, C, D, Hv] + + dg_last = (h * dh).sum(dim=-1).sum(dim=-1) # [B, N, Hv] + ds = torch.einsum("bnchv, bndhv -> bncdh", do, v) + dq = torch.einsum("bnchv, bnhkv -> bnchk", do, h) + dk = torch.einsum("bnchv, bnhkv -> bnchk", v, dh) + dw = -torch.einsum("bnchv, bnhkv -> bnchk", dv, h) + + g_last = g[:, :, -1] + dg_last *= g_last.exp() + dq = dq * g.unsqueeze(-1).exp() * scale + dg = (q * dq).sum(dim=-1) # [B, N, C, Hv] + dk = dk * (g_last.unsqueeze(-2) - g).unsqueeze(-1).exp() + dg -= (k * dk).sum(dim=-1) + dg_last += (k * dk).sum(dim=-1).sum(dim=-2) + ds *= decay_mask * scale + ds2 = ds * torch.einsum("bnchk, bndhk -> bncdh", q, k) + dg += ds2.sum(dim=-2) + dg -= ds2.sum(dim=-3) + dq += torch.einsum("bncdh, bndhk -> bnchk", ds, k) + dk += torch.einsum("bncdh, bnchk -> bndhk", ds, q) + dg[:, :, -1] += dg_last + + dg = fill_last_chunk_of_g( + dg, num_tokens, cu_seqlens, chunk_size=chunk_size, reverse=True + ) + dq = dq.reshape((batch_size, -1, num_v_heads, head_dim_k))[:, :num_tokens] + dk = dk.reshape((batch_size, -1, num_v_heads, head_dim_k))[:, :num_tokens] + dw = dw.reshape((batch_size, -1, num_v_heads, head_dim_k))[:, :num_tokens] + dg = dg.reshape((batch_size, -1, num_v_heads))[:, :num_tokens] + if cu_seqlens is not None: + dq = pack(dq, cu_seqlens) + dk = pack(dk, cu_seqlens) + dw = pack(dw, cu_seqlens) + dg = pack(dg, cu_seqlens) + return dq, dk, dw, dg + + +def torch_chunk_wy_bwd( + k: torch.Tensor, # [B, T, Hk, K] + v: torch.Tensor, # [B, T, Hv, V] + beta: torch.Tensor, # [B, T, Hv] + A: torch.Tensor, # [B, T, Hv, D] + g: torch.Tensor, # [B, T, Hv] + dw: torch.Tensor, # [B, T, Hv, K] + du: torch.Tensor, # [B, T, Hv, V] + dk1: torch.Tensor, # [B, T, Hv, K] + dg1: torch.Tensor, # [B, T, Hv] + cu_seqlens: torch.Tensor = None, +): + if cu_seqlens is not None: + k = unpack(k, cu_seqlens) + v = unpack(v, cu_seqlens) + beta = unpack(beta, cu_seqlens) + A = unpack(A, cu_seqlens) + g = unpack(g, cu_seqlens) + dw = unpack(dw, cu_seqlens) + du = unpack(du, cu_seqlens) + dk1 = unpack(dk1, cu_seqlens) + dg1 = unpack(dg1, cu_seqlens) + + batch_size, num_tokens, num_k_heads, head_dim_k = k.shape + _, _, num_v_heads, head_dim_v = v.shape + chunk_size = A.shape[-1] + + if num_k_heads != num_v_heads: + k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2) + + k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K] + v = pad_and_reshape(v, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V] + beta = pad_and_reshape(beta, dim=1, chunk_size=chunk_size) # [B, N, C, Hv] + A = pad_and_reshape(A, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, D] + g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, Hv] + dw = pad_and_reshape(dw, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K] + du = pad_and_reshape(du, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V] + dk1 = pad_and_reshape(dk1, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K] + dg1 = pad_and_reshape(dg1, dim=1, chunk_size=chunk_size) # [B, N, C, Hv] + + dA = torch.einsum("bnchk, bndhk -> bnchd", dw, k * (beta * g.exp()).unsqueeze(-1)) + dk_beta_g = torch.einsum("bnchd, bnchk -> bndhk", A, dw) + dk = dk_beta_g * (beta * g.exp()).unsqueeze(-1) + db = (dk_beta_g * k * g.exp().unsqueeze(-1)).sum(dim=-1) + dg = (dk_beta_g * k * (g.exp() * beta).unsqueeze(-1)).sum(dim=-1) + + dA += torch.einsum("bnchv, bndhv -> bnchd", du, v * beta.unsqueeze(-1)) + dv_beta = torch.einsum("bnchd, bnchv -> bndhv", A, du) + dv = dv_beta * beta.unsqueeze(-1) + db += (dv_beta * v).sum(dim=-1) + + mask = torch.triu( + torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device) + ) + decay_mask = torch.exp(g[:, :, :, None, :] - g[:, :, None, :, :]) + decay_mask = decay_mask.masked_fill(mask[None, None, :, :, None], 0.0).swapaxes( + -2, -1 + ) + dA = dA.masked_fill(mask[None, None, :, None, :], 0.0) + dA = torch.einsum("bndhc, bndhe -> bnche", A, dA) + dA = torch.einsum("bnchd, bnehd -> bnche", dA, A) + dA = -dA * decay_mask + + A = torch.einsum("bnchk, bndhk -> bnchd", k * beta.unsqueeze(-1), k) + dk_beta = torch.einsum("bnchd, bndhk -> bnchk", dA, k) + db += (dk_beta * k).sum(dim=-1) + dk += torch.einsum("bnchd, bnchk -> bndhk", dA, k * beta.unsqueeze(-1)) + dk += dk_beta * beta.unsqueeze(-1) + dk += dk1 + + dg += (dA * A).sum(dim=-1) - (dA * A).sum(dim=-3).swapaxes(-1, -2) + dg += dg1 + + # TODO: NOTE: GVA + dk = dk.reshape((batch_size, -1, num_v_heads, head_dim_k))[:, :num_tokens] + dv = dv.reshape((batch_size, -1, num_v_heads, head_dim_k))[:, :num_tokens] + db = db.reshape((batch_size, -1, num_v_heads))[:, :num_tokens] + dg = dg.reshape((batch_size, -1, num_v_heads))[:, :num_tokens] + if cu_seqlens is not None: + dk = pack(dk, cu_seqlens) + dv = pack(dv, cu_seqlens) + db = pack(db, cu_seqlens) + dg = pack(dg, cu_seqlens) + return dk, dv, db, dg + + +def chunk_gated_delta_rule_fwd( + q: torch.Tensor, # [B, T, Hk, K] + k: torch.Tensor, # [B, T, Hk, K] + v: torch.Tensor, # [B, T, Hv, K] + g: torch.Tensor, # [B, T, Hv] + beta: torch.Tensor, # [B, T, Hv] + cu_seqlens: torch.Tensor = None, + initial_state: torch.Tensor = None, + scale: float = None, + chunk_size: int = 64, +): + scale = scale or q.shape[-1] ** (-0.5) + g = torch_cumsum( + x=g, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + ) + A = torch_kkt_fwd( + k=k, + g=g, + beta=beta, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + ) + A = torch_solve( + x=A, + cu_seqlens=cu_seqlens, + ) + w, u = torch_w_u_fwd( + k=k, + v=v, + beta=beta, + A=A, + g=g, + cu_seqlens=cu_seqlens, + ) + h, vn, final_state = torch_chunk_gdr_fwd( + k=k, + w=w, + u=u, + g=g, + initial_state=initial_state, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + ) + o = torch_chunk_o_fwd( + q=q, + k=k, + v=vn, + h=h, + g=g, + cu_seqlens=cu_seqlens, + scale=scale, + chunk_size=chunk_size, + ) + return g, o, A, h, final_state + + +def chunk_gated_delta_rule_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + A: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + do: torch.Tensor, + dht: torch.Tensor, + cu_seqlens: torch.Tensor = None, + chunk_size: int = 64, +): + w, u = torch_w_u_fwd( + k=k, + v=v, + beta=beta, + A=A, + g=g, + cu_seqlens=cu_seqlens, + ) + h, vn, _ = torch_chunk_gdr_fwd( + k=k, + w=w, + u=u, + g=g, + initial_state=initial_state, + cu_seqlens=cu_seqlens, + chunk_size=chunk_size, + ) + dv = torch_chunk_dv_bwd( + q=q, + k=k, + g=g, + do=do, + scale=scale, + cu_seqlens=cu_seqlens, + ) + dh, dh0, dv = torch_chunk_gdr_bwd( + q=q, + k=k, + w=w, + g=g, + h0=initial_state, + dht=dht, + do=do, + dv=dv, + scale=scale, + cu_seqlens=cu_seqlens, + ) + dq, dk1, dw, dg1 = torch_chunk_dqkwg_bwd( + q=q, + k=k, + v=vn, + w=w, + g=g, + h=h, + dv=dv, + do=do, + dh=dh, + scale=scale, + cu_seqlens=cu_seqlens, + ) + dk, dv, db, dg = torch_chunk_wy_bwd( + k=k, + v=v, + beta=beta, + g=g, + A=A, + dw=dw, + du=dv, + dk1=dk1, + dg1=dg1, + cu_seqlens=cu_seqlens, + ) + Hg, H = k.shape[-2], v.shape[-2] + if Hg < H: + B, T, _, K = dq.shape + dq = torch.sum(dq.reshape(B, T, Hg, -1, K), dim=3) + dk = torch.sum(dk.reshape(B, T, Hg, -1, K), dim=3) + dg = torch_cumsum(dg, chunk_size=64, reverse=True, cu_seqlens=cu_seqlens) + return dq, dk, dv, db, dg, dh0 diff --git a/tests/settings/develop.csv b/tests/settings/develop.csv new file mode 100644 index 0000000..56b190e --- /dev/null +++ b/tests/settings/develop.csv @@ -0,0 +1,2 @@ +batch_size,num_tokens,varlen +1,32768,False diff --git a/tests/settings/product.csv b/tests/settings/product.csv new file mode 100644 index 0000000..7a79f0e --- /dev/null +++ b/tests/settings/product.csv @@ -0,0 +1,9 @@ +batch_size,num_tokens,varlen,cu_seqlens +1,16384,True,0-268-1139-1179-1212-1476-2792-3202-3611-3726-3820-4096-4882-6417-8130-8192-8328-9426-10473-11002-11754-12288-14085-15370-16384 +1,16384,True,0-3393-4096-4153-5636-5853-6318-6777-8192-8320-8931-9163-9494-10040-10113-10363-10561-11061-11388-11634-12288-14545-16288-16384 +1,16384,True,0-4096-6111-6485-6589-7118-8192-9056-10448-12288-14032-14525-15884-16012-16384 +1,16384,True,0-177-4096-8192-12288-12805-13171-13298-16055-16384 +1,16384,True,0-308-1128-1678-4096-4748-8192-8506-9657-10252-12113-12288-16384 +1,16384,True,0-4096-6893-7665-8192-12288-16384 +1,16384,True,0-410-841-1135-2126-2512-4096-4682-5022-5375-6259-6335-6580-6648-7308-8192-10450-12058-12288-14215-15280-15701-16384 +1,16384,True,0-2048-4096-6144-8192-10240-12288-14336-16384 diff --git a/tests/settings/profile.csv b/tests/settings/profile.csv new file mode 100644 index 0000000..8ed51f6 --- /dev/null +++ b/tests/settings/profile.csv @@ -0,0 +1,5 @@ +batch_size,num_tokens,varlen +1,4096,False +1,8192,False +1,16384,False +1,32768,False diff --git a/tests/settings/varlen.csv b/tests/settings/varlen.csv new file mode 100644 index 0000000..ab86abc --- /dev/null +++ b/tests/settings/varlen.csv @@ -0,0 +1,7 @@ +batch_size,num_tokens,varlen +11,33,False +7,4321,False +3,16789,True +5,8192,True +10,1024,True +20,512,True diff --git a/tests/test_function_signature.py b/tests/test_function_signature.py new file mode 100644 index 0000000..e35a714 --- /dev/null +++ b/tests/test_function_signature.py @@ -0,0 +1,71 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] + +"""Static checks for autograd ``Function`` signatures. + +PyTorch validates that ``Function.backward`` returns exactly as many gradients +as ``Function.forward`` received non-``ctx`` inputs; mismatches raise at +``.backward()`` time. ``tests/test_gdr.py`` invokes ``chunk_gated_delta_rule_fwd`` +and ``chunk_gated_delta_rule_bwd`` directly, bypassing the autograd path, so +drift between the forward signature and the backward return tuple goes +uncaught by the existing suite. + +These tests parse the source files with ``ast`` instead of importing the +modules so they run on CPU-only / non-Hopper machines. +""" + +import ast +from pathlib import Path + + +REPO_ROOT = Path(__file__).resolve().parents[1] +CHUNK_INIT = "flash_qla/ops/gated_delta_rule/chunk/__init__.py" + + +def _parse(rel_path: str) -> ast.Module: + return ast.parse((REPO_ROOT / rel_path).read_text(encoding="utf-8")) + + +def _get_class(module: ast.Module, name: str) -> ast.ClassDef: + for node in module.body: + if isinstance(node, ast.ClassDef) and node.name == name: + return node + raise AssertionError(f"class {name!r} not found") + + +def _get_method(cls: ast.ClassDef, name: str) -> ast.FunctionDef: + for node in cls.body: + if isinstance(node, ast.FunctionDef) and node.name == name: + return node + raise AssertionError(f"method {name!r} not found on {cls.name}") + + +def test_chunk_gated_delta_rule_grad_count_matches_forward_inputs(): + """``backward`` must return one gradient per non-``ctx`` input of ``forward``.""" + module = _parse(CHUNK_INIT) + cls = _get_class(module, "ChunkGatedDeltaRuleFunction") + + fwd = _get_method(cls, "forward") + fwd_args = fwd.args.args + assert fwd_args and fwd_args[0].arg == "ctx", ( + "forward must take `ctx` as its first argument" + ) + n_inputs = len(fwd_args) - 1 # exclude ctx + + bwd = _get_method(cls, "backward") + returns = [n for n in ast.walk(bwd) if isinstance(n, ast.Return)] + assert len(returns) == 1, f"expected one Return in backward, got {len(returns)}" + assert isinstance(returns[0].value, ast.Tuple), ( + "backward must return a tuple literal" + ) + n_grads = len(returns[0].value.elts) + + assert n_inputs == n_grads, ( + f"backward returns {n_grads} gradients but forward takes {n_inputs} non-ctx " + f"inputs; PyTorch will raise a count-mismatch error at .backward() time." + ) + + +if __name__ == "__main__": + test_chunk_gated_delta_rule_grad_count_matches_forward_inputs() + print("OK") diff --git a/tests/test_gdr.py b/tests/test_gdr.py new file mode 100644 index 0000000..0582090 --- /dev/null +++ b/tests/test_gdr.py @@ -0,0 +1,553 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] + +import argparse +import math + +import torch +import pandas as pd + +# Requires flash-linear-attention==0.5.0 +from fla.ops.gated_delta_rule.chunk import ( + chunk_gated_delta_rule_fwd as chunk_gated_delta_rule_fwd_fla, +) +from fla.ops.gated_delta_rule.chunk import ( + chunk_gated_delta_rule_bwd as chunk_gated_delta_rule_bwd_fla, +) + +from flash_qla import chunk_gated_delta_rule_fwd as chunk_gated_delta_rule_fwd_qla +from flash_qla import chunk_gated_delta_rule_bwd as chunk_gated_delta_rule_bwd_qla +from flash_qla.utils import l2norm, pack, profile + +from ref_gdr import chunk_gated_delta_rule_fwd as chunk_gated_delta_rule_fwd_ref +from ref_gdr import chunk_gated_delta_rule_bwd as chunk_gated_delta_rule_bwd_ref + + +def test_gated_delta_rule( + batch_size: int, + num_tokens: int, + num_k_heads: int, + num_v_heads: int, + head_dim_k: int, + head_dim_v: int, + varlen: bool = False, + cu_seqlens: list[int] | None = None, + use_h0: bool = False, + chunk_size: int = 64, + data_dtype: str = "bfloat16", + ref_dtype: str = "float32", + device: torch.device = "cuda", + random_seed: int = 42, + check_accuracy: bool = True, + show_speedup: bool = True, + auto_cp: bool = True, + swa_ratio: float = 0.75, + skip_bwd: bool = False, +): + data_dtype = getattr(torch, data_dtype) + ref_dtype = getattr(torch, ref_dtype) + torch.manual_seed(random_seed) + q = l2norm( + torch.randn( + (batch_size, num_tokens, num_k_heads, head_dim_k), + device=device, + dtype=data_dtype, + ) + ) + k = l2norm( + torch.randn( + (batch_size, num_tokens, num_k_heads, head_dim_k), + device=device, + dtype=data_dtype, + ) + ) + v = torch.randn( + (batch_size, num_tokens, num_v_heads, head_dim_v), + device=device, + dtype=data_dtype, + ) + g = ( + torch.nn.functional.logsigmoid( + torch.randn( + (batch_size, num_tokens, num_v_heads), + device=device, + dtype=torch.float32, + ) + ) + / 16 + ) + beta = torch.randn( + (batch_size, num_tokens, num_v_heads), device=device, dtype=torch.float32 + ).sigmoid() + h0 = ( + torch.randn( + (batch_size, num_v_heads, head_dim_k, head_dim_v), + device=device, + dtype=torch.float32, + ) + if use_h0 + else None + ) + do = torch.randn_like(v) + dht = ( + torch.randn( + (batch_size, num_v_heads, head_dim_k, head_dim_v), + device=device, + dtype=torch.float32, + ) + / 8 + if use_h0 + else None + ) + scale = head_dim_k ** (-0.5) + print( + f"Shape: B={batch_size} Hk={num_k_heads} Hv={num_v_heads} T={num_tokens} VarLen={varlen}" + ) + + swa_mask = torch.zeros((num_v_heads), dtype=torch.bool, device=device) + swa_mask[: math.ceil(swa_ratio * num_v_heads)] = 1 + swa_mask = swa_mask[torch.randperm(num_v_heads, device=device)] + g[:, :, ~swa_mask] = 0.0 + print(f"SWA Mask: {swa_mask.to(torch.int32, copy=True).tolist()}") + + if varlen: + if cu_seqlens is None: + cu_seqlens = torch.randint( + 1, num_tokens, (batch_size,), device=device, dtype=torch.int32 + ) + cu_seqlens = torch.nn.functional.pad( + torch.cumsum(cu_seqlens, dim=-1), (1, 0) + ) + q = pack(q, cu_seqlens) + k = pack(k, cu_seqlens) + v = pack(v, cu_seqlens) + g = pack(g, cu_seqlens) + beta = pack(beta, cu_seqlens) + do = pack(do, cu_seqlens) + else: + assert batch_size == 1 + assert cu_seqlens[0] == 0 + assert cu_seqlens[-1] == num_tokens + cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32) + if use_h0: + real_batch_size = cu_seqlens.shape[0] - 1 + h0 = torch.randn( + (real_batch_size, num_v_heads, head_dim_k, head_dim_v), + device=device, + dtype=torch.float32, + ) + dht = ( + torch.randn( + (real_batch_size, num_v_heads, head_dim_k, head_dim_v), + device=device, + dtype=torch.float32, + ) + / 8 + ) + assert (cu_seqlens[1:] - cu_seqlens[:-1]).min() > 0 + else: + cu_seqlens = None + + g_ref, o_ref, A_ref, h_ref, s_ref = chunk_gated_delta_rule_fwd_ref( + q=q.to(ref_dtype, copy=True), + k=k.to(ref_dtype, copy=True), + v=v.to(ref_dtype, copy=True), + g=g.to(ref_dtype, copy=True), + beta=beta.to(ref_dtype, copy=True), + scale=scale, + initial_state=h0, + cu_seqlens=cu_seqlens, + ) + g_fla, o_fla, A_fla, s_fla, _, _ = chunk_gated_delta_rule_fwd_fla( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens, + ) + g_qla, A_qla, o_qla, h_qla, s_qla = chunk_gated_delta_rule_fwd_qla( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + initial_state=h0, + cu_seqlens=cu_seqlens, + output_final_state=True, + output_h=True, + auto_cp=auto_cp, + ) + + if check_accuracy: + print( + f"h_qla: {(h_qla - h_ref).abs().max().item():.4f} / {h_ref.abs().max().item():.4f}" + ) + print( + f"s_fla: {(s_fla - s_ref).abs().max().item():.4f} / {s_ref.abs().max().item():.4f}" + ) + print( + f"s_qla: {(s_qla - s_ref).abs().max().item():.4f} / {s_ref.abs().max().item():.4f}" + ) + print( + f"o_fla: {(o_fla - o_ref).abs().max().item():.4f} / {o_ref.abs().max().item():.4f}" + ) + print( + f"o_qla: {(o_qla - o_ref).abs().max().item():.4f} / {o_ref.abs().max().item():.4f}" + ) + + for _ in range(1000): + g_qla, A_qla, o_qla, h_qla, s_qla = chunk_gated_delta_rule_fwd_qla( + q, + k, + v, + g, + beta, + scale, + h0, + cu_seqlens, + True, + False, + auto_cp, + ) + try: + if h0 is not None: + assert ( + s_qla - s_ref + ).abs().max().item() <= s_ref.abs().max().item() * 0.02 + assert ( + o_qla - o_ref + ).abs().max().item() <= o_ref.abs().max().item() * 0.02 + except AssertionError as e: + print("********** ERROR **********") + if h0 is not None: + print( + f"s_qla: {(s_qla - s_ref).abs().max().item():.4f} / {s_ref.abs().max().item():.4f}" + ) + print( + f"o_qla: {(o_qla - o_ref).abs().max().item():.4f} / {o_ref.abs().max().item():.4f}" + ) + print("********** ERROR **********") + raise e + + if show_speedup: + prof_fla = profile( + chunk_gated_delta_rule_fwd_fla, + [q, k, v, g, beta, scale, h0, True, cu_seqlens], + ) + prof_qla = profile( + chunk_gated_delta_rule_fwd_qla, + [q, k, v, g, beta, scale, h0, cu_seqlens, True, False, auto_cp], + ) + result_fla = { + "[fwd] csum": prof_fla["chunk_local_cumsum_scalar_kernel"], + "[fwd] solve": prof_fla["chunk_gated_delta_rule_fwd_kkt_solve_kernel"], + "[fwd] wu": prof_fla["recompute_w_u_fwd_kernel"], + "[fwd] gdr": prof_fla["chunk_gated_delta_rule_fwd_kernel_h_blockdim64"], + "[fwd] o": prof_fla["chunk_fwd_kernel_o"], + } + result_qla = { + "[fwd] csum": prof_qla["tilelang_chunk_local_cumsum_kernel_kernel"], + "[fwd] solve": prof_qla["tilelang_kkt_solve_kernel_kernel"], + "[fwd] gdr": prof_qla["tilelang_fused_chunk_gdr_fwd_kernel_kernel"], + } + if "tilelang_get_warmup_chunks_kernel_kernel" in prof_qla.keys(): + result_fla["[fwd] cp-w"] = None + result_fla["[fwd] cp-h"] = None + result_fla["[fwd] cp-c"] = None + result_qla["[fwd] cp-w"] = prof_qla[ + "tilelang_get_warmup_chunks_kernel_kernel" + ] + result_qla["[fwd] cp-h"] = prof_qla["tilelang_prepare_h_kernel_kernel"] + result_qla["[fwd] cp-c"] = prof_qla["tilelang_correct_h0_kernel_kernel"] + result_fla["total"] = prof_fla["total"] + result_qla["total"] = prof_qla["total"] + results = { + "fla": result_fla, + "flash_qla": result_qla, + } + df = pd.DataFrame(results) + print(df.round(3)) + speedup = results["fla"]["total"] / results["flash_qla"]["total"] + print(f"Speed up: {speedup:.2f}x") + + if skip_bwd: + return + + dq_ref, dk_ref, dv_ref, db_ref, dg_ref, dh0_ref = chunk_gated_delta_rule_bwd_ref( + q.to(ref_dtype, copy=True), + k.to(ref_dtype, copy=True), + v.to(ref_dtype, copy=True), + g_ref, + beta.to(ref_dtype, copy=True), + A_ref.to(ref_dtype, copy=True), + scale, + h0, + do.to(ref_dtype, copy=True), + dht, + cu_seqlens, + ) + dq_fla, dk_fla, dv_fla, db_fla, dg_fla, dh0_fla, _, _ = ( + chunk_gated_delta_rule_bwd_fla( + q, + k, + v, + g_fla, + beta, + A_fla, + scale, + h0, + do, + dht, + cu_seqlens, + ) + ) + dq_qla, dk_qla, dv_qla, db_qla, dg_qla, dh0_qla = chunk_gated_delta_rule_bwd_qla( + q, + k, + v, + g_qla, + beta, + A_qla, + do, + dht, + scale, + h0, + cu_seqlens, + ) + + if check_accuracy: + print( + f"dq_fla: {(dq_fla - dq_ref).abs().max().item():.4f} / {dq_ref.abs().max().item():.4f}" + ) + print( + f"dq_qla: {(dq_qla - dq_ref).abs().max().item():.4f} / {dq_ref.abs().max().item():.4f}" + ) + print( + f"dk_fla: {(dk_fla - dk_ref).abs().max().item():.4f} / {dk_ref.abs().max().item():.4f}" + ) + print( + f"dk_qla: {(dk_qla - dk_ref).abs().max().item():.4f} / {dk_ref.abs().max().item():.4f}" + ) + print( + f"dv_fla: {(dv_fla - dv_ref).abs().max().item():.4f} / {dv_ref.abs().max().item():.4f}" + ) + print( + f"dv_qla: {(dv_qla - dv_ref).abs().max().item():.4f} / {dv_ref.abs().max().item():.4f}" + ) + if dht is not None: + print( + f"dh0_fla: {(dh0_fla - dh0_ref).abs().max().item():.4f} / {dh0_ref.abs().max().item():.4f}" + ) + print( + f"dh0_qla: {(dh0_qla - dh0_ref).abs().max().item():.4f} / {dh0_ref.abs().max().item():.4f}" + ) + print( + f"db_fla: {(db_fla - db_ref).abs().max().item():.4f} / {db_ref.abs().max().item():.4f}" + ) + print( + f"db_qla: {(db_qla - db_ref).abs().max().item():.4f} / {db_ref.abs().max().item():.4f}" + ) + print( + f"dg_fla: {(dg_fla - dg_ref).abs().max().item():.4f} / {dg_ref.abs().max().item():.4f}" + ) + print( + f"dg_qla: {(dg_qla - dg_ref).abs().max().item():.4f} / {dg_ref.abs().max().item():.4f}" + ) + + for _ in range(1000): + dq_qla, dk_qla, dv_qla, db_qla, dg_qla, dh0_qla = ( + chunk_gated_delta_rule_bwd_qla( + q, + k, + v, + g_qla, + beta, + A_qla, + do, + dht, + scale, + h0, + cu_seqlens, + ) + ) + try: + assert ( + dq_qla - dq_ref + ).abs().max().item() <= dq_ref.abs().max().item() * 0.02 + assert ( + dk_qla - dk_ref + ).abs().max().item() <= dk_ref.abs().max().item() * 0.02 + assert ( + dv_qla - dv_ref + ).abs().max().item() <= dv_ref.abs().max().item() * 0.02 + assert ( + dg_qla - dg_ref + ).abs().max().item() <= dg_ref.abs().max().item() * 0.02 + assert ( + db_qla - db_ref + ).abs().max().item() <= db_ref.abs().max().item() * 0.02 + if dht is not None: + assert ( + dh0_qla - dh0_ref + ).abs().max().item() <= dh0_ref.abs().max().item() * 0.02 + except AssertionError as e: + print("********** ERROR **********") + print( + f"dq_qla: {(dq_qla - dq_ref).abs().max().item():.4f} / {dq_ref.abs().max().item():.4f}" + ) + print( + f"dk_qla: {(dk_qla - dk_ref).abs().max().item():.4f} / {dk_ref.abs().max().item():.4f}" + ) + print( + f"dv_qla: {(dv_qla - dv_ref).abs().max().item():.4f} / {dv_ref.abs().max().item():.4f}" + ) + if dht is not None: + print( + f"dh0_qla: {(dh0_qla - dh0_ref).abs().max().item():.4f} / {dh0_ref.abs().max().item():.4f}" + ) + print( + f"db_qla: {(db_qla - db_ref).abs().max().item():.4f} / {db_ref.abs().max().item():.4f}" + ) + print( + f"dg_qla: {(dg_qla - dg_ref).abs().max().item():.4f} / {dg_ref.abs().max().item():.4f}" + ) + print("********** ERROR **********") + raise e + + if show_speedup: + prof_fla = profile( + chunk_gated_delta_rule_bwd_fla, + [q, k, v, g_fla, beta, A_fla, scale, h0, do, dht, cu_seqlens], + ) + prof_qla = profile( + chunk_gated_delta_rule_bwd_qla, + [q, k, v, g_qla, beta, A_qla, do, dht, scale, h0, cu_seqlens], + ) + result_fla = { + "[bwd] csum": prof_fla["chunk_local_cumsum_scalar_kernel"], + "[bwd] recom": prof_fla["recompute_w_u_fwd_kernel"] + + prof_fla["chunk_gated_delta_rule_fwd_kernel_h_blockdim64"], + "[bwd] dv": prof_fla["chunk_bwd_kernel_dv_local"], + "[bwd] gdr": prof_fla["chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64"], + "[bwd] dqkwg": prof_fla["kernel_kernel"], + "[bwd] wy": prof_fla["prepare_wy_repr_bwd_kernel"], + } + result_qla = { + "[bwd] csum": prof_qla["tilelang_chunk_local_cumsum_kernel_kernel"], + "[bwd] recom": prof_qla["tilelang_prepare_h_kernel_kernel"], + "[bwd] gdr": prof_qla["tilelang_fused_chunk_gdr_bwd_kernel_kernel"], + } + if num_k_heads < num_v_heads: + result_fla["[bwd] reduc"] = prof_fla["compress_heads_kernel"] + result_qla["[bwd] reduc"] = ( + prof_qla["tilelang_group_reduce_vector_kernel_kernel"] * 2 + ) + result_fla["total"] = prof_fla["total"] + result_qla["total"] = prof_qla["total"] + results = { + "fla": result_fla, + "flash_qla": result_qla, + } + df = pd.DataFrame(results) + print(df.round(3)) + speedup = results["fla"]["total"] / results["flash_qla"]["total"] + print(f"Speed up: {speedup:2.2f}x") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Test Gated Delta Rule") + parser.add_argument( + "--set", + type=str, + default="develop", + help="Preset name (loads from settings/{set}.csv)", + ) + parser.add_argument( + "--seqlen", "--num-tokens", type=int, default=16384, help="Sequence Length" + ) + parser.add_argument( + "--nkh", + "--num-k-heads", + type=int, + default=0, + help="Number of K heads (num_k_heads)", + ) + parser.add_argument( + "--nvh", + "--num-heads", + "--num-v-heads", + type=int, + default=64, + help="Number of V heads (num_v_heads)", + ) + parser.add_argument( + "--no-h0", + action="store_true", + help="Disable initial state and gradient of final state", + ) + parser.add_argument("--skip-bwd", action="store_true", help="Test forward only") + parser.add_argument( + "--no-cp", + "--disable-auto-cp", + action="store_true", + help="Disable auto intra-card CP", + ) + parser.add_argument( + "--swa-ratio", type=float, default=0.75, help="Ratio of sliding-window heads" + ) + parser.add_argument( + "--data-dtype", + type=str, + default="bfloat16", + help="Data type for input and output", + ) + parser.add_argument( + "--ref-dtype", type=str, default="float64", help="Data type for reference" + ) + parser.add_argument("--hide-acc", action="store_true", help="Do not print accuracy") + parser.add_argument("--hide-lat", action="store_true", help="Do not print latency") + parser.add_argument( + "--seed", "--random-seed", type=int, default=42, help="Random seed" + ) + args = parser.parse_args() + + if args.nkh <= 0: + args.nkh = args.nvh + + metadata = { + "head_dim_k": 128, # MUST BE 128 + "head_dim_v": 128, # MUST BE 128 + "chunk_size": 64, # MUST BE 64 + "num_tokens": args.seqlen, + "num_k_heads": args.nkh, + "num_v_heads": args.nvh, + "use_h0": not args.no_h0, + "data_dtype": args.data_dtype, + "ref_dtype": args.ref_dtype, + "check_accuracy": not args.hide_acc, + "show_speedup": not args.hide_lat, + "skip_bwd": args.skip_bwd, + "auto_cp": not args.no_cp, + "swa_ratio": args.swa_ratio, + "random_seed": args.seed, + "device": "cuda", + } + + import os + + script_dir = os.path.dirname(os.path.abspath(__file__)) + preset = pd.read_csv(os.path.join(script_dir, "settings", f"{args.set}.csv")) + for i, row in preset.iterrows(): + print("-" * 64) + torch.cuda.empty_cache() + data = row.to_dict() + if "cu_seqlens" in data.keys(): + data["cu_seqlens"] = list(map(int, data["cu_seqlens"].split("-"))) + metadata.update(data) + test_gated_delta_rule(**metadata) + print("-" * 64) diff --git a/tests/test_legacy_sm_gdn.py b/tests/test_legacy_sm_gdn.py new file mode 100644 index 0000000..a730b18 --- /dev/null +++ b/tests/test_legacy_sm_gdn.py @@ -0,0 +1,57 @@ +# Copyright (c) 2026 The Qwen team, Alibaba Group. +# Licensed under The MIT License [see LICENSE for details] + +import math + +import pytest +import torch + +from flash_qla.ops.gated_delta_rule.legacy import chunk_gated_delta_rule_fwd_legacy + + +def _reference(q, k, v, g, beta, scale=None, initial_state=None): + batch, tokens, q_heads, dim = q.shape + v_heads = v.shape[2] + scale = scale if scale is not None else dim**-0.5 + state = ( + initial_state.clone() + if initial_state is not None + else torch.zeros(batch, v_heads, dim, dim, device=q.device, dtype=q.dtype) + ) + output = torch.empty_like(v) + for b in range(batch): + for hv in range(v_heads): + hq = hv // (v_heads // q_heads) + for t in range(tokens): + gate = torch.exp(g[b, t, hv]) + delta = (v[b, t, hv] - gate * (state[b, hv].transpose(0, 1) @ k[b, t, hq])) * beta[b, t, hv] + state[b, hv] = gate * state[b, hv] + torch.outer(k[b, t, hq], delta) + output[b, t, hv] = scale * (state[b, hv].transpose(0, 1) @ q[b, t, hq]) + return output, state + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") +@pytest.mark.parametrize("dim", [16, 32, 64, 128]) +def test_legacy_sm_gdn_matches_reference(dim): + torch.manual_seed(1000 + dim) + q = torch.randn(1, 5, 2, dim, device="cuda", dtype=torch.float32).contiguous() * 0.05 + k = torch.randn_like(q).contiguous() * 0.05 + v = torch.randn(1, 5, 4, dim, device="cuda", dtype=torch.float32).contiguous() * 0.1 + g = torch.randn(1, 5, 4, device="cuda", dtype=torch.float32).contiguous() * 0.02 - 0.04 + beta = torch.rand(1, 5, 4, device="cuda", dtype=torch.float32).contiguous() + h0 = torch.randn(1, 4, dim, dim, device="cuda", dtype=torch.float32).contiguous() * 0.01 + scale = 1.0 / math.sqrt(dim) + + out_ref, state_ref = _reference(q, k, v, g, beta, scale, h0) + out, state = chunk_gated_delta_rule_fwd_legacy(q, k, v, g, beta, scale, h0) + torch.cuda.synchronize() + + torch.testing.assert_close(out, out_ref, atol=2e-4, rtol=2e-4) + torch.testing.assert_close(state, state_ref, atol=1e-3, rtol=1e-3) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") +def test_legacy_sm_gdn_rejects_unsupported_dtype(): + q = torch.randn(1, 1, 1, 16, device="cuda", dtype=torch.float16) + with pytest.raises(ValueError, match="float32"): + chunk_gated_delta_rule_fwd_legacy(q, q, q, torch.randn(1, 1, 1, device="cuda"), torch.randn(1, 1, 1, device="cuda"))