Reading Notes: Implementation Notes for GLA in FLA¶
Paper: https://arxiv.org/pdf/2312.06635
Code: https://github.com/fla-org/flash-linear-attention
Disclaimer: These notes provide a step-by-step mathematical reading of the Triton implementation of GLA in the flash-linear-attention repository, and are intended to be read together with the previous note on the GLA paper.
1. Forward Entry function signature:¶
def chunk_gla_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
g_cumsum: torch.Tensor | None,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
chunk_size: int = 64,
chunk_indices: torch.LongTensor | None = None,
):
return g_cumsum, A, h, ht, o
Step 1: Compute the cumulative gate in the log domain¶
Notation note: In this note,
gammaandGammacorrespond tolog gammaandlog Gammain the previous GLA note. Therefore, all formulas involvinggammaandGammabelow are written in the log domain.
Corresponding code: fla.ops.utils.cumsum -> chunk_local_cumsum
Step 2: Recursively compute the hidden state¶
Notation note: In this note,
S_[t]^Ccorresponds toS_[t]^{C\top}in the previous GLA note. The same convention is used throughout.
Corresponding code: fla.ops.common.chunk_h -> chunk_fwd_h
Step 3: Compute the intra-chunk attention matrix A for off-diagonal sub-blocks¶
In this step, we only consider the [t]-th chunk. Let A_[i,j] denote the (i,j)-th sub-block, let Gamma_[i] denote the cumsum values of the i-th sub-chunk, and let gamma_[i]^k denote the k-th element in the i-th sub-chunk.
Corresponding code: fla.ops.gla.chunk -> chunk_gla_fwd_intra_gk -> chunk_gla_fwd_A_kernel_intra_sub_inter
Step 4: Compute the intra-chunk attention matrix A for diagonal sub-blocks¶
Again, we only consider the [t]-th chunk, and the notation is the same as in Step 3.
On diagonal sub-blocks, A is computed column by column in the implementation. For simplicity, however, we write it here in elementwise form.
Corresponding code: fla.ops.gla.chunk -> chunk_gla_fwd_intra_gk -> chunk_gla_fwd_A_kernel_intra_sub_intra
If the computation is split along the d dimension and then merged, the following two functions are used. Their mathematical content is essentially the same as above:
Corresponding code: fla.ops.gla.chunk -> chunk_gla_fwd_intra_gk -> chunk_gla_fwd_A_kernel_intra_sub_intra_split
Corresponding code: fla.ops.gla.chunk -> chunk_gla_fwd_intra_gk -> chunk_gla_fwd_A_kernel_intra_sub_intra_merge
Step 5: Compute the final output¶
Since A is initialized with torch.empty in chunk_gla_fwd_intra_gk, the upper-triangular part may contain garbage values. Therefore, a causal mask needs to be applied once again when computing the final output.
Corresponding code: fla.ops.gla.chunk -> chunk_gla_fwd_o_gk
2. Backward¶
Entry function signature:
def chunk_gla_bwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
g_cumsum: torch.Tensor | None,
scale: float,
initial_state: torch.Tensor,
h: torch.Tensor,
A: torch.Tensor,
do: torch.Tensor,
dht: torch.Tensor,
cu_seqlens: torch.LongTensor | None = None,
chunk_size: int = 64,
chunk_indices: torch.LongTensor | None = None,
):
return dq, dk, dv, dg, dh0
Step 1: Compute the cumulative gate in the log domain¶
This is the same as Forward Step 1.
Corresponding code: fla.ops.utils.cumsum -> chunk_local_cumsum
Step 2: Recompute the hidden state recursively¶
This is the same as Forward Step 2.
Corresponding code: fla.ops.common.chunk_h -> chunk_fwd_h
Step 3: Recursively compute the gradient of the hidden state¶
Corresponding code: fla.ops.common.chunk_h -> chunk_bwd_dh
Step 4: Compute the gradient with respect to V¶
Corresponding code: fla.ops.gla.chunk -> chunk_gla_bwd_dv
Step 5: Compute the gradient with respect to A¶
Corresponding code: fla.ops.gla.chunk -> chunk_gla_bwd_dA
Step 6: Compute the intra-chunk gradients of Q and K¶
In this step, we only consider the [t]-th chunk. Let A_[i,j] denote the (i,j)-th sub-block, let gamma_[i]^k denote the cumsum value of the k-th element in the i-th sub-chunk, and let gamma_[i]^{C_i} denote the cumsum value of the last element in the i-th sub-chunk.
Intra-chunk gradient of delta Q:
Intra-chunk gradient of delta K:
Corresponding code: fla.ops.gla.chunk -> chunk_gla_bwd_dqk_intra
Step 7: Compute the inter-chunk gradients and merge them to obtain the final dQ, dK, and dg¶
In this step, the inter-chunk contributions are combined with the intra-chunk contributions from Step 6, and the gradient of the gate is also computed.
Corresponding code: fla.ops.gla.chunk -> chunk_gla_bwd_dqkg