阅读笔记:GLA 在 FLA 中的代码实现笔记
原文链接:https://arxiv.org/pdf/2312.06635
代码链接:https://github.com/fla-org/flash-linear-attention
笔记链接:https://mzeromiko.github.io/blogs
声明:本笔记是对 flash-linear-attention 仓库中 GLA 部分 Triton 实现的逐步数学解读,与前一篇 GLA 论文笔记配合阅读。
一、前向传播(Forward)
入口函数签名:
def chunk_gla_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
g_cumsum: torch.Tensor | None,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
chunk_size: int = 64,
chunk_indices: torch.LongTensor | None = None,
):
return g_cumsum, A, h, ht, o
Step 1:计算 log 域累积门控
符号说明:本笔记中的 \(\gamma, \Gamma\) 定义为前一篇 GLA 笔记中的 \(\log \gamma, \log \Gamma\),后续所有涉及 \(\gamma, \Gamma\) 的公式均在 log 域下,请注意区分。
\[\begin{aligned}
\boldsymbol{\gamma}_{[t]}^r &= \sum_{i=tC+1}^{tC+r} \log \boldsymbol{\alpha}_i \in \mathbb{R}^{d \times 1}
\\
\\
\mathbf{\Gamma}_{[t]} &= [ \boldsymbol{\gamma}_{[t]}^{1}, \boldsymbol{\gamma}_{[t]}^{2}, \dots, \boldsymbol{\gamma}_{[t]}^{C} ]^\top \in \mathbb{R}^{C \times d}
\end{aligned}\]
对应代码:fla.ops.utils.cumsum -> chunk_local_cumsum
Step 2:递推计算隐藏状态
符号说明:本笔记中的 \(\mathbf{S}_{[t]}^{C}\) 对应 GLA 论文笔记中的 \(\mathbf{S}_{[t]}^{C\top}\),全文皆如此。
\[\begin{aligned}
\mathbf{S}_{[t]}^{C}
&=
\mathbf{S}_{[t-1]}^{C} \odot \exp(\boldsymbol{\gamma}_{[t]}^{C} \boldsymbol{1}^\top)
+
\left(
\boldsymbol{K}_{[t]}^\top \odot
\exp\left(\boldsymbol{\gamma}_{[t]}^{C} \boldsymbol{1}^\top - \mathbf{\Gamma}_{[t]}^\top \right)
\right)
\mathbf{V}_{[t]}
\end{aligned}\]
对应代码:fla.ops.common.chunk_h -> chunk_fwd_h
Step 3:计算 chunk 内注意力矩阵 A(非对角子块部分)
仅在第 \([t]\) 块内讨论。令 \(\mathbf{A}_{[i,j]}\) 表示第 \([i,j]\) 个子块,\(\mathbf{\Gamma}_{[i]}\) 表示第 \([i]\) 子块的 cumsum 值,\(\boldsymbol{\gamma}_{[i]}^{k}\) 表示第 \([i]\) 子块的第 \(k\) 个元素。
\[\begin{aligned}
\mathbf{A}_{[i,j]}
=
\left(\mathbf{Q}_{[i]} \odot
\exp\left(\mathbf{\Gamma}_{[i]} - \boldsymbol{\gamma}_{[i]}^1 \boldsymbol{1}^\top \right)
\right)
\left(\mathbf{K}_{[j]}^\top \odot
\exp\left(\boldsymbol{\gamma}_{[i]}^{1\top} \boldsymbol{1} - \mathbf{\Gamma}_{[j]}^{\top} \right)
\right)
,\quad i \gt j
\end{aligned}\]
对应代码:fla.ops.gla.chunk -> chunk_gla_fwd_intra_gk -> chunk_gla_fwd_A_kernel_intra_sub_inter
Step 4:计算 chunk 内注意力矩阵 A(对角子块部分)
仅在第 \([t]\) 块内讨论。符号定义同 Step 3。
对角子块上,\(\mathbf{A}\) 采用逐列计算(代码实现),为简便起见,此处以逐元素形式表述。
\[\begin{aligned}
\mathbf{A}_{[i,j]}^{k_1, k_2}
=
\sum_{d}
\left(\mathbf{Q}_{[i]}^{k_1} \odot \mathbf{K}_{[j]}^{k_2} \odot
\exp\left(\mathbf{\Gamma}_{[i]}^{k_1} - \mathbf{\Gamma}_{[j]}^{k_2} \right)
\right)
, \quad i = j, k_1 \geq k_2
\end{aligned}\]
对应代码:fla.ops.gla.chunk -> chunk_gla_fwd_intra_gk -> chunk_gla_fwd_A_kernel_intra_sub_intra
如果沿 \(d\) 维切成块分别计算后再合并,则使用以下两个函数,其数学内容与上述函数基本一致:
fla.ops.gla.chunk -> chunk_gla_fwd_intra_gk -> chunk_gla_fwd_A_kernel_intra_sub_intra_split
fla.ops.gla.chunk -> chunk_gla_fwd_intra_gk -> chunk_gla_fwd_A_kernel_intra_sub_intra_merge
Step 5:计算最终输出
由于在 chunk_gla_fwd_intra_gk 中,\(\mathbf{A}\) 使用 torch.empty 初始化(上三角子块区域包含垃圾值),因此在计算输出时需要再做一次下三角掩码。
\[\begin{aligned}
\mathbf{O}_{[t]}
=
(\mathbf{Q}_{[t]} \odot \exp\mathbf{\Gamma}_{[t]}) \mathbf{S}_{[t-1]}^{C}
+
(\mathbf{A}_{[t]} \odot \mathbf{M}) \mathbf{V}_{[t]}
\end{aligned}\]
对应代码:fla.ops.gla.chunk -> chunk_gla_fwd_o_gk
二、反向传播(Backward)
入口函数签名:
def chunk_gla_bwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
g_cumsum: torch.Tensor | None,
scale: float,
initial_state: torch.Tensor,
h: torch.Tensor,
A: torch.Tensor,
do: torch.Tensor,
dht: torch.Tensor,
cu_seqlens: torch.LongTensor | None = None,
chunk_size: int = 64,
chunk_indices: torch.LongTensor | None = None,
):
return dq, dk, dv, dg, dh0
Step 1:计算 log 域累积门控
同前向 Step 1。
对应代码:fla.ops.utils.cumsum -> chunk_local_cumsum
Step 2:递推计算隐藏状态(重计算)
同前向 Step 2。
对应代码:fla.ops.common.chunk_h -> chunk_fwd_h
Step 3:反向递推计算状态梯度
\[\begin{aligned}
\delta \mathbf{S}_{[t]}^{C} = \delta \mathbf{S}_{[t+1]}^{C} \odot \exp (\boldsymbol{\gamma}_{[t+1]}^{C} \boldsymbol{1}^\top) + \left(\mathbf{Q}_{[t+1]}^\top \odot \exp(\mathbf{\Gamma}_{[t+1]}^\top)\right) \delta \mathbf{O}_{[t+1]}
\end{aligned}\]
对应代码:fla.ops.common.chunk_h -> chunk_bwd_dh
Step 4:计算 V 的梯度
\[\begin{aligned}
\delta \mathbf{V}_{[t]} = \left(\mathbf{A}_{[t]}^\top \odot \mathbf{M}^\top \right) \delta \mathbf{O}_{[t]} + \left(\mathbf{K}_{[t]} \odot \exp(\boldsymbol{\gamma}_{[t]}^C - \mathbf{\Gamma}_{[t]}) \right) \delta \mathbf{S}_{[t]}^{C}
\end{aligned}\]
对应代码:fla.ops.gla.chunk -> chunk_gla_bwd_dv
Step 5:计算 A 的梯度
\[\begin{aligned}
\delta \mathbf{A}_{[t]} = \delta \mathbf{O}_{[t]} \mathbf{V}_{[t]}^{\top} \odot \mathbf{M}
\end{aligned}\]
对应代码:fla.ops.gla.chunk -> chunk_gla_bwd_dA
Step 6:计算 Q、K 的 intra-chunk 梯度
仅在第 \([t]\) 块内讨论。令 \(\mathbf{A}_{[i,j]}\) 表示第 \([i,j]\) 个子块,\(\boldsymbol{\gamma}_{[i]}^{k}\) 表示第 \([i]\) 子块第 \(k\) 个元素的 cumsum 值,\(\boldsymbol{\gamma}_{[i]}^{C_i}\) 表示第 \([i]\) 子块最后一个元素的 cumsum 值。
\(\delta \mathbf{Q}\) 的 intra-chunk 梯度:
\[\begin{aligned}
\left.\delta \mathbf{Q}_{[i]}\right|_{\text{from } \mathbf{A}_{[i,j]}}
&= \delta \mathbf{A}_{[i,j]} \left(\mathbf{K}_{[j]} \odot \exp\left( \boldsymbol{\gamma}_{[i]}^1 \boldsymbol{1}^\top - \mathbf{\Gamma}_{[j]}
\right) \right) \odot \exp\left(
\mathbf{\Gamma}_{[i]} - \boldsymbol{\gamma}_{[i]}^1 \boldsymbol{1}^\top
\right)
,\quad i \gt j
\\
\\
\left.\delta \mathbf{Q}_{[i]}^{k_1}\right|_{\text{from } \mathbf{A}_{[i,j]}}
&= \sum_{k_2 \leq k_1}
\delta \mathbf{A}_{[i,j]}^{k_1, k_2} \mathbf{K}_{[j]}^{k_2} \odot \exp\left( \boldsymbol{\gamma}_{[i]}^{k_1} - \boldsymbol{\gamma}_{[j]}^{k_2}
\right)
,\quad i = j
\end{aligned}\]
\(\delta \mathbf{K}\) 的 intra-chunk 梯度:
\[\begin{aligned}
\left.\delta \mathbf{K}_{[j]}\right|_{\text{from } \mathbf{A}_{[i,j]}}
&= \delta \mathbf{A}_{[i,j]}^\top \left(\mathbf{Q}_{[i]} \odot \exp\left(
\mathbf{\Gamma}_{[i]} - \boldsymbol{\gamma}_{[i]}^{C_i} \boldsymbol{1}^\top
\right) \right) \odot \exp\left( \boldsymbol{\gamma}_{[i]}^{C_i} \boldsymbol{1}^\top - \mathbf{\Gamma}_{[j]}
\right)
,\quad i \gt j
\\
\\
\left.\delta \mathbf{K}_{[i]}^{k_2}\right|_{\text{from } \mathbf{A}_{[i,j]}}
&= \sum_{k_1 \geq k_2}
\delta \mathbf{A}_{[i,j]}^{k_1, k_2 \top} \mathbf{Q}_{[i]}^{k_1} \odot \exp\left(
\boldsymbol{\gamma}_{[i]}^{k_1} - \boldsymbol{\gamma}_{[i]}^{k_2}
\right)
,\quad i = j
\end{aligned}\]
对应代码:fla.ops.gla.chunk -> chunk_gla_bwd_dqk_intra
Step 7:计算 inter-chunk 梯度并合并,得到最终 dQ、dK、dg
本步骤将 inter-chunk 贡献与 Step 6 的 intra-chunk 贡献合并,并计算门控值的梯度。
\[\begin{aligned}
\delta \mathbf{\gamma}_{[t]}^C
&=
(\delta \exp\mathbf{\gamma}_{[t]}) \odot \exp \boldsymbol{\gamma}_{[t]}^C
\\&=
\left(\sum_{d_v}
\mathbf{S}_{[t-1]}^{C\top} \odot \delta \mathbf{S}_{[t]}^{C\top} \right)
\odot \exp \boldsymbol{\gamma}_{[t]}^C
+
\sum_{r}
\left.\delta \mathbf{K}_{[t]}\right|_{\text{from } \mathbf{S}_{[t]}} \odot \mathbf{K}_{[t]}
\quad \text{in float32}
\\
\\
\left.\delta \mathbf{Q}_{[t]}\right|_{\text{from } \mathbf{O}_{[t], \text{part 1}}}
&=
\delta \mathbf{O}_{[t]} \mathbf{S}_{[t-1]}^{C\top} \odot \exp\mathbf{\Gamma}_{[t]}
\\
\\
\left.\delta \mathbf{K}_{[t]}\right|_{\text{from } \mathbf{S}_{[t]}}
&=
\mathbf{V}_{[t]} \delta \mathbf{S}_{[t-1]}^{C\top}
\odot \exp\left(\boldsymbol{\gamma}_{[t]}^C - \mathbf{\Gamma}_{[t]}\right)
\\
\\
\delta \mathbf{Q}_{[t]}
&=
\left.\delta \mathbf{Q}_{[t]}\right|_{\text{from } \mathbf{O}_{[t], \text{part 1}}} + \left.\delta \mathbf{Q}_{[t]}\right|_{\text{from } \mathbf{A}_{[t]}}
\\
\\
\delta \mathbf{K}_{[t]}
&=
\left.\delta \mathbf{K}_{[t]}\right|_{\text{from } \mathbf{S}_{[t]}} + \left.\delta \mathbf{K}_{[t]}\right|_{\text{from } \mathbf{A}_{[t]}}
\\
\\
\left.\delta \mathbf{\Gamma}_{[t]}\right|_{\text{from } \mathbf{O}_{[t]}}
&=
\delta \mathbf{Q}_{[t]} \odot \mathbf{Q}_{[t]}
-
\delta \mathbf{K}_{[t]} \odot \mathbf{K}_{[t]}
\\
\\
\delta \log \boldsymbol{\alpha}_{[t]}
&=
\text{suffix\_cumsum}(\left.\delta\mathbf{\Gamma}_{[t]}\right|_{\text{from } \mathbf{O}_{[t]}}) + \delta \mathbf{\gamma}_{[t]}^C
\end{aligned}\]
对应代码:fla.ops.gla.chunk -> chunk_gla_bwd_dqkg