# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang import functools from typing import Any from collections import OrderedDict from collections.abc import Callable import torch import tilelang import tilelang.language as T def tensor_cache( fn: Callable[..., torch.Tensor], ) -> Callable[..., torch.Tensor]: """ A decorator that caches the most recent results of a function with tensor inputs. This decorator will store the output of the decorated function for the most recent set of input tensors. The cache is limited to a fixed size (default is 256). When the cache is full, the oldest entry will be removed. Args: fn (Callable[..., torch.Tensor]): The function to be decorated. It should take tensor inputs and return tensor outputs. Returns: Callable[..., torch.Tensor]: A wrapped version of the input function with single-entry caching. """ cache: "OrderedDict[tuple[tuple[int, ...], tuple[tuple[str, int], ...]], tuple[tuple[Any, ...], dict[str, Any], Any]]" = OrderedDict() cache_size = 256 def get_id(x: Any): if (type(x) is int) or (type(x) is float) or (type(x) is str): return x else: return id(x) def make_identity_key( args: tuple[Any, ...], kwargs: dict[str, Any] ) -> tuple[tuple[int, ...], tuple[tuple[str, int], ...]]: args_key = tuple(get_id(a) for a in args) kwargs_key = tuple(sorted((k, get_id(v)) for k, v in kwargs.items())) return args_key, kwargs_key @functools.wraps(fn) def wrapper(*args: Any, **kwargs: Any) -> Any: nonlocal cache, cache_size key = make_identity_key(args, kwargs) if key in cache: cache.move_to_end(key, last=True) _, _, cached_result = cache[key] return cached_result result = fn(*args, **kwargs) cache[key] = (args, kwargs, result) cache.move_to_end(key, last=True) if len(cache) > cache_size: cache.popitem(last=False) return result return wrapper @tensor_cache def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor: return torch.diff(cu_seqlens) @tensor_cache def prepare_chunk_indices( cu_seqlens: torch.LongTensor, chunk_size: int, ) -> torch.LongTensor: # TODO: tilelang kernel indices = torch.cat( [ torch.arange(n) for n in tilelang.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist() ] ) return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens) @tilelang.jit() def tilelang_prepare_chunk_offsets( chunk_size, block_size, dtype, ): batch_size_plus_1 = T.dynamic("batch_size_plus_1") num_threads = min(max(block_size, 32), 128) @T.prim_func def tilelang_prepare_chunk_offsets_kernel( cu_seqlens: T.Tensor([batch_size_plus_1], dtype=dtype), chunk_offsets: T.Tensor([batch_size_plus_1], dtype=dtype), ): with T.Kernel(1, threads=num_threads) as (bb,): _batch_size = T.alloc_var("int32") _batch_size = batch_size_plus_1 - 1 seqlen_start_fragment = T.alloc_fragment((block_size), dtype=dtype) seqlen_end_fragment = T.alloc_fragment((block_size), dtype=dtype) chunk_offset_fragment = T.alloc_fragment((block_size), dtype=dtype) T.copy(cu_seqlens[: batch_size_plus_1 - 1], seqlen_start_fragment) T.copy(cu_seqlens[1:], seqlen_end_fragment) for i in T.Parallel(block_size): chunk_offset_fragment[i] = ( seqlen_end_fragment[i] - seqlen_start_fragment[i] ) chunk_offset_fragment[i] = ( chunk_offset_fragment[i] + chunk_size - 1 ) // chunk_size T.cumsum(src=chunk_offset_fragment, dim=0) chunk_offsets[0] = 0 T.copy(chunk_offset_fragment, chunk_offsets[1:]) return tilelang_prepare_chunk_offsets_kernel @tensor_cache def prepare_chunk_offsets( cu_seqlens: torch.LongTensor, chunk_size: int, ) -> torch.LongTensor: chunk_offsets = torch.empty_like(cu_seqlens) tilelang_prepare_chunk_offsets_kernel = tilelang_prepare_chunk_offsets( chunk_size=chunk_size, block_size=tilelang.next_power_of_2(cu_seqlens.shape[0] - 1), dtype=cu_seqlens.dtype, ) tilelang_prepare_chunk_offsets_kernel(cu_seqlens, chunk_offsets) return chunk_offsets, chunk_offsets[-1].item()