Reading Notes: Backward Pass for Kimi Delta Attention¶
Revisiting the Forward Equations¶
We first recall the key forward equations:
\[\begin{aligned}
\mathbf{V}_{[t],new} &:= \left(\mathbf{U}_{[t]} - \mathbf{W}_{[t]} \mathbf{S}_{[t-1]}^{C \top} \right)
\\
\\
\widetilde{\mathbf{X}}_{[t]} &= \mathbf{I} + \text{Diag}(\boldsymbol{\beta}_{[t]}) \left( \overleftarrow{\mathbf{K}_{[t]}} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M}_{-1} \right)
,\quad
\widetilde{\mathbf{A}}_{[t]}
= \widetilde{\mathbf{X}}_{[t]}^{-1}
\\
\\
\mathbf{W}_{[t]} &=
\widetilde{\mathbf{A}}_{[t]}
\text{Diag}(\boldsymbol{\beta}_{[t]})
\overleftarrow{\mathbf{K}_{[t]}}
,\quad
\mathbf{U}_{[t]} =
\widetilde{\mathbf{A}}_{[t]}
\text{Diag}(\boldsymbol{\beta}_{[t]})
\mathbf{V}_{[t]}
\\
\\
\mathbf{S}_{[t]}^{C}
&=
\mathbf{S}_{[t-1]}^{C}
\text{Diag}(\boldsymbol{\gamma}_{[t]}^C)
+ \mathbf{V}_{[t],new}^\top
\overrightarrow{\mathbf{K}_{[t]}}
\text{Diag}(\boldsymbol{\gamma}_{[t]}^C)
\\
\\
\mathbf{O}_{[t]}
&=
\overleftarrow{\mathbf{Q}_{[t]}}
\mathbf{S}_{[t-1]}^{C \top}
+ \left(
\overleftarrow{\mathbf{Q}_{[t]}}
\overrightarrow{\mathbf{K}_{[t]}}^\top
\odot \mathbf{M}
\right)
\mathbf{V}_{[t],new}
\\
\\
\mathbf{\Gamma}_{[t]} &= [
\boldsymbol{\gamma}_{[t]}^1,
\boldsymbol{\gamma}_{[t]}^2,
...,
\boldsymbol{\gamma}_{[t]}^C
]^\top
\\
\\
\overleftarrow{\square_{[t]}}
&=
\square_{[t]}
\odot
\mathbf{\Gamma}_{[t]}
,\quad
\overrightarrow{\square_{[t]}}
=
\square_{[t]}
\oslash
\mathbf{\Gamma}_{[t]}
\quad\text{for}\quad \square \in \{ \mathbf{Q}, \mathbf{K}\}
\end{aligned}\]
For \(\delta \mathbf{U}_{[t]}\)¶
The gradient with respect to \(\mathbf{U}_{[t]}\) is:
\[\begin{aligned}
\delta \mathbf{U}_{[t]}
=
\delta \mathbf{V}_{[t],new}
=
\overrightarrow{\mathbf{K}_{[t]}}
\text{Diag}(\boldsymbol{\gamma}_{[t]}^C)
\delta \mathbf{S}_{[t]}^{C \top}
+
\left(
\overleftarrow{\mathbf{Q}_{[t]}}
\overrightarrow{\mathbf{K}_{[t]}}^\top
\odot \mathbf{M}
\right)^\top
\delta \mathbf{O}_{[t]}
\end{aligned}\]
For \(\delta \mathbf{W}_{[t]}\)¶
From the definition of \(\mathbf{V}_{[t],new}\), we further obtain:
\[\begin{aligned}
\delta \mathbf{W}_{[t]}
= - \delta \mathbf{U}_{[t]} \mathbf{S}_{[t-1]}^{C}
\end{aligned}\]
For \(\delta \mathbf{S}_{[t]}^C\)¶
\[\begin{aligned}
\left.\delta \mathbf{S}_{[t-1]}^{C} \right|_{\text{from } \mathbf{V}_{[t],new}}
&=
-\delta \mathbf{U}_{[t]}^\top
\mathbf{W}_{[t]}
\\
\\
\left.\delta \mathbf{S}_{[t-1]}^{C} \right|_{\text{from } \mathbf{S}_{[t]}^{C} \text{w/} \mathbf{O}_{[t]} \text{w/o} \mathbf{V}_{[t],new}}
&=
\delta \mathbf{S}_{[t]}^{C}
\text{Diag}(\boldsymbol{\gamma}_{[t]}^C)
+ \delta \mathbf{O}_{[t]}^\top
\overleftarrow{\mathbf{Q}_{[t]}}
\end{aligned}\]
Combining the two terms, we obtain:
\[\begin{aligned}
\delta \mathbf{S}_{[t]}^{C}
&=
\delta \mathbf{S}_{[t+1]}^{C}
\text{Diag}(\boldsymbol{\gamma}_{[t+1]}^C)
+ \delta \mathbf{O}_{[t+1]}^\top
\overleftarrow{\mathbf{Q}_{[t+1]}}
-\delta \mathbf{U}_{[t+1]}^\top
\mathbf{W}_{[t+1]}
\end{aligned}\]
For \(\delta \mathbf{Q}_{[t]}\)¶
\[\begin{aligned}
\delta \overleftarrow{\mathbf{Q}_{[t]}}
&=
\delta \mathbf{O}_{[t]}
\mathbf{S}_{[t-1]}^C
+
\left(\delta \mathbf{O}_{[t]}
\mathbf{V}_{[t],new}^\top \odot \mathbf{M}\right) \overrightarrow{\mathbf{K}_{[t]}}
\\
\\
\delta \mathbf{Q}_{[t]}
&=
\delta \overleftarrow{\mathbf{Q}_{[t]}}
\odot \mathbf{\Gamma}_{[t]}
\end{aligned}\]
For \(\delta \mathbf{V}_{[t]}\)¶
\[\begin{aligned}
\delta \mathbf{V}_{[t]}
=
\text{Diag}(\boldsymbol{\beta}_{[t]})
\widetilde{\mathbf{A}}_{[t]}^\top
\delta \mathbf{U}_{[t]}
\end{aligned}\]
For \(\delta \widetilde{\mathbf{A}}_{[t]}\) \(\delta \widetilde{\mathbf{X}}_{[t]}\)¶
First, for A_[t], we have:
\[\begin{aligned}
\delta \widetilde{\mathbf{A}}_{[t]}
&=
\delta \mathbf{W}_{[t]}
\overleftarrow{\mathbf{K}_{[t]}}^\top
\text{Diag}(\boldsymbol{\beta}_{[t]})
+
\delta \mathbf{U}_{[t]}
\mathbf{V}_{[t]}^\top
\text{Diag}(\boldsymbol{\beta}_{[t]})
\end{aligned}\]
Then, using the differential formula for the matrix inverse, we obtain:
\[\begin{aligned}
\delta \widetilde{\mathbf{X}}_{[t]}
&=
- \widetilde{\mathbf{X}}_{[t]}^{-\top}
\delta (\widetilde{\mathbf{X}}_{[t]}^{-1})
\widetilde{\mathbf{X}}_{[t]}^{-\top}
=
- \widetilde{\mathbf{A}}_{[t]}^\top
\delta \widetilde{\mathbf{A}}_{[t]}
\widetilde{\mathbf{A}}_{[t]}^\top
\end{aligned}\]
For \(\delta \mathbf{K}_{[t]}\)¶
There are mainly two K in the forward process, K_left and K_right. We derive the gradients respectively.
For K_left, we have:
\[\begin{aligned}
\left.\delta \overleftarrow{\mathbf{K}_{[t]}}\right|_{\text{from } \mathbf{W}_{[t]} \text{w/o} \widetilde{\mathbf{A}}_{[t]} }
&=
\text{Diag}(\boldsymbol{\beta}_{[t]})
\widetilde{\mathbf{A}}_{[t]}^\top
\delta \mathbf{W}_{[t]}
\\
\\
\left.\delta \overleftarrow{\mathbf{K}_{[t]}}\right|_{\text{from } \widetilde{\mathbf{A}}_{[t]} }
&=
\left(
\text{Diag}(\boldsymbol{\beta}_{[t]})
\delta \widetilde{\mathbf{X}}_{[t]}
\odot \mathbf{M}_{-1}
\right)
\overrightarrow{\mathbf{K}_{[t]}}
\end{aligned}\]
so we get
\[\begin{aligned}
\delta \overleftarrow{\mathbf{K}_{[t]}}
&=
\text{Diag}(\boldsymbol{\beta}_{[t]})
\widetilde{\mathbf{A}}_{[t]}^\top
\delta \mathbf{W}_{[t]}
+
\left(
\text{Diag}(\boldsymbol{\beta}_{[t]})
\delta \widetilde{\mathbf{X}}_{[t]}
\odot \mathbf{M}_{-1}
\right)
\overrightarrow{\mathbf{K}_{[t]}}
\end{aligned}\]
For K_right, we have:
\[\begin{aligned}
\left.\delta \overrightarrow{\mathbf{K}_{[t]}}\right|_{\text{from } \mathbf{S}_{[t]}^C \text{w/o} \widetilde{\mathbf{A}}_{[t]} }
&=
\mathbf{V}_{[t],new}
\delta \mathbf{S}_{[t]}^C
\text{Diag}(\boldsymbol{\gamma}_{[t]}^C)
\\
\\
\left.\delta \overrightarrow{\mathbf{K}_{[t]}}\right|_{\text{from } \mathbf{O}_{[t]}^C \text{w/o} \widetilde{\mathbf{A}}_{[t]} }
&=
\left(
\delta \mathbf{O}_{[t]}
\mathbf{V}_{[t],new}^\top
\odot \mathbf{M}
\right)^\top
\overleftarrow{\mathbf{Q}_{[t]}}
\\
\\
\left.\delta \overrightarrow{\mathbf{K}_{[t]}}\right|_{\text{from } \widetilde{\mathbf{A}}_{[t]} }
&=
\left(
\text{Diag}(\boldsymbol{\beta}_{[t]})
\delta \widetilde{\mathbf{X}}_{[t]}
\odot \mathbf{M}_{-1}
\right)^\top
\overleftarrow{\mathbf{K}_{[t]}}
\end{aligned}\]
so we get
\[\begin{aligned}
\delta \overrightarrow{\mathbf{K}_{[t]}}
&=
\mathbf{V}_{[t],new}
\delta \mathbf{S}_{[t]}^C
\text{Diag}(\boldsymbol{\gamma}_{[t]}^C)
+
\left(
\text{Diag}(\boldsymbol{\beta}_{[t]})
\delta \widetilde{\mathbf{X}}_{[t]}
\odot \mathbf{M}_{-1}
\right)^\top
\overleftarrow{\mathbf{K}_{[t]}}
\end{aligned}\]
finally we merge those two path
\[\begin{aligned}
\delta \mathbf{K}_{[t]}
&=
\delta \overleftarrow{\mathbf{K}_{[t]}}
\odot \mathbf{\Gamma}_{[t]}
+
\delta \overrightarrow{\mathbf{K}_{[t]}}
\oslash \mathbf{\Gamma}_{[t]}
\\&=
\text{Diag}(\boldsymbol{\beta}_{[t]})
\widetilde{\mathbf{A}}_{[t]}^\top
\delta \mathbf{W}_{[t]}
\odot \mathbf{\Gamma}_{[t]}
+
\mathbf{V}_{[t],new}
\delta \mathbf{S}_{[t]}^C
\text{Diag}(\boldsymbol{\gamma}_{[t]}^C)
\oslash \mathbf{\Gamma}_{[t]}
\\&+
\left(
\delta \mathbf{O}_{[t]}
\mathbf{V}_{[t],new}^\top
\odot \mathbf{M}
\right)^\top
\overleftarrow{\mathbf{Q}_{[t]}}
\oslash \mathbf{\Gamma}_{[t]}
\\&+
\left(
\text{Diag}(\boldsymbol{\beta}_{[t]})
\delta \widetilde{\mathbf{X}}_{[t]}
\odot \mathbf{M}_{-1}
\right)
\overrightarrow{\mathbf{K}_{[t]}}
\odot \mathbf{\Gamma}_{[t]}
+
\left(
\text{Diag}(\boldsymbol{\beta}_{[t]})
\delta \widetilde{\mathbf{X}}_{[t]}
\odot \mathbf{M}_{-1}
\right)^\top
\overleftarrow{\mathbf{K}_{[t]}}
\oslash \mathbf{\Gamma}_{[t]}
\end{aligned}\]
For \(\delta \boldsymbol{\beta}_{[t]}\)¶
Next, the gradient with respect to \(\boldsymbol{\beta}_{[t]}\) can be written as:
\[\begin{aligned}
\delta \text{Diag}(\boldsymbol{\beta}_{[t]})
&=
\widetilde{\mathbf{A}}_{[t]}^\top
\delta \mathbf{W}_{[t]}
\overleftarrow{\mathbf{K}_{[t]}}^\top
+
\widetilde{\mathbf{A}}_{[t]}^\top
\delta \mathbf{U}_{[t]}
\mathbf{V}_{[t]}^\top
+
\delta \widetilde{\mathbf{X}}_{[t]}
\left( \overleftarrow{\mathbf{K}_{[t]}} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M}_{-1} \right)^\top
\\&=
\widetilde{\mathbf{A}}_{[t]}^\top
\delta \widetilde{\mathbf{A}}_{[t]}
\text{Diag}(\boldsymbol{\beta}_{[t]})^{-1}
+
\delta \widetilde{\mathbf{X}}_{[t]}
\left( \overleftarrow{\mathbf{K}_{[t]}} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M}_{-1} \right)^\top
\\
\\
\delta \boldsymbol{\beta}_{[t]}
&=
\text{diag}(\delta \text{Diag}(\boldsymbol{\beta}_{[t]}))
\\ &=
\text{diag}\left(
\widetilde{\mathbf{A}}_{[t]}^\top
\delta \widetilde{\mathbf{A}}_{[t]}
\right)
\odot
\boldsymbol{\beta}_{[t]}^{-1}
+
\text{diag}\left(
\left(
\delta \widetilde{\mathbf{X}}_{[t]}
\odot \mathbf{M}_{-1}
\right)
\left(
\overrightarrow{\mathbf{K}_{[t]}} \overleftarrow{\mathbf{K}_{[t]}}^\top
\right)
\right)
\end{aligned}\]
For \(\delta \mathbf{\Gamma}_{[t]}\)¶
\[\begin{aligned}
\delta \boldsymbol{\gamma}_{[t]}^C
&=
\text{diag}\left(
\left(
\mathbf{S}_{[t-1]}^C
+
\mathbf{V}_{[t],new}^\top
\overrightarrow{\mathbf{K}_{[t]}}
\right)^\top
\delta \mathbf{S}_{[t]}^C
\right)
=
\left(\boldsymbol{\gamma}_{[t]}^{C}\right)^{-1}
\odot
\text{diag}\left(
\mathbf{S}_{[t]}^{C\top}
\delta \mathbf{S}_{[t]}^C
\right)
\end{aligned}\]
\[\begin{aligned}
\left.\delta \mathbf{\Gamma}_{[t]}\right|_{\text{w/o extra} \boldsymbol{\gamma}_{[t]}^C}
&=
\delta \overleftarrow{\mathbf{Q}_{[t]}}
\odot \mathbf{Q}_{[t]}
+
\delta \overleftarrow{\mathbf{K}_{[t]}}
\odot \mathbf{K}_{[t]}
-
\left(
\delta \overrightarrow{\mathbf{K}_{[t]}}
\odot \mathbf{K}_{[t]}
\right)
\oslash
\left(
\mathbf{\Gamma}_{[t]}
\odot \mathbf{\Gamma}_{[t]}
\right)
\\&=
\delta \overleftarrow{\mathbf{Q}_{[t]}}
\odot \mathbf{\Gamma}_{[t]}
\odot \mathbf{Q}_{[t]}
\oslash \mathbf{\Gamma}_{[t]}
+
\left(
\delta \overleftarrow{\mathbf{K}_{[t]}}
\odot \mathbf{\Gamma}_{[t]}
\odot \mathbf{K}_{[t]}
-
\delta \overrightarrow{\mathbf{K}_{[t]}}
\oslash \mathbf{\Gamma}_{[t]}
\odot \mathbf{K}_{[t]}
\right)
\oslash
\mathbf{\Gamma}_{[t]}
\\&=
\left(
\delta \mathbf{Q}_{[t]}
\odot \mathbf{Q}_{[t]}
+
\left(
\delta \overleftarrow{\mathbf{K}_{[t]}}
\odot \mathbf{\Gamma}_{[t]}
\right)
\odot \mathbf{K}_{[t]}
-
\left(
\delta \overrightarrow{\mathbf{K}_{[t]}}
\oslash \mathbf{\Gamma}_{[t]}
\right)
\odot \mathbf{K}_{[t]}
\right)
\oslash
\mathbf{\Gamma}_{[t]}
\end{aligned}\]
specially:
\[\begin{aligned}
\delta \log \mathbf{\Gamma}_{[t]}
&=
\delta \mathbf{\Gamma}_{[t]}
\odot
\mathbf{\Gamma}_{[t]}
\\&=
\delta \mathbf{Q}_{[t]}
\odot \mathbf{Q}_{[t]}
+
\left(
\left(
\delta \overleftarrow{\mathbf{K}_{[t]}}
\odot \mathbf{\Gamma}_{[t]}
\right)
\odot \mathbf{K}_{[t]}
-
\left(
\delta \overrightarrow{\mathbf{K}_{[t]}}
\oslash \mathbf{\Gamma}_{[t]}
\right)
\odot \mathbf{K}_{[t]}
\right)
\\&+
[0,0,...,\text{diag}\left(
\mathbf{S}_{[t]}^{C\top}
\delta \mathbf{S}_{[t]}^C
\right)]
\end{aligned}\]
Comments