Reading Notes: RetNet¶
Paper: https://arxiv.org/abs/2307.08621
Disclaimer: These are personal reading notes. Some derivations are my own and may be incorrect, so please let me know if you spot any mistakes.
1. Motivation¶

2. The Retention Mechanism¶
2.1 Architectural Motivation¶
The starting point is a recurrent state-space-like formulation:
The paper then diagonalizes the transition matrix A:
By absorbing the change-of-basis matrices into W_Q and W_K, the output can be rewritten into a much cleaner form:
If we further assume that gamma is a scalar, the expression becomes even simpler:
This form is much easier to parallelize.
Comment: This derivation seems to rely on several hidden assumptions:
Aneeds to be diagonalizable; in an RNN-style setting its eigenvalues usually need magnitude smaller than 1 for stability; and the transition must be time-invariant. There may be other assumptions hiding in the background as well.
2.2 Recurrent Mode¶
In recurrent mode, the retention mechanism is written as:
This is the form used during autoregressive inference.
The state is updated step by step, which makes decoding efficient.
2.3 Parallel Mode¶
For training, the same mechanism can be rewritten into a fully parallel form:
where overline_Theta denotes the complex conjugate of \Theta.
2.4 Chunkwise Recurrent Mode¶
RetNet also supports a chunkwise recurrent formulation, which sits between the fully parallel and fully recurrent views.

3. Model Design¶
3.1 Multi-Head Design¶
The model uses a multi-head setup, where the number of heads is:
h = d_model / d
with d being the per-head hidden dimension.
This is conceptually similar to multi-head attention, except each head uses a different retention scale.
3.2 Multi-Scale Retention (MSR)¶
The core module is called Multi-Scale Retention (MSR).
3.3 Normalization¶
The paper discusses two motivations for normalization.
- Reason 1: GroupNorm has a useful scale-invariance property, which helps stabilize the model numerically when stacking many layers.
- Reason 2: different heads can have very different variances, so normalization helps align them.
The final normalization scheme is:
Comment: I wonder whether those normalization coefficients should be detached during training.
3.4 A Single Layer¶
A single RetNet block is structured as follows:
3.5 Training and Inference Modes¶
RetNet uses different computational forms in training and inference:
- Training: parallel mode and chunkwise recurrent mode
- Inference: recurrent mode
This is one of the main selling points of the paper:
the same mechanism admits multiple equivalent implementations depending on the use case.
3.6 Parameter Allocation¶

4. Training and Evaluation¶

4.1 Model Training¶

4.2 Performance¶

4.3 Training Efficiency¶
RetNet is implemented in PyTorch.
Training uses the chunkwise recurrent mode with:
- chunk size =
512 - hardware =
8 × A100 80G
For the 6.7B and 13B models, the paper also uses Tensor Parallelism.

4.4 Inference Efficiency¶

4.5 Ablation Study¶
The ablation study uses a 200M-parameter model with:
- 16 layers
- hidden dimension = 1024
For H3, the head dimension is set to 8.
For RWKV, the TimeMix module is used in place of attention, while keeping the FFN the same as in the other models.
Training setup:
- 10K steps
- batch size = 0.5M tokens
- training data = the same dataset used for RetNet
