meta 等提出了 ALiBi, 一个通过 linear biases 来实现位置编码的方法来提高 LLM 在推理阶段的外推能力。
Introduction
当下,有若干种位置编码的方式:
作者通过实验对比了不同的位置编码方法,发现这些方法在推理阶段的外推能力都比较差。
为了解决这个问题,作者提出了 ALiBi (attention with linear biases), 一个几乎不增加计算和内存开销的位置编码方法,来提高 LLM 在推理阶段的外推能力。
Method
作者将外推能力定义为
a model’s ability to continue performing well as the number of input tokens during validation increases beyond the number of tokens on which the the model was trained.
计 为训练阶段的上下文长度, 为推理阶段的上下文长度。
作者首先对比了不同的位置编码方法的外推能力,结果如下图所示
结果显示,不同位置编码在推理阶段扩展模型的上下文能力均有限。
| Context Length | ||
|---|---|---|
| Sinusoidal | 512 | 50 |
| 1024 | 50 | |
| RoPE | 512 | 200 |
| 1024 | 100 | |
| T5 bias | 512 | 600 |
| 1024 | 800 | |
| ALiBi | 512 | - |
| 1024 | - |
为了解决这个问题,作者提出了 AliBi, 其表达式为
其中 是一个和 heads 相关的超参数。如果我们有 8 个 heads, 则对应的 scaling 值分别为 , 如果我们有 16 个 heads, 则我们对 8 个 heads 的结果进行插值,得到 . ALiBi 的示意图如下所示
ALiBi 通过 bias 惩罚了较远的 query-key pairs, 并且不同的 heads 的惩罚项也不同,从而每个 head 对距离的信息敏感度也不尽相同。
Experiments
ALiBi 在 WikiText-103 上的实验结果如下图所示
Conclusion
作者分析了已有的 position embedding 方法,发现已有的方法在推理阶段均不能有效扩展模型的上下文长度。因此,作者提出了 AliBi, 一个通过 linear bias 来增加位置信息的方法,作者通过实验验证了 ALiBi 的有效性。
Introduction
【参考文献 1】中系统性对比了 AliBi, RoPE , T5 提出的 T5 bias 以及 Transformer 提出的绝对位置编码 (APE).
作者发现,常用的方法在 length generalization 上表现并不是最好的,而 NoPE 不需要额外的计算开销反而效果最好。
【参考文献 2】 进一步探究了 NoPE 长度外推的泛化性。作者有三点发现:
- NoPE 相比于 RoPE, 其长度外推泛化能力更强
- 对于 NoPE 来说,模型会在还没有到达预训练上下文长度之前,表现就出现下降的情况
- 通过调整 softmax 的温度超参数,我们可以提高 NoPE 的长度外推泛化性能力。
Method
【参考文献 1】对比了不同 position encoding 的相似度,结果如下图所示
实验结果表明,NoPE 与 T5 提出的 T5 bias 最相似。
作者在理论上推导出了 NoPE 的两个性质:
Theorem 1 (Absolute Encoding) Let be an input sequence of length to the model. Then, the first layer of can recover absolute positions in the hidden state . That is, there exist , and such that the self-attention and feedforward operations in the first layer compute absolute positions and write it to the next hidden state.
Theorem 2 (Relative Encoding) Suppose that the hidden state contains absolute positional information, as stated in Theorem 1, and assume that it is not overwritten by any subsequent layers. Then, the self-attention in all subsequent layers can implement a relative positional encoding: there exists a parameterization of fθ such that, for , the attention dot product between query and key at positions n and m can be expressed as:
where is a function of their content, and is a function of their relative distance.
【参考文献 2】探究了 softmax 中 normalization factor 对模型表现的影响,作者定义 attention 为
实验结果如下图所示
结果说明,通过调整 我们可以有效提高 NoPE 的上下文扩展泛化能力
Conclusion
NoPE 说明在 transformer 中我们可以不需要加入位置编码模块,这两篇论文均验证了 NoPE 的有效性。
- The Impact of Positional Encoding on Length Generalization in Transformers
- Length Generalization of Causal Transformers without Position Encoding
本文前半部分参考 参考文献1,推荐大家看博客原文。
Position encoding总结
在 上一篇blog 中, 我们介绍了 Attention 的两个性质,也就是在不加 position encoding 的情况下,Attention 对于 query 是 permutation equivariant 的,对于 key 和 value 是 permutation invariant 的。
但是“我爱你”和“你爱我”这两句话所表示的含义应该是不一样的,我们将这两句话作为 key 和 value 的时候,我们发现模型的输出是一致的,这显然是不能接受的。因此,我们就需要加入 position encoding,让模型学习到语序信息,从而明白不同的语序有不同的含义。
下面是测试代码 (来自 参考文献1)
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
model_id = "meta-llama/Llama-3.2-1B"
tok = AutoTokenizer.from_pretrained(model_id)
model = AutoModel.from_pretrained(model_id)
text = "The dog chased another dog"
tokens = tok(text, return_tensors="pt")["input_ids"]
embeddings = model.embed_tokens(tokens)
hdim = embeddings.shape[-1]
W_q = nn.Linear(hdim, hdim, bias=False)
W_k = nn.Linear(hdim, hdim, bias=False)
W_v = nn.Linear(hdim, hdim, bias=False)
mha = nn.MultiheadAttention(embed_dim=hdim, num_heads=4, batch_first=True)
with torch.no_grad():
for param in mha.parameters():
nn.init.normal_(param, std=0.1) # Initialize weights to be non-negligible
output, _ = mha(W_q(embeddings), W_k(embeddings), W_v(embeddings))
dog1_out = output[0, 2]
dog2_out = output[0, 5]
print(f"Dog output identical?: {torch.allclose(dog1_out, dog2_out, atol=1e-6)}") #True
Position encoding 可以分为绝对位置编码 (absolute position encoding, APE),相对位置编码 (relative position encoding, RPE) 以及可学习的位置编码。可学习位置编码主要是 BERT 类的模型在使用,其训练成本比较高,本文不做讨论。绝对位置编码是原始 transformer 里提出的编码模式,现在的大多数基于 transformer 模型使用的都是相对位置编码。
本文中,我们先介绍位置编码应该具有的性质,然后我们分别介绍绝对位置编码和相对位置编码,我们将着重关注苏剑林老师提出来的 RoPE。
位置编码
在介绍位置编码之前,我们首先应该关注位置编码的性质,位置编码的目标是为输入的 token embedding 增加位置信息,那么理想的位置编码应该是怎么样的呢?
我们这里直接引用 参考文献1 中给定的性质:
- 性质 1: token sequence 中每个位置的位置编码都是唯一的。这个很好理解,如果不唯一的话,那么根据前面推导的性质,这两个位置的 attention 输出就完全一致了
- 性质 2: 线性相关性。也就是说,如果我们知道了位置 处的位置编码,那么理想情况下,我们应该能比较简单地得到 处的位置编码,理想情况下,我们应该有 .
- 性质 3: 泛化到长上下文中去。我们希望位置编码不仅在 8K 的上下文起作用,还希望位置编码能够泛化到 32K 的上下文
- 性质 4: 生成模式是固定的。固定的模式有助于模型更好地学习位置相关的信息
- 性质 5: 可以扩展到多维。我们希望位置编码可以从文本扩展到图片再到视频,也就是从 到 .
绝对位置编码
绝对位置编码依照其名称,其思想就是为每个位置的 token 分配一个固定的位置信息,也就是对于输入的 hidden states , 我们有
这里,. 我们的 attention 就变成了
这里
整数位置编码
一个最简单的想法就是我们使用正整数来标记 token 所在的位置,也就是
可以看到,这个简单的设计满足性质 1,性质 2,性质 3,性质 4.
但是,注意到 attention 的输入 通常是经过 Layer Normalization 处理过后的,因此其按列符合正态分布,并且均值和方差一般较小。当我们加上整数位置编码之后,其 token 本身的信息就会被污染,也就是信噪比非常低。一个解决方法就是我们对 进行 normalization,即
现在所有的位置编码的值都比较小,但是我们发现新的位置编码不满足性质 2 了,这是因为现在位置编码还和 sequence 长度有关,我们从位置 到位置 不仅取决于 还取决于 sequence 长度
二进制位置编码
既然整数位置编码的主要问题是对输入影响太大,我们能否找一个不影响输入的整数位置编码方式呢? 参考文献1 提出了二进制位置编码,因为每个 token 是 维的,因此我们可以使用 位二进制来表示 . 比如说,当 , 时,我们的位置编码分别为
现在,我们二进制位置编码满足性质 1,性质 2. 对于性质 3,由于 位二进制的表示范围为 ,因此其泛化性受到 的影响。
参考文献1 画出了不同位置的值的变化情况。我们这里也模仿绘制出类似的曲线图
我们发现,二进制位置编码高位,也就是 的变化很慢,而低位,也就是 变化很快,
二进制位置编码解决了整数位置编码的信噪比过低和线性相关性。但是其问题是其对不同位置的 token embedding 产生的影响是不一样的。比如位置 1 和位置 2 的相同的 token embedding 之间的区别是:
一般来说, 比较小,因此使用二进制位置编码的问题是输入位置的微小变化(增加一个 token 或减少一个 token)都会对最终结果产生巨大影响。因此,我们需要想办法解决这个问题。
Sinusoidal
前面提到二进制位置编码的问题是相邻 token 之间变化太大,不够光滑。因此我们想要增加一个光滑性质,也就是说我们希望:
- 位置编码值在 之间,防止对 token embedding 产生影响
- 相邻 token 的位置编码尽可能相近,即 , 其中 是一个比较小的数。
- 与二进制一样,高位的变化比较慢,低位的变化比较快
一个想法就是利用三角函数 或者 ,三角函数满足前两个性质, 对于第三个性质,我们可以通过控制频率来满足。这样我们得到的位置编码就具有如下形式:
其中 是我们的超参数。
我们现在来推导一下上面位置编码的线性相关性:
我们发现, 位置编码不满足线性相关性。但是出现的 给了我们启发,也就是我们可以同时使用 和 来完成位置编码,这也是原始 transformer 里提出来的 Sinusoidal 位置编码,其形式为:
现在,记 , 我们再推导一下线性相关性,就得到:
也就是说,Sinusoidal 位置编码满足线性相关性。对于 Sinusoidal 位置编码我们也可以进行可视化:
相对位置编码
前面介绍了绝对位置编码,每个位置的位置编码是固定的。但是绝对位置编码的问题是,模型比较难以学习相对位置关系。
举个例子,我们提到上下文时,通常会使用“上一节”,“上一章”这些表示相对位置关系的词。
因此,我们希望让模型学习相对位置关系而不是绝对位置关系,因为相对关系更符合我们的认知。
RoPE
RoPE 由苏剑林老师提出,最早应用于 LLaMA 架构(没有确认),后续被大多数模型所采用。
之前的 PE 大多数关注于加性位置编码,也就是假设位置编码的形式为 , 基于这种假设,已有的工作基本都集中于优化下面的 Q 和 K 的内积
这里 , .
而 RoPE 里面,作者使用了一个不同的假设: 假设内积应该仅包含两者的相对信息,也就是
这里的 和 都是未知函数。我们的目标就是从这个公式中推导出一个合适的位置编码出来。
不失一般性,我们可以假设
这个假设代表初始条件下,我们不对输入做任何改变,也就是不增加位置信息。
2D 推导
与 RoPE 一样,我们直接使用复平面来进行推导。
我们假设 , 注意到二维平面上的每个点都可以表示为如下形式
其中 ( 定义参考 维基百科)
现在,对于三个向量 , , 我们可以写出其极坐标形式:
我们计算内积并比较同类项得到:
我们接下来分别推导 和 的形式
我们令 可以得到初始条件
我们再令 ,得到
这里最后一个等式带入了原始等式 (3),注意到左侧与 无关,因此右侧我们选取 , 得到
令 我们有
因此我们最终的表达式为:
并且,通过分别设置 以及 我们还可以得到
令 , 我们得到初始条件
令 , 我们有
这里我们带入了公式 (3),注意到公式左边与 无关,因此在公式右侧我们令 , 得到
分别令 并相加这些等式,我们得到
即
注意到
带入上式我们就得到
在 (4) 式中再令 ,并带入 就有
汇总
最后,我们将以上结果放在一起,就得到
这里 是一个超参数,用于控制频率。
我们记
则我们有:
并且
多维扩展
上面是 2D 的情况,对于多维情况,苏剑林老师通过将两个元素组对,然后分别进行处理,得到了多维的情形:
我们可以验证公式 (5) 仍然是成立的。
RoPE 的远程衰减性质
我们接下来看一下结果与相对距离 之间的关系, 注意到
这里 , 分别是对应的 pair,我们考虑其中一个分量,不妨假设 , 我们有
其中,第一个不等式是因为 两个向量相等时其内积最大,第二个不等式是由于二次型最大值为矩阵的特征值。
这样我们就有
我们可以简单画出对应的曲线:
这里针对不同的配置,结果也会略微不同,具体分析可以参加知乎回答 RoPE的远距离衰减
RoPE 代码实现与理解
Naive 实现
我们接下来看一下如何实现 RoPE
在实现的时候,我们一般根据 和 进行分组,也就是
我们通常按照奇偶 index 来分别计算,然后通过重排序来得到最终的结果,实现代码如下:
def apply_rotary_pos_emb(x: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor:
x_even = x[..., ::2] # (seq_len, d_k_half)
x_odd = x[..., 1::2] # (seq_len, d_k_half)
odds = cos * x_even - sin * x_odd # (...,seq_len, d_k_half)
evens = sin * x_even + cos * x_odd # (...,seq_len, d_k_half)
stacked = torch.stack((odds, evens), -2) # (...,seq_len, 2, d_k_half)
stacked_trans = rearrange(
stacked, "... seq_len double d_k_half -> ... seq_len d_k_half double"
) # (...,seq_len, d_k_half, 2)
out = rearrange(
stacked_trans, "... seq_len d_k_half double -> ... seq_len (d_k_half double)"
) # (..., seq_len, d_k)
return out
LLaMA 实现
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
在 LLaMA 中,我们首先还是计算 , 然后在计算的过程中,我们将 视作一个复数,然后 乘以 , 最后再取实部得到最终的结果
通用实现
实际上,naive 版本的实现与现在大语言模型所采用的实现并不一致,我们先看一下现有的大语言模型的 RoPE 实现,这里我们将 LLaMA的transformer代码 放在下面,
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class LlamaRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq)
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)
我们将上述代码翻译成公式,现在我们的 变成了 (对应 emb = torch.cat((freqs, freqs), dim=-1))
实际上 部分对应的向量现在变成了
我们带回到原始公式,可以得到对应的 RoPE 操作变成了
这列每一行的 和 都相差了 列.
因此,这里的区别在于,原始 RoPE 计算的 pair 为 , 而 LLaMA 里的 RoPE 计算的 pair 为 . transformers library 使用这种方式,可以减少计算量,提高整体的计算效率。
为了适应使用 LLaMA 中实现的 RoPE 的,Huggingface 对权重进行了转换,使得基于原始 RoPE 实现的模型也可以获得加速.
假设 ,原始 RoPE 的 pair 为 [(q_0, q_1), (q_2, q_3), (q_4, q_5), (q_6, q_7)], 新的 pair 为 [(q_0, q_4), (q_1, q_5), (q_2, q_6), (q_3, q_7)]. 我们希望对 index 进行 remap,我们发现一个满足条件的 permutation 为 [0, 2, 4, 6, 1, 3, 5, 7], 也就是 q_0->q_0, q_2->q_1, …, q_7->q_7.
但是,如果我们在推理时这样做,就会降低整体速度,因此 Huggingface 的做法是改变 和 的权重,具体来说,就是 , 左边是在线转换,右侧离线转换。转换好 之后,正常计算就可以了。具体代码 为
# permute for sliced rotary
def permute(w, n_heads=n_heads, dim1=dim, dim2=dim):
return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
state_dict = {
f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
loaded[f"layers.{layer_i}.attention.wq.weight"]
),
f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
loaded[f"layers.{layer_i}.attention.wk.weight"]
),
...
}
结论
本文中,我们回顾了位置编码,包括绝对位置编码和相对位置编码,我们着重介绍了 RoPE 的原理,推导以及代码实现。
参考文献
- You could have designed state of the art positional encoding
- Is LLaMA rotary embedding implementation correct?
- [LLaMA] Rotary positional embedding differs with official implementation
- RoPE blog
- RoFormer
- 位置编码之路
- RoPE的远距离衰减
作者系统性分析了已有的 attention 机制,然后作者提出了混合的 attention 机制,来提高模型在长上下文的表现以及维持模型在短上下文场景下的表现。
Introduction
作者首先强调了提升 LLM 上下文长度面临的问题:
- 如何有效处理长上下文输入
- 如何训练长上下文 LLM
- 如何降低长上下文 LLM 在 inference 时的 latency 以及 memory usage
对于建模长上下文输入,我们可以从 attention 机制或者位置编码来入手。前者类似的工作有 Landmark Attention 和 Focused Transformer, 但是这些方法的问题在于训练不稳定。 QK norm 可以比较好解决 softmax 分布过于极端的问题,但是其问题在于训练时的数值不稳定性,并且可能会影响模型的长上下文能力。
另一方面,对于位置编码,已有的工作如 APE, AliBi, RoPE 等都可以提供位置信息。但是,这些方法很有可能回影响模型最终的 attention score 分布。另外,NoPE 探究了移除 position encoding 的可能性。
还有一些工作目的是降低 softmax attention 的时间复杂度和空间复杂度。比如 sliding window attention, sparse attention, attention sink 等都可以降低整体的时间/空间复杂度。但是这些方法最终的表现都有所下降。
Observation
作者首先对比了以下不同方法对模型长上下文能力的影响。
作者训练了一个 8B 的模型,然后分别对比了三种方法:
- RoPE: base frequency 设置为 10,000, SFT 阶段扩展到 2M
- QK-Norm: 在 RoPE 的基础上,对 query 和 key 先进行 normalization 再进行 RoPE
- NoPE: 移除 attention 中的位置编码信息
作者分别评估了三种方法的表现,实验结果如下表
| Model | Val Loss | MMLU | HellaSwag | CommonsenseQA | ARC-E | ARC-C | Needles 65k |
|---|---|---|---|---|---|---|---|
| RoPE | 1.52 | 48.55 | 73.74 | 68.30 | 81.05 | 39.13 | 9.82 |
| QK-Norm | 1.53 | 48.21 | 73.68 | 68.23 | 80.54 | 38.98 | 7.93 |
| NoPE | 1.58 | 47.61 | 72.16 | 66.42 | 76.94 | 37.12 | 9.03 |
实验结果显示,对于通用任务,RoPE 和 QK-Norm 的表现差不多,而 NoPE 的表现较差。对于长上下文任务,QK-Norm 的表现最差。
接下来,作者分析了三种方法的 attention 分布情况。作者将上下文分为四个 segments:
- begin: 开始的 10 个 token
- needle: 与 needle 相关的 tokens
- context: 通用的上下文 token
- qc: question/completion token, 语文题答案相关的 token
作者将 needle 放置在 50% 深度的位置。评测的实验结果如下
| Context Length | Model Variants | begin | needle | context | qc |
|---|---|---|---|---|---|
| 8k | RoPE | 0.3863 | 0.0328 | 0.3809 | 0.2000 |
| QK-Norm | 0.0242 | 0.0173 | 0.8020 | 0.1565 | |
| NoPE | 0.3058 | 0.0454 | 0.4501 | 0.1987 | |
| 32k | RoPE | 0.3541 | 0.0201 | 0.4343 | 0.1915 |
| QK-Norm | 0.0064 | 0.0056 | 0.8517 | 0.1364 | |
| NoPE | 0.2807 | 0.0325 | 0.4981 | 0.1886 | |
| 128k | RoPE | 0.3463 | 0.0010 | 0.4751 | 0.1776 |
| QK-Norm | 0.0010 | 0.0004 | 0.8993 | 0.0994 | |
| NoPE | 0.0846 | 0.0073 | 0.8156 | 0.0925 |
实验结果显示:
- NoPE 是最关注 needle token 信息的,RoPE 次之,QK-Norm 最差
- QK-Norm 更关注上下文信息,对其他的信息关注度较少
作者发现 QK-Norm 对初始的 token 信息关注度较少,作者认为这是因为 normalization 会让模型更关注邻近的 token 信息。为了验证这个观点,作者在不同的上下文长度下,分别计算了不同方法的 attention 分布情况,为了避免噪声,作者将开始的 10 个 token 以及最后的 token 排除在外,实验结果如下图所示
实验结果说明,相比于 softmax attention, QK-Norm 的 attention score 分布更平滑,其对于 need token 的注意力不如 RoPE, 这也说明了为什么 QK-Norm 的长上下文能力比较差。并且,QK-Norm 的 recency bias 更严重。
作者通过计算 RoPE 和 QK-Norm 的 attention distribution 的 entropy 来进一步说明这一点,结果如下表所示
| Model | 8k | 32k | 128k |
|---|---|---|---|
| RoPE | 6.02 | 6.95 | 7.62 |
| QK-Norm | 10.71 | 12.46 | 14.14 |
结果显示 QK-Norm 的熵更高,这意味着 QK-Norm 的 attention score 分布更分散,也就证明了 Qk-Norm retrieval 能力比较差。
Method
考虑到 NoPE 和 RopE 各自的优点,作者提出了一个混合架构,来结合 NoPE 与 RoPE. 具体做法就是 NoPE layer 和 RoPE layer 交替进行。作者将这个模型架构记为 RNoPE.
RNoPE 不同 layer 与不同 base frequency 产生的 attention score 分布如下
| Model | NoPE Layers - begin | NoPE Layers - needle | NoPE Layers - context | NoPE Layers - qc | RoPE Layers - begin | RoPE Layers - needle | RoPE Layers - context | RoPE Layers - qc | needles-128k |
|---|---|---|---|---|---|---|---|---|---|
| RoPE | - | - | - | - | 0.3541 | 0.0201 | 0.4343 | 0.1915 | 7.395 |
| RNoPE-10k | 0.3275 | 0.0765 | 0.5672 | 0.0287 | 0.0049 | 0.0004 | 0.6805 | 0.3142 | 8.036 |
| RNoPE-100k | 0.3263 | 0.0778 | 0.5633 | 0.0327 | 0.0241 | 0.0005 | 0.6782 | 0.2972 | 7.461 |
| RNoPE-2M | 0.3250 | 0.0712 | 0.5735 | 0.0303 | 0.1111 | 0.0046 | 0.6233 | 0.2611 | 7.022 |
| RNoPE-4M | 0.3486 | 0.0369 | 0.5981 | 0.0165 | 0.0960 | 0.0039 | 0.6774 | 0.2227 | 6.203 |
| RNoPE-10k-swa | 0.3303 | 0.0742 | 0.5634 | 0.0321 | - | - | - | - | 9.562 |
实验结果显示,在 RNoPE 架构中,
- 提升 base frequency 带来的增益逐渐递减
- NoPE layer 的 retrieval 能力比较强,表现在 attention sink 现象以及对于 needle token 的 spike. 但是其 recency bias 比较弱
- RoPE 的 retrieval 能力比较弱,但是其 attention sink 现象比较小
- 当 base frequency 增加止呕,RoPE 的 recency bias 会降低,表现在 qc token 的权重降低
作者发现,RoPE 的 receptive field 会影响 NoPE layer 的 retrieval 能力。作者总结得到两个 insight
- NoPE layer 对于 retrieval 任务比较擅长,而 RoPE layer 对于处理 local Information 比较擅长
- 限制 RoPE 的 receptive field 可以保证 NoPE layer 的 retrieval 能力
基于这两个 insight, 作者构建了 RNoPE-SWA 架构,RNoPE-SWA 相比于 RNoPE, 将 full attention 变成了 sliding window attention, 来避免 RoPE layer 对下游的 NoPE layer 产生影响。
最终,作者基于 Command R+ 构建了模型,作者去除了 QK-Norm, 对于 NoPE layer, 作者使用了 full attention, 对于 RoPE layer, 作者使用了 window size 为 4096 的 sliding window attention. 作者通过实验验证了 NoPE layer 和 RoPElayer 的比例,实验结果发现 的比例是最优的。最终每 4 个 layer 为一组,前三组为 SWA, 最后一层为 full attention.
最终,在通用任务上评测的结果如下
| Model | MMLU | HellaSwag | ARC-E | ARC-C | SATEn | SATMath | GSM8K | Winogrande | MBPP |
|---|---|---|---|---|---|---|---|---|---|
| Baseline | 57.5 | 75.8 | 84.6 | 48.5 | 70.0 | 30.9 | 40.9 | 68.5 | 39.1 |
| RNope-SWA | 59.5 | 76.2 | 82.5 | 48.8 | 71.9 | 30.5 | 42.7 | 69.5 | 39.3 |
在 Ruler retrieval 上评测的结果如下
| Model | 8k | 16k | 32k | 64k | 128k | 256k |
|---|---|---|---|---|---|---|
| Baseline | 96.6 | 94.4 | 95.1 | 89.1 | 83.0 | 57.1 |
| RNope-SWA | 96.1 | 96.1 | 94.9 | 92.0 | 90.0 | 74.8 |
在 Ruler QA 上评测的结果如下
| Model | 8k | 16k | 32k | 64k | 128k | 256k |
|---|---|---|---|---|---|---|
| Baseline | 53.5 | 50.0 | 52.5 | 45.5 | 36.0 | 30.0 |
| RNope-SWA | 55.5 | 52.5 | 55.5 | 49.0 | 46.0 | 42.5 |
实验结果说明,模型在通用任务任务上与 baseline 模型表现差不多,且长上下文能力更强
Conclusion
作者提出了 RNope-SWA, 一个混合 NoPE, RoPE position embedding 和 sliding window attention 的 attention 机制。RNope-SWA 可以在保持模型表现的同时提高计算效率和降低 KV cache 占用。
尽管本文和已有的工作如 YoCo, Jamba-1.5 和 MiniMax-01 均证明了混合 attention 架构的有效性。但是,目前对于其工作机制尚缺乏一个比较系统性的理解。作者认为这是一个需要探究的方向。另外,作者还认为如何更好的汇总上下文信息与位置信息也是一个可以探究的方向。