跳转至

阅读笔记: Gated Delta Networks

原文链接https://arxiv.org/pdf/2412.06464
代码链接https://github.com/fla-org/flash-linear-attention
声明:本文为个人阅读笔记。所有来自我自己的推导都可能存在错误。

一、动机

评论:感觉这篇论文的 Introduction 写得很好。

  1. 标准注意力的二次方计算复杂度催生了线性注意力的研究方向。
  2. 尽管最新的线性注意力工作在性能上有所提升,但在上下文检索(in-context retrieval)任务上仍然表现不佳。
  3. 这并不意外,因为隐状态(state space)的容量是有限的。
  4. Mamba2 采用全局衰减(decay)来管理记忆,但它无法为每个 KV 对单独遗忘。
  5. Delta rule 可以逐个处理 KV 的遗忘,但无法迅速抹除过去的记忆。
  6. 两者具有天然的互补性,因此作者设计了 Gated Delta Rule,结合全局衰减与逐 KV 遗忘。
  7. 剩下的挑战在于如何高效实现。作者在 chunk-wise 并行算法和 Triton kernel 优化方面做了大量工作。
  8. 最终完成的 Gated DeltaNet 架构在多项基准上取得了优异表现。

二、符号约定

  1. 使用 \(\mathbf{S, Q}\) 等粗体大写字母表示矩阵
  2. 使用 \(\mathbf{q}_t, \mathbf{k}_t\) 等表示列向量(即 \([d, 1]\) 的形式),矩阵则是 \([L, d]\) 的形式,因此会有额外的转置操作
  3. 使用 \(W_t\) 等表示可学习参数
  4. 使用 \(\mathbf{q}_t\) 表示 \(\mathbf{Q}\) 的第 \(t\)
  5. \(\square_{[t]} = \square_{[t]}^{1:C} \in \mathbb{R}^{C \times d}\) 表示第 \(t\) 个 chunk,其中 \(\square \in { \mathbf{Q, K, V, \dots} }\)

三、在线学习视角

3.1 数学预备知识

我们先回顾几个基本恒等式。

\[\begin{aligned} \|\mathbf{A}\|_F^2 &= \sum_{i,j} a_{ij}^2 = \text{Tr}(\mathbf{A}^\top \mathbf{A}) ,\quad \langle\boldsymbol{k}_t, \boldsymbol{v}_t\rangle = \text{Tr}(\boldsymbol{v}_t^\top \boldsymbol{k}_t) ,\quad \|\boldsymbol{x}\|^2 = \sum_i x_i^2 = \boldsymbol{x}^\top \boldsymbol{x} \\ \\ d \left(\text{Tr}(\mathbf{A}^\top \mathbf{A})\right) &= \text{Tr}(d(\mathbf{A}^\top \mathbf{A})) = \text{Tr}((d\mathbf{A})^\top \mathbf{A} + \mathbf{A}^\top (d\mathbf{A})) = \text{Tr}(2\mathbf{A}^\top (d\mathbf{A})) \\ \\ d(\boldsymbol{x}^\top \boldsymbol{y}) &= (d\boldsymbol{x})^\top \boldsymbol{y} + \boldsymbol{x}^\top (d\boldsymbol{y}) \Rightarrow d(\boldsymbol{x}^\top \boldsymbol{x}) = 2 \boldsymbol{x}^\top (d\boldsymbol{x}) \end{aligned}\]

由此可以得到:

\[\begin{aligned} d \|\mathbf{A}\|_F^2 &= \text{Tr}( (\frac{\partial \|\mathbf{A}\|_F^2}{\partial \mathbf{A}})^\top d\mathbf{A}) \Rightarrow \frac{\partial \|\mathbf{A}\|_F^2}{\partial \mathbf{A}} = 2 \mathbf{A} \\ \\ d \|\boldsymbol{x}\|^2 &= \text{Tr}( (\frac{\partial \|\boldsymbol{x}\|^2}{\partial \boldsymbol{x}})^\top d\boldsymbol{x}) \Rightarrow \frac{\partial \|\boldsymbol{x}\|^2}{\partial \boldsymbol{x}} = 2 \boldsymbol{x} \\ \\ d \langle\boldsymbol{k}_t, \boldsymbol{v}_t\rangle &= \text{Tr}((\frac{\partial \langle\boldsymbol{k}_t, \boldsymbol{v}_t\rangle}{\partial \boldsymbol{k}_t})^\top d \boldsymbol{k}_t) \Rightarrow \frac{\partial \langle\boldsymbol{k}_t, \boldsymbol{v}_t\rangle}{\partial \boldsymbol{k}_t} = \boldsymbol{v}_t \end{aligned}\]

此外,还会用到 Sherman–Morrison 公式

\[\begin{aligned} (I + \boldsymbol{u}\boldsymbol{v}^\top)^{-1} = I - \frac{\boldsymbol{u}\boldsymbol{v}^\top}{1 + \boldsymbol{v}^\top\boldsymbol{u}} \end{aligned}\]

3.2 推导

3.2.1 Longhorn

先看 Longhorn 的目标函数:

\[\begin{aligned} L &= \|\mathbf{S}_t - \mathbf{S}_{t-1}\|_F^2 + \beta_t \|\mathbf{S}_t \boldsymbol{k}_t - \boldsymbol{v}_t\|^2 \\ \\ \Rightarrow \frac{\partial}{\partial \mathbf{S}_t} L &= 2 (\mathbf{S}_t - \mathbf{S}_{t-1}) + 2\beta_t (\mathbf{S}_t \boldsymbol{k}_t - \boldsymbol{v}_t) \boldsymbol{k}_t^\top = 0 \\ \\ \Rightarrow \mathbf{S}_t (\mathbf{I} + \beta_t \boldsymbol{k}_t \boldsymbol{k}_t^\top) &= \mathbf{S}_{t-1} + \beta_t\boldsymbol{v}_t \boldsymbol{k}_t^\top \\ \\ \Rightarrow \mathbf{S}_t &= \mathbf{S}_{t-1} (\mathbf{I} - \frac{ \beta_t \boldsymbol{k}_t \boldsymbol{k}_t^\top}{1 + \beta_t \boldsymbol{k}_t^\top \boldsymbol{k}_t} ) + \beta_t\boldsymbol{v}_t \boldsymbol{k}_t^\top (\mathbf{I} - \frac{ \beta_t \boldsymbol{k}_t \boldsymbol{k}_t^\top}{1 + \beta_t \boldsymbol{k}_t^\top \boldsymbol{k}_t} ) \\&= \mathbf{S}_{t-1} (\mathbf{I} - \epsilon_t \boldsymbol{k}_t \boldsymbol{k}_t^\top) + \epsilon_t \boldsymbol{v}_t \boldsymbol{k}_t^\top + \epsilon_t \beta_t \boldsymbol{v}_t \boldsymbol{k}_t^\top (\boldsymbol{k}_t^\top \boldsymbol{k}_t - \boldsymbol{k}_t \boldsymbol{k}_t^\top) \\&= \mathbf{S}_{t-1} (\mathbf{I} - \epsilon_t \boldsymbol{k}_t \boldsymbol{k}_t^\top) + \epsilon_t \boldsymbol{v}_t \boldsymbol{k}_t^\top \\ \epsilon_t &= \frac{\beta_t}{1 + \beta_t \boldsymbol{k}_t^\top \boldsymbol{k}_t} \end{aligned}\]

评论:这里的 beta_t 看起来和原论文中的定义并不完全一致。

3.2.2 Mamba2

接着,对于 Mamba2,有:

\[\begin{aligned} L &= \|\mathbf{S}_t - \alpha_t \mathbf{S}_{t-1}\|_F^2 - 2 \langle \mathbf{S}_t \boldsymbol{k}_t, \boldsymbol{v}_t\rangle \\ \\ \Rightarrow \frac{\partial}{\partial \mathbf{S}_t} L &= 2 (\mathbf{S}_t - \alpha_t \mathbf{S}_{t-1}) - 2 \boldsymbol{v}_t \boldsymbol{k}_t^\top \Rightarrow \mathbf{S}_t = \alpha_t \mathbf{S}_{t-1} + \boldsymbol{v}_t \boldsymbol{k}_t^\top \end{aligned}\]

根据论文中的表格,LADeltaNetGated DeltaNet 都可以按和 Mamba2 类似的方式推导出来。

四、Gated Delta Rule

4.1 前向传播

先回顾 Gated DeltaNet 的更新公式:

\[\begin{aligned} \mathbf{S}_t = \mathbf{S}_{t-1}\alpha_t (\mathbf{I} - \beta_t \boldsymbol{k}_t \boldsymbol{k}_t^\top) + \beta_t \boldsymbol{v}_t \boldsymbol{k}_t^\top \end{aligned}\]

接着,仿照 DeltaNet 和 GLA 中的推导,定义:

\[\begin{aligned} \mathbf{P}_{[t]}^{r} &= \prod_{i=t C + 1}^{t C + r}(\mathbf{I} - \beta_{i} \boldsymbol{k}_{i} \boldsymbol{k}_{i}^{\top}) \in \mathbb{R}^{d \times d} ,\quad \gamma_{[t]}^{r} = \prod_{i=t C + 1}^{t C + r} \alpha_i \\ \\ \mathbf{H}_{[t]}^{r} &= \sum_{i=tC + 1}^{tC + r} \beta_{i} (\boldsymbol{v}_{i} \boldsymbol{k}_{i}^{\top}) \frac{\gamma_{[t]}^{r}}{\gamma_{[t]}^{i}} \left( \prod_{j=i + 1}^{t C + r}(\mathbf{I} - \beta_{j} \boldsymbol{k}_{j} \boldsymbol{k}_{j}^{\top}) \right) \in \mathbb{R}^{d \times d} \end{aligned}\]

于是,分块后的状态可以写成:

\[\begin{aligned} \mathbf{S}_{[t]}^{r} = \mathbf{S}_{[t-1]}^{C} \gamma_{[t]}^{r} \mathbf{P}_{[t]}^{r} + \mathbf{H}_{[t]}^{r} , \quad \text{where } \mathbf{S}_{[-1]}^{C} = \mathbf{0} \end{aligned}\]

另一方面,假设:

\[\begin{aligned} \mathbf{S}_t = \eta_t \sum_{i=1}^{t} \boldsymbol{u^0}_i \boldsymbol{k}_i^\top \end{aligned}\]

那么由数学归纳法可得:

\[\begin{aligned} \mathbf{S}_t &= \alpha_t \mathbf{S}_{t-1} + \beta_t (\boldsymbol{v}_t - \alpha_t \mathbf{S}_{t-1} \boldsymbol{k}_t )\boldsymbol{k}_t^\top \\&= \alpha_t \eta_{t-1} \sum_{i=1}^{t-1} \boldsymbol{u^0}_i \boldsymbol{k}_i^\top + \beta_t \left( \boldsymbol{v}_t - \alpha_t \eta_{t-1} \sum_{i=1}^{t-1} \boldsymbol{u^0}_i \boldsymbol{k}_i^\top \boldsymbol{k}_t \right) \boldsymbol{k}_t^\top = \eta_t \sum_{i=1}^{t} \boldsymbol{u^0}_i \boldsymbol{k}_i^\top \end{aligned}\]

其中:

\[\begin{aligned} \eta_t = \alpha_t \eta_{t-1} = \gamma_t ,\quad \boldsymbol{u^0}_t = \frac{\beta_t}{\eta_t} \boldsymbol{v}_t - \beta_t \sum_{i=1}^{t-1} \boldsymbol{u^0}_i \boldsymbol{k}_i^\top \boldsymbol{k}_t \end{aligned}\]

等价地,也可以写成:

\[\begin{aligned} \mathbf{S}_t = \sum_{i=1}^{t} \frac{\gamma_t}{\gamma_i} \boldsymbol{u}_i \boldsymbol{k}_i^\top ,\quad \boldsymbol{u}_t = \beta_t \left(\boldsymbol{v}_t - \sum_{i=1}^{t-1} \frac{\gamma_t}{\gamma_i} \boldsymbol{u}_i \boldsymbol{k}_i^\top \boldsymbol{k}_t \right) \end{aligned}\]

另一方面,形如 (I - beta_t k_t k_t^T) 的 Householder 变换乘积总可以写成低秩形式,也就是 WY representation。根据 DeltaNet 笔记,可以得到:

\[\begin{aligned} \mathbf{P}_{[t]}^{r} = \mathbf{I} - \sum_{i=1}^{r} \boldsymbol{w}_{[t]}^{i} \boldsymbol{k}_{[t]}^{i \top} ,\quad \boldsymbol{w}_{[t]}^{r} = \beta_{[t]}^{r} \boldsymbol{k}_{[t]}^{r} - \beta_{[t]}^{r} \sum_{i=1}^{r-1} \boldsymbol{w}_{[t]}^{i} \boldsymbol{k}_{[t]}^{i \top} \boldsymbol{k}_{[t]}^{r} \end{aligned}\]

把它代回到 chunkwise 递推式中,就有:

\[\begin{aligned} \mathbf{S}_{[t]}^{r} &= \mathbf{S}_{[t-1]}^{C} \gamma_{[t]}^{r} \left(\mathbf{I} - \sum_{i=1}^{r} \boldsymbol{w}_{[t]}^{i} \boldsymbol{k}_{[t]}^{i \top}\right) + \sum_{i=1}^{r} \frac{\gamma_{[t]}^{r}}{\gamma_{[t]}^{i}} \boldsymbol{u}_{[t]}^{i} \boldsymbol{k}_{[t]}^{i\top} \\&= \gamma_{[t]}^{r} \mathbf{S}_{[t-1]}^{C} + \gamma_{[t]}^{r} \sum_{i=1}^{r} \left(\boldsymbol{u}_{[t]}^{i} - \mathbf{S}_{[t-1]}^{C} \gamma_{[t]}^{i} \boldsymbol{w}_{[t]}^{i} \right)\frac{\boldsymbol{k}_{[t]}^{i\top}}{\gamma_{[t]}^{i}} \\ \\ \boldsymbol{o}_{[t]}^{r} &= \mathbf{S}_{[t]}^{r} \boldsymbol{q}_{[t]}^{r} = \mathbf{S}_{[t-1]}^{C} \boldsymbol{q}_{[t]}^{r} \gamma_{[t]}^{r} + \sum_{i=1}^{r} \left(\boldsymbol{u}_{[t]}^{i} - \mathbf{S}_{[t-1]}^{C} \gamma_{[t]}^{i} \boldsymbol{w}_{[t]}^{i} \right)\frac{\boldsymbol{k}_{[t]}^{i\top}}{\gamma_{[t]}^{i}} \boldsymbol{q}_{[t]}^{r} \gamma_{[t]}^{r} \end{aligned}\]

现在再定义:

\[\begin{aligned} \overleftarrow{\square_{[t]}} &= \text{Diag}(\boldsymbol{\gamma}_{[t]}) \square_{[t]} ,\quad \overrightarrow{\square_{[t]}} = \text{Diag}(\boldsymbol{\gamma}_{[t]})^{-1} \square_{[t]} \quad\text{for}\quad \square \in \{ \mathbf{Q}, \mathbf{K}, \mathbf{W}\} \end{aligned}\]

于是,整体计算可以进一步写成矩阵形式:

\[\begin{aligned} \mathbf{S}_{[t]}^{C} &= \gamma_{[t]}^{C} \mathbf{S}_{[t-1]}^{C} + \gamma_{[t]}^{C} \left(\mathbf{U}_{[t]}^\top - \mathbf{S}_{[t-1]}^{C} \overleftarrow{\mathbf{W}_{[t]}}^\top \right)\overrightarrow{\mathbf{K}_{[t]}} \\ \\ \mathbf{O}_{[t]} &= \overleftarrow{\mathbf{Q}_{[t]}} \mathbf{S}_{[t-1]}^{C \top} + \left( \overleftarrow{\mathbf{Q}_{[t]}} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M} \right) \left(\mathbf{U}_{[t]} - \overleftarrow{\mathbf{W}_{[t]}} \mathbf{S}_{[t-1]}^{C \top} \right) \end{aligned}\]

评论:这里的 O 与原论文中的表达似乎并不一致。

接下来处理 u_[t]w_[t]。其中,M_{-1} = M - I 表示主对角线为零的严格下三角掩码。

u_[t] 的递推式出发,可以得到:

\[\begin{aligned} \boldsymbol{u}_{[t]}^{r} &= \beta_{[t]}^{r} \left( \boldsymbol{v}_{[t]}^{r} - \sum_{i=1}^{r-1} \frac{\gamma_{[t]}^{r}}{\gamma_{[t]}^{i}} \boldsymbol{u}_{[t]}^{i} \boldsymbol{k}_{[t]}^{i \top} \boldsymbol{k}_{[t]}^{r} \right) \\ \\ \Rightarrow \mathbf{U}_{[t]} &= \text{Diag}(\boldsymbol{\beta}_{[t]}) \mathbf{V}_{[t]} - \text{Diag}(\boldsymbol{\beta}_{[t]}) \left( \overleftarrow{\mathbf{K}_{[t]}} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M}_{-1} \right) \mathbf{U}_{[t]} \\ \\ \Rightarrow \mathbf{U}_{[t]} &= \left(\mathbf{I} + \text{Diag}(\boldsymbol{\beta}_{[t]}) \left( \overleftarrow{\mathbf{K}_{[t]}} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M}_{-1} \right) \right)^{-1} \text{Diag}(\boldsymbol{\beta}_{[t]}) \mathbf{V}_{[t]} \end{aligned}\]

同理,也有:

\[\begin{aligned} \mathbf{W}_{[t]} = \left(\mathbf{I} + \text{Diag}(\boldsymbol{\beta}_{[t]}) \left( \mathbf{K}_{[t]} \mathbf{K}_{[t]}^\top \odot \mathbf{M}_{-1} \right) \right)^{-1} \text{Diag}(\boldsymbol{\beta}_{[t]}) \mathbf{K}_{[t]} \end{aligned}\]

因此,要计算 U_[t]W_[t],关键就在于下面两个量:

\[\begin{aligned} \mathbf{T}_{[t]} = \left(\mathbf{I} + \text{Diag}(\boldsymbol{\beta}_{[t]}) \left( \mathbf{K}_{[t]} \mathbf{K}_{[t]}^\top \odot \mathbf{M}_{-1} \right) \right)^{-1} \text{Diag}(\boldsymbol{\beta}_{[t]}) \end{aligned}\]

以及:

\[\begin{aligned} \mathbf{\widetilde{T}}_{[t]} = \left(\mathbf{I} + \text{Diag}(\boldsymbol{\beta}_{[t]}) \left( \overleftarrow{\mathbf{K}_{[t]}} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M}_{-1} \right) \right)^{-1} \text{Diag}(\boldsymbol{\beta}_{[t]}) \end{aligned}\]

需要注意的是,这个表达式与原论文中的写法是等价的:

\[\begin{aligned} \mathbf{\widetilde{T}}_{[t]} = \left(\mathbf{I} + \text{Diag}(\boldsymbol{\beta}_{[t]}) \left( \mathbf{K}_{[t]} \mathbf{K}_{[t]}^\top \odot \mathbf{\Gamma}_{[t]} \right) \odot \mathbf{M}_{-1} \right)^{-1} \text{Diag}(\boldsymbol{\beta}_{[t]}) ,\quad (\mathbf{\Gamma}_{[t]})_{ij} = \frac{\gamma_i}{\gamma_j} \end{aligned}\]

矩阵求逆的部分可以按照 DeltaNet 笔记中的方法处理。

五、网络结构

六、实验

6.1 语言建模

6.2 上下文内检索

6.3 长序列上的长度外推

6.4 长上下文理解

Comments