跳转至

阅读笔记:GDN 在 FLA 中的实现笔记

原文链接https://arxiv.org/pdf/2412.06464
代码链接https://github.com/fla-org/flash-linear-attention
声明:本文为个人阅读笔记。所有来自我自己的推导都可能存在错误。

一、前向传播

入口函数签名:

def chunk_gated_delta_rule_fwd(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    g: torch.Tensor,
    beta: torch.Tensor,
    scale: float,
    initial_state: torch.Tensor,
    output_final_state: bool,
    cu_seqlens: torch.LongTensor | None = None,
    cp_context: FLACPContext | None = None,
    chunk_indices: torch.LongTensor | None = None,
    use_exp2: bool = True,
    transpose_state_layout: bool = False,
):
    return g, o, A, final_state, initial_state

记号

\[\begin{aligned} \overleftarrow{\square_{[t]}} &= \text{Diag}(\boldsymbol{\gamma}_{[t]}) \square_{[t]} ,\quad \overrightarrow{\square_{[t]}} = \text{Diag}(\boldsymbol{\gamma}_{[t]})^{-1} \square_{[t]} \quad\text{for}\quad \square \in \{ \mathbf{Q}, \mathbf{K}, \mathbf{W}\} \end{aligned}\]

Step 1:计算 log 域下的累积门控

这里的 gamma 对应之前定义的 log gamma,因此后续公式需要注意这一点。

[公式占位符:log 域下 gamma 的定义]

对应代码fla.ops.utils.cumsum -> chunk_local_cumsum

Step 2:求解下三角线性系统

\[\begin{aligned} \widetilde{\mathbf{A_0}}_{[t]} = \text{Diag}(\boldsymbol{\beta}_{[t]}) \left( \overleftarrow{\mathbf{K}_{[t]}} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M}_{-1} \right) \\ \\ \widetilde{\mathbf{A}}_{[t]} = (\mathbf{I} + \widetilde{\mathbf{A_0}}_{[t]})^{-1} \end{aligned}\]

对应代码fla.ops.gated_delta_rule.chunk_fwd -> chunk_gated_delta_rule_fwd_intra -> chunk_gated_delta_rule_fwd_kkt_solve_kernel

Step 3:计算 U 与左缩放的 W

\[\begin{aligned} \mathbf{U}_{[t]} &= \widetilde{\mathbf{A}}_{[t]} \text{Diag}(\boldsymbol{\beta}_{[t]}) \mathbf{V}_{[t]} \\ \\ \overleftarrow{\mathbf{W}_{[t]}} &= \widetilde{\mathbf{A}}_{[t]} \text{Diag}(\exp \boldsymbol{\gamma}_{[t]}) \text{Diag}(\boldsymbol{\beta}_{[t]}) \mathbf{K}_{[t]} \end{aligned}\]

对应代码fla.ops.gated_delta_rule.wy_fast -> recompute_w_u_fwd

Step 4:递推计算隐藏状态

这里的 S_[t]^C 采用 [d_k, d_v] 的布局,因此与之前推导相比存在转置关系。

\[\begin{aligned} \mathbf{V}_{[t],new} &= \mathbf{U}_{[t]} - \overleftarrow{\mathbf{W}_{[t]}} \mathbf{S}_{[t-1]}^{C} \\ \\ \mathbf{S}_{[t]}^{C} &= \boldsymbol{\gamma}_{[t]}^C \mathbf{S}_{[t-1]}^{C} + \mathbf{K}_{[t]}^\top \left( \text{Diag}(\exp(\boldsymbol{\gamma}_{[t]}^C - \boldsymbol{\gamma}_{[t]}) \mathbf{V}_{[t],new} \right) \end{aligned}\]

对应代码fla.ops.common.chunk_delta_h -> chunk_gated_delta_rule_fwd_h

Step 5:计算最终输出

\[\begin{aligned} \mathbf{O}_{[t]} = \text{Diag}\left(\exp \boldsymbol{\gamma}_{[t]} \right) \mathbf{Q}_{[t]} \mathbf{S}_{[t-1]}^{C} + \left( \overleftarrow{\mathbf{Q}_{[t]}} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M} \right) \mathbf{V}_{[t],new} \end{aligned}\]

对应代码fla.ops.common.chunk_o -> chunk_fwd_o

评论:从代码来看,2412.06464v3 的 Equation 9 确实是笔误。

二、反向传播

入口函数签名:

def chunk_gated_delta_rule_bwd(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    g: torch.Tensor,
    beta: torch.Tensor,
    A: torch.Tensor,
    scale: float,
    initial_state: torch.Tensor,
    do: torch.Tensor,
    dht: torch.Tensor,
    cu_seqlens: torch.LongTensor | None = None,
    cp_context: FLACPContext | None = None,
    chunk_indices: torch.LongTensor | None = None,
    use_exp2: bool = True,
    transpose_state_layout: bool = False,
):
    return dq, dk, dv, db, dg, dh0

Step 1:复用已存储的逆矩阵

这里的 A_[t] = (I + A0_[t])^{-1} 已经被存储下来了。其余部分与前向相同。

对应代码fla.ops.gated_delta_rule.wy_fast -> recompute_w_u_fwd

Step 2:重算隐藏状态

这一步与前向相同。

对应代码fla.ops.common.chunk_delta_h -> chunk_gated_delta_rule_fwd_h

Step 3:计算来自 O 的 dU

\[\begin{aligned} \left.\delta \mathbf{U}_{[t]}\right|_{\text{from } \mathbf{O}_{[t]}} = \left( \overrightarrow{\mathbf{K}_{[t]}} \overleftarrow{\mathbf{Q}_{[t]}}^\top \odot \mathbf{M}^\top \right) \delta \mathbf{O}_{[t]} \end{aligned}\]

对应代码fla.ops.common.chunk_o -> chunk_bwd_dv_local

Step 4:对 U 和状态做反向递推

\[\begin{aligned} \delta \mathbf{U}_{[t]} = \text{Diag}\left(\exp(\boldsymbol{\gamma}_{[t]}^C - \boldsymbol{\gamma}_{[t]})\right) \mathbf{K}_{[t]} \delta \mathbf{S}_{[t]}^{C} + \left.\delta \mathbf{U}_{[t]}\right|_{\text{from } \mathbf{O}_{[t]}} \end{aligned}\]
\[\begin{aligned} \delta \mathbf{S}_{[t]}^{C} = \exp(\boldsymbol{\gamma}_{[t]}^C) \delta \mathbf{S}_{[t+1]}^{C} + \overleftarrow{\mathbf{Q}_{[t+1]}}^\top \delta \mathbf{O}_{[t+1]} - \overleftarrow{\mathbf{W}_{[t+1]}} ^\top \delta \mathbf{U}_{[t+1]} \end{aligned}\]

对应代码fla.ops.common.chunk_delta_h -> chunk_gated_delta_rule_bwd_dhu

Step 5:计算 dB、dQ、部分 dK、dW,以及部分 d gamma

\[\begin{aligned} \delta \mathbf{B}_{[t]} &= \delta \mathbf{O}_{[t]} \mathbf{V}_{[t],new}^\top \odot \mathbf{M} \\ \\ \delta \mathbf{Q}_{[t]} &= \text{Diag}(\boldsymbol{\gamma}_{[t]}) \delta \mathbf{O}_{[t]} \mathbf{S}_{[t-1]}^{C\top} + \left( \text{Diag}(\boldsymbol{\gamma}_{[t]}) \delta \mathbf{B}_{[t]} \text{Diag}(\boldsymbol{\gamma}_{[t]})^{-1} \right) \mathbf{K}_{[t]} \\ \\ \delta \mathbf{K}_{[t], \text{part1}} &= \text{Diag}(\exp(\gamma_{[t]}^C - \boldsymbol{\gamma}_{[t]})) \mathbf{V}_{[t],new} \delta \mathbf{S}_{[t]}^{C\top} + \left( \text{Diag}(\boldsymbol{\gamma}_{[t]}) \delta \mathbf{B}_{[t]} \text{Diag}(\boldsymbol{\gamma}_{[t]})^{-1} \right)^\top \mathbf{Q}_{[t]} \\&= \left.\delta \mathbf{K}_{[t]}\right|_{\text{from } \mathbf{S}_{[t]} \text{ w/o } \mathbf{V}_{[t],new}} + \left.\delta \mathbf{K}_{[t]}\right|_{\text{from } \mathbf{O}_{[t]} \text{ w/o } \mathbf{V}_{[t],new}} \\ \\ \delta \mathbf{W}_{[t]} &= - \delta \mathbf{U}_{[t]} \mathbf{S}_{[t-1]}^{C\top} \end{aligned}\]
\[\begin{aligned} \delta \boldsymbol{\gamma}_{[t]}^C &= \delta \boldsymbol{\gamma}_{[t]}^C \exp \boldsymbol{\gamma}_{[t]}^C = \text{Tr}\left( \delta \mathbf{S}_{[t]}^{C\top} \mathbf{S}_{[t-1]}^C \right) \exp \boldsymbol{\gamma}_{[t]}^C + \text{Tr}\left( \left( \left.\delta \mathbf{K}_{[t]}\right|_{\text{from } \mathbf{S}_{[t]} \text{ w/o } \mathbf{V}_{[t],new}} \right) \mathbf{K}_{[t]}^\top\right) \\ \\ &= \text{Tr}\left( \delta \mathbf{S}_{[t]}^{C\top} \mathbf{S}_{[t-1]}^C + \delta \mathbf{S}_{[t]}^{C\top} \left( \mathbf{K}_{[t]}^\top \text{Diag}(\boldsymbol{\gamma}_{[t]})^{-1} \mathbf{V}_{[t]}^\top \right) \right) \exp \boldsymbol{\gamma}_{[t]}^C \\ \\ \delta \boldsymbol{\gamma}_{[t],\text{part1}} &= \delta \boldsymbol{\gamma}_{[t],\text{part1}} \odot \exp\boldsymbol{\gamma}_{[t]} \\&= \text{diag}\left( \left( \text{Diag}(\exp \boldsymbol{\gamma}_{[t]}) \delta \mathbf{O}_{[t]} \mathbf{S}_{[t-1]}^{C\top} \right) \mathbf{Q}_{[t]}^\top \right) \\&- \text{diag}\left( \mathbf{K}_{[t]} \left( \left.\delta \mathbf{K}_{[t]}\right|_{\text{from } \mathbf{S}_{[t]} \text{ w/o } \mathbf{V}_{[t],new}} \right)^\top \right) \\&+ \text{diag}\left( \left( \text{Diag}(\exp \boldsymbol{\gamma}_{[t]}) \delta \mathbf{B}_{[t]} \text{Diag}(\exp \boldsymbol{\gamma}_{[t]})^{-1} \right) \left( \mathbf{Q}_{[t]}\mathbf{K}_{[t]}^\top \right)^\top \right) \\&- \text{diag}\left( \left( \text{Diag}(\exp \boldsymbol{\gamma}_{[t]}) \delta \mathbf{B}_{[t]} \text{Diag}(\exp \boldsymbol{\gamma}_{[t]})^{-1} \right)^\top \left( \mathbf{Q}_{[t]}\mathbf{K}_{[t]}^\top \right) \right) \\&+ [0,0,...,\delta \boldsymbol{\gamma}_{[t]}^C \exp \boldsymbol{\gamma}_{[t]}^C]^\top \\&= \left( \left.\delta \text{Diag}(\exp \boldsymbol{\gamma}_{[t]})\right|_{\text{from } \mathbf{S}_{[t]} \text{w/o} \mathbf{V}_{[t],new}} + \left.\delta \text{Diag}(\exp \boldsymbol{\gamma}_{[t]})\right|_{\text{from } \mathbf{O}_{[t]} \text{w/o} \mathbf{V}_{[t],new}} + \left.\delta \text{Diag}(\exp \boldsymbol{\gamma}_{[t]})\right|_{\text{from } \gamma_{[t]}^C} \right) \odot \exp\boldsymbol{\gamma}_{[t]} \end{aligned}\]

对应代码fla.ops.common.chunk_o -> chunk_bwd_dqkwg

Step 6:计算 dK、dV、d beta、d gamma 的剩余部分

\[\begin{aligned} \left.\delta \mathbf{K}_{[t]}\right|_{\text{from } \overleftarrow{\mathbf{W}_{[t]}} \text{ w/o } \mathbf{T}_{[t]} } &= \text{Diag}(\exp \boldsymbol{\gamma}_{[t]}) \text{Diag}(\boldsymbol{\beta}_{[t]}) \left( \widetilde{\mathbf{A}}_{[t]}^\top \delta \overleftarrow{\mathbf{W}_{[t]}} \right) \\ \\ \delta \mathbf{V}_{[t]} &= \text{Diag}(\boldsymbol{\beta}_{[t]}) \widetilde{\mathbf{A}}_{[t]}^\top \delta \mathbf{U}_{[t]} \\ \\ \delta \widetilde{\mathbf{A}}_{[t]} &= \delta \overleftarrow{\mathbf{W}_{[t]}} (\text{Diag}(\boldsymbol{\beta}_{[t]} \exp \boldsymbol{\gamma}_{[t]})\mathbf{K}_{[t]})^\top + \delta \mathbf{U}_{[t]} (\text{Diag}(\boldsymbol{\beta}_{[t]})\mathbf{V}_{[t]} )^\top \\ \\ \delta \widetilde{\mathbf{A_x}}_{[t]} &= - \text{Diag}(\exp \boldsymbol{\gamma}_{[t]}) \widetilde{\mathbf{A}}_{[t]}^\top \delta \widetilde{\mathbf{A}}_{[t]} \widetilde{\mathbf{A}}_{[t]}^\top \text{Diag}(\exp \boldsymbol{\gamma}_{[t]})^{-1} \odot \mathbf{M_{-1}} \\&= \text{Diag}(\boldsymbol{\beta}_{[t]})^{-1} \left.\delta (\mathbf{K}_{[t]} \mathbf{K}_{[t]}^\top)\right|_{\text{from } \widetilde{\mathbf{X}}_{[t]}} \\ \\ \left.\delta \mathbf{K}_{[t]}\right|_{\text{from } \mathbf{X}_{[t]}} &= \left( \text{Diag}(\boldsymbol{\beta}_{[t]}) \delta \widetilde{\mathbf{A_x}}_{[t]} \right) \mathbf{K}_{[t]} + \left( \left( \text{Diag}(\boldsymbol{\beta}_{[t]}) \mathbf{K}_{[t]} \right)^\top \delta \widetilde{\mathbf{A_x}}_{[t]} \right)^\top \\ \\ \delta \boldsymbol{\beta}_{[t]} &= \text{diag}\left( \text{Diag}(\boldsymbol{\gamma}_{[t]}) \widetilde{\mathbf{A}}_{[t]}^\top \delta \overleftarrow{\mathbf{W}_{[t]}} \mathbf{K}_{[t]}^\top \right) + \text{diag}\left( \widetilde{\mathbf{A}}_{[t]}^\top \delta \mathbf{U}_{[t]} \mathbf{V}_{[t]}^\top \right) + \text{diag}\left( \delta \widetilde{\mathbf{A_x}}_{[t]} \mathbf{K}_{[t]} \mathbf{K}_{[t]}^\top\right) \\ \\ \delta \boldsymbol{\gamma}_{[t],\text{part2}} &= \text{diag}\left( \left.\delta \text{Diag}(\exp \boldsymbol{\gamma}_{[t]})\right|_{\text{from } \mathbf{U}_{[t]} \text{w/ } \mathbf{W}_{[t]} \text{w/o } \widetilde{\mathbf{A}}_{[t]} } \right) \odot \exp \boldsymbol{\gamma}_{[t]} + \text{diag}\left( \left.\delta \text{Diag}(\exp \boldsymbol{\gamma}_{[t]})\right|_{\text{from } \widetilde{\mathbf{A}}_{[t]} } \right) \odot \exp \boldsymbol{\gamma}_{[t]} \\&= \text{diag}\left( \widetilde{\mathbf{A}}_{[t]}^\top \delta \overleftarrow{\mathbf{W}}_{[t]} \left( \mathbf{K}_{[t]} \text{Diag}(\boldsymbol{\beta}_{[t]}) \text{Diag}(\boldsymbol{\gamma}_{[t]}) \right)^\top \right) \\&+ \text{diag}\left( \delta \widetilde{\mathbf{A_x}}_{[t]} \left( \text{Diag}(\boldsymbol{\beta}_{[t]}) \mathbf{K}_{[t]} \mathbf{K}_{[t]}^\top \right)^\top \right) - \text{diag}\left( \left( \text{Diag}(\boldsymbol{\beta}_{[t]}) \mathbf{K}_{[t]} \mathbf{K}_{[t]}^\top \right)^\top \delta \widetilde{\mathbf{A_x}}_{[t]} \right) \end{aligned}\]

对应代码fla.ops.gated_delta_rule.wy_fast -> prepare_wy_repr_bwd

Step 7:恢复 log alpha 的梯度

\[\begin{aligned} \delta \log \boldsymbol{\alpha}_{[t]} &= \text{suffix\_cumsum}(\delta\mathbf{\gamma}_{[t]}) \end{aligned}\]

对应代码fla.ops.utils.cumsum -> chunk_local_cumsum

Comments