first commit

This commit is contained in:
2026-06-14 23:49:03 +08:00
commit 3f95e2939d
35 changed files with 6764 additions and 0 deletions

View File

@@ -0,0 +1,585 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
#
# Benchmark Script for FlashQLA
import argparse
import math
import gc
from dataclasses import dataclass
from typing import List, Tuple, Optional, Dict, Any
import torch
import torch.nn.functional as F
import tilelang
# Kernel Imports
from fla.ops.gated_delta_rule.chunk import (
chunk_gated_delta_rule_fwd as fla_fwd,
chunk_gated_delta_rule_bwd as fla_bwd,
)
from flash_qla import (
chunk_gated_delta_rule_fwd as qla_fwd,
chunk_gated_delta_rule_bwd as qla_bwd,
)
from flash_qla.utils import l2norm
try:
from flashinfer.gdn_prefill import chunk_gated_delta_rule as fi_fwd
HAS_FI = True
except ImportError:
HAS_FI = False
HEAD_DIM = 128
BWD_SPLIT_SIZE = 8
@dataclass
class ModelConfig:
label: str
h_qk: int
h_v: int
@dataclass
class SeqLenConfig:
label: str
seqlens: List[int]
def generate_rand_seqlens(batch_size, num_tokens):
bars = (
torch.sort(
torch.randperm(num_tokens - 1, device="cuda", dtype=torch.int32)[
: batch_size - 1
]
).values
+ 1
)
cu_seqlens = torch.nn.functional.pad(bars, (1, 1))
cu_seqlens[-1] = num_tokens
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
return seqlens.tolist()
FWD_MODEL_CONFIGS = [
ModelConfig("397B/122B TP8", 2, 8),
ModelConfig("397B/122B TP4", 4, 16),
ModelConfig("397B/122B TP2", 8, 32),
ModelConfig("397B/122B TP1", 16, 64),
ModelConfig("35B/9B/4B TP1", 16, 32),
ModelConfig("27B TP2", 8, 24),
ModelConfig("27B TP1", 16, 48),
ModelConfig("2B/0.8B TP1", 16, 16),
ModelConfig("Sym h32", 32, 32),
]
FWD_SEQLEN_CONFIGS = [
SeqLenConfig("1x32768", [32768]),
SeqLenConfig("1x16384", [16384]),
SeqLenConfig("1x8192", [8192]),
SeqLenConfig("1x4096", [4096]),
SeqLenConfig("1x2048", [2048]),
SeqLenConfig("28672+4096", [28672, 4096]),
SeqLenConfig("24576+8192", [24576, 8192]),
SeqLenConfig("16384+16384", [16384, 16384]),
SeqLenConfig("8192+24576", [8192, 24576]),
SeqLenConfig("4096+28672", [4096, 28672]),
SeqLenConfig("12288+4096", [12288, 4096]),
SeqLenConfig("6144+2048", [6144, 2048]),
SeqLenConfig("4096+4096", [4096, 4096]),
SeqLenConfig("2048+6144", [2048, 6144]),
SeqLenConfig("1024+7168", [1024, 7168]),
SeqLenConfig("8192x4", [8192] * 4),
SeqLenConfig("4096x8", [4096] * 8),
SeqLenConfig("2048x4", [2048] * 4),
SeqLenConfig("1024x8", [1024] * 8),
]
BWD_MODEL_CONFIGS = [
ModelConfig("32", 32, 32),
ModelConfig("48", 48, 48),
ModelConfig("64", 64, 64),
]
BWD_SEQLEN_CONFIGS = [
SeqLenConfig("16k", generate_rand_seqlens(8, 16384)),
SeqLenConfig("32k", generate_rand_seqlens(8, 32768)),
SeqLenConfig("64k", generate_rand_seqlens(8, 65536)),
SeqLenConfig("128k", generate_rand_seqlens(8, 131072)),
SeqLenConfig("256k", generate_rand_seqlens(8, 262144)),
]
def cleanup_cuda():
try:
if torch.cuda.is_available():
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()
except Exception:
pass
def get_lib_versions() -> Dict[str, str]:
"""Collects version strings for relevant libraries."""
versions = {}
# Torch
try:
versions["torch"] = torch.__version__
except Exception:
versions["torch"] = "N/A"
# Flash Linear Attention (FLA)
try:
import fla
if hasattr(fla, "__version__"):
versions["fla"] = fla.__version__
else:
versions["fla"] = "Installed (ver unknown)"
except ImportError:
versions["fla"] = "Not Installed"
# FlashInfer
try:
import flashinfer
if hasattr(flashinfer, "__version__"):
versions["flashinfer"] = flashinfer.__version__
elif hasattr(flashinfer, "version"):
versions["flashinfer"] = str(flashinfer.version)
else:
versions["flashinfer"] = "Installed (ver unknown)"
except ImportError:
versions["flashinfer"] = "Not Installed"
# TileLang
try:
import tilelang
if hasattr(tilelang, "__version__"):
versions["tilelang"] = tilelang.__version__
elif hasattr(tilelang, "version"):
versions["tilelang"] = str(tilelang.version)
else:
versions["tilelang"] = "Installed (ver unknown)"
except ImportError:
versions["tilelang"] = "Not Installed"
return versions
def prepare_tensors(
seqlens: List[int], h_qk: int, h_v: int, head_dim: int = HEAD_DIM
) -> Optional[Dict[str, Any]]:
device = "cuda"
num_seqs = len(seqlens)
total_tokens = sum(seqlens)
scale = head_dim ** (-0.5)
offsets = [0]
for s in seqlens:
offsets.append(offsets[-1] + s)
cu_seqlens = torch.tensor(offsets, dtype=torch.int32, device=device)
try:
q = l2norm(
torch.randn(
1, total_tokens, h_qk, head_dim, device=device, dtype=torch.bfloat16
)
)
k = l2norm(
torch.randn(
1, total_tokens, h_qk, head_dim, device=device, dtype=torch.bfloat16
)
)
v = torch.randn(
1, total_tokens, h_v, head_dim, device=device, dtype=torch.bfloat16
)
g = (
F.logsigmoid(
torch.randn(1, total_tokens, h_v, device=device, dtype=torch.float32)
)
/ 16
)
beta = torch.randn(
1, total_tokens, h_v, device=device, dtype=torch.float32
).sigmoid()
h0 = torch.randn(
num_seqs, h_v, head_dim, head_dim, device=device, dtype=torch.float32
)
do = torch.randn_like(v)
dht = (
torch.randn(
num_seqs, h_v, head_dim, head_dim, device=device, dtype=torch.float32
)
/ 8
)
except RuntimeError as e:
if "out of memory" in str(e).lower():
return None
raise e
swa_ratio = 0.75
swa_mask = torch.zeros(h_v, dtype=torch.bool, device=device)
swa_mask[: math.ceil(swa_ratio * h_v)] = True
swa_mask = swa_mask[torch.randperm(h_v, device=device)]
g[:, :, ~swa_mask] = 0.0
return {
"device": device,
"num_seqs": num_seqs,
"total_tokens": total_tokens,
"scale": scale,
"cu_seqlens": cu_seqlens,
"q": q,
"k": k,
"v": v,
"g": g,
"beta": beta,
"h0": h0,
"do": do,
"dht": dht,
}
def bench_fwd(
seqlens: List[int],
h_qk: int,
h_v: int,
head_dim: int = HEAD_DIM,
warmup: int = 10,
repeats: int = 5,
) -> Tuple[float, float, float]:
"""
Run Forward Pass Benchmark.
Returns: (qla_mean_ms, fi_mean_ms, fla_mean_ms)
"""
cleanup_cuda()
data = prepare_tensors(seqlens, h_qk, h_v, head_dim)
if data is None:
return float("nan"), float("nan"), float("nan")
q, k, v, g, beta = data["q"], data["k"], data["v"], data["g"], data["beta"]
h0, scale, cu_seqlens = data["h0"], data["scale"], data["cu_seqlens"]
results = {}
def call_qla_fwd():
qla_fwd(
q,
k,
v,
g,
beta,
scale=scale,
initial_state=h0,
output_final_state=True,
output_h=False,
cu_seqlens=cu_seqlens,
auto_cp=True,
)
try:
mean = tilelang.profiler.do_bench(call_qla_fwd, warmup=warmup, rep=repeats)
results["flash_qla"] = mean
except RuntimeError as e:
print(f"\n[WARN] FlashQLA Fwd failed: {e}")
cleanup_cuda()
results["flash_qla"] = float("nan")
if HAS_FI:
def call_fi_fwd():
fi_fwd(
q=q.view(-1, h_qk, head_dim),
k=k.view(-1, h_qk, head_dim),
v=v.view(-1, h_v, head_dim),
g=g.view(-1, h_v),
beta=beta.view(-1, h_v),
scale=scale,
initial_state=h0,
cu_seqlens=cu_seqlens,
output_final_state=True,
)
try:
mean = tilelang.profiler.do_bench(call_fi_fwd, warmup=warmup, rep=repeats)
results["fi"] = mean
except RuntimeError as e:
print(f"\n[WARN] FI Fwd failed: {e}")
cleanup_cuda()
results["fi"] = float("nan")
else:
results["fi"] = float("nan")
def call_fla_fwd():
fla_fwd(
q,
k,
v,
g,
beta,
scale=scale,
initial_state=h0,
output_final_state=True,
cu_seqlens=cu_seqlens,
)
try:
mean = tilelang.profiler.do_bench(call_fla_fwd, warmup=warmup, rep=repeats)
results["fla"] = mean
except RuntimeError as e:
print(f"\n[WARN] FLA Fwd failed: {e}")
cleanup_cuda()
results["fla"] = float("nan")
try:
torch.cuda.synchronize()
except Exception:
pass
return (
results.get("flash_qla", float("nan")),
results.get("fi", float("nan")),
results.get("fla", float("nan")),
)
def bench_bwd(
seqlens: List[int],
h_qk: int,
h_v: int,
head_dim: int = HEAD_DIM,
warmup: int = 10,
repeats: int = 100,
) -> Tuple[float, float]:
"""
Run Backward Pass Benchmark.
Returns: (qla_mean_ms, fla_mean_ms)
"""
unified_h = h_qk
cleanup_cuda()
data = prepare_tensors(seqlens, unified_h, unified_h, head_dim)
if data is None:
return float("nan"), float("nan")
q, k, v, g, beta = data["q"], data["k"], data["v"], data["g"], data["beta"]
h0, scale, cu_seqlens = data["h0"], data["scale"], data["cu_seqlens"]
do, dht = data["do"], data["dht"]
g_cumsum = None
A = None
# Pre-run FWD to get intermediates
try:
result = qla_fwd(
q,
k,
v,
g,
beta,
scale=scale,
initial_state=h0,
output_final_state=True,
output_h=False,
cu_seqlens=cu_seqlens,
auto_cp=True,
)
if isinstance(result, tuple) and len(result) >= 2:
g_cumsum, A = result[0], result[1]
else:
raise RuntimeError("FlashQLA FWD did not return expected intermediates")
except RuntimeError as e:
print(f"[FWD Error] Failed at seqlens={seqlens}, heads={h_qk}. Error: {e}")
cleanup_cuda()
return float("nan"), float("nan")
results = {}
def call_qla_bwd():
return qla_bwd(
q,
k,
v,
g_cumsum,
beta,
A,
do,
dht,
scale=scale,
initial_state=h0,
cu_seqlens=cu_seqlens,
)
try:
mean = tilelang.profiler.do_bench(call_qla_bwd, warmup=warmup, rep=repeats)
results["flash_qla"] = mean
except RuntimeError as e:
print(f"\n[WARN] FlashQLA Bwd failed: {e}")
cleanup_cuda()
results["flash_qla"] = float("nan")
def call_fla_bwd():
return fla_bwd(q, k, v, g_cumsum, beta, A, scale, h0, do, dht, cu_seqlens)
try:
mean = tilelang.profiler.do_bench(call_fla_bwd, warmup=warmup, rep=repeats)
results["fla"] = mean
except RuntimeError as e:
print(f"\n[WARN] FLA Bwd failed: {e}")
cleanup_cuda()
results["fla"] = float("nan")
return results.get("flash_qla", float("nan")), results.get("fla", float("nan"))
FWD_HDR = (
f"{'Model Config':<16} {'Seqlens':<17} {'h_qk':>5} {'h_v':>5} "
f"{'flash_qla [fwd]':>10} {'FI [fwd]':>10} {'FLA [fwd]':>10} "
f"{'vs FLA':>7} {'vs FI':>7}"
)
BWD_HDR = (
f"{'Heads':<8} {'SeqLen':<15} "
f"{'flash_qla [bwd]':>10} {'FLA [bwd]':>10} {'Speedup':>8}"
)
def fmt_time(ms: float) -> str:
if math.isnan(ms):
return " N/A "
return f"{ms:>8.3f}ms"
def fmt_ratio(base: float, other: float) -> str:
if math.isnan(base) or math.isnan(other) or base == 0:
return " N/A "
return f"{other / base:>6.2f}x"
def main():
parser = argparse.ArgumentParser(description="Benchmark FlashQLA Gated Delta Rule")
parser.add_argument("--warmup", type=int, default=10)
parser.add_argument("--repeats", type=int, default=100)
parser.add_argument("--mode", choices=["fwd", "bwd", "all"], default="all")
parser.add_argument("--skip-fi", action="store_true")
args = parser.parse_args()
if not torch.cuda.is_available():
print("CUDA not available.")
return
global HAS_FI
if args.skip_fi:
HAS_FI = False
gpu_name = torch.cuda.get_device_properties(0).name
print(f"GPU: {gpu_name}")
print("Models: Qwen3.5 family (397B, 122B, 35B, 27B, 9B, 4B, 2B, 0.8B), d=128")
print(f"Config: Warmup={args.warmup}, Repeats={args.repeats}")
libs = get_lib_versions()
print("Library Versions:")
ver_str = " | ".join([f"{k}: {v}" for k, v in libs.items()])
print(f" {ver_str}")
print("=" * 110)
# Forward
if args.mode in ("fwd", "all"):
print("\n>>> FORWARD BENCHMARKS")
print(FWD_HDR)
print("-" * len(FWD_HDR))
prev_model = None
for cfg in FWD_MODEL_CONFIGS:
if prev_model is not None and cfg.label != prev_model:
print()
prev_model = cfg.label
for sl_cfg in FWD_SEQLEN_CONFIGS:
try:
qla_ms, fi_ms, fla_ms = bench_fwd(
sl_cfg.seqlens,
cfg.h_qk,
cfg.h_v,
warmup=args.warmup,
repeats=args.repeats,
)
if math.isnan(qla_ms) and math.isnan(fla_ms):
cleanup_cuda()
ratio_fla = fmt_ratio(qla_ms, fla_ms)
ratio_fi = fmt_ratio(qla_ms, fi_ms)
print(
f"{cfg.label:<16} {sl_cfg.label:<17} {cfg.h_qk:>5} {cfg.h_v:>5} "
f"{fmt_time(qla_ms)} {fmt_time(fi_ms)} {fmt_time(fla_ms)} "
f"{ratio_fla} {ratio_fi}",
flush=True,
)
except Exception as e:
print(
f"\n[ERROR] Forward Case Failed: {cfg.label} / {sl_cfg.label}"
)
print(f"Exception: {e}")
cleanup_cuda()
continue
cleanup_cuda()
# Backward
if args.mode in ("bwd", "all"):
print("\n" + "=" * 110)
print(f"\n>>> BACKWARD BENCHMARKS (Split seq into {BWD_SPLIT_SIZE} sequences)")
print(BWD_HDR)
print("-" * len(BWD_HDR))
prev_model = None
for cfg in BWD_MODEL_CONFIGS:
if prev_model is not None and cfg.label != prev_model:
print()
prev_model = cfg.label
for sl_cfg in BWD_SEQLEN_CONFIGS:
try:
qla_bwd_ms, fla_bwd_ms = bench_bwd(
sl_cfg.seqlens,
cfg.h_qk,
cfg.h_v,
warmup=args.warmup,
repeats=args.repeats,
)
if math.isnan(qla_bwd_ms) and math.isnan(fla_bwd_ms):
speedup = " Skip "
else:
speedup = fmt_ratio(qla_bwd_ms, fla_bwd_ms)
print(
f"{cfg.label:<8} {sl_cfg.label:<15} "
f"{fmt_time(qla_bwd_ms)} {fmt_time(fla_bwd_ms)} {speedup} ",
flush=True,
)
except Exception as e:
print(
f"\n[CRITICAL ERROR] Case Crashed: Heads={cfg.label}, TotalSeqLen={sl_cfg.label}"
)
print(f"Exception: {e}")
cleanup_cuda()
continue
cleanup_cuda()
print("\nBenchmark Finished.")
if __name__ == "__main__":
main()