first commit
This commit is contained in:
71
tests/test_function_signature.py
Normal file
71
tests/test_function_signature.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# Copyright (c) 2026 The Qwen team, Alibaba Group.
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
"""Static checks for autograd ``Function`` signatures.
|
||||
|
||||
PyTorch validates that ``Function.backward`` returns exactly as many gradients
|
||||
as ``Function.forward`` received non-``ctx`` inputs; mismatches raise at
|
||||
``.backward()`` time. ``tests/test_gdr.py`` invokes ``chunk_gated_delta_rule_fwd``
|
||||
and ``chunk_gated_delta_rule_bwd`` directly, bypassing the autograd path, so
|
||||
drift between the forward signature and the backward return tuple goes
|
||||
uncaught by the existing suite.
|
||||
|
||||
These tests parse the source files with ``ast`` instead of importing the
|
||||
modules so they run on CPU-only / non-Hopper machines.
|
||||
"""
|
||||
|
||||
import ast
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
CHUNK_INIT = "flash_qla/ops/gated_delta_rule/chunk/__init__.py"
|
||||
|
||||
|
||||
def _parse(rel_path: str) -> ast.Module:
|
||||
return ast.parse((REPO_ROOT / rel_path).read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
def _get_class(module: ast.Module, name: str) -> ast.ClassDef:
|
||||
for node in module.body:
|
||||
if isinstance(node, ast.ClassDef) and node.name == name:
|
||||
return node
|
||||
raise AssertionError(f"class {name!r} not found")
|
||||
|
||||
|
||||
def _get_method(cls: ast.ClassDef, name: str) -> ast.FunctionDef:
|
||||
for node in cls.body:
|
||||
if isinstance(node, ast.FunctionDef) and node.name == name:
|
||||
return node
|
||||
raise AssertionError(f"method {name!r} not found on {cls.name}")
|
||||
|
||||
|
||||
def test_chunk_gated_delta_rule_grad_count_matches_forward_inputs():
|
||||
"""``backward`` must return one gradient per non-``ctx`` input of ``forward``."""
|
||||
module = _parse(CHUNK_INIT)
|
||||
cls = _get_class(module, "ChunkGatedDeltaRuleFunction")
|
||||
|
||||
fwd = _get_method(cls, "forward")
|
||||
fwd_args = fwd.args.args
|
||||
assert fwd_args and fwd_args[0].arg == "ctx", (
|
||||
"forward must take `ctx` as its first argument"
|
||||
)
|
||||
n_inputs = len(fwd_args) - 1 # exclude ctx
|
||||
|
||||
bwd = _get_method(cls, "backward")
|
||||
returns = [n for n in ast.walk(bwd) if isinstance(n, ast.Return)]
|
||||
assert len(returns) == 1, f"expected one Return in backward, got {len(returns)}"
|
||||
assert isinstance(returns[0].value, ast.Tuple), (
|
||||
"backward must return a tuple literal"
|
||||
)
|
||||
n_grads = len(returns[0].value.elts)
|
||||
|
||||
assert n_inputs == n_grads, (
|
||||
f"backward returns {n_grads} gradients but forward takes {n_inputs} non-ctx "
|
||||
f"inputs; PyTorch will raise a count-mismatch error at .backward() time."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_chunk_gated_delta_rule_grad_count_matches_forward_inputs()
|
||||
print("OK")
|
||||
Reference in New Issue
Block a user