阅读笔记:GDN 在 FLA 中的实现笔记
原文链接:https://arxiv.org/pdf/2412.06464
代码链接:https://github.com/fla-org/flash-linear-attention
声明:本文为个人阅读笔记。所有来自我自己的推导都可能存在错误。
一、前向传播
入口函数签名:
def chunk_gated_delta_rule_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
cp_context: FLACPContext | None = None,
chunk_indices: torch.LongTensor | None = None,
use_exp2: bool = True,
transpose_state_layout: bool = False,
):
return g, o, A, final_state, initial_state
记号
\[\begin{aligned}
\overleftarrow{\square_{[t]}}
&=
\text{Diag}(\boldsymbol{\gamma}_{[t]})
\square_{[t]}
,\quad
\overrightarrow{\square_{[t]}}
=
\text{Diag}(\boldsymbol{\gamma}_{[t]})^{-1}
\square_{[t]}
\quad\text{for}\quad \square \in \{ \mathbf{Q}, \mathbf{K}, \mathbf{W}\}
\end{aligned}\]
Step 1:计算 log 域下的累积门控
这里的 gamma 对应之前定义的 log gamma,因此后续公式需要注意这一点。
[公式占位符:log 域下 gamma 的定义]
对应代码:fla.ops.utils.cumsum -> chunk_local_cumsum
Step 2:求解下三角线性系统
\[\begin{aligned}
\widetilde{\mathbf{A_0}}_{[t]} = \text{Diag}(\boldsymbol{\beta}_{[t]}) \left( \overleftarrow{\mathbf{K}_{[t]}} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M}_{-1} \right)
\\
\\
\widetilde{\mathbf{A}}_{[t]} = (\mathbf{I} + \widetilde{\mathbf{A_0}}_{[t]})^{-1}
\end{aligned}\]
对应代码:fla.ops.gated_delta_rule.chunk_fwd -> chunk_gated_delta_rule_fwd_intra -> chunk_gated_delta_rule_fwd_kkt_solve_kernel
Step 3:计算 U 与左缩放的 W
\[\begin{aligned}
\mathbf{U}_{[t]}
&=
\widetilde{\mathbf{A}}_{[t]}
\text{Diag}(\boldsymbol{\beta}_{[t]})
\mathbf{V}_{[t]}
\\
\\
\overleftarrow{\mathbf{W}_{[t]}}
&=
\widetilde{\mathbf{A}}_{[t]}
\text{Diag}(\exp \boldsymbol{\gamma}_{[t]})
\text{Diag}(\boldsymbol{\beta}_{[t]})
\mathbf{K}_{[t]}
\end{aligned}\]
对应代码:fla.ops.gated_delta_rule.wy_fast -> recompute_w_u_fwd
Step 4:递推计算隐藏状态
这里的 S_[t]^C 采用 [d_k, d_v] 的布局,因此与之前推导相比存在转置关系。
\[\begin{aligned}
\mathbf{V}_{[t],new}
&=
\mathbf{U}_{[t]}
- \overleftarrow{\mathbf{W}_{[t]}}
\mathbf{S}_{[t-1]}^{C}
\\
\\
\mathbf{S}_{[t]}^{C}
&=
\boldsymbol{\gamma}_{[t]}^C
\mathbf{S}_{[t-1]}^{C}
+ \mathbf{K}_{[t]}^\top
\left(
\text{Diag}(\exp(\boldsymbol{\gamma}_{[t]}^C - \boldsymbol{\gamma}_{[t]})
\mathbf{V}_{[t],new}
\right)
\end{aligned}\]
对应代码:fla.ops.common.chunk_delta_h -> chunk_gated_delta_rule_fwd_h
Step 5:计算最终输出
\[\begin{aligned}
\mathbf{O}_{[t]}
=
\text{Diag}\left(\exp \boldsymbol{\gamma}_{[t]} \right)
\mathbf{Q}_{[t]} \mathbf{S}_{[t-1]}^{C}
+
\left( \overleftarrow{\mathbf{Q}_{[t]}} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M} \right) \mathbf{V}_{[t],new}
\end{aligned}\]
对应代码:fla.ops.common.chunk_o -> chunk_fwd_o
评论:从代码来看,2412.06464v3 的 Equation 9 确实是笔误。
二、反向传播
入口函数签名:
def chunk_gated_delta_rule_bwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
A: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
do: torch.Tensor,
dht: torch.Tensor,
cu_seqlens: torch.LongTensor | None = None,
cp_context: FLACPContext | None = None,
chunk_indices: torch.LongTensor | None = None,
use_exp2: bool = True,
transpose_state_layout: bool = False,
):
return dq, dk, dv, db, dg, dh0
Step 1:复用已存储的逆矩阵
这里的 A_[t] = (I + A0_[t])^{-1} 已经被存储下来了。其余部分与前向相同。
对应代码:fla.ops.gated_delta_rule.wy_fast -> recompute_w_u_fwd
Step 2:重算隐藏状态
这一步与前向相同。
对应代码:fla.ops.common.chunk_delta_h -> chunk_gated_delta_rule_fwd_h
Step 3:计算来自 O 的 dU
\[\begin{aligned}
\left.\delta \mathbf{U}_{[t]}\right|_{\text{from } \mathbf{O}_{[t]}}
=
\left(
\overrightarrow{\mathbf{K}_{[t]}} \overleftarrow{\mathbf{Q}_{[t]}}^\top
\odot \mathbf{M}^\top \right) \delta \mathbf{O}_{[t]}
\end{aligned}\]
对应代码:fla.ops.common.chunk_o -> chunk_bwd_dv_local
Step 4:对 U 和状态做反向递推
\[\begin{aligned}
\delta \mathbf{U}_{[t]}
=
\text{Diag}\left(\exp(\boldsymbol{\gamma}_{[t]}^C - \boldsymbol{\gamma}_{[t]})\right)
\mathbf{K}_{[t]} \delta \mathbf{S}_{[t]}^{C}
+
\left.\delta \mathbf{U}_{[t]}\right|_{\text{from } \mathbf{O}_{[t]}}
\end{aligned}\]
\[\begin{aligned}
\delta \mathbf{S}_{[t]}^{C}
=
\exp(\boldsymbol{\gamma}_{[t]}^C)
\delta \mathbf{S}_{[t+1]}^{C}
+
\overleftarrow{\mathbf{Q}_{[t+1]}}^\top \delta \mathbf{O}_{[t+1]}
-
\overleftarrow{\mathbf{W}_{[t+1]}} ^\top
\delta \mathbf{U}_{[t+1]}
\end{aligned}\]
对应代码:fla.ops.common.chunk_delta_h -> chunk_gated_delta_rule_bwd_dhu
Step 5:计算 dB、dQ、部分 dK、dW,以及部分 d gamma
\[\begin{aligned}
\delta \mathbf{B}_{[t]}
&=
\delta \mathbf{O}_{[t]} \mathbf{V}_{[t],new}^\top \odot \mathbf{M}
\\
\\
\delta \mathbf{Q}_{[t]}
&=
\text{Diag}(\boldsymbol{\gamma}_{[t]}) \delta \mathbf{O}_{[t]} \mathbf{S}_{[t-1]}^{C\top}
+
\left(
\text{Diag}(\boldsymbol{\gamma}_{[t]}) \delta \mathbf{B}_{[t]}
\text{Diag}(\boldsymbol{\gamma}_{[t]})^{-1}
\right)
\mathbf{K}_{[t]}
\\
\\
\delta \mathbf{K}_{[t], \text{part1}}
&=
\text{Diag}(\exp(\gamma_{[t]}^C - \boldsymbol{\gamma}_{[t]})) \mathbf{V}_{[t],new} \delta \mathbf{S}_{[t]}^{C\top}
+
\left(
\text{Diag}(\boldsymbol{\gamma}_{[t]}) \delta \mathbf{B}_{[t]}
\text{Diag}(\boldsymbol{\gamma}_{[t]})^{-1}
\right)^\top
\mathbf{Q}_{[t]}
\\&=
\left.\delta \mathbf{K}_{[t]}\right|_{\text{from } \mathbf{S}_{[t]} \text{ w/o } \mathbf{V}_{[t],new}}
+
\left.\delta \mathbf{K}_{[t]}\right|_{\text{from } \mathbf{O}_{[t]} \text{ w/o } \mathbf{V}_{[t],new}}
\\
\\
\delta \mathbf{W}_{[t]}
&=
- \delta \mathbf{U}_{[t]} \mathbf{S}_{[t-1]}^{C\top}
\end{aligned}\]
\[\begin{aligned}
\delta \boldsymbol{\gamma}_{[t]}^C
&=
\delta \boldsymbol{\gamma}_{[t]}^C
\exp \boldsymbol{\gamma}_{[t]}^C
=
\text{Tr}\left(
\delta \mathbf{S}_{[t]}^{C\top} \mathbf{S}_{[t-1]}^C
\right)
\exp \boldsymbol{\gamma}_{[t]}^C
+
\text{Tr}\left(
\left(
\left.\delta \mathbf{K}_{[t]}\right|_{\text{from } \mathbf{S}_{[t]} \text{ w/o } \mathbf{V}_{[t],new}}
\right)
\mathbf{K}_{[t]}^\top\right)
\\ \\ &=
\text{Tr}\left(
\delta \mathbf{S}_{[t]}^{C\top} \mathbf{S}_{[t-1]}^C
+
\delta \mathbf{S}_{[t]}^{C\top}
\left(
\mathbf{K}_{[t]}^\top
\text{Diag}(\boldsymbol{\gamma}_{[t]})^{-1}
\mathbf{V}_{[t]}^\top
\right)
\right)
\exp \boldsymbol{\gamma}_{[t]}^C
\\
\\
\delta \boldsymbol{\gamma}_{[t],\text{part1}}
&=
\delta \boldsymbol{\gamma}_{[t],\text{part1}} \odot \exp\boldsymbol{\gamma}_{[t]}
\\&=
\text{diag}\left(
\left(
\text{Diag}(\exp \boldsymbol{\gamma}_{[t]}) \delta \mathbf{O}_{[t]} \mathbf{S}_{[t-1]}^{C\top}
\right)
\mathbf{Q}_{[t]}^\top
\right)
\\&-
\text{diag}\left(
\mathbf{K}_{[t]}
\left(
\left.\delta \mathbf{K}_{[t]}\right|_{\text{from } \mathbf{S}_{[t]} \text{ w/o } \mathbf{V}_{[t],new}}
\right)^\top
\right)
\\&+
\text{diag}\left(
\left(
\text{Diag}(\exp \boldsymbol{\gamma}_{[t]}) \delta \mathbf{B}_{[t]}
\text{Diag}(\exp \boldsymbol{\gamma}_{[t]})^{-1}
\right)
\left(
\mathbf{Q}_{[t]}\mathbf{K}_{[t]}^\top
\right)^\top
\right)
\\&-
\text{diag}\left(
\left(
\text{Diag}(\exp \boldsymbol{\gamma}_{[t]}) \delta \mathbf{B}_{[t]}
\text{Diag}(\exp \boldsymbol{\gamma}_{[t]})^{-1}
\right)^\top
\left(
\mathbf{Q}_{[t]}\mathbf{K}_{[t]}^\top
\right)
\right)
\\&+
[0,0,...,\delta \boldsymbol{\gamma}_{[t]}^C
\exp \boldsymbol{\gamma}_{[t]}^C]^\top
\\&=
\left(
\left.\delta \text{Diag}(\exp \boldsymbol{\gamma}_{[t]})\right|_{\text{from } \mathbf{S}_{[t]} \text{w/o} \mathbf{V}_{[t],new}}
+
\left.\delta \text{Diag}(\exp \boldsymbol{\gamma}_{[t]})\right|_{\text{from } \mathbf{O}_{[t]} \text{w/o} \mathbf{V}_{[t],new}}
+
\left.\delta \text{Diag}(\exp \boldsymbol{\gamma}_{[t]})\right|_{\text{from } \gamma_{[t]}^C}
\right)
\odot \exp\boldsymbol{\gamma}_{[t]}
\end{aligned}\]
对应代码:fla.ops.common.chunk_o -> chunk_bwd_dqkwg
Step 6:计算 dK、dV、d beta、d gamma 的剩余部分
\[\begin{aligned}
\left.\delta \mathbf{K}_{[t]}\right|_{\text{from } \overleftarrow{\mathbf{W}_{[t]}} \text{ w/o } \mathbf{T}_{[t]} }
&=
\text{Diag}(\exp \boldsymbol{\gamma}_{[t]})
\text{Diag}(\boldsymbol{\beta}_{[t]})
\left(
\widetilde{\mathbf{A}}_{[t]}^\top
\delta \overleftarrow{\mathbf{W}_{[t]}}
\right)
\\
\\
\delta \mathbf{V}_{[t]}
&=
\text{Diag}(\boldsymbol{\beta}_{[t]})
\widetilde{\mathbf{A}}_{[t]}^\top
\delta \mathbf{U}_{[t]}
\\
\\
\delta \widetilde{\mathbf{A}}_{[t]}
&=
\delta \overleftarrow{\mathbf{W}_{[t]}} (\text{Diag}(\boldsymbol{\beta}_{[t]} \exp \boldsymbol{\gamma}_{[t]})\mathbf{K}_{[t]})^\top
+
\delta \mathbf{U}_{[t]} (\text{Diag}(\boldsymbol{\beta}_{[t]})\mathbf{V}_{[t]} )^\top
\\
\\
\delta \widetilde{\mathbf{A_x}}_{[t]}
&= -
\text{Diag}(\exp \boldsymbol{\gamma}_{[t]})
\widetilde{\mathbf{A}}_{[t]}^\top
\delta \widetilde{\mathbf{A}}_{[t]}
\widetilde{\mathbf{A}}_{[t]}^\top
\text{Diag}(\exp \boldsymbol{\gamma}_{[t]})^{-1}
\odot
\mathbf{M_{-1}}
\\&=
\text{Diag}(\boldsymbol{\beta}_{[t]})^{-1}
\left.\delta (\mathbf{K}_{[t]} \mathbf{K}_{[t]}^\top)\right|_{\text{from } \widetilde{\mathbf{X}}_{[t]}}
\\
\\
\left.\delta \mathbf{K}_{[t]}\right|_{\text{from } \mathbf{X}_{[t]}}
&=
\left(
\text{Diag}(\boldsymbol{\beta}_{[t]})
\delta \widetilde{\mathbf{A_x}}_{[t]}
\right)
\mathbf{K}_{[t]}
+
\left(
\left(
\text{Diag}(\boldsymbol{\beta}_{[t]})
\mathbf{K}_{[t]}
\right)^\top
\delta \widetilde{\mathbf{A_x}}_{[t]}
\right)^\top
\\
\\
\delta \boldsymbol{\beta}_{[t]}
&=
\text{diag}\left(
\text{Diag}(\boldsymbol{\gamma}_{[t]})
\widetilde{\mathbf{A}}_{[t]}^\top
\delta \overleftarrow{\mathbf{W}_{[t]}}
\mathbf{K}_{[t]}^\top
\right)
+
\text{diag}\left(
\widetilde{\mathbf{A}}_{[t]}^\top
\delta \mathbf{U}_{[t]}
\mathbf{V}_{[t]}^\top
\right)
+
\text{diag}\left(
\delta \widetilde{\mathbf{A_x}}_{[t]}
\mathbf{K}_{[t]} \mathbf{K}_{[t]}^\top\right)
\\
\\
\delta \boldsymbol{\gamma}_{[t],\text{part2}}
&=
\text{diag}\left(
\left.\delta \text{Diag}(\exp \boldsymbol{\gamma}_{[t]})\right|_{\text{from } \mathbf{U}_{[t]} \text{w/ } \mathbf{W}_{[t]} \text{w/o } \widetilde{\mathbf{A}}_{[t]} }
\right)
\odot \exp \boldsymbol{\gamma}_{[t]}
+
\text{diag}\left(
\left.\delta \text{Diag}(\exp \boldsymbol{\gamma}_{[t]})\right|_{\text{from } \widetilde{\mathbf{A}}_{[t]} }
\right)
\odot \exp \boldsymbol{\gamma}_{[t]}
\\&=
\text{diag}\left(
\widetilde{\mathbf{A}}_{[t]}^\top
\delta \overleftarrow{\mathbf{W}}_{[t]}
\left(
\mathbf{K}_{[t]}
\text{Diag}(\boldsymbol{\beta}_{[t]})
\text{Diag}(\boldsymbol{\gamma}_{[t]})
\right)^\top
\right)
\\&+
\text{diag}\left(
\delta \widetilde{\mathbf{A_x}}_{[t]}
\left(
\text{Diag}(\boldsymbol{\beta}_{[t]})
\mathbf{K}_{[t]} \mathbf{K}_{[t]}^\top
\right)^\top
\right)
-
\text{diag}\left(
\left(
\text{Diag}(\boldsymbol{\beta}_{[t]})
\mathbf{K}_{[t]} \mathbf{K}_{[t]}^\top
\right)^\top
\delta \widetilde{\mathbf{A_x}}_{[t]}
\right)
\end{aligned}\]
对应代码:fla.ops.gated_delta_rule.wy_fast -> prepare_wy_repr_bwd
Step 7:恢复 log alpha 的梯度
\[\begin{aligned}
\delta \log \boldsymbol{\alpha}_{[t]}
&=
\text{suffix\_cumsum}(\delta\mathbf{\gamma}_{[t]})
\end{aligned}\]
对应代码:fla.ops.utils.cumsum -> chunk_local_cumsum