first commit
This commit is contained in:
585
benchmark/bench_gated_delta_rule.py
Normal file
585
benchmark/bench_gated_delta_rule.py
Normal file
@@ -0,0 +1,585 @@
|
||||
# Copyright (c) 2026 The Qwen team, Alibaba Group.
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
#
|
||||
# Benchmark Script for FlashQLA
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import gc
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple, Optional, Dict, Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import tilelang
|
||||
|
||||
# Kernel Imports
|
||||
from fla.ops.gated_delta_rule.chunk import (
|
||||
chunk_gated_delta_rule_fwd as fla_fwd,
|
||||
chunk_gated_delta_rule_bwd as fla_bwd,
|
||||
)
|
||||
|
||||
from flash_qla import (
|
||||
chunk_gated_delta_rule_fwd as qla_fwd,
|
||||
chunk_gated_delta_rule_bwd as qla_bwd,
|
||||
)
|
||||
from flash_qla.utils import l2norm
|
||||
|
||||
try:
|
||||
from flashinfer.gdn_prefill import chunk_gated_delta_rule as fi_fwd
|
||||
|
||||
HAS_FI = True
|
||||
except ImportError:
|
||||
HAS_FI = False
|
||||
|
||||
HEAD_DIM = 128
|
||||
BWD_SPLIT_SIZE = 8
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
label: str
|
||||
h_qk: int
|
||||
h_v: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class SeqLenConfig:
|
||||
label: str
|
||||
seqlens: List[int]
|
||||
|
||||
|
||||
def generate_rand_seqlens(batch_size, num_tokens):
|
||||
bars = (
|
||||
torch.sort(
|
||||
torch.randperm(num_tokens - 1, device="cuda", dtype=torch.int32)[
|
||||
: batch_size - 1
|
||||
]
|
||||
).values
|
||||
+ 1
|
||||
)
|
||||
cu_seqlens = torch.nn.functional.pad(bars, (1, 1))
|
||||
cu_seqlens[-1] = num_tokens
|
||||
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
return seqlens.tolist()
|
||||
|
||||
|
||||
FWD_MODEL_CONFIGS = [
|
||||
ModelConfig("397B/122B TP8", 2, 8),
|
||||
ModelConfig("397B/122B TP4", 4, 16),
|
||||
ModelConfig("397B/122B TP2", 8, 32),
|
||||
ModelConfig("397B/122B TP1", 16, 64),
|
||||
ModelConfig("35B/9B/4B TP1", 16, 32),
|
||||
ModelConfig("27B TP2", 8, 24),
|
||||
ModelConfig("27B TP1", 16, 48),
|
||||
ModelConfig("2B/0.8B TP1", 16, 16),
|
||||
ModelConfig("Sym h32", 32, 32),
|
||||
]
|
||||
|
||||
FWD_SEQLEN_CONFIGS = [
|
||||
SeqLenConfig("1x32768", [32768]),
|
||||
SeqLenConfig("1x16384", [16384]),
|
||||
SeqLenConfig("1x8192", [8192]),
|
||||
SeqLenConfig("1x4096", [4096]),
|
||||
SeqLenConfig("1x2048", [2048]),
|
||||
SeqLenConfig("28672+4096", [28672, 4096]),
|
||||
SeqLenConfig("24576+8192", [24576, 8192]),
|
||||
SeqLenConfig("16384+16384", [16384, 16384]),
|
||||
SeqLenConfig("8192+24576", [8192, 24576]),
|
||||
SeqLenConfig("4096+28672", [4096, 28672]),
|
||||
SeqLenConfig("12288+4096", [12288, 4096]),
|
||||
SeqLenConfig("6144+2048", [6144, 2048]),
|
||||
SeqLenConfig("4096+4096", [4096, 4096]),
|
||||
SeqLenConfig("2048+6144", [2048, 6144]),
|
||||
SeqLenConfig("1024+7168", [1024, 7168]),
|
||||
SeqLenConfig("8192x4", [8192] * 4),
|
||||
SeqLenConfig("4096x8", [4096] * 8),
|
||||
SeqLenConfig("2048x4", [2048] * 4),
|
||||
SeqLenConfig("1024x8", [1024] * 8),
|
||||
]
|
||||
BWD_MODEL_CONFIGS = [
|
||||
ModelConfig("32", 32, 32),
|
||||
ModelConfig("48", 48, 48),
|
||||
ModelConfig("64", 64, 64),
|
||||
]
|
||||
|
||||
BWD_SEQLEN_CONFIGS = [
|
||||
SeqLenConfig("16k", generate_rand_seqlens(8, 16384)),
|
||||
SeqLenConfig("32k", generate_rand_seqlens(8, 32768)),
|
||||
SeqLenConfig("64k", generate_rand_seqlens(8, 65536)),
|
||||
SeqLenConfig("128k", generate_rand_seqlens(8, 131072)),
|
||||
SeqLenConfig("256k", generate_rand_seqlens(8, 262144)),
|
||||
]
|
||||
|
||||
|
||||
def cleanup_cuda():
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def get_lib_versions() -> Dict[str, str]:
|
||||
"""Collects version strings for relevant libraries."""
|
||||
versions = {}
|
||||
|
||||
# Torch
|
||||
try:
|
||||
versions["torch"] = torch.__version__
|
||||
except Exception:
|
||||
versions["torch"] = "N/A"
|
||||
|
||||
# Flash Linear Attention (FLA)
|
||||
try:
|
||||
import fla
|
||||
|
||||
if hasattr(fla, "__version__"):
|
||||
versions["fla"] = fla.__version__
|
||||
else:
|
||||
versions["fla"] = "Installed (ver unknown)"
|
||||
except ImportError:
|
||||
versions["fla"] = "Not Installed"
|
||||
|
||||
# FlashInfer
|
||||
try:
|
||||
import flashinfer
|
||||
|
||||
if hasattr(flashinfer, "__version__"):
|
||||
versions["flashinfer"] = flashinfer.__version__
|
||||
elif hasattr(flashinfer, "version"):
|
||||
versions["flashinfer"] = str(flashinfer.version)
|
||||
else:
|
||||
versions["flashinfer"] = "Installed (ver unknown)"
|
||||
except ImportError:
|
||||
versions["flashinfer"] = "Not Installed"
|
||||
|
||||
# TileLang
|
||||
try:
|
||||
import tilelang
|
||||
|
||||
if hasattr(tilelang, "__version__"):
|
||||
versions["tilelang"] = tilelang.__version__
|
||||
elif hasattr(tilelang, "version"):
|
||||
versions["tilelang"] = str(tilelang.version)
|
||||
else:
|
||||
versions["tilelang"] = "Installed (ver unknown)"
|
||||
except ImportError:
|
||||
versions["tilelang"] = "Not Installed"
|
||||
|
||||
return versions
|
||||
|
||||
|
||||
def prepare_tensors(
|
||||
seqlens: List[int], h_qk: int, h_v: int, head_dim: int = HEAD_DIM
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
device = "cuda"
|
||||
num_seqs = len(seqlens)
|
||||
total_tokens = sum(seqlens)
|
||||
scale = head_dim ** (-0.5)
|
||||
|
||||
offsets = [0]
|
||||
for s in seqlens:
|
||||
offsets.append(offsets[-1] + s)
|
||||
cu_seqlens = torch.tensor(offsets, dtype=torch.int32, device=device)
|
||||
|
||||
try:
|
||||
q = l2norm(
|
||||
torch.randn(
|
||||
1, total_tokens, h_qk, head_dim, device=device, dtype=torch.bfloat16
|
||||
)
|
||||
)
|
||||
k = l2norm(
|
||||
torch.randn(
|
||||
1, total_tokens, h_qk, head_dim, device=device, dtype=torch.bfloat16
|
||||
)
|
||||
)
|
||||
v = torch.randn(
|
||||
1, total_tokens, h_v, head_dim, device=device, dtype=torch.bfloat16
|
||||
)
|
||||
g = (
|
||||
F.logsigmoid(
|
||||
torch.randn(1, total_tokens, h_v, device=device, dtype=torch.float32)
|
||||
)
|
||||
/ 16
|
||||
)
|
||||
beta = torch.randn(
|
||||
1, total_tokens, h_v, device=device, dtype=torch.float32
|
||||
).sigmoid()
|
||||
h0 = torch.randn(
|
||||
num_seqs, h_v, head_dim, head_dim, device=device, dtype=torch.float32
|
||||
)
|
||||
do = torch.randn_like(v)
|
||||
dht = (
|
||||
torch.randn(
|
||||
num_seqs, h_v, head_dim, head_dim, device=device, dtype=torch.float32
|
||||
)
|
||||
/ 8
|
||||
)
|
||||
except RuntimeError as e:
|
||||
if "out of memory" in str(e).lower():
|
||||
return None
|
||||
raise e
|
||||
|
||||
swa_ratio = 0.75
|
||||
swa_mask = torch.zeros(h_v, dtype=torch.bool, device=device)
|
||||
swa_mask[: math.ceil(swa_ratio * h_v)] = True
|
||||
swa_mask = swa_mask[torch.randperm(h_v, device=device)]
|
||||
g[:, :, ~swa_mask] = 0.0
|
||||
|
||||
return {
|
||||
"device": device,
|
||||
"num_seqs": num_seqs,
|
||||
"total_tokens": total_tokens,
|
||||
"scale": scale,
|
||||
"cu_seqlens": cu_seqlens,
|
||||
"q": q,
|
||||
"k": k,
|
||||
"v": v,
|
||||
"g": g,
|
||||
"beta": beta,
|
||||
"h0": h0,
|
||||
"do": do,
|
||||
"dht": dht,
|
||||
}
|
||||
|
||||
|
||||
def bench_fwd(
|
||||
seqlens: List[int],
|
||||
h_qk: int,
|
||||
h_v: int,
|
||||
head_dim: int = HEAD_DIM,
|
||||
warmup: int = 10,
|
||||
repeats: int = 5,
|
||||
) -> Tuple[float, float, float]:
|
||||
"""
|
||||
Run Forward Pass Benchmark.
|
||||
Returns: (qla_mean_ms, fi_mean_ms, fla_mean_ms)
|
||||
"""
|
||||
cleanup_cuda()
|
||||
data = prepare_tensors(seqlens, h_qk, h_v, head_dim)
|
||||
if data is None:
|
||||
return float("nan"), float("nan"), float("nan")
|
||||
|
||||
q, k, v, g, beta = data["q"], data["k"], data["v"], data["g"], data["beta"]
|
||||
h0, scale, cu_seqlens = data["h0"], data["scale"], data["cu_seqlens"]
|
||||
|
||||
results = {}
|
||||
|
||||
def call_qla_fwd():
|
||||
qla_fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
g,
|
||||
beta,
|
||||
scale=scale,
|
||||
initial_state=h0,
|
||||
output_final_state=True,
|
||||
output_h=False,
|
||||
cu_seqlens=cu_seqlens,
|
||||
auto_cp=True,
|
||||
)
|
||||
|
||||
try:
|
||||
mean = tilelang.profiler.do_bench(call_qla_fwd, warmup=warmup, rep=repeats)
|
||||
results["flash_qla"] = mean
|
||||
except RuntimeError as e:
|
||||
print(f"\n[WARN] FlashQLA Fwd failed: {e}")
|
||||
cleanup_cuda()
|
||||
results["flash_qla"] = float("nan")
|
||||
|
||||
if HAS_FI:
|
||||
|
||||
def call_fi_fwd():
|
||||
fi_fwd(
|
||||
q=q.view(-1, h_qk, head_dim),
|
||||
k=k.view(-1, h_qk, head_dim),
|
||||
v=v.view(-1, h_v, head_dim),
|
||||
g=g.view(-1, h_v),
|
||||
beta=beta.view(-1, h_v),
|
||||
scale=scale,
|
||||
initial_state=h0,
|
||||
cu_seqlens=cu_seqlens,
|
||||
output_final_state=True,
|
||||
)
|
||||
|
||||
try:
|
||||
mean = tilelang.profiler.do_bench(call_fi_fwd, warmup=warmup, rep=repeats)
|
||||
results["fi"] = mean
|
||||
except RuntimeError as e:
|
||||
print(f"\n[WARN] FI Fwd failed: {e}")
|
||||
cleanup_cuda()
|
||||
results["fi"] = float("nan")
|
||||
else:
|
||||
results["fi"] = float("nan")
|
||||
|
||||
def call_fla_fwd():
|
||||
fla_fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
g,
|
||||
beta,
|
||||
scale=scale,
|
||||
initial_state=h0,
|
||||
output_final_state=True,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
|
||||
try:
|
||||
mean = tilelang.profiler.do_bench(call_fla_fwd, warmup=warmup, rep=repeats)
|
||||
results["fla"] = mean
|
||||
except RuntimeError as e:
|
||||
print(f"\n[WARN] FLA Fwd failed: {e}")
|
||||
cleanup_cuda()
|
||||
results["fla"] = float("nan")
|
||||
|
||||
try:
|
||||
torch.cuda.synchronize()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return (
|
||||
results.get("flash_qla", float("nan")),
|
||||
results.get("fi", float("nan")),
|
||||
results.get("fla", float("nan")),
|
||||
)
|
||||
|
||||
|
||||
def bench_bwd(
|
||||
seqlens: List[int],
|
||||
h_qk: int,
|
||||
h_v: int,
|
||||
head_dim: int = HEAD_DIM,
|
||||
warmup: int = 10,
|
||||
repeats: int = 100,
|
||||
) -> Tuple[float, float]:
|
||||
"""
|
||||
Run Backward Pass Benchmark.
|
||||
Returns: (qla_mean_ms, fla_mean_ms)
|
||||
"""
|
||||
unified_h = h_qk
|
||||
cleanup_cuda()
|
||||
|
||||
data = prepare_tensors(seqlens, unified_h, unified_h, head_dim)
|
||||
if data is None:
|
||||
return float("nan"), float("nan")
|
||||
|
||||
q, k, v, g, beta = data["q"], data["k"], data["v"], data["g"], data["beta"]
|
||||
h0, scale, cu_seqlens = data["h0"], data["scale"], data["cu_seqlens"]
|
||||
do, dht = data["do"], data["dht"]
|
||||
|
||||
g_cumsum = None
|
||||
A = None
|
||||
|
||||
# Pre-run FWD to get intermediates
|
||||
try:
|
||||
result = qla_fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
g,
|
||||
beta,
|
||||
scale=scale,
|
||||
initial_state=h0,
|
||||
output_final_state=True,
|
||||
output_h=False,
|
||||
cu_seqlens=cu_seqlens,
|
||||
auto_cp=True,
|
||||
)
|
||||
if isinstance(result, tuple) and len(result) >= 2:
|
||||
g_cumsum, A = result[0], result[1]
|
||||
else:
|
||||
raise RuntimeError("FlashQLA FWD did not return expected intermediates")
|
||||
except RuntimeError as e:
|
||||
print(f"[FWD Error] Failed at seqlens={seqlens}, heads={h_qk}. Error: {e}")
|
||||
cleanup_cuda()
|
||||
return float("nan"), float("nan")
|
||||
|
||||
results = {}
|
||||
|
||||
def call_qla_bwd():
|
||||
return qla_bwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
g_cumsum,
|
||||
beta,
|
||||
A,
|
||||
do,
|
||||
dht,
|
||||
scale=scale,
|
||||
initial_state=h0,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
|
||||
try:
|
||||
mean = tilelang.profiler.do_bench(call_qla_bwd, warmup=warmup, rep=repeats)
|
||||
results["flash_qla"] = mean
|
||||
except RuntimeError as e:
|
||||
print(f"\n[WARN] FlashQLA Bwd failed: {e}")
|
||||
cleanup_cuda()
|
||||
results["flash_qla"] = float("nan")
|
||||
|
||||
def call_fla_bwd():
|
||||
return fla_bwd(q, k, v, g_cumsum, beta, A, scale, h0, do, dht, cu_seqlens)
|
||||
|
||||
try:
|
||||
mean = tilelang.profiler.do_bench(call_fla_bwd, warmup=warmup, rep=repeats)
|
||||
results["fla"] = mean
|
||||
except RuntimeError as e:
|
||||
print(f"\n[WARN] FLA Bwd failed: {e}")
|
||||
cleanup_cuda()
|
||||
results["fla"] = float("nan")
|
||||
|
||||
return results.get("flash_qla", float("nan")), results.get("fla", float("nan"))
|
||||
|
||||
|
||||
FWD_HDR = (
|
||||
f"{'Model Config':<16} {'Seqlens':<17} {'h_qk':>5} {'h_v':>5} "
|
||||
f"{'flash_qla [fwd]':>10} {'FI [fwd]':>10} {'FLA [fwd]':>10} "
|
||||
f"{'vs FLA':>7} {'vs FI':>7}"
|
||||
)
|
||||
|
||||
BWD_HDR = (
|
||||
f"{'Heads':<8} {'SeqLen':<15} "
|
||||
f"{'flash_qla [bwd]':>10} {'FLA [bwd]':>10} {'Speedup':>8}"
|
||||
)
|
||||
|
||||
|
||||
def fmt_time(ms: float) -> str:
|
||||
if math.isnan(ms):
|
||||
return " N/A "
|
||||
return f"{ms:>8.3f}ms"
|
||||
|
||||
|
||||
def fmt_ratio(base: float, other: float) -> str:
|
||||
if math.isnan(base) or math.isnan(other) or base == 0:
|
||||
return " N/A "
|
||||
return f"{other / base:>6.2f}x"
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Benchmark FlashQLA Gated Delta Rule")
|
||||
parser.add_argument("--warmup", type=int, default=10)
|
||||
parser.add_argument("--repeats", type=int, default=100)
|
||||
parser.add_argument("--mode", choices=["fwd", "bwd", "all"], default="all")
|
||||
parser.add_argument("--skip-fi", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
print("CUDA not available.")
|
||||
return
|
||||
|
||||
global HAS_FI
|
||||
if args.skip_fi:
|
||||
HAS_FI = False
|
||||
|
||||
gpu_name = torch.cuda.get_device_properties(0).name
|
||||
print(f"GPU: {gpu_name}")
|
||||
print("Models: Qwen3.5 family (397B, 122B, 35B, 27B, 9B, 4B, 2B, 0.8B), d=128")
|
||||
print(f"Config: Warmup={args.warmup}, Repeats={args.repeats}")
|
||||
|
||||
libs = get_lib_versions()
|
||||
print("Library Versions:")
|
||||
ver_str = " | ".join([f"{k}: {v}" for k, v in libs.items()])
|
||||
print(f" {ver_str}")
|
||||
|
||||
print("=" * 110)
|
||||
|
||||
# Forward
|
||||
if args.mode in ("fwd", "all"):
|
||||
print("\n>>> FORWARD BENCHMARKS")
|
||||
print(FWD_HDR)
|
||||
print("-" * len(FWD_HDR))
|
||||
|
||||
prev_model = None
|
||||
for cfg in FWD_MODEL_CONFIGS:
|
||||
if prev_model is not None and cfg.label != prev_model:
|
||||
print()
|
||||
prev_model = cfg.label
|
||||
|
||||
for sl_cfg in FWD_SEQLEN_CONFIGS:
|
||||
try:
|
||||
qla_ms, fi_ms, fla_ms = bench_fwd(
|
||||
sl_cfg.seqlens,
|
||||
cfg.h_qk,
|
||||
cfg.h_v,
|
||||
warmup=args.warmup,
|
||||
repeats=args.repeats,
|
||||
)
|
||||
|
||||
if math.isnan(qla_ms) and math.isnan(fla_ms):
|
||||
cleanup_cuda()
|
||||
|
||||
ratio_fla = fmt_ratio(qla_ms, fla_ms)
|
||||
ratio_fi = fmt_ratio(qla_ms, fi_ms)
|
||||
|
||||
print(
|
||||
f"{cfg.label:<16} {sl_cfg.label:<17} {cfg.h_qk:>5} {cfg.h_v:>5} "
|
||||
f"{fmt_time(qla_ms)} {fmt_time(fi_ms)} {fmt_time(fla_ms)} "
|
||||
f"{ratio_fla} {ratio_fi}",
|
||||
flush=True,
|
||||
)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"\n[ERROR] Forward Case Failed: {cfg.label} / {sl_cfg.label}"
|
||||
)
|
||||
print(f"Exception: {e}")
|
||||
cleanup_cuda()
|
||||
continue
|
||||
|
||||
cleanup_cuda()
|
||||
|
||||
# Backward
|
||||
if args.mode in ("bwd", "all"):
|
||||
print("\n" + "=" * 110)
|
||||
print(f"\n>>> BACKWARD BENCHMARKS (Split seq into {BWD_SPLIT_SIZE} sequences)")
|
||||
print(BWD_HDR)
|
||||
print("-" * len(BWD_HDR))
|
||||
|
||||
prev_model = None
|
||||
for cfg in BWD_MODEL_CONFIGS:
|
||||
if prev_model is not None and cfg.label != prev_model:
|
||||
print()
|
||||
prev_model = cfg.label
|
||||
|
||||
for sl_cfg in BWD_SEQLEN_CONFIGS:
|
||||
try:
|
||||
qla_bwd_ms, fla_bwd_ms = bench_bwd(
|
||||
sl_cfg.seqlens,
|
||||
cfg.h_qk,
|
||||
cfg.h_v,
|
||||
warmup=args.warmup,
|
||||
repeats=args.repeats,
|
||||
)
|
||||
|
||||
if math.isnan(qla_bwd_ms) and math.isnan(fla_bwd_ms):
|
||||
speedup = " Skip "
|
||||
else:
|
||||
speedup = fmt_ratio(qla_bwd_ms, fla_bwd_ms)
|
||||
|
||||
print(
|
||||
f"{cfg.label:<8} {sl_cfg.label:<15} "
|
||||
f"{fmt_time(qla_bwd_ms)} {fmt_time(fla_bwd_ms)} {speedup} ",
|
||||
flush=True,
|
||||
)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"\n[CRITICAL ERROR] Case Crashed: Heads={cfg.label}, TotalSeqLen={sl_cfg.label}"
|
||||
)
|
||||
print(f"Exception: {e}")
|
||||
cleanup_cuda()
|
||||
continue
|
||||
|
||||
cleanup_cuda()
|
||||
|
||||
print("\nBenchmark Finished.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
214
benchmark/benchmark_results_H200.txt
Normal file
214
benchmark/benchmark_results_H200.txt
Normal file
@@ -0,0 +1,214 @@
|
||||
GPU: NVIDIA H200
|
||||
Models: Qwen3.5 family (397B, 122B, 35B, 27B, 9B, 4B, 2B, 0.8B), d=128
|
||||
Config: Warmup=10, Repeats=100
|
||||
Library Versions:
|
||||
torch: 2.9.1+cu128 | fla: 0.5.0 | flashinfer: 0.6.9 | tilelang: 0.1.8
|
||||
==============================================================================================================
|
||||
|
||||
>>> FORWARD BENCHMARKS
|
||||
Model Config Seqlens h_qk h_v flash_qla [fwd] FI [fwd] FLA [fwd] vs FLA vs FI
|
||||
------------------------------------------------------------------------------------------------------------
|
||||
397B/122B TP8 1x32768 2 8 0.310ms 1.657ms 0.910ms 2.93x 5.34x
|
||||
397B/122B TP8 1x16384 2 8 0.184ms 0.832ms 0.465ms 2.53x 4.52x
|
||||
397B/122B TP8 1x8192 2 8 0.122ms 0.420ms 0.578ms 4.74x 3.45x
|
||||
397B/122B TP8 1x4096 2 8 0.107ms 0.215ms 0.277ms 2.59x 2.01x
|
||||
397B/122B TP8 1x2048 2 8 0.106ms 0.114ms 0.261ms 2.46x 1.07x
|
||||
397B/122B TP8 28672+4096 2 8 0.304ms 1.449ms 0.841ms 2.77x 4.77x
|
||||
397B/122B TP8 24576+8192 2 8 0.300ms 1.243ms 0.766ms 2.55x 4.15x
|
||||
397B/122B TP8 16384+16384 2 8 0.292ms 0.831ms 0.623ms 2.13x 2.84x
|
||||
397B/122B TP8 8192+24576 2 8 0.298ms 1.241ms 0.766ms 2.57x 4.16x
|
||||
397B/122B TP8 4096+28672 2 8 0.304ms 1.447ms 0.841ms 2.76x 4.75x
|
||||
397B/122B TP8 12288+4096 2 8 0.180ms 0.626ms 0.391ms 2.17x 3.48x
|
||||
397B/122B TP8 6144+2048 2 8 0.118ms 0.317ms 0.263ms 2.23x 2.70x
|
||||
397B/122B TP8 4096+4096 2 8 0.115ms 0.215ms 0.262ms 2.27x 1.86x
|
||||
397B/122B TP8 2048+6144 2 8 0.118ms 0.317ms 0.276ms 2.35x 2.70x
|
||||
397B/122B TP8 1024+7168 2 8 0.120ms 0.368ms 0.263ms 2.19x 3.06x
|
||||
397B/122B TP8 8192x4 2 8 0.274ms 0.401ms 0.477ms 1.74x 1.46x
|
||||
397B/122B TP8 4096x8 2 8 0.203ms 0.205ms 0.481ms 2.37x 1.01x
|
||||
397B/122B TP8 2048x4 2 8 0.109ms 0.110ms 0.268ms 2.46x 1.01x
|
||||
397B/122B TP8 1024x8 2 8 0.065ms 0.063ms 0.273ms 4.17x 0.95x
|
||||
|
||||
397B/122B TP4 1x32768 4 16 0.485ms 1.655ms 1.250ms 2.58x 3.41x
|
||||
397B/122B TP4 1x16384 4 16 0.292ms 0.832ms 0.630ms 2.16x 2.85x
|
||||
397B/122B TP4 1x8192 4 16 0.178ms 0.420ms 0.325ms 1.83x 2.37x
|
||||
397B/122B TP4 1x4096 4 16 0.115ms 0.215ms 0.290ms 2.51x 1.86x
|
||||
397B/122B TP4 1x2048 4 16 0.108ms 0.114ms 0.275ms 2.54x 1.05x
|
||||
397B/122B TP4 28672+4096 4 16 0.477ms 1.440ms 1.178ms 2.47x 3.02x
|
||||
397B/122B TP4 24576+8192 4 16 0.474ms 1.222ms 1.100ms 2.32x 2.58x
|
||||
397B/122B TP4 16384+16384 4 16 0.469ms 0.789ms 0.947ms 2.02x 1.68x
|
||||
397B/122B TP4 8192+24576 4 16 0.470ms 1.192ms 1.099ms 2.34x 2.53x
|
||||
397B/122B TP4 4096+28672 4 16 0.474ms 1.394ms 1.175ms 2.48x 2.94x
|
||||
397B/122B TP4 12288+4096 4 16 0.281ms 0.616ms 0.549ms 1.95x 2.19x
|
||||
397B/122B TP4 6144+2048 4 16 0.175ms 0.312ms 0.292ms 1.67x 1.78x
|
||||
397B/122B TP4 4096+4096 4 16 0.173ms 0.207ms 0.263ms 1.52x 1.19x
|
||||
397B/122B TP4 2048+6144 4 16 0.175ms 0.306ms 0.294ms 1.67x 1.75x
|
||||
397B/122B TP4 1024+7168 4 16 0.177ms 0.359ms 0.307ms 1.74x 2.03x
|
||||
397B/122B TP4 8192x4 4 16 0.388ms 0.400ms 0.949ms 2.45x 1.03x
|
||||
397B/122B TP4 4096x8 4 16 0.311ms 0.210ms 0.954ms 3.07x 0.68x
|
||||
397B/122B TP4 2048x4 4 16 0.114ms 0.110ms 0.264ms 2.32x 0.97x
|
||||
397B/122B TP4 1024x8 4 16 0.097ms 0.066ms 0.280ms 2.88x 0.68x
|
||||
|
||||
397B/122B TP2 1x32768 8 32 0.866ms 1.567ms 1.894ms 2.19x 1.81x
|
||||
397B/122B TP2 1x16384 8 32 0.468ms 0.789ms 0.926ms 1.98x 1.69x
|
||||
397B/122B TP2 1x8192 8 32 0.275ms 0.401ms 0.479ms 1.74x 1.46x
|
||||
397B/122B TP2 1x4096 8 32 0.175ms 0.207ms 0.273ms 1.56x 1.18x
|
||||
397B/122B TP2 1x2048 8 32 0.116ms 0.110ms 0.308ms 2.66x 0.96x
|
||||
397B/122B TP2 28672+4096 8 32 0.841ms 1.374ms 1.893ms 2.25x 1.63x
|
||||
397B/122B TP2 24576+8192 8 32 0.846ms 1.178ms 1.891ms 2.24x 1.39x
|
||||
397B/122B TP2 16384+16384 8 32 0.752ms 0.789ms 1.897ms 2.52x 1.05x
|
||||
397B/122B TP2 8192+24576 8 32 0.842ms 1.182ms 1.893ms 2.25x 1.40x
|
||||
397B/122B TP2 4096+28672 8 32 0.843ms 1.378ms 1.899ms 2.25x 1.63x
|
||||
397B/122B TP2 12288+4096 8 32 0.468ms 0.597ms 0.928ms 1.98x 1.28x
|
||||
397B/122B TP2 6144+2048 8 32 0.271ms 0.305ms 0.506ms 1.87x 1.13x
|
||||
397B/122B TP2 4096+4096 8 32 0.206ms 0.206ms 0.482ms 2.34x 1.00x
|
||||
397B/122B TP2 2048+6144 8 32 0.269ms 0.305ms 0.483ms 1.79x 1.13x
|
||||
397B/122B TP2 1024+7168 8 32 0.277ms 0.354ms 0.480ms 1.73x 1.27x
|
||||
397B/122B TP2 8192x4 8 32 0.596ms 0.405ms 1.894ms 3.18x 0.68x
|
||||
397B/122B TP2 4096x8 8 32 0.601ms 0.409ms 1.904ms 3.17x 0.68x
|
||||
397B/122B TP2 2048x4 8 32 0.172ms 0.115ms 0.486ms 2.83x 0.67x
|
||||
397B/122B TP2 1024x8 8 32 0.180ms 0.124ms 0.504ms 2.80x 0.69x
|
||||
|
||||
397B/122B TP1 1x32768 16 64 1.494ms 1.563ms 3.336ms 2.23x 1.05x
|
||||
397B/122B TP1 1x16384 16 64 0.762ms 0.790ms 1.608ms 2.11x 1.04x
|
||||
397B/122B TP1 1x8192 16 64 0.393ms 0.397ms 0.815ms 2.07x 1.01x
|
||||
397B/122B TP1 1x4096 16 64 0.209ms 0.206ms 0.430ms 2.06x 0.99x
|
||||
397B/122B TP1 1x2048 16 64 0.117ms 0.111ms 0.276ms 2.37x 0.95x
|
||||
397B/122B TP1 28672+4096 16 64 1.719ms 1.374ms 3.341ms 1.94x 0.80x
|
||||
397B/122B TP1 24576+8192 16 64 1.530ms 1.179ms 3.341ms 2.18x 0.77x
|
||||
397B/122B TP1 16384+16384 16 64 1.168ms 0.804ms 3.367ms 2.88x 0.69x
|
||||
397B/122B TP1 8192+24576 16 64 1.539ms 1.232ms 3.352ms 2.18x 0.80x
|
||||
397B/122B TP1 4096+28672 16 64 1.723ms 1.453ms 3.353ms 1.95x 0.84x
|
||||
397B/122B TP1 12288+4096 16 64 0.781ms 0.596ms 1.606ms 2.06x 0.76x
|
||||
397B/122B TP1 6144+2048 16 64 0.406ms 0.305ms 0.822ms 2.02x 0.75x
|
||||
397B/122B TP1 4096+4096 16 64 0.318ms 0.209ms 0.823ms 2.59x 0.66x
|
||||
397B/122B TP1 2048+6144 16 64 0.406ms 0.317ms 0.825ms 2.03x 0.78x
|
||||
397B/122B TP1 1024+7168 16 64 0.454ms 0.371ms 0.825ms 1.82x 0.82x
|
||||
397B/122B TP1 8192x4 16 64 1.170ms 0.800ms 3.361ms 2.87x 0.68x
|
||||
397B/122B TP1 4096x8 16 64 1.191ms 0.819ms 3.411ms 2.86x 0.69x
|
||||
397B/122B TP1 2048x4 16 64 0.324ms 0.219ms 0.837ms 2.58x 0.67x
|
||||
397B/122B TP1 1024x8 16 64 0.344ms 0.239ms 0.868ms 2.52x 0.70x
|
||||
|
||||
35B/9B/4B TP1 1x32768 16 32 0.871ms 1.567ms 1.908ms 2.19x 1.80x
|
||||
35B/9B/4B TP1 1x16384 16 32 0.472ms 0.790ms 0.945ms 2.00x 1.67x
|
||||
35B/9B/4B TP1 1x8192 16 32 0.278ms 0.402ms 0.484ms 1.74x 1.44x
|
||||
35B/9B/4B TP1 1x4096 16 32 0.176ms 0.207ms 0.266ms 1.51x 1.17x
|
||||
35B/9B/4B TP1 1x2048 16 32 0.118ms 0.111ms 0.265ms 2.24x 0.94x
|
||||
35B/9B/4B TP1 28672+4096 16 32 0.860ms 1.374ms 1.904ms 2.21x 1.60x
|
||||
35B/9B/4B TP1 24576+8192 16 32 0.857ms 1.180ms 1.904ms 2.22x 1.38x
|
||||
35B/9B/4B TP1 16384+16384 16 32 0.762ms 0.796ms 1.919ms 2.52x 1.05x
|
||||
35B/9B/4B TP1 8192+24576 16 32 0.854ms 1.184ms 1.909ms 2.24x 1.39x
|
||||
35B/9B/4B TP1 4096+28672 16 32 0.853ms 1.379ms 1.911ms 2.24x 1.62x
|
||||
35B/9B/4B TP1 12288+4096 16 32 0.474ms 0.597ms 0.946ms 2.00x 1.26x
|
||||
35B/9B/4B TP1 6144+2048 16 32 0.270ms 0.304ms 0.488ms 1.81x 1.13x
|
||||
35B/9B/4B TP1 4096+4096 16 32 0.208ms 0.206ms 0.489ms 2.35x 0.99x
|
||||
35B/9B/4B TP1 2048+6144 16 32 0.272ms 0.304ms 0.487ms 1.79x 1.12x
|
||||
35B/9B/4B TP1 1024+7168 16 32 0.279ms 0.354ms 0.488ms 1.75x 1.27x
|
||||
35B/9B/4B TP1 8192x4 16 32 0.596ms 0.406ms 1.903ms 3.19x 0.68x
|
||||
35B/9B/4B TP1 4096x8 16 32 0.607ms 0.412ms 1.918ms 3.16x 0.68x
|
||||
35B/9B/4B TP1 2048x4 16 32 0.171ms 0.116ms 0.495ms 2.89x 0.68x
|
||||
35B/9B/4B TP1 1024x8 16 32 0.184ms 0.125ms 0.513ms 2.78x 0.68x
|
||||
|
||||
27B TP2 1x32768 8 24 0.657ms 1.631ms 1.576ms 2.40x 2.48x
|
||||
27B TP2 1x16384 8 24 0.397ms 0.820ms 0.778ms 1.96x 2.06x
|
||||
27B TP2 1x8192 8 24 0.262ms 0.416ms 0.402ms 1.53x 1.58x
|
||||
27B TP2 1x4096 8 24 0.170ms 0.213ms 0.267ms 1.57x 1.25x
|
||||
27B TP2 1x2048 8 24 0.111ms 0.113ms 0.266ms 2.40x 1.02x
|
||||
27B TP2 28672+4096 8 24 0.655ms 1.422ms 1.495ms 2.28x 2.17x
|
||||
27B TP2 24576+8192 8 24 0.651ms 1.211ms 1.425ms 2.19x 1.86x
|
||||
27B TP2 16384+16384 8 24 0.657ms 0.791ms 1.579ms 2.40x 1.20x
|
||||
27B TP2 8192+24576 8 24 0.656ms 1.185ms 1.576ms 2.40x 1.81x
|
||||
27B TP2 4096+28672 8 24 0.655ms 1.382ms 1.577ms 2.41x 2.11x
|
||||
27B TP2 12288+4096 8 24 0.392ms 0.611ms 0.707ms 1.80x 1.56x
|
||||
27B TP2 6144+2048 8 24 0.258ms 0.310ms 0.372ms 1.44x 1.20x
|
||||
27B TP2 4096+4096 8 24 0.189ms 0.206ms 0.406ms 2.15x 1.09x
|
||||
27B TP2 2048+6144 8 24 0.261ms 0.303ms 0.406ms 1.56x 1.16x
|
||||
27B TP2 1024+7168 8 24 0.261ms 0.352ms 0.407ms 1.56x 1.35x
|
||||
27B TP2 8192x4 8 24 0.545ms 0.401ms 1.419ms 2.61x 0.74x
|
||||
27B TP2 4096x8 8 24 0.546ms 0.410ms 1.430ms 2.62x 0.75x
|
||||
27B TP2 2048x4 8 24 0.155ms 0.112ms 0.374ms 2.41x 0.73x
|
||||
27B TP2 1024x8 8 24 0.164ms 0.120ms 0.383ms 2.34x 0.74x
|
||||
|
||||
27B TP1 1x32768 16 48 1.243ms 1.573ms 2.728ms 2.19x 1.27x
|
||||
27B TP1 1x16384 16 48 0.664ms 0.792ms 1.327ms 2.00x 1.19x
|
||||
27B TP1 1x8192 16 48 0.402ms 0.401ms 0.677ms 1.69x 1.00x
|
||||
27B TP1 1x4096 16 48 0.193ms 0.206ms 0.355ms 1.84x 1.07x
|
||||
27B TP1 1x2048 16 48 0.107ms 0.110ms 0.264ms 2.47x 1.03x
|
||||
27B TP1 28672+4096 16 48 1.231ms 1.371ms 2.618ms 2.13x 1.11x
|
||||
27B TP1 24576+8192 16 48 1.434ms 1.185ms 2.541ms 1.77x 0.83x
|
||||
27B TP1 16384+16384 16 48 1.074ms 0.787ms 2.744ms 2.56x 0.73x
|
||||
27B TP1 8192+24576 16 48 1.435ms 1.186ms 2.741ms 1.91x 0.83x
|
||||
27B TP1 4096+28672 16 48 1.230ms 1.371ms 2.731ms 2.22x 1.11x
|
||||
27B TP1 12288+4096 16 48 0.733ms 0.598ms 1.230ms 1.68x 0.82x
|
||||
27B TP1 6144+2048 16 48 0.379ms 0.305ms 0.638ms 1.68x 0.80x
|
||||
27B TP1 4096+4096 16 48 0.290ms 0.207ms 0.682ms 2.35x 0.71x
|
||||
27B TP1 2048+6144 16 48 0.379ms 0.306ms 0.689ms 1.82x 0.81x
|
||||
27B TP1 1024+7168 16 48 0.425ms 0.355ms 0.688ms 1.62x 0.84x
|
||||
27B TP1 8192x4 16 48 1.068ms 0.804ms 2.530ms 2.37x 0.75x
|
||||
27B TP1 4096x8 16 48 0.899ms 0.611ms 2.571ms 2.86x 0.68x
|
||||
27B TP1 2048x4 16 48 0.296ms 0.218ms 0.645ms 2.18x 0.74x
|
||||
27B TP1 1024x8 16 48 0.268ms 0.180ms 0.659ms 2.46x 0.67x
|
||||
|
||||
2B/0.8B TP1 1x32768 16 16 0.492ms 1.654ms 1.288ms 2.62x 3.36x
|
||||
2B/0.8B TP1 1x16384 16 16 0.289ms 0.831ms 0.655ms 2.27x 2.88x
|
||||
2B/0.8B TP1 1x8192 16 16 0.181ms 0.421ms 0.341ms 1.88x 2.32x
|
||||
2B/0.8B TP1 1x4096 16 16 0.119ms 0.215ms 0.294ms 2.47x 1.81x
|
||||
2B/0.8B TP1 1x2048 16 16 0.109ms 0.114ms 0.282ms 2.58x 1.05x
|
||||
2B/0.8B TP1 28672+4096 16 16 0.484ms 1.440ms 1.215ms 2.51x 2.98x
|
||||
2B/0.8B TP1 24576+8192 16 16 0.478ms 1.222ms 1.140ms 2.38x 2.55x
|
||||
2B/0.8B TP1 16384+16384 16 16 0.474ms 0.789ms 0.982ms 2.07x 1.67x
|
||||
2B/0.8B TP1 8192+24576 16 16 0.478ms 1.191ms 1.138ms 2.38x 2.49x
|
||||
2B/0.8B TP1 4096+28672 16 16 0.481ms 1.392ms 1.217ms 2.53x 2.89x
|
||||
2B/0.8B TP1 12288+4096 16 16 0.287ms 0.618ms 0.581ms 2.02x 2.15x
|
||||
2B/0.8B TP1 6144+2048 16 16 0.179ms 0.312ms 0.306ms 1.71x 1.74x
|
||||
2B/0.8B TP1 4096+4096 16 16 0.177ms 0.207ms 0.292ms 1.65x 1.17x
|
||||
2B/0.8B TP1 2048+6144 16 16 0.180ms 0.306ms 0.304ms 1.69x 1.70x
|
||||
2B/0.8B TP1 1024+7168 16 16 0.181ms 0.358ms 0.324ms 1.79x 1.98x
|
||||
2B/0.8B TP1 8192x4 16 16 0.389ms 0.399ms 0.984ms 2.53x 1.03x
|
||||
2B/0.8B TP1 4096x8 16 16 0.318ms 0.214ms 0.992ms 3.12x 0.67x
|
||||
2B/0.8B TP1 2048x4 16 16 0.118ms 0.111ms 0.283ms 2.41x 0.95x
|
||||
2B/0.8B TP1 1024x8 16 16 0.100ms 0.069ms 0.321ms 3.21x 0.69x
|
||||
|
||||
Sym h32 1x32768 32 32 0.874ms 1.567ms 1.961ms 2.24x 1.79x
|
||||
Sym h32 1x16384 32 32 0.474ms 0.790ms 0.981ms 2.07x 1.66x
|
||||
Sym h32 1x8192 32 32 0.281ms 0.402ms 0.508ms 1.81x 1.43x
|
||||
Sym h32 1x4096 32 32 0.177ms 0.208ms 0.271ms 1.52x 1.17x
|
||||
Sym h32 1x2048 32 32 0.117ms 0.111ms 0.296ms 2.54x 0.96x
|
||||
Sym h32 28672+4096 32 32 0.861ms 1.376ms 1.964ms 2.28x 1.60x
|
||||
Sym h32 24576+8192 32 32 0.857ms 1.181ms 1.964ms 2.29x 1.38x
|
||||
Sym h32 16384+16384 32 32 0.771ms 0.808ms 1.974ms 2.56x 1.05x
|
||||
Sym h32 8192+24576 32 32 0.855ms 1.183ms 1.966ms 2.30x 1.38x
|
||||
Sym h32 4096+28672 32 32 0.857ms 1.379ms 1.965ms 2.29x 1.61x
|
||||
Sym h32 12288+4096 32 32 0.475ms 0.597ms 0.985ms 2.07x 1.26x
|
||||
Sym h32 6144+2048 32 32 0.270ms 0.305ms 0.512ms 1.90x 1.13x
|
||||
Sym h32 4096+4096 32 32 0.215ms 0.210ms 0.513ms 2.38x 0.97x
|
||||
Sym h32 2048+6144 32 32 0.271ms 0.304ms 0.510ms 1.88x 1.12x
|
||||
Sym h32 1024+7168 32 32 0.281ms 0.353ms 0.510ms 1.81x 1.26x
|
||||
Sym h32 8192x4 32 32 0.600ms 0.419ms 1.981ms 3.30x 0.70x
|
||||
Sym h32 4096x8 32 32 0.620ms 0.426ms 1.994ms 3.22x 0.69x
|
||||
Sym h32 2048x4 32 32 0.176ms 0.119ms 0.519ms 2.95x 0.67x
|
||||
Sym h32 1024x8 32 32 0.183ms 0.128ms 0.535ms 2.93x 0.70x
|
||||
|
||||
==============================================================================================================
|
||||
|
||||
>>> BACKWARD BENCHMARKS (Split seq into 8 sequences)
|
||||
Heads SeqLen flash_qla [bwd] FLA [bwd] Speedup
|
||||
---------------------------------------------------------------
|
||||
32 16k 1.262ms 2.433ms 1.93x
|
||||
32 32k 2.743ms 4.904ms 1.79x
|
||||
32 64k 6.554ms 9.247ms 1.41x
|
||||
32 128k 9.489ms 18.195ms 1.92x
|
||||
32 256k 20.939ms 38.613ms 1.84x
|
||||
|
||||
48 16k 1.775ms 3.549ms 2.00x
|
||||
48 32k 3.715ms 7.197ms 1.94x
|
||||
48 64k 7.565ms 13.864ms 1.83x
|
||||
48 128k 13.777ms 28.175ms 2.05x
|
||||
48 256k 31.321ms 56.344ms 1.80x
|
||||
|
||||
64 16k 2.213ms 4.654ms 2.10x
|
||||
64 32k 4.513ms 9.237ms 2.05x
|
||||
64 64k 8.024ms 18.265ms 2.28x
|
||||
64 128k 15.299ms 36.572ms 2.39x
|
||||
64 256k 34.310ms 74.392ms 2.17x
|
||||
|
||||
Benchmark Finished.
|
||||
Reference in New Issue
Block a user