Skip to content

Reading Notes: Implementation Notes for KDA in FLA

Paper: https://arxiv.org/pdf/2510.26692
Code: https://github.com/fla-org/flash-linear-attention
Disclaimer: These are personal reading notes. Some derivations are my own and may be incorrect.

Forward

Entry function signature:

def chunk_kda_fwd(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    g: torch.Tensor,
    beta: torch.Tensor,
    scale: float,
    initial_state: torch.Tensor,
    output_final_state: bool,
    cu_seqlens: torch.LongTensor | None = None,
    cu_seqlens_cpu: torch.LongTensor | None = None,
    chunk_indices: torch.LongTensor | None = None,
    chunk_size: int = 64,
    safe_gate: bool = False,
    lower_bound: float | None = None,
    use_gate_in_kernel: bool = False,
    A_log: torch.Tensor | None = None,
    dt_bias: torch.Tensor | None = None,
    disable_recompute: bool = False,
    return_intermediate_states: bool = False,
    cp_context: FLACPContext | None = None,
    transpose_state_layout: bool = False,
):
    return o, final_state, g, Aqk, Akk, w, u, qg, kg, v_new, h, initial_state

Notation

\[\begin{aligned} \overleftarrow{\square_{[t]}} &= \square_{[t]} \odot \mathbf{\Gamma}_{[t]} ,\quad \overrightarrow{\square_{[t]}} = \square_{[t]} \oslash \mathbf{\Gamma}_{[t]} \quad\text{for}\quad \square \in \{ \mathbf{Q}, \mathbf{K}\} \end{aligned}\]

Step 1: Compute the cumulative gate in the log domain

Here, gamma corresponds to what was previously defined as log gamma, so this should be kept in mind in the formulas that follow.

\[\begin{aligned} \boldsymbol{\gamma}_{[t]}^r &= \sum_{i=tC+1}^{tC+r} \log \alpha_i \in \mathbb{R} \end{aligned}\]

Corresponding code: fla.ops.utils.cumsum -> chunk_local_cumsum

Step 2: Solve the lower-triangular system

\[\begin{aligned} \mathbf{A_{qk}}_{[t]} &= \overleftarrow{\mathbf{Q}_{[t]}} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M} ,\quad \mathbf{A_{kk0}}_{[t]} = \text{Diag}(\boldsymbol{\beta}_{[t]}) \left( \overleftarrow{\mathbf{K}_{[t]}} \overrightarrow{\mathbf{K}_{[t]}}^\top \odot \mathbf{M}_{-1} \right) \\ \\ \mathbf{A}_{[t]} &= (\mathbf{I} + \mathbf{A_{kk0}}_{[t]})^{-1} \end{aligned}\]

Corresponding code: fla.ops.kda.chunk_intra -> chunk_kda_fwd_intra -> chunk_kda_fwd_kernel_intra_sub_chunk / chunk_kda_fwd_kernel_inter_solve_fused

Step 3: Compute U and the W

\[\begin{aligned} \mathbf{U}_{[t]} &= \mathbf{A}_{[t]} \left( \text{Diag}(\boldsymbol{\beta}_{[t]}) \mathbf{V}_{[t]} \right) \\ \\ \mathbf{W}_{[t]} &= \mathbf{A}_{[t]} \left( (\exp \boldsymbol{\Gamma}_{[t]}) \odot \text{Diag}(\boldsymbol{\beta}_{[t]}) \mathbf{K}_{[t]} \right) \\ \\ \mathbf{Q_g}_{[t]} &= (\exp \boldsymbol{\Gamma}_{[t]}) \odot \mathbf{Q}_{[t]} \\ \\ \mathbf{K_g}_{[t]} &= (\exp (\boldsymbol{1} \boldsymbol{\gamma}_{[t]}^{C\top} - \boldsymbol{\Gamma}_{[t]})) \odot \mathbf{K}_{[t]} \end{aligned}\]

To save memory bandwidth during the forward pass, \(\mathbf{Q_g}\) is not materialized if activation recomputation is enabled. Instead, it is computed on-the-fly during the backward pass.

Corresponding code: fla.ops.kda.wy_fast -> recompute_w_u_fwd

Step 4: Recursively compute the hidden state

Here, S_[t]^C uses the [d_k, d_v] layout, so it is transposed relative to the previous derivation.

\[\begin{aligned} \mathbf{V}_{[t],new} &= \mathbf{U}_{[t]} - \mathbf{W}_{[t]} \mathbf{S}_{[t-1]}^{C} \\ \\ \mathbf{S}_{[t]}^{C} &= \boldsymbol{\gamma}_{[t]}^C \mathbf{S}_{[t-1]}^{C} + \mathbf{K_g}_{[t]}^\top \mathbf{V}_{[t],new} \end{aligned}\]

Corresponding code: fla.ops.common.chunk_delta_h -> chunk_gated_delta_rule_fwd_h

Step 5: Compute the final output

\[\begin{aligned} \mathbf{O}_{[t]} = \left((\exp \boldsymbol{\Gamma}_{[t]}) \odot \mathbf{Q}_{[t]}\right) \mathbf{S}_{[t-1]}^{C} + \mathbf{A_{qk}}_{[t]} \mathbf{V}_{[t],new} \end{aligned}\]

Corresponding code: fla.ops.gla.chunk -> chunk_gla_fwd_o_gk

Backward

Entry function signature:

def chunk_kda_bwd(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    beta: torch.Tensor,
    Aqk: torch.Tensor,
    Akk: torch.Tensor,
    scale: float,
    initial_state: torch.Tensor,
    do: torch.Tensor,
    dht: torch.Tensor,
    g: torch.Tensor | None = None,
    g_org: torch.Tensor | None = None,
    cu_seqlens: torch.LongTensor | None = None,
    chunk_indices: torch.LongTensor | None = None,
    chunk_size: int = 64,
    safe_gate: bool = False,
    lower_bound: float | None = None,
    use_gate_in_kernel: bool = False,
    A_log: torch.Tensor | None = None,
    dt_bias: torch.Tensor | None = None,
    disable_recompute: bool = False,
    cp_context: FLACPContext | None = None,
    transpose_state_layout: bool = False,
    **kwargs,
):
    return dq, dk, dv, db, dg, dh0, dA, dbias

Step 1: Recompute U and W

A_[t] = (I + A0_[t])^{-1} has already been stored. Everything else is the same as in the forward pass except for extra calculation for Q_g

Corresponding code: fla.ops.kda.wy_fast -> recompute_w_u_fwd

Step 2: Recompute the hidden state

This is the same as in the forward pass.

Corresponding code: fla.ops.common.chunk_delta_h -> chunk_gated_delta_rule_fwd_h

Step 3: Derive Initial Gradients from Output (\(\delta \mathbf{A_{qk}}\) and \(\delta \mathbf{U}\))

\[\begin{aligned} \left.\delta \mathbf{U}_{[t]}\right|_{\text{from } \mathbf{O}_{[t]}} &= \mathbf{A_{qk}}_{[t]}^\top \delta \mathbf{O}_{[t]} \\ \\ \delta \mathbf{A_{qk}}_{[t]} &= \delta \mathbf{O}_{[t]} \mathbf{V}_{[t],new}^\top \odot \mathbf{M} \end{aligned}\]

Corresponding code: fla.ops.kda.chunk_bwd -> chunk_kda_bwd_dAv -> chunk_kda_bwd_kernel_dAv

Step 4: Inter-Chunk Backward Recurrence for Hidden States

\[\begin{aligned} \delta \mathbf{U}_{[t]} = \mathbf{K_g}_{[t]} \delta \mathbf{S}_{[t]}^{C} + \left.\delta \mathbf{U}_{[t]}\right|_{\text{from } \mathbf{O}_{[t]}} \end{aligned}\]
\[\begin{aligned} \delta \mathbf{S}_{[t]}^{C} = \exp(\boldsymbol{\gamma}_{[t]}^C) \delta \mathbf{S}_{[t+1]}^{C} + \mathbf{Q_g}_{[t+1]}^\top \delta \mathbf{O}_{[t+1]} - \mathbf{W}_{[t+1]}^\top \delta \mathbf{U}_{[t+1]} \end{aligned}\]

Corresponding code: fla.ops.common.chunk_delta_h -> chunk_gated_delta_rule_bwd_dhu

Step 5: The Fused WY-Representation Kernel

\[\begin{aligned} \delta \mathbf{Q}_{[t],part1} &= (\exp \boldsymbol{\Gamma}_{[t]}) \odot \mathbf{O}_{[t]} \mathbf{S}_{[t-1]}^{C\top} \\ \\ \delta \mathbf{W}_{[t]} &= - \delta \mathbf{U}_{[t]} \mathbf{S}_{[t-1]}^{C} \\ \\ \left.\delta \mathbf{K}_{[t], part1}\right|_{\text{from } \mathbf{S}_{[t]}^C \text{w/o} \mathbf{A}_{[t]} } &= \exp (\boldsymbol{1} \boldsymbol{\gamma}_{[t]}^{C\top} - \boldsymbol{\Gamma}_{[t]}) \odot \mathbf{V}_{[t],new} \delta \mathbf{S}_{[t]}^{C\top} \\ \\ \left.\delta \mathbf{K}_{[t],part1}\right|_{\text{from } \mathbf{W}_{[t]} \text{w/o} \mathbf{A}_{[t]} } &= \mathbf{A}_{[t]}^\top \delta \mathbf{W}_{[t]} \odot (\text{Diag}(\boldsymbol{\beta}_{[t]}) \exp \mathbf{\Gamma}_{[t]} ) \\ \\ \delta \mathbf{V}_{[t]} &= \text{Diag}(\boldsymbol{\beta}_{[t]}) \mathbf{A}_{[t]}^\top \delta \mathbf{U}_{[t]} \\ \\ \delta \boldsymbol{\gamma}_{[t]}^C &= \delta \exp \boldsymbol{\gamma}_{[t]}^C \odot \exp \boldsymbol{\gamma}_{[t]}^C \\ &= \text{diag}\left( \mathbf{S}_{[t-1]}^C \delta \mathbf{S}_{[t]}^C \right) \odot \exp \boldsymbol{\gamma}_{[t]}^C + \left.\delta \mathbf{K}_{[t], part1}\right|_{\text{from } \mathbf{S}_{[t]}^C \text{w/o} \mathbf{A}_{[t]} } \odot \mathbf{K}_{[t]} \\ \\ \delta \mathbf{\Gamma}_{[t],part1} &= \left.\delta \exp \mathbf{\Gamma}_{[t]}\right|_{\text{w/o extra} \boldsymbol{\gamma}_{[t]}^C} \odot \exp \mathbf{\Gamma}_{[t]} \\ &= \delta \mathbf{Q}_{[t],part1} \odot \mathbf{Q}_{[t]} - \left.\delta \mathbf{K}_{[t], part1}\right|_{\text{from } \mathbf{S}_{[t]}^C \text{w/o} \widetilde{\mathbf{A}}_{[t]} } \odot \mathbf{K}_{[t]} + \left.\delta \mathbf{K}_{[t],part1}\right|_{\text{from } \mathbf{W}_{[t]} \text{w/o} \widetilde{\mathbf{A}}_{[t]} } \odot \mathbf{K}_{[t]} \\ &+ [0,0,...,\delta \boldsymbol{\gamma}_{[t]}^C ] \\ \\ \delta \boldsymbol{\beta}_{[t],part1} &= \text{diag}(\delta \text{Diag}(\boldsymbol{\beta}_{[t]})) = \text{diag}( \mathbf{A}_{[t]}^\top \delta \mathbf{U}_{[t]} \mathbf{V}_{[t]}^\top ) + \text{diag}( \mathbf{A}_{[t]}^\top \delta \mathbf{W}_{[t]} ( \boldsymbol{\Gamma}_{[t]} \odot \mathbf{K}_{[t]} )^\top ) \\ \\ \delta \boldsymbol{\beta}_{[t]} &= \text{diag}(\delta \text{Diag}(\boldsymbol{\beta}_{[t]})) \\ &= \text{diag}\left( \widetilde{\mathbf{A}}_{[t]}^\top \delta \widetilde{\mathbf{A}}_{[t]} \right) \odot \boldsymbol{\beta}_{[t]}^{-1} + \text{diag}\left( \left( \delta \widetilde{\mathbf{X}}_{[t]} \odot \mathbf{M}_{-1} \right) \left( \overrightarrow{\mathbf{K}_{[t]}} \overleftarrow{\mathbf{K}_{[t]}}^\top \right) \right) \\ \\ \delta \mathbf{A}_{[t]} &= \left( \delta \mathbf{U}_{[t]} \mathbf{V}_{[t]}^\top + \delta \mathbf{W}_{[t]} (\boldsymbol{\Gamma}_{[t]} \odot \mathbf{K}_{[t]})^\top \right) \text{Diag}(\boldsymbol{\beta}_{[t]}) \\ \\ \delta \mathbf{A_X}_{[t]} &= - \mathbf{A}_{[t]}^\top \delta \mathbf{A}_{[t]} \mathbf{A}_{[t]}^\top \odot \mathbf{M}_{-1} \end{aligned}\]

Corresponding code: fla.ops.kda.chunk_bwd -> chunk_kda_bwd_wy_dqkg_fused

Step 6: Intra-Chunk Gradients

\[\begin{aligned} \delta \mathbf{Q}_{[t],part2} &= \delta \mathbf{A_{qk}}_{[t]} \left( \exp (\boldsymbol{1} \boldsymbol{\gamma}_{[t]}^{C\top} - \boldsymbol{\Gamma}_{[t]}) \odot \mathbf{K}_{[t]}) \right) \odot \exp (\boldsymbol{\Gamma}_{[t]} - \boldsymbol{1} \boldsymbol{\gamma}_{[t]}^{C\top}) \\ \\ \delta \mathbf{K}_{[t],part2} &= \text{Diag}(\boldsymbol{\beta}_{[t]}) \delta \mathbf{A_X}_{[t]} \left( \mathbf{K}_{[t]} \odot \exp (\boldsymbol{1} \boldsymbol{\gamma}_{[t]}^{C\top} - \boldsymbol{\Gamma}_{[t]}) \right) \odot \exp (\boldsymbol{\Gamma}_{[t]} - \boldsymbol{1} \boldsymbol{\gamma}_{[t]}^{C\top}) \\&+ \delta \mathbf{A_{qk}}_{[t]}^\top \left( \mathbf{Q}_{[t]} \odot \exp (\boldsymbol{\Gamma}_{[t]} - \boldsymbol{1} \boldsymbol{\gamma}_{[t]}^{C\top})\right) \odot \exp (\boldsymbol{1} \boldsymbol{\gamma}_{[t]}^{C\top} - \boldsymbol{\Gamma}_{[t]}) \\&+ \text{Diag}(\boldsymbol{\beta}_{[t]}) \delta \mathbf{A_X}_{[t]}^\top \left( \mathbf{K}_{[t]} \odot \exp (\boldsymbol{\Gamma}_{[t]} - \boldsymbol{1} \boldsymbol{\gamma}_{[t]}^{C\top})\right) \odot \exp (\boldsymbol{1} \boldsymbol{\gamma}_{[t]}^{C\top} - \boldsymbol{\Gamma}_{[t]}) \\&=\left.\delta \mathbf{K}_{[t]}\right|_{\leftarrow \text{from } \mathbf{A}_{[t]} } + \left( \left.\delta \mathbf{K}_{[t]}\right|_{\rightarrow \text{from } \mathbf{O}_{[t]} \text{w/o } \mathbf{A}_{[t]} } + \left.\delta \mathbf{K}_{[t]}\right|_{\rightarrow \text{from }\mathbf{A}_{[t]} } \right) \\ \\ \delta \mathbf{\Gamma}_{[t],part2} &= \delta \exp \mathbf{\Gamma}_{[t]} \odot \exp \mathbf{\Gamma}_{[t]} \\&= \delta \mathbf{Q}_{[t],part2} \odot \mathbf{Q}_{[t]} + \left( \left.\delta \mathbf{K}_{[t]}\right|_{\leftarrow \text{from } \mathbf{A}_{[t]} } - \left( \left.\delta \mathbf{K}_{[t]}\right|_{\rightarrow \text{from } \mathbf{O}_{[t]} \text{w/o } \mathbf{A}_{[t]} } + \left.\delta \mathbf{K}_{[t]}\right|_{\rightarrow \text{from }\mathbf{A}_{[t]} } \right) \right) \odot \mathbf{K}_{[t]} \end{aligned}\]

Corresponding code: fla.ops.kda.chunk_intra -> chunk_kda_bwd_intra -> chunk_kda_bwd_kernel_intra

Comment In the function chunk_kda_bwd_kernel_intra, the diagonal computation of secondary chunking is carefully optimized using the SAFE_GATE mechanism. Naively calculating \(\exp(g_i - g_j)\) for intra-block causal dependencies risks FP16 overflow, typically forcing a fallback to slow, token-by-token vector multiplications. SAFE_GATE elegantly resolves this by gathering a local reference point to normalize the exponents (acting as a local max-trick). This numerical stabilization safely replaces the sequential vector loops with hardware-accelerated tl.dot, drastically boosting computational speed while preserving mathematical precision.

Step 7: Recover Gradients for the Log Gate

\[\begin{aligned} \delta \log \boldsymbol{\alpha}_{[t]} &= \text{suffix\_cumsum}(\delta\mathbf{\gamma}_{[t]}) \end{aligned}\]

Corresponding code: fla.ops.utils.cumsum -> chunk_local_cumsum

Comments