first commit

This commit is contained in:
2026-06-14 23:49:03 +08:00
commit 3f95e2939d
35 changed files with 6764 additions and 0 deletions

View File

@@ -0,0 +1,237 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
import torch
import tilelang
from flash_qla.utils import l2norm
from flash_qla.ops.utils import chunk_local_cumsum, group_reduce_vector
if tilelang.contrib.nvcc.get_target_compute_version() == "9.0":
from .hopper import fused_gdr_fwd, fused_gdr_bwd, fused_gdr_h, kkt_solve
else:
raise ValueError("FlashQLA now support sm90 only.")
from .cp_context import intra_card_cp_preprocess
def chunk_gated_delta_rule_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float | None = None,
initial_state: torch.Tensor | None = None,
cu_seqlens: torch.LongTensor | None = None,
output_final_state: bool = True,
output_h: bool = False,
auto_cp: bool = True,
):
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
A = kkt_solve(
k=k,
b=beta,
cu_seqlens=cu_seqlens,
)
if auto_cp:
initial_state, cu_seqlens, cp_seq_map, raw_cu_seqlens = (
intra_card_cp_preprocess(
k=k,
v=v,
a=A,
g=g,
b=beta,
raw_h0=initial_state,
raw_cu_seqlens=cu_seqlens,
)
)
else:
cp_seq_map = None
raw_cu_seqlens = None
o, h, final_state = fused_gdr_fwd(
q=q,
k=k,
v=v,
a=A,
g=g,
b=beta,
scale=scale,
initial_state=initial_state,
output_final_state=output_final_state,
output_h=output_h,
output_o=True,
cu_seqlens=cu_seqlens,
cp_seq_map=cp_seq_map,
raw_cu_seqlens=raw_cu_seqlens,
)
return g, A, o, h, final_state
def chunk_gated_delta_rule_bwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
A: torch.Tensor,
do: torch.Tensor,
dht: torch.Tensor | None = None,
scale: float | None = None,
initial_state: torch.Tensor | None = None,
cu_seqlens: torch.LongTensor | None = None,
):
h, _, _ = fused_gdr_h(
k=k,
v=v,
a=A,
g=g,
b=beta,
initial_state=initial_state,
output_final_state=False,
output_h=True,
cu_seqlens=cu_seqlens,
)
dq, dk, dv, dg, db, dh0 = fused_gdr_bwd(
q=q,
k=k,
v=v,
a=A,
g=g,
b=beta,
do=do,
dht=dht,
h=h,
scale=scale,
cu_seqlens=cu_seqlens,
)
Hg, H = k.shape[-2], v.shape[-2]
if Hg < H:
dq = group_reduce_vector(dq, Hg)
dk = group_reduce_vector(dk, Hg)
assert dg.dtype == torch.float32, "dg should be fp32"
dg = chunk_local_cumsum(dg, chunk_size=64, reverse=True, cu_seqlens=cu_seqlens)
return dq, dk, dv, db, dg, dh0
class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
@staticmethod
@torch.amp.custom_fwd(device_type="cuda")
def forward(
ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float | None = None,
initial_state: torch.Tensor | None = None,
output_final_state: bool = False,
cu_seqlens: torch.LongTensor | None = None,
):
q_orig = q
k_orig = k
g, A, o, _, final_state = chunk_gated_delta_rule_fwd(
q=q,
k=k,
v=v,
g=g,
beta=beta,
scale=scale,
initial_state=initial_state,
output_final_state=output_final_state,
output_h=False,
cu_seqlens=cu_seqlens,
)
ctx.save_for_backward(q_orig, k_orig, v, g, beta, A, initial_state, cu_seqlens)
ctx.scale = scale
return o.to(q.dtype), final_state
@staticmethod
@torch.amp.custom_bwd(device_type="cuda")
def backward(ctx, do: torch.Tensor, dht: torch.Tensor):
q_orig, k_orig, v, g, beta, A, initial_state, cu_seqlens = ctx.saved_tensors
dq, dk, dv, db, dg, dh0 = chunk_gated_delta_rule_bwd(
q=q_orig,
k=k_orig,
v=v,
g=g,
beta=beta,
A=A,
do=do,
dht=dht,
scale=ctx.scale,
initial_state=initial_state,
cu_seqlens=cu_seqlens,
)
return (
dq.to(q_orig),
dk.to(k_orig),
dv.to(v),
dg.to(g),
db.to(beta),
None,
dh0,
None,
None,
)
@torch.compiler.disable
def chunk_gated_delta_rule(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
use_qk_l2norm_in_kernel: bool = False,
cu_seqlens: torch.LongTensor | None = None,
head_first: bool = False,
):
assert q.dtype == k.dtype == v.dtype
assert q.dtype != torch.float32, (
"ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16 or float16."
)
assert not head_first, "head_first=True is not supported."
assert v.shape[2] % k.shape[2] == 0, (
"num_qk_heads must be divisible to num_v_heads."
)
if cu_seqlens is not None:
if q.shape[0] != 1:
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing."
)
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
raise ValueError(
f"The number of initial states is expected to be equal to the number of input sequences, "
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
)
if scale is None:
scale = k.shape[-1] ** -0.5
if use_qk_l2norm_in_kernel:
q = l2norm(q)
k = l2norm(k)
o, final_state = ChunkGatedDeltaRuleFunction.apply(
q,
k,
v,
g,
beta,
scale,
initial_state,
output_final_state,
cu_seqlens,
)
return o, final_state

View File

@@ -0,0 +1,163 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
import math
import torch
import tilelang
from flash_qla.utils import tensor_cache
if tilelang.contrib.nvcc.get_target_compute_version() == "9.0":
from .hopper import get_warmup_chunks, fused_gdr_h, correct_initial_states
else:
raise ValueError("FlashQLA now support sm90 only.")
MULTI_PROCESSOR_COUNT = torch.cuda.get_device_properties().multi_processor_count
@tensor_cache
def _create_cu_seqlens(
batch_size: int,
num_tokens: int,
device_idx: int,
):
return (
torch.arange((batch_size + 1), dtype=torch.int32, device=f"cuda:{device_idx}")
* num_tokens
)
@tensor_cache
def _calc_cp_seqs(
raw_cu_seqlens: torch.LongTensor,
chunk_size: int,
num_v_heads: int,
):
# TODO: tilelang kernel
device = raw_cu_seqlens.device
seqlen_dtype = raw_cu_seqlens.dtype
raw_cu_seqlens = raw_cu_seqlens.tolist()
raw_batch_size = len(raw_cu_seqlens) - 1
seqlens = [raw_cu_seqlens[i + 1] - raw_cu_seqlens[i] for i in range(raw_batch_size)]
num_chunks = [tilelang.cdiv(x, chunk_size) for x in seqlens]
# autocp
H = num_v_heads
# Latency model: T = a·L_cp + b·(B·H·Lc/P) / L_cp + c
# Minimizing T yields the theoretical optimum: L_cp* ∝ √(B·H·Lc / P), where P = MULTI_PROCESSOR_COUNT, L_cp = max_local_chunks
# Scaled by empirical factor (3) and aligned to the nearest power of 2 for optimal SM scheduling & memory alignment.
max_local_chunks = 2 ** round(
math.log2(math.sqrt(H * sum(num_chunks) / MULTI_PROCESSOR_COUNT) * 3)
)
# Set min to 4 to ensure multi-stage pipelining in fused_gdr;
max_local_chunks = max(max_local_chunks, 4)
use_cp = False
cp_cu_seqlens = []
ht_mask = []
seq_map_c2r = []
seq_map_r2c = [0]
max_local_tokens = max_local_chunks * chunk_size
for i, c in enumerate(num_chunks):
s = raw_cu_seqlens[i]
e = raw_cu_seqlens[i + 1]
if c > max_local_chunks:
while s < e:
cp_cu_seqlens.append(s)
ht_mask.append(False)
seq_map_c2r.append(i)
s += max_local_tokens
ht_mask[-1] = True
else:
cp_cu_seqlens.append(s)
ht_mask.append(True)
seq_map_c2r.append(i)
seq_map_r2c.append(len(cp_cu_seqlens))
cp_cu_seqlens.append(raw_cu_seqlens[-1])
# Disable CP when B * H naturally saturates SM occupancy.
# For varlen inputs, use `total_chunks / max_seq_chunks` as effective B,
# since CP helps accelerate highly uneven sequence lengths.
Be = sum(num_chunks) / max(num_chunks)
use_cp = Be * H <= 40 or (Be * H <= 56 and max(num_chunks) >= 128)
if use_cp:
cp_cu_seqlens = torch.tensor(
cp_cu_seqlens, dtype=seqlen_dtype, device=device, requires_grad=False
)
seq_map_c2r = torch.tensor(seq_map_c2r, dtype=seqlen_dtype, device=device)
seq_map_r2c = torch.tensor(
seq_map_r2c, dtype=seqlen_dtype, device=device, requires_grad=False
)
ht_mask = torch.tensor(
ht_mask, dtype=torch.bool, device=device, requires_grad=False
)
else:
cp_cu_seqlens, seq_map_r2c, ht_mask = None, None, None
return use_cp, cp_cu_seqlens, seq_map_r2c, seq_map_c2r, ht_mask
def intra_card_cp_preprocess(
k: torch.Tensor,
v: torch.Tensor,
a: torch.Tensor,
g: torch.Tensor,
b: torch.Tensor,
raw_h0: torch.Tensor,
raw_cu_seqlens: torch.Tensor,
warmup_threshold: float = -10.0,
):
batch_size, num_tokens, num_k_heads, k_head_dim = k.shape
_, _, num_v_heads, v_head_dim = v.shape
chunk_size = a.shape[-1]
device = k.device
if batch_size > 1:
return raw_h0, raw_cu_seqlens, None, None
if raw_cu_seqlens is None:
raw_cu_seqlens = _create_cu_seqlens(batch_size, num_tokens, device.index)
use_cp, cp_cu_seqlens, seq_map_r2c, seq_map_c2r, ht_mask = _calc_cp_seqs(
raw_cu_seqlens,
chunk_size,
num_v_heads,
)
if not use_cp:
return raw_h0, raw_cu_seqlens, None, None
num_warmup_chunks, fallback_mask = get_warmup_chunks(
g=g,
cu_seqlens=cp_cu_seqlens,
ht_mask=ht_mask,
chunk_size=chunk_size,
threshold=warmup_threshold,
) # [cp_batch_size, num_v_heads]
_, ht, mt = fused_gdr_h(
k=k,
v=v,
a=a,
g=g,
b=b,
initial_state=None,
output_final_state=True,
output_h=False,
cu_seqlens=cp_cu_seqlens,
num_warmup_chunks=num_warmup_chunks,
) # [cp_batch_size, num_v_heads, k_head_dim, v_head_dim]
cp_h0 = correct_initial_states(
raw_h0=raw_h0,
ht_buffer=ht,
mt_buffer=mt,
fallback_mask=fallback_mask,
seq_map_r2c=seq_map_r2c,
)
return cp_h0, cp_cu_seqlens, seq_map_c2r, raw_cu_seqlens

View File

@@ -0,0 +1,18 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
from .fused_fwd import fused_gdr_fwd
from .fused_bwd import fused_gdr_bwd
from .prepare_h import fused_gdr_h
from .kkt_solve import kkt_solve
from .cp_fwd import get_warmup_chunks, correct_initial_states
__all__ = [
"fused_gdr_fwd",
"fused_gdr_bwd",
"fused_gdr_h",
"kkt_solve",
"get_warmup_chunks",
"correct_initial_states",
]

View File

@@ -0,0 +1,309 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
import torch
import tilelang
import tilelang.language as T
@tilelang.jit()
def tilelang_get_warmup_chunks(
num_heads,
chunk_size,
threshold,
accum_dtype,
g_dtype,
mask_dtype,
seqlen_dtype,
):
batch_size = T.dynamic("batch_size")
num_tokens = T.dynamic("num_tokens")
num_threads = tilelang.cdiv(num_heads, 32) * 32
@T.prim_func
def tilelang_get_warmup_chunks_kernel(
g: T.Tensor([1, num_tokens, num_heads], dtype=g_dtype),
ht_mask: T.Tensor([batch_size], dtype=mask_dtype),
cu_seqlens: T.Tensor([batch_size + 1], dtype=seqlen_dtype),
num_warmup_chunks: T.Tensor([batch_size, num_heads], dtype=seqlen_dtype),
fallback_mask: T.Tensor([batch_size, num_heads], dtype=mask_dtype),
):
with T.Kernel(batch_size, threads=num_threads) as (bb,):
if ht_mask[bb]:
for i_h in T.Parallel(num_heads):
num_warmup_chunks[bb, i_h] = 0
else:
seq_start_idx = T.alloc_var("int32")
seq_end_idx = T.alloc_var("int32")
num_iters = T.alloc_var("int32")
seq_start_idx = cu_seqlens[bb]
seq_end_idx = cu_seqlens[bb + 1]
num_iters = (seq_end_idx - seq_start_idx) // chunk_size
g_fragment = T.alloc_fragment((num_heads), dtype=accum_dtype)
g_cumsum = T.alloc_fragment((num_heads), dtype=accum_dtype)
n_fragment = T.alloc_fragment((num_heads), dtype=seqlen_dtype)
f_fragment = T.alloc_fragment((num_heads), dtype=mask_dtype)
T.clear(g_cumsum)
T.fill(n_fragment, num_iters)
T.fill(f_fragment, True)
for i_s in T.serial(num_iters):
for i_h in T.Parallel(num_heads):
g_fragment[i_h] = g[0, seq_end_idx - i_s * chunk_size - 1, i_h]
for i_h in T.Parallel(num_heads):
g_cumsum[i_h] += g_fragment[i_h]
for i_h in T.Parallel(num_heads):
if g_cumsum[i_h] < threshold and n_fragment[i_h] == num_iters:
n_fragment[i_h] = i_s + 1
f_fragment[i_h] = False
for i_h in T.Parallel(num_heads):
num_warmup_chunks[bb, i_h] = n_fragment[i_h]
for i_h in T.Parallel(num_heads):
fallback_mask[bb, i_h] = f_fragment[i_h]
return tilelang_get_warmup_chunks_kernel
def get_warmup_chunks(
g: torch.Tensor, # [1, num_total_tokens, num_v_heads]
cu_seqlens: torch.Tensor, # [cp_real_batch_size + 1]
ht_mask: torch.Tensor, # [cp_real_batch_size]
chunk_size: int = 64,
threshold: float = -10.0,
):
batch_size, num_tokens, num_heads = g.shape
real_batch_size = ht_mask.shape[0]
assert cu_seqlens.shape[0] == real_batch_size + 1
assert batch_size == 1
assert chunk_size == 64
tilelang_get_warmup_chunks_kernel = tilelang_get_warmup_chunks(
num_heads=num_heads,
chunk_size=chunk_size,
threshold=threshold,
accum_dtype="float32",
g_dtype=g.dtype,
mask_dtype=ht_mask.dtype,
seqlen_dtype=cu_seqlens.dtype,
)
num_warmup_chunks = torch.empty(
[real_batch_size, num_heads], dtype=cu_seqlens.dtype, device=cu_seqlens.device
)
fallback_mask = torch.empty(
[real_batch_size, num_heads], dtype=ht_mask.dtype, device=cu_seqlens.device
)
tilelang_get_warmup_chunks_kernel(
g, ht_mask, cu_seqlens, num_warmup_chunks, fallback_mask
)
return num_warmup_chunks, fallback_mask
@tilelang.jit()
def tilelang_correct_h0(
H,
DK,
DV,
res_dtype,
accum_dtype,
buffer_dtype,
seqlen_dtype,
mask_dtype,
use_raw_h0,
block_DV: int = 32,
):
cp_batch_size = T.dynamic("cp_batch_size")
raw_batch_size = T.dynamic("raw_batch_size")
@T.macro
def kernel_body(
bb,
bh,
bv,
seq_start_idx,
seq_end_idx,
num_iters,
ht_buffer,
mt_buffer,
fallback_mask,
seq_map_r2c,
cp_h0,
h_fragment,
):
h_shared = T.alloc_shared((DK, block_DV), dtype=buffer_dtype)
hd_shared = T.alloc_shared((DK, block_DV), dtype=buffer_dtype)
m_shared = T.alloc_shared((DK, DK), dtype=buffer_dtype)
T.copy(
h_fragment,
cp_h0[seq_start_idx, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV],
)
for i_s in T.Pipelined(num_iters - 1, num_stages=2):
if fallback_mask[seq_start_idx + i_s, bh]:
T.copy(h_fragment, hd_shared)
T.copy(
ht_buffer[
seq_start_idx + i_s, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV
],
h_shared,
)
T.copy(h_shared, h_fragment)
if fallback_mask[seq_start_idx + i_s, bh]:
T.copy(mt_buffer[seq_start_idx + i_s, bh, 0:DK, 0:DK], m_shared)
T.gemm(m_shared, hd_shared, h_fragment, clear_accum=False)
T.copy(
h_fragment,
cp_h0[
seq_start_idx + i_s + 1,
bh,
0:DK,
bv * block_DV : (bv + 1) * block_DV,
],
)
if use_raw_h0:
@T.prim_func
def tilelang_correct_h0_kernel(
raw_h0: T.Tensor([raw_batch_size, H, DK, DV], dtype=res_dtype),
ht_buffer: T.Tensor([cp_batch_size, H, DK, DV], dtype=buffer_dtype),
mt_buffer: T.Tensor([cp_batch_size, H, DK, DK], dtype=buffer_dtype),
fallback_mask: T.Tensor([cp_batch_size, H], dtype=mask_dtype),
seq_map_r2c: T.Tensor([raw_batch_size + 1], dtype=seqlen_dtype),
cp_h0: T.Tensor([cp_batch_size, H, DK, DV], dtype=res_dtype),
):
with T.Kernel(
T.ceildiv(DV, block_DV) * H * raw_batch_size, threads=128
) as (bbhv,):
bbh, bv = (
bbhv // T.ceildiv(DV, block_DV),
bbhv % T.ceildiv(DV, block_DV),
)
bb, bh = bbh // H, bbh % H
seq_start_idx = seq_map_r2c[bb]
seq_end_idx = seq_map_r2c[bb + 1]
num_iters = seq_end_idx - seq_start_idx
h_fragment = T.alloc_fragment((DK, block_DV), dtype=accum_dtype)
T.copy(
raw_h0[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV],
h_fragment,
)
kernel_body(
bb,
bh,
bv,
seq_start_idx,
seq_end_idx,
num_iters,
ht_buffer,
mt_buffer,
fallback_mask,
seq_map_r2c,
cp_h0,
h_fragment,
)
else:
@T.prim_func
def tilelang_correct_h0_kernel(
ht_buffer: T.Tensor([cp_batch_size, H, DK, DV], dtype=buffer_dtype),
mt_buffer: T.Tensor([cp_batch_size, H, DK, DK], dtype=buffer_dtype),
fallback_mask: T.Tensor([cp_batch_size, H], dtype=mask_dtype),
seq_map_r2c: T.Tensor([raw_batch_size + 1], dtype=seqlen_dtype),
cp_h0: T.Tensor([cp_batch_size, H, DK, DV], dtype=res_dtype),
):
with T.Kernel(
T.ceildiv(DV, block_DV) * H * raw_batch_size, threads=128
) as (bbhv,):
bbh, bv = (
bbhv // T.ceildiv(DV, block_DV),
bbhv % T.ceildiv(DV, block_DV),
)
bb, bh = bbh // H, bbh % H
seq_start_idx = seq_map_r2c[bb]
seq_end_idx = seq_map_r2c[bb + 1]
num_iters = seq_end_idx - seq_start_idx
h_fragment = T.alloc_fragment((DK, block_DV), dtype=accum_dtype)
T.clear(h_fragment)
kernel_body(
bb,
bh,
bv,
seq_start_idx,
seq_end_idx,
num_iters,
ht_buffer,
mt_buffer,
fallback_mask,
seq_map_r2c,
cp_h0,
h_fragment,
)
return tilelang_correct_h0_kernel
def correct_initial_states(
raw_h0: torch.Tensor
| None, # [raw_batch_size, num_v_heads, k_head_dim, v_head_dim]
ht_buffer: torch.Tensor, # [cp_batch_size, num_v_heads, k_head_dim, v_head_dim]
mt_buffer: torch.Tensor, # [cp_batch_size, num_v_heads, k_head_dim, k_head_dim]
fallback_mask: torch.Tensor, # [cp_batch_size, num_v_heads]
seq_map_r2c: torch.Tensor, # [raw_batch_size + 1]
):
cp_batch_size = fallback_mask.shape[0]
_, num_heads, k_head_dim, v_head_dim = ht_buffer.shape
assert k_head_dim == v_head_dim == 128
if raw_h0 is None:
res_dtype = torch.float32
use_raw_h0 = False
else:
res_dtype = raw_h0.dtype
use_raw_h0 = True
tilelang_correct_h0_kernel = tilelang_correct_h0(
H=num_heads,
DK=k_head_dim,
DV=v_head_dim,
res_dtype=res_dtype,
accum_dtype="float32",
buffer_dtype=ht_buffer.dtype,
seqlen_dtype=seq_map_r2c.dtype,
mask_dtype=fallback_mask.dtype,
use_raw_h0=use_raw_h0,
)
cp_h0 = torch.empty(
(cp_batch_size, num_heads, k_head_dim, v_head_dim),
dtype=res_dtype,
device=ht_buffer.device,
)
if use_raw_h0:
tilelang_correct_h0_kernel(
raw_h0,
ht_buffer,
mt_buffer,
fallback_mask,
seq_map_r2c,
cp_h0,
)
else:
tilelang_correct_h0_kernel(
ht_buffer,
mt_buffer,
fallback_mask,
seq_map_r2c,
cp_h0,
)
return cp_h0

View File

@@ -0,0 +1,985 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
import torch
import tilelang
import tilelang.language as T
from flash_qla.utils import prepare_chunk_offsets
@tilelang.jit(
# out_idx=[-5, -4, -3, -2, -1],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
tilelang.PassConfigKey.TL_DISABLE_DATA_RACE_CHECK: True,
},
)
def tilelang_fused_chunk_gdr_bwd(
H,
Hg,
DK,
DV,
chunk_size,
scale,
accum_dtype,
qkva_dtype,
g_dtype,
b_dtype,
h_dtype,
o_dtype,
seqlen_dtype,
is_varlen,
use_dht,
):
batch_size = T.dynamic("batch_size")
num_tokens = T.dynamic("num_tokens")
num_chunks = T.dynamic("num_chunks")
block_S = chunk_size
if is_varlen:
q_shape = (1, num_tokens, Hg, DK)
k_shape = (1, num_tokens, Hg, DK)
v_shape = (1, num_tokens, H, DV)
o_shape = (1, num_tokens, H, DV)
a_shape = (1, num_tokens, H, chunk_size)
g_shape = (1, num_tokens, H)
b_shape = (1, num_tokens, H)
h_shape = (1, num_chunks, H, DK, DV)
else:
q_shape = (batch_size, num_tokens, Hg, DK)
k_shape = (batch_size, num_tokens, Hg, DK)
v_shape = (batch_size, num_tokens, H, DV)
o_shape = (batch_size, num_tokens, H, DV)
a_shape = (batch_size, num_tokens, H, chunk_size)
g_shape = (batch_size, num_tokens, H)
b_shape = (batch_size, num_tokens, H)
h_shape = (batch_size, num_chunks, H, DK, DV)
h0_shape = (batch_size, H, DK, DV)
ht_shape = (batch_size, H, DK, DV)
@T.prim_func
def tilelang_fused_chunk_gdr_bwd_kernel(
do: T.Tensor(o_shape, dtype=o_dtype),
dht: T.Tensor(ht_shape, dtype=accum_dtype),
q: T.Tensor(q_shape, dtype=qkva_dtype),
k: T.Tensor(k_shape, dtype=qkva_dtype),
v: T.Tensor(v_shape, dtype=qkva_dtype),
a: T.Tensor(a_shape, dtype=qkva_dtype),
g: T.Tensor(g_shape, dtype=g_dtype),
b: T.Tensor(b_shape, dtype=b_dtype),
h: T.Tensor(h_shape, dtype=h_dtype),
cu_seqlens: T.Tensor([batch_size + 1], dtype=seqlen_dtype),
chunk_offsets: T.Tensor([batch_size + 1], dtype=seqlen_dtype),
dq: T.Tensor(v_shape, dtype=qkva_dtype),
dk: T.Tensor(v_shape, dtype=qkva_dtype),
dv: T.Tensor(v_shape, dtype=qkva_dtype),
dg: T.Tensor(g_shape, dtype=g_dtype),
db: T.Tensor(b_shape, dtype=b_dtype),
dh0: T.Tensor(h0_shape, dtype=accum_dtype),
):
with T.Kernel(batch_size * H, threads=512) as (bbh,):
bb, bh = bbh // H, bbh % H
bhg = bh // (H // Hg)
batch_idx = T.alloc_var("int32")
seq_start_idx = T.alloc_var("int32")
seq_end_idx = T.alloc_var("int32")
chunk_start_idx = T.alloc_var("int32")
batch_idx = 0 if is_varlen else bb
seq_start_idx = cu_seqlens[bb] if is_varlen else 0
seq_end_idx = cu_seqlens[bb + 1] if is_varlen else num_tokens
chunk_start_idx = chunk_offsets[bb] if is_varlen else 0
num_iters = T.alloc_var("int32")
num_iters = T.ceildiv(seq_end_idx - seq_start_idx, block_S)
# 2+2+2+2 + 1 + 4 = 13 units
do_shared = T.alloc_shared((block_S, DV), dtype=o_dtype)
q_shared = T.alloc_shared((block_S, DK), dtype=qkva_dtype)
k_shared = T.alloc_shared((block_S, DK), dtype=qkva_dtype)
v_shared = T.alloc_shared((block_S, DV), dtype=qkva_dtype)
a_shared = T.alloc_shared((block_S, block_S), dtype=qkva_dtype)
h_shared = T.alloc_shared((DK, DV), dtype=h_dtype)
g_shared = T.alloc_shared((block_S), dtype=accum_dtype, scope="shared")
g_exp_shared = T.alloc_shared((block_S), dtype=accum_dtype, scope="shared")
g_rev_exp_shared = T.alloc_shared(
(block_S), dtype=accum_dtype, scope="shared"
)
b_shared = T.alloc_shared((block_S), dtype=accum_dtype, scope="shared")
# 2 units
dqkv_shared = T.alloc_shared((block_S, DK), dtype=qkva_dtype)
dg_shared = T.alloc_shared((block_S), dtype=accum_dtype, scope="shared")
db_shared = T.alloc_shared((block_S), dtype=accum_dtype, scope="shared")
# 1+1 + 2+2+2 + 4 = 12 units
tmp_shared_1_1 = T.alloc_shared((block_S, block_S), dtype=qkva_dtype)
tmp_shared_1_2 = T.alloc_shared((block_S, block_S), dtype=qkva_dtype)
tmp_shared_1_3 = T.alloc_shared((block_S, block_S), dtype=qkva_dtype)
tmp_shared_2_1 = T.alloc_shared((block_S, DK), dtype=qkva_dtype)
tmp_shared_2_2 = T.alloc_shared((block_S, DK), dtype=qkva_dtype)
tmp_shared_2_3 = T.alloc_shared((block_S, DK), dtype=qkva_dtype)
tmp_shared_4_1 = T.alloc_shared((DK, DV), dtype=qkva_dtype)
# CONSUMER_K
dk_fragment = T.alloc_fragment((block_S, DK), dtype=accum_dtype)
dv_fragment = T.alloc_fragment((block_S, DK), dtype=accum_dtype)
odot_fragment_1 = T.alloc_fragment((block_S, DK), dtype=accum_dtype)
dg_fragment_1 = T.alloc_fragment((block_S), dtype=accum_dtype)
dg_last_local_1 = T.alloc_fragment((1), dtype=accum_dtype)
# CONSUMER_A
mask_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
p_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
a_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
dp_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
da_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
u_fragment = T.alloc_fragment((block_S, DK), dtype=accum_dtype)
dq_fragment = T.alloc_fragment((block_S, DK), dtype=accum_dtype)
db_fragment = T.alloc_fragment((block_S), dtype=accum_dtype)
odot_fragment_2 = T.alloc_fragment((block_S, DK), dtype=accum_dtype)
dg_fragment_2 = T.alloc_fragment((block_S), dtype=accum_dtype)
# CONSUMER_S
dh_fragment = T.alloc_fragment((DK, DV), dtype=accum_dtype)
_odot_fragment_3 = T.alloc_fragment((DK, DV), dtype=accum_dtype)
reduce_fragment = T.alloc_fragment((128, 2), dtype=accum_dtype)
dg_last_local_3 = T.alloc_fragment((1), dtype=accum_dtype)
g_last_local_3 = T.alloc_local((1), dtype=accum_dtype)
# 16 stages
bar_00 = T.alloc_barrier(arrive_count=448)
bar_01 = T.alloc_barrier(arrive_count=384)
bar_02 = T.alloc_barrier(arrive_count=288)
bar_03 = T.alloc_barrier(arrive_count=256)
bar_04 = T.alloc_barrier(arrive_count=416)
bar_05 = T.alloc_barrier(arrive_count=288)
bar_06 = T.alloc_barrier(arrive_count=256)
bar_07 = T.alloc_barrier(arrive_count=256)
bar_08 = T.alloc_barrier(arrive_count=384)
bar_09 = T.alloc_barrier(arrive_count=256)
bar_10 = T.alloc_barrier(arrive_count=288)
bar_11 = T.alloc_barrier(arrive_count=256)
bar_12 = T.alloc_barrier(arrive_count=128)
bar_13 = T.alloc_barrier(arrive_count=256)
bar_14 = T.alloc_barrier(arrive_count=256)
bar_15 = T.alloc_barrier(arrive_count=256)
T.annotate_layout(
{
do_shared: tilelang.layout.make_swizzled_layout(do_shared),
q_shared: tilelang.layout.make_swizzled_layout(q_shared),
k_shared: tilelang.layout.make_swizzled_layout(k_shared),
v_shared: tilelang.layout.make_swizzled_layout(v_shared),
a_shared: tilelang.layout.make_swizzled_layout(a_shared),
h_shared: tilelang.layout.make_swizzled_layout(h_shared),
dqkv_shared: tilelang.layout.make_swizzled_layout(dqkv_shared),
tmp_shared_1_1: tilelang.layout.make_swizzled_layout(
tmp_shared_1_1
),
tmp_shared_1_2: tilelang.layout.make_swizzled_layout(
tmp_shared_1_2
),
tmp_shared_1_3: tilelang.layout.make_swizzled_layout(
tmp_shared_1_3
),
tmp_shared_2_1: tilelang.layout.make_swizzled_layout(
tmp_shared_2_1
),
tmp_shared_2_2: tilelang.layout.make_swizzled_layout(
tmp_shared_2_2
),
tmp_shared_2_3: tilelang.layout.make_swizzled_layout(
tmp_shared_2_3
),
tmp_shared_4_1: tilelang.layout.make_swizzled_layout(
tmp_shared_4_1
),
}
)
# T.use_swizzle(10)
tx = T.get_thread_binding()
PRODUCER_NREG = 24
CONSUMER_K_NREG = 144
CONSUMER_A_NREG = 176
CONSUMER_S_NREG = 160
# Prefetch the last chunk of data
T.copy(
h[batch_idx, chunk_start_idx + num_iters - 1, bh, 0:DK, 0:DV], h_shared
)
for j_s, j_k in T.Parallel(block_S, DK):
if seq_start_idx + (num_iters - 1) * block_S + j_s < seq_end_idx:
q_shared[j_s, j_k] = q[
batch_idx,
seq_start_idx + (num_iters - 1) * block_S + j_s,
bhg,
j_k,
]
else:
q_shared[j_s, j_k] = 0
for j_s, j_k in T.Parallel(block_S, DK):
if seq_start_idx + (num_iters - 1) * block_S + j_s < seq_end_idx:
k_shared[j_s, j_k] = k[
batch_idx,
seq_start_idx + (num_iters - 1) * block_S + j_s,
bhg,
j_k,
]
else:
k_shared[j_s, j_k] = 0
for j_s, j_v in T.Parallel(block_S, DV):
if seq_start_idx + (num_iters - 1) * block_S + j_s < seq_end_idx:
v_shared[j_s, j_v] = v[
batch_idx,
seq_start_idx + (num_iters - 1) * block_S + j_s,
bh,
j_v,
]
else:
v_shared[j_s, j_v] = 0
for j_s, j_t in T.Parallel(block_S, block_S):
if seq_start_idx + (num_iters - 1) * block_S + j_s < seq_end_idx:
a_shared[j_s, j_t] = a[
batch_idx,
seq_start_idx + (num_iters - 1) * block_S + j_s,
bh,
j_t,
]
else:
a_shared[j_s, j_t] = 0
for j_s, j_v in T.Parallel(block_S, DV):
if seq_start_idx + (num_iters - 1) * block_S + j_s < seq_end_idx:
do_shared[j_s, j_v] = do[
batch_idx,
seq_start_idx + (num_iters - 1) * block_S + j_s,
bh,
j_v,
]
else:
do_shared[j_s, j_v] = 0
for j_s in T.Parallel(block_S):
if seq_start_idx + (num_iters - 1) * block_S + j_s < seq_end_idx:
g_shared[j_s] = g[
batch_idx, seq_start_idx + (num_iters - 1) * block_S + j_s, bh
]
else:
g_shared[j_s] = g[batch_idx, seq_end_idx - 1, bh]
for j_s in T.Parallel(block_S):
if seq_start_idx + (num_iters - 1) * block_S + j_s < seq_end_idx:
b_shared[j_s] = b[
batch_idx, seq_start_idx + (num_iters - 1) * block_S + j_s, bh
]
else:
b_shared[j_s] = 0
if tx < 128:
T.set_max_nreg(CONSUMER_S_NREG, 1)
if use_dht:
T.copy(dht[bb, bh, 0:DK, 0:DV], dh_fragment)
else:
T.clear(dh_fragment)
T.copy(dh_fragment, tmp_shared_4_1)
for i_s in T.serial(num_iters):
T.barrier_arrive(bar_00)
# 00
T.barrier_wait(bar_00, (i_s + 0) % 2)
for j_s in T.Parallel(block_S):
g_exp_shared[j_s] = T.exp2(g_shared[j_s] * 1.442695)
g_rev_exp_shared[j_s] = T.exp2(
(g_shared[block_S - 1] - g_shared[j_s]) * 1.442695
)
T.barrier_arrive(bar_01)
# 01, 02, 03
T.barrier_wait(bar_01, (i_s + 0) % 2)
g_last_local_3[0] = g_exp_shared[block_S - 1]
# dS0 = g_last * dSt
for j_k, j_v in T.Parallel(DK, DV):
dh_fragment[j_k, j_v] *= g_last_local_3[0]
T.barrier_arrive(bar_04)
# 04, 05, 06, 07
T.barrier_wait(bar_04, (i_s + 0) % 2)
# dg_last += sum(dS0 * S0)
T.clear(reduce_fragment)
for j_k, j_v in T.Parallel(DK, DV):
reduce_fragment[
j_k % 64 // 16 * 32 + j_k % 8 * 4 + j_v % 8 // 2, j_v % 2
] += dh_fragment[j_k, j_v] * h_shared[j_k, j_v]
T.barrier_arrive(bar_08)
T.barrier_wait(bar_08, (i_s + 0) % 2)
T.barrier_wait(bar_09, (i_s + 0) % 2)
# 10
T.barrier_wait(bar_10, (i_s + 0) % 2)
T.reduce_sum(
T.reshape(reduce_fragment, (128 * 2,)),
dg_last_local_3,
dim=0,
clear=True,
)
dg_shared[block_S - 1] += dg_last_local_3[0]
T.barrier_arrive(bar_11)
# 11
T.barrier_wait(bar_11, (i_s + 0) % 2)
# dS0 += K^T @ dVg
T.gemm_v1(
tmp_shared_2_2,
tmp_shared_2_3,
dh_fragment,
transpose_A=True,
clear_accum=False,
)
T.barrier_arrive(bar_12)
T.barrier_wait(bar_12, (i_s + 0) % 2)
# 13
T.barrier_wait(bar_13, (i_s + 0) % 2)
# dOg = s * g * dO
for j_s, j_v in T.Parallel(block_S, DV):
tmp_shared_2_3[j_s, j_v] = (
scale * do_shared[j_s, j_v] * g_exp_shared[j_s]
)
T.barrier_arrive(bar_14)
# 14
T.barrier_wait(bar_14, (i_s + 0) % 2)
# dS0 += Q^T @ dOg
T.gemm_v1(
tmp_shared_2_1,
tmp_shared_2_3,
dh_fragment,
transpose_A=True,
clear_accum=False,
)
T.barrier_arrive(bar_15)
# 15
T.barrier_wait(bar_15, (i_s + 0) % 2)
# S4[1] = dS0
T.copy(dh_fragment, tmp_shared_4_1)
if use_dht:
T.copy(dh_fragment, dh0[bb, bh, 0:DK, 0:DV])
elif tx < 256:
T.set_max_nreg(CONSUMER_K_NREG, 1)
for i_s in T.serial(num_iters):
T.barrier_arrive(bar_00)
# 16 == 00
T.barrier_wait(bar_00, (i_s + 0) % 2)
# S2[S] dK
if i_s > 0:
T.copy(dk_fragment, dqkv_shared)
T.barrier_arrive(bar_01)
# 01
T.barrier_wait(bar_01, (i_s + 0) % 2)
# dV' = K @ dSt
T.gemm_v1(k_shared, tmp_shared_4_1, dv_fragment, clear_accum=True)
# dV' = g_last/g * dV'
for j_s, j_v in T.Parallel(block_S, DV):
dv_fragment[j_s, j_v] *= g_rev_exp_shared[j_s]
T.barrier_arrive(bar_02)
# 02
T.barrier_wait(bar_02, (i_s + 0) % 2)
# dV' += Pg^T @ dO
T.gemm_v1(
tmp_shared_1_1,
do_shared,
dv_fragment,
transpose_A=True,
clear_accum=False,
)
T.barrier_arrive(bar_03)
# 03
T.barrier_wait(bar_03, (i_s + 0) % 2)
# S2[1] dV'
T.copy(dv_fragment, tmp_shared_2_1)
T.barrier_arrive(bar_04)
# 04
T.barrier_wait(bar_04, (i_s + 0) % 2)
# dV = Ag^T @ dV'
T.gemm_v1(
tmp_shared_1_2,
tmp_shared_2_1,
dv_fragment,
transpose_A=True,
clear_accum=True,
)
# S2[S] dV
T.copy(dv_fragment, dqkv_shared)
T.barrier_arrive(bar_05)
# 05
T.barrier_wait(bar_05, (i_s + 0) % 2)
# dVg = -g * dV
for j_s, j_v in T.Parallel(block_S, DV):
dv_fragment[j_s, j_v] = (
-dv_fragment[j_s, j_v] * g_exp_shared[j_s]
)
# dg += sum(dVg * U)
T.copy(tmp_shared_2_3, odot_fragment_1)
for j_s, j_v in T.Parallel(block_S, DV):
odot_fragment_1[j_s, j_v] *= dv_fragment[j_s, j_v]
T.reduce_sum(odot_fragment_1, dg_fragment_1, dim=1, clear=True)
T.copy(dg_fragment_1, dg_shared)
# S2[3] dVg
T.copy(dv_fragment, tmp_shared_2_3)
T.barrier_arrive(bar_06)
# 06
T.barrier_wait(bar_06, (i_s + 0) % 2)
# S2[2] K
T.copy(k_shared, odot_fragment_1)
T.copy(odot_fragment_1, tmp_shared_2_2)
T.barrier_arrive(bar_07)
# 07
T.barrier_wait(bar_07, (i_s + 0) % 2)
# dK = V' @ dSt^T
T.gemm_v1(
tmp_shared_2_1,
tmp_shared_4_1,
dk_fragment,
transpose_B=True,
clear_accum=True,
)
T.barrier_arrive(bar_08)
# 08
T.barrier_wait(bar_08, (i_s + 0) % 2)
# dK = g_last/g * dK
for j_s, j_k in T.Parallel(block_S, DK):
dk_fragment[j_s, j_k] *= g_rev_exp_shared[j_s]
# dg -= sum(K * dK)
for j_s, j_k in T.Parallel(block_S, DK):
odot_fragment_1[j_s, j_k] *= -dk_fragment[j_s, j_k]
T.reduce_sum(odot_fragment_1, dg_fragment_1, dim=1, clear=True)
for j_s in T.Parallel(block_S):
dg_shared[j_s] += dg_fragment_1[j_s]
# dg_last += sum(K * dK)
T.reduce_sum(dg_fragment_1, dg_last_local_1, dim=0, clear=True)
# Sg[S] dg
dg_shared[block_S - 1] -= dg_last_local_1[0]
T.barrier_arrive(bar_09)
# 09
T.barrier_wait(bar_09, (i_s + 0) % 2)
# dK += dVg @ S0^T
T.gemm_v1(
tmp_shared_2_3,
h_shared,
dk_fragment,
transpose_B=True,
clear_accum=False,
)
T.barrier_arrive(bar_10)
T.barrier_wait(bar_10, (i_s + 0) % 2)
# 12
T.barrier_wait(bar_12, (i_s + 0) % 2)
# dK += dP^T @ Q
T.gemm_v1(
tmp_shared_1_1,
tmp_shared_2_1,
dk_fragment,
transpose_A=True,
clear_accum=False,
)
T.barrier_arrive(bar_13)
T.barrier_wait(bar_13, (i_s + 0) % 2)
# 15
T.barrier_wait(bar_15, (i_s + 0) % 2)
# dK += dAs @ K
T.gemm_v1(
tmp_shared_1_2, tmp_shared_2_2, dk_fragment, clear_accum=False
)
for j_s, j_k in T.Parallel(block_S, DK):
if seq_start_idx + j_s < seq_end_idx:
dk[batch_idx, seq_start_idx + j_s, bh, j_k] = dk_fragment[
j_s, j_k
]
elif tx < 384:
T.set_max_nreg(CONSUMER_A_NREG, 1)
for i_s in T.serial(num_iters):
T.barrier_arrive(bar_00)
# 00
T.barrier_wait(bar_00, (i_s + 0) % 2)
# P = Q @ K^T
T.gemm_v1(
q_shared,
k_shared,
p_fragment,
transpose_B=True,
clear_accum=True,
)
T.barrier_arrive(bar_01)
# 01
T.barrier_wait(bar_01, (i_s + 0) % 2)
# G = Lower(diag(g) @ I @ diag(1/g))
for j_s, j_t in T.Parallel(block_S, block_S):
mask_fragment[j_s, j_t] = g_shared[j_s] - g_shared[j_t]
for j_s, j_t in T.Parallel(block_S, block_S):
if j_s >= j_t:
mask_fragment[j_s, j_t] = T.exp2(
mask_fragment[j_s, j_t] * 1.442695
)
else:
mask_fragment[j_s, j_t] = 0
# Pg = s * P * G
for j_s, j_t in T.Parallel(block_S, block_S):
p_fragment[j_s, j_t] *= mask_fragment[j_s, j_t]
for j_s, j_t in T.Parallel(block_S, block_S):
p_fragment[j_s, j_t] *= scale
# S1[1] Pg
T.copy(p_fragment, tmp_shared_1_1)
T.barrier_arrive(bar_02)
# 02
T.barrier_wait(bar_02, (i_s + 0) % 2)
# Ab = Ar * b
T.copy(a_shared, a_fragment)
for j_s, j_t in T.Parallel(block_S, block_S):
a_fragment[j_s, j_t] *= b_shared[j_t]
# Ag = G * Ab
for j_s, j_t in T.Parallel(block_S, block_S):
a_fragment[j_s, j_t] *= mask_fragment[j_s, j_t]
# S1[2] Ag
T.copy(a_fragment, tmp_shared_1_2)
T.barrier_arrive(bar_03)
# 03
T.barrier_wait(bar_03, (i_s + 0) % 2)
# U = K @ S0
T.gemm_v1(k_shared, h_shared, u_fragment, clear_accum=True)
T.barrier_arrive(bar_04)
# 04
T.barrier_wait(bar_04, (i_s + 0) % 2)
# S2[3] U
T.copy(u_fragment, tmp_shared_2_3)
# W = V - g * U
for j_s, j_v in T.Parallel(block_S, DV):
u_fragment[j_s, j_v] *= -g_exp_shared[j_s]
for j_s, j_v in T.Parallel(block_S, DV):
u_fragment[j_s, j_v] += v_shared[j_s, j_v]
# S2[2] W
T.copy(u_fragment, tmp_shared_2_2)
T.barrier_arrive(bar_05)
# 05
T.barrier_wait(bar_05, (i_s + 0) % 2)
# dAg = dV' @ W^T
T.gemm_v1(
tmp_shared_2_1,
tmp_shared_2_2,
da_fragment,
transpose_B=True,
clear_accum=True,
)
# V' = Ag @ W
T.gemm_v1(
tmp_shared_1_2, tmp_shared_2_2, u_fragment, clear_accum=True
)
# S2[1] V'
T.copy(u_fragment, tmp_shared_2_1)
T.barrier_arrive(bar_06)
# 06
T.barrier_wait(bar_06, (i_s + 0) % 2)
# dPg = dO @ V'^T
T.gemm_v1(
do_shared,
tmp_shared_2_1,
dp_fragment,
transpose_B=True,
clear_accum=True,
)
T.barrier_arrive(bar_07)
# 07
T.barrier_wait(bar_07, (i_s + 0) % 2)
# dAb = G * dAg
for j_s, j_t in T.Parallel(block_S, block_S):
da_fragment[j_s, j_t] *= mask_fragment[j_s, j_t]
# dg += sum((dPg * P) - (dPg * P)^T)
T.copy(tmp_shared_1_1, p_fragment)
for j_s, j_t in T.Parallel(block_S, block_S):
p_fragment[j_s, j_t] *= dp_fragment[j_s, j_t]
T.copy(p_fragment, tmp_shared_1_1)
for j_s, j_t in T.Parallel(block_S, block_S):
p_fragment[j_s, j_t] -= tmp_shared_1_1[j_t, j_s]
T.reduce_sum(p_fragment, dg_fragment_2, dim=1, clear=True)
# dP = s * G * dPg
for j_s, j_t in T.Parallel(block_S, block_S):
dp_fragment[j_s, j_t] *= mask_fragment[j_s, j_t]
for j_s, j_t in T.Parallel(block_S, block_S):
dp_fragment[j_s, j_t] *= scale
# S1[1] dP
T.copy(dp_fragment, tmp_shared_1_1)
T.barrier_arrive(bar_08)
# 08
T.barrier_wait(bar_08, (i_s + 0) % 2)
# dQ = dO @ S0^T
T.gemm_v1(
do_shared,
h_shared,
dq_fragment,
transpose_B=True,
clear_accum=True,
)
T.barrier_arrive(bar_09)
# 09
T.barrier_wait(bar_09, (i_s + 0) % 2)
# dQ = s * g * dQ
for j_s, j_k in T.Parallel(block_S, DK):
dq_fragment[j_s, j_k] *= g_exp_shared[j_s]
for j_s, j_k in T.Parallel(block_S, DK):
dq_fragment[j_s, j_k] *= scale
# S2[1] Q
T.copy(q_shared, odot_fragment_2)
# dg += sum(Q * dQ)
T.copy(odot_fragment_2, tmp_shared_2_1)
for j_s, j_k in T.Parallel(block_S, DK):
odot_fragment_2[j_s, j_k] *= dq_fragment[j_s, j_k]
T.reduce_sum(odot_fragment_2, dg_fragment_2, dim=1, clear=False)
T.barrier_arrive(bar_10)
# 10
T.barrier_wait(bar_10, (i_s + 0) % 2)
# dQ += dP @ K
T.gemm_v1(
tmp_shared_1_1, tmp_shared_2_2, dq_fragment, clear_accum=False
)
# S2[S] dQ
T.copy(dq_fragment, dqkv_shared)
T.barrier_arrive(bar_11)
# 11, 12
T.barrier_wait(bar_11, (i_s + 0) % 2)
# dAb * Ar
T.copy(a_shared, a_fragment)
for j_s, j_t in T.Parallel(block_S, block_S):
a_fragment[j_s, j_t] *= da_fragment[j_s, j_t]
T.copy(a_fragment, tmp_shared_1_3)
# dAb * Ab [ = G * dAg * Ab ]
for j_s, j_t in T.Parallel(block_S, block_S):
a_fragment[j_s, j_t] *= b_shared[j_t]
# dg += sum((dAb * Ab) - (dAb * Ab)^T)
T.copy(a_fragment, tmp_shared_1_2)
for j_s, j_t in T.Parallel(block_S, block_S):
a_fragment[j_s, j_t] -= tmp_shared_1_2[j_t, j_s]
T.reduce_sum(a_fragment, dg_fragment_2, dim=1, clear=False)
# Sg[S] dg
for j_s in T.Parallel(block_S):
dg_shared[j_s] += dg_fragment_2[j_s]
# db = sum((dAb * Ar)^T)
for j_s, j_t in T.Parallel(block_S, block_S):
a_fragment[j_s, j_t] = tmp_shared_1_3[j_t, j_s]
T.reduce_sum(a_fragment, db_fragment, dim=1, clear=True)
# dAr = dAb * b
for j_s, j_t in T.Parallel(block_S, block_S):
da_fragment[j_s, j_t] *= b_shared[j_t]
# S1[2] dAr
T.copy(da_fragment, tmp_shared_1_2)
T.barrier_arrive(bar_13)
# 13
T.barrier_wait(bar_13, (i_s + 0) % 2)
# dA = -Ar^T @ dAr @ Ar^T
T.gemm_v1(
a_shared,
tmp_shared_1_2,
da_fragment,
transpose_A=True,
clear_accum=True,
)
T.copy(da_fragment, tmp_shared_1_2)
T.gemm_v1(
tmp_shared_1_2,
a_shared,
da_fragment,
transpose_B=True,
clear_accum=True,
)
# At = K @ K^T
T.gemm_v1(
tmp_shared_2_2,
tmp_shared_2_2,
a_fragment,
transpose_B=True,
clear_accum=True,
)
T.barrier_arrive(bar_14)
# 14
T.barrier_wait(bar_14, (i_s + 0) % 2)
for j_s, j_t in T.Parallel(block_S, block_S):
if j_s <= j_t:
da_fragment[j_s, j_t] = 0
else:
da_fragment[j_s, j_t] = -da_fragment[j_s, j_t]
# db += sum(dA * At)
for j_s, j_t in T.Parallel(block_S, block_S):
a_fragment[j_s, j_t] *= da_fragment[j_s, j_t]
T.reduce_sum(a_fragment, db_fragment, dim=1, clear=False)
T.copy(db_fragment, db_shared)
# dAt = b * dA
for j_s, j_t in T.Parallel(block_S, block_S):
da_fragment[j_s, j_t] *= b_shared[j_s]
# dAs = dAt + dAt^T
T.copy(da_fragment, tmp_shared_1_2)
for j_s, j_t in T.Parallel(block_S, block_S):
da_fragment[j_s, j_t] += tmp_shared_1_2[j_t, j_s]
# S1[1] dAs
T.copy(da_fragment, tmp_shared_1_2)
T.barrier_arrive(bar_15)
T.barrier_wait(bar_15, (i_s + 0) % 2)
else:
T.set_max_nreg(PRODUCER_NREG, 0)
if tx < 384 + 32:
for i_s in T.serial(num_iters - 1):
chunk_idx = num_iters - i_s - 2
left = seq_start_idx + chunk_idx * block_S
right = left + block_S
T.barrier_arrive(bar_00)
T.barrier_wait(bar_00, (i_s + 0) % 2)
T.barrier_wait(bar_03, (i_s + 0) % 2)
for j_s in T.Parallel(block_S):
g_shared[j_s] = g[batch_idx, left + j_s, bh]
T.barrier_wait(bar_05, (i_s + 0) % 2)
T.copy(v[batch_idx, left:right, bh, 0:DV], v_shared)
T.barrier_wait(bar_07, (i_s + 0) % 2)
T.copy(k[batch_idx, left:right, bhg, 0:DK], k_shared)
T.barrier_wait(bar_10, (i_s + 0) % 2)
T.copy(q[batch_idx, left:right, bhg, 0:DK], q_shared)
if num_iters > 0:
T.barrier_arrive(bar_00)
elif tx < 384 + 64:
for i_s in T.serial(num_iters):
left = seq_start_idx + (num_iters - i_s - 1) * block_S
right = left + block_S
T.barrier_arrive(bar_00)
T.barrier_wait(bar_00, (i_s + 0) % 2)
T.barrier_wait(bar_01, (i_s + 0) % 2)
if i_s == 1:
for j_s, j_k in T.Parallel(block_S, DK):
if left + block_S + j_s < seq_end_idx:
dk[batch_idx, left + block_S + j_s, bh, j_k] = (
dqkv_shared[j_s, j_k]
)
elif i_s > 1:
T.copy(
dqkv_shared,
dk[
batch_idx,
left + block_S : right + block_S,
bh,
0:DK,
],
)
T.barrier_arrive(bar_04)
T.barrier_wait(bar_04, (i_s + 0) % 2)
T.barrier_wait(bar_05, (i_s + 0) % 2)
if i_s == 0:
for j_s, j_v in T.Parallel(block_S, DV):
if left + j_s < seq_end_idx:
dv[batch_idx, left + j_s, bh, j_v] = dqkv_shared[
j_s, j_v
]
else:
T.copy(dqkv_shared, dv[batch_idx, left:right, bh, 0:DV])
T.barrier_arrive(bar_10)
T.barrier_wait(bar_10, (i_s + 0) % 2)
T.barrier_wait(bar_11, (i_s + 0) % 2)
if i_s == 0:
for j_s, j_k in T.Parallel(block_S, DK):
if left + j_s < seq_end_idx:
dq[batch_idx, left + j_s, bh, j_k] = dqkv_shared[
j_s, j_k
]
else:
T.copy(dqkv_shared, dq[batch_idx, left:right, bh, 0:DK])
elif tx < 384 + 96:
for i_s in T.serial(num_iters - 1):
chunk_idx = num_iters - i_s - 2
left = seq_start_idx + chunk_idx * block_S
right = left + block_S
T.barrier_arrive(bar_02)
T.barrier_wait(bar_02, (i_s + 0) % 2)
T.barrier_wait(bar_10, (i_s + 0) % 2)
T.copy(
h[batch_idx, chunk_start_idx + chunk_idx, bh, 0:DK, 0:DV],
h_shared,
)
T.barrier_wait(bar_14, (i_s + 0) % 2)
T.copy(a[batch_idx, left:right, bh, 0:block_S], a_shared)
T.copy(do[batch_idx, left:right, bh, 0:DV], do_shared)
T.barrier_wait(bar_15, (i_s + 0) % 2)
for j_s in T.Parallel(block_S):
b_shared[j_s] = b[batch_idx, left + j_s, bh]
if num_iters > 0:
T.barrier_wait(bar_00, (num_iters - 1) % 2)
T.barrier_arrive(bar_02)
else:
for i_s in T.serial(num_iters):
left = seq_start_idx + (num_iters - i_s - 1) * block_S
T.barrier_arrive(bar_05)
T.barrier_wait(bar_05, (i_s + 0) % 2)
T.barrier_wait(bar_15, (i_s + 0) % 2)
if i_s == 0:
for j_s in T.Parallel(block_S):
if left + j_s < seq_end_idx:
dg[batch_idx, left + j_s, bh] = dg_shared[j_s]
if (seq_end_idx - seq_start_idx) % block_S > 0:
dg[batch_idx, seq_end_idx - 1, bh] += dg_shared[
block_S - 1
]
else:
for j_s in T.Parallel(block_S):
dg[batch_idx, left + j_s, bh] = dg_shared[j_s]
if i_s == 0:
for j_s in T.Parallel(block_S):
if left + j_s < seq_end_idx:
db[batch_idx, left + j_s, bh] = db_shared[j_s]
else:
for j_s in T.Parallel(block_S):
db[batch_idx, left + j_s, bh] = db_shared[j_s]
return tilelang_fused_chunk_gdr_bwd_kernel
def fused_gdr_bwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
a: torch.Tensor,
g: torch.Tensor,
b: torch.Tensor,
do: torch.Tensor,
dht: torch.Tensor,
h: torch.Tensor,
scale: float | None = None,
cu_seqlens: torch.LongTensor | None = None,
chunk_size: int = 64,
):
batch_size, num_tokens, Hg, K = k.shape
_, _, H, V = v.shape
scale = scale or K ** (-0.5)
assert K == V == 128
assert chunk_size == 64
if cu_seqlens is None:
real_batch_size = batch_size
cu_seqlens = torch.empty((batch_size + 1), dtype=torch.int32, device=k.device)
chunk_offsets = torch.empty(
(batch_size + 1), dtype=torch.int32, device=k.device
)
is_varlen = False
else:
real_batch_size = len(cu_seqlens) - 1
chunk_offsets, _ = prepare_chunk_offsets(cu_seqlens, chunk_size)
chunk_offsets = chunk_offsets.to(cu_seqlens.dtype)
is_varlen = True
use_dht = dht is not None
if dht is None:
dht = torch.empty(
(real_batch_size, H, K, V), dtype=torch.float32, device=k.device
)
dq = torch.empty_like(v)
dk = torch.empty_like(v)
dv = torch.empty_like(v)
dg = torch.empty_like(g)
db = torch.empty_like(b)
dh0 = torch.empty_like(dht)
tilelang_fused_chunk_gdr_bwd_kernel = tilelang_fused_chunk_gdr_bwd(
H,
Hg,
K,
V,
chunk_size,
scale,
qkva_dtype=q.dtype,
g_dtype=g.dtype,
b_dtype=b.dtype,
h_dtype=h.dtype,
o_dtype=do.dtype,
seqlen_dtype=cu_seqlens.dtype,
accum_dtype="float32",
is_varlen=is_varlen,
use_dht=use_dht,
)
tilelang_fused_chunk_gdr_bwd_kernel(
do,
dht,
q,
k,
v,
a,
g,
b,
h,
cu_seqlens,
chunk_offsets,
dq,
dk,
dv,
dg,
db,
dh0,
)
if not use_dht:
dh0 = None
return dq, dk, dv, dg, db, dh0

View File

@@ -0,0 +1,658 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
import torch
import tilelang
import tilelang.language as T
from flash_qla.utils import prepare_chunk_offsets
MULTI_PROCESSOR_COUNT = torch.cuda.get_device_properties().multi_processor_count
TARGET_NUM_CTAS = int(MULTI_PROCESSOR_COUNT * 0.7)
@tilelang.jit(
# out_idx=[-3, -2, -1],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
# tilelang.PassConfigKey.TL_DISABLE_THREAD_STORAGE_SYNC: True,
},
)
def tilelang_fused_chunk_gdr_fwd(
H,
Hg,
DK,
DV,
chunk_size,
scale,
accum_dtype,
qkva_dtype,
g_dtype,
b_dtype,
h0_dtype,
ht_dtype,
h_dtype,
o_dtype,
seqlen_dtype,
use_initial_state,
store_final_state,
store_h,
store_o,
is_varlen,
is_cp,
block_DV=128,
):
batch_size = T.dynamic("batch_size")
num_tokens = T.dynamic("num_tokens")
num_chunks = T.dynamic("num_chunks")
raw_batch_size = T.dynamic("raw_batch_size")
block_S = chunk_size
if is_varlen:
q_shape = (1, num_tokens, Hg, DK)
k_shape = (1, num_tokens, Hg, DK)
v_shape = (1, num_tokens, H, DV)
o_shape = (1, num_tokens, H, DV)
a_shape = (1, num_tokens, H, chunk_size)
g_shape = (1, num_tokens, H)
b_shape = (1, num_tokens, H)
h_shape = (1, num_chunks, H, DK, DV)
else:
q_shape = (batch_size, num_tokens, Hg, DK)
k_shape = (batch_size, num_tokens, Hg, DK)
v_shape = (batch_size, num_tokens, H, DV)
o_shape = (batch_size, num_tokens, H, DV)
a_shape = (batch_size, num_tokens, H, chunk_size)
g_shape = (batch_size, num_tokens, H)
b_shape = (batch_size, num_tokens, H)
h_shape = (batch_size, num_chunks, H, DK, DV)
h0_shape = (batch_size, H, DK, DV)
ht_shape = (raw_batch_size, H, DK, DV)
@T.prim_func
def tilelang_fused_chunk_gdr_fwd_kernel(
q: T.Tensor(q_shape, dtype=qkva_dtype),
k: T.Tensor(k_shape, dtype=qkva_dtype),
v: T.Tensor(v_shape, dtype=qkva_dtype),
a: T.Tensor(a_shape, dtype=qkva_dtype),
g: T.Tensor(g_shape, dtype=g_dtype),
b: T.Tensor(b_shape, dtype=b_dtype),
h0: T.Tensor(h0_shape, dtype=h0_dtype),
cu_seqlens: T.Tensor([batch_size + 1], dtype=seqlen_dtype),
chunk_offsets: T.Tensor([batch_size + 1], dtype=seqlen_dtype),
cp_seq_map: T.Tensor([batch_size], dtype=seqlen_dtype),
raw_cu_seqlens: T.Tensor([raw_batch_size + 1], dtype=seqlen_dtype),
o: T.Tensor(o_shape, dtype=o_dtype),
h: T.Tensor(h_shape, dtype=h_dtype),
ht: T.Tensor(ht_shape, dtype=ht_dtype),
):
with T.Kernel(T.ceildiv(DV, block_DV) * batch_size * H, threads=512) as (bbhv,):
bbh, bv = bbhv // T.ceildiv(DV, block_DV), bbhv % T.ceildiv(DV, block_DV)
bb, bh = bbh // H, bbh % H
bhg = bh // (H // Hg)
batch_idx = T.alloc_var("int32")
seq_start_idx = T.alloc_var("int32")
seq_end_idx = T.alloc_var("int32")
seq_split_idx = T.alloc_var("int32")
chunk_start_idx = T.alloc_var("int32")
chunk_split_idx = T.alloc_var("int32")
batch_idx = 0 if is_varlen else bb
seq_start_idx = cu_seqlens[bb] if is_varlen else 0
seq_end_idx = cu_seqlens[bb + 1] if is_varlen else num_tokens
chunk_start_idx = chunk_offsets[bb] if is_varlen else 0
raw_batch_idx = T.alloc_var("int32")
raw_seq_end_idx = T.alloc_var("int32")
need_store_final_state = T.alloc_var("bool")
raw_batch_idx = cp_seq_map[bb] if is_cp else bb
raw_seq_end_idx = (
raw_cu_seqlens[raw_batch_idx + 1] if is_cp else seq_end_idx
)
need_store_final_state = store_final_state & (
raw_seq_end_idx == seq_end_idx
)
num_iters = T.alloc_var("int32")
num_unmasked_iters = T.alloc_var("int32")
num_iters = T.ceildiv(seq_end_idx - seq_start_idx, block_S)
num_unmasked_iters = (seq_end_idx - seq_start_idx) // block_S
q_shared = T.alloc_shared((2, block_S, DK), dtype=qkva_dtype)
k_shared = T.alloc_shared((2, block_S, DK), dtype=qkva_dtype)
v_shared = T.alloc_shared((2, block_S, block_DV), dtype=qkva_dtype)
a_shared = T.alloc_shared((2, block_S, block_S), dtype=qkva_dtype)
g_shared = T.alloc_shared((2, block_S), dtype=accum_dtype, scope="shared")
b_shared = T.alloc_shared((2, block_S), dtype=accum_dtype, scope="shared")
o_shared = T.alloc_shared((block_S, block_DV), dtype=o_dtype)
h_shared = T.alloc_shared((DK, block_DV), dtype=qkva_dtype)
vd_shared = T.alloc_shared((block_S, block_DV), dtype=qkva_dtype)
vn_shared = T.alloc_shared((block_S, block_DV), dtype=qkva_dtype)
p_shared = T.alloc_shared((block_S, block_S), dtype=qkva_dtype)
g_exp_shared = T.alloc_shared((block_S), dtype=accum_dtype, scope="shared")
g_rev_exp_shared = T.alloc_shared(
(block_S), dtype=accum_dtype, scope="shared"
)
h_fragment = T.alloc_fragment((DK, block_DV), dtype=accum_dtype)
o_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
v_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
u_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
p_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
a_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
g_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
g_last_local = T.alloc_local((1), dtype=accum_dtype)
data_is_ready = T.alloc_barrier(arrive_count=[96] * 2)
data_is_free = T.alloc_barrier(arrive_count=[384] * 2)
bar_o = T.alloc_barrier(arrive_count=128)
bar_0 = T.alloc_barrier(arrive_count=416)
bar_1 = T.alloc_barrier(arrive_count=256)
_bar_2 = T.alloc_barrier(arrive_count=128)
bar_3 = T.alloc_barrier(arrive_count=128)
bar_4 = T.alloc_barrier(arrive_count=128)
bar_5 = T.alloc_barrier(arrive_count=416)
T.use_swizzle(10)
tx = T.get_thread_binding()
PRODUCER_NREG = 32
CONSUMER_V_NREG = 128
CONSUMER_S_NREG = 160
CONSUMER_O_NREG = 128
if tx < 128:
T.set_max_nreg(CONSUMER_S_NREG, 1)
# Initialize S
if use_initial_state:
T.copy(
h0[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV],
h_fragment,
)
else:
T.clear(h_fragment)
# Main Loop
for i_s in T.serial(num_iters):
# [STAGE 0]
T.barrier_wait(data_is_ready[i_s % 2], (i_s // 2 + 0) % 2)
T.barrier_arrive(bar_0)
# [STAGE 0] 0
T.barrier_wait(bar_0, i_s % 2)
# S4[S] S
T.copy(h_fragment, h_shared)
T.barrier_arrive(bar_1)
# [STAGE 0] 2, 3, 4
T.barrier_wait(bar_1, i_s % 2)
# S = g_last * S
g_last_local[0] = g_exp_shared[block_S - 1]
for j_k, j_v in T.Parallel(DK, block_DV):
h_fragment[j_k, j_v] *= g_last_local[0]
T.barrier_arrive(bar_5)
# [STAGE 0] 5
T.barrier_wait(bar_5, i_s % 2)
# S += K^T @ V'
T.gemm_v1(
k_shared[i_s % 2, :, :],
vn_shared,
h_fragment,
transpose_A=True,
clear_accum=False,
)
T.barrier_arrive(data_is_free[i_s % 2])
# Store final S
if need_store_final_state:
T.copy(
h_fragment,
ht[
raw_batch_idx, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV
],
)
elif tx < 256:
T.set_max_nreg(CONSUMER_V_NREG, 1)
# Main Loop
for i_s in T.serial(num_iters):
# [STAGE 0]
T.barrier_wait(data_is_ready[i_s % 2], (i_s // 2 + 0) % 2)
T.barrier_arrive(bar_0)
# [STAGE 0] 0
T.barrier_wait(bar_0, i_s % 2)
# Precompute g, g_last/g
for j_s in T.Parallel(block_S):
g_exp_shared[j_s] = T.exp2(g_shared[i_s % 2, j_s] * 1.442695)
for j_s in T.Parallel(block_S):
g_rev_exp_shared[j_s] = T.if_then_else(
seq_start_idx + i_s * block_S + j_s < seq_end_idx,
T.exp2(
(
g_shared[i_s % 2, block_S - 1]
- g_shared[i_s % 2, j_s]
)
* 1.442695
),
0.0,
)
T.barrier_arrive(bar_1)
# [STAGE 0] 1
T.barrier_wait(bar_1, i_s % 2)
# U = K @ S
T.gemm_v1(
k_shared[i_s % 2, :, :], h_shared, u_fragment, clear_accum=True
)
# [STAGE 0] 2
# W = V - g * U
for j_s, j_v in T.Parallel(block_S, block_DV):
u_fragment[j_s, j_v] *= -g_exp_shared[j_s]
for j_s, j_v in T.Parallel(block_S, block_DV):
u_fragment[j_s, j_v] += v_shared[i_s % 2, j_s, j_v]
# S2[V] W
for j_s, j_v in T.Parallel(block_S, block_DV):
v_shared[i_s % 2, j_s, j_v] = u_fragment[j_s, j_v]
# [STAGE 0] 3
T.barrier_wait(bar_3, i_s % 2)
# Vd = Ag @ W
T.gemm_v1(
a_shared[i_s % 2, :, :],
v_shared[i_s % 2, :, :],
v_fragment,
clear_accum=True,
)
# S2[2] Vd
T.copy(v_fragment, vd_shared)
T.barrier_arrive(bar_4)
# [STAGE 0] 4
# V' = g_last/g Vd
for j_s, j_v in T.Parallel(block_S, block_DV):
v_fragment[j_s, j_v] *= g_rev_exp_shared[j_s]
# S2[1] V'
T.copy(v_fragment, vn_shared)
T.barrier_arrive(bar_5)
T.barrier_wait(bar_5, i_s % 2)
T.barrier_arrive(data_is_free[i_s % 2])
elif tx < 384:
T.set_max_nreg(CONSUMER_O_NREG, 1)
# Main Loop
for i_s in T.serial(num_iters):
# [STAGE 0]
T.barrier_wait(data_is_ready[i_s % 2], (i_s // 2 + 0) % 2)
T.barrier_arrive(bar_0)
# [STAGE 0] 0
T.barrier_wait(bar_0, i_s % 2)
# P = Q K^T
T.gemm_v1(
q_shared[i_s % 2, :, :],
k_shared[i_s % 2, :, :],
p_fragment,
transpose_B=True,
clear_accum=True,
)
# [STAGE 0] 1
# G = Lower(diag(g) @ I @ diag(1/g))
for j_s, j_t in T.Parallel(block_S, block_S):
g_fragment[j_s, j_t] = (
g_shared[i_s % 2, j_s] - g_shared[i_s % 2, j_t]
)
for j_s, j_t in T.Parallel(block_S, block_S):
if j_s >= j_t:
g_fragment[j_s, j_t] = T.exp2(
g_fragment[j_s, j_t] * 1.442695
)
else:
g_fragment[j_s, j_t] = 0
# Ag = G * Ar * b
for j_s, j_t in T.Parallel(block_S, block_S):
a_fragment[j_s, j_t] = a_shared[i_s % 2, j_s, j_t]
for j_s, j_t in T.Parallel(block_S, block_S):
a_fragment[j_s, j_t] *= g_fragment[j_s, j_t]
for j_s, j_t in T.Parallel(block_S, block_S):
a_fragment[j_s, j_t] *= b_shared[i_s % 2, j_t]
for j_s, j_t in T.Parallel(block_S, block_S):
a_shared[i_s % 2, j_s, j_t] = a_fragment[j_s, j_t]
# [STAGE 0] 2
T.barrier_wait(bar_1, i_s % 2)
# O = Q @ S
T.gemm_v1(
q_shared[i_s % 2, :, :], h_shared, o_fragment, clear_accum=True
)
# [STAGE 0] 3
# Pg = s * G * P
for j_s, j_t in T.Parallel(block_S, block_S):
p_fragment[j_s, j_t] *= scale * g_fragment[j_s, j_t]
# S1[1] Pg
T.copy(p_fragment, p_shared)
T.barrier_arrive(bar_3)
# O = s * g * O
for j_s, j_k in T.Parallel(block_S, DK):
o_fragment[j_s, j_k] *= scale * g_exp_shared[j_s]
# [STAGE 0] 4
T.barrier_wait(bar_4, i_s % 2)
# O += Pg @ Vd
T.gemm_v1(p_shared, vd_shared, o_fragment, clear_accum=False)
T.barrier_arrive(bar_5)
# [STAGE 0] 5
T.barrier_wait(bar_5, i_s % 2)
# S2[S] O
T.copy(o_fragment, o_shared)
T.barrier_arrive(data_is_free[i_s % 2])
T.barrier_arrive(bar_o)
else:
T.set_max_nreg(PRODUCER_NREG, 0)
if tx < 384 + 32:
for i_s in T.serial(num_iters):
T.barrier_wait(data_is_free[i_s % 2], (i_s // 2 + 1) % 2)
left = seq_start_idx + i_s * block_S
right = left + block_S
# Load Q
T.copy(
q[batch_idx, left:right, bhg, 0:DK], q_shared[i_s % 2, :, :]
)
# Load K
T.copy(
k[batch_idx, left:right, bhg, 0:DK], k_shared[i_s % 2, :, :]
)
T.barrier_arrive(data_is_ready[i_s % 2])
elif tx < 384 + 64:
for i_s in T.serial(num_iters):
T.barrier_wait(data_is_free[i_s % 2], (i_s // 2 + 1) % 2)
left = seq_start_idx + i_s * block_S
right = left + block_S
# Load V
T.copy(
v[
batch_idx,
left:right,
bh,
bv * block_DV : (bv + 1) * block_DV,
],
v_shared[i_s % 2, :, :],
)
# Load beta
if right <= seq_end_idx:
for j_s in T.Parallel(block_S):
b_shared[i_s % 2, j_s] = b[batch_idx, left + j_s, bh]
else:
for j_s in T.Parallel(block_S):
if left + j_s < seq_end_idx:
b_shared[i_s % 2, j_s] = b[
batch_idx, left + j_s, bh
]
else:
b_shared[i_s % 2, j_s] = 0
T.barrier_arrive(data_is_ready[i_s % 2])
elif tx < 384 + 96:
for i_s in T.serial(num_iters):
T.barrier_wait(data_is_free[i_s % 2], (i_s // 2 + 1) % 2)
left = seq_start_idx + i_s * block_S
right = left + block_S
# Load A
T.copy(
a[batch_idx, left:right, bh, 0:block_S],
a_shared[i_s % 2, :, :],
)
# Load gamma
if right <= seq_end_idx:
for j_s in T.Parallel(block_S):
g_shared[i_s % 2, j_s] = g[batch_idx, left + j_s, bh]
else:
for j_s in T.Parallel(block_S):
if left + j_s < seq_end_idx:
g_shared[i_s % 2, j_s] = g[
batch_idx, left + j_s, bh
]
else:
g_shared[i_s % 2, j_s] = g[
batch_idx, seq_end_idx - 1, bh
]
T.barrier_arrive(data_is_ready[i_s % 2])
else:
for i_s in T.serial(num_unmasked_iters):
right = seq_start_idx + i_s * block_S
left = right - block_S
T.barrier_arrive(bar_0)
T.barrier_wait(bar_0, i_s % 2)
# Store O
if i_s > 0 and store_o:
T.copy(
o_shared,
o[
batch_idx,
left:right,
bh,
bv * block_DV : (bv + 1) * block_DV,
],
)
T.barrier_arrive(bar_5)
T.barrier_wait(bar_1, i_s % 2)
# Store S
if store_h:
T.copy(
h_shared,
h[
batch_idx,
chunk_start_idx + i_s,
bh,
0:DK,
bv * block_DV : (bv + 1) * block_DV,
],
)
if num_unmasked_iters < num_iters:
seq_split_idx = seq_start_idx + num_unmasked_iters * block_S
chunk_split_idx = chunk_start_idx + num_unmasked_iters
T.barrier_arrive(bar_0)
T.barrier_wait(bar_0, num_unmasked_iters % 2)
# Store O
if num_unmasked_iters > 0 and store_o:
T.copy(
o_shared,
o[
batch_idx,
seq_split_idx - block_S : seq_split_idx,
bh,
bv * block_DV : (bv + 1) * block_DV,
],
)
T.barrier_arrive(bar_5)
T.barrier_wait(bar_1, num_unmasked_iters % 2)
# Store S
if store_h:
T.copy(
h_shared,
h[
batch_idx,
chunk_split_idx,
bh,
0:DK,
bv * block_DV : (bv + 1) * block_DV,
],
)
seq_split_idx = seq_start_idx + (num_iters - 1) * block_S
# Store O
T.barrier_wait(bar_o, 0)
if store_o:
for j_s, j_v in T.Parallel(block_S, block_DV):
with T.If(seq_split_idx + j_s < seq_end_idx):
with T.Then():
o[
batch_idx,
seq_split_idx + j_s,
bh,
bv * block_DV + j_v,
] = o_shared[j_s, j_v]
return tilelang_fused_chunk_gdr_fwd_kernel
def fused_gdr_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
a: torch.Tensor,
g: torch.Tensor,
b: torch.Tensor,
scale: float | None = None,
initial_state: torch.Tensor | None = None,
output_final_state: bool = True,
output_h: bool = False,
output_o: bool = True,
cu_seqlens: torch.LongTensor | None = None,
cp_seq_map: torch.LongTensor | None = None,
raw_cu_seqlens: torch.LongTensor | None = None,
chunk_size: int = 64,
):
batch_size, num_tokens, Hg, K = k.shape
_, _, H, V = v.shape
scale = scale or K ** (-0.5)
assert K == V == 128
assert chunk_size == 64
if cu_seqlens is None:
real_batch_size = batch_size
num_chunks = tilelang.cdiv(num_tokens, chunk_size) if output_h else 0
cu_seqlens = torch.empty((batch_size + 1), dtype=torch.int32, device=k.device)
chunk_offsets = torch.empty(
(batch_size + 1), dtype=torch.int32, device=k.device
)
seqlen_dtype = torch.int32
is_varlen = False
else:
real_batch_size = len(cu_seqlens) - 1
chunk_offsets, num_chunks = prepare_chunk_offsets(cu_seqlens, chunk_size)
chunk_offsets = chunk_offsets.to(cu_seqlens.dtype)
num_chunks = num_chunks if output_h else 0
seqlen_dtype = cu_seqlens.dtype
is_varlen = True
if cp_seq_map is None:
cp_seq_map = torch.empty(
(real_batch_size,), dtype=seqlen_dtype, device=k.device
)
is_cp = False
else:
is_cp = True
use_initial_state = initial_state is not None
if initial_state is None:
initial_state = torch.empty(
(real_batch_size, H, K, V), dtype=torch.float32, device=k.device
)
h = torch.empty((batch_size, num_chunks, H, K, V), dtype=k.dtype, device=k.device)
if raw_cu_seqlens is None:
raw_cu_seqlens = torch.empty(
(real_batch_size + 1,), dtype=seqlen_dtype, device=k.device
)
final_state = torch.empty(
(real_batch_size, H, K, V), dtype=torch.float32, device=k.device
)
else:
final_state = torch.empty(
(raw_cu_seqlens.shape[0] - 1, H, K, V), dtype=torch.float32, device=k.device
)
o = torch.empty_like(v)
grid_size = real_batch_size * H
if grid_size >= TARGET_NUM_CTAS:
block_DV = 128
elif grid_size * 2 >= TARGET_NUM_CTAS:
block_DV = 64
else:
block_DV = 32
tilelang_fused_chunk_gdr_fwd_kernel = tilelang_fused_chunk_gdr_fwd(
H,
Hg,
K,
V,
chunk_size,
scale,
qkva_dtype=q.dtype,
g_dtype=g.dtype,
b_dtype=b.dtype,
h0_dtype=initial_state.dtype,
ht_dtype=final_state.dtype,
h_dtype=h.dtype,
o_dtype=o.dtype,
seqlen_dtype=seqlen_dtype,
accum_dtype="float32",
use_initial_state=use_initial_state,
store_final_state=output_final_state,
store_h=output_h,
store_o=output_o,
is_varlen=is_varlen,
is_cp=is_cp,
block_DV=block_DV,
)
tilelang_fused_chunk_gdr_fwd_kernel(
q,
k,
v,
a,
g,
b,
initial_state,
cu_seqlens,
chunk_offsets,
cp_seq_map,
raw_cu_seqlens,
o,
h,
final_state,
)
if not output_final_state:
final_state = None
if not output_h:
h = None
if not output_o:
o = None
return o, h, final_state

View File

@@ -0,0 +1,345 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
from typing import Optional
import torch
import tilelang
import tilelang.language as T
from flash_qla.utils import prepare_chunk_indices
@tilelang.jit(
# out_idx=[-1],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
# tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
# tilelang.PassConfigKey.TL_ENABLE_ASYNC_COPY: True,
},
)
def tilelang_kkt_solve(
H,
Hg,
DK,
chunk_size,
accum_dtype,
qkva_dtype,
b_dtype,
seqlen_dtype,
is_varlen,
):
data_batch_size = T.dynamic("data_batch_size")
real_batch_size = T.dynamic("real_batch_size")
num_tokens = T.dynamic("num_tokens")
num_chunks = T.dynamic("num_chunks")
block_S = chunk_size
k_shape = (data_batch_size, num_tokens, Hg, DK)
a_shape = (data_batch_size, num_tokens, H, chunk_size)
b_shape = (data_batch_size, num_tokens, H)
@T.macro
def kernel_body(
bb,
bc,
bh,
bhg,
batch_idx,
chunk_idx,
seq_start_idx,
seq_end_idx,
k,
b,
a,
):
left = seq_start_idx + chunk_idx * block_S
right = left + block_S
k_shared = T.alloc_shared((block_S, DK), dtype=qkva_dtype)
b_shared = T.alloc_shared((block_S), dtype=accum_dtype, scope="shared")
a64_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
a16i_row = T.alloc_fragment((4, 16), dtype=accum_dtype)
a16i_sum = T.alloc_fragment((4, 16), dtype=accum_dtype)
a16i_shared = T.alloc_shared((4, 17, 16), dtype=accum_dtype)
a16o_shared = T.alloc_shared((2, 17, 16), dtype=accum_dtype)
a16o_fragment = T.alloc_fragment((2, 16, 16), dtype=accum_dtype)
a32i_fragment = T.alloc_fragment((2, 32, 32), dtype=accum_dtype)
a32i0_shared = T.alloc_shared((32, 32), dtype=accum_dtype)
a32i1_shared = T.alloc_shared((32, 32), dtype=accum_dtype)
a32o_shared = T.alloc_shared((32, 32), dtype=accum_dtype)
a32o_fragment = T.alloc_fragment((32, 32), dtype=accum_dtype)
a64_shared = T.alloc_shared((block_S, block_S), dtype=qkva_dtype)
T.annotate_layout(
{
a16i_shared: tilelang.layout.make_linear_layout(a16i_shared),
a16o_shared: tilelang.layout.make_linear_layout(a16o_shared),
}
)
k_is_ready = T.alloc_barrier(arrive_count=32)
a_is_ready = T.alloc_barrier(arrive_count=128)
tx = T.get_thread_binding()
PRODUCER_NREG = 24
CONSUMER_NREG = 64
if tx < 128:
T.set_max_nreg(CONSUMER_NREG, 1)
# Load b
if right <= seq_end_idx:
for j_s in T.Parallel(block_S):
b_shared[j_s] = b[bb, left + j_s, bh]
else:
for j_s in T.Parallel(block_S):
if left + j_s < seq_end_idx:
b_shared[j_s] = b[bb, left + j_s, bh]
else:
b_shared[j_s] = 0
T.barrier_wait(k_is_ready, 0)
# A = K @ K^T
T.gemm_v1(
k_shared, k_shared, a64_fragment, transpose_B=True, clear_accum=True
)
# A = b * A
for j_s, j_t in T.Parallel(block_S, block_S):
a64_fragment[j_s, j_t] *= b_shared[j_s]
# A = I + StrictLower(A)
for j_s, j_t in T.Parallel(block_S, block_S):
if j_s < j_t:
a64_fragment[j_s, j_t] = 0
elif j_s == j_t:
a64_fragment[j_s, j_t] = 1
# Prepare inversion input
for j_s, j_t in T.Parallel(block_S, block_S):
if j_s >= 32 and j_t < 32:
a32o_shared[j_s - 32, j_t] = -a64_fragment[j_s, j_t]
elif (j_s // 16) == (j_t // 16) + 1:
a16o_shared[j_s // 32, j_s % 16, j_t % 16] = -a64_fragment[j_s, j_t]
elif (j_s // 16) == (j_t // 16):
a16i_shared[j_s // 16, j_s % 16, j_t % 16] = a64_fragment[j_s, j_t]
# Diagonal 4x16x16
T.clear(a16i_row)
for k_s in T.unroll(1, 16):
for j_s, k_t in T.Parallel(4, 16):
if k_t < k_s:
a16i_row[j_s, k_t] = a16i_shared[j_s, k_s, k_t]
T.clear(a16i_sum)
for k_r in T.unroll(k_s):
for j_s, k_t in T.Parallel(4, 16):
a16i_sum[j_s, k_t] -= (
a16i_shared[j_s, k_r, k_t] * a16i_row[j_s, k_r]
)
for j_s, k_t in T.Parallel(4, 16):
if k_t < k_s:
a16i_shared[j_s, k_s, k_t] = a16i_sum[j_s, k_t]
# First level 2x16x16
T.clear(a16o_fragment)
for k_r in T.unroll(16):
for j_s, k_s, k_t in T.Parallel(2, 16, 16):
a16o_fragment[j_s, k_s, k_t] += (
a16i_shared[j_s * 2 + 1, k_s, k_r] * a16o_shared[j_s, k_r, k_t]
)
for j_s, k_s, k_t in T.Parallel(2, 16, 16):
a16o_shared[j_s, k_t, k_s] = a16o_fragment[j_s, k_s, k_t]
T.clear(a16o_fragment)
for k_r in T.unroll(16):
for j_s, k_s, k_t in T.Parallel(2, 16, 16):
a16o_fragment[j_s, k_s, k_t] += (
a16o_shared[j_s, k_r, k_s] * a16i_shared[j_s * 2, k_r, k_t]
)
T.copy(a16o_fragment, a16o_shared[:, 0:16, 0:16])
# Second level 1x32x32
for j_s, k_s, k_t in T.Parallel(2, 32, 32):
if k_s < 16 and k_t >= 16:
a32i_fragment[j_s, k_s, k_t] = 0
for j_s, k_s, k_t in T.Parallel(2, 32, 32):
if k_s >= 16 and k_t < 16:
a32i_fragment[j_s, k_s, k_t] = a16o_shared[j_s, k_s - 16, k_t]
for j_s, k_s, k_t in T.Parallel(2, 32, 32):
if k_s // 16 == k_t // 16:
a32i_fragment[j_s, k_s, k_t] = a16i_shared[
j_s * 2 + k_s // 16, k_s % 16, k_t % 16
]
for j_s, k_s, k_t in T.Parallel(2, 32, 32):
if j_s == 0:
a32i0_shared[k_s, k_t] = a32i_fragment[j_s, k_s, k_t]
else:
a32i1_shared[k_s, k_t] = a32i_fragment[j_s, k_s, k_t]
T.gemm_v1(a32i1_shared, a32o_shared, a32o_fragment, clear_accum=True)
T.copy(a32o_fragment, a32o_shared)
T.gemm_v1(a32o_shared, a32i0_shared, a32o_fragment, clear_accum=True)
# Combine inversion output
for j_s, k_s, k_t in T.Parallel(2, 32, 32):
a64_shared[j_s * 32 + k_s, j_s * 32 + k_t] = a32i_fragment[
j_s, k_s, k_t
]
for k_s, k_t in T.Parallel(32, 32):
a64_shared[32 + k_s, k_t] = a32o_fragment[k_s, k_t]
for k_s, k_t in T.Parallel(32, 32):
a64_shared[k_s, 32 + k_t] = 0
T.barrier_arrive(a_is_ready)
else:
T.set_max_nreg(PRODUCER_NREG, 0)
if tx < 128 + 32:
# Load K
T.copy(k[bb, left:right, bhg, 0:DK], k_shared)
T.barrier_arrive(k_is_ready)
elif tx < 128 + 64:
T.barrier_wait(a_is_ready, 0)
# Save A (unmasked)
if right <= seq_end_idx:
T.copy(a64_shared, a[bb, left:right, bh, 0:block_S])
else:
T.barrier_wait(a_is_ready, 0)
# Save A (masked)
if right > seq_end_idx:
for j_s, j_t in T.Parallel(block_S, block_S):
if left + j_s < seq_end_idx:
a[bb, left + j_s, bh, j_t] = a64_shared[j_s, j_t]
if is_varlen:
@T.prim_func
def tilelang_kkt_solve_kernel(
k: T.Tensor(k_shape, dtype=qkva_dtype),
b: T.Tensor(b_shape, dtype=b_dtype),
cu_seqlens: T.Tensor([real_batch_size + 1], dtype=seqlen_dtype),
chunk_indices: T.Tensor([num_chunks, 2], dtype=seqlen_dtype),
a: T.Tensor(a_shape, dtype=qkva_dtype),
):
with T.Kernel(num_chunks * H, threads=256) as (bch,):
bc, bh = bch // H, bch % H
bhg = bh // (H // Hg)
batch_idx = T.alloc_var("int32")
chunk_idx = T.alloc_var("int32")
seq_start_idx = T.alloc_var("int32")
seq_end_idx = T.alloc_var("int32")
bb = 0
batch_idx = chunk_indices[bc, 0]
chunk_idx = chunk_indices[bc, 1]
seq_start_idx = cu_seqlens[batch_idx]
seq_end_idx = cu_seqlens[batch_idx + 1]
kernel_body(
bb,
bc,
bh,
bhg,
batch_idx,
chunk_idx,
seq_start_idx,
seq_end_idx,
k,
b,
a,
)
else:
@T.prim_func
def tilelang_kkt_solve_kernel(
k: T.Tensor(k_shape, dtype=qkva_dtype),
b: T.Tensor(b_shape, dtype=b_dtype),
a: T.Tensor(a_shape, dtype=qkva_dtype),
num_chunks: T.int32,
):
with T.Kernel(num_chunks * H, threads=256) as (bch,):
bc, bh = bch // H, bch % H
bhg = bh // (H // Hg)
batch_idx = T.alloc_var("int32")
chunk_idx = T.alloc_var("int32")
seq_start_idx = T.alloc_var("int32")
seq_end_idx = T.alloc_var("int32")
bb = bc % data_batch_size
batch_idx = bb
chunk_idx = bc // data_batch_size
seq_start_idx = 0
seq_end_idx = num_tokens
kernel_body(
bb,
bc,
bh,
bhg,
batch_idx,
chunk_idx,
seq_start_idx,
seq_end_idx,
k,
b,
a,
)
return tilelang_kkt_solve_kernel
def kkt_solve(
k: torch.Tensor,
b: torch.Tensor,
chunk_size: int = 64,
cu_seqlens: Optional[torch.LongTensor] = None,
):
batch_size, num_tokens, Hg, K = k.shape
_, _, H = b.shape
assert K == 128
assert chunk_size == 64
if cu_seqlens is None:
num_chunks = batch_size * tilelang.cdiv(num_tokens, chunk_size)
seqlen_dtype = "int32"
is_varlen = False
else:
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
seqlen_dtype = cu_seqlens.dtype
is_varlen = True
a = torch.empty(
(batch_size, num_tokens, H, chunk_size), dtype=k.dtype, device=k.device
)
tilelang_kkt_solve_kernel = tilelang_kkt_solve(
H,
Hg,
K,
chunk_size,
qkva_dtype=k.dtype,
b_dtype=b.dtype,
seqlen_dtype=seqlen_dtype,
accum_dtype="float32",
is_varlen=is_varlen,
)
if is_varlen:
tilelang_kkt_solve_kernel(k, b, cu_seqlens, chunk_indices, a)
else:
tilelang_kkt_solve_kernel(k, b, a, num_chunks)
return a

View File

@@ -0,0 +1,558 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
import torch
import tilelang
import tilelang.language as T
from flash_qla.utils import prepare_chunk_offsets
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def tilelang_prepare_h(
H,
Hg,
DK,
DV,
chunk_size,
accum_dtype,
qkva_dtype,
g_dtype,
b_dtype,
h0_dtype,
ht_dtype,
h_dtype,
seqlen_dtype,
use_initial_state,
store_final_state,
store_h,
is_varlen,
is_cp,
num_stages=2,
):
batch_size = T.dynamic("batch_size")
num_tokens = T.dynamic("num_tokens")
num_chunks = T.dynamic("num_chunks")
block_S = chunk_size
if is_varlen:
k_shape = (1, num_tokens, Hg, DK)
v_shape = (1, num_tokens, H, DV)
a_shape = (1, num_tokens, H, chunk_size)
g_shape = (1, num_tokens, H)
b_shape = (1, num_tokens, H)
h_shape = (1, num_chunks, H, DK, DV)
else:
k_shape = (batch_size, num_tokens, Hg, DK)
v_shape = (batch_size, num_tokens, H, DV)
a_shape = (batch_size, num_tokens, H, chunk_size)
g_shape = (batch_size, num_tokens, H)
b_shape = (batch_size, num_tokens, H)
h_shape = (batch_size, num_chunks, H, DK, DV)
h0_shape = (batch_size, H, DK, DV)
ht_shape = (batch_size, H, DK, DV)
m_shape = (batch_size, H, DK, DK)
@T.prim_func
def tilelang_prepare_h_kernel(
k: T.Tensor(k_shape, dtype=qkva_dtype),
v: T.Tensor(v_shape, dtype=qkva_dtype),
a: T.Tensor(a_shape, dtype=qkva_dtype),
g: T.Tensor(g_shape, dtype=g_dtype),
b: T.Tensor(b_shape, dtype=b_dtype),
h0: T.Tensor(h0_shape, dtype=h0_dtype),
cu_seqlens: T.Tensor([batch_size + 1], dtype=seqlen_dtype),
chunk_offsets: T.Tensor([batch_size + 1], dtype=seqlen_dtype),
num_warmup_chunks: T.Tensor([batch_size, H], dtype=seqlen_dtype),
h: T.Tensor(h_shape, dtype=h_dtype),
ht: T.Tensor(ht_shape, dtype=ht_dtype),
mt: T.Tensor(m_shape, dtype=ht_dtype),
):
with T.Kernel(batch_size * H, threads=512) as (bbh,):
bb, bh = bbh // H, bbh % H
bhg = bh // (H // Hg)
batch_idx = T.alloc_var("int32")
seq_start_idx = T.alloc_var("int32")
seq_end_idx = T.alloc_var("int32")
_seq_split_idx = T.alloc_var("int32")
chunk_start_idx = T.alloc_var("int32")
_chunk_split_idx = T.alloc_var("int32")
batch_idx = 0 if is_varlen else bb
seq_start_idx = cu_seqlens[bb] if is_varlen else 0
seq_end_idx = cu_seqlens[bb + 1] if is_varlen else num_tokens
chunk_start_idx = chunk_offsets[bb] if is_varlen else 0
num_iters = T.alloc_var("int32")
num_iters = (
num_warmup_chunks[bb, bh]
if is_cp
else T.ceildiv(seq_end_idx - seq_start_idx, block_S)
)
calc_mt = T.alloc_var("bool")
calc_mt = is_cp and num_iters >= T.ceildiv(
seq_end_idx - seq_start_idx, block_S
)
seq_start_idx = (
seq_end_idx - num_iters * block_S if is_cp else seq_start_idx
)
k_shared = T.alloc_shared((num_stages, block_S, DK), dtype=qkva_dtype)
v_shared = T.alloc_shared((num_stages, block_S, DV), dtype=qkva_dtype)
a_shared = T.alloc_shared((num_stages, block_S, block_S), dtype=qkva_dtype)
g_shared = T.alloc_shared(
(num_stages, block_S), dtype=accum_dtype, scope="shared"
)
b_shared = T.alloc_shared(
(num_stages, block_S), dtype=accum_dtype, scope="shared"
)
h_shared = T.alloc_shared((DK, DV), dtype=qkva_dtype)
x_shared = T.alloc_shared((block_S, DK), dtype=qkva_dtype)
y_shared = T.alloc_shared((block_S, DV), dtype=qkva_dtype)
m_shared_L = T.alloc_shared((DK, DK // 2), dtype=qkva_dtype)
m_shared_R = T.alloc_shared((DK, DK // 2), dtype=qkva_dtype)
z_shared_L = T.alloc_shared((block_S, DK // 2), dtype=qkva_dtype)
z_shared_R = T.alloc_shared((block_S, DK // 2), dtype=qkva_dtype)
g_rev_exp_shared = T.alloc_shared(
(block_S), dtype=accum_dtype, scope="shared"
)
h_fragment = T.alloc_fragment((DK, DV), dtype=accum_dtype)
x_fragment = T.alloc_fragment((block_S, DK), dtype=accum_dtype)
y_fragment = T.alloc_fragment((block_S, DV), dtype=accum_dtype)
m_fragment_L = T.alloc_fragment((DK, DK // 2), dtype=accum_dtype)
m_fragment_R = T.alloc_fragment((DK, DK // 2), dtype=accum_dtype)
z_fragment_L = T.alloc_fragment((block_S, DK // 2), dtype=accum_dtype)
z_fragment_R = T.alloc_fragment((block_S, DK // 2), dtype=accum_dtype)
g_last_local_S = T.alloc_local((1), dtype=accum_dtype)
g_last_local_X = T.alloc_local((1), dtype=accum_dtype)
g_last_local_Y = T.alloc_local((1), dtype=accum_dtype)
g_prod_X = T.alloc_fragment((1), dtype=accum_dtype)
g_prod_Y = T.alloc_fragment((1), dtype=accum_dtype)
data_is_ready = T.alloc_barrier(arrive_count=[96] * num_stages)
data_is_free = T.alloc_barrier(arrive_count=[384] * num_stages)
bar_0 = T.alloc_barrier(arrive_count=416)
bar_1 = T.alloc_barrier(arrive_count=256)
bar_2 = T.alloc_barrier(arrive_count=384)
bar_3 = T.alloc_barrier(arrive_count=128)
T.use_swizzle(10)
tx = T.get_thread_binding()
PRODUCER_NREG = 24
CONSUMER_S_NREG = 168
CONSUMER_X_NREG = 160
CONSUMER_Y_NREG = 160
if tx < 128:
T.set_max_nreg(CONSUMER_S_NREG, 1)
# Initialize S
if use_initial_state:
T.copy(h0[bb, bh, 0:DK, 0:DV], h_fragment)
else:
T.clear(h_fragment)
# Main Loop
for i_s in T.serial(num_iters):
# [STAGE = i_s % num_stages]
T.barrier_wait(
data_is_ready[i_s % num_stages], (i_s // num_stages + 0) % 2
)
T.barrier_arrive(bar_0)
# [STAGE = i_s % num_stages] 0
T.barrier_wait(bar_0, i_s % 2)
# S4[1] S
T.copy(h_fragment, h_shared)
T.barrier_arrive(bar_1)
# [STAGE = i_s % num_stages] 1
T.barrier_wait(bar_1, i_s % 2)
# S = g_last * S
g_last_local_S[0] = T.exp2(
g_shared[i_s % num_stages, block_S - 1] * 1.442695
)
for j_k, j_v in T.Parallel(DK, DV):
h_fragment[j_k, j_v] *= g_last_local_S[0]
T.barrier_arrive(bar_2)
# [STAGE = i_s % num_stages] 2
T.barrier_wait(bar_2, i_s % 2)
# S += X^T @ Y
T.gemm_v1(
x_shared,
y_shared,
h_fragment,
transpose_A=True,
clear_accum=False,
)
T.barrier_arrive(bar_3)
T.barrier_arrive(data_is_free[i_s % num_stages])
# Store final S
if store_final_state:
T.copy(h_fragment, ht[bb, bh, 0:DK, 0:DV])
elif tx < 256:
T.set_max_nreg(CONSUMER_X_NREG, 1)
if calc_mt:
for j_k, j_v in T.Parallel(DK, DK // 2):
if j_k == j_v + DK // 2:
m_fragment_R[j_k, j_v] = 1
else:
m_fragment_R[j_k, j_v] = 0
g_prod_X[0] = 0
# Main Loop
for i_s in T.serial(num_iters):
# [STAGE = i_s % num_stages]
T.barrier_wait(
data_is_ready[i_s % num_stages], (i_s // num_stages + 0) % 2
)
T.barrier_arrive(bar_0)
# [STAGE = i_s % num_stages] 0
T.barrier_wait(bar_0, i_s % 2)
# X = A^T @ K
T.gemm_v1(
a_shared[i_s % num_stages, :, :],
k_shared[i_s % num_stages, :, :],
x_fragment,
transpose_A=True,
clear_accum=True,
)
# [STAGE = i_s % num_stages] 1
# X = - b * X
for j_s, j_k in T.Parallel(block_S, DK):
x_fragment[j_s, j_k] *= -b_shared[i_s % num_stages, j_s]
# S2[1] X
T.copy(x_fragment, x_shared)
T.barrier_arrive(bar_2)
if calc_mt:
# [STAGE = i_s % num_stages] 2
g_prod_X[0] += g_shared[i_s % num_stages, block_S - 1]
# S4[2] M
T.copy(m_fragment_R, m_shared_R)
# [STAGE = i_s % num_stages] 3
T.barrier_wait(bar_3, i_s % 2)
# Z = K @ M
T.gemm_v1(
k_shared[i_s % num_stages, :, :],
m_shared_R,
z_fragment_R,
clear_accum=True,
)
# S4[2] Z
T.copy(z_fragment_R, z_shared_R)
# M += X^T @ Z
T.gemm_v1(
x_shared,
z_shared_R,
m_fragment_R,
transpose_A=True,
clear_accum=False,
)
T.barrier_arrive(data_is_free[i_s % num_stages])
if calc_mt:
g_last_local_X[0] = T.exp2(g_prod_X[0] * 1.442695)
for j_k, j_v in T.Parallel(DK, DK // 2):
m_fragment_R[j_k, j_v] *= g_last_local_X[0]
T.copy(m_fragment_R, mt[bb, bh, 0:DK, DK // 2 :])
elif tx < 384:
T.set_max_nreg(CONSUMER_Y_NREG, 1)
if calc_mt:
for j_k, j_v in T.Parallel(DK, DK // 2):
if j_k == j_v:
m_fragment_L[j_k, j_v] = 1
else:
m_fragment_L[j_k, j_v] = 0
g_prod_Y[0] = 0
# Main Loop
for i_s in T.serial(num_iters):
# [STAGE = i_s % num_stages]
T.barrier_wait(
data_is_ready[i_s % num_stages], (i_s // num_stages + 0) % 2
)
T.barrier_arrive(bar_0)
# [STAGE = i_s % num_stages] 0
T.barrier_wait(bar_0, i_s % 2)
# Precompute g_last/g
g_last_local_Y[0] = g_shared[i_s % num_stages, block_S - 1]
for j_s in T.Parallel(block_S):
g_rev_exp_shared[j_s] = T.exp2(
(g_last_local_Y[0] - g_shared[i_s % num_stages, j_s])
* 1.442695
)
g_last_local_Y[0] = T.exp2(g_last_local_Y[0] * 1.442695)
T.barrier_arrive(bar_1)
# [STAGE = i_s % num_stages] 1
T.barrier_wait(bar_1, i_s % 2)
# U = K @ S
T.gemm_v1(
k_shared[i_s % num_stages, :, :],
h_shared,
y_fragment,
clear_accum=True,
)
# Y = g_last * U - g_last/g * V
for j_s, j_v in T.Parallel(block_S, DV):
y_fragment[j_s, j_v] *= g_last_local_Y[0]
for j_s, j_v in T.Parallel(block_S, DV):
y_fragment[j_s, j_v] -= (
v_shared[i_s % num_stages, j_s, j_v] * g_rev_exp_shared[j_s]
)
# S2[2] Y
T.copy(y_fragment, y_shared)
T.barrier_arrive(bar_2)
if calc_mt:
# [STAGE = i_s % num_stages] 2
g_prod_Y[0] += g_shared[i_s % num_stages, block_S - 1]
# S4[2] M
T.copy(m_fragment_L, m_shared_L)
# [STAGE = i_s % num_stages] 3
T.barrier_wait(bar_3, i_s % 2)
# Z = K @ M
T.gemm_v1(
k_shared[i_s % num_stages, :, :],
m_shared_L,
z_fragment_L,
clear_accum=True,
)
# S4[2] Z
T.copy(z_fragment_L, z_shared_L)
# M += X^T @ Z
T.gemm_v1(
x_shared,
z_shared_L,
m_fragment_L,
transpose_A=True,
clear_accum=False,
)
T.barrier_arrive(data_is_free[i_s % num_stages])
if calc_mt:
g_last_local_Y[0] = T.exp2(g_prod_Y[0] * 1.442695)
for j_k, j_v in T.Parallel(DK, DK // 2):
m_fragment_L[j_k, j_v] *= g_last_local_Y[0]
T.copy(m_fragment_L, mt[bb, bh, 0:DK, : DK // 2])
else:
T.set_max_nreg(PRODUCER_NREG, 0)
if tx < 384 + 32:
for i_s in T.serial(num_iters):
T.barrier_wait(
data_is_free[i_s % num_stages], (i_s // num_stages + 1) % 2
)
left = seq_start_idx + i_s * block_S
right = left + block_S
# Load K
T.copy(
k[batch_idx, left:right, bhg, 0:DK],
k_shared[i_s % num_stages, :, :],
)
T.barrier_arrive(data_is_ready[i_s % num_stages])
elif tx < 384 + 64:
for i_s in T.serial(num_iters):
T.barrier_wait(
data_is_free[i_s % num_stages], (i_s // num_stages + 1) % 2
)
left = seq_start_idx + i_s * block_S
right = left + block_S
# Load V
T.copy(
v[batch_idx, left:right, bh, 0:DV],
v_shared[i_s % num_stages, :, :],
)
# Load A TODO: Mask A for the last chunk
T.copy(
a[batch_idx, left:right, bh, 0:block_S],
a_shared[i_s % num_stages, :, :],
)
T.barrier_arrive(data_is_ready[i_s % num_stages])
elif tx < 384 + 96:
for i_s in T.serial(num_iters):
T.barrier_wait(
data_is_free[i_s % num_stages], (i_s // num_stages + 1) % 2
)
left = seq_start_idx + i_s * block_S
right = left + block_S
# Load gamma
if right <= seq_end_idx:
for j_s in T.Parallel(block_S):
g_shared[i_s % num_stages, j_s] = g[
batch_idx, left + j_s, bh
]
else:
for j_s in T.Parallel(block_S):
if left + j_s < seq_end_idx:
g_shared[i_s % num_stages, j_s] = g[
batch_idx, left + j_s, bh
]
else:
g_shared[i_s % num_stages, j_s] = g[
batch_idx, seq_end_idx - 1, bh
]
# Load beta
if right <= seq_end_idx:
for j_s in T.Parallel(block_S):
b_shared[i_s % num_stages, j_s] = b[
batch_idx, left + j_s, bh
]
else:
for j_s in T.Parallel(block_S):
if left + j_s < seq_end_idx:
b_shared[i_s % num_stages, j_s] = b[
batch_idx, left + j_s, bh
]
else:
b_shared[i_s % num_stages, j_s] = 0
T.barrier_arrive(data_is_ready[i_s % num_stages])
else:
for i_s in T.serial(num_iters):
T.barrier_arrive(bar_0)
T.barrier_wait(bar_0, i_s % 2)
T.barrier_wait(bar_1, i_s % 2)
# Store S
if store_h:
T.copy(
h_shared,
h[batch_idx, chunk_start_idx + i_s, bh, 0:DK, 0:DV],
)
return tilelang_prepare_h_kernel
def fused_gdr_h(
k: torch.Tensor,
v: torch.Tensor,
a: torch.Tensor,
g: torch.Tensor,
b: torch.Tensor,
initial_state: torch.Tensor | None = None,
output_final_state: bool = True,
output_h: bool = True,
chunk_size: int = 64,
cu_seqlens: torch.LongTensor | None = None,
num_warmup_chunks: torch.LongTensor | None = None,
):
batch_size, num_tokens, Hg, K = k.shape
_, _, H, V = v.shape
assert K == V == 128
assert chunk_size == 64
if cu_seqlens is None:
assert num_warmup_chunks is None
real_batch_size = batch_size
num_chunks = tilelang.cdiv(num_tokens, chunk_size) if output_h else 0
cu_seqlens = torch.empty((batch_size + 1), dtype=torch.int32, device=k.device)
chunk_offsets = torch.empty(
(batch_size + 1), dtype=torch.int32, device=k.device
)
is_varlen = False
is_cp = False
else:
real_batch_size = len(cu_seqlens) - 1
chunk_offsets, num_chunks = prepare_chunk_offsets(cu_seqlens, chunk_size)
chunk_offsets = chunk_offsets.to(cu_seqlens.dtype)
num_chunks = num_chunks if output_h else 0
is_varlen = True
if num_warmup_chunks is None:
num_warmup_chunks = torch.empty(
(real_batch_size, H), dtype=cu_seqlens.dtype, device=k.device
)
is_cp = False
else:
is_cp = True
use_initial_state = initial_state is not None
if initial_state is None:
initial_state = torch.empty(
(real_batch_size, H, K, V), dtype=torch.float32, device=k.device
)
h = torch.empty((batch_size, num_chunks, H, K, V), dtype=k.dtype, device=k.device)
ht_dtype = k.dtype if is_cp else torch.float32
final_state = torch.empty(
(real_batch_size, H, K, V), dtype=ht_dtype, device=k.device
)
final_correction = torch.empty(
(real_batch_size, H, K, K), dtype=ht_dtype, device=k.device
)
tilelang_prepare_h_kernel = tilelang_prepare_h(
H,
Hg,
K,
V,
chunk_size,
qkva_dtype=k.dtype,
g_dtype=g.dtype,
b_dtype=b.dtype,
h0_dtype=initial_state.dtype,
ht_dtype=final_state.dtype,
h_dtype=h.dtype,
seqlen_dtype=cu_seqlens.dtype,
accum_dtype="float32",
use_initial_state=use_initial_state,
store_final_state=output_final_state,
store_h=output_h,
is_varlen=is_varlen,
is_cp=is_cp,
)
tilelang_prepare_h_kernel(
k,
v,
a,
g,
b,
initial_state,
cu_seqlens,
chunk_offsets,
num_warmup_chunks,
h,
final_state,
final_correction,
)
if not output_final_state:
final_state = None
final_correction = None
if not output_h:
h = None
return h, final_state, final_correction