701 lines
22 KiB
Python
701 lines
22 KiB
Python
# Copyright (c) 2026 The Qwen team, Alibaba Group.
|
|
# Licensed under The MIT License [see LICENSE for details]
|
|
|
|
import torch
|
|
|
|
from flash_qla.utils import (
|
|
pad_and_reshape,
|
|
pack,
|
|
unpack,
|
|
fill_last_chunk_of_g,
|
|
prepare_chunk_offsets,
|
|
)
|
|
|
|
|
|
def torch_cumsum(
|
|
x: torch.Tensor, # [B, T, H]
|
|
cu_seqlens: torch.Tensor = None,
|
|
chunk_size: int = 64,
|
|
reverse: bool = False,
|
|
):
|
|
if cu_seqlens is not None:
|
|
x = unpack(x, cu_seqlens)
|
|
|
|
batch_size, num_tokens, num_heads = x.shape
|
|
|
|
x = pad_and_reshape(x, dim=1, chunk_size=chunk_size)
|
|
|
|
if reverse:
|
|
x = torch.flip(x, dims=(2,))
|
|
x = x.cumsum(dim=2)
|
|
x = torch.flip(x, dims=(2,))
|
|
else:
|
|
x = x.cumsum(dim=2)
|
|
x = x.reshape(batch_size, -1, num_heads)
|
|
x = x[:, :num_tokens]
|
|
|
|
if cu_seqlens is not None:
|
|
x = pack(x, cu_seqlens)
|
|
return x
|
|
|
|
|
|
def torch_kkt_fwd(
|
|
k: torch.Tensor, # [B, T, Hk, K]
|
|
g: torch.Tensor, # [B, T, Hv]
|
|
beta: torch.Tensor, # [B, T, Hv]
|
|
cu_seqlens: torch.Tensor = None,
|
|
chunk_size: int = 64,
|
|
):
|
|
if cu_seqlens is not None:
|
|
k = unpack(k, cu_seqlens)
|
|
g = unpack(g, cu_seqlens)
|
|
beta = unpack(beta, cu_seqlens)
|
|
|
|
batch_size, num_tokens, num_k_heads, head_dim = k.shape
|
|
num_v_heads = g.shape[-1]
|
|
|
|
if num_k_heads != num_v_heads:
|
|
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
|
|
|
|
k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, H, K]
|
|
g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, H]
|
|
beta = pad_and_reshape(beta, dim=1, chunk_size=chunk_size) # [B, N, C, H]
|
|
|
|
mask = torch.triu(
|
|
torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device)
|
|
)
|
|
decay_mask = torch.exp(g[:, :, :, None, :] - g[:, :, None, :, :])
|
|
decay_mask = decay_mask.masked_fill(mask[None, None, :, :, None], 0.0)
|
|
# decay_mask = torch.where(mask[None, None, :, :, None], decay_mask, 0.0)
|
|
attn = torch.einsum(
|
|
"bnchk, bndhk -> bnchd", k * beta.unsqueeze(-1), k
|
|
) * decay_mask.swapaxes(-2, -1) # [B, N, C, H, D]
|
|
attn = attn.reshape(batch_size, -1, num_v_heads, chunk_size)[:, :num_tokens]
|
|
|
|
if cu_seqlens is not None:
|
|
attn = pack(attn, cu_seqlens)
|
|
return attn
|
|
|
|
|
|
def torch_solve(
|
|
x: torch.Tensor, # [B, T, H, D]
|
|
cu_seqlens: torch.Tensor = None,
|
|
):
|
|
if cu_seqlens is not None:
|
|
x = unpack(x, cu_seqlens)
|
|
|
|
batch_size, num_tokens, num_heads, chunk_size = x.shape
|
|
|
|
x = -pad_and_reshape(x, dim=1, chunk_size=chunk_size).swapaxes(
|
|
2, 3
|
|
) # [B, N, H, C, D]
|
|
|
|
for i in range(1, chunk_size):
|
|
row = x[..., i, :i].clone()
|
|
sub = x[..., :i, :i].clone()
|
|
x[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
|
|
x += torch.eye(chunk_size, dtype=x.dtype, device=x.device)
|
|
x = x.swapaxes(2, 3).reshape((batch_size, -1, num_heads, chunk_size))[
|
|
:, :num_tokens
|
|
]
|
|
|
|
if cu_seqlens is not None:
|
|
x = pack(x, cu_seqlens)
|
|
return x
|
|
|
|
|
|
def torch_w_u_fwd(
|
|
k: torch.Tensor, # [B, T, Hk, K]
|
|
v: torch.Tensor, # [B, T, Hv, V]
|
|
g: torch.Tensor, # [B, T, Hv]
|
|
beta: torch.Tensor, # [B, T, Hv]
|
|
A: torch.Tensor, # [B, T, Hv, D]
|
|
cu_seqlens: torch.Tensor = None,
|
|
):
|
|
if cu_seqlens is not None:
|
|
k = unpack(k, cu_seqlens)
|
|
v = unpack(v, cu_seqlens)
|
|
A = unpack(A, cu_seqlens)
|
|
beta = unpack(beta, cu_seqlens)
|
|
g = unpack(g, cu_seqlens)
|
|
|
|
batch_size, num_tokens, _, chunk_size = A.shape
|
|
_, _, num_k_heads, head_dim_k = k.shape
|
|
_, _, num_v_heads, head_dim_v = v.shape
|
|
|
|
if num_k_heads != num_v_heads:
|
|
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
|
|
|
|
k_beta = pad_and_reshape(
|
|
k * beta.unsqueeze(-1) * g.exp().unsqueeze(-1), dim=1, chunk_size=chunk_size
|
|
) # [B, N, C, Hv, K]
|
|
v_beta = pad_and_reshape(
|
|
v * beta.unsqueeze(-1), dim=1, chunk_size=chunk_size
|
|
) # [B, N, C, Hv, V]
|
|
A = pad_and_reshape(A, dim=1)
|
|
|
|
w = torch.einsum("bnchd, bndhk -> bnchk", A, k_beta).reshape(
|
|
(batch_size, -1, num_v_heads, head_dim_k)
|
|
)[:, :num_tokens]
|
|
u = torch.einsum("bnchd, bndhk -> bnchk", A, v_beta).reshape(
|
|
(batch_size, -1, num_v_heads, head_dim_v)
|
|
)[:, :num_tokens]
|
|
|
|
if cu_seqlens is not None:
|
|
w = pack(w, cu_seqlens)
|
|
u = pack(u, cu_seqlens)
|
|
return w, u
|
|
|
|
|
|
def torch_chunk_gdr_fwd(
|
|
k: torch.Tensor, # [B, T, Hk, K]
|
|
w: torch.Tensor, # [B, T, Hv, K]
|
|
u: torch.Tensor, # [B, T, Hv, V]
|
|
g: torch.Tensor, # [B, T, Hv]
|
|
initial_state: torch.Tensor = None, # [B, Hv, K, V]
|
|
cu_seqlens: torch.Tensor = None,
|
|
chunk_size: int = 64,
|
|
):
|
|
if cu_seqlens is not None:
|
|
k = unpack(k, cu_seqlens)
|
|
w = unpack(w, cu_seqlens)
|
|
u = unpack(u, cu_seqlens)
|
|
g = unpack(g, cu_seqlens)
|
|
|
|
batch_size, num_tokens, num_k_heads, head_dim_k = k.shape
|
|
_, _, num_v_heads, head_dim_v = u.shape
|
|
|
|
if num_k_heads != num_v_heads:
|
|
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
|
|
|
|
k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
|
|
w = pad_and_reshape(w, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
|
|
u = pad_and_reshape(u, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V]
|
|
g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, Hv]
|
|
g = fill_last_chunk_of_g(g, num_tokens, cu_seqlens, chunk_size=chunk_size)
|
|
|
|
if initial_state is None:
|
|
last_state = torch.zeros(
|
|
(batch_size, num_v_heads, head_dim_k, head_dim_v),
|
|
dtype=g.dtype,
|
|
device=g.device,
|
|
)
|
|
else:
|
|
last_state = initial_state.to(g.dtype, copy=True)
|
|
|
|
h, vn = [], []
|
|
for i in range(k.shape[1]):
|
|
h.append(last_state)
|
|
v_new = u[:, i] - torch.einsum("bchk, bhkv -> bchv", w[:, i], last_state)
|
|
vn.append(v_new)
|
|
last_state = last_state * g[:, i, -1, :, None, None].exp()
|
|
last_state = last_state + torch.einsum(
|
|
"bchk, bchv -> bhkv",
|
|
k[:, i] * (g[:, i, -1:, :, None] - g[:, i, :, :, None]).exp(),
|
|
v_new,
|
|
)
|
|
h = torch.stack(h, dim=1).contiguous()
|
|
vn = (
|
|
torch.stack(vn, dim=1)
|
|
.reshape((batch_size, -1, num_v_heads, head_dim_v))[:, :num_tokens]
|
|
.contiguous()
|
|
)
|
|
|
|
if cu_seqlens is not None:
|
|
vn = pack(vn, cu_seqlens)
|
|
h = pack(h, prepare_chunk_offsets(cu_seqlens, chunk_size))
|
|
|
|
return h, vn, last_state
|
|
|
|
|
|
def torch_chunk_o_fwd(
|
|
q: torch.Tensor, # [B, T, Hk, K]
|
|
k: torch.Tensor, # [B, T, Hk, K]
|
|
v: torch.Tensor, # [B, T, Hv, K]
|
|
h: torch.Tensor, # [B, N, Hv, K, V]
|
|
g: torch.Tensor, # [B, T, Hv]
|
|
cu_seqlens: torch.Tensor = None,
|
|
scale: float = None,
|
|
chunk_size: int = 64,
|
|
):
|
|
if cu_seqlens is not None:
|
|
q = unpack(q, cu_seqlens)
|
|
k = unpack(k, cu_seqlens)
|
|
v = unpack(v, cu_seqlens)
|
|
g = unpack(g, cu_seqlens)
|
|
h = unpack(h, prepare_chunk_offsets(cu_seqlens, chunk_size))
|
|
|
|
batch_size, num_tokens, num_k_heads, head_dim_k = k.shape
|
|
_, _, num_v_heads, head_dim_v = v.shape
|
|
|
|
if num_k_heads != num_v_heads:
|
|
q = q.repeat_interleave(num_v_heads // num_k_heads, dim=2)
|
|
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
|
|
|
|
scale = scale or head_dim_k ** (-0.5)
|
|
|
|
q = pad_and_reshape(q, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
|
|
k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
|
|
v = pad_and_reshape(v, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
|
|
g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, Hv]
|
|
|
|
q = q * scale
|
|
|
|
mask = torch.triu(
|
|
torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device),
|
|
diagonal=1,
|
|
)
|
|
decay_mask = torch.exp(g[:, :, :, None, :] - g[:, :, None, :, :])
|
|
decay_mask = decay_mask.masked_fill(
|
|
mask[None, None, :, :, None], 0.0
|
|
) # [B, N, C, D, Hv]
|
|
|
|
attn = torch.einsum("bnchk, bndhk -> bncdh", q, k) * decay_mask
|
|
attn_inter = torch.einsum("bnchk, bnhkv -> bnchv", q * g.exp().unsqueeze(-1), h)
|
|
o = attn_inter + torch.einsum("bncdh, bndhv -> bnchv", attn, v)
|
|
|
|
o = o.reshape((batch_size, -1, num_v_heads, head_dim_v))[:, :num_tokens]
|
|
if cu_seqlens is not None:
|
|
o = pack(o, cu_seqlens)
|
|
return o
|
|
|
|
|
|
def torch_chunk_dv_bwd(
|
|
q: torch.Tensor, # [B, T, Hk, K]
|
|
k: torch.Tensor, # [B, T, Hk, K]
|
|
g: torch.Tensor, # [B, T, Hv]
|
|
do: torch.Tensor, # [B, T, Hv, V]
|
|
cu_seqlens: torch.Tensor = None,
|
|
scale: float = None,
|
|
chunk_size: int = 64,
|
|
):
|
|
if cu_seqlens is not None:
|
|
q = unpack(q, cu_seqlens)
|
|
k = unpack(k, cu_seqlens)
|
|
g = unpack(g, cu_seqlens)
|
|
do = unpack(do, cu_seqlens)
|
|
|
|
batch_size, num_tokens, num_k_heads, head_dim_k = k.shape
|
|
_, _, num_v_heads, head_dim_v = do.shape
|
|
|
|
if num_k_heads != num_v_heads:
|
|
q = q.repeat_interleave(num_v_heads // num_k_heads, dim=2)
|
|
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
|
|
|
|
scale = scale or head_dim_k ** (-0.5)
|
|
|
|
q = pad_and_reshape(q, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
|
|
k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
|
|
g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, Hv]
|
|
do = pad_and_reshape(do, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V]
|
|
|
|
q = q * scale
|
|
|
|
mask = torch.triu(
|
|
torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device),
|
|
diagonal=1,
|
|
)
|
|
decay_mask = torch.exp(g[:, :, :, None, :] - g[:, :, None, :, :])
|
|
decay_mask = decay_mask.masked_fill(
|
|
mask[None, None, :, :, None], 0.0
|
|
) # [B, N, C, D, Hv]
|
|
|
|
attn = torch.einsum("bnchk, bndhk -> bncdh", q, k) * decay_mask
|
|
dv = torch.einsum("bncdh, bnchv -> bndhv", attn, do)
|
|
|
|
dv = dv.reshape((batch_size, -1, num_v_heads, head_dim_v))[:, :num_tokens]
|
|
if cu_seqlens is not None:
|
|
dv = pack(dv, cu_seqlens)
|
|
return dv
|
|
|
|
|
|
def torch_chunk_gdr_bwd(
|
|
q: torch.Tensor, # [B, T, Hk, K]
|
|
k: torch.Tensor, # [B, T, Hk, K]
|
|
w: torch.Tensor, # [B, T, Hv, K]
|
|
g: torch.Tensor, # [B, T, Hv]
|
|
do: torch.Tensor, # [B, T, Hv, V]
|
|
dv: torch.Tensor, # [B, T, Hv, V]
|
|
h0: torch.Tensor = None, # [B, Hv, K, V]
|
|
dht: torch.Tensor = None, # [B, Hv, K, V]
|
|
cu_seqlens: torch.Tensor = None,
|
|
scale: float = None,
|
|
chunk_size: int = 64,
|
|
):
|
|
if cu_seqlens is not None:
|
|
q = unpack(q, cu_seqlens)
|
|
k = unpack(k, cu_seqlens)
|
|
w = unpack(w, cu_seqlens)
|
|
g = unpack(g, cu_seqlens)
|
|
do = unpack(do, cu_seqlens)
|
|
dv = unpack(dv, cu_seqlens)
|
|
|
|
batch_size, num_tokens, num_k_heads, head_dim_k = k.shape
|
|
_, _, num_v_heads, head_dim_v = do.shape
|
|
|
|
if num_k_heads != num_v_heads:
|
|
q = q.repeat_interleave(num_v_heads // num_k_heads, dim=2)
|
|
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
|
|
|
|
scale = scale or head_dim_k ** (-0.5)
|
|
|
|
q = pad_and_reshape(q, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
|
|
k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
|
|
w = pad_and_reshape(w, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
|
|
g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, Hv]
|
|
do = pad_and_reshape(do, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V]
|
|
dv = pad_and_reshape(dv, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V]
|
|
g = fill_last_chunk_of_g(g, num_tokens, cu_seqlens, chunk_size=chunk_size)
|
|
|
|
q = q * scale
|
|
|
|
if dht is None:
|
|
dstate = torch.zeros(
|
|
(batch_size, num_v_heads, head_dim_k, head_dim_v),
|
|
dtype=g.dtype,
|
|
device=g.device,
|
|
)
|
|
else:
|
|
dstate = dht.to(g.dtype, copy=True)
|
|
dstate_inter = torch.einsum("bnchk, bnchv -> bnhkv", q * g.exp().unsqueeze(-1), do)
|
|
|
|
dh = []
|
|
for i in reversed(range(k.shape[1])):
|
|
dh.insert(0, dstate)
|
|
dv[:, i] += torch.einsum(
|
|
"bchk, bhkv -> bchv",
|
|
k[:, i] * (g[:, i, -1:, :, None] - g[:, i, :, :, None]).exp(),
|
|
dstate,
|
|
)
|
|
dstate = dstate * g[:, i, -1, :, None, None].exp()
|
|
dstate = (
|
|
dstate
|
|
+ dstate_inter[:, i]
|
|
- torch.einsum("bchk, bchv -> bhkv", w[:, i], dv[:, i])
|
|
)
|
|
dh = torch.stack(dh, dim=1).contiguous()
|
|
|
|
dh0 = None if h0 is None else dstate
|
|
dv = dv.reshape((batch_size, -1, num_v_heads, head_dim_v))[:, :num_tokens]
|
|
if cu_seqlens is not None:
|
|
dv = pack(dv, cu_seqlens)
|
|
dh = pack(dh, prepare_chunk_offsets(cu_seqlens, chunk_size))
|
|
return dh, dh0, dv
|
|
|
|
|
|
def torch_chunk_dqkwg_bwd(
|
|
q: torch.Tensor, # [B, T, Hk, K]
|
|
k: torch.Tensor, # [B, T, Hk, K]
|
|
v: torch.Tensor, # [B, T, Hv, V]
|
|
w: torch.Tensor, # [B, T, Hv, K]
|
|
g: torch.Tensor, # [B, T, Hv]
|
|
h: torch.Tensor, # [B, N, Hv, K, V]
|
|
dv: torch.Tensor, # [B, T, Hv, V]
|
|
do: torch.Tensor, # [B, T, Hv, V]
|
|
dh: torch.Tensor, # [B, N, Hv, K, V]
|
|
cu_seqlens: torch.Tensor = None,
|
|
scale: float = None,
|
|
chunk_size: int = 64,
|
|
):
|
|
if cu_seqlens is not None:
|
|
q = unpack(q, cu_seqlens)
|
|
k = unpack(k, cu_seqlens)
|
|
v = unpack(v, cu_seqlens)
|
|
w = unpack(w, cu_seqlens)
|
|
g = unpack(g, cu_seqlens)
|
|
do = unpack(do, cu_seqlens)
|
|
dv = unpack(dv, cu_seqlens)
|
|
h = unpack(h, prepare_chunk_offsets(cu_seqlens, chunk_size))
|
|
dh = unpack(dh, prepare_chunk_offsets(cu_seqlens, chunk_size))
|
|
|
|
batch_size, num_tokens, num_k_heads, head_dim_k = k.shape
|
|
_, _, num_v_heads, head_dim_v = do.shape
|
|
|
|
if num_k_heads != num_v_heads:
|
|
q = q.repeat_interleave(num_v_heads // num_k_heads, dim=2)
|
|
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
|
|
|
|
scale = scale or head_dim_k ** (-0.5)
|
|
|
|
q = pad_and_reshape(q, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
|
|
k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
|
|
v = pad_and_reshape(v, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V]
|
|
w = pad_and_reshape(w, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
|
|
g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, Hv]
|
|
do = pad_and_reshape(do, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V]
|
|
dv = pad_and_reshape(dv, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V]
|
|
g = fill_last_chunk_of_g(g, num_tokens, cu_seqlens, chunk_size=chunk_size)
|
|
|
|
mask = torch.triu(
|
|
torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device),
|
|
diagonal=1,
|
|
)
|
|
decay_mask = torch.exp(g[:, :, :, None, :] - g[:, :, None, :, :])
|
|
decay_mask = decay_mask.masked_fill(
|
|
mask[None, None, :, :, None], 0.0
|
|
) # [B, N, C, D, Hv]
|
|
|
|
dg_last = (h * dh).sum(dim=-1).sum(dim=-1) # [B, N, Hv]
|
|
ds = torch.einsum("bnchv, bndhv -> bncdh", do, v)
|
|
dq = torch.einsum("bnchv, bnhkv -> bnchk", do, h)
|
|
dk = torch.einsum("bnchv, bnhkv -> bnchk", v, dh)
|
|
dw = -torch.einsum("bnchv, bnhkv -> bnchk", dv, h)
|
|
|
|
g_last = g[:, :, -1]
|
|
dg_last *= g_last.exp()
|
|
dq = dq * g.unsqueeze(-1).exp() * scale
|
|
dg = (q * dq).sum(dim=-1) # [B, N, C, Hv]
|
|
dk = dk * (g_last.unsqueeze(-2) - g).unsqueeze(-1).exp()
|
|
dg -= (k * dk).sum(dim=-1)
|
|
dg_last += (k * dk).sum(dim=-1).sum(dim=-2)
|
|
ds *= decay_mask * scale
|
|
ds2 = ds * torch.einsum("bnchk, bndhk -> bncdh", q, k)
|
|
dg += ds2.sum(dim=-2)
|
|
dg -= ds2.sum(dim=-3)
|
|
dq += torch.einsum("bncdh, bndhk -> bnchk", ds, k)
|
|
dk += torch.einsum("bncdh, bnchk -> bndhk", ds, q)
|
|
dg[:, :, -1] += dg_last
|
|
|
|
dg = fill_last_chunk_of_g(
|
|
dg, num_tokens, cu_seqlens, chunk_size=chunk_size, reverse=True
|
|
)
|
|
dq = dq.reshape((batch_size, -1, num_v_heads, head_dim_k))[:, :num_tokens]
|
|
dk = dk.reshape((batch_size, -1, num_v_heads, head_dim_k))[:, :num_tokens]
|
|
dw = dw.reshape((batch_size, -1, num_v_heads, head_dim_k))[:, :num_tokens]
|
|
dg = dg.reshape((batch_size, -1, num_v_heads))[:, :num_tokens]
|
|
if cu_seqlens is not None:
|
|
dq = pack(dq, cu_seqlens)
|
|
dk = pack(dk, cu_seqlens)
|
|
dw = pack(dw, cu_seqlens)
|
|
dg = pack(dg, cu_seqlens)
|
|
return dq, dk, dw, dg
|
|
|
|
|
|
def torch_chunk_wy_bwd(
|
|
k: torch.Tensor, # [B, T, Hk, K]
|
|
v: torch.Tensor, # [B, T, Hv, V]
|
|
beta: torch.Tensor, # [B, T, Hv]
|
|
A: torch.Tensor, # [B, T, Hv, D]
|
|
g: torch.Tensor, # [B, T, Hv]
|
|
dw: torch.Tensor, # [B, T, Hv, K]
|
|
du: torch.Tensor, # [B, T, Hv, V]
|
|
dk1: torch.Tensor, # [B, T, Hv, K]
|
|
dg1: torch.Tensor, # [B, T, Hv]
|
|
cu_seqlens: torch.Tensor = None,
|
|
):
|
|
if cu_seqlens is not None:
|
|
k = unpack(k, cu_seqlens)
|
|
v = unpack(v, cu_seqlens)
|
|
beta = unpack(beta, cu_seqlens)
|
|
A = unpack(A, cu_seqlens)
|
|
g = unpack(g, cu_seqlens)
|
|
dw = unpack(dw, cu_seqlens)
|
|
du = unpack(du, cu_seqlens)
|
|
dk1 = unpack(dk1, cu_seqlens)
|
|
dg1 = unpack(dg1, cu_seqlens)
|
|
|
|
batch_size, num_tokens, num_k_heads, head_dim_k = k.shape
|
|
_, _, num_v_heads, head_dim_v = v.shape
|
|
chunk_size = A.shape[-1]
|
|
|
|
if num_k_heads != num_v_heads:
|
|
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
|
|
|
|
k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
|
|
v = pad_and_reshape(v, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V]
|
|
beta = pad_and_reshape(beta, dim=1, chunk_size=chunk_size) # [B, N, C, Hv]
|
|
A = pad_and_reshape(A, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, D]
|
|
g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, Hv]
|
|
dw = pad_and_reshape(dw, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
|
|
du = pad_and_reshape(du, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V]
|
|
dk1 = pad_and_reshape(dk1, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K]
|
|
dg1 = pad_and_reshape(dg1, dim=1, chunk_size=chunk_size) # [B, N, C, Hv]
|
|
|
|
dA = torch.einsum("bnchk, bndhk -> bnchd", dw, k * (beta * g.exp()).unsqueeze(-1))
|
|
dk_beta_g = torch.einsum("bnchd, bnchk -> bndhk", A, dw)
|
|
dk = dk_beta_g * (beta * g.exp()).unsqueeze(-1)
|
|
db = (dk_beta_g * k * g.exp().unsqueeze(-1)).sum(dim=-1)
|
|
dg = (dk_beta_g * k * (g.exp() * beta).unsqueeze(-1)).sum(dim=-1)
|
|
|
|
dA += torch.einsum("bnchv, bndhv -> bnchd", du, v * beta.unsqueeze(-1))
|
|
dv_beta = torch.einsum("bnchd, bnchv -> bndhv", A, du)
|
|
dv = dv_beta * beta.unsqueeze(-1)
|
|
db += (dv_beta * v).sum(dim=-1)
|
|
|
|
mask = torch.triu(
|
|
torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device)
|
|
)
|
|
decay_mask = torch.exp(g[:, :, :, None, :] - g[:, :, None, :, :])
|
|
decay_mask = decay_mask.masked_fill(mask[None, None, :, :, None], 0.0).swapaxes(
|
|
-2, -1
|
|
)
|
|
dA = dA.masked_fill(mask[None, None, :, None, :], 0.0)
|
|
dA = torch.einsum("bndhc, bndhe -> bnche", A, dA)
|
|
dA = torch.einsum("bnchd, bnehd -> bnche", dA, A)
|
|
dA = -dA * decay_mask
|
|
|
|
A = torch.einsum("bnchk, bndhk -> bnchd", k * beta.unsqueeze(-1), k)
|
|
dk_beta = torch.einsum("bnchd, bndhk -> bnchk", dA, k)
|
|
db += (dk_beta * k).sum(dim=-1)
|
|
dk += torch.einsum("bnchd, bnchk -> bndhk", dA, k * beta.unsqueeze(-1))
|
|
dk += dk_beta * beta.unsqueeze(-1)
|
|
dk += dk1
|
|
|
|
dg += (dA * A).sum(dim=-1) - (dA * A).sum(dim=-3).swapaxes(-1, -2)
|
|
dg += dg1
|
|
|
|
# TODO: NOTE: GVA
|
|
dk = dk.reshape((batch_size, -1, num_v_heads, head_dim_k))[:, :num_tokens]
|
|
dv = dv.reshape((batch_size, -1, num_v_heads, head_dim_k))[:, :num_tokens]
|
|
db = db.reshape((batch_size, -1, num_v_heads))[:, :num_tokens]
|
|
dg = dg.reshape((batch_size, -1, num_v_heads))[:, :num_tokens]
|
|
if cu_seqlens is not None:
|
|
dk = pack(dk, cu_seqlens)
|
|
dv = pack(dv, cu_seqlens)
|
|
db = pack(db, cu_seqlens)
|
|
dg = pack(dg, cu_seqlens)
|
|
return dk, dv, db, dg
|
|
|
|
|
|
def chunk_gated_delta_rule_fwd(
|
|
q: torch.Tensor, # [B, T, Hk, K]
|
|
k: torch.Tensor, # [B, T, Hk, K]
|
|
v: torch.Tensor, # [B, T, Hv, K]
|
|
g: torch.Tensor, # [B, T, Hv]
|
|
beta: torch.Tensor, # [B, T, Hv]
|
|
cu_seqlens: torch.Tensor = None,
|
|
initial_state: torch.Tensor = None,
|
|
scale: float = None,
|
|
chunk_size: int = 64,
|
|
):
|
|
scale = scale or q.shape[-1] ** (-0.5)
|
|
g = torch_cumsum(
|
|
x=g,
|
|
cu_seqlens=cu_seqlens,
|
|
chunk_size=chunk_size,
|
|
)
|
|
A = torch_kkt_fwd(
|
|
k=k,
|
|
g=g,
|
|
beta=beta,
|
|
cu_seqlens=cu_seqlens,
|
|
chunk_size=chunk_size,
|
|
)
|
|
A = torch_solve(
|
|
x=A,
|
|
cu_seqlens=cu_seqlens,
|
|
)
|
|
w, u = torch_w_u_fwd(
|
|
k=k,
|
|
v=v,
|
|
beta=beta,
|
|
A=A,
|
|
g=g,
|
|
cu_seqlens=cu_seqlens,
|
|
)
|
|
h, vn, final_state = torch_chunk_gdr_fwd(
|
|
k=k,
|
|
w=w,
|
|
u=u,
|
|
g=g,
|
|
initial_state=initial_state,
|
|
cu_seqlens=cu_seqlens,
|
|
chunk_size=chunk_size,
|
|
)
|
|
o = torch_chunk_o_fwd(
|
|
q=q,
|
|
k=k,
|
|
v=vn,
|
|
h=h,
|
|
g=g,
|
|
cu_seqlens=cu_seqlens,
|
|
scale=scale,
|
|
chunk_size=chunk_size,
|
|
)
|
|
return g, o, A, h, final_state
|
|
|
|
|
|
def chunk_gated_delta_rule_bwd(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
g: torch.Tensor,
|
|
beta: torch.Tensor,
|
|
A: torch.Tensor,
|
|
scale: float,
|
|
initial_state: torch.Tensor,
|
|
do: torch.Tensor,
|
|
dht: torch.Tensor,
|
|
cu_seqlens: torch.Tensor = None,
|
|
chunk_size: int = 64,
|
|
):
|
|
w, u = torch_w_u_fwd(
|
|
k=k,
|
|
v=v,
|
|
beta=beta,
|
|
A=A,
|
|
g=g,
|
|
cu_seqlens=cu_seqlens,
|
|
)
|
|
h, vn, _ = torch_chunk_gdr_fwd(
|
|
k=k,
|
|
w=w,
|
|
u=u,
|
|
g=g,
|
|
initial_state=initial_state,
|
|
cu_seqlens=cu_seqlens,
|
|
chunk_size=chunk_size,
|
|
)
|
|
dv = torch_chunk_dv_bwd(
|
|
q=q,
|
|
k=k,
|
|
g=g,
|
|
do=do,
|
|
scale=scale,
|
|
cu_seqlens=cu_seqlens,
|
|
)
|
|
dh, dh0, dv = torch_chunk_gdr_bwd(
|
|
q=q,
|
|
k=k,
|
|
w=w,
|
|
g=g,
|
|
h0=initial_state,
|
|
dht=dht,
|
|
do=do,
|
|
dv=dv,
|
|
scale=scale,
|
|
cu_seqlens=cu_seqlens,
|
|
)
|
|
dq, dk1, dw, dg1 = torch_chunk_dqkwg_bwd(
|
|
q=q,
|
|
k=k,
|
|
v=vn,
|
|
w=w,
|
|
g=g,
|
|
h=h,
|
|
dv=dv,
|
|
do=do,
|
|
dh=dh,
|
|
scale=scale,
|
|
cu_seqlens=cu_seqlens,
|
|
)
|
|
dk, dv, db, dg = torch_chunk_wy_bwd(
|
|
k=k,
|
|
v=v,
|
|
beta=beta,
|
|
g=g,
|
|
A=A,
|
|
dw=dw,
|
|
du=dv,
|
|
dk1=dk1,
|
|
dg1=dg1,
|
|
cu_seqlens=cu_seqlens,
|
|
)
|
|
Hg, H = k.shape[-2], v.shape[-2]
|
|
if Hg < H:
|
|
B, T, _, K = dq.shape
|
|
dq = torch.sum(dq.reshape(B, T, Hg, -1, K), dim=3)
|
|
dk = torch.sum(dk.reshape(B, T, Hg, -1, K), dim=3)
|
|
dg = torch_cumsum(dg, chunk_size=64, reverse=True, cu_seqlens=cu_seqlens)
|
|
return dq, dk, dv, db, dg, dh0
|