Overview of Flash Attention series

作者提出了 flashattention, 一个通过降低 multi head attention 内存访问开销来提高 attention 计算效率的方法

Author

Published

2025-08-21 11:32:53+0800

PDF

作者提出了 flashattention, 一个通过降低 multi head attention 内存访问开销来提高 attention 计算效率的方法

Introduction

Transformer 的 attention 是一个平方度复杂度的算法,这个平方复杂度既体现在时间复杂度上(矩阵乘法),也体现在空间复杂度上(需要存储中间结果)。因此,要降低 attention 的复杂度,我们有两种思路:

  1. 从时间复杂度上入手,比如使用稀疏 attention 机制或者线性注意力机制
  2. 从空间复杂度上入手,比如使用 GQA, MQA 等减少内存的访问开销

本文提出的 flashattention 就属于降低空间复杂度的一种做法。作者认为,我们应该设计一种 IO-aware 的 attention 算法,来减少 attention 计算式的内存访问开销,进而提高 attention 的计算效率。

作者首先提到,一个未解决的问题就是:

降低 attention 的内存访问开销是否可以提高 attention 的计算效率?

作者发现,已有的一些工作虽然在理论上降低了 attention 的计算效率,但是在实际中,他们的效果并没有提升太多。作者分析原因认为,已有工作主要关注于降低 FLOPs, 但是忽略了内存访问开销。

因此,作者在本文中就提出了 flashattention, 一个 IO-aware 的 attention 算法,作者通过尽可能降低内存访问开销来提高模型的计算效率。具体做法就是,避免从内存中读写 attention matrix, 作者认为这个目标有两个挑战:

  1. 计算 softmax 的时候不访问所有的输入
  2. 在反向传播时不存储中间的 attention matrix

作者提出了两个方法来分别解决这两个问题:

  1. 作者使用了 tiling 技巧,将 input 分成多个 block, 然后分别进行处理,进而降低 softmax 的内存访问开销
  2. 作者使用了 recompute 技巧,在反向传播时,重新计算 softmax normalization factor

通过这些改进,我们可以让 attention 运行更快,并且降低内存访问开销。

作者还从理论上分析了 flashattention 的复杂度,提供了理论基础。

作者通过实验验证了 flashattention 的有效性,主要是三点:

  1. 训练效率更高:相比于 Huggingface 和 Megatron, flashattention 的训练效率提升了 2-3 倍
  2. 模型的表现更好:相比于 GPT-2, 模型的 perplexity 提升了 0.7 个点左右
  3. 速度更快:flashattention 比标准的 attention 实现快 3 倍以上

Background

Hardware Performance

作者首先介绍了以下 GPU 的内存架构,如下图所示

Memory Hierarchy of GPU

可以看到,GPU 内存可以分为三个层级:

  1. SRAM: GPU 的寄存器,容量小,但是访问速度极快
  2. High bandwith memory (HBM): GPU 的高速内存,访问速度较快,容量中等
  3. DRAM: CPU 内存,容量最大,但是访问速度较慢

接下来作者介绍了 Execution model 的概念,GPU 有多个线程来执行同一个操作(SPMD),这个操作也被称为 kernel, kernel 会从 HBM 中加载输入到 SRAM 中进行计算,然后写回 HBM.

对一个算法,我们可以将其归类为 compute-bound 和 memory-bound 两类, 我们可以用 arithmetic intensity 来进行区分,arithmetic intensity 定义为 arithmetic operations 与 memory access 的比率。

  1. compute bound: 算法的瓶颈在于算力,由于算力不足导致运行时间慢,比如矩阵乘法
  2. memory-bound: 算法的瓶颈在于内存访问效率,比如 element-wise 操作或者是 reduction

为了提高 memory-bound 类型算法的效率,我们进行 kernel fusion, 即把多个访问同一片内存的操作放在一起处理,避免多次读写内存

Standard Attention Implementation

作者还回顾了一下标准化的 attention 实现。

Forward Pass

给定 Q,K,VRN×dQ,K,V\in\mathbb{R}^{N\times d}, 其中 NN 是序列长度, dd 是 head dimension, attention 的定义如下

S=QKTRN×N,P=softmax(S)RN×N,O=PVRN×dS = QK^T\in\mathbb{R}^{N\times N},\quad P = \mathrm{softmax}(S)\in\mathbb{R}^{N\times N},\quad O = PV\in\mathbb{R}^{N\times d}

这里 softmax\mathrm{softmax} 是逐行计算的。

算法的执行过程如下

Algorithm 0: Standard Attention Implementation

我们有第一个结论

Proposition 1 标准化 attention 前向传播时访问 HBM 的内存访问开销为 O(Nd+N2)\mathcal{O}(Nd+N^2).

证明:对于 attention, 我们需要从 HBM 中加载 Q,K,VRN×dQ,K,V\in\mathbb{R}^{N\times d}, 然后输出 ORN×dO\in\mathbb{R}^{N\times d} 并保存到内存中。

首先我们需要计算 S=QKTS = QK^T, 这一步需要加载 Q,KQ,K 并将 SS 保存到 HBM 中,内存访问量为 O(Nd+N2)\mathcal{O}(Nd+N^2).

接下来,我们需要计算 P=softmax(S)P = \mathrm{softmax}(S), 这一步需要加载 SS 然后将 PP 保存到 HBM 中,内存访问量为 O(N2)\mathcal{O}(N^2).

最后,我们需要计算 O=PVO = PV, 这一步需要加载 PPVV 然后将 OO 保存到 HBM 中,内存访问量为 O(Nd+N2)\mathcal{O}(Nd+N^2).

总的来说,标准化 attention 的内存访问开销为 O(Nd+N2)\mathcal{O}(Nd+N^2).

Backward Pass

标准 attention 反向传播过程如下图所示

Standard attention backward pass

Proposition 2 标准化 attention 反向传播时访问 HBM 的内存访问开销为 O(Nd+N2)\mathcal{O}(Nd+N^2).

证明:对于标准化 attention 的反向传播,我们需要从 HBM 中加载 Q,K,V,dORN×dQ,K,V,dO\in\mathbb{R}^{N\times d} , 然后输出 dQ,dK,dVdQ,dK,dV 并保存到 HBM 中。

首先我们计算 dV=PTdOdV=P^TdO, 这一步需要加载 P,dOP,dO 并将 dVdV 保存到 HBM 中,内存访问开销为 O(Nd+N2)\mathcal{O}(Nd+N^2).

接下来我们计算 dP=dOVTdP=dOV^T, 这一步需要加载 dO,VdO, V 并将 dPdP 保存到 HBM 中,内存访问开销为 O(Nd)\mathcal{O}(Nd).

然后我们计算 dSdS, 这一步需要加载 PP 并将 dSdS 保存到 HBM 中,内存访问开销为 O(N2)\mathcal{O}(N^2).

对于 dQdQdKdK 的计算,内存访问开销都是 O(Nd+N2)\mathcal{O}(Nd+N^2).

因此,标准化 attention 的内存访问开销为 O(Nd+N2)\mathcal{O}(Nd+N^2).

Method

作者在本节首先介绍了 flashattention 算法,然后作者证明了 flashattention 的正确性以及分析了复杂度。最后作者对 flashattention 进行扩展得到了 Block-sparse Flashattention.

Flashattention

attention 模块的输入是 Q,K,VRN×dQ,K,V\in\mathbb{R}^{N\times d}, 输出是 ORN×dO\in\mathbb{R}^{N\times d}, 作者的目标是减少计算过程中的 HBM 访问次数

作者分别使用了 tiling 和 recomputation 来解决 attention 前向传播和反向传播中的内存访问开销。flashattention 的核心思想是,我们将 Q,K,VQ,K,V 分割成 block, 然后在 block 层面进行加载和计算。

Tiling

首先作者介绍了一下如何使用 tiling 来计算 softmax.

给定一个向量 xRBx\in\mathbb{R}^{B}, 其 softmax 计算方式如下

m(x)=maxixi, f(x)=[ex1m(x),,exBm(x)], (x)=if(x)i, softmax(x)=f(x)(x)m(x) = \max_i x_i,\ f(x) = [e^{x_1-m(x)},\dots,e^{x_B-m(x)}], \ \ell(x)=\sum_if(x)_i, \ \mathrm{softmax}(x) = \frac{f(x)}{\ell(x)}

如果我们现在有两个向量 x(1),x(2)RBx^{(1)}, x^{(2)}\in\mathbb{R}^{B}, 记 x=[x(1),x(2)]TR2Bx=[x^{(1)}, x^{(2)}]^T\in\mathbb{R}^{2B}, 我们可以将 softmax(x)\mathrm{softmax}(x) 的计算分解为

m(x)=max(m(x(1)),m(x(2)))f(x)=[em(x(1))m(x)f(x(1)),em(x(2))m(x)f(x(2))](x)=em(x(1))m(x)(x(1))+em(x(2))m(x)(x(2)),softmax(x)=f(x)(x)\begin{aligned} m(x) &= \max(m(x^{(1)}), m(x^{(2)})), f(x) = [e^{m(x^{(1)})-m(x)}f(x^{(1)}),e^{m(x^{(2)})-m(x)}f(x^{(2)})]\\ \ell(x) &= e^{m(x^{(1)})-m(x)}\ell(x^{(1)}) + e^{m(x^{(2)})-m(x)}\ell(x^{(2)}), \mathrm{softmax}(x) = \frac{f(x)}{\ell(x)} \end{aligned}

因此,如果我们额外记录 m(x)m(x) 以及 (x)\ell(x) 这两个量,那么我们可以每次仅计算 softmax 的一个 block. 具体细节见softmax.

Recomputation

在反向传播过程中,一般我们需要存储 S,PRN×NS,P\in\mathbb{R}^{N\times N}, 需要的空间复杂度为 O(N2)\mathcal{O}(N^2). 但是,通过存储 ORN×dO\in\mathbb{R}^{N\times d} 以及 (m,)(m,\ell), 我们可以避免重新计算 S,PS,P,这可以看做是 gradient checkpointing. 但是与 checkpointing 相比,因为 flashattention 减少了内存访问开销,因此其反向过程并没有变得更慢。

Algorithm

最终,flashattention 的算法如下图所示

FlashAttention algorithm

Analysis

Correctness

算法的正确性由定理 1 给出

Theorem 1 flashattention (即算法 1) 输出 O=softmax(QKT)VO=\mathrm{softmax}(QK^T)V, 其时间复杂度为 O(N2d)\mathcal{O}(N^2d), 空间复杂度为 O(N)\mathcal{O}(N).

证明:时间复杂度主要由矩阵乘法决定。在计算 Sij=QiKjTS_{ij}=Q_iK_j^T 时,所花费的 FLOPS 为 O(BrBcd)\mathcal{O}(B_rB_cd). 在计算 P~ijVj\tilde{P}_{ij}V_j 时,所花费的 FLOPS 为 O(BrBcd)\mathcal{O}(B_rB_cd). 循环一共执行了

TcTr=NBcNBrT_cT_r = \left\lceil\frac{N}{B_c}\right\rceil\left\lceil\frac{N}{B_r}\right\rceil

从而总的 FLOPS 为

O(N2BrBcBrBcd)=O(N2d)\mathcal{O}\left(\frac{N^2}{B_rB_c}B_rB_cd\right) = \mathcal{O}(N^2d)

在 flashattention 的计算过程中,我们只需要保存 (,m)(\ell, m) 即可,因此需要的额外内存空间为 O(N)\mathcal{O}(N).

接下来,我们可以证明 flashattention 的正确性,我们使用归纳法来证明。令 jj 满足 0jTc0\leq j\leq T_c, K:jRjBc×dK_{:j}\in\mathbb{R}^{jB_c\times d}, V:jRjBc×dV_{:j}\in\mathbb{R}^{jB_c\times d} 分别为 KKVV 的前 jBcjB_c 行。 S:,:j=QK:jTRN×jBcS_{:, :j}=QK_{:j}^T\in\mathbb{R}^{N\times jB_c}, P:,:j=softmax(S:,:j)RN×jBcP_{:,:j}=\mathrm{softmax}(S_{:,:j})\in\mathbb{R}^{N\times jB_c}, m(j),(j),O(j)m^{(j)}, \ell^{(j)}, O^{(j)} 分别为 m,,Om,\ell, O 的第 jj 个元素。我们证明经过第 jj 次迭代后,HBM 中保存的是

m(j)=rowmax(S:,:j)RN,(j)=rowsum(exp(S:,:jm(j)))RN,O(j)=P:,:jV:jRN×dm^{(j)}=\mathrm{rowmax}(S_{:,:j})\in\mathbb{R}^N, \ell^{(j)}=\mathrm{rowsum}(\exp(S_{:,:j}-m^{(j)}))\in\mathbb{R}^N, O^{(j)} = P_{:,:j}V_{:j}\in\mathbb{R}^{N\times d}

j=0j=0 时,上面的结果显然成立。现在我们假设对某个 j=0,,Tc1j=0,\dots, T_c-1 上面的结果成立,我们需要证明对 j+1j+1 也成立。

首先

m(j+1)=max(m(j)m~)=max(rowmax(S:,:j),rowmax(S:,j:j+1))=rowmax(S:,:j+1)m^{(j+1)}=\max(m^{(j)}, \tilde{m}) = \max(\mathrm{rowmax}(S_{:,:j}), \mathrm{rowmax}(S_{:,j:j+1}))=\mathrm{rowmax}(S_{:,:j+1})

接下来

(j+1)=exp(m(j)m(j+1))(j)+exp(m~m(j+1))~=exp(m(j)m(j+1))rowsum(exp(S:,:jm(j)))+exp(m~m(j+1))rowsum(exp(S:,j:j+1m~))=rowsum(exp(S:,:jm(j+1)))+rowsum(exp(S:,j:j+1m(j+1)))=rowsum(exp(S:,:j+1m(j+1)))\begin{aligned} \ell^{(j+1)} &= \exp(m^{(j)}-m^{(j+1)})\ell^{(j)} + \exp(\tilde{m}-m^{(j+1)})\tilde{\ell}\\ &=\exp(m^{(j)}-m^{(j+1)})\mathrm{rowsum}(\exp(S_{:,:j}-m^{(j)})) + \exp(\tilde{m}-m^{(j+1)})\mathrm{rowsum}(\exp(S_{:,j:j+1}-\tilde{m}))\\ &= \mathrm{rowsum}(\exp(S_{:,:j}-m^{(j+1)})) + \mathrm{rowsum}(\exp(S_{:,j:j+1}-m^{(j+1)}))\\ &= \mathrm{rowsum}(\exp(S_{:,:j+1}-m^{(j+1)})) \end{aligned}

最后,我们计算 O(j+1)O^{(j+1)} 得到:

O(j+1)=diag((j+1))1(diag((j))exp(m(j)m(j+1))O(j)+exp(m~m(j+1))exp(S:,j:j+1m~)V:,j:j+1)=diag((j+1))1(diag((j))exp(m(j)m(j+1))P:,:jV:,:j+exp(m(j+1))exp(S:,j:j+1)V:,j:j+1)=diag((j+1))1(diag((j))exp(m(j)m(j+1))diag((j))1exp(S:,:jm(j))V:,:j+exp(m(j+1))exp(S:,j:j+1)V:,j:j+1)=diag((j+1))1(exp(m(j+1))exp(S:,:j))V:,:j+exp(m(j+1))exp(S:,j:j+1)V:,j:j+1)=diag((j+1))1([exp(S:,:jm(j+1))exp(S:,:jm(j+1))][V:,:jV:,j:j+1]=softmax(S:,:j+1)V:,:j+1\begin{aligned} O^{(j+1)} &= \mathrm{diag}(\ell^{(j+1)})^{-1}(\mathrm{diag}(\ell^{(j)})\exp(m^{(j)}-m^{(j+1)})O^{(j)}+\exp(\tilde{m}-m^{(j+1)})\exp(S_{:,j:j+1}-\tilde{m})V_{:,j:j+1})\\ &= \mathrm{diag}(\ell^{(j+1)})^{-1}(\mathrm{diag}(\ell^{(j)})\exp(m^{(j)}-m^{(j+1)})P_{:,:j}V_{:,:j}+\exp(-m^{(j+1)})\exp(S_{:,j:j+1})V_{:,j:j+1})\\ &= \mathrm{diag}(\ell^{(j+1)})^{-1}(\mathrm{diag}(\ell^{(j)})\exp(m^{(j)}-m^{(j+1)})\mathrm{diag}(\ell^{(j)})^{-1}\exp(S_{:,:j}-m^{(j)})V_{:,:j}+\exp(-m^{(j+1)})\exp(S_{:,j:j+1})V_{:,j:j+1})\\ &= \mathrm{diag}(\ell^{(j+1)})^{-1}(\exp(-m^{(j+1)})\exp(S_{:,:j}))V_{:,:j}+\exp(-m^{(j+1)})\exp(S_{:,j:j+1})V_{:,j:j+1})\\ &= \mathrm{diag}(\ell^{(j+1)})^{-1}( \begin{bmatrix} \exp(S_{:,:j}-m^{(j+1)}) & \exp(S_{:,:j}-m^{(j+1)}) \end{bmatrix}\begin{bmatrix} V_{:,:j} \\ V_{:,j:j+1} \end{bmatrix}\\ &= \mathrm{softmax}(S_{:,:j+1})V_{:,:j+1} \end{aligned}

因此上面的结果对 j+1j+1 也成立,从而 flashattention 的结果对 j=0,,Tcj=0,\dots,T_c 都成立。

Forward Pass of flashattention

第一个问题是如何提高 softmax 计算的效率,作者的做法先先计算 normalization constant 然后再分别计算不同的 column.

给定 Q,K,VRN×dQ,K,V\in\mathbb{R}^{N\times d}, 其中 NN 是序列长度, dd 是 head dimension, attention 的定义如下

S=QKTRN×N,P=softmax(S)RN×N,O=PVRN×dS = QK^T\in\mathbb{R}^{N\times N},\quad P = \mathrm{softmax}(S)\in\mathbb{R}^{N\times N},\quad O = PV\in\mathbb{R}^{N\times d}

我们有 Sij=qiTkjS_{ij}=q_i^Tk_j, 这里 qiq_ikjk_j 分别是 QQKK 的第 ii 列以及第 jj 列, normalization constant 定义为:

Li=j=1Nexp(qiTkj)L_i = \sum_{j=1}^N \exp\left(q_i^Tk_j\right)

对任意 ii, 计算 LiL_i 只需要 O(N)\mathcal{O}(N) 的空间复杂度。

vjv_jVV 的第 ii 列,则输出 OO 的第 iioio_i

oi=Pi:V=j=1NPijvj=j=1Nexp(qiTkj)Livjo_i = P_{i:}V = \sum_{j=1}^N P_{ij}v_j = \sum_{j=1}^N\frac{\exp(q_i^Tk_j)}{L_i}v_j

这个过程中,对任意 ii, 计算 oio_i 也只需要 O(N)\mathcal{O}(N) 的空间复杂度。

因此,在 LiL_i 已经计算好的情况下,我们可以在 O(N)\mathcal{O}(N) 的空间复杂度下计算 oio_i.

最终,flashattention 的 forward pass 过程如下图所示

FlashAttention forward pass

接下来,作者分析了 flashattention 的内存访问开销。结论如下

Theorem 2 令 NN 为 sequence length, dd 为 head dimension, MM 是 SRAM 的 size, 且满足 dMNdd\leq M\leq Nd. 则 flashattention 前向传播的内存访问开销为 Θ(N2d2M1)\Theta(N^2d^2M^{-1}).

证明:由 Algorithm 1(或者 Algorithm 2)可以知道,KKVV 的每一个元素都只需要从 HBM 中加载一次,而每一次外层循环都会从 HBM 中加载一次 OOQQ, 因此总的 HBM 访问次数为 O(Nd+NdTc)=O(NdTc)\mathcal{O}(Nd+NdT_c)=\mathcal{O}(NdT_c).

接下来,我们给出每一次内层循环的内存访问开销,这是由 SRAM 的大小决定的。由于我们需要 SRAM 可以存储 KjRBc×dK_j\in\mathbb{R}^{B_c\times d} 以及 VjRBc×dV_j\in\mathbb{R}^{B_c\times d} ,我们的 block size 需要满足

Bcd=O(M)Bc=O(Md)B_cd = \mathcal{O}(M) \Rightarrow B_c = \mathcal{O}\left(\frac{M}{d}\right)

同理,对于 OOQQ, 我们有

Brd=O(M)Br=O(Md)B_rd = \mathcal{O}(M) \Rightarrow B_r = \mathcal{O}\left(\frac{M}{d}\right)

最后,我们还需要 SRAM 可以存储 SijRBr×BcS_{ij}\in\mathbb{R}^{B_r\times B_c}, 因此

BrBc=O(M)B_rB_c=\mathcal{O}(M)

这样,

Bc=O(Md),Br=O(min(Md,MBc))=O(min(Md,d))B_c = \mathcal{O}\left(\frac{M}{d}\right), B_r=\mathcal{O}\left(\min\left(\frac{M}{d},\frac{M}{B_c}\right)\right)=\mathcal{O}\left(\min\left(\frac{M}{d},d\right)\right)

从而

Tc=NBc=O(NdM)T_c = \frac{N}{B_c} = \mathcal{O}\left(\frac{Nd}{M}\right)

最终,总的内存访问开销为

O(NdTc)=O(N2d2M)\mathcal{O}(NdT_c) = \mathcal{O}\left(\frac{N^2d^2}{M}\right)

一般来说, dd 的大小为 6412864-128, MM 的大小为 100KB100 KB 左右, $d^2 \ll M, 因此 flashattention 的内存访问开销远小于标准化 attention 的内存访问开销。

作者还证明 flashattention 的内存访问开销是一个下界,即

Proposition 3 令 NN 为 sequence length, dd 为 head dimension, MM 是 SRAM 的 size, 且满足 dMNdd\leq M\leq Nd. 则不存在一个对任意 M[d,Nd]M\in[d,Nd] 都可以在 内存访问开销为 Θ(N2d2M1)\Theta(N^2d^2M^{-1}) 的条件下完成 attention 计算的算法。

证明可以用反证法,基本思想是加载 Q,K,VQ,K,V 的 HBM 访问次数至少为 O(Nd)\mathcal{O}(Nd).

Backward Pass of flashattention

第二个问题是能否在线性空间复杂度下计算 attention 的反向传播过程。

首先我们记损失函数为 ϕ\phi, 然后令 ϕ\phiO,Q,K,VO,Q,K,V 的梯度分别为 dO,dQ,dK,dVRN×ddO,dQ,dK, dV\in\mathbb{R}^{N\times d}, 我们的目标是计算 dQ,dK,dVdQ, dK, dV.

dVdV 的计算是最容易的,我们有 dV=PTdOdV=P^TdO, 因此

dvj=i=1NPijdoi=i=1Nexp(qiTkj)Lidoidv_j = \sum_{i=1}^N P_{ij}do_i = \sum_{i=1}^N\frac{\exp(q_i^Tk_j)}{L_i}do_i

由于我们已经计算了 LiL_i, 因此,dvjdv_j 只需要 O(d)\mathcal{O}(d) 的空间复杂度。

接下来,注意到 dP=dOVTdP=dOV^T, 因此我们有

dPij=doiTvjdP_{ij} = do_i^Tv_j

计算的空间复杂度也是要 O(N)\mathcal{O}(N)

注意到 Pi:=softmax(si:)P_{i:}=\mathrm{softmax}(s_{i:}), 且 y=softmax(x)y=\mathrm{softmax}(x) 的 Jacobian 是 diag(y)yyT\mathrm{diag}(y)-yy^T (推导过程见 softmax), 我们有

dSi:=(diag(Pi:)Pi:Pi:T)dPi:=Pi:dPi:(Pi:TdPi:)Pi:dS_{i:} = (\mathrm{diag}(P_{i:})-P_{i:}P_{i:}^T)dP_{i:} = P_{i:} \odot dP_{i:} - (P_{i:}^TdP_{i:})P_{i:}

我们定义

Di=Pi:TdPi:=j=1Nexp(qiTkj)LidoiTvj=doiTj=1Nexp(qiTkj)Livj=doiToiD_i = P_{i:}^TdP_{i:}= \sum_{j=1}^N\frac{\exp(q_i^Tk_j)}{L_i}do_i^Tv_j = do_i^T\sum_{j=1}^N\frac{\exp(q_i^Tk_j)}{L_i}v_j = do_i^To_i

DiD_i 的空间复杂度也只需要 O(N)\mathcal{O}(N).

dSi:=Pi:dPi:DiPi:dS_{i:} =P_{i:} \odot dP_{i:} - D_iP_{i:}

我们有

dSij=PijdPijDiPij=Pij(dPijDi)dS_{ij} = P_{ij}dP_{ij} - D_iP_{ij} = P_{ij}(dP_{ij}-D_i)

注意到 Sij=qiTkjS_{ij}=q_i^Tk_j, 我们有

dqi=j=1NdSijkj=j=1NPij(dPijDi)kj=j=1Nexp(qiTkj)Li(doiTvjDi)kjdq_i = \sum_{j=1}^N dS_{ij}k_j = \sum_{j=1}^NP_{ij}(dP_{ij}-D_i)k_j = \sum_{j=1}^N\frac{\exp(q_i^Tk_j)}{L_i}(do_i^Tv_j-D_i)k_j

因此计算 dqidq_i 的空间复杂度为 O(d)\mathcal{O}(d).

同样的,

dkj=j=1NdSijqi=j=1NPij(dPijDi)qi=j=1Nexp(qiTkj)Li(doiTvjDi)qidk_j = \sum_{j=1}^N dS_{ij}q_i = \sum_{j=1}^NP_{ij}(dP_{ij}-D_i)q_i = \sum_{j=1}^N\frac{\exp(q_i^Tk_j)}{L_i}(do_i^Tv_j-D_i)q_i

其空间复杂度为 O(N)\mathcal{O}(N).

总之,attention 的反向传播过程所需要的空间复杂度为 O(N)\mathcal{O}(N).

作者发现有两点可以改进:

  1. attention mask 不需要存储,我们只需要保存 forward pass 时的输入,然后在 backward pass 时重新生成即可,这样只需要 O(N)\mathcal{O}(N) 的空间复杂度。
  2. 计算 softmax 的梯度是,如果使用公式 Di=Pi:TdPi:D_i=P_{i:}^TdP_{i:} 来计算的话,由于 Pi:RNP_{i:}\in\mathbb{R}^N, 可能会导致超过 SRAM 的内存使用限制,因此,我们可以使用 Di=doiToiD_i=do_i^To_i 来避免这个问题,其中 oiRdo_i\in\mathbb{R}^d.

最终,flashattention 的 backward pass 过程如下图所示

FlashAttention backward pass

经过前面的分析,flashattention 的反向传播的时间复杂度为 O(N2)\mathcal{O}(N^2), 空间复杂度为 O(N)\mathcal{O}(N).

Theorem 5 令 NN 为 sequence length, dd 为 head dimension, MM 是 SRAM 的 size, 且满足 dMNdd\leq M\leq Nd. 则 flashattention 反向传播的内存访问开销为 Θ(N2d2M1)\Theta(N^2d^2M^{-1}).

定理的证明与 Theorem 2 基本一致,我们此处不再赘述。

Block-sparse Flashattention

当 attention 具有 block sparsity 的性质时,作者提出了 blck-sparse flashattention 来进一步提高 attention 的计算效率。

给定 Q,K,VRN×dQ,K,V\in\mathbb{R}^{N\times d}, 以及一个 mask M{0,1}N×NM\in\{0,1\}^{N\times N}, 我们需要计算

S=QKTRN×N,P=softmax(S1M)RN×N,O=PVRN×dS = QK^T\in\mathbb{R}^{N\times N},\quad P = \mathrm{softmax}(S\odot \mathbb{1}_{M})\in\mathbb{R}^{N\times N},\quad O = PV\in\mathbb{R}^{N\times d}

其中当 Mkl=1M_{kl}=1 时, (S1M)kl=Skl(S\odot \mathbb{1}_ {M})_ {kl}=S_ {kl}, 否则 (S1M)kl=0(S\odot \mathbb{1}_ {M})_{kl}=0.

Block-sparse attention 的算法如下所示

Block-sparse FlashAttention forward pass

Proposition 4 令 NN 为 sequence length, dd 为 head dimension, MM 是 SRAM 的 size, 且满足 dMNdd\leq M\leq Nd. 则 block-sparse attention 的内存访问开销为 Θ(Nd+N2d2M1s)\Theta(Nd+N^2d^2M^{-1}s), 其中 ss 是 block-sparse mask 中的非零 block 的比例

证明与 Theorem 2 的证明是类似的,总的内存访问开销为 O(Nd+NdTc)\mathcal{O}(Nd+NdT_c), 但是在计算的过程中,由于 mask 矩阵的 block-sparsity, 我们实际上只需要计算一小部分 Mij0M_{ij}\neq0 的情况,因此最终的内存访问开销为

O(Nd+N2d2Ms)\mathcal{O}\left(Nd+\frac{N^2d^2}{M}s\right)

可以看到,attention mask 的 sparsity 越高,block-sparse flashattention 的效率也就越高。当 NN 非常大时,ss 通常为 1/N1/\sqrt{N} 或者 N1logNN^{-1}\log N, 从而最终的内存访问开销为 O(NN)\mathcal{O}(N\sqrt{N}) 或者 O(NlogN)\mathcal{O}(N\log N).

作者对比了以下 block-sparse flashattention 和 flashattention 的效率对比,结果如下图所示

Efficiency of block-sparse FlashAttention

Experiment

作者通过实验验证了 flashattention 的有效性,如下表所示

AttentionStandardFlashAttention
GFLOPs66.675.2
HBM R/W (GB)40.34.4
Runtime (ms)41.77.3

可以看到,尽管 flashattention 相比于标准化 attention 需要更多的算力,但是由于其内存访问开销更少,所以最终的运行时间大有了大幅度降低

作者还探究了 block size 对 flashattention 性能对的影响,实验结果如下图所示

Ablation on block size

可以看到,随着 block size 增加,循环次数降低,内存访问开销也逐渐降低。但是当 block size 充分大 ( >256> 256) 之后,运行时间就会被别的因素所限制,并且过大的 block size 可能会导致 SRAM 的内存溢出

作者首先在 BERT 和 GPT-2 上验证了 flashattention 的表现,BERT 的实验结果如下表所示

BERT ImplementationTraining time (minutes)
Nvidia MLPerf 1.120.0±1.520.0\pm1.5
FlashAttention (ours)17.4±1.417.4\pm1.4

GPT-2 的实验结果如下表所示

Model implementationsOpenWebText (ppl)Training time (speedup)
GPT-2 small - Huggingface18.29.5 days (1.0 )
GPT-2 small - Megatron-LM18.24.7 days (2.0 )
GPT-2 small - FlashAttention18.22.7 days (3.5 )
GPT-2 medium - Huggingface14.221.0 days (1.0 )
GPT-2 medium - Megatron-LM14.311.5 days (1.8 )
GPT-2 medium - FlashAttention14.36.9 days (3.0 )

实验结果显示,flashattention 比 Huggingface 快 3 倍左右,比 Megatron 快 1.7 倍左右

  1. 训练速度:实验显示,flashattention 在 BERT 上,比 MLPerf 1.1 快 15%15\%, 在 GPT-2 上比 HuggingFace 快 3 倍,比 Megatron 快 1.8 倍
  2. 准确率:flashattention 是第一个在 Path-X 上比随机表现更好的 transformer 模型;block-sparse flashattention 是第一个在 Path-256 上比随机表现更好的的 sequence model

Conclusion

作者提出了 flashattention, 一个通过优化标准 attention 内存访问效率来提高 attention 计算效率的方法,作者详细介绍了算法设计的原理与证明,并通过实验证明了结果的有效性。