first commit

This commit is contained in:
2026-06-14 23:49:03 +08:00
commit 3f95e2939d
35 changed files with 6764 additions and 0 deletions

View File

@@ -0,0 +1,6 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
from .sm_legacy import chunk_gated_delta_rule_fwd_legacy
__all__ = ["chunk_gated_delta_rule_fwd_legacy"]

View File

@@ -0,0 +1,348 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cstdint>
#include <stdexcept>
#include <string>
#include <vector>
namespace {
void check_cuda(cudaError_t status, const char* context) {
if (status != cudaSuccess) {
throw std::runtime_error(std::string(context) + ": " +
cudaGetErrorString(status));
}
}
__device__ __forceinline__ float subgroup_sum_lane0(float value,
int width) {
constexpr unsigned mask = 0xffffffffU;
for (int offset = width / 2; offset > 0; offset >>= 1) {
value += __shfl_down_sync(mask, value, offset, width);
}
return value;
}
__device__ __forceinline__ float subgroup_broadcast_lane0(float value,
int width) {
return __shfl_sync(0xffffffffU, value, 0, width);
}
template <int D, int COLS, int WIDTH>
__global__ void gdn_forward_kernel(const float* __restrict__ q,
const float* __restrict__ k,
const float* __restrict__ v,
const float* __restrict__ gate,
const float* __restrict__ beta,
const float* __restrict__ initial_state,
float* __restrict__ output,
float* __restrict__ final_state,
int batch,
int tokens,
int q_heads,
int v_heads,
float scale) {
static_assert(D % (COLS * (32 / WIDTH)) == 0);
constexpr int subgroups_per_warp = 32 / WIDTH;
constexpr int rows_per_lane = (D + WIDTH - 1) / WIDTH;
const int hv = blockIdx.x;
const int b = blockIdx.y;
const int subgroup = threadIdx.x / WIDTH;
const int lane = threadIdx.x % WIDTH;
const int group_base =
(blockIdx.z * blockDim.y + threadIdx.y) * subgroups_per_warp + subgroup;
const int col_base = group_base * COLS;
const int hq = hv / (v_heads / q_heads);
float state_shard[COLS][rows_per_lane];
#pragma unroll
for (int c = 0; c < COLS; ++c) {
const int col = col_base + c;
#pragma unroll
for (int r = 0; r < rows_per_lane; ++r) {
const int row = r * WIDTH + lane;
float value = 0.0F;
if (row < D) {
const auto state_index =
(((static_cast<int64_t>(b) * v_heads + hv) * D + col) * D) + row;
value = initial_state == nullptr ? 0.0F : initial_state[state_index];
}
state_shard[c][r] = value;
}
}
for (int t = 0; t < tokens; ++t) {
const auto gate_index =
((static_cast<int64_t>(b) * tokens + t) * v_heads + hv);
float gate_value = 0.0F;
float beta_value = 0.0F;
if (threadIdx.x == 0) {
gate_value = __expf(gate[gate_index]);
beta_value = beta[gate_index];
}
gate_value = __shfl_sync(0xffffffffU, gate_value, 0);
beta_value = __shfl_sync(0xffffffffU, beta_value, 0);
float k_reg[rows_per_lane];
float q_reg[rows_per_lane];
float kv_partial[COLS];
#pragma unroll
for (int c = 0; c < COLS; ++c) {
kv_partial[c] = 0.0F;
}
#pragma unroll
for (int r = 0; r < rows_per_lane; ++r) {
const int row = r * WIDTH + lane;
float q_value = 0.0F;
float k_value = 0.0F;
if (row < D) {
const auto qk_index =
(((static_cast<int64_t>(b) * tokens + t) * q_heads + hq) * D) + row;
q_value = q[qk_index];
k_value = k[qk_index];
}
q_reg[r] = q_value;
k_reg[r] = k_value;
#pragma unroll
for (int c = 0; c < COLS; ++c) {
kv_partial[c] += state_shard[c][r] * k_value;
}
}
float delta[COLS];
#pragma unroll
for (int c = 0; c < COLS; ++c) {
const float kv_col = subgroup_sum_lane0(kv_partial[c], WIDTH);
float delta_value = 0.0F;
if (lane == 0) {
const auto v_index =
(((static_cast<int64_t>(b) * tokens + t) * v_heads + hv) * D) +
col_base + c;
delta_value = (v[v_index] - gate_value * kv_col) * beta_value;
}
delta[c] = subgroup_broadcast_lane0(delta_value, WIDTH);
}
float attn_partial[COLS];
#pragma unroll
for (int c = 0; c < COLS; ++c) {
attn_partial[c] = 0.0F;
}
#pragma unroll
for (int r = 0; r < rows_per_lane; ++r) {
#pragma unroll
for (int c = 0; c < COLS; ++c) {
const float new_state =
fmaf(k_reg[r], delta[c], gate_value * state_shard[c][r]);
state_shard[c][r] = new_state;
attn_partial[c] += new_state * q_reg[r];
}
}
#pragma unroll
for (int c = 0; c < COLS; ++c) {
attn_partial[c] = subgroup_sum_lane0(attn_partial[c], WIDTH);
}
if (lane == 0) {
const auto out_base =
(((static_cast<int64_t>(b) * tokens + t) * v_heads + hv) * D);
#pragma unroll
for (int c = 0; c < COLS; ++c) {
output[out_base + col_base + c] = attn_partial[c] * scale;
}
}
}
#pragma unroll
for (int c = 0; c < COLS; ++c) {
const int col = col_base + c;
#pragma unroll
for (int r = 0; r < rows_per_lane; ++r) {
const int row = r * WIDTH + lane;
if (row < D) {
const auto state_index =
(((static_cast<int64_t>(b) * v_heads + hv) * D + col) * D) + row;
final_state[state_index] = state_shard[c][r];
}
}
}
}
template <int D>
void launch_gdn_forward(const float* q,
const float* k,
const float* v,
const float* gate,
const float* beta,
const float* initial_state,
float* output,
float* final_state,
int batch,
int tokens,
int q_heads,
int v_heads,
float scale,
cudaStream_t stream) {
constexpr int cols = D == 128 ? 4 : 1;
constexpr int width = D == 128 ? 16 : 32;
constexpr int groups_per_warp = 32 / width;
constexpr int column_groups_per_block = 8;
const dim3 block(32, column_groups_per_block);
const int groups = D / cols;
const int z = (groups + column_groups_per_block * groups_per_warp - 1) /
(column_groups_per_block * groups_per_warp);
const dim3 grid(v_heads, batch, z);
gdn_forward_kernel<D, cols, width>
<<<grid, block, 0, stream>>>(q,
k,
v,
gate,
beta,
initial_state,
output,
final_state,
batch,
tokens,
q_heads,
v_heads,
scale);
}
void validate_tensor(const torch::Tensor& tensor,
const char* name,
int64_t dims) {
TORCH_CHECK(tensor.is_cuda(), name, " must be a CUDA tensor");
TORCH_CHECK(tensor.scalar_type() == torch::kFloat32,
name,
" must be float32");
TORCH_CHECK(tensor.is_contiguous(), name, " must be contiguous");
TORCH_CHECK(tensor.dim() == dims, name, " has wrong rank");
}
} // namespace
std::vector<torch::Tensor> gdn_forward(torch::Tensor q,
torch::Tensor k,
torch::Tensor v,
torch::Tensor gate,
torch::Tensor beta,
c10::optional<torch::Tensor> initial_state,
double scale) {
validate_tensor(q, "q", 4);
validate_tensor(k, "k", 4);
validate_tensor(v, "v", 4);
validate_tensor(gate, "gate", 3);
validate_tensor(beta, "beta", 3);
TORCH_CHECK(q.sizes() == k.sizes(), "q and k must have the same shape");
const int batch = static_cast<int>(q.size(0));
const int tokens = static_cast<int>(q.size(1));
const int q_heads = static_cast<int>(q.size(2));
const int dim = static_cast<int>(q.size(3));
const int v_heads = static_cast<int>(v.size(2));
TORCH_CHECK(v.size(0) == batch && v.size(1) == tokens && v.size(3) == dim,
"v must have shape [B, T, Hv, D] matching q/k");
TORCH_CHECK(gate.size(0) == batch && gate.size(1) == tokens &&
gate.size(2) == v_heads,
"gate must have shape [B, T, Hv]");
TORCH_CHECK(beta.sizes() == gate.sizes(),
"beta must have the same shape as gate");
TORCH_CHECK(v_heads % q_heads == 0, "Hv must be divisible by Hq");
TORCH_CHECK(dim == 16 || dim == 32 || dim == 64 || dim == 128,
"D must be one of 16, 32, 64, or 128");
const float* initial_ptr = nullptr;
if (initial_state.has_value() && initial_state.value().defined()) {
const auto& h0 = initial_state.value();
validate_tensor(h0, "initial_state", 4);
TORCH_CHECK(h0.size(0) == batch && h0.size(1) == v_heads &&
h0.size(2) == dim && h0.size(3) == dim,
"initial_state must have shape [B, Hv, D, D]");
initial_ptr = h0.data_ptr<float>();
}
auto output = torch::empty_like(v);
auto final_state = torch::empty({batch, v_heads, dim, dim}, q.options());
const auto stream = at::cuda::getCurrentCUDAStream(q.device().index()).stream();
switch (dim) {
case 16:
launch_gdn_forward<16>(q.data_ptr<float>(),
k.data_ptr<float>(),
v.data_ptr<float>(),
gate.data_ptr<float>(),
beta.data_ptr<float>(),
initial_ptr,
output.data_ptr<float>(),
final_state.data_ptr<float>(),
batch,
tokens,
q_heads,
v_heads,
static_cast<float>(scale),
stream);
break;
case 32:
launch_gdn_forward<32>(q.data_ptr<float>(),
k.data_ptr<float>(),
v.data_ptr<float>(),
gate.data_ptr<float>(),
beta.data_ptr<float>(),
initial_ptr,
output.data_ptr<float>(),
final_state.data_ptr<float>(),
batch,
tokens,
q_heads,
v_heads,
static_cast<float>(scale),
stream);
break;
case 64:
launch_gdn_forward<64>(q.data_ptr<float>(),
k.data_ptr<float>(),
v.data_ptr<float>(),
gate.data_ptr<float>(),
beta.data_ptr<float>(),
initial_ptr,
output.data_ptr<float>(),
final_state.data_ptr<float>(),
batch,
tokens,
q_heads,
v_heads,
static_cast<float>(scale),
stream);
break;
case 128:
launch_gdn_forward<128>(q.data_ptr<float>(),
k.data_ptr<float>(),
v.data_ptr<float>(),
gate.data_ptr<float>(),
beta.data_ptr<float>(),
initial_ptr,
output.data_ptr<float>(),
final_state.data_ptr<float>(),
batch,
tokens,
q_heads,
v_heads,
static_cast<float>(scale),
stream);
break;
}
check_cuda(cudaGetLastError(), "gdn_forward launch");
return {output, final_state};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("gdn_forward", &gdn_forward, "SM70/SM75 legacy GDN forward");
}

View File

@@ -0,0 +1,104 @@
# Copyright (c) 2026 The Qwen team, Alibaba Group.
# Licensed under The MIT License [see LICENSE for details]
from __future__ import annotations
import os
from pathlib import Path
import torch
from torch.utils.cpp_extension import load
_EXT = None
def _load_ext():
global _EXT
if _EXT is not None:
return _EXT
if not torch.cuda.is_available():
raise RuntimeError("SM70/SM75 legacy GDN backend requires CUDA")
os.environ.setdefault("TORCH_CUDA_ARCH_LIST", "7.0;7.5")
src = Path(__file__).with_name("csrc") / "gdn_forward.cu"
_EXT = load(
name="flash_qla_legacy_gdn",
sources=[str(src)],
extra_cuda_cflags=["-O3"],
extra_cflags=["-O3"],
verbose=bool(int(os.environ.get("FLASH_QLA_LEGACY_VERBOSE_BUILD", "0"))),
)
return _EXT
def _check_inputs(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
initial_state: torch.Tensor | None,
) -> None:
tensors = [q, k, v, g, beta]
if initial_state is not None:
tensors.append(initial_state)
if any(not tensor.is_cuda for tensor in tensors):
raise ValueError("legacy GDN tensors must be CUDA tensors")
# if any(tensor.dtype != torch.float32 for tensor in tensors):
# raise ValueError("legacy GDN backend currently supports float32 tensors only")
if any(not tensor.is_contiguous() for tensor in tensors):
raise ValueError("legacy GDN tensors must be contiguous")
if q.ndim != 4 or k.ndim != 4 or v.ndim != 4:
raise ValueError("q, k, and v must have shape [B, T, H, D]")
if g.ndim != 3 or beta.ndim != 3:
raise ValueError("g and beta must have shape [B, T, Hv]")
if q.shape != k.shape:
raise ValueError("q and k must have the same shape")
batch, tokens, q_heads, dim = q.shape
if v.shape[0] != batch or v.shape[1] != tokens or v.shape[3] != dim:
raise ValueError("v must have shape [B, T, Hv, D] matching q/k")
if g.shape != beta.shape or g.shape != v.shape[:3]:
raise ValueError("g and beta must have shape [B, T, Hv]")
if v.shape[2] % q_heads != 0:
raise ValueError("Hv must be divisible by Hq")
if dim not in (16, 32, 64, 128):
raise ValueError("legacy GDN backend supports D in {16, 32, 64, 128}")
if initial_state is not None and initial_state.shape != (batch, v.shape[2], dim, dim):
raise ValueError("initial_state must have shape [B, Hv, D, D]")
def chunk_gated_delta_rule_fwd_legacy(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float | None = None,
initial_state: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Run the experimental SM70/SM75 forward-only GDN backend.
This legacy backend is intentionally explicit. It does not replace the
Hopper/SM90 TileLang path and currently supports only contiguous float32
tensors for inference-oriented forward execution.
Shapes:
q, k: [B, T, Hq, D]
v: [B, T, Hv, D]
g, beta: [B, T, Hv]
initial_state: optional [B, Hv, D, D]
Returns:
output: [B, T, Hv, D]
final_state: [B, Hv, D, D]
"""
_check_inputs(q, k, v, g, beta, initial_state)
if scale is None:
scale = q.shape[-1] ** -0.5
ext = _load_ext()
return ext.gdn_forward(q, k, v, g, beta, initial_state, float(scale))