first commit

This commit is contained in:
2026-06-14 23:49:03 +08:00
commit 3f95e2939d
35 changed files with 6764 additions and 0 deletions

View File

@@ -0,0 +1,585 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
#
# Benchmark Script for FlashQLA
import argparse
import math
import gc
from dataclasses import dataclass
from typing import List, Tuple, Optional, Dict, Any
import torch
import torch.nn.functional as F
import tilelang
# Kernel Imports
from fla.ops.gated_delta_rule.chunk import (
chunk_gated_delta_rule_fwd as fla_fwd,
chunk_gated_delta_rule_bwd as fla_bwd,
)
from flash_qla import (
chunk_gated_delta_rule_fwd as qla_fwd,
chunk_gated_delta_rule_bwd as qla_bwd,
)
from flash_qla.utils import l2norm
try:
from flashinfer.gdn_prefill import chunk_gated_delta_rule as fi_fwd
HAS_FI = True
except ImportError:
HAS_FI = False
HEAD_DIM = 128
BWD_SPLIT_SIZE = 8
@dataclass
class ModelConfig:
label: str
h_qk: int
h_v: int
@dataclass
class SeqLenConfig:
label: str
seqlens: List[int]
def generate_rand_seqlens(batch_size, num_tokens):
bars = (
torch.sort(
torch.randperm(num_tokens - 1, device="cuda", dtype=torch.int32)[
: batch_size - 1
]
).values
+ 1
)
cu_seqlens = torch.nn.functional.pad(bars, (1, 1))
cu_seqlens[-1] = num_tokens
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
return seqlens.tolist()
FWD_MODEL_CONFIGS = [
ModelConfig("397B/122B TP8", 2, 8),
ModelConfig("397B/122B TP4", 4, 16),
ModelConfig("397B/122B TP2", 8, 32),
ModelConfig("397B/122B TP1", 16, 64),
ModelConfig("35B/9B/4B TP1", 16, 32),
ModelConfig("27B TP2", 8, 24),
ModelConfig("27B TP1", 16, 48),
ModelConfig("2B/0.8B TP1", 16, 16),
ModelConfig("Sym h32", 32, 32),
]
FWD_SEQLEN_CONFIGS = [
SeqLenConfig("1x32768", [32768]),
SeqLenConfig("1x16384", [16384]),
SeqLenConfig("1x8192", [8192]),
SeqLenConfig("1x4096", [4096]),
SeqLenConfig("1x2048", [2048]),
SeqLenConfig("28672+4096", [28672, 4096]),
SeqLenConfig("24576+8192", [24576, 8192]),
SeqLenConfig("16384+16384", [16384, 16384]),
SeqLenConfig("8192+24576", [8192, 24576]),
SeqLenConfig("4096+28672", [4096, 28672]),
SeqLenConfig("12288+4096", [12288, 4096]),
SeqLenConfig("6144+2048", [6144, 2048]),
SeqLenConfig("4096+4096", [4096, 4096]),
SeqLenConfig("2048+6144", [2048, 6144]),
SeqLenConfig("1024+7168", [1024, 7168]),
SeqLenConfig("8192x4", [8192] * 4),
SeqLenConfig("4096x8", [4096] * 8),
SeqLenConfig("2048x4", [2048] * 4),
SeqLenConfig("1024x8", [1024] * 8),
]
BWD_MODEL_CONFIGS = [
ModelConfig("32", 32, 32),
ModelConfig("48", 48, 48),
ModelConfig("64", 64, 64),
]
BWD_SEQLEN_CONFIGS = [
SeqLenConfig("16k", generate_rand_seqlens(8, 16384)),
SeqLenConfig("32k", generate_rand_seqlens(8, 32768)),
SeqLenConfig("64k", generate_rand_seqlens(8, 65536)),
SeqLenConfig("128k", generate_rand_seqlens(8, 131072)),
SeqLenConfig("256k", generate_rand_seqlens(8, 262144)),
]
def cleanup_cuda():
try:
if torch.cuda.is_available():
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()
except Exception:
pass
def get_lib_versions() -> Dict[str, str]:
"""Collects version strings for relevant libraries."""
versions = {}
# Torch
try:
versions["torch"] = torch.__version__
except Exception:
versions["torch"] = "N/A"
# Flash Linear Attention (FLA)
try:
import fla
if hasattr(fla, "__version__"):
versions["fla"] = fla.__version__
else:
versions["fla"] = "Installed (ver unknown)"
except ImportError:
versions["fla"] = "Not Installed"
# FlashInfer
try:
import flashinfer
if hasattr(flashinfer, "__version__"):
versions["flashinfer"] = flashinfer.__version__
elif hasattr(flashinfer, "version"):
versions["flashinfer"] = str(flashinfer.version)
else:
versions["flashinfer"] = "Installed (ver unknown)"
except ImportError:
versions["flashinfer"] = "Not Installed"
# TileLang
try:
import tilelang
if hasattr(tilelang, "__version__"):
versions["tilelang"] = tilelang.__version__
elif hasattr(tilelang, "version"):
versions["tilelang"] = str(tilelang.version)
else:
versions["tilelang"] = "Installed (ver unknown)"
except ImportError:
versions["tilelang"] = "Not Installed"
return versions
def prepare_tensors(
seqlens: List[int], h_qk: int, h_v: int, head_dim: int = HEAD_DIM
) -> Optional[Dict[str, Any]]:
device = "cuda"
num_seqs = len(seqlens)
total_tokens = sum(seqlens)
scale = head_dim ** (-0.5)
offsets = [0]
for s in seqlens:
offsets.append(offsets[-1] + s)
cu_seqlens = torch.tensor(offsets, dtype=torch.int32, device=device)
try:
q = l2norm(
torch.randn(
1, total_tokens, h_qk, head_dim, device=device, dtype=torch.bfloat16
)
)
k = l2norm(
torch.randn(
1, total_tokens, h_qk, head_dim, device=device, dtype=torch.bfloat16
)
)
v = torch.randn(
1, total_tokens, h_v, head_dim, device=device, dtype=torch.bfloat16
)
g = (
F.logsigmoid(
torch.randn(1, total_tokens, h_v, device=device, dtype=torch.float32)
)
/ 16
)
beta = torch.randn(
1, total_tokens, h_v, device=device, dtype=torch.float32
).sigmoid()
h0 = torch.randn(
num_seqs, h_v, head_dim, head_dim, device=device, dtype=torch.float32
)
do = torch.randn_like(v)
dht = (
torch.randn(
num_seqs, h_v, head_dim, head_dim, device=device, dtype=torch.float32
)
/ 8
)
except RuntimeError as e:
if "out of memory" in str(e).lower():
return None
raise e
swa_ratio = 0.75
swa_mask = torch.zeros(h_v, dtype=torch.bool, device=device)
swa_mask[: math.ceil(swa_ratio * h_v)] = True
swa_mask = swa_mask[torch.randperm(h_v, device=device)]
g[:, :, ~swa_mask] = 0.0
return {
"device": device,
"num_seqs": num_seqs,
"total_tokens": total_tokens,
"scale": scale,
"cu_seqlens": cu_seqlens,
"q": q,
"k": k,
"v": v,
"g": g,
"beta": beta,
"h0": h0,
"do": do,
"dht": dht,
}
def bench_fwd(
seqlens: List[int],
h_qk: int,
h_v: int,
head_dim: int = HEAD_DIM,
warmup: int = 10,
repeats: int = 5,
) -> Tuple[float, float, float]:
"""
Run Forward Pass Benchmark.
Returns: (qla_mean_ms, fi_mean_ms, fla_mean_ms)
"""
cleanup_cuda()
data = prepare_tensors(seqlens, h_qk, h_v, head_dim)
if data is None:
return float("nan"), float("nan"), float("nan")
q, k, v, g, beta = data["q"], data["k"], data["v"], data["g"], data["beta"]
h0, scale, cu_seqlens = data["h0"], data["scale"], data["cu_seqlens"]
results = {}
def call_qla_fwd():
qla_fwd(
q,
k,
v,
g,
beta,
scale=scale,
initial_state=h0,
output_final_state=True,
output_h=False,
cu_seqlens=cu_seqlens,
auto_cp=True,
)
try:
mean = tilelang.profiler.do_bench(call_qla_fwd, warmup=warmup, rep=repeats)
results["flash_qla"] = mean
except RuntimeError as e:
print(f"\n[WARN] FlashQLA Fwd failed: {e}")
cleanup_cuda()
results["flash_qla"] = float("nan")
if HAS_FI:
def call_fi_fwd():
fi_fwd(
q=q.view(-1, h_qk, head_dim),
k=k.view(-1, h_qk, head_dim),
v=v.view(-1, h_v, head_dim),
g=g.view(-1, h_v),
beta=beta.view(-1, h_v),
scale=scale,
initial_state=h0,
cu_seqlens=cu_seqlens,
output_final_state=True,
)
try:
mean = tilelang.profiler.do_bench(call_fi_fwd, warmup=warmup, rep=repeats)
results["fi"] = mean
except RuntimeError as e:
print(f"\n[WARN] FI Fwd failed: {e}")
cleanup_cuda()
results["fi"] = float("nan")
else:
results["fi"] = float("nan")
def call_fla_fwd():
fla_fwd(
q,
k,
v,
g,
beta,
scale=scale,
initial_state=h0,
output_final_state=True,
cu_seqlens=cu_seqlens,
)
try:
mean = tilelang.profiler.do_bench(call_fla_fwd, warmup=warmup, rep=repeats)
results["fla"] = mean
except RuntimeError as e:
print(f"\n[WARN] FLA Fwd failed: {e}")
cleanup_cuda()
results["fla"] = float("nan")
try:
torch.cuda.synchronize()
except Exception:
pass
return (
results.get("flash_qla", float("nan")),
results.get("fi", float("nan")),
results.get("fla", float("nan")),
)
def bench_bwd(
seqlens: List[int],
h_qk: int,
h_v: int,
head_dim: int = HEAD_DIM,
warmup: int = 10,
repeats: int = 100,
) -> Tuple[float, float]:
"""
Run Backward Pass Benchmark.
Returns: (qla_mean_ms, fla_mean_ms)
"""
unified_h = h_qk
cleanup_cuda()
data = prepare_tensors(seqlens, unified_h, unified_h, head_dim)
if data is None:
return float("nan"), float("nan")
q, k, v, g, beta = data["q"], data["k"], data["v"], data["g"], data["beta"]
h0, scale, cu_seqlens = data["h0"], data["scale"], data["cu_seqlens"]
do, dht = data["do"], data["dht"]
g_cumsum = None
A = None
# Pre-run FWD to get intermediates
try:
result = qla_fwd(
q,
k,
v,
g,
beta,
scale=scale,
initial_state=h0,
output_final_state=True,
output_h=False,
cu_seqlens=cu_seqlens,
auto_cp=True,
)
if isinstance(result, tuple) and len(result) >= 2:
g_cumsum, A = result[0], result[1]
else:
raise RuntimeError("FlashQLA FWD did not return expected intermediates")
except RuntimeError as e:
print(f"[FWD Error] Failed at seqlens={seqlens}, heads={h_qk}. Error: {e}")
cleanup_cuda()
return float("nan"), float("nan")
results = {}
def call_qla_bwd():
return qla_bwd(
q,
k,
v,
g_cumsum,
beta,
A,
do,
dht,
scale=scale,
initial_state=h0,
cu_seqlens=cu_seqlens,
)
try:
mean = tilelang.profiler.do_bench(call_qla_bwd, warmup=warmup, rep=repeats)
results["flash_qla"] = mean
except RuntimeError as e:
print(f"\n[WARN] FlashQLA Bwd failed: {e}")
cleanup_cuda()
results["flash_qla"] = float("nan")
def call_fla_bwd():
return fla_bwd(q, k, v, g_cumsum, beta, A, scale, h0, do, dht, cu_seqlens)
try:
mean = tilelang.profiler.do_bench(call_fla_bwd, warmup=warmup, rep=repeats)
results["fla"] = mean
except RuntimeError as e:
print(f"\n[WARN] FLA Bwd failed: {e}")
cleanup_cuda()
results["fla"] = float("nan")
return results.get("flash_qla", float("nan")), results.get("fla", float("nan"))
FWD_HDR = (
f"{'Model Config':<16} {'Seqlens':<17} {'h_qk':>5} {'h_v':>5} "
f"{'flash_qla [fwd]':>10} {'FI [fwd]':>10} {'FLA [fwd]':>10} "
f"{'vs FLA':>7} {'vs FI':>7}"
)
BWD_HDR = (
f"{'Heads':<8} {'SeqLen':<15} "
f"{'flash_qla [bwd]':>10} {'FLA [bwd]':>10} {'Speedup':>8}"
)
def fmt_time(ms: float) -> str:
if math.isnan(ms):
return " N/A "
return f"{ms:>8.3f}ms"
def fmt_ratio(base: float, other: float) -> str:
if math.isnan(base) or math.isnan(other) or base == 0:
return " N/A "
return f"{other / base:>6.2f}x"
def main():
parser = argparse.ArgumentParser(description="Benchmark FlashQLA Gated Delta Rule")
parser.add_argument("--warmup", type=int, default=10)
parser.add_argument("--repeats", type=int, default=100)
parser.add_argument("--mode", choices=["fwd", "bwd", "all"], default="all")
parser.add_argument("--skip-fi", action="store_true")
args = parser.parse_args()
if not torch.cuda.is_available():
print("CUDA not available.")
return
global HAS_FI
if args.skip_fi:
HAS_FI = False
gpu_name = torch.cuda.get_device_properties(0).name
print(f"GPU: {gpu_name}")
print("Models: Qwen3.5 family (397B, 122B, 35B, 27B, 9B, 4B, 2B, 0.8B), d=128")
print(f"Config: Warmup={args.warmup}, Repeats={args.repeats}")
libs = get_lib_versions()
print("Library Versions:")
ver_str = " | ".join([f"{k}: {v}" for k, v in libs.items()])
print(f" {ver_str}")
print("=" * 110)
# Forward
if args.mode in ("fwd", "all"):
print("\n>>> FORWARD BENCHMARKS")
print(FWD_HDR)
print("-" * len(FWD_HDR))
prev_model = None
for cfg in FWD_MODEL_CONFIGS:
if prev_model is not None and cfg.label != prev_model:
print()
prev_model = cfg.label
for sl_cfg in FWD_SEQLEN_CONFIGS:
try:
qla_ms, fi_ms, fla_ms = bench_fwd(
sl_cfg.seqlens,
cfg.h_qk,
cfg.h_v,
warmup=args.warmup,
repeats=args.repeats,
)
if math.isnan(qla_ms) and math.isnan(fla_ms):
cleanup_cuda()
ratio_fla = fmt_ratio(qla_ms, fla_ms)
ratio_fi = fmt_ratio(qla_ms, fi_ms)
print(
f"{cfg.label:<16} {sl_cfg.label:<17} {cfg.h_qk:>5} {cfg.h_v:>5} "
f"{fmt_time(qla_ms)} {fmt_time(fi_ms)} {fmt_time(fla_ms)} "
f"{ratio_fla} {ratio_fi}",
flush=True,
)
except Exception as e:
print(
f"\n[ERROR] Forward Case Failed: {cfg.label} / {sl_cfg.label}"
)
print(f"Exception: {e}")
cleanup_cuda()
continue
cleanup_cuda()
# Backward
if args.mode in ("bwd", "all"):
print("\n" + "=" * 110)
print(f"\n>>> BACKWARD BENCHMARKS (Split seq into {BWD_SPLIT_SIZE} sequences)")
print(BWD_HDR)
print("-" * len(BWD_HDR))
prev_model = None
for cfg in BWD_MODEL_CONFIGS:
if prev_model is not None and cfg.label != prev_model:
print()
prev_model = cfg.label
for sl_cfg in BWD_SEQLEN_CONFIGS:
try:
qla_bwd_ms, fla_bwd_ms = bench_bwd(
sl_cfg.seqlens,
cfg.h_qk,
cfg.h_v,
warmup=args.warmup,
repeats=args.repeats,
)
if math.isnan(qla_bwd_ms) and math.isnan(fla_bwd_ms):
speedup = " Skip "
else:
speedup = fmt_ratio(qla_bwd_ms, fla_bwd_ms)
print(
f"{cfg.label:<8} {sl_cfg.label:<15} "
f"{fmt_time(qla_bwd_ms)} {fmt_time(fla_bwd_ms)} {speedup} ",
flush=True,
)
except Exception as e:
print(
f"\n[CRITICAL ERROR] Case Crashed: Heads={cfg.label}, TotalSeqLen={sl_cfg.label}"
)
print(f"Exception: {e}")
cleanup_cuda()
continue
cleanup_cuda()
print("\nBenchmark Finished.")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,214 @@
GPU: NVIDIA H200
Models: Qwen3.5 family (397B, 122B, 35B, 27B, 9B, 4B, 2B, 0.8B), d=128
Config: Warmup=10, Repeats=100
Library Versions:
torch: 2.9.1+cu128 | fla: 0.5.0 | flashinfer: 0.6.9 | tilelang: 0.1.8
==============================================================================================================
>>> FORWARD BENCHMARKS
Model Config Seqlens h_qk h_v flash_qla [fwd] FI [fwd] FLA [fwd] vs FLA vs FI
------------------------------------------------------------------------------------------------------------
397B/122B TP8 1x32768 2 8 0.310ms 1.657ms 0.910ms 2.93x 5.34x
397B/122B TP8 1x16384 2 8 0.184ms 0.832ms 0.465ms 2.53x 4.52x
397B/122B TP8 1x8192 2 8 0.122ms 0.420ms 0.578ms 4.74x 3.45x
397B/122B TP8 1x4096 2 8 0.107ms 0.215ms 0.277ms 2.59x 2.01x
397B/122B TP8 1x2048 2 8 0.106ms 0.114ms 0.261ms 2.46x 1.07x
397B/122B TP8 28672+4096 2 8 0.304ms 1.449ms 0.841ms 2.77x 4.77x
397B/122B TP8 24576+8192 2 8 0.300ms 1.243ms 0.766ms 2.55x 4.15x
397B/122B TP8 16384+16384 2 8 0.292ms 0.831ms 0.623ms 2.13x 2.84x
397B/122B TP8 8192+24576 2 8 0.298ms 1.241ms 0.766ms 2.57x 4.16x
397B/122B TP8 4096+28672 2 8 0.304ms 1.447ms 0.841ms 2.76x 4.75x
397B/122B TP8 12288+4096 2 8 0.180ms 0.626ms 0.391ms 2.17x 3.48x
397B/122B TP8 6144+2048 2 8 0.118ms 0.317ms 0.263ms 2.23x 2.70x
397B/122B TP8 4096+4096 2 8 0.115ms 0.215ms 0.262ms 2.27x 1.86x
397B/122B TP8 2048+6144 2 8 0.118ms 0.317ms 0.276ms 2.35x 2.70x
397B/122B TP8 1024+7168 2 8 0.120ms 0.368ms 0.263ms 2.19x 3.06x
397B/122B TP8 8192x4 2 8 0.274ms 0.401ms 0.477ms 1.74x 1.46x
397B/122B TP8 4096x8 2 8 0.203ms 0.205ms 0.481ms 2.37x 1.01x
397B/122B TP8 2048x4 2 8 0.109ms 0.110ms 0.268ms 2.46x 1.01x
397B/122B TP8 1024x8 2 8 0.065ms 0.063ms 0.273ms 4.17x 0.95x
397B/122B TP4 1x32768 4 16 0.485ms 1.655ms 1.250ms 2.58x 3.41x
397B/122B TP4 1x16384 4 16 0.292ms 0.832ms 0.630ms 2.16x 2.85x
397B/122B TP4 1x8192 4 16 0.178ms 0.420ms 0.325ms 1.83x 2.37x
397B/122B TP4 1x4096 4 16 0.115ms 0.215ms 0.290ms 2.51x 1.86x
397B/122B TP4 1x2048 4 16 0.108ms 0.114ms 0.275ms 2.54x 1.05x
397B/122B TP4 28672+4096 4 16 0.477ms 1.440ms 1.178ms 2.47x 3.02x
397B/122B TP4 24576+8192 4 16 0.474ms 1.222ms 1.100ms 2.32x 2.58x
397B/122B TP4 16384+16384 4 16 0.469ms 0.789ms 0.947ms 2.02x 1.68x
397B/122B TP4 8192+24576 4 16 0.470ms 1.192ms 1.099ms 2.34x 2.53x
397B/122B TP4 4096+28672 4 16 0.474ms 1.394ms 1.175ms 2.48x 2.94x
397B/122B TP4 12288+4096 4 16 0.281ms 0.616ms 0.549ms 1.95x 2.19x
397B/122B TP4 6144+2048 4 16 0.175ms 0.312ms 0.292ms 1.67x 1.78x
397B/122B TP4 4096+4096 4 16 0.173ms 0.207ms 0.263ms 1.52x 1.19x
397B/122B TP4 2048+6144 4 16 0.175ms 0.306ms 0.294ms 1.67x 1.75x
397B/122B TP4 1024+7168 4 16 0.177ms 0.359ms 0.307ms 1.74x 2.03x
397B/122B TP4 8192x4 4 16 0.388ms 0.400ms 0.949ms 2.45x 1.03x
397B/122B TP4 4096x8 4 16 0.311ms 0.210ms 0.954ms 3.07x 0.68x
397B/122B TP4 2048x4 4 16 0.114ms 0.110ms 0.264ms 2.32x 0.97x
397B/122B TP4 1024x8 4 16 0.097ms 0.066ms 0.280ms 2.88x 0.68x
397B/122B TP2 1x32768 8 32 0.866ms 1.567ms 1.894ms 2.19x 1.81x
397B/122B TP2 1x16384 8 32 0.468ms 0.789ms 0.926ms 1.98x 1.69x
397B/122B TP2 1x8192 8 32 0.275ms 0.401ms 0.479ms 1.74x 1.46x
397B/122B TP2 1x4096 8 32 0.175ms 0.207ms 0.273ms 1.56x 1.18x
397B/122B TP2 1x2048 8 32 0.116ms 0.110ms 0.308ms 2.66x 0.96x
397B/122B TP2 28672+4096 8 32 0.841ms 1.374ms 1.893ms 2.25x 1.63x
397B/122B TP2 24576+8192 8 32 0.846ms 1.178ms 1.891ms 2.24x 1.39x
397B/122B TP2 16384+16384 8 32 0.752ms 0.789ms 1.897ms 2.52x 1.05x
397B/122B TP2 8192+24576 8 32 0.842ms 1.182ms 1.893ms 2.25x 1.40x
397B/122B TP2 4096+28672 8 32 0.843ms 1.378ms 1.899ms 2.25x 1.63x
397B/122B TP2 12288+4096 8 32 0.468ms 0.597ms 0.928ms 1.98x 1.28x
397B/122B TP2 6144+2048 8 32 0.271ms 0.305ms 0.506ms 1.87x 1.13x
397B/122B TP2 4096+4096 8 32 0.206ms 0.206ms 0.482ms 2.34x 1.00x
397B/122B TP2 2048+6144 8 32 0.269ms 0.305ms 0.483ms 1.79x 1.13x
397B/122B TP2 1024+7168 8 32 0.277ms 0.354ms 0.480ms 1.73x 1.27x
397B/122B TP2 8192x4 8 32 0.596ms 0.405ms 1.894ms 3.18x 0.68x
397B/122B TP2 4096x8 8 32 0.601ms 0.409ms 1.904ms 3.17x 0.68x
397B/122B TP2 2048x4 8 32 0.172ms 0.115ms 0.486ms 2.83x 0.67x
397B/122B TP2 1024x8 8 32 0.180ms 0.124ms 0.504ms 2.80x 0.69x
397B/122B TP1 1x32768 16 64 1.494ms 1.563ms 3.336ms 2.23x 1.05x
397B/122B TP1 1x16384 16 64 0.762ms 0.790ms 1.608ms 2.11x 1.04x
397B/122B TP1 1x8192 16 64 0.393ms 0.397ms 0.815ms 2.07x 1.01x
397B/122B TP1 1x4096 16 64 0.209ms 0.206ms 0.430ms 2.06x 0.99x
397B/122B TP1 1x2048 16 64 0.117ms 0.111ms 0.276ms 2.37x 0.95x
397B/122B TP1 28672+4096 16 64 1.719ms 1.374ms 3.341ms 1.94x 0.80x
397B/122B TP1 24576+8192 16 64 1.530ms 1.179ms 3.341ms 2.18x 0.77x
397B/122B TP1 16384+16384 16 64 1.168ms 0.804ms 3.367ms 2.88x 0.69x
397B/122B TP1 8192+24576 16 64 1.539ms 1.232ms 3.352ms 2.18x 0.80x
397B/122B TP1 4096+28672 16 64 1.723ms 1.453ms 3.353ms 1.95x 0.84x
397B/122B TP1 12288+4096 16 64 0.781ms 0.596ms 1.606ms 2.06x 0.76x
397B/122B TP1 6144+2048 16 64 0.406ms 0.305ms 0.822ms 2.02x 0.75x
397B/122B TP1 4096+4096 16 64 0.318ms 0.209ms 0.823ms 2.59x 0.66x
397B/122B TP1 2048+6144 16 64 0.406ms 0.317ms 0.825ms 2.03x 0.78x
397B/122B TP1 1024+7168 16 64 0.454ms 0.371ms 0.825ms 1.82x 0.82x
397B/122B TP1 8192x4 16 64 1.170ms 0.800ms 3.361ms 2.87x 0.68x
397B/122B TP1 4096x8 16 64 1.191ms 0.819ms 3.411ms 2.86x 0.69x
397B/122B TP1 2048x4 16 64 0.324ms 0.219ms 0.837ms 2.58x 0.67x
397B/122B TP1 1024x8 16 64 0.344ms 0.239ms 0.868ms 2.52x 0.70x
35B/9B/4B TP1 1x32768 16 32 0.871ms 1.567ms 1.908ms 2.19x 1.80x
35B/9B/4B TP1 1x16384 16 32 0.472ms 0.790ms 0.945ms 2.00x 1.67x
35B/9B/4B TP1 1x8192 16 32 0.278ms 0.402ms 0.484ms 1.74x 1.44x
35B/9B/4B TP1 1x4096 16 32 0.176ms 0.207ms 0.266ms 1.51x 1.17x
35B/9B/4B TP1 1x2048 16 32 0.118ms 0.111ms 0.265ms 2.24x 0.94x
35B/9B/4B TP1 28672+4096 16 32 0.860ms 1.374ms 1.904ms 2.21x 1.60x
35B/9B/4B TP1 24576+8192 16 32 0.857ms 1.180ms 1.904ms 2.22x 1.38x
35B/9B/4B TP1 16384+16384 16 32 0.762ms 0.796ms 1.919ms 2.52x 1.05x
35B/9B/4B TP1 8192+24576 16 32 0.854ms 1.184ms 1.909ms 2.24x 1.39x
35B/9B/4B TP1 4096+28672 16 32 0.853ms 1.379ms 1.911ms 2.24x 1.62x
35B/9B/4B TP1 12288+4096 16 32 0.474ms 0.597ms 0.946ms 2.00x 1.26x
35B/9B/4B TP1 6144+2048 16 32 0.270ms 0.304ms 0.488ms 1.81x 1.13x
35B/9B/4B TP1 4096+4096 16 32 0.208ms 0.206ms 0.489ms 2.35x 0.99x
35B/9B/4B TP1 2048+6144 16 32 0.272ms 0.304ms 0.487ms 1.79x 1.12x
35B/9B/4B TP1 1024+7168 16 32 0.279ms 0.354ms 0.488ms 1.75x 1.27x
35B/9B/4B TP1 8192x4 16 32 0.596ms 0.406ms 1.903ms 3.19x 0.68x
35B/9B/4B TP1 4096x8 16 32 0.607ms 0.412ms 1.918ms 3.16x 0.68x
35B/9B/4B TP1 2048x4 16 32 0.171ms 0.116ms 0.495ms 2.89x 0.68x
35B/9B/4B TP1 1024x8 16 32 0.184ms 0.125ms 0.513ms 2.78x 0.68x
27B TP2 1x32768 8 24 0.657ms 1.631ms 1.576ms 2.40x 2.48x
27B TP2 1x16384 8 24 0.397ms 0.820ms 0.778ms 1.96x 2.06x
27B TP2 1x8192 8 24 0.262ms 0.416ms 0.402ms 1.53x 1.58x
27B TP2 1x4096 8 24 0.170ms 0.213ms 0.267ms 1.57x 1.25x
27B TP2 1x2048 8 24 0.111ms 0.113ms 0.266ms 2.40x 1.02x
27B TP2 28672+4096 8 24 0.655ms 1.422ms 1.495ms 2.28x 2.17x
27B TP2 24576+8192 8 24 0.651ms 1.211ms 1.425ms 2.19x 1.86x
27B TP2 16384+16384 8 24 0.657ms 0.791ms 1.579ms 2.40x 1.20x
27B TP2 8192+24576 8 24 0.656ms 1.185ms 1.576ms 2.40x 1.81x
27B TP2 4096+28672 8 24 0.655ms 1.382ms 1.577ms 2.41x 2.11x
27B TP2 12288+4096 8 24 0.392ms 0.611ms 0.707ms 1.80x 1.56x
27B TP2 6144+2048 8 24 0.258ms 0.310ms 0.372ms 1.44x 1.20x
27B TP2 4096+4096 8 24 0.189ms 0.206ms 0.406ms 2.15x 1.09x
27B TP2 2048+6144 8 24 0.261ms 0.303ms 0.406ms 1.56x 1.16x
27B TP2 1024+7168 8 24 0.261ms 0.352ms 0.407ms 1.56x 1.35x
27B TP2 8192x4 8 24 0.545ms 0.401ms 1.419ms 2.61x 0.74x
27B TP2 4096x8 8 24 0.546ms 0.410ms 1.430ms 2.62x 0.75x
27B TP2 2048x4 8 24 0.155ms 0.112ms 0.374ms 2.41x 0.73x
27B TP2 1024x8 8 24 0.164ms 0.120ms 0.383ms 2.34x 0.74x
27B TP1 1x32768 16 48 1.243ms 1.573ms 2.728ms 2.19x 1.27x
27B TP1 1x16384 16 48 0.664ms 0.792ms 1.327ms 2.00x 1.19x
27B TP1 1x8192 16 48 0.402ms 0.401ms 0.677ms 1.69x 1.00x
27B TP1 1x4096 16 48 0.193ms 0.206ms 0.355ms 1.84x 1.07x
27B TP1 1x2048 16 48 0.107ms 0.110ms 0.264ms 2.47x 1.03x
27B TP1 28672+4096 16 48 1.231ms 1.371ms 2.618ms 2.13x 1.11x
27B TP1 24576+8192 16 48 1.434ms 1.185ms 2.541ms 1.77x 0.83x
27B TP1 16384+16384 16 48 1.074ms 0.787ms 2.744ms 2.56x 0.73x
27B TP1 8192+24576 16 48 1.435ms 1.186ms 2.741ms 1.91x 0.83x
27B TP1 4096+28672 16 48 1.230ms 1.371ms 2.731ms 2.22x 1.11x
27B TP1 12288+4096 16 48 0.733ms 0.598ms 1.230ms 1.68x 0.82x
27B TP1 6144+2048 16 48 0.379ms 0.305ms 0.638ms 1.68x 0.80x
27B TP1 4096+4096 16 48 0.290ms 0.207ms 0.682ms 2.35x 0.71x
27B TP1 2048+6144 16 48 0.379ms 0.306ms 0.689ms 1.82x 0.81x
27B TP1 1024+7168 16 48 0.425ms 0.355ms 0.688ms 1.62x 0.84x
27B TP1 8192x4 16 48 1.068ms 0.804ms 2.530ms 2.37x 0.75x
27B TP1 4096x8 16 48 0.899ms 0.611ms 2.571ms 2.86x 0.68x
27B TP1 2048x4 16 48 0.296ms 0.218ms 0.645ms 2.18x 0.74x
27B TP1 1024x8 16 48 0.268ms 0.180ms 0.659ms 2.46x 0.67x
2B/0.8B TP1 1x32768 16 16 0.492ms 1.654ms 1.288ms 2.62x 3.36x
2B/0.8B TP1 1x16384 16 16 0.289ms 0.831ms 0.655ms 2.27x 2.88x
2B/0.8B TP1 1x8192 16 16 0.181ms 0.421ms 0.341ms 1.88x 2.32x
2B/0.8B TP1 1x4096 16 16 0.119ms 0.215ms 0.294ms 2.47x 1.81x
2B/0.8B TP1 1x2048 16 16 0.109ms 0.114ms 0.282ms 2.58x 1.05x
2B/0.8B TP1 28672+4096 16 16 0.484ms 1.440ms 1.215ms 2.51x 2.98x
2B/0.8B TP1 24576+8192 16 16 0.478ms 1.222ms 1.140ms 2.38x 2.55x
2B/0.8B TP1 16384+16384 16 16 0.474ms 0.789ms 0.982ms 2.07x 1.67x
2B/0.8B TP1 8192+24576 16 16 0.478ms 1.191ms 1.138ms 2.38x 2.49x
2B/0.8B TP1 4096+28672 16 16 0.481ms 1.392ms 1.217ms 2.53x 2.89x
2B/0.8B TP1 12288+4096 16 16 0.287ms 0.618ms 0.581ms 2.02x 2.15x
2B/0.8B TP1 6144+2048 16 16 0.179ms 0.312ms 0.306ms 1.71x 1.74x
2B/0.8B TP1 4096+4096 16 16 0.177ms 0.207ms 0.292ms 1.65x 1.17x
2B/0.8B TP1 2048+6144 16 16 0.180ms 0.306ms 0.304ms 1.69x 1.70x
2B/0.8B TP1 1024+7168 16 16 0.181ms 0.358ms 0.324ms 1.79x 1.98x
2B/0.8B TP1 8192x4 16 16 0.389ms 0.399ms 0.984ms 2.53x 1.03x
2B/0.8B TP1 4096x8 16 16 0.318ms 0.214ms 0.992ms 3.12x 0.67x
2B/0.8B TP1 2048x4 16 16 0.118ms 0.111ms 0.283ms 2.41x 0.95x
2B/0.8B TP1 1024x8 16 16 0.100ms 0.069ms 0.321ms 3.21x 0.69x
Sym h32 1x32768 32 32 0.874ms 1.567ms 1.961ms 2.24x 1.79x
Sym h32 1x16384 32 32 0.474ms 0.790ms 0.981ms 2.07x 1.66x
Sym h32 1x8192 32 32 0.281ms 0.402ms 0.508ms 1.81x 1.43x
Sym h32 1x4096 32 32 0.177ms 0.208ms 0.271ms 1.52x 1.17x
Sym h32 1x2048 32 32 0.117ms 0.111ms 0.296ms 2.54x 0.96x
Sym h32 28672+4096 32 32 0.861ms 1.376ms 1.964ms 2.28x 1.60x
Sym h32 24576+8192 32 32 0.857ms 1.181ms 1.964ms 2.29x 1.38x
Sym h32 16384+16384 32 32 0.771ms 0.808ms 1.974ms 2.56x 1.05x
Sym h32 8192+24576 32 32 0.855ms 1.183ms 1.966ms 2.30x 1.38x
Sym h32 4096+28672 32 32 0.857ms 1.379ms 1.965ms 2.29x 1.61x
Sym h32 12288+4096 32 32 0.475ms 0.597ms 0.985ms 2.07x 1.26x
Sym h32 6144+2048 32 32 0.270ms 0.305ms 0.512ms 1.90x 1.13x
Sym h32 4096+4096 32 32 0.215ms 0.210ms 0.513ms 2.38x 0.97x
Sym h32 2048+6144 32 32 0.271ms 0.304ms 0.510ms 1.88x 1.12x
Sym h32 1024+7168 32 32 0.281ms 0.353ms 0.510ms 1.81x 1.26x
Sym h32 8192x4 32 32 0.600ms 0.419ms 1.981ms 3.30x 0.70x
Sym h32 4096x8 32 32 0.620ms 0.426ms 1.994ms 3.22x 0.69x
Sym h32 2048x4 32 32 0.176ms 0.119ms 0.519ms 2.95x 0.67x
Sym h32 1024x8 32 32 0.183ms 0.128ms 0.535ms 2.93x 0.70x
==============================================================================================================
>>> BACKWARD BENCHMARKS (Split seq into 8 sequences)
Heads SeqLen flash_qla [bwd] FLA [bwd] Speedup
---------------------------------------------------------------
32 16k 1.262ms 2.433ms 1.93x
32 32k 2.743ms 4.904ms 1.79x
32 64k 6.554ms 9.247ms 1.41x
32 128k 9.489ms 18.195ms 1.92x
32 256k 20.939ms 38.613ms 1.84x
48 16k 1.775ms 3.549ms 2.00x
48 32k 3.715ms 7.197ms 1.94x
48 64k 7.565ms 13.864ms 1.83x
48 128k 13.777ms 28.175ms 2.05x
48 256k 31.321ms 56.344ms 1.80x
64 16k 2.213ms 4.654ms 2.10x
64 32k 4.513ms 9.237ms 2.05x
64 64k 8.024ms 18.265ms 2.28x
64 128k 15.299ms 36.572ms 2.39x
64 256k 34.310ms 74.392ms 2.17x
Benchmark Finished.