跳转至

阅读笔记:DeltaNet 在 FLA 中的代码实现笔记

原文链接https://arxiv.org/abs/2406.06484
代码链接https://github.com/fla-org/flash-linear-attention
声明:本笔记是对 flash-linear-attention 仓库中 DeltaNet部分 Triton 实现的逐步数学解读,与前一篇 DeltaNet 论文笔记配合阅读。

一、前向传播(Forward)

入口函数签名:

def chunk_delta_rule_fwd(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    beta: torch.Tensor,
    scale: float,
    initial_state: torch.Tensor,
    output_final_state: bool,
    cu_seqlens: torch.LongTensor | None = None,
    chunk_indices: torch.LongTensor | None = None,
):
    return o, A, final_state

Step 1:计算 KK^T 矩阵

\[\begin{aligned} \mathbf{A} = (\boldsymbol{\beta}_{[t]}^\top \boldsymbol{1} ) \odot \left( \mathbf{K}_{[t]} \mathbf{K}_{[t]}^\top \odot \mathbf{M}_{-1} \right) \end{aligned}\]

对应代码fla.ops.delta_rule.wy_fast -> prepare_wy_repr_fwd -> chunk_scaled_dot_kkt_fwd

Step 2:单位下三角矩阵求逆

此处代码采用逐行前代法与分块矩阵求逆进行求解。其中 \(16 \times 16\) 的子矩阵块使用前代法处理。这里 \(\mathbf{A_i}\) 相当于前一篇笔记中定义的 \(\mathbf{X}^{-1}\)

\[\begin{aligned} \mathbf{A_i} = (\mathbf{I} + \mathbf{A})^{-1} \end{aligned}\]

对应代码fla.ops.delta_rule.wy_fast -> prepare_wy_repr_fwd -> solve_tril

评论:Neumann 级数 + 倍增法可以更快,但精度更低。

Step 3:计算 U 和 W

\[\begin{aligned} \mathbf{U}_{[t]} &= \mathbf{A_i} \left((\boldsymbol{\beta}_{[t]}\top \boldsymbol{1}) \odot \mathbf{V}_{[t]} \right) \\ \\ \mathbf{W}_{[t]} &= \mathbf{A_i} \left((\boldsymbol{\beta}_{[t]}\top \boldsymbol{1}) \odot \mathbf{K}_{[t]} \right) \end{aligned}\]

对应代码fla.ops.delta_rule.wy_fast -> prepare_wy_repr_fwd -> recompute_w_u_fwd

Step 4:递推计算隐藏状态

符号说明:这里 \(\mathbf{S}_{[t]}^C\) 采用 \([d_k, d_v]\) 的布局,因此与前一篇论文笔记中的推导存在转置关系。

\[\begin{aligned} \mathbf{V}_{[t],new} &= \mathbf{U}_{[t]} - \mathbf{W}_{[t]} \mathbf{S}_{[t-1]}^{C} \\ \\ \mathbf{S}_{[t]}^{C} &= \mathbf{S}_{[t-1]}^{C} + \mathbf{K}_{[t]}^\top \mathbf{V}_{[t],new} \end{aligned}\]

对应代码fla.ops.common.chunk_delta_h -> chunk_gated_delta_rule_fwd_h

Step 5:计算最终输出

\[\begin{aligned} \mathbf{O}_{[t]} = \mathbf{Q}_{[t]} \mathbf{S}_{[t-1]}^{C} + \left( \mathbf{Q}_{[t]} \mathbf{K}_{[t]}^\top \odot \mathbf{M} \right) \mathbf{V}_{[t],new} \end{aligned}\]

对应代码fla.ops.common.chunk_o -> chunk_fwd_o

二、反向传播(Backward)

入口函数签名:

def chunk_delta_rule_bwd(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    beta: torch.Tensor,
    A: torch.Tensor,
    scale: float,
    initial_state: torch.Tensor,
    do: torch.Tensor,
    dht: torch.Tensor,
    cu_seqlens: torch.LongTensor | None = None,
    chunk_indices: torch.LongTensor | None = None,
):
    return dq, dk, dv, db, dh0

Step 1:重计算 U 和 W

\(\mathbf{A_i}\) 已在前向传播中存储,\(\mathbf{U}, \mathbf{W}\) 需要重新计算。

对应代码fla.ops.delta_rule.wy_fast -> prepare_wy_repr_fwd -> recompute_w_u_fwd

Step 2:重计算隐藏状态

对应代码fla.ops.common.chunk_delta_h -> chunk_gated_delta_rule_fwd_h

Step 3:计算 \delta U 的 intra-chunk 部分

\[\begin{aligned} \left.\delta \mathbf{U}_{[t]}\right|_{\text{from } \mathbf{O}_{[t]}} = \left( \mathbf{K}_{[t]} \mathbf{Q}_{[t]}^\top \odot \mathbf{M}^\top \right) \delta \mathbf{O}_{[t]} \end{aligned}\]

对应代码fla.ops.common.chunk_o -> chunk_bwd_dv_local

Step 4:反向递推计算 \(\delta \mathbf{U}\)\(\delta \mathbf{S}\)

\[\begin{aligned} \delta \mathbf{U}_{[t]} = \mathbf{K}_{[t]} \delta \mathbf{S}_{[t]}^{C} + \left.\delta \mathbf{U}_{[t]}\right|_{\text{from } \mathbf{O}_{[t]}} \end{aligned}\]
\[\begin{aligned} \delta \mathbf{S}_{[t]}^{C} = \delta \mathbf{S}_{[t+1]}^{C} + \mathbf{Q}_{[t+1]}^\top \delta \mathbf{O}_{[t+1]} - \mathbf{W}_{[t+1]} ^\top \delta \mathbf{U}_{[t+1]} \end{aligned}\]

对应代码fla.ops.common.chunk_delta_h -> chunk_gated_delta_rule_bwd_dhu

Step 5:计算 \(\delta \mathbf{Q}\)\(\delta \mathbf{K}_{part1}\)\(\delta \mathbf{W}\)

\[\begin{aligned} \delta \mathbf{B}_{[t]} &= \delta \mathbf{O}_{[t]} \mathbf{V}_{[t],new}^\top \odot \mathbf{M} \\ \\ \delta \mathbf{Q}_{[t]} &= \delta \mathbf{O}_{[t]} \mathbf{S}_{[t-1]}^{C\top} + \delta \mathbf{B}_{[t]} \mathbf{K}_{[t]} \\ \\ \delta \mathbf{K}_{[t], \text{part1}} &= \mathbf{V}_{[t],new} \delta \mathbf{S}_{[t]}^{C\top} + \delta \mathbf{B}_{[t]}^\top\mathbf{Q}_{[t]} \\ \\ \delta \mathbf{W}_{[t]} &= - \delta \mathbf{U}_{[t]} \mathbf{S}_{[t-1]}^{C\top} \end{aligned}\]

对应代码fla.ops.common.chunk_o -> chunk_bwd_dqkwg

Step 6:计算 \(\delta \mathbf{V}\)\(\delta \mathbf{K}_{part2}\)\(\delta \boldsymbol{\beta}\)

本步骤通过 WY 表示的反向传播,计算 \(\delta \mathbf{V}\)\(\delta \mathbf{K}\) 的剩余部分以及 \(\delta \boldsymbol{\beta}\)

\[\begin{aligned} \delta \mathbf{V}_{[t]} &= \text{Diag}(\boldsymbol{\beta}_{[t]}) \mathbf{A_i}_{[t]} \delta \mathbf{U}_{[t]} \\ \\ \delta \mathbf{A_i}_{[t]} &= \delta \mathbf{U}_{[t]} (\mathbf{V}_{[t]} \text{Diag}(\boldsymbol{\beta}_{[t]}))^\top + \delta \mathbf{W}_{[t]} (\mathbf{K}_{[t]}\text{Diag}(\boldsymbol{\beta}_{[t]}))^\top \\ \\ \delta \mathbf{A_x}_{[t]} &= \delta \mathbf{X}_{[t]} \odot \mathbf{M}_{-1} = -\mathbf{A_i}_{[t]}^\top \delta \mathbf{A_i}_{[t]} \mathbf{A_i}_{[t]}^\top \odot \mathbf{M}_{-1} \\ \\ \delta \mathbf{K}_{[t], \text{part2}} &= \text{Diag}(\boldsymbol{\beta}_{[t]}) \mathbf{A_i}_{[t]}^\top \delta \mathbf{W}_{[t]} + \left(\delta \mathbf{X}_{[t]} \odot \mathbf{M}_{-1} \right)^\top (\text{Diag}(\boldsymbol{\beta}_{[t]}) \mathbf{K}_{[t]}) + \text{Diag}(\boldsymbol{\beta}_{[t]}) \left(\delta \mathbf{X}_{[t]} \odot \mathbf{M}_{-1} \right) \mathbf{K}_{[t]} \\ \\ \delta \boldsymbol{\beta}_{[t]} &= \text{diag}\left(\mathbf{A_i}_{[t]}^\top \delta \mathbf{U}_{[t]} \mathbf{V}_{[t]}^\top\right) + \text{diag}\left(\mathbf{A_i}_{[t]}^\top \delta \mathbf{W}_{[t]} \mathbf{K}_{[t]}^\top\right) + \text{diag}\left((\delta \mathbf{X}_{[t]} \odot \mathbf{M}_{-1}) \mathbf{K}_{[t]} \mathbf{K}_{[t]}^\top\right) \end{aligned}\]

对应代码fla.ops.delta_rule.wy_fast -> prepare_wy_repr_bwd

Comments