first commit
This commit is contained in:
700
tests/ref_gdr.py
Normal file
700
tests/ref_gdr.py
Normal file
@@ -0,0 +1,700 @@
|
||||
# Copyright (c) 2026 The Qwen team, Alibaba Group.
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
import torch
|
||||
|
||||
from flash_qla.utils import (
|
||||
pad_and_reshape,
|
||||
pack,
|
||||
unpack,
|
||||
fill_last_chunk_of_g,
|
||||
prepare_chunk_offsets,
|
||||
)
|
||||
|
||||
|
||||
def torch_cumsum(
|
||||
x: torch.Tensor, # [B, T, H]
|
||||
cu_seqlens: torch.Tensor = None,
|
||||
chunk_size: int = 64,
|
||||
reverse: bool = False,
|
||||
):
|
||||
if cu_seqlens is not None:
|
||||
x = unpack(x, cu_seqlens)
|
||||
|
||||
batch_size, num_tokens, num_heads = x.shape
|
||||
|
||||
x = pad_and_reshape(x, dim=1, chunk_size=chunk_size)
|
||||
|
||||
if reverse:
|
||||
x = torch.flip(x, dims=(2,))
|
||||
x = x.cumsum(dim=2)
|
||||
x = torch.flip(x, dims=(2,))
|
||||
else:
|
||||
x = x.cumsum(dim=2)
|
||||
x = x.reshape(batch_size, -1, num_heads)
|
||||
x = x[:, :num_tokens]
|
||||
|
||||
if cu_seqlens is not None:
|
||||
x = pack(x, cu_seqlens)
|
||||
return x
|
||||
|
||||
|
||||
def torch_kkt_fwd(
|
||||
k: torch.Tensor, # [B, T, Hk, K]
|
||||
g: torch.Tensor, # [B, T, Hv]
|
||||
beta: torch.Tensor, # [B, T, Hv]
|
||||
cu_seqlens: torch.Tensor = None,
|
||||
chunk_size: int = 64,
|
||||
):
|
||||
if cu_seqlens is not None:
|
||||
k = unpack(k, cu_seqlens)
|
||||
g = unpack(g, cu_seqlens)
|
||||
beta = unpack(beta, cu_seqlens)
|
||||
|
||||
batch_size, num_tokens, num_k_heads, head_dim = k.shape
|
||||
num_v_heads = g.shape[-1]
|
||||
|
||||
if num_k_heads != num_v_heads:
|
||||
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
|
||||
|
||||
k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, H, K]
|
||||
g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, H]
|
||||
beta = pad_and_reshape(beta, dim=1, chunk_size=chunk_size) # [B, N, C, H]
|
||||
|
||||
mask = torch.triu(
|
||||
torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device)
|
||||
)
|
||||
decay_mask = torch.exp(g[:, :, :, None, :] - g[:, :, None, :, :])
|
||||
decay_mask = decay_mask.masked_fill(mask[None, None, :, :, None], 0.0)
|
||||
# decay_mask = torch.where(mask[None, None, :, :, None], decay_mask, 0.0)
|
||||
attn = torch.einsum(
|
||||
"bnchk, bndhk -> bnchd", k * beta.unsqueeze(-1), k
|
||||
) * decay_mask.swapaxes(-2, -1) # [B, N, C, H, D]
|
||||
attn = attn.reshape(batch_size, -1, num_v_heads, chunk_size)[:, :num_tokens]
|
||||
|
||||
if cu_seqlens is not None:
|
||||
attn = pack(attn, cu_seqlens)
|
||||
return attn
|
||||
|
||||
|
||||
def torch_solve(
|
||||
x: torch.Tensor, # [B, T, H, D]
|
||||
cu_seqlens: torch.Tensor = None,
|
||||
):
|
||||
if cu_seqlens is not None:
|
||||
x = unpack(x, cu_seqlens)
|
||||
|
||||
batch_size, num_tokens, num_heads, chunk_size = x.shape
|
||||
|
||||
x = -pad_and_reshape(x, dim=1, chunk_size=chunk_size).swapaxes(
|
||||
2, 3
|
||||
) # [B, N, H, C, D]
|
||||
|
||||
for i in range(1, chunk_size):
|
||||
row = x[..., i, :i].clone()
|
||||
sub = x[..., :i, :i].clone()
|
||||
x[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
|
||||
x += torch.eye(chunk_size, dtype=x.dtype, device=x.device)
|
||||
x = x.swapaxes(2, 3).reshape((batch_size, -1, num_heads, chunk_size))[
|
||||
:, :num_tokens
|
||||
]
|
||||
|
||||
if cu_seqlens is not None:
|
||||
x = pack(x, cu_seqlens)
|
||||
return x
|
||||
|
||||
|
||||
def torch_w_u_fwd(
|
||||
k: torch.Tensor, # [B, T, Hk, K]
|
||||
v: torch.Tensor, # [B, T, Hv, V]
|
||||
g: torch.Tensor, # [B, T, Hv]
|
||||
beta: torch.Tensor, # [B, T, Hv]
|
||||
A: torch.Tensor, # [B, T, Hv, D]
|
||||
cu_seqlens: torch.Tensor = None,
|
||||
):
|
||||
if cu_seqlens is not None:
|
||||
k = unpack(k, cu_seqlens)
|
||||
v = unpack(v, cu_seqlens)
|
||||
A = unpack(A, cu_seqlens)
|
||||
beta = unpack(beta, cu_seqlens)
|
||||
g = unpack(g, cu_seqlens)
|
||||
|
||||
batch_size, num_tokens, _, chunk_size = A.shape
|
||||
_, _, num_k_heads, head_dim_k = k.shape
|
||||
_, _, num_v_heads, head_dim_v = v.shape
|
||||
|
||||
if num_k_heads != num_v_heads:
|
||||
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
|
||||
|
||||
k_beta = pad_and_reshape(
|
||||
k * beta.unsqueeze(-1) * g.exp().unsqueeze(-1), dim=1, chunk_size=chunk_size
|
||||
) # [B, N, C, Hv, K]
|
||||
v_beta = pad_and_reshape(
|
||||
v * beta.unsqueeze(-1), dim=1, chunk_size=chunk_size
|
||||
) # [B, N, C, Hv, V]
|
||||
A = pad_and_reshape(A, dim=1)
|
||||
|
||||
w = torch.einsum("bnchd, bndhk -> bnchk", A, k_beta).reshape(
|
||||
(batch_size, -1, num_v_heads, head_dim_k)
|
||||
)[:, :num_tokens]
|
||||
u = torch.einsum("bnchd, bndhk -> bnchk", A, v_beta).reshape(
|
||||
(batch_size, -1, num_v_heads, head_dim_v)
|
||||
)[:, :num_tokens]
|
||||
|
||||
if cu_seqlens is not None:
|
||||
w = pack(w, cu_seqlens)
|
||||
u = pack(u, cu_seqlens)
|
||||
return w, u
|
||||
|
||||
|
||||
def torch_chunk_gdr_fwd(
|
||||
k: torch.Tensor, # [B, T, Hk, K]
|
||||
w: torch.Tensor, # [B, T, Hv, K]
|
||||
u: torch.Tensor, # [B, T, Hv, V]
|
||||
g: torch.Tensor, # [B, T, Hv]
|
||||
initial_state: torch.Tensor = None, # [B, Hv, K, V]
|
||||
cu_seqlens: torch.Tensor = None,
|
||||
chunk_size: int = 64,
|
||||
):
|
||||
if cu_seqlens is not None:
|
||||
k = unpack(k, cu_seqlens)
|
||||
w = unpack(w, cu_seqlens)
|
||||
u = unpack(u, cu_seqlens)
|
||||
g = unpack(g, cu_seqlens)
|
||||
|
||||
batch_size, num_tokens, num_k_heads, head_dim_k = k.shape
|
||||
_, _, num_v_heads, head_dim_v = u.shape
|
||||
|
||||
if num_k_heads != num_v_heads:
|
||||
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
|
||||
|
||||
k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
|
||||
w = pad_and_reshape(w, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
|
||||
u = pad_and_reshape(u, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V]
|
||||
g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, Hv]
|
||||
g = fill_last_chunk_of_g(g, num_tokens, cu_seqlens, chunk_size=chunk_size)
|
||||
|
||||
if initial_state is None:
|
||||
last_state = torch.zeros(
|
||||
(batch_size, num_v_heads, head_dim_k, head_dim_v),
|
||||
dtype=g.dtype,
|
||||
device=g.device,
|
||||
)
|
||||
else:
|
||||
last_state = initial_state.to(g.dtype, copy=True)
|
||||
|
||||
h, vn = [], []
|
||||
for i in range(k.shape[1]):
|
||||
h.append(last_state)
|
||||
v_new = u[:, i] - torch.einsum("bchk, bhkv -> bchv", w[:, i], last_state)
|
||||
vn.append(v_new)
|
||||
last_state = last_state * g[:, i, -1, :, None, None].exp()
|
||||
last_state = last_state + torch.einsum(
|
||||
"bchk, bchv -> bhkv",
|
||||
k[:, i] * (g[:, i, -1:, :, None] - g[:, i, :, :, None]).exp(),
|
||||
v_new,
|
||||
)
|
||||
h = torch.stack(h, dim=1).contiguous()
|
||||
vn = (
|
||||
torch.stack(vn, dim=1)
|
||||
.reshape((batch_size, -1, num_v_heads, head_dim_v))[:, :num_tokens]
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
if cu_seqlens is not None:
|
||||
vn = pack(vn, cu_seqlens)
|
||||
h = pack(h, prepare_chunk_offsets(cu_seqlens, chunk_size))
|
||||
|
||||
return h, vn, last_state
|
||||
|
||||
|
||||
def torch_chunk_o_fwd(
|
||||
q: torch.Tensor, # [B, T, Hk, K]
|
||||
k: torch.Tensor, # [B, T, Hk, K]
|
||||
v: torch.Tensor, # [B, T, Hv, K]
|
||||
h: torch.Tensor, # [B, N, Hv, K, V]
|
||||
g: torch.Tensor, # [B, T, Hv]
|
||||
cu_seqlens: torch.Tensor = None,
|
||||
scale: float = None,
|
||||
chunk_size: int = 64,
|
||||
):
|
||||
if cu_seqlens is not None:
|
||||
q = unpack(q, cu_seqlens)
|
||||
k = unpack(k, cu_seqlens)
|
||||
v = unpack(v, cu_seqlens)
|
||||
g = unpack(g, cu_seqlens)
|
||||
h = unpack(h, prepare_chunk_offsets(cu_seqlens, chunk_size))
|
||||
|
||||
batch_size, num_tokens, num_k_heads, head_dim_k = k.shape
|
||||
_, _, num_v_heads, head_dim_v = v.shape
|
||||
|
||||
if num_k_heads != num_v_heads:
|
||||
q = q.repeat_interleave(num_v_heads // num_k_heads, dim=2)
|
||||
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
|
||||
|
||||
scale = scale or head_dim_k ** (-0.5)
|
||||
|
||||
q = pad_and_reshape(q, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
|
||||
k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
|
||||
v = pad_and_reshape(v, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
|
||||
g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, Hv]
|
||||
|
||||
q = q * scale
|
||||
|
||||
mask = torch.triu(
|
||||
torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device),
|
||||
diagonal=1,
|
||||
)
|
||||
decay_mask = torch.exp(g[:, :, :, None, :] - g[:, :, None, :, :])
|
||||
decay_mask = decay_mask.masked_fill(
|
||||
mask[None, None, :, :, None], 0.0
|
||||
) # [B, N, C, D, Hv]
|
||||
|
||||
attn = torch.einsum("bnchk, bndhk -> bncdh", q, k) * decay_mask
|
||||
attn_inter = torch.einsum("bnchk, bnhkv -> bnchv", q * g.exp().unsqueeze(-1), h)
|
||||
o = attn_inter + torch.einsum("bncdh, bndhv -> bnchv", attn, v)
|
||||
|
||||
o = o.reshape((batch_size, -1, num_v_heads, head_dim_v))[:, :num_tokens]
|
||||
if cu_seqlens is not None:
|
||||
o = pack(o, cu_seqlens)
|
||||
return o
|
||||
|
||||
|
||||
def torch_chunk_dv_bwd(
|
||||
q: torch.Tensor, # [B, T, Hk, K]
|
||||
k: torch.Tensor, # [B, T, Hk, K]
|
||||
g: torch.Tensor, # [B, T, Hv]
|
||||
do: torch.Tensor, # [B, T, Hv, V]
|
||||
cu_seqlens: torch.Tensor = None,
|
||||
scale: float = None,
|
||||
chunk_size: int = 64,
|
||||
):
|
||||
if cu_seqlens is not None:
|
||||
q = unpack(q, cu_seqlens)
|
||||
k = unpack(k, cu_seqlens)
|
||||
g = unpack(g, cu_seqlens)
|
||||
do = unpack(do, cu_seqlens)
|
||||
|
||||
batch_size, num_tokens, num_k_heads, head_dim_k = k.shape
|
||||
_, _, num_v_heads, head_dim_v = do.shape
|
||||
|
||||
if num_k_heads != num_v_heads:
|
||||
q = q.repeat_interleave(num_v_heads // num_k_heads, dim=2)
|
||||
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
|
||||
|
||||
scale = scale or head_dim_k ** (-0.5)
|
||||
|
||||
q = pad_and_reshape(q, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
|
||||
k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
|
||||
g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, Hv]
|
||||
do = pad_and_reshape(do, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V]
|
||||
|
||||
q = q * scale
|
||||
|
||||
mask = torch.triu(
|
||||
torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device),
|
||||
diagonal=1,
|
||||
)
|
||||
decay_mask = torch.exp(g[:, :, :, None, :] - g[:, :, None, :, :])
|
||||
decay_mask = decay_mask.masked_fill(
|
||||
mask[None, None, :, :, None], 0.0
|
||||
) # [B, N, C, D, Hv]
|
||||
|
||||
attn = torch.einsum("bnchk, bndhk -> bncdh", q, k) * decay_mask
|
||||
dv = torch.einsum("bncdh, bnchv -> bndhv", attn, do)
|
||||
|
||||
dv = dv.reshape((batch_size, -1, num_v_heads, head_dim_v))[:, :num_tokens]
|
||||
if cu_seqlens is not None:
|
||||
dv = pack(dv, cu_seqlens)
|
||||
return dv
|
||||
|
||||
|
||||
def torch_chunk_gdr_bwd(
|
||||
q: torch.Tensor, # [B, T, Hk, K]
|
||||
k: torch.Tensor, # [B, T, Hk, K]
|
||||
w: torch.Tensor, # [B, T, Hv, K]
|
||||
g: torch.Tensor, # [B, T, Hv]
|
||||
do: torch.Tensor, # [B, T, Hv, V]
|
||||
dv: torch.Tensor, # [B, T, Hv, V]
|
||||
h0: torch.Tensor = None, # [B, Hv, K, V]
|
||||
dht: torch.Tensor = None, # [B, Hv, K, V]
|
||||
cu_seqlens: torch.Tensor = None,
|
||||
scale: float = None,
|
||||
chunk_size: int = 64,
|
||||
):
|
||||
if cu_seqlens is not None:
|
||||
q = unpack(q, cu_seqlens)
|
||||
k = unpack(k, cu_seqlens)
|
||||
w = unpack(w, cu_seqlens)
|
||||
g = unpack(g, cu_seqlens)
|
||||
do = unpack(do, cu_seqlens)
|
||||
dv = unpack(dv, cu_seqlens)
|
||||
|
||||
batch_size, num_tokens, num_k_heads, head_dim_k = k.shape
|
||||
_, _, num_v_heads, head_dim_v = do.shape
|
||||
|
||||
if num_k_heads != num_v_heads:
|
||||
q = q.repeat_interleave(num_v_heads // num_k_heads, dim=2)
|
||||
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
|
||||
|
||||
scale = scale or head_dim_k ** (-0.5)
|
||||
|
||||
q = pad_and_reshape(q, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
|
||||
k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
|
||||
w = pad_and_reshape(w, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
|
||||
g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, Hv]
|
||||
do = pad_and_reshape(do, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V]
|
||||
dv = pad_and_reshape(dv, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V]
|
||||
g = fill_last_chunk_of_g(g, num_tokens, cu_seqlens, chunk_size=chunk_size)
|
||||
|
||||
q = q * scale
|
||||
|
||||
if dht is None:
|
||||
dstate = torch.zeros(
|
||||
(batch_size, num_v_heads, head_dim_k, head_dim_v),
|
||||
dtype=g.dtype,
|
||||
device=g.device,
|
||||
)
|
||||
else:
|
||||
dstate = dht.to(g.dtype, copy=True)
|
||||
dstate_inter = torch.einsum("bnchk, bnchv -> bnhkv", q * g.exp().unsqueeze(-1), do)
|
||||
|
||||
dh = []
|
||||
for i in reversed(range(k.shape[1])):
|
||||
dh.insert(0, dstate)
|
||||
dv[:, i] += torch.einsum(
|
||||
"bchk, bhkv -> bchv",
|
||||
k[:, i] * (g[:, i, -1:, :, None] - g[:, i, :, :, None]).exp(),
|
||||
dstate,
|
||||
)
|
||||
dstate = dstate * g[:, i, -1, :, None, None].exp()
|
||||
dstate = (
|
||||
dstate
|
||||
+ dstate_inter[:, i]
|
||||
- torch.einsum("bchk, bchv -> bhkv", w[:, i], dv[:, i])
|
||||
)
|
||||
dh = torch.stack(dh, dim=1).contiguous()
|
||||
|
||||
dh0 = None if h0 is None else dstate
|
||||
dv = dv.reshape((batch_size, -1, num_v_heads, head_dim_v))[:, :num_tokens]
|
||||
if cu_seqlens is not None:
|
||||
dv = pack(dv, cu_seqlens)
|
||||
dh = pack(dh, prepare_chunk_offsets(cu_seqlens, chunk_size))
|
||||
return dh, dh0, dv
|
||||
|
||||
|
||||
def torch_chunk_dqkwg_bwd(
|
||||
q: torch.Tensor, # [B, T, Hk, K]
|
||||
k: torch.Tensor, # [B, T, Hk, K]
|
||||
v: torch.Tensor, # [B, T, Hv, V]
|
||||
w: torch.Tensor, # [B, T, Hv, K]
|
||||
g: torch.Tensor, # [B, T, Hv]
|
||||
h: torch.Tensor, # [B, N, Hv, K, V]
|
||||
dv: torch.Tensor, # [B, T, Hv, V]
|
||||
do: torch.Tensor, # [B, T, Hv, V]
|
||||
dh: torch.Tensor, # [B, N, Hv, K, V]
|
||||
cu_seqlens: torch.Tensor = None,
|
||||
scale: float = None,
|
||||
chunk_size: int = 64,
|
||||
):
|
||||
if cu_seqlens is not None:
|
||||
q = unpack(q, cu_seqlens)
|
||||
k = unpack(k, cu_seqlens)
|
||||
v = unpack(v, cu_seqlens)
|
||||
w = unpack(w, cu_seqlens)
|
||||
g = unpack(g, cu_seqlens)
|
||||
do = unpack(do, cu_seqlens)
|
||||
dv = unpack(dv, cu_seqlens)
|
||||
h = unpack(h, prepare_chunk_offsets(cu_seqlens, chunk_size))
|
||||
dh = unpack(dh, prepare_chunk_offsets(cu_seqlens, chunk_size))
|
||||
|
||||
batch_size, num_tokens, num_k_heads, head_dim_k = k.shape
|
||||
_, _, num_v_heads, head_dim_v = do.shape
|
||||
|
||||
if num_k_heads != num_v_heads:
|
||||
q = q.repeat_interleave(num_v_heads // num_k_heads, dim=2)
|
||||
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
|
||||
|
||||
scale = scale or head_dim_k ** (-0.5)
|
||||
|
||||
q = pad_and_reshape(q, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
|
||||
k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
|
||||
v = pad_and_reshape(v, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V]
|
||||
w = pad_and_reshape(w, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
|
||||
g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, Hv]
|
||||
do = pad_and_reshape(do, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V]
|
||||
dv = pad_and_reshape(dv, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V]
|
||||
g = fill_last_chunk_of_g(g, num_tokens, cu_seqlens, chunk_size=chunk_size)
|
||||
|
||||
mask = torch.triu(
|
||||
torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device),
|
||||
diagonal=1,
|
||||
)
|
||||
decay_mask = torch.exp(g[:, :, :, None, :] - g[:, :, None, :, :])
|
||||
decay_mask = decay_mask.masked_fill(
|
||||
mask[None, None, :, :, None], 0.0
|
||||
) # [B, N, C, D, Hv]
|
||||
|
||||
dg_last = (h * dh).sum(dim=-1).sum(dim=-1) # [B, N, Hv]
|
||||
ds = torch.einsum("bnchv, bndhv -> bncdh", do, v)
|
||||
dq = torch.einsum("bnchv, bnhkv -> bnchk", do, h)
|
||||
dk = torch.einsum("bnchv, bnhkv -> bnchk", v, dh)
|
||||
dw = -torch.einsum("bnchv, bnhkv -> bnchk", dv, h)
|
||||
|
||||
g_last = g[:, :, -1]
|
||||
dg_last *= g_last.exp()
|
||||
dq = dq * g.unsqueeze(-1).exp() * scale
|
||||
dg = (q * dq).sum(dim=-1) # [B, N, C, Hv]
|
||||
dk = dk * (g_last.unsqueeze(-2) - g).unsqueeze(-1).exp()
|
||||
dg -= (k * dk).sum(dim=-1)
|
||||
dg_last += (k * dk).sum(dim=-1).sum(dim=-2)
|
||||
ds *= decay_mask * scale
|
||||
ds2 = ds * torch.einsum("bnchk, bndhk -> bncdh", q, k)
|
||||
dg += ds2.sum(dim=-2)
|
||||
dg -= ds2.sum(dim=-3)
|
||||
dq += torch.einsum("bncdh, bndhk -> bnchk", ds, k)
|
||||
dk += torch.einsum("bncdh, bnchk -> bndhk", ds, q)
|
||||
dg[:, :, -1] += dg_last
|
||||
|
||||
dg = fill_last_chunk_of_g(
|
||||
dg, num_tokens, cu_seqlens, chunk_size=chunk_size, reverse=True
|
||||
)
|
||||
dq = dq.reshape((batch_size, -1, num_v_heads, head_dim_k))[:, :num_tokens]
|
||||
dk = dk.reshape((batch_size, -1, num_v_heads, head_dim_k))[:, :num_tokens]
|
||||
dw = dw.reshape((batch_size, -1, num_v_heads, head_dim_k))[:, :num_tokens]
|
||||
dg = dg.reshape((batch_size, -1, num_v_heads))[:, :num_tokens]
|
||||
if cu_seqlens is not None:
|
||||
dq = pack(dq, cu_seqlens)
|
||||
dk = pack(dk, cu_seqlens)
|
||||
dw = pack(dw, cu_seqlens)
|
||||
dg = pack(dg, cu_seqlens)
|
||||
return dq, dk, dw, dg
|
||||
|
||||
|
||||
def torch_chunk_wy_bwd(
|
||||
k: torch.Tensor, # [B, T, Hk, K]
|
||||
v: torch.Tensor, # [B, T, Hv, V]
|
||||
beta: torch.Tensor, # [B, T, Hv]
|
||||
A: torch.Tensor, # [B, T, Hv, D]
|
||||
g: torch.Tensor, # [B, T, Hv]
|
||||
dw: torch.Tensor, # [B, T, Hv, K]
|
||||
du: torch.Tensor, # [B, T, Hv, V]
|
||||
dk1: torch.Tensor, # [B, T, Hv, K]
|
||||
dg1: torch.Tensor, # [B, T, Hv]
|
||||
cu_seqlens: torch.Tensor = None,
|
||||
):
|
||||
if cu_seqlens is not None:
|
||||
k = unpack(k, cu_seqlens)
|
||||
v = unpack(v, cu_seqlens)
|
||||
beta = unpack(beta, cu_seqlens)
|
||||
A = unpack(A, cu_seqlens)
|
||||
g = unpack(g, cu_seqlens)
|
||||
dw = unpack(dw, cu_seqlens)
|
||||
du = unpack(du, cu_seqlens)
|
||||
dk1 = unpack(dk1, cu_seqlens)
|
||||
dg1 = unpack(dg1, cu_seqlens)
|
||||
|
||||
batch_size, num_tokens, num_k_heads, head_dim_k = k.shape
|
||||
_, _, num_v_heads, head_dim_v = v.shape
|
||||
chunk_size = A.shape[-1]
|
||||
|
||||
if num_k_heads != num_v_heads:
|
||||
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
|
||||
|
||||
k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
|
||||
v = pad_and_reshape(v, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V]
|
||||
beta = pad_and_reshape(beta, dim=1, chunk_size=chunk_size) # [B, N, C, Hv]
|
||||
A = pad_and_reshape(A, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, D]
|
||||
g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, Hv]
|
||||
dw = pad_and_reshape(dw, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
|
||||
du = pad_and_reshape(du, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V]
|
||||
dk1 = pad_and_reshape(dk1, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
|
||||
dg1 = pad_and_reshape(dg1, dim=1, chunk_size=chunk_size) # [B, N, C, Hv]
|
||||
|
||||
dA = torch.einsum("bnchk, bndhk -> bnchd", dw, k * (beta * g.exp()).unsqueeze(-1))
|
||||
dk_beta_g = torch.einsum("bnchd, bnchk -> bndhk", A, dw)
|
||||
dk = dk_beta_g * (beta * g.exp()).unsqueeze(-1)
|
||||
db = (dk_beta_g * k * g.exp().unsqueeze(-1)).sum(dim=-1)
|
||||
dg = (dk_beta_g * k * (g.exp() * beta).unsqueeze(-1)).sum(dim=-1)
|
||||
|
||||
dA += torch.einsum("bnchv, bndhv -> bnchd", du, v * beta.unsqueeze(-1))
|
||||
dv_beta = torch.einsum("bnchd, bnchv -> bndhv", A, du)
|
||||
dv = dv_beta * beta.unsqueeze(-1)
|
||||
db += (dv_beta * v).sum(dim=-1)
|
||||
|
||||
mask = torch.triu(
|
||||
torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device)
|
||||
)
|
||||
decay_mask = torch.exp(g[:, :, :, None, :] - g[:, :, None, :, :])
|
||||
decay_mask = decay_mask.masked_fill(mask[None, None, :, :, None], 0.0).swapaxes(
|
||||
-2, -1
|
||||
)
|
||||
dA = dA.masked_fill(mask[None, None, :, None, :], 0.0)
|
||||
dA = torch.einsum("bndhc, bndhe -> bnche", A, dA)
|
||||
dA = torch.einsum("bnchd, bnehd -> bnche", dA, A)
|
||||
dA = -dA * decay_mask
|
||||
|
||||
A = torch.einsum("bnchk, bndhk -> bnchd", k * beta.unsqueeze(-1), k)
|
||||
dk_beta = torch.einsum("bnchd, bndhk -> bnchk", dA, k)
|
||||
db += (dk_beta * k).sum(dim=-1)
|
||||
dk += torch.einsum("bnchd, bnchk -> bndhk", dA, k * beta.unsqueeze(-1))
|
||||
dk += dk_beta * beta.unsqueeze(-1)
|
||||
dk += dk1
|
||||
|
||||
dg += (dA * A).sum(dim=-1) - (dA * A).sum(dim=-3).swapaxes(-1, -2)
|
||||
dg += dg1
|
||||
|
||||
# TODO: NOTE: GVA
|
||||
dk = dk.reshape((batch_size, -1, num_v_heads, head_dim_k))[:, :num_tokens]
|
||||
dv = dv.reshape((batch_size, -1, num_v_heads, head_dim_k))[:, :num_tokens]
|
||||
db = db.reshape((batch_size, -1, num_v_heads))[:, :num_tokens]
|
||||
dg = dg.reshape((batch_size, -1, num_v_heads))[:, :num_tokens]
|
||||
if cu_seqlens is not None:
|
||||
dk = pack(dk, cu_seqlens)
|
||||
dv = pack(dv, cu_seqlens)
|
||||
db = pack(db, cu_seqlens)
|
||||
dg = pack(dg, cu_seqlens)
|
||||
return dk, dv, db, dg
|
||||
|
||||
|
||||
def chunk_gated_delta_rule_fwd(
|
||||
q: torch.Tensor, # [B, T, Hk, K]
|
||||
k: torch.Tensor, # [B, T, Hk, K]
|
||||
v: torch.Tensor, # [B, T, Hv, K]
|
||||
g: torch.Tensor, # [B, T, Hv]
|
||||
beta: torch.Tensor, # [B, T, Hv]
|
||||
cu_seqlens: torch.Tensor = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
scale: float = None,
|
||||
chunk_size: int = 64,
|
||||
):
|
||||
scale = scale or q.shape[-1] ** (-0.5)
|
||||
g = torch_cumsum(
|
||||
x=g,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
A = torch_kkt_fwd(
|
||||
k=k,
|
||||
g=g,
|
||||
beta=beta,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
A = torch_solve(
|
||||
x=A,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
w, u = torch_w_u_fwd(
|
||||
k=k,
|
||||
v=v,
|
||||
beta=beta,
|
||||
A=A,
|
||||
g=g,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
h, vn, final_state = torch_chunk_gdr_fwd(
|
||||
k=k,
|
||||
w=w,
|
||||
u=u,
|
||||
g=g,
|
||||
initial_state=initial_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
o = torch_chunk_o_fwd(
|
||||
q=q,
|
||||
k=k,
|
||||
v=vn,
|
||||
h=h,
|
||||
g=g,
|
||||
cu_seqlens=cu_seqlens,
|
||||
scale=scale,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
return g, o, A, h, final_state
|
||||
|
||||
|
||||
def chunk_gated_delta_rule_bwd(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
A: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
do: torch.Tensor,
|
||||
dht: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor = None,
|
||||
chunk_size: int = 64,
|
||||
):
|
||||
w, u = torch_w_u_fwd(
|
||||
k=k,
|
||||
v=v,
|
||||
beta=beta,
|
||||
A=A,
|
||||
g=g,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
h, vn, _ = torch_chunk_gdr_fwd(
|
||||
k=k,
|
||||
w=w,
|
||||
u=u,
|
||||
g=g,
|
||||
initial_state=initial_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
dv = torch_chunk_dv_bwd(
|
||||
q=q,
|
||||
k=k,
|
||||
g=g,
|
||||
do=do,
|
||||
scale=scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
dh, dh0, dv = torch_chunk_gdr_bwd(
|
||||
q=q,
|
||||
k=k,
|
||||
w=w,
|
||||
g=g,
|
||||
h0=initial_state,
|
||||
dht=dht,
|
||||
do=do,
|
||||
dv=dv,
|
||||
scale=scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
dq, dk1, dw, dg1 = torch_chunk_dqkwg_bwd(
|
||||
q=q,
|
||||
k=k,
|
||||
v=vn,
|
||||
w=w,
|
||||
g=g,
|
||||
h=h,
|
||||
dv=dv,
|
||||
do=do,
|
||||
dh=dh,
|
||||
scale=scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
dk, dv, db, dg = torch_chunk_wy_bwd(
|
||||
k=k,
|
||||
v=v,
|
||||
beta=beta,
|
||||
g=g,
|
||||
A=A,
|
||||
dw=dw,
|
||||
du=dv,
|
||||
dk1=dk1,
|
||||
dg1=dg1,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
Hg, H = k.shape[-2], v.shape[-2]
|
||||
if Hg < H:
|
||||
B, T, _, K = dq.shape
|
||||
dq = torch.sum(dq.reshape(B, T, Hg, -1, K), dim=3)
|
||||
dk = torch.sum(dk.reshape(B, T, Hg, -1, K), dim=3)
|
||||
dg = torch_cumsum(dg, chunk_size=64, reverse=True, cu_seqlens=cu_seqlens)
|
||||
return dq, dk, dv, db, dg, dh0
|
||||
2
tests/settings/develop.csv
Normal file
2
tests/settings/develop.csv
Normal file
@@ -0,0 +1,2 @@
|
||||
batch_size,num_tokens,varlen
|
||||
1,32768,False
|
||||
|
9
tests/settings/product.csv
Normal file
9
tests/settings/product.csv
Normal file
@@ -0,0 +1,9 @@
|
||||
batch_size,num_tokens,varlen,cu_seqlens
|
||||
1,16384,True,0-268-1139-1179-1212-1476-2792-3202-3611-3726-3820-4096-4882-6417-8130-8192-8328-9426-10473-11002-11754-12288-14085-15370-16384
|
||||
1,16384,True,0-3393-4096-4153-5636-5853-6318-6777-8192-8320-8931-9163-9494-10040-10113-10363-10561-11061-11388-11634-12288-14545-16288-16384
|
||||
1,16384,True,0-4096-6111-6485-6589-7118-8192-9056-10448-12288-14032-14525-15884-16012-16384
|
||||
1,16384,True,0-177-4096-8192-12288-12805-13171-13298-16055-16384
|
||||
1,16384,True,0-308-1128-1678-4096-4748-8192-8506-9657-10252-12113-12288-16384
|
||||
1,16384,True,0-4096-6893-7665-8192-12288-16384
|
||||
1,16384,True,0-410-841-1135-2126-2512-4096-4682-5022-5375-6259-6335-6580-6648-7308-8192-10450-12058-12288-14215-15280-15701-16384
|
||||
1,16384,True,0-2048-4096-6144-8192-10240-12288-14336-16384
|
||||
|
5
tests/settings/profile.csv
Normal file
5
tests/settings/profile.csv
Normal file
@@ -0,0 +1,5 @@
|
||||
batch_size,num_tokens,varlen
|
||||
1,4096,False
|
||||
1,8192,False
|
||||
1,16384,False
|
||||
1,32768,False
|
||||
|
7
tests/settings/varlen.csv
Normal file
7
tests/settings/varlen.csv
Normal file
@@ -0,0 +1,7 @@
|
||||
batch_size,num_tokens,varlen
|
||||
11,33,False
|
||||
7,4321,False
|
||||
3,16789,True
|
||||
5,8192,True
|
||||
10,1024,True
|
||||
20,512,True
|
||||
|
71
tests/test_function_signature.py
Normal file
71
tests/test_function_signature.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# Copyright (c) 2026 The Qwen team, Alibaba Group.
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
"""Static checks for autograd ``Function`` signatures.
|
||||
|
||||
PyTorch validates that ``Function.backward`` returns exactly as many gradients
|
||||
as ``Function.forward`` received non-``ctx`` inputs; mismatches raise at
|
||||
``.backward()`` time. ``tests/test_gdr.py`` invokes ``chunk_gated_delta_rule_fwd``
|
||||
and ``chunk_gated_delta_rule_bwd`` directly, bypassing the autograd path, so
|
||||
drift between the forward signature and the backward return tuple goes
|
||||
uncaught by the existing suite.
|
||||
|
||||
These tests parse the source files with ``ast`` instead of importing the
|
||||
modules so they run on CPU-only / non-Hopper machines.
|
||||
"""
|
||||
|
||||
import ast
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
CHUNK_INIT = "flash_qla/ops/gated_delta_rule/chunk/__init__.py"
|
||||
|
||||
|
||||
def _parse(rel_path: str) -> ast.Module:
|
||||
return ast.parse((REPO_ROOT / rel_path).read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
def _get_class(module: ast.Module, name: str) -> ast.ClassDef:
|
||||
for node in module.body:
|
||||
if isinstance(node, ast.ClassDef) and node.name == name:
|
||||
return node
|
||||
raise AssertionError(f"class {name!r} not found")
|
||||
|
||||
|
||||
def _get_method(cls: ast.ClassDef, name: str) -> ast.FunctionDef:
|
||||
for node in cls.body:
|
||||
if isinstance(node, ast.FunctionDef) and node.name == name:
|
||||
return node
|
||||
raise AssertionError(f"method {name!r} not found on {cls.name}")
|
||||
|
||||
|
||||
def test_chunk_gated_delta_rule_grad_count_matches_forward_inputs():
|
||||
"""``backward`` must return one gradient per non-``ctx`` input of ``forward``."""
|
||||
module = _parse(CHUNK_INIT)
|
||||
cls = _get_class(module, "ChunkGatedDeltaRuleFunction")
|
||||
|
||||
fwd = _get_method(cls, "forward")
|
||||
fwd_args = fwd.args.args
|
||||
assert fwd_args and fwd_args[0].arg == "ctx", (
|
||||
"forward must take `ctx` as its first argument"
|
||||
)
|
||||
n_inputs = len(fwd_args) - 1 # exclude ctx
|
||||
|
||||
bwd = _get_method(cls, "backward")
|
||||
returns = [n for n in ast.walk(bwd) if isinstance(n, ast.Return)]
|
||||
assert len(returns) == 1, f"expected one Return in backward, got {len(returns)}"
|
||||
assert isinstance(returns[0].value, ast.Tuple), (
|
||||
"backward must return a tuple literal"
|
||||
)
|
||||
n_grads = len(returns[0].value.elts)
|
||||
|
||||
assert n_inputs == n_grads, (
|
||||
f"backward returns {n_grads} gradients but forward takes {n_inputs} non-ctx "
|
||||
f"inputs; PyTorch will raise a count-mismatch error at .backward() time."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_chunk_gated_delta_rule_grad_count_matches_forward_inputs()
|
||||
print("OK")
|
||||
553
tests/test_gdr.py
Normal file
553
tests/test_gdr.py
Normal file
@@ -0,0 +1,553 @@
|
||||
# Copyright (c) 2026 The Qwen team, Alibaba Group.
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
import argparse
|
||||
import math
|
||||
|
||||
import torch
|
||||
import pandas as pd
|
||||
|
||||
# Requires flash-linear-attention==0.5.0
|
||||
from fla.ops.gated_delta_rule.chunk import (
|
||||
chunk_gated_delta_rule_fwd as chunk_gated_delta_rule_fwd_fla,
|
||||
)
|
||||
from fla.ops.gated_delta_rule.chunk import (
|
||||
chunk_gated_delta_rule_bwd as chunk_gated_delta_rule_bwd_fla,
|
||||
)
|
||||
|
||||
from flash_qla import chunk_gated_delta_rule_fwd as chunk_gated_delta_rule_fwd_qla
|
||||
from flash_qla import chunk_gated_delta_rule_bwd as chunk_gated_delta_rule_bwd_qla
|
||||
from flash_qla.utils import l2norm, pack, profile
|
||||
|
||||
from ref_gdr import chunk_gated_delta_rule_fwd as chunk_gated_delta_rule_fwd_ref
|
||||
from ref_gdr import chunk_gated_delta_rule_bwd as chunk_gated_delta_rule_bwd_ref
|
||||
|
||||
|
||||
def test_gated_delta_rule(
|
||||
batch_size: int,
|
||||
num_tokens: int,
|
||||
num_k_heads: int,
|
||||
num_v_heads: int,
|
||||
head_dim_k: int,
|
||||
head_dim_v: int,
|
||||
varlen: bool = False,
|
||||
cu_seqlens: list[int] | None = None,
|
||||
use_h0: bool = False,
|
||||
chunk_size: int = 64,
|
||||
data_dtype: str = "bfloat16",
|
||||
ref_dtype: str = "float32",
|
||||
device: torch.device = "cuda",
|
||||
random_seed: int = 42,
|
||||
check_accuracy: bool = True,
|
||||
show_speedup: bool = True,
|
||||
auto_cp: bool = True,
|
||||
swa_ratio: float = 0.75,
|
||||
skip_bwd: bool = False,
|
||||
):
|
||||
data_dtype = getattr(torch, data_dtype)
|
||||
ref_dtype = getattr(torch, ref_dtype)
|
||||
torch.manual_seed(random_seed)
|
||||
q = l2norm(
|
||||
torch.randn(
|
||||
(batch_size, num_tokens, num_k_heads, head_dim_k),
|
||||
device=device,
|
||||
dtype=data_dtype,
|
||||
)
|
||||
)
|
||||
k = l2norm(
|
||||
torch.randn(
|
||||
(batch_size, num_tokens, num_k_heads, head_dim_k),
|
||||
device=device,
|
||||
dtype=data_dtype,
|
||||
)
|
||||
)
|
||||
v = torch.randn(
|
||||
(batch_size, num_tokens, num_v_heads, head_dim_v),
|
||||
device=device,
|
||||
dtype=data_dtype,
|
||||
)
|
||||
g = (
|
||||
torch.nn.functional.logsigmoid(
|
||||
torch.randn(
|
||||
(batch_size, num_tokens, num_v_heads),
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
)
|
||||
/ 16
|
||||
)
|
||||
beta = torch.randn(
|
||||
(batch_size, num_tokens, num_v_heads), device=device, dtype=torch.float32
|
||||
).sigmoid()
|
||||
h0 = (
|
||||
torch.randn(
|
||||
(batch_size, num_v_heads, head_dim_k, head_dim_v),
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
if use_h0
|
||||
else None
|
||||
)
|
||||
do = torch.randn_like(v)
|
||||
dht = (
|
||||
torch.randn(
|
||||
(batch_size, num_v_heads, head_dim_k, head_dim_v),
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
/ 8
|
||||
if use_h0
|
||||
else None
|
||||
)
|
||||
scale = head_dim_k ** (-0.5)
|
||||
print(
|
||||
f"Shape: B={batch_size} Hk={num_k_heads} Hv={num_v_heads} T={num_tokens} VarLen={varlen}"
|
||||
)
|
||||
|
||||
swa_mask = torch.zeros((num_v_heads), dtype=torch.bool, device=device)
|
||||
swa_mask[: math.ceil(swa_ratio * num_v_heads)] = 1
|
||||
swa_mask = swa_mask[torch.randperm(num_v_heads, device=device)]
|
||||
g[:, :, ~swa_mask] = 0.0
|
||||
print(f"SWA Mask: {swa_mask.to(torch.int32, copy=True).tolist()}")
|
||||
|
||||
if varlen:
|
||||
if cu_seqlens is None:
|
||||
cu_seqlens = torch.randint(
|
||||
1, num_tokens, (batch_size,), device=device, dtype=torch.int32
|
||||
)
|
||||
cu_seqlens = torch.nn.functional.pad(
|
||||
torch.cumsum(cu_seqlens, dim=-1), (1, 0)
|
||||
)
|
||||
q = pack(q, cu_seqlens)
|
||||
k = pack(k, cu_seqlens)
|
||||
v = pack(v, cu_seqlens)
|
||||
g = pack(g, cu_seqlens)
|
||||
beta = pack(beta, cu_seqlens)
|
||||
do = pack(do, cu_seqlens)
|
||||
else:
|
||||
assert batch_size == 1
|
||||
assert cu_seqlens[0] == 0
|
||||
assert cu_seqlens[-1] == num_tokens
|
||||
cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32)
|
||||
if use_h0:
|
||||
real_batch_size = cu_seqlens.shape[0] - 1
|
||||
h0 = torch.randn(
|
||||
(real_batch_size, num_v_heads, head_dim_k, head_dim_v),
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
dht = (
|
||||
torch.randn(
|
||||
(real_batch_size, num_v_heads, head_dim_k, head_dim_v),
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
/ 8
|
||||
)
|
||||
assert (cu_seqlens[1:] - cu_seqlens[:-1]).min() > 0
|
||||
else:
|
||||
cu_seqlens = None
|
||||
|
||||
g_ref, o_ref, A_ref, h_ref, s_ref = chunk_gated_delta_rule_fwd_ref(
|
||||
q=q.to(ref_dtype, copy=True),
|
||||
k=k.to(ref_dtype, copy=True),
|
||||
v=v.to(ref_dtype, copy=True),
|
||||
g=g.to(ref_dtype, copy=True),
|
||||
beta=beta.to(ref_dtype, copy=True),
|
||||
scale=scale,
|
||||
initial_state=h0,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
g_fla, o_fla, A_fla, s_fla, _, _ = chunk_gated_delta_rule_fwd_fla(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
g=g,
|
||||
beta=beta,
|
||||
scale=scale,
|
||||
initial_state=h0,
|
||||
output_final_state=True,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
g_qla, A_qla, o_qla, h_qla, s_qla = chunk_gated_delta_rule_fwd_qla(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
g=g,
|
||||
beta=beta,
|
||||
scale=scale,
|
||||
initial_state=h0,
|
||||
cu_seqlens=cu_seqlens,
|
||||
output_final_state=True,
|
||||
output_h=True,
|
||||
auto_cp=auto_cp,
|
||||
)
|
||||
|
||||
if check_accuracy:
|
||||
print(
|
||||
f"h_qla: {(h_qla - h_ref).abs().max().item():.4f} / {h_ref.abs().max().item():.4f}"
|
||||
)
|
||||
print(
|
||||
f"s_fla: {(s_fla - s_ref).abs().max().item():.4f} / {s_ref.abs().max().item():.4f}"
|
||||
)
|
||||
print(
|
||||
f"s_qla: {(s_qla - s_ref).abs().max().item():.4f} / {s_ref.abs().max().item():.4f}"
|
||||
)
|
||||
print(
|
||||
f"o_fla: {(o_fla - o_ref).abs().max().item():.4f} / {o_ref.abs().max().item():.4f}"
|
||||
)
|
||||
print(
|
||||
f"o_qla: {(o_qla - o_ref).abs().max().item():.4f} / {o_ref.abs().max().item():.4f}"
|
||||
)
|
||||
|
||||
for _ in range(1000):
|
||||
g_qla, A_qla, o_qla, h_qla, s_qla = chunk_gated_delta_rule_fwd_qla(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
g,
|
||||
beta,
|
||||
scale,
|
||||
h0,
|
||||
cu_seqlens,
|
||||
True,
|
||||
False,
|
||||
auto_cp,
|
||||
)
|
||||
try:
|
||||
if h0 is not None:
|
||||
assert (
|
||||
s_qla - s_ref
|
||||
).abs().max().item() <= s_ref.abs().max().item() * 0.02
|
||||
assert (
|
||||
o_qla - o_ref
|
||||
).abs().max().item() <= o_ref.abs().max().item() * 0.02
|
||||
except AssertionError as e:
|
||||
print("********** ERROR **********")
|
||||
if h0 is not None:
|
||||
print(
|
||||
f"s_qla: {(s_qla - s_ref).abs().max().item():.4f} / {s_ref.abs().max().item():.4f}"
|
||||
)
|
||||
print(
|
||||
f"o_qla: {(o_qla - o_ref).abs().max().item():.4f} / {o_ref.abs().max().item():.4f}"
|
||||
)
|
||||
print("********** ERROR **********")
|
||||
raise e
|
||||
|
||||
if show_speedup:
|
||||
prof_fla = profile(
|
||||
chunk_gated_delta_rule_fwd_fla,
|
||||
[q, k, v, g, beta, scale, h0, True, cu_seqlens],
|
||||
)
|
||||
prof_qla = profile(
|
||||
chunk_gated_delta_rule_fwd_qla,
|
||||
[q, k, v, g, beta, scale, h0, cu_seqlens, True, False, auto_cp],
|
||||
)
|
||||
result_fla = {
|
||||
"[fwd] csum": prof_fla["chunk_local_cumsum_scalar_kernel"],
|
||||
"[fwd] solve": prof_fla["chunk_gated_delta_rule_fwd_kkt_solve_kernel"],
|
||||
"[fwd] wu": prof_fla["recompute_w_u_fwd_kernel"],
|
||||
"[fwd] gdr": prof_fla["chunk_gated_delta_rule_fwd_kernel_h_blockdim64"],
|
||||
"[fwd] o": prof_fla["chunk_fwd_kernel_o"],
|
||||
}
|
||||
result_qla = {
|
||||
"[fwd] csum": prof_qla["tilelang_chunk_local_cumsum_kernel_kernel"],
|
||||
"[fwd] solve": prof_qla["tilelang_kkt_solve_kernel_kernel"],
|
||||
"[fwd] gdr": prof_qla["tilelang_fused_chunk_gdr_fwd_kernel_kernel"],
|
||||
}
|
||||
if "tilelang_get_warmup_chunks_kernel_kernel" in prof_qla.keys():
|
||||
result_fla["[fwd] cp-w"] = None
|
||||
result_fla["[fwd] cp-h"] = None
|
||||
result_fla["[fwd] cp-c"] = None
|
||||
result_qla["[fwd] cp-w"] = prof_qla[
|
||||
"tilelang_get_warmup_chunks_kernel_kernel"
|
||||
]
|
||||
result_qla["[fwd] cp-h"] = prof_qla["tilelang_prepare_h_kernel_kernel"]
|
||||
result_qla["[fwd] cp-c"] = prof_qla["tilelang_correct_h0_kernel_kernel"]
|
||||
result_fla["total"] = prof_fla["total"]
|
||||
result_qla["total"] = prof_qla["total"]
|
||||
results = {
|
||||
"fla": result_fla,
|
||||
"flash_qla": result_qla,
|
||||
}
|
||||
df = pd.DataFrame(results)
|
||||
print(df.round(3))
|
||||
speedup = results["fla"]["total"] / results["flash_qla"]["total"]
|
||||
print(f"Speed up: {speedup:.2f}x")
|
||||
|
||||
if skip_bwd:
|
||||
return
|
||||
|
||||
dq_ref, dk_ref, dv_ref, db_ref, dg_ref, dh0_ref = chunk_gated_delta_rule_bwd_ref(
|
||||
q.to(ref_dtype, copy=True),
|
||||
k.to(ref_dtype, copy=True),
|
||||
v.to(ref_dtype, copy=True),
|
||||
g_ref,
|
||||
beta.to(ref_dtype, copy=True),
|
||||
A_ref.to(ref_dtype, copy=True),
|
||||
scale,
|
||||
h0,
|
||||
do.to(ref_dtype, copy=True),
|
||||
dht,
|
||||
cu_seqlens,
|
||||
)
|
||||
dq_fla, dk_fla, dv_fla, db_fla, dg_fla, dh0_fla, _, _ = (
|
||||
chunk_gated_delta_rule_bwd_fla(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
g_fla,
|
||||
beta,
|
||||
A_fla,
|
||||
scale,
|
||||
h0,
|
||||
do,
|
||||
dht,
|
||||
cu_seqlens,
|
||||
)
|
||||
)
|
||||
dq_qla, dk_qla, dv_qla, db_qla, dg_qla, dh0_qla = chunk_gated_delta_rule_bwd_qla(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
g_qla,
|
||||
beta,
|
||||
A_qla,
|
||||
do,
|
||||
dht,
|
||||
scale,
|
||||
h0,
|
||||
cu_seqlens,
|
||||
)
|
||||
|
||||
if check_accuracy:
|
||||
print(
|
||||
f"dq_fla: {(dq_fla - dq_ref).abs().max().item():.4f} / {dq_ref.abs().max().item():.4f}"
|
||||
)
|
||||
print(
|
||||
f"dq_qla: {(dq_qla - dq_ref).abs().max().item():.4f} / {dq_ref.abs().max().item():.4f}"
|
||||
)
|
||||
print(
|
||||
f"dk_fla: {(dk_fla - dk_ref).abs().max().item():.4f} / {dk_ref.abs().max().item():.4f}"
|
||||
)
|
||||
print(
|
||||
f"dk_qla: {(dk_qla - dk_ref).abs().max().item():.4f} / {dk_ref.abs().max().item():.4f}"
|
||||
)
|
||||
print(
|
||||
f"dv_fla: {(dv_fla - dv_ref).abs().max().item():.4f} / {dv_ref.abs().max().item():.4f}"
|
||||
)
|
||||
print(
|
||||
f"dv_qla: {(dv_qla - dv_ref).abs().max().item():.4f} / {dv_ref.abs().max().item():.4f}"
|
||||
)
|
||||
if dht is not None:
|
||||
print(
|
||||
f"dh0_fla: {(dh0_fla - dh0_ref).abs().max().item():.4f} / {dh0_ref.abs().max().item():.4f}"
|
||||
)
|
||||
print(
|
||||
f"dh0_qla: {(dh0_qla - dh0_ref).abs().max().item():.4f} / {dh0_ref.abs().max().item():.4f}"
|
||||
)
|
||||
print(
|
||||
f"db_fla: {(db_fla - db_ref).abs().max().item():.4f} / {db_ref.abs().max().item():.4f}"
|
||||
)
|
||||
print(
|
||||
f"db_qla: {(db_qla - db_ref).abs().max().item():.4f} / {db_ref.abs().max().item():.4f}"
|
||||
)
|
||||
print(
|
||||
f"dg_fla: {(dg_fla - dg_ref).abs().max().item():.4f} / {dg_ref.abs().max().item():.4f}"
|
||||
)
|
||||
print(
|
||||
f"dg_qla: {(dg_qla - dg_ref).abs().max().item():.4f} / {dg_ref.abs().max().item():.4f}"
|
||||
)
|
||||
|
||||
for _ in range(1000):
|
||||
dq_qla, dk_qla, dv_qla, db_qla, dg_qla, dh0_qla = (
|
||||
chunk_gated_delta_rule_bwd_qla(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
g_qla,
|
||||
beta,
|
||||
A_qla,
|
||||
do,
|
||||
dht,
|
||||
scale,
|
||||
h0,
|
||||
cu_seqlens,
|
||||
)
|
||||
)
|
||||
try:
|
||||
assert (
|
||||
dq_qla - dq_ref
|
||||
).abs().max().item() <= dq_ref.abs().max().item() * 0.02
|
||||
assert (
|
||||
dk_qla - dk_ref
|
||||
).abs().max().item() <= dk_ref.abs().max().item() * 0.02
|
||||
assert (
|
||||
dv_qla - dv_ref
|
||||
).abs().max().item() <= dv_ref.abs().max().item() * 0.02
|
||||
assert (
|
||||
dg_qla - dg_ref
|
||||
).abs().max().item() <= dg_ref.abs().max().item() * 0.02
|
||||
assert (
|
||||
db_qla - db_ref
|
||||
).abs().max().item() <= db_ref.abs().max().item() * 0.02
|
||||
if dht is not None:
|
||||
assert (
|
||||
dh0_qla - dh0_ref
|
||||
).abs().max().item() <= dh0_ref.abs().max().item() * 0.02
|
||||
except AssertionError as e:
|
||||
print("********** ERROR **********")
|
||||
print(
|
||||
f"dq_qla: {(dq_qla - dq_ref).abs().max().item():.4f} / {dq_ref.abs().max().item():.4f}"
|
||||
)
|
||||
print(
|
||||
f"dk_qla: {(dk_qla - dk_ref).abs().max().item():.4f} / {dk_ref.abs().max().item():.4f}"
|
||||
)
|
||||
print(
|
||||
f"dv_qla: {(dv_qla - dv_ref).abs().max().item():.4f} / {dv_ref.abs().max().item():.4f}"
|
||||
)
|
||||
if dht is not None:
|
||||
print(
|
||||
f"dh0_qla: {(dh0_qla - dh0_ref).abs().max().item():.4f} / {dh0_ref.abs().max().item():.4f}"
|
||||
)
|
||||
print(
|
||||
f"db_qla: {(db_qla - db_ref).abs().max().item():.4f} / {db_ref.abs().max().item():.4f}"
|
||||
)
|
||||
print(
|
||||
f"dg_qla: {(dg_qla - dg_ref).abs().max().item():.4f} / {dg_ref.abs().max().item():.4f}"
|
||||
)
|
||||
print("********** ERROR **********")
|
||||
raise e
|
||||
|
||||
if show_speedup:
|
||||
prof_fla = profile(
|
||||
chunk_gated_delta_rule_bwd_fla,
|
||||
[q, k, v, g_fla, beta, A_fla, scale, h0, do, dht, cu_seqlens],
|
||||
)
|
||||
prof_qla = profile(
|
||||
chunk_gated_delta_rule_bwd_qla,
|
||||
[q, k, v, g_qla, beta, A_qla, do, dht, scale, h0, cu_seqlens],
|
||||
)
|
||||
result_fla = {
|
||||
"[bwd] csum": prof_fla["chunk_local_cumsum_scalar_kernel"],
|
||||
"[bwd] recom": prof_fla["recompute_w_u_fwd_kernel"]
|
||||
+ prof_fla["chunk_gated_delta_rule_fwd_kernel_h_blockdim64"],
|
||||
"[bwd] dv": prof_fla["chunk_bwd_kernel_dv_local"],
|
||||
"[bwd] gdr": prof_fla["chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64"],
|
||||
"[bwd] dqkwg": prof_fla["kernel_kernel"],
|
||||
"[bwd] wy": prof_fla["prepare_wy_repr_bwd_kernel"],
|
||||
}
|
||||
result_qla = {
|
||||
"[bwd] csum": prof_qla["tilelang_chunk_local_cumsum_kernel_kernel"],
|
||||
"[bwd] recom": prof_qla["tilelang_prepare_h_kernel_kernel"],
|
||||
"[bwd] gdr": prof_qla["tilelang_fused_chunk_gdr_bwd_kernel_kernel"],
|
||||
}
|
||||
if num_k_heads < num_v_heads:
|
||||
result_fla["[bwd] reduc"] = prof_fla["compress_heads_kernel"]
|
||||
result_qla["[bwd] reduc"] = (
|
||||
prof_qla["tilelang_group_reduce_vector_kernel_kernel"] * 2
|
||||
)
|
||||
result_fla["total"] = prof_fla["total"]
|
||||
result_qla["total"] = prof_qla["total"]
|
||||
results = {
|
||||
"fla": result_fla,
|
||||
"flash_qla": result_qla,
|
||||
}
|
||||
df = pd.DataFrame(results)
|
||||
print(df.round(3))
|
||||
speedup = results["fla"]["total"] / results["flash_qla"]["total"]
|
||||
print(f"Speed up: {speedup:2.2f}x")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Test Gated Delta Rule")
|
||||
parser.add_argument(
|
||||
"--set",
|
||||
type=str,
|
||||
default="develop",
|
||||
help="Preset name (loads from settings/{set}.csv)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seqlen", "--num-tokens", type=int, default=16384, help="Sequence Length"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--nkh",
|
||||
"--num-k-heads",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of K heads (num_k_heads)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--nvh",
|
||||
"--num-heads",
|
||||
"--num-v-heads",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Number of V heads (num_v_heads)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-h0",
|
||||
action="store_true",
|
||||
help="Disable initial state and gradient of final state",
|
||||
)
|
||||
parser.add_argument("--skip-bwd", action="store_true", help="Test forward only")
|
||||
parser.add_argument(
|
||||
"--no-cp",
|
||||
"--disable-auto-cp",
|
||||
action="store_true",
|
||||
help="Disable auto intra-card CP",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--swa-ratio", type=float, default=0.75, help="Ratio of sliding-window heads"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data-dtype",
|
||||
type=str,
|
||||
default="bfloat16",
|
||||
help="Data type for input and output",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ref-dtype", type=str, default="float64", help="Data type for reference"
|
||||
)
|
||||
parser.add_argument("--hide-acc", action="store_true", help="Do not print accuracy")
|
||||
parser.add_argument("--hide-lat", action="store_true", help="Do not print latency")
|
||||
parser.add_argument(
|
||||
"--seed", "--random-seed", type=int, default=42, help="Random seed"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.nkh <= 0:
|
||||
args.nkh = args.nvh
|
||||
|
||||
metadata = {
|
||||
"head_dim_k": 128, # MUST BE 128
|
||||
"head_dim_v": 128, # MUST BE 128
|
||||
"chunk_size": 64, # MUST BE 64
|
||||
"num_tokens": args.seqlen,
|
||||
"num_k_heads": args.nkh,
|
||||
"num_v_heads": args.nvh,
|
||||
"use_h0": not args.no_h0,
|
||||
"data_dtype": args.data_dtype,
|
||||
"ref_dtype": args.ref_dtype,
|
||||
"check_accuracy": not args.hide_acc,
|
||||
"show_speedup": not args.hide_lat,
|
||||
"skip_bwd": args.skip_bwd,
|
||||
"auto_cp": not args.no_cp,
|
||||
"swa_ratio": args.swa_ratio,
|
||||
"random_seed": args.seed,
|
||||
"device": "cuda",
|
||||
}
|
||||
|
||||
import os
|
||||
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
preset = pd.read_csv(os.path.join(script_dir, "settings", f"{args.set}.csv"))
|
||||
for i, row in preset.iterrows():
|
||||
print("-" * 64)
|
||||
torch.cuda.empty_cache()
|
||||
data = row.to_dict()
|
||||
if "cu_seqlens" in data.keys():
|
||||
data["cu_seqlens"] = list(map(int, data["cu_seqlens"].split("-")))
|
||||
metadata.update(data)
|
||||
test_gated_delta_rule(**metadata)
|
||||
print("-" * 64)
|
||||
57
tests/test_legacy_sm_gdn.py
Normal file
57
tests/test_legacy_sm_gdn.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# Copyright (c) 2026 The Qwen team, Alibaba Group.
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from flash_qla.ops.gated_delta_rule.legacy import chunk_gated_delta_rule_fwd_legacy
|
||||
|
||||
|
||||
def _reference(q, k, v, g, beta, scale=None, initial_state=None):
|
||||
batch, tokens, q_heads, dim = q.shape
|
||||
v_heads = v.shape[2]
|
||||
scale = scale if scale is not None else dim**-0.5
|
||||
state = (
|
||||
initial_state.clone()
|
||||
if initial_state is not None
|
||||
else torch.zeros(batch, v_heads, dim, dim, device=q.device, dtype=q.dtype)
|
||||
)
|
||||
output = torch.empty_like(v)
|
||||
for b in range(batch):
|
||||
for hv in range(v_heads):
|
||||
hq = hv // (v_heads // q_heads)
|
||||
for t in range(tokens):
|
||||
gate = torch.exp(g[b, t, hv])
|
||||
delta = (v[b, t, hv] - gate * (state[b, hv].transpose(0, 1) @ k[b, t, hq])) * beta[b, t, hv]
|
||||
state[b, hv] = gate * state[b, hv] + torch.outer(k[b, t, hq], delta)
|
||||
output[b, t, hv] = scale * (state[b, hv].transpose(0, 1) @ q[b, t, hq])
|
||||
return output, state
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required")
|
||||
@pytest.mark.parametrize("dim", [16, 32, 64, 128])
|
||||
def test_legacy_sm_gdn_matches_reference(dim):
|
||||
torch.manual_seed(1000 + dim)
|
||||
q = torch.randn(1, 5, 2, dim, device="cuda", dtype=torch.float32).contiguous() * 0.05
|
||||
k = torch.randn_like(q).contiguous() * 0.05
|
||||
v = torch.randn(1, 5, 4, dim, device="cuda", dtype=torch.float32).contiguous() * 0.1
|
||||
g = torch.randn(1, 5, 4, device="cuda", dtype=torch.float32).contiguous() * 0.02 - 0.04
|
||||
beta = torch.rand(1, 5, 4, device="cuda", dtype=torch.float32).contiguous()
|
||||
h0 = torch.randn(1, 4, dim, dim, device="cuda", dtype=torch.float32).contiguous() * 0.01
|
||||
scale = 1.0 / math.sqrt(dim)
|
||||
|
||||
out_ref, state_ref = _reference(q, k, v, g, beta, scale, h0)
|
||||
out, state = chunk_gated_delta_rule_fwd_legacy(q, k, v, g, beta, scale, h0)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
torch.testing.assert_close(out, out_ref, atol=2e-4, rtol=2e-4)
|
||||
torch.testing.assert_close(state, state_ref, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required")
|
||||
def test_legacy_sm_gdn_rejects_unsupported_dtype():
|
||||
q = torch.randn(1, 1, 1, 16, device="cuda", dtype=torch.float16)
|
||||
with pytest.raises(ValueError, match="float32"):
|
||||
chunk_gated_delta_rule_fwd_legacy(q, q, q, torch.randn(1, 1, 1, device="cuda"), torch.randn(1, 1, 1, device="cuda"))
|
||||
Reference in New Issue
Block a user