72 lines
2.6 KiB
Python
72 lines
2.6 KiB
Python
# Copyright (c) 2026 The Qwen team, Alibaba Group.
|
|
# Licensed under The MIT License [see LICENSE for details]
|
|
|
|
"""Static checks for autograd ``Function`` signatures.
|
|
|
|
PyTorch validates that ``Function.backward`` returns exactly as many gradients
|
|
as ``Function.forward`` received non-``ctx`` inputs; mismatches raise at
|
|
``.backward()`` time. ``tests/test_gdr.py`` invokes ``chunk_gated_delta_rule_fwd``
|
|
and ``chunk_gated_delta_rule_bwd`` directly, bypassing the autograd path, so
|
|
drift between the forward signature and the backward return tuple goes
|
|
uncaught by the existing suite.
|
|
|
|
These tests parse the source files with ``ast`` instead of importing the
|
|
modules so they run on CPU-only / non-Hopper machines.
|
|
"""
|
|
|
|
import ast
|
|
from pathlib import Path
|
|
|
|
|
|
REPO_ROOT = Path(__file__).resolve().parents[1]
|
|
CHUNK_INIT = "flash_qla/ops/gated_delta_rule/chunk/__init__.py"
|
|
|
|
|
|
def _parse(rel_path: str) -> ast.Module:
|
|
return ast.parse((REPO_ROOT / rel_path).read_text(encoding="utf-8"))
|
|
|
|
|
|
def _get_class(module: ast.Module, name: str) -> ast.ClassDef:
|
|
for node in module.body:
|
|
if isinstance(node, ast.ClassDef) and node.name == name:
|
|
return node
|
|
raise AssertionError(f"class {name!r} not found")
|
|
|
|
|
|
def _get_method(cls: ast.ClassDef, name: str) -> ast.FunctionDef:
|
|
for node in cls.body:
|
|
if isinstance(node, ast.FunctionDef) and node.name == name:
|
|
return node
|
|
raise AssertionError(f"method {name!r} not found on {cls.name}")
|
|
|
|
|
|
def test_chunk_gated_delta_rule_grad_count_matches_forward_inputs():
|
|
"""``backward`` must return one gradient per non-``ctx`` input of ``forward``."""
|
|
module = _parse(CHUNK_INIT)
|
|
cls = _get_class(module, "ChunkGatedDeltaRuleFunction")
|
|
|
|
fwd = _get_method(cls, "forward")
|
|
fwd_args = fwd.args.args
|
|
assert fwd_args and fwd_args[0].arg == "ctx", (
|
|
"forward must take `ctx` as its first argument"
|
|
)
|
|
n_inputs = len(fwd_args) - 1 # exclude ctx
|
|
|
|
bwd = _get_method(cls, "backward")
|
|
returns = [n for n in ast.walk(bwd) if isinstance(n, ast.Return)]
|
|
assert len(returns) == 1, f"expected one Return in backward, got {len(returns)}"
|
|
assert isinstance(returns[0].value, ast.Tuple), (
|
|
"backward must return a tuple literal"
|
|
)
|
|
n_grads = len(returns[0].value.elts)
|
|
|
|
assert n_inputs == n_grads, (
|
|
f"backward returns {n_grads} gradients but forward takes {n_inputs} non-ctx "
|
|
f"inputs; PyTorch will raise a count-mismatch error at .backward() time."
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_chunk_gated_delta_rule_grad_count_matches_forward_inputs()
|
|
print("OK")
|