first commit

This commit is contained in:
2026-06-14 23:49:03 +08:00
commit 3f95e2939d
35 changed files with 6764 additions and 0 deletions

View File

@@ -0,0 +1,20 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
from .profiler import profile
from .pack import pad_and_reshape, pack, unpack, fill_last_chunk_of_g
from .math import l2norm
from .index import prepare_chunk_indices, prepare_chunk_offsets, tensor_cache
__all__ = [
"profile",
"pad_and_reshape",
"pack",
"unpack",
"fill_last_chunk_of_g",
"l2norm",
"prepare_chunk_indices",
"prepare_chunk_offsets",
"tensor_cache",
]

138
flash_qla/utils/index.py Normal file
View File

@@ -0,0 +1,138 @@
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
import functools
from typing import Any
from collections import OrderedDict
from collections.abc import Callable
import torch
import tilelang
import tilelang.language as T
def tensor_cache(
fn: Callable[..., torch.Tensor],
) -> Callable[..., torch.Tensor]:
"""
A decorator that caches the most recent results of a function with tensor inputs.
This decorator will store the output of the decorated function for the most recent set of input tensors.
The cache is limited to a fixed size (default is 256). When the cache is full, the oldest entry will be removed.
Args:
fn (Callable[..., torch.Tensor]):
The function to be decorated. It should take tensor inputs and return tensor outputs.
Returns:
Callable[..., torch.Tensor]:
A wrapped version of the input function with single-entry caching.
"""
cache: "OrderedDict[tuple[tuple[int, ...], tuple[tuple[str, int], ...]], tuple[tuple[Any, ...], dict[str, Any], Any]]" = OrderedDict()
cache_size = 256
def get_id(x: Any):
if (type(x) is int) or (type(x) is float) or (type(x) is str):
return x
else:
return id(x)
def make_identity_key(
args: tuple[Any, ...], kwargs: dict[str, Any]
) -> tuple[tuple[int, ...], tuple[tuple[str, int], ...]]:
args_key = tuple(get_id(a) for a in args)
kwargs_key = tuple(sorted((k, get_id(v)) for k, v in kwargs.items()))
return args_key, kwargs_key
@functools.wraps(fn)
def wrapper(*args: Any, **kwargs: Any) -> Any:
nonlocal cache, cache_size
key = make_identity_key(args, kwargs)
if key in cache:
cache.move_to_end(key, last=True)
_, _, cached_result = cache[key]
return cached_result
result = fn(*args, **kwargs)
cache[key] = (args, kwargs, result)
cache.move_to_end(key, last=True)
if len(cache) > cache_size:
cache.popitem(last=False)
return result
return wrapper
@tensor_cache
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
return torch.diff(cu_seqlens)
@tensor_cache
def prepare_chunk_indices(
cu_seqlens: torch.LongTensor,
chunk_size: int,
) -> torch.LongTensor:
# TODO: tilelang kernel
indices = torch.cat(
[
torch.arange(n)
for n in tilelang.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()
]
)
return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens)
@tilelang.jit()
def tilelang_prepare_chunk_offsets(
chunk_size,
block_size,
dtype,
):
batch_size_plus_1 = T.dynamic("batch_size_plus_1")
num_threads = min(max(block_size, 32), 128)
@T.prim_func
def tilelang_prepare_chunk_offsets_kernel(
cu_seqlens: T.Tensor([batch_size_plus_1], dtype=dtype),
chunk_offsets: T.Tensor([batch_size_plus_1], dtype=dtype),
):
with T.Kernel(1, threads=num_threads) as (bb,):
_batch_size = T.alloc_var("int32")
_batch_size = batch_size_plus_1 - 1
seqlen_start_fragment = T.alloc_fragment((block_size), dtype=dtype)
seqlen_end_fragment = T.alloc_fragment((block_size), dtype=dtype)
chunk_offset_fragment = T.alloc_fragment((block_size), dtype=dtype)
T.copy(cu_seqlens[: batch_size_plus_1 - 1], seqlen_start_fragment)
T.copy(cu_seqlens[1:], seqlen_end_fragment)
for i in T.Parallel(block_size):
chunk_offset_fragment[i] = (
seqlen_end_fragment[i] - seqlen_start_fragment[i]
)
chunk_offset_fragment[i] = (
chunk_offset_fragment[i] + chunk_size - 1
) // chunk_size
T.cumsum(src=chunk_offset_fragment, dim=0)
chunk_offsets[0] = 0
T.copy(chunk_offset_fragment, chunk_offsets[1:])
return tilelang_prepare_chunk_offsets_kernel
@tensor_cache
def prepare_chunk_offsets(
cu_seqlens: torch.LongTensor,
chunk_size: int,
) -> torch.LongTensor:
chunk_offsets = torch.empty_like(cu_seqlens)
tilelang_prepare_chunk_offsets_kernel = tilelang_prepare_chunk_offsets(
chunk_size=chunk_size,
block_size=tilelang.next_power_of_2(cu_seqlens.shape[0] - 1),
dtype=cu_seqlens.dtype,
)
tilelang_prepare_chunk_offsets_kernel(cu_seqlens, chunk_offsets)
return chunk_offsets, chunk_offsets[-1].item()

21
flash_qla/utils/math.py Normal file
View File

@@ -0,0 +1,21 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
import torch
@torch.compile
def l2norm_compiled(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6):
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
return (x * inv_norm).to(x.dtype)
def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6):
assert dim == -1
assert x.stride(-1) == 1
raw_shape = x.shape
x = x.view((-1, raw_shape[-1]))
torch._dynamo.mark_dynamic(x, 0)
y = l2norm_compiled(x, dim, eps)
y = y.view(raw_shape)
return y

83
flash_qla/utils/pack.py Normal file
View File

@@ -0,0 +1,83 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
import torch
def unpack(
x: torch.Tensor, # [B, T, H]
cu_seqlens: torch.Tensor,
):
assert x.shape[0] == 1
assert len(cu_seqlens.shape) == 1
max_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
batch_size = cu_seqlens.shape[0] - 1
y = torch.zeros((batch_size, max_len, *x.shape[2:]), dtype=x.dtype, device=x.device)
for i in range(batch_size):
start = cu_seqlens[i].item()
end = cu_seqlens[i + 1].item()
y[i, : end - start] = x[0, start:end]
return y
def pack(
x: torch.Tensor, # [B, T, H]
cu_seqlens: torch.Tensor,
):
assert len(cu_seqlens.shape) == 1
sum_len = cu_seqlens[-1].item()
batch_size = cu_seqlens.shape[0] - 1
y = torch.empty((1, sum_len, *x.shape[2:]), dtype=x.dtype, device=x.device)
for i in range(batch_size):
start = cu_seqlens[i].item()
end = cu_seqlens[i + 1].item()
y[0, start:end] = x[i, : end - start]
return y
def pad_and_reshape(
x: torch.Tensor,
dim: int,
chunk_size: int = 64,
):
sequence_length = x.shape[dim]
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
zeros = [
0,
] * (2 * (len(x.shape) - 1 - dim))
padded = torch.nn.functional.pad(x, (*zeros, 0, pad_size))
return padded.reshape((*x.shape[:dim], -1, chunk_size, *x.shape[dim + 1 :]))
def fill_last_chunk_of_g(
g: torch.Tensor,
num_tokens: int,
cu_seqlens: torch.Tensor,
chunk_size: int = 64,
reverse: bool = False,
):
if cu_seqlens is None:
last_chunk_size = num_tokens % chunk_size
if last_chunk_size > 0:
if reverse:
g[:, -1, last_chunk_size - 1] += g[:, -1, -1]
else:
g[:, -1, last_chunk_size:] = g[
:, -1, last_chunk_size - 1 : last_chunk_size
]
else:
for i in range(cu_seqlens.shape[0] - 1):
start = cu_seqlens[i].item()
end = cu_seqlens[i + 1].item()
last_chunk_idx = (end - start) // chunk_size
last_chunk_size = (end - start) % chunk_size
if last_chunk_size > 0:
if reverse:
g[i, last_chunk_idx, last_chunk_size - 1] += g[
i, last_chunk_idx, -1
]
else:
g[i, last_chunk_idx, last_chunk_size:] = g[
i, last_chunk_idx, last_chunk_size - 1 : last_chunk_size
]
return g

View File

@@ -0,0 +1,25 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
import torch
import tilelang
def profile(func, inputs, wait: int = 50, warmup: int = 50, rep: int = 100):
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=rep),
# on_trace_ready=torch.profiler.tensorboard_trace_handler('./tb'),
) as prof:
for idx in range(wait + warmup + rep):
func(*inputs)
prof.step()
# print(prof.key_averages().table(sort_by="cpu_time", row_limit=10))
result = {x.key: x.device_time * 1e-3 for x in prof.key_averages()}
result["total"] = tilelang.profiler.do_bench(
lambda: func(*inputs), warmup=warmup, rep=rep
)
return result