first commit
This commit is contained in:
138
flash_qla/utils/index.py
Normal file
138
flash_qla/utils/index.py
Normal file
@@ -0,0 +1,138 @@
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
import functools
|
||||
from typing import Any
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
import tilelang
|
||||
import tilelang.language as T
|
||||
|
||||
|
||||
def tensor_cache(
|
||||
fn: Callable[..., torch.Tensor],
|
||||
) -> Callable[..., torch.Tensor]:
|
||||
"""
|
||||
A decorator that caches the most recent results of a function with tensor inputs.
|
||||
|
||||
This decorator will store the output of the decorated function for the most recent set of input tensors.
|
||||
The cache is limited to a fixed size (default is 256). When the cache is full, the oldest entry will be removed.
|
||||
|
||||
Args:
|
||||
fn (Callable[..., torch.Tensor]):
|
||||
The function to be decorated. It should take tensor inputs and return tensor outputs.
|
||||
|
||||
Returns:
|
||||
Callable[..., torch.Tensor]:
|
||||
A wrapped version of the input function with single-entry caching.
|
||||
"""
|
||||
|
||||
cache: "OrderedDict[tuple[tuple[int, ...], tuple[tuple[str, int], ...]], tuple[tuple[Any, ...], dict[str, Any], Any]]" = OrderedDict()
|
||||
cache_size = 256
|
||||
|
||||
def get_id(x: Any):
|
||||
if (type(x) is int) or (type(x) is float) or (type(x) is str):
|
||||
return x
|
||||
else:
|
||||
return id(x)
|
||||
|
||||
def make_identity_key(
|
||||
args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> tuple[tuple[int, ...], tuple[tuple[str, int], ...]]:
|
||||
args_key = tuple(get_id(a) for a in args)
|
||||
kwargs_key = tuple(sorted((k, get_id(v)) for k, v in kwargs.items()))
|
||||
return args_key, kwargs_key
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
nonlocal cache, cache_size
|
||||
key = make_identity_key(args, kwargs)
|
||||
if key in cache:
|
||||
cache.move_to_end(key, last=True)
|
||||
_, _, cached_result = cache[key]
|
||||
return cached_result
|
||||
|
||||
result = fn(*args, **kwargs)
|
||||
cache[key] = (args, kwargs, result)
|
||||
cache.move_to_end(key, last=True)
|
||||
if len(cache) > cache_size:
|
||||
cache.popitem(last=False)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@tensor_cache
|
||||
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
|
||||
return torch.diff(cu_seqlens)
|
||||
|
||||
|
||||
@tensor_cache
|
||||
def prepare_chunk_indices(
|
||||
cu_seqlens: torch.LongTensor,
|
||||
chunk_size: int,
|
||||
) -> torch.LongTensor:
|
||||
# TODO: tilelang kernel
|
||||
indices = torch.cat(
|
||||
[
|
||||
torch.arange(n)
|
||||
for n in tilelang.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()
|
||||
]
|
||||
)
|
||||
return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens)
|
||||
|
||||
|
||||
@tilelang.jit()
|
||||
def tilelang_prepare_chunk_offsets(
|
||||
chunk_size,
|
||||
block_size,
|
||||
dtype,
|
||||
):
|
||||
batch_size_plus_1 = T.dynamic("batch_size_plus_1")
|
||||
num_threads = min(max(block_size, 32), 128)
|
||||
|
||||
@T.prim_func
|
||||
def tilelang_prepare_chunk_offsets_kernel(
|
||||
cu_seqlens: T.Tensor([batch_size_plus_1], dtype=dtype),
|
||||
chunk_offsets: T.Tensor([batch_size_plus_1], dtype=dtype),
|
||||
):
|
||||
with T.Kernel(1, threads=num_threads) as (bb,):
|
||||
_batch_size = T.alloc_var("int32")
|
||||
_batch_size = batch_size_plus_1 - 1
|
||||
|
||||
seqlen_start_fragment = T.alloc_fragment((block_size), dtype=dtype)
|
||||
seqlen_end_fragment = T.alloc_fragment((block_size), dtype=dtype)
|
||||
chunk_offset_fragment = T.alloc_fragment((block_size), dtype=dtype)
|
||||
|
||||
T.copy(cu_seqlens[: batch_size_plus_1 - 1], seqlen_start_fragment)
|
||||
T.copy(cu_seqlens[1:], seqlen_end_fragment)
|
||||
|
||||
for i in T.Parallel(block_size):
|
||||
chunk_offset_fragment[i] = (
|
||||
seqlen_end_fragment[i] - seqlen_start_fragment[i]
|
||||
)
|
||||
chunk_offset_fragment[i] = (
|
||||
chunk_offset_fragment[i] + chunk_size - 1
|
||||
) // chunk_size
|
||||
T.cumsum(src=chunk_offset_fragment, dim=0)
|
||||
|
||||
chunk_offsets[0] = 0
|
||||
T.copy(chunk_offset_fragment, chunk_offsets[1:])
|
||||
|
||||
return tilelang_prepare_chunk_offsets_kernel
|
||||
|
||||
|
||||
@tensor_cache
|
||||
def prepare_chunk_offsets(
|
||||
cu_seqlens: torch.LongTensor,
|
||||
chunk_size: int,
|
||||
) -> torch.LongTensor:
|
||||
chunk_offsets = torch.empty_like(cu_seqlens)
|
||||
tilelang_prepare_chunk_offsets_kernel = tilelang_prepare_chunk_offsets(
|
||||
chunk_size=chunk_size,
|
||||
block_size=tilelang.next_power_of_2(cu_seqlens.shape[0] - 1),
|
||||
dtype=cu_seqlens.dtype,
|
||||
)
|
||||
tilelang_prepare_chunk_offsets_kernel(cu_seqlens, chunk_offsets)
|
||||
return chunk_offsets, chunk_offsets[-1].item()
|
||||
Reference in New Issue
Block a user