跳转至

阅读笔记:RetNet

原文链接https://arxiv.org/abs/2307.08621
笔记链接https://mzeromiko.github.io/blogs
声明:本文为个人阅读笔记。所有来自我自己的推导都可能存在错误,如有发现请指正。

一、动机

二、Retention 机制设计

2.1 架构动机

输入:

\[\begin{aligned} X^0 &= [x_1, \cdots, x_{|x|}] \in \mathbb{R}^{|x| \times d_{\text{model}}} \\ \\ s_n &= As_{n-1} + K_n^\top v_n , \quad A \in \mathbb{R}^{d \times d}, K_n \in \mathbb{R}^{1 \times d} \\ \\ o_n &= Q_n s_n = \sum_{m=1}^n Q_n A^{n-m} K_m^\top v_m , \quad Q_n \in \mathbb{R}^{1 \times d} \\ \\ Q &= X W_Q , \quad K = X W_K , \quad V = X W_V \end{aligned}\]

对矩阵 \(A\) 做对角化:

\[\begin{aligned} A = \Lambda(\gamma e^{i\theta})\Lambda^{-1} ,\quad \gamma, \theta \in \mathbb{R}^d ,\quad A^{n-m} = \Lambda(\gamma e^{i\theta})^{n-m}\Lambda^{-1} \end{aligned}\]

进一步将 \(\Lambda\) 吸收到 \(W_Q, W_K\) 中,可以得到:

\[\begin{aligned} o_n = \sum_{m=1}^n Q_n (\gamma e^{i\theta})^{n-m} K_m^\top v_m = \sum_{m=1}^n (Q_n(\gamma e^{i\theta})^n)(K_m(\gamma e^{i\theta})^{-m})^\top v_m \end{aligned}\]

进一步,如果 \(\gamma\) 是标量,则:

\[\begin{aligned} o_n = \sum_{m=1}^n \gamma^{n-m}(Q_n e^{in\theta})(K_m e^{im\theta})^\dagger v_m \end{aligned}\]

这样就非常容易并行化了。

评论:矩阵 \(A\) 需要可对角化;在 RNN 中还需要特征值模长小于 1;需要时不变。还有什么隐藏假设?

2.2 递推模式(Recurrent Mode)

\[\begin{aligned} S_n &= \gamma S_{n-1} + K_n^\top V_n \\ \\ o_n &= Q_n S_n, \quad n = 1, \cdots, |x| \end{aligned}\]

2.3 并行模式(Parallel Mode)

\[\begin{aligned} Q &= (XW_Q) \odot \Theta , \quad K = (XW_K) \odot \overline{\Theta} , \quad V = XW_V \\ \\ \Theta_n &= e^{in\theta} , \quad D_{nm} = \begin{cases} \gamma^{n-m}, & n \ge m \\ 0, & n < m \end{cases} \\ \\ O &= (QK^\top \odot D)V \end{aligned}\]

其中 \(\overline{\Theta}\)\(\Theta\) 的共轭转置。

2.4 分块递推模式(Chunkwise Recurrent Mode)

\[\begin{aligned} Q_{[i]} &= Q_{Bi:B(i+1)} , \quad K_{[i]} = K_{Bi:B(i+1)} , \quad V_{[i]} = V_{Bi:B(i+1)} \\ \\ R_i &= K_{[i]}^\top (V_{[i]} \odot \zeta) + \gamma^B R_{i-1} , \quad \zeta_{ij} = \gamma^{B-i-1} \\ \\ O_{[i]} &= \underbrace{(Q_{[i]}K_{[i]}^\top \odot D)V_{[i]}}_{\text{Inner-Chunk}} + \underbrace{(Q_{[i]}R_{i-1}) \odot \xi}_{\text{Cross-Chunk}} , \quad \xi_{ij} = \gamma^{i+1} \end{aligned}\]

三、模型设计

1. 多头设计

头数量 \(h = d_{\text{model}} / d\),其中 \(d\) 是单个头的维度。

2. Multi-Scale Retention(MSR)模块结构

\[\begin{aligned} \gamma &= 1 - 2^{-5-\text{arange}(0, h)} \in \mathbb{R}^h \\ \\ \text{head}_i &= \text{Retention}(X, \gamma_i) \\ \\ Y &= \text{GroupNorm}_h(\text{Concat}(\text{head}_1, \cdots, \text{head}_h)) \\ \\ \text{MSR}(X) &= (\text{swish}(XW_G) \odot Y)W_O \end{aligned}\]

3. 归一化(Normalization)

  • 原因一:利用 GroupNorm 的缩放不变性 \(\text{GroupNorm}(\alpha \cdot \text{head}_i) = \text{GroupNorm}(\text{head}_i)\),将并行模式乘上若干系数后,通过数学技巧确保模型在深层堆叠时的数值稳定性。
  • 原因二:不同 head 之间存在不同的方差。

最终归一化公式:

\[\begin{aligned} QK^\top &\to QK^\top / \sqrt{d} \\ \\ D_{nm} &\to \tilde{D}_{nm} = D_{nm}/\sqrt{\sum_{i=1}^n D_{ni}} \\ \\ R_{nm} &\to \tilde{R}_{nm} = R_{nm}/\max(|\sum_{i=1}^n R_{ni}|, 1) \end{aligned}\]

评论:是不是要 detach 那些归一化系数?

4. 单层网络结构

\[\begin{aligned} Y^l &= \text{MSR}(\text{LN}(X^l)) + X^l \\ \\ X^{l+1} &= \text{FFN}(\text{LN}(Y^l)) + Y^l \\ \\ \text{FFN}(X) &= \text{gelu}(XW_1)W_2 \end{aligned}\]

5. 训练与推理模式

  • 训练:采用并行模式和分块递推模式
  • 推理:采用递推模式

6. 模型参数分配

四、训练与评估

1. 模型训练

2. 性能

3. 训练效率

RetNet 采用 PyTorch 实现,训练使用 chunkwise recurrent mode,chunk size = 512,硬件为 8×A100 80G。对于 6.7B 和 13B 模型,采用了 Tensor Parallel。

4. 推理效率

5. 消融实验

采用 200M 参数的模型,16 层,hidden dim = 1024。H3 的 head dim = 8。对于 RWKV,采用 TimeMix 模块替代 attention,FFN 与其他模型相同。训练 10K steps,batch size = 0.5M tokens。训练数据集为 RetNet 的训练数据集。

Comments