first commit
This commit is contained in:
7
flash_qla/ops/__init__.py
Normal file
7
flash_qla/ops/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# Copyright (c) 2026 The Qwen team, Alibaba Group.
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
from .gated_delta_rule import chunk_gated_delta_rule
|
||||
|
||||
|
||||
__all__ = ["chunk_gated_delta_rule"]
|
||||
7
flash_qla/ops/gated_delta_rule/__init__.py
Normal file
7
flash_qla/ops/gated_delta_rule/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# Copyright (c) 2026 The Qwen team, Alibaba Group.
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
from .chunk import chunk_gated_delta_rule
|
||||
|
||||
|
||||
__all__ = ["chunk_gated_delta_rule"]
|
||||
237
flash_qla/ops/gated_delta_rule/chunk/__init__.py
Normal file
237
flash_qla/ops/gated_delta_rule/chunk/__init__.py
Normal file
@@ -0,0 +1,237 @@
|
||||
# Copyright (c) 2026 The Qwen team, Alibaba Group.
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
import torch
|
||||
import tilelang
|
||||
|
||||
from flash_qla.utils import l2norm
|
||||
from flash_qla.ops.utils import chunk_local_cumsum, group_reduce_vector
|
||||
|
||||
if tilelang.contrib.nvcc.get_target_compute_version() == "9.0":
|
||||
from .hopper import fused_gdr_fwd, fused_gdr_bwd, fused_gdr_h, kkt_solve
|
||||
else:
|
||||
raise ValueError("FlashQLA now support sm90 only.")
|
||||
from .cp_context import intra_card_cp_preprocess
|
||||
|
||||
|
||||
def chunk_gated_delta_rule_fwd(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float | None = None,
|
||||
initial_state: torch.Tensor | None = None,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
output_final_state: bool = True,
|
||||
output_h: bool = False,
|
||||
auto_cp: bool = True,
|
||||
):
|
||||
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
|
||||
A = kkt_solve(
|
||||
k=k,
|
||||
b=beta,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
if auto_cp:
|
||||
initial_state, cu_seqlens, cp_seq_map, raw_cu_seqlens = (
|
||||
intra_card_cp_preprocess(
|
||||
k=k,
|
||||
v=v,
|
||||
a=A,
|
||||
g=g,
|
||||
b=beta,
|
||||
raw_h0=initial_state,
|
||||
raw_cu_seqlens=cu_seqlens,
|
||||
)
|
||||
)
|
||||
else:
|
||||
cp_seq_map = None
|
||||
raw_cu_seqlens = None
|
||||
o, h, final_state = fused_gdr_fwd(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
a=A,
|
||||
g=g,
|
||||
b=beta,
|
||||
scale=scale,
|
||||
initial_state=initial_state,
|
||||
output_final_state=output_final_state,
|
||||
output_h=output_h,
|
||||
output_o=True,
|
||||
cu_seqlens=cu_seqlens,
|
||||
cp_seq_map=cp_seq_map,
|
||||
raw_cu_seqlens=raw_cu_seqlens,
|
||||
)
|
||||
return g, A, o, h, final_state
|
||||
|
||||
|
||||
def chunk_gated_delta_rule_bwd(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
A: torch.Tensor,
|
||||
do: torch.Tensor,
|
||||
dht: torch.Tensor | None = None,
|
||||
scale: float | None = None,
|
||||
initial_state: torch.Tensor | None = None,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
):
|
||||
h, _, _ = fused_gdr_h(
|
||||
k=k,
|
||||
v=v,
|
||||
a=A,
|
||||
g=g,
|
||||
b=beta,
|
||||
initial_state=initial_state,
|
||||
output_final_state=False,
|
||||
output_h=True,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
dq, dk, dv, dg, db, dh0 = fused_gdr_bwd(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
a=A,
|
||||
g=g,
|
||||
b=beta,
|
||||
do=do,
|
||||
dht=dht,
|
||||
h=h,
|
||||
scale=scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
Hg, H = k.shape[-2], v.shape[-2]
|
||||
if Hg < H:
|
||||
dq = group_reduce_vector(dq, Hg)
|
||||
dk = group_reduce_vector(dk, Hg)
|
||||
assert dg.dtype == torch.float32, "dg should be fp32"
|
||||
dg = chunk_local_cumsum(dg, chunk_size=64, reverse=True, cu_seqlens=cu_seqlens)
|
||||
return dq, dk, dv, db, dg, dh0
|
||||
|
||||
|
||||
class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@torch.amp.custom_fwd(device_type="cuda")
|
||||
def forward(
|
||||
ctx,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float | None = None,
|
||||
initial_state: torch.Tensor | None = None,
|
||||
output_final_state: bool = False,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
):
|
||||
q_orig = q
|
||||
k_orig = k
|
||||
|
||||
g, A, o, _, final_state = chunk_gated_delta_rule_fwd(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
g=g,
|
||||
beta=beta,
|
||||
scale=scale,
|
||||
initial_state=initial_state,
|
||||
output_final_state=output_final_state,
|
||||
output_h=False,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
|
||||
ctx.save_for_backward(q_orig, k_orig, v, g, beta, A, initial_state, cu_seqlens)
|
||||
ctx.scale = scale
|
||||
return o.to(q.dtype), final_state
|
||||
|
||||
@staticmethod
|
||||
@torch.amp.custom_bwd(device_type="cuda")
|
||||
def backward(ctx, do: torch.Tensor, dht: torch.Tensor):
|
||||
q_orig, k_orig, v, g, beta, A, initial_state, cu_seqlens = ctx.saved_tensors
|
||||
|
||||
dq, dk, dv, db, dg, dh0 = chunk_gated_delta_rule_bwd(
|
||||
q=q_orig,
|
||||
k=k_orig,
|
||||
v=v,
|
||||
g=g,
|
||||
beta=beta,
|
||||
A=A,
|
||||
do=do,
|
||||
dht=dht,
|
||||
scale=ctx.scale,
|
||||
initial_state=initial_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
|
||||
return (
|
||||
dq.to(q_orig),
|
||||
dk.to(k_orig),
|
||||
dv.to(v),
|
||||
dg.to(g),
|
||||
db.to(beta),
|
||||
None,
|
||||
dh0,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
@torch.compiler.disable
|
||||
def chunk_gated_delta_rule(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
head_first: bool = False,
|
||||
):
|
||||
assert q.dtype == k.dtype == v.dtype
|
||||
assert q.dtype != torch.float32, (
|
||||
"ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16 or float16."
|
||||
)
|
||||
assert not head_first, "head_first=True is not supported."
|
||||
assert v.shape[2] % k.shape[2] == 0, (
|
||||
"num_qk_heads must be divisible to num_v_heads."
|
||||
)
|
||||
|
||||
if cu_seqlens is not None:
|
||||
if q.shape[0] != 1:
|
||||
raise ValueError(
|
||||
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
||||
f"Please flatten variable-length inputs before processing."
|
||||
)
|
||||
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
|
||||
raise ValueError(
|
||||
f"The number of initial states is expected to be equal to the number of input sequences, "
|
||||
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
|
||||
)
|
||||
|
||||
if scale is None:
|
||||
scale = k.shape[-1] ** -0.5
|
||||
|
||||
if use_qk_l2norm_in_kernel:
|
||||
q = l2norm(q)
|
||||
k = l2norm(k)
|
||||
|
||||
o, final_state = ChunkGatedDeltaRuleFunction.apply(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
g,
|
||||
beta,
|
||||
scale,
|
||||
initial_state,
|
||||
output_final_state,
|
||||
cu_seqlens,
|
||||
)
|
||||
|
||||
return o, final_state
|
||||
163
flash_qla/ops/gated_delta_rule/chunk/cp_context.py
Normal file
163
flash_qla/ops/gated_delta_rule/chunk/cp_context.py
Normal file
@@ -0,0 +1,163 @@
|
||||
# Copyright (c) 2026 The Qwen team, Alibaba Group.
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import tilelang
|
||||
|
||||
from flash_qla.utils import tensor_cache
|
||||
|
||||
if tilelang.contrib.nvcc.get_target_compute_version() == "9.0":
|
||||
from .hopper import get_warmup_chunks, fused_gdr_h, correct_initial_states
|
||||
else:
|
||||
raise ValueError("FlashQLA now support sm90 only.")
|
||||
|
||||
|
||||
MULTI_PROCESSOR_COUNT = torch.cuda.get_device_properties().multi_processor_count
|
||||
|
||||
|
||||
@tensor_cache
|
||||
def _create_cu_seqlens(
|
||||
batch_size: int,
|
||||
num_tokens: int,
|
||||
device_idx: int,
|
||||
):
|
||||
return (
|
||||
torch.arange((batch_size + 1), dtype=torch.int32, device=f"cuda:{device_idx}")
|
||||
* num_tokens
|
||||
)
|
||||
|
||||
|
||||
@tensor_cache
|
||||
def _calc_cp_seqs(
|
||||
raw_cu_seqlens: torch.LongTensor,
|
||||
chunk_size: int,
|
||||
num_v_heads: int,
|
||||
):
|
||||
# TODO: tilelang kernel
|
||||
device = raw_cu_seqlens.device
|
||||
seqlen_dtype = raw_cu_seqlens.dtype
|
||||
raw_cu_seqlens = raw_cu_seqlens.tolist()
|
||||
raw_batch_size = len(raw_cu_seqlens) - 1
|
||||
seqlens = [raw_cu_seqlens[i + 1] - raw_cu_seqlens[i] for i in range(raw_batch_size)]
|
||||
num_chunks = [tilelang.cdiv(x, chunk_size) for x in seqlens]
|
||||
|
||||
# autocp
|
||||
H = num_v_heads
|
||||
# Latency model: T = a·L_cp + b·(B·H·Lc/P) / L_cp + c
|
||||
# Minimizing T yields the theoretical optimum: L_cp* ∝ √(B·H·Lc / P), where P = MULTI_PROCESSOR_COUNT, L_cp = max_local_chunks
|
||||
# Scaled by empirical factor (3) and aligned to the nearest power of 2 for optimal SM scheduling & memory alignment.
|
||||
|
||||
max_local_chunks = 2 ** round(
|
||||
math.log2(math.sqrt(H * sum(num_chunks) / MULTI_PROCESSOR_COUNT) * 3)
|
||||
)
|
||||
|
||||
# Set min to 4 to ensure multi-stage pipelining in fused_gdr;
|
||||
max_local_chunks = max(max_local_chunks, 4)
|
||||
|
||||
use_cp = False
|
||||
cp_cu_seqlens = []
|
||||
ht_mask = []
|
||||
seq_map_c2r = []
|
||||
seq_map_r2c = [0]
|
||||
max_local_tokens = max_local_chunks * chunk_size
|
||||
for i, c in enumerate(num_chunks):
|
||||
s = raw_cu_seqlens[i]
|
||||
e = raw_cu_seqlens[i + 1]
|
||||
if c > max_local_chunks:
|
||||
while s < e:
|
||||
cp_cu_seqlens.append(s)
|
||||
ht_mask.append(False)
|
||||
seq_map_c2r.append(i)
|
||||
s += max_local_tokens
|
||||
ht_mask[-1] = True
|
||||
else:
|
||||
cp_cu_seqlens.append(s)
|
||||
ht_mask.append(True)
|
||||
seq_map_c2r.append(i)
|
||||
seq_map_r2c.append(len(cp_cu_seqlens))
|
||||
cp_cu_seqlens.append(raw_cu_seqlens[-1])
|
||||
|
||||
# Disable CP when B * H naturally saturates SM occupancy.
|
||||
# For varlen inputs, use `total_chunks / max_seq_chunks` as effective B,
|
||||
# since CP helps accelerate highly uneven sequence lengths.
|
||||
|
||||
Be = sum(num_chunks) / max(num_chunks)
|
||||
use_cp = Be * H <= 40 or (Be * H <= 56 and max(num_chunks) >= 128)
|
||||
|
||||
if use_cp:
|
||||
cp_cu_seqlens = torch.tensor(
|
||||
cp_cu_seqlens, dtype=seqlen_dtype, device=device, requires_grad=False
|
||||
)
|
||||
seq_map_c2r = torch.tensor(seq_map_c2r, dtype=seqlen_dtype, device=device)
|
||||
seq_map_r2c = torch.tensor(
|
||||
seq_map_r2c, dtype=seqlen_dtype, device=device, requires_grad=False
|
||||
)
|
||||
ht_mask = torch.tensor(
|
||||
ht_mask, dtype=torch.bool, device=device, requires_grad=False
|
||||
)
|
||||
else:
|
||||
cp_cu_seqlens, seq_map_r2c, ht_mask = None, None, None
|
||||
|
||||
return use_cp, cp_cu_seqlens, seq_map_r2c, seq_map_c2r, ht_mask
|
||||
|
||||
|
||||
def intra_card_cp_preprocess(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
a: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
raw_h0: torch.Tensor,
|
||||
raw_cu_seqlens: torch.Tensor,
|
||||
warmup_threshold: float = -10.0,
|
||||
):
|
||||
batch_size, num_tokens, num_k_heads, k_head_dim = k.shape
|
||||
_, _, num_v_heads, v_head_dim = v.shape
|
||||
chunk_size = a.shape[-1]
|
||||
device = k.device
|
||||
|
||||
if batch_size > 1:
|
||||
return raw_h0, raw_cu_seqlens, None, None
|
||||
|
||||
if raw_cu_seqlens is None:
|
||||
raw_cu_seqlens = _create_cu_seqlens(batch_size, num_tokens, device.index)
|
||||
|
||||
use_cp, cp_cu_seqlens, seq_map_r2c, seq_map_c2r, ht_mask = _calc_cp_seqs(
|
||||
raw_cu_seqlens,
|
||||
chunk_size,
|
||||
num_v_heads,
|
||||
)
|
||||
|
||||
if not use_cp:
|
||||
return raw_h0, raw_cu_seqlens, None, None
|
||||
|
||||
num_warmup_chunks, fallback_mask = get_warmup_chunks(
|
||||
g=g,
|
||||
cu_seqlens=cp_cu_seqlens,
|
||||
ht_mask=ht_mask,
|
||||
chunk_size=chunk_size,
|
||||
threshold=warmup_threshold,
|
||||
) # [cp_batch_size, num_v_heads]
|
||||
_, ht, mt = fused_gdr_h(
|
||||
k=k,
|
||||
v=v,
|
||||
a=a,
|
||||
g=g,
|
||||
b=b,
|
||||
initial_state=None,
|
||||
output_final_state=True,
|
||||
output_h=False,
|
||||
cu_seqlens=cp_cu_seqlens,
|
||||
num_warmup_chunks=num_warmup_chunks,
|
||||
) # [cp_batch_size, num_v_heads, k_head_dim, v_head_dim]
|
||||
cp_h0 = correct_initial_states(
|
||||
raw_h0=raw_h0,
|
||||
ht_buffer=ht,
|
||||
mt_buffer=mt,
|
||||
fallback_mask=fallback_mask,
|
||||
seq_map_r2c=seq_map_r2c,
|
||||
)
|
||||
|
||||
return cp_h0, cp_cu_seqlens, seq_map_c2r, raw_cu_seqlens
|
||||
18
flash_qla/ops/gated_delta_rule/chunk/hopper/__init__.py
Normal file
18
flash_qla/ops/gated_delta_rule/chunk/hopper/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# Copyright (c) 2026 The Qwen team, Alibaba Group.
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
from .fused_fwd import fused_gdr_fwd
|
||||
from .fused_bwd import fused_gdr_bwd
|
||||
from .prepare_h import fused_gdr_h
|
||||
from .kkt_solve import kkt_solve
|
||||
from .cp_fwd import get_warmup_chunks, correct_initial_states
|
||||
|
||||
|
||||
__all__ = [
|
||||
"fused_gdr_fwd",
|
||||
"fused_gdr_bwd",
|
||||
"fused_gdr_h",
|
||||
"kkt_solve",
|
||||
"get_warmup_chunks",
|
||||
"correct_initial_states",
|
||||
]
|
||||
309
flash_qla/ops/gated_delta_rule/chunk/hopper/cp_fwd.py
Normal file
309
flash_qla/ops/gated_delta_rule/chunk/hopper/cp_fwd.py
Normal file
@@ -0,0 +1,309 @@
|
||||
# Copyright (c) 2026 The Qwen team, Alibaba Group.
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
import torch
|
||||
import tilelang
|
||||
import tilelang.language as T
|
||||
|
||||
|
||||
@tilelang.jit()
|
||||
def tilelang_get_warmup_chunks(
|
||||
num_heads,
|
||||
chunk_size,
|
||||
threshold,
|
||||
accum_dtype,
|
||||
g_dtype,
|
||||
mask_dtype,
|
||||
seqlen_dtype,
|
||||
):
|
||||
batch_size = T.dynamic("batch_size")
|
||||
num_tokens = T.dynamic("num_tokens")
|
||||
num_threads = tilelang.cdiv(num_heads, 32) * 32
|
||||
|
||||
@T.prim_func
|
||||
def tilelang_get_warmup_chunks_kernel(
|
||||
g: T.Tensor([1, num_tokens, num_heads], dtype=g_dtype),
|
||||
ht_mask: T.Tensor([batch_size], dtype=mask_dtype),
|
||||
cu_seqlens: T.Tensor([batch_size + 1], dtype=seqlen_dtype),
|
||||
num_warmup_chunks: T.Tensor([batch_size, num_heads], dtype=seqlen_dtype),
|
||||
fallback_mask: T.Tensor([batch_size, num_heads], dtype=mask_dtype),
|
||||
):
|
||||
with T.Kernel(batch_size, threads=num_threads) as (bb,):
|
||||
if ht_mask[bb]:
|
||||
for i_h in T.Parallel(num_heads):
|
||||
num_warmup_chunks[bb, i_h] = 0
|
||||
else:
|
||||
seq_start_idx = T.alloc_var("int32")
|
||||
seq_end_idx = T.alloc_var("int32")
|
||||
num_iters = T.alloc_var("int32")
|
||||
seq_start_idx = cu_seqlens[bb]
|
||||
seq_end_idx = cu_seqlens[bb + 1]
|
||||
num_iters = (seq_end_idx - seq_start_idx) // chunk_size
|
||||
|
||||
g_fragment = T.alloc_fragment((num_heads), dtype=accum_dtype)
|
||||
g_cumsum = T.alloc_fragment((num_heads), dtype=accum_dtype)
|
||||
n_fragment = T.alloc_fragment((num_heads), dtype=seqlen_dtype)
|
||||
f_fragment = T.alloc_fragment((num_heads), dtype=mask_dtype)
|
||||
T.clear(g_cumsum)
|
||||
T.fill(n_fragment, num_iters)
|
||||
T.fill(f_fragment, True)
|
||||
|
||||
for i_s in T.serial(num_iters):
|
||||
for i_h in T.Parallel(num_heads):
|
||||
g_fragment[i_h] = g[0, seq_end_idx - i_s * chunk_size - 1, i_h]
|
||||
for i_h in T.Parallel(num_heads):
|
||||
g_cumsum[i_h] += g_fragment[i_h]
|
||||
for i_h in T.Parallel(num_heads):
|
||||
if g_cumsum[i_h] < threshold and n_fragment[i_h] == num_iters:
|
||||
n_fragment[i_h] = i_s + 1
|
||||
f_fragment[i_h] = False
|
||||
|
||||
for i_h in T.Parallel(num_heads):
|
||||
num_warmup_chunks[bb, i_h] = n_fragment[i_h]
|
||||
for i_h in T.Parallel(num_heads):
|
||||
fallback_mask[bb, i_h] = f_fragment[i_h]
|
||||
|
||||
return tilelang_get_warmup_chunks_kernel
|
||||
|
||||
|
||||
def get_warmup_chunks(
|
||||
g: torch.Tensor, # [1, num_total_tokens, num_v_heads]
|
||||
cu_seqlens: torch.Tensor, # [cp_real_batch_size + 1]
|
||||
ht_mask: torch.Tensor, # [cp_real_batch_size]
|
||||
chunk_size: int = 64,
|
||||
threshold: float = -10.0,
|
||||
):
|
||||
batch_size, num_tokens, num_heads = g.shape
|
||||
real_batch_size = ht_mask.shape[0]
|
||||
assert cu_seqlens.shape[0] == real_batch_size + 1
|
||||
assert batch_size == 1
|
||||
assert chunk_size == 64
|
||||
|
||||
tilelang_get_warmup_chunks_kernel = tilelang_get_warmup_chunks(
|
||||
num_heads=num_heads,
|
||||
chunk_size=chunk_size,
|
||||
threshold=threshold,
|
||||
accum_dtype="float32",
|
||||
g_dtype=g.dtype,
|
||||
mask_dtype=ht_mask.dtype,
|
||||
seqlen_dtype=cu_seqlens.dtype,
|
||||
)
|
||||
num_warmup_chunks = torch.empty(
|
||||
[real_batch_size, num_heads], dtype=cu_seqlens.dtype, device=cu_seqlens.device
|
||||
)
|
||||
fallback_mask = torch.empty(
|
||||
[real_batch_size, num_heads], dtype=ht_mask.dtype, device=cu_seqlens.device
|
||||
)
|
||||
tilelang_get_warmup_chunks_kernel(
|
||||
g, ht_mask, cu_seqlens, num_warmup_chunks, fallback_mask
|
||||
)
|
||||
|
||||
return num_warmup_chunks, fallback_mask
|
||||
|
||||
|
||||
@tilelang.jit()
|
||||
def tilelang_correct_h0(
|
||||
H,
|
||||
DK,
|
||||
DV,
|
||||
res_dtype,
|
||||
accum_dtype,
|
||||
buffer_dtype,
|
||||
seqlen_dtype,
|
||||
mask_dtype,
|
||||
use_raw_h0,
|
||||
block_DV: int = 32,
|
||||
):
|
||||
cp_batch_size = T.dynamic("cp_batch_size")
|
||||
raw_batch_size = T.dynamic("raw_batch_size")
|
||||
|
||||
@T.macro
|
||||
def kernel_body(
|
||||
bb,
|
||||
bh,
|
||||
bv,
|
||||
seq_start_idx,
|
||||
seq_end_idx,
|
||||
num_iters,
|
||||
ht_buffer,
|
||||
mt_buffer,
|
||||
fallback_mask,
|
||||
seq_map_r2c,
|
||||
cp_h0,
|
||||
h_fragment,
|
||||
):
|
||||
h_shared = T.alloc_shared((DK, block_DV), dtype=buffer_dtype)
|
||||
hd_shared = T.alloc_shared((DK, block_DV), dtype=buffer_dtype)
|
||||
m_shared = T.alloc_shared((DK, DK), dtype=buffer_dtype)
|
||||
|
||||
T.copy(
|
||||
h_fragment,
|
||||
cp_h0[seq_start_idx, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV],
|
||||
)
|
||||
|
||||
for i_s in T.Pipelined(num_iters - 1, num_stages=2):
|
||||
if fallback_mask[seq_start_idx + i_s, bh]:
|
||||
T.copy(h_fragment, hd_shared)
|
||||
T.copy(
|
||||
ht_buffer[
|
||||
seq_start_idx + i_s, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV
|
||||
],
|
||||
h_shared,
|
||||
)
|
||||
T.copy(h_shared, h_fragment)
|
||||
if fallback_mask[seq_start_idx + i_s, bh]:
|
||||
T.copy(mt_buffer[seq_start_idx + i_s, bh, 0:DK, 0:DK], m_shared)
|
||||
T.gemm(m_shared, hd_shared, h_fragment, clear_accum=False)
|
||||
T.copy(
|
||||
h_fragment,
|
||||
cp_h0[
|
||||
seq_start_idx + i_s + 1,
|
||||
bh,
|
||||
0:DK,
|
||||
bv * block_DV : (bv + 1) * block_DV,
|
||||
],
|
||||
)
|
||||
|
||||
if use_raw_h0:
|
||||
|
||||
@T.prim_func
|
||||
def tilelang_correct_h0_kernel(
|
||||
raw_h0: T.Tensor([raw_batch_size, H, DK, DV], dtype=res_dtype),
|
||||
ht_buffer: T.Tensor([cp_batch_size, H, DK, DV], dtype=buffer_dtype),
|
||||
mt_buffer: T.Tensor([cp_batch_size, H, DK, DK], dtype=buffer_dtype),
|
||||
fallback_mask: T.Tensor([cp_batch_size, H], dtype=mask_dtype),
|
||||
seq_map_r2c: T.Tensor([raw_batch_size + 1], dtype=seqlen_dtype),
|
||||
cp_h0: T.Tensor([cp_batch_size, H, DK, DV], dtype=res_dtype),
|
||||
):
|
||||
with T.Kernel(
|
||||
T.ceildiv(DV, block_DV) * H * raw_batch_size, threads=128
|
||||
) as (bbhv,):
|
||||
bbh, bv = (
|
||||
bbhv // T.ceildiv(DV, block_DV),
|
||||
bbhv % T.ceildiv(DV, block_DV),
|
||||
)
|
||||
bb, bh = bbh // H, bbh % H
|
||||
|
||||
seq_start_idx = seq_map_r2c[bb]
|
||||
seq_end_idx = seq_map_r2c[bb + 1]
|
||||
num_iters = seq_end_idx - seq_start_idx
|
||||
|
||||
h_fragment = T.alloc_fragment((DK, block_DV), dtype=accum_dtype)
|
||||
T.copy(
|
||||
raw_h0[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV],
|
||||
h_fragment,
|
||||
)
|
||||
|
||||
kernel_body(
|
||||
bb,
|
||||
bh,
|
||||
bv,
|
||||
seq_start_idx,
|
||||
seq_end_idx,
|
||||
num_iters,
|
||||
ht_buffer,
|
||||
mt_buffer,
|
||||
fallback_mask,
|
||||
seq_map_r2c,
|
||||
cp_h0,
|
||||
h_fragment,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
@T.prim_func
|
||||
def tilelang_correct_h0_kernel(
|
||||
ht_buffer: T.Tensor([cp_batch_size, H, DK, DV], dtype=buffer_dtype),
|
||||
mt_buffer: T.Tensor([cp_batch_size, H, DK, DK], dtype=buffer_dtype),
|
||||
fallback_mask: T.Tensor([cp_batch_size, H], dtype=mask_dtype),
|
||||
seq_map_r2c: T.Tensor([raw_batch_size + 1], dtype=seqlen_dtype),
|
||||
cp_h0: T.Tensor([cp_batch_size, H, DK, DV], dtype=res_dtype),
|
||||
):
|
||||
with T.Kernel(
|
||||
T.ceildiv(DV, block_DV) * H * raw_batch_size, threads=128
|
||||
) as (bbhv,):
|
||||
bbh, bv = (
|
||||
bbhv // T.ceildiv(DV, block_DV),
|
||||
bbhv % T.ceildiv(DV, block_DV),
|
||||
)
|
||||
bb, bh = bbh // H, bbh % H
|
||||
|
||||
seq_start_idx = seq_map_r2c[bb]
|
||||
seq_end_idx = seq_map_r2c[bb + 1]
|
||||
num_iters = seq_end_idx - seq_start_idx
|
||||
|
||||
h_fragment = T.alloc_fragment((DK, block_DV), dtype=accum_dtype)
|
||||
T.clear(h_fragment)
|
||||
|
||||
kernel_body(
|
||||
bb,
|
||||
bh,
|
||||
bv,
|
||||
seq_start_idx,
|
||||
seq_end_idx,
|
||||
num_iters,
|
||||
ht_buffer,
|
||||
mt_buffer,
|
||||
fallback_mask,
|
||||
seq_map_r2c,
|
||||
cp_h0,
|
||||
h_fragment,
|
||||
)
|
||||
|
||||
return tilelang_correct_h0_kernel
|
||||
|
||||
|
||||
def correct_initial_states(
|
||||
raw_h0: torch.Tensor
|
||||
| None, # [raw_batch_size, num_v_heads, k_head_dim, v_head_dim]
|
||||
ht_buffer: torch.Tensor, # [cp_batch_size, num_v_heads, k_head_dim, v_head_dim]
|
||||
mt_buffer: torch.Tensor, # [cp_batch_size, num_v_heads, k_head_dim, k_head_dim]
|
||||
fallback_mask: torch.Tensor, # [cp_batch_size, num_v_heads]
|
||||
seq_map_r2c: torch.Tensor, # [raw_batch_size + 1]
|
||||
):
|
||||
cp_batch_size = fallback_mask.shape[0]
|
||||
_, num_heads, k_head_dim, v_head_dim = ht_buffer.shape
|
||||
assert k_head_dim == v_head_dim == 128
|
||||
|
||||
if raw_h0 is None:
|
||||
res_dtype = torch.float32
|
||||
use_raw_h0 = False
|
||||
else:
|
||||
res_dtype = raw_h0.dtype
|
||||
use_raw_h0 = True
|
||||
|
||||
tilelang_correct_h0_kernel = tilelang_correct_h0(
|
||||
H=num_heads,
|
||||
DK=k_head_dim,
|
||||
DV=v_head_dim,
|
||||
res_dtype=res_dtype,
|
||||
accum_dtype="float32",
|
||||
buffer_dtype=ht_buffer.dtype,
|
||||
seqlen_dtype=seq_map_r2c.dtype,
|
||||
mask_dtype=fallback_mask.dtype,
|
||||
use_raw_h0=use_raw_h0,
|
||||
)
|
||||
cp_h0 = torch.empty(
|
||||
(cp_batch_size, num_heads, k_head_dim, v_head_dim),
|
||||
dtype=res_dtype,
|
||||
device=ht_buffer.device,
|
||||
)
|
||||
if use_raw_h0:
|
||||
tilelang_correct_h0_kernel(
|
||||
raw_h0,
|
||||
ht_buffer,
|
||||
mt_buffer,
|
||||
fallback_mask,
|
||||
seq_map_r2c,
|
||||
cp_h0,
|
||||
)
|
||||
else:
|
||||
tilelang_correct_h0_kernel(
|
||||
ht_buffer,
|
||||
mt_buffer,
|
||||
fallback_mask,
|
||||
seq_map_r2c,
|
||||
cp_h0,
|
||||
)
|
||||
|
||||
return cp_h0
|
||||
985
flash_qla/ops/gated_delta_rule/chunk/hopper/fused_bwd.py
Normal file
985
flash_qla/ops/gated_delta_rule/chunk/hopper/fused_bwd.py
Normal file
@@ -0,0 +1,985 @@
|
||||
# Copyright (c) 2026 The Qwen team, Alibaba Group.
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
import torch
|
||||
import tilelang
|
||||
import tilelang.language as T
|
||||
|
||||
from flash_qla.utils import prepare_chunk_offsets
|
||||
|
||||
|
||||
@tilelang.jit(
|
||||
# out_idx=[-5, -4, -3, -2, -1],
|
||||
pass_configs={
|
||||
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
|
||||
tilelang.PassConfigKey.TL_DISABLE_DATA_RACE_CHECK: True,
|
||||
},
|
||||
)
|
||||
def tilelang_fused_chunk_gdr_bwd(
|
||||
H,
|
||||
Hg,
|
||||
DK,
|
||||
DV,
|
||||
chunk_size,
|
||||
scale,
|
||||
accum_dtype,
|
||||
qkva_dtype,
|
||||
g_dtype,
|
||||
b_dtype,
|
||||
h_dtype,
|
||||
o_dtype,
|
||||
seqlen_dtype,
|
||||
is_varlen,
|
||||
use_dht,
|
||||
):
|
||||
batch_size = T.dynamic("batch_size")
|
||||
num_tokens = T.dynamic("num_tokens")
|
||||
num_chunks = T.dynamic("num_chunks")
|
||||
block_S = chunk_size
|
||||
|
||||
if is_varlen:
|
||||
q_shape = (1, num_tokens, Hg, DK)
|
||||
k_shape = (1, num_tokens, Hg, DK)
|
||||
v_shape = (1, num_tokens, H, DV)
|
||||
o_shape = (1, num_tokens, H, DV)
|
||||
a_shape = (1, num_tokens, H, chunk_size)
|
||||
g_shape = (1, num_tokens, H)
|
||||
b_shape = (1, num_tokens, H)
|
||||
h_shape = (1, num_chunks, H, DK, DV)
|
||||
else:
|
||||
q_shape = (batch_size, num_tokens, Hg, DK)
|
||||
k_shape = (batch_size, num_tokens, Hg, DK)
|
||||
v_shape = (batch_size, num_tokens, H, DV)
|
||||
o_shape = (batch_size, num_tokens, H, DV)
|
||||
a_shape = (batch_size, num_tokens, H, chunk_size)
|
||||
g_shape = (batch_size, num_tokens, H)
|
||||
b_shape = (batch_size, num_tokens, H)
|
||||
h_shape = (batch_size, num_chunks, H, DK, DV)
|
||||
h0_shape = (batch_size, H, DK, DV)
|
||||
ht_shape = (batch_size, H, DK, DV)
|
||||
|
||||
@T.prim_func
|
||||
def tilelang_fused_chunk_gdr_bwd_kernel(
|
||||
do: T.Tensor(o_shape, dtype=o_dtype),
|
||||
dht: T.Tensor(ht_shape, dtype=accum_dtype),
|
||||
q: T.Tensor(q_shape, dtype=qkva_dtype),
|
||||
k: T.Tensor(k_shape, dtype=qkva_dtype),
|
||||
v: T.Tensor(v_shape, dtype=qkva_dtype),
|
||||
a: T.Tensor(a_shape, dtype=qkva_dtype),
|
||||
g: T.Tensor(g_shape, dtype=g_dtype),
|
||||
b: T.Tensor(b_shape, dtype=b_dtype),
|
||||
h: T.Tensor(h_shape, dtype=h_dtype),
|
||||
cu_seqlens: T.Tensor([batch_size + 1], dtype=seqlen_dtype),
|
||||
chunk_offsets: T.Tensor([batch_size + 1], dtype=seqlen_dtype),
|
||||
dq: T.Tensor(v_shape, dtype=qkva_dtype),
|
||||
dk: T.Tensor(v_shape, dtype=qkva_dtype),
|
||||
dv: T.Tensor(v_shape, dtype=qkva_dtype),
|
||||
dg: T.Tensor(g_shape, dtype=g_dtype),
|
||||
db: T.Tensor(b_shape, dtype=b_dtype),
|
||||
dh0: T.Tensor(h0_shape, dtype=accum_dtype),
|
||||
):
|
||||
with T.Kernel(batch_size * H, threads=512) as (bbh,):
|
||||
bb, bh = bbh // H, bbh % H
|
||||
bhg = bh // (H // Hg)
|
||||
|
||||
batch_idx = T.alloc_var("int32")
|
||||
seq_start_idx = T.alloc_var("int32")
|
||||
seq_end_idx = T.alloc_var("int32")
|
||||
chunk_start_idx = T.alloc_var("int32")
|
||||
batch_idx = 0 if is_varlen else bb
|
||||
seq_start_idx = cu_seqlens[bb] if is_varlen else 0
|
||||
seq_end_idx = cu_seqlens[bb + 1] if is_varlen else num_tokens
|
||||
chunk_start_idx = chunk_offsets[bb] if is_varlen else 0
|
||||
|
||||
num_iters = T.alloc_var("int32")
|
||||
num_iters = T.ceildiv(seq_end_idx - seq_start_idx, block_S)
|
||||
|
||||
# 2+2+2+2 + 1 + 4 = 13 units
|
||||
do_shared = T.alloc_shared((block_S, DV), dtype=o_dtype)
|
||||
q_shared = T.alloc_shared((block_S, DK), dtype=qkva_dtype)
|
||||
k_shared = T.alloc_shared((block_S, DK), dtype=qkva_dtype)
|
||||
v_shared = T.alloc_shared((block_S, DV), dtype=qkva_dtype)
|
||||
a_shared = T.alloc_shared((block_S, block_S), dtype=qkva_dtype)
|
||||
h_shared = T.alloc_shared((DK, DV), dtype=h_dtype)
|
||||
g_shared = T.alloc_shared((block_S), dtype=accum_dtype, scope="shared")
|
||||
g_exp_shared = T.alloc_shared((block_S), dtype=accum_dtype, scope="shared")
|
||||
g_rev_exp_shared = T.alloc_shared(
|
||||
(block_S), dtype=accum_dtype, scope="shared"
|
||||
)
|
||||
b_shared = T.alloc_shared((block_S), dtype=accum_dtype, scope="shared")
|
||||
|
||||
# 2 units
|
||||
dqkv_shared = T.alloc_shared((block_S, DK), dtype=qkva_dtype)
|
||||
dg_shared = T.alloc_shared((block_S), dtype=accum_dtype, scope="shared")
|
||||
db_shared = T.alloc_shared((block_S), dtype=accum_dtype, scope="shared")
|
||||
|
||||
# 1+1 + 2+2+2 + 4 = 12 units
|
||||
tmp_shared_1_1 = T.alloc_shared((block_S, block_S), dtype=qkva_dtype)
|
||||
tmp_shared_1_2 = T.alloc_shared((block_S, block_S), dtype=qkva_dtype)
|
||||
tmp_shared_1_3 = T.alloc_shared((block_S, block_S), dtype=qkva_dtype)
|
||||
tmp_shared_2_1 = T.alloc_shared((block_S, DK), dtype=qkva_dtype)
|
||||
tmp_shared_2_2 = T.alloc_shared((block_S, DK), dtype=qkva_dtype)
|
||||
tmp_shared_2_3 = T.alloc_shared((block_S, DK), dtype=qkva_dtype)
|
||||
tmp_shared_4_1 = T.alloc_shared((DK, DV), dtype=qkva_dtype)
|
||||
|
||||
# CONSUMER_K
|
||||
dk_fragment = T.alloc_fragment((block_S, DK), dtype=accum_dtype)
|
||||
dv_fragment = T.alloc_fragment((block_S, DK), dtype=accum_dtype)
|
||||
odot_fragment_1 = T.alloc_fragment((block_S, DK), dtype=accum_dtype)
|
||||
dg_fragment_1 = T.alloc_fragment((block_S), dtype=accum_dtype)
|
||||
dg_last_local_1 = T.alloc_fragment((1), dtype=accum_dtype)
|
||||
|
||||
# CONSUMER_A
|
||||
mask_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
|
||||
p_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
|
||||
a_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
|
||||
dp_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
|
||||
da_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
|
||||
u_fragment = T.alloc_fragment((block_S, DK), dtype=accum_dtype)
|
||||
dq_fragment = T.alloc_fragment((block_S, DK), dtype=accum_dtype)
|
||||
db_fragment = T.alloc_fragment((block_S), dtype=accum_dtype)
|
||||
odot_fragment_2 = T.alloc_fragment((block_S, DK), dtype=accum_dtype)
|
||||
dg_fragment_2 = T.alloc_fragment((block_S), dtype=accum_dtype)
|
||||
|
||||
# CONSUMER_S
|
||||
dh_fragment = T.alloc_fragment((DK, DV), dtype=accum_dtype)
|
||||
_odot_fragment_3 = T.alloc_fragment((DK, DV), dtype=accum_dtype)
|
||||
reduce_fragment = T.alloc_fragment((128, 2), dtype=accum_dtype)
|
||||
dg_last_local_3 = T.alloc_fragment((1), dtype=accum_dtype)
|
||||
g_last_local_3 = T.alloc_local((1), dtype=accum_dtype)
|
||||
|
||||
# 16 stages
|
||||
bar_00 = T.alloc_barrier(arrive_count=448)
|
||||
bar_01 = T.alloc_barrier(arrive_count=384)
|
||||
bar_02 = T.alloc_barrier(arrive_count=288)
|
||||
bar_03 = T.alloc_barrier(arrive_count=256)
|
||||
bar_04 = T.alloc_barrier(arrive_count=416)
|
||||
bar_05 = T.alloc_barrier(arrive_count=288)
|
||||
bar_06 = T.alloc_barrier(arrive_count=256)
|
||||
bar_07 = T.alloc_barrier(arrive_count=256)
|
||||
bar_08 = T.alloc_barrier(arrive_count=384)
|
||||
bar_09 = T.alloc_barrier(arrive_count=256)
|
||||
bar_10 = T.alloc_barrier(arrive_count=288)
|
||||
bar_11 = T.alloc_barrier(arrive_count=256)
|
||||
bar_12 = T.alloc_barrier(arrive_count=128)
|
||||
bar_13 = T.alloc_barrier(arrive_count=256)
|
||||
bar_14 = T.alloc_barrier(arrive_count=256)
|
||||
bar_15 = T.alloc_barrier(arrive_count=256)
|
||||
|
||||
T.annotate_layout(
|
||||
{
|
||||
do_shared: tilelang.layout.make_swizzled_layout(do_shared),
|
||||
q_shared: tilelang.layout.make_swizzled_layout(q_shared),
|
||||
k_shared: tilelang.layout.make_swizzled_layout(k_shared),
|
||||
v_shared: tilelang.layout.make_swizzled_layout(v_shared),
|
||||
a_shared: tilelang.layout.make_swizzled_layout(a_shared),
|
||||
h_shared: tilelang.layout.make_swizzled_layout(h_shared),
|
||||
dqkv_shared: tilelang.layout.make_swizzled_layout(dqkv_shared),
|
||||
tmp_shared_1_1: tilelang.layout.make_swizzled_layout(
|
||||
tmp_shared_1_1
|
||||
),
|
||||
tmp_shared_1_2: tilelang.layout.make_swizzled_layout(
|
||||
tmp_shared_1_2
|
||||
),
|
||||
tmp_shared_1_3: tilelang.layout.make_swizzled_layout(
|
||||
tmp_shared_1_3
|
||||
),
|
||||
tmp_shared_2_1: tilelang.layout.make_swizzled_layout(
|
||||
tmp_shared_2_1
|
||||
),
|
||||
tmp_shared_2_2: tilelang.layout.make_swizzled_layout(
|
||||
tmp_shared_2_2
|
||||
),
|
||||
tmp_shared_2_3: tilelang.layout.make_swizzled_layout(
|
||||
tmp_shared_2_3
|
||||
),
|
||||
tmp_shared_4_1: tilelang.layout.make_swizzled_layout(
|
||||
tmp_shared_4_1
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# T.use_swizzle(10)
|
||||
|
||||
tx = T.get_thread_binding()
|
||||
|
||||
PRODUCER_NREG = 24
|
||||
CONSUMER_K_NREG = 144
|
||||
CONSUMER_A_NREG = 176
|
||||
CONSUMER_S_NREG = 160
|
||||
|
||||
# Prefetch the last chunk of data
|
||||
T.copy(
|
||||
h[batch_idx, chunk_start_idx + num_iters - 1, bh, 0:DK, 0:DV], h_shared
|
||||
)
|
||||
for j_s, j_k in T.Parallel(block_S, DK):
|
||||
if seq_start_idx + (num_iters - 1) * block_S + j_s < seq_end_idx:
|
||||
q_shared[j_s, j_k] = q[
|
||||
batch_idx,
|
||||
seq_start_idx + (num_iters - 1) * block_S + j_s,
|
||||
bhg,
|
||||
j_k,
|
||||
]
|
||||
else:
|
||||
q_shared[j_s, j_k] = 0
|
||||
for j_s, j_k in T.Parallel(block_S, DK):
|
||||
if seq_start_idx + (num_iters - 1) * block_S + j_s < seq_end_idx:
|
||||
k_shared[j_s, j_k] = k[
|
||||
batch_idx,
|
||||
seq_start_idx + (num_iters - 1) * block_S + j_s,
|
||||
bhg,
|
||||
j_k,
|
||||
]
|
||||
else:
|
||||
k_shared[j_s, j_k] = 0
|
||||
for j_s, j_v in T.Parallel(block_S, DV):
|
||||
if seq_start_idx + (num_iters - 1) * block_S + j_s < seq_end_idx:
|
||||
v_shared[j_s, j_v] = v[
|
||||
batch_idx,
|
||||
seq_start_idx + (num_iters - 1) * block_S + j_s,
|
||||
bh,
|
||||
j_v,
|
||||
]
|
||||
else:
|
||||
v_shared[j_s, j_v] = 0
|
||||
for j_s, j_t in T.Parallel(block_S, block_S):
|
||||
if seq_start_idx + (num_iters - 1) * block_S + j_s < seq_end_idx:
|
||||
a_shared[j_s, j_t] = a[
|
||||
batch_idx,
|
||||
seq_start_idx + (num_iters - 1) * block_S + j_s,
|
||||
bh,
|
||||
j_t,
|
||||
]
|
||||
else:
|
||||
a_shared[j_s, j_t] = 0
|
||||
for j_s, j_v in T.Parallel(block_S, DV):
|
||||
if seq_start_idx + (num_iters - 1) * block_S + j_s < seq_end_idx:
|
||||
do_shared[j_s, j_v] = do[
|
||||
batch_idx,
|
||||
seq_start_idx + (num_iters - 1) * block_S + j_s,
|
||||
bh,
|
||||
j_v,
|
||||
]
|
||||
else:
|
||||
do_shared[j_s, j_v] = 0
|
||||
for j_s in T.Parallel(block_S):
|
||||
if seq_start_idx + (num_iters - 1) * block_S + j_s < seq_end_idx:
|
||||
g_shared[j_s] = g[
|
||||
batch_idx, seq_start_idx + (num_iters - 1) * block_S + j_s, bh
|
||||
]
|
||||
else:
|
||||
g_shared[j_s] = g[batch_idx, seq_end_idx - 1, bh]
|
||||
for j_s in T.Parallel(block_S):
|
||||
if seq_start_idx + (num_iters - 1) * block_S + j_s < seq_end_idx:
|
||||
b_shared[j_s] = b[
|
||||
batch_idx, seq_start_idx + (num_iters - 1) * block_S + j_s, bh
|
||||
]
|
||||
else:
|
||||
b_shared[j_s] = 0
|
||||
|
||||
if tx < 128:
|
||||
T.set_max_nreg(CONSUMER_S_NREG, 1)
|
||||
|
||||
if use_dht:
|
||||
T.copy(dht[bb, bh, 0:DK, 0:DV], dh_fragment)
|
||||
else:
|
||||
T.clear(dh_fragment)
|
||||
T.copy(dh_fragment, tmp_shared_4_1)
|
||||
|
||||
for i_s in T.serial(num_iters):
|
||||
T.barrier_arrive(bar_00)
|
||||
|
||||
# 00
|
||||
T.barrier_wait(bar_00, (i_s + 0) % 2)
|
||||
for j_s in T.Parallel(block_S):
|
||||
g_exp_shared[j_s] = T.exp2(g_shared[j_s] * 1.442695)
|
||||
g_rev_exp_shared[j_s] = T.exp2(
|
||||
(g_shared[block_S - 1] - g_shared[j_s]) * 1.442695
|
||||
)
|
||||
T.barrier_arrive(bar_01)
|
||||
|
||||
# 01, 02, 03
|
||||
T.barrier_wait(bar_01, (i_s + 0) % 2)
|
||||
g_last_local_3[0] = g_exp_shared[block_S - 1]
|
||||
# dS0 = g_last * dSt
|
||||
for j_k, j_v in T.Parallel(DK, DV):
|
||||
dh_fragment[j_k, j_v] *= g_last_local_3[0]
|
||||
T.barrier_arrive(bar_04)
|
||||
|
||||
# 04, 05, 06, 07
|
||||
T.barrier_wait(bar_04, (i_s + 0) % 2)
|
||||
# dg_last += sum(dS0 * S0)
|
||||
T.clear(reduce_fragment)
|
||||
for j_k, j_v in T.Parallel(DK, DV):
|
||||
reduce_fragment[
|
||||
j_k % 64 // 16 * 32 + j_k % 8 * 4 + j_v % 8 // 2, j_v % 2
|
||||
] += dh_fragment[j_k, j_v] * h_shared[j_k, j_v]
|
||||
T.barrier_arrive(bar_08)
|
||||
T.barrier_wait(bar_08, (i_s + 0) % 2)
|
||||
T.barrier_wait(bar_09, (i_s + 0) % 2)
|
||||
|
||||
# 10
|
||||
T.barrier_wait(bar_10, (i_s + 0) % 2)
|
||||
T.reduce_sum(
|
||||
T.reshape(reduce_fragment, (128 * 2,)),
|
||||
dg_last_local_3,
|
||||
dim=0,
|
||||
clear=True,
|
||||
)
|
||||
dg_shared[block_S - 1] += dg_last_local_3[0]
|
||||
T.barrier_arrive(bar_11)
|
||||
|
||||
# 11
|
||||
T.barrier_wait(bar_11, (i_s + 0) % 2)
|
||||
# dS0 += K^T @ dVg
|
||||
T.gemm_v1(
|
||||
tmp_shared_2_2,
|
||||
tmp_shared_2_3,
|
||||
dh_fragment,
|
||||
transpose_A=True,
|
||||
clear_accum=False,
|
||||
)
|
||||
T.barrier_arrive(bar_12)
|
||||
T.barrier_wait(bar_12, (i_s + 0) % 2)
|
||||
|
||||
# 13
|
||||
T.barrier_wait(bar_13, (i_s + 0) % 2)
|
||||
# dOg = s * g * dO
|
||||
for j_s, j_v in T.Parallel(block_S, DV):
|
||||
tmp_shared_2_3[j_s, j_v] = (
|
||||
scale * do_shared[j_s, j_v] * g_exp_shared[j_s]
|
||||
)
|
||||
T.barrier_arrive(bar_14)
|
||||
|
||||
# 14
|
||||
T.barrier_wait(bar_14, (i_s + 0) % 2)
|
||||
# dS0 += Q^T @ dOg
|
||||
T.gemm_v1(
|
||||
tmp_shared_2_1,
|
||||
tmp_shared_2_3,
|
||||
dh_fragment,
|
||||
transpose_A=True,
|
||||
clear_accum=False,
|
||||
)
|
||||
T.barrier_arrive(bar_15)
|
||||
|
||||
# 15
|
||||
T.barrier_wait(bar_15, (i_s + 0) % 2)
|
||||
# S4[1] = dS0
|
||||
T.copy(dh_fragment, tmp_shared_4_1)
|
||||
|
||||
if use_dht:
|
||||
T.copy(dh_fragment, dh0[bb, bh, 0:DK, 0:DV])
|
||||
|
||||
elif tx < 256:
|
||||
T.set_max_nreg(CONSUMER_K_NREG, 1)
|
||||
|
||||
for i_s in T.serial(num_iters):
|
||||
T.barrier_arrive(bar_00)
|
||||
|
||||
# 16 == 00
|
||||
T.barrier_wait(bar_00, (i_s + 0) % 2)
|
||||
# S2[S] dK
|
||||
if i_s > 0:
|
||||
T.copy(dk_fragment, dqkv_shared)
|
||||
T.barrier_arrive(bar_01)
|
||||
|
||||
# 01
|
||||
T.barrier_wait(bar_01, (i_s + 0) % 2)
|
||||
# dV' = K @ dSt
|
||||
T.gemm_v1(k_shared, tmp_shared_4_1, dv_fragment, clear_accum=True)
|
||||
# dV' = g_last/g * dV'
|
||||
for j_s, j_v in T.Parallel(block_S, DV):
|
||||
dv_fragment[j_s, j_v] *= g_rev_exp_shared[j_s]
|
||||
T.barrier_arrive(bar_02)
|
||||
|
||||
# 02
|
||||
T.barrier_wait(bar_02, (i_s + 0) % 2)
|
||||
# dV' += Pg^T @ dO
|
||||
T.gemm_v1(
|
||||
tmp_shared_1_1,
|
||||
do_shared,
|
||||
dv_fragment,
|
||||
transpose_A=True,
|
||||
clear_accum=False,
|
||||
)
|
||||
T.barrier_arrive(bar_03)
|
||||
|
||||
# 03
|
||||
T.barrier_wait(bar_03, (i_s + 0) % 2)
|
||||
# S2[1] dV'
|
||||
T.copy(dv_fragment, tmp_shared_2_1)
|
||||
T.barrier_arrive(bar_04)
|
||||
|
||||
# 04
|
||||
T.barrier_wait(bar_04, (i_s + 0) % 2)
|
||||
# dV = Ag^T @ dV'
|
||||
T.gemm_v1(
|
||||
tmp_shared_1_2,
|
||||
tmp_shared_2_1,
|
||||
dv_fragment,
|
||||
transpose_A=True,
|
||||
clear_accum=True,
|
||||
)
|
||||
# S2[S] dV
|
||||
T.copy(dv_fragment, dqkv_shared)
|
||||
T.barrier_arrive(bar_05)
|
||||
|
||||
# 05
|
||||
T.barrier_wait(bar_05, (i_s + 0) % 2)
|
||||
# dVg = -g * dV
|
||||
for j_s, j_v in T.Parallel(block_S, DV):
|
||||
dv_fragment[j_s, j_v] = (
|
||||
-dv_fragment[j_s, j_v] * g_exp_shared[j_s]
|
||||
)
|
||||
# dg += sum(dVg * U)
|
||||
T.copy(tmp_shared_2_3, odot_fragment_1)
|
||||
for j_s, j_v in T.Parallel(block_S, DV):
|
||||
odot_fragment_1[j_s, j_v] *= dv_fragment[j_s, j_v]
|
||||
T.reduce_sum(odot_fragment_1, dg_fragment_1, dim=1, clear=True)
|
||||
T.copy(dg_fragment_1, dg_shared)
|
||||
# S2[3] dVg
|
||||
T.copy(dv_fragment, tmp_shared_2_3)
|
||||
T.barrier_arrive(bar_06)
|
||||
|
||||
# 06
|
||||
T.barrier_wait(bar_06, (i_s + 0) % 2)
|
||||
# S2[2] K
|
||||
T.copy(k_shared, odot_fragment_1)
|
||||
T.copy(odot_fragment_1, tmp_shared_2_2)
|
||||
T.barrier_arrive(bar_07)
|
||||
|
||||
# 07
|
||||
T.barrier_wait(bar_07, (i_s + 0) % 2)
|
||||
# dK = V' @ dSt^T
|
||||
T.gemm_v1(
|
||||
tmp_shared_2_1,
|
||||
tmp_shared_4_1,
|
||||
dk_fragment,
|
||||
transpose_B=True,
|
||||
clear_accum=True,
|
||||
)
|
||||
T.barrier_arrive(bar_08)
|
||||
|
||||
# 08
|
||||
T.barrier_wait(bar_08, (i_s + 0) % 2)
|
||||
# dK = g_last/g * dK
|
||||
for j_s, j_k in T.Parallel(block_S, DK):
|
||||
dk_fragment[j_s, j_k] *= g_rev_exp_shared[j_s]
|
||||
# dg -= sum(K * dK)
|
||||
for j_s, j_k in T.Parallel(block_S, DK):
|
||||
odot_fragment_1[j_s, j_k] *= -dk_fragment[j_s, j_k]
|
||||
T.reduce_sum(odot_fragment_1, dg_fragment_1, dim=1, clear=True)
|
||||
for j_s in T.Parallel(block_S):
|
||||
dg_shared[j_s] += dg_fragment_1[j_s]
|
||||
# dg_last += sum(K * dK)
|
||||
T.reduce_sum(dg_fragment_1, dg_last_local_1, dim=0, clear=True)
|
||||
# Sg[S] dg
|
||||
dg_shared[block_S - 1] -= dg_last_local_1[0]
|
||||
T.barrier_arrive(bar_09)
|
||||
|
||||
# 09
|
||||
T.barrier_wait(bar_09, (i_s + 0) % 2)
|
||||
# dK += dVg @ S0^T
|
||||
T.gemm_v1(
|
||||
tmp_shared_2_3,
|
||||
h_shared,
|
||||
dk_fragment,
|
||||
transpose_B=True,
|
||||
clear_accum=False,
|
||||
)
|
||||
T.barrier_arrive(bar_10)
|
||||
T.barrier_wait(bar_10, (i_s + 0) % 2)
|
||||
|
||||
# 12
|
||||
T.barrier_wait(bar_12, (i_s + 0) % 2)
|
||||
# dK += dP^T @ Q
|
||||
T.gemm_v1(
|
||||
tmp_shared_1_1,
|
||||
tmp_shared_2_1,
|
||||
dk_fragment,
|
||||
transpose_A=True,
|
||||
clear_accum=False,
|
||||
)
|
||||
T.barrier_arrive(bar_13)
|
||||
T.barrier_wait(bar_13, (i_s + 0) % 2)
|
||||
|
||||
# 15
|
||||
T.barrier_wait(bar_15, (i_s + 0) % 2)
|
||||
# dK += dAs @ K
|
||||
T.gemm_v1(
|
||||
tmp_shared_1_2, tmp_shared_2_2, dk_fragment, clear_accum=False
|
||||
)
|
||||
|
||||
for j_s, j_k in T.Parallel(block_S, DK):
|
||||
if seq_start_idx + j_s < seq_end_idx:
|
||||
dk[batch_idx, seq_start_idx + j_s, bh, j_k] = dk_fragment[
|
||||
j_s, j_k
|
||||
]
|
||||
|
||||
elif tx < 384:
|
||||
T.set_max_nreg(CONSUMER_A_NREG, 1)
|
||||
|
||||
for i_s in T.serial(num_iters):
|
||||
T.barrier_arrive(bar_00)
|
||||
|
||||
# 00
|
||||
T.barrier_wait(bar_00, (i_s + 0) % 2)
|
||||
# P = Q @ K^T
|
||||
T.gemm_v1(
|
||||
q_shared,
|
||||
k_shared,
|
||||
p_fragment,
|
||||
transpose_B=True,
|
||||
clear_accum=True,
|
||||
)
|
||||
T.barrier_arrive(bar_01)
|
||||
|
||||
# 01
|
||||
T.barrier_wait(bar_01, (i_s + 0) % 2)
|
||||
# G = Lower(diag(g) @ I @ diag(1/g))
|
||||
for j_s, j_t in T.Parallel(block_S, block_S):
|
||||
mask_fragment[j_s, j_t] = g_shared[j_s] - g_shared[j_t]
|
||||
for j_s, j_t in T.Parallel(block_S, block_S):
|
||||
if j_s >= j_t:
|
||||
mask_fragment[j_s, j_t] = T.exp2(
|
||||
mask_fragment[j_s, j_t] * 1.442695
|
||||
)
|
||||
else:
|
||||
mask_fragment[j_s, j_t] = 0
|
||||
# Pg = s * P * G
|
||||
for j_s, j_t in T.Parallel(block_S, block_S):
|
||||
p_fragment[j_s, j_t] *= mask_fragment[j_s, j_t]
|
||||
for j_s, j_t in T.Parallel(block_S, block_S):
|
||||
p_fragment[j_s, j_t] *= scale
|
||||
# S1[1] Pg
|
||||
T.copy(p_fragment, tmp_shared_1_1)
|
||||
T.barrier_arrive(bar_02)
|
||||
|
||||
# 02
|
||||
T.barrier_wait(bar_02, (i_s + 0) % 2)
|
||||
# Ab = Ar * b
|
||||
T.copy(a_shared, a_fragment)
|
||||
for j_s, j_t in T.Parallel(block_S, block_S):
|
||||
a_fragment[j_s, j_t] *= b_shared[j_t]
|
||||
# Ag = G * Ab
|
||||
for j_s, j_t in T.Parallel(block_S, block_S):
|
||||
a_fragment[j_s, j_t] *= mask_fragment[j_s, j_t]
|
||||
# S1[2] Ag
|
||||
T.copy(a_fragment, tmp_shared_1_2)
|
||||
T.barrier_arrive(bar_03)
|
||||
|
||||
# 03
|
||||
T.barrier_wait(bar_03, (i_s + 0) % 2)
|
||||
# U = K @ S0
|
||||
T.gemm_v1(k_shared, h_shared, u_fragment, clear_accum=True)
|
||||
T.barrier_arrive(bar_04)
|
||||
|
||||
# 04
|
||||
T.barrier_wait(bar_04, (i_s + 0) % 2)
|
||||
# S2[3] U
|
||||
T.copy(u_fragment, tmp_shared_2_3)
|
||||
# W = V - g * U
|
||||
for j_s, j_v in T.Parallel(block_S, DV):
|
||||
u_fragment[j_s, j_v] *= -g_exp_shared[j_s]
|
||||
for j_s, j_v in T.Parallel(block_S, DV):
|
||||
u_fragment[j_s, j_v] += v_shared[j_s, j_v]
|
||||
# S2[2] W
|
||||
T.copy(u_fragment, tmp_shared_2_2)
|
||||
T.barrier_arrive(bar_05)
|
||||
|
||||
# 05
|
||||
T.barrier_wait(bar_05, (i_s + 0) % 2)
|
||||
# dAg = dV' @ W^T
|
||||
T.gemm_v1(
|
||||
tmp_shared_2_1,
|
||||
tmp_shared_2_2,
|
||||
da_fragment,
|
||||
transpose_B=True,
|
||||
clear_accum=True,
|
||||
)
|
||||
# V' = Ag @ W
|
||||
T.gemm_v1(
|
||||
tmp_shared_1_2, tmp_shared_2_2, u_fragment, clear_accum=True
|
||||
)
|
||||
# S2[1] V'
|
||||
T.copy(u_fragment, tmp_shared_2_1)
|
||||
T.barrier_arrive(bar_06)
|
||||
|
||||
# 06
|
||||
T.barrier_wait(bar_06, (i_s + 0) % 2)
|
||||
# dPg = dO @ V'^T
|
||||
T.gemm_v1(
|
||||
do_shared,
|
||||
tmp_shared_2_1,
|
||||
dp_fragment,
|
||||
transpose_B=True,
|
||||
clear_accum=True,
|
||||
)
|
||||
T.barrier_arrive(bar_07)
|
||||
|
||||
# 07
|
||||
T.barrier_wait(bar_07, (i_s + 0) % 2)
|
||||
# dAb = G * dAg
|
||||
for j_s, j_t in T.Parallel(block_S, block_S):
|
||||
da_fragment[j_s, j_t] *= mask_fragment[j_s, j_t]
|
||||
# dg += sum((dPg * P) - (dPg * P)^T)
|
||||
T.copy(tmp_shared_1_1, p_fragment)
|
||||
for j_s, j_t in T.Parallel(block_S, block_S):
|
||||
p_fragment[j_s, j_t] *= dp_fragment[j_s, j_t]
|
||||
T.copy(p_fragment, tmp_shared_1_1)
|
||||
for j_s, j_t in T.Parallel(block_S, block_S):
|
||||
p_fragment[j_s, j_t] -= tmp_shared_1_1[j_t, j_s]
|
||||
T.reduce_sum(p_fragment, dg_fragment_2, dim=1, clear=True)
|
||||
# dP = s * G * dPg
|
||||
for j_s, j_t in T.Parallel(block_S, block_S):
|
||||
dp_fragment[j_s, j_t] *= mask_fragment[j_s, j_t]
|
||||
for j_s, j_t in T.Parallel(block_S, block_S):
|
||||
dp_fragment[j_s, j_t] *= scale
|
||||
# S1[1] dP
|
||||
T.copy(dp_fragment, tmp_shared_1_1)
|
||||
T.barrier_arrive(bar_08)
|
||||
|
||||
# 08
|
||||
T.barrier_wait(bar_08, (i_s + 0) % 2)
|
||||
# dQ = dO @ S0^T
|
||||
T.gemm_v1(
|
||||
do_shared,
|
||||
h_shared,
|
||||
dq_fragment,
|
||||
transpose_B=True,
|
||||
clear_accum=True,
|
||||
)
|
||||
T.barrier_arrive(bar_09)
|
||||
|
||||
# 09
|
||||
T.barrier_wait(bar_09, (i_s + 0) % 2)
|
||||
# dQ = s * g * dQ
|
||||
for j_s, j_k in T.Parallel(block_S, DK):
|
||||
dq_fragment[j_s, j_k] *= g_exp_shared[j_s]
|
||||
for j_s, j_k in T.Parallel(block_S, DK):
|
||||
dq_fragment[j_s, j_k] *= scale
|
||||
# S2[1] Q
|
||||
T.copy(q_shared, odot_fragment_2)
|
||||
# dg += sum(Q * dQ)
|
||||
T.copy(odot_fragment_2, tmp_shared_2_1)
|
||||
for j_s, j_k in T.Parallel(block_S, DK):
|
||||
odot_fragment_2[j_s, j_k] *= dq_fragment[j_s, j_k]
|
||||
T.reduce_sum(odot_fragment_2, dg_fragment_2, dim=1, clear=False)
|
||||
T.barrier_arrive(bar_10)
|
||||
|
||||
# 10
|
||||
T.barrier_wait(bar_10, (i_s + 0) % 2)
|
||||
# dQ += dP @ K
|
||||
T.gemm_v1(
|
||||
tmp_shared_1_1, tmp_shared_2_2, dq_fragment, clear_accum=False
|
||||
)
|
||||
# S2[S] dQ
|
||||
T.copy(dq_fragment, dqkv_shared)
|
||||
T.barrier_arrive(bar_11)
|
||||
|
||||
# 11, 12
|
||||
T.barrier_wait(bar_11, (i_s + 0) % 2)
|
||||
# dAb * Ar
|
||||
T.copy(a_shared, a_fragment)
|
||||
for j_s, j_t in T.Parallel(block_S, block_S):
|
||||
a_fragment[j_s, j_t] *= da_fragment[j_s, j_t]
|
||||
T.copy(a_fragment, tmp_shared_1_3)
|
||||
# dAb * Ab [ = G * dAg * Ab ]
|
||||
for j_s, j_t in T.Parallel(block_S, block_S):
|
||||
a_fragment[j_s, j_t] *= b_shared[j_t]
|
||||
# dg += sum((dAb * Ab) - (dAb * Ab)^T)
|
||||
T.copy(a_fragment, tmp_shared_1_2)
|
||||
for j_s, j_t in T.Parallel(block_S, block_S):
|
||||
a_fragment[j_s, j_t] -= tmp_shared_1_2[j_t, j_s]
|
||||
T.reduce_sum(a_fragment, dg_fragment_2, dim=1, clear=False)
|
||||
# Sg[S] dg
|
||||
for j_s in T.Parallel(block_S):
|
||||
dg_shared[j_s] += dg_fragment_2[j_s]
|
||||
# db = sum((dAb * Ar)^T)
|
||||
for j_s, j_t in T.Parallel(block_S, block_S):
|
||||
a_fragment[j_s, j_t] = tmp_shared_1_3[j_t, j_s]
|
||||
T.reduce_sum(a_fragment, db_fragment, dim=1, clear=True)
|
||||
# dAr = dAb * b
|
||||
for j_s, j_t in T.Parallel(block_S, block_S):
|
||||
da_fragment[j_s, j_t] *= b_shared[j_t]
|
||||
# S1[2] dAr
|
||||
T.copy(da_fragment, tmp_shared_1_2)
|
||||
T.barrier_arrive(bar_13)
|
||||
|
||||
# 13
|
||||
T.barrier_wait(bar_13, (i_s + 0) % 2)
|
||||
# dA = -Ar^T @ dAr @ Ar^T
|
||||
T.gemm_v1(
|
||||
a_shared,
|
||||
tmp_shared_1_2,
|
||||
da_fragment,
|
||||
transpose_A=True,
|
||||
clear_accum=True,
|
||||
)
|
||||
T.copy(da_fragment, tmp_shared_1_2)
|
||||
T.gemm_v1(
|
||||
tmp_shared_1_2,
|
||||
a_shared,
|
||||
da_fragment,
|
||||
transpose_B=True,
|
||||
clear_accum=True,
|
||||
)
|
||||
# At = K @ K^T
|
||||
T.gemm_v1(
|
||||
tmp_shared_2_2,
|
||||
tmp_shared_2_2,
|
||||
a_fragment,
|
||||
transpose_B=True,
|
||||
clear_accum=True,
|
||||
)
|
||||
T.barrier_arrive(bar_14)
|
||||
|
||||
# 14
|
||||
T.barrier_wait(bar_14, (i_s + 0) % 2)
|
||||
for j_s, j_t in T.Parallel(block_S, block_S):
|
||||
if j_s <= j_t:
|
||||
da_fragment[j_s, j_t] = 0
|
||||
else:
|
||||
da_fragment[j_s, j_t] = -da_fragment[j_s, j_t]
|
||||
# db += sum(dA * At)
|
||||
for j_s, j_t in T.Parallel(block_S, block_S):
|
||||
a_fragment[j_s, j_t] *= da_fragment[j_s, j_t]
|
||||
T.reduce_sum(a_fragment, db_fragment, dim=1, clear=False)
|
||||
T.copy(db_fragment, db_shared)
|
||||
# dAt = b * dA
|
||||
for j_s, j_t in T.Parallel(block_S, block_S):
|
||||
da_fragment[j_s, j_t] *= b_shared[j_s]
|
||||
# dAs = dAt + dAt^T
|
||||
T.copy(da_fragment, tmp_shared_1_2)
|
||||
for j_s, j_t in T.Parallel(block_S, block_S):
|
||||
da_fragment[j_s, j_t] += tmp_shared_1_2[j_t, j_s]
|
||||
# S1[1] dAs
|
||||
T.copy(da_fragment, tmp_shared_1_2)
|
||||
T.barrier_arrive(bar_15)
|
||||
T.barrier_wait(bar_15, (i_s + 0) % 2)
|
||||
|
||||
else:
|
||||
T.set_max_nreg(PRODUCER_NREG, 0)
|
||||
|
||||
if tx < 384 + 32:
|
||||
for i_s in T.serial(num_iters - 1):
|
||||
chunk_idx = num_iters - i_s - 2
|
||||
left = seq_start_idx + chunk_idx * block_S
|
||||
right = left + block_S
|
||||
|
||||
T.barrier_arrive(bar_00)
|
||||
T.barrier_wait(bar_00, (i_s + 0) % 2)
|
||||
|
||||
T.barrier_wait(bar_03, (i_s + 0) % 2)
|
||||
for j_s in T.Parallel(block_S):
|
||||
g_shared[j_s] = g[batch_idx, left + j_s, bh]
|
||||
|
||||
T.barrier_wait(bar_05, (i_s + 0) % 2)
|
||||
T.copy(v[batch_idx, left:right, bh, 0:DV], v_shared)
|
||||
|
||||
T.barrier_wait(bar_07, (i_s + 0) % 2)
|
||||
T.copy(k[batch_idx, left:right, bhg, 0:DK], k_shared)
|
||||
|
||||
T.barrier_wait(bar_10, (i_s + 0) % 2)
|
||||
T.copy(q[batch_idx, left:right, bhg, 0:DK], q_shared)
|
||||
|
||||
if num_iters > 0:
|
||||
T.barrier_arrive(bar_00)
|
||||
|
||||
elif tx < 384 + 64:
|
||||
for i_s in T.serial(num_iters):
|
||||
left = seq_start_idx + (num_iters - i_s - 1) * block_S
|
||||
right = left + block_S
|
||||
|
||||
T.barrier_arrive(bar_00)
|
||||
T.barrier_wait(bar_00, (i_s + 0) % 2)
|
||||
|
||||
T.barrier_wait(bar_01, (i_s + 0) % 2)
|
||||
if i_s == 1:
|
||||
for j_s, j_k in T.Parallel(block_S, DK):
|
||||
if left + block_S + j_s < seq_end_idx:
|
||||
dk[batch_idx, left + block_S + j_s, bh, j_k] = (
|
||||
dqkv_shared[j_s, j_k]
|
||||
)
|
||||
elif i_s > 1:
|
||||
T.copy(
|
||||
dqkv_shared,
|
||||
dk[
|
||||
batch_idx,
|
||||
left + block_S : right + block_S,
|
||||
bh,
|
||||
0:DK,
|
||||
],
|
||||
)
|
||||
T.barrier_arrive(bar_04)
|
||||
T.barrier_wait(bar_04, (i_s + 0) % 2)
|
||||
|
||||
T.barrier_wait(bar_05, (i_s + 0) % 2)
|
||||
if i_s == 0:
|
||||
for j_s, j_v in T.Parallel(block_S, DV):
|
||||
if left + j_s < seq_end_idx:
|
||||
dv[batch_idx, left + j_s, bh, j_v] = dqkv_shared[
|
||||
j_s, j_v
|
||||
]
|
||||
else:
|
||||
T.copy(dqkv_shared, dv[batch_idx, left:right, bh, 0:DV])
|
||||
T.barrier_arrive(bar_10)
|
||||
T.barrier_wait(bar_10, (i_s + 0) % 2)
|
||||
|
||||
T.barrier_wait(bar_11, (i_s + 0) % 2)
|
||||
if i_s == 0:
|
||||
for j_s, j_k in T.Parallel(block_S, DK):
|
||||
if left + j_s < seq_end_idx:
|
||||
dq[batch_idx, left + j_s, bh, j_k] = dqkv_shared[
|
||||
j_s, j_k
|
||||
]
|
||||
else:
|
||||
T.copy(dqkv_shared, dq[batch_idx, left:right, bh, 0:DK])
|
||||
|
||||
elif tx < 384 + 96:
|
||||
for i_s in T.serial(num_iters - 1):
|
||||
chunk_idx = num_iters - i_s - 2
|
||||
left = seq_start_idx + chunk_idx * block_S
|
||||
right = left + block_S
|
||||
|
||||
T.barrier_arrive(bar_02)
|
||||
T.barrier_wait(bar_02, (i_s + 0) % 2)
|
||||
|
||||
T.barrier_wait(bar_10, (i_s + 0) % 2)
|
||||
T.copy(
|
||||
h[batch_idx, chunk_start_idx + chunk_idx, bh, 0:DK, 0:DV],
|
||||
h_shared,
|
||||
)
|
||||
|
||||
T.barrier_wait(bar_14, (i_s + 0) % 2)
|
||||
T.copy(a[batch_idx, left:right, bh, 0:block_S], a_shared)
|
||||
|
||||
T.copy(do[batch_idx, left:right, bh, 0:DV], do_shared)
|
||||
|
||||
T.barrier_wait(bar_15, (i_s + 0) % 2)
|
||||
for j_s in T.Parallel(block_S):
|
||||
b_shared[j_s] = b[batch_idx, left + j_s, bh]
|
||||
|
||||
if num_iters > 0:
|
||||
T.barrier_wait(bar_00, (num_iters - 1) % 2)
|
||||
T.barrier_arrive(bar_02)
|
||||
|
||||
else:
|
||||
for i_s in T.serial(num_iters):
|
||||
left = seq_start_idx + (num_iters - i_s - 1) * block_S
|
||||
|
||||
T.barrier_arrive(bar_05)
|
||||
T.barrier_wait(bar_05, (i_s + 0) % 2)
|
||||
|
||||
T.barrier_wait(bar_15, (i_s + 0) % 2)
|
||||
|
||||
if i_s == 0:
|
||||
for j_s in T.Parallel(block_S):
|
||||
if left + j_s < seq_end_idx:
|
||||
dg[batch_idx, left + j_s, bh] = dg_shared[j_s]
|
||||
if (seq_end_idx - seq_start_idx) % block_S > 0:
|
||||
dg[batch_idx, seq_end_idx - 1, bh] += dg_shared[
|
||||
block_S - 1
|
||||
]
|
||||
else:
|
||||
for j_s in T.Parallel(block_S):
|
||||
dg[batch_idx, left + j_s, bh] = dg_shared[j_s]
|
||||
|
||||
if i_s == 0:
|
||||
for j_s in T.Parallel(block_S):
|
||||
if left + j_s < seq_end_idx:
|
||||
db[batch_idx, left + j_s, bh] = db_shared[j_s]
|
||||
else:
|
||||
for j_s in T.Parallel(block_S):
|
||||
db[batch_idx, left + j_s, bh] = db_shared[j_s]
|
||||
|
||||
return tilelang_fused_chunk_gdr_bwd_kernel
|
||||
|
||||
|
||||
def fused_gdr_bwd(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
a: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
do: torch.Tensor,
|
||||
dht: torch.Tensor,
|
||||
h: torch.Tensor,
|
||||
scale: float | None = None,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
chunk_size: int = 64,
|
||||
):
|
||||
batch_size, num_tokens, Hg, K = k.shape
|
||||
_, _, H, V = v.shape
|
||||
scale = scale or K ** (-0.5)
|
||||
assert K == V == 128
|
||||
assert chunk_size == 64
|
||||
|
||||
if cu_seqlens is None:
|
||||
real_batch_size = batch_size
|
||||
cu_seqlens = torch.empty((batch_size + 1), dtype=torch.int32, device=k.device)
|
||||
chunk_offsets = torch.empty(
|
||||
(batch_size + 1), dtype=torch.int32, device=k.device
|
||||
)
|
||||
is_varlen = False
|
||||
else:
|
||||
real_batch_size = len(cu_seqlens) - 1
|
||||
chunk_offsets, _ = prepare_chunk_offsets(cu_seqlens, chunk_size)
|
||||
chunk_offsets = chunk_offsets.to(cu_seqlens.dtype)
|
||||
is_varlen = True
|
||||
|
||||
use_dht = dht is not None
|
||||
if dht is None:
|
||||
dht = torch.empty(
|
||||
(real_batch_size, H, K, V), dtype=torch.float32, device=k.device
|
||||
)
|
||||
dq = torch.empty_like(v)
|
||||
dk = torch.empty_like(v)
|
||||
dv = torch.empty_like(v)
|
||||
dg = torch.empty_like(g)
|
||||
db = torch.empty_like(b)
|
||||
dh0 = torch.empty_like(dht)
|
||||
|
||||
tilelang_fused_chunk_gdr_bwd_kernel = tilelang_fused_chunk_gdr_bwd(
|
||||
H,
|
||||
Hg,
|
||||
K,
|
||||
V,
|
||||
chunk_size,
|
||||
scale,
|
||||
qkva_dtype=q.dtype,
|
||||
g_dtype=g.dtype,
|
||||
b_dtype=b.dtype,
|
||||
h_dtype=h.dtype,
|
||||
o_dtype=do.dtype,
|
||||
seqlen_dtype=cu_seqlens.dtype,
|
||||
accum_dtype="float32",
|
||||
is_varlen=is_varlen,
|
||||
use_dht=use_dht,
|
||||
)
|
||||
tilelang_fused_chunk_gdr_bwd_kernel(
|
||||
do,
|
||||
dht,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
a,
|
||||
g,
|
||||
b,
|
||||
h,
|
||||
cu_seqlens,
|
||||
chunk_offsets,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
dg,
|
||||
db,
|
||||
dh0,
|
||||
)
|
||||
|
||||
if not use_dht:
|
||||
dh0 = None
|
||||
|
||||
return dq, dk, dv, dg, db, dh0
|
||||
658
flash_qla/ops/gated_delta_rule/chunk/hopper/fused_fwd.py
Normal file
658
flash_qla/ops/gated_delta_rule/chunk/hopper/fused_fwd.py
Normal file
@@ -0,0 +1,658 @@
|
||||
# Copyright (c) 2026 The Qwen team, Alibaba Group.
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
import torch
|
||||
import tilelang
|
||||
import tilelang.language as T
|
||||
|
||||
from flash_qla.utils import prepare_chunk_offsets
|
||||
|
||||
|
||||
MULTI_PROCESSOR_COUNT = torch.cuda.get_device_properties().multi_processor_count
|
||||
TARGET_NUM_CTAS = int(MULTI_PROCESSOR_COUNT * 0.7)
|
||||
|
||||
|
||||
@tilelang.jit(
|
||||
# out_idx=[-3, -2, -1],
|
||||
pass_configs={
|
||||
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
|
||||
# tilelang.PassConfigKey.TL_DISABLE_THREAD_STORAGE_SYNC: True,
|
||||
},
|
||||
)
|
||||
def tilelang_fused_chunk_gdr_fwd(
|
||||
H,
|
||||
Hg,
|
||||
DK,
|
||||
DV,
|
||||
chunk_size,
|
||||
scale,
|
||||
accum_dtype,
|
||||
qkva_dtype,
|
||||
g_dtype,
|
||||
b_dtype,
|
||||
h0_dtype,
|
||||
ht_dtype,
|
||||
h_dtype,
|
||||
o_dtype,
|
||||
seqlen_dtype,
|
||||
use_initial_state,
|
||||
store_final_state,
|
||||
store_h,
|
||||
store_o,
|
||||
is_varlen,
|
||||
is_cp,
|
||||
block_DV=128,
|
||||
):
|
||||
batch_size = T.dynamic("batch_size")
|
||||
num_tokens = T.dynamic("num_tokens")
|
||||
num_chunks = T.dynamic("num_chunks")
|
||||
raw_batch_size = T.dynamic("raw_batch_size")
|
||||
block_S = chunk_size
|
||||
|
||||
if is_varlen:
|
||||
q_shape = (1, num_tokens, Hg, DK)
|
||||
k_shape = (1, num_tokens, Hg, DK)
|
||||
v_shape = (1, num_tokens, H, DV)
|
||||
o_shape = (1, num_tokens, H, DV)
|
||||
a_shape = (1, num_tokens, H, chunk_size)
|
||||
g_shape = (1, num_tokens, H)
|
||||
b_shape = (1, num_tokens, H)
|
||||
h_shape = (1, num_chunks, H, DK, DV)
|
||||
else:
|
||||
q_shape = (batch_size, num_tokens, Hg, DK)
|
||||
k_shape = (batch_size, num_tokens, Hg, DK)
|
||||
v_shape = (batch_size, num_tokens, H, DV)
|
||||
o_shape = (batch_size, num_tokens, H, DV)
|
||||
a_shape = (batch_size, num_tokens, H, chunk_size)
|
||||
g_shape = (batch_size, num_tokens, H)
|
||||
b_shape = (batch_size, num_tokens, H)
|
||||
h_shape = (batch_size, num_chunks, H, DK, DV)
|
||||
h0_shape = (batch_size, H, DK, DV)
|
||||
ht_shape = (raw_batch_size, H, DK, DV)
|
||||
|
||||
@T.prim_func
|
||||
def tilelang_fused_chunk_gdr_fwd_kernel(
|
||||
q: T.Tensor(q_shape, dtype=qkva_dtype),
|
||||
k: T.Tensor(k_shape, dtype=qkva_dtype),
|
||||
v: T.Tensor(v_shape, dtype=qkva_dtype),
|
||||
a: T.Tensor(a_shape, dtype=qkva_dtype),
|
||||
g: T.Tensor(g_shape, dtype=g_dtype),
|
||||
b: T.Tensor(b_shape, dtype=b_dtype),
|
||||
h0: T.Tensor(h0_shape, dtype=h0_dtype),
|
||||
cu_seqlens: T.Tensor([batch_size + 1], dtype=seqlen_dtype),
|
||||
chunk_offsets: T.Tensor([batch_size + 1], dtype=seqlen_dtype),
|
||||
cp_seq_map: T.Tensor([batch_size], dtype=seqlen_dtype),
|
||||
raw_cu_seqlens: T.Tensor([raw_batch_size + 1], dtype=seqlen_dtype),
|
||||
o: T.Tensor(o_shape, dtype=o_dtype),
|
||||
h: T.Tensor(h_shape, dtype=h_dtype),
|
||||
ht: T.Tensor(ht_shape, dtype=ht_dtype),
|
||||
):
|
||||
with T.Kernel(T.ceildiv(DV, block_DV) * batch_size * H, threads=512) as (bbhv,):
|
||||
bbh, bv = bbhv // T.ceildiv(DV, block_DV), bbhv % T.ceildiv(DV, block_DV)
|
||||
bb, bh = bbh // H, bbh % H
|
||||
bhg = bh // (H // Hg)
|
||||
|
||||
batch_idx = T.alloc_var("int32")
|
||||
seq_start_idx = T.alloc_var("int32")
|
||||
seq_end_idx = T.alloc_var("int32")
|
||||
seq_split_idx = T.alloc_var("int32")
|
||||
chunk_start_idx = T.alloc_var("int32")
|
||||
chunk_split_idx = T.alloc_var("int32")
|
||||
|
||||
batch_idx = 0 if is_varlen else bb
|
||||
seq_start_idx = cu_seqlens[bb] if is_varlen else 0
|
||||
seq_end_idx = cu_seqlens[bb + 1] if is_varlen else num_tokens
|
||||
chunk_start_idx = chunk_offsets[bb] if is_varlen else 0
|
||||
|
||||
raw_batch_idx = T.alloc_var("int32")
|
||||
raw_seq_end_idx = T.alloc_var("int32")
|
||||
need_store_final_state = T.alloc_var("bool")
|
||||
raw_batch_idx = cp_seq_map[bb] if is_cp else bb
|
||||
raw_seq_end_idx = (
|
||||
raw_cu_seqlens[raw_batch_idx + 1] if is_cp else seq_end_idx
|
||||
)
|
||||
need_store_final_state = store_final_state & (
|
||||
raw_seq_end_idx == seq_end_idx
|
||||
)
|
||||
|
||||
num_iters = T.alloc_var("int32")
|
||||
num_unmasked_iters = T.alloc_var("int32")
|
||||
num_iters = T.ceildiv(seq_end_idx - seq_start_idx, block_S)
|
||||
num_unmasked_iters = (seq_end_idx - seq_start_idx) // block_S
|
||||
|
||||
q_shared = T.alloc_shared((2, block_S, DK), dtype=qkva_dtype)
|
||||
k_shared = T.alloc_shared((2, block_S, DK), dtype=qkva_dtype)
|
||||
v_shared = T.alloc_shared((2, block_S, block_DV), dtype=qkva_dtype)
|
||||
a_shared = T.alloc_shared((2, block_S, block_S), dtype=qkva_dtype)
|
||||
g_shared = T.alloc_shared((2, block_S), dtype=accum_dtype, scope="shared")
|
||||
b_shared = T.alloc_shared((2, block_S), dtype=accum_dtype, scope="shared")
|
||||
|
||||
o_shared = T.alloc_shared((block_S, block_DV), dtype=o_dtype)
|
||||
h_shared = T.alloc_shared((DK, block_DV), dtype=qkva_dtype)
|
||||
vd_shared = T.alloc_shared((block_S, block_DV), dtype=qkva_dtype)
|
||||
vn_shared = T.alloc_shared((block_S, block_DV), dtype=qkva_dtype)
|
||||
p_shared = T.alloc_shared((block_S, block_S), dtype=qkva_dtype)
|
||||
g_exp_shared = T.alloc_shared((block_S), dtype=accum_dtype, scope="shared")
|
||||
g_rev_exp_shared = T.alloc_shared(
|
||||
(block_S), dtype=accum_dtype, scope="shared"
|
||||
)
|
||||
|
||||
h_fragment = T.alloc_fragment((DK, block_DV), dtype=accum_dtype)
|
||||
o_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
|
||||
v_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
|
||||
u_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
|
||||
p_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
|
||||
a_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
|
||||
g_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
|
||||
g_last_local = T.alloc_local((1), dtype=accum_dtype)
|
||||
|
||||
data_is_ready = T.alloc_barrier(arrive_count=[96] * 2)
|
||||
data_is_free = T.alloc_barrier(arrive_count=[384] * 2)
|
||||
|
||||
bar_o = T.alloc_barrier(arrive_count=128)
|
||||
bar_0 = T.alloc_barrier(arrive_count=416)
|
||||
bar_1 = T.alloc_barrier(arrive_count=256)
|
||||
_bar_2 = T.alloc_barrier(arrive_count=128)
|
||||
bar_3 = T.alloc_barrier(arrive_count=128)
|
||||
bar_4 = T.alloc_barrier(arrive_count=128)
|
||||
bar_5 = T.alloc_barrier(arrive_count=416)
|
||||
|
||||
T.use_swizzle(10)
|
||||
|
||||
tx = T.get_thread_binding()
|
||||
|
||||
PRODUCER_NREG = 32
|
||||
CONSUMER_V_NREG = 128
|
||||
CONSUMER_S_NREG = 160
|
||||
CONSUMER_O_NREG = 128
|
||||
|
||||
if tx < 128:
|
||||
T.set_max_nreg(CONSUMER_S_NREG, 1)
|
||||
|
||||
# Initialize S
|
||||
if use_initial_state:
|
||||
T.copy(
|
||||
h0[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV],
|
||||
h_fragment,
|
||||
)
|
||||
else:
|
||||
T.clear(h_fragment)
|
||||
|
||||
# Main Loop
|
||||
for i_s in T.serial(num_iters):
|
||||
# [STAGE 0]
|
||||
T.barrier_wait(data_is_ready[i_s % 2], (i_s // 2 + 0) % 2)
|
||||
T.barrier_arrive(bar_0)
|
||||
|
||||
# [STAGE 0] 0
|
||||
T.barrier_wait(bar_0, i_s % 2)
|
||||
# S4[S] S
|
||||
T.copy(h_fragment, h_shared)
|
||||
T.barrier_arrive(bar_1)
|
||||
|
||||
# [STAGE 0] 2, 3, 4
|
||||
T.barrier_wait(bar_1, i_s % 2)
|
||||
# S = g_last * S
|
||||
g_last_local[0] = g_exp_shared[block_S - 1]
|
||||
for j_k, j_v in T.Parallel(DK, block_DV):
|
||||
h_fragment[j_k, j_v] *= g_last_local[0]
|
||||
T.barrier_arrive(bar_5)
|
||||
|
||||
# [STAGE 0] 5
|
||||
T.barrier_wait(bar_5, i_s % 2)
|
||||
# S += K^T @ V'
|
||||
T.gemm_v1(
|
||||
k_shared[i_s % 2, :, :],
|
||||
vn_shared,
|
||||
h_fragment,
|
||||
transpose_A=True,
|
||||
clear_accum=False,
|
||||
)
|
||||
|
||||
T.barrier_arrive(data_is_free[i_s % 2])
|
||||
|
||||
# Store final S
|
||||
if need_store_final_state:
|
||||
T.copy(
|
||||
h_fragment,
|
||||
ht[
|
||||
raw_batch_idx, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV
|
||||
],
|
||||
)
|
||||
|
||||
elif tx < 256:
|
||||
T.set_max_nreg(CONSUMER_V_NREG, 1)
|
||||
|
||||
# Main Loop
|
||||
for i_s in T.serial(num_iters):
|
||||
# [STAGE 0]
|
||||
T.barrier_wait(data_is_ready[i_s % 2], (i_s // 2 + 0) % 2)
|
||||
T.barrier_arrive(bar_0)
|
||||
|
||||
# [STAGE 0] 0
|
||||
T.barrier_wait(bar_0, i_s % 2)
|
||||
# Precompute g, g_last/g
|
||||
for j_s in T.Parallel(block_S):
|
||||
g_exp_shared[j_s] = T.exp2(g_shared[i_s % 2, j_s] * 1.442695)
|
||||
for j_s in T.Parallel(block_S):
|
||||
g_rev_exp_shared[j_s] = T.if_then_else(
|
||||
seq_start_idx + i_s * block_S + j_s < seq_end_idx,
|
||||
T.exp2(
|
||||
(
|
||||
g_shared[i_s % 2, block_S - 1]
|
||||
- g_shared[i_s % 2, j_s]
|
||||
)
|
||||
* 1.442695
|
||||
),
|
||||
0.0,
|
||||
)
|
||||
T.barrier_arrive(bar_1)
|
||||
|
||||
# [STAGE 0] 1
|
||||
T.barrier_wait(bar_1, i_s % 2)
|
||||
# U = K @ S
|
||||
T.gemm_v1(
|
||||
k_shared[i_s % 2, :, :], h_shared, u_fragment, clear_accum=True
|
||||
)
|
||||
|
||||
# [STAGE 0] 2
|
||||
# W = V - g * U
|
||||
for j_s, j_v in T.Parallel(block_S, block_DV):
|
||||
u_fragment[j_s, j_v] *= -g_exp_shared[j_s]
|
||||
for j_s, j_v in T.Parallel(block_S, block_DV):
|
||||
u_fragment[j_s, j_v] += v_shared[i_s % 2, j_s, j_v]
|
||||
# S2[V] W
|
||||
for j_s, j_v in T.Parallel(block_S, block_DV):
|
||||
v_shared[i_s % 2, j_s, j_v] = u_fragment[j_s, j_v]
|
||||
|
||||
# [STAGE 0] 3
|
||||
T.barrier_wait(bar_3, i_s % 2)
|
||||
# Vd = Ag @ W
|
||||
T.gemm_v1(
|
||||
a_shared[i_s % 2, :, :],
|
||||
v_shared[i_s % 2, :, :],
|
||||
v_fragment,
|
||||
clear_accum=True,
|
||||
)
|
||||
# S2[2] Vd
|
||||
T.copy(v_fragment, vd_shared)
|
||||
T.barrier_arrive(bar_4)
|
||||
|
||||
# [STAGE 0] 4
|
||||
# V' = g_last/g Vd
|
||||
for j_s, j_v in T.Parallel(block_S, block_DV):
|
||||
v_fragment[j_s, j_v] *= g_rev_exp_shared[j_s]
|
||||
# S2[1] V'
|
||||
T.copy(v_fragment, vn_shared)
|
||||
T.barrier_arrive(bar_5)
|
||||
|
||||
T.barrier_wait(bar_5, i_s % 2)
|
||||
|
||||
T.barrier_arrive(data_is_free[i_s % 2])
|
||||
|
||||
elif tx < 384:
|
||||
T.set_max_nreg(CONSUMER_O_NREG, 1)
|
||||
|
||||
# Main Loop
|
||||
for i_s in T.serial(num_iters):
|
||||
# [STAGE 0]
|
||||
T.barrier_wait(data_is_ready[i_s % 2], (i_s // 2 + 0) % 2)
|
||||
T.barrier_arrive(bar_0)
|
||||
|
||||
# [STAGE 0] 0
|
||||
T.barrier_wait(bar_0, i_s % 2)
|
||||
# P = Q K^T
|
||||
T.gemm_v1(
|
||||
q_shared[i_s % 2, :, :],
|
||||
k_shared[i_s % 2, :, :],
|
||||
p_fragment,
|
||||
transpose_B=True,
|
||||
clear_accum=True,
|
||||
)
|
||||
|
||||
# [STAGE 0] 1
|
||||
# G = Lower(diag(g) @ I @ diag(1/g))
|
||||
for j_s, j_t in T.Parallel(block_S, block_S):
|
||||
g_fragment[j_s, j_t] = (
|
||||
g_shared[i_s % 2, j_s] - g_shared[i_s % 2, j_t]
|
||||
)
|
||||
for j_s, j_t in T.Parallel(block_S, block_S):
|
||||
if j_s >= j_t:
|
||||
g_fragment[j_s, j_t] = T.exp2(
|
||||
g_fragment[j_s, j_t] * 1.442695
|
||||
)
|
||||
else:
|
||||
g_fragment[j_s, j_t] = 0
|
||||
# Ag = G * Ar * b
|
||||
for j_s, j_t in T.Parallel(block_S, block_S):
|
||||
a_fragment[j_s, j_t] = a_shared[i_s % 2, j_s, j_t]
|
||||
for j_s, j_t in T.Parallel(block_S, block_S):
|
||||
a_fragment[j_s, j_t] *= g_fragment[j_s, j_t]
|
||||
for j_s, j_t in T.Parallel(block_S, block_S):
|
||||
a_fragment[j_s, j_t] *= b_shared[i_s % 2, j_t]
|
||||
for j_s, j_t in T.Parallel(block_S, block_S):
|
||||
a_shared[i_s % 2, j_s, j_t] = a_fragment[j_s, j_t]
|
||||
|
||||
# [STAGE 0] 2
|
||||
T.barrier_wait(bar_1, i_s % 2)
|
||||
# O = Q @ S
|
||||
T.gemm_v1(
|
||||
q_shared[i_s % 2, :, :], h_shared, o_fragment, clear_accum=True
|
||||
)
|
||||
|
||||
# [STAGE 0] 3
|
||||
# Pg = s * G * P
|
||||
for j_s, j_t in T.Parallel(block_S, block_S):
|
||||
p_fragment[j_s, j_t] *= scale * g_fragment[j_s, j_t]
|
||||
# S1[1] Pg
|
||||
T.copy(p_fragment, p_shared)
|
||||
T.barrier_arrive(bar_3)
|
||||
# O = s * g * O
|
||||
for j_s, j_k in T.Parallel(block_S, DK):
|
||||
o_fragment[j_s, j_k] *= scale * g_exp_shared[j_s]
|
||||
|
||||
# [STAGE 0] 4
|
||||
T.barrier_wait(bar_4, i_s % 2)
|
||||
# O += Pg @ Vd
|
||||
T.gemm_v1(p_shared, vd_shared, o_fragment, clear_accum=False)
|
||||
T.barrier_arrive(bar_5)
|
||||
|
||||
# [STAGE 0] 5
|
||||
T.barrier_wait(bar_5, i_s % 2)
|
||||
# S2[S] O
|
||||
T.copy(o_fragment, o_shared)
|
||||
|
||||
T.barrier_arrive(data_is_free[i_s % 2])
|
||||
|
||||
T.barrier_arrive(bar_o)
|
||||
|
||||
else:
|
||||
T.set_max_nreg(PRODUCER_NREG, 0)
|
||||
|
||||
if tx < 384 + 32:
|
||||
for i_s in T.serial(num_iters):
|
||||
T.barrier_wait(data_is_free[i_s % 2], (i_s // 2 + 1) % 2)
|
||||
left = seq_start_idx + i_s * block_S
|
||||
right = left + block_S
|
||||
|
||||
# Load Q
|
||||
T.copy(
|
||||
q[batch_idx, left:right, bhg, 0:DK], q_shared[i_s % 2, :, :]
|
||||
)
|
||||
# Load K
|
||||
T.copy(
|
||||
k[batch_idx, left:right, bhg, 0:DK], k_shared[i_s % 2, :, :]
|
||||
)
|
||||
|
||||
T.barrier_arrive(data_is_ready[i_s % 2])
|
||||
|
||||
elif tx < 384 + 64:
|
||||
for i_s in T.serial(num_iters):
|
||||
T.barrier_wait(data_is_free[i_s % 2], (i_s // 2 + 1) % 2)
|
||||
left = seq_start_idx + i_s * block_S
|
||||
right = left + block_S
|
||||
|
||||
# Load V
|
||||
T.copy(
|
||||
v[
|
||||
batch_idx,
|
||||
left:right,
|
||||
bh,
|
||||
bv * block_DV : (bv + 1) * block_DV,
|
||||
],
|
||||
v_shared[i_s % 2, :, :],
|
||||
)
|
||||
# Load beta
|
||||
if right <= seq_end_idx:
|
||||
for j_s in T.Parallel(block_S):
|
||||
b_shared[i_s % 2, j_s] = b[batch_idx, left + j_s, bh]
|
||||
else:
|
||||
for j_s in T.Parallel(block_S):
|
||||
if left + j_s < seq_end_idx:
|
||||
b_shared[i_s % 2, j_s] = b[
|
||||
batch_idx, left + j_s, bh
|
||||
]
|
||||
else:
|
||||
b_shared[i_s % 2, j_s] = 0
|
||||
|
||||
T.barrier_arrive(data_is_ready[i_s % 2])
|
||||
|
||||
elif tx < 384 + 96:
|
||||
for i_s in T.serial(num_iters):
|
||||
T.barrier_wait(data_is_free[i_s % 2], (i_s // 2 + 1) % 2)
|
||||
left = seq_start_idx + i_s * block_S
|
||||
right = left + block_S
|
||||
|
||||
# Load A
|
||||
T.copy(
|
||||
a[batch_idx, left:right, bh, 0:block_S],
|
||||
a_shared[i_s % 2, :, :],
|
||||
)
|
||||
# Load gamma
|
||||
if right <= seq_end_idx:
|
||||
for j_s in T.Parallel(block_S):
|
||||
g_shared[i_s % 2, j_s] = g[batch_idx, left + j_s, bh]
|
||||
else:
|
||||
for j_s in T.Parallel(block_S):
|
||||
if left + j_s < seq_end_idx:
|
||||
g_shared[i_s % 2, j_s] = g[
|
||||
batch_idx, left + j_s, bh
|
||||
]
|
||||
else:
|
||||
g_shared[i_s % 2, j_s] = g[
|
||||
batch_idx, seq_end_idx - 1, bh
|
||||
]
|
||||
|
||||
T.barrier_arrive(data_is_ready[i_s % 2])
|
||||
|
||||
else:
|
||||
for i_s in T.serial(num_unmasked_iters):
|
||||
right = seq_start_idx + i_s * block_S
|
||||
left = right - block_S
|
||||
|
||||
T.barrier_arrive(bar_0)
|
||||
|
||||
T.barrier_wait(bar_0, i_s % 2)
|
||||
# Store O
|
||||
if i_s > 0 and store_o:
|
||||
T.copy(
|
||||
o_shared,
|
||||
o[
|
||||
batch_idx,
|
||||
left:right,
|
||||
bh,
|
||||
bv * block_DV : (bv + 1) * block_DV,
|
||||
],
|
||||
)
|
||||
T.barrier_arrive(bar_5)
|
||||
|
||||
T.barrier_wait(bar_1, i_s % 2)
|
||||
# Store S
|
||||
if store_h:
|
||||
T.copy(
|
||||
h_shared,
|
||||
h[
|
||||
batch_idx,
|
||||
chunk_start_idx + i_s,
|
||||
bh,
|
||||
0:DK,
|
||||
bv * block_DV : (bv + 1) * block_DV,
|
||||
],
|
||||
)
|
||||
|
||||
if num_unmasked_iters < num_iters:
|
||||
seq_split_idx = seq_start_idx + num_unmasked_iters * block_S
|
||||
chunk_split_idx = chunk_start_idx + num_unmasked_iters
|
||||
|
||||
T.barrier_arrive(bar_0)
|
||||
|
||||
T.barrier_wait(bar_0, num_unmasked_iters % 2)
|
||||
# Store O
|
||||
if num_unmasked_iters > 0 and store_o:
|
||||
T.copy(
|
||||
o_shared,
|
||||
o[
|
||||
batch_idx,
|
||||
seq_split_idx - block_S : seq_split_idx,
|
||||
bh,
|
||||
bv * block_DV : (bv + 1) * block_DV,
|
||||
],
|
||||
)
|
||||
T.barrier_arrive(bar_5)
|
||||
|
||||
T.barrier_wait(bar_1, num_unmasked_iters % 2)
|
||||
# Store S
|
||||
if store_h:
|
||||
T.copy(
|
||||
h_shared,
|
||||
h[
|
||||
batch_idx,
|
||||
chunk_split_idx,
|
||||
bh,
|
||||
0:DK,
|
||||
bv * block_DV : (bv + 1) * block_DV,
|
||||
],
|
||||
)
|
||||
|
||||
seq_split_idx = seq_start_idx + (num_iters - 1) * block_S
|
||||
|
||||
# Store O
|
||||
T.barrier_wait(bar_o, 0)
|
||||
if store_o:
|
||||
for j_s, j_v in T.Parallel(block_S, block_DV):
|
||||
with T.If(seq_split_idx + j_s < seq_end_idx):
|
||||
with T.Then():
|
||||
o[
|
||||
batch_idx,
|
||||
seq_split_idx + j_s,
|
||||
bh,
|
||||
bv * block_DV + j_v,
|
||||
] = o_shared[j_s, j_v]
|
||||
|
||||
return tilelang_fused_chunk_gdr_fwd_kernel
|
||||
|
||||
|
||||
def fused_gdr_fwd(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
a: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
scale: float | None = None,
|
||||
initial_state: torch.Tensor | None = None,
|
||||
output_final_state: bool = True,
|
||||
output_h: bool = False,
|
||||
output_o: bool = True,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cp_seq_map: torch.LongTensor | None = None,
|
||||
raw_cu_seqlens: torch.LongTensor | None = None,
|
||||
chunk_size: int = 64,
|
||||
):
|
||||
batch_size, num_tokens, Hg, K = k.shape
|
||||
_, _, H, V = v.shape
|
||||
scale = scale or K ** (-0.5)
|
||||
assert K == V == 128
|
||||
assert chunk_size == 64
|
||||
|
||||
if cu_seqlens is None:
|
||||
real_batch_size = batch_size
|
||||
num_chunks = tilelang.cdiv(num_tokens, chunk_size) if output_h else 0
|
||||
cu_seqlens = torch.empty((batch_size + 1), dtype=torch.int32, device=k.device)
|
||||
chunk_offsets = torch.empty(
|
||||
(batch_size + 1), dtype=torch.int32, device=k.device
|
||||
)
|
||||
seqlen_dtype = torch.int32
|
||||
is_varlen = False
|
||||
else:
|
||||
real_batch_size = len(cu_seqlens) - 1
|
||||
chunk_offsets, num_chunks = prepare_chunk_offsets(cu_seqlens, chunk_size)
|
||||
chunk_offsets = chunk_offsets.to(cu_seqlens.dtype)
|
||||
num_chunks = num_chunks if output_h else 0
|
||||
seqlen_dtype = cu_seqlens.dtype
|
||||
is_varlen = True
|
||||
|
||||
if cp_seq_map is None:
|
||||
cp_seq_map = torch.empty(
|
||||
(real_batch_size,), dtype=seqlen_dtype, device=k.device
|
||||
)
|
||||
is_cp = False
|
||||
else:
|
||||
is_cp = True
|
||||
|
||||
use_initial_state = initial_state is not None
|
||||
if initial_state is None:
|
||||
initial_state = torch.empty(
|
||||
(real_batch_size, H, K, V), dtype=torch.float32, device=k.device
|
||||
)
|
||||
h = torch.empty((batch_size, num_chunks, H, K, V), dtype=k.dtype, device=k.device)
|
||||
if raw_cu_seqlens is None:
|
||||
raw_cu_seqlens = torch.empty(
|
||||
(real_batch_size + 1,), dtype=seqlen_dtype, device=k.device
|
||||
)
|
||||
final_state = torch.empty(
|
||||
(real_batch_size, H, K, V), dtype=torch.float32, device=k.device
|
||||
)
|
||||
else:
|
||||
final_state = torch.empty(
|
||||
(raw_cu_seqlens.shape[0] - 1, H, K, V), dtype=torch.float32, device=k.device
|
||||
)
|
||||
o = torch.empty_like(v)
|
||||
|
||||
grid_size = real_batch_size * H
|
||||
if grid_size >= TARGET_NUM_CTAS:
|
||||
block_DV = 128
|
||||
elif grid_size * 2 >= TARGET_NUM_CTAS:
|
||||
block_DV = 64
|
||||
else:
|
||||
block_DV = 32
|
||||
|
||||
tilelang_fused_chunk_gdr_fwd_kernel = tilelang_fused_chunk_gdr_fwd(
|
||||
H,
|
||||
Hg,
|
||||
K,
|
||||
V,
|
||||
chunk_size,
|
||||
scale,
|
||||
qkva_dtype=q.dtype,
|
||||
g_dtype=g.dtype,
|
||||
b_dtype=b.dtype,
|
||||
h0_dtype=initial_state.dtype,
|
||||
ht_dtype=final_state.dtype,
|
||||
h_dtype=h.dtype,
|
||||
o_dtype=o.dtype,
|
||||
seqlen_dtype=seqlen_dtype,
|
||||
accum_dtype="float32",
|
||||
use_initial_state=use_initial_state,
|
||||
store_final_state=output_final_state,
|
||||
store_h=output_h,
|
||||
store_o=output_o,
|
||||
is_varlen=is_varlen,
|
||||
is_cp=is_cp,
|
||||
block_DV=block_DV,
|
||||
)
|
||||
tilelang_fused_chunk_gdr_fwd_kernel(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
a,
|
||||
g,
|
||||
b,
|
||||
initial_state,
|
||||
cu_seqlens,
|
||||
chunk_offsets,
|
||||
cp_seq_map,
|
||||
raw_cu_seqlens,
|
||||
o,
|
||||
h,
|
||||
final_state,
|
||||
)
|
||||
|
||||
if not output_final_state:
|
||||
final_state = None
|
||||
if not output_h:
|
||||
h = None
|
||||
if not output_o:
|
||||
o = None
|
||||
|
||||
return o, h, final_state
|
||||
345
flash_qla/ops/gated_delta_rule/chunk/hopper/kkt_solve.py
Normal file
345
flash_qla/ops/gated_delta_rule/chunk/hopper/kkt_solve.py
Normal file
@@ -0,0 +1,345 @@
|
||||
# Copyright (c) 2026 The Qwen team, Alibaba Group.
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import tilelang
|
||||
import tilelang.language as T
|
||||
|
||||
from flash_qla.utils import prepare_chunk_indices
|
||||
|
||||
|
||||
@tilelang.jit(
|
||||
# out_idx=[-1],
|
||||
pass_configs={
|
||||
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
|
||||
# tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
|
||||
# tilelang.PassConfigKey.TL_ENABLE_ASYNC_COPY: True,
|
||||
},
|
||||
)
|
||||
def tilelang_kkt_solve(
|
||||
H,
|
||||
Hg,
|
||||
DK,
|
||||
chunk_size,
|
||||
accum_dtype,
|
||||
qkva_dtype,
|
||||
b_dtype,
|
||||
seqlen_dtype,
|
||||
is_varlen,
|
||||
):
|
||||
data_batch_size = T.dynamic("data_batch_size")
|
||||
real_batch_size = T.dynamic("real_batch_size")
|
||||
num_tokens = T.dynamic("num_tokens")
|
||||
num_chunks = T.dynamic("num_chunks")
|
||||
block_S = chunk_size
|
||||
|
||||
k_shape = (data_batch_size, num_tokens, Hg, DK)
|
||||
a_shape = (data_batch_size, num_tokens, H, chunk_size)
|
||||
b_shape = (data_batch_size, num_tokens, H)
|
||||
|
||||
@T.macro
|
||||
def kernel_body(
|
||||
bb,
|
||||
bc,
|
||||
bh,
|
||||
bhg,
|
||||
batch_idx,
|
||||
chunk_idx,
|
||||
seq_start_idx,
|
||||
seq_end_idx,
|
||||
k,
|
||||
b,
|
||||
a,
|
||||
):
|
||||
left = seq_start_idx + chunk_idx * block_S
|
||||
right = left + block_S
|
||||
|
||||
k_shared = T.alloc_shared((block_S, DK), dtype=qkva_dtype)
|
||||
b_shared = T.alloc_shared((block_S), dtype=accum_dtype, scope="shared")
|
||||
a64_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
|
||||
|
||||
a16i_row = T.alloc_fragment((4, 16), dtype=accum_dtype)
|
||||
a16i_sum = T.alloc_fragment((4, 16), dtype=accum_dtype)
|
||||
|
||||
a16i_shared = T.alloc_shared((4, 17, 16), dtype=accum_dtype)
|
||||
a16o_shared = T.alloc_shared((2, 17, 16), dtype=accum_dtype)
|
||||
a16o_fragment = T.alloc_fragment((2, 16, 16), dtype=accum_dtype)
|
||||
|
||||
a32i_fragment = T.alloc_fragment((2, 32, 32), dtype=accum_dtype)
|
||||
a32i0_shared = T.alloc_shared((32, 32), dtype=accum_dtype)
|
||||
a32i1_shared = T.alloc_shared((32, 32), dtype=accum_dtype)
|
||||
a32o_shared = T.alloc_shared((32, 32), dtype=accum_dtype)
|
||||
a32o_fragment = T.alloc_fragment((32, 32), dtype=accum_dtype)
|
||||
|
||||
a64_shared = T.alloc_shared((block_S, block_S), dtype=qkva_dtype)
|
||||
|
||||
T.annotate_layout(
|
||||
{
|
||||
a16i_shared: tilelang.layout.make_linear_layout(a16i_shared),
|
||||
a16o_shared: tilelang.layout.make_linear_layout(a16o_shared),
|
||||
}
|
||||
)
|
||||
|
||||
k_is_ready = T.alloc_barrier(arrive_count=32)
|
||||
a_is_ready = T.alloc_barrier(arrive_count=128)
|
||||
|
||||
tx = T.get_thread_binding()
|
||||
|
||||
PRODUCER_NREG = 24
|
||||
CONSUMER_NREG = 64
|
||||
|
||||
if tx < 128:
|
||||
T.set_max_nreg(CONSUMER_NREG, 1)
|
||||
|
||||
# Load b
|
||||
if right <= seq_end_idx:
|
||||
for j_s in T.Parallel(block_S):
|
||||
b_shared[j_s] = b[bb, left + j_s, bh]
|
||||
else:
|
||||
for j_s in T.Parallel(block_S):
|
||||
if left + j_s < seq_end_idx:
|
||||
b_shared[j_s] = b[bb, left + j_s, bh]
|
||||
else:
|
||||
b_shared[j_s] = 0
|
||||
|
||||
T.barrier_wait(k_is_ready, 0)
|
||||
|
||||
# A = K @ K^T
|
||||
T.gemm_v1(
|
||||
k_shared, k_shared, a64_fragment, transpose_B=True, clear_accum=True
|
||||
)
|
||||
|
||||
# A = b * A
|
||||
for j_s, j_t in T.Parallel(block_S, block_S):
|
||||
a64_fragment[j_s, j_t] *= b_shared[j_s]
|
||||
|
||||
# A = I + StrictLower(A)
|
||||
for j_s, j_t in T.Parallel(block_S, block_S):
|
||||
if j_s < j_t:
|
||||
a64_fragment[j_s, j_t] = 0
|
||||
elif j_s == j_t:
|
||||
a64_fragment[j_s, j_t] = 1
|
||||
|
||||
# Prepare inversion input
|
||||
for j_s, j_t in T.Parallel(block_S, block_S):
|
||||
if j_s >= 32 and j_t < 32:
|
||||
a32o_shared[j_s - 32, j_t] = -a64_fragment[j_s, j_t]
|
||||
elif (j_s // 16) == (j_t // 16) + 1:
|
||||
a16o_shared[j_s // 32, j_s % 16, j_t % 16] = -a64_fragment[j_s, j_t]
|
||||
elif (j_s // 16) == (j_t // 16):
|
||||
a16i_shared[j_s // 16, j_s % 16, j_t % 16] = a64_fragment[j_s, j_t]
|
||||
|
||||
# Diagonal 4x16x16
|
||||
T.clear(a16i_row)
|
||||
for k_s in T.unroll(1, 16):
|
||||
for j_s, k_t in T.Parallel(4, 16):
|
||||
if k_t < k_s:
|
||||
a16i_row[j_s, k_t] = a16i_shared[j_s, k_s, k_t]
|
||||
T.clear(a16i_sum)
|
||||
for k_r in T.unroll(k_s):
|
||||
for j_s, k_t in T.Parallel(4, 16):
|
||||
a16i_sum[j_s, k_t] -= (
|
||||
a16i_shared[j_s, k_r, k_t] * a16i_row[j_s, k_r]
|
||||
)
|
||||
for j_s, k_t in T.Parallel(4, 16):
|
||||
if k_t < k_s:
|
||||
a16i_shared[j_s, k_s, k_t] = a16i_sum[j_s, k_t]
|
||||
|
||||
# First level 2x16x16
|
||||
T.clear(a16o_fragment)
|
||||
for k_r in T.unroll(16):
|
||||
for j_s, k_s, k_t in T.Parallel(2, 16, 16):
|
||||
a16o_fragment[j_s, k_s, k_t] += (
|
||||
a16i_shared[j_s * 2 + 1, k_s, k_r] * a16o_shared[j_s, k_r, k_t]
|
||||
)
|
||||
for j_s, k_s, k_t in T.Parallel(2, 16, 16):
|
||||
a16o_shared[j_s, k_t, k_s] = a16o_fragment[j_s, k_s, k_t]
|
||||
T.clear(a16o_fragment)
|
||||
for k_r in T.unroll(16):
|
||||
for j_s, k_s, k_t in T.Parallel(2, 16, 16):
|
||||
a16o_fragment[j_s, k_s, k_t] += (
|
||||
a16o_shared[j_s, k_r, k_s] * a16i_shared[j_s * 2, k_r, k_t]
|
||||
)
|
||||
T.copy(a16o_fragment, a16o_shared[:, 0:16, 0:16])
|
||||
|
||||
# Second level 1x32x32
|
||||
for j_s, k_s, k_t in T.Parallel(2, 32, 32):
|
||||
if k_s < 16 and k_t >= 16:
|
||||
a32i_fragment[j_s, k_s, k_t] = 0
|
||||
for j_s, k_s, k_t in T.Parallel(2, 32, 32):
|
||||
if k_s >= 16 and k_t < 16:
|
||||
a32i_fragment[j_s, k_s, k_t] = a16o_shared[j_s, k_s - 16, k_t]
|
||||
for j_s, k_s, k_t in T.Parallel(2, 32, 32):
|
||||
if k_s // 16 == k_t // 16:
|
||||
a32i_fragment[j_s, k_s, k_t] = a16i_shared[
|
||||
j_s * 2 + k_s // 16, k_s % 16, k_t % 16
|
||||
]
|
||||
for j_s, k_s, k_t in T.Parallel(2, 32, 32):
|
||||
if j_s == 0:
|
||||
a32i0_shared[k_s, k_t] = a32i_fragment[j_s, k_s, k_t]
|
||||
else:
|
||||
a32i1_shared[k_s, k_t] = a32i_fragment[j_s, k_s, k_t]
|
||||
T.gemm_v1(a32i1_shared, a32o_shared, a32o_fragment, clear_accum=True)
|
||||
T.copy(a32o_fragment, a32o_shared)
|
||||
T.gemm_v1(a32o_shared, a32i0_shared, a32o_fragment, clear_accum=True)
|
||||
|
||||
# Combine inversion output
|
||||
for j_s, k_s, k_t in T.Parallel(2, 32, 32):
|
||||
a64_shared[j_s * 32 + k_s, j_s * 32 + k_t] = a32i_fragment[
|
||||
j_s, k_s, k_t
|
||||
]
|
||||
for k_s, k_t in T.Parallel(32, 32):
|
||||
a64_shared[32 + k_s, k_t] = a32o_fragment[k_s, k_t]
|
||||
for k_s, k_t in T.Parallel(32, 32):
|
||||
a64_shared[k_s, 32 + k_t] = 0
|
||||
|
||||
T.barrier_arrive(a_is_ready)
|
||||
|
||||
else:
|
||||
T.set_max_nreg(PRODUCER_NREG, 0)
|
||||
|
||||
if tx < 128 + 32:
|
||||
# Load K
|
||||
T.copy(k[bb, left:right, bhg, 0:DK], k_shared)
|
||||
|
||||
T.barrier_arrive(k_is_ready)
|
||||
|
||||
elif tx < 128 + 64:
|
||||
T.barrier_wait(a_is_ready, 0)
|
||||
|
||||
# Save A (unmasked)
|
||||
if right <= seq_end_idx:
|
||||
T.copy(a64_shared, a[bb, left:right, bh, 0:block_S])
|
||||
|
||||
else:
|
||||
T.barrier_wait(a_is_ready, 0)
|
||||
|
||||
# Save A (masked)
|
||||
if right > seq_end_idx:
|
||||
for j_s, j_t in T.Parallel(block_S, block_S):
|
||||
if left + j_s < seq_end_idx:
|
||||
a[bb, left + j_s, bh, j_t] = a64_shared[j_s, j_t]
|
||||
|
||||
if is_varlen:
|
||||
|
||||
@T.prim_func
|
||||
def tilelang_kkt_solve_kernel(
|
||||
k: T.Tensor(k_shape, dtype=qkva_dtype),
|
||||
b: T.Tensor(b_shape, dtype=b_dtype),
|
||||
cu_seqlens: T.Tensor([real_batch_size + 1], dtype=seqlen_dtype),
|
||||
chunk_indices: T.Tensor([num_chunks, 2], dtype=seqlen_dtype),
|
||||
a: T.Tensor(a_shape, dtype=qkva_dtype),
|
||||
):
|
||||
with T.Kernel(num_chunks * H, threads=256) as (bch,):
|
||||
bc, bh = bch // H, bch % H
|
||||
bhg = bh // (H // Hg)
|
||||
|
||||
batch_idx = T.alloc_var("int32")
|
||||
chunk_idx = T.alloc_var("int32")
|
||||
seq_start_idx = T.alloc_var("int32")
|
||||
seq_end_idx = T.alloc_var("int32")
|
||||
|
||||
bb = 0
|
||||
batch_idx = chunk_indices[bc, 0]
|
||||
chunk_idx = chunk_indices[bc, 1]
|
||||
seq_start_idx = cu_seqlens[batch_idx]
|
||||
seq_end_idx = cu_seqlens[batch_idx + 1]
|
||||
|
||||
kernel_body(
|
||||
bb,
|
||||
bc,
|
||||
bh,
|
||||
bhg,
|
||||
batch_idx,
|
||||
chunk_idx,
|
||||
seq_start_idx,
|
||||
seq_end_idx,
|
||||
k,
|
||||
b,
|
||||
a,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
@T.prim_func
|
||||
def tilelang_kkt_solve_kernel(
|
||||
k: T.Tensor(k_shape, dtype=qkva_dtype),
|
||||
b: T.Tensor(b_shape, dtype=b_dtype),
|
||||
a: T.Tensor(a_shape, dtype=qkva_dtype),
|
||||
num_chunks: T.int32,
|
||||
):
|
||||
with T.Kernel(num_chunks * H, threads=256) as (bch,):
|
||||
bc, bh = bch // H, bch % H
|
||||
bhg = bh // (H // Hg)
|
||||
|
||||
batch_idx = T.alloc_var("int32")
|
||||
chunk_idx = T.alloc_var("int32")
|
||||
seq_start_idx = T.alloc_var("int32")
|
||||
seq_end_idx = T.alloc_var("int32")
|
||||
|
||||
bb = bc % data_batch_size
|
||||
batch_idx = bb
|
||||
chunk_idx = bc // data_batch_size
|
||||
seq_start_idx = 0
|
||||
seq_end_idx = num_tokens
|
||||
|
||||
kernel_body(
|
||||
bb,
|
||||
bc,
|
||||
bh,
|
||||
bhg,
|
||||
batch_idx,
|
||||
chunk_idx,
|
||||
seq_start_idx,
|
||||
seq_end_idx,
|
||||
k,
|
||||
b,
|
||||
a,
|
||||
)
|
||||
|
||||
return tilelang_kkt_solve_kernel
|
||||
|
||||
|
||||
def kkt_solve(
|
||||
k: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
chunk_size: int = 64,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
batch_size, num_tokens, Hg, K = k.shape
|
||||
_, _, H = b.shape
|
||||
assert K == 128
|
||||
assert chunk_size == 64
|
||||
|
||||
if cu_seqlens is None:
|
||||
num_chunks = batch_size * tilelang.cdiv(num_tokens, chunk_size)
|
||||
seqlen_dtype = "int32"
|
||||
is_varlen = False
|
||||
else:
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
|
||||
seqlen_dtype = cu_seqlens.dtype
|
||||
is_varlen = True
|
||||
|
||||
a = torch.empty(
|
||||
(batch_size, num_tokens, H, chunk_size), dtype=k.dtype, device=k.device
|
||||
)
|
||||
|
||||
tilelang_kkt_solve_kernel = tilelang_kkt_solve(
|
||||
H,
|
||||
Hg,
|
||||
K,
|
||||
chunk_size,
|
||||
qkva_dtype=k.dtype,
|
||||
b_dtype=b.dtype,
|
||||
seqlen_dtype=seqlen_dtype,
|
||||
accum_dtype="float32",
|
||||
is_varlen=is_varlen,
|
||||
)
|
||||
if is_varlen:
|
||||
tilelang_kkt_solve_kernel(k, b, cu_seqlens, chunk_indices, a)
|
||||
else:
|
||||
tilelang_kkt_solve_kernel(k, b, a, num_chunks)
|
||||
|
||||
return a
|
||||
558
flash_qla/ops/gated_delta_rule/chunk/hopper/prepare_h.py
Normal file
558
flash_qla/ops/gated_delta_rule/chunk/hopper/prepare_h.py
Normal file
@@ -0,0 +1,558 @@
|
||||
# Copyright (c) 2026 The Qwen team, Alibaba Group.
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
import torch
|
||||
import tilelang
|
||||
import tilelang.language as T
|
||||
|
||||
from flash_qla.utils import prepare_chunk_offsets
|
||||
|
||||
|
||||
@tilelang.jit(
|
||||
pass_configs={
|
||||
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
|
||||
},
|
||||
)
|
||||
def tilelang_prepare_h(
|
||||
H,
|
||||
Hg,
|
||||
DK,
|
||||
DV,
|
||||
chunk_size,
|
||||
accum_dtype,
|
||||
qkva_dtype,
|
||||
g_dtype,
|
||||
b_dtype,
|
||||
h0_dtype,
|
||||
ht_dtype,
|
||||
h_dtype,
|
||||
seqlen_dtype,
|
||||
use_initial_state,
|
||||
store_final_state,
|
||||
store_h,
|
||||
is_varlen,
|
||||
is_cp,
|
||||
num_stages=2,
|
||||
):
|
||||
batch_size = T.dynamic("batch_size")
|
||||
num_tokens = T.dynamic("num_tokens")
|
||||
num_chunks = T.dynamic("num_chunks")
|
||||
block_S = chunk_size
|
||||
|
||||
if is_varlen:
|
||||
k_shape = (1, num_tokens, Hg, DK)
|
||||
v_shape = (1, num_tokens, H, DV)
|
||||
a_shape = (1, num_tokens, H, chunk_size)
|
||||
g_shape = (1, num_tokens, H)
|
||||
b_shape = (1, num_tokens, H)
|
||||
h_shape = (1, num_chunks, H, DK, DV)
|
||||
else:
|
||||
k_shape = (batch_size, num_tokens, Hg, DK)
|
||||
v_shape = (batch_size, num_tokens, H, DV)
|
||||
a_shape = (batch_size, num_tokens, H, chunk_size)
|
||||
g_shape = (batch_size, num_tokens, H)
|
||||
b_shape = (batch_size, num_tokens, H)
|
||||
h_shape = (batch_size, num_chunks, H, DK, DV)
|
||||
h0_shape = (batch_size, H, DK, DV)
|
||||
ht_shape = (batch_size, H, DK, DV)
|
||||
m_shape = (batch_size, H, DK, DK)
|
||||
|
||||
@T.prim_func
|
||||
def tilelang_prepare_h_kernel(
|
||||
k: T.Tensor(k_shape, dtype=qkva_dtype),
|
||||
v: T.Tensor(v_shape, dtype=qkva_dtype),
|
||||
a: T.Tensor(a_shape, dtype=qkva_dtype),
|
||||
g: T.Tensor(g_shape, dtype=g_dtype),
|
||||
b: T.Tensor(b_shape, dtype=b_dtype),
|
||||
h0: T.Tensor(h0_shape, dtype=h0_dtype),
|
||||
cu_seqlens: T.Tensor([batch_size + 1], dtype=seqlen_dtype),
|
||||
chunk_offsets: T.Tensor([batch_size + 1], dtype=seqlen_dtype),
|
||||
num_warmup_chunks: T.Tensor([batch_size, H], dtype=seqlen_dtype),
|
||||
h: T.Tensor(h_shape, dtype=h_dtype),
|
||||
ht: T.Tensor(ht_shape, dtype=ht_dtype),
|
||||
mt: T.Tensor(m_shape, dtype=ht_dtype),
|
||||
):
|
||||
with T.Kernel(batch_size * H, threads=512) as (bbh,):
|
||||
bb, bh = bbh // H, bbh % H
|
||||
bhg = bh // (H // Hg)
|
||||
|
||||
batch_idx = T.alloc_var("int32")
|
||||
seq_start_idx = T.alloc_var("int32")
|
||||
seq_end_idx = T.alloc_var("int32")
|
||||
_seq_split_idx = T.alloc_var("int32")
|
||||
chunk_start_idx = T.alloc_var("int32")
|
||||
_chunk_split_idx = T.alloc_var("int32")
|
||||
|
||||
batch_idx = 0 if is_varlen else bb
|
||||
seq_start_idx = cu_seqlens[bb] if is_varlen else 0
|
||||
seq_end_idx = cu_seqlens[bb + 1] if is_varlen else num_tokens
|
||||
chunk_start_idx = chunk_offsets[bb] if is_varlen else 0
|
||||
|
||||
num_iters = T.alloc_var("int32")
|
||||
num_iters = (
|
||||
num_warmup_chunks[bb, bh]
|
||||
if is_cp
|
||||
else T.ceildiv(seq_end_idx - seq_start_idx, block_S)
|
||||
)
|
||||
|
||||
calc_mt = T.alloc_var("bool")
|
||||
calc_mt = is_cp and num_iters >= T.ceildiv(
|
||||
seq_end_idx - seq_start_idx, block_S
|
||||
)
|
||||
seq_start_idx = (
|
||||
seq_end_idx - num_iters * block_S if is_cp else seq_start_idx
|
||||
)
|
||||
|
||||
k_shared = T.alloc_shared((num_stages, block_S, DK), dtype=qkva_dtype)
|
||||
v_shared = T.alloc_shared((num_stages, block_S, DV), dtype=qkva_dtype)
|
||||
a_shared = T.alloc_shared((num_stages, block_S, block_S), dtype=qkva_dtype)
|
||||
g_shared = T.alloc_shared(
|
||||
(num_stages, block_S), dtype=accum_dtype, scope="shared"
|
||||
)
|
||||
b_shared = T.alloc_shared(
|
||||
(num_stages, block_S), dtype=accum_dtype, scope="shared"
|
||||
)
|
||||
h_shared = T.alloc_shared((DK, DV), dtype=qkva_dtype)
|
||||
x_shared = T.alloc_shared((block_S, DK), dtype=qkva_dtype)
|
||||
y_shared = T.alloc_shared((block_S, DV), dtype=qkva_dtype)
|
||||
m_shared_L = T.alloc_shared((DK, DK // 2), dtype=qkva_dtype)
|
||||
m_shared_R = T.alloc_shared((DK, DK // 2), dtype=qkva_dtype)
|
||||
z_shared_L = T.alloc_shared((block_S, DK // 2), dtype=qkva_dtype)
|
||||
z_shared_R = T.alloc_shared((block_S, DK // 2), dtype=qkva_dtype)
|
||||
g_rev_exp_shared = T.alloc_shared(
|
||||
(block_S), dtype=accum_dtype, scope="shared"
|
||||
)
|
||||
|
||||
h_fragment = T.alloc_fragment((DK, DV), dtype=accum_dtype)
|
||||
x_fragment = T.alloc_fragment((block_S, DK), dtype=accum_dtype)
|
||||
y_fragment = T.alloc_fragment((block_S, DV), dtype=accum_dtype)
|
||||
m_fragment_L = T.alloc_fragment((DK, DK // 2), dtype=accum_dtype)
|
||||
m_fragment_R = T.alloc_fragment((DK, DK // 2), dtype=accum_dtype)
|
||||
z_fragment_L = T.alloc_fragment((block_S, DK // 2), dtype=accum_dtype)
|
||||
z_fragment_R = T.alloc_fragment((block_S, DK // 2), dtype=accum_dtype)
|
||||
g_last_local_S = T.alloc_local((1), dtype=accum_dtype)
|
||||
g_last_local_X = T.alloc_local((1), dtype=accum_dtype)
|
||||
g_last_local_Y = T.alloc_local((1), dtype=accum_dtype)
|
||||
g_prod_X = T.alloc_fragment((1), dtype=accum_dtype)
|
||||
g_prod_Y = T.alloc_fragment((1), dtype=accum_dtype)
|
||||
|
||||
data_is_ready = T.alloc_barrier(arrive_count=[96] * num_stages)
|
||||
data_is_free = T.alloc_barrier(arrive_count=[384] * num_stages)
|
||||
|
||||
bar_0 = T.alloc_barrier(arrive_count=416)
|
||||
bar_1 = T.alloc_barrier(arrive_count=256)
|
||||
bar_2 = T.alloc_barrier(arrive_count=384)
|
||||
bar_3 = T.alloc_barrier(arrive_count=128)
|
||||
|
||||
T.use_swizzle(10)
|
||||
|
||||
tx = T.get_thread_binding()
|
||||
|
||||
PRODUCER_NREG = 24
|
||||
CONSUMER_S_NREG = 168
|
||||
CONSUMER_X_NREG = 160
|
||||
CONSUMER_Y_NREG = 160
|
||||
|
||||
if tx < 128:
|
||||
T.set_max_nreg(CONSUMER_S_NREG, 1)
|
||||
|
||||
# Initialize S
|
||||
if use_initial_state:
|
||||
T.copy(h0[bb, bh, 0:DK, 0:DV], h_fragment)
|
||||
else:
|
||||
T.clear(h_fragment)
|
||||
|
||||
# Main Loop
|
||||
for i_s in T.serial(num_iters):
|
||||
# [STAGE = i_s % num_stages]
|
||||
T.barrier_wait(
|
||||
data_is_ready[i_s % num_stages], (i_s // num_stages + 0) % 2
|
||||
)
|
||||
T.barrier_arrive(bar_0)
|
||||
|
||||
# [STAGE = i_s % num_stages] 0
|
||||
T.barrier_wait(bar_0, i_s % 2)
|
||||
# S4[1] S
|
||||
T.copy(h_fragment, h_shared)
|
||||
T.barrier_arrive(bar_1)
|
||||
|
||||
# [STAGE = i_s % num_stages] 1
|
||||
T.barrier_wait(bar_1, i_s % 2)
|
||||
# S = g_last * S
|
||||
g_last_local_S[0] = T.exp2(
|
||||
g_shared[i_s % num_stages, block_S - 1] * 1.442695
|
||||
)
|
||||
for j_k, j_v in T.Parallel(DK, DV):
|
||||
h_fragment[j_k, j_v] *= g_last_local_S[0]
|
||||
T.barrier_arrive(bar_2)
|
||||
|
||||
# [STAGE = i_s % num_stages] 2
|
||||
T.barrier_wait(bar_2, i_s % 2)
|
||||
# S += X^T @ Y
|
||||
T.gemm_v1(
|
||||
x_shared,
|
||||
y_shared,
|
||||
h_fragment,
|
||||
transpose_A=True,
|
||||
clear_accum=False,
|
||||
)
|
||||
T.barrier_arrive(bar_3)
|
||||
|
||||
T.barrier_arrive(data_is_free[i_s % num_stages])
|
||||
|
||||
# Store final S
|
||||
if store_final_state:
|
||||
T.copy(h_fragment, ht[bb, bh, 0:DK, 0:DV])
|
||||
|
||||
elif tx < 256:
|
||||
T.set_max_nreg(CONSUMER_X_NREG, 1)
|
||||
|
||||
if calc_mt:
|
||||
for j_k, j_v in T.Parallel(DK, DK // 2):
|
||||
if j_k == j_v + DK // 2:
|
||||
m_fragment_R[j_k, j_v] = 1
|
||||
else:
|
||||
m_fragment_R[j_k, j_v] = 0
|
||||
g_prod_X[0] = 0
|
||||
|
||||
# Main Loop
|
||||
for i_s in T.serial(num_iters):
|
||||
# [STAGE = i_s % num_stages]
|
||||
T.barrier_wait(
|
||||
data_is_ready[i_s % num_stages], (i_s // num_stages + 0) % 2
|
||||
)
|
||||
T.barrier_arrive(bar_0)
|
||||
|
||||
# [STAGE = i_s % num_stages] 0
|
||||
T.barrier_wait(bar_0, i_s % 2)
|
||||
# X = A^T @ K
|
||||
T.gemm_v1(
|
||||
a_shared[i_s % num_stages, :, :],
|
||||
k_shared[i_s % num_stages, :, :],
|
||||
x_fragment,
|
||||
transpose_A=True,
|
||||
clear_accum=True,
|
||||
)
|
||||
|
||||
# [STAGE = i_s % num_stages] 1
|
||||
# X = - b * X
|
||||
for j_s, j_k in T.Parallel(block_S, DK):
|
||||
x_fragment[j_s, j_k] *= -b_shared[i_s % num_stages, j_s]
|
||||
# S2[1] X
|
||||
T.copy(x_fragment, x_shared)
|
||||
T.barrier_arrive(bar_2)
|
||||
|
||||
if calc_mt:
|
||||
# [STAGE = i_s % num_stages] 2
|
||||
g_prod_X[0] += g_shared[i_s % num_stages, block_S - 1]
|
||||
# S4[2] M
|
||||
T.copy(m_fragment_R, m_shared_R)
|
||||
|
||||
# [STAGE = i_s % num_stages] 3
|
||||
T.barrier_wait(bar_3, i_s % 2)
|
||||
# Z = K @ M
|
||||
T.gemm_v1(
|
||||
k_shared[i_s % num_stages, :, :],
|
||||
m_shared_R,
|
||||
z_fragment_R,
|
||||
clear_accum=True,
|
||||
)
|
||||
# S4[2] Z
|
||||
T.copy(z_fragment_R, z_shared_R)
|
||||
# M += X^T @ Z
|
||||
T.gemm_v1(
|
||||
x_shared,
|
||||
z_shared_R,
|
||||
m_fragment_R,
|
||||
transpose_A=True,
|
||||
clear_accum=False,
|
||||
)
|
||||
|
||||
T.barrier_arrive(data_is_free[i_s % num_stages])
|
||||
|
||||
if calc_mt:
|
||||
g_last_local_X[0] = T.exp2(g_prod_X[0] * 1.442695)
|
||||
for j_k, j_v in T.Parallel(DK, DK // 2):
|
||||
m_fragment_R[j_k, j_v] *= g_last_local_X[0]
|
||||
T.copy(m_fragment_R, mt[bb, bh, 0:DK, DK // 2 :])
|
||||
|
||||
elif tx < 384:
|
||||
T.set_max_nreg(CONSUMER_Y_NREG, 1)
|
||||
|
||||
if calc_mt:
|
||||
for j_k, j_v in T.Parallel(DK, DK // 2):
|
||||
if j_k == j_v:
|
||||
m_fragment_L[j_k, j_v] = 1
|
||||
else:
|
||||
m_fragment_L[j_k, j_v] = 0
|
||||
g_prod_Y[0] = 0
|
||||
|
||||
# Main Loop
|
||||
for i_s in T.serial(num_iters):
|
||||
# [STAGE = i_s % num_stages]
|
||||
T.barrier_wait(
|
||||
data_is_ready[i_s % num_stages], (i_s // num_stages + 0) % 2
|
||||
)
|
||||
T.barrier_arrive(bar_0)
|
||||
|
||||
# [STAGE = i_s % num_stages] 0
|
||||
T.barrier_wait(bar_0, i_s % 2)
|
||||
# Precompute g_last/g
|
||||
g_last_local_Y[0] = g_shared[i_s % num_stages, block_S - 1]
|
||||
for j_s in T.Parallel(block_S):
|
||||
g_rev_exp_shared[j_s] = T.exp2(
|
||||
(g_last_local_Y[0] - g_shared[i_s % num_stages, j_s])
|
||||
* 1.442695
|
||||
)
|
||||
g_last_local_Y[0] = T.exp2(g_last_local_Y[0] * 1.442695)
|
||||
T.barrier_arrive(bar_1)
|
||||
|
||||
# [STAGE = i_s % num_stages] 1
|
||||
T.barrier_wait(bar_1, i_s % 2)
|
||||
# U = K @ S
|
||||
T.gemm_v1(
|
||||
k_shared[i_s % num_stages, :, :],
|
||||
h_shared,
|
||||
y_fragment,
|
||||
clear_accum=True,
|
||||
)
|
||||
# Y = g_last * U - g_last/g * V
|
||||
for j_s, j_v in T.Parallel(block_S, DV):
|
||||
y_fragment[j_s, j_v] *= g_last_local_Y[0]
|
||||
for j_s, j_v in T.Parallel(block_S, DV):
|
||||
y_fragment[j_s, j_v] -= (
|
||||
v_shared[i_s % num_stages, j_s, j_v] * g_rev_exp_shared[j_s]
|
||||
)
|
||||
# S2[2] Y
|
||||
T.copy(y_fragment, y_shared)
|
||||
T.barrier_arrive(bar_2)
|
||||
|
||||
if calc_mt:
|
||||
# [STAGE = i_s % num_stages] 2
|
||||
g_prod_Y[0] += g_shared[i_s % num_stages, block_S - 1]
|
||||
# S4[2] M
|
||||
T.copy(m_fragment_L, m_shared_L)
|
||||
|
||||
# [STAGE = i_s % num_stages] 3
|
||||
T.barrier_wait(bar_3, i_s % 2)
|
||||
# Z = K @ M
|
||||
T.gemm_v1(
|
||||
k_shared[i_s % num_stages, :, :],
|
||||
m_shared_L,
|
||||
z_fragment_L,
|
||||
clear_accum=True,
|
||||
)
|
||||
# S4[2] Z
|
||||
T.copy(z_fragment_L, z_shared_L)
|
||||
# M += X^T @ Z
|
||||
T.gemm_v1(
|
||||
x_shared,
|
||||
z_shared_L,
|
||||
m_fragment_L,
|
||||
transpose_A=True,
|
||||
clear_accum=False,
|
||||
)
|
||||
|
||||
T.barrier_arrive(data_is_free[i_s % num_stages])
|
||||
|
||||
if calc_mt:
|
||||
g_last_local_Y[0] = T.exp2(g_prod_Y[0] * 1.442695)
|
||||
for j_k, j_v in T.Parallel(DK, DK // 2):
|
||||
m_fragment_L[j_k, j_v] *= g_last_local_Y[0]
|
||||
T.copy(m_fragment_L, mt[bb, bh, 0:DK, : DK // 2])
|
||||
|
||||
else:
|
||||
T.set_max_nreg(PRODUCER_NREG, 0)
|
||||
|
||||
if tx < 384 + 32:
|
||||
for i_s in T.serial(num_iters):
|
||||
T.barrier_wait(
|
||||
data_is_free[i_s % num_stages], (i_s // num_stages + 1) % 2
|
||||
)
|
||||
left = seq_start_idx + i_s * block_S
|
||||
right = left + block_S
|
||||
|
||||
# Load K
|
||||
T.copy(
|
||||
k[batch_idx, left:right, bhg, 0:DK],
|
||||
k_shared[i_s % num_stages, :, :],
|
||||
)
|
||||
|
||||
T.barrier_arrive(data_is_ready[i_s % num_stages])
|
||||
|
||||
elif tx < 384 + 64:
|
||||
for i_s in T.serial(num_iters):
|
||||
T.barrier_wait(
|
||||
data_is_free[i_s % num_stages], (i_s // num_stages + 1) % 2
|
||||
)
|
||||
left = seq_start_idx + i_s * block_S
|
||||
right = left + block_S
|
||||
|
||||
# Load V
|
||||
T.copy(
|
||||
v[batch_idx, left:right, bh, 0:DV],
|
||||
v_shared[i_s % num_stages, :, :],
|
||||
)
|
||||
# Load A TODO: Mask A for the last chunk
|
||||
T.copy(
|
||||
a[batch_idx, left:right, bh, 0:block_S],
|
||||
a_shared[i_s % num_stages, :, :],
|
||||
)
|
||||
|
||||
T.barrier_arrive(data_is_ready[i_s % num_stages])
|
||||
|
||||
elif tx < 384 + 96:
|
||||
for i_s in T.serial(num_iters):
|
||||
T.barrier_wait(
|
||||
data_is_free[i_s % num_stages], (i_s // num_stages + 1) % 2
|
||||
)
|
||||
left = seq_start_idx + i_s * block_S
|
||||
right = left + block_S
|
||||
|
||||
# Load gamma
|
||||
if right <= seq_end_idx:
|
||||
for j_s in T.Parallel(block_S):
|
||||
g_shared[i_s % num_stages, j_s] = g[
|
||||
batch_idx, left + j_s, bh
|
||||
]
|
||||
else:
|
||||
for j_s in T.Parallel(block_S):
|
||||
if left + j_s < seq_end_idx:
|
||||
g_shared[i_s % num_stages, j_s] = g[
|
||||
batch_idx, left + j_s, bh
|
||||
]
|
||||
else:
|
||||
g_shared[i_s % num_stages, j_s] = g[
|
||||
batch_idx, seq_end_idx - 1, bh
|
||||
]
|
||||
# Load beta
|
||||
if right <= seq_end_idx:
|
||||
for j_s in T.Parallel(block_S):
|
||||
b_shared[i_s % num_stages, j_s] = b[
|
||||
batch_idx, left + j_s, bh
|
||||
]
|
||||
else:
|
||||
for j_s in T.Parallel(block_S):
|
||||
if left + j_s < seq_end_idx:
|
||||
b_shared[i_s % num_stages, j_s] = b[
|
||||
batch_idx, left + j_s, bh
|
||||
]
|
||||
else:
|
||||
b_shared[i_s % num_stages, j_s] = 0
|
||||
|
||||
T.barrier_arrive(data_is_ready[i_s % num_stages])
|
||||
|
||||
else:
|
||||
for i_s in T.serial(num_iters):
|
||||
T.barrier_arrive(bar_0)
|
||||
|
||||
T.barrier_wait(bar_0, i_s % 2)
|
||||
T.barrier_wait(bar_1, i_s % 2)
|
||||
# Store S
|
||||
if store_h:
|
||||
T.copy(
|
||||
h_shared,
|
||||
h[batch_idx, chunk_start_idx + i_s, bh, 0:DK, 0:DV],
|
||||
)
|
||||
|
||||
return tilelang_prepare_h_kernel
|
||||
|
||||
|
||||
def fused_gdr_h(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
a: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
initial_state: torch.Tensor | None = None,
|
||||
output_final_state: bool = True,
|
||||
output_h: bool = True,
|
||||
chunk_size: int = 64,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
num_warmup_chunks: torch.LongTensor | None = None,
|
||||
):
|
||||
batch_size, num_tokens, Hg, K = k.shape
|
||||
_, _, H, V = v.shape
|
||||
assert K == V == 128
|
||||
assert chunk_size == 64
|
||||
|
||||
if cu_seqlens is None:
|
||||
assert num_warmup_chunks is None
|
||||
real_batch_size = batch_size
|
||||
num_chunks = tilelang.cdiv(num_tokens, chunk_size) if output_h else 0
|
||||
cu_seqlens = torch.empty((batch_size + 1), dtype=torch.int32, device=k.device)
|
||||
chunk_offsets = torch.empty(
|
||||
(batch_size + 1), dtype=torch.int32, device=k.device
|
||||
)
|
||||
is_varlen = False
|
||||
is_cp = False
|
||||
else:
|
||||
real_batch_size = len(cu_seqlens) - 1
|
||||
chunk_offsets, num_chunks = prepare_chunk_offsets(cu_seqlens, chunk_size)
|
||||
chunk_offsets = chunk_offsets.to(cu_seqlens.dtype)
|
||||
num_chunks = num_chunks if output_h else 0
|
||||
is_varlen = True
|
||||
if num_warmup_chunks is None:
|
||||
num_warmup_chunks = torch.empty(
|
||||
(real_batch_size, H), dtype=cu_seqlens.dtype, device=k.device
|
||||
)
|
||||
is_cp = False
|
||||
else:
|
||||
is_cp = True
|
||||
|
||||
use_initial_state = initial_state is not None
|
||||
if initial_state is None:
|
||||
initial_state = torch.empty(
|
||||
(real_batch_size, H, K, V), dtype=torch.float32, device=k.device
|
||||
)
|
||||
h = torch.empty((batch_size, num_chunks, H, K, V), dtype=k.dtype, device=k.device)
|
||||
ht_dtype = k.dtype if is_cp else torch.float32
|
||||
final_state = torch.empty(
|
||||
(real_batch_size, H, K, V), dtype=ht_dtype, device=k.device
|
||||
)
|
||||
final_correction = torch.empty(
|
||||
(real_batch_size, H, K, K), dtype=ht_dtype, device=k.device
|
||||
)
|
||||
|
||||
tilelang_prepare_h_kernel = tilelang_prepare_h(
|
||||
H,
|
||||
Hg,
|
||||
K,
|
||||
V,
|
||||
chunk_size,
|
||||
qkva_dtype=k.dtype,
|
||||
g_dtype=g.dtype,
|
||||
b_dtype=b.dtype,
|
||||
h0_dtype=initial_state.dtype,
|
||||
ht_dtype=final_state.dtype,
|
||||
h_dtype=h.dtype,
|
||||
seqlen_dtype=cu_seqlens.dtype,
|
||||
accum_dtype="float32",
|
||||
use_initial_state=use_initial_state,
|
||||
store_final_state=output_final_state,
|
||||
store_h=output_h,
|
||||
is_varlen=is_varlen,
|
||||
is_cp=is_cp,
|
||||
)
|
||||
tilelang_prepare_h_kernel(
|
||||
k,
|
||||
v,
|
||||
a,
|
||||
g,
|
||||
b,
|
||||
initial_state,
|
||||
cu_seqlens,
|
||||
chunk_offsets,
|
||||
num_warmup_chunks,
|
||||
h,
|
||||
final_state,
|
||||
final_correction,
|
||||
)
|
||||
|
||||
if not output_final_state:
|
||||
final_state = None
|
||||
final_correction = None
|
||||
if not output_h:
|
||||
h = None
|
||||
|
||||
return h, final_state, final_correction
|
||||
6
flash_qla/ops/gated_delta_rule/legacy/__init__.py
Normal file
6
flash_qla/ops/gated_delta_rule/legacy/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# Copyright (c) 2026 The Qwen team, Alibaba Group.
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
from .sm_legacy import chunk_gated_delta_rule_fwd_legacy
|
||||
|
||||
__all__ = ["chunk_gated_delta_rule_fwd_legacy"]
|
||||
348
flash_qla/ops/gated_delta_rule/legacy/csrc/gdn_forward.cu
Normal file
348
flash_qla/ops/gated_delta_rule/legacy/csrc/gdn_forward.cu
Normal file
@@ -0,0 +1,348 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <cstdint>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace {
|
||||
|
||||
void check_cuda(cudaError_t status, const char* context) {
|
||||
if (status != cudaSuccess) {
|
||||
throw std::runtime_error(std::string(context) + ": " +
|
||||
cudaGetErrorString(status));
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float subgroup_sum_lane0(float value,
|
||||
int width) {
|
||||
constexpr unsigned mask = 0xffffffffU;
|
||||
for (int offset = width / 2; offset > 0; offset >>= 1) {
|
||||
value += __shfl_down_sync(mask, value, offset, width);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float subgroup_broadcast_lane0(float value,
|
||||
int width) {
|
||||
return __shfl_sync(0xffffffffU, value, 0, width);
|
||||
}
|
||||
|
||||
template <int D, int COLS, int WIDTH>
|
||||
__global__ void gdn_forward_kernel(const float* __restrict__ q,
|
||||
const float* __restrict__ k,
|
||||
const float* __restrict__ v,
|
||||
const float* __restrict__ gate,
|
||||
const float* __restrict__ beta,
|
||||
const float* __restrict__ initial_state,
|
||||
float* __restrict__ output,
|
||||
float* __restrict__ final_state,
|
||||
int batch,
|
||||
int tokens,
|
||||
int q_heads,
|
||||
int v_heads,
|
||||
float scale) {
|
||||
static_assert(D % (COLS * (32 / WIDTH)) == 0);
|
||||
constexpr int subgroups_per_warp = 32 / WIDTH;
|
||||
constexpr int rows_per_lane = (D + WIDTH - 1) / WIDTH;
|
||||
|
||||
const int hv = blockIdx.x;
|
||||
const int b = blockIdx.y;
|
||||
const int subgroup = threadIdx.x / WIDTH;
|
||||
const int lane = threadIdx.x % WIDTH;
|
||||
const int group_base =
|
||||
(blockIdx.z * blockDim.y + threadIdx.y) * subgroups_per_warp + subgroup;
|
||||
const int col_base = group_base * COLS;
|
||||
const int hq = hv / (v_heads / q_heads);
|
||||
|
||||
float state_shard[COLS][rows_per_lane];
|
||||
|
||||
#pragma unroll
|
||||
for (int c = 0; c < COLS; ++c) {
|
||||
const int col = col_base + c;
|
||||
#pragma unroll
|
||||
for (int r = 0; r < rows_per_lane; ++r) {
|
||||
const int row = r * WIDTH + lane;
|
||||
float value = 0.0F;
|
||||
if (row < D) {
|
||||
const auto state_index =
|
||||
(((static_cast<int64_t>(b) * v_heads + hv) * D + col) * D) + row;
|
||||
value = initial_state == nullptr ? 0.0F : initial_state[state_index];
|
||||
}
|
||||
state_shard[c][r] = value;
|
||||
}
|
||||
}
|
||||
|
||||
for (int t = 0; t < tokens; ++t) {
|
||||
const auto gate_index =
|
||||
((static_cast<int64_t>(b) * tokens + t) * v_heads + hv);
|
||||
float gate_value = 0.0F;
|
||||
float beta_value = 0.0F;
|
||||
if (threadIdx.x == 0) {
|
||||
gate_value = __expf(gate[gate_index]);
|
||||
beta_value = beta[gate_index];
|
||||
}
|
||||
gate_value = __shfl_sync(0xffffffffU, gate_value, 0);
|
||||
beta_value = __shfl_sync(0xffffffffU, beta_value, 0);
|
||||
|
||||
float k_reg[rows_per_lane];
|
||||
float q_reg[rows_per_lane];
|
||||
float kv_partial[COLS];
|
||||
#pragma unroll
|
||||
for (int c = 0; c < COLS; ++c) {
|
||||
kv_partial[c] = 0.0F;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int r = 0; r < rows_per_lane; ++r) {
|
||||
const int row = r * WIDTH + lane;
|
||||
float q_value = 0.0F;
|
||||
float k_value = 0.0F;
|
||||
if (row < D) {
|
||||
const auto qk_index =
|
||||
(((static_cast<int64_t>(b) * tokens + t) * q_heads + hq) * D) + row;
|
||||
q_value = q[qk_index];
|
||||
k_value = k[qk_index];
|
||||
}
|
||||
q_reg[r] = q_value;
|
||||
k_reg[r] = k_value;
|
||||
#pragma unroll
|
||||
for (int c = 0; c < COLS; ++c) {
|
||||
kv_partial[c] += state_shard[c][r] * k_value;
|
||||
}
|
||||
}
|
||||
|
||||
float delta[COLS];
|
||||
#pragma unroll
|
||||
for (int c = 0; c < COLS; ++c) {
|
||||
const float kv_col = subgroup_sum_lane0(kv_partial[c], WIDTH);
|
||||
float delta_value = 0.0F;
|
||||
if (lane == 0) {
|
||||
const auto v_index =
|
||||
(((static_cast<int64_t>(b) * tokens + t) * v_heads + hv) * D) +
|
||||
col_base + c;
|
||||
delta_value = (v[v_index] - gate_value * kv_col) * beta_value;
|
||||
}
|
||||
delta[c] = subgroup_broadcast_lane0(delta_value, WIDTH);
|
||||
}
|
||||
|
||||
float attn_partial[COLS];
|
||||
#pragma unroll
|
||||
for (int c = 0; c < COLS; ++c) {
|
||||
attn_partial[c] = 0.0F;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int r = 0; r < rows_per_lane; ++r) {
|
||||
#pragma unroll
|
||||
for (int c = 0; c < COLS; ++c) {
|
||||
const float new_state =
|
||||
fmaf(k_reg[r], delta[c], gate_value * state_shard[c][r]);
|
||||
state_shard[c][r] = new_state;
|
||||
attn_partial[c] += new_state * q_reg[r];
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int c = 0; c < COLS; ++c) {
|
||||
attn_partial[c] = subgroup_sum_lane0(attn_partial[c], WIDTH);
|
||||
}
|
||||
|
||||
if (lane == 0) {
|
||||
const auto out_base =
|
||||
(((static_cast<int64_t>(b) * tokens + t) * v_heads + hv) * D);
|
||||
#pragma unroll
|
||||
for (int c = 0; c < COLS; ++c) {
|
||||
output[out_base + col_base + c] = attn_partial[c] * scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int c = 0; c < COLS; ++c) {
|
||||
const int col = col_base + c;
|
||||
#pragma unroll
|
||||
for (int r = 0; r < rows_per_lane; ++r) {
|
||||
const int row = r * WIDTH + lane;
|
||||
if (row < D) {
|
||||
const auto state_index =
|
||||
(((static_cast<int64_t>(b) * v_heads + hv) * D + col) * D) + row;
|
||||
final_state[state_index] = state_shard[c][r];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int D>
|
||||
void launch_gdn_forward(const float* q,
|
||||
const float* k,
|
||||
const float* v,
|
||||
const float* gate,
|
||||
const float* beta,
|
||||
const float* initial_state,
|
||||
float* output,
|
||||
float* final_state,
|
||||
int batch,
|
||||
int tokens,
|
||||
int q_heads,
|
||||
int v_heads,
|
||||
float scale,
|
||||
cudaStream_t stream) {
|
||||
constexpr int cols = D == 128 ? 4 : 1;
|
||||
constexpr int width = D == 128 ? 16 : 32;
|
||||
constexpr int groups_per_warp = 32 / width;
|
||||
constexpr int column_groups_per_block = 8;
|
||||
const dim3 block(32, column_groups_per_block);
|
||||
const int groups = D / cols;
|
||||
const int z = (groups + column_groups_per_block * groups_per_warp - 1) /
|
||||
(column_groups_per_block * groups_per_warp);
|
||||
const dim3 grid(v_heads, batch, z);
|
||||
gdn_forward_kernel<D, cols, width>
|
||||
<<<grid, block, 0, stream>>>(q,
|
||||
k,
|
||||
v,
|
||||
gate,
|
||||
beta,
|
||||
initial_state,
|
||||
output,
|
||||
final_state,
|
||||
batch,
|
||||
tokens,
|
||||
q_heads,
|
||||
v_heads,
|
||||
scale);
|
||||
}
|
||||
|
||||
void validate_tensor(const torch::Tensor& tensor,
|
||||
const char* name,
|
||||
int64_t dims) {
|
||||
TORCH_CHECK(tensor.is_cuda(), name, " must be a CUDA tensor");
|
||||
TORCH_CHECK(tensor.scalar_type() == torch::kFloat32,
|
||||
name,
|
||||
" must be float32");
|
||||
TORCH_CHECK(tensor.is_contiguous(), name, " must be contiguous");
|
||||
TORCH_CHECK(tensor.dim() == dims, name, " has wrong rank");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::vector<torch::Tensor> gdn_forward(torch::Tensor q,
|
||||
torch::Tensor k,
|
||||
torch::Tensor v,
|
||||
torch::Tensor gate,
|
||||
torch::Tensor beta,
|
||||
c10::optional<torch::Tensor> initial_state,
|
||||
double scale) {
|
||||
validate_tensor(q, "q", 4);
|
||||
validate_tensor(k, "k", 4);
|
||||
validate_tensor(v, "v", 4);
|
||||
validate_tensor(gate, "gate", 3);
|
||||
validate_tensor(beta, "beta", 3);
|
||||
|
||||
TORCH_CHECK(q.sizes() == k.sizes(), "q and k must have the same shape");
|
||||
const int batch = static_cast<int>(q.size(0));
|
||||
const int tokens = static_cast<int>(q.size(1));
|
||||
const int q_heads = static_cast<int>(q.size(2));
|
||||
const int dim = static_cast<int>(q.size(3));
|
||||
const int v_heads = static_cast<int>(v.size(2));
|
||||
TORCH_CHECK(v.size(0) == batch && v.size(1) == tokens && v.size(3) == dim,
|
||||
"v must have shape [B, T, Hv, D] matching q/k");
|
||||
TORCH_CHECK(gate.size(0) == batch && gate.size(1) == tokens &&
|
||||
gate.size(2) == v_heads,
|
||||
"gate must have shape [B, T, Hv]");
|
||||
TORCH_CHECK(beta.sizes() == gate.sizes(),
|
||||
"beta must have the same shape as gate");
|
||||
TORCH_CHECK(v_heads % q_heads == 0, "Hv must be divisible by Hq");
|
||||
TORCH_CHECK(dim == 16 || dim == 32 || dim == 64 || dim == 128,
|
||||
"D must be one of 16, 32, 64, or 128");
|
||||
|
||||
const float* initial_ptr = nullptr;
|
||||
if (initial_state.has_value() && initial_state.value().defined()) {
|
||||
const auto& h0 = initial_state.value();
|
||||
validate_tensor(h0, "initial_state", 4);
|
||||
TORCH_CHECK(h0.size(0) == batch && h0.size(1) == v_heads &&
|
||||
h0.size(2) == dim && h0.size(3) == dim,
|
||||
"initial_state must have shape [B, Hv, D, D]");
|
||||
initial_ptr = h0.data_ptr<float>();
|
||||
}
|
||||
|
||||
auto output = torch::empty_like(v);
|
||||
auto final_state = torch::empty({batch, v_heads, dim, dim}, q.options());
|
||||
|
||||
const auto stream = at::cuda::getCurrentCUDAStream(q.device().index()).stream();
|
||||
switch (dim) {
|
||||
case 16:
|
||||
launch_gdn_forward<16>(q.data_ptr<float>(),
|
||||
k.data_ptr<float>(),
|
||||
v.data_ptr<float>(),
|
||||
gate.data_ptr<float>(),
|
||||
beta.data_ptr<float>(),
|
||||
initial_ptr,
|
||||
output.data_ptr<float>(),
|
||||
final_state.data_ptr<float>(),
|
||||
batch,
|
||||
tokens,
|
||||
q_heads,
|
||||
v_heads,
|
||||
static_cast<float>(scale),
|
||||
stream);
|
||||
break;
|
||||
case 32:
|
||||
launch_gdn_forward<32>(q.data_ptr<float>(),
|
||||
k.data_ptr<float>(),
|
||||
v.data_ptr<float>(),
|
||||
gate.data_ptr<float>(),
|
||||
beta.data_ptr<float>(),
|
||||
initial_ptr,
|
||||
output.data_ptr<float>(),
|
||||
final_state.data_ptr<float>(),
|
||||
batch,
|
||||
tokens,
|
||||
q_heads,
|
||||
v_heads,
|
||||
static_cast<float>(scale),
|
||||
stream);
|
||||
break;
|
||||
case 64:
|
||||
launch_gdn_forward<64>(q.data_ptr<float>(),
|
||||
k.data_ptr<float>(),
|
||||
v.data_ptr<float>(),
|
||||
gate.data_ptr<float>(),
|
||||
beta.data_ptr<float>(),
|
||||
initial_ptr,
|
||||
output.data_ptr<float>(),
|
||||
final_state.data_ptr<float>(),
|
||||
batch,
|
||||
tokens,
|
||||
q_heads,
|
||||
v_heads,
|
||||
static_cast<float>(scale),
|
||||
stream);
|
||||
break;
|
||||
case 128:
|
||||
launch_gdn_forward<128>(q.data_ptr<float>(),
|
||||
k.data_ptr<float>(),
|
||||
v.data_ptr<float>(),
|
||||
gate.data_ptr<float>(),
|
||||
beta.data_ptr<float>(),
|
||||
initial_ptr,
|
||||
output.data_ptr<float>(),
|
||||
final_state.data_ptr<float>(),
|
||||
batch,
|
||||
tokens,
|
||||
q_heads,
|
||||
v_heads,
|
||||
static_cast<float>(scale),
|
||||
stream);
|
||||
break;
|
||||
}
|
||||
check_cuda(cudaGetLastError(), "gdn_forward launch");
|
||||
return {output, final_state};
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("gdn_forward", &gdn_forward, "SM70/SM75 legacy GDN forward");
|
||||
}
|
||||
104
flash_qla/ops/gated_delta_rule/legacy/sm_legacy.py
Normal file
104
flash_qla/ops/gated_delta_rule/legacy/sm_legacy.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# Copyright (c) 2026 The Qwen team, Alibaba Group.
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
_EXT = None
|
||||
|
||||
|
||||
def _load_ext():
|
||||
global _EXT
|
||||
if _EXT is not None:
|
||||
return _EXT
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("SM70/SM75 legacy GDN backend requires CUDA")
|
||||
|
||||
os.environ.setdefault("TORCH_CUDA_ARCH_LIST", "7.0;7.5")
|
||||
src = Path(__file__).with_name("csrc") / "gdn_forward.cu"
|
||||
_EXT = load(
|
||||
name="flash_qla_legacy_gdn",
|
||||
sources=[str(src)],
|
||||
extra_cuda_cflags=["-O3"],
|
||||
extra_cflags=["-O3"],
|
||||
verbose=bool(int(os.environ.get("FLASH_QLA_LEGACY_VERBOSE_BUILD", "0"))),
|
||||
)
|
||||
return _EXT
|
||||
|
||||
|
||||
def _check_inputs(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
initial_state: torch.Tensor | None,
|
||||
) -> None:
|
||||
tensors = [q, k, v, g, beta]
|
||||
if initial_state is not None:
|
||||
tensors.append(initial_state)
|
||||
|
||||
if any(not tensor.is_cuda for tensor in tensors):
|
||||
raise ValueError("legacy GDN tensors must be CUDA tensors")
|
||||
# if any(tensor.dtype != torch.float32 for tensor in tensors):
|
||||
# raise ValueError("legacy GDN backend currently supports float32 tensors only")
|
||||
if any(not tensor.is_contiguous() for tensor in tensors):
|
||||
raise ValueError("legacy GDN tensors must be contiguous")
|
||||
if q.ndim != 4 or k.ndim != 4 or v.ndim != 4:
|
||||
raise ValueError("q, k, and v must have shape [B, T, H, D]")
|
||||
if g.ndim != 3 or beta.ndim != 3:
|
||||
raise ValueError("g and beta must have shape [B, T, Hv]")
|
||||
if q.shape != k.shape:
|
||||
raise ValueError("q and k must have the same shape")
|
||||
|
||||
batch, tokens, q_heads, dim = q.shape
|
||||
if v.shape[0] != batch or v.shape[1] != tokens or v.shape[3] != dim:
|
||||
raise ValueError("v must have shape [B, T, Hv, D] matching q/k")
|
||||
if g.shape != beta.shape or g.shape != v.shape[:3]:
|
||||
raise ValueError("g and beta must have shape [B, T, Hv]")
|
||||
if v.shape[2] % q_heads != 0:
|
||||
raise ValueError("Hv must be divisible by Hq")
|
||||
if dim not in (16, 32, 64, 128):
|
||||
raise ValueError("legacy GDN backend supports D in {16, 32, 64, 128}")
|
||||
if initial_state is not None and initial_state.shape != (batch, v.shape[2], dim, dim):
|
||||
raise ValueError("initial_state must have shape [B, Hv, D, D]")
|
||||
|
||||
|
||||
def chunk_gated_delta_rule_fwd_legacy(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float | None = None,
|
||||
initial_state: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Run the experimental SM70/SM75 forward-only GDN backend.
|
||||
|
||||
This legacy backend is intentionally explicit. It does not replace the
|
||||
Hopper/SM90 TileLang path and currently supports only contiguous float32
|
||||
tensors for inference-oriented forward execution.
|
||||
|
||||
Shapes:
|
||||
q, k: [B, T, Hq, D]
|
||||
v: [B, T, Hv, D]
|
||||
g, beta: [B, T, Hv]
|
||||
initial_state: optional [B, Hv, D, D]
|
||||
|
||||
Returns:
|
||||
output: [B, T, Hv, D]
|
||||
final_state: [B, Hv, D, D]
|
||||
"""
|
||||
|
||||
_check_inputs(q, k, v, g, beta, initial_state)
|
||||
if scale is None:
|
||||
scale = q.shape[-1] ** -0.5
|
||||
|
||||
ext = _load_ext()
|
||||
return ext.gdn_forward(q, k, v, g, beta, initial_state, float(scale))
|
||||
11
flash_qla/ops/utils/__init__.py
Normal file
11
flash_qla/ops/utils/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
# Copyright (c) 2026 The Qwen team, Alibaba Group.
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
from .cumsum import chunk_local_cumsum
|
||||
from .group_reduce import group_reduce_vector
|
||||
|
||||
|
||||
__all__ = [
|
||||
"chunk_local_cumsum",
|
||||
"group_reduce_vector",
|
||||
]
|
||||
165
flash_qla/ops/utils/cumsum.py
Normal file
165
flash_qla/ops/utils/cumsum.py
Normal file
@@ -0,0 +1,165 @@
|
||||
# Copyright (c) 2026 The Qwen team, Alibaba Group.
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
import torch
|
||||
import tilelang
|
||||
import tilelang.language as T
|
||||
|
||||
from flash_qla.utils import prepare_chunk_indices
|
||||
|
||||
|
||||
@tilelang.jit(
|
||||
# out_idx=[-1],
|
||||
)
|
||||
def tilelang_chunk_local_cumsum(
|
||||
H,
|
||||
chunk_size,
|
||||
accum_dtype,
|
||||
g_dtype,
|
||||
seqlen_dtype,
|
||||
is_varlen,
|
||||
reverse,
|
||||
):
|
||||
data_batch_size = T.dynamic("data_batch_size")
|
||||
real_batch_size = T.dynamic("real_batch_size")
|
||||
num_tokens = T.dynamic("num_tokens")
|
||||
num_chunks = T.dynamic("num_chunks")
|
||||
block_S = chunk_size
|
||||
|
||||
g_shape = (data_batch_size, num_tokens, H)
|
||||
|
||||
@T.macro
|
||||
def kernel_body(
|
||||
bb,
|
||||
bc,
|
||||
batch_idx,
|
||||
chunk_idx,
|
||||
seq_start_idx,
|
||||
seq_end_idx,
|
||||
g_raw,
|
||||
g_cumsum,
|
||||
):
|
||||
left = seq_start_idx + chunk_idx * block_S
|
||||
right = left + block_S
|
||||
|
||||
g_fragment = T.alloc_fragment((H, block_S), dtype=accum_dtype)
|
||||
gT_fragment = T.alloc_fragment((block_S, H), dtype=g_dtype)
|
||||
gT_shared = T.alloc_shared((block_S, H + 1), dtype=g_dtype)
|
||||
|
||||
if right <= seq_end_idx:
|
||||
T.copy(g_raw[bb, left:right, 0:H], gT_fragment)
|
||||
else:
|
||||
for j, i in T.Parallel(block_S, H):
|
||||
if left + j < seq_end_idx:
|
||||
gT_fragment[j, i] = g_raw[bb, left + j, i]
|
||||
else:
|
||||
gT_fragment[j, i] = 0
|
||||
T.copy(gT_fragment, gT_shared[:, :H])
|
||||
|
||||
for i, j in T.Parallel(H, block_S):
|
||||
g_fragment[i, j] = gT_shared[j, i]
|
||||
|
||||
T.cumsum(g_fragment, dim=1, reverse=reverse)
|
||||
|
||||
for i, j in T.Parallel(H, block_S):
|
||||
gT_shared[j, i] = g_fragment[i, j]
|
||||
|
||||
T.copy(gT_shared[:, :H], gT_fragment)
|
||||
if right <= seq_end_idx:
|
||||
T.copy(gT_fragment, g_cumsum[bb, left:right, 0:H])
|
||||
else:
|
||||
for j, i in T.Parallel(block_S, H):
|
||||
if left + j < seq_end_idx:
|
||||
g_cumsum[bb, left + j, i] = gT_fragment[j, i]
|
||||
|
||||
if is_varlen:
|
||||
|
||||
@T.prim_func
|
||||
def tilelang_chunk_local_cumsum_kernel(
|
||||
g_raw: T.Tensor(g_shape, dtype=g_dtype),
|
||||
cu_seqlens: T.Tensor([real_batch_size + 1], dtype=seqlen_dtype),
|
||||
chunk_indices: T.Tensor([num_chunks, 2], dtype=seqlen_dtype),
|
||||
g_cumsum: T.Tensor(g_shape, dtype=g_dtype),
|
||||
):
|
||||
with T.Kernel(num_chunks, threads=128) as (bc,):
|
||||
bb = 0
|
||||
batch_idx = chunk_indices[bc, 0]
|
||||
chunk_idx = chunk_indices[bc, 1]
|
||||
seq_start_idx = cu_seqlens[batch_idx]
|
||||
seq_end_idx = cu_seqlens[batch_idx + 1]
|
||||
|
||||
kernel_body(
|
||||
bb,
|
||||
bc,
|
||||
batch_idx,
|
||||
chunk_idx,
|
||||
seq_start_idx,
|
||||
seq_end_idx,
|
||||
g_raw,
|
||||
g_cumsum,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
@T.prim_func
|
||||
def tilelang_chunk_local_cumsum_kernel(
|
||||
g_raw: T.Tensor(g_shape, dtype=g_dtype),
|
||||
g_cumsum: T.Tensor(g_shape, dtype=g_dtype),
|
||||
num_chunks: T.int32,
|
||||
):
|
||||
with T.Kernel(num_chunks, threads=128) as (bc,):
|
||||
bb = bc % data_batch_size
|
||||
batch_idx = bb
|
||||
chunk_idx = bc // data_batch_size
|
||||
seq_start_idx = 0
|
||||
seq_end_idx = num_tokens
|
||||
|
||||
kernel_body(
|
||||
bb,
|
||||
bc,
|
||||
batch_idx,
|
||||
chunk_idx,
|
||||
seq_start_idx,
|
||||
seq_end_idx,
|
||||
g_raw,
|
||||
g_cumsum,
|
||||
)
|
||||
|
||||
return tilelang_chunk_local_cumsum_kernel
|
||||
|
||||
|
||||
def chunk_local_cumsum(
|
||||
g: torch.Tensor,
|
||||
chunk_size: int = 64,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
reverse: bool = False,
|
||||
):
|
||||
batch_size, num_tokens, H = g.shape
|
||||
assert g.stride(-1) == 1
|
||||
|
||||
if cu_seqlens is None:
|
||||
num_chunks = batch_size * tilelang.cdiv(num_tokens, chunk_size)
|
||||
seqlen_dtype = "int32"
|
||||
is_varlen = False
|
||||
else:
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
|
||||
seqlen_dtype = cu_seqlens.dtype
|
||||
is_varlen = True
|
||||
|
||||
g_cumsum = torch.empty_like(g)
|
||||
|
||||
tilelang_chunk_local_cumsum_kernel = tilelang_chunk_local_cumsum(
|
||||
H,
|
||||
chunk_size,
|
||||
g_dtype=g.dtype,
|
||||
seqlen_dtype=seqlen_dtype,
|
||||
accum_dtype="float32",
|
||||
is_varlen=is_varlen,
|
||||
reverse=reverse,
|
||||
)
|
||||
if is_varlen:
|
||||
tilelang_chunk_local_cumsum_kernel(g, cu_seqlens, chunk_indices, g_cumsum)
|
||||
else:
|
||||
tilelang_chunk_local_cumsum_kernel(g, g_cumsum, num_chunks)
|
||||
|
||||
return g_cumsum
|
||||
79
flash_qla/ops/utils/group_reduce.py
Normal file
79
flash_qla/ops/utils/group_reduce.py
Normal file
@@ -0,0 +1,79 @@
|
||||
# Copyright (c) 2026 The Qwen team, Alibaba Group.
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
import torch
|
||||
import tilelang
|
||||
import tilelang.language as T
|
||||
|
||||
|
||||
@tilelang.jit(
|
||||
# out_idx=[-1],
|
||||
)
|
||||
def tilelang_group_reduce_vector(
|
||||
H,
|
||||
Hg,
|
||||
DK,
|
||||
accum_dtype,
|
||||
qkva_dtype,
|
||||
block_size: int = 16,
|
||||
):
|
||||
batch_size = T.dynamic("batch_size")
|
||||
num_tokens = T.dynamic("num_tokens")
|
||||
|
||||
group_size = H // Hg
|
||||
|
||||
buffer_shape = (batch_size, num_tokens, H, DK)
|
||||
dqk_shape = (batch_size, num_tokens, Hg, DK)
|
||||
|
||||
@T.prim_func
|
||||
def tilelang_group_reduce_vector_kernel(
|
||||
buffer: T.Tensor(buffer_shape, dtype=qkva_dtype),
|
||||
result: T.Tensor(dqk_shape, dtype=qkva_dtype),
|
||||
):
|
||||
with T.Kernel(
|
||||
tilelang.cdiv(num_tokens, block_size), Hg, batch_size, threads=128
|
||||
) as (bt, bhg, bb):
|
||||
buffer_fragment = T.alloc_fragment((block_size, DK), dtype=accum_dtype)
|
||||
result_fragment = T.alloc_fragment((block_size, DK), dtype=accum_dtype)
|
||||
|
||||
T.clear(result_fragment)
|
||||
for i in T.serial(group_size):
|
||||
T.copy(
|
||||
buffer[
|
||||
bb,
|
||||
bt * block_size : (bt + 1) * block_size,
|
||||
bhg * group_size + i,
|
||||
0:DK,
|
||||
],
|
||||
buffer_fragment,
|
||||
)
|
||||
for j, k in T.Parallel(block_size, DK):
|
||||
result_fragment[j, k] += buffer_fragment[j, k]
|
||||
T.copy(
|
||||
result_fragment,
|
||||
result[bb, bt * block_size : (bt + 1) * block_size, bhg, 0:DK],
|
||||
)
|
||||
|
||||
return tilelang_group_reduce_vector_kernel
|
||||
|
||||
|
||||
def group_reduce_vector(
|
||||
buffer: torch.Tensor,
|
||||
Hg: int,
|
||||
):
|
||||
batch_size, num_tokens, H, K = buffer.shape
|
||||
|
||||
result = torch.empty(
|
||||
(batch_size, num_tokens, Hg, K), dtype=buffer.dtype, device=buffer.device
|
||||
)
|
||||
|
||||
tilelang_group_reduce_vector_kernel = tilelang_group_reduce_vector(
|
||||
H,
|
||||
Hg,
|
||||
K,
|
||||
qkva_dtype=buffer.dtype,
|
||||
accum_dtype="float32",
|
||||
)
|
||||
tilelang_group_reduce_vector_kernel(buffer, result)
|
||||
|
||||
return result
|
||||
Reference in New Issue
Block a user