58 lines
2.5 KiB
Python
58 lines
2.5 KiB
Python
# Copyright (c) 2026 The Qwen team, Alibaba Group.
|
|
# Licensed under The MIT License [see LICENSE for details]
|
|
|
|
import math
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from flash_qla.ops.gated_delta_rule.legacy import chunk_gated_delta_rule_fwd_legacy
|
|
|
|
|
|
def _reference(q, k, v, g, beta, scale=None, initial_state=None):
|
|
batch, tokens, q_heads, dim = q.shape
|
|
v_heads = v.shape[2]
|
|
scale = scale if scale is not None else dim**-0.5
|
|
state = (
|
|
initial_state.clone()
|
|
if initial_state is not None
|
|
else torch.zeros(batch, v_heads, dim, dim, device=q.device, dtype=q.dtype)
|
|
)
|
|
output = torch.empty_like(v)
|
|
for b in range(batch):
|
|
for hv in range(v_heads):
|
|
hq = hv // (v_heads // q_heads)
|
|
for t in range(tokens):
|
|
gate = torch.exp(g[b, t, hv])
|
|
delta = (v[b, t, hv] - gate * (state[b, hv].transpose(0, 1) @ k[b, t, hq])) * beta[b, t, hv]
|
|
state[b, hv] = gate * state[b, hv] + torch.outer(k[b, t, hq], delta)
|
|
output[b, t, hv] = scale * (state[b, hv].transpose(0, 1) @ q[b, t, hq])
|
|
return output, state
|
|
|
|
|
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required")
|
|
@pytest.mark.parametrize("dim", [16, 32, 64, 128])
|
|
def test_legacy_sm_gdn_matches_reference(dim):
|
|
torch.manual_seed(1000 + dim)
|
|
q = torch.randn(1, 5, 2, dim, device="cuda", dtype=torch.float32).contiguous() * 0.05
|
|
k = torch.randn_like(q).contiguous() * 0.05
|
|
v = torch.randn(1, 5, 4, dim, device="cuda", dtype=torch.float32).contiguous() * 0.1
|
|
g = torch.randn(1, 5, 4, device="cuda", dtype=torch.float32).contiguous() * 0.02 - 0.04
|
|
beta = torch.rand(1, 5, 4, device="cuda", dtype=torch.float32).contiguous()
|
|
h0 = torch.randn(1, 4, dim, dim, device="cuda", dtype=torch.float32).contiguous() * 0.01
|
|
scale = 1.0 / math.sqrt(dim)
|
|
|
|
out_ref, state_ref = _reference(q, k, v, g, beta, scale, h0)
|
|
out, state = chunk_gated_delta_rule_fwd_legacy(q, k, v, g, beta, scale, h0)
|
|
torch.cuda.synchronize()
|
|
|
|
torch.testing.assert_close(out, out_ref, atol=2e-4, rtol=2e-4)
|
|
torch.testing.assert_close(state, state_ref, atol=1e-3, rtol=1e-3)
|
|
|
|
|
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required")
|
|
def test_legacy_sm_gdn_rejects_unsupported_dtype():
|
|
q = torch.randn(1, 1, 1, 16, device="cuda", dtype=torch.float16)
|
|
with pytest.raises(ValueError, match="float32"):
|
|
chunk_gated_delta_rule_fwd_legacy(q, q, q, torch.randn(1, 1, 1, device="cuda"), torch.randn(1, 1, 1, device="cuda"))
|