Notes on RNoPE-SWA

作者系统性分析了已有的 attention 机制,然后作者提出了混合的 attention 机制,来提高模型在长上下文的表现以及维持模型在短上下文场景下的表现。

作者系统性分析了已有的 attention 机制,然后作者提出了混合的 attention 机制,来提高模型在长上下文的表现以及维持模型在短上下文场景下的表现。

Introduction

作者首先强调了提升 LLM 上下文长度面临的问题:

  1. 如何有效处理长上下文输入
  2. 如何训练长上下文 LLM
  3. 如何降低长上下文 LLM 在 inference 时的 latency 以及 memory usage

对于建模长上下文输入,我们可以从 attention 机制或者位置编码来入手。前者类似的工作有 Landmark Attention 和 Focused Transformer, 但是这些方法的问题在于训练不稳定。 QK norm 可以比较好解决 softmax 分布过于极端的问题,但是其问题在于训练时的数值不稳定性,并且可能会影响模型的长上下文能力。

另一方面,对于位置编码,已有的工作如 APE, AliBi, RoPE 等都可以提供位置信息。但是,这些方法很有可能回影响模型最终的 attention score 分布。另外,NoPE 探究了移除 position encoding 的可能性。

还有一些工作目的是降低 softmax attention 的时间复杂度和空间复杂度。比如 sliding window attention, sparse attention, attention sink 等都可以降低整体的时间/空间复杂度。但是这些方法最终的表现都有所下降。

Observation

作者首先对比了以下不同方法对模型长上下文能力的影响。

作者训练了一个 8B 的模型,然后分别对比了三种方法:

  1. RoPE: base frequency 设置为 10,000, SFT 阶段扩展到 2M
  2. QK-Norm: 在 RoPE 的基础上,对 query 和 key 先进行 normalization 再进行 RoPE
  3. NoPE: 移除 attention 中的位置编码信息

作者分别评估了三种方法的表现,实验结果如下表

ModelVal LossMMLUHellaSwagCommonsenseQAARC-EARC-CNeedles 65k
RoPE1.5248.5573.7468.3081.0539.139.82
QK-Norm1.5348.2173.6868.2380.5438.987.93
NoPE1.5847.6172.1666.4276.9437.129.03

实验结果显示,对于通用任务,RoPE 和 QK-Norm 的表现差不多,而 NoPE 的表现较差。对于长上下文任务,QK-Norm 的表现最差。

接下来,作者分析了三种方法的 attention 分布情况。作者将上下文分为四个 segments:

  • begin: 开始的 10 个 token
  • needle: 与 needle 相关的 tokens
  • context: 通用的上下文 token
  • qc: question/completion token, 语文题答案相关的 token

作者将 needle 放置在 50% 深度的位置。评测的实验结果如下

Context LengthModel Variantsbeginneedlecontextqc
8kRoPE0.38630.03280.38090.2000
QK-Norm0.02420.01730.80200.1565
NoPE0.30580.04540.45010.1987
32kRoPE0.35410.02010.43430.1915
QK-Norm0.00640.00560.85170.1364
NoPE0.28070.03250.49810.1886
128kRoPE0.34630.00100.47510.1776
QK-Norm0.00100.00040.89930.0994
NoPE0.08460.00730.81560.0925

实验结果显示:

  1. NoPE 是最关注 needle token 信息的,RoPE 次之,QK-Norm 最差
  2. QK-Norm 更关注上下文信息,对其他的信息关注度较少

作者发现 QK-Norm 对初始的 token 信息关注度较少,作者认为这是因为 normalization 会让模型更关注邻近的 token 信息。为了验证这个观点,作者在不同的上下文长度下,分别计算了不同方法的 attention 分布情况,为了避免噪声,作者将开始的 10 个 token 以及最后的 $3%$ token 排除在外,实验结果如下图所示

Attention distribution on 8K context length

Attention distribution on 32K context length

Attention distribution on 128K context length

实验结果说明,相比于 softmax attention, QK-Norm 的 attention score 分布更平滑,其对于 need token 的注意力不如 RoPE, 这也说明了为什么 QK-Norm 的长上下文能力比较差。并且,QK-Norm 的 recency bias 更严重。

作者通过计算 RoPE 和 QK-Norm 的 attention distribution 的 entropy 来进一步说明这一点,结果如下表所示

Model8k32k128k
RoPE6.026.957.62
QK-Norm10.7112.4614.14

结果显示 QK-Norm 的熵更高,这意味着 QK-Norm 的 attention score 分布更分散,也就证明了 Qk-Norm retrieval 能力比较差。

Method

考虑到 NoPE 和 RopE 各自的优点,作者提出了一个混合架构,来结合 NoPE 与 RoPE. 具体做法就是 NoPE layer 和 RoPE layer 交替进行。作者将这个模型架构记为 RNoPE.

RNoPE 不同 layer 与不同 base frequency 产生的 attention score 分布如下

ModelNoPE Layers - beginNoPE Layers - needleNoPE Layers - contextNoPE Layers - qcRoPE Layers - beginRoPE Layers - needleRoPE Layers - contextRoPE Layers - qcneedles-128k
RoPE----0.35410.02010.43430.19157.395
RNoPE-10k0.32750.07650.56720.02870.00490.00040.68050.31428.036
RNoPE-100k0.32630.07780.56330.03270.02410.00050.67820.29727.461
RNoPE-2M0.32500.07120.57350.03030.11110.00460.62330.26117.022
RNoPE-4M0.34860.03690.59810.01650.09600.00390.67740.22276.203
RNoPE-10k-swa0.33030.07420.56340.0321----9.562

实验结果显示,在 RNoPE 架构中,

  1. 提升 base frequency 带来的增益逐渐递减
  2. NoPE layer 的 retrieval 能力比较强,表现在 attention sink 现象以及对于 needle token 的 spike. 但是其 recency bias 比较弱
  3. RoPE 的 retrieval 能力比较弱,但是其 attention sink 现象比较小
  4. 当 base frequency 增加止呕,RoPE 的 recency bias 会降低,表现在 qc token 的权重降低

作者发现,RoPE 的 receptive field 会影响 NoPE layer 的 retrieval 能力。作者总结得到两个 insight

  1. NoPE layer 对于 retrieval 任务比较擅长,而 RoPE layer 对于处理 local Information 比较擅长
  2. 限制 RoPE 的 receptive field 可以保证 NoPE layer 的 retrieval 能力

基于这两个 insight, 作者构建了 RNoPE-SWA 架构,RNoPE-SWA 相比于 RNoPE, 将 full attention 变成了 sliding window attention, 来避免 RoPE layer 对下游的 NoPE layer 产生影响。

最终,作者基于 Command R+ 构建了模型,作者去除了 QK-Norm, 对于 NoPE layer, 作者使用了 full attention, 对于 RoPE layer, 作者使用了 window size 为 4096 的 sliding window attention. 作者通过实验验证了 NoPE layer 和 RoPElayer 的比例,实验结果发现 $1:3$ 的比例是最优的。最终每 4 个 layer 为一组,前三组为 SWA, 最后一层为 full attention.

最终,在通用任务上评测的结果如下

ModelMMLUHellaSwagARC-EARC-CSATEnSATMathGSM8KWinograndeMBPP
Baseline57.575.884.648.570.030.940.968.539.1
RNope-SWA59.576.282.548.871.930.542.769.539.3

在 Ruler retrieval 上评测的结果如下

Model8k16k32k64k128k256k
Baseline96.694.495.189.183.057.1
RNope-SWA96.196.194.992.090.074.8

在 Ruler QA 上评测的结果如下

Model8k16k32k64k128k256k
Baseline53.550.052.545.536.030.0
RNope-SWA55.552.555.549.046.042.5

实验结果说明,模型在通用任务任务上与 baseline 模型表现差不多,且长上下文能力更强

Conclusion

作者提出了 RNope-SWA, 一个混合 NoPE, RoPE position embedding 和 sliding window attention 的 attention 机制。RNope-SWA 可以在保持模型表现的同时提高计算效率和降低 KV cache 占用。

尽管本文和已有的工作如 YoCo, Jamba-1.5 和 MiniMax-01 均证明了混合 attention 架构的有效性。但是,目前对于其工作机制尚缺乏一个比较系统性的理解。作者认为这是一个需要探究的方向。另外,作者还认为如何更好的汇总上下文信息与位置信息也是一个可以探究的方向。

References

Built with Hugo
Theme Stack designed by Jimmy