Files
FlashQLA-SM70-SM75/tests/test_gdr.py
2026-06-14 23:49:03 +08:00

554 lines
18 KiB
Python

# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
import argparse
import math
import torch
import pandas as pd
# Requires flash-linear-attention==0.5.0
from fla.ops.gated_delta_rule.chunk import (
chunk_gated_delta_rule_fwd as chunk_gated_delta_rule_fwd_fla,
)
from fla.ops.gated_delta_rule.chunk import (
chunk_gated_delta_rule_bwd as chunk_gated_delta_rule_bwd_fla,
)
from flash_qla import chunk_gated_delta_rule_fwd as chunk_gated_delta_rule_fwd_qla
from flash_qla import chunk_gated_delta_rule_bwd as chunk_gated_delta_rule_bwd_qla
from flash_qla.utils import l2norm, pack, profile
from ref_gdr import chunk_gated_delta_rule_fwd as chunk_gated_delta_rule_fwd_ref
from ref_gdr import chunk_gated_delta_rule_bwd as chunk_gated_delta_rule_bwd_ref
def test_gated_delta_rule(
batch_size: int,
num_tokens: int,
num_k_heads: int,
num_v_heads: int,
head_dim_k: int,
head_dim_v: int,
varlen: bool = False,
cu_seqlens: list[int] | None = None,
use_h0: bool = False,
chunk_size: int = 64,
data_dtype: str = "bfloat16",
ref_dtype: str = "float32",
device: torch.device = "cuda",
random_seed: int = 42,
check_accuracy: bool = True,
show_speedup: bool = True,
auto_cp: bool = True,
swa_ratio: float = 0.75,
skip_bwd: bool = False,
):
data_dtype = getattr(torch, data_dtype)
ref_dtype = getattr(torch, ref_dtype)
torch.manual_seed(random_seed)
q = l2norm(
torch.randn(
(batch_size, num_tokens, num_k_heads, head_dim_k),
device=device,
dtype=data_dtype,
)
)
k = l2norm(
torch.randn(
(batch_size, num_tokens, num_k_heads, head_dim_k),
device=device,
dtype=data_dtype,
)
)
v = torch.randn(
(batch_size, num_tokens, num_v_heads, head_dim_v),
device=device,
dtype=data_dtype,
)
g = (
torch.nn.functional.logsigmoid(
torch.randn(
(batch_size, num_tokens, num_v_heads),
device=device,
dtype=torch.float32,
)
)
/ 16
)
beta = torch.randn(
(batch_size, num_tokens, num_v_heads), device=device, dtype=torch.float32
).sigmoid()
h0 = (
torch.randn(
(batch_size, num_v_heads, head_dim_k, head_dim_v),
device=device,
dtype=torch.float32,
)
if use_h0
else None
)
do = torch.randn_like(v)
dht = (
torch.randn(
(batch_size, num_v_heads, head_dim_k, head_dim_v),
device=device,
dtype=torch.float32,
)
/ 8
if use_h0
else None
)
scale = head_dim_k ** (-0.5)
print(
f"Shape: B={batch_size} Hk={num_k_heads} Hv={num_v_heads} T={num_tokens} VarLen={varlen}"
)
swa_mask = torch.zeros((num_v_heads), dtype=torch.bool, device=device)
swa_mask[: math.ceil(swa_ratio * num_v_heads)] = 1
swa_mask = swa_mask[torch.randperm(num_v_heads, device=device)]
g[:, :, ~swa_mask] = 0.0
print(f"SWA Mask: {swa_mask.to(torch.int32, copy=True).tolist()}")
if varlen:
if cu_seqlens is None:
cu_seqlens = torch.randint(
1, num_tokens, (batch_size,), device=device, dtype=torch.int32
)
cu_seqlens = torch.nn.functional.pad(
torch.cumsum(cu_seqlens, dim=-1), (1, 0)
)
q = pack(q, cu_seqlens)
k = pack(k, cu_seqlens)
v = pack(v, cu_seqlens)
g = pack(g, cu_seqlens)
beta = pack(beta, cu_seqlens)
do = pack(do, cu_seqlens)
else:
assert batch_size == 1
assert cu_seqlens[0] == 0
assert cu_seqlens[-1] == num_tokens
cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32)
if use_h0:
real_batch_size = cu_seqlens.shape[0] - 1
h0 = torch.randn(
(real_batch_size, num_v_heads, head_dim_k, head_dim_v),
device=device,
dtype=torch.float32,
)
dht = (
torch.randn(
(real_batch_size, num_v_heads, head_dim_k, head_dim_v),
device=device,
dtype=torch.float32,
)
/ 8
)
assert (cu_seqlens[1:] - cu_seqlens[:-1]).min() > 0
else:
cu_seqlens = None
g_ref, o_ref, A_ref, h_ref, s_ref = chunk_gated_delta_rule_fwd_ref(
q=q.to(ref_dtype, copy=True),
k=k.to(ref_dtype, copy=True),
v=v.to(ref_dtype, copy=True),
g=g.to(ref_dtype, copy=True),
beta=beta.to(ref_dtype, copy=True),
scale=scale,
initial_state=h0,
cu_seqlens=cu_seqlens,
)
g_fla, o_fla, A_fla, s_fla, _, _ = chunk_gated_delta_rule_fwd_fla(
q=q,
k=k,
v=v,
g=g,
beta=beta,
scale=scale,
initial_state=h0,
output_final_state=True,
cu_seqlens=cu_seqlens,
)
g_qla, A_qla, o_qla, h_qla, s_qla = chunk_gated_delta_rule_fwd_qla(
q=q,
k=k,
v=v,
g=g,
beta=beta,
scale=scale,
initial_state=h0,
cu_seqlens=cu_seqlens,
output_final_state=True,
output_h=True,
auto_cp=auto_cp,
)
if check_accuracy:
print(
f"h_qla: {(h_qla - h_ref).abs().max().item():.4f} / {h_ref.abs().max().item():.4f}"
)
print(
f"s_fla: {(s_fla - s_ref).abs().max().item():.4f} / {s_ref.abs().max().item():.4f}"
)
print(
f"s_qla: {(s_qla - s_ref).abs().max().item():.4f} / {s_ref.abs().max().item():.4f}"
)
print(
f"o_fla: {(o_fla - o_ref).abs().max().item():.4f} / {o_ref.abs().max().item():.4f}"
)
print(
f"o_qla: {(o_qla - o_ref).abs().max().item():.4f} / {o_ref.abs().max().item():.4f}"
)
for _ in range(1000):
g_qla, A_qla, o_qla, h_qla, s_qla = chunk_gated_delta_rule_fwd_qla(
q,
k,
v,
g,
beta,
scale,
h0,
cu_seqlens,
True,
False,
auto_cp,
)
try:
if h0 is not None:
assert (
s_qla - s_ref
).abs().max().item() <= s_ref.abs().max().item() * 0.02
assert (
o_qla - o_ref
).abs().max().item() <= o_ref.abs().max().item() * 0.02
except AssertionError as e:
print("********** ERROR **********")
if h0 is not None:
print(
f"s_qla: {(s_qla - s_ref).abs().max().item():.4f} / {s_ref.abs().max().item():.4f}"
)
print(
f"o_qla: {(o_qla - o_ref).abs().max().item():.4f} / {o_ref.abs().max().item():.4f}"
)
print("********** ERROR **********")
raise e
if show_speedup:
prof_fla = profile(
chunk_gated_delta_rule_fwd_fla,
[q, k, v, g, beta, scale, h0, True, cu_seqlens],
)
prof_qla = profile(
chunk_gated_delta_rule_fwd_qla,
[q, k, v, g, beta, scale, h0, cu_seqlens, True, False, auto_cp],
)
result_fla = {
"[fwd] csum": prof_fla["chunk_local_cumsum_scalar_kernel"],
"[fwd] solve": prof_fla["chunk_gated_delta_rule_fwd_kkt_solve_kernel"],
"[fwd] wu": prof_fla["recompute_w_u_fwd_kernel"],
"[fwd] gdr": prof_fla["chunk_gated_delta_rule_fwd_kernel_h_blockdim64"],
"[fwd] o": prof_fla["chunk_fwd_kernel_o"],
}
result_qla = {
"[fwd] csum": prof_qla["tilelang_chunk_local_cumsum_kernel_kernel"],
"[fwd] solve": prof_qla["tilelang_kkt_solve_kernel_kernel"],
"[fwd] gdr": prof_qla["tilelang_fused_chunk_gdr_fwd_kernel_kernel"],
}
if "tilelang_get_warmup_chunks_kernel_kernel" in prof_qla.keys():
result_fla["[fwd] cp-w"] = None
result_fla["[fwd] cp-h"] = None
result_fla["[fwd] cp-c"] = None
result_qla["[fwd] cp-w"] = prof_qla[
"tilelang_get_warmup_chunks_kernel_kernel"
]
result_qla["[fwd] cp-h"] = prof_qla["tilelang_prepare_h_kernel_kernel"]
result_qla["[fwd] cp-c"] = prof_qla["tilelang_correct_h0_kernel_kernel"]
result_fla["total"] = prof_fla["total"]
result_qla["total"] = prof_qla["total"]
results = {
"fla": result_fla,
"flash_qla": result_qla,
}
df = pd.DataFrame(results)
print(df.round(3))
speedup = results["fla"]["total"] / results["flash_qla"]["total"]
print(f"Speed up: {speedup:.2f}x")
if skip_bwd:
return
dq_ref, dk_ref, dv_ref, db_ref, dg_ref, dh0_ref = chunk_gated_delta_rule_bwd_ref(
q.to(ref_dtype, copy=True),
k.to(ref_dtype, copy=True),
v.to(ref_dtype, copy=True),
g_ref,
beta.to(ref_dtype, copy=True),
A_ref.to(ref_dtype, copy=True),
scale,
h0,
do.to(ref_dtype, copy=True),
dht,
cu_seqlens,
)
dq_fla, dk_fla, dv_fla, db_fla, dg_fla, dh0_fla, _, _ = (
chunk_gated_delta_rule_bwd_fla(
q,
k,
v,
g_fla,
beta,
A_fla,
scale,
h0,
do,
dht,
cu_seqlens,
)
)
dq_qla, dk_qla, dv_qla, db_qla, dg_qla, dh0_qla = chunk_gated_delta_rule_bwd_qla(
q,
k,
v,
g_qla,
beta,
A_qla,
do,
dht,
scale,
h0,
cu_seqlens,
)
if check_accuracy:
print(
f"dq_fla: {(dq_fla - dq_ref).abs().max().item():.4f} / {dq_ref.abs().max().item():.4f}"
)
print(
f"dq_qla: {(dq_qla - dq_ref).abs().max().item():.4f} / {dq_ref.abs().max().item():.4f}"
)
print(
f"dk_fla: {(dk_fla - dk_ref).abs().max().item():.4f} / {dk_ref.abs().max().item():.4f}"
)
print(
f"dk_qla: {(dk_qla - dk_ref).abs().max().item():.4f} / {dk_ref.abs().max().item():.4f}"
)
print(
f"dv_fla: {(dv_fla - dv_ref).abs().max().item():.4f} / {dv_ref.abs().max().item():.4f}"
)
print(
f"dv_qla: {(dv_qla - dv_ref).abs().max().item():.4f} / {dv_ref.abs().max().item():.4f}"
)
if dht is not None:
print(
f"dh0_fla: {(dh0_fla - dh0_ref).abs().max().item():.4f} / {dh0_ref.abs().max().item():.4f}"
)
print(
f"dh0_qla: {(dh0_qla - dh0_ref).abs().max().item():.4f} / {dh0_ref.abs().max().item():.4f}"
)
print(
f"db_fla: {(db_fla - db_ref).abs().max().item():.4f} / {db_ref.abs().max().item():.4f}"
)
print(
f"db_qla: {(db_qla - db_ref).abs().max().item():.4f} / {db_ref.abs().max().item():.4f}"
)
print(
f"dg_fla: {(dg_fla - dg_ref).abs().max().item():.4f} / {dg_ref.abs().max().item():.4f}"
)
print(
f"dg_qla: {(dg_qla - dg_ref).abs().max().item():.4f} / {dg_ref.abs().max().item():.4f}"
)
for _ in range(1000):
dq_qla, dk_qla, dv_qla, db_qla, dg_qla, dh0_qla = (
chunk_gated_delta_rule_bwd_qla(
q,
k,
v,
g_qla,
beta,
A_qla,
do,
dht,
scale,
h0,
cu_seqlens,
)
)
try:
assert (
dq_qla - dq_ref
).abs().max().item() <= dq_ref.abs().max().item() * 0.02
assert (
dk_qla - dk_ref
).abs().max().item() <= dk_ref.abs().max().item() * 0.02
assert (
dv_qla - dv_ref
).abs().max().item() <= dv_ref.abs().max().item() * 0.02
assert (
dg_qla - dg_ref
).abs().max().item() <= dg_ref.abs().max().item() * 0.02
assert (
db_qla - db_ref
).abs().max().item() <= db_ref.abs().max().item() * 0.02
if dht is not None:
assert (
dh0_qla - dh0_ref
).abs().max().item() <= dh0_ref.abs().max().item() * 0.02
except AssertionError as e:
print("********** ERROR **********")
print(
f"dq_qla: {(dq_qla - dq_ref).abs().max().item():.4f} / {dq_ref.abs().max().item():.4f}"
)
print(
f"dk_qla: {(dk_qla - dk_ref).abs().max().item():.4f} / {dk_ref.abs().max().item():.4f}"
)
print(
f"dv_qla: {(dv_qla - dv_ref).abs().max().item():.4f} / {dv_ref.abs().max().item():.4f}"
)
if dht is not None:
print(
f"dh0_qla: {(dh0_qla - dh0_ref).abs().max().item():.4f} / {dh0_ref.abs().max().item():.4f}"
)
print(
f"db_qla: {(db_qla - db_ref).abs().max().item():.4f} / {db_ref.abs().max().item():.4f}"
)
print(
f"dg_qla: {(dg_qla - dg_ref).abs().max().item():.4f} / {dg_ref.abs().max().item():.4f}"
)
print("********** ERROR **********")
raise e
if show_speedup:
prof_fla = profile(
chunk_gated_delta_rule_bwd_fla,
[q, k, v, g_fla, beta, A_fla, scale, h0, do, dht, cu_seqlens],
)
prof_qla = profile(
chunk_gated_delta_rule_bwd_qla,
[q, k, v, g_qla, beta, A_qla, do, dht, scale, h0, cu_seqlens],
)
result_fla = {
"[bwd] csum": prof_fla["chunk_local_cumsum_scalar_kernel"],
"[bwd] recom": prof_fla["recompute_w_u_fwd_kernel"]
+ prof_fla["chunk_gated_delta_rule_fwd_kernel_h_blockdim64"],
"[bwd] dv": prof_fla["chunk_bwd_kernel_dv_local"],
"[bwd] gdr": prof_fla["chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64"],
"[bwd] dqkwg": prof_fla["kernel_kernel"],
"[bwd] wy": prof_fla["prepare_wy_repr_bwd_kernel"],
}
result_qla = {
"[bwd] csum": prof_qla["tilelang_chunk_local_cumsum_kernel_kernel"],
"[bwd] recom": prof_qla["tilelang_prepare_h_kernel_kernel"],
"[bwd] gdr": prof_qla["tilelang_fused_chunk_gdr_bwd_kernel_kernel"],
}
if num_k_heads < num_v_heads:
result_fla["[bwd] reduc"] = prof_fla["compress_heads_kernel"]
result_qla["[bwd] reduc"] = (
prof_qla["tilelang_group_reduce_vector_kernel_kernel"] * 2
)
result_fla["total"] = prof_fla["total"]
result_qla["total"] = prof_qla["total"]
results = {
"fla": result_fla,
"flash_qla": result_qla,
}
df = pd.DataFrame(results)
print(df.round(3))
speedup = results["fla"]["total"] / results["flash_qla"]["total"]
print(f"Speed up: {speedup:2.2f}x")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Test Gated Delta Rule")
parser.add_argument(
"--set",
type=str,
default="develop",
help="Preset name (loads from settings/{set}.csv)",
)
parser.add_argument(
"--seqlen", "--num-tokens", type=int, default=16384, help="Sequence Length"
)
parser.add_argument(
"--nkh",
"--num-k-heads",
type=int,
default=0,
help="Number of K heads (num_k_heads)",
)
parser.add_argument(
"--nvh",
"--num-heads",
"--num-v-heads",
type=int,
default=64,
help="Number of V heads (num_v_heads)",
)
parser.add_argument(
"--no-h0",
action="store_true",
help="Disable initial state and gradient of final state",
)
parser.add_argument("--skip-bwd", action="store_true", help="Test forward only")
parser.add_argument(
"--no-cp",
"--disable-auto-cp",
action="store_true",
help="Disable auto intra-card CP",
)
parser.add_argument(
"--swa-ratio", type=float, default=0.75, help="Ratio of sliding-window heads"
)
parser.add_argument(
"--data-dtype",
type=str,
default="bfloat16",
help="Data type for input and output",
)
parser.add_argument(
"--ref-dtype", type=str, default="float64", help="Data type for reference"
)
parser.add_argument("--hide-acc", action="store_true", help="Do not print accuracy")
parser.add_argument("--hide-lat", action="store_true", help="Do not print latency")
parser.add_argument(
"--seed", "--random-seed", type=int, default=42, help="Random seed"
)
args = parser.parse_args()
if args.nkh <= 0:
args.nkh = args.nvh
metadata = {
"head_dim_k": 128, # MUST BE 128
"head_dim_v": 128, # MUST BE 128
"chunk_size": 64, # MUST BE 64
"num_tokens": args.seqlen,
"num_k_heads": args.nkh,
"num_v_heads": args.nvh,
"use_h0": not args.no_h0,
"data_dtype": args.data_dtype,
"ref_dtype": args.ref_dtype,
"check_accuracy": not args.hide_acc,
"show_speedup": not args.hide_lat,
"skip_bwd": args.skip_bwd,
"auto_cp": not args.no_cp,
"swa_ratio": args.swa_ratio,
"random_seed": args.seed,
"device": "cuda",
}
import os
script_dir = os.path.dirname(os.path.abspath(__file__))
preset = pd.read_csv(os.path.join(script_dir, "settings", f"{args.set}.csv"))
for i, row in preset.iterrows():
print("-" * 64)
torch.cuda.empty_cache()
data = row.to_dict()
if "cu_seqlens" in data.keys():
data["cu_seqlens"] = list(map(int, data["cu_seqlens"].split("-")))
metadata.update(data)
test_gated_delta_rule(**metadata)
print("-" * 64)