跳转至

阅读笔记:GLA 在 FLA 中的代码实现笔记

原文链接https://arxiv.org/pdf/2312.06635
代码链接https://github.com/fla-org/flash-linear-attention
笔记链接https://mzeromiko.github.io/blogs
声明:本笔记是对 flash-linear-attention 仓库中 GLA 部分 Triton 实现的逐步数学解读,与前一篇 GLA 论文笔记配合阅读。

一、前向传播(Forward)

入口函数签名:

def chunk_gla_fwd(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    g: torch.Tensor,
    g_cumsum: torch.Tensor | None,
    scale: float,
    initial_state: torch.Tensor,
    output_final_state: bool,
    cu_seqlens: torch.LongTensor | None = None,
    chunk_size: int = 64,
    chunk_indices: torch.LongTensor | None = None,
):
return g_cumsum, A, h, ht, o

Step 1:计算 log 域累积门控

符号说明:本笔记中的 \(\gamma, \Gamma\) 定义为前一篇 GLA 笔记中的 \(\log \gamma, \log \Gamma\),后续所有涉及 \(\gamma, \Gamma\) 的公式均在 log 域下,请注意区分。

\[\begin{aligned} \boldsymbol{\gamma}_{[t]}^r &= \sum_{i=tC+1}^{tC+r} \log \boldsymbol{\alpha}_i \in \mathbb{R}^{d \times 1} \\ \\ \mathbf{\Gamma}_{[t]} &= [ \boldsymbol{\gamma}_{[t]}^{1}, \boldsymbol{\gamma}_{[t]}^{2}, \dots, \boldsymbol{\gamma}_{[t]}^{C} ]^\top \in \mathbb{R}^{C \times d} \end{aligned}\]

对应代码fla.ops.utils.cumsum -> chunk_local_cumsum

Step 2:递推计算隐藏状态

符号说明:本笔记中的 \(\mathbf{S}_{[t]}^{C}\) 对应 GLA 论文笔记中的 \(\mathbf{S}_{[t]}^{C\top}\),全文皆如此。

\[\begin{aligned} \mathbf{S}_{[t]}^{C} &= \mathbf{S}_{[t-1]}^{C} \odot \exp(\boldsymbol{\gamma}_{[t]}^{C} \boldsymbol{1}^\top) + \left( \boldsymbol{K}_{[t]}^\top \odot \exp\left(\boldsymbol{\gamma}_{[t]}^{C} \boldsymbol{1}^\top - \mathbf{\Gamma}_{[t]}^\top \right) \right) \mathbf{V}_{[t]} \end{aligned}\]

对应代码fla.ops.common.chunk_h -> chunk_fwd_h

Step 3:计算 chunk 内注意力矩阵 A(非对角子块部分)

仅在第 \([t]\) 块内讨论。令 \(\mathbf{A}_{[i,j]}\) 表示第 \([i,j]\) 个子块,\(\mathbf{\Gamma}_{[i]}\) 表示第 \([i]\) 子块的 cumsum 值,\(\boldsymbol{\gamma}_{[i]}^{k}\) 表示第 \([i]\) 子块的第 \(k\) 个元素。

\[\begin{aligned} \mathbf{A}_{[i,j]} = \left(\mathbf{Q}_{[i]} \odot \exp\left(\mathbf{\Gamma}_{[i]} - \boldsymbol{\gamma}_{[i]}^1 \boldsymbol{1}^\top \right) \right) \left(\mathbf{K}_{[j]}^\top \odot \exp\left(\boldsymbol{\gamma}_{[i]}^{1\top} \boldsymbol{1} - \mathbf{\Gamma}_{[j]}^{\top} \right) \right) ,\quad i \gt j \end{aligned}\]

对应代码fla.ops.gla.chunk -> chunk_gla_fwd_intra_gk -> chunk_gla_fwd_A_kernel_intra_sub_inter

Step 4:计算 chunk 内注意力矩阵 A(对角子块部分)

仅在第 \([t]\) 块内讨论。符号定义同 Step 3。

对角子块上,\(\mathbf{A}\) 采用逐列计算(代码实现),为简便起见,此处以逐元素形式表述。

\[\begin{aligned} \mathbf{A}_{[i,j]}^{k_1, k_2} = \sum_{d} \left(\mathbf{Q}_{[i]}^{k_1} \odot \mathbf{K}_{[j]}^{k_2} \odot \exp\left(\mathbf{\Gamma}_{[i]}^{k_1} - \mathbf{\Gamma}_{[j]}^{k_2} \right) \right) , \quad i = j, k_1 \geq k_2 \end{aligned}\]

对应代码fla.ops.gla.chunk -> chunk_gla_fwd_intra_gk -> chunk_gla_fwd_A_kernel_intra_sub_intra

如果沿 \(d\) 维切成块分别计算后再合并,则使用以下两个函数,其数学内容与上述函数基本一致:

fla.ops.gla.chunk -> chunk_gla_fwd_intra_gk -> chunk_gla_fwd_A_kernel_intra_sub_intra_split fla.ops.gla.chunk -> chunk_gla_fwd_intra_gk -> chunk_gla_fwd_A_kernel_intra_sub_intra_merge


Step 5:计算最终输出

由于在 chunk_gla_fwd_intra_gk 中,\(\mathbf{A}\) 使用 torch.empty 初始化(上三角子块区域包含垃圾值),因此在计算输出时需要再做一次下三角掩码。

\[\begin{aligned} \mathbf{O}_{[t]} = (\mathbf{Q}_{[t]} \odot \exp\mathbf{\Gamma}_{[t]}) \mathbf{S}_{[t-1]}^{C} + (\mathbf{A}_{[t]} \odot \mathbf{M}) \mathbf{V}_{[t]} \end{aligned}\]

对应代码fla.ops.gla.chunk -> chunk_gla_fwd_o_gk

二、反向传播(Backward)

入口函数签名:

def chunk_gla_bwd(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    g: torch.Tensor,
    g_cumsum: torch.Tensor | None,
    scale: float,
    initial_state: torch.Tensor,
    h: torch.Tensor,
    A: torch.Tensor,
    do: torch.Tensor,
    dht: torch.Tensor,
    cu_seqlens: torch.LongTensor | None = None,
    chunk_size: int = 64,
    chunk_indices: torch.LongTensor | None = None,
):
    return dq, dk, dv, dg, dh0

Step 1:计算 log 域累积门控

同前向 Step 1。

对应代码fla.ops.utils.cumsum -> chunk_local_cumsum

Step 2:递推计算隐藏状态(重计算)

同前向 Step 2。

对应代码fla.ops.common.chunk_h -> chunk_fwd_h

Step 3:反向递推计算状态梯度

\[\begin{aligned} \delta \mathbf{S}_{[t]}^{C} = \delta \mathbf{S}_{[t+1]}^{C} \odot \exp (\boldsymbol{\gamma}_{[t+1]}^{C} \boldsymbol{1}^\top) + \left(\mathbf{Q}_{[t+1]}^\top \odot \exp(\mathbf{\Gamma}_{[t+1]}^\top)\right) \delta \mathbf{O}_{[t+1]} \end{aligned}\]

对应代码fla.ops.common.chunk_h -> chunk_bwd_dh

Step 4:计算 V 的梯度

\[\begin{aligned} \delta \mathbf{V}_{[t]} = \left(\mathbf{A}_{[t]}^\top \odot \mathbf{M}^\top \right) \delta \mathbf{O}_{[t]} + \left(\mathbf{K}_{[t]} \odot \exp(\boldsymbol{\gamma}_{[t]}^C - \mathbf{\Gamma}_{[t]}) \right) \delta \mathbf{S}_{[t]}^{C} \end{aligned}\]

对应代码fla.ops.gla.chunk -> chunk_gla_bwd_dv

Step 5:计算 A 的梯度

\[\begin{aligned} \delta \mathbf{A}_{[t]} = \delta \mathbf{O}_{[t]} \mathbf{V}_{[t]}^{\top} \odot \mathbf{M} \end{aligned}\]

对应代码fla.ops.gla.chunk -> chunk_gla_bwd_dA

Step 6:计算 Q、K 的 intra-chunk 梯度

仅在第 \([t]\) 块内讨论。令 \(\mathbf{A}_{[i,j]}\) 表示第 \([i,j]\) 个子块,\(\boldsymbol{\gamma}_{[i]}^{k}\) 表示第 \([i]\) 子块第 \(k\) 个元素的 cumsum 值,\(\boldsymbol{\gamma}_{[i]}^{C_i}\) 表示第 \([i]\) 子块最后一个元素的 cumsum 值。

\(\delta \mathbf{Q}\) 的 intra-chunk 梯度

\[\begin{aligned} \left.\delta \mathbf{Q}_{[i]}\right|_{\text{from } \mathbf{A}_{[i,j]}} &= \delta \mathbf{A}_{[i,j]} \left(\mathbf{K}_{[j]} \odot \exp\left( \boldsymbol{\gamma}_{[i]}^1 \boldsymbol{1}^\top - \mathbf{\Gamma}_{[j]} \right) \right) \odot \exp\left( \mathbf{\Gamma}_{[i]} - \boldsymbol{\gamma}_{[i]}^1 \boldsymbol{1}^\top \right) ,\quad i \gt j \\ \\ \left.\delta \mathbf{Q}_{[i]}^{k_1}\right|_{\text{from } \mathbf{A}_{[i,j]}} &= \sum_{k_2 \leq k_1} \delta \mathbf{A}_{[i,j]}^{k_1, k_2} \mathbf{K}_{[j]}^{k_2} \odot \exp\left( \boldsymbol{\gamma}_{[i]}^{k_1} - \boldsymbol{\gamma}_{[j]}^{k_2} \right) ,\quad i = j \end{aligned}\]

\(\delta \mathbf{K}\) 的 intra-chunk 梯度

\[\begin{aligned} \left.\delta \mathbf{K}_{[j]}\right|_{\text{from } \mathbf{A}_{[i,j]}} &= \delta \mathbf{A}_{[i,j]}^\top \left(\mathbf{Q}_{[i]} \odot \exp\left( \mathbf{\Gamma}_{[i]} - \boldsymbol{\gamma}_{[i]}^{C_i} \boldsymbol{1}^\top \right) \right) \odot \exp\left( \boldsymbol{\gamma}_{[i]}^{C_i} \boldsymbol{1}^\top - \mathbf{\Gamma}_{[j]} \right) ,\quad i \gt j \\ \\ \left.\delta \mathbf{K}_{[i]}^{k_2}\right|_{\text{from } \mathbf{A}_{[i,j]}} &= \sum_{k_1 \geq k_2} \delta \mathbf{A}_{[i,j]}^{k_1, k_2 \top} \mathbf{Q}_{[i]}^{k_1} \odot \exp\left( \boldsymbol{\gamma}_{[i]}^{k_1} - \boldsymbol{\gamma}_{[i]}^{k_2} \right) ,\quad i = j \end{aligned}\]

对应代码fla.ops.gla.chunk -> chunk_gla_bwd_dqk_intra

Step 7:计算 inter-chunk 梯度并合并,得到最终 dQ、dK、dg

本步骤将 inter-chunk 贡献与 Step 6 的 intra-chunk 贡献合并,并计算门控值的梯度。

\[\begin{aligned} \delta \mathbf{\gamma}_{[t]}^C &= (\delta \exp\mathbf{\gamma}_{[t]}) \odot \exp \boldsymbol{\gamma}_{[t]}^C \\&= \left(\sum_{d_v} \mathbf{S}_{[t-1]}^{C\top} \odot \delta \mathbf{S}_{[t]}^{C\top} \right) \odot \exp \boldsymbol{\gamma}_{[t]}^C + \sum_{r} \left.\delta \mathbf{K}_{[t]}\right|_{\text{from } \mathbf{S}_{[t]}} \odot \mathbf{K}_{[t]} \quad \text{in float32} \\ \\ \left.\delta \mathbf{Q}_{[t]}\right|_{\text{from } \mathbf{O}_{[t], \text{part 1}}} &= \delta \mathbf{O}_{[t]} \mathbf{S}_{[t-1]}^{C\top} \odot \exp\mathbf{\Gamma}_{[t]} \\ \\ \left.\delta \mathbf{K}_{[t]}\right|_{\text{from } \mathbf{S}_{[t]}} &= \mathbf{V}_{[t]} \delta \mathbf{S}_{[t-1]}^{C\top} \odot \exp\left(\boldsymbol{\gamma}_{[t]}^C - \mathbf{\Gamma}_{[t]}\right) \\ \\ \delta \mathbf{Q}_{[t]} &= \left.\delta \mathbf{Q}_{[t]}\right|_{\text{from } \mathbf{O}_{[t], \text{part 1}}} + \left.\delta \mathbf{Q}_{[t]}\right|_{\text{from } \mathbf{A}_{[t]}} \\ \\ \delta \mathbf{K}_{[t]} &= \left.\delta \mathbf{K}_{[t]}\right|_{\text{from } \mathbf{S}_{[t]}} + \left.\delta \mathbf{K}_{[t]}\right|_{\text{from } \mathbf{A}_{[t]}} \\ \\ \left.\delta \mathbf{\Gamma}_{[t]}\right|_{\text{from } \mathbf{O}_{[t]}} &= \delta \mathbf{Q}_{[t]} \odot \mathbf{Q}_{[t]} - \delta \mathbf{K}_{[t]} \odot \mathbf{K}_{[t]} \\ \\ \delta \log \boldsymbol{\alpha}_{[t]} &= \text{suffix\_cumsum}(\left.\delta\mathbf{\Gamma}_{[t]}\right|_{\text{from } \mathbf{O}_{[t]}}) + \delta \mathbf{\gamma}_{[t]}^C \end{aligned}\]

对应代码fla.ops.gla.chunk -> chunk_gla_bwd_dqkg

Comments