Skip to content

Reading Notes: Backward Pass for DeltaNet

Mathematical Preliminaries

Besides the material already covered in the GLA notes, we also need the differential of a matrix inverse here:

\[\begin{aligned} & \mathbf{A} = \mathbf{B}^{-1} \Rightarrow d \mathbf{A} = - \mathbf{B}^{-1} (d \mathbf{B}) \mathbf{B}^{-1} \\ \\ \Rightarrow & dy = \text{Tr}\left((\delta \mathbf{A})^\top (d \mathbf{A})\right) = \text{Tr}\left(-(\delta \mathbf{A})^\top \mathbf{B}^{-1} (d \mathbf{B}) \mathbf{B}^{-1}\right) = \text{Tr}\left(- \mathbf{B}^{-1} (\delta \mathbf{A})^\top \mathbf{B}^{-1} (d \mathbf{B}) \right) \\ \\ \Rightarrow & \delta \mathbf{B} = - \mathbf{B}^{-1 \top} (\delta \mathbf{A}) \mathbf{B}^{-1 \top} = - \mathbf{A}^\top (\delta \mathbf{A}) \mathbf{A}^\top \end{aligned}\]

We will also use the following transpose identity:

\[\begin{aligned} \text{diag}\left(\mathbf{A} \left(\mathbf{B} \odot \mathbf{C}\right)^\top \right) = \text{diag}\left(\left(\mathbf{A} \odot \mathbf{C}\right) \mathbf{B}^\top \right) \end{aligned}\]

Revisiting the Forward Equations

We first recall the key forward equations:

\[\begin{aligned} \mathbf{V}_{[t],new} &:= \left(\mathbf{U}_{[t]} - \mathbf{W}_{[t]} \mathbf{S}_{[t-1]}^{C \top} \right) \\ \\ \mathbf{X}_{[t]} &= \mathbf{I} + \text{Diag}(\boldsymbol{\beta}_{[t]}) \left( \mathbf{K}_{[t]} \mathbf{K}_{[t]}^\top \odot \mathbf{M}_{-1} \right) \\ \\ \mathbf{A}_{[t]} &= \mathbf{X}_{[t]}^{-1} ,\quad \mathbf{T}_{[t]} = \mathbf{A}_{[t]} \text{Diag}(\boldsymbol{\beta}_{[t]}) \\ \\ \mathbf{W}_{[t]} &= \mathbf{T}_{[t]} \mathbf{K}_{[t]} ,\quad \mathbf{U}_{[t]} = \mathbf{T}_{[t]} \mathbf{V}_{[t]} \\ \\ \mathbf{S}_{[t]}^{C} &= \mathbf{S}_{[t-1]}^{C} + \mathbf{V}_{[t],new}^\top\mathbf{K}_{[t]} \\ \\ \mathbf{O}_{[t]} &= \mathbf{Q}_{[t]} \mathbf{S}_{[t-1]}^{C \top} + \left( \mathbf{Q}_{[t]} \mathbf{K}_{[t]}^\top \odot \mathbf{M} \right) \mathbf{V}_{[t],new} \end{aligned}\]

For \(\delta \mathbf{S}_{[t]}^C\)

We begin with the gradient of the chunk state.

The contribution from \(\mathbf{S}_{[t]}^{C}\) itself is:

\[\begin{aligned} \left.\delta \mathbf{S}_{[t-1]}^{C} \right|_{\text{from } \mathbf{S}_{[t]}^{C}} &= \delta \mathbf{S}_{[t]}^{C} - \delta \mathbf{S}_{[t]}^{C} \mathbf{K}_{[t]}^\top \mathbf{W}_{[t]} \end{aligned}\]

Meanwhile, the contribution from \(\mathbf{O}_{[t]}\) is:

\[\begin{aligned} \left.\delta \mathbf{S}_{[t-1]}^{C} \right|_{\text{from } \mathbf{O}_{[t]}} &= \delta \mathbf{O}_{[t]}^\top \mathbf{Q}_{[t]} - \delta \mathbf{O}_{[t]}^\top \left( \mathbf{Q}_{[t]} \mathbf{K}_{[t]}^\top \odot \mathbf{M} \right) \mathbf{W}_{[t]} \end{aligned}\]

Combining the two terms, we obtain:

\[\begin{aligned} \delta \mathbf{S}_{[t]}^{C} &= \delta \mathbf{S}_{[t+1]}^{C} - \delta \mathbf{S}_{[t+1]}^{C} \mathbf{K}_{[t+1]}^\top \mathbf{W}_{[t+1]} + \delta \mathbf{O}_{[t+1]}^\top \mathbf{Q}_{[t+1]} - \delta \mathbf{O}_{[t+1]}^\top \left( \mathbf{Q}_{[t+1]} \mathbf{K}_{[t+1]}^\top \odot \mathbf{M} \right) \mathbf{W}_{[t+1]} \\ \\ &= \delta \mathbf{S}_{[t+1]}^{C} + \delta \mathbf{O}_{[t+1]}^\top \mathbf{Q}_{[t+1]} - \delta \mathbf{U}_{[t+1]}^\top \mathbf{W}_{[t+1]} \end{aligned}\]

For \(\delta \mathbf{Q}_{[t]}\)

Next, the gradient with respect to \(\mathbf{Q}_{[t]}\) is:

\[\begin{aligned} \delta \mathbf{Q}_{[t]} &= \delta \mathbf{O}_{[t]} \mathbf{S}_{[t-1]}^C + \left(\delta \mathbf{O}_{[t]} \mathbf{V}_{[t],new}^\top \odot \mathbf{M}\right) \mathbf{K}_{[t]} \end{aligned}\]

For \(\delta \mathbf{U}_{[t]}\)

Similarly, the gradient with respect to \(\mathbf{U}_{[t]}\) is:

\[\begin{aligned} \delta \mathbf{U}_{[t]} = \mathbf{K}_{[t]} \delta \mathbf{S}_{[t]}^{C \top} + \left( \mathbf{K}_{[t]} \mathbf{Q}_{[t]}^\top \odot \mathbf{M}^\top \right) \delta \mathbf{O}_{[t]} \end{aligned}\]

For \(\delta \mathbf{W}_{[t]}\)

From the definition of \(\mathbf{V}_{[t],new}\), we further obtain:

\[\begin{aligned} \delta \mathbf{W}_{[t]} = - \mathbf{K}_{[t]} \delta \mathbf{S}_{[t]}^{C \top} \mathbf{S}_{[t-1]}^{C} - \left( \mathbf{K}_{[t]} \mathbf{Q}_{[t]}^\top \odot \mathbf{M}^\top \right) \delta \mathbf{O}_{[t]} \mathbf{S}_{[t-1]}^{C} = - \delta \mathbf{U}_{[t]} \mathbf{S}_{[t-1]}^{C} \end{aligned}\]

For \(\delta \mathbf{V}_{[t]}\)

Since \(\mathbf{U}_{[t]} = \mathbf{T}_{[t]} \mathbf{V}_{[t]}\), it follows that:

\[\begin{aligned} \delta \mathbf{V}_{[t]} = \mathbf{T}_{[t]}^\top \delta \mathbf{U}_{[t]} \end{aligned}\]

For \(\delta \mathbf{K}_{[t]}\)

This is the most involved part, since the gradient with respect to \(\mathbf{K}_{[t]}\) receives contributions from multiple paths.

We first collect the relevant intermediate expressions:

\[\begin{aligned} \delta \mathbf{W}_{[t]} &= - \mathbf{K}_{[t]} \delta \mathbf{S}_{[t]}^{C \top} \mathbf{S}_{[t-1]}^{C} - \left( \mathbf{K}_{[t]} \mathbf{Q}_{[t]}^\top \odot \mathbf{M}^\top \right) \delta \mathbf{O}_{[t]} \mathbf{S}_{[t-1]}^{C} = - \delta \mathbf{U}_{[t]} \mathbf{S}_{[t-1]}^{C} \\ \\ \left.\delta \mathbf{T}_{[t]}\right|_{\text{from } \mathbf{U}_{[t]} \text{ and } \mathbf{W}_{[t]}} &= \delta \mathbf{U}_{[t]} \mathbf{V}_{[t]}^\top + \delta \mathbf{W}_{[t]} \mathbf{K}_{[t]}^\top \\ \\ \left.\delta \mathbf{X}_{[t]}\right|_{\text{from } \mathbf{T}_{[t]}} &= - \mathbf{X}_{[t]}^{-\top} \delta \mathbf{T}_{[t]} \text{Diag}(\boldsymbol{\beta}_{[t]}) \mathbf{X}_{[t]}^{-\top} = - \mathbf{X}_{[t]}^{-\top} \delta \mathbf{T}_{[t]} \mathbf{T}_{[t]}^{\top} \\ \\ \left.\delta \mathbf{K}_{[t]}\right|_{\text{from } \mathbf{S}_{[t]} \text{ w/o } \mathbf{U}_{[t]} \text{ or } \mathbf{W}_{[t]} } &= \mathbf{V}_{[t],new} \delta \mathbf{S}_{[t]}^{C} \\ \\ \left.\delta \mathbf{K}_{[t]}\right|_{\text{from } \mathbf{O}_{[t]} \text{ w/o } \mathbf{U}_{[t]} \text{ or } \mathbf{W}_{[t]} } &= \left(\mathbf{V}_{[t],new} \delta \mathbf{O}_{[t]}^\top \odot \mathbf{M}^\top \right)\mathbf{Q}_{[t]} \\ \\ \left.\delta \mathbf{K}_{[t]}\right|_{\text{from } \mathbf{W}_{[t]} \text{ w/o } \mathbf{T}_{[t]} } &= \mathbf{T}_{[t]}^\top \delta \mathbf{W}_{[t]} = - \mathbf{T}_{[t]}^\top \delta \mathbf{U}_{[t]} \mathbf{S}_{[t-1]}^{C} \\ \\ \left.\delta (\mathbf{K}_{[t]} \mathbf{K}_{[t]}^\top)\right|_{\text{from } \mathbf{T}_{[t]}} &= - \text{Diag}(\boldsymbol{\beta}_{[t]}) \mathbf{X}_{[t]}^{-\top} \delta \mathbf{T}_{[t]} \text{Diag}(\boldsymbol{\beta}_{[t]}) \mathbf{X}_{[t]}^{-\top} \odot \mathbf{M}_{-1} = - \mathbf{T}_{[t]}^{\top} \delta \mathbf{T}_{[t]} \mathbf{T}_{[t]}^{\top} \odot \mathbf{M}_{-1} \end{aligned}\]

From these, the contribution from \(\mathbf{T}_{[t]}\) can be written as:

\[\begin{aligned} \left.\delta \mathbf{K}_{[t]}\right|_{\text{from } \mathbf{T}_{[t]}} &= - \left(\mathbf{T}_{[t]}^{\top} \delta \mathbf{T}_{[t]} \mathbf{T}_{[t]}^{\top} \odot \mathbf{M}_{-1} \right) \mathbf{K}_{[t]} - \left(\mathbf{T}_{[t]}^{\top} \delta \mathbf{T}_{[t]} \mathbf{T}_{[t]}^{\top} \odot \mathbf{M}_{-1} \right)^\top \mathbf{K}_{[t]} \\&= - \left(\mathbf{T}_{[t]}^{\top} \delta \mathbf{T}_{[t]} \mathbf{T}_{[t]}^{\top} \odot \mathbf{M}_{-1} + \mathbf{T}_{[t]} \delta \mathbf{T}_{[t]}^\top \mathbf{T}_{[t]} \odot \mathbf{M}_{-1}^\top \right) \mathbf{K}_{[t]} \\ \\ \left.\delta \mathbf{T}_{[t]}\right|_{\text{from } \mathbf{U}_{[t]} \text{ and } \mathbf{W}_{[t]}} &= \delta \mathbf{U}_{[t]} \mathbf{V}_{[t]}^\top + \delta \mathbf{W}_{[t]} \mathbf{K}_{[t]}^\top = \delta \mathbf{U}_{[t]} \left(\mathbf{V}_{[t]}^\top - \delta \mathbf{U}_{[t]} \mathbf{S}_{[t-1]}^{C}\mathbf{K}_{[t]}^\top \right) \end{aligned}\]

Putting all terms together, we finally get:

\[\begin{aligned} \delta \mathbf{K}_{[t]} &= \mathbf{V}_{[t],new} \delta \mathbf{S}_{[t]}^{C} + \left(\mathbf{V}_{[t],new} \delta \mathbf{O}_{[t]}^\top \odot \mathbf{M}^\top \right)\mathbf{Q}_{[t]} - \mathbf{T}_{[t]}^\top \delta \mathbf{U}_{[t]} \mathbf{S}_{[t-1]}^{C} \\ &- \left(\mathbf{T}_{[t]}^{\top} \delta \mathbf{T}_{[t]} \mathbf{T}_{[t]}^{\top} \odot \mathbf{M}_{-1} + \mathbf{T}_{[t]} \delta \mathbf{T}_{[t]} ^\top \mathbf{T}_{[t]} \odot \mathbf{M}_{-1}^\top \right) \mathbf{K}_{[t]} \end{aligned}\]

For \(\delta \boldsymbol{\beta}_{[t]}\)

Next, the gradient with respect to \(\boldsymbol{\beta}_{[t]}\) can be written as:

\[\begin{aligned} \delta \text{Diag}(\boldsymbol{\beta}_{[t]}) &= \mathbf{X}_{[t]}^{-\top} \delta \mathbf{T}_{[t]} - \mathbf{X}_{[t]}^{-\top} \delta \mathbf{T}_{[t]} \mathbf{T}_{[t]}^\top \left( \mathbf{K}_{[t]} \mathbf{K}_{[t]}^\top \odot \mathbf{M}_{-1}^\top \right) \\ &= \text{Diag}(\boldsymbol{\beta}_{[t]})^{-1}\mathbf{T}_{[t]}^{\top} \delta \mathbf{T}_{[t]} \left(\mathbf{I} - \mathbf{T}_{[t]}^\top \left( \mathbf{K}_{[t]} \mathbf{K}_{[t]}^\top \odot \mathbf{M}_{-1}^\top \right) \right) \\ \\ \delta \boldsymbol{\beta}_{[t]} &= \text{diag}(\delta \text{Diag}(\boldsymbol{\beta}_{[t]})) \end{aligned}\]

An Equivalent Form for \(\delta \mathbf{K}_{[t]}\) and \(\delta \boldsymbol{\beta}_{[t]}\)

Let \(\mathbf{A} = \mathbf{X}^{-1}\). Then we can derive an equivalent form that is more convenient for implementation.

First, we have:

\[\begin{aligned} \delta \mathbf{W}_{[t]} &= - \delta \mathbf{U}_{[t]} \mathbf{S}_{[t-1]}^{C} \\ \\ \left.\delta \mathbf{T}_{[t]}\right|_{\text{from } \mathbf{U}_{[t]} \text{ and } \mathbf{W}_{[t]}} &= \delta \mathbf{U}_{[t]} \mathbf{V}_{[t]}^\top + \delta \mathbf{W}_{[t]} \mathbf{K}_{[t]}^\top \\ \\ \delta \mathbf{A}_{[t]} &= \delta \mathbf{T}_{[t]} \text{Diag}(\boldsymbol{\beta}_{[t]})^\top = \delta \mathbf{U}_{[t]} (\mathbf{V}_{[t]} \text{Diag}(\boldsymbol{\beta}_{[t]}))^\top + \delta \mathbf{W}_{[t]} (\mathbf{K}_{[t]}\text{Diag}(\boldsymbol{\beta}_{[t]}))^\top \\ \\ \delta \mathbf{X}_{[t]} &= - \mathbf{A}_{[t]}^\top \delta \mathbf{A}_{[t]} \mathbf{A}_{[t]}^\top , \quad \mathbf{T}_{[t]}^\top \delta \mathbf{T}_{[t]} \mathbf{T}_{[t]}^\top = \text{Diag}(\boldsymbol{\beta}_{[t]}) \mathbf{A}_{[t]}^\top \delta \mathbf{A}_{[t]} \mathbf{A}_{[t]}^\top \\ \\ \left.\delta \mathbf{K}_{[t]}\right|_{\text{from } \mathbf{S}_{[t]} \text{ w/ } \mathbf{O}_{[t]} \text{ w/o } \mathbf{U}_{[t]} \text{ or } \mathbf{W}_{[t]} } &= \mathbf{V}_{[t],new} \delta \mathbf{S}_{[t]}^{C} +\left(\mathbf{V}_{[t],new} \delta \mathbf{O}_{[t]}^\top \odot \mathbf{M}^\top \right)\mathbf{Q}_{[t]} \\ \\ \left.\delta \mathbf{K}_{[t]}\right|_{\text{from } \mathbf{W}_{[t]} \text{ w/o } \mathbf{T}_{[t]} } &= \text{Diag}(\boldsymbol{\beta}_{[t]}) \mathbf{A}_{[t]}^\top \delta \mathbf{W}_{[t]} \\ \\ \left.\delta (\mathbf{K}_{[t]} \mathbf{K}_{[t]}^\top)\right|_{\text{from } \mathbf{T}_{[t]}} &= - \text{Diag}(\boldsymbol{\beta}_{[t]}) \mathbf{X}_{[t]}^{-\top} \delta \mathbf{T}_{[t]} \text{Diag}(\boldsymbol{\beta}_{[t]}) \mathbf{X}_{[t]}^{-\top} \odot \mathbf{M}_{-1} = \text{Diag}(\boldsymbol{\beta}_{[t]}) (\delta \mathbf{X}_{[t]} \odot \mathbf{M}_{-1}) \\ \\ \left.\delta \mathbf{K}_{[t]}\right|_{\text{from } \mathbf{T}_{[t]}} &= \text{Diag}(\boldsymbol{\beta}_{[t]}) \left(\delta \mathbf{X}_{[t]} \odot \mathbf{M}_{-1} \right) \mathbf{K}_{[t]} + \left(\delta \mathbf{X}_{[t]} \odot \mathbf{M}_{-1} \right)^\top \text{Diag}(\boldsymbol{\beta}_{[t]}) \mathbf{K}_{[t]} \\ \\ \delta \mathbf{K}_{[t]} &= \mathbf{V}_{[t],new} \delta \mathbf{S}_{[t]}^{C} + \left(\mathbf{V}_{[t],new} \delta \mathbf{O}_{[t]}^\top \odot \mathbf{M}^\top \right)\mathbf{Q}_{[t]} + \text{Diag}(\boldsymbol{\beta}_{[t]}) \mathbf{A}_{[t]}^\top \delta \mathbf{W}_{[t]} \\ &+ \text{Diag}(\boldsymbol{\beta}_{[t]}) \left(\delta \mathbf{X}_{[t]} \odot \mathbf{M}_{-1} \right) \mathbf{K}_{[t]} + \left(\delta \mathbf{X}_{[t]} \odot \mathbf{M}_{-1} \right)^\top \text{Diag}(\boldsymbol{\beta}_{[t]}) \mathbf{K}_{[t]} \end{aligned}\]

We also obtain the following equivalent form for \(\delta \boldsymbol{\beta}\):

\[\begin{aligned} \delta \text{Diag}(\boldsymbol{\beta}_{[t]}) &= \mathbf{X}_{[t]}^{-\top} \delta \mathbf{T}_{[t]} + \delta \mathbf{X}_{[t]} \left( \mathbf{K}_{[t]} \mathbf{K}_{[t]}^\top \odot \mathbf{M}_{-1}^\top \right) \\ &= \mathbf{A}_{[t]}^\top \delta \mathbf{U}_{[t]} \mathbf{V}_{[t]}^\top + \mathbf{A}_{[t]}^\top \delta \mathbf{W}_{[t]} \mathbf{K}_{[t]}^\top + \delta \mathbf{X}_{[t]} \left( \mathbf{K}_{[t]} \mathbf{K}_{[t]}^\top \odot \mathbf{M}_{-1} \right)^\top \\ \\ \delta \boldsymbol{\beta}_{[t]} &= \text{diag}(\delta \text{Diag}(\boldsymbol{\beta}_{[t]})) \\ &= \text{diag}\left(\mathbf{A}_{[t]}^\top \delta \mathbf{U}_{[t]} \mathbf{V}_{[t]}^\top\right) + \text{diag}\left(\mathbf{A}_{[t]}^\top \delta \mathbf{W}_{[t]} \mathbf{K}_{[t]}^\top\right) + \text{diag}\left(\delta \mathbf{X}_{[t]} \left( \mathbf{K}_{[t]} \mathbf{K}_{[t]}^\top \odot \mathbf{M}_{-1} \right)^\top\right) \\&= \text{diag}\left(\mathbf{A}_{[t]}^\top \delta \mathbf{U}_{[t]} \mathbf{V}_{[t]}^\top\right) + \text{diag}\left(\mathbf{A}_{[t]}^\top \delta \mathbf{W}_{[t]} \mathbf{K}_{[t]}^\top\right) + \text{diag}\left((\delta \mathbf{X}_{[t]} \odot \mathbf{M}_{-1}) \mathbf{K}_{[t]} \mathbf{K}_{[t]}^\top\right) \end{aligned}\]
Comments