阅读笔记:DeltaNet
原文链接:https://arxiv.org/abs/2406.06484
代码链接:https://github.com/fla-org/flash-linear-attention
声明:本文为个人阅读笔记。所有来自我自己的推导都可能存在错误,后续需要对照代码进行交叉验证。
一、动机
- 大部分线性注意力模型不如 Transformer,尤其在需要 In-context Retrieval 的场景下
- Delta Rule 的训练效率不高
二、符号约定
- 使用 \(\mathbf{S, Q}\) 等粗体大写字母表示矩阵
- 使用 \(\mathbf{q}_t, \mathbf{k}_t\) 等表示列向量(即 \([d, 1]\) 的形式),矩阵则是 \([L, d]\) 的形式,因此会有额外的转置操作
- 使用 \(W_t\) 等表示可学习参数
- 使用 \(\mathbf{q}_t\) 表示 \(\mathbf{Q}\) 的第 \(t\) 行
三、背景知识
1. GLA
\[\begin{aligned}
\mathbf{O} = (\mathbf{Q}\mathbf{K}^\top \odot \mathbf{M}) \mathbf{V}
\Leftrightarrow
\boldsymbol{o}_r = \sum_{i=1}^r \boldsymbol{v}_i \boldsymbol{k}_i^\top \boldsymbol{q}_r
\end{aligned}\]
2. DeltaNet
DeltaNet 可以被视作一种 SGD 优化器。
四、DeltaNet 的并行化
4.1 前向传播
评论:推导目标是尽量减少显存占用,并多使用矩阵运算。
假设前人没做过,那么公式推导中的难点在于 Householder 变换的性质。
回顾 DeltaNet 公式
\[\begin{aligned}
\mathbf{S}_t
&=
\mathbf{S}_{t-1} - \boldsymbol{v}_t^{\text{old}} \boldsymbol{k}_t^\top + \boldsymbol{v}_t^{\text{new}} \boldsymbol{k}_t^\top
\\&=
\mathbf{S}_{t-1} - \beta_t (\mathbf{S}_{t-1} \boldsymbol{k}_t) \boldsymbol{k}_t^\top + \beta_t \boldsymbol{v}_t \boldsymbol{k}_t^\top
\\&=
\mathbf{S}_{t-1}(\mathbf{I} - \beta_t \boldsymbol{k}_t \boldsymbol{k}_t^\top) + \beta_t \boldsymbol{v}_t \boldsymbol{k}_t^\top
\end{aligned}\]
仿照 GLA,我们定义分块变量:
\[\begin{aligned}
\square_{[t]} &= \square_{[t]}^{1:C} \in \mathbb{R}^{C \times d} \quad\text{for}\quad \square \in \{ \mathbf{Q, K, V,...} \}
\\
\\
\mathbf{P}_{[t]}^{r} &= \prod_{i=t C + 1}^{t C + r}(\mathbf{I} - \beta_{i} \boldsymbol{k}_{i} \boldsymbol{k}_{i}^{\top})
\in \mathbb{R}^{d \times d}
,\quad
\mathbf{H}_{[t]}^{r} = \sum_{i=tC + 1}^{tC + r} \beta_{i} (\boldsymbol{v}_{i} \boldsymbol{k}_{i}^{\top})
\left(
\prod_{j=i + 1}^{t C + r}(\mathbf{I} - \beta_{j} \boldsymbol{k}_{j} \boldsymbol{k}_{j}^{\top})
\right)
\in \mathbb{R}^{d \times d}
\end{aligned}\]
则有:
\[\begin{aligned}
\mathbf{S}_{[t]}^{r}
=
\mathbf{S}_{[t-1]}^{C} \mathbf{P}_{[t]}^{r} + \mathbf{H}_{[t]}^{r}
, \quad
\text{where }
\mathbf{S}_{[-1]}^{C} = \mathbf{0}
\end{aligned}\]
另一方面,假设初始条件:
\[\begin{aligned}
\boldsymbol{u}_1 = \beta_1 \boldsymbol{v}_1 , \quad \mathbf{S}_1 = \beta_1 \boldsymbol{v}_1 \boldsymbol{k}_1^\top
\end{aligned}\]
利用数学归纳法可得:
\[\begin{aligned}
\mathbf{S}_t = \mathbf{S}_{t-1} + \beta_t (\boldsymbol{v}_t - \mathbf{S}_{t-1} \boldsymbol{k}_t )\boldsymbol{k}_t^\top
=
\sum_{i=1}^{t-1} \boldsymbol{u}_i \boldsymbol{k}_i^\top
+
\underbrace{\beta_t \left( \boldsymbol{v}_t - \sum_{i=1}^{t-1} \boldsymbol{u}_i (\boldsymbol{k}_i^\top \boldsymbol{k}_t) \right)}_{\text{defined as } \boldsymbol{u}_t} \boldsymbol{k}_t^\top
=
\sum_{i=1}^{t} \boldsymbol{u}_i \boldsymbol{k}_i^\top
\end{aligned}\]
因此:
\[\begin{aligned}
\mathbf{H}_{[t]}^{r} =
\sum_{i=1}^{r} \boldsymbol{u}_{[t]}^{i} \boldsymbol{k}_{[t]}^{i\top}
, \quad
\boldsymbol{u}_{[t]}^{r} = \beta_{[t]}^{r} \left( \boldsymbol{v}_{[t]}^{r} - \sum_{i=1}^{r-1} \boldsymbol{u}_{[t]}^{i} \boldsymbol{k}_{[t]}^{i \top} \boldsymbol{k}_{[t]}^{r} \right)
\end{aligned}\]
另一方面,Householder 变换(形如 \(\mathbf{I} - \beta_t \boldsymbol{k}_t \boldsymbol{k}_t^\top\))的乘积一定可以表示成低秩模式(WY Representation),我们不妨假设:
\[\begin{aligned}
\mathbf{P}_{[t]}^{r} = \mathbf{I} - \sum_{i=1}^{r} \boldsymbol{w}_{[t]}^{i} \boldsymbol{k}_{[t]}^{i \top}
\end{aligned}\]
利用数学归纳法可得:
\[\begin{aligned}
\mathbf{P}_{[t]}^{r}
=
(\mathbf{I} - \sum_{i=1}^{r - 1} \boldsymbol{w}_{[t]}^{i} \boldsymbol{k}_{[t]}^{i \top}) (\mathbf{I} - \beta_{[t]}^{r} \boldsymbol{k}_{[t]}^{r} \boldsymbol{k}_{[t]}^{r \top})
=
\mathbf{I} - \sum_{i=1}^{r - 1} \boldsymbol{w}_{[t]}^{i} \boldsymbol{k}_{[t]}^{i \top}
-
\underbrace{\beta_{[t]}^{r} \left(\boldsymbol{k}_{[t]}^{r}
-
\sum_{i=1}^{r - 1} \boldsymbol{w}_{[t]}^{i} \boldsymbol{k}_{[t]}^{i \top} \boldsymbol{k}_{[t]}^{r} \right) }_{\text{defined as } \boldsymbol{w}_{[t]}^{r}}
\boldsymbol{k}_{[t]}^{r \top}
\end{aligned}\]
于是:
\[\begin{aligned}
\mathbf{P}_{[t]}^{r} = \mathbf{I} - \sum_{i=1}^{r} \boldsymbol{w}_{[t]}^{i} \boldsymbol{k}_{[t]}^{i \top}
, \quad
\boldsymbol{w}_{[t]}^{r} = \beta_{[t]}^{r} \left( \boldsymbol{k}_{[t]}^{r} - \sum_{i=1}^{r-1} \boldsymbol{w}_{[t]}^{i} \boldsymbol{k}_{[t]}^{i \top} \boldsymbol{k}_{[t]}^{r} \right)
\end{aligned}\]
进一步:
\[\begin{aligned}
\mathbf{S}_{[t]}^{r}
&=
\mathbf{S}_{[t-1]}^{C} \left(\mathbf{I} - \sum_{i=1}^{r} \boldsymbol{w}_{[t]}^{i} \boldsymbol{k}_{[t]}^{i \top}\right)
+
\sum_{i=1}^{r} \boldsymbol{u}_{[t]}^{i} \boldsymbol{k}_{[t]}^{i\top}
=
\mathbf{S}_{[t-1]}^{C}
+
\sum_{i=1}^{r} \left(\boldsymbol{u}_{[t]}^{i} - \mathbf{S}_{[t-1]}^{C} \boldsymbol{w}_{[t]}^{i} \right)\boldsymbol{k}_{[t]}^{i\top}
\\
\\
\boldsymbol{o}_{[t]}^{r}
&=
\mathbf{S}_{[t]}^{r} \boldsymbol{q}_{[t]}^{r}
=
\mathbf{S}_{[t-1]}^{C} \boldsymbol{q}_{[t]}^{r}
+
\sum_{i=1}^{r} \left(\boldsymbol{u}_{[t]}^{i} - \mathbf{S}_{[t-1]}^{C} \boldsymbol{w}_{[t]}^{i} \right)\boldsymbol{k}_{[t]}^{i\top} \boldsymbol{q}_{[t]}^{r}
\end{aligned}\]
我们可以得到矩阵形式:
\[\begin{aligned}
\mathbf{S}_{[t]}^{C}
&=
\mathbf{S}_{[t-1]}^{C}
+
\left(\mathbf{U}_{[t]}^\top - \mathbf{S}_{[t-1]}^{C} \mathbf{W}_{[t]}^\top \right)\mathbf{K}_{[t]}
\\
\\
\mathbf{O}_{[t]}
&=
\mathbf{Q}_{[t]} \mathbf{S}_{[t-1]}^{C \top}
+
\left( \mathbf{Q}_{[t]} \mathbf{K}_{[t]}^\top \odot \mathbf{M} \right) \left(\mathbf{U}_{[t]} - \mathbf{W}_{[t]} \mathbf{S}_{[t-1]}^{C \top} \right)
\end{aligned}\]
接着处理 \(\boldsymbol{u}_{[t]}, \boldsymbol{w}_{[t]}\) 的计算。其中 \(\mathbf{M}_{-1} = \mathbf{M} - \mathbf{I}\) 表示对角线元素为 0 的严格下三角掩码矩阵。
\[\begin{aligned}
& \boldsymbol{u}_{[t]}^{r} = \beta_{[t]}^{r} \left( \boldsymbol{v}_{[t]}^{r} - \sum_{i=1}^{r-1} \boldsymbol{u}_{[t]}^{i} \boldsymbol{k}_{[t]}^{i \top} \boldsymbol{k}_{[t]}^{r} \right)
\\
\\
\Rightarrow &
\mathbf{U}_{[t]}
=
\text{Diag}(\boldsymbol{\beta}_{[t]}) \mathbf{V}_{[t]}
-
\text{Diag}(\boldsymbol{\beta}_{[t]}) \left( \mathbf{K}_{[t]} \mathbf{K}_{[t]}^\top \odot \mathbf{M}_{-1} \right) \mathbf{U}_{[t]}
\\
\\
\Rightarrow &
\left(\mathbf{I} + \text{Diag}(\boldsymbol{\beta}_{[t]}) \left( \mathbf{K}_{[t]} \mathbf{K}_{[t]}^\top \odot \mathbf{M}_{-1} \right) \right) \mathbf{U}_{[t]}
= \text{Diag}(\boldsymbol{\beta}_{[t]}) \mathbf{V}_{[t]}
\\
\\
\Rightarrow &
\mathbf{U}_{[t]}
=
\left(\mathbf{I} + \text{Diag}(\boldsymbol{\beta}_{[t]}) \left( \mathbf{K}_{[t]} \mathbf{K}_{[t]}^\top \odot \mathbf{M}_{-1} \right) \right)^{-1} \text{Diag}(\boldsymbol{\beta}_{[t]}) \mathbf{V}_{[t]}
\end{aligned}\]
同理:
\[\begin{aligned}
\mathbf{W}_{[t]}
=
\left(\mathbf{I} + \text{Diag}(\boldsymbol{\beta}_{[t]}) \left( \mathbf{K}_{[t]} \mathbf{K}_{[t]}^\top \odot \mathbf{M}_{-1} \right) \right)^{-1} \text{Diag}(\boldsymbol{\beta}_{[t]}) \mathbf{K}_{[t]}
\end{aligned}\]
因此,要求 \(\mathbf{U}_{[t]}, \mathbf{W}_{[t]}\),核心在于计算:
\[\begin{aligned}
\mathbf{T}_{[t]}
=
\left(\mathbf{I} + \text{Diag}(\boldsymbol{\beta}_{[t]}) \left( \mathbf{K}_{[t]} \mathbf{K}_{[t]}^\top \odot \mathbf{M}_{-1} \right) \right)^{-1} \text{Diag}(\boldsymbol{\beta}_{[t]})
\end{aligned}\]
4.2 单位下三角矩阵求逆
本小节讨论 \(\mathbf{T}_{[t]}\) 的求解。更一般地,形如 \(\mathbf{A} = \mathbf{I} + \mathbf{L} \in \mathbb{R}^{N \times N}\) 的单位下三角矩阵的求逆问题。流程上可以先对矩阵进行分块,然后对每一块采用 Neumann 级数配合倍增法求逆。
Neumann 级数
\[\begin{aligned}
(\mathbf{I} + \mathbf{L})^{-1}
=
\sum_{n=0}^{\infty} (- \mathbf{L})^n
\end{aligned}\]
由于严格下三角矩阵 \(\mathbf{L}\) 具有幂零性质,即:
\[\begin{aligned}
\mathbf{L}^{C} = \mathbf{0} \quad \forall~ C \gt N
\end{aligned}\]
因此级数有限截断:
\[\begin{aligned}
(\mathbf{I} + \mathbf{L})^{-1}
=
\sum_{n=0}^{N - 1} (- \mathbf{L})^n
\end{aligned}\]
倍增法
定义:
\[\begin{aligned}
\mathbf{S}_{k} = \sum_{n=0}^{2^k-1} (- \mathbf{L})^n
, \quad
\mathbf{G}_{k} = (- \mathbf{L})^{2^k}
\end{aligned}\]
则有递推关系:
\[\begin{aligned}
\mathbf{S}_{k+1} = \mathbf{S}_{k+1}(\mathbf{I} + \mathbf{G}_{k})
, \quad
\mathbf{G}_{k+1} = \mathbf{G}_{k} \mathbf{G}_{k}
\end{aligned}\]
分块矩阵求逆(形式一)
\[\begin{aligned}
\begin{pmatrix}
\mathbf{A}_{11} & \mathbf{0} \\
\mathbf{A}_{21} & \mathbf{A}_{22}
\end{pmatrix}
\begin{pmatrix}
\mathbf{A}_{11}^{-1} & \mathbf{0} \\
-\mathbf{A}_{22}^{-1} \mathbf{A}_{21} \mathbf{A}_{11}^{-1} & \mathbf{A}_{22}^{-1}
\end{pmatrix}
=
\begin{pmatrix}
\mathbf{I} & \mathbf{0} \\
\mathbf{0} & \mathbf{I}
\end{pmatrix}
\end{aligned}\]
分块矩阵求逆(形式二)
假设 \(\mathbf{A}\mathbf{B} = \mathbf{I}\),其中 \(\mathbf{A}\) 为分块下三角矩阵,则 \(\mathbf{B}\) 也是分块下三角矩阵,且:
\[\begin{aligned}
\mathbf{B}_{ij} = -\mathbf{A}_{ii}^{-1} \sum_{k=j}^{i-1} \mathbf{A}_{ik} \mathbf{B}_{kj}
, \quad i \gt j
\end{aligned}\]
五、网络架构
- RMSNorm:为了更稳定的训练(与 VMamba、Mamba-2 中提到的做法一致)
- Q、K 的参数化:采用 \(\frac{\text{SiLU}(\mathbf{W}\boldsymbol{x}_t)}{|\text{SiLU}(\mathbf{W}\boldsymbol{x}_t)|_2}\) 的形式,其中 SiLU 替代了原来的 ELU+1,L2 归一化确保特征值小于 1
评论:这里 Q 只是顺带做了归一化,真正需要的是 K 小于 1,因为递推中有 \(\boldsymbol{k}^\top \boldsymbol{k}\) 项。
- Short Convolution(Shift-SSM):理由来自 H3(Hungry Hungry Hippos)
评论:我觉得苏剑林的理由更充分:https://spaces.ac.cn/archives/11320
- 混合网络设计:GDN + SWA + Full Attention

六、实验
6.1 加速效果

6.2 MQAR
来自论文"Measuring and improving recall in efficient language models"。MQAR 用于评估模型的 recall 能力。该实验中 DeltaNet 没有使用 conv,其他训练设定与原文相同。
6.3 MAD
合成 token 操作任务集合,用于评估线性注意力模型。

6.4 语言建模
实验设置与 GLA 相同。

6.5 RegBench

评论:Mamba 是否带 conv 的差距太大了。这是否可能是因为 hidden state 更小的原因?
七、相关工作
