Files
2026-06-14 23:49:03 +08:00

139 lines
4.5 KiB
Python

# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
import functools
from typing import Any
from collections import OrderedDict
from collections.abc import Callable
import torch
import tilelang
import tilelang.language as T
def tensor_cache(
fn: Callable[..., torch.Tensor],
) -> Callable[..., torch.Tensor]:
"""
A decorator that caches the most recent results of a function with tensor inputs.
This decorator will store the output of the decorated function for the most recent set of input tensors.
The cache is limited to a fixed size (default is 256). When the cache is full, the oldest entry will be removed.
Args:
fn (Callable[..., torch.Tensor]):
The function to be decorated. It should take tensor inputs and return tensor outputs.
Returns:
Callable[..., torch.Tensor]:
A wrapped version of the input function with single-entry caching.
"""
cache: "OrderedDict[tuple[tuple[int, ...], tuple[tuple[str, int], ...]], tuple[tuple[Any, ...], dict[str, Any], Any]]" = OrderedDict()
cache_size = 256
def get_id(x: Any):
if (type(x) is int) or (type(x) is float) or (type(x) is str):
return x
else:
return id(x)
def make_identity_key(
args: tuple[Any, ...], kwargs: dict[str, Any]
) -> tuple[tuple[int, ...], tuple[tuple[str, int], ...]]:
args_key = tuple(get_id(a) for a in args)
kwargs_key = tuple(sorted((k, get_id(v)) for k, v in kwargs.items()))
return args_key, kwargs_key
@functools.wraps(fn)
def wrapper(*args: Any, **kwargs: Any) -> Any:
nonlocal cache, cache_size
key = make_identity_key(args, kwargs)
if key in cache:
cache.move_to_end(key, last=True)
_, _, cached_result = cache[key]
return cached_result
result = fn(*args, **kwargs)
cache[key] = (args, kwargs, result)
cache.move_to_end(key, last=True)
if len(cache) > cache_size:
cache.popitem(last=False)
return result
return wrapper
@tensor_cache
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
return torch.diff(cu_seqlens)
@tensor_cache
def prepare_chunk_indices(
cu_seqlens: torch.LongTensor,
chunk_size: int,
) -> torch.LongTensor:
# TODO: tilelang kernel
indices = torch.cat(
[
torch.arange(n)
for n in tilelang.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()
]
)
return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens)
@tilelang.jit()
def tilelang_prepare_chunk_offsets(
chunk_size,
block_size,
dtype,
):
batch_size_plus_1 = T.dynamic("batch_size_plus_1")
num_threads = min(max(block_size, 32), 128)
@T.prim_func
def tilelang_prepare_chunk_offsets_kernel(
cu_seqlens: T.Tensor([batch_size_plus_1], dtype=dtype),
chunk_offsets: T.Tensor([batch_size_plus_1], dtype=dtype),
):
with T.Kernel(1, threads=num_threads) as (bb,):
_batch_size = T.alloc_var("int32")
_batch_size = batch_size_plus_1 - 1
seqlen_start_fragment = T.alloc_fragment((block_size), dtype=dtype)
seqlen_end_fragment = T.alloc_fragment((block_size), dtype=dtype)
chunk_offset_fragment = T.alloc_fragment((block_size), dtype=dtype)
T.copy(cu_seqlens[: batch_size_plus_1 - 1], seqlen_start_fragment)
T.copy(cu_seqlens[1:], seqlen_end_fragment)
for i in T.Parallel(block_size):
chunk_offset_fragment[i] = (
seqlen_end_fragment[i] - seqlen_start_fragment[i]
)
chunk_offset_fragment[i] = (
chunk_offset_fragment[i] + chunk_size - 1
) // chunk_size
T.cumsum(src=chunk_offset_fragment, dim=0)
chunk_offsets[0] = 0
T.copy(chunk_offset_fragment, chunk_offsets[1:])
return tilelang_prepare_chunk_offsets_kernel
@tensor_cache
def prepare_chunk_offsets(
cu_seqlens: torch.LongTensor,
chunk_size: int,
) -> torch.LongTensor:
chunk_offsets = torch.empty_like(cu_seqlens)
tilelang_prepare_chunk_offsets_kernel = tilelang_prepare_chunk_offsets(
chunk_size=chunk_size,
block_size=tilelang.next_power_of_2(cu_seqlens.shape[0] - 1),
dtype=cu_seqlens.dtype,
)
tilelang_prepare_chunk_offsets_kernel(cu_seqlens, chunk_offsets)
return chunk_offsets, chunk_offsets[-1].item()