first commit

This commit is contained in:
2026-06-14 23:49:03 +08:00
commit 3f95e2939d
35 changed files with 6764 additions and 0 deletions

16
flash_qla/__init__.py Normal file
View File

@@ -0,0 +1,16 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
__version__ = "0.1.0"
from flash_qla.ops.gated_delta_rule.chunk import (
chunk_gated_delta_rule_fwd,
chunk_gated_delta_rule_bwd,
chunk_gated_delta_rule,
)
__all__ = [
"chunk_gated_delta_rule_fwd",
"chunk_gated_delta_rule_bwd",
"chunk_gated_delta_rule",
]

View File

@@ -0,0 +1,7 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
from .gated_delta_rule import chunk_gated_delta_rule
__all__ = ["chunk_gated_delta_rule"]

View File

@@ -0,0 +1,7 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
from .chunk import chunk_gated_delta_rule
__all__ = ["chunk_gated_delta_rule"]

View File

@@ -0,0 +1,237 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
import torch
import tilelang
from flash_qla.utils import l2norm
from flash_qla.ops.utils import chunk_local_cumsum, group_reduce_vector
if tilelang.contrib.nvcc.get_target_compute_version() == "9.0":
from .hopper import fused_gdr_fwd, fused_gdr_bwd, fused_gdr_h, kkt_solve
else:
raise ValueError("FlashQLA now support sm90 only.")
from .cp_context import intra_card_cp_preprocess
def chunk_gated_delta_rule_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float | None = None,
initial_state: torch.Tensor | None = None,
cu_seqlens: torch.LongTensor | None = None,
output_final_state: bool = True,
output_h: bool = False,
auto_cp: bool = True,
):
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
A = kkt_solve(
k=k,
b=beta,
cu_seqlens=cu_seqlens,
)
if auto_cp:
initial_state, cu_seqlens, cp_seq_map, raw_cu_seqlens = (
intra_card_cp_preprocess(
k=k,
v=v,
a=A,
g=g,
b=beta,
raw_h0=initial_state,
raw_cu_seqlens=cu_seqlens,
)
)
else:
cp_seq_map = None
raw_cu_seqlens = None
o, h, final_state = fused_gdr_fwd(
q=q,
k=k,
v=v,
a=A,
g=g,
b=beta,
scale=scale,
initial_state=initial_state,
output_final_state=output_final_state,
output_h=output_h,
output_o=True,
cu_seqlens=cu_seqlens,
cp_seq_map=cp_seq_map,
raw_cu_seqlens=raw_cu_seqlens,
)
return g, A, o, h, final_state
def chunk_gated_delta_rule_bwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
A: torch.Tensor,
do: torch.Tensor,
dht: torch.Tensor | None = None,
scale: float | None = None,
initial_state: torch.Tensor | None = None,
cu_seqlens: torch.LongTensor | None = None,
):
h, _, _ = fused_gdr_h(
k=k,
v=v,
a=A,
g=g,
b=beta,
initial_state=initial_state,
output_final_state=False,
output_h=True,
cu_seqlens=cu_seqlens,
)
dq, dk, dv, dg, db, dh0 = fused_gdr_bwd(
q=q,
k=k,
v=v,
a=A,
g=g,
b=beta,
do=do,
dht=dht,
h=h,
scale=scale,
cu_seqlens=cu_seqlens,
)
Hg, H = k.shape[-2], v.shape[-2]
if Hg < H:
dq = group_reduce_vector(dq, Hg)
dk = group_reduce_vector(dk, Hg)
assert dg.dtype == torch.float32, "dg should be fp32"
dg = chunk_local_cumsum(dg, chunk_size=64, reverse=True, cu_seqlens=cu_seqlens)
return dq, dk, dv, db, dg, dh0
class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
@staticmethod
@torch.amp.custom_fwd(device_type="cuda")
def forward(
ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float | None = None,
initial_state: torch.Tensor | None = None,
output_final_state: bool = False,
cu_seqlens: torch.LongTensor | None = None,
):
q_orig = q
k_orig = k
g, A, o, _, final_state = chunk_gated_delta_rule_fwd(
q=q,
k=k,
v=v,
g=g,
beta=beta,
scale=scale,
initial_state=initial_state,
output_final_state=output_final_state,
output_h=False,
cu_seqlens=cu_seqlens,
)
ctx.save_for_backward(q_orig, k_orig, v, g, beta, A, initial_state, cu_seqlens)
ctx.scale = scale
return o.to(q.dtype), final_state
@staticmethod
@torch.amp.custom_bwd(device_type="cuda")
def backward(ctx, do: torch.Tensor, dht: torch.Tensor):
q_orig, k_orig, v, g, beta, A, initial_state, cu_seqlens = ctx.saved_tensors
dq, dk, dv, db, dg, dh0 = chunk_gated_delta_rule_bwd(
q=q_orig,
k=k_orig,
v=v,
g=g,
beta=beta,
A=A,
do=do,
dht=dht,
scale=ctx.scale,
initial_state=initial_state,
cu_seqlens=cu_seqlens,
)
return (
dq.to(q_orig),
dk.to(k_orig),
dv.to(v),
dg.to(g),
db.to(beta),
None,
dh0,
None,
None,
)
@torch.compiler.disable
def chunk_gated_delta_rule(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
use_qk_l2norm_in_kernel: bool = False,
cu_seqlens: torch.LongTensor | None = None,
head_first: bool = False,
):
assert q.dtype == k.dtype == v.dtype
assert q.dtype != torch.float32, (
"ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16 or float16."
)
assert not head_first, "head_first=True is not supported."
assert v.shape[2] % k.shape[2] == 0, (
"num_qk_heads must be divisible to num_v_heads."
)
if cu_seqlens is not None:
if q.shape[0] != 1:
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing."
)
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
raise ValueError(
f"The number of initial states is expected to be equal to the number of input sequences, "
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
)
if scale is None:
scale = k.shape[-1] ** -0.5
if use_qk_l2norm_in_kernel:
q = l2norm(q)
k = l2norm(k)
o, final_state = ChunkGatedDeltaRuleFunction.apply(
q,
k,
v,
g,
beta,
scale,
initial_state,
output_final_state,
cu_seqlens,
)
return o, final_state

View File

@@ -0,0 +1,163 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
import math
import torch
import tilelang
from flash_qla.utils import tensor_cache
if tilelang.contrib.nvcc.get_target_compute_version() == "9.0":
from .hopper import get_warmup_chunks, fused_gdr_h, correct_initial_states
else:
raise ValueError("FlashQLA now support sm90 only.")
MULTI_PROCESSOR_COUNT = torch.cuda.get_device_properties().multi_processor_count
@tensor_cache
def _create_cu_seqlens(
batch_size: int,
num_tokens: int,
device_idx: int,
):
return (
torch.arange((batch_size + 1), dtype=torch.int32, device=f"cuda:{device_idx}")
* num_tokens
)
@tensor_cache
def _calc_cp_seqs(
raw_cu_seqlens: torch.LongTensor,
chunk_size: int,
num_v_heads: int,
):
# TODO: tilelang kernel
device = raw_cu_seqlens.device
seqlen_dtype = raw_cu_seqlens.dtype
raw_cu_seqlens = raw_cu_seqlens.tolist()
raw_batch_size = len(raw_cu_seqlens) - 1
seqlens = [raw_cu_seqlens[i + 1] - raw_cu_seqlens[i] for i in range(raw_batch_size)]
num_chunks = [tilelang.cdiv(x, chunk_size) for x in seqlens]
# autocp
H = num_v_heads
# Latency model: T = a·L_cp + b·(B·H·Lc/P) / L_cp + c
# Minimizing T yields the theoretical optimum: L_cp* ∝ √(B·H·Lc / P), where P = MULTI_PROCESSOR_COUNT, L_cp = max_local_chunks
# Scaled by empirical factor (3) and aligned to the nearest power of 2 for optimal SM scheduling & memory alignment.
max_local_chunks = 2 ** round(
math.log2(math.sqrt(H * sum(num_chunks) / MULTI_PROCESSOR_COUNT) * 3)
)
# Set min to 4 to ensure multi-stage pipelining in fused_gdr;
max_local_chunks = max(max_local_chunks, 4)
use_cp = False
cp_cu_seqlens = []
ht_mask = []
seq_map_c2r = []
seq_map_r2c = [0]
max_local_tokens = max_local_chunks * chunk_size
for i, c in enumerate(num_chunks):
s = raw_cu_seqlens[i]
e = raw_cu_seqlens[i + 1]
if c > max_local_chunks:
while s < e:
cp_cu_seqlens.append(s)
ht_mask.append(False)
seq_map_c2r.append(i)
s += max_local_tokens
ht_mask[-1] = True
else:
cp_cu_seqlens.append(s)
ht_mask.append(True)
seq_map_c2r.append(i)
seq_map_r2c.append(len(cp_cu_seqlens))
cp_cu_seqlens.append(raw_cu_seqlens[-1])
# Disable CP when B * H naturally saturates SM occupancy.
# For varlen inputs, use `total_chunks / max_seq_chunks` as effective B,
# since CP helps accelerate highly uneven sequence lengths.
Be = sum(num_chunks) / max(num_chunks)
use_cp = Be * H <= 40 or (Be * H <= 56 and max(num_chunks) >= 128)
if use_cp:
cp_cu_seqlens = torch.tensor(
cp_cu_seqlens, dtype=seqlen_dtype, device=device, requires_grad=False
)
seq_map_c2r = torch.tensor(seq_map_c2r, dtype=seqlen_dtype, device=device)
seq_map_r2c = torch.tensor(
seq_map_r2c, dtype=seqlen_dtype, device=device, requires_grad=False
)
ht_mask = torch.tensor(
ht_mask, dtype=torch.bool, device=device, requires_grad=False
)
else:
cp_cu_seqlens, seq_map_r2c, ht_mask = None, None, None
return use_cp, cp_cu_seqlens, seq_map_r2c, seq_map_c2r, ht_mask
def intra_card_cp_preprocess(
k: torch.Tensor,
v: torch.Tensor,
a: torch.Tensor,
g: torch.Tensor,
b: torch.Tensor,
raw_h0: torch.Tensor,
raw_cu_seqlens: torch.Tensor,
warmup_threshold: float = -10.0,
):
batch_size, num_tokens, num_k_heads, k_head_dim = k.shape
_, _, num_v_heads, v_head_dim = v.shape
chunk_size = a.shape[-1]
device = k.device
if batch_size > 1:
return raw_h0, raw_cu_seqlens, None, None
if raw_cu_seqlens is None:
raw_cu_seqlens = _create_cu_seqlens(batch_size, num_tokens, device.index)
use_cp, cp_cu_seqlens, seq_map_r2c, seq_map_c2r, ht_mask = _calc_cp_seqs(
raw_cu_seqlens,
chunk_size,
num_v_heads,
)
if not use_cp:
return raw_h0, raw_cu_seqlens, None, None
num_warmup_chunks, fallback_mask = get_warmup_chunks(
g=g,
cu_seqlens=cp_cu_seqlens,
ht_mask=ht_mask,
chunk_size=chunk_size,
threshold=warmup_threshold,
) # [cp_batch_size, num_v_heads]
_, ht, mt = fused_gdr_h(
k=k,
v=v,
a=a,
g=g,
b=b,
initial_state=None,
output_final_state=True,
output_h=False,
cu_seqlens=cp_cu_seqlens,
num_warmup_chunks=num_warmup_chunks,
) # [cp_batch_size, num_v_heads, k_head_dim, v_head_dim]
cp_h0 = correct_initial_states(
raw_h0=raw_h0,
ht_buffer=ht,
mt_buffer=mt,
fallback_mask=fallback_mask,
seq_map_r2c=seq_map_r2c,
)
return cp_h0, cp_cu_seqlens, seq_map_c2r, raw_cu_seqlens

View File

@@ -0,0 +1,18 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
from .fused_fwd import fused_gdr_fwd
from .fused_bwd import fused_gdr_bwd
from .prepare_h import fused_gdr_h
from .kkt_solve import kkt_solve
from .cp_fwd import get_warmup_chunks, correct_initial_states
__all__ = [
"fused_gdr_fwd",
"fused_gdr_bwd",
"fused_gdr_h",
"kkt_solve",
"get_warmup_chunks",
"correct_initial_states",
]

View File

@@ -0,0 +1,309 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
import torch
import tilelang
import tilelang.language as T
@tilelang.jit()
def tilelang_get_warmup_chunks(
num_heads,
chunk_size,
threshold,
accum_dtype,
g_dtype,
mask_dtype,
seqlen_dtype,
):
batch_size = T.dynamic("batch_size")
num_tokens = T.dynamic("num_tokens")
num_threads = tilelang.cdiv(num_heads, 32) * 32
@T.prim_func
def tilelang_get_warmup_chunks_kernel(
g: T.Tensor([1, num_tokens, num_heads], dtype=g_dtype),
ht_mask: T.Tensor([batch_size], dtype=mask_dtype),
cu_seqlens: T.Tensor([batch_size + 1], dtype=seqlen_dtype),
num_warmup_chunks: T.Tensor([batch_size, num_heads], dtype=seqlen_dtype),
fallback_mask: T.Tensor([batch_size, num_heads], dtype=mask_dtype),
):
with T.Kernel(batch_size, threads=num_threads) as (bb,):
if ht_mask[bb]:
for i_h in T.Parallel(num_heads):
num_warmup_chunks[bb, i_h] = 0
else:
seq_start_idx = T.alloc_var("int32")
seq_end_idx = T.alloc_var("int32")
num_iters = T.alloc_var("int32")
seq_start_idx = cu_seqlens[bb]
seq_end_idx = cu_seqlens[bb + 1]
num_iters = (seq_end_idx - seq_start_idx) // chunk_size
g_fragment = T.alloc_fragment((num_heads), dtype=accum_dtype)
g_cumsum = T.alloc_fragment((num_heads), dtype=accum_dtype)
n_fragment = T.alloc_fragment((num_heads), dtype=seqlen_dtype)
f_fragment = T.alloc_fragment((num_heads), dtype=mask_dtype)
T.clear(g_cumsum)
T.fill(n_fragment, num_iters)
T.fill(f_fragment, True)
for i_s in T.serial(num_iters):
for i_h in T.Parallel(num_heads):
g_fragment[i_h] = g[0, seq_end_idx - i_s * chunk_size - 1, i_h]
for i_h in T.Parallel(num_heads):
g_cumsum[i_h] += g_fragment[i_h]
for i_h in T.Parallel(num_heads):
if g_cumsum[i_h] < threshold and n_fragment[i_h] == num_iters:
n_fragment[i_h] = i_s + 1
f_fragment[i_h] = False
for i_h in T.Parallel(num_heads):
num_warmup_chunks[bb, i_h] = n_fragment[i_h]
for i_h in T.Parallel(num_heads):
fallback_mask[bb, i_h] = f_fragment[i_h]
return tilelang_get_warmup_chunks_kernel
def get_warmup_chunks(
g: torch.Tensor, # [1, num_total_tokens, num_v_heads]
cu_seqlens: torch.Tensor, # [cp_real_batch_size + 1]
ht_mask: torch.Tensor, # [cp_real_batch_size]
chunk_size: int = 64,
threshold: float = -10.0,
):
batch_size, num_tokens, num_heads = g.shape
real_batch_size = ht_mask.shape[0]
assert cu_seqlens.shape[0] == real_batch_size + 1
assert batch_size == 1
assert chunk_size == 64
tilelang_get_warmup_chunks_kernel = tilelang_get_warmup_chunks(
num_heads=num_heads,
chunk_size=chunk_size,
threshold=threshold,
accum_dtype="float32",
g_dtype=g.dtype,
mask_dtype=ht_mask.dtype,
seqlen_dtype=cu_seqlens.dtype,
)
num_warmup_chunks = torch.empty(
[real_batch_size, num_heads], dtype=cu_seqlens.dtype, device=cu_seqlens.device
)
fallback_mask = torch.empty(
[real_batch_size, num_heads], dtype=ht_mask.dtype, device=cu_seqlens.device
)
tilelang_get_warmup_chunks_kernel(
g, ht_mask, cu_seqlens, num_warmup_chunks, fallback_mask
)
return num_warmup_chunks, fallback_mask
@tilelang.jit()
def tilelang_correct_h0(
H,
DK,
DV,
res_dtype,
accum_dtype,
buffer_dtype,
seqlen_dtype,
mask_dtype,
use_raw_h0,
block_DV: int = 32,
):
cp_batch_size = T.dynamic("cp_batch_size")
raw_batch_size = T.dynamic("raw_batch_size")
@T.macro
def kernel_body(
bb,
bh,
bv,
seq_start_idx,
seq_end_idx,
num_iters,
ht_buffer,
mt_buffer,
fallback_mask,
seq_map_r2c,
cp_h0,
h_fragment,
):
h_shared = T.alloc_shared((DK, block_DV), dtype=buffer_dtype)
hd_shared = T.alloc_shared((DK, block_DV), dtype=buffer_dtype)
m_shared = T.alloc_shared((DK, DK), dtype=buffer_dtype)
T.copy(
h_fragment,
cp_h0[seq_start_idx, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV],
)
for i_s in T.Pipelined(num_iters - 1, num_stages=2):
if fallback_mask[seq_start_idx + i_s, bh]:
T.copy(h_fragment, hd_shared)
T.copy(
ht_buffer[
seq_start_idx + i_s, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV
],
h_shared,
)
T.copy(h_shared, h_fragment)
if fallback_mask[seq_start_idx + i_s, bh]:
T.copy(mt_buffer[seq_start_idx + i_s, bh, 0:DK, 0:DK], m_shared)
T.gemm(m_shared, hd_shared, h_fragment, clear_accum=False)
T.copy(
h_fragment,
cp_h0[
seq_start_idx + i_s + 1,
bh,
0:DK,
bv * block_DV : (bv + 1) * block_DV,
],
)
if use_raw_h0:
@T.prim_func
def tilelang_correct_h0_kernel(
raw_h0: T.Tensor([raw_batch_size, H, DK, DV], dtype=res_dtype),
ht_buffer: T.Tensor([cp_batch_size, H, DK, DV], dtype=buffer_dtype),
mt_buffer: T.Tensor([cp_batch_size, H, DK, DK], dtype=buffer_dtype),
fallback_mask: T.Tensor([cp_batch_size, H], dtype=mask_dtype),
seq_map_r2c: T.Tensor([raw_batch_size + 1], dtype=seqlen_dtype),
cp_h0: T.Tensor([cp_batch_size, H, DK, DV], dtype=res_dtype),
):
with T.Kernel(
T.ceildiv(DV, block_DV) * H * raw_batch_size, threads=128
) as (bbhv,):
bbh, bv = (
bbhv // T.ceildiv(DV, block_DV),
bbhv % T.ceildiv(DV, block_DV),
)
bb, bh = bbh // H, bbh % H
seq_start_idx = seq_map_r2c[bb]
seq_end_idx = seq_map_r2c[bb + 1]
num_iters = seq_end_idx - seq_start_idx
h_fragment = T.alloc_fragment((DK, block_DV), dtype=accum_dtype)
T.copy(
raw_h0[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV],
h_fragment,
)
kernel_body(
bb,
bh,
bv,
seq_start_idx,
seq_end_idx,
num_iters,
ht_buffer,
mt_buffer,
fallback_mask,
seq_map_r2c,
cp_h0,
h_fragment,
)
else:
@T.prim_func
def tilelang_correct_h0_kernel(
ht_buffer: T.Tensor([cp_batch_size, H, DK, DV], dtype=buffer_dtype),
mt_buffer: T.Tensor([cp_batch_size, H, DK, DK], dtype=buffer_dtype),
fallback_mask: T.Tensor([cp_batch_size, H], dtype=mask_dtype),
seq_map_r2c: T.Tensor([raw_batch_size + 1], dtype=seqlen_dtype),
cp_h0: T.Tensor([cp_batch_size, H, DK, DV], dtype=res_dtype),
):
with T.Kernel(
T.ceildiv(DV, block_DV) * H * raw_batch_size, threads=128
) as (bbhv,):
bbh, bv = (
bbhv // T.ceildiv(DV, block_DV),
bbhv % T.ceildiv(DV, block_DV),
)
bb, bh = bbh // H, bbh % H
seq_start_idx = seq_map_r2c[bb]
seq_end_idx = seq_map_r2c[bb + 1]
num_iters = seq_end_idx - seq_start_idx
h_fragment = T.alloc_fragment((DK, block_DV), dtype=accum_dtype)
T.clear(h_fragment)
kernel_body(
bb,
bh,
bv,
seq_start_idx,
seq_end_idx,
num_iters,
ht_buffer,
mt_buffer,
fallback_mask,
seq_map_r2c,
cp_h0,
h_fragment,
)
return tilelang_correct_h0_kernel
def correct_initial_states(
raw_h0: torch.Tensor
| None, # [raw_batch_size, num_v_heads, k_head_dim, v_head_dim]
ht_buffer: torch.Tensor, # [cp_batch_size, num_v_heads, k_head_dim, v_head_dim]
mt_buffer: torch.Tensor, # [cp_batch_size, num_v_heads, k_head_dim, k_head_dim]
fallback_mask: torch.Tensor, # [cp_batch_size, num_v_heads]
seq_map_r2c: torch.Tensor, # [raw_batch_size + 1]
):
cp_batch_size = fallback_mask.shape[0]
_, num_heads, k_head_dim, v_head_dim = ht_buffer.shape
assert k_head_dim == v_head_dim == 128
if raw_h0 is None:
res_dtype = torch.float32
use_raw_h0 = False
else:
res_dtype = raw_h0.dtype
use_raw_h0 = True
tilelang_correct_h0_kernel = tilelang_correct_h0(
H=num_heads,
DK=k_head_dim,
DV=v_head_dim,
res_dtype=res_dtype,
accum_dtype="float32",
buffer_dtype=ht_buffer.dtype,
seqlen_dtype=seq_map_r2c.dtype,
mask_dtype=fallback_mask.dtype,
use_raw_h0=use_raw_h0,
)
cp_h0 = torch.empty(
(cp_batch_size, num_heads, k_head_dim, v_head_dim),
dtype=res_dtype,
device=ht_buffer.device,
)
if use_raw_h0:
tilelang_correct_h0_kernel(
raw_h0,
ht_buffer,
mt_buffer,
fallback_mask,
seq_map_r2c,
cp_h0,
)
else:
tilelang_correct_h0_kernel(
ht_buffer,
mt_buffer,
fallback_mask,
seq_map_r2c,
cp_h0,
)
return cp_h0

View File

@@ -0,0 +1,985 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
import torch
import tilelang
import tilelang.language as T
from flash_qla.utils import prepare_chunk_offsets
@tilelang.jit(
# out_idx=[-5, -4, -3, -2, -1],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
tilelang.PassConfigKey.TL_DISABLE_DATA_RACE_CHECK: True,
},
)
def tilelang_fused_chunk_gdr_bwd(
H,
Hg,
DK,
DV,
chunk_size,
scale,
accum_dtype,
qkva_dtype,
g_dtype,
b_dtype,
h_dtype,
o_dtype,
seqlen_dtype,
is_varlen,
use_dht,
):
batch_size = T.dynamic("batch_size")
num_tokens = T.dynamic("num_tokens")
num_chunks = T.dynamic("num_chunks")
block_S = chunk_size
if is_varlen:
q_shape = (1, num_tokens, Hg, DK)
k_shape = (1, num_tokens, Hg, DK)
v_shape = (1, num_tokens, H, DV)
o_shape = (1, num_tokens, H, DV)
a_shape = (1, num_tokens, H, chunk_size)
g_shape = (1, num_tokens, H)
b_shape = (1, num_tokens, H)
h_shape = (1, num_chunks, H, DK, DV)
else:
q_shape = (batch_size, num_tokens, Hg, DK)
k_shape = (batch_size, num_tokens, Hg, DK)
v_shape = (batch_size, num_tokens, H, DV)
o_shape = (batch_size, num_tokens, H, DV)
a_shape = (batch_size, num_tokens, H, chunk_size)
g_shape = (batch_size, num_tokens, H)
b_shape = (batch_size, num_tokens, H)
h_shape = (batch_size, num_chunks, H, DK, DV)
h0_shape = (batch_size, H, DK, DV)
ht_shape = (batch_size, H, DK, DV)
@T.prim_func
def tilelang_fused_chunk_gdr_bwd_kernel(
do: T.Tensor(o_shape, dtype=o_dtype),
dht: T.Tensor(ht_shape, dtype=accum_dtype),
q: T.Tensor(q_shape, dtype=qkva_dtype),
k: T.Tensor(k_shape, dtype=qkva_dtype),
v: T.Tensor(v_shape, dtype=qkva_dtype),
a: T.Tensor(a_shape, dtype=qkva_dtype),
g: T.Tensor(g_shape, dtype=g_dtype),
b: T.Tensor(b_shape, dtype=b_dtype),
h: T.Tensor(h_shape, dtype=h_dtype),
cu_seqlens: T.Tensor([batch_size + 1], dtype=seqlen_dtype),
chunk_offsets: T.Tensor([batch_size + 1], dtype=seqlen_dtype),
dq: T.Tensor(v_shape, dtype=qkva_dtype),
dk: T.Tensor(v_shape, dtype=qkva_dtype),
dv: T.Tensor(v_shape, dtype=qkva_dtype),
dg: T.Tensor(g_shape, dtype=g_dtype),
db: T.Tensor(b_shape, dtype=b_dtype),
dh0: T.Tensor(h0_shape, dtype=accum_dtype),
):
with T.Kernel(batch_size * H, threads=512) as (bbh,):
bb, bh = bbh // H, bbh % H
bhg = bh // (H // Hg)
batch_idx = T.alloc_var("int32")
seq_start_idx = T.alloc_var("int32")
seq_end_idx = T.alloc_var("int32")
chunk_start_idx = T.alloc_var("int32")
batch_idx = 0 if is_varlen else bb
seq_start_idx = cu_seqlens[bb] if is_varlen else 0
seq_end_idx = cu_seqlens[bb + 1] if is_varlen else num_tokens
chunk_start_idx = chunk_offsets[bb] if is_varlen else 0
num_iters = T.alloc_var("int32")
num_iters = T.ceildiv(seq_end_idx - seq_start_idx, block_S)
# 2+2+2+2 + 1 + 4 = 13 units
do_shared = T.alloc_shared((block_S, DV), dtype=o_dtype)
q_shared = T.alloc_shared((block_S, DK), dtype=qkva_dtype)
k_shared = T.alloc_shared((block_S, DK), dtype=qkva_dtype)
v_shared = T.alloc_shared((block_S, DV), dtype=qkva_dtype)
a_shared = T.alloc_shared((block_S, block_S), dtype=qkva_dtype)
h_shared = T.alloc_shared((DK, DV), dtype=h_dtype)
g_shared = T.alloc_shared((block_S), dtype=accum_dtype, scope="shared")
g_exp_shared = T.alloc_shared((block_S), dtype=accum_dtype, scope="shared")
g_rev_exp_shared = T.alloc_shared(
(block_S), dtype=accum_dtype, scope="shared"
)
b_shared = T.alloc_shared((block_S), dtype=accum_dtype, scope="shared")
# 2 units
dqkv_shared = T.alloc_shared((block_S, DK), dtype=qkva_dtype)
dg_shared = T.alloc_shared((block_S), dtype=accum_dtype, scope="shared")
db_shared = T.alloc_shared((block_S), dtype=accum_dtype, scope="shared")
# 1+1 + 2+2+2 + 4 = 12 units
tmp_shared_1_1 = T.alloc_shared((block_S, block_S), dtype=qkva_dtype)
tmp_shared_1_2 = T.alloc_shared((block_S, block_S), dtype=qkva_dtype)
tmp_shared_1_3 = T.alloc_shared((block_S, block_S), dtype=qkva_dtype)
tmp_shared_2_1 = T.alloc_shared((block_S, DK), dtype=qkva_dtype)
tmp_shared_2_2 = T.alloc_shared((block_S, DK), dtype=qkva_dtype)
tmp_shared_2_3 = T.alloc_shared((block_S, DK), dtype=qkva_dtype)
tmp_shared_4_1 = T.alloc_shared((DK, DV), dtype=qkva_dtype)
# CONSUMER_K
dk_fragment = T.alloc_fragment((block_S, DK), dtype=accum_dtype)
dv_fragment = T.alloc_fragment((block_S, DK), dtype=accum_dtype)
odot_fragment_1 = T.alloc_fragment((block_S, DK), dtype=accum_dtype)
dg_fragment_1 = T.alloc_fragment((block_S), dtype=accum_dtype)
dg_last_local_1 = T.alloc_fragment((1), dtype=accum_dtype)
# CONSUMER_A
mask_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
p_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
a_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
dp_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
da_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
u_fragment = T.alloc_fragment((block_S, DK), dtype=accum_dtype)
dq_fragment = T.alloc_fragment((block_S, DK), dtype=accum_dtype)
db_fragment = T.alloc_fragment((block_S), dtype=accum_dtype)
odot_fragment_2 = T.alloc_fragment((block_S, DK), dtype=accum_dtype)
dg_fragment_2 = T.alloc_fragment((block_S), dtype=accum_dtype)
# CONSUMER_S
dh_fragment = T.alloc_fragment((DK, DV), dtype=accum_dtype)
_odot_fragment_3 = T.alloc_fragment((DK, DV), dtype=accum_dtype)
reduce_fragment = T.alloc_fragment((128, 2), dtype=accum_dtype)
dg_last_local_3 = T.alloc_fragment((1), dtype=accum_dtype)
g_last_local_3 = T.alloc_local((1), dtype=accum_dtype)
# 16 stages
bar_00 = T.alloc_barrier(arrive_count=448)
bar_01 = T.alloc_barrier(arrive_count=384)
bar_02 = T.alloc_barrier(arrive_count=288)
bar_03 = T.alloc_barrier(arrive_count=256)
bar_04 = T.alloc_barrier(arrive_count=416)
bar_05 = T.alloc_barrier(arrive_count=288)
bar_06 = T.alloc_barrier(arrive_count=256)
bar_07 = T.alloc_barrier(arrive_count=256)
bar_08 = T.alloc_barrier(arrive_count=384)
bar_09 = T.alloc_barrier(arrive_count=256)
bar_10 = T.alloc_barrier(arrive_count=288)
bar_11 = T.alloc_barrier(arrive_count=256)
bar_12 = T.alloc_barrier(arrive_count=128)
bar_13 = T.alloc_barrier(arrive_count=256)
bar_14 = T.alloc_barrier(arrive_count=256)
bar_15 = T.alloc_barrier(arrive_count=256)
T.annotate_layout(
{
do_shared: tilelang.layout.make_swizzled_layout(do_shared),
q_shared: tilelang.layout.make_swizzled_layout(q_shared),
k_shared: tilelang.layout.make_swizzled_layout(k_shared),
v_shared: tilelang.layout.make_swizzled_layout(v_shared),
a_shared: tilelang.layout.make_swizzled_layout(a_shared),
h_shared: tilelang.layout.make_swizzled_layout(h_shared),
dqkv_shared: tilelang.layout.make_swizzled_layout(dqkv_shared),
tmp_shared_1_1: tilelang.layout.make_swizzled_layout(
tmp_shared_1_1
),
tmp_shared_1_2: tilelang.layout.make_swizzled_layout(
tmp_shared_1_2
),
tmp_shared_1_3: tilelang.layout.make_swizzled_layout(
tmp_shared_1_3
),
tmp_shared_2_1: tilelang.layout.make_swizzled_layout(
tmp_shared_2_1
),
tmp_shared_2_2: tilelang.layout.make_swizzled_layout(
tmp_shared_2_2
),
tmp_shared_2_3: tilelang.layout.make_swizzled_layout(
tmp_shared_2_3
),
tmp_shared_4_1: tilelang.layout.make_swizzled_layout(
tmp_shared_4_1
),
}
)
# T.use_swizzle(10)
tx = T.get_thread_binding()
PRODUCER_NREG = 24
CONSUMER_K_NREG = 144
CONSUMER_A_NREG = 176
CONSUMER_S_NREG = 160
# Prefetch the last chunk of data
T.copy(
h[batch_idx, chunk_start_idx + num_iters - 1, bh, 0:DK, 0:DV], h_shared
)
for j_s, j_k in T.Parallel(block_S, DK):
if seq_start_idx + (num_iters - 1) * block_S + j_s < seq_end_idx:
q_shared[j_s, j_k] = q[
batch_idx,
seq_start_idx + (num_iters - 1) * block_S + j_s,
bhg,
j_k,
]
else:
q_shared[j_s, j_k] = 0
for j_s, j_k in T.Parallel(block_S, DK):
if seq_start_idx + (num_iters - 1) * block_S + j_s < seq_end_idx:
k_shared[j_s, j_k] = k[
batch_idx,
seq_start_idx + (num_iters - 1) * block_S + j_s,
bhg,
j_k,
]
else:
k_shared[j_s, j_k] = 0
for j_s, j_v in T.Parallel(block_S, DV):
if seq_start_idx + (num_iters - 1) * block_S + j_s < seq_end_idx:
v_shared[j_s, j_v] = v[
batch_idx,
seq_start_idx + (num_iters - 1) * block_S + j_s,
bh,
j_v,
]
else:
v_shared[j_s, j_v] = 0
for j_s, j_t in T.Parallel(block_S, block_S):
if seq_start_idx + (num_iters - 1) * block_S + j_s < seq_end_idx:
a_shared[j_s, j_t] = a[
batch_idx,
seq_start_idx + (num_iters - 1) * block_S + j_s,
bh,
j_t,
]
else:
a_shared[j_s, j_t] = 0
for j_s, j_v in T.Parallel(block_S, DV):
if seq_start_idx + (num_iters - 1) * block_S + j_s < seq_end_idx:
do_shared[j_s, j_v] = do[
batch_idx,
seq_start_idx + (num_iters - 1) * block_S + j_s,
bh,
j_v,
]
else:
do_shared[j_s, j_v] = 0
for j_s in T.Parallel(block_S):
if seq_start_idx + (num_iters - 1) * block_S + j_s < seq_end_idx:
g_shared[j_s] = g[
batch_idx, seq_start_idx + (num_iters - 1) * block_S + j_s, bh
]
else:
g_shared[j_s] = g[batch_idx, seq_end_idx - 1, bh]
for j_s in T.Parallel(block_S):
if seq_start_idx + (num_iters - 1) * block_S + j_s < seq_end_idx:
b_shared[j_s] = b[
batch_idx, seq_start_idx + (num_iters - 1) * block_S + j_s, bh
]
else:
b_shared[j_s] = 0
if tx < 128:
T.set_max_nreg(CONSUMER_S_NREG, 1)
if use_dht:
T.copy(dht[bb, bh, 0:DK, 0:DV], dh_fragment)
else:
T.clear(dh_fragment)
T.copy(dh_fragment, tmp_shared_4_1)
for i_s in T.serial(num_iters):
T.barrier_arrive(bar_00)
# 00
T.barrier_wait(bar_00, (i_s + 0) % 2)
for j_s in T.Parallel(block_S):
g_exp_shared[j_s] = T.exp2(g_shared[j_s] * 1.442695)
g_rev_exp_shared[j_s] = T.exp2(
(g_shared[block_S - 1] - g_shared[j_s]) * 1.442695
)
T.barrier_arrive(bar_01)
# 01, 02, 03
T.barrier_wait(bar_01, (i_s + 0) % 2)
g_last_local_3[0] = g_exp_shared[block_S - 1]
# dS0 = g_last * dSt
for j_k, j_v in T.Parallel(DK, DV):
dh_fragment[j_k, j_v] *= g_last_local_3[0]
T.barrier_arrive(bar_04)
# 04, 05, 06, 07
T.barrier_wait(bar_04, (i_s + 0) % 2)
# dg_last += sum(dS0 * S0)
T.clear(reduce_fragment)
for j_k, j_v in T.Parallel(DK, DV):
reduce_fragment[
j_k % 64 // 16 * 32 + j_k % 8 * 4 + j_v % 8 // 2, j_v % 2
] += dh_fragment[j_k, j_v] * h_shared[j_k, j_v]
T.barrier_arrive(bar_08)
T.barrier_wait(bar_08, (i_s + 0) % 2)
T.barrier_wait(bar_09, (i_s + 0) % 2)
# 10
T.barrier_wait(bar_10, (i_s + 0) % 2)
T.reduce_sum(
T.reshape(reduce_fragment, (128 * 2,)),
dg_last_local_3,
dim=0,
clear=True,
)
dg_shared[block_S - 1] += dg_last_local_3[0]
T.barrier_arrive(bar_11)
# 11
T.barrier_wait(bar_11, (i_s + 0) % 2)
# dS0 += K^T @ dVg
T.gemm_v1(
tmp_shared_2_2,
tmp_shared_2_3,
dh_fragment,
transpose_A=True,
clear_accum=False,
)
T.barrier_arrive(bar_12)
T.barrier_wait(bar_12, (i_s + 0) % 2)
# 13
T.barrier_wait(bar_13, (i_s + 0) % 2)
# dOg = s * g * dO
for j_s, j_v in T.Parallel(block_S, DV):
tmp_shared_2_3[j_s, j_v] = (
scale * do_shared[j_s, j_v] * g_exp_shared[j_s]
)
T.barrier_arrive(bar_14)
# 14
T.barrier_wait(bar_14, (i_s + 0) % 2)
# dS0 += Q^T @ dOg
T.gemm_v1(
tmp_shared_2_1,
tmp_shared_2_3,
dh_fragment,
transpose_A=True,
clear_accum=False,
)
T.barrier_arrive(bar_15)
# 15
T.barrier_wait(bar_15, (i_s + 0) % 2)
# S4[1] = dS0
T.copy(dh_fragment, tmp_shared_4_1)
if use_dht:
T.copy(dh_fragment, dh0[bb, bh, 0:DK, 0:DV])
elif tx < 256:
T.set_max_nreg(CONSUMER_K_NREG, 1)
for i_s in T.serial(num_iters):
T.barrier_arrive(bar_00)
# 16 == 00
T.barrier_wait(bar_00, (i_s + 0) % 2)
# S2[S] dK
if i_s > 0:
T.copy(dk_fragment, dqkv_shared)
T.barrier_arrive(bar_01)
# 01
T.barrier_wait(bar_01, (i_s + 0) % 2)
# dV' = K @ dSt
T.gemm_v1(k_shared, tmp_shared_4_1, dv_fragment, clear_accum=True)
# dV' = g_last/g * dV'
for j_s, j_v in T.Parallel(block_S, DV):
dv_fragment[j_s, j_v] *= g_rev_exp_shared[j_s]
T.barrier_arrive(bar_02)
# 02
T.barrier_wait(bar_02, (i_s + 0) % 2)
# dV' += Pg^T @ dO
T.gemm_v1(
tmp_shared_1_1,
do_shared,
dv_fragment,
transpose_A=True,
clear_accum=False,
)
T.barrier_arrive(bar_03)
# 03
T.barrier_wait(bar_03, (i_s + 0) % 2)
# S2[1] dV'
T.copy(dv_fragment, tmp_shared_2_1)
T.barrier_arrive(bar_04)
# 04
T.barrier_wait(bar_04, (i_s + 0) % 2)
# dV = Ag^T @ dV'
T.gemm_v1(
tmp_shared_1_2,
tmp_shared_2_1,
dv_fragment,
transpose_A=True,
clear_accum=True,
)
# S2[S] dV
T.copy(dv_fragment, dqkv_shared)
T.barrier_arrive(bar_05)
# 05
T.barrier_wait(bar_05, (i_s + 0) % 2)
# dVg = -g * dV
for j_s, j_v in T.Parallel(block_S, DV):
dv_fragment[j_s, j_v] = (
-dv_fragment[j_s, j_v] * g_exp_shared[j_s]
)
# dg += sum(dVg * U)
T.copy(tmp_shared_2_3, odot_fragment_1)
for j_s, j_v in T.Parallel(block_S, DV):
odot_fragment_1[j_s, j_v] *= dv_fragment[j_s, j_v]
T.reduce_sum(odot_fragment_1, dg_fragment_1, dim=1, clear=True)
T.copy(dg_fragment_1, dg_shared)
# S2[3] dVg
T.copy(dv_fragment, tmp_shared_2_3)
T.barrier_arrive(bar_06)
# 06
T.barrier_wait(bar_06, (i_s + 0) % 2)
# S2[2] K
T.copy(k_shared, odot_fragment_1)
T.copy(odot_fragment_1, tmp_shared_2_2)
T.barrier_arrive(bar_07)
# 07
T.barrier_wait(bar_07, (i_s + 0) % 2)
# dK = V' @ dSt^T
T.gemm_v1(
tmp_shared_2_1,
tmp_shared_4_1,
dk_fragment,
transpose_B=True,
clear_accum=True,
)
T.barrier_arrive(bar_08)
# 08
T.barrier_wait(bar_08, (i_s + 0) % 2)
# dK = g_last/g * dK
for j_s, j_k in T.Parallel(block_S, DK):
dk_fragment[j_s, j_k] *= g_rev_exp_shared[j_s]
# dg -= sum(K * dK)
for j_s, j_k in T.Parallel(block_S, DK):
odot_fragment_1[j_s, j_k] *= -dk_fragment[j_s, j_k]
T.reduce_sum(odot_fragment_1, dg_fragment_1, dim=1, clear=True)
for j_s in T.Parallel(block_S):
dg_shared[j_s] += dg_fragment_1[j_s]
# dg_last += sum(K * dK)
T.reduce_sum(dg_fragment_1, dg_last_local_1, dim=0, clear=True)
# Sg[S] dg
dg_shared[block_S - 1] -= dg_last_local_1[0]
T.barrier_arrive(bar_09)
# 09
T.barrier_wait(bar_09, (i_s + 0) % 2)
# dK += dVg @ S0^T
T.gemm_v1(
tmp_shared_2_3,
h_shared,
dk_fragment,
transpose_B=True,
clear_accum=False,
)
T.barrier_arrive(bar_10)
T.barrier_wait(bar_10, (i_s + 0) % 2)
# 12
T.barrier_wait(bar_12, (i_s + 0) % 2)
# dK += dP^T @ Q
T.gemm_v1(
tmp_shared_1_1,
tmp_shared_2_1,
dk_fragment,
transpose_A=True,
clear_accum=False,
)
T.barrier_arrive(bar_13)
T.barrier_wait(bar_13, (i_s + 0) % 2)
# 15
T.barrier_wait(bar_15, (i_s + 0) % 2)
# dK += dAs @ K
T.gemm_v1(
tmp_shared_1_2, tmp_shared_2_2, dk_fragment, clear_accum=False
)
for j_s, j_k in T.Parallel(block_S, DK):
if seq_start_idx + j_s < seq_end_idx:
dk[batch_idx, seq_start_idx + j_s, bh, j_k] = dk_fragment[
j_s, j_k
]
elif tx < 384:
T.set_max_nreg(CONSUMER_A_NREG, 1)
for i_s in T.serial(num_iters):
T.barrier_arrive(bar_00)
# 00
T.barrier_wait(bar_00, (i_s + 0) % 2)
# P = Q @ K^T
T.gemm_v1(
q_shared,
k_shared,
p_fragment,
transpose_B=True,
clear_accum=True,
)
T.barrier_arrive(bar_01)
# 01
T.barrier_wait(bar_01, (i_s + 0) % 2)
# G = Lower(diag(g) @ I @ diag(1/g))
for j_s, j_t in T.Parallel(block_S, block_S):
mask_fragment[j_s, j_t] = g_shared[j_s] - g_shared[j_t]
for j_s, j_t in T.Parallel(block_S, block_S):
if j_s >= j_t:
mask_fragment[j_s, j_t] = T.exp2(
mask_fragment[j_s, j_t] * 1.442695
)
else:
mask_fragment[j_s, j_t] = 0
# Pg = s * P * G
for j_s, j_t in T.Parallel(block_S, block_S):
p_fragment[j_s, j_t] *= mask_fragment[j_s, j_t]
for j_s, j_t in T.Parallel(block_S, block_S):
p_fragment[j_s, j_t] *= scale
# S1[1] Pg
T.copy(p_fragment, tmp_shared_1_1)
T.barrier_arrive(bar_02)
# 02
T.barrier_wait(bar_02, (i_s + 0) % 2)
# Ab = Ar * b
T.copy(a_shared, a_fragment)
for j_s, j_t in T.Parallel(block_S, block_S):
a_fragment[j_s, j_t] *= b_shared[j_t]
# Ag = G * Ab
for j_s, j_t in T.Parallel(block_S, block_S):
a_fragment[j_s, j_t] *= mask_fragment[j_s, j_t]
# S1[2] Ag
T.copy(a_fragment, tmp_shared_1_2)
T.barrier_arrive(bar_03)
# 03
T.barrier_wait(bar_03, (i_s + 0) % 2)
# U = K @ S0
T.gemm_v1(k_shared, h_shared, u_fragment, clear_accum=True)
T.barrier_arrive(bar_04)
# 04
T.barrier_wait(bar_04, (i_s + 0) % 2)
# S2[3] U
T.copy(u_fragment, tmp_shared_2_3)
# W = V - g * U
for j_s, j_v in T.Parallel(block_S, DV):
u_fragment[j_s, j_v] *= -g_exp_shared[j_s]
for j_s, j_v in T.Parallel(block_S, DV):
u_fragment[j_s, j_v] += v_shared[j_s, j_v]
# S2[2] W
T.copy(u_fragment, tmp_shared_2_2)
T.barrier_arrive(bar_05)
# 05
T.barrier_wait(bar_05, (i_s + 0) % 2)
# dAg = dV' @ W^T
T.gemm_v1(
tmp_shared_2_1,
tmp_shared_2_2,
da_fragment,
transpose_B=True,
clear_accum=True,
)
# V' = Ag @ W
T.gemm_v1(
tmp_shared_1_2, tmp_shared_2_2, u_fragment, clear_accum=True
)
# S2[1] V'
T.copy(u_fragment, tmp_shared_2_1)
T.barrier_arrive(bar_06)
# 06
T.barrier_wait(bar_06, (i_s + 0) % 2)
# dPg = dO @ V'^T
T.gemm_v1(
do_shared,
tmp_shared_2_1,
dp_fragment,
transpose_B=True,
clear_accum=True,
)
T.barrier_arrive(bar_07)
# 07
T.barrier_wait(bar_07, (i_s + 0) % 2)
# dAb = G * dAg
for j_s, j_t in T.Parallel(block_S, block_S):
da_fragment[j_s, j_t] *= mask_fragment[j_s, j_t]
# dg += sum((dPg * P) - (dPg * P)^T)
T.copy(tmp_shared_1_1, p_fragment)
for j_s, j_t in T.Parallel(block_S, block_S):
p_fragment[j_s, j_t] *= dp_fragment[j_s, j_t]
T.copy(p_fragment, tmp_shared_1_1)
for j_s, j_t in T.Parallel(block_S, block_S):
p_fragment[j_s, j_t] -= tmp_shared_1_1[j_t, j_s]
T.reduce_sum(p_fragment, dg_fragment_2, dim=1, clear=True)
# dP = s * G * dPg
for j_s, j_t in T.Parallel(block_S, block_S):
dp_fragment[j_s, j_t] *= mask_fragment[j_s, j_t]
for j_s, j_t in T.Parallel(block_S, block_S):
dp_fragment[j_s, j_t] *= scale
# S1[1] dP
T.copy(dp_fragment, tmp_shared_1_1)
T.barrier_arrive(bar_08)
# 08
T.barrier_wait(bar_08, (i_s + 0) % 2)
# dQ = dO @ S0^T
T.gemm_v1(
do_shared,
h_shared,
dq_fragment,
transpose_B=True,
clear_accum=True,
)
T.barrier_arrive(bar_09)
# 09
T.barrier_wait(bar_09, (i_s + 0) % 2)
# dQ = s * g * dQ
for j_s, j_k in T.Parallel(block_S, DK):
dq_fragment[j_s, j_k] *= g_exp_shared[j_s]
for j_s, j_k in T.Parallel(block_S, DK):
dq_fragment[j_s, j_k] *= scale
# S2[1] Q
T.copy(q_shared, odot_fragment_2)
# dg += sum(Q * dQ)
T.copy(odot_fragment_2, tmp_shared_2_1)
for j_s, j_k in T.Parallel(block_S, DK):
odot_fragment_2[j_s, j_k] *= dq_fragment[j_s, j_k]
T.reduce_sum(odot_fragment_2, dg_fragment_2, dim=1, clear=False)
T.barrier_arrive(bar_10)
# 10
T.barrier_wait(bar_10, (i_s + 0) % 2)
# dQ += dP @ K
T.gemm_v1(
tmp_shared_1_1, tmp_shared_2_2, dq_fragment, clear_accum=False
)
# S2[S] dQ
T.copy(dq_fragment, dqkv_shared)
T.barrier_arrive(bar_11)
# 11, 12
T.barrier_wait(bar_11, (i_s + 0) % 2)
# dAb * Ar
T.copy(a_shared, a_fragment)
for j_s, j_t in T.Parallel(block_S, block_S):
a_fragment[j_s, j_t] *= da_fragment[j_s, j_t]
T.copy(a_fragment, tmp_shared_1_3)
# dAb * Ab [ = G * dAg * Ab ]
for j_s, j_t in T.Parallel(block_S, block_S):
a_fragment[j_s, j_t] *= b_shared[j_t]
# dg += sum((dAb * Ab) - (dAb * Ab)^T)
T.copy(a_fragment, tmp_shared_1_2)
for j_s, j_t in T.Parallel(block_S, block_S):
a_fragment[j_s, j_t] -= tmp_shared_1_2[j_t, j_s]
T.reduce_sum(a_fragment, dg_fragment_2, dim=1, clear=False)
# Sg[S] dg
for j_s in T.Parallel(block_S):
dg_shared[j_s] += dg_fragment_2[j_s]
# db = sum((dAb * Ar)^T)
for j_s, j_t in T.Parallel(block_S, block_S):
a_fragment[j_s, j_t] = tmp_shared_1_3[j_t, j_s]
T.reduce_sum(a_fragment, db_fragment, dim=1, clear=True)
# dAr = dAb * b
for j_s, j_t in T.Parallel(block_S, block_S):
da_fragment[j_s, j_t] *= b_shared[j_t]
# S1[2] dAr
T.copy(da_fragment, tmp_shared_1_2)
T.barrier_arrive(bar_13)
# 13
T.barrier_wait(bar_13, (i_s + 0) % 2)
# dA = -Ar^T @ dAr @ Ar^T
T.gemm_v1(
a_shared,
tmp_shared_1_2,
da_fragment,
transpose_A=True,
clear_accum=True,
)
T.copy(da_fragment, tmp_shared_1_2)
T.gemm_v1(
tmp_shared_1_2,
a_shared,
da_fragment,
transpose_B=True,
clear_accum=True,
)
# At = K @ K^T
T.gemm_v1(
tmp_shared_2_2,
tmp_shared_2_2,
a_fragment,
transpose_B=True,
clear_accum=True,
)
T.barrier_arrive(bar_14)
# 14
T.barrier_wait(bar_14, (i_s + 0) % 2)
for j_s, j_t in T.Parallel(block_S, block_S):
if j_s <= j_t:
da_fragment[j_s, j_t] = 0
else:
da_fragment[j_s, j_t] = -da_fragment[j_s, j_t]
# db += sum(dA * At)
for j_s, j_t in T.Parallel(block_S, block_S):
a_fragment[j_s, j_t] *= da_fragment[j_s, j_t]
T.reduce_sum(a_fragment, db_fragment, dim=1, clear=False)
T.copy(db_fragment, db_shared)
# dAt = b * dA
for j_s, j_t in T.Parallel(block_S, block_S):
da_fragment[j_s, j_t] *= b_shared[j_s]
# dAs = dAt + dAt^T
T.copy(da_fragment, tmp_shared_1_2)
for j_s, j_t in T.Parallel(block_S, block_S):
da_fragment[j_s, j_t] += tmp_shared_1_2[j_t, j_s]
# S1[1] dAs
T.copy(da_fragment, tmp_shared_1_2)
T.barrier_arrive(bar_15)
T.barrier_wait(bar_15, (i_s + 0) % 2)
else:
T.set_max_nreg(PRODUCER_NREG, 0)
if tx < 384 + 32:
for i_s in T.serial(num_iters - 1):
chunk_idx = num_iters - i_s - 2
left = seq_start_idx + chunk_idx * block_S
right = left + block_S
T.barrier_arrive(bar_00)
T.barrier_wait(bar_00, (i_s + 0) % 2)
T.barrier_wait(bar_03, (i_s + 0) % 2)
for j_s in T.Parallel(block_S):
g_shared[j_s] = g[batch_idx, left + j_s, bh]
T.barrier_wait(bar_05, (i_s + 0) % 2)
T.copy(v[batch_idx, left:right, bh, 0:DV], v_shared)
T.barrier_wait(bar_07, (i_s + 0) % 2)
T.copy(k[batch_idx, left:right, bhg, 0:DK], k_shared)
T.barrier_wait(bar_10, (i_s + 0) % 2)
T.copy(q[batch_idx, left:right, bhg, 0:DK], q_shared)
if num_iters > 0:
T.barrier_arrive(bar_00)
elif tx < 384 + 64:
for i_s in T.serial(num_iters):
left = seq_start_idx + (num_iters - i_s - 1) * block_S
right = left + block_S
T.barrier_arrive(bar_00)
T.barrier_wait(bar_00, (i_s + 0) % 2)
T.barrier_wait(bar_01, (i_s + 0) % 2)
if i_s == 1:
for j_s, j_k in T.Parallel(block_S, DK):
if left + block_S + j_s < seq_end_idx:
dk[batch_idx, left + block_S + j_s, bh, j_k] = (
dqkv_shared[j_s, j_k]
)
elif i_s > 1:
T.copy(
dqkv_shared,
dk[
batch_idx,
left + block_S : right + block_S,
bh,
0:DK,
],
)
T.barrier_arrive(bar_04)
T.barrier_wait(bar_04, (i_s + 0) % 2)
T.barrier_wait(bar_05, (i_s + 0) % 2)
if i_s == 0:
for j_s, j_v in T.Parallel(block_S, DV):
if left + j_s < seq_end_idx:
dv[batch_idx, left + j_s, bh, j_v] = dqkv_shared[
j_s, j_v
]
else:
T.copy(dqkv_shared, dv[batch_idx, left:right, bh, 0:DV])
T.barrier_arrive(bar_10)
T.barrier_wait(bar_10, (i_s + 0) % 2)
T.barrier_wait(bar_11, (i_s + 0) % 2)
if i_s == 0:
for j_s, j_k in T.Parallel(block_S, DK):
if left + j_s < seq_end_idx:
dq[batch_idx, left + j_s, bh, j_k] = dqkv_shared[
j_s, j_k
]
else:
T.copy(dqkv_shared, dq[batch_idx, left:right, bh, 0:DK])
elif tx < 384 + 96:
for i_s in T.serial(num_iters - 1):
chunk_idx = num_iters - i_s - 2
left = seq_start_idx + chunk_idx * block_S
right = left + block_S
T.barrier_arrive(bar_02)
T.barrier_wait(bar_02, (i_s + 0) % 2)
T.barrier_wait(bar_10, (i_s + 0) % 2)
T.copy(
h[batch_idx, chunk_start_idx + chunk_idx, bh, 0:DK, 0:DV],
h_shared,
)
T.barrier_wait(bar_14, (i_s + 0) % 2)
T.copy(a[batch_idx, left:right, bh, 0:block_S], a_shared)
T.copy(do[batch_idx, left:right, bh, 0:DV], do_shared)
T.barrier_wait(bar_15, (i_s + 0) % 2)
for j_s in T.Parallel(block_S):
b_shared[j_s] = b[batch_idx, left + j_s, bh]
if num_iters > 0:
T.barrier_wait(bar_00, (num_iters - 1) % 2)
T.barrier_arrive(bar_02)
else:
for i_s in T.serial(num_iters):
left = seq_start_idx + (num_iters - i_s - 1) * block_S
T.barrier_arrive(bar_05)
T.barrier_wait(bar_05, (i_s + 0) % 2)
T.barrier_wait(bar_15, (i_s + 0) % 2)
if i_s == 0:
for j_s in T.Parallel(block_S):
if left + j_s < seq_end_idx:
dg[batch_idx, left + j_s, bh] = dg_shared[j_s]
if (seq_end_idx - seq_start_idx) % block_S > 0:
dg[batch_idx, seq_end_idx - 1, bh] += dg_shared[
block_S - 1
]
else:
for j_s in T.Parallel(block_S):
dg[batch_idx, left + j_s, bh] = dg_shared[j_s]
if i_s == 0:
for j_s in T.Parallel(block_S):
if left + j_s < seq_end_idx:
db[batch_idx, left + j_s, bh] = db_shared[j_s]
else:
for j_s in T.Parallel(block_S):
db[batch_idx, left + j_s, bh] = db_shared[j_s]
return tilelang_fused_chunk_gdr_bwd_kernel
def fused_gdr_bwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
a: torch.Tensor,
g: torch.Tensor,
b: torch.Tensor,
do: torch.Tensor,
dht: torch.Tensor,
h: torch.Tensor,
scale: float | None = None,
cu_seqlens: torch.LongTensor | None = None,
chunk_size: int = 64,
):
batch_size, num_tokens, Hg, K = k.shape
_, _, H, V = v.shape
scale = scale or K ** (-0.5)
assert K == V == 128
assert chunk_size == 64
if cu_seqlens is None:
real_batch_size = batch_size
cu_seqlens = torch.empty((batch_size + 1), dtype=torch.int32, device=k.device)
chunk_offsets = torch.empty(
(batch_size + 1), dtype=torch.int32, device=k.device
)
is_varlen = False
else:
real_batch_size = len(cu_seqlens) - 1
chunk_offsets, _ = prepare_chunk_offsets(cu_seqlens, chunk_size)
chunk_offsets = chunk_offsets.to(cu_seqlens.dtype)
is_varlen = True
use_dht = dht is not None
if dht is None:
dht = torch.empty(
(real_batch_size, H, K, V), dtype=torch.float32, device=k.device
)
dq = torch.empty_like(v)
dk = torch.empty_like(v)
dv = torch.empty_like(v)
dg = torch.empty_like(g)
db = torch.empty_like(b)
dh0 = torch.empty_like(dht)
tilelang_fused_chunk_gdr_bwd_kernel = tilelang_fused_chunk_gdr_bwd(
H,
Hg,
K,
V,
chunk_size,
scale,
qkva_dtype=q.dtype,
g_dtype=g.dtype,
b_dtype=b.dtype,
h_dtype=h.dtype,
o_dtype=do.dtype,
seqlen_dtype=cu_seqlens.dtype,
accum_dtype="float32",
is_varlen=is_varlen,
use_dht=use_dht,
)
tilelang_fused_chunk_gdr_bwd_kernel(
do,
dht,
q,
k,
v,
a,
g,
b,
h,
cu_seqlens,
chunk_offsets,
dq,
dk,
dv,
dg,
db,
dh0,
)
if not use_dht:
dh0 = None
return dq, dk, dv, dg, db, dh0

View File

@@ -0,0 +1,658 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
import torch
import tilelang
import tilelang.language as T
from flash_qla.utils import prepare_chunk_offsets
MULTI_PROCESSOR_COUNT = torch.cuda.get_device_properties().multi_processor_count
TARGET_NUM_CTAS = int(MULTI_PROCESSOR_COUNT * 0.7)
@tilelang.jit(
# out_idx=[-3, -2, -1],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
# tilelang.PassConfigKey.TL_DISABLE_THREAD_STORAGE_SYNC: True,
},
)
def tilelang_fused_chunk_gdr_fwd(
H,
Hg,
DK,
DV,
chunk_size,
scale,
accum_dtype,
qkva_dtype,
g_dtype,
b_dtype,
h0_dtype,
ht_dtype,
h_dtype,
o_dtype,
seqlen_dtype,
use_initial_state,
store_final_state,
store_h,
store_o,
is_varlen,
is_cp,
block_DV=128,
):
batch_size = T.dynamic("batch_size")
num_tokens = T.dynamic("num_tokens")
num_chunks = T.dynamic("num_chunks")
raw_batch_size = T.dynamic("raw_batch_size")
block_S = chunk_size
if is_varlen:
q_shape = (1, num_tokens, Hg, DK)
k_shape = (1, num_tokens, Hg, DK)
v_shape = (1, num_tokens, H, DV)
o_shape = (1, num_tokens, H, DV)
a_shape = (1, num_tokens, H, chunk_size)
g_shape = (1, num_tokens, H)
b_shape = (1, num_tokens, H)
h_shape = (1, num_chunks, H, DK, DV)
else:
q_shape = (batch_size, num_tokens, Hg, DK)
k_shape = (batch_size, num_tokens, Hg, DK)
v_shape = (batch_size, num_tokens, H, DV)
o_shape = (batch_size, num_tokens, H, DV)
a_shape = (batch_size, num_tokens, H, chunk_size)
g_shape = (batch_size, num_tokens, H)
b_shape = (batch_size, num_tokens, H)
h_shape = (batch_size, num_chunks, H, DK, DV)
h0_shape = (batch_size, H, DK, DV)
ht_shape = (raw_batch_size, H, DK, DV)
@T.prim_func
def tilelang_fused_chunk_gdr_fwd_kernel(
q: T.Tensor(q_shape, dtype=qkva_dtype),
k: T.Tensor(k_shape, dtype=qkva_dtype),
v: T.Tensor(v_shape, dtype=qkva_dtype),
a: T.Tensor(a_shape, dtype=qkva_dtype),
g: T.Tensor(g_shape, dtype=g_dtype),
b: T.Tensor(b_shape, dtype=b_dtype),
h0: T.Tensor(h0_shape, dtype=h0_dtype),
cu_seqlens: T.Tensor([batch_size + 1], dtype=seqlen_dtype),
chunk_offsets: T.Tensor([batch_size + 1], dtype=seqlen_dtype),
cp_seq_map: T.Tensor([batch_size], dtype=seqlen_dtype),
raw_cu_seqlens: T.Tensor([raw_batch_size + 1], dtype=seqlen_dtype),
o: T.Tensor(o_shape, dtype=o_dtype),
h: T.Tensor(h_shape, dtype=h_dtype),
ht: T.Tensor(ht_shape, dtype=ht_dtype),
):
with T.Kernel(T.ceildiv(DV, block_DV) * batch_size * H, threads=512) as (bbhv,):
bbh, bv = bbhv // T.ceildiv(DV, block_DV), bbhv % T.ceildiv(DV, block_DV)
bb, bh = bbh // H, bbh % H
bhg = bh // (H // Hg)
batch_idx = T.alloc_var("int32")
seq_start_idx = T.alloc_var("int32")
seq_end_idx = T.alloc_var("int32")
seq_split_idx = T.alloc_var("int32")
chunk_start_idx = T.alloc_var("int32")
chunk_split_idx = T.alloc_var("int32")
batch_idx = 0 if is_varlen else bb
seq_start_idx = cu_seqlens[bb] if is_varlen else 0
seq_end_idx = cu_seqlens[bb + 1] if is_varlen else num_tokens
chunk_start_idx = chunk_offsets[bb] if is_varlen else 0
raw_batch_idx = T.alloc_var("int32")
raw_seq_end_idx = T.alloc_var("int32")
need_store_final_state = T.alloc_var("bool")
raw_batch_idx = cp_seq_map[bb] if is_cp else bb
raw_seq_end_idx = (
raw_cu_seqlens[raw_batch_idx + 1] if is_cp else seq_end_idx
)
need_store_final_state = store_final_state & (
raw_seq_end_idx == seq_end_idx
)
num_iters = T.alloc_var("int32")
num_unmasked_iters = T.alloc_var("int32")
num_iters = T.ceildiv(seq_end_idx - seq_start_idx, block_S)
num_unmasked_iters = (seq_end_idx - seq_start_idx) // block_S
q_shared = T.alloc_shared((2, block_S, DK), dtype=qkva_dtype)
k_shared = T.alloc_shared((2, block_S, DK), dtype=qkva_dtype)
v_shared = T.alloc_shared((2, block_S, block_DV), dtype=qkva_dtype)
a_shared = T.alloc_shared((2, block_S, block_S), dtype=qkva_dtype)
g_shared = T.alloc_shared((2, block_S), dtype=accum_dtype, scope="shared")
b_shared = T.alloc_shared((2, block_S), dtype=accum_dtype, scope="shared")
o_shared = T.alloc_shared((block_S, block_DV), dtype=o_dtype)
h_shared = T.alloc_shared((DK, block_DV), dtype=qkva_dtype)
vd_shared = T.alloc_shared((block_S, block_DV), dtype=qkva_dtype)
vn_shared = T.alloc_shared((block_S, block_DV), dtype=qkva_dtype)
p_shared = T.alloc_shared((block_S, block_S), dtype=qkva_dtype)
g_exp_shared = T.alloc_shared((block_S), dtype=accum_dtype, scope="shared")
g_rev_exp_shared = T.alloc_shared(
(block_S), dtype=accum_dtype, scope="shared"
)
h_fragment = T.alloc_fragment((DK, block_DV), dtype=accum_dtype)
o_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
v_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
u_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
p_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
a_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
g_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
g_last_local = T.alloc_local((1), dtype=accum_dtype)
data_is_ready = T.alloc_barrier(arrive_count=[96] * 2)
data_is_free = T.alloc_barrier(arrive_count=[384] * 2)
bar_o = T.alloc_barrier(arrive_count=128)
bar_0 = T.alloc_barrier(arrive_count=416)
bar_1 = T.alloc_barrier(arrive_count=256)
_bar_2 = T.alloc_barrier(arrive_count=128)
bar_3 = T.alloc_barrier(arrive_count=128)
bar_4 = T.alloc_barrier(arrive_count=128)
bar_5 = T.alloc_barrier(arrive_count=416)
T.use_swizzle(10)
tx = T.get_thread_binding()
PRODUCER_NREG = 32
CONSUMER_V_NREG = 128
CONSUMER_S_NREG = 160
CONSUMER_O_NREG = 128
if tx < 128:
T.set_max_nreg(CONSUMER_S_NREG, 1)
# Initialize S
if use_initial_state:
T.copy(
h0[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV],
h_fragment,
)
else:
T.clear(h_fragment)
# Main Loop
for i_s in T.serial(num_iters):
# [STAGE 0]
T.barrier_wait(data_is_ready[i_s % 2], (i_s // 2 + 0) % 2)
T.barrier_arrive(bar_0)
# [STAGE 0] 0
T.barrier_wait(bar_0, i_s % 2)
# S4[S] S
T.copy(h_fragment, h_shared)
T.barrier_arrive(bar_1)
# [STAGE 0] 2, 3, 4
T.barrier_wait(bar_1, i_s % 2)
# S = g_last * S
g_last_local[0] = g_exp_shared[block_S - 1]
for j_k, j_v in T.Parallel(DK, block_DV):
h_fragment[j_k, j_v] *= g_last_local[0]
T.barrier_arrive(bar_5)
# [STAGE 0] 5
T.barrier_wait(bar_5, i_s % 2)
# S += K^T @ V'
T.gemm_v1(
k_shared[i_s % 2, :, :],
vn_shared,
h_fragment,
transpose_A=True,
clear_accum=False,
)
T.barrier_arrive(data_is_free[i_s % 2])
# Store final S
if need_store_final_state:
T.copy(
h_fragment,
ht[
raw_batch_idx, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV
],
)
elif tx < 256:
T.set_max_nreg(CONSUMER_V_NREG, 1)
# Main Loop
for i_s in T.serial(num_iters):
# [STAGE 0]
T.barrier_wait(data_is_ready[i_s % 2], (i_s // 2 + 0) % 2)
T.barrier_arrive(bar_0)
# [STAGE 0] 0
T.barrier_wait(bar_0, i_s % 2)
# Precompute g, g_last/g
for j_s in T.Parallel(block_S):
g_exp_shared[j_s] = T.exp2(g_shared[i_s % 2, j_s] * 1.442695)
for j_s in T.Parallel(block_S):
g_rev_exp_shared[j_s] = T.if_then_else(
seq_start_idx + i_s * block_S + j_s < seq_end_idx,
T.exp2(
(
g_shared[i_s % 2, block_S - 1]
- g_shared[i_s % 2, j_s]
)
* 1.442695
),
0.0,
)
T.barrier_arrive(bar_1)
# [STAGE 0] 1
T.barrier_wait(bar_1, i_s % 2)
# U = K @ S
T.gemm_v1(
k_shared[i_s % 2, :, :], h_shared, u_fragment, clear_accum=True
)
# [STAGE 0] 2
# W = V - g * U
for j_s, j_v in T.Parallel(block_S, block_DV):
u_fragment[j_s, j_v] *= -g_exp_shared[j_s]
for j_s, j_v in T.Parallel(block_S, block_DV):
u_fragment[j_s, j_v] += v_shared[i_s % 2, j_s, j_v]
# S2[V] W
for j_s, j_v in T.Parallel(block_S, block_DV):
v_shared[i_s % 2, j_s, j_v] = u_fragment[j_s, j_v]
# [STAGE 0] 3
T.barrier_wait(bar_3, i_s % 2)
# Vd = Ag @ W
T.gemm_v1(
a_shared[i_s % 2, :, :],
v_shared[i_s % 2, :, :],
v_fragment,
clear_accum=True,
)
# S2[2] Vd
T.copy(v_fragment, vd_shared)
T.barrier_arrive(bar_4)
# [STAGE 0] 4
# V' = g_last/g Vd
for j_s, j_v in T.Parallel(block_S, block_DV):
v_fragment[j_s, j_v] *= g_rev_exp_shared[j_s]
# S2[1] V'
T.copy(v_fragment, vn_shared)
T.barrier_arrive(bar_5)
T.barrier_wait(bar_5, i_s % 2)
T.barrier_arrive(data_is_free[i_s % 2])
elif tx < 384:
T.set_max_nreg(CONSUMER_O_NREG, 1)
# Main Loop
for i_s in T.serial(num_iters):
# [STAGE 0]
T.barrier_wait(data_is_ready[i_s % 2], (i_s // 2 + 0) % 2)
T.barrier_arrive(bar_0)
# [STAGE 0] 0
T.barrier_wait(bar_0, i_s % 2)
# P = Q K^T
T.gemm_v1(
q_shared[i_s % 2, :, :],
k_shared[i_s % 2, :, :],
p_fragment,
transpose_B=True,
clear_accum=True,
)
# [STAGE 0] 1
# G = Lower(diag(g) @ I @ diag(1/g))
for j_s, j_t in T.Parallel(block_S, block_S):
g_fragment[j_s, j_t] = (
g_shared[i_s % 2, j_s] - g_shared[i_s % 2, j_t]
)
for j_s, j_t in T.Parallel(block_S, block_S):
if j_s >= j_t:
g_fragment[j_s, j_t] = T.exp2(
g_fragment[j_s, j_t] * 1.442695
)
else:
g_fragment[j_s, j_t] = 0
# Ag = G * Ar * b
for j_s, j_t in T.Parallel(block_S, block_S):
a_fragment[j_s, j_t] = a_shared[i_s % 2, j_s, j_t]
for j_s, j_t in T.Parallel(block_S, block_S):
a_fragment[j_s, j_t] *= g_fragment[j_s, j_t]
for j_s, j_t in T.Parallel(block_S, block_S):
a_fragment[j_s, j_t] *= b_shared[i_s % 2, j_t]
for j_s, j_t in T.Parallel(block_S, block_S):
a_shared[i_s % 2, j_s, j_t] = a_fragment[j_s, j_t]
# [STAGE 0] 2
T.barrier_wait(bar_1, i_s % 2)
# O = Q @ S
T.gemm_v1(
q_shared[i_s % 2, :, :], h_shared, o_fragment, clear_accum=True
)
# [STAGE 0] 3
# Pg = s * G * P
for j_s, j_t in T.Parallel(block_S, block_S):
p_fragment[j_s, j_t] *= scale * g_fragment[j_s, j_t]
# S1[1] Pg
T.copy(p_fragment, p_shared)
T.barrier_arrive(bar_3)
# O = s * g * O
for j_s, j_k in T.Parallel(block_S, DK):
o_fragment[j_s, j_k] *= scale * g_exp_shared[j_s]
# [STAGE 0] 4
T.barrier_wait(bar_4, i_s % 2)
# O += Pg @ Vd
T.gemm_v1(p_shared, vd_shared, o_fragment, clear_accum=False)
T.barrier_arrive(bar_5)
# [STAGE 0] 5
T.barrier_wait(bar_5, i_s % 2)
# S2[S] O
T.copy(o_fragment, o_shared)
T.barrier_arrive(data_is_free[i_s % 2])
T.barrier_arrive(bar_o)
else:
T.set_max_nreg(PRODUCER_NREG, 0)
if tx < 384 + 32:
for i_s in T.serial(num_iters):
T.barrier_wait(data_is_free[i_s % 2], (i_s // 2 + 1) % 2)
left = seq_start_idx + i_s * block_S
right = left + block_S
# Load Q
T.copy(
q[batch_idx, left:right, bhg, 0:DK], q_shared[i_s % 2, :, :]
)
# Load K
T.copy(
k[batch_idx, left:right, bhg, 0:DK], k_shared[i_s % 2, :, :]
)
T.barrier_arrive(data_is_ready[i_s % 2])
elif tx < 384 + 64:
for i_s in T.serial(num_iters):
T.barrier_wait(data_is_free[i_s % 2], (i_s // 2 + 1) % 2)
left = seq_start_idx + i_s * block_S
right = left + block_S
# Load V
T.copy(
v[
batch_idx,
left:right,
bh,
bv * block_DV : (bv + 1) * block_DV,
],
v_shared[i_s % 2, :, :],
)
# Load beta
if right <= seq_end_idx:
for j_s in T.Parallel(block_S):
b_shared[i_s % 2, j_s] = b[batch_idx, left + j_s, bh]
else:
for j_s in T.Parallel(block_S):
if left + j_s < seq_end_idx:
b_shared[i_s % 2, j_s] = b[
batch_idx, left + j_s, bh
]
else:
b_shared[i_s % 2, j_s] = 0
T.barrier_arrive(data_is_ready[i_s % 2])
elif tx < 384 + 96:
for i_s in T.serial(num_iters):
T.barrier_wait(data_is_free[i_s % 2], (i_s // 2 + 1) % 2)
left = seq_start_idx + i_s * block_S
right = left + block_S
# Load A
T.copy(
a[batch_idx, left:right, bh, 0:block_S],
a_shared[i_s % 2, :, :],
)
# Load gamma
if right <= seq_end_idx:
for j_s in T.Parallel(block_S):
g_shared[i_s % 2, j_s] = g[batch_idx, left + j_s, bh]
else:
for j_s in T.Parallel(block_S):
if left + j_s < seq_end_idx:
g_shared[i_s % 2, j_s] = g[
batch_idx, left + j_s, bh
]
else:
g_shared[i_s % 2, j_s] = g[
batch_idx, seq_end_idx - 1, bh
]
T.barrier_arrive(data_is_ready[i_s % 2])
else:
for i_s in T.serial(num_unmasked_iters):
right = seq_start_idx + i_s * block_S
left = right - block_S
T.barrier_arrive(bar_0)
T.barrier_wait(bar_0, i_s % 2)
# Store O
if i_s > 0 and store_o:
T.copy(
o_shared,
o[
batch_idx,
left:right,
bh,
bv * block_DV : (bv + 1) * block_DV,
],
)
T.barrier_arrive(bar_5)
T.barrier_wait(bar_1, i_s % 2)
# Store S
if store_h:
T.copy(
h_shared,
h[
batch_idx,
chunk_start_idx + i_s,
bh,
0:DK,
bv * block_DV : (bv + 1) * block_DV,
],
)
if num_unmasked_iters < num_iters:
seq_split_idx = seq_start_idx + num_unmasked_iters * block_S
chunk_split_idx = chunk_start_idx + num_unmasked_iters
T.barrier_arrive(bar_0)
T.barrier_wait(bar_0, num_unmasked_iters % 2)
# Store O
if num_unmasked_iters > 0 and store_o:
T.copy(
o_shared,
o[
batch_idx,
seq_split_idx - block_S : seq_split_idx,
bh,
bv * block_DV : (bv + 1) * block_DV,
],
)
T.barrier_arrive(bar_5)
T.barrier_wait(bar_1, num_unmasked_iters % 2)
# Store S
if store_h:
T.copy(
h_shared,
h[
batch_idx,
chunk_split_idx,
bh,
0:DK,
bv * block_DV : (bv + 1) * block_DV,
],
)
seq_split_idx = seq_start_idx + (num_iters - 1) * block_S
# Store O
T.barrier_wait(bar_o, 0)
if store_o:
for j_s, j_v in T.Parallel(block_S, block_DV):
with T.If(seq_split_idx + j_s < seq_end_idx):
with T.Then():
o[
batch_idx,
seq_split_idx + j_s,
bh,
bv * block_DV + j_v,
] = o_shared[j_s, j_v]
return tilelang_fused_chunk_gdr_fwd_kernel
def fused_gdr_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
a: torch.Tensor,
g: torch.Tensor,
b: torch.Tensor,
scale: float | None = None,
initial_state: torch.Tensor | None = None,
output_final_state: bool = True,
output_h: bool = False,
output_o: bool = True,
cu_seqlens: torch.LongTensor | None = None,
cp_seq_map: torch.LongTensor | None = None,
raw_cu_seqlens: torch.LongTensor | None = None,
chunk_size: int = 64,
):
batch_size, num_tokens, Hg, K = k.shape
_, _, H, V = v.shape
scale = scale or K ** (-0.5)
assert K == V == 128
assert chunk_size == 64
if cu_seqlens is None:
real_batch_size = batch_size
num_chunks = tilelang.cdiv(num_tokens, chunk_size) if output_h else 0
cu_seqlens = torch.empty((batch_size + 1), dtype=torch.int32, device=k.device)
chunk_offsets = torch.empty(
(batch_size + 1), dtype=torch.int32, device=k.device
)
seqlen_dtype = torch.int32
is_varlen = False
else:
real_batch_size = len(cu_seqlens) - 1
chunk_offsets, num_chunks = prepare_chunk_offsets(cu_seqlens, chunk_size)
chunk_offsets = chunk_offsets.to(cu_seqlens.dtype)
num_chunks = num_chunks if output_h else 0
seqlen_dtype = cu_seqlens.dtype
is_varlen = True
if cp_seq_map is None:
cp_seq_map = torch.empty(
(real_batch_size,), dtype=seqlen_dtype, device=k.device
)
is_cp = False
else:
is_cp = True
use_initial_state = initial_state is not None
if initial_state is None:
initial_state = torch.empty(
(real_batch_size, H, K, V), dtype=torch.float32, device=k.device
)
h = torch.empty((batch_size, num_chunks, H, K, V), dtype=k.dtype, device=k.device)
if raw_cu_seqlens is None:
raw_cu_seqlens = torch.empty(
(real_batch_size + 1,), dtype=seqlen_dtype, device=k.device
)
final_state = torch.empty(
(real_batch_size, H, K, V), dtype=torch.float32, device=k.device
)
else:
final_state = torch.empty(
(raw_cu_seqlens.shape[0] - 1, H, K, V), dtype=torch.float32, device=k.device
)
o = torch.empty_like(v)
grid_size = real_batch_size * H
if grid_size >= TARGET_NUM_CTAS:
block_DV = 128
elif grid_size * 2 >= TARGET_NUM_CTAS:
block_DV = 64
else:
block_DV = 32
tilelang_fused_chunk_gdr_fwd_kernel = tilelang_fused_chunk_gdr_fwd(
H,
Hg,
K,
V,
chunk_size,
scale,
qkva_dtype=q.dtype,
g_dtype=g.dtype,
b_dtype=b.dtype,
h0_dtype=initial_state.dtype,
ht_dtype=final_state.dtype,
h_dtype=h.dtype,
o_dtype=o.dtype,
seqlen_dtype=seqlen_dtype,
accum_dtype="float32",
use_initial_state=use_initial_state,
store_final_state=output_final_state,
store_h=output_h,
store_o=output_o,
is_varlen=is_varlen,
is_cp=is_cp,
block_DV=block_DV,
)
tilelang_fused_chunk_gdr_fwd_kernel(
q,
k,
v,
a,
g,
b,
initial_state,
cu_seqlens,
chunk_offsets,
cp_seq_map,
raw_cu_seqlens,
o,
h,
final_state,
)
if not output_final_state:
final_state = None
if not output_h:
h = None
if not output_o:
o = None
return o, h, final_state

View File

@@ -0,0 +1,345 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
from typing import Optional
import torch
import tilelang
import tilelang.language as T
from flash_qla.utils import prepare_chunk_indices
@tilelang.jit(
# out_idx=[-1],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
# tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
# tilelang.PassConfigKey.TL_ENABLE_ASYNC_COPY: True,
},
)
def tilelang_kkt_solve(
H,
Hg,
DK,
chunk_size,
accum_dtype,
qkva_dtype,
b_dtype,
seqlen_dtype,
is_varlen,
):
data_batch_size = T.dynamic("data_batch_size")
real_batch_size = T.dynamic("real_batch_size")
num_tokens = T.dynamic("num_tokens")
num_chunks = T.dynamic("num_chunks")
block_S = chunk_size
k_shape = (data_batch_size, num_tokens, Hg, DK)
a_shape = (data_batch_size, num_tokens, H, chunk_size)
b_shape = (data_batch_size, num_tokens, H)
@T.macro
def kernel_body(
bb,
bc,
bh,
bhg,
batch_idx,
chunk_idx,
seq_start_idx,
seq_end_idx,
k,
b,
a,
):
left = seq_start_idx + chunk_idx * block_S
right = left + block_S
k_shared = T.alloc_shared((block_S, DK), dtype=qkva_dtype)
b_shared = T.alloc_shared((block_S), dtype=accum_dtype, scope="shared")
a64_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
a16i_row = T.alloc_fragment((4, 16), dtype=accum_dtype)
a16i_sum = T.alloc_fragment((4, 16), dtype=accum_dtype)
a16i_shared = T.alloc_shared((4, 17, 16), dtype=accum_dtype)
a16o_shared = T.alloc_shared((2, 17, 16), dtype=accum_dtype)
a16o_fragment = T.alloc_fragment((2, 16, 16), dtype=accum_dtype)
a32i_fragment = T.alloc_fragment((2, 32, 32), dtype=accum_dtype)
a32i0_shared = T.alloc_shared((32, 32), dtype=accum_dtype)
a32i1_shared = T.alloc_shared((32, 32), dtype=accum_dtype)
a32o_shared = T.alloc_shared((32, 32), dtype=accum_dtype)
a32o_fragment = T.alloc_fragment((32, 32), dtype=accum_dtype)
a64_shared = T.alloc_shared((block_S, block_S), dtype=qkva_dtype)
T.annotate_layout(
{
a16i_shared: tilelang.layout.make_linear_layout(a16i_shared),
a16o_shared: tilelang.layout.make_linear_layout(a16o_shared),
}
)
k_is_ready = T.alloc_barrier(arrive_count=32)
a_is_ready = T.alloc_barrier(arrive_count=128)
tx = T.get_thread_binding()
PRODUCER_NREG = 24
CONSUMER_NREG = 64
if tx < 128:
T.set_max_nreg(CONSUMER_NREG, 1)
# Load b
if right <= seq_end_idx:
for j_s in T.Parallel(block_S):
b_shared[j_s] = b[bb, left + j_s, bh]
else:
for j_s in T.Parallel(block_S):
if left + j_s < seq_end_idx:
b_shared[j_s] = b[bb, left + j_s, bh]
else:
b_shared[j_s] = 0
T.barrier_wait(k_is_ready, 0)
# A = K @ K^T
T.gemm_v1(
k_shared, k_shared, a64_fragment, transpose_B=True, clear_accum=True
)
# A = b * A
for j_s, j_t in T.Parallel(block_S, block_S):
a64_fragment[j_s, j_t] *= b_shared[j_s]
# A = I + StrictLower(A)
for j_s, j_t in T.Parallel(block_S, block_S):
if j_s < j_t:
a64_fragment[j_s, j_t] = 0
elif j_s == j_t:
a64_fragment[j_s, j_t] = 1
# Prepare inversion input
for j_s, j_t in T.Parallel(block_S, block_S):
if j_s >= 32 and j_t < 32:
a32o_shared[j_s - 32, j_t] = -a64_fragment[j_s, j_t]
elif (j_s // 16) == (j_t // 16) + 1:
a16o_shared[j_s // 32, j_s % 16, j_t % 16] = -a64_fragment[j_s, j_t]
elif (j_s // 16) == (j_t // 16):
a16i_shared[j_s // 16, j_s % 16, j_t % 16] = a64_fragment[j_s, j_t]
# Diagonal 4x16x16
T.clear(a16i_row)
for k_s in T.unroll(1, 16):
for j_s, k_t in T.Parallel(4, 16):
if k_t < k_s:
a16i_row[j_s, k_t] = a16i_shared[j_s, k_s, k_t]
T.clear(a16i_sum)
for k_r in T.unroll(k_s):
for j_s, k_t in T.Parallel(4, 16):
a16i_sum[j_s, k_t] -= (
a16i_shared[j_s, k_r, k_t] * a16i_row[j_s, k_r]
)
for j_s, k_t in T.Parallel(4, 16):
if k_t < k_s:
a16i_shared[j_s, k_s, k_t] = a16i_sum[j_s, k_t]
# First level 2x16x16
T.clear(a16o_fragment)
for k_r in T.unroll(16):
for j_s, k_s, k_t in T.Parallel(2, 16, 16):
a16o_fragment[j_s, k_s, k_t] += (
a16i_shared[j_s * 2 + 1, k_s, k_r] * a16o_shared[j_s, k_r, k_t]
)
for j_s, k_s, k_t in T.Parallel(2, 16, 16):
a16o_shared[j_s, k_t, k_s] = a16o_fragment[j_s, k_s, k_t]
T.clear(a16o_fragment)
for k_r in T.unroll(16):
for j_s, k_s, k_t in T.Parallel(2, 16, 16):
a16o_fragment[j_s, k_s, k_t] += (
a16o_shared[j_s, k_r, k_s] * a16i_shared[j_s * 2, k_r, k_t]
)
T.copy(a16o_fragment, a16o_shared[:, 0:16, 0:16])
# Second level 1x32x32
for j_s, k_s, k_t in T.Parallel(2, 32, 32):
if k_s < 16 and k_t >= 16:
a32i_fragment[j_s, k_s, k_t] = 0
for j_s, k_s, k_t in T.Parallel(2, 32, 32):
if k_s >= 16 and k_t < 16:
a32i_fragment[j_s, k_s, k_t] = a16o_shared[j_s, k_s - 16, k_t]
for j_s, k_s, k_t in T.Parallel(2, 32, 32):
if k_s // 16 == k_t // 16:
a32i_fragment[j_s, k_s, k_t] = a16i_shared[
j_s * 2 + k_s // 16, k_s % 16, k_t % 16
]
for j_s, k_s, k_t in T.Parallel(2, 32, 32):
if j_s == 0:
a32i0_shared[k_s, k_t] = a32i_fragment[j_s, k_s, k_t]
else:
a32i1_shared[k_s, k_t] = a32i_fragment[j_s, k_s, k_t]
T.gemm_v1(a32i1_shared, a32o_shared, a32o_fragment, clear_accum=True)
T.copy(a32o_fragment, a32o_shared)
T.gemm_v1(a32o_shared, a32i0_shared, a32o_fragment, clear_accum=True)
# Combine inversion output
for j_s, k_s, k_t in T.Parallel(2, 32, 32):
a64_shared[j_s * 32 + k_s, j_s * 32 + k_t] = a32i_fragment[
j_s, k_s, k_t
]
for k_s, k_t in T.Parallel(32, 32):
a64_shared[32 + k_s, k_t] = a32o_fragment[k_s, k_t]
for k_s, k_t in T.Parallel(32, 32):
a64_shared[k_s, 32 + k_t] = 0
T.barrier_arrive(a_is_ready)
else:
T.set_max_nreg(PRODUCER_NREG, 0)
if tx < 128 + 32:
# Load K
T.copy(k[bb, left:right, bhg, 0:DK], k_shared)
T.barrier_arrive(k_is_ready)
elif tx < 128 + 64:
T.barrier_wait(a_is_ready, 0)
# Save A (unmasked)
if right <= seq_end_idx:
T.copy(a64_shared, a[bb, left:right, bh, 0:block_S])
else:
T.barrier_wait(a_is_ready, 0)
# Save A (masked)
if right > seq_end_idx:
for j_s, j_t in T.Parallel(block_S, block_S):
if left + j_s < seq_end_idx:
a[bb, left + j_s, bh, j_t] = a64_shared[j_s, j_t]
if is_varlen:
@T.prim_func
def tilelang_kkt_solve_kernel(
k: T.Tensor(k_shape, dtype=qkva_dtype),
b: T.Tensor(b_shape, dtype=b_dtype),
cu_seqlens: T.Tensor([real_batch_size + 1], dtype=seqlen_dtype),
chunk_indices: T.Tensor([num_chunks, 2], dtype=seqlen_dtype),
a: T.Tensor(a_shape, dtype=qkva_dtype),
):
with T.Kernel(num_chunks * H, threads=256) as (bch,):
bc, bh = bch // H, bch % H
bhg = bh // (H // Hg)
batch_idx = T.alloc_var("int32")
chunk_idx = T.alloc_var("int32")
seq_start_idx = T.alloc_var("int32")
seq_end_idx = T.alloc_var("int32")
bb = 0
batch_idx = chunk_indices[bc, 0]
chunk_idx = chunk_indices[bc, 1]
seq_start_idx = cu_seqlens[batch_idx]
seq_end_idx = cu_seqlens[batch_idx + 1]
kernel_body(
bb,
bc,
bh,
bhg,
batch_idx,
chunk_idx,
seq_start_idx,
seq_end_idx,
k,
b,
a,
)
else:
@T.prim_func
def tilelang_kkt_solve_kernel(
k: T.Tensor(k_shape, dtype=qkva_dtype),
b: T.Tensor(b_shape, dtype=b_dtype),
a: T.Tensor(a_shape, dtype=qkva_dtype),
num_chunks: T.int32,
):
with T.Kernel(num_chunks * H, threads=256) as (bch,):
bc, bh = bch // H, bch % H
bhg = bh // (H // Hg)
batch_idx = T.alloc_var("int32")
chunk_idx = T.alloc_var("int32")
seq_start_idx = T.alloc_var("int32")
seq_end_idx = T.alloc_var("int32")
bb = bc % data_batch_size
batch_idx = bb
chunk_idx = bc // data_batch_size
seq_start_idx = 0
seq_end_idx = num_tokens
kernel_body(
bb,
bc,
bh,
bhg,
batch_idx,
chunk_idx,
seq_start_idx,
seq_end_idx,
k,
b,
a,
)
return tilelang_kkt_solve_kernel
def kkt_solve(
k: torch.Tensor,
b: torch.Tensor,
chunk_size: int = 64,
cu_seqlens: Optional[torch.LongTensor] = None,
):
batch_size, num_tokens, Hg, K = k.shape
_, _, H = b.shape
assert K == 128
assert chunk_size == 64
if cu_seqlens is None:
num_chunks = batch_size * tilelang.cdiv(num_tokens, chunk_size)
seqlen_dtype = "int32"
is_varlen = False
else:
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
seqlen_dtype = cu_seqlens.dtype
is_varlen = True
a = torch.empty(
(batch_size, num_tokens, H, chunk_size), dtype=k.dtype, device=k.device
)
tilelang_kkt_solve_kernel = tilelang_kkt_solve(
H,
Hg,
K,
chunk_size,
qkva_dtype=k.dtype,
b_dtype=b.dtype,
seqlen_dtype=seqlen_dtype,
accum_dtype="float32",
is_varlen=is_varlen,
)
if is_varlen:
tilelang_kkt_solve_kernel(k, b, cu_seqlens, chunk_indices, a)
else:
tilelang_kkt_solve_kernel(k, b, a, num_chunks)
return a

View File

@@ -0,0 +1,558 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
import torch
import tilelang
import tilelang.language as T
from flash_qla.utils import prepare_chunk_offsets
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
)
def tilelang_prepare_h(
H,
Hg,
DK,
DV,
chunk_size,
accum_dtype,
qkva_dtype,
g_dtype,
b_dtype,
h0_dtype,
ht_dtype,
h_dtype,
seqlen_dtype,
use_initial_state,
store_final_state,
store_h,
is_varlen,
is_cp,
num_stages=2,
):
batch_size = T.dynamic("batch_size")
num_tokens = T.dynamic("num_tokens")
num_chunks = T.dynamic("num_chunks")
block_S = chunk_size
if is_varlen:
k_shape = (1, num_tokens, Hg, DK)
v_shape = (1, num_tokens, H, DV)
a_shape = (1, num_tokens, H, chunk_size)
g_shape = (1, num_tokens, H)
b_shape = (1, num_tokens, H)
h_shape = (1, num_chunks, H, DK, DV)
else:
k_shape = (batch_size, num_tokens, Hg, DK)
v_shape = (batch_size, num_tokens, H, DV)
a_shape = (batch_size, num_tokens, H, chunk_size)
g_shape = (batch_size, num_tokens, H)
b_shape = (batch_size, num_tokens, H)
h_shape = (batch_size, num_chunks, H, DK, DV)
h0_shape = (batch_size, H, DK, DV)
ht_shape = (batch_size, H, DK, DV)
m_shape = (batch_size, H, DK, DK)
@T.prim_func
def tilelang_prepare_h_kernel(
k: T.Tensor(k_shape, dtype=qkva_dtype),
v: T.Tensor(v_shape, dtype=qkva_dtype),
a: T.Tensor(a_shape, dtype=qkva_dtype),
g: T.Tensor(g_shape, dtype=g_dtype),
b: T.Tensor(b_shape, dtype=b_dtype),
h0: T.Tensor(h0_shape, dtype=h0_dtype),
cu_seqlens: T.Tensor([batch_size + 1], dtype=seqlen_dtype),
chunk_offsets: T.Tensor([batch_size + 1], dtype=seqlen_dtype),
num_warmup_chunks: T.Tensor([batch_size, H], dtype=seqlen_dtype),
h: T.Tensor(h_shape, dtype=h_dtype),
ht: T.Tensor(ht_shape, dtype=ht_dtype),
mt: T.Tensor(m_shape, dtype=ht_dtype),
):
with T.Kernel(batch_size * H, threads=512) as (bbh,):
bb, bh = bbh // H, bbh % H
bhg = bh // (H // Hg)
batch_idx = T.alloc_var("int32")
seq_start_idx = T.alloc_var("int32")
seq_end_idx = T.alloc_var("int32")
_seq_split_idx = T.alloc_var("int32")
chunk_start_idx = T.alloc_var("int32")
_chunk_split_idx = T.alloc_var("int32")
batch_idx = 0 if is_varlen else bb
seq_start_idx = cu_seqlens[bb] if is_varlen else 0
seq_end_idx = cu_seqlens[bb + 1] if is_varlen else num_tokens
chunk_start_idx = chunk_offsets[bb] if is_varlen else 0
num_iters = T.alloc_var("int32")
num_iters = (
num_warmup_chunks[bb, bh]
if is_cp
else T.ceildiv(seq_end_idx - seq_start_idx, block_S)
)
calc_mt = T.alloc_var("bool")
calc_mt = is_cp and num_iters >= T.ceildiv(
seq_end_idx - seq_start_idx, block_S
)
seq_start_idx = (
seq_end_idx - num_iters * block_S if is_cp else seq_start_idx
)
k_shared = T.alloc_shared((num_stages, block_S, DK), dtype=qkva_dtype)
v_shared = T.alloc_shared((num_stages, block_S, DV), dtype=qkva_dtype)
a_shared = T.alloc_shared((num_stages, block_S, block_S), dtype=qkva_dtype)
g_shared = T.alloc_shared(
(num_stages, block_S), dtype=accum_dtype, scope="shared"
)
b_shared = T.alloc_shared(
(num_stages, block_S), dtype=accum_dtype, scope="shared"
)
h_shared = T.alloc_shared((DK, DV), dtype=qkva_dtype)
x_shared = T.alloc_shared((block_S, DK), dtype=qkva_dtype)
y_shared = T.alloc_shared((block_S, DV), dtype=qkva_dtype)
m_shared_L = T.alloc_shared((DK, DK // 2), dtype=qkva_dtype)
m_shared_R = T.alloc_shared((DK, DK // 2), dtype=qkva_dtype)
z_shared_L = T.alloc_shared((block_S, DK // 2), dtype=qkva_dtype)
z_shared_R = T.alloc_shared((block_S, DK // 2), dtype=qkva_dtype)
g_rev_exp_shared = T.alloc_shared(
(block_S), dtype=accum_dtype, scope="shared"
)
h_fragment = T.alloc_fragment((DK, DV), dtype=accum_dtype)
x_fragment = T.alloc_fragment((block_S, DK), dtype=accum_dtype)
y_fragment = T.alloc_fragment((block_S, DV), dtype=accum_dtype)
m_fragment_L = T.alloc_fragment((DK, DK // 2), dtype=accum_dtype)
m_fragment_R = T.alloc_fragment((DK, DK // 2), dtype=accum_dtype)
z_fragment_L = T.alloc_fragment((block_S, DK // 2), dtype=accum_dtype)
z_fragment_R = T.alloc_fragment((block_S, DK // 2), dtype=accum_dtype)
g_last_local_S = T.alloc_local((1), dtype=accum_dtype)
g_last_local_X = T.alloc_local((1), dtype=accum_dtype)
g_last_local_Y = T.alloc_local((1), dtype=accum_dtype)
g_prod_X = T.alloc_fragment((1), dtype=accum_dtype)
g_prod_Y = T.alloc_fragment((1), dtype=accum_dtype)
data_is_ready = T.alloc_barrier(arrive_count=[96] * num_stages)
data_is_free = T.alloc_barrier(arrive_count=[384] * num_stages)
bar_0 = T.alloc_barrier(arrive_count=416)
bar_1 = T.alloc_barrier(arrive_count=256)
bar_2 = T.alloc_barrier(arrive_count=384)
bar_3 = T.alloc_barrier(arrive_count=128)
T.use_swizzle(10)
tx = T.get_thread_binding()
PRODUCER_NREG = 24
CONSUMER_S_NREG = 168
CONSUMER_X_NREG = 160
CONSUMER_Y_NREG = 160
if tx < 128:
T.set_max_nreg(CONSUMER_S_NREG, 1)
# Initialize S
if use_initial_state:
T.copy(h0[bb, bh, 0:DK, 0:DV], h_fragment)
else:
T.clear(h_fragment)
# Main Loop
for i_s in T.serial(num_iters):
# [STAGE = i_s % num_stages]
T.barrier_wait(
data_is_ready[i_s % num_stages], (i_s // num_stages + 0) % 2
)
T.barrier_arrive(bar_0)
# [STAGE = i_s % num_stages] 0
T.barrier_wait(bar_0, i_s % 2)
# S4[1] S
T.copy(h_fragment, h_shared)
T.barrier_arrive(bar_1)
# [STAGE = i_s % num_stages] 1
T.barrier_wait(bar_1, i_s % 2)
# S = g_last * S
g_last_local_S[0] = T.exp2(
g_shared[i_s % num_stages, block_S - 1] * 1.442695
)
for j_k, j_v in T.Parallel(DK, DV):
h_fragment[j_k, j_v] *= g_last_local_S[0]
T.barrier_arrive(bar_2)
# [STAGE = i_s % num_stages] 2
T.barrier_wait(bar_2, i_s % 2)
# S += X^T @ Y
T.gemm_v1(
x_shared,
y_shared,
h_fragment,
transpose_A=True,
clear_accum=False,
)
T.barrier_arrive(bar_3)
T.barrier_arrive(data_is_free[i_s % num_stages])
# Store final S
if store_final_state:
T.copy(h_fragment, ht[bb, bh, 0:DK, 0:DV])
elif tx < 256:
T.set_max_nreg(CONSUMER_X_NREG, 1)
if calc_mt:
for j_k, j_v in T.Parallel(DK, DK // 2):
if j_k == j_v + DK // 2:
m_fragment_R[j_k, j_v] = 1
else:
m_fragment_R[j_k, j_v] = 0
g_prod_X[0] = 0
# Main Loop
for i_s in T.serial(num_iters):
# [STAGE = i_s % num_stages]
T.barrier_wait(
data_is_ready[i_s % num_stages], (i_s // num_stages + 0) % 2
)
T.barrier_arrive(bar_0)
# [STAGE = i_s % num_stages] 0
T.barrier_wait(bar_0, i_s % 2)
# X = A^T @ K
T.gemm_v1(
a_shared[i_s % num_stages, :, :],
k_shared[i_s % num_stages, :, :],
x_fragment,
transpose_A=True,
clear_accum=True,
)
# [STAGE = i_s % num_stages] 1
# X = - b * X
for j_s, j_k in T.Parallel(block_S, DK):
x_fragment[j_s, j_k] *= -b_shared[i_s % num_stages, j_s]
# S2[1] X
T.copy(x_fragment, x_shared)
T.barrier_arrive(bar_2)
if calc_mt:
# [STAGE = i_s % num_stages] 2
g_prod_X[0] += g_shared[i_s % num_stages, block_S - 1]
# S4[2] M
T.copy(m_fragment_R, m_shared_R)
# [STAGE = i_s % num_stages] 3
T.barrier_wait(bar_3, i_s % 2)
# Z = K @ M
T.gemm_v1(
k_shared[i_s % num_stages, :, :],
m_shared_R,
z_fragment_R,
clear_accum=True,
)
# S4[2] Z
T.copy(z_fragment_R, z_shared_R)
# M += X^T @ Z
T.gemm_v1(
x_shared,
z_shared_R,
m_fragment_R,
transpose_A=True,
clear_accum=False,
)
T.barrier_arrive(data_is_free[i_s % num_stages])
if calc_mt:
g_last_local_X[0] = T.exp2(g_prod_X[0] * 1.442695)
for j_k, j_v in T.Parallel(DK, DK // 2):
m_fragment_R[j_k, j_v] *= g_last_local_X[0]
T.copy(m_fragment_R, mt[bb, bh, 0:DK, DK // 2 :])
elif tx < 384:
T.set_max_nreg(CONSUMER_Y_NREG, 1)
if calc_mt:
for j_k, j_v in T.Parallel(DK, DK // 2):
if j_k == j_v:
m_fragment_L[j_k, j_v] = 1
else:
m_fragment_L[j_k, j_v] = 0
g_prod_Y[0] = 0
# Main Loop
for i_s in T.serial(num_iters):
# [STAGE = i_s % num_stages]
T.barrier_wait(
data_is_ready[i_s % num_stages], (i_s // num_stages + 0) % 2
)
T.barrier_arrive(bar_0)
# [STAGE = i_s % num_stages] 0
T.barrier_wait(bar_0, i_s % 2)
# Precompute g_last/g
g_last_local_Y[0] = g_shared[i_s % num_stages, block_S - 1]
for j_s in T.Parallel(block_S):
g_rev_exp_shared[j_s] = T.exp2(
(g_last_local_Y[0] - g_shared[i_s % num_stages, j_s])
* 1.442695
)
g_last_local_Y[0] = T.exp2(g_last_local_Y[0] * 1.442695)
T.barrier_arrive(bar_1)
# [STAGE = i_s % num_stages] 1
T.barrier_wait(bar_1, i_s % 2)
# U = K @ S
T.gemm_v1(
k_shared[i_s % num_stages, :, :],
h_shared,
y_fragment,
clear_accum=True,
)
# Y = g_last * U - g_last/g * V
for j_s, j_v in T.Parallel(block_S, DV):
y_fragment[j_s, j_v] *= g_last_local_Y[0]
for j_s, j_v in T.Parallel(block_S, DV):
y_fragment[j_s, j_v] -= (
v_shared[i_s % num_stages, j_s, j_v] * g_rev_exp_shared[j_s]
)
# S2[2] Y
T.copy(y_fragment, y_shared)
T.barrier_arrive(bar_2)
if calc_mt:
# [STAGE = i_s % num_stages] 2
g_prod_Y[0] += g_shared[i_s % num_stages, block_S - 1]
# S4[2] M
T.copy(m_fragment_L, m_shared_L)
# [STAGE = i_s % num_stages] 3
T.barrier_wait(bar_3, i_s % 2)
# Z = K @ M
T.gemm_v1(
k_shared[i_s % num_stages, :, :],
m_shared_L,
z_fragment_L,
clear_accum=True,
)
# S4[2] Z
T.copy(z_fragment_L, z_shared_L)
# M += X^T @ Z
T.gemm_v1(
x_shared,
z_shared_L,
m_fragment_L,
transpose_A=True,
clear_accum=False,
)
T.barrier_arrive(data_is_free[i_s % num_stages])
if calc_mt:
g_last_local_Y[0] = T.exp2(g_prod_Y[0] * 1.442695)
for j_k, j_v in T.Parallel(DK, DK // 2):
m_fragment_L[j_k, j_v] *= g_last_local_Y[0]
T.copy(m_fragment_L, mt[bb, bh, 0:DK, : DK // 2])
else:
T.set_max_nreg(PRODUCER_NREG, 0)
if tx < 384 + 32:
for i_s in T.serial(num_iters):
T.barrier_wait(
data_is_free[i_s % num_stages], (i_s // num_stages + 1) % 2
)
left = seq_start_idx + i_s * block_S
right = left + block_S
# Load K
T.copy(
k[batch_idx, left:right, bhg, 0:DK],
k_shared[i_s % num_stages, :, :],
)
T.barrier_arrive(data_is_ready[i_s % num_stages])
elif tx < 384 + 64:
for i_s in T.serial(num_iters):
T.barrier_wait(
data_is_free[i_s % num_stages], (i_s // num_stages + 1) % 2
)
left = seq_start_idx + i_s * block_S
right = left + block_S
# Load V
T.copy(
v[batch_idx, left:right, bh, 0:DV],
v_shared[i_s % num_stages, :, :],
)
# Load A TODO: Mask A for the last chunk
T.copy(
a[batch_idx, left:right, bh, 0:block_S],
a_shared[i_s % num_stages, :, :],
)
T.barrier_arrive(data_is_ready[i_s % num_stages])
elif tx < 384 + 96:
for i_s in T.serial(num_iters):
T.barrier_wait(
data_is_free[i_s % num_stages], (i_s // num_stages + 1) % 2
)
left = seq_start_idx + i_s * block_S
right = left + block_S
# Load gamma
if right <= seq_end_idx:
for j_s in T.Parallel(block_S):
g_shared[i_s % num_stages, j_s] = g[
batch_idx, left + j_s, bh
]
else:
for j_s in T.Parallel(block_S):
if left + j_s < seq_end_idx:
g_shared[i_s % num_stages, j_s] = g[
batch_idx, left + j_s, bh
]
else:
g_shared[i_s % num_stages, j_s] = g[
batch_idx, seq_end_idx - 1, bh
]
# Load beta
if right <= seq_end_idx:
for j_s in T.Parallel(block_S):
b_shared[i_s % num_stages, j_s] = b[
batch_idx, left + j_s, bh
]
else:
for j_s in T.Parallel(block_S):
if left + j_s < seq_end_idx:
b_shared[i_s % num_stages, j_s] = b[
batch_idx, left + j_s, bh
]
else:
b_shared[i_s % num_stages, j_s] = 0
T.barrier_arrive(data_is_ready[i_s % num_stages])
else:
for i_s in T.serial(num_iters):
T.barrier_arrive(bar_0)
T.barrier_wait(bar_0, i_s % 2)
T.barrier_wait(bar_1, i_s % 2)
# Store S
if store_h:
T.copy(
h_shared,
h[batch_idx, chunk_start_idx + i_s, bh, 0:DK, 0:DV],
)
return tilelang_prepare_h_kernel
def fused_gdr_h(
k: torch.Tensor,
v: torch.Tensor,
a: torch.Tensor,
g: torch.Tensor,
b: torch.Tensor,
initial_state: torch.Tensor | None = None,
output_final_state: bool = True,
output_h: bool = True,
chunk_size: int = 64,
cu_seqlens: torch.LongTensor | None = None,
num_warmup_chunks: torch.LongTensor | None = None,
):
batch_size, num_tokens, Hg, K = k.shape
_, _, H, V = v.shape
assert K == V == 128
assert chunk_size == 64
if cu_seqlens is None:
assert num_warmup_chunks is None
real_batch_size = batch_size
num_chunks = tilelang.cdiv(num_tokens, chunk_size) if output_h else 0
cu_seqlens = torch.empty((batch_size + 1), dtype=torch.int32, device=k.device)
chunk_offsets = torch.empty(
(batch_size + 1), dtype=torch.int32, device=k.device
)
is_varlen = False
is_cp = False
else:
real_batch_size = len(cu_seqlens) - 1
chunk_offsets, num_chunks = prepare_chunk_offsets(cu_seqlens, chunk_size)
chunk_offsets = chunk_offsets.to(cu_seqlens.dtype)
num_chunks = num_chunks if output_h else 0
is_varlen = True
if num_warmup_chunks is None:
num_warmup_chunks = torch.empty(
(real_batch_size, H), dtype=cu_seqlens.dtype, device=k.device
)
is_cp = False
else:
is_cp = True
use_initial_state = initial_state is not None
if initial_state is None:
initial_state = torch.empty(
(real_batch_size, H, K, V), dtype=torch.float32, device=k.device
)
h = torch.empty((batch_size, num_chunks, H, K, V), dtype=k.dtype, device=k.device)
ht_dtype = k.dtype if is_cp else torch.float32
final_state = torch.empty(
(real_batch_size, H, K, V), dtype=ht_dtype, device=k.device
)
final_correction = torch.empty(
(real_batch_size, H, K, K), dtype=ht_dtype, device=k.device
)
tilelang_prepare_h_kernel = tilelang_prepare_h(
H,
Hg,
K,
V,
chunk_size,
qkva_dtype=k.dtype,
g_dtype=g.dtype,
b_dtype=b.dtype,
h0_dtype=initial_state.dtype,
ht_dtype=final_state.dtype,
h_dtype=h.dtype,
seqlen_dtype=cu_seqlens.dtype,
accum_dtype="float32",
use_initial_state=use_initial_state,
store_final_state=output_final_state,
store_h=output_h,
is_varlen=is_varlen,
is_cp=is_cp,
)
tilelang_prepare_h_kernel(
k,
v,
a,
g,
b,
initial_state,
cu_seqlens,
chunk_offsets,
num_warmup_chunks,
h,
final_state,
final_correction,
)
if not output_final_state:
final_state = None
final_correction = None
if not output_h:
h = None
return h, final_state, final_correction

View File

@@ -0,0 +1,6 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
from .sm_legacy import chunk_gated_delta_rule_fwd_legacy
__all__ = ["chunk_gated_delta_rule_fwd_legacy"]

View File

@@ -0,0 +1,348 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cstdint>
#include <stdexcept>
#include <string>
#include <vector>
namespace {
void check_cuda(cudaError_t status, const char* context) {
if (status != cudaSuccess) {
throw std::runtime_error(std::string(context) + ": " +
cudaGetErrorString(status));
}
}
__device__ __forceinline__ float subgroup_sum_lane0(float value,
int width) {
constexpr unsigned mask = 0xffffffffU;
for (int offset = width / 2; offset > 0; offset >>= 1) {
value += __shfl_down_sync(mask, value, offset, width);
}
return value;
}
__device__ __forceinline__ float subgroup_broadcast_lane0(float value,
int width) {
return __shfl_sync(0xffffffffU, value, 0, width);
}
template <int D, int COLS, int WIDTH>
__global__ void gdn_forward_kernel(const float* __restrict__ q,
const float* __restrict__ k,
const float* __restrict__ v,
const float* __restrict__ gate,
const float* __restrict__ beta,
const float* __restrict__ initial_state,
float* __restrict__ output,
float* __restrict__ final_state,
int batch,
int tokens,
int q_heads,
int v_heads,
float scale) {
static_assert(D % (COLS * (32 / WIDTH)) == 0);
constexpr int subgroups_per_warp = 32 / WIDTH;
constexpr int rows_per_lane = (D + WIDTH - 1) / WIDTH;
const int hv = blockIdx.x;
const int b = blockIdx.y;
const int subgroup = threadIdx.x / WIDTH;
const int lane = threadIdx.x % WIDTH;
const int group_base =
(blockIdx.z * blockDim.y + threadIdx.y) * subgroups_per_warp + subgroup;
const int col_base = group_base * COLS;
const int hq = hv / (v_heads / q_heads);
float state_shard[COLS][rows_per_lane];
#pragma unroll
for (int c = 0; c < COLS; ++c) {
const int col = col_base + c;
#pragma unroll
for (int r = 0; r < rows_per_lane; ++r) {
const int row = r * WIDTH + lane;
float value = 0.0F;
if (row < D) {
const auto state_index =
(((static_cast<int64_t>(b) * v_heads + hv) * D + col) * D) + row;
value = initial_state == nullptr ? 0.0F : initial_state[state_index];
}
state_shard[c][r] = value;
}
}
for (int t = 0; t < tokens; ++t) {
const auto gate_index =
((static_cast<int64_t>(b) * tokens + t) * v_heads + hv);
float gate_value = 0.0F;
float beta_value = 0.0F;
if (threadIdx.x == 0) {
gate_value = __expf(gate[gate_index]);
beta_value = beta[gate_index];
}
gate_value = __shfl_sync(0xffffffffU, gate_value, 0);
beta_value = __shfl_sync(0xffffffffU, beta_value, 0);
float k_reg[rows_per_lane];
float q_reg[rows_per_lane];
float kv_partial[COLS];
#pragma unroll
for (int c = 0; c < COLS; ++c) {
kv_partial[c] = 0.0F;
}
#pragma unroll
for (int r = 0; r < rows_per_lane; ++r) {
const int row = r * WIDTH + lane;
float q_value = 0.0F;
float k_value = 0.0F;
if (row < D) {
const auto qk_index =
(((static_cast<int64_t>(b) * tokens + t) * q_heads + hq) * D) + row;
q_value = q[qk_index];
k_value = k[qk_index];
}
q_reg[r] = q_value;
k_reg[r] = k_value;
#pragma unroll
for (int c = 0; c < COLS; ++c) {
kv_partial[c] += state_shard[c][r] * k_value;
}
}
float delta[COLS];
#pragma unroll
for (int c = 0; c < COLS; ++c) {
const float kv_col = subgroup_sum_lane0(kv_partial[c], WIDTH);
float delta_value = 0.0F;
if (lane == 0) {
const auto v_index =
(((static_cast<int64_t>(b) * tokens + t) * v_heads + hv) * D) +
col_base + c;
delta_value = (v[v_index] - gate_value * kv_col) * beta_value;
}
delta[c] = subgroup_broadcast_lane0(delta_value, WIDTH);
}
float attn_partial[COLS];
#pragma unroll
for (int c = 0; c < COLS; ++c) {
attn_partial[c] = 0.0F;
}
#pragma unroll
for (int r = 0; r < rows_per_lane; ++r) {
#pragma unroll
for (int c = 0; c < COLS; ++c) {
const float new_state =
fmaf(k_reg[r], delta[c], gate_value * state_shard[c][r]);
state_shard[c][r] = new_state;
attn_partial[c] += new_state * q_reg[r];
}
}
#pragma unroll
for (int c = 0; c < COLS; ++c) {
attn_partial[c] = subgroup_sum_lane0(attn_partial[c], WIDTH);
}
if (lane == 0) {
const auto out_base =
(((static_cast<int64_t>(b) * tokens + t) * v_heads + hv) * D);
#pragma unroll
for (int c = 0; c < COLS; ++c) {
output[out_base + col_base + c] = attn_partial[c] * scale;
}
}
}
#pragma unroll
for (int c = 0; c < COLS; ++c) {
const int col = col_base + c;
#pragma unroll
for (int r = 0; r < rows_per_lane; ++r) {
const int row = r * WIDTH + lane;
if (row < D) {
const auto state_index =
(((static_cast<int64_t>(b) * v_heads + hv) * D + col) * D) + row;
final_state[state_index] = state_shard[c][r];
}
}
}
}
template <int D>
void launch_gdn_forward(const float* q,
const float* k,
const float* v,
const float* gate,
const float* beta,
const float* initial_state,
float* output,
float* final_state,
int batch,
int tokens,
int q_heads,
int v_heads,
float scale,
cudaStream_t stream) {
constexpr int cols = D == 128 ? 4 : 1;
constexpr int width = D == 128 ? 16 : 32;
constexpr int groups_per_warp = 32 / width;
constexpr int column_groups_per_block = 8;
const dim3 block(32, column_groups_per_block);
const int groups = D / cols;
const int z = (groups + column_groups_per_block * groups_per_warp - 1) /
(column_groups_per_block * groups_per_warp);
const dim3 grid(v_heads, batch, z);
gdn_forward_kernel<D, cols, width>
<<<grid, block, 0, stream>>>(q,
k,
v,
gate,
beta,
initial_state,
output,
final_state,
batch,
tokens,
q_heads,
v_heads,
scale);
}
void validate_tensor(const torch::Tensor& tensor,
const char* name,
int64_t dims) {
TORCH_CHECK(tensor.is_cuda(), name, " must be a CUDA tensor");
TORCH_CHECK(tensor.scalar_type() == torch::kFloat32,
name,
" must be float32");
TORCH_CHECK(tensor.is_contiguous(), name, " must be contiguous");
TORCH_CHECK(tensor.dim() == dims, name, " has wrong rank");
}
} // namespace
std::vector<torch::Tensor> gdn_forward(torch::Tensor q,
torch::Tensor k,
torch::Tensor v,
torch::Tensor gate,
torch::Tensor beta,
c10::optional<torch::Tensor> initial_state,
double scale) {
validate_tensor(q, "q", 4);
validate_tensor(k, "k", 4);
validate_tensor(v, "v", 4);
validate_tensor(gate, "gate", 3);
validate_tensor(beta, "beta", 3);
TORCH_CHECK(q.sizes() == k.sizes(), "q and k must have the same shape");
const int batch = static_cast<int>(q.size(0));
const int tokens = static_cast<int>(q.size(1));
const int q_heads = static_cast<int>(q.size(2));
const int dim = static_cast<int>(q.size(3));
const int v_heads = static_cast<int>(v.size(2));
TORCH_CHECK(v.size(0) == batch && v.size(1) == tokens && v.size(3) == dim,
"v must have shape [B, T, Hv, D] matching q/k");
TORCH_CHECK(gate.size(0) == batch && gate.size(1) == tokens &&
gate.size(2) == v_heads,
"gate must have shape [B, T, Hv]");
TORCH_CHECK(beta.sizes() == gate.sizes(),
"beta must have the same shape as gate");
TORCH_CHECK(v_heads % q_heads == 0, "Hv must be divisible by Hq");
TORCH_CHECK(dim == 16 || dim == 32 || dim == 64 || dim == 128,
"D must be one of 16, 32, 64, or 128");
const float* initial_ptr = nullptr;
if (initial_state.has_value() && initial_state.value().defined()) {
const auto& h0 = initial_state.value();
validate_tensor(h0, "initial_state", 4);
TORCH_CHECK(h0.size(0) == batch && h0.size(1) == v_heads &&
h0.size(2) == dim && h0.size(3) == dim,
"initial_state must have shape [B, Hv, D, D]");
initial_ptr = h0.data_ptr<float>();
}
auto output = torch::empty_like(v);
auto final_state = torch::empty({batch, v_heads, dim, dim}, q.options());
const auto stream = at::cuda::getCurrentCUDAStream(q.device().index()).stream();
switch (dim) {
case 16:
launch_gdn_forward<16>(q.data_ptr<float>(),
k.data_ptr<float>(),
v.data_ptr<float>(),
gate.data_ptr<float>(),
beta.data_ptr<float>(),
initial_ptr,
output.data_ptr<float>(),
final_state.data_ptr<float>(),
batch,
tokens,
q_heads,
v_heads,
static_cast<float>(scale),
stream);
break;
case 32:
launch_gdn_forward<32>(q.data_ptr<float>(),
k.data_ptr<float>(),
v.data_ptr<float>(),
gate.data_ptr<float>(),
beta.data_ptr<float>(),
initial_ptr,
output.data_ptr<float>(),
final_state.data_ptr<float>(),
batch,
tokens,
q_heads,
v_heads,
static_cast<float>(scale),
stream);
break;
case 64:
launch_gdn_forward<64>(q.data_ptr<float>(),
k.data_ptr<float>(),
v.data_ptr<float>(),
gate.data_ptr<float>(),
beta.data_ptr<float>(),
initial_ptr,
output.data_ptr<float>(),
final_state.data_ptr<float>(),
batch,
tokens,
q_heads,
v_heads,
static_cast<float>(scale),
stream);
break;
case 128:
launch_gdn_forward<128>(q.data_ptr<float>(),
k.data_ptr<float>(),
v.data_ptr<float>(),
gate.data_ptr<float>(),
beta.data_ptr<float>(),
initial_ptr,
output.data_ptr<float>(),
final_state.data_ptr<float>(),
batch,
tokens,
q_heads,
v_heads,
static_cast<float>(scale),
stream);
break;
}
check_cuda(cudaGetLastError(), "gdn_forward launch");
return {output, final_state};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("gdn_forward", &gdn_forward, "SM70/SM75 legacy GDN forward");
}

View File

@@ -0,0 +1,104 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
from __future__ import annotations
import os
from pathlib import Path
import torch
from torch.utils.cpp_extension import load
_EXT = None
def _load_ext():
global _EXT
if _EXT is not None:
return _EXT
if not torch.cuda.is_available():
raise RuntimeError("SM70/SM75 legacy GDN backend requires CUDA")
os.environ.setdefault("TORCH_CUDA_ARCH_LIST", "7.0;7.5")
src = Path(__file__).with_name("csrc") / "gdn_forward.cu"
_EXT = load(
name="flash_qla_legacy_gdn",
sources=[str(src)],
extra_cuda_cflags=["-O3"],
extra_cflags=["-O3"],
verbose=bool(int(os.environ.get("FLASH_QLA_LEGACY_VERBOSE_BUILD", "0"))),
)
return _EXT
def _check_inputs(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
initial_state: torch.Tensor | None,
) -> None:
tensors = [q, k, v, g, beta]
if initial_state is not None:
tensors.append(initial_state)
if any(not tensor.is_cuda for tensor in tensors):
raise ValueError("legacy GDN tensors must be CUDA tensors")
# if any(tensor.dtype != torch.float32 for tensor in tensors):
# raise ValueError("legacy GDN backend currently supports float32 tensors only")
if any(not tensor.is_contiguous() for tensor in tensors):
raise ValueError("legacy GDN tensors must be contiguous")
if q.ndim != 4 or k.ndim != 4 or v.ndim != 4:
raise ValueError("q, k, and v must have shape [B, T, H, D]")
if g.ndim != 3 or beta.ndim != 3:
raise ValueError("g and beta must have shape [B, T, Hv]")
if q.shape != k.shape:
raise ValueError("q and k must have the same shape")
batch, tokens, q_heads, dim = q.shape
if v.shape[0] != batch or v.shape[1] != tokens or v.shape[3] != dim:
raise ValueError("v must have shape [B, T, Hv, D] matching q/k")
if g.shape != beta.shape or g.shape != v.shape[:3]:
raise ValueError("g and beta must have shape [B, T, Hv]")
if v.shape[2] % q_heads != 0:
raise ValueError("Hv must be divisible by Hq")
if dim not in (16, 32, 64, 128):
raise ValueError("legacy GDN backend supports D in {16, 32, 64, 128}")
if initial_state is not None and initial_state.shape != (batch, v.shape[2], dim, dim):
raise ValueError("initial_state must have shape [B, Hv, D, D]")
def chunk_gated_delta_rule_fwd_legacy(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float | None = None,
initial_state: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Run the experimental SM70/SM75 forward-only GDN backend.
This legacy backend is intentionally explicit. It does not replace the
Hopper/SM90 TileLang path and currently supports only contiguous float32
tensors for inference-oriented forward execution.
Shapes:
q, k: [B, T, Hq, D]
v: [B, T, Hv, D]
g, beta: [B, T, Hv]
initial_state: optional [B, Hv, D, D]
Returns:
output: [B, T, Hv, D]
final_state: [B, Hv, D, D]
"""
_check_inputs(q, k, v, g, beta, initial_state)
if scale is None:
scale = q.shape[-1] ** -0.5
ext = _load_ext()
return ext.gdn_forward(q, k, v, g, beta, initial_state, float(scale))

View File

@@ -0,0 +1,11 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
from .cumsum import chunk_local_cumsum
from .group_reduce import group_reduce_vector
__all__ = [
"chunk_local_cumsum",
"group_reduce_vector",
]

View File

@@ -0,0 +1,165 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
import torch
import tilelang
import tilelang.language as T
from flash_qla.utils import prepare_chunk_indices
@tilelang.jit(
# out_idx=[-1],
)
def tilelang_chunk_local_cumsum(
H,
chunk_size,
accum_dtype,
g_dtype,
seqlen_dtype,
is_varlen,
reverse,
):
data_batch_size = T.dynamic("data_batch_size")
real_batch_size = T.dynamic("real_batch_size")
num_tokens = T.dynamic("num_tokens")
num_chunks = T.dynamic("num_chunks")
block_S = chunk_size
g_shape = (data_batch_size, num_tokens, H)
@T.macro
def kernel_body(
bb,
bc,
batch_idx,
chunk_idx,
seq_start_idx,
seq_end_idx,
g_raw,
g_cumsum,
):
left = seq_start_idx + chunk_idx * block_S
right = left + block_S
g_fragment = T.alloc_fragment((H, block_S), dtype=accum_dtype)
gT_fragment = T.alloc_fragment((block_S, H), dtype=g_dtype)
gT_shared = T.alloc_shared((block_S, H + 1), dtype=g_dtype)
if right <= seq_end_idx:
T.copy(g_raw[bb, left:right, 0:H], gT_fragment)
else:
for j, i in T.Parallel(block_S, H):
if left + j < seq_end_idx:
gT_fragment[j, i] = g_raw[bb, left + j, i]
else:
gT_fragment[j, i] = 0
T.copy(gT_fragment, gT_shared[:, :H])
for i, j in T.Parallel(H, block_S):
g_fragment[i, j] = gT_shared[j, i]
T.cumsum(g_fragment, dim=1, reverse=reverse)
for i, j in T.Parallel(H, block_S):
gT_shared[j, i] = g_fragment[i, j]
T.copy(gT_shared[:, :H], gT_fragment)
if right <= seq_end_idx:
T.copy(gT_fragment, g_cumsum[bb, left:right, 0:H])
else:
for j, i in T.Parallel(block_S, H):
if left + j < seq_end_idx:
g_cumsum[bb, left + j, i] = gT_fragment[j, i]
if is_varlen:
@T.prim_func
def tilelang_chunk_local_cumsum_kernel(
g_raw: T.Tensor(g_shape, dtype=g_dtype),
cu_seqlens: T.Tensor([real_batch_size + 1], dtype=seqlen_dtype),
chunk_indices: T.Tensor([num_chunks, 2], dtype=seqlen_dtype),
g_cumsum: T.Tensor(g_shape, dtype=g_dtype),
):
with T.Kernel(num_chunks, threads=128) as (bc,):
bb = 0
batch_idx = chunk_indices[bc, 0]
chunk_idx = chunk_indices[bc, 1]
seq_start_idx = cu_seqlens[batch_idx]
seq_end_idx = cu_seqlens[batch_idx + 1]
kernel_body(
bb,
bc,
batch_idx,
chunk_idx,
seq_start_idx,
seq_end_idx,
g_raw,
g_cumsum,
)
else:
@T.prim_func
def tilelang_chunk_local_cumsum_kernel(
g_raw: T.Tensor(g_shape, dtype=g_dtype),
g_cumsum: T.Tensor(g_shape, dtype=g_dtype),
num_chunks: T.int32,
):
with T.Kernel(num_chunks, threads=128) as (bc,):
bb = bc % data_batch_size
batch_idx = bb
chunk_idx = bc // data_batch_size
seq_start_idx = 0
seq_end_idx = num_tokens
kernel_body(
bb,
bc,
batch_idx,
chunk_idx,
seq_start_idx,
seq_end_idx,
g_raw,
g_cumsum,
)
return tilelang_chunk_local_cumsum_kernel
def chunk_local_cumsum(
g: torch.Tensor,
chunk_size: int = 64,
cu_seqlens: torch.LongTensor | None = None,
reverse: bool = False,
):
batch_size, num_tokens, H = g.shape
assert g.stride(-1) == 1
if cu_seqlens is None:
num_chunks = batch_size * tilelang.cdiv(num_tokens, chunk_size)
seqlen_dtype = "int32"
is_varlen = False
else:
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
seqlen_dtype = cu_seqlens.dtype
is_varlen = True
g_cumsum = torch.empty_like(g)
tilelang_chunk_local_cumsum_kernel = tilelang_chunk_local_cumsum(
H,
chunk_size,
g_dtype=g.dtype,
seqlen_dtype=seqlen_dtype,
accum_dtype="float32",
is_varlen=is_varlen,
reverse=reverse,
)
if is_varlen:
tilelang_chunk_local_cumsum_kernel(g, cu_seqlens, chunk_indices, g_cumsum)
else:
tilelang_chunk_local_cumsum_kernel(g, g_cumsum, num_chunks)
return g_cumsum

View File

@@ -0,0 +1,79 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
import torch
import tilelang
import tilelang.language as T
@tilelang.jit(
# out_idx=[-1],
)
def tilelang_group_reduce_vector(
H,
Hg,
DK,
accum_dtype,
qkva_dtype,
block_size: int = 16,
):
batch_size = T.dynamic("batch_size")
num_tokens = T.dynamic("num_tokens")
group_size = H // Hg
buffer_shape = (batch_size, num_tokens, H, DK)
dqk_shape = (batch_size, num_tokens, Hg, DK)
@T.prim_func
def tilelang_group_reduce_vector_kernel(
buffer: T.Tensor(buffer_shape, dtype=qkva_dtype),
result: T.Tensor(dqk_shape, dtype=qkva_dtype),
):
with T.Kernel(
tilelang.cdiv(num_tokens, block_size), Hg, batch_size, threads=128
) as (bt, bhg, bb):
buffer_fragment = T.alloc_fragment((block_size, DK), dtype=accum_dtype)
result_fragment = T.alloc_fragment((block_size, DK), dtype=accum_dtype)
T.clear(result_fragment)
for i in T.serial(group_size):
T.copy(
buffer[
bb,
bt * block_size : (bt + 1) * block_size,
bhg * group_size + i,
0:DK,
],
buffer_fragment,
)
for j, k in T.Parallel(block_size, DK):
result_fragment[j, k] += buffer_fragment[j, k]
T.copy(
result_fragment,
result[bb, bt * block_size : (bt + 1) * block_size, bhg, 0:DK],
)
return tilelang_group_reduce_vector_kernel
def group_reduce_vector(
buffer: torch.Tensor,
Hg: int,
):
batch_size, num_tokens, H, K = buffer.shape
result = torch.empty(
(batch_size, num_tokens, Hg, K), dtype=buffer.dtype, device=buffer.device
)
tilelang_group_reduce_vector_kernel = tilelang_group_reduce_vector(
H,
Hg,
K,
qkva_dtype=buffer.dtype,
accum_dtype="float32",
)
tilelang_group_reduce_vector_kernel(buffer, result)
return result

View File

@@ -0,0 +1,20 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
from .profiler import profile
from .pack import pad_and_reshape, pack, unpack, fill_last_chunk_of_g
from .math import l2norm
from .index import prepare_chunk_indices, prepare_chunk_offsets, tensor_cache
__all__ = [
"profile",
"pad_and_reshape",
"pack",
"unpack",
"fill_last_chunk_of_g",
"l2norm",
"prepare_chunk_indices",
"prepare_chunk_offsets",
"tensor_cache",
]

138
flash_qla/utils/index.py Normal file
View File

@@ -0,0 +1,138 @@
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
import functools
from typing import Any
from collections import OrderedDict
from collections.abc import Callable
import torch
import tilelang
import tilelang.language as T
def tensor_cache(
fn: Callable[..., torch.Tensor],
) -> Callable[..., torch.Tensor]:
"""
A decorator that caches the most recent results of a function with tensor inputs.
This decorator will store the output of the decorated function for the most recent set of input tensors.
The cache is limited to a fixed size (default is 256). When the cache is full, the oldest entry will be removed.
Args:
fn (Callable[..., torch.Tensor]):
The function to be decorated. It should take tensor inputs and return tensor outputs.
Returns:
Callable[..., torch.Tensor]:
A wrapped version of the input function with single-entry caching.
"""
cache: "OrderedDict[tuple[tuple[int, ...], tuple[tuple[str, int], ...]], tuple[tuple[Any, ...], dict[str, Any], Any]]" = OrderedDict()
cache_size = 256
def get_id(x: Any):
if (type(x) is int) or (type(x) is float) or (type(x) is str):
return x
else:
return id(x)
def make_identity_key(
args: tuple[Any, ...], kwargs: dict[str, Any]
) -> tuple[tuple[int, ...], tuple[tuple[str, int], ...]]:
args_key = tuple(get_id(a) for a in args)
kwargs_key = tuple(sorted((k, get_id(v)) for k, v in kwargs.items()))
return args_key, kwargs_key
@functools.wraps(fn)
def wrapper(*args: Any, **kwargs: Any) -> Any:
nonlocal cache, cache_size
key = make_identity_key(args, kwargs)
if key in cache:
cache.move_to_end(key, last=True)
_, _, cached_result = cache[key]
return cached_result
result = fn(*args, **kwargs)
cache[key] = (args, kwargs, result)
cache.move_to_end(key, last=True)
if len(cache) > cache_size:
cache.popitem(last=False)
return result
return wrapper
@tensor_cache
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
return torch.diff(cu_seqlens)
@tensor_cache
def prepare_chunk_indices(
cu_seqlens: torch.LongTensor,
chunk_size: int,
) -> torch.LongTensor:
# TODO: tilelang kernel
indices = torch.cat(
[
torch.arange(n)
for n in tilelang.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()
]
)
return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens)
@tilelang.jit()
def tilelang_prepare_chunk_offsets(
chunk_size,
block_size,
dtype,
):
batch_size_plus_1 = T.dynamic("batch_size_plus_1")
num_threads = min(max(block_size, 32), 128)
@T.prim_func
def tilelang_prepare_chunk_offsets_kernel(
cu_seqlens: T.Tensor([batch_size_plus_1], dtype=dtype),
chunk_offsets: T.Tensor([batch_size_plus_1], dtype=dtype),
):
with T.Kernel(1, threads=num_threads) as (bb,):
_batch_size = T.alloc_var("int32")
_batch_size = batch_size_plus_1 - 1
seqlen_start_fragment = T.alloc_fragment((block_size), dtype=dtype)
seqlen_end_fragment = T.alloc_fragment((block_size), dtype=dtype)
chunk_offset_fragment = T.alloc_fragment((block_size), dtype=dtype)
T.copy(cu_seqlens[: batch_size_plus_1 - 1], seqlen_start_fragment)
T.copy(cu_seqlens[1:], seqlen_end_fragment)
for i in T.Parallel(block_size):
chunk_offset_fragment[i] = (
seqlen_end_fragment[i] - seqlen_start_fragment[i]
)
chunk_offset_fragment[i] = (
chunk_offset_fragment[i] + chunk_size - 1
) // chunk_size
T.cumsum(src=chunk_offset_fragment, dim=0)
chunk_offsets[0] = 0
T.copy(chunk_offset_fragment, chunk_offsets[1:])
return tilelang_prepare_chunk_offsets_kernel
@tensor_cache
def prepare_chunk_offsets(
cu_seqlens: torch.LongTensor,
chunk_size: int,
) -> torch.LongTensor:
chunk_offsets = torch.empty_like(cu_seqlens)
tilelang_prepare_chunk_offsets_kernel = tilelang_prepare_chunk_offsets(
chunk_size=chunk_size,
block_size=tilelang.next_power_of_2(cu_seqlens.shape[0] - 1),
dtype=cu_seqlens.dtype,
)
tilelang_prepare_chunk_offsets_kernel(cu_seqlens, chunk_offsets)
return chunk_offsets, chunk_offsets[-1].item()

21
flash_qla/utils/math.py Normal file
View File

@@ -0,0 +1,21 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
import torch
@torch.compile
def l2norm_compiled(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6):
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
return (x * inv_norm).to(x.dtype)
def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6):
assert dim == -1
assert x.stride(-1) == 1
raw_shape = x.shape
x = x.view((-1, raw_shape[-1]))
torch._dynamo.mark_dynamic(x, 0)
y = l2norm_compiled(x, dim, eps)
y = y.view(raw_shape)
return y

83
flash_qla/utils/pack.py Normal file
View File

@@ -0,0 +1,83 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
import torch
def unpack(
x: torch.Tensor, # [B, T, H]
cu_seqlens: torch.Tensor,
):
assert x.shape[0] == 1
assert len(cu_seqlens.shape) == 1
max_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
batch_size = cu_seqlens.shape[0] - 1
y = torch.zeros((batch_size, max_len, *x.shape[2:]), dtype=x.dtype, device=x.device)
for i in range(batch_size):
start = cu_seqlens[i].item()
end = cu_seqlens[i + 1].item()
y[i, : end - start] = x[0, start:end]
return y
def pack(
x: torch.Tensor, # [B, T, H]
cu_seqlens: torch.Tensor,
):
assert len(cu_seqlens.shape) == 1
sum_len = cu_seqlens[-1].item()
batch_size = cu_seqlens.shape[0] - 1
y = torch.empty((1, sum_len, *x.shape[2:]), dtype=x.dtype, device=x.device)
for i in range(batch_size):
start = cu_seqlens[i].item()
end = cu_seqlens[i + 1].item()
y[0, start:end] = x[i, : end - start]
return y
def pad_and_reshape(
x: torch.Tensor,
dim: int,
chunk_size: int = 64,
):
sequence_length = x.shape[dim]
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
zeros = [
0,
] * (2 * (len(x.shape) - 1 - dim))
padded = torch.nn.functional.pad(x, (*zeros, 0, pad_size))
return padded.reshape((*x.shape[:dim], -1, chunk_size, *x.shape[dim + 1 :]))
def fill_last_chunk_of_g(
g: torch.Tensor,
num_tokens: int,
cu_seqlens: torch.Tensor,
chunk_size: int = 64,
reverse: bool = False,
):
if cu_seqlens is None:
last_chunk_size = num_tokens % chunk_size
if last_chunk_size > 0:
if reverse:
g[:, -1, last_chunk_size - 1] += g[:, -1, -1]
else:
g[:, -1, last_chunk_size:] = g[
:, -1, last_chunk_size - 1 : last_chunk_size
]
else:
for i in range(cu_seqlens.shape[0] - 1):
start = cu_seqlens[i].item()
end = cu_seqlens[i + 1].item()
last_chunk_idx = (end - start) // chunk_size
last_chunk_size = (end - start) % chunk_size
if last_chunk_size > 0:
if reverse:
g[i, last_chunk_idx, last_chunk_size - 1] += g[
i, last_chunk_idx, -1
]
else:
g[i, last_chunk_idx, last_chunk_size:] = g[
i, last_chunk_idx, last_chunk_size - 1 : last_chunk_size
]
return g

View File

@@ -0,0 +1,25 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
import torch
import tilelang
def profile(func, inputs, wait: int = 50, warmup: int = 50, rep: int = 100):
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=rep),
# on_trace_ready=torch.profiler.tensorboard_trace_handler('./tb'),
) as prof:
for idx in range(wait + warmup + rep):
func(*inputs)
prof.step()
# print(prof.key_averages().table(sort_by="cpu_time", row_limit=10))
result = {x.key: x.device_time * 1e-3 for x in prof.key_averages()}
result["total"] = tilelang.profiler.do_bench(
lambda: func(*inputs), warmup=warmup, rep=rep
)
return result