Reading Notes: Implementation Notes for DeltaNet in FLA¶
Paper: https://arxiv.org/abs/2406.06484
Code: https://github.com/fla-org/flash-linear-attention
Disclaimer: These notes provide a step-by-step mathematical reading of the Triton implementation of DeltaNet in the flash-linear-attention repository, and are intended to be read together with the previous note on the DeltaNet paper.
1. Forward¶
Entry function:
def chunk_delta_rule_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
chunk_indices: torch.LongTensor | None = None,
):
return o, A, final_state
Step 1: Compute the \(KK^\top\) matrix¶
Corresponding code: fla.ops.delta_rule.wy_fast -> prepare_wy_repr_fwd -> chunk_scaled_dot_kkt_fwd
Step 2: Invert the unit lower-triangular matrix¶
In the code, this step is implemented using row-wise forward substitution together with block matrix inversion. The 16 x 16 submatrices are handled with forward substitution. Here, A_i corresponds to X^{-1} in the previous note.
Corresponding code: fla.ops.delta_rule.wy_fast -> prepare_wy_repr_fwd -> solve_tril
Comment: Neumann series plus recursive doubling can be faster, but the numerical precision is worse.
Step 3: Compute U and W¶
Corresponding code: fla.ops.delta_rule.wy_fast -> prepare_wy_repr_fwd -> recompute_w_u_fwd
Step 4: Recursively compute the hidden state¶
Notation note: Here,
S_[t]^Cuses the[d_k, d_v]layout, so it is transposed relative to the derivation in the previous paper note.
Corresponding code: fla.ops.common.chunk_delta_h -> chunk_gated_delta_rule_fwd_h
Step 5: Compute the final output¶
Corresponding code: fla.ops.common.chunk_o -> chunk_fwd_o
2. Backward¶
Entry function:
def chunk_delta_rule_bwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
beta: torch.Tensor,
A: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
do: torch.Tensor,
dht: torch.Tensor,
cu_seqlens: torch.LongTensor | None = None,
chunk_indices: torch.LongTensor | None = None,
):
return dq, dk, dv, db, dh0
Step 1: Recompute U and W¶
A_i has already been stored in the forward pass, while U and W need to be recomputed.
Corresponding code: fla.ops.delta_rule.wy_fast -> prepare_wy_repr_fwd -> recompute_w_u_fwd
Step 2: Recompute the hidden state¶
Corresponding code: fla.ops.common.chunk_delta_h -> chunk_gated_delta_rule_fwd_h
Step 3: Compute the intra-chunk part of \(\delta U\)¶
Corresponding code: fla.ops.common.chunk_o -> chunk_bwd_dv_local
Step 4: Recursively compute \(\delta U\) and \(\delta S\)¶
Corresponding code: fla.ops.common.chunk_delta_h -> chunk_gated_delta_rule_bwd_dhu
Step 5: Compute \(\delta Q\), \(\delta K_{\text{part1}}\), and \(\delta W\)¶
Corresponding code: fla.ops.common.chunk_o -> chunk_bwd_dqkwg
Step 6: Compute \(\delta V\), \(\delta K_{\text{part2}}\), and \(\delta \beta\)¶
In this step, the backward pass through the WY representation is used to compute dV, the remaining part of dK, and d beta.
Corresponding code: fla.ops.delta_rule.wy_fast -> prepare_wy_repr_bwd