166 lines
4.7 KiB
Python
166 lines
4.7 KiB
Python
# Copyright (c) 2026 The Qwen team, Alibaba Group.
|
|
# Licensed under The MIT License [see LICENSE for details]
|
|
|
|
import torch
|
|
import tilelang
|
|
import tilelang.language as T
|
|
|
|
from flash_qla.utils import prepare_chunk_indices
|
|
|
|
|
|
@tilelang.jit(
|
|
# out_idx=[-1],
|
|
)
|
|
def tilelang_chunk_local_cumsum(
|
|
H,
|
|
chunk_size,
|
|
accum_dtype,
|
|
g_dtype,
|
|
seqlen_dtype,
|
|
is_varlen,
|
|
reverse,
|
|
):
|
|
data_batch_size = T.dynamic("data_batch_size")
|
|
real_batch_size = T.dynamic("real_batch_size")
|
|
num_tokens = T.dynamic("num_tokens")
|
|
num_chunks = T.dynamic("num_chunks")
|
|
block_S = chunk_size
|
|
|
|
g_shape = (data_batch_size, num_tokens, H)
|
|
|
|
@T.macro
|
|
def kernel_body(
|
|
bb,
|
|
bc,
|
|
batch_idx,
|
|
chunk_idx,
|
|
seq_start_idx,
|
|
seq_end_idx,
|
|
g_raw,
|
|
g_cumsum,
|
|
):
|
|
left = seq_start_idx + chunk_idx * block_S
|
|
right = left + block_S
|
|
|
|
g_fragment = T.alloc_fragment((H, block_S), dtype=accum_dtype)
|
|
gT_fragment = T.alloc_fragment((block_S, H), dtype=g_dtype)
|
|
gT_shared = T.alloc_shared((block_S, H + 1), dtype=g_dtype)
|
|
|
|
if right <= seq_end_idx:
|
|
T.copy(g_raw[bb, left:right, 0:H], gT_fragment)
|
|
else:
|
|
for j, i in T.Parallel(block_S, H):
|
|
if left + j < seq_end_idx:
|
|
gT_fragment[j, i] = g_raw[bb, left + j, i]
|
|
else:
|
|
gT_fragment[j, i] = 0
|
|
T.copy(gT_fragment, gT_shared[:, :H])
|
|
|
|
for i, j in T.Parallel(H, block_S):
|
|
g_fragment[i, j] = gT_shared[j, i]
|
|
|
|
T.cumsum(g_fragment, dim=1, reverse=reverse)
|
|
|
|
for i, j in T.Parallel(H, block_S):
|
|
gT_shared[j, i] = g_fragment[i, j]
|
|
|
|
T.copy(gT_shared[:, :H], gT_fragment)
|
|
if right <= seq_end_idx:
|
|
T.copy(gT_fragment, g_cumsum[bb, left:right, 0:H])
|
|
else:
|
|
for j, i in T.Parallel(block_S, H):
|
|
if left + j < seq_end_idx:
|
|
g_cumsum[bb, left + j, i] = gT_fragment[j, i]
|
|
|
|
if is_varlen:
|
|
|
|
@T.prim_func
|
|
def tilelang_chunk_local_cumsum_kernel(
|
|
g_raw: T.Tensor(g_shape, dtype=g_dtype),
|
|
cu_seqlens: T.Tensor([real_batch_size + 1], dtype=seqlen_dtype),
|
|
chunk_indices: T.Tensor([num_chunks, 2], dtype=seqlen_dtype),
|
|
g_cumsum: T.Tensor(g_shape, dtype=g_dtype),
|
|
):
|
|
with T.Kernel(num_chunks, threads=128) as (bc,):
|
|
bb = 0
|
|
batch_idx = chunk_indices[bc, 0]
|
|
chunk_idx = chunk_indices[bc, 1]
|
|
seq_start_idx = cu_seqlens[batch_idx]
|
|
seq_end_idx = cu_seqlens[batch_idx + 1]
|
|
|
|
kernel_body(
|
|
bb,
|
|
bc,
|
|
batch_idx,
|
|
chunk_idx,
|
|
seq_start_idx,
|
|
seq_end_idx,
|
|
g_raw,
|
|
g_cumsum,
|
|
)
|
|
|
|
else:
|
|
|
|
@T.prim_func
|
|
def tilelang_chunk_local_cumsum_kernel(
|
|
g_raw: T.Tensor(g_shape, dtype=g_dtype),
|
|
g_cumsum: T.Tensor(g_shape, dtype=g_dtype),
|
|
num_chunks: T.int32,
|
|
):
|
|
with T.Kernel(num_chunks, threads=128) as (bc,):
|
|
bb = bc % data_batch_size
|
|
batch_idx = bb
|
|
chunk_idx = bc // data_batch_size
|
|
seq_start_idx = 0
|
|
seq_end_idx = num_tokens
|
|
|
|
kernel_body(
|
|
bb,
|
|
bc,
|
|
batch_idx,
|
|
chunk_idx,
|
|
seq_start_idx,
|
|
seq_end_idx,
|
|
g_raw,
|
|
g_cumsum,
|
|
)
|
|
|
|
return tilelang_chunk_local_cumsum_kernel
|
|
|
|
|
|
def chunk_local_cumsum(
|
|
g: torch.Tensor,
|
|
chunk_size: int = 64,
|
|
cu_seqlens: torch.LongTensor | None = None,
|
|
reverse: bool = False,
|
|
):
|
|
batch_size, num_tokens, H = g.shape
|
|
assert g.stride(-1) == 1
|
|
|
|
if cu_seqlens is None:
|
|
num_chunks = batch_size * tilelang.cdiv(num_tokens, chunk_size)
|
|
seqlen_dtype = "int32"
|
|
is_varlen = False
|
|
else:
|
|
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
|
|
seqlen_dtype = cu_seqlens.dtype
|
|
is_varlen = True
|
|
|
|
g_cumsum = torch.empty_like(g)
|
|
|
|
tilelang_chunk_local_cumsum_kernel = tilelang_chunk_local_cumsum(
|
|
H,
|
|
chunk_size,
|
|
g_dtype=g.dtype,
|
|
seqlen_dtype=seqlen_dtype,
|
|
accum_dtype="float32",
|
|
is_varlen=is_varlen,
|
|
reverse=reverse,
|
|
)
|
|
if is_varlen:
|
|
tilelang_chunk_local_cumsum_kernel(g, cu_seqlens, chunk_indices, g_cumsum)
|
|
else:
|
|
tilelang_chunk_local_cumsum_kernel(g, g_cumsum, num_chunks)
|
|
|
|
return g_cumsum
|