first commit
This commit is contained in:
11
flash_qla/ops/utils/__init__.py
Normal file
11
flash_qla/ops/utils/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
# Copyright (c) 2026 The Qwen team, Alibaba Group.
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
from .cumsum import chunk_local_cumsum
|
||||
from .group_reduce import group_reduce_vector
|
||||
|
||||
|
||||
__all__ = [
|
||||
"chunk_local_cumsum",
|
||||
"group_reduce_vector",
|
||||
]
|
||||
165
flash_qla/ops/utils/cumsum.py
Normal file
165
flash_qla/ops/utils/cumsum.py
Normal file
@@ -0,0 +1,165 @@
|
||||
# Copyright (c) 2026 The Qwen team, Alibaba Group.
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
import torch
|
||||
import tilelang
|
||||
import tilelang.language as T
|
||||
|
||||
from flash_qla.utils import prepare_chunk_indices
|
||||
|
||||
|
||||
@tilelang.jit(
|
||||
# out_idx=[-1],
|
||||
)
|
||||
def tilelang_chunk_local_cumsum(
|
||||
H,
|
||||
chunk_size,
|
||||
accum_dtype,
|
||||
g_dtype,
|
||||
seqlen_dtype,
|
||||
is_varlen,
|
||||
reverse,
|
||||
):
|
||||
data_batch_size = T.dynamic("data_batch_size")
|
||||
real_batch_size = T.dynamic("real_batch_size")
|
||||
num_tokens = T.dynamic("num_tokens")
|
||||
num_chunks = T.dynamic("num_chunks")
|
||||
block_S = chunk_size
|
||||
|
||||
g_shape = (data_batch_size, num_tokens, H)
|
||||
|
||||
@T.macro
|
||||
def kernel_body(
|
||||
bb,
|
||||
bc,
|
||||
batch_idx,
|
||||
chunk_idx,
|
||||
seq_start_idx,
|
||||
seq_end_idx,
|
||||
g_raw,
|
||||
g_cumsum,
|
||||
):
|
||||
left = seq_start_idx + chunk_idx * block_S
|
||||
right = left + block_S
|
||||
|
||||
g_fragment = T.alloc_fragment((H, block_S), dtype=accum_dtype)
|
||||
gT_fragment = T.alloc_fragment((block_S, H), dtype=g_dtype)
|
||||
gT_shared = T.alloc_shared((block_S, H + 1), dtype=g_dtype)
|
||||
|
||||
if right <= seq_end_idx:
|
||||
T.copy(g_raw[bb, left:right, 0:H], gT_fragment)
|
||||
else:
|
||||
for j, i in T.Parallel(block_S, H):
|
||||
if left + j < seq_end_idx:
|
||||
gT_fragment[j, i] = g_raw[bb, left + j, i]
|
||||
else:
|
||||
gT_fragment[j, i] = 0
|
||||
T.copy(gT_fragment, gT_shared[:, :H])
|
||||
|
||||
for i, j in T.Parallel(H, block_S):
|
||||
g_fragment[i, j] = gT_shared[j, i]
|
||||
|
||||
T.cumsum(g_fragment, dim=1, reverse=reverse)
|
||||
|
||||
for i, j in T.Parallel(H, block_S):
|
||||
gT_shared[j, i] = g_fragment[i, j]
|
||||
|
||||
T.copy(gT_shared[:, :H], gT_fragment)
|
||||
if right <= seq_end_idx:
|
||||
T.copy(gT_fragment, g_cumsum[bb, left:right, 0:H])
|
||||
else:
|
||||
for j, i in T.Parallel(block_S, H):
|
||||
if left + j < seq_end_idx:
|
||||
g_cumsum[bb, left + j, i] = gT_fragment[j, i]
|
||||
|
||||
if is_varlen:
|
||||
|
||||
@T.prim_func
|
||||
def tilelang_chunk_local_cumsum_kernel(
|
||||
g_raw: T.Tensor(g_shape, dtype=g_dtype),
|
||||
cu_seqlens: T.Tensor([real_batch_size + 1], dtype=seqlen_dtype),
|
||||
chunk_indices: T.Tensor([num_chunks, 2], dtype=seqlen_dtype),
|
||||
g_cumsum: T.Tensor(g_shape, dtype=g_dtype),
|
||||
):
|
||||
with T.Kernel(num_chunks, threads=128) as (bc,):
|
||||
bb = 0
|
||||
batch_idx = chunk_indices[bc, 0]
|
||||
chunk_idx = chunk_indices[bc, 1]
|
||||
seq_start_idx = cu_seqlens[batch_idx]
|
||||
seq_end_idx = cu_seqlens[batch_idx + 1]
|
||||
|
||||
kernel_body(
|
||||
bb,
|
||||
bc,
|
||||
batch_idx,
|
||||
chunk_idx,
|
||||
seq_start_idx,
|
||||
seq_end_idx,
|
||||
g_raw,
|
||||
g_cumsum,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
@T.prim_func
|
||||
def tilelang_chunk_local_cumsum_kernel(
|
||||
g_raw: T.Tensor(g_shape, dtype=g_dtype),
|
||||
g_cumsum: T.Tensor(g_shape, dtype=g_dtype),
|
||||
num_chunks: T.int32,
|
||||
):
|
||||
with T.Kernel(num_chunks, threads=128) as (bc,):
|
||||
bb = bc % data_batch_size
|
||||
batch_idx = bb
|
||||
chunk_idx = bc // data_batch_size
|
||||
seq_start_idx = 0
|
||||
seq_end_idx = num_tokens
|
||||
|
||||
kernel_body(
|
||||
bb,
|
||||
bc,
|
||||
batch_idx,
|
||||
chunk_idx,
|
||||
seq_start_idx,
|
||||
seq_end_idx,
|
||||
g_raw,
|
||||
g_cumsum,
|
||||
)
|
||||
|
||||
return tilelang_chunk_local_cumsum_kernel
|
||||
|
||||
|
||||
def chunk_local_cumsum(
|
||||
g: torch.Tensor,
|
||||
chunk_size: int = 64,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
reverse: bool = False,
|
||||
):
|
||||
batch_size, num_tokens, H = g.shape
|
||||
assert g.stride(-1) == 1
|
||||
|
||||
if cu_seqlens is None:
|
||||
num_chunks = batch_size * tilelang.cdiv(num_tokens, chunk_size)
|
||||
seqlen_dtype = "int32"
|
||||
is_varlen = False
|
||||
else:
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
|
||||
seqlen_dtype = cu_seqlens.dtype
|
||||
is_varlen = True
|
||||
|
||||
g_cumsum = torch.empty_like(g)
|
||||
|
||||
tilelang_chunk_local_cumsum_kernel = tilelang_chunk_local_cumsum(
|
||||
H,
|
||||
chunk_size,
|
||||
g_dtype=g.dtype,
|
||||
seqlen_dtype=seqlen_dtype,
|
||||
accum_dtype="float32",
|
||||
is_varlen=is_varlen,
|
||||
reverse=reverse,
|
||||
)
|
||||
if is_varlen:
|
||||
tilelang_chunk_local_cumsum_kernel(g, cu_seqlens, chunk_indices, g_cumsum)
|
||||
else:
|
||||
tilelang_chunk_local_cumsum_kernel(g, g_cumsum, num_chunks)
|
||||
|
||||
return g_cumsum
|
||||
79
flash_qla/ops/utils/group_reduce.py
Normal file
79
flash_qla/ops/utils/group_reduce.py
Normal file
@@ -0,0 +1,79 @@
|
||||
# Copyright (c) 2026 The Qwen team, Alibaba Group.
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
import torch
|
||||
import tilelang
|
||||
import tilelang.language as T
|
||||
|
||||
|
||||
@tilelang.jit(
|
||||
# out_idx=[-1],
|
||||
)
|
||||
def tilelang_group_reduce_vector(
|
||||
H,
|
||||
Hg,
|
||||
DK,
|
||||
accum_dtype,
|
||||
qkva_dtype,
|
||||
block_size: int = 16,
|
||||
):
|
||||
batch_size = T.dynamic("batch_size")
|
||||
num_tokens = T.dynamic("num_tokens")
|
||||
|
||||
group_size = H // Hg
|
||||
|
||||
buffer_shape = (batch_size, num_tokens, H, DK)
|
||||
dqk_shape = (batch_size, num_tokens, Hg, DK)
|
||||
|
||||
@T.prim_func
|
||||
def tilelang_group_reduce_vector_kernel(
|
||||
buffer: T.Tensor(buffer_shape, dtype=qkva_dtype),
|
||||
result: T.Tensor(dqk_shape, dtype=qkva_dtype),
|
||||
):
|
||||
with T.Kernel(
|
||||
tilelang.cdiv(num_tokens, block_size), Hg, batch_size, threads=128
|
||||
) as (bt, bhg, bb):
|
||||
buffer_fragment = T.alloc_fragment((block_size, DK), dtype=accum_dtype)
|
||||
result_fragment = T.alloc_fragment((block_size, DK), dtype=accum_dtype)
|
||||
|
||||
T.clear(result_fragment)
|
||||
for i in T.serial(group_size):
|
||||
T.copy(
|
||||
buffer[
|
||||
bb,
|
||||
bt * block_size : (bt + 1) * block_size,
|
||||
bhg * group_size + i,
|
||||
0:DK,
|
||||
],
|
||||
buffer_fragment,
|
||||
)
|
||||
for j, k in T.Parallel(block_size, DK):
|
||||
result_fragment[j, k] += buffer_fragment[j, k]
|
||||
T.copy(
|
||||
result_fragment,
|
||||
result[bb, bt * block_size : (bt + 1) * block_size, bhg, 0:DK],
|
||||
)
|
||||
|
||||
return tilelang_group_reduce_vector_kernel
|
||||
|
||||
|
||||
def group_reduce_vector(
|
||||
buffer: torch.Tensor,
|
||||
Hg: int,
|
||||
):
|
||||
batch_size, num_tokens, H, K = buffer.shape
|
||||
|
||||
result = torch.empty(
|
||||
(batch_size, num_tokens, Hg, K), dtype=buffer.dtype, device=buffer.device
|
||||
)
|
||||
|
||||
tilelang_group_reduce_vector_kernel = tilelang_group_reduce_vector(
|
||||
H,
|
||||
Hg,
|
||||
K,
|
||||
qkva_dtype=buffer.dtype,
|
||||
accum_dtype="float32",
|
||||
)
|
||||
tilelang_group_reduce_vector_kernel(buffer, result)
|
||||
|
||||
return result
|
||||
Reference in New Issue
Block a user