Files
FlashQLA-SM70-SM75/tests/test_legacy_sm_gdn.py
2026-06-14 23:49:03 +08:00

58 lines
2.5 KiB
Python

# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
import math
import pytest
import torch
from flash_qla.ops.gated_delta_rule.legacy import chunk_gated_delta_rule_fwd_legacy
def _reference(q, k, v, g, beta, scale=None, initial_state=None):
batch, tokens, q_heads, dim = q.shape
v_heads = v.shape[2]
scale = scale if scale is not None else dim**-0.5
state = (
initial_state.clone()
if initial_state is not None
else torch.zeros(batch, v_heads, dim, dim, device=q.device, dtype=q.dtype)
)
output = torch.empty_like(v)
for b in range(batch):
for hv in range(v_heads):
hq = hv // (v_heads // q_heads)
for t in range(tokens):
gate = torch.exp(g[b, t, hv])
delta = (v[b, t, hv] - gate * (state[b, hv].transpose(0, 1) @ k[b, t, hq])) * beta[b, t, hv]
state[b, hv] = gate * state[b, hv] + torch.outer(k[b, t, hq], delta)
output[b, t, hv] = scale * (state[b, hv].transpose(0, 1) @ q[b, t, hq])
return output, state
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required")
@pytest.mark.parametrize("dim", [16, 32, 64, 128])
def test_legacy_sm_gdn_matches_reference(dim):
torch.manual_seed(1000 + dim)
q = torch.randn(1, 5, 2, dim, device="cuda", dtype=torch.float32).contiguous() * 0.05
k = torch.randn_like(q).contiguous() * 0.05
v = torch.randn(1, 5, 4, dim, device="cuda", dtype=torch.float32).contiguous() * 0.1
g = torch.randn(1, 5, 4, device="cuda", dtype=torch.float32).contiguous() * 0.02 - 0.04
beta = torch.rand(1, 5, 4, device="cuda", dtype=torch.float32).contiguous()
h0 = torch.randn(1, 4, dim, dim, device="cuda", dtype=torch.float32).contiguous() * 0.01
scale = 1.0 / math.sqrt(dim)
out_ref, state_ref = _reference(q, k, v, g, beta, scale, h0)
out, state = chunk_gated_delta_rule_fwd_legacy(q, k, v, g, beta, scale, h0)
torch.cuda.synchronize()
torch.testing.assert_close(out, out_ref, atol=2e-4, rtol=2e-4)
torch.testing.assert_close(state, state_ref, atol=1e-3, rtol=1e-3)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required")
def test_legacy_sm_gdn_rejects_unsupported_dtype():
q = torch.randn(1, 1, 1, 16, device="cuda", dtype=torch.float16)
with pytest.raises(ValueError, match="float32"):
chunk_gated_delta_rule_fwd_legacy(q, q, q, torch.randn(1, 1, 1, device="cuda"), torch.randn(1, 1, 1, device="cuda"))