The Hidden Recurrence in Softmax Attention: A Path to Flash Attention
Paper: https://arxiv.org/abs/2205.14135
Paper: https://arxiv.org/abs/2307.08691
Code: https://github.com/Dao-AILab/flash-attention
Disclaimer: These are personal reading notes. Some derivations are my own and may be incorrect, so they should be cross-checked against the code later.
Standard Attention
Forward Pass
\[\begin{aligned}
\mathbf{O} &=
\text{Diag}\left(
\left(\exp(\mathbf{Q}\mathbf{K}^\top ) \odot \mathbf{M}\right)\boldsymbol{1}\right)^{-1}
\left(
\exp(\mathbf{Q}\mathbf{K}^\top) \odot \mathbf{M}\right)\mathbf{V}
\end{aligned}\]

Backward Pass
Before deriving the standard attention backward, considering a simple case:
\[\begin{aligned}
\mathbf{Y} &= \text{Diag}\left( \mathbf{X}\boldsymbol{1}\right)^{-1}
\\ \\ \Rightarrow
\delta \mathbf{X} &=
\text{diag}\left(
- \mathbf{Y}^\top \delta \mathbf{Y} \mathbf{Y}^\top
\right) \boldsymbol{1}^\top
=
- \text{diag}\left(
\mathbf{Y} \delta \mathbf{Y} \mathbf{Y}
\right) \boldsymbol{1}^\top
=
- \mathbf{Y}^2 \text{diag}\left(
\delta \mathbf{Y}
\right) \boldsymbol{1}^\top
\end{aligned}\]
Then the standard attention backward can be derived as follow:
\[\begin{aligned}
\mathbf{O} &=
\text{Diag}\left(
\left(\exp(\mathbf{Q}\mathbf{K}^\top ) \odot \mathbf{M}\right)\boldsymbol{1}\right)^{-1}
\left(
\exp(\mathbf{Q}\mathbf{K}^\top) \odot \mathbf{M}\right)\mathbf{V}
\\
\\
\Rightarrow
\delta \left(\exp(\mathbf{Q}\mathbf{K}^\top) \odot \mathbf{M}\right)
&=
\text{Diag}\left(\left(\exp(\mathbf{Q}\mathbf{K}^\top ) \odot \mathbf{M}\right) \boldsymbol{1}\right)^{-1}
\delta \mathbf{O}
\mathbf{V}^\top
\\ &-
\text{Diag}\left(\left(\exp(\mathbf{Q}\mathbf{K}^\top ) \odot \mathbf{M}\right) \boldsymbol{1}\right)^{-2}
\text{diag}\left(
\delta \mathbf{O} \mathbf{V}^\top
\left(\exp(\mathbf{Q}\mathbf{K}^\top ) \odot \mathbf{M}\right)^\top\right)
\boldsymbol{1}^\top
\\ &=
\text{Diag}\left(\left(\exp(\mathbf{Q}\mathbf{K}^\top ) \odot \mathbf{M}\right) \boldsymbol{1}\right)^{-1}
\left(
\delta \mathbf{O} \mathbf{V}^\top
-
\text{diag}\left(\delta \mathbf{O} \mathbf{O}^\top
\right)
\boldsymbol{1}^\top
\right)
\\ \\ \Rightarrow
\delta \mathbf{V} &=
\left(\exp(\mathbf{Q}\mathbf{K}^\top ) \odot \mathbf{M}\right)^\top
\text{Diag}\left(\left(\exp(\mathbf{Q}\mathbf{K}^\top ) \odot \mathbf{M}\right) \boldsymbol{1} \right)^{-1}
\delta \mathbf{O}
\\ \\
\delta \mathbf{Q} &= \left(
\delta \left(\exp(\mathbf{Q}\mathbf{K}^\top) \odot \mathbf{M}\right)
\odot \left(\exp(\mathbf{Q}\mathbf{K}^\top) \odot \mathbf{M}\right)
\right) \mathbf{K}
\\ \\
\delta \mathbf{K} &= \left(
\delta \left(\exp(\mathbf{Q}\mathbf{K}^\top) \odot \mathbf{M}\right)
\odot \left(\exp(\mathbf{Q}\mathbf{K}^\top) \odot \mathbf{M}\right)
\right)^\top \mathbf{Q}
\end{aligned}\]

The Hidden Recurrence behind Softmax Attention
Recurrent mode
comment: The key to designing a tile-based algorithm where every local part relates to the global information is to find the recurrence. So here we find the hidden recurrent mode behind softmax dot-product attention.
Considering
$\(\begin{aligned}
\mathbf{O}
&= \frac{\exp(\mathbf{Q}\mathbf{K}^\top)}{\sum_k \exp(\mathbf{Q}\mathbf{K}^\top)} \mathbf{V}
\quad \Leftrightarrow \quad
\mathbf{O} =
\text{Diag}\left( \exp(\mathbf{Q}\mathbf{K}^\top) \boldsymbol{1}\right)^{-1}
\exp(\mathbf{Q}\mathbf{K}^\top)\mathbf{V}
\\ \\ \Leftrightarrow
\boldsymbol{o}_t
&= \frac{ \sum_{i=1}^L \boldsymbol{v}_i \exp(\boldsymbol{k}_i^\top\boldsymbol{q}_t) }
{\sum_{i=1}^L \exp(\boldsymbol{k}_i^\top\boldsymbol{q}_t)}
\end{aligned}\)$
We set
\[\begin{aligned}
\boldsymbol{a}^j_t
=
\sum_{i=1}^t \boldsymbol{v}_i \exp(\boldsymbol{k}_i^\top\boldsymbol{q}_j)
,\quad
l^j_t = \sum_{i=1}^t \exp(\boldsymbol{k}_i^\top\boldsymbol{q}_j)
,\quad
\boldsymbol{a}^t_L = l^t_L \boldsymbol{o}_t
\end{aligned}\]
Then we have
\[\begin{aligned}
l^j_t &= l^j_{t-1} + \exp(\boldsymbol{k}_t^\top\boldsymbol{q}_j)
\\
\\
\boldsymbol{a}^j_{t} &= \sum_{i=1}^t \boldsymbol{v}_i \exp(\boldsymbol{k}_i^\top\boldsymbol{q}_j)
=
\boldsymbol{a}^j_{t-1}
+ \boldsymbol{v}_t \exp(\boldsymbol{k}_t^\top\boldsymbol{q}_j)
\end{aligned}\]
Chunk-wise Mode
We set \(\mathbf{M}\) as the causal mask or no mask below.
\[\begin{aligned}
\mathbf{O}_{[t]} &= \text{Diag}\left(\boldsymbol{ll}^{[t]}_{[L]}\right)^{-1} \mathbf{A}^{[t]}_{[L]}
\\ \\
\boldsymbol{ll}^{[j]}_{[t]} &= \boldsymbol{ll}^{[j]}_{[t-1]}
+
\left(\exp \left(\mathbf{Q}_{[j]}\mathbf{K}_{[t]}^\top\right) \odot \mathbf{M}\right) \boldsymbol{1}
\\ \\
\mathbf{A}^{[j]}_{[t]} &=
\mathbf{A}^{[j]}_{[t-1]}
+
\left(\exp \left(\mathbf{Q}_{[j]}\mathbf{K}_{[t]}^\top\right) \odot \mathbf{M}\right) \mathbf{V}_{[t]}
\end{aligned}\]
After we set
\[\begin{aligned}
\mathbf{S}_{[t],[\tau]} &= \mathbf{Q}_{[t]}\mathbf{K}_{[\tau]}^\top
\in \mathbb{R}^{N \times M}
,\quad
\boldsymbol{m}_{[t],\tau}
\in \mathbb{R}^{N}
\\ \\
\mathbf{P}_{[t],[\tau]} &=
\text{Diag}\left(\boldsymbol{m}_{[t],\tau}\right)^{-1}
\left(
\exp \left(\mathbf{S}_{[t],[\tau]}\right)
\odot \mathbf{M}
\right)
\in \mathbb{R}^{N \times M}
\\ \\
\boldsymbol{l}_{[t],[\tau]} &= \mathbf{P}_{[t],[\tau]} \boldsymbol{1}
\in \mathbb{R}^{N}
\end{aligned}\]
we have
\[\begin{aligned}
\boldsymbol{l}^{[t]}_{[\tau]} &=
\boldsymbol{ll}^{[t]}_{[\tau]} \oslash \boldsymbol{m}^{\tau}_{[t]}
=
\text{Diag}\left(\boldsymbol{m}^{\tau-1}_{[t]}\right)
\text{Diag}\left(\boldsymbol{m}^\tau_{[t]}\right)^{-1}
\boldsymbol{l}^{[t]}_{[\tau-1]}
+ \text{Diag}\left(\boldsymbol{m}_{[t],\tau}\right)
\text{Diag}\left(\boldsymbol{m}^\tau_{[t]}\right)^{-1}
\boldsymbol{l}_{[t],[\tau]}
\\ \\
\mathbf{A}^{[t]}_{[\tau]} &= \mathbf{A}^{[t]}_{[\tau-1]}
+ \text{Diag}\left(\boldsymbol{m}_{[t],\tau}\right)
\mathbf{P}_{[t],[\tau]}\mathbf{V}_{[\tau]}
\end{aligned}\]
Chunk-wise Mode with Online Softmax Re-scale
For numerical stability, we often consider the rescaling trick for softmax. We set the scale as the exponential of row-max of past \(\mathbf{QK}^\top\). So we define:
\[\begin{aligned}
\boldsymbol{m}^{\tau}_{[t]} &=
\max(\boldsymbol{m}^{\tau-1}_{[t]}, \boldsymbol{m}_{[t],\tau})
\\
\\
\boldsymbol{l}^{[t]}_{[\tau]} &=
\text{Diag}\left(\boldsymbol{m}^{\tau-1}_{[t]}\right)
\text{Diag}\left(\boldsymbol{m}^\tau_{[t]}\right)^{-1}
\boldsymbol{l}^{[t]}_{[\tau-1]}
+ \text{Diag}\left(\boldsymbol{m}_{[t],\tau}\right)
\text{Diag}\left(\boldsymbol{m}^\tau_{[t]}\right)^{-1}
\boldsymbol{l}_{[t],[\tau]}
\\
\\
\mathbf{C}^{[t]}_{[\tau]} &=
\text{Diag}\left(\boldsymbol{m}^\tau_{[t]}\right)^{-1}
\mathbf{A}^{[t]}_{[\tau]}
\\&=
\text{Diag}\left(\boldsymbol{m}^{\tau-1}_{[t]}\right)
\text{Diag}\left(\boldsymbol{m}^\tau_{[t]}\right)^{-1}
\mathbf{C}^{[t]}_{[\tau-1]}
+
\text{Diag}\left(\boldsymbol{m}_{[t],\tau}\right)
\text{Diag}\left(\boldsymbol{m}^\tau_{[t]}\right)^{-1}
\mathbf{P}_{[t],[\tau]}\mathbf{V}_{[\tau]}
\\
\\
\mathbf{O}^{[t]} &=
\text{Diag}\left(\boldsymbol{l}^{[t]}_{L}\right)^{-1}
\mathbf{C}^{[t]}_{[L]}
\end{aligned}\]
Now, with the addition of the inner-outer loop exchange and the position where the logsumexp scale is introduced, we've essentially derived the core logic of FlashAttention v1/v2 from scratch.


Flash Attention Backward Pass
comment: The exponential of row-max \(\boldsymbol{m}\) is computed from the input, but can be treated as a constant for gradient purposes. However, \(\boldsymbol{l}\) is a smooth function of the input, so gradient does flow through it.
Revisiting the Forward Pass
\[\begin{aligned}
\mathbf{P}_{[t],[\tau]} &=
\text{Diag}\left(\boldsymbol{m}_{[t],\tau}\right)^{-1}
\left(
\exp \left(\mathbf{Q}_{[t]}\mathbf{K}_{[\tau]}^\top \right)\odot \mathbf{M}
\right)
\in \mathbb{R}^{N \times M}
\\
\\
\boldsymbol{l}^{[t]}_{[\tau]} &=
\text{Diag}\left(\boldsymbol{m}^{\tau-1}_{[t]}\right)
\text{Diag}\left(\boldsymbol{m}^\tau_{[t]}\right)^{-1}
\boldsymbol{l}^{[t]}_{[\tau-1]}
+ \text{Diag}\left(\boldsymbol{m}_{[t],\tau}\right)
\text{Diag}\left(\boldsymbol{m}^\tau_{[t]}\right)^{-1}
\mathbf{P}_{[t],[\tau]} \boldsymbol{1}
\\
\\
\mathbf{C}^{[t]}_{[\tau]} &=
\text{Diag}\left(\boldsymbol{m}^{\tau-1}_{[t]}\right)
\text{Diag}\left(\boldsymbol{m}^\tau_{[t]}\right)^{-1}
\mathbf{C}^{[t]}_{[\tau-1]}
+
\text{Diag}\left(\boldsymbol{m}_{[t],\tau}\right)
\text{Diag}\left(\boldsymbol{m}^\tau_{[t]}\right)^{-1}
\mathbf{P}_{[t],[\tau]}\mathbf{V}_{[\tau]}
\\
\\
\mathbf{O}^{[t]} &=
\text{Diag}\left(\boldsymbol{l}^{[t]}_{L}\right)^{-1}
\mathbf{C}^{[t]}_{[L]}
\end{aligned}\]
Gradient at the Output
Assuming we have stored \(\boldsymbol{l}\) and \(\boldsymbol{m}\) from the forward pass, we begin back propagation from \(\mathbf{O}^{[t]}\):
\[\begin{aligned}
\delta \mathbf{C}^{[t]}_{[L]} &=
\text{Diag}\left(\boldsymbol{l}^{[t]}_{[L]}\right)^{-1}
\delta \mathbf{O}^{[t]}
\\
\\
\delta \boldsymbol{l}^{[t]}_{[L]} &= -
\text{diag}\left(
\text{Diag}\left(\boldsymbol{l}^{[t]}_{[L]}\right)^{-\top}
\left(\delta \mathbf{O}^{[t]} (\mathbf{C}^{[t]}_{[L]})^\top \right)
\text{Diag}\left(\boldsymbol{l}^{[t]}_{[L]}\right)^{-\top}
\right)
= -
\text{diag}\left(\delta \mathbf{O}^{[t]} (\mathbf{O}^{[t]})^\top \right)
\oslash \boldsymbol{l}^{[t]}_{[L]}
\\
\\
\delta \boldsymbol{l}^{[t]}_{[\tau-1]}
&=
\left(
\text{Diag}\left(\boldsymbol{m}^{\tau-1}_{[t]}\right)
\text{Diag}\left(\boldsymbol{m}^\tau_{[t]}\right)^{-1}
\right)^\top
\delta \boldsymbol{l}^{[t]}_{[\tau]}
=
\left(
\text{Diag}\left(\boldsymbol{m}^{\tau-1}_{[t]}\right)
\text{Diag}\left(\boldsymbol{m}^L_{[t]}\right)^{-1}
\right)^\top
\delta \boldsymbol{l}^{[t]}_{[L]}
\\
\\
\delta \mathbf{C}^{[t]}_{[\tau-1]}
&=
\left(
\text{Diag}\left(\boldsymbol{m}^{\tau-1}_{[t]}\right)
\text{Diag}\left(\boldsymbol{m}^\tau_{[t]}\right)^{-1}
\right)^\top
\delta \mathbf{C}^{[t]}_{[\tau]}
\\
\\
\left.\delta \mathbf{V}_{[\tau]}\right|_{\text{from [t]}}
&= \mathbf{P}_{[t],[\tau]}^\top
\left(
\text{Diag}\left(\boldsymbol{m}_{[t],\tau}\right)
\text{Diag}\left(\boldsymbol{m}^\tau_{[t]}\right)^{-1}
\right)^\top
\delta \mathbf{C}^{[t]}_{[\tau]}
=
\left(
\exp \left(\mathbf{Q}_{[t]}\mathbf{K}_{[\tau]}^\top\right)
\odot \mathbf{M}
\right)^\top
\text{Diag}\left(\boldsymbol{m}^\tau_{[t]}\right)^{-1}
\delta \mathbf{C}^{[t]}_{[\tau]}
\\
\\
\delta \mathbf{P}_{[t],[\tau]}
&=
\left(
\text{Diag}\left(\boldsymbol{m}_{[t],\tau}\right)
\text{Diag}\left(\boldsymbol{m}^\tau_{[t]}\right)^{-1}
\right)^\top
\left(
\delta \boldsymbol{l}^{[t]}_{[\tau]}
\boldsymbol{1}^\top
+
\delta \mathbf{C}^{[t]}_{[\tau]} \mathbf{V}_{[\tau]}^\top
\right)
\\ &=
\left(
\text{Diag}\left(\boldsymbol{m}_{[t],\tau}\right)
\text{Diag}\left(\boldsymbol{m}^L_{[t]}\right)^{-1}
\right)^\top
\delta \mathbf{l}^{[t]}_{[L]}
\boldsymbol{1}^\top
+
\left(
\text{Diag}\left(\boldsymbol{m}_{[t],\tau}\right)
\text{Diag}\left(\boldsymbol{m}^\tau_{[t]}\right)^{-1}
\right)^\top
\delta \mathbf{C}^{[t]}_{[\tau]} \mathbf{V}_{[\tau]}^\top
\\
\\
\left.\delta \mathbf{Q}_{[t]}\right|_{\text{from } \mathbf{P}_{[t],[\tau]}}
&=
\left(
\delta \exp \left(\mathbf{Q}_{[t]}\mathbf{K}_{[\tau]}^\top\right)
\odot \exp \left(\mathbf{Q}_{[t]}\mathbf{K}_{[\tau]}^\top\right)
\right) \mathbf{K}_{[\tau]}
\\ &=
\left(
\text{Diag}\left(\boldsymbol{m}_{[t],\tau}\right)^{-1}
\delta \mathbf{P}_{[t],[\tau]} \odot \mathbf{M}
\odot \exp \left(\mathbf{Q}_{[t]}\mathbf{K}_{[\tau]}^\top\right)
\right) \mathbf{K}_{[\tau]}
\\&=
\left(
\delta \mathbf{P}_{[t],[\tau]} \odot \mathbf{P}_{[t],[\tau]}
\right) \mathbf{K}_{[\tau]}
\\
\\
\left.\delta \mathbf{K}_{[\tau]}\right|_{\text{from } \mathbf{P}_{[t],[\tau]}}
&=
\left(
\delta \exp \left(\mathbf{Q}_{[t]}\mathbf{K}_{[\tau]}^\top\right)
\odot \exp \left(\mathbf{Q}_{[t]}\mathbf{K}_{[\tau]}^\top\right)
\right) ^\top \mathbf{Q}_{[t]}
\\ &=
\left(
\delta \mathbf{P}_{[t],[\tau]} \odot \mathbf{P}_{[t],[\tau]}
\right) ^\top \mathbf{Q}_{[t]}
\end{aligned}\]
comment: The key insight of the backward pass is that re-computation replaces storage. Rather than saving the full attention matrix \(\mathbf{P}\) from the forward pass, we only store \(\boldsymbol{l}\) and \(\boldsymbol{m}\), then recompute the attention matrix tile-by-tile during the backward pass.

