跳转至

阅读笔记: Kimi Delta Attention

原文链接: https://arxiv.org/pdf/2510.26692
代码链接: https://github.com/fla-org/flash-linear-attention
声明: These are personal reading notes. Some derivations are my own and may be incorrect, so they should be cross-checked against the code later.

一、动机

  1. 使用细粒度门控改进 delta rule

二、符号约定

  1. 使用 \(\mathbf{S, Q}\) 等粗体大写字母表示矩阵
  2. 使用 \(\mathbf{q}_t, \mathbf{k}_t\) 等表示列向量(即 \([d, 1]\) 的形式),矩阵则是 \([L, d]\) 的形式,因此会有额外的转置操作
  3. 使用 \(W_t\) 等表示可学习参数
  4. 使用 \(\mathbf{q}_t\) 表示 \(\mathbf{Q}\) 的第 \(t\)
  5. \(\square_{[t]} = \square_{[t]}^{1:C} \in \mathbb{R}^{C \times d}\) 表示第 \(t\) 个 chunk,其中 \(\square \in { \mathbf{Q, K, V, \dots} }\)

三、背景

  1. Gated Linear Attention
  2. Online Learning
  3. DeltaNet and Gated DeltaNet

四、Kimi Delta Attention

前向传播

评论: 这里再次应用了 WY 表示法 (即假设 P 可以写成 X - WY 的形式)

原始公式:

\[\begin{aligned} \mathbf{S}_t &= \mathbf{S}_{t-1} \text{Diag}(\boldsymbol{\alpha}_t) (\mathbf{I} - \beta_t \boldsymbol{k}_t \boldsymbol{k}_t^\top) + \beta_t \boldsymbol{v}_t \boldsymbol{k}_t^\top \in \mathbb{R}^{d_v \times d_k} \\ \\ \boldsymbol{o}_t &= \mathbf{S}_{t}\boldsymbol{q}_t \end{aligned}\]

接下来,遵循 DeltaNet 的推导,定义块级符号和以下辅助量:

\[\begin{aligned} \mathbf{P}_{[t]}^{r} &= \prod_{i=t C + 1}^{t C + r} \text{Diag}(\boldsymbol{\alpha}_i) (\mathbf{I} - \beta_{i} \boldsymbol{k}_{i} \boldsymbol{k}_{i}^{\top}) \in \mathbb{R}^{d_k \times d_k} \\ \\ \mathbf{H}_{[t]}^{r} &= \sum_{i=tC + 1}^{tC + r} \beta_{i} (\boldsymbol{v}_{i} \boldsymbol{k}_{i}^{\top}) \prod_{j=i + 1}^{t C + r} \text{Diag}(\boldsymbol{\alpha}_j) (\mathbf{I} - \beta_{j} \boldsymbol{k}_{j} \boldsymbol{k}_{j}^{\top}) \in \mathbb{R}^{d_v \times d_k} \end{aligned}\]

然后,块级状态可以写为:

\[\begin{aligned} \mathbf{S}_{[t]}^{r} = \mathbf{S}_{[t-1]}^{C} \mathbf{P}_{[t]}^{r} + \mathbf{H}_{[t]}^{r} , \quad \text{where } \mathbf{S}_{[-1]}^{C} = \mathbf{0} \end{aligned}\]

另一方面,假设:

\[\begin{aligned} \mathbf{S}_t = \sum_{i=1}^{t} \text{Diag}(\boldsymbol{\eta}_t) \text{Diag}(\boldsymbol{\xi}_i) \boldsymbol{u}_i \boldsymbol{k}_i^\top \text{Diag}(\boldsymbol{\epsilon}_i) \text{Diag}(\boldsymbol{\gamma}_t) \end{aligned}\]

利用数学归纳法,我们得到:

\[\begin{aligned} \mathbf{S}_t &= \mathbf{S}_{t-1} \text{Diag}(\boldsymbol{\alpha}_t) + \beta_t \left( \boldsymbol{v}_t - \mathbf{S}_{t-1} \text{Diag}(\boldsymbol{\alpha}_t) \boldsymbol{k}_t \right) \boldsymbol{k}_t^\top \\&= \sum_{i=1}^{t-1} \text{Diag}(\boldsymbol{\eta}_{t-1}) \text{Diag}(\boldsymbol{\xi}_i) \boldsymbol{u}_i \boldsymbol{k}_i^\top \text{Diag}(\boldsymbol{\epsilon}_i) \text{Diag}(\boldsymbol{\gamma}_{t-1}) \text{Diag}(\boldsymbol{\alpha}_{t}) \\&+ \beta_t \left( \boldsymbol{v}_t - \sum_{i=1}^{t-1} \text{Diag}(\boldsymbol{\eta}_{t-1}) \text{Diag}(\boldsymbol{\xi}_i) \boldsymbol{u}_i \boldsymbol{k}_i^\top \text{Diag}(\boldsymbol{\epsilon}_i) \text{Diag}(\boldsymbol{\gamma}_{t-1}) \text{Diag}(\boldsymbol{\alpha}_{t}) \boldsymbol{k}_t \right) \boldsymbol{k}_t^\top \\&= \sum_{i=1}^{t} \text{Diag}(\boldsymbol{\eta}_{t}) \text{Diag}(\boldsymbol{\xi}_i) \boldsymbol{u}_i \boldsymbol{k}_i^\top \text{Diag}(\boldsymbol{\epsilon}_i) \text{Diag}(\boldsymbol{\gamma}_{t}) \end{aligned}\]

在我们设置如下参数之后:

\[\begin{aligned} \text{Diag}(\boldsymbol{\gamma}_t) = \prod_{i=1}^t \text{Diag}(\boldsymbol{\alpha}_i) ,\quad \text{Diag}(\boldsymbol{\eta}_t) = \mathbf{I} ,\quad \text{Diag}(\boldsymbol{\epsilon}_t) = \text{Diag}(\boldsymbol{\gamma}_t)^{-1} \end{aligned}\]

那么我们很容易得到:

\[\begin{aligned} \boldsymbol{u}_t &= \beta_t \text{Diag}(\boldsymbol{\xi}_t)^{-1} \left( \boldsymbol{v}_t - \sum_{i=1}^{t-1} \text{Diag}(\boldsymbol{\xi}_i) \boldsymbol{u}_i \boldsymbol{k}_i^\top \text{Diag}(\boldsymbol{\epsilon}_i) \text{Diag}(\boldsymbol{\gamma}_{t}) \boldsymbol{k}_t \right) \end{aligned}\]

\xi 吸收进 u 后,我们最终得到:

\[\begin{aligned} \mathbf{S}_t = \sum_{i=1}^{t} \boldsymbol{u}_i \left(\text{Diag}(\boldsymbol{\gamma}_i)^{-1} \boldsymbol{k}_i\right)^\top \text{Diag}(\boldsymbol{\gamma}_t) ,\quad \boldsymbol{u}_t &= \beta_t \left( \boldsymbol{v}_t - \sum_{i=1}^{t-1} \boldsymbol{u}_i \left(\text{Diag}(\boldsymbol{\gamma}_i)^{-1} \boldsymbol{k}_i\right)^\top \left(\text{Diag}(\boldsymbol{\gamma}_{t}) \boldsymbol{k}_t\right) \right) \end{aligned}\]

这与 Gated DeltaNet 中几乎完全相同。

同时,形如 (I - beta_t k_t k_t^T)Householder变换 的乘积,总是可以使用 WY表示法 写成低秩形式。因此我们进一步推导它,再次使用几乎相同的归纳过程。

当 k = 0 时,我们有:

\[\begin{aligned} \mathbf{P}_{[t]}^{r} = \text{Diag}(\boldsymbol{\gamma}_{[t]}^r) \end{aligned}\]

因此我们假设:

\[\begin{aligned} \mathbf{P}_{[t]}^{r} = \text{Diag}(\boldsymbol{\gamma}_{[t]}^r) - \sum_{i=1}^{r} \text{Diag}(\boldsymbol{\eta}_{[t]}^r) \text{Diag}(\boldsymbol{\xi}_{[t]}^i) \boldsymbol{w}_{[t]}^{i} \boldsymbol{k}_{[t]}^{i \top} \text{Diag}(\boldsymbol{\epsilon}_{[t]}^i) \text{Diag}(\boldsymbol{\gamma}_{[t]}^r) \end{aligned}\]

然后我们有:

\[\begin{aligned} \mathbf{P}_{[t]}^{r} &= \mathbf{P}_{[t]}^{r-1} \text{Diag}(\boldsymbol{\alpha}_{[t]}^r) (\mathbf{I} - \beta_{[t]}^r \boldsymbol{k}_{[t]}^r \boldsymbol{k}_{[t]}^{r\top}) \\ \\&= \text{Diag}(\boldsymbol{\gamma}_{[t]}^r) - \text{Diag}(\boldsymbol{\gamma}_{[t]}^r) \beta_{[t]}^r \boldsymbol{k}_{[t]}^r \boldsymbol{k}_{[t]}^{r\top} \\ \\&- \sum_{i=1}^{r-1} \text{Diag}(\boldsymbol{\eta}_{[t]}^{r-1}) \text{Diag}(\boldsymbol{\xi}_{[t]}^i) \boldsymbol{w}_{[t]}^{i} \boldsymbol{k}_{[t]}^{i \top} \text{Diag}(\boldsymbol{\epsilon}_{[t]}^i) \text{Diag}(\boldsymbol{\gamma}_{[t]}^{r-1}) \\ \\&+ \sum_{i=1}^{r-1} \text{Diag}(\boldsymbol{\eta}_{[t]}^{r-1}) \text{Diag}(\boldsymbol{\xi}_{[t]}^i) \boldsymbol{w}_{[t]}^{i} \boldsymbol{k}_{[t]}^{i \top} \text{Diag}(\boldsymbol{\epsilon}_{[t]}^i) \text{Diag}(\boldsymbol{\gamma}_{[t]}^{r-1}) \text{Diag}(\boldsymbol{\alpha}_{[t]}^r) \beta_{[t]}^r \boldsymbol{k}_{[t]}^r \boldsymbol{k}_{[t]}^{r\top} \end{aligned}\]

通过消去同类项、如下设置参数并将 \xi 吸收进 w,我们得到:

\[\begin{aligned} \text{Diag}(\boldsymbol{\gamma}_{[t]}^i) &= \prod_{j=1}^i \text{Diag}(\boldsymbol{\alpha}_{[t]}^j) ,\quad \text{Diag}(\boldsymbol{\eta}_{[t]}^i) = \mathbf{I} ,\quad \text{Diag}(\boldsymbol{\epsilon}_{[t]}^i) = \text{Diag}(\boldsymbol{\gamma}_{[t]}^i)^{-1} \\ \\ \boldsymbol{w}_{[t]}^{r} &= \beta_{[t]}^r \left(\mathbf{I} - \sum_{i=1}^{r-1} \boldsymbol{w}_{[t]}^{i} \left( \text{Diag}(\boldsymbol{\gamma}_{[t]}^i)^{-1} \boldsymbol{k}_{[t]}^{i} \right)^\top \right) \text{Diag}(\boldsymbol{\gamma}_{[t]}^r) \boldsymbol{k}_{[t]}^r \\ \\ \mathbf{P}_{[t]}^{r} &= \left(\mathbf{I}- \sum_{i=1}^{r} \boldsymbol{w}_{[t]}^{i} \left( \text{Diag}(\boldsymbol{\gamma}_{[t]}^i)^{-1} \boldsymbol{k}_{[t]}^{i} \right)^\top \right) \text{Diag}(\boldsymbol{\gamma}_{[t]}^r) \end{aligned}\]

将其代入块级递推中得到:

\[\begin{aligned} \mathbf{S}_{[t]}^{r} &= \mathbf{S}_{[t-1]}^{C} \left(\mathbf{I}- \sum_{i=1}^{r} \boldsymbol{w}_{[t]}^{i} \left( \text{Diag}(\boldsymbol{\gamma}_{[t]}^i)^{-1} \boldsymbol{k}_{[t]}^{i} \right)^\top \right) \text{Diag}(\boldsymbol{\gamma}_{[t]}^r) \\&+ \sum_{i=1}^{r} \boldsymbol{u}_{[t]}^i \left( \text{Diag}(\boldsymbol{\gamma}_{[t]}^i)^{-1} \boldsymbol{k}_{[t]}^i\right)^\top \text{Diag}(\boldsymbol{\gamma}_{[t]}^r) \\&= \mathbf{S}_{[t-1]}^{C} \text{Diag}(\boldsymbol{\gamma}_{[t]}^r) + \sum_{i=1}^{r} \left( \boldsymbol{u}_{[t]}^i - \mathbf{S}_{[t-1]}^{C} \boldsymbol{w}_{[t]}^{i} \right) \left( \text{Diag}(\boldsymbol{\gamma}_{[t]}^i)^{-1} \boldsymbol{k}_{[t]}^{i} \right)^\top \text{Diag}(\boldsymbol{\gamma}_{[t]}^r) \\ \\ \boldsymbol{o}_{[t]}^{r} &= \mathbf{S}_{[t]}^{r} \boldsymbol{q}_{[t]}^{r} \\ \\&= \mathbf{S}_{[t-1]}^{C} \left( \text{Diag}(\boldsymbol{\gamma}_{[t]}^r) \boldsymbol{q}_{[t]}^{r} \right) + \sum_{i=1}^{r} \left( \boldsymbol{u}_{[t]}^i - \mathbf{S}_{[t-1]}^{C} \boldsymbol{w}_{[t]}^{i} \right) \left( \text{Diag}(\boldsymbol{\gamma}_{[t]}^i)^{-1} \boldsymbol{k}_{[t]}^{i} \right)^\top \left( \text{Diag}(\boldsymbol{\gamma}_{[t]}^r) \boldsymbol{q}_{[t]}^{r} \right) \end{aligned}\]

现在定义:

\[\begin{aligned} \mathbf{\Gamma}_{[t]} &= [ \boldsymbol{\gamma}_{[t]}^1, \boldsymbol{\gamma}_{[t]}^2, ..., \boldsymbol{\gamma}_{[t]}^C ]^\top \\ \\ \overleftarrow{\square_{[t]}} &= \square_{[t]} \odot \mathbf{\Gamma}_{[t]} ,\quad \overrightarrow{\square_{[t]}} = \square_{[t]} \oslash \mathbf{\Gamma}_{[t]} \quad\text{for}\quad \square \in \{ \mathbf{Q}, \mathbf{K}, \mathbf{W}\} \end{aligned}\]

然后我们可以将计算重写为矩阵形式(温馨提示:不要忘记对 O 进行转置):

\[\begin{aligned} \mathbf{S}_{[t]}^{C} &= \mathbf{S}_{[t-1]}^{C} \text{Diag}(\boldsymbol{\gamma}_{[t]}^C) + \left(\mathbf{U}_{[t]}^\top - \mathbf{S}_{[t-1]}^{C} \mathbf{W}_{[t]}^\top \right)\overrightarrow{\mathbf{K}_{[t]}} \text{Diag}(\boldsymbol{\gamma}_{[t]}^C) \\ \\ \mathbf{O}_{[t]} &= \overleftarrow{\mathbf{Q}_{[t]}} \mathbf{S}_{[t-1]}^{C \top} + \left( \overleftarrow{\mathbf{Q}_{[t]}} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M} \right) \left(\mathbf{U}_{[t]} - \mathbf{W}_{[t]} \mathbf{S}_{[t-1]}^{C \top} \right) \end{aligned}\]

评论: 这种形式非常自然,该推导再次证明了 DeltaNet 的强大之处。

接下来,我们转向 u_[t]w_[t]。这里,M_{-1} = M - I 表示对角线上为零的严格下三角掩码。

u_[t] 的递推式出发,我们得到:

\[\begin{aligned} \boldsymbol{u}_{[t]}^r &= \beta_{[t]}^r \left( \boldsymbol{v}_{[t]}^r - \sum_{i=1}^{r-1} \boldsymbol{u}_{[t]}^i \left(\text{Diag}(\boldsymbol{\gamma}_{[t]}^i)^{-1} \boldsymbol{k}_{[t]}^i\right)^\top \left(\text{Diag}(\boldsymbol{\gamma}_{[t]}^r) \boldsymbol{k}_{[t]}^r\right) \right) \\ \\ \Rightarrow \mathbf{U}_{[t]} &= \text{Diag}(\boldsymbol{\beta}_{[t]}) \left( \mathbf{V}_{[t]} - \left( \overleftarrow{\mathbf{K}_{[t]}} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M}_{-1} \right) \mathbf{U}_{[t]} \right) \\ \\ \Rightarrow \mathbf{U}_{[t]} &= \left( \mathbf{I} + \text{Diag}(\boldsymbol{\beta}_{[t]}) \left( \overleftarrow{\mathbf{K}_{[t]}} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M}_{-1} \right) \right)^{-1} \text{Diag}(\boldsymbol{\beta}_{[t]}) \mathbf{V}_{[t]} \end{aligned}\]

基于相同的理由,我们也有:

\[\begin{aligned} \mathbf{W}_{[t]} = \left( \mathbf{I} + \text{Diag}(\boldsymbol{\beta}_{[t]}) \left( \overleftarrow{\mathbf{K}_{[t]}} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M}_{-1} \right) \right)^{-1} \text{Diag}(\boldsymbol{\beta}_{[t]}) \overleftarrow{\mathbf{K}_{[t]}} \end{aligned}\]

因此,为了计算 U_[t]W_[t],关键的量是:

\[\begin{aligned} \mathbf{\widetilde{A}}_{[t]} = \left(\mathbf{I} + \text{Diag}(\boldsymbol{\beta}_{[t]}) \left( \overleftarrow{\mathbf{K}_{[t]}} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M}_{-1} \right) \right)^{-1} \end{aligned}\]

矩阵求逆步骤可以按照与 DeltaNet 笔记中相同的方法进行处理。

五、相关方法

与 DPLR 的对比

\[\begin{aligned} \mathbf{S}_t &= \mathbf{S}_{t-1} (\text{Diag}(\boldsymbol{\alpha}_t) - \boldsymbol{b}_t \boldsymbol{a}_t^\top) + \beta_t \boldsymbol{v}_t \boldsymbol{k}_t^\top \in \mathbb{R}^{d_v \times d_k} \end{aligned}\]

评论(来自原文): 先前的研究(如 GLA)已经指出,当与衰减参数 Diag 交互时,必须高度注意数值精度问题。最好在 全精度下使用 次级分块,这是一种用速度换取稳定性的做法。在 DPLR 中,这可能会导致四个数值极其脆弱的项,即 Q @ K、Q @ B、A @ B、A @ K。KDA 将变量 a 和 b 都绑定到了 k 上,从而将其减少到只有两项,即 Q * K、K @ K。

与其他模型的对比

测试时训练的视角

六、实验

模型架构

  1. KDA 层
\[\begin{aligned} \boldsymbol{q}_{t}^{h}, \boldsymbol{k}_{t}^{h} &= \operatorname{L2Norm}\left(\operatorname{Swish}\left(\operatorname{ShortConv}\left(\mathbf{W}_{q / k}^{h} \boldsymbol{x}_{t}\right)\right)\right) \in \mathbb{R}^{d_{k}} \\ \\ \boldsymbol{v}_{t}^{h} &=\operatorname{Swish}\left(\operatorname{ShortConv}\left(\mathbf{W}_{v}^{h} \boldsymbol{x}_{t}\right)\right) \in \mathbb{R}^{d_{v}} \\ \\ \alpha_{t}^{h} &=f\left(\mathbf{W}_{\alpha}^{\uparrow} \mathbf{W}_{\alpha}^{\downarrow} \boldsymbol{x}_{t}\right) \in[0,1]^{d_{k}} \\ \\ \beta_{t}^{h} &=\operatorname{Sigmoid}\left(\mathbf{W}_{\beta}^{h} \boldsymbol{x}_{t}\right) \in[0,1] \\ \\ \boldsymbol{o}_{t} &=\mathbf{W}_{o}\left(\operatorname{Sigmoid}\left(\mathbf{W}_{g}^{\uparrow} \mathbf{W}_{g}^{\downarrow} \boldsymbol{x}_{t}\right) \odot \operatorname{RMSNorm}\left(\operatorname{KDA}\left(\boldsymbol{q}_{t}, \boldsymbol{k}_{t}, \boldsymbol{v}_{t}, \boldsymbol{\alpha}_{t}, \beta_{t}\right)\right)\right) \end{aligned}\]

门控 \(\alpha_t^h\) 的计算方式为 \alpha_t^h = -exp(A_log) * softplus(W^up W^down x + dt_bias),这类似于 mamba。其中A_log 的形状为 [H],而 dt_bias 的形状为 [H * K]

  1. 层间混合 Kimi Linear 使用了 3:1 的 KDA 与 MLA 比例。选择层间混合而不是层内混合(例如每个头分别采用不同的attention)的主要原因是 infra 更为简单。

  2. 为 MLA 使用 NoPE MLA 使用 NoPE,因为 KDA 能够通过其循环衰减机制天生地捕获位置信息。这种设计对于长上下文的外推也非常有益。

合成数据测试

Palindrome 生成反转序列

Multi Query Associative Recall (MQAR) 输出下一个

Stack 具有后进先出 (LIFO) 的栈操作

缩放定律

  • 基座:MoE Moonlight 架构
  • 激活专家 / 总专家数:8 / 64
  • 优化器:Muon Optimizer
  • 使用了 Chinchilla 缩放定律

语言建模

预训练

  • 基座 MoE Moonlight 架构
  • 激活专家 / 总专家数:8 / 256,包含一个共享专家
  • 48B-A8B 规模
  • 4096 上下文窗口
  • Muon Clip 优化器
  • WSD 学习率调度
  • 来自 K2 预训练语料库的 1.4T tokens
  • 学习率设置为 1.1 × 10−3
  • 全局 batch size 固定为 32M tokens
  • 与 Kimi K2 中确立的退火调度和长上下文激活阶段相同

后训练: SFT - Kimi K2 SFT 数据 + 额外的推理任务 - 多阶段 SFT 方法:最初用于通用指令遵循,随后是推理密集型数据

后训练: RL - 来自 Kimi K2 数据的数学、代码和 STEM 任务,筛选出对于起始权重为中等难度的数据。 - 额外的 PTX loss。PTX dataset涵盖了推理和通用任务。

[注意] 后训练: RL

训练和推理引擎之间的精度不匹配可能会导致RL 学习不稳定 -> 采用截断的重要性采样,动态调整 KL 惩罚项以及 mini batch 大小。

评估

  • 语言理解与推理: Hellaswag, ARC-Challenge, Winogrande, MMLU, TriviaQA, MMLU-Redux, MMLU-Pro, GPQA-Diamond, BBH, Livebench.
  • 代码生成: LiveCodeBench v6, EvalPlus
  • 数学和推理: AIME 2025, MATH 500, HMMT 2025, PolyMath-en.
  • 长上下文: MRCR 5 , RULER, Frames, HELMET-ICL, RepoQA, Long Code Arena, LongBench v2.
  • 中文语言理解与推理: C-Eval, and CMMLU.
  • Temprature=1.0, LM-Harness-Evaluation
  • Base模型对 MMLU, MMLU-Redux, GPQA-Diamond, and C-Eval 采用基于PPL的评估, 其他情况使用基于生成的评估.

Comments