first commit
This commit is contained in:
585
benchmark/bench_gated_delta_rule.py
Normal file
585
benchmark/bench_gated_delta_rule.py
Normal file
@@ -0,0 +1,585 @@
|
||||
# Copyright (c) 2026 The Qwen team, Alibaba Group.
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
#
|
||||
# Benchmark Script for FlashQLA
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import gc
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple, Optional, Dict, Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import tilelang
|
||||
|
||||
# Kernel Imports
|
||||
from fla.ops.gated_delta_rule.chunk import (
|
||||
chunk_gated_delta_rule_fwd as fla_fwd,
|
||||
chunk_gated_delta_rule_bwd as fla_bwd,
|
||||
)
|
||||
|
||||
from flash_qla import (
|
||||
chunk_gated_delta_rule_fwd as qla_fwd,
|
||||
chunk_gated_delta_rule_bwd as qla_bwd,
|
||||
)
|
||||
from flash_qla.utils import l2norm
|
||||
|
||||
try:
|
||||
from flashinfer.gdn_prefill import chunk_gated_delta_rule as fi_fwd
|
||||
|
||||
HAS_FI = True
|
||||
except ImportError:
|
||||
HAS_FI = False
|
||||
|
||||
HEAD_DIM = 128
|
||||
BWD_SPLIT_SIZE = 8
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
label: str
|
||||
h_qk: int
|
||||
h_v: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class SeqLenConfig:
|
||||
label: str
|
||||
seqlens: List[int]
|
||||
|
||||
|
||||
def generate_rand_seqlens(batch_size, num_tokens):
|
||||
bars = (
|
||||
torch.sort(
|
||||
torch.randperm(num_tokens - 1, device="cuda", dtype=torch.int32)[
|
||||
: batch_size - 1
|
||||
]
|
||||
).values
|
||||
+ 1
|
||||
)
|
||||
cu_seqlens = torch.nn.functional.pad(bars, (1, 1))
|
||||
cu_seqlens[-1] = num_tokens
|
||||
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
return seqlens.tolist()
|
||||
|
||||
|
||||
FWD_MODEL_CONFIGS = [
|
||||
ModelConfig("397B/122B TP8", 2, 8),
|
||||
ModelConfig("397B/122B TP4", 4, 16),
|
||||
ModelConfig("397B/122B TP2", 8, 32),
|
||||
ModelConfig("397B/122B TP1", 16, 64),
|
||||
ModelConfig("35B/9B/4B TP1", 16, 32),
|
||||
ModelConfig("27B TP2", 8, 24),
|
||||
ModelConfig("27B TP1", 16, 48),
|
||||
ModelConfig("2B/0.8B TP1", 16, 16),
|
||||
ModelConfig("Sym h32", 32, 32),
|
||||
]
|
||||
|
||||
FWD_SEQLEN_CONFIGS = [
|
||||
SeqLenConfig("1x32768", [32768]),
|
||||
SeqLenConfig("1x16384", [16384]),
|
||||
SeqLenConfig("1x8192", [8192]),
|
||||
SeqLenConfig("1x4096", [4096]),
|
||||
SeqLenConfig("1x2048", [2048]),
|
||||
SeqLenConfig("28672+4096", [28672, 4096]),
|
||||
SeqLenConfig("24576+8192", [24576, 8192]),
|
||||
SeqLenConfig("16384+16384", [16384, 16384]),
|
||||
SeqLenConfig("8192+24576", [8192, 24576]),
|
||||
SeqLenConfig("4096+28672", [4096, 28672]),
|
||||
SeqLenConfig("12288+4096", [12288, 4096]),
|
||||
SeqLenConfig("6144+2048", [6144, 2048]),
|
||||
SeqLenConfig("4096+4096", [4096, 4096]),
|
||||
SeqLenConfig("2048+6144", [2048, 6144]),
|
||||
SeqLenConfig("1024+7168", [1024, 7168]),
|
||||
SeqLenConfig("8192x4", [8192] * 4),
|
||||
SeqLenConfig("4096x8", [4096] * 8),
|
||||
SeqLenConfig("2048x4", [2048] * 4),
|
||||
SeqLenConfig("1024x8", [1024] * 8),
|
||||
]
|
||||
BWD_MODEL_CONFIGS = [
|
||||
ModelConfig("32", 32, 32),
|
||||
ModelConfig("48", 48, 48),
|
||||
ModelConfig("64", 64, 64),
|
||||
]
|
||||
|
||||
BWD_SEQLEN_CONFIGS = [
|
||||
SeqLenConfig("16k", generate_rand_seqlens(8, 16384)),
|
||||
SeqLenConfig("32k", generate_rand_seqlens(8, 32768)),
|
||||
SeqLenConfig("64k", generate_rand_seqlens(8, 65536)),
|
||||
SeqLenConfig("128k", generate_rand_seqlens(8, 131072)),
|
||||
SeqLenConfig("256k", generate_rand_seqlens(8, 262144)),
|
||||
]
|
||||
|
||||
|
||||
def cleanup_cuda():
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def get_lib_versions() -> Dict[str, str]:
|
||||
"""Collects version strings for relevant libraries."""
|
||||
versions = {}
|
||||
|
||||
# Torch
|
||||
try:
|
||||
versions["torch"] = torch.__version__
|
||||
except Exception:
|
||||
versions["torch"] = "N/A"
|
||||
|
||||
# Flash Linear Attention (FLA)
|
||||
try:
|
||||
import fla
|
||||
|
||||
if hasattr(fla, "__version__"):
|
||||
versions["fla"] = fla.__version__
|
||||
else:
|
||||
versions["fla"] = "Installed (ver unknown)"
|
||||
except ImportError:
|
||||
versions["fla"] = "Not Installed"
|
||||
|
||||
# FlashInfer
|
||||
try:
|
||||
import flashinfer
|
||||
|
||||
if hasattr(flashinfer, "__version__"):
|
||||
versions["flashinfer"] = flashinfer.__version__
|
||||
elif hasattr(flashinfer, "version"):
|
||||
versions["flashinfer"] = str(flashinfer.version)
|
||||
else:
|
||||
versions["flashinfer"] = "Installed (ver unknown)"
|
||||
except ImportError:
|
||||
versions["flashinfer"] = "Not Installed"
|
||||
|
||||
# TileLang
|
||||
try:
|
||||
import tilelang
|
||||
|
||||
if hasattr(tilelang, "__version__"):
|
||||
versions["tilelang"] = tilelang.__version__
|
||||
elif hasattr(tilelang, "version"):
|
||||
versions["tilelang"] = str(tilelang.version)
|
||||
else:
|
||||
versions["tilelang"] = "Installed (ver unknown)"
|
||||
except ImportError:
|
||||
versions["tilelang"] = "Not Installed"
|
||||
|
||||
return versions
|
||||
|
||||
|
||||
def prepare_tensors(
|
||||
seqlens: List[int], h_qk: int, h_v: int, head_dim: int = HEAD_DIM
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
device = "cuda"
|
||||
num_seqs = len(seqlens)
|
||||
total_tokens = sum(seqlens)
|
||||
scale = head_dim ** (-0.5)
|
||||
|
||||
offsets = [0]
|
||||
for s in seqlens:
|
||||
offsets.append(offsets[-1] + s)
|
||||
cu_seqlens = torch.tensor(offsets, dtype=torch.int32, device=device)
|
||||
|
||||
try:
|
||||
q = l2norm(
|
||||
torch.randn(
|
||||
1, total_tokens, h_qk, head_dim, device=device, dtype=torch.bfloat16
|
||||
)
|
||||
)
|
||||
k = l2norm(
|
||||
torch.randn(
|
||||
1, total_tokens, h_qk, head_dim, device=device, dtype=torch.bfloat16
|
||||
)
|
||||
)
|
||||
v = torch.randn(
|
||||
1, total_tokens, h_v, head_dim, device=device, dtype=torch.bfloat16
|
||||
)
|
||||
g = (
|
||||
F.logsigmoid(
|
||||
torch.randn(1, total_tokens, h_v, device=device, dtype=torch.float32)
|
||||
)
|
||||
/ 16
|
||||
)
|
||||
beta = torch.randn(
|
||||
1, total_tokens, h_v, device=device, dtype=torch.float32
|
||||
).sigmoid()
|
||||
h0 = torch.randn(
|
||||
num_seqs, h_v, head_dim, head_dim, device=device, dtype=torch.float32
|
||||
)
|
||||
do = torch.randn_like(v)
|
||||
dht = (
|
||||
torch.randn(
|
||||
num_seqs, h_v, head_dim, head_dim, device=device, dtype=torch.float32
|
||||
)
|
||||
/ 8
|
||||
)
|
||||
except RuntimeError as e:
|
||||
if "out of memory" in str(e).lower():
|
||||
return None
|
||||
raise e
|
||||
|
||||
swa_ratio = 0.75
|
||||
swa_mask = torch.zeros(h_v, dtype=torch.bool, device=device)
|
||||
swa_mask[: math.ceil(swa_ratio * h_v)] = True
|
||||
swa_mask = swa_mask[torch.randperm(h_v, device=device)]
|
||||
g[:, :, ~swa_mask] = 0.0
|
||||
|
||||
return {
|
||||
"device": device,
|
||||
"num_seqs": num_seqs,
|
||||
"total_tokens": total_tokens,
|
||||
"scale": scale,
|
||||
"cu_seqlens": cu_seqlens,
|
||||
"q": q,
|
||||
"k": k,
|
||||
"v": v,
|
||||
"g": g,
|
||||
"beta": beta,
|
||||
"h0": h0,
|
||||
"do": do,
|
||||
"dht": dht,
|
||||
}
|
||||
|
||||
|
||||
def bench_fwd(
|
||||
seqlens: List[int],
|
||||
h_qk: int,
|
||||
h_v: int,
|
||||
head_dim: int = HEAD_DIM,
|
||||
warmup: int = 10,
|
||||
repeats: int = 5,
|
||||
) -> Tuple[float, float, float]:
|
||||
"""
|
||||
Run Forward Pass Benchmark.
|
||||
Returns: (qla_mean_ms, fi_mean_ms, fla_mean_ms)
|
||||
"""
|
||||
cleanup_cuda()
|
||||
data = prepare_tensors(seqlens, h_qk, h_v, head_dim)
|
||||
if data is None:
|
||||
return float("nan"), float("nan"), float("nan")
|
||||
|
||||
q, k, v, g, beta = data["q"], data["k"], data["v"], data["g"], data["beta"]
|
||||
h0, scale, cu_seqlens = data["h0"], data["scale"], data["cu_seqlens"]
|
||||
|
||||
results = {}
|
||||
|
||||
def call_qla_fwd():
|
||||
qla_fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
g,
|
||||
beta,
|
||||
scale=scale,
|
||||
initial_state=h0,
|
||||
output_final_state=True,
|
||||
output_h=False,
|
||||
cu_seqlens=cu_seqlens,
|
||||
auto_cp=True,
|
||||
)
|
||||
|
||||
try:
|
||||
mean = tilelang.profiler.do_bench(call_qla_fwd, warmup=warmup, rep=repeats)
|
||||
results["flash_qla"] = mean
|
||||
except RuntimeError as e:
|
||||
print(f"\n[WARN] FlashQLA Fwd failed: {e}")
|
||||
cleanup_cuda()
|
||||
results["flash_qla"] = float("nan")
|
||||
|
||||
if HAS_FI:
|
||||
|
||||
def call_fi_fwd():
|
||||
fi_fwd(
|
||||
q=q.view(-1, h_qk, head_dim),
|
||||
k=k.view(-1, h_qk, head_dim),
|
||||
v=v.view(-1, h_v, head_dim),
|
||||
g=g.view(-1, h_v),
|
||||
beta=beta.view(-1, h_v),
|
||||
scale=scale,
|
||||
initial_state=h0,
|
||||
cu_seqlens=cu_seqlens,
|
||||
output_final_state=True,
|
||||
)
|
||||
|
||||
try:
|
||||
mean = tilelang.profiler.do_bench(call_fi_fwd, warmup=warmup, rep=repeats)
|
||||
results["fi"] = mean
|
||||
except RuntimeError as e:
|
||||
print(f"\n[WARN] FI Fwd failed: {e}")
|
||||
cleanup_cuda()
|
||||
results["fi"] = float("nan")
|
||||
else:
|
||||
results["fi"] = float("nan")
|
||||
|
||||
def call_fla_fwd():
|
||||
fla_fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
g,
|
||||
beta,
|
||||
scale=scale,
|
||||
initial_state=h0,
|
||||
output_final_state=True,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
|
||||
try:
|
||||
mean = tilelang.profiler.do_bench(call_fla_fwd, warmup=warmup, rep=repeats)
|
||||
results["fla"] = mean
|
||||
except RuntimeError as e:
|
||||
print(f"\n[WARN] FLA Fwd failed: {e}")
|
||||
cleanup_cuda()
|
||||
results["fla"] = float("nan")
|
||||
|
||||
try:
|
||||
torch.cuda.synchronize()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return (
|
||||
results.get("flash_qla", float("nan")),
|
||||
results.get("fi", float("nan")),
|
||||
results.get("fla", float("nan")),
|
||||
)
|
||||
|
||||
|
||||
def bench_bwd(
|
||||
seqlens: List[int],
|
||||
h_qk: int,
|
||||
h_v: int,
|
||||
head_dim: int = HEAD_DIM,
|
||||
warmup: int = 10,
|
||||
repeats: int = 100,
|
||||
) -> Tuple[float, float]:
|
||||
"""
|
||||
Run Backward Pass Benchmark.
|
||||
Returns: (qla_mean_ms, fla_mean_ms)
|
||||
"""
|
||||
unified_h = h_qk
|
||||
cleanup_cuda()
|
||||
|
||||
data = prepare_tensors(seqlens, unified_h, unified_h, head_dim)
|
||||
if data is None:
|
||||
return float("nan"), float("nan")
|
||||
|
||||
q, k, v, g, beta = data["q"], data["k"], data["v"], data["g"], data["beta"]
|
||||
h0, scale, cu_seqlens = data["h0"], data["scale"], data["cu_seqlens"]
|
||||
do, dht = data["do"], data["dht"]
|
||||
|
||||
g_cumsum = None
|
||||
A = None
|
||||
|
||||
# Pre-run FWD to get intermediates
|
||||
try:
|
||||
result = qla_fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
g,
|
||||
beta,
|
||||
scale=scale,
|
||||
initial_state=h0,
|
||||
output_final_state=True,
|
||||
output_h=False,
|
||||
cu_seqlens=cu_seqlens,
|
||||
auto_cp=True,
|
||||
)
|
||||
if isinstance(result, tuple) and len(result) >= 2:
|
||||
g_cumsum, A = result[0], result[1]
|
||||
else:
|
||||
raise RuntimeError("FlashQLA FWD did not return expected intermediates")
|
||||
except RuntimeError as e:
|
||||
print(f"[FWD Error] Failed at seqlens={seqlens}, heads={h_qk}. Error: {e}")
|
||||
cleanup_cuda()
|
||||
return float("nan"), float("nan")
|
||||
|
||||
results = {}
|
||||
|
||||
def call_qla_bwd():
|
||||
return qla_bwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
g_cumsum,
|
||||
beta,
|
||||
A,
|
||||
do,
|
||||
dht,
|
||||
scale=scale,
|
||||
initial_state=h0,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
|
||||
try:
|
||||
mean = tilelang.profiler.do_bench(call_qla_bwd, warmup=warmup, rep=repeats)
|
||||
results["flash_qla"] = mean
|
||||
except RuntimeError as e:
|
||||
print(f"\n[WARN] FlashQLA Bwd failed: {e}")
|
||||
cleanup_cuda()
|
||||
results["flash_qla"] = float("nan")
|
||||
|
||||
def call_fla_bwd():
|
||||
return fla_bwd(q, k, v, g_cumsum, beta, A, scale, h0, do, dht, cu_seqlens)
|
||||
|
||||
try:
|
||||
mean = tilelang.profiler.do_bench(call_fla_bwd, warmup=warmup, rep=repeats)
|
||||
results["fla"] = mean
|
||||
except RuntimeError as e:
|
||||
print(f"\n[WARN] FLA Bwd failed: {e}")
|
||||
cleanup_cuda()
|
||||
results["fla"] = float("nan")
|
||||
|
||||
return results.get("flash_qla", float("nan")), results.get("fla", float("nan"))
|
||||
|
||||
|
||||
FWD_HDR = (
|
||||
f"{'Model Config':<16} {'Seqlens':<17} {'h_qk':>5} {'h_v':>5} "
|
||||
f"{'flash_qla [fwd]':>10} {'FI [fwd]':>10} {'FLA [fwd]':>10} "
|
||||
f"{'vs FLA':>7} {'vs FI':>7}"
|
||||
)
|
||||
|
||||
BWD_HDR = (
|
||||
f"{'Heads':<8} {'SeqLen':<15} "
|
||||
f"{'flash_qla [bwd]':>10} {'FLA [bwd]':>10} {'Speedup':>8}"
|
||||
)
|
||||
|
||||
|
||||
def fmt_time(ms: float) -> str:
|
||||
if math.isnan(ms):
|
||||
return " N/A "
|
||||
return f"{ms:>8.3f}ms"
|
||||
|
||||
|
||||
def fmt_ratio(base: float, other: float) -> str:
|
||||
if math.isnan(base) or math.isnan(other) or base == 0:
|
||||
return " N/A "
|
||||
return f"{other / base:>6.2f}x"
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Benchmark FlashQLA Gated Delta Rule")
|
||||
parser.add_argument("--warmup", type=int, default=10)
|
||||
parser.add_argument("--repeats", type=int, default=100)
|
||||
parser.add_argument("--mode", choices=["fwd", "bwd", "all"], default="all")
|
||||
parser.add_argument("--skip-fi", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
print("CUDA not available.")
|
||||
return
|
||||
|
||||
global HAS_FI
|
||||
if args.skip_fi:
|
||||
HAS_FI = False
|
||||
|
||||
gpu_name = torch.cuda.get_device_properties(0).name
|
||||
print(f"GPU: {gpu_name}")
|
||||
print("Models: Qwen3.5 family (397B, 122B, 35B, 27B, 9B, 4B, 2B, 0.8B), d=128")
|
||||
print(f"Config: Warmup={args.warmup}, Repeats={args.repeats}")
|
||||
|
||||
libs = get_lib_versions()
|
||||
print("Library Versions:")
|
||||
ver_str = " | ".join([f"{k}: {v}" for k, v in libs.items()])
|
||||
print(f" {ver_str}")
|
||||
|
||||
print("=" * 110)
|
||||
|
||||
# Forward
|
||||
if args.mode in ("fwd", "all"):
|
||||
print("\n>>> FORWARD BENCHMARKS")
|
||||
print(FWD_HDR)
|
||||
print("-" * len(FWD_HDR))
|
||||
|
||||
prev_model = None
|
||||
for cfg in FWD_MODEL_CONFIGS:
|
||||
if prev_model is not None and cfg.label != prev_model:
|
||||
print()
|
||||
prev_model = cfg.label
|
||||
|
||||
for sl_cfg in FWD_SEQLEN_CONFIGS:
|
||||
try:
|
||||
qla_ms, fi_ms, fla_ms = bench_fwd(
|
||||
sl_cfg.seqlens,
|
||||
cfg.h_qk,
|
||||
cfg.h_v,
|
||||
warmup=args.warmup,
|
||||
repeats=args.repeats,
|
||||
)
|
||||
|
||||
if math.isnan(qla_ms) and math.isnan(fla_ms):
|
||||
cleanup_cuda()
|
||||
|
||||
ratio_fla = fmt_ratio(qla_ms, fla_ms)
|
||||
ratio_fi = fmt_ratio(qla_ms, fi_ms)
|
||||
|
||||
print(
|
||||
f"{cfg.label:<16} {sl_cfg.label:<17} {cfg.h_qk:>5} {cfg.h_v:>5} "
|
||||
f"{fmt_time(qla_ms)} {fmt_time(fi_ms)} {fmt_time(fla_ms)} "
|
||||
f"{ratio_fla} {ratio_fi}",
|
||||
flush=True,
|
||||
)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"\n[ERROR] Forward Case Failed: {cfg.label} / {sl_cfg.label}"
|
||||
)
|
||||
print(f"Exception: {e}")
|
||||
cleanup_cuda()
|
||||
continue
|
||||
|
||||
cleanup_cuda()
|
||||
|
||||
# Backward
|
||||
if args.mode in ("bwd", "all"):
|
||||
print("\n" + "=" * 110)
|
||||
print(f"\n>>> BACKWARD BENCHMARKS (Split seq into {BWD_SPLIT_SIZE} sequences)")
|
||||
print(BWD_HDR)
|
||||
print("-" * len(BWD_HDR))
|
||||
|
||||
prev_model = None
|
||||
for cfg in BWD_MODEL_CONFIGS:
|
||||
if prev_model is not None and cfg.label != prev_model:
|
||||
print()
|
||||
prev_model = cfg.label
|
||||
|
||||
for sl_cfg in BWD_SEQLEN_CONFIGS:
|
||||
try:
|
||||
qla_bwd_ms, fla_bwd_ms = bench_bwd(
|
||||
sl_cfg.seqlens,
|
||||
cfg.h_qk,
|
||||
cfg.h_v,
|
||||
warmup=args.warmup,
|
||||
repeats=args.repeats,
|
||||
)
|
||||
|
||||
if math.isnan(qla_bwd_ms) and math.isnan(fla_bwd_ms):
|
||||
speedup = " Skip "
|
||||
else:
|
||||
speedup = fmt_ratio(qla_bwd_ms, fla_bwd_ms)
|
||||
|
||||
print(
|
||||
f"{cfg.label:<8} {sl_cfg.label:<15} "
|
||||
f"{fmt_time(qla_bwd_ms)} {fmt_time(fla_bwd_ms)} {speedup} ",
|
||||
flush=True,
|
||||
)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"\n[CRITICAL ERROR] Case Crashed: Heads={cfg.label}, TotalSeqLen={sl_cfg.label}"
|
||||
)
|
||||
print(f"Exception: {e}")
|
||||
cleanup_cuda()
|
||||
continue
|
||||
|
||||
cleanup_cuda()
|
||||
|
||||
print("\nBenchmark Finished.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user