Diagonal Plus Low Rank
代码: https://github.com/fla-org/flash-linear-attention
声明: These are personal reading notes. Some derivations are my own and may be incorrect, so they should be cross-checked against the code later.
一、符号约定
- 使用 \(\mathbf{S, Q}\) 等粗体大写字母表示矩阵
- 使用 \(\mathbf{q}_t, \mathbf{k}_t\) 等表示列向量(即 \([d, 1]\) 的形式),矩阵则是 \([L, d]\) 的形式,因此会有额外的转置操作
- 使用 \(W_t\) 等表示可学习参数
- 使用 \(\mathbf{q}_t\) 表示 \(\mathbf{Q}\) 的第 \(t\) 行
- \(\square_{[t]} = \square_{[t]}^{1:C} \in \mathbb{R}^{C \times d}\) 表示第 \(t\) 个 chunk,其中 \(\square \in { \mathbf{Q, K, V, \dots} }\)
二、带Gate的PRLR前向过程推导
原始公式:
\[\begin{aligned}
\mathbf{S}_t
&= \mathbf{S}_{t-1}
\text{Diag}(\boldsymbol{\alpha}_t)
(\mathbf{I} - \boldsymbol{b}_t \boldsymbol{a}_t^\top)
+
\boldsymbol{v}_t \boldsymbol{k}_t^\top
\in \mathbb{R}^{d_v \times d_k}
\\
\\
\boldsymbol{o}_t
&= \mathbf{S}_{t}\boldsymbol{q}_t
\end{aligned}\]
接下来,遵循 DeltaNet 的推导,定义块级符号和以下辅助量:
\[\begin{aligned}
\mathbf{P}_{[t]}^{r} &= \prod_{i=t C + 1}^{t C + r}
\text{Diag}(\boldsymbol{\alpha}_i)(\mathbf{I} - \boldsymbol{b}_{i} \boldsymbol{a}_{i}^{\top})
\in \mathbb{R}^{d_k \times d_k}
\\
\\
\mathbf{H}_{[t]}^{r} &= \sum_{i=tC + 1}^{tC + r}
(\boldsymbol{v}_{i} \boldsymbol{k}_{i}^{\top})
\prod_{j=i + 1}^{t C + r}
\text{Diag}(\boldsymbol{\alpha}_j)(\mathbf{I} - \boldsymbol{b}_{j} \boldsymbol{a}_{j}^{\top})
\in \mathbb{R}^{d_v \times d_k}
\end{aligned}\]
然后,块级状态可以写为:
\[\begin{aligned}
\mathbf{S}_{[t]}^{r}
=
\mathbf{S}_{[t-1]}^{C}
\mathbf{P}_{[t]}^{r} + \mathbf{H}_{[t]}^{r}
, \quad
\text{where }
\mathbf{S}_{[-1]}^{C} = \mathbf{0}
\end{aligned}\]
同时,形如 (I - beta_t k_t k_t^T) 的 Householder变换 的乘积,总是可以使用 WY表示法 写成低秩形式。因此我们进一步推导它,再次使用几乎相同的归纳过程。
当 k = 0 时,我们有:
\[\begin{aligned}
\mathbf{P}_{[t]}^{r}
= \prod_{i=t C + 1}^{t C + r} \text{Diag}(\boldsymbol{\alpha}_{[t]}^i)
= \text{Diag}(\boldsymbol{\gamma}_{[t]}^r)
\end{aligned}\]
因此我们假设:
\[\begin{aligned}
\mathbf{P}_{[t]}^{r} = \text{Diag}(\boldsymbol{\gamma}_{[t]}^r) - \sum_{i=1}^{r}
\text{Diag}(\boldsymbol{\eta}_{[t]}^r)
\text{Diag}(\boldsymbol{\xi}_{[t]}^i)
\boldsymbol{w}_{[t]}^{i} \boldsymbol{a}_{[t]}^{i \top}
\text{Diag}(\boldsymbol{\epsilon}_{[t]}^i)
\text{Diag}(\boldsymbol{\lambda}_{[t]}^r)
\end{aligned}\]
然后我们有:
\[\begin{aligned}
\mathbf{P}_{[t]}^{r}
&=
\mathbf{P}_{[t]}^{r-1}
\text{Diag}(\boldsymbol{\alpha}_{[t]}^r)
(\mathbf{I} - \boldsymbol{b}_{[t]}^r \boldsymbol{a}_{[t]}^{r\top})
\\ \\&=
\text{Diag}(\boldsymbol{\gamma}_{[t]}^r)
-
\text{Diag}(\boldsymbol{\gamma}_{[t]}^r)
\boldsymbol{b}_{[t]}^r \boldsymbol{a}_{[t]}^{r\top}
\\ \\&-
\sum_{i=1}^{r-1}
\text{Diag}(\boldsymbol{\eta}_{[t]}^{r-1})
\text{Diag}(\boldsymbol{\xi}_{[t]}^i)
\boldsymbol{w}_{[t]}^{i} \boldsymbol{a}_{[t]}^{i \top}
\text{Diag}(\boldsymbol{\epsilon}_{[t]}^i)
\text{Diag}(\boldsymbol{\lambda}_{[t]}^{r-1})
\\ \\&+
\sum_{i=1}^{r-1}
\text{Diag}(\boldsymbol{\eta}_{[t]}^{r-1})
\text{Diag}(\boldsymbol{\xi}_{[t]}^i)
\boldsymbol{w}_{[t]}^{i} \boldsymbol{a}_{[t]}^{i \top}
\text{Diag}(\boldsymbol{\epsilon}_{[t]}^i)
\text{Diag}(\boldsymbol{\lambda}_{[t]}^{r-1})
\text{Diag}(\boldsymbol{\alpha}_{[t]}^r)
\boldsymbol{b}_{[t]}^r \boldsymbol{a}_{[t]}^{r\top}
\end{aligned}\]
通过消去同类项、如下设置参数并将 \xi 吸收进 w,我们得到:
\[\begin{aligned}
\text{Diag}(\boldsymbol{\lambda}_{[t]}^i)
&=
\text{Diag}(\boldsymbol{\gamma}_{[t]}^i)
=
\prod_{j=1}^i
\text{Diag}(\boldsymbol{\alpha}_{[t]}^j)
,\quad
\text{Diag}(\boldsymbol{\eta}_{[t]}^i) = \mathbf{I}
,\quad
\text{Diag}(\boldsymbol{\epsilon}_{[t]}^i)
= \text{Diag}(\boldsymbol{\gamma}_{[t]}^i)^{-1}
\\
\\
\boldsymbol{w}_{[t]}^{r}
&=
\left(\mathbf{I} -
\sum_{i=1}^{r-1}
\boldsymbol{w}_{[t]}^{i}
\left(
\text{Diag}(\boldsymbol{\gamma}_{[t]}^i)^{-1}
\boldsymbol{a}_{[t]}^{i}
\right)^\top
\right)
\text{Diag}(\boldsymbol{\gamma}_{[t]}^r)
\boldsymbol{b}_{[t]}^r
\\
\\
\mathbf{P}_{[t]}^{r} &=
\left(\mathbf{I}- \sum_{i=1}^{r}
\boldsymbol{w}_{[t]}^{i}
\left(
\text{Diag}(\boldsymbol{\gamma}_{[t]}^i)^{-1}
\boldsymbol{a}_{[t]}^{i}
\right)^\top
\right)
\text{Diag}(\boldsymbol{\gamma}_{[t]}^r)
\end{aligned}\]
H 的前传是与之前的推导最为不同的部分,这里我们需要假定S 可以被表达为含有两项的递推式.
\[\begin{aligned}
\mathbf{S}_t
=
\sum_{i=1}^{t}
\text{Diag}(\boldsymbol{\eta}_t)
\boldsymbol{u}_i
\boldsymbol{k}_i^\top
\text{Diag}(\boldsymbol{\epsilon}_i)
\text{Diag}(\boldsymbol{\gamma}_t)
+
\sum_{i=1}^{t}
\text{Diag}(\boldsymbol{\eta^1}_t)
\boldsymbol{c}_i
\boldsymbol{a}_i^\top
\text{Diag}(\boldsymbol{\epsilon^1}_i)
\text{Diag}(\boldsymbol{\gamma^1}_t)
\end{aligned}\]
由数学归纳法我们得到:
\[\begin{aligned}
\mathbf{S}_t
&=
\mathbf{S}_{t-1}
\text{Diag}(\boldsymbol{\alpha}_t)
- \mathbf{S}_{t-1}
\text{Diag}(\boldsymbol{\alpha}_t)
\boldsymbol{b}_t
\boldsymbol{a}_t^\top
+ \boldsymbol{v}_t \boldsymbol{k}_t^\top
\\&=
\sum_{i=1}^{t-1}
\text{Diag}(\boldsymbol{\eta}_{t-1})
\boldsymbol{u}_i
\boldsymbol{k}_i^\top
\text{Diag}(\boldsymbol{\epsilon}_i)
\text{Diag}(\boldsymbol{\gamma}_{t-1})
\text{Diag}(\boldsymbol{\alpha}_{t})
\\&+
\sum_{i=1}^{t-1}
\text{Diag}(\boldsymbol{\eta^1}_{t-1})
\boldsymbol{c}_i
\boldsymbol{a}_i^\top
\text{Diag}(\boldsymbol{\epsilon^1}_i)
\text{Diag}(\boldsymbol{\gamma^1}_{t-1})
\text{Diag}(\boldsymbol{\alpha}_{t})
\\&-
\sum_{i=1}^{t-1}
\text{Diag}(\boldsymbol{\eta}_{t-1})
\boldsymbol{u}_i
\boldsymbol{k}_i^\top
\text{Diag}(\boldsymbol{\epsilon}_i)
\text{Diag}(\boldsymbol{\gamma}_{t-1})
\text{Diag}(\boldsymbol{\alpha}_{t})
\boldsymbol{b}_t
\boldsymbol{a}_t^\top
\\&-
\sum_{i=1}^{t-1}
\text{Diag}(\boldsymbol{\eta^1}_{t-1})
\boldsymbol{c}_i
\boldsymbol{a}_i^\top
\text{Diag}(\boldsymbol{\epsilon^1}_i)
\text{Diag}(\boldsymbol{\gamma^1}_{t-1})
\text{Diag}(\boldsymbol{\alpha}_{t})
\boldsymbol{b}_t
\boldsymbol{a}_t^\top
\\&+ \boldsymbol{v}_t \boldsymbol{k}_t^\top
\\&=
\sum_{i=1}^{t}
\text{Diag}(\boldsymbol{\eta}_{t})
\boldsymbol{u}_i
\boldsymbol{k}_i^\top
\text{Diag}(\boldsymbol{\epsilon}_i)
\text{Diag}(\boldsymbol{\gamma}_{t})
\\&+
\sum_{i=1}^{t}
\text{Diag}(\boldsymbol{\eta^1}_{t})
\boldsymbol{c}_i
\boldsymbol{a}_i^\top
\text{Diag}(\boldsymbol{\epsilon^1}_i)
\text{Diag}(\boldsymbol{\gamma^1}_{t})
\end{aligned}\]
我们定义以下辅助量:
\[\begin{aligned}
\text{Diag}(\boldsymbol{\gamma}_t)
&= \text{Diag}(\boldsymbol{\gamma^1}_t)
= \prod_{i=1}^t
\text{Diag}(\boldsymbol{\alpha}_i)
,\quad
\text{Diag}(\boldsymbol{\epsilon}_t)
= \text{Diag}(\boldsymbol{\epsilon^1}_t)
= \text{Diag}(\boldsymbol{\gamma}_t)^{-1}
\\
\\
\text{Diag}(\boldsymbol{\eta}_t)
&= \text{Diag}(\boldsymbol{\eta^1}_t)
= \mathbf{I}
\end{aligned}\]
然后,我们就可以得到:
\[\begin{aligned}
\boldsymbol{u}_{[t]}^r &= \boldsymbol{v}_{[t]}^r
\\
\\
\boldsymbol{c}_{[t]}^r
&=
- \sum_{i=1}^{r-1} \left(
\boldsymbol{v}_i
\left(\text{Diag}(\boldsymbol{\gamma}_{[t]}^i)^{-1} \boldsymbol{k}_{[t]}^i\right)^\top
+ \boldsymbol{c}_{[t]}^i
\left(\text{Diag}(\boldsymbol{\gamma}_{[t]}^i)^{-1} \boldsymbol{a}_{[t]}^i\right)^\top
\right)
\left(\text{Diag}(\boldsymbol{\gamma}_{[t]}^r) \boldsymbol{b}_{[t]}^r\right)
\end{aligned}\]
S 和 o 可以分别表示为:
\[\begin{aligned}
\mathbf{S}_{[t]}^r
&=
\mathbf{S}_{[t-1]}^C
\left(\mathbf{I}- \sum_{i=1}^{r}
\boldsymbol{w}_{[t]}^{i}
\left(
\text{Diag}(\boldsymbol{\gamma}_{[t]}^i)^{-1}
\boldsymbol{a}_{[t]}^{i}
\right)^\top
\right)
\text{Diag}(\boldsymbol{\gamma}_{[t]}^r)
\\&+
\sum_{i=1}^{r}
\boldsymbol{v}_i
\left(\text{Diag}(\boldsymbol{\gamma}_{[t]}^i)^{-1} \boldsymbol{k}_{[t]}^i\right)^\top
\text{Diag}(\boldsymbol{\gamma}_{[t]}^r)
+
\sum_{i=1}^{r}
\boldsymbol{c}_i
\left(\text{Diag}(\boldsymbol{\gamma}_{[t]}^i)^{-1} \boldsymbol{a}_{[t]}^i\right)^\top
\text{Diag}(\boldsymbol{\gamma}_{[t]}^r)
\\ &=
\mathbf{S}_{[t-1]}^C
\text{Diag}(\boldsymbol{\gamma}_{[t]}^r)
+
\sum_{i=1}^{r}
\boldsymbol{v}_i
\left(\text{Diag}(\boldsymbol{\gamma}_{[t]}^i)^{-1} \boldsymbol{k}_{[t]}^i\right)^\top
\text{Diag}(\boldsymbol{\gamma}_{[t]}^r)
\\&+
\sum_{i=1}^{r}
\left(
\boldsymbol{c}_i
- \mathbf{S}_{[t-1]}^C \boldsymbol{w}_{[t]}^{i}
\right)
\left(\text{Diag}(\boldsymbol{\gamma}_{[t]}^i)^{-1} \boldsymbol{a}_{[t]}^i\right)^\top
\text{Diag}(\boldsymbol{\gamma}_{[t]}^r)
\\
\\
\boldsymbol{o}_{[t]}^r
&=
\mathbf{S}_{[t-1]}^C
\text{Diag}(\boldsymbol{\gamma}_{[t]}^r)
\boldsymbol{q}_{[t]}^r
+
\sum_{i=1}^{r}
\boldsymbol{v}_i
\left(\text{Diag}(\boldsymbol{\gamma}_{[t]}^i)^{-1} \boldsymbol{k}_{[t]}^i\right)^\top
\text{Diag}(\boldsymbol{\gamma}_{[t]}^r)
\boldsymbol{q}_{[t]}^r
\\&+
\sum_{i=1}^{r}
\left(
\boldsymbol{c}_{[t]}^i
- \mathbf{S}_{[t-1]}^C \boldsymbol{w}_{[t]}^{i}
\right)
\left(\text{Diag}(\boldsymbol{\gamma}_{[t]}^i)^{-1} \boldsymbol{a}_{[t]}^i\right)^\top
\text{Diag}(\boldsymbol{\gamma}_{[t]}^r)
\boldsymbol{q}_{[t]}^r
\end{aligned}\]
定义
\[\begin{aligned}
\mathbf{\Gamma}_{[t]} &= [
\boldsymbol{\gamma}_{[t]}^1,
\boldsymbol{\gamma}_{[t]}^2,
...,
\boldsymbol{\gamma}_{[t]}^C
]^\top
\\
\\
\overleftarrow{\square_{[t]}}
&=
\square_{[t]}
\odot
\mathbf{\Gamma}_{[t]}
,\quad
\overrightarrow{\square_{[t]}}
=
\square_{[t]}
\oslash
\mathbf{\Gamma}_{[t]}
\quad\text{for}\quad \square \in \{ \mathbf{Q}, \mathbf{K}, \mathbf{A}, \mathbf{B}\}
\end{aligned}\]
然后我们可以将计算重写为矩阵形式:
\[\begin{aligned}
\mathbf{S}_{[t]}^C
&=
\mathbf{S}_{[t-1]}^C
\text{Diag}(\boldsymbol{\gamma}_{[t]}^C)
+
\mathbf{V}_{[t]}^\top
\overrightarrow{\mathbf{K}_{[t]}}
\text{Diag}(\boldsymbol{\gamma}_{[t]}^C)
+
\left(
\mathbf{C}_{[t]} - \mathbf{W}_{[t]} \mathbf{S}_{[t-1]}^{C\top}
\right)^\top
\overrightarrow{\mathbf{A}_{[t]}}
\text{Diag}(\boldsymbol{\gamma}_{[t]}^C)
\\
\\
\mathbf{O}_{[t]}
&=
\overleftarrow{\mathbf{Q}_{[t]}}
\mathbf{S}_{[t-1]}^{C\top}
+
\left(
\overleftarrow{\mathbf{Q}_{[t]}}
\overrightarrow{\mathbf{K}_{[t]}}^\top
\odot \mathbf{M} \right) \mathbf{V}_{[t]}
+
\left(
\overleftarrow{\mathbf{Q}_{[t]}}
\overrightarrow{\mathbf{A}_{[t]}}^\top
\odot \mathbf{M} \right)
\left(
\mathbf{C}_{[t]} - \mathbf{W}_{[t]} \mathbf{S}_{[t-1]}^{C\top}
\right)
\end{aligned}\]
接下来,我们转向 u_[t] 和 w_[t]。这里,M_{-1} = M - I 表示对角线上为零的严格下三角掩码。
从 u_[t] 的递推式出发,我们得到:
\[\begin{aligned}
\boldsymbol{w}_{[t]}^r
&=
\left(
\mathbf{I} -
\sum_{i=1}^{r-1}
\boldsymbol{w}_{[t]}^i
\left(\text{Diag}(\boldsymbol{\gamma}_{[t]}^i)^{-1} \boldsymbol{a}_{[t]}^i\right)^\top
\right)
\left(\text{Diag}(\boldsymbol{\gamma}_{[t]}^r) \boldsymbol{b}_{[t]}^r\right)
\\
\\
\Rightarrow
\mathbf{W}_{[t]}
&=
\overleftarrow{\mathbf{B}_{[t]}} -
\left(
\overleftarrow{\mathbf{B}_{[t]}}
\overrightarrow{\mathbf{A}_{[t]}}^\top
\odot
\mathbf{M}_{-1}
\right)
\mathbf{W}_{[t]}
\\
\\
\Rightarrow
\mathbf{W}_{[t]}
&=
\left(
\mathbf{I} +
\left(
\overleftarrow{\mathbf{B}_{[t]}}
\overrightarrow{\mathbf{A}_{[t]}}^\top
\odot
\mathbf{M}_{-1}
\right)
\right)^{-1}
\overleftarrow{\mathbf{B}_{[t]}}
\end{aligned}\]
基于相同的理由,我们也有:
\[\begin{aligned}
\mathbf{C}_{[t]}
&=
- \left(
\overleftarrow{\mathbf{B}_{[t]}}
\overrightarrow{\mathbf{K}_{[t]}}^\top
\odot
\mathbf{M}_{-1}
\right)
\mathbf{V}_{[t]}
- \left(
\overleftarrow{\mathbf{B}_{[t]}}
\overrightarrow{\mathbf{A}_{[t]}}^\top
\odot
\mathbf{M}_{-1}
\right)
\mathbf{C}_{[t]}
\\
\\
\Rightarrow
\mathbf{C}_{[t]}
&=
- \left(
\mathbf{I} +
\left(
\overleftarrow{\mathbf{B}_{[t]}}
\overrightarrow{\mathbf{A}_{[t]}}^\top
\odot
\mathbf{M}_{-1}
\right)
\right)^{-1}
\left(
\overleftarrow{\mathbf{B}_{[t]}}
\overrightarrow{\mathbf{K}_{[t]}}^\top
\odot
\mathbf{M}_{-1}
\right)
\mathbf{V}_{[t]}
\end{aligned}\]
因此,为了计算 U_[t] 和 W_[t],关键的量是:
\[\begin{aligned}
\mathbf{E}_{[t]}
=
\left(\mathbf{I} + \left( \overleftarrow{\mathbf{B}_{[t]}} \overrightarrow{\mathbf{A}_{[t]}}^\top \odot \mathbf{M}_{-1} \right) \right)^{-1}
\end{aligned}\]
矩阵求逆步骤可以按照与 DeltaNet 笔记中相同的方法进行处理。
三、带Gate的PRLR梯度反传推导
回顾前向传播公式
\[\begin{aligned}
\mathbf{F}_{[t]}
&=
\mathbf{I} + \left( \overleftarrow{\mathbf{B}_{[t]}} \overrightarrow{\mathbf{A}_{[t]}}^\top \odot \mathbf{M}_{-1} \right)
,\quad
\mathbf{E}_{[t]} = \mathbf{F}_{[t]}^{-1}
\\
\\
\mathbf{W}_{[t]} &= \mathbf{E}_{[t]} \overleftarrow{\mathbf{B}_{[t]}}
,\quad
\mathbf{C}_{[t]} =
- \mathbf{E}_{[t]}
\left(
\overleftarrow{\mathbf{B}_{[t]}}
\overrightarrow{\mathbf{K}_{[t]}}^\top
\odot
\mathbf{M}_{-1}
\right)
\mathbf{V}_{[t]}
\\
\\
\mathbf{V}_{[t],1} &:= \mathbf{V}_{[t]}
,\quad
\mathbf{V}_{[t],2} := \left(\mathbf{C}_{[t]} - \mathbf{W}_{[t]} \mathbf{S}_{[t-1]}^{C \top} \right)
\\
\\
\mathbf{S}_{[t]}^C
&=
\mathbf{S}_{[t-1]}^C
\text{Diag}(\boldsymbol{\gamma}_{[t]}^C)
+
\mathbf{V}_{[t],1}^\top
\overrightarrow{\mathbf{K}_{[t]}}
\text{Diag}(\boldsymbol{\gamma}_{[t]}^C)
+
\mathbf{V}_{[t],2}^\top
\overrightarrow{\mathbf{A}_{[t]}}
\text{Diag}(\boldsymbol{\gamma}_{[t]}^C)
\\
\\
\mathbf{O}_{[t]}
&=
\overleftarrow{\mathbf{Q}_{[t]}}
\mathbf{S}_{[t-1]}^{C\top}
+
\left(
\overleftarrow{\mathbf{Q}_{[t]}}
\overrightarrow{\mathbf{K}_{[t]}}^\top
\odot \mathbf{M} \right) \mathbf{V}_{[t],1}
+
\left(
\overleftarrow{\mathbf{Q}_{[t]}}
\overrightarrow{\mathbf{A}_{[t]}}^\top
\odot \mathbf{M} \right)
\mathbf{V}_{[t],2}
\\
\\
\mathbf{\Gamma}_{[t]} &= [
\boldsymbol{\gamma}_{[t]}^1,
\boldsymbol{\gamma}_{[t]}^2,
...,
\boldsymbol{\gamma}_{[t]}^C
]^\top
\\
\\
\overleftarrow{\square_{[t]}}
&=
\square_{[t]}
\odot
\mathbf{\Gamma}_{[t]}
,\quad
\overrightarrow{\square_{[t]}}
=
\square_{[t]}
\oslash
\mathbf{\Gamma}_{[t]}
\quad\text{for}\quad \square \in \{ \mathbf{Q}, \mathbf{K}\}
\end{aligned}\]
对于 \(\delta \mathbf{C}_{[t]}\), \(\delta \mathbf{V}_{[t],1}\), \(\delta \mathbf{W}_{[t]}\)
\[\begin{aligned}
\delta \mathbf{C}_{[t]}
&= \delta \mathbf{V}_{[t],2}
=
\overrightarrow{\mathbf{A}_{[t]}}
\text{Diag}(\boldsymbol{\gamma}_{[t]}^C)
\delta \mathbf{S}_{[t]}^{C \top}
+
\left(
\overleftarrow{\mathbf{Q}_{[t]}}
\overrightarrow{\mathbf{A}_{[t]}}^\top
\odot \mathbf{M}
\right)^\top
\delta \mathbf{O}_{[t]}
\\
\\
\delta \mathbf{V}_{[t],1}
& =
\overrightarrow{\mathbf{K}_{[t]}}
\text{Diag}(\boldsymbol{\gamma}_{[t]}^C)
\delta \mathbf{S}_{[t]}^{C \top}
+
\left(
\overleftarrow{\mathbf{Q}_{[t]}}
\overrightarrow{\mathbf{K}_{[t]}}^\top
\odot \mathbf{M}
\right)^\top
\delta \mathbf{O}_{[t]}
\\
\\
\delta \mathbf{W}_{[t]}
&=
- \delta \mathbf{C}_{[t]} \mathbf{S}_{[t]}^{C}
\end{aligned}\]
对于 \(\delta \mathbf{E}_{[t]}\), \(\delta \mathbf{F}_{[t]}\)
\[\begin{aligned}
\delta \mathbf{E}_{[t]}
&=
- \delta \mathbf{C}_{[t]}
\mathbf{V}_{[t]}^\top
\left(
\overleftarrow{\mathbf{B}_{[t]}}
\overrightarrow{\mathbf{K}_{[t]}}^\top
\odot
\mathbf{M}_{-1}
\right)^\top
+ \delta \mathbf{W}_{[t]} \overleftarrow{\mathbf{B}_{[t]}}^\top
\\
\\
\delta \mathbf{F}_{[t]}
&=
- \mathbf{E}_{[t]}^\top \delta \mathbf{E}_{[t]} \mathbf{E}_{[t]}^\top
\end{aligned}\]
对于 \(\delta \mathbf{S}_{[t]}^C\)
\[\begin{aligned}
\delta \mathbf{S}_{[t-1]}^{C}
&=
-\delta \mathbf{C}_{[t]}^\top
\mathbf{W}_{[t]}
+
\delta \mathbf{S}_{[t]}^{C}
\text{Diag}(\boldsymbol{\gamma}_{[t]}^C)
+ \delta \mathbf{O}_{[t]}^\top
\overleftarrow{\mathbf{Q}_{[t]}}
\\
\\
\Rightarrow
\delta \mathbf{S}_{[t]}^{C}
&=
\delta \mathbf{S}_{[t+1]}^{C}
\text{Diag}(\boldsymbol{\gamma}_{[t+1]}^C)
+ \delta \mathbf{O}_{[t+1]}^\top
\overleftarrow{\mathbf{Q}_{[t+1]}}
-\delta \mathbf{C}_{[t+1]}^\top
\mathbf{W}_{[t+1]}
\end{aligned}\]
对于 \(\delta \mathbf{Q}_{[t]}\)
\[\begin{aligned}
\delta \overleftarrow{\mathbf{Q}_{[t]}}
&=
\delta \mathbf{O}_{[t]}
\mathbf{S}_{[t-1]}^C
+
\left(\delta \mathbf{O}_{[t]}
\mathbf{V}_{[t],1}^\top \odot \mathbf{M}\right) \overrightarrow{\mathbf{K}_{[t]}}
+
\left(\delta \mathbf{O}_{[t]}
\mathbf{V}_{[t],2}^\top \odot \mathbf{M}\right) \overrightarrow{\mathbf{A}_{[t]}}
\\
\\
\delta \mathbf{Q}_{[t]}
&=
\delta \overleftarrow{\mathbf{Q}_{[t]}}
\odot \mathbf{\Gamma}_{[t]}
\end{aligned}\]
对于 \(\delta \mathbf{B}_{[t]}\)
\[\begin{aligned}
\delta \overleftarrow{\mathbf{B}_{[t]}}
&=
\mathbf{E}_{[t]}^\top
\delta \mathbf{W}_{[t]}
-
\left(
\mathbf{E}_{[t]}^\top
\delta \mathbf{C}_{[t]}
\mathbf{V}_{[t]}^\top
\odot \mathbf{M}_{-1}
\right)
\overrightarrow{\mathbf{K}_{[t]}}
+
\left(
\delta \mathbf{F}_{[t]}
\odot \mathbf{M}_{-1}
\right)
\overrightarrow{\mathbf{A}_{[t]}}
\\
\\
\delta \mathbf{B}_{[t]}
&=
\delta \overleftarrow{\mathbf{Q}_{[t]}}
\odot \mathbf{\Gamma}_{[t]}
\end{aligned}\]
对于 \(\delta \mathbf{K}_{[t]}\)
\[\begin{aligned}
\delta \overrightarrow{\mathbf{K}_{[t]}}
&=
\left(\delta \mathbf{O}_{[t]}
\mathbf{V}_{[t],1}^\top \odot \mathbf{M}\right)^\top
\overleftarrow{\mathbf{Q}_{[t]}}
-
\left(
\mathbf{E}_{[t]}^\top
\delta \mathbf{C}_{[t]}
\mathbf{V}_{[t]}^\top
\odot \mathbf{M}_{-1}
\right)^\top
\overleftarrow{\mathbf{B}_{[t]}}
\\
\\
\delta \mathbf{K}_{[t]}
&=
\delta \overrightarrow{\mathbf{K}_{[t]}}
\oslash \mathbf{\Gamma}_{[t]}
\end{aligned}\]
对于 \(\delta \mathbf{A}_{[t]}\)
\[\begin{aligned}
\delta \overrightarrow{\mathbf{A}_{[t]}}
&=
\left(\delta \mathbf{O}_{[t]}
\mathbf{V}_{[t],2}^\top \odot \mathbf{M}\right)^\top
\overleftarrow{\mathbf{Q}_{[t]}}
+
\left(
\delta \mathbf{F}_{[t]}
\odot \mathbf{M}_{-1}
\right)^\top
\overleftarrow{\mathbf{B}_{[t]}}
\\
\\
\delta \mathbf{A}_{[t]}
&=
\delta \overrightarrow{\mathbf{A}_{[t]}}
\oslash \mathbf{\Gamma}_{[t]}
\end{aligned}\]
对于 \(\delta \mathbf{V}_{[t]}\)
\[\begin{aligned}
\delta \mathbf{V}_{[t]}
=
\delta \mathbf{V}_{[t],1}
-
\left(
\overleftarrow{\mathbf{B}_{[t]}}
\overrightarrow{\mathbf{K}_{[t]}}^\top
\odot
\mathbf{M}_{-1}
\right)^\top
\mathbf{E}_{[t]} ^\top
\delta \mathbf{C}_{[t]}
\end{aligned}\]
对于 \(\delta \mathbf{\Gamma}_{[t]}\)
\[\begin{aligned}
\delta \boldsymbol{\gamma}_{[t]}^C
&=
\text{diag}\left(
\left(
\mathbf{S}_{[t-1]}^C
+
\mathbf{V}_{[t],1}^\top
\overrightarrow{\mathbf{K}_{[t]}}
+
\mathbf{V}_{[t],2}^\top
\overrightarrow{\mathbf{A}_{[t]}}
\right)^\top
\delta \mathbf{S}_{[t]}^C
\right)
=
\left(\boldsymbol{\gamma}_{[t]}^{C}\right)^{-1}
\odot
\text{diag}\left(
\mathbf{S}_{[t]}^{C\top}
\delta \mathbf{S}_{[t]}^C
\right)
\end{aligned}\]
\[\begin{aligned}
\left.\delta \mathbf{\Gamma}_{[t]}\right|_{\text{w/o extra} \boldsymbol{\gamma}_{[t]}^C}
&=
\delta \overleftarrow{\mathbf{Q}_{[t]}} \odot \mathbf{Q}_{[t]}
+
\delta \overleftarrow{\mathbf{B}_{[t]}} \odot \mathbf{B}_{[t]}
\\&-
\delta \overrightarrow{\mathbf{K}_{[t]}} \odot \mathbf{K}_{[t]} \oslash \mathbf{\Gamma}_{[t]} \oslash \mathbf{\Gamma}_{[t]}
-
\delta \overrightarrow{\mathbf{A}_{[t]}} \odot \mathbf{A}_{[t]} \oslash \mathbf{\Gamma}_{[t]} \oslash \mathbf{\Gamma}_{[t]}
\end{aligned}\]
特别的:
\[\begin{aligned}
\delta \log \mathbf{\Gamma}_{[t]}
&=
\delta \mathbf{\Gamma}_{[t]}
\odot
\mathbf{\Gamma}_{[t]}
\\&=
\delta \mathbf{Q}_{[t]} \odot \mathbf{Q}_{[t]}
+
\delta \mathbf{B}_{[t]} \odot \mathbf{B}_{[t]}
-
\delta \mathbf{K}_{[t]} \odot \mathbf{K}_{[t]}
-
\delta \mathbf{A}_{[t]} \odot \mathbf{A}_{[t]}
\\&+
[0,0,...,\text{diag}\left(
\mathbf{S}_{[t]}^{C\top}
\delta \mathbf{S}_{[t]}^C
\right)]
\end{aligned}\]
讨论
与KDA的关系
从递推模式的角度,很容易将DPLR归约到 KDA:
\[\begin{aligned}
\boldsymbol{a}_t &= \boldsymbol{k}_t
,\quad
\boldsymbol{b}_t = \beta_t \boldsymbol{k}_t
,\quad
\boldsymbol{v}_t = \beta_t \boldsymbol{v'}_t
\\
\\
\mathbf{S}_t
&= \mathbf{S}_{t-1}
\text{Diag}(\boldsymbol{\alpha}_t)
(\mathbf{I} - \boldsymbol{b}_t \boldsymbol{a}_t^\top)
+
\boldsymbol{v}_t \boldsymbol{k}_t^\top
\end{aligned}\]
这里我们选择分解b而非a是因为在分块递推模式下对应起来更为简单。其关系如下:
\[\begin{aligned}
\mathbf{E}_{[t]}
&=
\left(\mathbf{I} + \left( \overleftarrow{\mathbf{B}_{[t]}} \overrightarrow{\mathbf{A}_{[t]}}^\top \odot \mathbf{M}_{-1} \right) \right)^{-1}
=
\left( \mathbf{I} +
\text{Diag}(\boldsymbol{\beta}_{[t]})
\left( \overleftarrow{\mathbf{K}_{[t]}} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M}_{-1} \right)
\right)^{-1}
\\
\\
\mathbf{C}_{[t]} &=
- \mathbf{E}_{[t]}
\left(
\overleftarrow{\mathbf{B}_{[t]}}
\overrightarrow{\mathbf{K}_{[t]}}^\top
\odot
\mathbf{M}_{-1}
\right)
\mathbf{V}_{[t]}
=
- \mathbf{E}_{[t]}
\text{Diag}(\boldsymbol{\beta}_{[t]})
\left(
\overleftarrow{\mathbf{K}_{[t]}}
\overrightarrow{\mathbf{K}_{[t]}}^\top
\odot
\mathbf{M}_{-1}
\right)
\mathbf{V}_{[t]}
\\
\\
\mathbf{U}_{[t]} &:=
\mathbf{V}_{[t]} + \mathbf{C}_{[t]}
=
\left(
\mathbf{I}
- \mathbf{E}_{[t]}
\text{Diag}(\boldsymbol{\beta}_{[t]})
\left(
\overleftarrow{\mathbf{K}_{[t]}}
\overrightarrow{\mathbf{K}_{[t]}}^\top
\odot
\mathbf{M}_{-1}
\right)
\right)
\mathbf{V}_{[t]}
\\&=
\mathbf{E}_{[t]}
\left(
\mathbf{E}_{[t]}^{-1}
-
\text{Diag}(\boldsymbol{\beta}_{[t]})
\left(
\overleftarrow{\mathbf{K}_{[t]}}
\overrightarrow{\mathbf{K}_{[t]}}^\top
\odot
\mathbf{M}_{-1}
\right)
\right)
\mathbf{V}_{[t]}
=
\mathbf{E}_{[t]} \mathbf{V}_{[t]}
\\
\\
\mathbf{V}_{[t],new}
&:= \mathbf{V}_{[t],1} + \mathbf{V}_{[t],2}
= \mathbf{V}_{[t]} + \mathbf{C}_{[t]} - \mathbf{W}_{[t]} \mathbf{S}_{[t-1]}^{C \top}
= \mathbf{U}_{[t]} - \mathbf{W}_{[t]} \mathbf{S}_{[t-1]}^{C \top}
\end{aligned}\]
不包含外挂Gate的DPLR
我们可以将对角阵\alpha吸收到 b中.
\[\begin{aligned}
\boldsymbol{b}_t &= \text{Diag}(\boldsymbol{\alpha_t})^{-1} \boldsymbol{d}_t
\\
\\
\mathbf{S}_t
&= \mathbf{S}_{t-1}
\text{Diag}(\boldsymbol{\alpha}_t)
(\mathbf{I} - \boldsymbol{b}_t \boldsymbol{a}_t^\top)
+
\boldsymbol{v}_t \boldsymbol{k}_t^\top
= \mathbf{S}_{t-1}
(\text{Diag}(\boldsymbol{\alpha}_t) - \boldsymbol{d}_t \boldsymbol{a}_t^\top)
+
\boldsymbol{v}_t \boldsymbol{k}_t^\top
\end{aligned}\]
于是我们有
\[\begin{aligned}
\mathbf{F}_{[t]}
&=
\mathbf{I} + \left(\mathbf{D}_{[t]} \overrightarrow{\mathbf{A}_{[t]}}^\top \odot \mathbf{M}_{-1} \right)
,\quad
\mathbf{E}_{[t]} = \mathbf{F}_{[t]}^{-1}
\\
\\
\mathbf{W}_{[t]} &= \mathbf{E}_{[t]} \overleftarrow{\mathbf{B}_{[t]}}
,\quad
\mathbf{C}_{[t]} =
- \mathbf{E}_{[t]}
\left(
\mathbf{D}_{[t]}
\overrightarrow{\mathbf{K}_{[t]}}^\top
\odot
\mathbf{M}_{-1}
\right)
\mathbf{V}_{[t]}
\\
\\
\mathbf{V}_{[t],1} &:= \mathbf{V}_{[t]}
,\quad
\mathbf{V}_{[t],2} := \left(\mathbf{C}_{[t]} - \mathbf{W}_{[t]} \mathbf{S}_{[t-1]}^{C \top} \right)
\\
\\
\mathbf{S}_{[t]}^C
&=
\mathbf{S}_{[t-1]}^C
\text{Diag}(\boldsymbol{\gamma}_{[t]}^C)
+
\mathbf{V}_{[t],1}^\top
\overrightarrow{\mathbf{K}_{[t]}}
\text{Diag}(\boldsymbol{\gamma}_{[t]}^C)
+
\mathbf{V}_{[t],2}^\top
\overrightarrow{\mathbf{A}_{[t]}}
\text{Diag}(\boldsymbol{\gamma}_{[t]}^C)
\\
\\
\mathbf{O}_{[t]}
&=
\overleftarrow{\mathbf{Q}_{[t]}}
\mathbf{S}_{[t-1]}^{C\top}
+
\left(
\overleftarrow{\mathbf{Q}_{[t]}}
\overrightarrow{\mathbf{K}_{[t]}}^\top
\odot \mathbf{M} \right) \mathbf{V}_{[t],1}
+
\left(
\overleftarrow{\mathbf{Q}_{[t]}}
\overrightarrow{\mathbf{A}_{[t]}}^\top
\odot \mathbf{M} \right)
\mathbf{V}_{[t],2}
\\
\\
\mathbf{\Gamma}_{[t]} &= [
\boldsymbol{\gamma}_{[t]}^1,
\boldsymbol{\gamma}_{[t]}^2,
...,
\boldsymbol{\gamma}_{[t]}^C
]^\top
\\
\\
\overleftarrow{\square_{[t]}}
&=
\square_{[t]}
\odot
\mathbf{\Gamma}_{[t]}
,\quad
\overrightarrow{\square_{[t]}}
=
\square_{[t]}
\oslash
\mathbf{\Gamma}_{[t]}
\quad\text{for}\quad \square \in \{ \mathbf{Q}, \mathbf{K}\}
\end{aligned}\]