first commit

This commit is contained in:
2026-06-14 23:49:03 +08:00
commit 3f95e2939d
35 changed files with 6764 additions and 0 deletions

700
tests/ref_gdr.py Normal file
View File

@@ -0,0 +1,700 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
import torch
from flash_qla.utils import (
pad_and_reshape,
pack,
unpack,
fill_last_chunk_of_g,
prepare_chunk_offsets,
)
def torch_cumsum(
x: torch.Tensor, # [B, T, H]
cu_seqlens: torch.Tensor = None,
chunk_size: int = 64,
reverse: bool = False,
):
if cu_seqlens is not None:
x = unpack(x, cu_seqlens)
batch_size, num_tokens, num_heads = x.shape
x = pad_and_reshape(x, dim=1, chunk_size=chunk_size)
if reverse:
x = torch.flip(x, dims=(2,))
x = x.cumsum(dim=2)
x = torch.flip(x, dims=(2,))
else:
x = x.cumsum(dim=2)
x = x.reshape(batch_size, -1, num_heads)
x = x[:, :num_tokens]
if cu_seqlens is not None:
x = pack(x, cu_seqlens)
return x
def torch_kkt_fwd(
k: torch.Tensor, # [B, T, Hk, K]
g: torch.Tensor, # [B, T, Hv]
beta: torch.Tensor, # [B, T, Hv]
cu_seqlens: torch.Tensor = None,
chunk_size: int = 64,
):
if cu_seqlens is not None:
k = unpack(k, cu_seqlens)
g = unpack(g, cu_seqlens)
beta = unpack(beta, cu_seqlens)
batch_size, num_tokens, num_k_heads, head_dim = k.shape
num_v_heads = g.shape[-1]
if num_k_heads != num_v_heads:
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, H, K]
g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, H]
beta = pad_and_reshape(beta, dim=1, chunk_size=chunk_size) # [B, N, C, H]
mask = torch.triu(
torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device)
)
decay_mask = torch.exp(g[:, :, :, None, :] - g[:, :, None, :, :])
decay_mask = decay_mask.masked_fill(mask[None, None, :, :, None], 0.0)
# decay_mask = torch.where(mask[None, None, :, :, None], decay_mask, 0.0)
attn = torch.einsum(
"bnchk, bndhk -> bnchd", k * beta.unsqueeze(-1), k
) * decay_mask.swapaxes(-2, -1) # [B, N, C, H, D]
attn = attn.reshape(batch_size, -1, num_v_heads, chunk_size)[:, :num_tokens]
if cu_seqlens is not None:
attn = pack(attn, cu_seqlens)
return attn
def torch_solve(
x: torch.Tensor, # [B, T, H, D]
cu_seqlens: torch.Tensor = None,
):
if cu_seqlens is not None:
x = unpack(x, cu_seqlens)
batch_size, num_tokens, num_heads, chunk_size = x.shape
x = -pad_and_reshape(x, dim=1, chunk_size=chunk_size).swapaxes(
2, 3
) # [B, N, H, C, D]
for i in range(1, chunk_size):
row = x[..., i, :i].clone()
sub = x[..., :i, :i].clone()
x[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
x += torch.eye(chunk_size, dtype=x.dtype, device=x.device)
x = x.swapaxes(2, 3).reshape((batch_size, -1, num_heads, chunk_size))[
:, :num_tokens
]
if cu_seqlens is not None:
x = pack(x, cu_seqlens)
return x
def torch_w_u_fwd(
k: torch.Tensor, # [B, T, Hk, K]
v: torch.Tensor, # [B, T, Hv, V]
g: torch.Tensor, # [B, T, Hv]
beta: torch.Tensor, # [B, T, Hv]
A: torch.Tensor, # [B, T, Hv, D]
cu_seqlens: torch.Tensor = None,
):
if cu_seqlens is not None:
k = unpack(k, cu_seqlens)
v = unpack(v, cu_seqlens)
A = unpack(A, cu_seqlens)
beta = unpack(beta, cu_seqlens)
g = unpack(g, cu_seqlens)
batch_size, num_tokens, _, chunk_size = A.shape
_, _, num_k_heads, head_dim_k = k.shape
_, _, num_v_heads, head_dim_v = v.shape
if num_k_heads != num_v_heads:
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
k_beta = pad_and_reshape(
k * beta.unsqueeze(-1) * g.exp().unsqueeze(-1), dim=1, chunk_size=chunk_size
) # [B, N, C, Hv, K]
v_beta = pad_and_reshape(
v * beta.unsqueeze(-1), dim=1, chunk_size=chunk_size
) # [B, N, C, Hv, V]
A = pad_and_reshape(A, dim=1)
w = torch.einsum("bnchd, bndhk -> bnchk", A, k_beta).reshape(
(batch_size, -1, num_v_heads, head_dim_k)
)[:, :num_tokens]
u = torch.einsum("bnchd, bndhk -> bnchk", A, v_beta).reshape(
(batch_size, -1, num_v_heads, head_dim_v)
)[:, :num_tokens]
if cu_seqlens is not None:
w = pack(w, cu_seqlens)
u = pack(u, cu_seqlens)
return w, u
def torch_chunk_gdr_fwd(
k: torch.Tensor, # [B, T, Hk, K]
w: torch.Tensor, # [B, T, Hv, K]
u: torch.Tensor, # [B, T, Hv, V]
g: torch.Tensor, # [B, T, Hv]
initial_state: torch.Tensor = None, # [B, Hv, K, V]
cu_seqlens: torch.Tensor = None,
chunk_size: int = 64,
):
if cu_seqlens is not None:
k = unpack(k, cu_seqlens)
w = unpack(w, cu_seqlens)
u = unpack(u, cu_seqlens)
g = unpack(g, cu_seqlens)
batch_size, num_tokens, num_k_heads, head_dim_k = k.shape
_, _, num_v_heads, head_dim_v = u.shape
if num_k_heads != num_v_heads:
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
w = pad_and_reshape(w, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
u = pad_and_reshape(u, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V]
g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, Hv]
g = fill_last_chunk_of_g(g, num_tokens, cu_seqlens, chunk_size=chunk_size)
if initial_state is None:
last_state = torch.zeros(
(batch_size, num_v_heads, head_dim_k, head_dim_v),
dtype=g.dtype,
device=g.device,
)
else:
last_state = initial_state.to(g.dtype, copy=True)
h, vn = [], []
for i in range(k.shape[1]):
h.append(last_state)
v_new = u[:, i] - torch.einsum("bchk, bhkv -> bchv", w[:, i], last_state)
vn.append(v_new)
last_state = last_state * g[:, i, -1, :, None, None].exp()
last_state = last_state + torch.einsum(
"bchk, bchv -> bhkv",
k[:, i] * (g[:, i, -1:, :, None] - g[:, i, :, :, None]).exp(),
v_new,
)
h = torch.stack(h, dim=1).contiguous()
vn = (
torch.stack(vn, dim=1)
.reshape((batch_size, -1, num_v_heads, head_dim_v))[:, :num_tokens]
.contiguous()
)
if cu_seqlens is not None:
vn = pack(vn, cu_seqlens)
h = pack(h, prepare_chunk_offsets(cu_seqlens, chunk_size))
return h, vn, last_state
def torch_chunk_o_fwd(
q: torch.Tensor, # [B, T, Hk, K]
k: torch.Tensor, # [B, T, Hk, K]
v: torch.Tensor, # [B, T, Hv, K]
h: torch.Tensor, # [B, N, Hv, K, V]
g: torch.Tensor, # [B, T, Hv]
cu_seqlens: torch.Tensor = None,
scale: float = None,
chunk_size: int = 64,
):
if cu_seqlens is not None:
q = unpack(q, cu_seqlens)
k = unpack(k, cu_seqlens)
v = unpack(v, cu_seqlens)
g = unpack(g, cu_seqlens)
h = unpack(h, prepare_chunk_offsets(cu_seqlens, chunk_size))
batch_size, num_tokens, num_k_heads, head_dim_k = k.shape
_, _, num_v_heads, head_dim_v = v.shape
if num_k_heads != num_v_heads:
q = q.repeat_interleave(num_v_heads // num_k_heads, dim=2)
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
scale = scale or head_dim_k ** (-0.5)
q = pad_and_reshape(q, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
v = pad_and_reshape(v, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, Hv]
q = q * scale
mask = torch.triu(
torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device),
diagonal=1,
)
decay_mask = torch.exp(g[:, :, :, None, :] - g[:, :, None, :, :])
decay_mask = decay_mask.masked_fill(
mask[None, None, :, :, None], 0.0
) # [B, N, C, D, Hv]
attn = torch.einsum("bnchk, bndhk -> bncdh", q, k) * decay_mask
attn_inter = torch.einsum("bnchk, bnhkv -> bnchv", q * g.exp().unsqueeze(-1), h)
o = attn_inter + torch.einsum("bncdh, bndhv -> bnchv", attn, v)
o = o.reshape((batch_size, -1, num_v_heads, head_dim_v))[:, :num_tokens]
if cu_seqlens is not None:
o = pack(o, cu_seqlens)
return o
def torch_chunk_dv_bwd(
q: torch.Tensor, # [B, T, Hk, K]
k: torch.Tensor, # [B, T, Hk, K]
g: torch.Tensor, # [B, T, Hv]
do: torch.Tensor, # [B, T, Hv, V]
cu_seqlens: torch.Tensor = None,
scale: float = None,
chunk_size: int = 64,
):
if cu_seqlens is not None:
q = unpack(q, cu_seqlens)
k = unpack(k, cu_seqlens)
g = unpack(g, cu_seqlens)
do = unpack(do, cu_seqlens)
batch_size, num_tokens, num_k_heads, head_dim_k = k.shape
_, _, num_v_heads, head_dim_v = do.shape
if num_k_heads != num_v_heads:
q = q.repeat_interleave(num_v_heads // num_k_heads, dim=2)
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
scale = scale or head_dim_k ** (-0.5)
q = pad_and_reshape(q, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, Hv]
do = pad_and_reshape(do, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V]
q = q * scale
mask = torch.triu(
torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device),
diagonal=1,
)
decay_mask = torch.exp(g[:, :, :, None, :] - g[:, :, None, :, :])
decay_mask = decay_mask.masked_fill(
mask[None, None, :, :, None], 0.0
) # [B, N, C, D, Hv]
attn = torch.einsum("bnchk, bndhk -> bncdh", q, k) * decay_mask
dv = torch.einsum("bncdh, bnchv -> bndhv", attn, do)
dv = dv.reshape((batch_size, -1, num_v_heads, head_dim_v))[:, :num_tokens]
if cu_seqlens is not None:
dv = pack(dv, cu_seqlens)
return dv
def torch_chunk_gdr_bwd(
q: torch.Tensor, # [B, T, Hk, K]
k: torch.Tensor, # [B, T, Hk, K]
w: torch.Tensor, # [B, T, Hv, K]
g: torch.Tensor, # [B, T, Hv]
do: torch.Tensor, # [B, T, Hv, V]
dv: torch.Tensor, # [B, T, Hv, V]
h0: torch.Tensor = None, # [B, Hv, K, V]
dht: torch.Tensor = None, # [B, Hv, K, V]
cu_seqlens: torch.Tensor = None,
scale: float = None,
chunk_size: int = 64,
):
if cu_seqlens is not None:
q = unpack(q, cu_seqlens)
k = unpack(k, cu_seqlens)
w = unpack(w, cu_seqlens)
g = unpack(g, cu_seqlens)
do = unpack(do, cu_seqlens)
dv = unpack(dv, cu_seqlens)
batch_size, num_tokens, num_k_heads, head_dim_k = k.shape
_, _, num_v_heads, head_dim_v = do.shape
if num_k_heads != num_v_heads:
q = q.repeat_interleave(num_v_heads // num_k_heads, dim=2)
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
scale = scale or head_dim_k ** (-0.5)
q = pad_and_reshape(q, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
w = pad_and_reshape(w, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, Hv]
do = pad_and_reshape(do, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V]
dv = pad_and_reshape(dv, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V]
g = fill_last_chunk_of_g(g, num_tokens, cu_seqlens, chunk_size=chunk_size)
q = q * scale
if dht is None:
dstate = torch.zeros(
(batch_size, num_v_heads, head_dim_k, head_dim_v),
dtype=g.dtype,
device=g.device,
)
else:
dstate = dht.to(g.dtype, copy=True)
dstate_inter = torch.einsum("bnchk, bnchv -> bnhkv", q * g.exp().unsqueeze(-1), do)
dh = []
for i in reversed(range(k.shape[1])):
dh.insert(0, dstate)
dv[:, i] += torch.einsum(
"bchk, bhkv -> bchv",
k[:, i] * (g[:, i, -1:, :, None] - g[:, i, :, :, None]).exp(),
dstate,
)
dstate = dstate * g[:, i, -1, :, None, None].exp()
dstate = (
dstate
+ dstate_inter[:, i]
- torch.einsum("bchk, bchv -> bhkv", w[:, i], dv[:, i])
)
dh = torch.stack(dh, dim=1).contiguous()
dh0 = None if h0 is None else dstate
dv = dv.reshape((batch_size, -1, num_v_heads, head_dim_v))[:, :num_tokens]
if cu_seqlens is not None:
dv = pack(dv, cu_seqlens)
dh = pack(dh, prepare_chunk_offsets(cu_seqlens, chunk_size))
return dh, dh0, dv
def torch_chunk_dqkwg_bwd(
q: torch.Tensor, # [B, T, Hk, K]
k: torch.Tensor, # [B, T, Hk, K]
v: torch.Tensor, # [B, T, Hv, V]
w: torch.Tensor, # [B, T, Hv, K]
g: torch.Tensor, # [B, T, Hv]
h: torch.Tensor, # [B, N, Hv, K, V]
dv: torch.Tensor, # [B, T, Hv, V]
do: torch.Tensor, # [B, T, Hv, V]
dh: torch.Tensor, # [B, N, Hv, K, V]
cu_seqlens: torch.Tensor = None,
scale: float = None,
chunk_size: int = 64,
):
if cu_seqlens is not None:
q = unpack(q, cu_seqlens)
k = unpack(k, cu_seqlens)
v = unpack(v, cu_seqlens)
w = unpack(w, cu_seqlens)
g = unpack(g, cu_seqlens)
do = unpack(do, cu_seqlens)
dv = unpack(dv, cu_seqlens)
h = unpack(h, prepare_chunk_offsets(cu_seqlens, chunk_size))
dh = unpack(dh, prepare_chunk_offsets(cu_seqlens, chunk_size))
batch_size, num_tokens, num_k_heads, head_dim_k = k.shape
_, _, num_v_heads, head_dim_v = do.shape
if num_k_heads != num_v_heads:
q = q.repeat_interleave(num_v_heads // num_k_heads, dim=2)
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
scale = scale or head_dim_k ** (-0.5)
q = pad_and_reshape(q, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
v = pad_and_reshape(v, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V]
w = pad_and_reshape(w, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, Hv]
do = pad_and_reshape(do, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V]
dv = pad_and_reshape(dv, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V]
g = fill_last_chunk_of_g(g, num_tokens, cu_seqlens, chunk_size=chunk_size)
mask = torch.triu(
torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device),
diagonal=1,
)
decay_mask = torch.exp(g[:, :, :, None, :] - g[:, :, None, :, :])
decay_mask = decay_mask.masked_fill(
mask[None, None, :, :, None], 0.0
) # [B, N, C, D, Hv]
dg_last = (h * dh).sum(dim=-1).sum(dim=-1) # [B, N, Hv]
ds = torch.einsum("bnchv, bndhv -> bncdh", do, v)
dq = torch.einsum("bnchv, bnhkv -> bnchk", do, h)
dk = torch.einsum("bnchv, bnhkv -> bnchk", v, dh)
dw = -torch.einsum("bnchv, bnhkv -> bnchk", dv, h)
g_last = g[:, :, -1]
dg_last *= g_last.exp()
dq = dq * g.unsqueeze(-1).exp() * scale
dg = (q * dq).sum(dim=-1) # [B, N, C, Hv]
dk = dk * (g_last.unsqueeze(-2) - g).unsqueeze(-1).exp()
dg -= (k * dk).sum(dim=-1)
dg_last += (k * dk).sum(dim=-1).sum(dim=-2)
ds *= decay_mask * scale
ds2 = ds * torch.einsum("bnchk, bndhk -> bncdh", q, k)
dg += ds2.sum(dim=-2)
dg -= ds2.sum(dim=-3)
dq += torch.einsum("bncdh, bndhk -> bnchk", ds, k)
dk += torch.einsum("bncdh, bnchk -> bndhk", ds, q)
dg[:, :, -1] += dg_last
dg = fill_last_chunk_of_g(
dg, num_tokens, cu_seqlens, chunk_size=chunk_size, reverse=True
)
dq = dq.reshape((batch_size, -1, num_v_heads, head_dim_k))[:, :num_tokens]
dk = dk.reshape((batch_size, -1, num_v_heads, head_dim_k))[:, :num_tokens]
dw = dw.reshape((batch_size, -1, num_v_heads, head_dim_k))[:, :num_tokens]
dg = dg.reshape((batch_size, -1, num_v_heads))[:, :num_tokens]
if cu_seqlens is not None:
dq = pack(dq, cu_seqlens)
dk = pack(dk, cu_seqlens)
dw = pack(dw, cu_seqlens)
dg = pack(dg, cu_seqlens)
return dq, dk, dw, dg
def torch_chunk_wy_bwd(
k: torch.Tensor, # [B, T, Hk, K]
v: torch.Tensor, # [B, T, Hv, V]
beta: torch.Tensor, # [B, T, Hv]
A: torch.Tensor, # [B, T, Hv, D]
g: torch.Tensor, # [B, T, Hv]
dw: torch.Tensor, # [B, T, Hv, K]
du: torch.Tensor, # [B, T, Hv, V]
dk1: torch.Tensor, # [B, T, Hv, K]
dg1: torch.Tensor, # [B, T, Hv]
cu_seqlens: torch.Tensor = None,
):
if cu_seqlens is not None:
k = unpack(k, cu_seqlens)
v = unpack(v, cu_seqlens)
beta = unpack(beta, cu_seqlens)
A = unpack(A, cu_seqlens)
g = unpack(g, cu_seqlens)
dw = unpack(dw, cu_seqlens)
du = unpack(du, cu_seqlens)
dk1 = unpack(dk1, cu_seqlens)
dg1 = unpack(dg1, cu_seqlens)
batch_size, num_tokens, num_k_heads, head_dim_k = k.shape
_, _, num_v_heads, head_dim_v = v.shape
chunk_size = A.shape[-1]
if num_k_heads != num_v_heads:
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
v = pad_and_reshape(v, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V]
beta = pad_and_reshape(beta, dim=1, chunk_size=chunk_size) # [B, N, C, Hv]
A = pad_and_reshape(A, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, D]
g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, Hv]
dw = pad_and_reshape(dw, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
du = pad_and_reshape(du, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V]
dk1 = pad_and_reshape(dk1, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
dg1 = pad_and_reshape(dg1, dim=1, chunk_size=chunk_size) # [B, N, C, Hv]
dA = torch.einsum("bnchk, bndhk -> bnchd", dw, k * (beta * g.exp()).unsqueeze(-1))
dk_beta_g = torch.einsum("bnchd, bnchk -> bndhk", A, dw)
dk = dk_beta_g * (beta * g.exp()).unsqueeze(-1)
db = (dk_beta_g * k * g.exp().unsqueeze(-1)).sum(dim=-1)
dg = (dk_beta_g * k * (g.exp() * beta).unsqueeze(-1)).sum(dim=-1)
dA += torch.einsum("bnchv, bndhv -> bnchd", du, v * beta.unsqueeze(-1))
dv_beta = torch.einsum("bnchd, bnchv -> bndhv", A, du)
dv = dv_beta * beta.unsqueeze(-1)
db += (dv_beta * v).sum(dim=-1)
mask = torch.triu(
torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device)
)
decay_mask = torch.exp(g[:, :, :, None, :] - g[:, :, None, :, :])
decay_mask = decay_mask.masked_fill(mask[None, None, :, :, None], 0.0).swapaxes(
-2, -1
)
dA = dA.masked_fill(mask[None, None, :, None, :], 0.0)
dA = torch.einsum("bndhc, bndhe -> bnche", A, dA)
dA = torch.einsum("bnchd, bnehd -> bnche", dA, A)
dA = -dA * decay_mask
A = torch.einsum("bnchk, bndhk -> bnchd", k * beta.unsqueeze(-1), k)
dk_beta = torch.einsum("bnchd, bndhk -> bnchk", dA, k)
db += (dk_beta * k).sum(dim=-1)
dk += torch.einsum("bnchd, bnchk -> bndhk", dA, k * beta.unsqueeze(-1))
dk += dk_beta * beta.unsqueeze(-1)
dk += dk1
dg += (dA * A).sum(dim=-1) - (dA * A).sum(dim=-3).swapaxes(-1, -2)
dg += dg1
# TODO: NOTE: GVA
dk = dk.reshape((batch_size, -1, num_v_heads, head_dim_k))[:, :num_tokens]
dv = dv.reshape((batch_size, -1, num_v_heads, head_dim_k))[:, :num_tokens]
db = db.reshape((batch_size, -1, num_v_heads))[:, :num_tokens]
dg = dg.reshape((batch_size, -1, num_v_heads))[:, :num_tokens]
if cu_seqlens is not None:
dk = pack(dk, cu_seqlens)
dv = pack(dv, cu_seqlens)
db = pack(db, cu_seqlens)
dg = pack(dg, cu_seqlens)
return dk, dv, db, dg
def chunk_gated_delta_rule_fwd(
q: torch.Tensor, # [B, T, Hk, K]
k: torch.Tensor, # [B, T, Hk, K]
v: torch.Tensor, # [B, T, Hv, K]
g: torch.Tensor, # [B, T, Hv]
beta: torch.Tensor, # [B, T, Hv]
cu_seqlens: torch.Tensor = None,
initial_state: torch.Tensor = None,
scale: float = None,
chunk_size: int = 64,
):
scale = scale or q.shape[-1] ** (-0.5)
g = torch_cumsum(
x=g,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
)
A = torch_kkt_fwd(
k=k,
g=g,
beta=beta,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
)
A = torch_solve(
x=A,
cu_seqlens=cu_seqlens,
)
w, u = torch_w_u_fwd(
k=k,
v=v,
beta=beta,
A=A,
g=g,
cu_seqlens=cu_seqlens,
)
h, vn, final_state = torch_chunk_gdr_fwd(
k=k,
w=w,
u=u,
g=g,
initial_state=initial_state,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
)
o = torch_chunk_o_fwd(
q=q,
k=k,
v=vn,
h=h,
g=g,
cu_seqlens=cu_seqlens,
scale=scale,
chunk_size=chunk_size,
)
return g, o, A, h, final_state
def chunk_gated_delta_rule_bwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
A: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
do: torch.Tensor,
dht: torch.Tensor,
cu_seqlens: torch.Tensor = None,
chunk_size: int = 64,
):
w, u = torch_w_u_fwd(
k=k,
v=v,
beta=beta,
A=A,
g=g,
cu_seqlens=cu_seqlens,
)
h, vn, _ = torch_chunk_gdr_fwd(
k=k,
w=w,
u=u,
g=g,
initial_state=initial_state,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
)
dv = torch_chunk_dv_bwd(
q=q,
k=k,
g=g,
do=do,
scale=scale,
cu_seqlens=cu_seqlens,
)
dh, dh0, dv = torch_chunk_gdr_bwd(
q=q,
k=k,
w=w,
g=g,
h0=initial_state,
dht=dht,
do=do,
dv=dv,
scale=scale,
cu_seqlens=cu_seqlens,
)
dq, dk1, dw, dg1 = torch_chunk_dqkwg_bwd(
q=q,
k=k,
v=vn,
w=w,
g=g,
h=h,
dv=dv,
do=do,
dh=dh,
scale=scale,
cu_seqlens=cu_seqlens,
)
dk, dv, db, dg = torch_chunk_wy_bwd(
k=k,
v=v,
beta=beta,
g=g,
A=A,
dw=dw,
du=dv,
dk1=dk1,
dg1=dg1,
cu_seqlens=cu_seqlens,
)
Hg, H = k.shape[-2], v.shape[-2]
if Hg < H:
B, T, _, K = dq.shape
dq = torch.sum(dq.reshape(B, T, Hg, -1, K), dim=3)
dk = torch.sum(dk.reshape(B, T, Hg, -1, K), dim=3)
dg = torch_cumsum(dg, chunk_size=64, reverse=True, cu_seqlens=cu_seqlens)
return dq, dk, dv, db, dg, dh0

View File

@@ -0,0 +1,2 @@
batch_size,num_tokens,varlen
1,32768,False
1 batch_size num_tokens varlen
2 1 32768 False

View File

@@ -0,0 +1,9 @@
batch_size,num_tokens,varlen,cu_seqlens
1,16384,True,0-268-1139-1179-1212-1476-2792-3202-3611-3726-3820-4096-4882-6417-8130-8192-8328-9426-10473-11002-11754-12288-14085-15370-16384
1,16384,True,0-3393-4096-4153-5636-5853-6318-6777-8192-8320-8931-9163-9494-10040-10113-10363-10561-11061-11388-11634-12288-14545-16288-16384
1,16384,True,0-4096-6111-6485-6589-7118-8192-9056-10448-12288-14032-14525-15884-16012-16384
1,16384,True,0-177-4096-8192-12288-12805-13171-13298-16055-16384
1,16384,True,0-308-1128-1678-4096-4748-8192-8506-9657-10252-12113-12288-16384
1,16384,True,0-4096-6893-7665-8192-12288-16384
1,16384,True,0-410-841-1135-2126-2512-4096-4682-5022-5375-6259-6335-6580-6648-7308-8192-10450-12058-12288-14215-15280-15701-16384
1,16384,True,0-2048-4096-6144-8192-10240-12288-14336-16384
1 batch_size num_tokens varlen cu_seqlens
2 1 16384 True 0-268-1139-1179-1212-1476-2792-3202-3611-3726-3820-4096-4882-6417-8130-8192-8328-9426-10473-11002-11754-12288-14085-15370-16384
3 1 16384 True 0-3393-4096-4153-5636-5853-6318-6777-8192-8320-8931-9163-9494-10040-10113-10363-10561-11061-11388-11634-12288-14545-16288-16384
4 1 16384 True 0-4096-6111-6485-6589-7118-8192-9056-10448-12288-14032-14525-15884-16012-16384
5 1 16384 True 0-177-4096-8192-12288-12805-13171-13298-16055-16384
6 1 16384 True 0-308-1128-1678-4096-4748-8192-8506-9657-10252-12113-12288-16384
7 1 16384 True 0-4096-6893-7665-8192-12288-16384
8 1 16384 True 0-410-841-1135-2126-2512-4096-4682-5022-5375-6259-6335-6580-6648-7308-8192-10450-12058-12288-14215-15280-15701-16384
9 1 16384 True 0-2048-4096-6144-8192-10240-12288-14336-16384

View File

@@ -0,0 +1,5 @@
batch_size,num_tokens,varlen
1,4096,False
1,8192,False
1,16384,False
1,32768,False
1 batch_size num_tokens varlen
2 1 4096 False
3 1 8192 False
4 1 16384 False
5 1 32768 False

View File

@@ -0,0 +1,7 @@
batch_size,num_tokens,varlen
11,33,False
7,4321,False
3,16789,True
5,8192,True
10,1024,True
20,512,True
1 batch_size num_tokens varlen
2 11 33 False
3 7 4321 False
4 3 16789 True
5 5 8192 True
6 10 1024 True
7 20 512 True

View File

@@ -0,0 +1,71 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
"""Static checks for autograd ``Function`` signatures.
PyTorch validates that ``Function.backward`` returns exactly as many gradients
as ``Function.forward`` received non-``ctx`` inputs; mismatches raise at
``.backward()`` time. ``tests/test_gdr.py`` invokes ``chunk_gated_delta_rule_fwd``
and ``chunk_gated_delta_rule_bwd`` directly, bypassing the autograd path, so
drift between the forward signature and the backward return tuple goes
uncaught by the existing suite.
These tests parse the source files with ``ast`` instead of importing the
modules so they run on CPU-only / non-Hopper machines.
"""
import ast
from pathlib import Path
REPO_ROOT = Path(__file__).resolve().parents[1]
CHUNK_INIT = "flash_qla/ops/gated_delta_rule/chunk/__init__.py"
def _parse(rel_path: str) -> ast.Module:
return ast.parse((REPO_ROOT / rel_path).read_text(encoding="utf-8"))
def _get_class(module: ast.Module, name: str) -> ast.ClassDef:
for node in module.body:
if isinstance(node, ast.ClassDef) and node.name == name:
return node
raise AssertionError(f"class {name!r} not found")
def _get_method(cls: ast.ClassDef, name: str) -> ast.FunctionDef:
for node in cls.body:
if isinstance(node, ast.FunctionDef) and node.name == name:
return node
raise AssertionError(f"method {name!r} not found on {cls.name}")
def test_chunk_gated_delta_rule_grad_count_matches_forward_inputs():
"""``backward`` must return one gradient per non-``ctx`` input of ``forward``."""
module = _parse(CHUNK_INIT)
cls = _get_class(module, "ChunkGatedDeltaRuleFunction")
fwd = _get_method(cls, "forward")
fwd_args = fwd.args.args
assert fwd_args and fwd_args[0].arg == "ctx", (
"forward must take `ctx` as its first argument"
)
n_inputs = len(fwd_args) - 1 # exclude ctx
bwd = _get_method(cls, "backward")
returns = [n for n in ast.walk(bwd) if isinstance(n, ast.Return)]
assert len(returns) == 1, f"expected one Return in backward, got {len(returns)}"
assert isinstance(returns[0].value, ast.Tuple), (
"backward must return a tuple literal"
)
n_grads = len(returns[0].value.elts)
assert n_inputs == n_grads, (
f"backward returns {n_grads} gradients but forward takes {n_inputs} non-ctx "
f"inputs; PyTorch will raise a count-mismatch error at .backward() time."
)
if __name__ == "__main__":
test_chunk_gated_delta_rule_grad_count_matches_forward_inputs()
print("OK")

553
tests/test_gdr.py Normal file
View File

@@ -0,0 +1,553 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
import argparse
import math
import torch
import pandas as pd
# Requires flash-linear-attention==0.5.0
from fla.ops.gated_delta_rule.chunk import (
chunk_gated_delta_rule_fwd as chunk_gated_delta_rule_fwd_fla,
)
from fla.ops.gated_delta_rule.chunk import (
chunk_gated_delta_rule_bwd as chunk_gated_delta_rule_bwd_fla,
)
from flash_qla import chunk_gated_delta_rule_fwd as chunk_gated_delta_rule_fwd_qla
from flash_qla import chunk_gated_delta_rule_bwd as chunk_gated_delta_rule_bwd_qla
from flash_qla.utils import l2norm, pack, profile
from ref_gdr import chunk_gated_delta_rule_fwd as chunk_gated_delta_rule_fwd_ref
from ref_gdr import chunk_gated_delta_rule_bwd as chunk_gated_delta_rule_bwd_ref
def test_gated_delta_rule(
batch_size: int,
num_tokens: int,
num_k_heads: int,
num_v_heads: int,
head_dim_k: int,
head_dim_v: int,
varlen: bool = False,
cu_seqlens: list[int] | None = None,
use_h0: bool = False,
chunk_size: int = 64,
data_dtype: str = "bfloat16",
ref_dtype: str = "float32",
device: torch.device = "cuda",
random_seed: int = 42,
check_accuracy: bool = True,
show_speedup: bool = True,
auto_cp: bool = True,
swa_ratio: float = 0.75,
skip_bwd: bool = False,
):
data_dtype = getattr(torch, data_dtype)
ref_dtype = getattr(torch, ref_dtype)
torch.manual_seed(random_seed)
q = l2norm(
torch.randn(
(batch_size, num_tokens, num_k_heads, head_dim_k),
device=device,
dtype=data_dtype,
)
)
k = l2norm(
torch.randn(
(batch_size, num_tokens, num_k_heads, head_dim_k),
device=device,
dtype=data_dtype,
)
)
v = torch.randn(
(batch_size, num_tokens, num_v_heads, head_dim_v),
device=device,
dtype=data_dtype,
)
g = (
torch.nn.functional.logsigmoid(
torch.randn(
(batch_size, num_tokens, num_v_heads),
device=device,
dtype=torch.float32,
)
)
/ 16
)
beta = torch.randn(
(batch_size, num_tokens, num_v_heads), device=device, dtype=torch.float32
).sigmoid()
h0 = (
torch.randn(
(batch_size, num_v_heads, head_dim_k, head_dim_v),
device=device,
dtype=torch.float32,
)
if use_h0
else None
)
do = torch.randn_like(v)
dht = (
torch.randn(
(batch_size, num_v_heads, head_dim_k, head_dim_v),
device=device,
dtype=torch.float32,
)
/ 8
if use_h0
else None
)
scale = head_dim_k ** (-0.5)
print(
f"Shape: B={batch_size} Hk={num_k_heads} Hv={num_v_heads} T={num_tokens} VarLen={varlen}"
)
swa_mask = torch.zeros((num_v_heads), dtype=torch.bool, device=device)
swa_mask[: math.ceil(swa_ratio * num_v_heads)] = 1
swa_mask = swa_mask[torch.randperm(num_v_heads, device=device)]
g[:, :, ~swa_mask] = 0.0
print(f"SWA Mask: {swa_mask.to(torch.int32, copy=True).tolist()}")
if varlen:
if cu_seqlens is None:
cu_seqlens = torch.randint(
1, num_tokens, (batch_size,), device=device, dtype=torch.int32
)
cu_seqlens = torch.nn.functional.pad(
torch.cumsum(cu_seqlens, dim=-1), (1, 0)
)
q = pack(q, cu_seqlens)
k = pack(k, cu_seqlens)
v = pack(v, cu_seqlens)
g = pack(g, cu_seqlens)
beta = pack(beta, cu_seqlens)
do = pack(do, cu_seqlens)
else:
assert batch_size == 1
assert cu_seqlens[0] == 0
assert cu_seqlens[-1] == num_tokens
cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32)
if use_h0:
real_batch_size = cu_seqlens.shape[0] - 1
h0 = torch.randn(
(real_batch_size, num_v_heads, head_dim_k, head_dim_v),
device=device,
dtype=torch.float32,
)
dht = (
torch.randn(
(real_batch_size, num_v_heads, head_dim_k, head_dim_v),
device=device,
dtype=torch.float32,
)
/ 8
)
assert (cu_seqlens[1:] - cu_seqlens[:-1]).min() > 0
else:
cu_seqlens = None
g_ref, o_ref, A_ref, h_ref, s_ref = chunk_gated_delta_rule_fwd_ref(
q=q.to(ref_dtype, copy=True),
k=k.to(ref_dtype, copy=True),
v=v.to(ref_dtype, copy=True),
g=g.to(ref_dtype, copy=True),
beta=beta.to(ref_dtype, copy=True),
scale=scale,
initial_state=h0,
cu_seqlens=cu_seqlens,
)
g_fla, o_fla, A_fla, s_fla, _, _ = chunk_gated_delta_rule_fwd_fla(
q=q,
k=k,
v=v,
g=g,
beta=beta,
scale=scale,
initial_state=h0,
output_final_state=True,
cu_seqlens=cu_seqlens,
)
g_qla, A_qla, o_qla, h_qla, s_qla = chunk_gated_delta_rule_fwd_qla(
q=q,
k=k,
v=v,
g=g,
beta=beta,
scale=scale,
initial_state=h0,
cu_seqlens=cu_seqlens,
output_final_state=True,
output_h=True,
auto_cp=auto_cp,
)
if check_accuracy:
print(
f"h_qla: {(h_qla - h_ref).abs().max().item():.4f} / {h_ref.abs().max().item():.4f}"
)
print(
f"s_fla: {(s_fla - s_ref).abs().max().item():.4f} / {s_ref.abs().max().item():.4f}"
)
print(
f"s_qla: {(s_qla - s_ref).abs().max().item():.4f} / {s_ref.abs().max().item():.4f}"
)
print(
f"o_fla: {(o_fla - o_ref).abs().max().item():.4f} / {o_ref.abs().max().item():.4f}"
)
print(
f"o_qla: {(o_qla - o_ref).abs().max().item():.4f} / {o_ref.abs().max().item():.4f}"
)
for _ in range(1000):
g_qla, A_qla, o_qla, h_qla, s_qla = chunk_gated_delta_rule_fwd_qla(
q,
k,
v,
g,
beta,
scale,
h0,
cu_seqlens,
True,
False,
auto_cp,
)
try:
if h0 is not None:
assert (
s_qla - s_ref
).abs().max().item() <= s_ref.abs().max().item() * 0.02
assert (
o_qla - o_ref
).abs().max().item() <= o_ref.abs().max().item() * 0.02
except AssertionError as e:
print("********** ERROR **********")
if h0 is not None:
print(
f"s_qla: {(s_qla - s_ref).abs().max().item():.4f} / {s_ref.abs().max().item():.4f}"
)
print(
f"o_qla: {(o_qla - o_ref).abs().max().item():.4f} / {o_ref.abs().max().item():.4f}"
)
print("********** ERROR **********")
raise e
if show_speedup:
prof_fla = profile(
chunk_gated_delta_rule_fwd_fla,
[q, k, v, g, beta, scale, h0, True, cu_seqlens],
)
prof_qla = profile(
chunk_gated_delta_rule_fwd_qla,
[q, k, v, g, beta, scale, h0, cu_seqlens, True, False, auto_cp],
)
result_fla = {
"[fwd] csum": prof_fla["chunk_local_cumsum_scalar_kernel"],
"[fwd] solve": prof_fla["chunk_gated_delta_rule_fwd_kkt_solve_kernel"],
"[fwd] wu": prof_fla["recompute_w_u_fwd_kernel"],
"[fwd] gdr": prof_fla["chunk_gated_delta_rule_fwd_kernel_h_blockdim64"],
"[fwd] o": prof_fla["chunk_fwd_kernel_o"],
}
result_qla = {
"[fwd] csum": prof_qla["tilelang_chunk_local_cumsum_kernel_kernel"],
"[fwd] solve": prof_qla["tilelang_kkt_solve_kernel_kernel"],
"[fwd] gdr": prof_qla["tilelang_fused_chunk_gdr_fwd_kernel_kernel"],
}
if "tilelang_get_warmup_chunks_kernel_kernel" in prof_qla.keys():
result_fla["[fwd] cp-w"] = None
result_fla["[fwd] cp-h"] = None
result_fla["[fwd] cp-c"] = None
result_qla["[fwd] cp-w"] = prof_qla[
"tilelang_get_warmup_chunks_kernel_kernel"
]
result_qla["[fwd] cp-h"] = prof_qla["tilelang_prepare_h_kernel_kernel"]
result_qla["[fwd] cp-c"] = prof_qla["tilelang_correct_h0_kernel_kernel"]
result_fla["total"] = prof_fla["total"]
result_qla["total"] = prof_qla["total"]
results = {
"fla": result_fla,
"flash_qla": result_qla,
}
df = pd.DataFrame(results)
print(df.round(3))
speedup = results["fla"]["total"] / results["flash_qla"]["total"]
print(f"Speed up: {speedup:.2f}x")
if skip_bwd:
return
dq_ref, dk_ref, dv_ref, db_ref, dg_ref, dh0_ref = chunk_gated_delta_rule_bwd_ref(
q.to(ref_dtype, copy=True),
k.to(ref_dtype, copy=True),
v.to(ref_dtype, copy=True),
g_ref,
beta.to(ref_dtype, copy=True),
A_ref.to(ref_dtype, copy=True),
scale,
h0,
do.to(ref_dtype, copy=True),
dht,
cu_seqlens,
)
dq_fla, dk_fla, dv_fla, db_fla, dg_fla, dh0_fla, _, _ = (
chunk_gated_delta_rule_bwd_fla(
q,
k,
v,
g_fla,
beta,
A_fla,
scale,
h0,
do,
dht,
cu_seqlens,
)
)
dq_qla, dk_qla, dv_qla, db_qla, dg_qla, dh0_qla = chunk_gated_delta_rule_bwd_qla(
q,
k,
v,
g_qla,
beta,
A_qla,
do,
dht,
scale,
h0,
cu_seqlens,
)
if check_accuracy:
print(
f"dq_fla: {(dq_fla - dq_ref).abs().max().item():.4f} / {dq_ref.abs().max().item():.4f}"
)
print(
f"dq_qla: {(dq_qla - dq_ref).abs().max().item():.4f} / {dq_ref.abs().max().item():.4f}"
)
print(
f"dk_fla: {(dk_fla - dk_ref).abs().max().item():.4f} / {dk_ref.abs().max().item():.4f}"
)
print(
f"dk_qla: {(dk_qla - dk_ref).abs().max().item():.4f} / {dk_ref.abs().max().item():.4f}"
)
print(
f"dv_fla: {(dv_fla - dv_ref).abs().max().item():.4f} / {dv_ref.abs().max().item():.4f}"
)
print(
f"dv_qla: {(dv_qla - dv_ref).abs().max().item():.4f} / {dv_ref.abs().max().item():.4f}"
)
if dht is not None:
print(
f"dh0_fla: {(dh0_fla - dh0_ref).abs().max().item():.4f} / {dh0_ref.abs().max().item():.4f}"
)
print(
f"dh0_qla: {(dh0_qla - dh0_ref).abs().max().item():.4f} / {dh0_ref.abs().max().item():.4f}"
)
print(
f"db_fla: {(db_fla - db_ref).abs().max().item():.4f} / {db_ref.abs().max().item():.4f}"
)
print(
f"db_qla: {(db_qla - db_ref).abs().max().item():.4f} / {db_ref.abs().max().item():.4f}"
)
print(
f"dg_fla: {(dg_fla - dg_ref).abs().max().item():.4f} / {dg_ref.abs().max().item():.4f}"
)
print(
f"dg_qla: {(dg_qla - dg_ref).abs().max().item():.4f} / {dg_ref.abs().max().item():.4f}"
)
for _ in range(1000):
dq_qla, dk_qla, dv_qla, db_qla, dg_qla, dh0_qla = (
chunk_gated_delta_rule_bwd_qla(
q,
k,
v,
g_qla,
beta,
A_qla,
do,
dht,
scale,
h0,
cu_seqlens,
)
)
try:
assert (
dq_qla - dq_ref
).abs().max().item() <= dq_ref.abs().max().item() * 0.02
assert (
dk_qla - dk_ref
).abs().max().item() <= dk_ref.abs().max().item() * 0.02
assert (
dv_qla - dv_ref
).abs().max().item() <= dv_ref.abs().max().item() * 0.02
assert (
dg_qla - dg_ref
).abs().max().item() <= dg_ref.abs().max().item() * 0.02
assert (
db_qla - db_ref
).abs().max().item() <= db_ref.abs().max().item() * 0.02
if dht is not None:
assert (
dh0_qla - dh0_ref
).abs().max().item() <= dh0_ref.abs().max().item() * 0.02
except AssertionError as e:
print("********** ERROR **********")
print(
f"dq_qla: {(dq_qla - dq_ref).abs().max().item():.4f} / {dq_ref.abs().max().item():.4f}"
)
print(
f"dk_qla: {(dk_qla - dk_ref).abs().max().item():.4f} / {dk_ref.abs().max().item():.4f}"
)
print(
f"dv_qla: {(dv_qla - dv_ref).abs().max().item():.4f} / {dv_ref.abs().max().item():.4f}"
)
if dht is not None:
print(
f"dh0_qla: {(dh0_qla - dh0_ref).abs().max().item():.4f} / {dh0_ref.abs().max().item():.4f}"
)
print(
f"db_qla: {(db_qla - db_ref).abs().max().item():.4f} / {db_ref.abs().max().item():.4f}"
)
print(
f"dg_qla: {(dg_qla - dg_ref).abs().max().item():.4f} / {dg_ref.abs().max().item():.4f}"
)
print("********** ERROR **********")
raise e
if show_speedup:
prof_fla = profile(
chunk_gated_delta_rule_bwd_fla,
[q, k, v, g_fla, beta, A_fla, scale, h0, do, dht, cu_seqlens],
)
prof_qla = profile(
chunk_gated_delta_rule_bwd_qla,
[q, k, v, g_qla, beta, A_qla, do, dht, scale, h0, cu_seqlens],
)
result_fla = {
"[bwd] csum": prof_fla["chunk_local_cumsum_scalar_kernel"],
"[bwd] recom": prof_fla["recompute_w_u_fwd_kernel"]
+ prof_fla["chunk_gated_delta_rule_fwd_kernel_h_blockdim64"],
"[bwd] dv": prof_fla["chunk_bwd_kernel_dv_local"],
"[bwd] gdr": prof_fla["chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64"],
"[bwd] dqkwg": prof_fla["kernel_kernel"],
"[bwd] wy": prof_fla["prepare_wy_repr_bwd_kernel"],
}
result_qla = {
"[bwd] csum": prof_qla["tilelang_chunk_local_cumsum_kernel_kernel"],
"[bwd] recom": prof_qla["tilelang_prepare_h_kernel_kernel"],
"[bwd] gdr": prof_qla["tilelang_fused_chunk_gdr_bwd_kernel_kernel"],
}
if num_k_heads < num_v_heads:
result_fla["[bwd] reduc"] = prof_fla["compress_heads_kernel"]
result_qla["[bwd] reduc"] = (
prof_qla["tilelang_group_reduce_vector_kernel_kernel"] * 2
)
result_fla["total"] = prof_fla["total"]
result_qla["total"] = prof_qla["total"]
results = {
"fla": result_fla,
"flash_qla": result_qla,
}
df = pd.DataFrame(results)
print(df.round(3))
speedup = results["fla"]["total"] / results["flash_qla"]["total"]
print(f"Speed up: {speedup:2.2f}x")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Test Gated Delta Rule")
parser.add_argument(
"--set",
type=str,
default="develop",
help="Preset name (loads from settings/{set}.csv)",
)
parser.add_argument(
"--seqlen", "--num-tokens", type=int, default=16384, help="Sequence Length"
)
parser.add_argument(
"--nkh",
"--num-k-heads",
type=int,
default=0,
help="Number of K heads (num_k_heads)",
)
parser.add_argument(
"--nvh",
"--num-heads",
"--num-v-heads",
type=int,
default=64,
help="Number of V heads (num_v_heads)",
)
parser.add_argument(
"--no-h0",
action="store_true",
help="Disable initial state and gradient of final state",
)
parser.add_argument("--skip-bwd", action="store_true", help="Test forward only")
parser.add_argument(
"--no-cp",
"--disable-auto-cp",
action="store_true",
help="Disable auto intra-card CP",
)
parser.add_argument(
"--swa-ratio", type=float, default=0.75, help="Ratio of sliding-window heads"
)
parser.add_argument(
"--data-dtype",
type=str,
default="bfloat16",
help="Data type for input and output",
)
parser.add_argument(
"--ref-dtype", type=str, default="float64", help="Data type for reference"
)
parser.add_argument("--hide-acc", action="store_true", help="Do not print accuracy")
parser.add_argument("--hide-lat", action="store_true", help="Do not print latency")
parser.add_argument(
"--seed", "--random-seed", type=int, default=42, help="Random seed"
)
args = parser.parse_args()
if args.nkh <= 0:
args.nkh = args.nvh
metadata = {
"head_dim_k": 128, # MUST BE 128
"head_dim_v": 128, # MUST BE 128
"chunk_size": 64, # MUST BE 64
"num_tokens": args.seqlen,
"num_k_heads": args.nkh,
"num_v_heads": args.nvh,
"use_h0": not args.no_h0,
"data_dtype": args.data_dtype,
"ref_dtype": args.ref_dtype,
"check_accuracy": not args.hide_acc,
"show_speedup": not args.hide_lat,
"skip_bwd": args.skip_bwd,
"auto_cp": not args.no_cp,
"swa_ratio": args.swa_ratio,
"random_seed": args.seed,
"device": "cuda",
}
import os
script_dir = os.path.dirname(os.path.abspath(__file__))
preset = pd.read_csv(os.path.join(script_dir, "settings", f"{args.set}.csv"))
for i, row in preset.iterrows():
print("-" * 64)
torch.cuda.empty_cache()
data = row.to_dict()
if "cu_seqlens" in data.keys():
data["cu_seqlens"] = list(map(int, data["cu_seqlens"].split("-")))
metadata.update(data)
test_gated_delta_rule(**metadata)
print("-" * 64)

View File

@@ -0,0 +1,57 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
import math
import pytest
import torch
from flash_qla.ops.gated_delta_rule.legacy import chunk_gated_delta_rule_fwd_legacy
def _reference(q, k, v, g, beta, scale=None, initial_state=None):
batch, tokens, q_heads, dim = q.shape
v_heads = v.shape[2]
scale = scale if scale is not None else dim**-0.5
state = (
initial_state.clone()
if initial_state is not None
else torch.zeros(batch, v_heads, dim, dim, device=q.device, dtype=q.dtype)
)
output = torch.empty_like(v)
for b in range(batch):
for hv in range(v_heads):
hq = hv // (v_heads // q_heads)
for t in range(tokens):
gate = torch.exp(g[b, t, hv])
delta = (v[b, t, hv] - gate * (state[b, hv].transpose(0, 1) @ k[b, t, hq])) * beta[b, t, hv]
state[b, hv] = gate * state[b, hv] + torch.outer(k[b, t, hq], delta)
output[b, t, hv] = scale * (state[b, hv].transpose(0, 1) @ q[b, t, hq])
return output, state
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required")
@pytest.mark.parametrize("dim", [16, 32, 64, 128])
def test_legacy_sm_gdn_matches_reference(dim):
torch.manual_seed(1000 + dim)
q = torch.randn(1, 5, 2, dim, device="cuda", dtype=torch.float32).contiguous() * 0.05
k = torch.randn_like(q).contiguous() * 0.05
v = torch.randn(1, 5, 4, dim, device="cuda", dtype=torch.float32).contiguous() * 0.1
g = torch.randn(1, 5, 4, device="cuda", dtype=torch.float32).contiguous() * 0.02 - 0.04
beta = torch.rand(1, 5, 4, device="cuda", dtype=torch.float32).contiguous()
h0 = torch.randn(1, 4, dim, dim, device="cuda", dtype=torch.float32).contiguous() * 0.01
scale = 1.0 / math.sqrt(dim)
out_ref, state_ref = _reference(q, k, v, g, beta, scale, h0)
out, state = chunk_gated_delta_rule_fwd_legacy(q, k, v, g, beta, scale, h0)
torch.cuda.synchronize()
torch.testing.assert_close(out, out_ref, atol=2e-4, rtol=2e-4)
torch.testing.assert_close(state, state_ref, atol=1e-3, rtol=1e-3)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required")
def test_legacy_sm_gdn_rejects_unsupported_dtype():
q = torch.randn(1, 1, 1, 16, device="cuda", dtype=torch.float16)
with pytest.raises(ValueError, match="float32"):
chunk_gated_delta_rule_fwd_legacy(q, q, q, torch.randn(1, 1, 1, device="cuda"), torch.randn(1, 1, 1, device="cuda"))