first commit
This commit is contained in:
6
flash_qla/ops/gated_delta_rule/legacy/__init__.py
Normal file
6
flash_qla/ops/gated_delta_rule/legacy/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# Copyright (c) 2026 The Qwen team, Alibaba Group.
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
from .sm_legacy import chunk_gated_delta_rule_fwd_legacy
|
||||
|
||||
__all__ = ["chunk_gated_delta_rule_fwd_legacy"]
|
||||
348
flash_qla/ops/gated_delta_rule/legacy/csrc/gdn_forward.cu
Normal file
348
flash_qla/ops/gated_delta_rule/legacy/csrc/gdn_forward.cu
Normal file
@@ -0,0 +1,348 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <cstdint>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace {
|
||||
|
||||
void check_cuda(cudaError_t status, const char* context) {
|
||||
if (status != cudaSuccess) {
|
||||
throw std::runtime_error(std::string(context) + ": " +
|
||||
cudaGetErrorString(status));
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float subgroup_sum_lane0(float value,
|
||||
int width) {
|
||||
constexpr unsigned mask = 0xffffffffU;
|
||||
for (int offset = width / 2; offset > 0; offset >>= 1) {
|
||||
value += __shfl_down_sync(mask, value, offset, width);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float subgroup_broadcast_lane0(float value,
|
||||
int width) {
|
||||
return __shfl_sync(0xffffffffU, value, 0, width);
|
||||
}
|
||||
|
||||
template <int D, int COLS, int WIDTH>
|
||||
__global__ void gdn_forward_kernel(const float* __restrict__ q,
|
||||
const float* __restrict__ k,
|
||||
const float* __restrict__ v,
|
||||
const float* __restrict__ gate,
|
||||
const float* __restrict__ beta,
|
||||
const float* __restrict__ initial_state,
|
||||
float* __restrict__ output,
|
||||
float* __restrict__ final_state,
|
||||
int batch,
|
||||
int tokens,
|
||||
int q_heads,
|
||||
int v_heads,
|
||||
float scale) {
|
||||
static_assert(D % (COLS * (32 / WIDTH)) == 0);
|
||||
constexpr int subgroups_per_warp = 32 / WIDTH;
|
||||
constexpr int rows_per_lane = (D + WIDTH - 1) / WIDTH;
|
||||
|
||||
const int hv = blockIdx.x;
|
||||
const int b = blockIdx.y;
|
||||
const int subgroup = threadIdx.x / WIDTH;
|
||||
const int lane = threadIdx.x % WIDTH;
|
||||
const int group_base =
|
||||
(blockIdx.z * blockDim.y + threadIdx.y) * subgroups_per_warp + subgroup;
|
||||
const int col_base = group_base * COLS;
|
||||
const int hq = hv / (v_heads / q_heads);
|
||||
|
||||
float state_shard[COLS][rows_per_lane];
|
||||
|
||||
#pragma unroll
|
||||
for (int c = 0; c < COLS; ++c) {
|
||||
const int col = col_base + c;
|
||||
#pragma unroll
|
||||
for (int r = 0; r < rows_per_lane; ++r) {
|
||||
const int row = r * WIDTH + lane;
|
||||
float value = 0.0F;
|
||||
if (row < D) {
|
||||
const auto state_index =
|
||||
(((static_cast<int64_t>(b) * v_heads + hv) * D + col) * D) + row;
|
||||
value = initial_state == nullptr ? 0.0F : initial_state[state_index];
|
||||
}
|
||||
state_shard[c][r] = value;
|
||||
}
|
||||
}
|
||||
|
||||
for (int t = 0; t < tokens; ++t) {
|
||||
const auto gate_index =
|
||||
((static_cast<int64_t>(b) * tokens + t) * v_heads + hv);
|
||||
float gate_value = 0.0F;
|
||||
float beta_value = 0.0F;
|
||||
if (threadIdx.x == 0) {
|
||||
gate_value = __expf(gate[gate_index]);
|
||||
beta_value = beta[gate_index];
|
||||
}
|
||||
gate_value = __shfl_sync(0xffffffffU, gate_value, 0);
|
||||
beta_value = __shfl_sync(0xffffffffU, beta_value, 0);
|
||||
|
||||
float k_reg[rows_per_lane];
|
||||
float q_reg[rows_per_lane];
|
||||
float kv_partial[COLS];
|
||||
#pragma unroll
|
||||
for (int c = 0; c < COLS; ++c) {
|
||||
kv_partial[c] = 0.0F;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int r = 0; r < rows_per_lane; ++r) {
|
||||
const int row = r * WIDTH + lane;
|
||||
float q_value = 0.0F;
|
||||
float k_value = 0.0F;
|
||||
if (row < D) {
|
||||
const auto qk_index =
|
||||
(((static_cast<int64_t>(b) * tokens + t) * q_heads + hq) * D) + row;
|
||||
q_value = q[qk_index];
|
||||
k_value = k[qk_index];
|
||||
}
|
||||
q_reg[r] = q_value;
|
||||
k_reg[r] = k_value;
|
||||
#pragma unroll
|
||||
for (int c = 0; c < COLS; ++c) {
|
||||
kv_partial[c] += state_shard[c][r] * k_value;
|
||||
}
|
||||
}
|
||||
|
||||
float delta[COLS];
|
||||
#pragma unroll
|
||||
for (int c = 0; c < COLS; ++c) {
|
||||
const float kv_col = subgroup_sum_lane0(kv_partial[c], WIDTH);
|
||||
float delta_value = 0.0F;
|
||||
if (lane == 0) {
|
||||
const auto v_index =
|
||||
(((static_cast<int64_t>(b) * tokens + t) * v_heads + hv) * D) +
|
||||
col_base + c;
|
||||
delta_value = (v[v_index] - gate_value * kv_col) * beta_value;
|
||||
}
|
||||
delta[c] = subgroup_broadcast_lane0(delta_value, WIDTH);
|
||||
}
|
||||
|
||||
float attn_partial[COLS];
|
||||
#pragma unroll
|
||||
for (int c = 0; c < COLS; ++c) {
|
||||
attn_partial[c] = 0.0F;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int r = 0; r < rows_per_lane; ++r) {
|
||||
#pragma unroll
|
||||
for (int c = 0; c < COLS; ++c) {
|
||||
const float new_state =
|
||||
fmaf(k_reg[r], delta[c], gate_value * state_shard[c][r]);
|
||||
state_shard[c][r] = new_state;
|
||||
attn_partial[c] += new_state * q_reg[r];
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int c = 0; c < COLS; ++c) {
|
||||
attn_partial[c] = subgroup_sum_lane0(attn_partial[c], WIDTH);
|
||||
}
|
||||
|
||||
if (lane == 0) {
|
||||
const auto out_base =
|
||||
(((static_cast<int64_t>(b) * tokens + t) * v_heads + hv) * D);
|
||||
#pragma unroll
|
||||
for (int c = 0; c < COLS; ++c) {
|
||||
output[out_base + col_base + c] = attn_partial[c] * scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int c = 0; c < COLS; ++c) {
|
||||
const int col = col_base + c;
|
||||
#pragma unroll
|
||||
for (int r = 0; r < rows_per_lane; ++r) {
|
||||
const int row = r * WIDTH + lane;
|
||||
if (row < D) {
|
||||
const auto state_index =
|
||||
(((static_cast<int64_t>(b) * v_heads + hv) * D + col) * D) + row;
|
||||
final_state[state_index] = state_shard[c][r];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int D>
|
||||
void launch_gdn_forward(const float* q,
|
||||
const float* k,
|
||||
const float* v,
|
||||
const float* gate,
|
||||
const float* beta,
|
||||
const float* initial_state,
|
||||
float* output,
|
||||
float* final_state,
|
||||
int batch,
|
||||
int tokens,
|
||||
int q_heads,
|
||||
int v_heads,
|
||||
float scale,
|
||||
cudaStream_t stream) {
|
||||
constexpr int cols = D == 128 ? 4 : 1;
|
||||
constexpr int width = D == 128 ? 16 : 32;
|
||||
constexpr int groups_per_warp = 32 / width;
|
||||
constexpr int column_groups_per_block = 8;
|
||||
const dim3 block(32, column_groups_per_block);
|
||||
const int groups = D / cols;
|
||||
const int z = (groups + column_groups_per_block * groups_per_warp - 1) /
|
||||
(column_groups_per_block * groups_per_warp);
|
||||
const dim3 grid(v_heads, batch, z);
|
||||
gdn_forward_kernel<D, cols, width>
|
||||
<<<grid, block, 0, stream>>>(q,
|
||||
k,
|
||||
v,
|
||||
gate,
|
||||
beta,
|
||||
initial_state,
|
||||
output,
|
||||
final_state,
|
||||
batch,
|
||||
tokens,
|
||||
q_heads,
|
||||
v_heads,
|
||||
scale);
|
||||
}
|
||||
|
||||
void validate_tensor(const torch::Tensor& tensor,
|
||||
const char* name,
|
||||
int64_t dims) {
|
||||
TORCH_CHECK(tensor.is_cuda(), name, " must be a CUDA tensor");
|
||||
TORCH_CHECK(tensor.scalar_type() == torch::kFloat32,
|
||||
name,
|
||||
" must be float32");
|
||||
TORCH_CHECK(tensor.is_contiguous(), name, " must be contiguous");
|
||||
TORCH_CHECK(tensor.dim() == dims, name, " has wrong rank");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::vector<torch::Tensor> gdn_forward(torch::Tensor q,
|
||||
torch::Tensor k,
|
||||
torch::Tensor v,
|
||||
torch::Tensor gate,
|
||||
torch::Tensor beta,
|
||||
c10::optional<torch::Tensor> initial_state,
|
||||
double scale) {
|
||||
validate_tensor(q, "q", 4);
|
||||
validate_tensor(k, "k", 4);
|
||||
validate_tensor(v, "v", 4);
|
||||
validate_tensor(gate, "gate", 3);
|
||||
validate_tensor(beta, "beta", 3);
|
||||
|
||||
TORCH_CHECK(q.sizes() == k.sizes(), "q and k must have the same shape");
|
||||
const int batch = static_cast<int>(q.size(0));
|
||||
const int tokens = static_cast<int>(q.size(1));
|
||||
const int q_heads = static_cast<int>(q.size(2));
|
||||
const int dim = static_cast<int>(q.size(3));
|
||||
const int v_heads = static_cast<int>(v.size(2));
|
||||
TORCH_CHECK(v.size(0) == batch && v.size(1) == tokens && v.size(3) == dim,
|
||||
"v must have shape [B, T, Hv, D] matching q/k");
|
||||
TORCH_CHECK(gate.size(0) == batch && gate.size(1) == tokens &&
|
||||
gate.size(2) == v_heads,
|
||||
"gate must have shape [B, T, Hv]");
|
||||
TORCH_CHECK(beta.sizes() == gate.sizes(),
|
||||
"beta must have the same shape as gate");
|
||||
TORCH_CHECK(v_heads % q_heads == 0, "Hv must be divisible by Hq");
|
||||
TORCH_CHECK(dim == 16 || dim == 32 || dim == 64 || dim == 128,
|
||||
"D must be one of 16, 32, 64, or 128");
|
||||
|
||||
const float* initial_ptr = nullptr;
|
||||
if (initial_state.has_value() && initial_state.value().defined()) {
|
||||
const auto& h0 = initial_state.value();
|
||||
validate_tensor(h0, "initial_state", 4);
|
||||
TORCH_CHECK(h0.size(0) == batch && h0.size(1) == v_heads &&
|
||||
h0.size(2) == dim && h0.size(3) == dim,
|
||||
"initial_state must have shape [B, Hv, D, D]");
|
||||
initial_ptr = h0.data_ptr<float>();
|
||||
}
|
||||
|
||||
auto output = torch::empty_like(v);
|
||||
auto final_state = torch::empty({batch, v_heads, dim, dim}, q.options());
|
||||
|
||||
const auto stream = at::cuda::getCurrentCUDAStream(q.device().index()).stream();
|
||||
switch (dim) {
|
||||
case 16:
|
||||
launch_gdn_forward<16>(q.data_ptr<float>(),
|
||||
k.data_ptr<float>(),
|
||||
v.data_ptr<float>(),
|
||||
gate.data_ptr<float>(),
|
||||
beta.data_ptr<float>(),
|
||||
initial_ptr,
|
||||
output.data_ptr<float>(),
|
||||
final_state.data_ptr<float>(),
|
||||
batch,
|
||||
tokens,
|
||||
q_heads,
|
||||
v_heads,
|
||||
static_cast<float>(scale),
|
||||
stream);
|
||||
break;
|
||||
case 32:
|
||||
launch_gdn_forward<32>(q.data_ptr<float>(),
|
||||
k.data_ptr<float>(),
|
||||
v.data_ptr<float>(),
|
||||
gate.data_ptr<float>(),
|
||||
beta.data_ptr<float>(),
|
||||
initial_ptr,
|
||||
output.data_ptr<float>(),
|
||||
final_state.data_ptr<float>(),
|
||||
batch,
|
||||
tokens,
|
||||
q_heads,
|
||||
v_heads,
|
||||
static_cast<float>(scale),
|
||||
stream);
|
||||
break;
|
||||
case 64:
|
||||
launch_gdn_forward<64>(q.data_ptr<float>(),
|
||||
k.data_ptr<float>(),
|
||||
v.data_ptr<float>(),
|
||||
gate.data_ptr<float>(),
|
||||
beta.data_ptr<float>(),
|
||||
initial_ptr,
|
||||
output.data_ptr<float>(),
|
||||
final_state.data_ptr<float>(),
|
||||
batch,
|
||||
tokens,
|
||||
q_heads,
|
||||
v_heads,
|
||||
static_cast<float>(scale),
|
||||
stream);
|
||||
break;
|
||||
case 128:
|
||||
launch_gdn_forward<128>(q.data_ptr<float>(),
|
||||
k.data_ptr<float>(),
|
||||
v.data_ptr<float>(),
|
||||
gate.data_ptr<float>(),
|
||||
beta.data_ptr<float>(),
|
||||
initial_ptr,
|
||||
output.data_ptr<float>(),
|
||||
final_state.data_ptr<float>(),
|
||||
batch,
|
||||
tokens,
|
||||
q_heads,
|
||||
v_heads,
|
||||
static_cast<float>(scale),
|
||||
stream);
|
||||
break;
|
||||
}
|
||||
check_cuda(cudaGetLastError(), "gdn_forward launch");
|
||||
return {output, final_state};
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("gdn_forward", &gdn_forward, "SM70/SM75 legacy GDN forward");
|
||||
}
|
||||
104
flash_qla/ops/gated_delta_rule/legacy/sm_legacy.py
Normal file
104
flash_qla/ops/gated_delta_rule/legacy/sm_legacy.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# Copyright (c) 2026 The Qwen team, Alibaba Group.
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
_EXT = None
|
||||
|
||||
|
||||
def _load_ext():
|
||||
global _EXT
|
||||
if _EXT is not None:
|
||||
return _EXT
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("SM70/SM75 legacy GDN backend requires CUDA")
|
||||
|
||||
os.environ.setdefault("TORCH_CUDA_ARCH_LIST", "7.0;7.5")
|
||||
src = Path(__file__).with_name("csrc") / "gdn_forward.cu"
|
||||
_EXT = load(
|
||||
name="flash_qla_legacy_gdn",
|
||||
sources=[str(src)],
|
||||
extra_cuda_cflags=["-O3"],
|
||||
extra_cflags=["-O3"],
|
||||
verbose=bool(int(os.environ.get("FLASH_QLA_LEGACY_VERBOSE_BUILD", "0"))),
|
||||
)
|
||||
return _EXT
|
||||
|
||||
|
||||
def _check_inputs(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
initial_state: torch.Tensor | None,
|
||||
) -> None:
|
||||
tensors = [q, k, v, g, beta]
|
||||
if initial_state is not None:
|
||||
tensors.append(initial_state)
|
||||
|
||||
if any(not tensor.is_cuda for tensor in tensors):
|
||||
raise ValueError("legacy GDN tensors must be CUDA tensors")
|
||||
# if any(tensor.dtype != torch.float32 for tensor in tensors):
|
||||
# raise ValueError("legacy GDN backend currently supports float32 tensors only")
|
||||
if any(not tensor.is_contiguous() for tensor in tensors):
|
||||
raise ValueError("legacy GDN tensors must be contiguous")
|
||||
if q.ndim != 4 or k.ndim != 4 or v.ndim != 4:
|
||||
raise ValueError("q, k, and v must have shape [B, T, H, D]")
|
||||
if g.ndim != 3 or beta.ndim != 3:
|
||||
raise ValueError("g and beta must have shape [B, T, Hv]")
|
||||
if q.shape != k.shape:
|
||||
raise ValueError("q and k must have the same shape")
|
||||
|
||||
batch, tokens, q_heads, dim = q.shape
|
||||
if v.shape[0] != batch or v.shape[1] != tokens or v.shape[3] != dim:
|
||||
raise ValueError("v must have shape [B, T, Hv, D] matching q/k")
|
||||
if g.shape != beta.shape or g.shape != v.shape[:3]:
|
||||
raise ValueError("g and beta must have shape [B, T, Hv]")
|
||||
if v.shape[2] % q_heads != 0:
|
||||
raise ValueError("Hv must be divisible by Hq")
|
||||
if dim not in (16, 32, 64, 128):
|
||||
raise ValueError("legacy GDN backend supports D in {16, 32, 64, 128}")
|
||||
if initial_state is not None and initial_state.shape != (batch, v.shape[2], dim, dim):
|
||||
raise ValueError("initial_state must have shape [B, Hv, D, D]")
|
||||
|
||||
|
||||
def chunk_gated_delta_rule_fwd_legacy(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float | None = None,
|
||||
initial_state: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Run the experimental SM70/SM75 forward-only GDN backend.
|
||||
|
||||
This legacy backend is intentionally explicit. It does not replace the
|
||||
Hopper/SM90 TileLang path and currently supports only contiguous float32
|
||||
tensors for inference-oriented forward execution.
|
||||
|
||||
Shapes:
|
||||
q, k: [B, T, Hq, D]
|
||||
v: [B, T, Hv, D]
|
||||
g, beta: [B, T, Hv]
|
||||
initial_state: optional [B, Hv, D, D]
|
||||
|
||||
Returns:
|
||||
output: [B, T, Hv, D]
|
||||
final_state: [B, Hv, D, D]
|
||||
"""
|
||||
|
||||
_check_inputs(q, k, v, g, beta, initial_state)
|
||||
if scale is None:
|
||||
scale = q.shape[-1] ** -0.5
|
||||
|
||||
ext = _load_ext()
|
||||
return ext.gdn_forward(q, k, v, g, beta, initial_state, float(scale))
|
||||
Reference in New Issue
Block a user