Reading Notes: Implementation Notes for KDA in FLA
Paper: https://arxiv.org/pdf/2510.26692
Code: https://github.com/fla-org/flash-linear-attention
Disclaimer: 本文为个人阅读笔记。部分推导由我本人完成,可能存在误差,建议结合 Triton 源码交叉验证。
前向传播
入口函数:
def chunk_kda_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens_cpu: torch.LongTensor | None = None,
chunk_indices: torch.LongTensor | None = None,
chunk_size: int = 64,
safe_gate: bool = False,
lower_bound: float | None = None,
use_gate_in_kernel: bool = False,
A_log: torch.Tensor | None = None,
dt_bias: torch.Tensor | None = None,
disable_recompute: bool = False,
return_intermediate_states: bool = False,
cp_context: FLACPContext | None = None,
transpose_state_layout: bool = False,
):
return o, final_state, g, Aqk, Akk, w, u, qg, kg, v_new, h, initial_state
符号约定
\[\begin{aligned}
\overleftarrow{\square_{[t]}}
&=
\square_{[t]}
\odot
\mathbf{\Gamma}_{[t]}
,\quad
\overrightarrow{\square_{[t]}}
=
\square_{[t]}
\oslash
\mathbf{\Gamma}_{[t]}
\quad\text{for}\quad \square \in \{ \mathbf{Q}, \mathbf{K}\}
\end{aligned}\]
步骤 1:在对数域中计算累积门控
此处变量 gamma 实际上代表的是对数域中的累积和(即 log gamma),在该笔记中皆是如此。
\[\begin{aligned}
\boldsymbol{\gamma}_{[t]}^r &= \sum_{i=tC+1}^{tC+r} \log \alpha_i \in \mathbb{R}
\end{aligned}\]
对应代码: fla.ops.utils.cumsum -> chunk_local_cumsum
步骤 2:求解下三角系统以获取块内注意力
\[\begin{aligned}
\mathbf{A_{qk}}_{[t]} &= \overleftarrow{\mathbf{Q}_{[t]}} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M}
,\quad
\mathbf{A_{kk0}}_{[t]} =
\text{Diag}(\boldsymbol{\beta}_{[t]})
\left(
\overleftarrow{\mathbf{K}_{[t]}} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M}_{-1}
\right)
\\
\\
\mathbf{A}_{[t]} &= (\mathbf{I} + \mathbf{A_{kk0}}_{[t]})^{-1}
\end{aligned}\]
对应代码: fla.ops.kda.chunk_intra -> chunk_kda_fwd_intra -> chunk_kda_fwd_kernel_intra_sub_chunk / chunk_kda_fwd_kernel_inter_solve_fused
步骤 3:计算聚合状态 U 与 W
\[\begin{aligned}
\mathbf{U}_{[t]}
&=
\mathbf{A}_{[t]}
\left(
\text{Diag}(\boldsymbol{\beta}_{[t]})
\mathbf{V}_{[t]}
\right)
\\
\\
\mathbf{W}_{[t]}
&=
\mathbf{A}_{[t]}
\left(
(\exp \boldsymbol{\Gamma}_{[t]}) \odot
\text{Diag}(\boldsymbol{\beta}_{[t]})
\mathbf{K}_{[t]}
\right)
\\
\\
\mathbf{Q_g}_{[t]}
&=
(\exp \boldsymbol{\Gamma}_{[t]}) \odot \mathbf{Q}_{[t]}
\\
\\
\mathbf{K_g}_{[t]}
&=
(\exp (\boldsymbol{1} \boldsymbol{\gamma}_{[t]}^{C\top} - \boldsymbol{\Gamma}_{[t]})) \odot
\mathbf{K}_{[t]}
\end{aligned}\]
为了在训练的前向传播中节省显存带宽,如果开启了激活重计算,此时的 Q_g 不会被计算并写入显存,而是会在反向传播初期进行即时计算。
对应代码: fla.ops.kda.wy_fast -> recompute_w_u_fwd
步骤 4:跨块状态递推
这里隐状态 S_[t]^C 使用了 [d_k, d_v]的内存布局。因此在底层代码中,它相对于我们之前的纯数学推导是发生过转置的。
\[\begin{aligned}
\mathbf{V}_{[t],new}
&= \mathbf{U}_{[t]} - \mathbf{W}_{[t]} \mathbf{S}_{[t-1]}^{C}
\\
\\
\mathbf{S}_{[t]}^{C}
&= \boldsymbol{\gamma}_{[t]}^C \mathbf{S}_{[t-1]}^{C}
+ \mathbf{K_g}_{[t]}^\top \mathbf{V}_{[t],new}
\end{aligned}\]
对应代码: fla.ops.common.chunk_delta_h -> chunk_gated_delta_rule_fwd_h
步骤 5:计算最终输出
\[\begin{aligned}
\mathbf{O}_{[t]}
=
\left((\exp \boldsymbol{\Gamma}_{[t]}) \odot \mathbf{Q}_{[t]}\right) \mathbf{S}_{[t-1]}^{C}
+ \mathbf{A_{qk}}_{[t]} \mathbf{V}_{[t],new}
\end{aligned}\]
对应代码: fla.ops.gla.chunk -> chunk_gla_fwd_o_gk
反向传播
入口函数:
def chunk_kda_bwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
beta: torch.Tensor,
Aqk: torch.Tensor,
Akk: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
do: torch.Tensor,
dht: torch.Tensor,
g: torch.Tensor | None = None,
g_org: torch.Tensor | None = None,
cu_seqlens: torch.LongTensor | None = None,
chunk_indices: torch.LongTensor | None = None,
chunk_size: int = 64,
safe_gate: bool = False,
lower_bound: float | None = None,
use_gate_in_kernel: bool = False,
A_log: torch.Tensor | None = None,
dt_bias: torch.Tensor | None = None,
disable_recompute: bool = False,
cp_context: FLACPContext | None = None,
transpose_state_layout: bool = False,
**kwargs,
):
return dq, dk, dv, db, dg, dh0, dA, dbias
步骤 1:重计算聚合状态 U 与 W
这一步的计算逻辑与前向传播完全一致,唯一多出来的操作是计算并存储了 Q_g。
对应代码: fla.ops.kda.wy_fast -> recompute_w_u_fwd
Step 2: Recompute the hidden state
This 步骤 2:重计算隐状态
对应代码: fla.ops.common.chunk_delta_h -> chunk_gated_delta_rule_fwd_h
步骤 3:从输出中推导初始梯度
在这里,我们从输出的梯度中解开对 U 和 A_qk 的初始导数。
\[\begin{aligned}
\left.\delta \mathbf{U}_{[t]}\right|_{\text{from } \mathbf{O}_{[t]}}
&=
\mathbf{A_{qk}}_{[t]}^\top \delta \mathbf{O}_{[t]}
\\
\\
\delta \mathbf{A_{qk}}_{[t]}
&=
\delta \mathbf{O}_{[t]} \mathbf{V}_{[t],new}^\top \odot \mathbf{M}
\end{aligned}\]
对应代码: fla.ops.kda.chunk_bwd -> chunk_kda_bwd_dAv -> chunk_kda_bwd_kernel_dAv
步骤 4:隐状态的跨块反向递推
\[\begin{aligned}
\delta \mathbf{U}_{[t]}
=
\mathbf{K_g}_{[t]} \delta \mathbf{S}_{[t]}^{C}
+
\left.\delta \mathbf{U}_{[t]}\right|_{\text{from } \mathbf{O}_{[t]}}
\end{aligned}\]
\[\begin{aligned}
\delta \mathbf{S}_{[t]}^{C}
=
\exp(\boldsymbol{\gamma}_{[t]}^C)
\delta \mathbf{S}_{[t+1]}^{C}
+
\mathbf{Q_g}_{[t+1]}^\top \delta \mathbf{O}_{[t+1]}
-
\mathbf{W}_{[t+1]}^\top
\delta \mathbf{U}_{[t+1]}
\end{aligned}\]
对应代码: fla.ops.common.chunk_delta_h -> chunk_gated_delta_rule_bwd_dhu
步骤 5: WY 表示法的核心融合算子
这是一个极其庞大的融合算子。它在一个 Kernel 内集中计算了 W, V, beta, A_x 的梯度,以及 Q, K, Gamma 梯度的部分贡献。
\[\begin{aligned}
\delta \mathbf{Q}_{[t],part1}
&=
(\exp \boldsymbol{\Gamma}_{[t]}) \odot
\mathbf{O}_{[t]} \mathbf{S}_{[t-1]}^{C\top}
\\
\\
\delta \mathbf{W}_{[t]}
&= - \delta \mathbf{U}_{[t]} \mathbf{S}_{[t-1]}^{C}
\\
\\
\left.\delta \mathbf{K}_{[t], part1}\right|_{\text{from } \mathbf{S}_{[t]}^C \text{w/o} \mathbf{A}_{[t]} }
&=
\exp (\boldsymbol{1} \boldsymbol{\gamma}_{[t]}^{C\top} - \boldsymbol{\Gamma}_{[t]}) \odot
\mathbf{V}_{[t],new} \delta \mathbf{S}_{[t]}^{C\top}
\\
\\
\left.\delta \mathbf{K}_{[t],part1}\right|_{\text{from } \mathbf{W}_{[t]} \text{w/o} \mathbf{A}_{[t]} }
&=
\mathbf{A}_{[t]}^\top
\delta \mathbf{W}_{[t]}
\odot (\text{Diag}(\boldsymbol{\beta}_{[t]}) \exp \mathbf{\Gamma}_{[t]} )
\\
\\
\delta \mathbf{V}_{[t]}
&=
\text{Diag}(\boldsymbol{\beta}_{[t]})
\mathbf{A}_{[t]}^\top
\delta \mathbf{U}_{[t]}
\\
\\
\delta \boldsymbol{\gamma}_{[t]}^C
&=
\delta \exp \boldsymbol{\gamma}_{[t]}^C
\odot \exp \boldsymbol{\gamma}_{[t]}^C
\\ &=
\text{diag}\left(
\mathbf{S}_{[t-1]}^C
\delta \mathbf{S}_{[t]}^C
\right)
\odot \exp \boldsymbol{\gamma}_{[t]}^C
+
\left.\delta \mathbf{K}_{[t], part1}\right|_{\text{from } \mathbf{S}_{[t]}^C \text{w/o} \mathbf{A}_{[t]} }
\odot \mathbf{K}_{[t]}
\\
\\
\delta \mathbf{\Gamma}_{[t],part1}
&=
\left.\delta \exp \mathbf{\Gamma}_{[t]}\right|_{\text{w/o extra} \boldsymbol{\gamma}_{[t]}^C}
\odot \exp \mathbf{\Gamma}_{[t]}
\\ &=
\delta \mathbf{Q}_{[t],part1} \odot \mathbf{Q}_{[t]}
-
\left.\delta \mathbf{K}_{[t], part1}\right|_{\text{from } \mathbf{S}_{[t]}^C \text{w/o} \widetilde{\mathbf{A}}_{[t]} }
\odot \mathbf{K}_{[t]}
+
\left.\delta \mathbf{K}_{[t],part1}\right|_{\text{from } \mathbf{W}_{[t]} \text{w/o} \widetilde{\mathbf{A}}_{[t]} }
\odot \mathbf{K}_{[t]}
\\ &+
[0,0,...,\delta \boldsymbol{\gamma}_{[t]}^C ]
\\
\\
\delta \boldsymbol{\beta}_{[t],part1}
&=
\text{diag}(\delta \text{Diag}(\boldsymbol{\beta}_{[t]}))
=
\text{diag}(
\mathbf{A}_{[t]}^\top
\delta \mathbf{U}_{[t]}
\mathbf{V}_{[t]}^\top
)
+
\text{diag}(
\mathbf{A}_{[t]}^\top
\delta \mathbf{W}_{[t]}
(
\boldsymbol{\Gamma}_{[t]} \odot \mathbf{K}_{[t]}
)^\top
)
\\
\\
\delta \boldsymbol{\beta}_{[t]}
&=
\text{diag}(\delta \text{Diag}(\boldsymbol{\beta}_{[t]}))
\\ &=
\text{diag}\left(
\widetilde{\mathbf{A}}_{[t]}^\top
\delta \widetilde{\mathbf{A}}_{[t]}
\right)
\odot
\boldsymbol{\beta}_{[t]}^{-1}
+
\text{diag}\left(
\left(
\delta \widetilde{\mathbf{X}}_{[t]}
\odot \mathbf{M}_{-1}
\right)
\left(
\overrightarrow{\mathbf{K}_{[t]}} \overleftarrow{\mathbf{K}_{[t]}}^\top
\right)
\right)
\\
\\
\delta \mathbf{A}_{[t]}
&=
\left(
\delta \mathbf{U}_{[t]}
\mathbf{V}_{[t]}^\top
+
\delta \mathbf{W}_{[t]}
(\boldsymbol{\Gamma}_{[t]} \odot \mathbf{K}_{[t]})^\top
\right)
\text{Diag}(\boldsymbol{\beta}_{[t]})
\\
\\
\delta \mathbf{A_X}_{[t]}
&=
- \mathbf{A}_{[t]}^\top
\delta \mathbf{A}_{[t]}
\mathbf{A}_{[t]}^\top \odot \mathbf{M}_{-1}
\end{aligned}\]
对应代码: fla.ops.kda.chunk_bwd -> chunk_kda_bwd_wy_dqkg_fused
步骤 6:块内局部梯度计算
\[\begin{aligned}
\delta \mathbf{Q}_{[t],part2}
&=
\delta \mathbf{A_{qk}}_{[t]}
\left(
\exp (\boldsymbol{1} \boldsymbol{\gamma}_{[t]}^{C\top} - \boldsymbol{\Gamma}_{[t]})
\odot \mathbf{K}_{[t]})
\right)
\odot \exp (\boldsymbol{\Gamma}_{[t]} - \boldsymbol{1} \boldsymbol{\gamma}_{[t]}^{C\top})
\\
\\
\delta \mathbf{K}_{[t],part2}
&=
\text{Diag}(\boldsymbol{\beta}_{[t]})
\delta \mathbf{A_X}_{[t]}
\left(
\mathbf{K}_{[t]} \odot \exp (\boldsymbol{1} \boldsymbol{\gamma}_{[t]}^{C\top} - \boldsymbol{\Gamma}_{[t]})
\right)
\odot \exp (\boldsymbol{\Gamma}_{[t]} - \boldsymbol{1} \boldsymbol{\gamma}_{[t]}^{C\top})
\\&+
\delta \mathbf{A_{qk}}_{[t]}^\top
\left( \mathbf{Q}_{[t]} \odot \exp (\boldsymbol{\Gamma}_{[t]} - \boldsymbol{1} \boldsymbol{\gamma}_{[t]}^{C\top})\right)
\odot \exp (\boldsymbol{1} \boldsymbol{\gamma}_{[t]}^{C\top} - \boldsymbol{\Gamma}_{[t]})
\\&+
\text{Diag}(\boldsymbol{\beta}_{[t]})
\delta \mathbf{A_X}_{[t]}^\top
\left( \mathbf{K}_{[t]} \odot \exp (\boldsymbol{\Gamma}_{[t]} - \boldsymbol{1} \boldsymbol{\gamma}_{[t]}^{C\top})\right)
\odot \exp (\boldsymbol{1} \boldsymbol{\gamma}_{[t]}^{C\top} - \boldsymbol{\Gamma}_{[t]})
\\&=\left.\delta \mathbf{K}_{[t]}\right|_{\leftarrow \text{from } \mathbf{A}_{[t]} }
+
\left(
\left.\delta \mathbf{K}_{[t]}\right|_{\rightarrow \text{from } \mathbf{O}_{[t]} \text{w/o } \mathbf{A}_{[t]} }
+
\left.\delta \mathbf{K}_{[t]}\right|_{\rightarrow \text{from }\mathbf{A}_{[t]} }
\right)
\\
\\
\delta \mathbf{\Gamma}_{[t],part2}
&=
\delta \exp \mathbf{\Gamma}_{[t]}
\odot \exp \mathbf{\Gamma}_{[t]}
\\&=
\delta \mathbf{Q}_{[t],part2} \odot \mathbf{Q}_{[t]}
+
\left(
\left.\delta \mathbf{K}_{[t]}\right|_{\leftarrow \text{from } \mathbf{A}_{[t]} }
-
\left(
\left.\delta \mathbf{K}_{[t]}\right|_{\rightarrow \text{from } \mathbf{O}_{[t]} \text{w/o } \mathbf{A}_{[t]} }
+
\left.\delta \mathbf{K}_{[t]}\right|_{\rightarrow \text{from }\mathbf{A}_{[t]} }
\right)
\right)
\odot \mathbf{K}_{[t]}
\end{aligned}\]
对应代码: fla.ops.kda.chunk_intra -> chunk_kda_bwd_intra -> chunk_kda_bwd_kernel_intra
评论: 在 chunk_kda_bwd_kernel_intra 函数中,对角线部分的计算利用 SAFE_GATE 机制进行了精心优化。如果我们简单粗暴地去计算块内因果依赖所需的指数项,在门控波动较大时,极易引发 FP16 精度下的数值溢出;通常的做法是妥协退化为标量/向量级别的 for 循环操作,这会严重拖慢速度。而 SAFE_GATE 机制巧妙地通过 gather 提取一个局部基准点来进行指数归一化(其作用类似于局部的 Max-Trick)。这在数学层面保证了绝对的数值稳定性,从而允许底层算子安全地将低效的向量循环替换为硬件加速的 tl.dot 。这一设计在保持数学精度的同时取得一定的加速。
步骤 7:后缀和逆向运算
\[\begin{aligned}
\delta \log \boldsymbol{\alpha}_{[t]}
&=
\text{suffix\_cumsum}(\delta\mathbf{\gamma}_{[t]})
\end{aligned}\]
对应代码: fla.ops.utils.cumsum -> chunk_local_cumsum