Transformer 的 attention 是一个平方度复杂度的算法,这个平方复杂度既体现在时间复杂度上(矩阵乘法),也体现在空间复杂度上(需要存储中间结果)。因此,要降低 attention 的复杂度,我们有两种思路:
从时间复杂度上入手,比如使用稀疏 attention 机制或者线性注意力机制
从空间复杂度上入手,比如使用 GQA , MQA 等减少内存的访问开销
本文提出的 flash attention (Dao et al., 2022 ) 就属于降低空间复杂度的一种做法。作者认为,我们应该设计一种 IO-aware 的 attention 算法,来减少 attention 计算式的内存访问开销,进而提高 attention 的计算效率。
作者首先提到,一个未解决的问题就是:
降低 attention 的内存访问开销是否可以提高 attention 的计算效率?
作者发现,已有的一些工作虽然在理论上降低了 attention 的计算效率,但是在实际中,他们的效果并没有提升太多。作者分析原因认为,已有工作主要关注于降低 FLOPs, 但是忽略了内存访问开销。
因此,作者在本文中就提出了 flashattention, 一个 IO-aware 的 attention 算法,作者通过尽可能降低内存访问开销来提高模型的计算效率。具体做法就是,避免从内存中读写 attention matrix, 作者认为这个目标有两个挑战:
计算 softmax 的时候不访问所有的输入
在反向传播时不存储中间的 attention matrix
作者提出了两个方法来分别解决这两个问题:
作者使用了 tiling 技巧,将 input 分成多个 block, 然后分别进行处理,进而降低 softmax 的内存访问开销
作者使用了 recompute 技巧,在反向传播时,重新计算 softmax normalization factor
通过这些改进,我们可以让 attention 运行更快,并且降低内存访问开销。
作者还从理论上分析了 flashattention 的复杂度,提供了理论基础。
作者通过实验验证了 flashattention 的有效性,主要是三点:
训练效率更高:相比于 Huggingface 和 Megatron, flashattention 的训练效率提升了 2-3 倍
模型的表现更好:相比于 GPT-2, 模型的 perplexity 提升了 0.7 个点左右
速度更快:flashattention 比标准的 attention 实现快 3 倍以上
作者首先介绍了以下 GPU 的内存架构,如下图所示
可以看到,GPU 内存可以分为三个层级:
SRAM: GPU 的寄存器,容量小,但是访问速度极快
High bandwith memory (HBM): GPU 的高速内存,访问速度较快,容量中等
DRAM: CPU 内存,容量最大,但是访问速度较慢
接下来作者介绍了 Execution model 的概念,GPU 有多个线程来执行同一个操作(SPMD),这个操作也被称为 kernel, kernel 会从 HBM 中加载输入到 SRAM 中进行计算,然后写回 HBM.
对一个算法,我们可以将其归类为 compute-bound 和 memory-bound 两类, 我们可以用 arithmetic intensity 来进行区分,arithmetic intensity 定义为 arithmetic operations 与 memory access 的比率。
compute bound: 算法的瓶颈在于算力,由于算力不足导致运行时间慢,比如矩阵乘法
memory-bound: 算法的瓶颈在于内存访问效率,比如 element-wise 操作或者是 reduction
为了提高 memory-bound 类型算法的效率,我们进行 kernel fusion, 即把多个访问同一片内存的操作放在一起处理,避免多次读写内存
作者还回顾了一下标准化的 attention 实现。
给定 Q , K , V ∈ R N × d Q,K,V\in\mathbb{R}^{N\times d} Q , K , V ∈ R N × d , 其中 N N N 是序列长度, d d d 是 head dimension, attention 的定义如下
S = Q K T ∈ R N × N , P = s o f t m a x ( S ) ∈ R N × N , O = P V ∈ R N × d S = QK^T\in\mathbb{R}^{N\times N},\quad P = \mathrm{softmax}(S)\in\mathbb{R}^{N\times N},\quad O = PV\in\mathbb{R}^{N\times d} S = Q K T ∈ R N × N , P = softmax ( S ) ∈ R N × N , O = P V ∈ R N × d
这里 s o f t m a x \mathrm{softmax} softmax 是逐行计算的。
算法的执行过程如下
我们有第一个结论
Proposition 1
标准化 attention 前向传播时访问 HBM 的内存访问开销为 O ( N d + N 2 ) \mathcal{O}(Nd+N^2) O ( N d + N 2 ) .
证明:对于 attention, 我们需要从 HBM 中加载 Q , K , V ∈ R N × d Q,K,V\in\mathbb{R}^{N\times d} Q , K , V ∈ R N × d , 然后输出 O ∈ R N × d O\in\mathbb{R}^{N\times d} O ∈ R N × d 并保存到内存中。
首先我们需要计算 S = Q K T S = QK^T S = Q K T , 这一步需要加载 Q , K Q,K Q , K 并将 S S S 保存到 HBM 中,内存访问量为 O ( N d + N 2 ) \mathcal{O}(Nd+N^2) O ( N d + N 2 ) .
接下来,我们需要计算 P = s o f t m a x ( S ) P = \mathrm{softmax}(S) P = softmax ( S ) , 这一步需要加载 S S S 然后将 P P P 保存到 HBM 中,内存访问量为 O ( N 2 ) \mathcal{O}(N^2) O ( N 2 ) .
最后,我们需要计算 O = P V O = PV O = P V , 这一步需要加载 P P P 和 V V V 然后将 O O O 保存到 HBM 中,内存访问量为 O ( N d + N 2 ) \mathcal{O}(Nd+N^2) O ( N d + N 2 ) .
总的来说,标准化 attention 的内存访问开销为 O ( N d + N 2 ) \mathcal{O}(Nd+N^2) O ( N d + N 2 ) .
标准 attention 反向传播过程如下图所示
Proposition 2
标准化 attention 反向传播时访问 HBM 的内存访问开销为 O ( N d + N 2 ) \mathcal{O}(Nd+N^2) O ( N d + N 2 ) .
证明:对于标准化 attention 的反向传播,我们需要从 HBM 中加载 Q , K , V , d O ∈ R N × d Q,K,V,dO\in\mathbb{R}^{N\times d} Q , K , V , d O ∈ R N × d , 然后输出 d Q , d K , d V dQ,dK,dV d Q , d K , d V 并保存到 HBM 中。
首先我们计算 d V = P T d O dV=P^TdO d V = P T d O , 这一步需要加载 P , d O P,dO P , d O 并将 d V dV d V 保存到 HBM 中,内存访问开销为 O ( N d + N 2 ) \mathcal{O}(Nd+N^2) O ( N d + N 2 ) .
接下来我们计算 d P = d O V T dP=dOV^T d P = d O V T , 这一步需要加载 d O , V dO, V d O , V 并将 d P dP d P 保存到 HBM 中,内存访问开销为 O ( N d ) \mathcal{O}(Nd) O ( N d ) .
然后我们计算 d S dS d S , 这一步需要加载 P P P 并将 d S dS d S 保存到 HBM 中,内存访问开销为 O ( N 2 ) \mathcal{O}(N^2) O ( N 2 ) .
对于 d Q dQ d Q 和 d K dK d K 的计算,内存访问开销都是 O ( N d + N 2 ) \mathcal{O}(Nd+N^2) O ( N d + N 2 ) .
因此,标准化 attention 的内存访问开销为 O ( N d + N 2 ) \mathcal{O}(Nd+N^2) O ( N d + N 2 ) .
作者在本节首先介绍了 flashattention 算法,然后作者证明了 flashattention 的正确性以及分析了复杂度。最后作者对 flashattention 进行扩展得到了 Block-sparse Flashattention.
attention 模块的输入是 Q , K , V ∈ R N × d Q,K,V\in\mathbb{R}^{N\times d} Q , K , V ∈ R N × d , 输出是 O ∈ R N × d O\in\mathbb{R}^{N\times d} O ∈ R N × d , 作者的目标是减少计算过程中的 HBM 访问次数
作者分别使用了 tiling 和 recomputation 来解决 attention 前向传播和反向传播中的内存访问开销。flashattention 的核心思想是,我们将 Q , K , V Q,K,V Q , K , V 分割成 block, 然后在 block 层面进行加载和计算。
首先作者介绍了一下如何使用 tiling 来计算 softmax.
给定一个向量 x ∈ R B x\in\mathbb{R}^{B} x ∈ R B , 其 softmax 计算方式如下
m ( x ) = max i x i , f ( x ) = [ e x 1 − m ( x ) , … , e x B − m ( x ) ] , ℓ ( x ) = ∑ i f ( x ) i , s o f t m a x ( x ) = f ( x ) ℓ ( x ) m(x) = \max_i x_i,\ f(x) = [e^{x_1-m(x)},\dots,e^{x_B-m(x)}], \ \ell(x)=\sum_if(x)_i, \ \mathrm{softmax}(x) = \frac{f(x)}{\ell(x)} m ( x ) = i max x i , f ( x ) = [ e x 1 − m ( x ) , … , e x B − m ( x ) ] , ℓ ( x ) = i ∑ f ( x ) i , softmax ( x ) = ℓ ( x ) f ( x )
如果我们现在有两个向量 x ( 1 ) , x ( 2 ) ∈ R B x^{(1)}, x^{(2)}\in\mathbb{R}^{B} x ( 1 ) , x ( 2 ) ∈ R B , 记 x = [ x ( 1 ) , x ( 2 ) ] T ∈ R 2 B x=[x^{(1)}, x^{(2)}]^T\in\mathbb{R}^{2B} x = [ x ( 1 ) , x ( 2 ) ] T ∈ R 2 B , 我们可以将 s o f t m a x ( x ) \mathrm{softmax}(x) softmax ( x ) 的计算分解为
m ( x ) = max ( m ( x ( 1 ) ) , m ( x ( 2 ) ) ) , f ( x ) = [ e m ( x ( 1 ) ) − m ( x ) f ( x ( 1 ) ) , e m ( x ( 2 ) ) − m ( x ) f ( x ( 2 ) ) ] ℓ ( x ) = e m ( x ( 1 ) ) − m ( x ) ℓ ( x ( 1 ) ) + e m ( x ( 2 ) ) − m ( x ) ℓ ( x ( 2 ) ) , s o f t m a x ( x ) = f ( x ) ℓ ( x ) \begin{aligned}
m(x) &= \max(m(x^{(1)}), m(x^{(2)})), f(x) = [e^{m(x^{(1)})-m(x)}f(x^{(1)}),e^{m(x^{(2)})-m(x)}f(x^{(2)})]\\
\ell(x) &= e^{m(x^{(1)})-m(x)}\ell(x^{(1)}) + e^{m(x^{(2)})-m(x)}\ell(x^{(2)}), \mathrm{softmax}(x) = \frac{f(x)}{\ell(x)}
\end{aligned} m ( x ) ℓ ( x ) = max ( m ( x ( 1 ) ) , m ( x ( 2 ) )) , f ( x ) = [ e m ( x ( 1 ) ) − m ( x ) f ( x ( 1 ) ) , e m ( x ( 2 ) ) − m ( x ) f ( x ( 2 ) )] = e m ( x ( 1 ) ) − m ( x ) ℓ ( x ( 1 ) ) + e m ( x ( 2 ) ) − m ( x ) ℓ ( x ( 2 ) ) , softmax ( x ) = ℓ ( x ) f ( x )
因此,如果我们额外记录 m ( x ) m(x) m ( x ) 以及 ℓ ( x ) \ell(x) ℓ ( x ) 这两个量,那么我们可以每次仅计算 softmax 的一个 block. 具体细节见softmax .
在反向传播过程中,一般我们需要存储 S , P ∈ R N × N S,P\in\mathbb{R}^{N\times N} S , P ∈ R N × N , 需要的空间复杂度为 O ( N 2 ) \mathcal{O}(N^2) O ( N 2 ) . 但是,通过存储 O ∈ R N × d O\in\mathbb{R}^{N\times d} O ∈ R N × d 以及 ( m , ℓ ) (m,\ell) ( m , ℓ ) , 我们可以避免重新计算 S , P S,P S , P ,这可以看做是 gradient checkpointing. 但是与 checkpointing 相比,因为 flashattention 减少了内存访问开销,因此其反向过程并没有变得更慢。
最终,flashattention 的算法如下图所示
算法的正确性由定理 1 给出
Theorem 1
flashattention (即算法 1) 输出 O = s o f t m a x ( Q K T ) V O=\mathrm{softmax}(QK^T)V O = softmax ( Q K T ) V , 其时间复杂度为 O ( N 2 d ) \mathcal{O}(N^2d) O ( N 2 d ) , 空间复杂度为 O ( N ) \mathcal{O}(N) O ( N ) .
证明:时间复杂度主要由矩阵乘法决定。在计算 S i j = Q i K j T S_{ij}=Q_iK_j^T S ij = Q i K j T 时,所花费的 FLOPS 为 O ( B r B c d ) \mathcal{O}(B_rB_cd) O ( B r B c d ) . 在计算 P ~ i j V j \tilde{P}_{ij}V_j P ~ ij V j 时,所花费的 FLOPS 为 O ( B r B c d ) \mathcal{O}(B_rB_cd) O ( B r B c d ) . 循环一共执行了
T c T r = ⌈ N B c ⌉ ⌈ N B r ⌉ T_cT_r = \left\lceil\frac{N}{B_c}\right\rceil\left\lceil\frac{N}{B_r}\right\rceil T c T r = ⌈ B c N ⌉ ⌈ B r N ⌉
从而总的 FLOPS 为
O ( N 2 B r B c B r B c d ) = O ( N 2 d ) \mathcal{O}\left(\frac{N^2}{B_rB_c}B_rB_cd\right) = \mathcal{O}(N^2d) O ( B r B c N 2 B r B c d ) = O ( N 2 d )
在 flashattention 的计算过程中,我们只需要保存 ( ℓ , m ) (\ell, m) ( ℓ , m ) 即可,因此需要的额外内存空间为 O ( N ) \mathcal{O}(N) O ( N ) .
接下来,我们可以证明 flashattention 的正确性,我们使用归纳法来证明。令 j j j 满足 0 ≤ j ≤ T c 0\leq j\leq T_c 0 ≤ j ≤ T c , K : j ∈ R j B c × d K_{:j}\in\mathbb{R}^{jB_c\times d} K : j ∈ R j B c × d , V : j ∈ R j B c × d V_{:j}\in\mathbb{R}^{jB_c\times d} V : j ∈ R j B c × d 分别为 K K K 和 V V V 的前 j B c jB_c j B c 行。 S : , : j = Q K : j T ∈ R N × j B c S_{:, :j}=QK_{:j}^T\in\mathbb{R}^{N\times jB_c} S : , : j = Q K : j T ∈ R N × j B c , P : , : j = s o f t m a x ( S : , : j ) ∈ R N × j B c P_{:,:j}=\mathrm{softmax}(S_{:,:j})\in\mathbb{R}^{N\times jB_c} P : , : j = softmax ( S : , : j ) ∈ R N × j B c , m ( j ) , ℓ ( j ) , O ( j ) m^{(j)}, \ell^{(j)}, O^{(j)} m ( j ) , ℓ ( j ) , O ( j ) 分别为 m , ℓ , O m,\ell, O m , ℓ , O 的第 j j j 个元素。我们证明经过第 j j j 次迭代后,HBM 中保存的是
m ( j ) = r o w m a x ( S : , : j ) ∈ R N , ℓ ( j ) = r o w s u m ( exp ( S : , : j − m ( j ) ) ) ∈ R N , O ( j ) = P : , : j V : j ∈ R N × d m^{(j)}=\mathrm{rowmax}(S_{:,:j})\in\mathbb{R}^N, \ell^{(j)}=\mathrm{rowsum}(\exp(S_{:,:j}-m^{(j)}))\in\mathbb{R}^N, O^{(j)} = P_{:,:j}V_{:j}\in\mathbb{R}^{N\times d} m ( j ) = rowmax ( S : , : j ) ∈ R N , ℓ ( j ) = rowsum ( exp ( S : , : j − m ( j ) )) ∈ R N , O ( j ) = P : , : j V : j ∈ R N × d
当 j = 0 j=0 j = 0 时,上面的结果显然成立。现在我们假设对某个 j = 0 , … , T c − 1 j=0,\dots, T_c-1 j = 0 , … , T c − 1 上面的结果成立,我们需要证明对 j + 1 j+1 j + 1 也成立。
首先
m ( j + 1 ) = max ( m ( j ) , m ~ ) = max ( r o w m a x ( S : , : j ) , r o w m a x ( S : , j : j + 1 ) ) = r o w m a x ( S : , : j + 1 ) m^{(j+1)}=\max(m^{(j)}, \tilde{m}) = \max(\mathrm{rowmax}(S_{:,:j}), \mathrm{rowmax}(S_{:,j:j+1}))=\mathrm{rowmax}(S_{:,:j+1}) m ( j + 1 ) = max ( m ( j ) , m ~ ) = max ( rowmax ( S : , : j ) , rowmax ( S : , j : j + 1 )) = rowmax ( S : , : j + 1 )
接下来
ℓ ( j + 1 ) = exp ( m ( j ) − m ( j + 1 ) ) ℓ ( j ) + exp ( m ~ − m ( j + 1 ) ) ℓ ~ = exp ( m ( j ) − m ( j + 1 ) ) r o w s u m ( exp ( S : , : j − m ( j ) ) ) + exp ( m ~ − m ( j + 1 ) ) r o w s u m ( exp ( S : , j : j + 1 − m ~ ) ) = r o w s u m ( exp ( S : , : j − m ( j + 1 ) ) ) + r o w s u m ( exp ( S : , j : j + 1 − m ( j + 1 ) ) ) = r o w s u m ( exp ( S : , : j + 1 − m ( j + 1 ) ) ) \begin{aligned}
\ell^{(j+1)} &= \exp(m^{(j)}-m^{(j+1)})\ell^{(j)} + \exp(\tilde{m}-m^{(j+1)})\tilde{\ell}\\
&=\exp(m^{(j)}-m^{(j+1)})\mathrm{rowsum}(\exp(S_{:,:j}-m^{(j)})) + \exp(\tilde{m}-m^{(j+1)})\mathrm{rowsum}(\exp(S_{:,j:j+1}-\tilde{m}))\\
&= \mathrm{rowsum}(\exp(S_{:,:j}-m^{(j+1)})) + \mathrm{rowsum}(\exp(S_{:,j:j+1}-m^{(j+1)}))\\
&= \mathrm{rowsum}(\exp(S_{:,:j+1}-m^{(j+1)}))
\end{aligned} ℓ ( j + 1 ) = exp ( m ( j ) − m ( j + 1 ) ) ℓ ( j ) + exp ( m ~ − m ( j + 1 ) ) ℓ ~ = exp ( m ( j ) − m ( j + 1 ) ) rowsum ( exp ( S : , : j − m ( j ) )) + exp ( m ~ − m ( j + 1 ) ) rowsum ( exp ( S : , j : j + 1 − m ~ )) = rowsum ( exp ( S : , : j − m ( j + 1 ) )) + rowsum ( exp ( S : , j : j + 1 − m ( j + 1 ) )) = rowsum ( exp ( S : , : j + 1 − m ( j + 1 ) ))
最后,我们计算 O ( j + 1 ) O^{(j+1)} O ( j + 1 ) 得到:
O ( j + 1 ) = d i a g ( ℓ ( j + 1 ) ) − 1 ( d i a g ( ℓ ( j ) ) exp ( m ( j ) − m ( j + 1 ) ) O ( j ) + exp ( m ~ − m ( j + 1 ) ) exp ( S : , j : j + 1 − m ~ ) V : , j : j + 1 ) = d i a g ( ℓ ( j + 1 ) ) − 1 ( d i a g ( ℓ ( j ) ) exp ( m ( j ) − m ( j + 1 ) ) P : , : j V : , : j + exp ( − m ( j + 1 ) ) exp ( S : , j : j + 1 ) V : , j : j + 1 ) = d i a g ( ℓ ( j + 1 ) ) − 1 ( d i a g ( ℓ ( j ) ) exp ( m ( j ) − m ( j + 1 ) ) d i a g ( ℓ ( j ) ) − 1 exp ( S : , : j − m ( j ) ) V : , : j + exp ( − m ( j + 1 ) ) exp ( S : , j : j + 1 ) V : , j : j + 1 ) = d i a g ( ℓ ( j + 1 ) ) − 1 ( exp ( − m ( j + 1 ) ) exp ( S : , : j ) ) V : , : j + exp ( − m ( j + 1 ) ) exp ( S : , j : j + 1 ) V : , j : j + 1 ) = d i a g ( ℓ ( j + 1 ) ) − 1 ( [ exp ( S : , : j − m ( j + 1 ) ) exp ( S : , : j − m ( j + 1 ) ) ] [ V : , : j V : , j : j + 1 ] = s o f t m a x ( S : , : j + 1 ) V : , : j + 1 \begin{aligned}
O^{(j+1)} &= \mathrm{diag}(\ell^{(j+1)})^{-1}(\mathrm{diag}(\ell^{(j)})\exp(m^{(j)}-m^{(j+1)})O^{(j)}+\exp(\tilde{m}-m^{(j+1)})\exp(S_{:,j:j+1}-\tilde{m})V_{:,j:j+1})\\
&= \mathrm{diag}(\ell^{(j+1)})^{-1}(\mathrm{diag}(\ell^{(j)})\exp(m^{(j)}-m^{(j+1)})P_{:,:j}V_{:,:j}+\exp(-m^{(j+1)})\exp(S_{:,j:j+1})V_{:,j:j+1})\\
&= \mathrm{diag}(\ell^{(j+1)})^{-1}(\mathrm{diag}(\ell^{(j)})\exp(m^{(j)}-m^{(j+1)})\mathrm{diag}(\ell^{(j)})^{-1}\exp(S_{:,:j}-m^{(j)})V_{:,:j}+\exp(-m^{(j+1)})\exp(S_{:,j:j+1})V_{:,j:j+1})\\
&= \mathrm{diag}(\ell^{(j+1)})^{-1}(\exp(-m^{(j+1)})\exp(S_{:,:j}))V_{:,:j}+\exp(-m^{(j+1)})\exp(S_{:,j:j+1})V_{:,j:j+1})\\
&= \mathrm{diag}(\ell^{(j+1)})^{-1}(
\begin{bmatrix}
\exp(S_{:,:j}-m^{(j+1)}) & \exp(S_{:,:j}-m^{(j+1)})
\end{bmatrix}\begin{bmatrix}
V_{:,:j} \\
V_{:,j:j+1}
\end{bmatrix}\\
&= \mathrm{softmax}(S_{:,:j+1})V_{:,:j+1}
\end{aligned} O ( j + 1 ) = diag ( ℓ ( j + 1 ) ) − 1 ( diag ( ℓ ( j ) ) exp ( m ( j ) − m ( j + 1 ) ) O ( j ) + exp ( m ~ − m ( j + 1 ) ) exp ( S : , j : j + 1 − m ~ ) V : , j : j + 1 ) = diag ( ℓ ( j + 1 ) ) − 1 ( diag ( ℓ ( j ) ) exp ( m ( j ) − m ( j + 1 ) ) P : , : j V : , : j + exp ( − m ( j + 1 ) ) exp ( S : , j : j + 1 ) V : , j : j + 1 ) = diag ( ℓ ( j + 1 ) ) − 1 ( diag ( ℓ ( j ) ) exp ( m ( j ) − m ( j + 1 ) ) diag ( ℓ ( j ) ) − 1 exp ( S : , : j − m ( j ) ) V : , : j + exp ( − m ( j + 1 ) ) exp ( S : , j : j + 1 ) V : , j : j + 1 ) = diag ( ℓ ( j + 1 ) ) − 1 ( exp ( − m ( j + 1 ) ) exp ( S : , : j )) V : , : j + exp ( − m ( j + 1 ) ) exp ( S : , j : j + 1 ) V : , j : j + 1 ) = diag ( ℓ ( j + 1 ) ) − 1 ( [ exp ( S : , : j − m ( j + 1 ) ) exp ( S : , : j − m ( j + 1 ) ) ] [ V : , : j V : , j : j + 1 ] = softmax ( S : , : j + 1 ) V : , : j + 1
因此上面的结果对 j + 1 j+1 j + 1 也成立,从而 flashattention 的结果对 j = 0 , … , T c j=0,\dots,T_c j = 0 , … , T c 都成立。
第一个问题是如何提高 softmax 计算的效率,作者的做法先先计算 normalization constant 然后再分别计算不同的 column.
给定 Q , K , V ∈ R N × d Q,K,V\in\mathbb{R}^{N\times d} Q , K , V ∈ R N × d , 其中 N N N 是序列长度, d d d 是 head dimension, attention 的定义如下
S = Q K T ∈ R N × N , P = s o f t m a x ( S ) ∈ R N × N , O = P V ∈ R N × d S = QK^T\in\mathbb{R}^{N\times N},\quad P = \mathrm{softmax}(S)\in\mathbb{R}^{N\times N},\quad O = PV\in\mathbb{R}^{N\times d} S = Q K T ∈ R N × N , P = softmax ( S ) ∈ R N × N , O = P V ∈ R N × d
我们有 S i j = q i T k j S_{ij}=q_i^Tk_j S ij = q i T k j , 这里 q i q_i q i 和 k j k_j k j 分别是 Q Q Q 和 K K K 的第 i i i 列以及第 j j j 列, normalization constant 定义为:
L i = ∑ j = 1 N exp ( q i T k j ) L_i = \sum_{j=1}^N \exp\left(q_i^Tk_j\right) L i = j = 1 ∑ N exp ( q i T k j )
对任意 i i i , 计算 L i L_i L i 只需要 O ( N ) \mathcal{O}(N) O ( N ) 的空间复杂度。
令 v j v_j v j 是 V V V 的第 i i i 列,则输出 O O O 的第 i i i 列 o i o_i o i 为
o i = P i : V = ∑ j = 1 N P i j v j = ∑ j = 1 N exp ( q i T k j ) L i v j o_i = P_{i:}V = \sum_{j=1}^N P_{ij}v_j = \sum_{j=1}^N\frac{\exp(q_i^Tk_j)}{L_i}v_j o i = P i : V = j = 1 ∑ N P ij v j = j = 1 ∑ N L i exp ( q i T k j ) v j
这个过程中,对任意 i i i , 计算 o i o_i o i 也只需要 O ( N ) \mathcal{O}(N) O ( N ) 的空间复杂度。
因此,在 L i L_i L i 已经计算好的情况下,我们可以在 O ( N ) \mathcal{O}(N) O ( N ) 的空间复杂度下计算 o i o_i o i .
最终,flashattention 的 forward pass 过程如下图所示
接下来,作者分析了 flashattention 的内存访问开销。结论如下
Theorem 2
令 N N N 为 sequence length, d d d 为 head dimension, M M M 是 SRAM 的 size, 且满足 d ≤ M ≤ N d d\leq M\leq Nd d ≤ M ≤ N d . 则 flashattention 前向传播的内存访问开销为 Θ ( N 2 d 2 M − 1 ) \Theta(N^2d^2M^{-1}) Θ ( N 2 d 2 M − 1 ) .
证明:由 Algorithm 1(或者 Algorithm 2)可以知道,K K K 和 V V V 的每一个元素都只需要从 HBM 中加载一次,而每一次外层循环都会从 HBM 中加载一次 O O O 和 Q Q Q , 因此总的 HBM 访问次数为 O ( N d + N d T c ) = O ( N d T c ) \mathcal{O}(Nd+NdT_c)=\mathcal{O}(NdT_c) O ( N d + N d T c ) = O ( N d T c ) .
接下来,我们给出每一次内层循环的内存访问开销,这是由 SRAM 的大小决定的。由于我们需要 SRAM 可以存储 K j ∈ R B c × d K_j\in\mathbb{R}^{B_c\times d} K j ∈ R B c × d 以及 V j ∈ R B c × d V_j\in\mathbb{R}^{B_c\times d} V j ∈ R B c × d ,我们的 block size 需要满足
B c d = O ( M ) ⇒ B c = O ( M d ) B_cd = \mathcal{O}(M) \Rightarrow B_c = \mathcal{O}\left(\frac{M}{d}\right) B c d = O ( M ) ⇒ B c = O ( d M )
同理,对于 O O O 和 Q Q Q , 我们有
B r d = O ( M ) ⇒ B r = O ( M d ) B_rd = \mathcal{O}(M) \Rightarrow B_r = \mathcal{O}\left(\frac{M}{d}\right) B r d = O ( M ) ⇒ B r = O ( d M )
最后,我们还需要 SRAM 可以存储 S i j ∈ R B r × B c S_{ij}\in\mathbb{R}^{B_r\times B_c} S ij ∈ R B r × B c , 因此
B r B c = O ( M ) B_rB_c=\mathcal{O}(M) B r B c = O ( M )
这样,
B c = O ( M d ) , B r = O ( min ( M d , M B c ) ) = O ( min ( M d , d ) ) B_c = \mathcal{O}\left(\frac{M}{d}\right), B_r=\mathcal{O}\left(\min\left(\frac{M}{d},\frac{M}{B_c}\right)\right)=\mathcal{O}\left(\min\left(\frac{M}{d},d\right)\right) B c = O ( d M ) , B r = O ( min ( d M , B c M ) ) = O ( min ( d M , d ) )
从而
T c = N B c = O ( N d M ) T_c = \frac{N}{B_c} = \mathcal{O}\left(\frac{Nd}{M}\right) T c = B c N = O ( M N d )
最终,总的内存访问开销为
O ( N d T c ) = O ( N 2 d 2 M ) \mathcal{O}(NdT_c) = \mathcal{O}\left(\frac{N^2d^2}{M}\right) O ( N d T c ) = O ( M N 2 d 2 )
一般来说, d d d 的大小为 64 − 128 64-128 64 − 128 , M M M 的大小为 100 K B 100 KB 100 K B 左右, $d^2 \ll M, 因此 flashattention 的内存访问开销远小于标准化 attention 的内存访问开销。
作者还证明 flashattention 的内存访问开销是一个下界,即
Proposition 3
令 N N N 为 sequence length, d d d 为 head dimension, M M M 是 SRAM 的 size, 且满足 d ≤ M ≤ N d d\leq M\leq Nd d ≤ M ≤ N d . 则不存在一个对任意 M ∈ [ d , N d ] M\in[d,Nd] M ∈ [ d , N d ] 都可以在 内存访问开销为 Θ ( N 2 d 2 M − 1 ) \Theta(N^2d^2M^{-1}) Θ ( N 2 d 2 M − 1 ) 的条件下完成 attention 计算的算法。
证明可以用反证法,基本思想是加载 Q , K , V Q,K,V Q , K , V 的 HBM 访问次数至少为 O ( N d ) \mathcal{O}(Nd) O ( N d ) .
第二个问题是能否在线性空间复杂度下计算 attention 的反向传播过程。
首先我们记损失函数为 ϕ \phi ϕ , 然后令 ϕ \phi ϕ 对 O , Q , K , V O,Q,K,V O , Q , K , V 的梯度分别为 d O , d Q , d K , d V ∈ R N × d dO,dQ,dK, dV\in\mathbb{R}^{N\times d} d O , d Q , d K , d V ∈ R N × d , 我们的目标是计算 d Q , d K , d V dQ, dK, dV d Q , d K , d V .
d V dV d V 的计算是最容易的,我们有 d V = P T d O dV=P^TdO d V = P T d O , 因此
d v j = ∑ i = 1 N P i j d o i = ∑ i = 1 N exp ( q i T k j ) L i d o i dv_j = \sum_{i=1}^N P_{ij}do_i = \sum_{i=1}^N\frac{\exp(q_i^Tk_j)}{L_i}do_i d v j = i = 1 ∑ N P ij d o i = i = 1 ∑ N L i exp ( q i T k j ) d o i
由于我们已经计算了 L i L_i L i , 因此,d v j dv_j d v j 只需要 O ( d ) \mathcal{O}(d) O ( d ) 的空间复杂度。
接下来,注意到 d P = d O V T dP=dOV^T d P = d O V T , 因此我们有
d P i j = d o i T v j dP_{ij} = do_i^Tv_j d P ij = d o i T v j
计算的空间复杂度也是要 O ( N ) \mathcal{O}(N) O ( N ) 的
注意到 P i : = s o f t m a x ( s i : ) P_{i:}=\mathrm{softmax}(s_{i:}) P i : = softmax ( s i : ) , 且 y = s o f t m a x ( x ) y=\mathrm{softmax}(x) y = softmax ( x ) 的 Jacobian 是 d i a g ( y ) − y y T \mathrm{diag}(y)-yy^T diag ( y ) − y y T (推导过程见 softmax ), 我们有
d S i : = ( d i a g ( P i : ) − P i : P i : T ) d P i : = P i : ⊙ d P i : − ( P i : T d P i : ) P i : dS_{i:} = (\mathrm{diag}(P_{i:})-P_{i:}P_{i:}^T)dP_{i:} = P_{i:} \odot dP_{i:} - (P_{i:}^TdP_{i:})P_{i:} d S i : = ( diag ( P i : ) − P i : P i : T ) d P i : = P i : ⊙ d P i : − ( P i : T d P i : ) P i :
我们定义
D i = P i : T d P i : = ∑ j = 1 N exp ( q i T k j ) L i d o i T v j = d o i T ∑ j = 1 N exp ( q i T k j ) L i v j = d o i T o i D_i = P_{i:}^TdP_{i:}= \sum_{j=1}^N\frac{\exp(q_i^Tk_j)}{L_i}do_i^Tv_j = do_i^T\sum_{j=1}^N\frac{\exp(q_i^Tk_j)}{L_i}v_j = do_i^To_i D i = P i : T d P i : = j = 1 ∑ N L i exp ( q i T k j ) d o i T v j = d o i T j = 1 ∑ N L i exp ( q i T k j ) v j = d o i T o i
D i D_i D i 的空间复杂度也只需要 O ( N ) \mathcal{O}(N) O ( N ) .
则
d S i : = P i : ⊙ d P i : − D i P i : dS_{i:} =P_{i:} \odot dP_{i:} - D_iP_{i:} d S i : = P i : ⊙ d P i : − D i P i :
我们有
d S i j = P i j d P i j − D i P i j = P i j ( d P i j − D i ) dS_{ij} = P_{ij}dP_{ij} - D_iP_{ij} = P_{ij}(dP_{ij}-D_i) d S ij = P ij d P ij − D i P ij = P ij ( d P ij − D i )
注意到 S i j = q i T k j S_{ij}=q_i^Tk_j S ij = q i T k j , 我们有
d q i = ∑ j = 1 N d S i j k j = ∑ j = 1 N P i j ( d P i j − D i ) k j = ∑ j = 1 N exp ( q i T k j ) L i ( d o i T v j − D i ) k j dq_i = \sum_{j=1}^N dS_{ij}k_j = \sum_{j=1}^NP_{ij}(dP_{ij}-D_i)k_j = \sum_{j=1}^N\frac{\exp(q_i^Tk_j)}{L_i}(do_i^Tv_j-D_i)k_j d q i = j = 1 ∑ N d S ij k j = j = 1 ∑ N P ij ( d P ij − D i ) k j = j = 1 ∑ N L i exp ( q i T k j ) ( d o i T v j − D i ) k j
因此计算 d q i dq_i d q i 的空间复杂度为 O ( d ) \mathcal{O}(d) O ( d ) .
同样的,
d k j = ∑ j = 1 N d S i j q i = ∑ j = 1 N P i j ( d P i j − D i ) q i = ∑ j = 1 N exp ( q i T k j ) L i ( d o i T v j − D i ) q i dk_j = \sum_{j=1}^N dS_{ij}q_i = \sum_{j=1}^NP_{ij}(dP_{ij}-D_i)q_i = \sum_{j=1}^N\frac{\exp(q_i^Tk_j)}{L_i}(do_i^Tv_j-D_i)q_i d k j = j = 1 ∑ N d S ij q i = j = 1 ∑ N P ij ( d P ij − D i ) q i = j = 1 ∑ N L i exp ( q i T k j ) ( d o i T v j − D i ) q i
其空间复杂度为 O ( N ) \mathcal{O}(N) O ( N ) .
总之,attention 的反向传播过程所需要的空间复杂度为 O ( N ) \mathcal{O}(N) O ( N ) .
作者发现有两点可以改进:
attention mask 不需要存储,我们只需要保存 forward pass 时的输入,然后在 backward pass 时重新生成即可,这样只需要 O ( N ) \mathcal{O}(N) O ( N ) 的空间复杂度。
计算 softmax 的梯度是,如果使用公式 D i = P i : T d P i : D_i=P_{i:}^TdP_{i:} D i = P i : T d P i : 来计算的话,由于 P i : ∈ R N P_{i:}\in\mathbb{R}^N P i : ∈ R N , 可能会导致超过 SRAM 的内存使用限制,因此,我们可以使用 D i = d o i T o i D_i=do_i^To_i D i = d o i T o i 来避免这个问题,其中 o i ∈ R d o_i\in\mathbb{R}^d o i ∈ R d .
最终,flashattention 的 backward pass 过程如下图所示
经过前面的分析,flashattention 的反向传播的时间复杂度为 O ( N 2 ) \mathcal{O}(N^2) O ( N 2 ) , 空间复杂度为 O ( N ) \mathcal{O}(N) O ( N ) .
Theorem 5
令 N N N 为 sequence length, d d d 为 head dimension, M M M 是 SRAM 的 size, 且满足 d ≤ M ≤ N d d\leq M\leq Nd d ≤ M ≤ N d . 则 flashattention 反向传播的内存访问开销为 Θ ( N 2 d 2 M − 1 ) \Theta(N^2d^2M^{-1}) Θ ( N 2 d 2 M − 1 ) .
定理的证明与 Theorem 2 基本一致,我们此处不再赘述。
当 attention 具有 block sparsity 的性质时,作者提出了 blck-sparse flashattention 来进一步提高 attention 的计算效率。
给定 Q , K , V ∈ R N × d Q,K,V\in\mathbb{R}^{N\times d} Q , K , V ∈ R N × d , 以及一个 mask M ∈ { 0 , 1 } N × N M\in\{0,1\}^{N\times N} M ∈ { 0 , 1 } N × N , 我们需要计算
S = Q K T ∈ R N × N , P = s o f t m a x ( S ⊙ 1 M ) ∈ R N × N , O = P V ∈ R N × d S = QK^T\in\mathbb{R}^{N\times N},\quad P = \mathrm{softmax}(S\odot \mathbb{1}_{M})\in\mathbb{R}^{N\times N},\quad O = PV\in\mathbb{R}^{N\times d} S = Q K T ∈ R N × N , P = softmax ( S ⊙ 1 M ) ∈ R N × N , O = P V ∈ R N × d
其中当 M k l = 1 M_{kl}=1 M k l = 1 时, ( S ⊙ 1 M ) k l = S k l (S\odot \mathbb{1}_ {M})_ {kl}=S_ {kl} ( S ⊙ 1 M ) k l = S k l , 否则 ( S ⊙ 1 M ) k l = 0 (S\odot \mathbb{1}_ {M})_{kl}=0 ( S ⊙ 1 M ) k l = 0 .
Block-sparse attention 的算法如下所示
Proposition 4
令 N N N 为 sequence length, d d d 为 head dimension, M M M 是 SRAM 的 size, 且满足 d ≤ M ≤ N d d\leq M\leq Nd d ≤ M ≤ N d . 则 block-sparse attention 的内存访问开销为 Θ ( N d + N 2 d 2 M − 1 s ) \Theta(Nd+N^2d^2M^{-1}s) Θ ( N d + N 2 d 2 M − 1 s ) , 其中 s s s 是 block-sparse mask 中的非零 block 的比例
证明与 Theorem 2 的证明是类似的,总的内存访问开销为 O ( N d + N d T c ) \mathcal{O}(Nd+NdT_c) O ( N d + N d T c ) , 但是在计算的过程中,由于 mask 矩阵的 block-sparsity, 我们实际上只需要计算一小部分 M i j ≠ 0 M_{ij}\neq0 M ij = 0 的情况,因此最终的内存访问开销为
O ( N d + N 2 d 2 M s ) \mathcal{O}\left(Nd+\frac{N^2d^2}{M}s\right) O ( N d + M N 2 d 2 s )
可以看到,attention mask 的 sparsity 越高,block-sparse flashattention 的效率也就越高。当 N N N 非常大时,s s s 通常为 1 / N 1/\sqrt{N} 1/ N 或者 N − 1 log N N^{-1}\log N N − 1 log N , 从而最终的内存访问开销为 O ( N N ) \mathcal{O}(N\sqrt{N}) O ( N N ) 或者 O ( N log N ) \mathcal{O}(N\log N) O ( N log N ) .
作者对比了以下 block-sparse flashattention 和 flashattention 的效率对比,结果如下图所示
作者通过实验验证了 flashattention 的有效性,如下表所示
可以看到,尽管 flashattention 相比于标准化 attention 需要更多的算力,但是由于其内存访问开销更少,所以最终的运行时间大有了大幅度降低
作者还探究了 block size 对 flashattention 性能对的影响,实验结果如下图所示
可以看到,随着 block size 增加,循环次数降低,内存访问开销也逐渐降低。但是当 block size 充分大 ( > 256 > 256 > 256 ) 之后,运行时间就会被别的因素所限制,并且过大的 block size 可能会导致 SRAM 的内存溢出
作者首先在 BERT 和 GPT-2 上验证了 flashattention 的表现,BERT 的实验结果如下表所示
GPT-2 的实验结果如下表所示
实验结果显示,flashattention 比 Huggingface 快 3 倍左右,比 Megatron 快 1.7 倍左右
训练速度:实验显示,flashattention 在 BERT 上,比 MLPerf 1.1 快 15 % 15\% 15% , 在 GPT-2 上比 HuggingFace 快 3 倍,比 Megatron 快 1.8 倍
准确率:flashattention 是第一个在 Path-X 上比随机表现更好的 transformer 模型;block-sparse flashattention 是第一个在 Path-256 上比随机表现更好的的 sequence model
Dao, T., Fu, D. Y., Ermon, S., Rudra, A., & Re, C. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. In A. H. Oh, A. Agarwal, D. Belgrave, & K. Cho (Eds.), Advances in Neural Information Processing Systems . https://openreview.net/forum?id=H4DqfPSibmx
flash attention 2 在 flash attention 的基础上进行了进一步改进。
https://zhuanlan.zhihu.com/p/665170554
https://zhuanlan.zhihu.com/p/4264163756?utm_id=0
https://zhuanlan.zhihu.com/p/17533058076