跳转至

Diagonal Plus Low Rank

代码: https://github.com/fla-org/flash-linear-attention
声明: These are personal reading notes. Some derivations are my own and may be incorrect, so they should be cross-checked against the code later.

一、符号约定

  1. 使用 \(\mathbf{S, Q}\) 等粗体大写字母表示矩阵
  2. 使用 \(\mathbf{q}_t, \mathbf{k}_t\) 等表示列向量(即 \([d, 1]\) 的形式),矩阵则是 \([L, d]\) 的形式,因此会有额外的转置操作
  3. 使用 \(W_t\) 等表示可学习参数
  4. 使用 \(\mathbf{q}_t\) 表示 \(\mathbf{Q}\) 的第 \(t\)
  5. \(\square_{[t]} = \square_{[t]}^{1:C} \in \mathbb{R}^{C \times d}\) 表示第 \(t\) 个 chunk,其中 \(\square \in { \mathbf{Q, K, V, \dots} }\)

二、带Gate的PRLR前向过程推导

原始公式:

\[\begin{aligned} \mathbf{S}_t &= \mathbf{S}_{t-1} \text{Diag}(\boldsymbol{\alpha}_t) (\mathbf{I} - \boldsymbol{b}_t \boldsymbol{a}_t^\top) + \boldsymbol{v}_t \boldsymbol{k}_t^\top \in \mathbb{R}^{d_v \times d_k} \\ \\ \boldsymbol{o}_t &= \mathbf{S}_{t}\boldsymbol{q}_t \end{aligned}\]

接下来,遵循 DeltaNet 的推导,定义块级符号和以下辅助量:

\[\begin{aligned} \mathbf{P}_{[t]}^{r} &= \prod_{i=t C + 1}^{t C + r} \text{Diag}(\boldsymbol{\alpha}_i)(\mathbf{I} - \boldsymbol{b}_{i} \boldsymbol{a}_{i}^{\top}) \in \mathbb{R}^{d_k \times d_k} \\ \\ \mathbf{H}_{[t]}^{r} &= \sum_{i=tC + 1}^{tC + r} (\boldsymbol{v}_{i} \boldsymbol{k}_{i}^{\top}) \prod_{j=i + 1}^{t C + r} \text{Diag}(\boldsymbol{\alpha}_j)(\mathbf{I} - \boldsymbol{b}_{j} \boldsymbol{a}_{j}^{\top}) \in \mathbb{R}^{d_v \times d_k} \end{aligned}\]

然后,块级状态可以写为:

\[\begin{aligned} \mathbf{S}_{[t]}^{r} = \mathbf{S}_{[t-1]}^{C} \mathbf{P}_{[t]}^{r} + \mathbf{H}_{[t]}^{r} , \quad \text{where } \mathbf{S}_{[-1]}^{C} = \mathbf{0} \end{aligned}\]

同时,形如 (I - beta_t k_t k_t^T)Householder变换 的乘积,总是可以使用 WY表示法 写成低秩形式。因此我们进一步推导它,再次使用几乎相同的归纳过程。

当 k = 0 时,我们有:

\[\begin{aligned} \mathbf{P}_{[t]}^{r} = \prod_{i=t C + 1}^{t C + r} \text{Diag}(\boldsymbol{\alpha}_{[t]}^i) = \text{Diag}(\boldsymbol{\gamma}_{[t]}^r) \end{aligned}\]

因此我们假设:

\[\begin{aligned} \mathbf{P}_{[t]}^{r} = \text{Diag}(\boldsymbol{\gamma}_{[t]}^r) - \sum_{i=1}^{r} \text{Diag}(\boldsymbol{\eta}_{[t]}^r) \text{Diag}(\boldsymbol{\xi}_{[t]}^i) \boldsymbol{w}_{[t]}^{i} \boldsymbol{a}_{[t]}^{i \top} \text{Diag}(\boldsymbol{\epsilon}_{[t]}^i) \text{Diag}(\boldsymbol{\lambda}_{[t]}^r) \end{aligned}\]

然后我们有:

\[\begin{aligned} \mathbf{P}_{[t]}^{r} &= \mathbf{P}_{[t]}^{r-1} \text{Diag}(\boldsymbol{\alpha}_{[t]}^r) (\mathbf{I} - \boldsymbol{b}_{[t]}^r \boldsymbol{a}_{[t]}^{r\top}) \\ \\&= \text{Diag}(\boldsymbol{\gamma}_{[t]}^r) - \text{Diag}(\boldsymbol{\gamma}_{[t]}^r) \boldsymbol{b}_{[t]}^r \boldsymbol{a}_{[t]}^{r\top} \\ \\&- \sum_{i=1}^{r-1} \text{Diag}(\boldsymbol{\eta}_{[t]}^{r-1}) \text{Diag}(\boldsymbol{\xi}_{[t]}^i) \boldsymbol{w}_{[t]}^{i} \boldsymbol{a}_{[t]}^{i \top} \text{Diag}(\boldsymbol{\epsilon}_{[t]}^i) \text{Diag}(\boldsymbol{\lambda}_{[t]}^{r-1}) \\ \\&+ \sum_{i=1}^{r-1} \text{Diag}(\boldsymbol{\eta}_{[t]}^{r-1}) \text{Diag}(\boldsymbol{\xi}_{[t]}^i) \boldsymbol{w}_{[t]}^{i} \boldsymbol{a}_{[t]}^{i \top} \text{Diag}(\boldsymbol{\epsilon}_{[t]}^i) \text{Diag}(\boldsymbol{\lambda}_{[t]}^{r-1}) \text{Diag}(\boldsymbol{\alpha}_{[t]}^r) \boldsymbol{b}_{[t]}^r \boldsymbol{a}_{[t]}^{r\top} \end{aligned}\]

通过消去同类项、如下设置参数并将 \xi 吸收进 w,我们得到:

\[\begin{aligned} \text{Diag}(\boldsymbol{\lambda}_{[t]}^i) &= \text{Diag}(\boldsymbol{\gamma}_{[t]}^i) = \prod_{j=1}^i \text{Diag}(\boldsymbol{\alpha}_{[t]}^j) ,\quad \text{Diag}(\boldsymbol{\eta}_{[t]}^i) = \mathbf{I} ,\quad \text{Diag}(\boldsymbol{\epsilon}_{[t]}^i) = \text{Diag}(\boldsymbol{\gamma}_{[t]}^i)^{-1} \\ \\ \boldsymbol{w}_{[t]}^{r} &= \left(\mathbf{I} - \sum_{i=1}^{r-1} \boldsymbol{w}_{[t]}^{i} \left( \text{Diag}(\boldsymbol{\gamma}_{[t]}^i)^{-1} \boldsymbol{a}_{[t]}^{i} \right)^\top \right) \text{Diag}(\boldsymbol{\gamma}_{[t]}^r) \boldsymbol{b}_{[t]}^r \\ \\ \mathbf{P}_{[t]}^{r} &= \left(\mathbf{I}- \sum_{i=1}^{r} \boldsymbol{w}_{[t]}^{i} \left( \text{Diag}(\boldsymbol{\gamma}_{[t]}^i)^{-1} \boldsymbol{a}_{[t]}^{i} \right)^\top \right) \text{Diag}(\boldsymbol{\gamma}_{[t]}^r) \end{aligned}\]

H 的前传是与之前的推导最为不同的部分,这里我们需要假定S 可以被表达为含有两项的递推式.

\[\begin{aligned} \mathbf{S}_t = \sum_{i=1}^{t} \text{Diag}(\boldsymbol{\eta}_t) \boldsymbol{u}_i \boldsymbol{k}_i^\top \text{Diag}(\boldsymbol{\epsilon}_i) \text{Diag}(\boldsymbol{\gamma}_t) + \sum_{i=1}^{t} \text{Diag}(\boldsymbol{\eta^1}_t) \boldsymbol{c}_i \boldsymbol{a}_i^\top \text{Diag}(\boldsymbol{\epsilon^1}_i) \text{Diag}(\boldsymbol{\gamma^1}_t) \end{aligned}\]

由数学归纳法我们得到:

\[\begin{aligned} \mathbf{S}_t &= \mathbf{S}_{t-1} \text{Diag}(\boldsymbol{\alpha}_t) - \mathbf{S}_{t-1} \text{Diag}(\boldsymbol{\alpha}_t) \boldsymbol{b}_t \boldsymbol{a}_t^\top + \boldsymbol{v}_t \boldsymbol{k}_t^\top \\&= \sum_{i=1}^{t-1} \text{Diag}(\boldsymbol{\eta}_{t-1}) \boldsymbol{u}_i \boldsymbol{k}_i^\top \text{Diag}(\boldsymbol{\epsilon}_i) \text{Diag}(\boldsymbol{\gamma}_{t-1}) \text{Diag}(\boldsymbol{\alpha}_{t}) \\&+ \sum_{i=1}^{t-1} \text{Diag}(\boldsymbol{\eta^1}_{t-1}) \boldsymbol{c}_i \boldsymbol{a}_i^\top \text{Diag}(\boldsymbol{\epsilon^1}_i) \text{Diag}(\boldsymbol{\gamma^1}_{t-1}) \text{Diag}(\boldsymbol{\alpha}_{t}) \\&- \sum_{i=1}^{t-1} \text{Diag}(\boldsymbol{\eta}_{t-1}) \boldsymbol{u}_i \boldsymbol{k}_i^\top \text{Diag}(\boldsymbol{\epsilon}_i) \text{Diag}(\boldsymbol{\gamma}_{t-1}) \text{Diag}(\boldsymbol{\alpha}_{t}) \boldsymbol{b}_t \boldsymbol{a}_t^\top \\&- \sum_{i=1}^{t-1} \text{Diag}(\boldsymbol{\eta^1}_{t-1}) \boldsymbol{c}_i \boldsymbol{a}_i^\top \text{Diag}(\boldsymbol{\epsilon^1}_i) \text{Diag}(\boldsymbol{\gamma^1}_{t-1}) \text{Diag}(\boldsymbol{\alpha}_{t}) \boldsymbol{b}_t \boldsymbol{a}_t^\top \\&+ \boldsymbol{v}_t \boldsymbol{k}_t^\top \\&= \sum_{i=1}^{t} \text{Diag}(\boldsymbol{\eta}_{t}) \boldsymbol{u}_i \boldsymbol{k}_i^\top \text{Diag}(\boldsymbol{\epsilon}_i) \text{Diag}(\boldsymbol{\gamma}_{t}) \\&+ \sum_{i=1}^{t} \text{Diag}(\boldsymbol{\eta^1}_{t}) \boldsymbol{c}_i \boldsymbol{a}_i^\top \text{Diag}(\boldsymbol{\epsilon^1}_i) \text{Diag}(\boldsymbol{\gamma^1}_{t}) \end{aligned}\]

我们定义以下辅助量:

\[\begin{aligned} \text{Diag}(\boldsymbol{\gamma}_t) &= \text{Diag}(\boldsymbol{\gamma^1}_t) = \prod_{i=1}^t \text{Diag}(\boldsymbol{\alpha}_i) ,\quad \text{Diag}(\boldsymbol{\epsilon}_t) = \text{Diag}(\boldsymbol{\epsilon^1}_t) = \text{Diag}(\boldsymbol{\gamma}_t)^{-1} \\ \\ \text{Diag}(\boldsymbol{\eta}_t) &= \text{Diag}(\boldsymbol{\eta^1}_t) = \mathbf{I} \end{aligned}\]

然后,我们就可以得到:

\[\begin{aligned} \boldsymbol{u}_{[t]}^r &= \boldsymbol{v}_{[t]}^r \\ \\ \boldsymbol{c}_{[t]}^r &= - \sum_{i=1}^{r-1} \left( \boldsymbol{v}_i \left(\text{Diag}(\boldsymbol{\gamma}_{[t]}^i)^{-1} \boldsymbol{k}_{[t]}^i\right)^\top + \boldsymbol{c}_{[t]}^i \left(\text{Diag}(\boldsymbol{\gamma}_{[t]}^i)^{-1} \boldsymbol{a}_{[t]}^i\right)^\top \right) \left(\text{Diag}(\boldsymbol{\gamma}_{[t]}^r) \boldsymbol{b}_{[t]}^r\right) \end{aligned}\]

So 可以分别表示为:

\[\begin{aligned} \mathbf{S}_{[t]}^r &= \mathbf{S}_{[t-1]}^C \left(\mathbf{I}- \sum_{i=1}^{r} \boldsymbol{w}_{[t]}^{i} \left( \text{Diag}(\boldsymbol{\gamma}_{[t]}^i)^{-1} \boldsymbol{a}_{[t]}^{i} \right)^\top \right) \text{Diag}(\boldsymbol{\gamma}_{[t]}^r) \\&+ \sum_{i=1}^{r} \boldsymbol{v}_i \left(\text{Diag}(\boldsymbol{\gamma}_{[t]}^i)^{-1} \boldsymbol{k}_{[t]}^i\right)^\top \text{Diag}(\boldsymbol{\gamma}_{[t]}^r) + \sum_{i=1}^{r} \boldsymbol{c}_i \left(\text{Diag}(\boldsymbol{\gamma}_{[t]}^i)^{-1} \boldsymbol{a}_{[t]}^i\right)^\top \text{Diag}(\boldsymbol{\gamma}_{[t]}^r) \\ &= \mathbf{S}_{[t-1]}^C \text{Diag}(\boldsymbol{\gamma}_{[t]}^r) + \sum_{i=1}^{r} \boldsymbol{v}_i \left(\text{Diag}(\boldsymbol{\gamma}_{[t]}^i)^{-1} \boldsymbol{k}_{[t]}^i\right)^\top \text{Diag}(\boldsymbol{\gamma}_{[t]}^r) \\&+ \sum_{i=1}^{r} \left( \boldsymbol{c}_i - \mathbf{S}_{[t-1]}^C \boldsymbol{w}_{[t]}^{i} \right) \left(\text{Diag}(\boldsymbol{\gamma}_{[t]}^i)^{-1} \boldsymbol{a}_{[t]}^i\right)^\top \text{Diag}(\boldsymbol{\gamma}_{[t]}^r) \\ \\ \boldsymbol{o}_{[t]}^r &= \mathbf{S}_{[t-1]}^C \text{Diag}(\boldsymbol{\gamma}_{[t]}^r) \boldsymbol{q}_{[t]}^r + \sum_{i=1}^{r} \boldsymbol{v}_i \left(\text{Diag}(\boldsymbol{\gamma}_{[t]}^i)^{-1} \boldsymbol{k}_{[t]}^i\right)^\top \text{Diag}(\boldsymbol{\gamma}_{[t]}^r) \boldsymbol{q}_{[t]}^r \\&+ \sum_{i=1}^{r} \left( \boldsymbol{c}_{[t]}^i - \mathbf{S}_{[t-1]}^C \boldsymbol{w}_{[t]}^{i} \right) \left(\text{Diag}(\boldsymbol{\gamma}_{[t]}^i)^{-1} \boldsymbol{a}_{[t]}^i\right)^\top \text{Diag}(\boldsymbol{\gamma}_{[t]}^r) \boldsymbol{q}_{[t]}^r \end{aligned}\]

定义

\[\begin{aligned} \mathbf{\Gamma}_{[t]} &= [ \boldsymbol{\gamma}_{[t]}^1, \boldsymbol{\gamma}_{[t]}^2, ..., \boldsymbol{\gamma}_{[t]}^C ]^\top \\ \\ \overleftarrow{\square_{[t]}} &= \square_{[t]} \odot \mathbf{\Gamma}_{[t]} ,\quad \overrightarrow{\square_{[t]}} = \square_{[t]} \oslash \mathbf{\Gamma}_{[t]} \quad\text{for}\quad \square \in \{ \mathbf{Q}, \mathbf{K}, \mathbf{A}, \mathbf{B}\} \end{aligned}\]

然后我们可以将计算重写为矩阵形式:

\[\begin{aligned} \mathbf{S}_{[t]}^C &= \mathbf{S}_{[t-1]}^C \text{Diag}(\boldsymbol{\gamma}_{[t]}^C) + \mathbf{V}_{[t]}^\top \overrightarrow{\mathbf{K}_{[t]}} \text{Diag}(\boldsymbol{\gamma}_{[t]}^C) + \left( \mathbf{C}_{[t]} - \mathbf{W}_{[t]} \mathbf{S}_{[t-1]}^{C\top} \right)^\top \overrightarrow{\mathbf{A}_{[t]}} \text{Diag}(\boldsymbol{\gamma}_{[t]}^C) \\ \\ \mathbf{O}_{[t]} &= \overleftarrow{\mathbf{Q}_{[t]}} \mathbf{S}_{[t-1]}^{C\top} + \left( \overleftarrow{\mathbf{Q}_{[t]}} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M} \right) \mathbf{V}_{[t]} + \left( \overleftarrow{\mathbf{Q}_{[t]}} \overrightarrow{\mathbf{A}_{[t]}}^\top \odot \mathbf{M} \right) \left( \mathbf{C}_{[t]} - \mathbf{W}_{[t]} \mathbf{S}_{[t-1]}^{C\top} \right) \end{aligned}\]

接下来,我们转向 u_[t]w_[t]。这里,M_{-1} = M - I 表示对角线上为零的严格下三角掩码。

u_[t] 的递推式出发,我们得到:

\[\begin{aligned} \boldsymbol{w}_{[t]}^r &= \left( \mathbf{I} - \sum_{i=1}^{r-1} \boldsymbol{w}_{[t]}^i \left(\text{Diag}(\boldsymbol{\gamma}_{[t]}^i)^{-1} \boldsymbol{a}_{[t]}^i\right)^\top \right) \left(\text{Diag}(\boldsymbol{\gamma}_{[t]}^r) \boldsymbol{b}_{[t]}^r\right) \\ \\ \Rightarrow \mathbf{W}_{[t]} &= \overleftarrow{\mathbf{B}_{[t]}} - \left( \overleftarrow{\mathbf{B}_{[t]}} \overrightarrow{\mathbf{A}_{[t]}}^\top \odot \mathbf{M}_{-1} \right) \mathbf{W}_{[t]} \\ \\ \Rightarrow \mathbf{W}_{[t]} &= \left( \mathbf{I} + \left( \overleftarrow{\mathbf{B}_{[t]}} \overrightarrow{\mathbf{A}_{[t]}}^\top \odot \mathbf{M}_{-1} \right) \right)^{-1} \overleftarrow{\mathbf{B}_{[t]}} \end{aligned}\]

基于相同的理由,我们也有:

\[\begin{aligned} \mathbf{C}_{[t]} &= - \left( \overleftarrow{\mathbf{B}_{[t]}} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M}_{-1} \right) \mathbf{V}_{[t]} - \left( \overleftarrow{\mathbf{B}_{[t]}} \overrightarrow{\mathbf{A}_{[t]}}^\top \odot \mathbf{M}_{-1} \right) \mathbf{C}_{[t]} \\ \\ \Rightarrow \mathbf{C}_{[t]} &= - \left( \mathbf{I} + \left( \overleftarrow{\mathbf{B}_{[t]}} \overrightarrow{\mathbf{A}_{[t]}}^\top \odot \mathbf{M}_{-1} \right) \right)^{-1} \left( \overleftarrow{\mathbf{B}_{[t]}} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M}_{-1} \right) \mathbf{V}_{[t]} \end{aligned}\]

因此,为了计算 U_[t]W_[t],关键的量是:

\[\begin{aligned} \mathbf{E}_{[t]} = \left(\mathbf{I} + \left( \overleftarrow{\mathbf{B}_{[t]}} \overrightarrow{\mathbf{A}_{[t]}}^\top \odot \mathbf{M}_{-1} \right) \right)^{-1} \end{aligned}\]

矩阵求逆步骤可以按照与 DeltaNet 笔记中相同的方法进行处理。

三、带Gate的PRLR梯度反传推导

回顾前向传播公式

\[\begin{aligned} \mathbf{F}_{[t]} &= \mathbf{I} + \left( \overleftarrow{\mathbf{B}_{[t]}} \overrightarrow{\mathbf{A}_{[t]}}^\top \odot \mathbf{M}_{-1} \right) ,\quad \mathbf{E}_{[t]} = \mathbf{F}_{[t]}^{-1} \\ \\ \mathbf{W}_{[t]} &= \mathbf{E}_{[t]} \overleftarrow{\mathbf{B}_{[t]}} ,\quad \mathbf{C}_{[t]} = - \mathbf{E}_{[t]} \left( \overleftarrow{\mathbf{B}_{[t]}} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M}_{-1} \right) \mathbf{V}_{[t]} \\ \\ \mathbf{V}_{[t],1} &:= \mathbf{V}_{[t]} ,\quad \mathbf{V}_{[t],2} := \left(\mathbf{C}_{[t]} - \mathbf{W}_{[t]} \mathbf{S}_{[t-1]}^{C \top} \right) \\ \\ \mathbf{S}_{[t]}^C &= \mathbf{S}_{[t-1]}^C \text{Diag}(\boldsymbol{\gamma}_{[t]}^C) + \mathbf{V}_{[t],1}^\top \overrightarrow{\mathbf{K}_{[t]}} \text{Diag}(\boldsymbol{\gamma}_{[t]}^C) + \mathbf{V}_{[t],2}^\top \overrightarrow{\mathbf{A}_{[t]}} \text{Diag}(\boldsymbol{\gamma}_{[t]}^C) \\ \\ \mathbf{O}_{[t]} &= \overleftarrow{\mathbf{Q}_{[t]}} \mathbf{S}_{[t-1]}^{C\top} + \left( \overleftarrow{\mathbf{Q}_{[t]}} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M} \right) \mathbf{V}_{[t],1} + \left( \overleftarrow{\mathbf{Q}_{[t]}} \overrightarrow{\mathbf{A}_{[t]}}^\top \odot \mathbf{M} \right) \mathbf{V}_{[t],2} \\ \\ \mathbf{\Gamma}_{[t]} &= [ \boldsymbol{\gamma}_{[t]}^1, \boldsymbol{\gamma}_{[t]}^2, ..., \boldsymbol{\gamma}_{[t]}^C ]^\top \\ \\ \overleftarrow{\square_{[t]}} &= \square_{[t]} \odot \mathbf{\Gamma}_{[t]} ,\quad \overrightarrow{\square_{[t]}} = \square_{[t]} \oslash \mathbf{\Gamma}_{[t]} \quad\text{for}\quad \square \in \{ \mathbf{Q}, \mathbf{K}\} \end{aligned}\]

对于 \(\delta \mathbf{C}_{[t]}\), \(\delta \mathbf{V}_{[t],1}\), \(\delta \mathbf{W}_{[t]}\)

\[\begin{aligned} \delta \mathbf{C}_{[t]} &= \delta \mathbf{V}_{[t],2} = \overrightarrow{\mathbf{A}_{[t]}} \text{Diag}(\boldsymbol{\gamma}_{[t]}^C) \delta \mathbf{S}_{[t]}^{C \top} + \left( \overleftarrow{\mathbf{Q}_{[t]}} \overrightarrow{\mathbf{A}_{[t]}}^\top \odot \mathbf{M} \right)^\top \delta \mathbf{O}_{[t]} \\ \\ \delta \mathbf{V}_{[t],1} & = \overrightarrow{\mathbf{K}_{[t]}} \text{Diag}(\boldsymbol{\gamma}_{[t]}^C) \delta \mathbf{S}_{[t]}^{C \top} + \left( \overleftarrow{\mathbf{Q}_{[t]}} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M} \right)^\top \delta \mathbf{O}_{[t]} \\ \\ \delta \mathbf{W}_{[t]} &= - \delta \mathbf{C}_{[t]} \mathbf{S}_{[t]}^{C} \end{aligned}\]

对于 \(\delta \mathbf{E}_{[t]}\), \(\delta \mathbf{F}_{[t]}\)

\[\begin{aligned} \delta \mathbf{E}_{[t]} &= - \delta \mathbf{C}_{[t]} \mathbf{V}_{[t]}^\top \left( \overleftarrow{\mathbf{B}_{[t]}} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M}_{-1} \right)^\top + \delta \mathbf{W}_{[t]} \overleftarrow{\mathbf{B}_{[t]}}^\top \\ \\ \delta \mathbf{F}_{[t]} &= - \mathbf{E}_{[t]}^\top \delta \mathbf{E}_{[t]} \mathbf{E}_{[t]}^\top \end{aligned}\]

对于 \(\delta \mathbf{S}_{[t]}^C\)

\[\begin{aligned} \delta \mathbf{S}_{[t-1]}^{C} &= -\delta \mathbf{C}_{[t]}^\top \mathbf{W}_{[t]} + \delta \mathbf{S}_{[t]}^{C} \text{Diag}(\boldsymbol{\gamma}_{[t]}^C) + \delta \mathbf{O}_{[t]}^\top \overleftarrow{\mathbf{Q}_{[t]}} \\ \\ \Rightarrow \delta \mathbf{S}_{[t]}^{C} &= \delta \mathbf{S}_{[t+1]}^{C} \text{Diag}(\boldsymbol{\gamma}_{[t+1]}^C) + \delta \mathbf{O}_{[t+1]}^\top \overleftarrow{\mathbf{Q}_{[t+1]}} -\delta \mathbf{C}_{[t+1]}^\top \mathbf{W}_{[t+1]} \end{aligned}\]

对于 \(\delta \mathbf{Q}_{[t]}\)

\[\begin{aligned} \delta \overleftarrow{\mathbf{Q}_{[t]}} &= \delta \mathbf{O}_{[t]} \mathbf{S}_{[t-1]}^C + \left(\delta \mathbf{O}_{[t]} \mathbf{V}_{[t],1}^\top \odot \mathbf{M}\right) \overrightarrow{\mathbf{K}_{[t]}} + \left(\delta \mathbf{O}_{[t]} \mathbf{V}_{[t],2}^\top \odot \mathbf{M}\right) \overrightarrow{\mathbf{A}_{[t]}} \\ \\ \delta \mathbf{Q}_{[t]} &= \delta \overleftarrow{\mathbf{Q}_{[t]}} \odot \mathbf{\Gamma}_{[t]} \end{aligned}\]

对于 \(\delta \mathbf{B}_{[t]}\)

\[\begin{aligned} \delta \overleftarrow{\mathbf{B}_{[t]}} &= \mathbf{E}_{[t]}^\top \delta \mathbf{W}_{[t]} - \left( \mathbf{E}_{[t]}^\top \delta \mathbf{C}_{[t]} \mathbf{V}_{[t]}^\top \odot \mathbf{M}_{-1} \right) \overrightarrow{\mathbf{K}_{[t]}} + \left( \delta \mathbf{F}_{[t]} \odot \mathbf{M}_{-1} \right) \overrightarrow{\mathbf{A}_{[t]}} \\ \\ \delta \mathbf{B}_{[t]} &= \delta \overleftarrow{\mathbf{Q}_{[t]}} \odot \mathbf{\Gamma}_{[t]} \end{aligned}\]

对于 \(\delta \mathbf{K}_{[t]}\)

\[\begin{aligned} \delta \overrightarrow{\mathbf{K}_{[t]}} &= \left(\delta \mathbf{O}_{[t]} \mathbf{V}_{[t],1}^\top \odot \mathbf{M}\right)^\top \overleftarrow{\mathbf{Q}_{[t]}} - \left( \mathbf{E}_{[t]}^\top \delta \mathbf{C}_{[t]} \mathbf{V}_{[t]}^\top \odot \mathbf{M}_{-1} \right)^\top \overleftarrow{\mathbf{B}_{[t]}} \\ \\ \delta \mathbf{K}_{[t]} &= \delta \overrightarrow{\mathbf{K}_{[t]}} \oslash \mathbf{\Gamma}_{[t]} \end{aligned}\]

对于 \(\delta \mathbf{A}_{[t]}\)

\[\begin{aligned} \delta \overrightarrow{\mathbf{A}_{[t]}} &= \left(\delta \mathbf{O}_{[t]} \mathbf{V}_{[t],2}^\top \odot \mathbf{M}\right)^\top \overleftarrow{\mathbf{Q}_{[t]}} + \left( \delta \mathbf{F}_{[t]} \odot \mathbf{M}_{-1} \right)^\top \overleftarrow{\mathbf{B}_{[t]}} \\ \\ \delta \mathbf{A}_{[t]} &= \delta \overrightarrow{\mathbf{A}_{[t]}} \oslash \mathbf{\Gamma}_{[t]} \end{aligned}\]

对于 \(\delta \mathbf{V}_{[t]}\)

\[\begin{aligned} \delta \mathbf{V}_{[t]} = \delta \mathbf{V}_{[t],1} - \left( \overleftarrow{\mathbf{B}_{[t]}} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M}_{-1} \right)^\top \mathbf{E}_{[t]} ^\top \delta \mathbf{C}_{[t]} \end{aligned}\]

对于 \(\delta \mathbf{\Gamma}_{[t]}\)

\[\begin{aligned} \delta \boldsymbol{\gamma}_{[t]}^C &= \text{diag}\left( \left( \mathbf{S}_{[t-1]}^C + \mathbf{V}_{[t],1}^\top \overrightarrow{\mathbf{K}_{[t]}} + \mathbf{V}_{[t],2}^\top \overrightarrow{\mathbf{A}_{[t]}} \right)^\top \delta \mathbf{S}_{[t]}^C \right) = \left(\boldsymbol{\gamma}_{[t]}^{C}\right)^{-1} \odot \text{diag}\left( \mathbf{S}_{[t]}^{C\top} \delta \mathbf{S}_{[t]}^C \right) \end{aligned}\]
\[\begin{aligned} \left.\delta \mathbf{\Gamma}_{[t]}\right|_{\text{w/o extra} \boldsymbol{\gamma}_{[t]}^C} &= \delta \overleftarrow{\mathbf{Q}_{[t]}} \odot \mathbf{Q}_{[t]} + \delta \overleftarrow{\mathbf{B}_{[t]}} \odot \mathbf{B}_{[t]} \\&- \delta \overrightarrow{\mathbf{K}_{[t]}} \odot \mathbf{K}_{[t]} \oslash \mathbf{\Gamma}_{[t]} \oslash \mathbf{\Gamma}_{[t]} - \delta \overrightarrow{\mathbf{A}_{[t]}} \odot \mathbf{A}_{[t]} \oslash \mathbf{\Gamma}_{[t]} \oslash \mathbf{\Gamma}_{[t]} \end{aligned}\]

特别的:

\[\begin{aligned} \delta \log \mathbf{\Gamma}_{[t]} &= \delta \mathbf{\Gamma}_{[t]} \odot \mathbf{\Gamma}_{[t]} \\&= \delta \mathbf{Q}_{[t]} \odot \mathbf{Q}_{[t]} + \delta \mathbf{B}_{[t]} \odot \mathbf{B}_{[t]} - \delta \mathbf{K}_{[t]} \odot \mathbf{K}_{[t]} - \delta \mathbf{A}_{[t]} \odot \mathbf{A}_{[t]} \\&+ [0,0,...,\text{diag}\left( \mathbf{S}_{[t]}^{C\top} \delta \mathbf{S}_{[t]}^C \right)] \end{aligned}\]

讨论

与KDA的关系

从递推模式的角度,很容易将DPLR归约到 KDA

\[\begin{aligned} \boldsymbol{a}_t &= \boldsymbol{k}_t ,\quad \boldsymbol{b}_t = \beta_t \boldsymbol{k}_t ,\quad \boldsymbol{v}_t = \beta_t \boldsymbol{v'}_t \\ \\ \mathbf{S}_t &= \mathbf{S}_{t-1} \text{Diag}(\boldsymbol{\alpha}_t) (\mathbf{I} - \boldsymbol{b}_t \boldsymbol{a}_t^\top) + \boldsymbol{v}_t \boldsymbol{k}_t^\top \end{aligned}\]

这里我们选择分解b而非a是因为在分块递推模式下对应起来更为简单。其关系如下:

\[\begin{aligned} \mathbf{E}_{[t]} &= \left(\mathbf{I} + \left( \overleftarrow{\mathbf{B}_{[t]}} \overrightarrow{\mathbf{A}_{[t]}}^\top \odot \mathbf{M}_{-1} \right) \right)^{-1} = \left( \mathbf{I} + \text{Diag}(\boldsymbol{\beta}_{[t]}) \left( \overleftarrow{\mathbf{K}_{[t]}} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M}_{-1} \right) \right)^{-1} \\ \\ \mathbf{C}_{[t]} &= - \mathbf{E}_{[t]} \left( \overleftarrow{\mathbf{B}_{[t]}} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M}_{-1} \right) \mathbf{V}_{[t]} = - \mathbf{E}_{[t]} \text{Diag}(\boldsymbol{\beta}_{[t]}) \left( \overleftarrow{\mathbf{K}_{[t]}} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M}_{-1} \right) \mathbf{V}_{[t]} \\ \\ \mathbf{U}_{[t]} &:= \mathbf{V}_{[t]} + \mathbf{C}_{[t]} = \left( \mathbf{I} - \mathbf{E}_{[t]} \text{Diag}(\boldsymbol{\beta}_{[t]}) \left( \overleftarrow{\mathbf{K}_{[t]}} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M}_{-1} \right) \right) \mathbf{V}_{[t]} \\&= \mathbf{E}_{[t]} \left( \mathbf{E}_{[t]}^{-1} - \text{Diag}(\boldsymbol{\beta}_{[t]}) \left( \overleftarrow{\mathbf{K}_{[t]}} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M}_{-1} \right) \right) \mathbf{V}_{[t]} = \mathbf{E}_{[t]} \mathbf{V}_{[t]} \\ \\ \mathbf{V}_{[t],new} &:= \mathbf{V}_{[t],1} + \mathbf{V}_{[t],2} = \mathbf{V}_{[t]} + \mathbf{C}_{[t]} - \mathbf{W}_{[t]} \mathbf{S}_{[t-1]}^{C \top} = \mathbf{U}_{[t]} - \mathbf{W}_{[t]} \mathbf{S}_{[t-1]}^{C \top} \end{aligned}\]

不包含外挂Gate的DPLR

我们可以将对角阵\alpha吸收到 b中.

\[\begin{aligned} \boldsymbol{b}_t &= \text{Diag}(\boldsymbol{\alpha_t})^{-1} \boldsymbol{d}_t \\ \\ \mathbf{S}_t &= \mathbf{S}_{t-1} \text{Diag}(\boldsymbol{\alpha}_t) (\mathbf{I} - \boldsymbol{b}_t \boldsymbol{a}_t^\top) + \boldsymbol{v}_t \boldsymbol{k}_t^\top = \mathbf{S}_{t-1} (\text{Diag}(\boldsymbol{\alpha}_t) - \boldsymbol{d}_t \boldsymbol{a}_t^\top) + \boldsymbol{v}_t \boldsymbol{k}_t^\top \end{aligned}\]

于是我们有

\[\begin{aligned} \mathbf{F}_{[t]} &= \mathbf{I} + \left(\mathbf{D}_{[t]} \overrightarrow{\mathbf{A}_{[t]}}^\top \odot \mathbf{M}_{-1} \right) ,\quad \mathbf{E}_{[t]} = \mathbf{F}_{[t]}^{-1} \\ \\ \mathbf{W}_{[t]} &= \mathbf{E}_{[t]} \overleftarrow{\mathbf{B}_{[t]}} ,\quad \mathbf{C}_{[t]} = - \mathbf{E}_{[t]} \left( \mathbf{D}_{[t]} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M}_{-1} \right) \mathbf{V}_{[t]} \\ \\ \mathbf{V}_{[t],1} &:= \mathbf{V}_{[t]} ,\quad \mathbf{V}_{[t],2} := \left(\mathbf{C}_{[t]} - \mathbf{W}_{[t]} \mathbf{S}_{[t-1]}^{C \top} \right) \\ \\ \mathbf{S}_{[t]}^C &= \mathbf{S}_{[t-1]}^C \text{Diag}(\boldsymbol{\gamma}_{[t]}^C) + \mathbf{V}_{[t],1}^\top \overrightarrow{\mathbf{K}_{[t]}} \text{Diag}(\boldsymbol{\gamma}_{[t]}^C) + \mathbf{V}_{[t],2}^\top \overrightarrow{\mathbf{A}_{[t]}} \text{Diag}(\boldsymbol{\gamma}_{[t]}^C) \\ \\ \mathbf{O}_{[t]} &= \overleftarrow{\mathbf{Q}_{[t]}} \mathbf{S}_{[t-1]}^{C\top} + \left( \overleftarrow{\mathbf{Q}_{[t]}} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M} \right) \mathbf{V}_{[t],1} + \left( \overleftarrow{\mathbf{Q}_{[t]}} \overrightarrow{\mathbf{A}_{[t]}}^\top \odot \mathbf{M} \right) \mathbf{V}_{[t],2} \\ \\ \mathbf{\Gamma}_{[t]} &= [ \boldsymbol{\gamma}_{[t]}^1, \boldsymbol{\gamma}_{[t]}^2, ..., \boldsymbol{\gamma}_{[t]}^C ]^\top \\ \\ \overleftarrow{\square_{[t]}} &= \square_{[t]} \odot \mathbf{\Gamma}_{[t]} ,\quad \overrightarrow{\square_{[t]}} = \square_{[t]} \oslash \mathbf{\Gamma}_{[t]} \quad\text{for}\quad \square \in \{ \mathbf{Q}, \mathbf{K}\} \end{aligned}\]
Comments