Files
FlashQLA-SM70-SM75/flash_qla/utils/math.py
2026-06-14 23:49:03 +08:00

22 lines
618 B
Python

# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
import torch
@torch.compile
def l2norm_compiled(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6):
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
return (x * inv_norm).to(x.dtype)
def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6):
assert dim == -1
assert x.stride(-1) == 1
raw_shape = x.shape
x = x.view((-1, raw_shape[-1]))
torch._dynamo.mark_dynamic(x, 0)
y = l2norm_compiled(x, dim, eps)
y = y.view(raw_shape)
return y