Overview of Attention Mechanism

A hands-on guide to understanding and implementing the Attention mechanism in deep learning models.

Author

Published

2026-05-08 18:28:18+08:00

Dual Chunk Attention (DCA) 由阿里巴巴在 2024 年 9 月份提出,DCA 是一个无需训练的,扩展 LLM 上下文长度的方法,后续,DCA 被应用于 Qwen2, Qwen2.5, Qwen2.5-1M 以及 Qwen3 中,与 YARN 一起作为扩展模型上下文的有效手段

Introduction

提升 LLM 上下文长度的方法可以分为两类:一类是 training-free 的,包括 LM-infinite 和 StreamingLLM 等,这些方法以损失 long range dependency 为代价来保持较低的 perplexity。另一类为了保留全局信息,则是通过外插来扩展模型的上下文,主要工作我们在 YARN 中已经回顾了。

第二类方法的问题在于,其依赖训练,在 training-free 的 setting 下,这些方法也会导致 perplexity 的上升

因此,在本文中,作者就提出了 Dual Chunk Attention (DCA) ,一个无需训练的,扩展 LLM 上下文长度的方法。DCA 的主要做法是将 attention 的计算进行分块,这样就可以提高计算效率。

通过实验,作者给出了三点关键发现:

  1. Extrapolation: DCA 可以在无需训练的情况下,将 LLM 的上下文提升到 32K,而不导致 Perplexity 大幅度增加
  2. Orthogonality: DCA 可以和其他方法一起使用,如 YARN, 这一点已经在 Qwen2.5-1M 以及 Qwen3 中得到了应用
  3. Long Context Understanding: DCA 可以在无需训练的情况下,在长上下文设置下,达到已有 SOTA 模型的表现

Preliminary

对于一个长度为 LL 的 token 序列,我们首先定义对应的 position id 如下

Pq=[0,1,,L1],Pk=[0,1,,L1]P_{\mathbf{q}} = [0,1,\dots,L-1],\quad P_{\mathbf{k}} = [0,1,\dots,L-1]

然后,对于第 ii 个位置和第 jj 个位置的 token,其 attention score 定义为:

f(q,i),f(k,j)=Rθ,iq,Rθ,jk=qTRθ,ijk\langle f(\mathbf{q}, i), f(\mathbf{k}, j)\rangle =\langle R_{\theta,i}\mathbf{q}, R_{\theta,j}\mathbf{k}\rangle =\mathbf{q}^TR_{\theta, i-j}\mathbf{k}

具体细节参考 Position Encoding 中的 RoPE 部分介绍。这里面的关键在于,最后的结果只与相对位置 iji-j 相关,而与绝对位置 iijj 无关。因此,我们可以用一个相对位置矩阵 MRL×LM\in\mathbb{R}^{L\times L} 来表示这个信息,其中 Mij=Pq,iPk,jM_{ij}=P_{\mathbf{q},i}- P_{\mathbf{k},j} 代表了第 ii 个位置的 query q\mathbf{q} 和第 jj 个位置的 key k\mathbf{k} 的相对位置信息,其示意图如下所示

Relative Position Visualization

原始版本的 RoPE 的问题在于,在训练时,模型没有见过更长的上下文,因此其泛化性也最差,这一点在 YARN 已经得到了验证

Method

DCA 的关键在于将 sequence 分割为若干个 Chunk,然后将 attention 的计算拆分为三个部分:

  1. intra-chunk:负责计算每个 chunk 内部的 attention
  2. inter-chunk :负责计算 chunk 之间的 attention
  3. successive-chunk:负责计算相邻两个 chunk 之间的 attention

为了更好理解 DCA,我们接下来假设 L=12L=12, 这时我们有

Pq=[0,1,,11],Pk=[0,1,,11]P_{\mathbf{q}} = [0,1,\dots,11],\quad P_{\mathbf{k}} = [0,1,\dots,11]

Intra-Chunk Attention

我们首先定义个超参数 chunk size s>0s>0, 然后我们将我们的 sequence 分割成 L/sL/s 个 chunk,然后每个 chunk 重新进行编号,就得到了如下的 position id

PqIntra=[0,1,,L1]mods,PkIntra=[0,1,,L1]modsP_{\mathbf{q}}^{Intra} = [0,1,\dots,L-1]\mod s,\quad P_{\mathbf{k}}^{Intra} = [0,1,\dots,L-1]\mod s

接下来,我们定义 intra-chunk 的相对位置矩阵 MM, 此时,我们仅在每个 chunk 内部进行计算 attention,即

M[i][j]={Pq,iIntraPk,jIntra,ifPq,i/s=Pk,j/s,0,otherwiseM[i][j] = \begin{cases} P_{\mathbf{q},i}^{Intra} - P_{\mathbf{k},j}^{Intra},&\text{if}\lfloor P_{\mathbf{q},i}/s \rfloor = \lfloor P_{\mathbf{k},j} /s\rfloor ,\\ 0,&\text{otherwise} \end{cases}

在上面的例子中,假设 s=6s=6, 那么我们新的 position id 就变成了

PqIntra=[0,1,2,3,4,5,0,1,2,3,4,5]PkIntra=[0,1,2,3,4,5Chunk 0,0,1,2,3,4,5Chunk 1]\begin{aligned} P_{\mathbf{q}}^{Intra} = [0,1,2,3,4,5,0,1,2,3,4,5]\\ P_{\mathbf{k}}^{Intra} = [\underbrace{0,1,2,3,4,5}_{\text{Chunk 0}},\underbrace{0,1,2,3,4,5}_{\text{Chunk 1}}] \end{aligned}

对其进行可视化,我们就得到

Intra Chunk Attention Visualization

Inter-Chunk Attention

接下来,我们来看一下不同 chunk 之间如何计算彼此的 attention。在 Intra-chunk attention 计算中,我们忽略了跨 chunk 的信息,而且,由于现在的 position id 不再是单调递增的了,我们直接使用 PqIntraP_{\mathbf{q}}^{Intra}PkIntraP_{\mathbf{k}}^{Intra} 给出的位置信息不对,这也是为什么我们在 Intra-chunk attention 中要求 query 和 key 在同一个 chunk 中才能计算的原因。

为了解决这个问题,作者构建了一个新的 position id。首先我们引入一个新的超参数 c>maxiPq,ic>\max_i P_{\mathbf{q},i}, cc 代表了模型预训练时的上下文长度,如 4096。

接下来,基于 cc, 我们定义新的 position id 如下:

PqInter=[c1,c1,,c1]Rs,PkInter=PkIntraP_{\mathbf{q}}^{Inter} = [c-1,c-1,\dots,c-1]\in\mathbb{R}^s,\quad P_{\mathbf{k}}^{Inter} = P_{\mathbf{k}}^{Intra}

注:这里的 PqInterP_{\mathbf{q}}^{Inter} 指的是某一个 chunk 中的 position id,每个 chunk 中的 position id 都相同,请参考例子理解,后面不再赘述。

也就是说,在计算跨 chunk 的 attention 的时候,我们直接把 query 的 position id 设置为最大值,然后 key 的 position id 依然使用 intra-chunk 的位置信息。由于 maxiPk,i=s1\max_i P_{\mathbf{k},i}=s-1, 因此我们有

M[i][j]=Pq,iInterPk,jInter=c1Pk,jInterc1(s1)cs.M[i][j] = P_{\mathbf{q},i}^{Inter} - P_{\mathbf{k},j}^{Inter} = c - 1 - P_{\mathbf{k},j}^{Inter}\geq c - 1 - (s- 1) \geq c-s.

最后,我们对于 inter chunk 的位置矩阵 MM 定义如下:

M[i][j]={Pq,iInterPk,jInter,ifPq,i/sPk,j/s,0,otherwiseM[i][j] = \begin{cases} P_{\mathbf{q},i}^{Inter} - P_{\mathbf{k},j}^{Inter},&\text{if}\lfloor P_{\mathbf{q},i}/s \rfloor \neq \lfloor P_{\mathbf{k},j} /s\rfloor ,\\ 0,&\text{otherwise} \end{cases}

在上面的例子中,当 c=10c=10 时,c1=9c-1=9, 我们有

PqInter=[9,9,9,9,9,9Chunk 0,9,9,9,9,9,9Chunk 1]P_{\mathbf{q}}^{Inter}=[\underbrace{9,9,9,9,9,9}_{\text{Chunk 0}},\underbrace{9,9,9,9,9,9}_{\text{Chunk 1}}]

对其进行可视化,得到

Inter Chunk Attention Visualization

Successive-Chunk Attention

现在我们既可以计算 intra-chunk,也可以计算 inter-block 的 attention,但是问题是对于相邻的 chunk,其位置信息不对了,从上面的可视化中,我们可以看到,当 Pq,i=6P_{\mathbf{q},i}=6 , Pk,j=5P_{\mathbf{k},j}=5 时,我们有

Pq,iInterPk,jInter=95=41=Pq,iPk,jP_{\mathbf{q},i}^{Inter} - P_{\mathbf{k},j}^{Inter}=9-5=4\neq 1 = P_{\mathbf{q},i}-P_{\mathbf{k},j}

也就是说,inter-block 的 attention 会破坏原有的相对位置信息,因此我们就通过 successive chunk attention 来解决这个问题,使得 Pq,iInterPk,jInterPq,iPk,jP_{\mathbf{q},i}^{Inter} - P_{\mathbf{k},j}^{Inter}\approx P_{\mathbf{q},i}-P_{\mathbf{k},j}.

作者发现,这个问题不是所有的 chunk 都有,而是只存在于相邻的 chunk 中,因此,作者又加入了一个超参数 w>0w>0 代表了 local window size,我们可以直接将其设置为 csc-s, 通过这个 local window,我们调整对应的 position id 如下:

PqSucc=[s,s+1,,s+w1w elements,c1,,c1]Rs,PkSucc=PkInterP_{\mathbf{q}}^{Succ} = [\overbrace{s,s+1,\dots,s+w-1}^{w \text{ elements}},c-1, \dots,c-1]\in\mathbb{R}^s,\quad P_{\mathbf{k}}^{Succ} = P_{\mathbf{k}}^{Inter}

对于 successive chunk 的位置矩阵 MM 定义如下:

M[i][j]={Pq,iInterPk,jInter,ifPq,i/sPk,j/s=1,0,otherwiseM[i][j] = \begin{cases} P_{\mathbf{q},i}^{Inter} - P_{\mathbf{k},j}^{Inter},&\text{if}\lfloor P_{\mathbf{q},i}/s \rfloor - \lfloor P_{\mathbf{k},j} /s\rfloor=1 ,\\ 0,&\text{otherwise} \end{cases}

在上面的例子中,我们设置 w=4w=4, 就得到

PqSucc=[6,7,8,9,9,9Chunk 0,6,7,8,9,9,9Chunk 1]P_{\mathbf{q}}^{Succ}=[\underbrace{6,7,8,9,9,9}_{\text{Chunk 0}},\underbrace{6,7,8,9,9,9}_{\text{Chunk 1}}]

对其进行可视化,得到

Successive Chunk Attention Visualization

Computation

接下来,我们把所有的改进放在一起,就得到

M[i][j]={Pq,iIntraPk,j,ifPq,i/sPk,j/s=0,Pq,iSuccPk,j,ifPq,i/sPk,j/s=1,Pq,iInterPk,j,ifPq,i/sPk,j/s>1.M[i][j] = \begin{cases} P_{\mathbf{q},i}^{Intra} - P_{\mathbf{k},j},&\text{if}\lfloor P_{\mathbf{q},i}/s \rfloor - \lfloor P_{\mathbf{k},j} /s\rfloor=0 ,\\ P_{\mathbf{q},i}^{Succ} - P_{\mathbf{k},j},&\text{if}\lfloor P_{\mathbf{q},i}/s \rfloor - \lfloor P_{\mathbf{k},j} /s\rfloor=1 ,\\ P_{\mathbf{q},i}^{Inter} - P_{\mathbf{k},j},&\text{if}\lfloor P_{\mathbf{q},i}/s \rfloor - \lfloor P_{\mathbf{k},j} /s\rfloor>1. \end{cases}

基于上面的位置矩阵 MM, 我们再依次计算对应的 attention score

f(q,i),f(k,j)={f(q,Pq,iIntra),f(k,Pk,j),ifPq,i/sPk,j/s=0,f(q,Pq,iSucc),f(k,Pk,j),ifPq,i/sPk,j/s=1,f(q,Pq,iInter),f(k,Pk,j),ifPq,i/sPk,j/s>1.\langle f(\mathbf{q}, i), f(\mathbf{k}, j)\rangle = \begin{cases} \langle f(\mathbf{q}, P_{\mathbf{q},i}^{Intra}) , f(\mathbf{k}, P_{\mathbf{k},j})\rangle,&\text{if}\lfloor P_{\mathbf{q},i}/s \rfloor - \lfloor P_{\mathbf{k},j} /s\rfloor=0 ,\\ \langle f(\mathbf{q}, P_{\mathbf{q},i}^{Succ}) , f(\mathbf{k}, P_{\mathbf{k},j})\rangle,&\text{if}\lfloor P_{\mathbf{q},i}/s \rfloor - \lfloor P_{\mathbf{k},j} /s\rfloor=1 ,\\ \langle f(\mathbf{q}, P_{\mathbf{q},i}^{Inter}) , f(\mathbf{k}, P_{\mathbf{k},j})\rangle,&\text{if}\lfloor P_{\mathbf{q},i}/s \rfloor - \lfloor P_{\mathbf{k},j} /s\rfloor>1. \end{cases}

Code

首先是 RotaryEmbedding 部分的修改

class DCARotaryEmbedding(nn.Module):
    def __init__(self, max_len, chunk_size, local_window):
        self.max_len = max_len
        self.chunk_size = chunk_size
        self.local_window = local_window

        self.inv_freq = ...

    def forward(self, x):
        q_t =  torch.arange(self.chunk_size)
        qc_t = (q_t + self.chunk_size).clamp(max=self.chunk_size)
        k_t = torch.arange(seq_len) % self.chunk_size

        q_freqs = torch.outer(q_t, self.inv_freq)  # seq_len x dim/2
        qc_freqs = torch.outer(qc_t, self.inv_freq)
        k_freqs = torch.outer(k_t, self.inv_freq)  # seq_len x dim/2

        q_emb = torch.cat((q_freqs, q_freqs), dim=-1)  # seq_len x dim
        qc_emb = torch.cat((qc_freqs, qc_freqs), dim=-1)
        k_emb = torch.cat((k_freqs, k_freqs), dim=-1)  # seq_len x dim
        # compute related sin, cos
        return q_sin, q_cos, qc_sin, qc_cos, k_sin, k_cos

attention 计算时的逻辑

class Attention(nn.Module):
    def forward(...):
        
        key_states = apply_rotary_pos_emb(key_states, k_cos, k_sin, position_ids)
        q_states_intra = apply_rotary_pos_emb(query_states[:, :, :chunk_len, :], q_cos, q_sin,
                                              position_ids[:, :chunk_len])
        k_states_prev = key_states[:, :, :chunk_len, :]
        v_states_prev = value_states[:, :, :chunk_len, :]
        # first chunk
        flash_result = do_flash_attn(q_states_intra, k_states_prev, v_states_prev)
        flash_results.append(flash_result)
        remain_len = kv_seq_len - chunk_len

         while remain_len > 0:
            flash_per_chunk = []
            begin = kv_seq_len - remain_len
            curr_chunk_len = min(chunk_len, remain_len)
            end = begin + curr_chunk_len
            # current chunk, intra-chunk attention
            q_states_intra = apply_rotary_pos_emb(query_states[:, :, begin:end, :], q_cos, q_sin,
                                                  position_ids[:, begin:end])

            k_states_intra = key_states[:, :, begin:end, :]
            v_states_intra = value_states[:, :, begin:end, :]
            flash_result = do_flash_attn(q_states_intra, k_states_intra, v_states_intra)
            flash_per_chunk.append(flash_result)
            # successive chunk attention
            q_states_succ = apply_rotary_pos_emb(query_states[:, :, begin:end, :], qc_cos, qc_sin,
                                                 position_ids[:, begin:end])
            flash_result = do_flash_attn(q_states_succ, k_states_prev, v_states_prev, False)
            flash_per_chunk.append(flash_result)
            # inter chunk attention
            if begin - (k_states_prev.size(-2)) > 0:
                prev_len = k_states_prev.size(-2)
                q_states_inter = apply_rotary_pos_emb(query_states[:, :, begin:end, :], qc_cos, qc_sin,
                                                      position_ids[:, chunk_len - 1][:, None].repeat(1, curr_chunk_len))
                k_states_inter = key_states[:, :, :begin - prev_len, :]
                v_states_inter = value_states[:, :, :begin - prev_len, :]
                flash_result = do_flash_attn(q_states_inter, k_states_inter, v_states_inter, False)
                flash_per_chunk.append(flash_result)

            flash_results.append(flash_per_chunk)
            k_states_prev = k_states_intra
            v_states_prev = v_states_intra
            remain_len = remain_len - chunk_len
        # merge the final results
        attn_output = merge_attn_outputs(flash_results)

Evaluation

Perplexity evaluation on PG19

作者还分析了一下 DCA 的效率,结果如下

Efficiency of DCA

可以看到,在 flash attention 的基础上加上 DCA 之后,内存占用和推理时间并没有发生太大变化

作者还分析了三种 attention 对结果的贡献,如下图所示

Ablation study on three modules

结果显示,intra block 的 perplexity 是最低的,但是其在下游任务上表现是最差的。当三者结合在一起之后,perplexity 和下游任务上的表现都是最好的。

Conclusion

本文中,我们回顾了 Qwen 系列扩展大模型上下文的方法 Dual Chunk Attention (DCA) 通过将 attention 切分成更小的 chunk,然后将 attention 的计算分为 intra-chunk,inter-chunk 和 successive-chunk,分别处理 chunk 内部,chunk 之间以及相邻 chunk 的 attention,通过这种方式,在无需训练的情况下,我们可以有效将模型上下文长度扩展 4 倍以上。

    Google Research 在 23 年 12 月份提出了 Group Query Attention (GQA), 一个提升 multi-head attention 效率的方法。GQA 自 Qwen2 系列开始被应用。

    Introduction

    Multi-head attention (MHA) 的问题在于 inference 阶段,每次 decoding,都需要重新加载 attention 模块中 query layer, key layer 和 value layer 的权重,而加载权重会受带宽限制。

    已有的工作有 MQA, 也就是我们把多个 head 的 key layer 以及 value layer 压缩成一个,这样对于 hh 个 head 的 attention,我们有 hh 个 query layer,11 个 key layer 以及 1 个 value layer. 但是 MQA 的问题在于其会导致性能下降,而且训练过程会不稳定。

    因此,在本文中作者就作出了两点贡献:

    1. 如何将一个 MHA 模型转化为一个一个 MQA 模型
    2. 提出了 Group Query Attention (GQA),在保持模型性能的同时,提高计算效率

    Method

    Uptraining

    将 MHA 模型转化为 MQA 模型分为两步:

    1. 将 MHA 权重转化为 MQA 权重
    2. 额外的预训练

    具体来讲,作者使用了一个 mean pooling 的方法,来将不同 head 的 query layer 以及 key layer 的权重转化为 MQA 对应 layer 的权重。然后作者 pre-training 若干步来让模型适应新的结构。

    GQA

    GQA 的思路在于在 MHA 和 MQA 之间达到一个平衡,也就是说我们将 key layer 和 value layer 进行分组,每个组内共享一个 key layer 和 value layer, 我们假设有 hh 个 head,GG 个 group,那么

    1. G=1G=1 时,所有的 head 共享一个 key layer 和一个 value layer, 此时 GQA 等价于 MQA
    2. G=HG=H 时,每个 head 都有一个 key layer 和一个 value layer, 此时 GQA 等价于 MHA
    3. 1<G<H1<G<H 时,GQA 时 MQA 和 MHA 的一个 trade-off,兼顾两者的性能与效率

    三者的示意图如下所示

    Overview of grouped-query methods

    Code

    MQA 的代码也比较好理解,我们首先定义 group size,即 num_key_value_heads, 然后基于 group size 定义对应的 key layer self.k_proj 和 value layer self.v_proj.

    计算得到 key_statesvalue_states 之后,在计算 attention,即 eager_attention_forward 的时候,我们对 key_statesvalue_states 进行复制,即 repeat_kv

    def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
        batch, num_key_value_heads, slen, head_dim = hidden_states.shape
        if n_rep == 1:
            return hidden_states
        hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
        return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
    
    def eager_attention_forward(
        module: nn.Module,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        attention_mask: Optional[torch.Tensor],
        scaling: float,
        dropout: float = 0.0,
        **kwargs: Unpack[TransformersKwargs],
    ):
        key_states = repeat_kv(key, module.num_key_value_groups)
        value_states = repeat_kv(value, module.num_key_value_groups)
    
        attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
        if attention_mask is not None:
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask
    
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
        attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
        attn_output = torch.matmul(attn_weights, value_states)
        attn_output = attn_output.transpose(1, 2).contiguous()
    
        return attn_output, attn_weights
    
    class Qwen3Attention(nn.Module):
        def __init__(self, config: Qwen3Config, layer_idx: int):
            super().__init__()
            self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
            self.scaling = self.head_dim**-0.5
    
            self.q_proj = nn.Linear(
                config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
            )
            self.k_proj = nn.Linear(
                config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
            )
            self.v_proj = nn.Linear(
                config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
            )
            self.o_proj = nn.Linear(
                config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
            )
    
        def forward(
            self,
            hidden_states: torch.Tensor,
            position_embeddings: tuple[torch.Tensor, torch.Tensor],
            attention_mask: Optional[torch.Tensor],
            past_key_value: Optional[Cache] = None,
            cache_position: Optional[torch.LongTensor] = None,
            **kwargs: Unpack[FlashAttentionKwargs],
        ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
            input_shape = hidden_states.shape[:-1]
            hidden_shape = (*input_shape, -1, self.head_dim)
    
            query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
            key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
            value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
    
            cos, sin = position_embeddings
            query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
    
            attn_output, attn_weights = eager_attention_forward(
                self,
                query_states,
                key_states,
                value_states,
                attention_mask,
                dropout=0.0,
                scaling=self.scaling,
                sliding_window=self.sliding_window,  # diff with Llama
                **kwargs,
            )
    
            attn_output = attn_output.reshape(*input_shape, -1).contiguous()
            attn_output = self.o_proj(attn_output)
            return attn_output, attn_weights
    

    Conclusion

    本文中,作者提出了一个解决 multi-query attention 的 uptraining 方法,以及提出了 GQA,一个结合 MHA 表现和 MQA 效率的新型注意力机制。

      Introduction

      现有的大部分模型都基于 Transformer 提出的 softmax attention (SDPA), 虽然也有相关的改进工作,但是主要集中于降低 attention 计算复杂度,提高 attention 在推理时的内存使用效率等。之前的工作提出了关于 attention 的两个问题:

      1. attention sink, 即模型的注意力会放在初始几个 token 上, 这限制了模型的上下文扩展能力
      2. massive activation, 少部分 token 的 hidden states 会非常大,这限制了模型的训练稳定性

      在本文中,作者通过在 attention 中加入 gating 机制来探索 gating 对模型表现和训练稳定性的影响。尽管 gating 并没有降低 attention 计算复杂度,但是 gating 提出了一个新的视角,即 sparity 与 attention sink 和 massive activation 息息相关,这为后面 sparse attention 的研究提供了 Insight.

      作者发现,对 Multi head attention 的输出进行 head-specific gating 的效果最好,并且这种方式还可以提高训练稳定性,模型的表达能力和长上下文能力。作者还进一步分析了这种 gating 方式更好的原因,发现有两点:

      1. non-linearity: 通过 gating 可以有效提高 output projection layer 输入的秩,进而提高表达能力
      2. sparsity: gating 可以降低 massive activation 和 attention sink 的影响

      作者最终推荐使用 element-wise SDPA gating 方式来进行训练

      作者主要介绍了 gating 和 attention sink 这两部分的工作。

      gating 早在 LSTM 和 GRU 使其就得到了广泛的运用,在 transformer 之后,相关的现行注意力也有应用,比如 MiniMax-01 所使用的 Lightning Attention 等,但是这些工作没有系统性探究 gating 背后的机制。

      第二部分是 attention sink, attention sink 现象由 StreamingLLM 提出, 即模型会将相当一部分注意力权重方开始开始的几个 token 上。而本文提出的 gating 机制可以缓解 attention sink 现象。

      Method

      首先是标准 MHA 定义:

      Q=XWQ,K=XWK,V=XWVAttni(Q,K,V)=softmax(QKTdk)V,i=1,,hMHA(Q,K,V)=Concat([Attn1,,Attnh])O=MHA(Q,K,V)WO\begin{aligned} Q &= XW_Q, K=XW_K, V=XW_V\\ \mathrm{Attn}_i(Q,K,V) &= \mathrm{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V, i=1,\dots,h\\ \mathrm{MHA}(Q, K, V) &= \mathrm{Concat}([\mathrm{Attn}_1,\dots,\mathrm{Attn}_h])\\ O &= \mathrm{MHA}(Q, K, V) W_O \end{aligned}

      这里 XRn×dX\in\mathbb{R}^{n\times d} 是 transformer layer pre-normalization 的输出(或者 attention block 的输入), nn 是 sequence length, dd 是 hidden size, hh 是 number of heads, dkd_k 是 head dimension.

      接下来,作者介绍了不同的 gating 策略。这里作者用同一的公式来进行表示

      Y=g(Y,X,Wθ,σ)=Yσ(XWθ)Y' = g(Y,X,W_\theta, \sigma) = Y\odot \sigma(XW_\theta)

      这里 YY 是输入, XX 是 attention 的输入,WθW_\theta 是可学习权重

      Position 首先是位置,作者考虑了如下几种变体:

      MHA(Q,K,V)=MHA(Q,K,V)σ(XWθ))Q=Qσ(XWθ)K=Kσ(XWθ)V=Vσ(XWθ)O=Oσ(XWθ)\begin{align} \mathrm{MHA}(Q, K, V)' &= \mathrm{MHA}(Q, K, V)\odot \sigma\left(X W_\theta)\right) \tag{G1}\\ Q' &= Q\odot \sigma\left(XW_\theta\right) \tag{G2}\\ K' &= K\odot \sigma\left(XW_\theta\right) \tag{G3}\\ V' &= V\odot \sigma\left(XW_\theta\right) \tag{G4}\\ O' &= O\odot \sigma\left(XW_\theta\right) \tag{G5}\\ \end{align}

      这里 σ\sigma 是激活函数,WθW_\theta 是激活函数的可学习参数,我们可以将其理解为一个 linear layer, 即当前模块的输出取决于输入 hidden sates 经过一个线性层和激活层之后的结果,相似的做法还有 MoE 中的 gating layer, NSA 中的 gating layer 等。对应的示意图如下所示

      Positions of different gating methods

      granularity 作者设计了不同粒度的 gating(假设输入为 XRn×h×dkX\in\mathbb{R}^{n\times h\times d_k}):

      1. head-shared: 不同 head 共享 gating score, Y'[i,h,k]=gate[i,k]*Y[i,h,k]
      2. head-wise: 同一个 head 共享 gating score, Y'[i,h,:]=gate[i,h]*Y[i,h,:]
      3. element-wise: 不同元素不共享 gating score, Y'[i,h,k]=gate[i,h,k]*Y[i,h,k]

      从 attention 的角度看,不同 head 本身就承担不同的语义子空间,如果强行共享 gating,会破坏这种分工。

      format 作者还构建了 multiplication 和 addition 两种形式:

      1. multiplication: Y=Yσ(XWθ)Y'=Y\odot \sigma(XW_\theta)
      2. addition: Y=Y+σ(XWθ)Y'=Y+\sigma(XW_\theta)

      activation function 本文中作者使用了 SiLU 和 sigmoid 两种形式,即

      σsigmoid(x)=11+ex,σSiLU=xσsigmoid(x)=x1+ex\sigma_{\mathrm{sigmoid}}(x) = \frac{1}{1+e^{-x}},\quad \sigma_{\mathrm{SiLU}} = x*\sigma_{\mathrm{sigmoid}}(x)=\frac{x}{1+e^{-x}}

      Experiments

      作者构建了三个模型进行实验,模型配置如下表所示

      Model1.7B-28 layers1.7B-48 layers15B-A2.4B MoE
      Layers284824
      query heads161632
      key/value heads884
      head dim128128128
      tie embeddingyesyesno
      QK normalizationyesyesyes
      hidden size204815362048
      ffn hidden size61444608768
      experts--128
      top-K--8

      首先是不同 gating 方法对 MoE model 影响,结果如下图所示

      Performance of different gating variants

      结论如下:

      1. 对 SDPA 的输出 (G1) 或者 value (G2) 进行 gating 效果最好
      2. head-specific gating 效果更好
      3. multiplication 效果比 addition 效果更好
      4. sigmoid 效果比 SiLU 效果更好

      总的来说,position 对最终结果提升最明显,其次是 granularity 和 activation function.

      接下来是不同 gating 方法对 dense model 的影响,作者构建了两个 dense 模型,参数都是 1.7B, 这两个模型的 layers 和 FFN hidden size 不同(通过调整保持总参数一致)。作者对比了 G1 和 baseline 的表现, 结果如下图所示

      Performance of dense models with Gated attention

      结论验证了 gating 机制可以有效提高模型的表现。作者还发现使用 gating 之后,模型的训练也更加稳定,训练的损失变化曲线如下图所示

      training loss curve of gated attention

      Analysis

      首先,作者对 multi head attention 进行了重写,得到如下形式

      oik=j=1i(SijkXjWVk)WOk=j=1iSijkXj(WVkWOk)o_i^k = \sum_{j=1}^i\left(S_{ij}^k X_jW_V^k\right)W_O^k = \sum_{j=1}^i S_{ij}^k X_j(W_V^kW_O^k)

      也就是说,WKW_KWOW_O 可以吸收到一起,由于 WVjRd×dkW_V^j\in\mathbb{R}^{d\times d_k}, WOkRdk×dW_O^k\in\mathbb{R}^{d_k\times d}, 从而 rank(WVjWOk)max(rank(WVj),rank(WOk))dk\mathrm{rank}(W_V^jW_O^k)\leq \max(\mathrm{rank}(W_V^j), \mathrm{rank}(W_O^k))\leq d_k. 对于 GQAMQA, 最终的有效秩会进一步降低。

      而使用本文提到的 G1 和 G2 gating 策略之后,我们相当于是通过非线性机制提高了上面的秩,进而解决了 softmax attention 表达能力不足的问题, 实际上,StepFun 的 MFA 也是类似的思想。下面是 G1 和 G2 做的改进:

      oik=j=1i(Sijkgating(XjWVk))WOkoik=gating(j=1iSijkXjWVk)WOk\begin{align} o_i^k &= \sum_{j=1}^i\left(S_{ij}^k \mathrm{gating}(X_jW_V^k)\right)W_O^k\tag{G1}\\ o_i^k &= \mathrm{gating}\left(\sum_{j=1}^iS_{ij}^k X_jW_V^k\right)W_O^k \tag{G2} \end{align}

      通过 gating 的非线性机制,我们提高的矩阵的秩,进而提高了模型的表达能力,而 G5 提升有限的原因也在于此。实验结果如下图所示

      Performance of different non-linearity variants

      可以看到,不同的 non-linearity 方法对模型表现都有提升,这验证了矩阵秩会影响模型表达能力的分析。

      接下来,作者探究了 gating 机制对 attention score distribution 的影响,结果如下图所示

      attention score distribution of different methods

      实验结果说明:

      1. 有效的 gating 机制对应的 attention score 是非常稀疏的
      2. head-specific sparsity 非常重要,当在不同的 head 共享 gating 时,模型表现会有所下降
      3. gating 必须与 query 相关,与 G2 先比,G1 的表现更好,这说明 gating score 更依赖于 query. 作者认为基于当前 query token 构建 gating, 可以有效过滤历史 token 的噪音信息
      4. non-sparse gating 效果比较差,作者构建了一个 non-sparse 版本的 sigmoid, 结果发现模型表现非常差,这说明了 attention score 应该是一个稀疏形式

      通过前面的分析和实验结果,作者认为 gating 机制还可以缓解 attention sink 现象,作者对 baseline 以及 G1 两种方法的 attention 分布进行了可视化,结果如下图所示

      Visualization of attention sink

      实验结果整理如下表所示

      methodmassive activationattention sink
      baselinehighhigh
      input-independencehighhigh
      head-shared gatinglowhigh
      head-specific gatinglowlow

      因此,作者的结论为,input-dependent, head-specific gating 可以提高 attention score distribution 的 sparsity, 进而减缓 attention sink. 并且引入 spaisity 之后,我们还可以避免 massive activation, 进而使用更低的精度进行训练。

      最后,作者探究了以下 gating 机制的上下文扩展能力,作者在已有的模型上基于 32k 上下文长度使用了 80B token 进行 continue pre-training, 然后使用 YARN 将模型上下文长度扩展到了 128K。 测试的结果如下图所示

      Method4k8k16k32k64k128k
      Baseline88.8985.8883.1579.50--
      SDPA-Gate90.5687.1184.6179.77--
      YaRN Extended
      Baseline82.90 (-6.0)71.52 (-14.4)61.23 (-21.9)37.94 (-41.56)37.5131.65
      SDPA-Gate88.13 (-2.4)80.01 (-7.1)76.74 (-7.87)72.88 (-6.89)66.6058.82

      可以看到,对于短上下文,虽然两者表现都有所下降,但是本文提出的 gating 表现下降程度较小。而对于长上下文,本文提出的 gating 机制效果明显更好。作者分析原因认为这是由于 softmax attention 倾向于退化为对少数 token 的依赖, 而 gating 通过引入 token-level sparsity,避免了这种路径依赖。

      Conclusion

      在本文中,作者系统性探究了 attention 中的 gating 机制,包括 gating 对模型表现,训练稳定性以及训练动态的影响。作者发现,通过提高 non-linearity 和 sparsity 我们可以有效提高模型的上下文能力以及减缓 attention sink 现象。

      从更高层次看,本文的结果可以总结为一点:

      attention 的问题不在于 softmax 本身,而在于线性 aggregation 的表达上限与缺乏选择性。而 gating 提供了一种几乎零成本、却极其有效的方式来引入非线性与稀疏性。

        Appendix

        作者在附录中还进一步分析了 massive activation 以及 attention sink.

        1. massive activation 并不是 attention sink 产生的必要原因,并且 sparsity 可以减缓这一现象
        2. head-specific gating 会提升 gating score 的值,因此不同的 head 需要安排不同的 sparsity
        3. 并不能通过 clipping 的方式来提高训练稳定性
        4. 在 continue pre-training 阶段加入 gating 机制并不能提高模型的表现

        阶跃星辰等提出了 Multi-matrix Factorization Attention (MFA), 一个新型注意力机制,用于在 KV cache 限制下最大化模型的表现。

        Introduction

        multi-head attention (MHA) 的问题在于,其 KV cache 的内存占用(memory footprint)随 sequence length 以及 batch size 线性增长,从而成为了 LLM 在 decoding 阶段的瓶颈。

        为了解决 MHA 的内存占用过高问题,已有的工作如 MQA, GQA 等通过共享 key, value projection 来降低 KV cache size. 而 DeepSeek-V3 提出的 MLA 则是通过对 key, value projection 进行 low-rank compression, 然后只存储 latents 的方法来降低 KV cache size.

        但是,已有的这些方法的问题在于,当我们设置 KV cache budget 之后,它们的表现就比标准的 MHA 要差。

        基于以上这些发现,作者首先分析了已有 attention 机制的 modeling capacity, 然后使用一个统一的框架来表示这些 attention 机制。作者发现,attention heads 的个数以及 dimension 对模型表现有较大影响。

        基于这个发现,作者提出了 Multi-matrix Factorization Attention (MFA), 以及其变体 MFA-Key-Reuse (MFA-KR). MFA 的主要目的是在有限的 KV cache size 下提高模型的表现。

        Background

        作者首先介绍了 GMHA 的概念,GMHA 由三部分组成:

        1. QK circuit: 决定了信息之间如何交互
        2. valueoutput (VO) circuits:决定了信息如何传递
        3. per-head softmax attention.

        接下来,作者介绍了 Fully Parameterized Bilinear Attention (FPBA), FPBA 的定义如下:

        O=c=1d(j=1Nϕ(xWcxjH)xjUc)O = \sum_{c=1}^d\left(\sum_{j=1}^N\phi\left(\frac{xW_cx_j}{H}\right)x_jU_c\right)

        其中 ϕ\phi 是 softmax 函数,dd 是模型的 hidden dimension, NN 是 sequence length, Wc,UcRd×dW_c,U_c\in\mathbb{R}^{d\times d} 每个 channel 上的参数矩阵

        1. 每个 channel 都有各自的参数 Wc,UcW_c, U_c 来获取 xix_ixjx_j 之间的信息
        2. 提高泛化性,所有 channel 的 UcU_c 组合起来可以遍历 dd 维空间中的任意一个 permutation, 这样就避免来的信息损失
        3. 利用率高,FPBA 获取了 xix_ixjx_j 之间 dd 维空间可能的表示

        基于以上这三个特点,作者认为 FPBA 是 GMHA 框架的一个 capacity upper bound. 此时每个 token 的 KV cache 占用为 2d22d^2 (key and value).

        然后,作者分析了 MHA 及其变体与 GMHA 的关系,MHA 可以写作如下形式

        O=c=1h(j=1Nϕ(xQc(xjKc)Td)xjVc)OcT=c=1h(j=1Nϕ(x(QcKcT)xjTd)xjVcOcT)\begin{aligned} O &= \sum_{c=1}^h\left(\sum_{j=1}^N\phi\left(\frac{xQ_c(x_jK_c)^T}{\sqrt{d}}\right)x_jV_c\right)O_c^T\\ &= \sum_{c=1}^h\left(\sum_{j=1}^N\phi\left(\frac{x(Q_cK_c^T)x_j^T}{\sqrt{d}}\right)x_jV_cO_c^T\right) \end{aligned}

        其中 Qc,Kc,VcRd×hdQ_c,K_c,V_c\in\mathbb{R}^{d\times h_d}, OcRd×hdO_c\in\mathbb{R}^{d\times h_d} 分别是 query, key, value, output projection layer 对应的权重矩阵,nn 是 attention head 的个数,令 hdh_d 为每个 attention 的 head dimension,则我们有 nhd=dnh_d=d.

        可以看到,MHA 实际上是一个特殊的 FPBA, 其中,WcW_cUcU_c 分别由秩为 hdh_d 的低秩分解 QcKcTQ_cK_c^T 以及 VcOcTV_cO_c^T 近似。此时每个 token 的 KV cache 占用为 2d2d (key and value).

        MQA 可以看作是 GQA 的一个特殊情况。对于 GQA 来说,我们有一个 group size g[1,h]g\in[1, h], 当 g=1g=1 时,GQA 就是 MHA. 当 g=hg=h 时,GQA 就是 MQA, 通常 gg 满足 h % g=0h\ \%\ g=0. GQA 的表达式与 MHA 基本相同,只是多个 head 会共享一个 KcK_c 以及 VcV_c. 此时,每个 token 的 KV cache 占用为 2ghd2gh_d. 对于 MQA,其每个 token 的 KV cache 占用为 2hd2h_d.

        对于 MLA, 其表达式如下所示

        O=c=1m(j=1Nϕ(xSQQc(xjSKKc)Td)xjSVVc)OcT=c=1m(j=1Nϕ(x(SQQcKcTSKT)xjTd)xjSVVcOcT)\begin{aligned} O &= \sum_{c=1}^m\left(\sum_{j=1}^N\phi\left(\frac{xS_QQ_c(x_jS_KK_c)^T}{\sqrt{d}}\right)x_jS_VV_c\right)O_c^T\\ &= \sum_{c=1}^m\left(\sum_{j=1}^N\phi\left(\frac{x(S_QQ_cK_c^TS_K^T)x_j^T}{\sqrt{d}}\right)x_jS_VV_cO_c^T\right) \end{aligned}

        其中,SQ,SK,SVRd×CS_Q,S_K,S_V\in\mathbb{R}^{d\times C} 在所有的 heads 中是共享的,Qc,Kc,VcRC×hdQ_c,K_c,V_c\in\mathbb{R}^{C\times h_d} 是每个 head 的 query, key, value projection layer 的参数, 是 latent factorization 的维度。与 FPBA 相比,我们可以看到,MLA 实际上是在 d/md/m 个 head 上共享了参数,其中,WcW_cUcU_c 分别由秩为 的低秩分解 SQQcKcTSKTS_QQ_cK_c^TS_K^T 以及 SVVcOcTS_VV_cO_c^T 近似。尽管模型中 C>hdC>h_d, 但是最终的 rank 仍然是 hdh_d, 因此模型的表现也就受到了限制。

        Method

        对已有的 attention 分析之后,作者认为,要提高模型的表现,attention 需要做到亮点:

        1. 最小化 KV cache 占用和参数量
        2. attention 的 capacity 尽可能接近 FPBA

        基于这两个原则,作者提出了 MFA, MFA 主要依赖三个策略:

        1. 提升 attention heads 的 head dimension, 通过提高 head dimension, 我们可以有效提高 attention head 的表达能力
        2. 使用矩阵分解来降低参数量
        3. 使用单一的 KV head 来降低 KV cache 内存占用

        最终,MFA 的表达式如下所示

        O=c=1m(j=1Nϕ(xSQQc(xjSK)Td)xjSV)OcT=c=1m(j=1N(x(SQQcSKT)xjTd)xjSVOcT)\begin{aligned} O &= \sum_{c=1}^m\left(\sum_{j=1}^N\phi\left(\frac{xS_QQ_c(x_jS_K)^T}{\sqrt{d}}\right)x_jS_V\right)O_c^T\\ &= \sum_{c=1}^m\left(\sum_{j=1}^N\left(\frac{x(S_QQ_cS_K^T)x_j^T}{\sqrt{d}}\right)x_jS_VO_c^T\right) \end{aligned}

        其中 SQ,SK,SVRd×CS_Q,S_K,S_V\in\mathbb{R}^{d\times C} 是所有的 attention head 所共享的,Qc,OcRC×CQ_c,O_c\in\mathbb{R}^{C\times C} 是每个 head 的 query up projection 和 output projection, CC 是 latent factorization 的维度。

        在 inference 的时候,由于我们只需要保存 xjSKx_jS_KxjSVx_jS_V, 因此所需要的 KV cache size 为 2C2C. 与 FPBA 相比,MFA 分别使用 SQQcSKTS_QQ_cS_K^TSVOcTS_VO_c^T 来近似 WcW_cUcU_c, 近似矩阵的 rank 为 CC. 由于 C>dC>d, 因此其表达能力也更强,MFA 有如下优势:

        1. scalable head count: MFA 可以支持使用更多的 attention heads, 每增加一个 heads, 所需要的额外参数为 2C22C^2. 并且,增加 attention heads 个数不会增加 KV cache 占用
        2. enhanced head expressiveness: MFA 近似矩阵的 rank 为 C>dC>d, 因此表达能力更强
        3. Compatibility with position encodings: MFA 可以无缝集成 position encoding.

        为了进一步降低 MFA 的 KV cache 占用,作者提出了 MFA-Key-Reuse (MFA-KA). 核心思想是使用 SKS_K 来表示 SVS_V, 这样可以额外降低 50%50\% 的 KV cache 占用,表示方法如下所示

        SV=SK+αNSK=(I+diag(α)N)SKS_V = S_K + \alpha\odot NS_K = (I +\mathrm{diag}(\alpha)N)S_K

        其中 NRN×NN\in\mathbb{R}^{N\times N}, αRC\alpha\in\mathbb{R}^C.

        最终,MFA, MFA-KR 与 GQA 的对比如下图所示

        Comparison of MFA with GQA

        不同 attention 的量化对比如下表所示

        MethodKV CacheParameterHeadsFactor. rank per headShared latent subspace Dim.Total effec. rank
        FPBA2d22d^22d32d^3ddddddd2d^2
        MHA2d2d4d24d^2nnhdh_dddnhdnh_d
        MQA2hd2h_d(2+2/n)d2(2 + 2/n)d^2nnhdh_dhdh_dnhdnh_d
        GQA2ghd2gh_d(2+2g/n)d2(2 + 2g/n)d^2nnhdh_dghdgh_dnhdnh_d
        MLA2C2C5dC+d25dC + d^2mmhdh_dCCmhdmh_d
        MFA2C2C3Cd+2mC23Cd + 2mC^2mmCCCCmCmC

        Code

        class Step3vAttention(nn.Module):
            def __init__(self, config: Step3VLConfig, layer_idx):
                super().__init__()
                self.config = config
                self.layer_idx = layer_idx
                self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
                self.num_key_value_heads = 1
                self.total_num_kv_heads = self.num_key_value_heads
                self.num_attention_heads = config.num_attention_heads
                self.num_key_value_groups = config.num_attention_heads // self.num_key_value_heads
                self.q_size = getattr(config, "share_q_dim", self.head_dim)
                self.kv_size = self.num_key_value_heads * self.head_dim
                self.scaling = self.head_dim**-0.5
                self.is_causal = True
        
                self.q_proj = nn.Linear(config.hidden_size, self.q_size , bias=False)
                self.k_proj = nn.Linear(config.hidden_size, self.head_dim, bias=False)
                self.v_proj = nn.Linear(config.hidden_size, self.head_dim, bias=False)
                self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
                # query down projection normalization
                self.inter_norm = Step3vRMSNorm(self.q_size, eps=config.rms_norm_eps)
                # query up projection
                self.wq = nn.Linear(self.q_size, self.head_dim * self.num_attention_heads, bias=False)
        
            def forward(
                self,
                hidden_states: torch.Tensor,
                position_embeddings: Tuple[torch.Tensor, torch.Tensor],
                attention_mask: Optional[torch.Tensor],
                past_key_value: Optional[Cache] = None,
                cache_position: Optional[torch.LongTensor] = None,
                **kwargs: Unpack[FlashAttentionKwargs],
            ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
                input_shape = hidden_states.shape[:-1]
        
                query_states = self.q_proj(hidden_states)
                key_states = self.k_proj(hidden_states).view((*input_shape, -1, self.head_dim)).transpose(1, 2)
                value_states = self.v_proj(hidden_states).view((*input_shape, -1, self.head_dim)).transpose(1, 2)
        
                query_states = self.inter_norm(query_states)        
                query_states = self.wq(query_states).view((*input_shape, -1, self.head_dim)).transpose(1, 2)
                
                cos, sin = position_embeddings
                query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
        
                ...
        

        Conclusion

        作者在本文中提出了 MFA 以及 MFA-KR, 一个在 KV cache 有限的条件下最大限度提高 attention 表达能力的 attention 机制。

          DeepSeek 在 2024 年 5 月提出了 multi-head latent attention (MLA), 用于提高 attention 的 Inference 效率

          Introduction

          传统的 multi head attention (MHA) 虽然效果好,但是在 inference 时,其 KV cache 会变成瓶颈,影响推理效率。为了解决这个问题,已有的工作如 MQAGQA 通过共享权重来减少 KV cache 内存占用,但是结果发现模型的表现也会降低。

          为了解决这个问题,作者提出了 multi-head latent attention (MLA), 来压缩 KV cache.

          MHA

          dd 为 hidden size, nhn_h 为 attention heads 的个数,\ell 为 transformer layer 的层数,dhd_h 为每个 head 的 dimension, htRdh_t\in\mathbb{R}^d 为 attention layer 中第 tt 个 token 对应的 hidden states。对于标准的 MHA, 我们首先计算 Q, K, V 如下:

          qt=WQht,kt=WKht,vt=WVhtq_t=W^{Q}h_t,\quad k_t=W^Kh_t,\quad v_t = W^Vh_t

          其中,WQ,WK,WVRdhnh×dW^Q,W^K,W^V\in\mathbb{R}^{d_hn_h\times d} 分别为 query, key, value projection layer 的权重。接下来 MHA 的计算方式如下

          ot,i=j=1tsoftmaxj(qt,iTkj,idh)vj,i,ut=WO[ot,1;ot,2;,;ot,nh]\begin{aligned} o_{t,i} &= \sum_{j=1}^t\mathrm{softmax}_j\left(\frac{q_{t,i}^Tk_{j,i}}{\sqrt{d_h}}\right)v_{j,i},\\ u_t&= W^O[o_{t,1};o_{t,2};\dots,;o_{t,n_h}] \end{aligned}

          其中 qt=[qt,1;qt,2;,;qt,nh]q_t=[q_{t,1};q_{t,2};\dots,;q_{t,n_h}], kt=[kt,1;kt,2;,;kt,nh]k_t=[k_{t,1};k_{t,2};\dots,;k_{t,n_h}], vt=[vt,1;vt,2;,;vt,nh]v_t=[v_{t,1};v_{t,2};\dots,;v_{t,n_h}]. WORd×dhnhW^O\in\mathbb{R}^{d\times d_hn_h} 为 output projection 的权重。在 inference 阶段,每个 token 需要缓存其 key 以及 value 对应的值,从而每个 token 的 kv cache 占用为 2nhdh2n_hd_h\ell. 当序列长度过大时,KV cache 会影响整体的 inference efficiency.

          MQA & GQA

          MQA 通过在所有的 heads 中共享 key 和 value 来实现降低 kv cache 的作用,在 MQA 中,WK,WVRdh×dW^K, W^V\in\mathbb{R}^{d_h\times d}, 在计算时,对应的 ktk_tvtv_t 通过广播机制参与 attention 的计算。此时,KV cache 占用为 MHA 的 1/nh1/n_h, 即 2dh2d_h\ell.

          但是,MQA 的问题是表达能力太弱(表现差),因此后续 GQA 进行了改进,GQA 在 MQA 和 MHA 之间进行了权衡,即将 heads 分为若干个 group, 每个 group 中共享 key 和 value, 即 WK,WVRngdh×dW^K, W^V\in\mathbb{R}^{n_gd_h\times d}, 这里 ngn_g 是 group 个数,在计算 attention 时,key 和 value 在 group 内部共享,此时,GQA 的 KV cache 占用是 MQA 的 ngn_g 倍,即 2ngdh2n_gd_h\ell.

          这部分具体介绍见 MQAGQA.

          MLA

          MLA 的架构图如下所示

          MLA architecture (sourced from MHA2MLA)

          MLA 使用 low-rank joint compression 来压缩 key 以及 value 的 KV cache:

          ctKV=WDKVht,ktC=WUKctKV,vtC=WUVctKVc_t^{KV} = W^{DKV}h_t,\quad k_t^C = W^{UK}c_t^{KV}, v_t^C = W^{UV}c_t^{KV}

          这里 ctKVRdcc_t^{KV}\in\mathbb{R}^{d_c} 为 key 以及 value 压缩后的 latent vector. dcdhnhd_c \ll d_hn_h 为 KV cache compression dimension. WDKVRdc×dW^{DKV}\in\mathbb{R}^{d_c\times d} 为 down projection matrix, 这个矩阵是 key 和 value 共享的,WUK,WUVRdhnh×dcW^{UK}, W^{UV}\in\mathbb{R}^{d_hn_h\times d_c} 为 key, value 对应的 up projection matrix.

          另外,为了减少训练时的 activation memory, 作者对于 query 同样也执行了 low-rank compression, 压缩方式如下

          ctQ=WDQht,qtC=WUQctQc_t^Q = W^{DQ}h_t,\quad q_t^C = W^{UQ}c_t^Q

          其中 ctQRdcc_t^Q\in\mathbb{R}^{d_c'} 为 query 压缩后的 latent vector, dcdhnhd_c' \ll d_hn_h 为 query compression dimension, WDQRdc×dW^{DQ}\in\mathbb{R}^{d_c'\times d}, WUQRdhnh×dcW^{UQ}\in\mathbb{R}^{d_hn_h\times d_c'} 分别时 down projection, up projection matrix.

          最后 attention 的计算与 MHA 保持一致:

          ot,i=j=1tsoftmaxj(qt,iTkj,idh)vj,i,ut=WO[ot,1;ot,2;,;ot,nh]\begin{aligned} o_{t,i} &= \sum_{j=1}^t\mathrm{softmax}_j\left(\frac{q_{t,i}^Tk_{j,i}}{\sqrt{d_h}}\right)v_{j,i},\\ u_t&= W^O[o_{t,1};o_{t,2};\dots,;o_{t,n_h}] \end{aligned}

          在推理的时候,我们只需要缓存 ctKVc_t^{KV} 即可,这样每个 token 的 KV cache 为 dcd_c\ell. 并且在 inference 时,我们可以将 WUKW^{UK}WQW^{Q} 融合在一起,将 WUVW^{UV}WOW^{O} 融合在一起,也就是说我们不需要显式的计算出 ktk_t 以及 vtv_t, 即

          qtTkt=(WUQctQ)T(WUKctKV)=(ctQ)T((WUQ)TWUK)ctKVq_t^Tk_t = (W^{UQ}c_t^Q)^T(W^{UK}c_t^{KV}) = (c_t^Q)^T((W^{UQ})^TW^{UK})\boxed{c_t^{KV}}

          以及

          ut=WO[ot,1;ot,2;,;ot,nh]=i=1tWiOot,i=i=1tWiOj=1tsoftmaxj(qt,iTkj,idh)vj,i=i=1tWiOj=1tsoftmaxj(qt,iTkj,idh)WiKVctKV=i=1t(WiOWiKV)j=1tsoftmaxj(qt,iTkj,idh)ctKV\begin{aligned} u_t &= W^O[o_{t,1};o_{t,2};\dots,;o_{t,n_h}] \\ &= \sum_{i=1}^tW_i^Oo_{t,i}\\ &= \sum_{i=1}^tW_i^O\sum_{j=1}^t\mathrm{softmax}_j\left(\frac{q_{t,i}^Tk_{j,i}}{\sqrt{d_h}}\right)v_{j,i}\\ &= \sum_{i=1}^tW_i^O\sum_{j=1}^t\mathrm{softmax}_j\left(\frac{q_{t,i}^Tk_{j,i}}{\sqrt{d_h}}\right)W_i^{KV}c_t^{KV}\\ &= \sum_{i=1}^t(W_i^OW_i^{KV})\sum_{j=1}^t\mathrm{softmax}_j\left(\frac{q_{t,i}^Tk_{j,i}}{\sqrt{d_h}}\right)\boxed{c_t^{KV}} \end{aligned}

          这里 WO=[W1O,,WnhO]W^O = [W^O_1,\dots,W^O_{n_h}], WUV=[W1KV;;WnhKV]W^{UV}=[W^{KV}_1;\dots;W^{KV}_{n_h}], WiORd×dhW_i^O\in\mathbb{R}^{d\times d_h}, WiUVRdh×dcW^{UV}_i\in\mathbb{R}^{d_h\times d_c}.

          Decoupled Position Embedding

          接下来,作者介绍了如何解决 RoPE 不相容的问题。如果说我们直接在 ktCk_t^C 上进行 RoPE, 那么我们有

          qtTkt=(RmWUQctQ)T(RnWUKctKV)=(ctQ)T((WUQ)TRmnWUK)ctKVq_t^Tk_t = (R_mW^{UQ}c_t^Q)^T(R_nW^{UK}c_t^{KV}) = (c_t^Q)^T((W^{UQ})^TR_{m-n}W^{UK})\boxed{c_t^{KV}}

          此时,我们没有办法将 WUKW^{UK} 吸收到 WUQW^{UQ} 中,这样就导致在 inference 时我们必须重新计算所有 prefix token 对应的 key, 这显然会降低 inference efficiency

          为了解决这个问题,作者使用了partial RoPE的技巧,即将query和key拆解为NoPE以及RoPE两部分,前者由MLA产生,后者携带位置信息。RoPE部分包括query qt,iRRdhRq_{t,i}^R\in\mathbb{R}^{d_h^R} 以及一个共享的 key ktRRdhRk_t^R\in\mathbb{R}^{d_h^R}, 其中 dhRd_h^R 是 decoupled query 以及 decoupled key 的 head dimension.

          [!remark] 这里 key 对应的 RoPE 共享的原因是这部分信息也需要使用 KV cache 进行缓存,通过共享可以降低 KV cache 占用;而 query 对应的 RoPE 不共享的原因是提高 head 的表达能力,与 MHA 原理一致。

          对应 MLA 的计算公式如下

          qtR=RoPE(WQRctQ)ktR=RoPE(WKRht)qt,i=[qt,iC;qt,iR]kt,i=[kt,iC;ktR]ot,i=j=1tsoftmaxj(qt,iTkj,idh+dhR)vj,iC,ut=WO[ot,1;ot,2;,;ot,nh]\begin{aligned} q_t^R&=\mathrm{RoPE}(W^{QR}c_t^Q)\\ k_t^R &= \mathrm{RoPE}(W^{KR}h_t)\\ q_{t,i} &= [q_{t,i}^C;q_{t,i}^R]\\ k_{t,i} &= [k_{t,i}^C;k_{t}^R]\\ o_{t,i} &= \sum_{j=1}^t\mathrm{softmax}_j\left(\frac{q_{t,i}^Tk_{j,i}}{\sqrt{d_h+d_h^R}}\right)v_{j,i}^C,\\ u_t&= W^O[o_{t,1};o_{t,2};\dots,;o_{t,n_h}] \end{aligned}

          其中 qtR=[qt,1R;qt,2R;,;qt,nhR]q_t^R=[q_{t,1}^R;q_{t,2}^R;\dots,;q_{t,n_h}^R], WQRRdhRnh×dcW^{QR}\in\mathbb{R}^{d_h^Rn_h\times d_c'}, WKRRdhR×dW^{KR}\in\mathbb{R}^{d_h^R\times d} . RoPE()\mathrm{RoPE}(\cdot) 只执行 RoPE 矩阵乘法的操作。

          在这种情形下,attention 的计算如下所示

          qt,iTkt,i=[qt,iC;qt,iR]T[kt,iC;ktR]=(qt,iC)Tkt,iC+(qt,iR)TktR\begin{aligned} q_{t,i}^Tk_{t,i} &= [q_{t,i}^C;q_{t,i}^R]^T[k_{t,i}^C;k_{t}^R]\\ &= (q_{t,i}^C)^Tk_{t,i}^C + (q_{t,i}^R)^Tk_{t}^R \end{aligned}

          可以看到,现在 attention 的计算分为了两部分,一部分是 MLA 自身的计算,这部分计算前面已经证明可以通过矩阵吸收的方式来进行优化,第二部分是关于 RoPE 部分的计算,这部分计算量不是很大

          最终,MLA 完整的计算公式如下

          ctQ=WDQht[qt,1C;;qt,nhC]=qtC=WUQctQ[qt,1R;;qt,nhR]=qtR=RoPE(WQRctQ)qt,i=[qt,iC;qt,iR]ctKV=WDKVht[kt,1C;;kt,nhC]=ktC=WUKctKVktR=RoPE(WKRht)kt,i=[kt,iC;ktR][vt,1C;;vt,nhC]=vtC=WUVctKVot,i=j=1tsoftmaxj(qt,iTkj,idh+dhR)vj,iC,ut=WO[ot,1;ot,2;,;ot,nh]\begin{aligned} c_t^Q=W^{DQ}h_t\\ [q_{t,1}^C;\dots;q_{t,n_h}^C]=q_t^C&= W^{UQ}c_t^Q\\ [q_{t,1}^R;\dots;q_{t,n_h}^R]=q_t^R&= \mathrm{RoPE}(W^{QR}c_t^Q)\\ q_{t,i} &= [q_{t,i}^C;q_{t,i}^R]\\ \boxed{c_t^{KV}} &= W^{DKV}h_t\\ [k_{t,1}^C;\dots;k_{t,n_h}^C]=k_t^C&= W^{UK}c_t^{KV}\\ \boxed{k_t^R} &= \mathrm{RoPE}(W^{KR}h_t)\\ k_{t,i} &= [k_{t,i}^C;k_{t}^R]\\ [v_{t,1}^C;\dots;v_{t,n_h}^C]=v_t^C&= W^{UV}c_t^{KV}\\ o_{t,i} &= \sum_{j=1}^t\mathrm{softmax}_j\left(\frac{q_{t,i}^Tk_{j,i}}{\sqrt{d_h+d_h^R}}\right)v_{j,i}^C,\\ u_t&= W^O[o_{t,1};o_{t,2};\dots,;o_{t,n_h}] \end{aligned}

          在 inference 时,decoupled key 也需要被缓存,因此 DeepSeek-V2 每个 token 所需要的 KV cache 为 (dc+dhR)(d_c+d_h^R)\ell, 框选的部分即为 Inference 阶段需要缓存的内容

          MLA 与 MHA, MQA, GQA 的对比如下图所示

          Comparison of different attention mechanisms

          Comparison of KV Cache

          接下来,作者对比了不同 attention 机制的 KV cache, 结果如下表所示

          Attention MechanismKV Cache per Token (# Element)Capability
          Multi-Head Attention (MHA)2nhdh2n_hd_h\ellStrong
          Grouped-Query Attention (GQA)2ngdh2n_gd_h\ellModerate
          Multi-Query Attention (MQA)2dh2d_h\ellWeak
          MLA (Ours)(dc+dhR)9/2dh(d_c+d_h^R)\ell\approx 9/2d_h\ellStronger

          这里作者将 dcd_c 设置为 4dh4d_h, dhRd_h^R 设置为 dh/2d_h/2, 因此得到了上面的 9/2dh9/2d_h\ell 的近似。与 GQA 相比,相当于 MLA 使用了 2.25 个 group, 但是可以得到更强的效果。

          为了避免 low-rank compression 以及 fine-grained expert segmentation 对输出的 scale 产生影响,作者对 compressed latent vectors ctQ,ctKVc_t^Q, c_t^{KV} 进行了 normalization.

          Code

          首先是代码变量与公式变量的对应关系

          code namevariable nameValue
          hidden_sizedd5120
          kv_lora_rankdcd_c512
          q_lora_rankdcd_c'1536
          qk_nope_head_dimdhd_h128
          qk_rope_head_dimdhRd_h^R64
          v_head_dimdhd_h128
          num_attention_headsnhn_h128

          在具体实现时,作者对计算过程进行了优化,具体就是先合并计算然后通过 split 进行拆分,这部分策略应用于三个部分:

          [qtcqtR]=[WUQWQR]WDQht[ctKVktR]=[WDKVWKR]ht[ktcvtc]=[WUKWUV]ctKV\begin{aligned} \begin{bmatrix} q_t^c\\ q_t^R \end{bmatrix} &= \begin{bmatrix} W^{UQ}\\ W^{QR} \end{bmatrix}W^{DQ}h_t\\ \begin{bmatrix} c_t^{KV}\\ k_t^R \end{bmatrix} &= \begin{bmatrix} W^{DKV}\\ W^{KR} \end{bmatrix}h_t\\ \begin{bmatrix} k_t^c\\ v_t^c \end{bmatrix} &= \begin{bmatrix} W^{UK}\\ W^{UV} \end{bmatrix}c_t^{KV} \end{aligned}

          代码如下所示

          class DeepseekV2Attention(nn.Module):
              def __init__(self, config, layer_idx):
                  # d_h + d_h^R
                  self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
                  # W^{DQ}
                  self.q_a_proj = nn.Linear(
                      self.hidden_size, config.q_lora_rank, bias=config.attention_bias
                      )
                  # [W^{UQ}; W^{QR}]
                  self.q_b_proj = nn.Linear(
                      config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False)
                  # [W^{DKV}; W^{KR}]
                  self.kv_a_proj_with_mqa = nn.Linear(self.hidden_size, config.kv_lora_rank + config.qk_rope_head_dim, bias=config.attention_bias)
                  # [W^{UK}; W^{UV}]
                  self.kv_b_proj = nn.Linear(config.kv_lora_rank, self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), bias=False)
                  # W^O
                  self.o_proj = nn.Linear(self.num_heads * self.v_head_dim, self.hidden_size, bias=config.attention_bias)
                  
              def forward(self, hidden_states, ...):
                  # [q_t^c; q_t^R]
                  q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
                  q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
                  q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
                  
                  # [c_t^{KV}; k_t^R]
                  compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
                  compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
                  k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
          
                  # [k_t^c; v_t^c]
                  kv = (self.kv_b_proj(self.kv_a_layernorm(compressed_kv)).view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).transpose(1, 2))
                  k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
          
                  # q_t^R, k_t^R
                  q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
          
                  # q_{t, i}
                  query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
                  query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
                  query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
                  
                  # k_{t, i}
                  key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
                  key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
                  key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
                  
                  # Q^TK
                  attn_weights = (torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale)
                  # softmax(...) in FP32
                  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
                  
                  # o_{t, i}
                  attn_output = torch.matmul(attn_weights, value_states)
                  
                  # u_t
                  attn_output = self.o_proj(attn_output)
                          
                  return attn_output, attn_weights, past_key_value
                  
          

          参数量计算

          首先,我们结合 DeepSeek-V2 的 config 计算一下 MLA 部分的参数量:

          MatrixParametersvaluesratio
          WDKVW^{DKV}ddcdd_c26214401.91%
          WUKW^{UK}dhnhdcd_hn_hd_c83886086.12%
          WUVW^{UV}dhnhdcd_hn_hd_c83886086.12%
          WDQW^{DQ}dcdd_c'd78643205.74%
          WUQW^{UQ}dhnhdcd_hn_hd_c'2516582418.37%
          WKRW^{KR}dhRdd_h^Rd3276800.24%
          WQRW^{QR}dhRdd_h^Rd3276800.24%
          WOW^{O}ddhnhdd_hn_h8388608061.24%
          Totald(dc+dc+2hhR)+dhnh(2dc+dc+d)d(d_c+d_c'+2h_h^R)+d_hn_h(2d_c+d_c'+d)136970240100%

          我们接下来对比一下各个模型架构之间 attention 部分的参数量,可以看到与 MHA 一致,大部分参数量都集中在最后的 Output projection layer 上

          Experiments

          作者首先对比了 MHA, GQA, MQA 的表现,作者基于一个 7B 的 dense 模型,使用 1.33T token 进行训练,实验结果如下

          Benchmark (Metric)# ShotsMQAGQA(8 Groups)MHA
          # Params-7.1B6.9B6.9B
          BBH (EM)3-shot33.235.637.0
          MMLU (Acc.)5-shot37.941.245.2
          C-Eval (Acc.)5-shot30.037.742.9
          CMMLU (Acc.)5-shot34.638.443.5

          实验结果显示,MHA 的表现显著优于 GQA 和 MQA. 这说明了 MQA 和 GQA 虽然减少了 KV cache 的占用,但是相应地,它们对应的表现也有所降低。

          接下来,作者对比了 MLA 和 MHA 的表现,实验结果如下

          Benchmark (Metric)# ShotsMHAMLAMHAMLA
          # Activated Params-2.5B2.4B25.0B21.5B
          # Total Params-15.8B15.7B250.8B247.4B
          KV Cache per Token (# Element)-110.6K15.6K860.2K34.6K
          BBH (EM)3-shot37.939.046.650.7
          MMLU (Acc.)5-shot48.750.057.559.0
          C-Eval (Acc.)5-shot51.650.957.959.2
          CMMLU (Acc.)5-shot52.353.460.762.5

          可以看到,MLA 的表现比 MHA 的表现更好,并且 KV cache 也更少。

          Conclusion

          作者提出了 MLA, 一个基于 low rank compression 的注意力机制,通过将 key value vector 压缩到低维空间,MLA 可以有效降低 Inference latency, 作者通过实现证明 MLA 的表现可以与 MHA 相比,并且 KV cache 更小。

            Google 在 2019 年提出了 multi-query attention (MQA), 用于解决 MQA 内存带宽瓶颈问题。

            Method

            Background

            对于 multi-head attention, 我们假设其 hidden size 为 dd, 有 hh 个 heads, 每个 head 的 size 为 dh=d/hd_h=d/h, 输入 sequence 长度为 nn, batch size 为 dd. 则总的 arithmetic operations 为 O(bnd2)O(bnd^2). 总的内存访问量为 O(bnd+bhn2+d2)O(bnd + bhn^2+d^2), 第一项是 Q,K,VQ,K,V 的内存占用(Q,K,VQ,K,V 分别是 query, key 和 value layer 的输出),第二项是 attention score 的占用,第三项是 query, key 和 value layer 的权重。

            因此,其 Memory Access Ratio (MAR), 也就是内存访问量 与 arithmetic operations 之比为

            O(1k+1bn)O\left(\frac1k + \frac{1}{bn}\right)

            对于现代的 GPU 来说,其一般算力比较强,但是内存访问带宽相对较慢,因此我们希望 MAR 越低越好,以充分发挥 GPU 的算力。

            MHA Analysis

            在训练的时候,由于我们知道 ground truth sequence, 因此我们可以并行计算。但是在 inference 的时候,我们只能 token-by-token 进行计算,因此我们分析一下 token-by-token 场景下的 MAR

            我们整体的 arithmetic operations 还是 O(bnd2)O(bnd^2).

            但是,现在我们要调用 nn 次 multi-head attention, 因此我们总的内存访问量为 O(bn2d+nd2)O(bn^2d + nd^2), 第一项是 KKVV , 第二项是 query, key 和 value layer 的权重。

            这种情况下,MAR 就变成了

            O(nd+1b)O\left(\frac{n}{d} + \frac{1}{b}\right)

            ndn\approx d 或者 b1b\approx 1 时,MAR 就非常接近于 1,意味着内存带宽成了一个主要的瓶颈。为了解决这个问题,我们有两种做法:

            1. 提升 batch size bb, 也就是同时 inference 多次
            2. 降低 KKVV 的大小

            MQA

            MQA 的做法就是第二种,也就是降低 KKVV 的大小,但是 K,VK,V 分别是 key 和 value layer 的输出,要降低输出大小,我们就必须改变 key 和 value layer 的 size。基于这个考虑,作者在所有的 head 上共享了一个 key 和 value layer,也就是说,原来

            self.k_proj = nn.Linear(hidden_size, num_heads * head_dim) # (d, n*d_h)
            self.v_proj = nn.Linear(hidden_size, num_heads * head_dim) # (d, n*d_h)
            

            现在在 MQA 里,其变成了

            self.k_proj = nn.Linear(hidden_size, head_dim) # (d, n*d_h)
            self.v_proj = nn.Linear(hidden_size, head_dim) # (d, n*d_h)
            

            MQA Analysis

            我们还是在 token-by-token 的场景下进行分析。

            我们整体的 arithmetic operations 还是 O(bnd2)O(bnd^2).

            调用 nn 次 multi-query attention 的总的内存访问量为 O(bnd+bn2dh+nd2)O(bnd +bn^2d_h+ nd^2), 第一项是 qq , 第二项是 KKVV , 第三项是是 query, key 和 value layer 的权重。

            此时,MAR 变成了

            O(1d+ndh+1b)O\left(\frac{1}{d} + \frac{n}{dh}+\frac{1}{b}\right)

            现在,我们就将 n/dn/d 这一项给降低了 hh 倍。如果我们的 batch size 足够大的话,理论上 MQA 应该能极大提高整体的计算效率。

            Conclusion

            MQA 为了追求极致的内存带宽占用,选择使用单一的 key 和 value, 来极大提高 inference 的 decoding 效率,但是后来在 GQA 中验证发现,MQA 虽然非常高效,但是其表现比较差,这也是后来没有得以应用的原因。

              DeepSeek 在 25 年 1 月提出了 Natively trainable Sparse Attention (NSA), 一个软硬件结合的稀疏注意力机制,NSA 可以在提高模型推理效率的同时提高计算效率。

              Introduction

              现有的大模型主要是基于 Transformer 提出的 softmax attention, 其主要问题在于随上下文长度增加,其 latency 也上升更快。理论估计,对于 64k 上下文长度的输出,softmax attention 部分的计算占 70%80%70\%\sim80\% 的 latency.

              为了解决 softmax 的 high latency 问题,,一个做法就是使用稀疏注意力机制,如 MInference 等,但是这些系数注意力机制大多没有实际部署,且它们一般只在 inference 阶段使用

              作者认为解决这个问题有两个挑战:

              1. Hardware-aligned inference speedup: 降低 inference latency 需要算法与硬件结合,不能只关注算法层面的改进
              2. Training-aware algorithm design: 需要在训练阶段也支持算法,从而可以降低训练的算力消耗并且保持模型的表现

              为了解决这两个问题,作者就提出了 natively trainable sparse attention (NSA) 架构。NSA 通过将 key 和 value 分割为不同的 block, 然后基于三种 path: compressed coarse-grained tokens, selectively retrained fine-grained tokens 以及 sliding windows for local contextual information 来进行处理和过滤。NSA 提出了两点观点改进:

              1. Hardware-aligned system: 优化了 blockwise sparse attention 来平衡 arithmetic intensity.
              2. Training-aware design: 支持端到端的训练和部署

              Method

              Overview

              作者首先回顾了 attention 的定义如下:

              ot=Attn(qt,k:,t,v:,t)=i=1tαt,it,iαt,ivi, αt,i=exp(qtTkidk)\mathbf{o}_t=\mathrm{Attn}(\mathbf{q_t},\mathbf{k}_{:,t}, \mathbf{v}_{:,t})=\sum_{i=1}^t \frac{\alpha_{t,i}}{\sum_{t,i}\alpha_{t,i}}\mathbf{v_i},\ \alpha_{t,i} = \exp\left(\frac{\mathbf{q_t}^T\mathbf{k}_{i}}{\sqrt{d_k}}\right)

              其中 qtRdk\mathbf{q_t}\in\mathbb{R}^{d_k}.

              接下来是 Arithmetic Intensity. Arithmetic intensity 指的是 FLOPs 与内存访问次数之比。由于现在的 GPU 都是计算密集型设备,理想情况下应该是 Arithmetic intensity 越高越好。

              对于 causal self-attention 来说,在训练以及 prefilling 阶段,由于 batch 较大,因此整体的 Arithmetic intensity 较高,因而这两个阶段是 computer-bound. 但是在 decoding 阶段,由于其 token-by-token generation 的性质,每次生成新的 token 时都需要重新加载 KV cache, 因而是 memory-bound.

              从而我们的优化目标也变得不一致:在训练阶段,我们希望降低计算消耗,而在推理 (decodng) 阶段,我们希望降低内存访问次数。

              基于这两个目标,作者提出了使用 k:,t,v:,t\mathbf{k}_{:,t}, \mathbf{v}_{:,t} 的子集 K~t,V~t\tilde{K}_t, \tilde{V}_t 来参与计算,其对应的 attention 如下所示

              K~t=fK(qt,k:,t,v:,t),V~t=fV(qt,k:,t,v:,t),ot=Attn(qt,K~t,V~t)\tilde{K}_t=f_K(\mathbf{q_t},\mathbf{k}_{:,t}, \mathbf{v}_{:,t}), \tilde{V}_t=f_V(\mathbf{q_t},\mathbf{k}_{:,t}, \mathbf{v}_{:,t}), \mathbf{o}_t=\mathrm{Attn}(\mathbf{q_t},\tilde{K}_t, \tilde{V}_t)

              我们还可以结合不同的方法来进行组合:

              ot=cCgtcAttn(qt,K~tc,V~tc)\mathbf{o}_t^*=\sum_{c\in\mathcal{C}}g_t^c\mathrm{Attn}(\mathbf{q_t},\tilde{K}_t^c, \tilde{V}_t^c)

              作者在本文中使用了三种方法 C={cmp,slc,win}\mathcal{C}=\{\mathrm{cmp},\mathrm{slc},\mathrm{win}\}, 分别代表了 compression, selection 以及 sliding window, gtc[0,1]g_t^c\in[0,1] 代表了不同方法对应的 gating score, 类似于 MoE 的 gating layer, gtcg_t^c 由一个 MLP 和一个 sigmoid activation 生成。最终 NSA 的架构如下图所示

              Overview of NSA architecture

              作者定义 NtN_t 代表参与计算的 KV 的总个数:

              Nt=cCsize[K~tc].N_t = \sum_{c\in\mathcal{C}} \mathrm{size}[\tilde{K}_t^c].

              作者使用了一个较高的 sparsity ratio 来保证 NttN_t \ll t.

              Design

              接下来作者分别介绍了每一部分的设计

              Token Compression

              对于 token compression, 其定义如下:

              K~tcmp=fKcmp(k:,t)={ϕ(kid+1:id+l)0itld}Rdk×tld\tilde{K}_t^{\mathrm{cmp}} = f_K^{\mathrm{cmp}}(\mathbf{k}_{:,t})=\left\{\phi(\mathbf{k}_{id+1:id+l})\mid 0\leq i\leq \left\lfloor\frac{t-l}{d}\right\rfloor\right\}\in\mathbb{R}^{d_k\times \left\lfloor\frac{t-l}{d}\right\rfloor}

              其中 ll 是 block size, dd 是 sliding stride, ϕ:Rl×dkRkd\phi:\mathbb{R}^{l\times d_k}\to \mathbb{R}^d_k 是一个 MLP 用于将 block key 映射为一个单一的 key. 对于 V~tcmp\tilde{V}_t^{\mathrm{cmp}} 作者也使用了类似的做法。

              Token Selection

              仅使用 compressed token 的话,可能会丢失一些细粒度的信息。因此,作者额外提出了 token selection 机制来解决这个问题。

              作者使用的做法是 blockwise selection. 这样做的原因有两点:

              1. hardware efficiency. 这样做的原因是 GPU 访问内存是在 block 层面进行的,因而更加高效
              2. inherent distribution patterns of attention scores. MInference 证明了 attention score 在空间上存在连续性。即相邻的 key 对应的重要性非常相似

              为了实现 block-wise selection, 作者首先将 key value sequences 分割为 blocks, 然后针对每个 blocks 分配 Importance score.

              作者首先介绍了如何计算不同 block 的 importance score.

              如果 selection block size 与 compression block size ,即 l=ll'=l 相同的话,则我们可以直接用 compression block 提供的信息:

              ptcmp=sotmax(qtTK~tcmp)\mathbf{p}_{t}^{\mathrm{cmp}} = \mathrm{sotmax}\left(\mathbf{q}_t^T\tilde{K}_t^{\mathrm{cmp}}\right)

              其中 ptcmpRtld+1\mathbf{p}_{t}^{\mathrm{cmp}}\in\mathbb{R}^{\left\lfloor\frac{t-l}{d}\right\rfloor+1} 代表了 qt\mathbf{q}_t 和 compressed key K~tcmp\tilde{K}_t^{\mathrm{cmp}} 之间的 attention score.

              如果 lll'\neq l 的话,作者通过空间关系来进行计算,假设 lll\leq l', dld\mid l, dmodld\mod l', 则我们有

              ptslc[j]=m=0l/d1n=0l/d1ptcmp[ldjmn]\mathbf{p}_{t}^{\mathrm{slc}}[j] = \sum_{m=0}^{l'/d-1}\sum_{n=0}^{l/d-1}\mathbf{p}_{t}^{\mathrm{cmp}}\left[\frac{l'}{d}j-m-n\right]

              对于 GQAMQA, 由于其 KV-cache 在 heads 之间共享,因此我们必须保证不同 heads 之间的 consistency, 因此作者提出了 shared importance score 如下:

              ptslc=h=1Hptslc,(h)\mathbf{p}_{t}^{\mathrm{slc}'} = \sum_{h=1}^H\mathbf{p}_{t}^{\mathrm{slc},(h)}

              接下来,对于每个 block 及其对应的 Importance score, 作者保存 top-nn sparse blcoks, 如下所示

              It={irank(ptslc[i])n}K~tslc=Cat[{kil+1:(i+1)liIt}]\begin{aligned} \mathcal{I}_t &= \{i\mid \mathrm{rank}(p_t^{\mathrm{slc}'}[i])\leq n\}\\ \tilde{K}_t^{\mathrm{slc}} &= \mathrm{Cat}\left[\{\mathbf{k}_{il'+1:(i+1)l'}\mid i\in \mathcal{I}_t\}\right] \end{aligned}

              其中 rank()\mathrm{rank}(\cdot) 代表了降序排列的 importance scores. It\mathcal{I}_t 是选择出来的 block indices, Cat()\mathrm{Cat}(\cdot) 表示了 concatenation operation. K~tslcRdk×il\tilde{K}_t^{\mathrm{slc}}\in\mathbb{R}^{d_k\times il'} 代表了选择出来的 key.

              Sliding Window

              为了避免 local pattern 对 compression token 以及 selection token 的学习产生影响,作者额外使用了一个 branch 来学习这个 local pattern. 其具体做法就是维持一个 sliding window 用于最近的若干个 token, 即

              K~twin=ktw:t,V~twin=vtw:t\tilde{K}_t^{\mathrm{win}} = \mathbf{k}_{t-w:t}, \tilde{V}_t^{\mathrm{win}} = \mathbf{v}_{t-w:t}

              这里 ww 是 window size.

              为了进一步避免 shortcut learning, 对于三个 branch 作者提供了不同的 key 和 values

              Kernel Design

              接下来是针对硬件设计进行的优化。由于 flash attention 2 对 compression attention 以及 sliding window attention 已经支持的比较好,作者这里介绍了如何针对 selection attention 进行优化。

              Experiments

              作者构建了一个 27B-A3B 的 MoE 模型,attention 基于 GQA, MoE 基于 DeepSeekMoE. 模型配置如下表所示

              fieldvalue
              layers30
              hidden dimension2560
              head groups4
              attention heads64
              query head dimension192
              value head dimension128
              routed experts72
              shared experts2
              activated experts6
              dense layers1

              NSA 配置如下

              fieldvalue
              ll32
              dd16
              ll'64
              nn16
              ww512

              其中 selection blocks 包含初始的一个 block 以及最近的 2 个 block.

              模型先在 8K 的上下文长度下使用 270B token 进行预训练,接下来在使用 YARN 将模型上下文通过 continual pre-training 以及 SFT 扩展到 32K. 训练过程的损失如下图所示

              Training loss of NSA

              作者从 general performance, long-context performance 以及 CoT reasoning performance 三个层面来评估 NSA 的表现。

              首先是 NSA 与其他 sparse attention 以及 baseline 在通用任务上表现的对比,结果如下图所示

              Performance of NSA on general benchmarks

              接下来是 NSA 在 LongBench 上的表现:

              Performance of NSA on LongBench

              作者还使用了 DeepSeek-R1 中的知识蒸馏方法,结果如下表所示

              Generation token limit819216384
              Full Attention-R0.0460.092
              NSA-R0.1210.146

              上面的结果均验证了 NSA 的有效性

              Analysis

              接下来,作者分析了 NSA 的性质。作者首先对比了 NSA 和 flash attention 2 的训练速度,结果如下图所示

              Performance comparison between NSA and flash attention 2

              可以看到,相比于 flash attention 2, NSA 在 forward 过程和 backward 过程的的效率分别提升了 9 倍和 6 倍。作者认为这是由于两个优点:

              1. NSA 使用了 block-wise memory access, 提高了 tensor core 的利用率
              2. loop scheduling 减少了 KV transfer 时的 kernel 冗余

              作者还对比了不同 attention 的解码速度,在 NSA 中,每次只需要 sld+nl+w\left\lfloor\frac{s-l}{d}\right\rfloor+nl'+w 个 token 就可以完成计算,作者对比不同 attention 所需余姚的 token 如下表所示如下表所示

              Context Length8192163843276865536
              Full attention8192163843276865536
              NSA2048256035845632
              speedup4x6.4x9.1x11.6x

              Discussion

              Conclusion

              作者在本文中提出了 NSA, 一个通过软硬件协同结合 compression, selection 以及 sliding window 的稀疏注意力机制,作者通过实验验证了其有效性。