跳转至

Reading Notes: Implementation Notes for KDA in FLA

Paper: https://arxiv.org/pdf/2510.26692
Code: https://github.com/fla-org/flash-linear-attention
Disclaimer: 本文为个人阅读笔记。部分推导由我本人完成,可能存在误差,建议结合 Triton 源码交叉验证。

前向传播

入口函数:

def chunk_kda_fwd(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    g: torch.Tensor,
    beta: torch.Tensor,
    scale: float,
    initial_state: torch.Tensor,
    output_final_state: bool,
    cu_seqlens: torch.LongTensor | None = None,
    cu_seqlens_cpu: torch.LongTensor | None = None,
    chunk_indices: torch.LongTensor | None = None,
    chunk_size: int = 64,
    safe_gate: bool = False,
    lower_bound: float | None = None,
    use_gate_in_kernel: bool = False,
    A_log: torch.Tensor | None = None,
    dt_bias: torch.Tensor | None = None,
    disable_recompute: bool = False,
    return_intermediate_states: bool = False,
    cp_context: FLACPContext | None = None,
    transpose_state_layout: bool = False,
):
    return o, final_state, g, Aqk, Akk, w, u, qg, kg, v_new, h, initial_state

符号约定

\[\begin{aligned} \overleftarrow{\square_{[t]}} &= \square_{[t]} \odot \mathbf{\Gamma}_{[t]} ,\quad \overrightarrow{\square_{[t]}} = \square_{[t]} \oslash \mathbf{\Gamma}_{[t]} \quad\text{for}\quad \square \in \{ \mathbf{Q}, \mathbf{K}\} \end{aligned}\]

步骤 1:在对数域中计算累积门控

此处变量 gamma 实际上代表的是对数域中的累积和(即 log gamma),在该笔记中皆是如此。

\[\begin{aligned} \boldsymbol{\gamma}_{[t]}^r &= \sum_{i=tC+1}^{tC+r} \log \alpha_i \in \mathbb{R} \end{aligned}\]

对应代码: fla.ops.utils.cumsum -> chunk_local_cumsum

步骤 2:求解下三角系统以获取块内注意力

\[\begin{aligned} \mathbf{A_{qk}}_{[t]} &= \overleftarrow{\mathbf{Q}_{[t]}} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M} ,\quad \mathbf{A_{kk0}}_{[t]} = \text{Diag}(\boldsymbol{\beta}_{[t]}) \left( \overleftarrow{\mathbf{K}_{[t]}} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M}_{-1} \right) \\ \\ \mathbf{A}_{[t]} &= (\mathbf{I} + \mathbf{A_{kk0}}_{[t]})^{-1} \end{aligned}\]

对应代码: fla.ops.kda.chunk_intra -> chunk_kda_fwd_intra -> chunk_kda_fwd_kernel_intra_sub_chunk / chunk_kda_fwd_kernel_inter_solve_fused

步骤 3:计算聚合状态 U 与 W

\[\begin{aligned} \mathbf{U}_{[t]} &= \mathbf{A}_{[t]} \left( \text{Diag}(\boldsymbol{\beta}_{[t]}) \mathbf{V}_{[t]} \right) \\ \\ \mathbf{W}_{[t]} &= \mathbf{A}_{[t]} \left( (\exp \boldsymbol{\Gamma}_{[t]}) \odot \text{Diag}(\boldsymbol{\beta}_{[t]}) \mathbf{K}_{[t]} \right) \\ \\ \mathbf{Q_g}_{[t]} &= (\exp \boldsymbol{\Gamma}_{[t]}) \odot \mathbf{Q}_{[t]} \\ \\ \mathbf{K_g}_{[t]} &= (\exp (\boldsymbol{1} \boldsymbol{\gamma}_{[t]}^{C\top} - \boldsymbol{\Gamma}_{[t]})) \odot \mathbf{K}_{[t]} \end{aligned}\]

为了在训练的前向传播中节省显存带宽,如果开启了激活重计算,此时的 Q_g 不会被计算并写入显存,而是会在反向传播初期进行即时计算。

对应代码: fla.ops.kda.wy_fast -> recompute_w_u_fwd

步骤 4:跨块状态递推

这里隐状态 S_[t]^C 使用了 [d_k, d_v]的内存布局。因此在底层代码中,它相对于我们之前的纯数学推导是发生过转置的。

\[\begin{aligned} \mathbf{V}_{[t],new} &= \mathbf{U}_{[t]} - \mathbf{W}_{[t]} \mathbf{S}_{[t-1]}^{C} \\ \\ \mathbf{S}_{[t]}^{C} &= \boldsymbol{\gamma}_{[t]}^C \mathbf{S}_{[t-1]}^{C} + \mathbf{K_g}_{[t]}^\top \mathbf{V}_{[t],new} \end{aligned}\]

对应代码: fla.ops.common.chunk_delta_h -> chunk_gated_delta_rule_fwd_h

步骤 5:计算最终输出

\[\begin{aligned} \mathbf{O}_{[t]} = \left((\exp \boldsymbol{\Gamma}_{[t]}) \odot \mathbf{Q}_{[t]}\right) \mathbf{S}_{[t-1]}^{C} + \mathbf{A_{qk}}_{[t]} \mathbf{V}_{[t],new} \end{aligned}\]

对应代码: fla.ops.gla.chunk -> chunk_gla_fwd_o_gk

反向传播

入口函数:

def chunk_kda_bwd(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    beta: torch.Tensor,
    Aqk: torch.Tensor,
    Akk: torch.Tensor,
    scale: float,
    initial_state: torch.Tensor,
    do: torch.Tensor,
    dht: torch.Tensor,
    g: torch.Tensor | None = None,
    g_org: torch.Tensor | None = None,
    cu_seqlens: torch.LongTensor | None = None,
    chunk_indices: torch.LongTensor | None = None,
    chunk_size: int = 64,
    safe_gate: bool = False,
    lower_bound: float | None = None,
    use_gate_in_kernel: bool = False,
    A_log: torch.Tensor | None = None,
    dt_bias: torch.Tensor | None = None,
    disable_recompute: bool = False,
    cp_context: FLACPContext | None = None,
    transpose_state_layout: bool = False,
    **kwargs,
):
    return dq, dk, dv, db, dg, dh0, dA, dbias

步骤 1:重计算聚合状态 U 与 W

这一步的计算逻辑与前向传播完全一致,唯一多出来的操作是计算并存储了 Q_g

对应代码: fla.ops.kda.wy_fast -> recompute_w_u_fwd

Step 2: Recompute the hidden state

This 步骤 2:重计算隐状态

对应代码: fla.ops.common.chunk_delta_h -> chunk_gated_delta_rule_fwd_h

步骤 3:从输出中推导初始梯度

在这里,我们从输出的梯度中解开对 U 和 A_qk 的初始导数。

\[\begin{aligned} \left.\delta \mathbf{U}_{[t]}\right|_{\text{from } \mathbf{O}_{[t]}} &= \mathbf{A_{qk}}_{[t]}^\top \delta \mathbf{O}_{[t]} \\ \\ \delta \mathbf{A_{qk}}_{[t]} &= \delta \mathbf{O}_{[t]} \mathbf{V}_{[t],new}^\top \odot \mathbf{M} \end{aligned}\]

对应代码: fla.ops.kda.chunk_bwd -> chunk_kda_bwd_dAv -> chunk_kda_bwd_kernel_dAv

步骤 4:隐状态的跨块反向递推

\[\begin{aligned} \delta \mathbf{U}_{[t]} = \mathbf{K_g}_{[t]} \delta \mathbf{S}_{[t]}^{C} + \left.\delta \mathbf{U}_{[t]}\right|_{\text{from } \mathbf{O}_{[t]}} \end{aligned}\]
\[\begin{aligned} \delta \mathbf{S}_{[t]}^{C} = \exp(\boldsymbol{\gamma}_{[t]}^C) \delta \mathbf{S}_{[t+1]}^{C} + \mathbf{Q_g}_{[t+1]}^\top \delta \mathbf{O}_{[t+1]} - \mathbf{W}_{[t+1]}^\top \delta \mathbf{U}_{[t+1]} \end{aligned}\]

对应代码: fla.ops.common.chunk_delta_h -> chunk_gated_delta_rule_bwd_dhu

步骤 5: WY 表示法的核心融合算子

这是一个极其庞大的融合算子。它在一个 Kernel 内集中计算了 W, V, beta, A_x 的梯度,以及 Q, K, Gamma 梯度的部分贡献。

\[\begin{aligned} \delta \mathbf{Q}_{[t],part1} &= (\exp \boldsymbol{\Gamma}_{[t]}) \odot \mathbf{O}_{[t]} \mathbf{S}_{[t-1]}^{C\top} \\ \\ \delta \mathbf{W}_{[t]} &= - \delta \mathbf{U}_{[t]} \mathbf{S}_{[t-1]}^{C} \\ \\ \left.\delta \mathbf{K}_{[t], part1}\right|_{\text{from } \mathbf{S}_{[t]}^C \text{w/o} \mathbf{A}_{[t]} } &= \exp (\boldsymbol{1} \boldsymbol{\gamma}_{[t]}^{C\top} - \boldsymbol{\Gamma}_{[t]}) \odot \mathbf{V}_{[t],new} \delta \mathbf{S}_{[t]}^{C\top} \\ \\ \left.\delta \mathbf{K}_{[t],part1}\right|_{\text{from } \mathbf{W}_{[t]} \text{w/o} \mathbf{A}_{[t]} } &= \mathbf{A}_{[t]}^\top \delta \mathbf{W}_{[t]} \odot (\text{Diag}(\boldsymbol{\beta}_{[t]}) \exp \mathbf{\Gamma}_{[t]} ) \\ \\ \delta \mathbf{V}_{[t]} &= \text{Diag}(\boldsymbol{\beta}_{[t]}) \mathbf{A}_{[t]}^\top \delta \mathbf{U}_{[t]} \\ \\ \delta \boldsymbol{\gamma}_{[t]}^C &= \delta \exp \boldsymbol{\gamma}_{[t]}^C \odot \exp \boldsymbol{\gamma}_{[t]}^C \\ &= \text{diag}\left( \mathbf{S}_{[t-1]}^C \delta \mathbf{S}_{[t]}^C \right) \odot \exp \boldsymbol{\gamma}_{[t]}^C + \left.\delta \mathbf{K}_{[t], part1}\right|_{\text{from } \mathbf{S}_{[t]}^C \text{w/o} \mathbf{A}_{[t]} } \odot \mathbf{K}_{[t]} \\ \\ \delta \mathbf{\Gamma}_{[t],part1} &= \left.\delta \exp \mathbf{\Gamma}_{[t]}\right|_{\text{w/o extra} \boldsymbol{\gamma}_{[t]}^C} \odot \exp \mathbf{\Gamma}_{[t]} \\ &= \delta \mathbf{Q}_{[t],part1} \odot \mathbf{Q}_{[t]} - \left.\delta \mathbf{K}_{[t], part1}\right|_{\text{from } \mathbf{S}_{[t]}^C \text{w/o} \widetilde{\mathbf{A}}_{[t]} } \odot \mathbf{K}_{[t]} + \left.\delta \mathbf{K}_{[t],part1}\right|_{\text{from } \mathbf{W}_{[t]} \text{w/o} \widetilde{\mathbf{A}}_{[t]} } \odot \mathbf{K}_{[t]} \\ &+ [0,0,...,\delta \boldsymbol{\gamma}_{[t]}^C ] \\ \\ \delta \boldsymbol{\beta}_{[t],part1} &= \text{diag}(\delta \text{Diag}(\boldsymbol{\beta}_{[t]})) = \text{diag}( \mathbf{A}_{[t]}^\top \delta \mathbf{U}_{[t]} \mathbf{V}_{[t]}^\top ) + \text{diag}( \mathbf{A}_{[t]}^\top \delta \mathbf{W}_{[t]} ( \boldsymbol{\Gamma}_{[t]} \odot \mathbf{K}_{[t]} )^\top ) \\ \\ \delta \boldsymbol{\beta}_{[t]} &= \text{diag}(\delta \text{Diag}(\boldsymbol{\beta}_{[t]})) \\ &= \text{diag}\left( \widetilde{\mathbf{A}}_{[t]}^\top \delta \widetilde{\mathbf{A}}_{[t]} \right) \odot \boldsymbol{\beta}_{[t]}^{-1} + \text{diag}\left( \left( \delta \widetilde{\mathbf{X}}_{[t]} \odot \mathbf{M}_{-1} \right) \left( \overrightarrow{\mathbf{K}_{[t]}} \overleftarrow{\mathbf{K}_{[t]}}^\top \right) \right) \\ \\ \delta \mathbf{A}_{[t]} &= \left( \delta \mathbf{U}_{[t]} \mathbf{V}_{[t]}^\top + \delta \mathbf{W}_{[t]} (\boldsymbol{\Gamma}_{[t]} \odot \mathbf{K}_{[t]})^\top \right) \text{Diag}(\boldsymbol{\beta}_{[t]}) \\ \\ \delta \mathbf{A_X}_{[t]} &= - \mathbf{A}_{[t]}^\top \delta \mathbf{A}_{[t]} \mathbf{A}_{[t]}^\top \odot \mathbf{M}_{-1} \end{aligned}\]

对应代码: fla.ops.kda.chunk_bwd -> chunk_kda_bwd_wy_dqkg_fused

步骤 6:块内局部梯度计算

\[\begin{aligned} \delta \mathbf{Q}_{[t],part2} &= \delta \mathbf{A_{qk}}_{[t]} \left( \exp (\boldsymbol{1} \boldsymbol{\gamma}_{[t]}^{C\top} - \boldsymbol{\Gamma}_{[t]}) \odot \mathbf{K}_{[t]}) \right) \odot \exp (\boldsymbol{\Gamma}_{[t]} - \boldsymbol{1} \boldsymbol{\gamma}_{[t]}^{C\top}) \\ \\ \delta \mathbf{K}_{[t],part2} &= \text{Diag}(\boldsymbol{\beta}_{[t]}) \delta \mathbf{A_X}_{[t]} \left( \mathbf{K}_{[t]} \odot \exp (\boldsymbol{1} \boldsymbol{\gamma}_{[t]}^{C\top} - \boldsymbol{\Gamma}_{[t]}) \right) \odot \exp (\boldsymbol{\Gamma}_{[t]} - \boldsymbol{1} \boldsymbol{\gamma}_{[t]}^{C\top}) \\&+ \delta \mathbf{A_{qk}}_{[t]}^\top \left( \mathbf{Q}_{[t]} \odot \exp (\boldsymbol{\Gamma}_{[t]} - \boldsymbol{1} \boldsymbol{\gamma}_{[t]}^{C\top})\right) \odot \exp (\boldsymbol{1} \boldsymbol{\gamma}_{[t]}^{C\top} - \boldsymbol{\Gamma}_{[t]}) \\&+ \text{Diag}(\boldsymbol{\beta}_{[t]}) \delta \mathbf{A_X}_{[t]}^\top \left( \mathbf{K}_{[t]} \odot \exp (\boldsymbol{\Gamma}_{[t]} - \boldsymbol{1} \boldsymbol{\gamma}_{[t]}^{C\top})\right) \odot \exp (\boldsymbol{1} \boldsymbol{\gamma}_{[t]}^{C\top} - \boldsymbol{\Gamma}_{[t]}) \\&=\left.\delta \mathbf{K}_{[t]}\right|_{\leftarrow \text{from } \mathbf{A}_{[t]} } + \left( \left.\delta \mathbf{K}_{[t]}\right|_{\rightarrow \text{from } \mathbf{O}_{[t]} \text{w/o } \mathbf{A}_{[t]} } + \left.\delta \mathbf{K}_{[t]}\right|_{\rightarrow \text{from }\mathbf{A}_{[t]} } \right) \\ \\ \delta \mathbf{\Gamma}_{[t],part2} &= \delta \exp \mathbf{\Gamma}_{[t]} \odot \exp \mathbf{\Gamma}_{[t]} \\&= \delta \mathbf{Q}_{[t],part2} \odot \mathbf{Q}_{[t]} + \left( \left.\delta \mathbf{K}_{[t]}\right|_{\leftarrow \text{from } \mathbf{A}_{[t]} } - \left( \left.\delta \mathbf{K}_{[t]}\right|_{\rightarrow \text{from } \mathbf{O}_{[t]} \text{w/o } \mathbf{A}_{[t]} } + \left.\delta \mathbf{K}_{[t]}\right|_{\rightarrow \text{from }\mathbf{A}_{[t]} } \right) \right) \odot \mathbf{K}_{[t]} \end{aligned}\]

对应代码: fla.ops.kda.chunk_intra -> chunk_kda_bwd_intra -> chunk_kda_bwd_kernel_intra

评论: 在 chunk_kda_bwd_kernel_intra 函数中,对角线部分的计算利用 SAFE_GATE 机制进行了精心优化。如果我们简单粗暴地去计算块内因果依赖所需的指数项,在门控波动较大时,极易引发 FP16 精度下的数值溢出;通常的做法是妥协退化为标量/向量级别的 for 循环操作,这会严重拖慢速度。而 SAFE_GATE 机制巧妙地通过 gather 提取一个局部基准点来进行指数归一化(其作用类似于局部的 Max-Trick)。这在数学层面保证了绝对的数值稳定性,从而允许底层算子安全地将低效的向量循环替换为硬件加速的 tl.dot 。这一设计在保持数学精度的同时取得一定的加速。

步骤 7:后缀和逆向运算

\[\begin{aligned} \delta \log \boldsymbol{\alpha}_{[t]} &= \text{suffix\_cumsum}(\delta\mathbf{\gamma}_{[t]}) \end{aligned}\]

对应代码: fla.ops.utils.cumsum -> chunk_local_cumsum

Comments