# Copyright (c) 2026 The Qwen team, Alibaba Group. # Licensed under The MIT License [see LICENSE for details] import torch from flash_qla.utils import ( pad_and_reshape, pack, unpack, fill_last_chunk_of_g, prepare_chunk_offsets, ) def torch_cumsum( x: torch.Tensor, # [B, T, H] cu_seqlens: torch.Tensor = None, chunk_size: int = 64, reverse: bool = False, ): if cu_seqlens is not None: x = unpack(x, cu_seqlens) batch_size, num_tokens, num_heads = x.shape x = pad_and_reshape(x, dim=1, chunk_size=chunk_size) if reverse: x = torch.flip(x, dims=(2,)) x = x.cumsum(dim=2) x = torch.flip(x, dims=(2,)) else: x = x.cumsum(dim=2) x = x.reshape(batch_size, -1, num_heads) x = x[:, :num_tokens] if cu_seqlens is not None: x = pack(x, cu_seqlens) return x def torch_kkt_fwd( k: torch.Tensor, # [B, T, Hk, K] g: torch.Tensor, # [B, T, Hv] beta: torch.Tensor, # [B, T, Hv] cu_seqlens: torch.Tensor = None, chunk_size: int = 64, ): if cu_seqlens is not None: k = unpack(k, cu_seqlens) g = unpack(g, cu_seqlens) beta = unpack(beta, cu_seqlens) batch_size, num_tokens, num_k_heads, head_dim = k.shape num_v_heads = g.shape[-1] if num_k_heads != num_v_heads: k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2) k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, H, K] g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, H] beta = pad_and_reshape(beta, dim=1, chunk_size=chunk_size) # [B, N, C, H] mask = torch.triu( torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device) ) decay_mask = torch.exp(g[:, :, :, None, :] - g[:, :, None, :, :]) decay_mask = decay_mask.masked_fill(mask[None, None, :, :, None], 0.0) # decay_mask = torch.where(mask[None, None, :, :, None], decay_mask, 0.0) attn = torch.einsum( "bnchk, bndhk -> bnchd", k * beta.unsqueeze(-1), k ) * decay_mask.swapaxes(-2, -1) # [B, N, C, H, D] attn = attn.reshape(batch_size, -1, num_v_heads, chunk_size)[:, :num_tokens] if cu_seqlens is not None: attn = pack(attn, cu_seqlens) return attn def torch_solve( x: torch.Tensor, # [B, T, H, D] cu_seqlens: torch.Tensor = None, ): if cu_seqlens is not None: x = unpack(x, cu_seqlens) batch_size, num_tokens, num_heads, chunk_size = x.shape x = -pad_and_reshape(x, dim=1, chunk_size=chunk_size).swapaxes( 2, 3 ) # [B, N, H, C, D] for i in range(1, chunk_size): row = x[..., i, :i].clone() sub = x[..., :i, :i].clone() x[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) x += torch.eye(chunk_size, dtype=x.dtype, device=x.device) x = x.swapaxes(2, 3).reshape((batch_size, -1, num_heads, chunk_size))[ :, :num_tokens ] if cu_seqlens is not None: x = pack(x, cu_seqlens) return x def torch_w_u_fwd( k: torch.Tensor, # [B, T, Hk, K] v: torch.Tensor, # [B, T, Hv, V] g: torch.Tensor, # [B, T, Hv] beta: torch.Tensor, # [B, T, Hv] A: torch.Tensor, # [B, T, Hv, D] cu_seqlens: torch.Tensor = None, ): if cu_seqlens is not None: k = unpack(k, cu_seqlens) v = unpack(v, cu_seqlens) A = unpack(A, cu_seqlens) beta = unpack(beta, cu_seqlens) g = unpack(g, cu_seqlens) batch_size, num_tokens, _, chunk_size = A.shape _, _, num_k_heads, head_dim_k = k.shape _, _, num_v_heads, head_dim_v = v.shape if num_k_heads != num_v_heads: k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2) k_beta = pad_and_reshape( k * beta.unsqueeze(-1) * g.exp().unsqueeze(-1), dim=1, chunk_size=chunk_size ) # [B, N, C, Hv, K] v_beta = pad_and_reshape( v * beta.unsqueeze(-1), dim=1, chunk_size=chunk_size ) # [B, N, C, Hv, V] A = pad_and_reshape(A, dim=1) w = torch.einsum("bnchd, bndhk -> bnchk", A, k_beta).reshape( (batch_size, -1, num_v_heads, head_dim_k) )[:, :num_tokens] u = torch.einsum("bnchd, bndhk -> bnchk", A, v_beta).reshape( (batch_size, -1, num_v_heads, head_dim_v) )[:, :num_tokens] if cu_seqlens is not None: w = pack(w, cu_seqlens) u = pack(u, cu_seqlens) return w, u def torch_chunk_gdr_fwd( k: torch.Tensor, # [B, T, Hk, K] w: torch.Tensor, # [B, T, Hv, K] u: torch.Tensor, # [B, T, Hv, V] g: torch.Tensor, # [B, T, Hv] initial_state: torch.Tensor = None, # [B, Hv, K, V] cu_seqlens: torch.Tensor = None, chunk_size: int = 64, ): if cu_seqlens is not None: k = unpack(k, cu_seqlens) w = unpack(w, cu_seqlens) u = unpack(u, cu_seqlens) g = unpack(g, cu_seqlens) batch_size, num_tokens, num_k_heads, head_dim_k = k.shape _, _, num_v_heads, head_dim_v = u.shape if num_k_heads != num_v_heads: k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2) k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K] w = pad_and_reshape(w, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K] u = pad_and_reshape(u, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V] g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, Hv] g = fill_last_chunk_of_g(g, num_tokens, cu_seqlens, chunk_size=chunk_size) if initial_state is None: last_state = torch.zeros( (batch_size, num_v_heads, head_dim_k, head_dim_v), dtype=g.dtype, device=g.device, ) else: last_state = initial_state.to(g.dtype, copy=True) h, vn = [], [] for i in range(k.shape[1]): h.append(last_state) v_new = u[:, i] - torch.einsum("bchk, bhkv -> bchv", w[:, i], last_state) vn.append(v_new) last_state = last_state * g[:, i, -1, :, None, None].exp() last_state = last_state + torch.einsum( "bchk, bchv -> bhkv", k[:, i] * (g[:, i, -1:, :, None] - g[:, i, :, :, None]).exp(), v_new, ) h = torch.stack(h, dim=1).contiguous() vn = ( torch.stack(vn, dim=1) .reshape((batch_size, -1, num_v_heads, head_dim_v))[:, :num_tokens] .contiguous() ) if cu_seqlens is not None: vn = pack(vn, cu_seqlens) h = pack(h, prepare_chunk_offsets(cu_seqlens, chunk_size)) return h, vn, last_state def torch_chunk_o_fwd( q: torch.Tensor, # [B, T, Hk, K] k: torch.Tensor, # [B, T, Hk, K] v: torch.Tensor, # [B, T, Hv, K] h: torch.Tensor, # [B, N, Hv, K, V] g: torch.Tensor, # [B, T, Hv] cu_seqlens: torch.Tensor = None, scale: float = None, chunk_size: int = 64, ): if cu_seqlens is not None: q = unpack(q, cu_seqlens) k = unpack(k, cu_seqlens) v = unpack(v, cu_seqlens) g = unpack(g, cu_seqlens) h = unpack(h, prepare_chunk_offsets(cu_seqlens, chunk_size)) batch_size, num_tokens, num_k_heads, head_dim_k = k.shape _, _, num_v_heads, head_dim_v = v.shape if num_k_heads != num_v_heads: q = q.repeat_interleave(num_v_heads // num_k_heads, dim=2) k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2) scale = scale or head_dim_k ** (-0.5) q = pad_and_reshape(q, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K] k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K] v = pad_and_reshape(v, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K] g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, Hv] q = q * scale mask = torch.triu( torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device), diagonal=1, ) decay_mask = torch.exp(g[:, :, :, None, :] - g[:, :, None, :, :]) decay_mask = decay_mask.masked_fill( mask[None, None, :, :, None], 0.0 ) # [B, N, C, D, Hv] attn = torch.einsum("bnchk, bndhk -> bncdh", q, k) * decay_mask attn_inter = torch.einsum("bnchk, bnhkv -> bnchv", q * g.exp().unsqueeze(-1), h) o = attn_inter + torch.einsum("bncdh, bndhv -> bnchv", attn, v) o = o.reshape((batch_size, -1, num_v_heads, head_dim_v))[:, :num_tokens] if cu_seqlens is not None: o = pack(o, cu_seqlens) return o def torch_chunk_dv_bwd( q: torch.Tensor, # [B, T, Hk, K] k: torch.Tensor, # [B, T, Hk, K] g: torch.Tensor, # [B, T, Hv] do: torch.Tensor, # [B, T, Hv, V] cu_seqlens: torch.Tensor = None, scale: float = None, chunk_size: int = 64, ): if cu_seqlens is not None: q = unpack(q, cu_seqlens) k = unpack(k, cu_seqlens) g = unpack(g, cu_seqlens) do = unpack(do, cu_seqlens) batch_size, num_tokens, num_k_heads, head_dim_k = k.shape _, _, num_v_heads, head_dim_v = do.shape if num_k_heads != num_v_heads: q = q.repeat_interleave(num_v_heads // num_k_heads, dim=2) k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2) scale = scale or head_dim_k ** (-0.5) q = pad_and_reshape(q, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K] k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K] g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, Hv] do = pad_and_reshape(do, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V] q = q * scale mask = torch.triu( torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device), diagonal=1, ) decay_mask = torch.exp(g[:, :, :, None, :] - g[:, :, None, :, :]) decay_mask = decay_mask.masked_fill( mask[None, None, :, :, None], 0.0 ) # [B, N, C, D, Hv] attn = torch.einsum("bnchk, bndhk -> bncdh", q, k) * decay_mask dv = torch.einsum("bncdh, bnchv -> bndhv", attn, do) dv = dv.reshape((batch_size, -1, num_v_heads, head_dim_v))[:, :num_tokens] if cu_seqlens is not None: dv = pack(dv, cu_seqlens) return dv def torch_chunk_gdr_bwd( q: torch.Tensor, # [B, T, Hk, K] k: torch.Tensor, # [B, T, Hk, K] w: torch.Tensor, # [B, T, Hv, K] g: torch.Tensor, # [B, T, Hv] do: torch.Tensor, # [B, T, Hv, V] dv: torch.Tensor, # [B, T, Hv, V] h0: torch.Tensor = None, # [B, Hv, K, V] dht: torch.Tensor = None, # [B, Hv, K, V] cu_seqlens: torch.Tensor = None, scale: float = None, chunk_size: int = 64, ): if cu_seqlens is not None: q = unpack(q, cu_seqlens) k = unpack(k, cu_seqlens) w = unpack(w, cu_seqlens) g = unpack(g, cu_seqlens) do = unpack(do, cu_seqlens) dv = unpack(dv, cu_seqlens) batch_size, num_tokens, num_k_heads, head_dim_k = k.shape _, _, num_v_heads, head_dim_v = do.shape if num_k_heads != num_v_heads: q = q.repeat_interleave(num_v_heads // num_k_heads, dim=2) k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2) scale = scale or head_dim_k ** (-0.5) q = pad_and_reshape(q, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K] k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K] w = pad_and_reshape(w, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K] g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, Hv] do = pad_and_reshape(do, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V] dv = pad_and_reshape(dv, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V] g = fill_last_chunk_of_g(g, num_tokens, cu_seqlens, chunk_size=chunk_size) q = q * scale if dht is None: dstate = torch.zeros( (batch_size, num_v_heads, head_dim_k, head_dim_v), dtype=g.dtype, device=g.device, ) else: dstate = dht.to(g.dtype, copy=True) dstate_inter = torch.einsum("bnchk, bnchv -> bnhkv", q * g.exp().unsqueeze(-1), do) dh = [] for i in reversed(range(k.shape[1])): dh.insert(0, dstate) dv[:, i] += torch.einsum( "bchk, bhkv -> bchv", k[:, i] * (g[:, i, -1:, :, None] - g[:, i, :, :, None]).exp(), dstate, ) dstate = dstate * g[:, i, -1, :, None, None].exp() dstate = ( dstate + dstate_inter[:, i] - torch.einsum("bchk, bchv -> bhkv", w[:, i], dv[:, i]) ) dh = torch.stack(dh, dim=1).contiguous() dh0 = None if h0 is None else dstate dv = dv.reshape((batch_size, -1, num_v_heads, head_dim_v))[:, :num_tokens] if cu_seqlens is not None: dv = pack(dv, cu_seqlens) dh = pack(dh, prepare_chunk_offsets(cu_seqlens, chunk_size)) return dh, dh0, dv def torch_chunk_dqkwg_bwd( q: torch.Tensor, # [B, T, Hk, K] k: torch.Tensor, # [B, T, Hk, K] v: torch.Tensor, # [B, T, Hv, V] w: torch.Tensor, # [B, T, Hv, K] g: torch.Tensor, # [B, T, Hv] h: torch.Tensor, # [B, N, Hv, K, V] dv: torch.Tensor, # [B, T, Hv, V] do: torch.Tensor, # [B, T, Hv, V] dh: torch.Tensor, # [B, N, Hv, K, V] cu_seqlens: torch.Tensor = None, scale: float = None, chunk_size: int = 64, ): if cu_seqlens is not None: q = unpack(q, cu_seqlens) k = unpack(k, cu_seqlens) v = unpack(v, cu_seqlens) w = unpack(w, cu_seqlens) g = unpack(g, cu_seqlens) do = unpack(do, cu_seqlens) dv = unpack(dv, cu_seqlens) h = unpack(h, prepare_chunk_offsets(cu_seqlens, chunk_size)) dh = unpack(dh, prepare_chunk_offsets(cu_seqlens, chunk_size)) batch_size, num_tokens, num_k_heads, head_dim_k = k.shape _, _, num_v_heads, head_dim_v = do.shape if num_k_heads != num_v_heads: q = q.repeat_interleave(num_v_heads // num_k_heads, dim=2) k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2) scale = scale or head_dim_k ** (-0.5) q = pad_and_reshape(q, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K] k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K] v = pad_and_reshape(v, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V] w = pad_and_reshape(w, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K] g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, Hv] do = pad_and_reshape(do, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V] dv = pad_and_reshape(dv, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V] g = fill_last_chunk_of_g(g, num_tokens, cu_seqlens, chunk_size=chunk_size) mask = torch.triu( torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device), diagonal=1, ) decay_mask = torch.exp(g[:, :, :, None, :] - g[:, :, None, :, :]) decay_mask = decay_mask.masked_fill( mask[None, None, :, :, None], 0.0 ) # [B, N, C, D, Hv] dg_last = (h * dh).sum(dim=-1).sum(dim=-1) # [B, N, Hv] ds = torch.einsum("bnchv, bndhv -> bncdh", do, v) dq = torch.einsum("bnchv, bnhkv -> bnchk", do, h) dk = torch.einsum("bnchv, bnhkv -> bnchk", v, dh) dw = -torch.einsum("bnchv, bnhkv -> bnchk", dv, h) g_last = g[:, :, -1] dg_last *= g_last.exp() dq = dq * g.unsqueeze(-1).exp() * scale dg = (q * dq).sum(dim=-1) # [B, N, C, Hv] dk = dk * (g_last.unsqueeze(-2) - g).unsqueeze(-1).exp() dg -= (k * dk).sum(dim=-1) dg_last += (k * dk).sum(dim=-1).sum(dim=-2) ds *= decay_mask * scale ds2 = ds * torch.einsum("bnchk, bndhk -> bncdh", q, k) dg += ds2.sum(dim=-2) dg -= ds2.sum(dim=-3) dq += torch.einsum("bncdh, bndhk -> bnchk", ds, k) dk += torch.einsum("bncdh, bnchk -> bndhk", ds, q) dg[:, :, -1] += dg_last dg = fill_last_chunk_of_g( dg, num_tokens, cu_seqlens, chunk_size=chunk_size, reverse=True ) dq = dq.reshape((batch_size, -1, num_v_heads, head_dim_k))[:, :num_tokens] dk = dk.reshape((batch_size, -1, num_v_heads, head_dim_k))[:, :num_tokens] dw = dw.reshape((batch_size, -1, num_v_heads, head_dim_k))[:, :num_tokens] dg = dg.reshape((batch_size, -1, num_v_heads))[:, :num_tokens] if cu_seqlens is not None: dq = pack(dq, cu_seqlens) dk = pack(dk, cu_seqlens) dw = pack(dw, cu_seqlens) dg = pack(dg, cu_seqlens) return dq, dk, dw, dg def torch_chunk_wy_bwd( k: torch.Tensor, # [B, T, Hk, K] v: torch.Tensor, # [B, T, Hv, V] beta: torch.Tensor, # [B, T, Hv] A: torch.Tensor, # [B, T, Hv, D] g: torch.Tensor, # [B, T, Hv] dw: torch.Tensor, # [B, T, Hv, K] du: torch.Tensor, # [B, T, Hv, V] dk1: torch.Tensor, # [B, T, Hv, K] dg1: torch.Tensor, # [B, T, Hv] cu_seqlens: torch.Tensor = None, ): if cu_seqlens is not None: k = unpack(k, cu_seqlens) v = unpack(v, cu_seqlens) beta = unpack(beta, cu_seqlens) A = unpack(A, cu_seqlens) g = unpack(g, cu_seqlens) dw = unpack(dw, cu_seqlens) du = unpack(du, cu_seqlens) dk1 = unpack(dk1, cu_seqlens) dg1 = unpack(dg1, cu_seqlens) batch_size, num_tokens, num_k_heads, head_dim_k = k.shape _, _, num_v_heads, head_dim_v = v.shape chunk_size = A.shape[-1] if num_k_heads != num_v_heads: k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2) k = pad_and_reshape(k, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K] v = pad_and_reshape(v, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V] beta = pad_and_reshape(beta, dim=1, chunk_size=chunk_size) # [B, N, C, Hv] A = pad_and_reshape(A, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, D] g = pad_and_reshape(g, dim=1, chunk_size=chunk_size) # [B, N, C, Hv] dw = pad_and_reshape(dw, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K] du = pad_and_reshape(du, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, V] dk1 = pad_and_reshape(dk1, dim=1, chunk_size=chunk_size) # [B, N, C, Hv, K] dg1 = pad_and_reshape(dg1, dim=1, chunk_size=chunk_size) # [B, N, C, Hv] dA = torch.einsum("bnchk, bndhk -> bnchd", dw, k * (beta * g.exp()).unsqueeze(-1)) dk_beta_g = torch.einsum("bnchd, bnchk -> bndhk", A, dw) dk = dk_beta_g * (beta * g.exp()).unsqueeze(-1) db = (dk_beta_g * k * g.exp().unsqueeze(-1)).sum(dim=-1) dg = (dk_beta_g * k * (g.exp() * beta).unsqueeze(-1)).sum(dim=-1) dA += torch.einsum("bnchv, bndhv -> bnchd", du, v * beta.unsqueeze(-1)) dv_beta = torch.einsum("bnchd, bnchv -> bndhv", A, du) dv = dv_beta * beta.unsqueeze(-1) db += (dv_beta * v).sum(dim=-1) mask = torch.triu( torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=k.device) ) decay_mask = torch.exp(g[:, :, :, None, :] - g[:, :, None, :, :]) decay_mask = decay_mask.masked_fill(mask[None, None, :, :, None], 0.0).swapaxes( -2, -1 ) dA = dA.masked_fill(mask[None, None, :, None, :], 0.0) dA = torch.einsum("bndhc, bndhe -> bnche", A, dA) dA = torch.einsum("bnchd, bnehd -> bnche", dA, A) dA = -dA * decay_mask A = torch.einsum("bnchk, bndhk -> bnchd", k * beta.unsqueeze(-1), k) dk_beta = torch.einsum("bnchd, bndhk -> bnchk", dA, k) db += (dk_beta * k).sum(dim=-1) dk += torch.einsum("bnchd, bnchk -> bndhk", dA, k * beta.unsqueeze(-1)) dk += dk_beta * beta.unsqueeze(-1) dk += dk1 dg += (dA * A).sum(dim=-1) - (dA * A).sum(dim=-3).swapaxes(-1, -2) dg += dg1 # TODO: NOTE: GVA dk = dk.reshape((batch_size, -1, num_v_heads, head_dim_k))[:, :num_tokens] dv = dv.reshape((batch_size, -1, num_v_heads, head_dim_k))[:, :num_tokens] db = db.reshape((batch_size, -1, num_v_heads))[:, :num_tokens] dg = dg.reshape((batch_size, -1, num_v_heads))[:, :num_tokens] if cu_seqlens is not None: dk = pack(dk, cu_seqlens) dv = pack(dv, cu_seqlens) db = pack(db, cu_seqlens) dg = pack(dg, cu_seqlens) return dk, dv, db, dg def chunk_gated_delta_rule_fwd( q: torch.Tensor, # [B, T, Hk, K] k: torch.Tensor, # [B, T, Hk, K] v: torch.Tensor, # [B, T, Hv, K] g: torch.Tensor, # [B, T, Hv] beta: torch.Tensor, # [B, T, Hv] cu_seqlens: torch.Tensor = None, initial_state: torch.Tensor = None, scale: float = None, chunk_size: int = 64, ): scale = scale or q.shape[-1] ** (-0.5) g = torch_cumsum( x=g, cu_seqlens=cu_seqlens, chunk_size=chunk_size, ) A = torch_kkt_fwd( k=k, g=g, beta=beta, cu_seqlens=cu_seqlens, chunk_size=chunk_size, ) A = torch_solve( x=A, cu_seqlens=cu_seqlens, ) w, u = torch_w_u_fwd( k=k, v=v, beta=beta, A=A, g=g, cu_seqlens=cu_seqlens, ) h, vn, final_state = torch_chunk_gdr_fwd( k=k, w=w, u=u, g=g, initial_state=initial_state, cu_seqlens=cu_seqlens, chunk_size=chunk_size, ) o = torch_chunk_o_fwd( q=q, k=k, v=vn, h=h, g=g, cu_seqlens=cu_seqlens, scale=scale, chunk_size=chunk_size, ) return g, o, A, h, final_state def chunk_gated_delta_rule_bwd( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, beta: torch.Tensor, A: torch.Tensor, scale: float, initial_state: torch.Tensor, do: torch.Tensor, dht: torch.Tensor, cu_seqlens: torch.Tensor = None, chunk_size: int = 64, ): w, u = torch_w_u_fwd( k=k, v=v, beta=beta, A=A, g=g, cu_seqlens=cu_seqlens, ) h, vn, _ = torch_chunk_gdr_fwd( k=k, w=w, u=u, g=g, initial_state=initial_state, cu_seqlens=cu_seqlens, chunk_size=chunk_size, ) dv = torch_chunk_dv_bwd( q=q, k=k, g=g, do=do, scale=scale, cu_seqlens=cu_seqlens, ) dh, dh0, dv = torch_chunk_gdr_bwd( q=q, k=k, w=w, g=g, h0=initial_state, dht=dht, do=do, dv=dv, scale=scale, cu_seqlens=cu_seqlens, ) dq, dk1, dw, dg1 = torch_chunk_dqkwg_bwd( q=q, k=k, v=vn, w=w, g=g, h=h, dv=dv, do=do, dh=dh, scale=scale, cu_seqlens=cu_seqlens, ) dk, dv, db, dg = torch_chunk_wy_bwd( k=k, v=v, beta=beta, g=g, A=A, dw=dw, du=dv, dk1=dk1, dg1=dg1, cu_seqlens=cu_seqlens, ) Hg, H = k.shape[-2], v.shape[-2] if Hg < H: B, T, _, K = dq.shape dq = torch.sum(dq.reshape(B, T, Hg, -1, K), dim=3) dk = torch.sum(dk.reshape(B, T, Hg, -1, K), dim=3) dg = torch_cumsum(dg, chunk_size=64, reverse=True, cu_seqlens=cu_seqlens) return dq, dk, dv, db, dg, dh0