阅读笔记:DeltaNet 在 FLA 中的代码实现笔记
原文链接:https://arxiv.org/abs/2406.06484
代码链接:https://github.com/fla-org/flash-linear-attention
声明:本笔记是对 flash-linear-attention 仓库中 DeltaNet部分 Triton 实现的逐步数学解读,与前一篇 DeltaNet 论文笔记配合阅读。
一、前向传播(Forward)
入口函数签名:
def chunk_delta_rule_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
chunk_indices: torch.LongTensor | None = None,
):
return o, A, final_state
Step 1:计算 KK^T 矩阵
\[\begin{aligned}
\mathbf{A} = (\boldsymbol{\beta}_{[t]}^\top \boldsymbol{1} ) \odot \left( \mathbf{K}_{[t]} \mathbf{K}_{[t]}^\top \odot \mathbf{M}_{-1} \right)
\end{aligned}\]
对应代码:fla.ops.delta_rule.wy_fast -> prepare_wy_repr_fwd -> chunk_scaled_dot_kkt_fwd
Step 2:单位下三角矩阵求逆
此处代码采用逐行前代法与分块矩阵求逆进行求解。其中 \(16 \times 16\) 的子矩阵块使用前代法处理。这里 \(\mathbf{A_i}\) 相当于前一篇笔记中定义的 \(\mathbf{X}^{-1}\)。
\[\begin{aligned}
\mathbf{A_i} = (\mathbf{I} + \mathbf{A})^{-1}
\end{aligned}\]
对应代码:fla.ops.delta_rule.wy_fast -> prepare_wy_repr_fwd -> solve_tril
评论:Neumann 级数 + 倍增法可以更快,但精度更低。
Step 3:计算 U 和 W
\[\begin{aligned}
\mathbf{U}_{[t]}
&=
\mathbf{A_i} \left((\boldsymbol{\beta}_{[t]}\top \boldsymbol{1}) \odot \mathbf{V}_{[t]} \right)
\\
\\
\mathbf{W}_{[t]}
&=
\mathbf{A_i} \left((\boldsymbol{\beta}_{[t]}\top \boldsymbol{1}) \odot \mathbf{K}_{[t]} \right)
\end{aligned}\]
对应代码:fla.ops.delta_rule.wy_fast -> prepare_wy_repr_fwd -> recompute_w_u_fwd
Step 4:递推计算隐藏状态
符号说明:这里 \(\mathbf{S}_{[t]}^C\) 采用 \([d_k, d_v]\) 的布局,因此与前一篇论文笔记中的推导存在转置关系。
\[\begin{aligned}
\mathbf{V}_{[t],new} &= \mathbf{U}_{[t]} - \mathbf{W}_{[t]} \mathbf{S}_{[t-1]}^{C}
\\
\\
\mathbf{S}_{[t]}^{C}
&=
\mathbf{S}_{[t-1]}^{C}
+ \mathbf{K}_{[t]}^\top \mathbf{V}_{[t],new}
\end{aligned}\]
对应代码:fla.ops.common.chunk_delta_h -> chunk_gated_delta_rule_fwd_h
Step 5:计算最终输出
\[\begin{aligned}
\mathbf{O}_{[t]}
=
\mathbf{Q}_{[t]} \mathbf{S}_{[t-1]}^{C}
+
\left( \mathbf{Q}_{[t]} \mathbf{K}_{[t]}^\top \odot \mathbf{M} \right) \mathbf{V}_{[t],new}
\end{aligned}\]
对应代码:fla.ops.common.chunk_o -> chunk_fwd_o
二、反向传播(Backward)
入口函数签名:
def chunk_delta_rule_bwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
beta: torch.Tensor,
A: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
do: torch.Tensor,
dht: torch.Tensor,
cu_seqlens: torch.LongTensor | None = None,
chunk_indices: torch.LongTensor | None = None,
):
return dq, dk, dv, db, dh0
Step 1:重计算 U 和 W
\(\mathbf{A_i}\) 已在前向传播中存储,\(\mathbf{U}, \mathbf{W}\) 需要重新计算。
对应代码:fla.ops.delta_rule.wy_fast -> prepare_wy_repr_fwd -> recompute_w_u_fwd
Step 2:重计算隐藏状态
对应代码:fla.ops.common.chunk_delta_h -> chunk_gated_delta_rule_fwd_h
Step 3:计算 \delta U 的 intra-chunk 部分
\[\begin{aligned}
\left.\delta \mathbf{U}_{[t]}\right|_{\text{from } \mathbf{O}_{[t]}}
=
\left( \mathbf{K}_{[t]} \mathbf{Q}_{[t]}^\top \odot \mathbf{M}^\top \right) \delta \mathbf{O}_{[t]}
\end{aligned}\]
对应代码:fla.ops.common.chunk_o -> chunk_bwd_dv_local
Step 4:反向递推计算 \(\delta \mathbf{U}\) 和 \(\delta \mathbf{S}\)
\[\begin{aligned}
\delta \mathbf{U}_{[t]}
=
\mathbf{K}_{[t]} \delta \mathbf{S}_{[t]}^{C}
+
\left.\delta \mathbf{U}_{[t]}\right|_{\text{from } \mathbf{O}_{[t]}}
\end{aligned}\]
\[\begin{aligned}
\delta \mathbf{S}_{[t]}^{C}
=
\delta \mathbf{S}_{[t+1]}^{C}
+
\mathbf{Q}_{[t+1]}^\top \delta \mathbf{O}_{[t+1]}
-
\mathbf{W}_{[t+1]} ^\top
\delta \mathbf{U}_{[t+1]}
\end{aligned}\]
对应代码:fla.ops.common.chunk_delta_h -> chunk_gated_delta_rule_bwd_dhu
Step 5:计算 \(\delta \mathbf{Q}\)、\(\delta \mathbf{K}_{part1}\) 和 \(\delta \mathbf{W}\)
\[\begin{aligned}
\delta \mathbf{B}_{[t]}
&=
\delta \mathbf{O}_{[t]} \mathbf{V}_{[t],new}^\top \odot \mathbf{M}
\\
\\
\delta \mathbf{Q}_{[t]}
&=
\delta \mathbf{O}_{[t]} \mathbf{S}_{[t-1]}^{C\top}
+
\delta \mathbf{B}_{[t]} \mathbf{K}_{[t]}
\\
\\
\delta \mathbf{K}_{[t], \text{part1}}
&=
\mathbf{V}_{[t],new} \delta \mathbf{S}_{[t]}^{C\top}
+
\delta \mathbf{B}_{[t]}^\top\mathbf{Q}_{[t]}
\\
\\
\delta \mathbf{W}_{[t]}
&=
- \delta \mathbf{U}_{[t]} \mathbf{S}_{[t-1]}^{C\top}
\end{aligned}\]
对应代码:fla.ops.common.chunk_o -> chunk_bwd_dqkwg
Step 6:计算 \(\delta \mathbf{V}\)、\(\delta \mathbf{K}_{part2}\) 和 \(\delta \boldsymbol{\beta}\)
本步骤通过 WY 表示的反向传播,计算 \(\delta \mathbf{V}\)、\(\delta \mathbf{K}\) 的剩余部分以及 \(\delta \boldsymbol{\beta}\)。
\[\begin{aligned}
\delta \mathbf{V}_{[t]}
&=
\text{Diag}(\boldsymbol{\beta}_{[t]}) \mathbf{A_i}_{[t]} \delta \mathbf{U}_{[t]}
\\
\\
\delta \mathbf{A_i}_{[t]}
&=
\delta \mathbf{U}_{[t]} (\mathbf{V}_{[t]} \text{Diag}(\boldsymbol{\beta}_{[t]}))^\top
+
\delta \mathbf{W}_{[t]} (\mathbf{K}_{[t]}\text{Diag}(\boldsymbol{\beta}_{[t]}))^\top
\\
\\
\delta \mathbf{A_x}_{[t]}
&= \delta \mathbf{X}_{[t]} \odot \mathbf{M}_{-1}
=
-\mathbf{A_i}_{[t]}^\top \delta \mathbf{A_i}_{[t]} \mathbf{A_i}_{[t]}^\top \odot \mathbf{M}_{-1}
\\
\\
\delta \mathbf{K}_{[t], \text{part2}}
&=
\text{Diag}(\boldsymbol{\beta}_{[t]}) \mathbf{A_i}_{[t]}^\top \delta \mathbf{W}_{[t]}
+
\left(\delta \mathbf{X}_{[t]} \odot \mathbf{M}_{-1} \right)^\top (\text{Diag}(\boldsymbol{\beta}_{[t]}) \mathbf{K}_{[t]})
+
\text{Diag}(\boldsymbol{\beta}_{[t]}) \left(\delta \mathbf{X}_{[t]} \odot \mathbf{M}_{-1} \right) \mathbf{K}_{[t]}
\\
\\
\delta \boldsymbol{\beta}_{[t]}
&=
\text{diag}\left(\mathbf{A_i}_{[t]}^\top \delta \mathbf{U}_{[t]} \mathbf{V}_{[t]}^\top\right)
+
\text{diag}\left(\mathbf{A_i}_{[t]}^\top \delta \mathbf{W}_{[t]} \mathbf{K}_{[t]}^\top\right)
+
\text{diag}\left((\delta \mathbf{X}_{[t]} \odot \mathbf{M}_{-1})
\mathbf{K}_{[t]} \mathbf{K}_{[t]}^\top\right)
\end{aligned}\]
对应代码:fla.ops.delta_rule.wy_fast -> prepare_wy_repr_bwd