阅读笔记: Gated Delta Networks
原文链接:https://arxiv.org/pdf/2412.06464
代码链接:https://github.com/fla-org/flash-linear-attention
声明:本文为个人阅读笔记。所有来自我自己的推导都可能存在错误。
一、动机
评论:感觉这篇论文的 Introduction 写得很好。
- 标准注意力的二次方计算复杂度催生了线性注意力的研究方向。
- 尽管最新的线性注意力工作在性能上有所提升,但在上下文检索(in-context retrieval)任务上仍然表现不佳。
- 这并不意外,因为隐状态(state space)的容量是有限的。
- Mamba2 采用全局衰减(decay)来管理记忆,但它无法为每个 KV 对单独遗忘。
- Delta rule 可以逐个处理 KV 的遗忘,但无法迅速抹除过去的记忆。
- 两者具有天然的互补性,因此作者设计了 Gated Delta Rule,结合全局衰减与逐 KV 遗忘。
- 剩下的挑战在于如何高效实现。作者在 chunk-wise 并行算法和 Triton kernel 优化方面做了大量工作。
- 最终完成的 Gated DeltaNet 架构在多项基准上取得了优异表现。
二、符号约定
- 使用 \(\mathbf{S, Q}\) 等粗体大写字母表示矩阵
- 使用 \(\mathbf{q}_t, \mathbf{k}_t\) 等表示列向量(即 \([d, 1]\) 的形式),矩阵则是 \([L, d]\) 的形式,因此会有额外的转置操作
- 使用 \(W_t\) 等表示可学习参数
- 使用 \(\mathbf{q}_t\) 表示 \(\mathbf{Q}\) 的第 \(t\) 行
- \(\square_{[t]} = \square_{[t]}^{1:C} \in \mathbb{R}^{C \times d}\) 表示第 \(t\) 个 chunk,其中 \(\square \in { \mathbf{Q, K, V, \dots} }\)
三、在线学习视角

3.1 数学预备知识
我们先回顾几个基本恒等式。
\[\begin{aligned}
\|\mathbf{A}\|_F^2 &= \sum_{i,j} a_{ij}^2 = \text{Tr}(\mathbf{A}^\top \mathbf{A})
,\quad
\langle\boldsymbol{k}_t, \boldsymbol{v}_t\rangle = \text{Tr}(\boldsymbol{v}_t^\top \boldsymbol{k}_t)
,\quad
\|\boldsymbol{x}\|^2 = \sum_i x_i^2 = \boldsymbol{x}^\top \boldsymbol{x}
\\
\\
d \left(\text{Tr}(\mathbf{A}^\top \mathbf{A})\right) &= \text{Tr}(d(\mathbf{A}^\top \mathbf{A})) = \text{Tr}((d\mathbf{A})^\top \mathbf{A} + \mathbf{A}^\top (d\mathbf{A}))
=
\text{Tr}(2\mathbf{A}^\top (d\mathbf{A}))
\\
\\
d(\boldsymbol{x}^\top \boldsymbol{y}) &= (d\boldsymbol{x})^\top \boldsymbol{y} + \boldsymbol{x}^\top (d\boldsymbol{y})
\Rightarrow
d(\boldsymbol{x}^\top \boldsymbol{x}) = 2 \boldsymbol{x}^\top (d\boldsymbol{x})
\end{aligned}\]
由此可以得到:
\[\begin{aligned}
d \|\mathbf{A}\|_F^2 &= \text{Tr}( (\frac{\partial \|\mathbf{A}\|_F^2}{\partial \mathbf{A}})^\top d\mathbf{A})
\Rightarrow
\frac{\partial \|\mathbf{A}\|_F^2}{\partial \mathbf{A}}
=
2 \mathbf{A}
\\
\\
d \|\boldsymbol{x}\|^2 &= \text{Tr}( (\frac{\partial \|\boldsymbol{x}\|^2}{\partial \boldsymbol{x}})^\top d\boldsymbol{x})
\Rightarrow
\frac{\partial \|\boldsymbol{x}\|^2}{\partial \boldsymbol{x}}
=
2 \boldsymbol{x}
\\
\\
d \langle\boldsymbol{k}_t, \boldsymbol{v}_t\rangle
&=
\text{Tr}((\frac{\partial \langle\boldsymbol{k}_t, \boldsymbol{v}_t\rangle}{\partial \boldsymbol{k}_t})^\top d \boldsymbol{k}_t)
\Rightarrow
\frac{\partial \langle\boldsymbol{k}_t, \boldsymbol{v}_t\rangle}{\partial \boldsymbol{k}_t} = \boldsymbol{v}_t
\end{aligned}\]
此外,还会用到 Sherman–Morrison 公式:
\[\begin{aligned}
(I + \boldsymbol{u}\boldsymbol{v}^\top)^{-1} = I - \frac{\boldsymbol{u}\boldsymbol{v}^\top}{1 + \boldsymbol{v}^\top\boldsymbol{u}}
\end{aligned}\]
3.2 推导
3.2.1 Longhorn
先看 Longhorn 的目标函数:
\[\begin{aligned}
L &= \|\mathbf{S}_t - \mathbf{S}_{t-1}\|_F^2 + \beta_t \|\mathbf{S}_t \boldsymbol{k}_t - \boldsymbol{v}_t\|^2
\\
\\
\Rightarrow
\frac{\partial}{\partial \mathbf{S}_t} L
&=
2 (\mathbf{S}_t - \mathbf{S}_{t-1})
+ 2\beta_t (\mathbf{S}_t \boldsymbol{k}_t - \boldsymbol{v}_t) \boldsymbol{k}_t^\top
= 0
\\
\\
\Rightarrow
\mathbf{S}_t (\mathbf{I} + \beta_t \boldsymbol{k}_t \boldsymbol{k}_t^\top) &= \mathbf{S}_{t-1} + \beta_t\boldsymbol{v}_t \boldsymbol{k}_t^\top
\\
\\
\Rightarrow
\mathbf{S}_t
&= \mathbf{S}_{t-1} (\mathbf{I} - \frac{ \beta_t \boldsymbol{k}_t \boldsymbol{k}_t^\top}{1 + \beta_t \boldsymbol{k}_t^\top \boldsymbol{k}_t} ) + \beta_t\boldsymbol{v}_t \boldsymbol{k}_t^\top (\mathbf{I} - \frac{ \beta_t \boldsymbol{k}_t \boldsymbol{k}_t^\top}{1 + \beta_t \boldsymbol{k}_t^\top \boldsymbol{k}_t} )
\\&=
\mathbf{S}_{t-1} (\mathbf{I} - \epsilon_t \boldsymbol{k}_t \boldsymbol{k}_t^\top)
+ \epsilon_t \boldsymbol{v}_t \boldsymbol{k}_t^\top
+ \epsilon_t \beta_t \boldsymbol{v}_t \boldsymbol{k}_t^\top (\boldsymbol{k}_t^\top \boldsymbol{k}_t - \boldsymbol{k}_t \boldsymbol{k}_t^\top)
\\&=
\mathbf{S}_{t-1} (\mathbf{I} - \epsilon_t \boldsymbol{k}_t \boldsymbol{k}_t^\top)
+ \epsilon_t \boldsymbol{v}_t \boldsymbol{k}_t^\top
\\
\epsilon_t &= \frac{\beta_t}{1 + \beta_t \boldsymbol{k}_t^\top \boldsymbol{k}_t}
\end{aligned}\]
评论:这里的 beta_t 看起来和原论文中的定义并不完全一致。
3.2.2 Mamba2
接着,对于 Mamba2,有:
\[\begin{aligned}
L &= \|\mathbf{S}_t - \alpha_t \mathbf{S}_{t-1}\|_F^2 - 2 \langle \mathbf{S}_t \boldsymbol{k}_t, \boldsymbol{v}_t\rangle
\\
\\
\Rightarrow
\frac{\partial}{\partial \mathbf{S}_t} L
&=
2 (\mathbf{S}_t - \alpha_t \mathbf{S}_{t-1})
-
2 \boldsymbol{v}_t \boldsymbol{k}_t^\top
\Rightarrow
\mathbf{S}_t = \alpha_t \mathbf{S}_{t-1} + \boldsymbol{v}_t \boldsymbol{k}_t^\top
\end{aligned}\]
根据论文中的表格,LA、DeltaNet 和 Gated DeltaNet 都可以按和 Mamba2 类似的方式推导出来。
四、Gated Delta Rule
4.1 前向传播
先回顾 Gated DeltaNet 的更新公式:
\[\begin{aligned}
\mathbf{S}_t
=
\mathbf{S}_{t-1}\alpha_t (\mathbf{I} - \beta_t \boldsymbol{k}_t \boldsymbol{k}_t^\top) + \beta_t \boldsymbol{v}_t \boldsymbol{k}_t^\top
\end{aligned}\]
接着,仿照 DeltaNet 和 GLA 中的推导,定义:
\[\begin{aligned}
\mathbf{P}_{[t]}^{r} &= \prod_{i=t C + 1}^{t C + r}(\mathbf{I} - \beta_{i} \boldsymbol{k}_{i} \boldsymbol{k}_{i}^{\top})
\in \mathbb{R}^{d \times d}
,\quad
\gamma_{[t]}^{r} = \prod_{i=t C + 1}^{t C + r} \alpha_i
\\
\\
\mathbf{H}_{[t]}^{r} &= \sum_{i=tC + 1}^{tC + r} \beta_{i} (\boldsymbol{v}_{i} \boldsymbol{k}_{i}^{\top})
\frac{\gamma_{[t]}^{r}}{\gamma_{[t]}^{i}}
\left(
\prod_{j=i + 1}^{t C + r}(\mathbf{I} - \beta_{j} \boldsymbol{k}_{j} \boldsymbol{k}_{j}^{\top})
\right)
\in \mathbb{R}^{d \times d}
\end{aligned}\]
于是,分块后的状态可以写成:
\[\begin{aligned}
\mathbf{S}_{[t]}^{r}
=
\mathbf{S}_{[t-1]}^{C}
\gamma_{[t]}^{r}
\mathbf{P}_{[t]}^{r} + \mathbf{H}_{[t]}^{r}
, \quad
\text{where }
\mathbf{S}_{[-1]}^{C} = \mathbf{0}
\end{aligned}\]
另一方面,假设:
\[\begin{aligned}
\mathbf{S}_t = \eta_t \sum_{i=1}^{t} \boldsymbol{u^0}_i \boldsymbol{k}_i^\top
\end{aligned}\]
那么由数学归纳法可得:
\[\begin{aligned}
\mathbf{S}_t &= \alpha_t \mathbf{S}_{t-1} + \beta_t (\boldsymbol{v}_t - \alpha_t \mathbf{S}_{t-1} \boldsymbol{k}_t )\boldsymbol{k}_t^\top
\\&=
\alpha_t \eta_{t-1} \sum_{i=1}^{t-1} \boldsymbol{u^0}_i \boldsymbol{k}_i^\top
+
\beta_t \left( \boldsymbol{v}_t - \alpha_t \eta_{t-1} \sum_{i=1}^{t-1} \boldsymbol{u^0}_i \boldsymbol{k}_i^\top \boldsymbol{k}_t \right) \boldsymbol{k}_t^\top
=
\eta_t \sum_{i=1}^{t} \boldsymbol{u^0}_i \boldsymbol{k}_i^\top
\end{aligned}\]
其中:
\[\begin{aligned}
\eta_t = \alpha_t \eta_{t-1} = \gamma_t
,\quad
\boldsymbol{u^0}_t = \frac{\beta_t}{\eta_t} \boldsymbol{v}_t - \beta_t \sum_{i=1}^{t-1} \boldsymbol{u^0}_i \boldsymbol{k}_i^\top \boldsymbol{k}_t
\end{aligned}\]
等价地,也可以写成:
\[\begin{aligned}
\mathbf{S}_t = \sum_{i=1}^{t} \frac{\gamma_t}{\gamma_i} \boldsymbol{u}_i \boldsymbol{k}_i^\top
,\quad
\boldsymbol{u}_t = \beta_t \left(\boldsymbol{v}_t - \sum_{i=1}^{t-1} \frac{\gamma_t}{\gamma_i} \boldsymbol{u}_i \boldsymbol{k}_i^\top \boldsymbol{k}_t \right)
\end{aligned}\]
另一方面,形如 (I - beta_t k_t k_t^T) 的 Householder 变换乘积总可以写成低秩形式,也就是 WY representation。根据 DeltaNet 笔记,可以得到:
\[\begin{aligned}
\mathbf{P}_{[t]}^{r} = \mathbf{I} - \sum_{i=1}^{r} \boldsymbol{w}_{[t]}^{i} \boldsymbol{k}_{[t]}^{i \top}
,\quad
\boldsymbol{w}_{[t]}^{r} = \beta_{[t]}^{r} \boldsymbol{k}_{[t]}^{r} - \beta_{[t]}^{r} \sum_{i=1}^{r-1} \boldsymbol{w}_{[t]}^{i} \boldsymbol{k}_{[t]}^{i \top} \boldsymbol{k}_{[t]}^{r}
\end{aligned}\]
把它代回到 chunkwise 递推式中,就有:
\[\begin{aligned}
\mathbf{S}_{[t]}^{r}
&=
\mathbf{S}_{[t-1]}^{C} \gamma_{[t]}^{r} \left(\mathbf{I} - \sum_{i=1}^{r} \boldsymbol{w}_{[t]}^{i} \boldsymbol{k}_{[t]}^{i \top}\right)
+
\sum_{i=1}^{r} \frac{\gamma_{[t]}^{r}}{\gamma_{[t]}^{i}} \boldsymbol{u}_{[t]}^{i} \boldsymbol{k}_{[t]}^{i\top}
\\&=
\gamma_{[t]}^{r} \mathbf{S}_{[t-1]}^{C}
+
\gamma_{[t]}^{r} \sum_{i=1}^{r} \left(\boldsymbol{u}_{[t]}^{i} - \mathbf{S}_{[t-1]}^{C} \gamma_{[t]}^{i} \boldsymbol{w}_{[t]}^{i} \right)\frac{\boldsymbol{k}_{[t]}^{i\top}}{\gamma_{[t]}^{i}}
\\
\\
\boldsymbol{o}_{[t]}^{r}
&=
\mathbf{S}_{[t]}^{r} \boldsymbol{q}_{[t]}^{r}
=
\mathbf{S}_{[t-1]}^{C} \boldsymbol{q}_{[t]}^{r} \gamma_{[t]}^{r}
+
\sum_{i=1}^{r} \left(\boldsymbol{u}_{[t]}^{i} - \mathbf{S}_{[t-1]}^{C} \gamma_{[t]}^{i} \boldsymbol{w}_{[t]}^{i} \right)\frac{\boldsymbol{k}_{[t]}^{i\top}}{\gamma_{[t]}^{i}} \boldsymbol{q}_{[t]}^{r} \gamma_{[t]}^{r}
\end{aligned}\]
现在再定义:
\[\begin{aligned}
\overleftarrow{\square_{[t]}}
&=
\text{Diag}(\boldsymbol{\gamma}_{[t]})
\square_{[t]}
,\quad
\overrightarrow{\square_{[t]}}
=
\text{Diag}(\boldsymbol{\gamma}_{[t]})^{-1}
\square_{[t]}
\quad\text{for}\quad \square \in \{ \mathbf{Q}, \mathbf{K}, \mathbf{W}\}
\end{aligned}\]
于是,整体计算可以进一步写成矩阵形式:
\[\begin{aligned}
\mathbf{S}_{[t]}^{C}
&=
\gamma_{[t]}^{C} \mathbf{S}_{[t-1]}^{C}
+
\gamma_{[t]}^{C} \left(\mathbf{U}_{[t]}^\top - \mathbf{S}_{[t-1]}^{C} \overleftarrow{\mathbf{W}_{[t]}}^\top \right)\overrightarrow{\mathbf{K}_{[t]}}
\\
\\
\mathbf{O}_{[t]}
&=
\overleftarrow{\mathbf{Q}_{[t]}} \mathbf{S}_{[t-1]}^{C \top}
+
\left( \overleftarrow{\mathbf{Q}_{[t]}} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M} \right) \left(\mathbf{U}_{[t]} - \overleftarrow{\mathbf{W}_{[t]}} \mathbf{S}_{[t-1]}^{C \top} \right)
\end{aligned}\]
评论:这里的 O 与原论文中的表达似乎并不一致。
接下来处理 u_[t] 和 w_[t]。其中,M_{-1} = M - I 表示主对角线为零的严格下三角掩码。
从 u_[t] 的递推式出发,可以得到:
\[\begin{aligned}
\boldsymbol{u}_{[t]}^{r} &= \beta_{[t]}^{r} \left( \boldsymbol{v}_{[t]}^{r} - \sum_{i=1}^{r-1} \frac{\gamma_{[t]}^{r}}{\gamma_{[t]}^{i}} \boldsymbol{u}_{[t]}^{i} \boldsymbol{k}_{[t]}^{i \top} \boldsymbol{k}_{[t]}^{r} \right)
\\
\\
\Rightarrow
\mathbf{U}_{[t]}
&=
\text{Diag}(\boldsymbol{\beta}_{[t]}) \mathbf{V}_{[t]}
-
\text{Diag}(\boldsymbol{\beta}_{[t]}) \left( \overleftarrow{\mathbf{K}_{[t]}} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M}_{-1} \right) \mathbf{U}_{[t]}
\\
\\
\Rightarrow
\mathbf{U}_{[t]}
&=
\left(\mathbf{I} + \text{Diag}(\boldsymbol{\beta}_{[t]}) \left( \overleftarrow{\mathbf{K}_{[t]}} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M}_{-1} \right) \right)^{-1} \text{Diag}(\boldsymbol{\beta}_{[t]}) \mathbf{V}_{[t]}
\end{aligned}\]
同理,也有:
\[\begin{aligned}
\mathbf{W}_{[t]}
=
\left(\mathbf{I} + \text{Diag}(\boldsymbol{\beta}_{[t]}) \left( \mathbf{K}_{[t]} \mathbf{K}_{[t]}^\top \odot \mathbf{M}_{-1} \right) \right)^{-1} \text{Diag}(\boldsymbol{\beta}_{[t]}) \mathbf{K}_{[t]}
\end{aligned}\]
因此,要计算 U_[t] 和 W_[t],关键就在于下面两个量:
\[\begin{aligned}
\mathbf{T}_{[t]}
=
\left(\mathbf{I} + \text{Diag}(\boldsymbol{\beta}_{[t]}) \left( \mathbf{K}_{[t]} \mathbf{K}_{[t]}^\top \odot \mathbf{M}_{-1} \right) \right)^{-1} \text{Diag}(\boldsymbol{\beta}_{[t]})
\end{aligned}\]
以及:
\[\begin{aligned}
\mathbf{\widetilde{T}}_{[t]}
=
\left(\mathbf{I} + \text{Diag}(\boldsymbol{\beta}_{[t]}) \left( \overleftarrow{\mathbf{K}_{[t]}} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M}_{-1} \right) \right)^{-1} \text{Diag}(\boldsymbol{\beta}_{[t]})
\end{aligned}\]
需要注意的是,这个表达式与原论文中的写法是等价的:
\[\begin{aligned}
\mathbf{\widetilde{T}}_{[t]}
=
\left(\mathbf{I} + \text{Diag}(\boldsymbol{\beta}_{[t]}) \left( \mathbf{K}_{[t]} \mathbf{K}_{[t]}^\top \odot \mathbf{\Gamma}_{[t]} \right) \odot \mathbf{M}_{-1} \right)^{-1} \text{Diag}(\boldsymbol{\beta}_{[t]})
,\quad
(\mathbf{\Gamma}_{[t]})_{ij} = \frac{\gamma_i}{\gamma_j}
\end{aligned}\]
矩阵求逆的部分可以按照 DeltaNet 笔记中的方法处理。
五、网络结构

六、实验
6.1 语言建模

6.2 上下文内检索

6.3 长序列上的长度外推

6.4 长上下文理解
