first commit
This commit is contained in:
20
flash_qla/utils/__init__.py
Normal file
20
flash_qla/utils/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# Copyright (c) 2026 The Qwen team, Alibaba Group.
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
from .profiler import profile
|
||||
from .pack import pad_and_reshape, pack, unpack, fill_last_chunk_of_g
|
||||
from .math import l2norm
|
||||
from .index import prepare_chunk_indices, prepare_chunk_offsets, tensor_cache
|
||||
|
||||
|
||||
__all__ = [
|
||||
"profile",
|
||||
"pad_and_reshape",
|
||||
"pack",
|
||||
"unpack",
|
||||
"fill_last_chunk_of_g",
|
||||
"l2norm",
|
||||
"prepare_chunk_indices",
|
||||
"prepare_chunk_offsets",
|
||||
"tensor_cache",
|
||||
]
|
||||
138
flash_qla/utils/index.py
Normal file
138
flash_qla/utils/index.py
Normal file
@@ -0,0 +1,138 @@
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
import functools
|
||||
from typing import Any
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
import tilelang
|
||||
import tilelang.language as T
|
||||
|
||||
|
||||
def tensor_cache(
|
||||
fn: Callable[..., torch.Tensor],
|
||||
) -> Callable[..., torch.Tensor]:
|
||||
"""
|
||||
A decorator that caches the most recent results of a function with tensor inputs.
|
||||
|
||||
This decorator will store the output of the decorated function for the most recent set of input tensors.
|
||||
The cache is limited to a fixed size (default is 256). When the cache is full, the oldest entry will be removed.
|
||||
|
||||
Args:
|
||||
fn (Callable[..., torch.Tensor]):
|
||||
The function to be decorated. It should take tensor inputs and return tensor outputs.
|
||||
|
||||
Returns:
|
||||
Callable[..., torch.Tensor]:
|
||||
A wrapped version of the input function with single-entry caching.
|
||||
"""
|
||||
|
||||
cache: "OrderedDict[tuple[tuple[int, ...], tuple[tuple[str, int], ...]], tuple[tuple[Any, ...], dict[str, Any], Any]]" = OrderedDict()
|
||||
cache_size = 256
|
||||
|
||||
def get_id(x: Any):
|
||||
if (type(x) is int) or (type(x) is float) or (type(x) is str):
|
||||
return x
|
||||
else:
|
||||
return id(x)
|
||||
|
||||
def make_identity_key(
|
||||
args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> tuple[tuple[int, ...], tuple[tuple[str, int], ...]]:
|
||||
args_key = tuple(get_id(a) for a in args)
|
||||
kwargs_key = tuple(sorted((k, get_id(v)) for k, v in kwargs.items()))
|
||||
return args_key, kwargs_key
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
nonlocal cache, cache_size
|
||||
key = make_identity_key(args, kwargs)
|
||||
if key in cache:
|
||||
cache.move_to_end(key, last=True)
|
||||
_, _, cached_result = cache[key]
|
||||
return cached_result
|
||||
|
||||
result = fn(*args, **kwargs)
|
||||
cache[key] = (args, kwargs, result)
|
||||
cache.move_to_end(key, last=True)
|
||||
if len(cache) > cache_size:
|
||||
cache.popitem(last=False)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@tensor_cache
|
||||
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
|
||||
return torch.diff(cu_seqlens)
|
||||
|
||||
|
||||
@tensor_cache
|
||||
def prepare_chunk_indices(
|
||||
cu_seqlens: torch.LongTensor,
|
||||
chunk_size: int,
|
||||
) -> torch.LongTensor:
|
||||
# TODO: tilelang kernel
|
||||
indices = torch.cat(
|
||||
[
|
||||
torch.arange(n)
|
||||
for n in tilelang.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()
|
||||
]
|
||||
)
|
||||
return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens)
|
||||
|
||||
|
||||
@tilelang.jit()
|
||||
def tilelang_prepare_chunk_offsets(
|
||||
chunk_size,
|
||||
block_size,
|
||||
dtype,
|
||||
):
|
||||
batch_size_plus_1 = T.dynamic("batch_size_plus_1")
|
||||
num_threads = min(max(block_size, 32), 128)
|
||||
|
||||
@T.prim_func
|
||||
def tilelang_prepare_chunk_offsets_kernel(
|
||||
cu_seqlens: T.Tensor([batch_size_plus_1], dtype=dtype),
|
||||
chunk_offsets: T.Tensor([batch_size_plus_1], dtype=dtype),
|
||||
):
|
||||
with T.Kernel(1, threads=num_threads) as (bb,):
|
||||
_batch_size = T.alloc_var("int32")
|
||||
_batch_size = batch_size_plus_1 - 1
|
||||
|
||||
seqlen_start_fragment = T.alloc_fragment((block_size), dtype=dtype)
|
||||
seqlen_end_fragment = T.alloc_fragment((block_size), dtype=dtype)
|
||||
chunk_offset_fragment = T.alloc_fragment((block_size), dtype=dtype)
|
||||
|
||||
T.copy(cu_seqlens[: batch_size_plus_1 - 1], seqlen_start_fragment)
|
||||
T.copy(cu_seqlens[1:], seqlen_end_fragment)
|
||||
|
||||
for i in T.Parallel(block_size):
|
||||
chunk_offset_fragment[i] = (
|
||||
seqlen_end_fragment[i] - seqlen_start_fragment[i]
|
||||
)
|
||||
chunk_offset_fragment[i] = (
|
||||
chunk_offset_fragment[i] + chunk_size - 1
|
||||
) // chunk_size
|
||||
T.cumsum(src=chunk_offset_fragment, dim=0)
|
||||
|
||||
chunk_offsets[0] = 0
|
||||
T.copy(chunk_offset_fragment, chunk_offsets[1:])
|
||||
|
||||
return tilelang_prepare_chunk_offsets_kernel
|
||||
|
||||
|
||||
@tensor_cache
|
||||
def prepare_chunk_offsets(
|
||||
cu_seqlens: torch.LongTensor,
|
||||
chunk_size: int,
|
||||
) -> torch.LongTensor:
|
||||
chunk_offsets = torch.empty_like(cu_seqlens)
|
||||
tilelang_prepare_chunk_offsets_kernel = tilelang_prepare_chunk_offsets(
|
||||
chunk_size=chunk_size,
|
||||
block_size=tilelang.next_power_of_2(cu_seqlens.shape[0] - 1),
|
||||
dtype=cu_seqlens.dtype,
|
||||
)
|
||||
tilelang_prepare_chunk_offsets_kernel(cu_seqlens, chunk_offsets)
|
||||
return chunk_offsets, chunk_offsets[-1].item()
|
||||
21
flash_qla/utils/math.py
Normal file
21
flash_qla/utils/math.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# Copyright (c) 2026 The Qwen team, Alibaba Group.
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@torch.compile
|
||||
def l2norm_compiled(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6):
|
||||
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
|
||||
return (x * inv_norm).to(x.dtype)
|
||||
|
||||
|
||||
def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6):
|
||||
assert dim == -1
|
||||
assert x.stride(-1) == 1
|
||||
raw_shape = x.shape
|
||||
x = x.view((-1, raw_shape[-1]))
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
y = l2norm_compiled(x, dim, eps)
|
||||
y = y.view(raw_shape)
|
||||
return y
|
||||
83
flash_qla/utils/pack.py
Normal file
83
flash_qla/utils/pack.py
Normal file
@@ -0,0 +1,83 @@
|
||||
# Copyright (c) 2026 The Qwen team, Alibaba Group.
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def unpack(
|
||||
x: torch.Tensor, # [B, T, H]
|
||||
cu_seqlens: torch.Tensor,
|
||||
):
|
||||
assert x.shape[0] == 1
|
||||
assert len(cu_seqlens.shape) == 1
|
||||
max_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
batch_size = cu_seqlens.shape[0] - 1
|
||||
y = torch.zeros((batch_size, max_len, *x.shape[2:]), dtype=x.dtype, device=x.device)
|
||||
for i in range(batch_size):
|
||||
start = cu_seqlens[i].item()
|
||||
end = cu_seqlens[i + 1].item()
|
||||
y[i, : end - start] = x[0, start:end]
|
||||
return y
|
||||
|
||||
|
||||
def pack(
|
||||
x: torch.Tensor, # [B, T, H]
|
||||
cu_seqlens: torch.Tensor,
|
||||
):
|
||||
assert len(cu_seqlens.shape) == 1
|
||||
sum_len = cu_seqlens[-1].item()
|
||||
batch_size = cu_seqlens.shape[0] - 1
|
||||
y = torch.empty((1, sum_len, *x.shape[2:]), dtype=x.dtype, device=x.device)
|
||||
for i in range(batch_size):
|
||||
start = cu_seqlens[i].item()
|
||||
end = cu_seqlens[i + 1].item()
|
||||
y[0, start:end] = x[i, : end - start]
|
||||
return y
|
||||
|
||||
|
||||
def pad_and_reshape(
|
||||
x: torch.Tensor,
|
||||
dim: int,
|
||||
chunk_size: int = 64,
|
||||
):
|
||||
sequence_length = x.shape[dim]
|
||||
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
|
||||
zeros = [
|
||||
0,
|
||||
] * (2 * (len(x.shape) - 1 - dim))
|
||||
padded = torch.nn.functional.pad(x, (*zeros, 0, pad_size))
|
||||
return padded.reshape((*x.shape[:dim], -1, chunk_size, *x.shape[dim + 1 :]))
|
||||
|
||||
|
||||
def fill_last_chunk_of_g(
|
||||
g: torch.Tensor,
|
||||
num_tokens: int,
|
||||
cu_seqlens: torch.Tensor,
|
||||
chunk_size: int = 64,
|
||||
reverse: bool = False,
|
||||
):
|
||||
if cu_seqlens is None:
|
||||
last_chunk_size = num_tokens % chunk_size
|
||||
if last_chunk_size > 0:
|
||||
if reverse:
|
||||
g[:, -1, last_chunk_size - 1] += g[:, -1, -1]
|
||||
else:
|
||||
g[:, -1, last_chunk_size:] = g[
|
||||
:, -1, last_chunk_size - 1 : last_chunk_size
|
||||
]
|
||||
else:
|
||||
for i in range(cu_seqlens.shape[0] - 1):
|
||||
start = cu_seqlens[i].item()
|
||||
end = cu_seqlens[i + 1].item()
|
||||
last_chunk_idx = (end - start) // chunk_size
|
||||
last_chunk_size = (end - start) % chunk_size
|
||||
if last_chunk_size > 0:
|
||||
if reverse:
|
||||
g[i, last_chunk_idx, last_chunk_size - 1] += g[
|
||||
i, last_chunk_idx, -1
|
||||
]
|
||||
else:
|
||||
g[i, last_chunk_idx, last_chunk_size:] = g[
|
||||
i, last_chunk_idx, last_chunk_size - 1 : last_chunk_size
|
||||
]
|
||||
return g
|
||||
25
flash_qla/utils/profiler.py
Normal file
25
flash_qla/utils/profiler.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# Copyright (c) 2026 The Qwen team, Alibaba Group.
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
import torch
|
||||
import tilelang
|
||||
|
||||
|
||||
def profile(func, inputs, wait: int = 50, warmup: int = 50, rep: int = 100):
|
||||
with torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=rep),
|
||||
# on_trace_ready=torch.profiler.tensorboard_trace_handler('./tb'),
|
||||
) as prof:
|
||||
for idx in range(wait + warmup + rep):
|
||||
func(*inputs)
|
||||
prof.step()
|
||||
# print(prof.key_averages().table(sort_by="cpu_time", row_limit=10))
|
||||
result = {x.key: x.device_time * 1e-3 for x in prof.key_averages()}
|
||||
result["total"] = tilelang.profiler.do_bench(
|
||||
lambda: func(*inputs), warmup=warmup, rep=rep
|
||||
)
|
||||
return result
|
||||
Reference in New Issue
Block a user