作者提出了 flashattention, 一个通过降低 multi head attention 内存访问开销来提高 attention 计算效率的方法
Transformer 的 attention 是一个平方度复杂度的算法,这个平方复杂度既体现在时间复杂度上(矩阵乘法),也体现在空间复杂度上(需要存储中间结果)。因此,要降低 attention 的复杂度,我们有两种思路:
- 从时间复杂度上入手,比如使用稀疏 attention 机制或者线性注意力机制
- 从空间复杂度上入手,比如使用 GQA, MQA 等减少内存的访问开销
本文提出的 flashattention 就属于降低空间复杂度的一种做法。作者认为,我们应该设计一种 IO-aware 的 attention 算法,来减少 attention 计算式的内存访问开销,进而提高 attention 的计算效率。
作者首先提到,一个未解决的问题就是:
降低 attention 的内存访问开销是否可以提高 attention 的计算效率?
作者发现,已有的一些工作虽然在理论上降低了 attention 的计算效率,但是在实际中,他们的效果并没有提升太多。作者分析原因认为,已有工作主要关注于降低 FLOPs, 但是忽略了内存访问开销。
因此,作者在本文中就提出了 flashattention, 一个 IO-aware 的 attention 算法,作者通过尽可能降低内存访问开销来提高模型的计算效率。具体做法就是,避免从内存中读写 attention matrix, 作者认为这个目标有两个挑战:
- 计算 softmax 的时候不访问所有的输入
- 在反向传播时不存储中间的 attention matrix
作者提出了两个方法来分别解决这两个问题:
- 作者使用了 tiling 技巧,将 input 分成多个 block, 然后分别进行处理,进而降低 softmax 的内存访问开销
- 作者使用了 recompute 技巧,在反向传播时,重新计算 softmax normalization factor
通过这些改进,我们可以让 attention 运行更快,并且降低内存访问开销。
作者还从理论上分析了 flashattention 的复杂度,提供了理论基础。
作者通过实验验证了 flashattention 的有效性,主要是三点:
- 训练效率更高:相比于 Huggingface 和 Megatron, flashattention 的训练效率提升了 2-3 倍
- 模型的表现更好:相比于 GPT-2, 模型的 perplexity 提升了 0.7 个点左右
- 速度更快:flashattention 比标准的 attention 实现快 3 倍以上
作者首先介绍了以下 GPU 的内存架构,如下图所示
可以看到,GPU 内存可以分为三个层级:
- SRAM: GPU 的寄存器,容量小,但是访问速度极快
- High bandwith memory (HBM): GPU 的高速内存,访问速度较快,容量中等
- DRAM: CPU 内存,容量最大,但是访问速度较慢
接下来作者介绍了 Execution model 的概念,GPU 有多个线程来执行同一个操作(SPMD),这个操作也被称为 kernel, kernel 会从 HBM 中加载输入到 SRAM 中进行计算,然后写回 HBM.
对一个算法,我们可以将其归类为 compute-bound 和 memory-bound 两类, 我们可以用 arithmetic intensity 来进行区分,arithmetic intensity 定义为 arithmetic operations 与 memory access 的比率。
- compute bound: 算法的瓶颈在于算力,由于算力不足导致运行时间慢,比如矩阵乘法
- memory-bound: 算法的瓶颈在于内存访问效率,比如 element-wise 操作或者是 reduction
为了提高 memory-bound 类型算法的效率,我们进行 kernel fusion, 即把多个访问同一片内存的操作放在一起处理,避免多次读写内存
作者还回顾了一下标准化的 attention 实现。
给定 Q,K,V∈RN×d, 其中 N 是序列长度, d 是 head dimension, attention 的定义如下
S=QKT∈RN×N,P=softmax(S)∈RN×N,O=PV∈RN×d
这里 softmax 是逐行计算的。
算法的执行过程如下
我们有第一个结论
Proposition 1
标准化 attention 前向传播时访问 HBM 的内存访问开销为 O(Nd+N2).
证明:对于 attention, 我们需要从 HBM 中加载 Q,K,V∈RN×d, 然后输出 O∈RN×d 并保存到内存中。
首先我们需要计算 S=QKT, 这一步需要加载 Q,K 并将 S 保存到 HBM 中,内存访问量为 O(Nd+N2).
接下来,我们需要计算 P=softmax(S), 这一步需要加载 S 然后将 P 保存到 HBM 中,内存访问量为 O(N2).
最后,我们需要计算 O=PV, 这一步需要加载 P 和 V 然后将 O 保存到 HBM 中,内存访问量为 O(Nd+N2).
总的来说,标准化 attention 的内存访问开销为 O(Nd+N2).
标准 attention 反向传播过程如下图所示
Proposition 2
标准化 attention 反向传播时访问 HBM 的内存访问开销为 O(Nd+N2).
证明:对于标准化 attention 的反向传播,我们需要从 HBM 中加载 Q,K,V,dO∈RN×d , 然后输出 dQ,dK,dV 并保存到 HBM 中。
首先我们计算 dV=PTdO, 这一步需要加载 P,dO 并将 dV 保存到 HBM 中,内存访问开销为 O(Nd+N2).
接下来我们计算 dP=dOVT, 这一步需要加载 dO,V 并将 dP 保存到 HBM 中,内存访问开销为 O(Nd).
然后我们计算 dS, 这一步需要加载 P 并将 dS 保存到 HBM 中,内存访问开销为 O(N2).
对于 dQ 和 dK 的计算,内存访问开销都是 O(Nd+N2).
因此,标准化 attention 的内存访问开销为 O(Nd+N2).
作者在本节首先介绍了 flashattention 算法,然后作者证明了 flashattention 的正确性以及分析了复杂度。最后作者对 flashattention 进行扩展得到了 Block-sparse Flashattention.
attention 模块的输入是 Q,K,V∈RN×d, 输出是 O∈RN×d, 作者的目标是减少计算过程中的 HBM 访问次数
作者分别使用了 tiling 和 recomputation 来解决 attention 前向传播和反向传播中的内存访问开销。flashattention 的核心思想是,我们将 Q,K,V 分割成 block, 然后在 block 层面进行加载和计算。
首先作者介绍了一下如何使用 tiling 来计算 softmax.
给定一个向量 x∈RB, 其 softmax 计算方式如下
m(x)=imaxxi, f(x)=[ex1−m(x),…,exB−m(x)], ℓ(x)=i∑f(x)i, softmax(x)=ℓ(x)f(x)
如果我们现在有两个向量 x(1),x(2)∈RB, 记 x=[x(1),x(2)]T∈R2B, 我们可以将 softmax(x) 的计算分解为
m(x)ℓ(x)=max(m(x(1)),m(x(2))),f(x)=[em(x(1))−m(x)f(x(1)),em(x(2))−m(x)f(x(2))]=em(x(1))−m(x)ℓ(x(1))+em(x(2))−m(x)ℓ(x(2)),softmax(x)=ℓ(x)f(x)
因此,如果我们额外记录 m(x) 以及 ℓ(x) 这两个量,那么我们可以每次仅计算 softmax 的一个 block. 具体细节见softmax.
在反向传播过程中,一般我们需要存储 S,P∈RN×N, 需要的空间复杂度为 O(N2). 但是,通过存储 O∈RN×d 以及 (m,ℓ), 我们可以避免重新计算 S,P,这可以看做是 gradient checkpointing. 但是与 checkpointing 相比,因为 flashattention 减少了内存访问开销,因此其反向过程并没有变得更慢。
最终,flashattention 的算法如下图所示
算法的正确性由定理 1 给出
Theorem 1
flashattention (即算法 1) 输出 O=softmax(QKT)V, 其时间复杂度为 O(N2d), 空间复杂度为 O(N).
证明:时间复杂度主要由矩阵乘法决定。在计算 Sij=QiKjT 时,所花费的 FLOPS 为 O(BrBcd). 在计算 P~ijVj 时,所花费的 FLOPS 为 O(BrBcd). 循环一共执行了
TcTr=⌈BcN⌉⌈BrN⌉
从而总的 FLOPS 为
O(BrBcN2BrBcd)=O(N2d)
在 flashattention 的计算过程中,我们只需要保存 (ℓ,m) 即可,因此需要的额外内存空间为 O(N).
接下来,我们可以证明 flashattention 的正确性,我们使用归纳法来证明。令 j 满足 0≤j≤Tc, K:j∈RjBc×d, V:j∈RjBc×d 分别为 K 和 V 的前 jBc 行。 S:,:j=QK:jT∈RN×jBc, P:,:j=softmax(S:,:j)∈RN×jBc, m(j),ℓ(j),O(j) 分别为 m,ℓ,O 的第 j 个元素。我们证明经过第 j 次迭代后,HBM 中保存的是
m(j)=rowmax(S:,:j)∈RN,ℓ(j)=rowsum(exp(S:,:j−m(j)))∈RN,O(j)=P:,:jV:j∈RN×d
当 j=0 时,上面的结果显然成立。现在我们假设对某个 j=0,…,Tc−1 上面的结果成立,我们需要证明对 j+1 也成立。
首先
m(j+1)=max(m(j),m~)=max(rowmax(S:,:j),rowmax(S:,j:j+1))=rowmax(S:,:j+1)
接下来
ℓ(j+1)=exp(m(j)−m(j+1))ℓ(j)+exp(m~−m(j+1))ℓ~=exp(m(j)−m(j+1))rowsum(exp(S:,:j−m(j)))+exp(m~−m(j+1))rowsum(exp(S:,j:j+1−m~))=rowsum(exp(S:,:j−m(j+1)))+rowsum(exp(S:,j:j+1−m(j+1)))=rowsum(exp(S:,:j+1−m(j+1)))
最后,我们计算 O(j+1) 得到:
O(j+1)=diag(ℓ(j+1))−1(diag(ℓ(j))exp(m(j)−m(j+1))O(j)+exp(m~−m(j+1))exp(S:,j:j+1−m~)V:,j:j+1)=diag(ℓ(j+1))−1(diag(ℓ(j))exp(m(j)−m(j+1))P:,:jV:,:j+exp(−m(j+1))exp(S:,j:j+1)V:,j:j+1)=diag(ℓ(j+1))−1(diag(ℓ(j))exp(m(j)−m(j+1))diag(ℓ(j))−1exp(S:,:j−m(j))V:,:j+exp(−m(j+1))exp(S:,j:j+1)V:,j:j+1)=diag(ℓ(j+1))−1(exp(−m(j+1))exp(S:,:j))V:,:j+exp(−m(j+1))exp(S:,j:j+1)V:,j:j+1)=diag(ℓ(j+1))−1([exp(S:,:j−m(j+1))exp(S:,:j−m(j+1))][V:,:jV:,j:j+1]=softmax(S:,:j+1)V:,:j+1
因此上面的结果对 j+1 也成立,从而 flashattention 的结果对 j=0,…,Tc 都成立。
第一个问题是如何提高 softmax 计算的效率,作者的做法先先计算 normalization constant 然后再分别计算不同的 column.
给定 Q,K,V∈RN×d, 其中 N 是序列长度, d 是 head dimension, attention 的定义如下
S=QKT∈RN×N,P=softmax(S)∈RN×N,O=PV∈RN×d
我们有 Sij=qiTkj, 这里 qi 和 kj 分别是 Q 和 K 的第 i 列以及第 j 列, normalization constant 定义为:
Li=j=1∑Nexp(qiTkj)
对任意 i, 计算 Li 只需要 O(N) 的空间复杂度。
令 vj 是 V 的第 i 列,则输出 O 的第 i 列 oi 为
oi=Pi:V=j=1∑NPijvj=j=1∑NLiexp(qiTkj)vj
这个过程中,对任意 i, 计算 oi 也只需要 O(N) 的空间复杂度。
因此,在 Li 已经计算好的情况下,我们可以在 O(N) 的空间复杂度下计算 oi.
最终,flashattention 的 forward pass 过程如下图所示
接下来,作者分析了 flashattention 的内存访问开销。结论如下
Theorem 2
令 N 为 sequence length, d 为 head dimension, M 是 SRAM 的 size, 且满足 d≤M≤Nd. 则 flashattention 前向传播的内存访问开销为 Θ(N2d2M−1).
证明:由 Algorithm 1(或者 Algorithm 2)可以知道,K 和 V 的每一个元素都只需要从 HBM 中加载一次,而每一次外层循环都会从 HBM 中加载一次 O 和 Q, 因此总的 HBM 访问次数为 O(Nd+NdTc)=O(NdTc).
接下来,我们给出每一次内层循环的内存访问开销,这是由 SRAM 的大小决定的。由于我们需要 SRAM 可以存储 Kj∈RBc×d 以及 Vj∈RBc×d ,我们的 block size 需要满足
Bcd=O(M)⇒Bc=O(dM)
同理,对于 O 和 Q, 我们有
Brd=O(M)⇒Br=O(dM)
最后,我们还需要 SRAM 可以存储 Sij∈RBr×Bc, 因此
BrBc=O(M)
这样,
Bc=O(dM),Br=O(min(dM,BcM))=O(min(dM,d))
从而
Tc=BcN=O(MNd)
最终,总的内存访问开销为
O(NdTc)=O(MN2d2)
一般来说, d 的大小为 64−128, M 的大小为 100KB 左右, $d^2 \ll M, 因此 flashattention 的内存访问开销远小于标准化 attention 的内存访问开销。
作者还证明 flashattention 的内存访问开销是一个下界,即
Proposition 3
令 N 为 sequence length, d 为 head dimension, M 是 SRAM 的 size, 且满足 d≤M≤Nd. 则不存在一个对任意 M∈[d,Nd] 都可以在 内存访问开销为 Θ(N2d2M−1) 的条件下完成 attention 计算的算法。
证明可以用反证法,基本思想是加载 Q,K,V 的 HBM 访问次数至少为 O(Nd).
第二个问题是能否在线性空间复杂度下计算 attention 的反向传播过程。
首先我们记损失函数为 ϕ, 然后令 ϕ 对 O,Q,K,V 的梯度分别为 dO,dQ,dK,dV∈RN×d, 我们的目标是计算 dQ,dK,dV.
dV 的计算是最容易的,我们有 dV=PTdO, 因此
dvj=i=1∑NPijdoi=i=1∑NLiexp(qiTkj)doi
由于我们已经计算了 Li, 因此,dvj 只需要 O(d) 的空间复杂度。
接下来,注意到 dP=dOVT, 因此我们有
dPij=doiTvj
计算的空间复杂度也是要 O(N) 的
注意到 Pi:=softmax(si:), 且 y=softmax(x) 的 Jacobian 是 diag(y)−yyT (推导过程见 softmax), 我们有
dSi:=(diag(Pi:)−Pi:Pi:T)dPi:=Pi:⊙dPi:−(Pi:TdPi:)Pi:
我们定义
Di=Pi:TdPi:=j=1∑NLiexp(qiTkj)doiTvj=doiTj=1∑NLiexp(qiTkj)vj=doiToi
Di 的空间复杂度也只需要 O(N).
则
dSi:=Pi:⊙dPi:−DiPi:
我们有
dSij=PijdPij−DiPij=Pij(dPij−Di)
注意到 Sij=qiTkj, 我们有
dqi=j=1∑NdSijkj=j=1∑NPij(dPij−Di)kj=j=1∑NLiexp(qiTkj)(doiTvj−Di)kj
因此计算 dqi 的空间复杂度为 O(d).
同样的,
dkj=j=1∑NdSijqi=j=1∑NPij(dPij−Di)qi=j=1∑NLiexp(qiTkj)(doiTvj−Di)qi
其空间复杂度为 O(N).
总之,attention 的反向传播过程所需要的空间复杂度为 O(N).
作者发现有两点可以改进:
- attention mask 不需要存储,我们只需要保存 forward pass 时的输入,然后在 backward pass 时重新生成即可,这样只需要 O(N) 的空间复杂度。
- 计算 softmax 的梯度是,如果使用公式 Di=Pi:TdPi: 来计算的话,由于 Pi:∈RN, 可能会导致超过 SRAM 的内存使用限制,因此,我们可以使用 Di=doiToi 来避免这个问题,其中 oi∈Rd.
最终,flashattention 的 backward pass 过程如下图所示
经过前面的分析,flashattention 的反向传播的时间复杂度为 O(N2), 空间复杂度为 O(N).
Theorem 5
令 N 为 sequence length, d 为 head dimension, M 是 SRAM 的 size, 且满足 d≤M≤Nd. 则 flashattention 反向传播的内存访问开销为 Θ(N2d2M−1).
定理的证明与 Theorem 2 基本一致,我们此处不再赘述。
当 attention 具有 block sparsity 的性质时,作者提出了 blck-sparse flashattention 来进一步提高 attention 的计算效率。
给定 Q,K,V∈RN×d, 以及一个 mask M∈{0,1}N×N, 我们需要计算
S=QKT∈RN×N,P=softmax(S⊙1M)∈RN×N,O=PV∈RN×d
其中当 Mkl=1 时, (S⊙1M)kl=Skl, 否则 (S⊙1M)kl=0.
Block-sparse attention 的算法如下所示
Proposition 4
令 N 为 sequence length, d 为 head dimension, M 是 SRAM 的 size, 且满足 d≤M≤Nd. 则 block-sparse attention 的内存访问开销为 Θ(Nd+N2d2M−1s), 其中 s 是 block-sparse mask 中的非零 block 的比例
证明与 Theorem 2 的证明是类似的,总的内存访问开销为 O(Nd+NdTc), 但是在计算的过程中,由于 mask 矩阵的 block-sparsity, 我们实际上只需要计算一小部分 Mij=0 的情况,因此最终的内存访问开销为
O(Nd+MN2d2s)
可以看到,attention mask 的 sparsity 越高,block-sparse flashattention 的效率也就越高。当 N 非常大时,s 通常为 1/N 或者 N−1logN, 从而最终的内存访问开销为 O(NN) 或者 O(NlogN).
作者对比了以下 block-sparse flashattention 和 flashattention 的效率对比,结果如下图所示
作者通过实验验证了 flashattention 的有效性,如下表所示
可以看到,尽管 flashattention 相比于标准化 attention 需要更多的算力,但是由于其内存访问开销更少,所以最终的运行时间大有了大幅度降低
作者还探究了 block size 对 flashattention 性能对的影响,实验结果如下图所示
可以看到,随着 block size 增加,循环次数降低,内存访问开销也逐渐降低。但是当 block size 充分大 ( >256) 之后,运行时间就会被别的因素所限制,并且过大的 block size 可能会导致 SRAM 的内存溢出
作者首先在 BERT 和 GPT-2 上验证了 flashattention 的表现,BERT 的实验结果如下表所示
GPT-2 的实验结果如下表所示
实验结果显示,flashattention 比 Huggingface 快 3 倍左右,比 Megatron 快 1.7 倍左右
- 训练速度:实验显示,flashattention 在 BERT 上,比 MLPerf 1.1 快 15%, 在 GPT-2 上比 HuggingFace 快 3 倍,比 Megatron 快 1.8 倍
- 准确率:flashattention 是第一个在 Path-X 上比随机表现更好的 transformer 模型;block-sparse flashattention 是第一个在 Path-256 上比随机表现更好的的 sequence model
作者提出了 flashattention, 一个通过优化标准 attention 内存访问效率来提高 attention 计算效率的方法,作者详细介绍了算法设计的原理与证明,并通过实验证明了结果的有效性。