GPipe

google 在 2018 年提出了 GPipe, 一个使用 pipeline parallelism 来训练大规模神经网络的并行策略

Author

Published

2025-12-23 16:49:25+0800

google 在 2018 年提出了 GPipe, 一个使用 pipeline parallelism 来训练大规模神经网络的并行策略

Introduction

大规模神经网络已经在计算机视觉和自然语言处理等任务上取得了突破性进展。但是目前训练大规模神经网络存在的问题时,我们无法在单一 GPU 上训练我们的模型。基于多 GPU 训练模型需要考虑模型的切分以及通信优化。

为了解决这个问题,作者提出了 GPipe, 一个用于将大规模性模型分割部署到不同设备上的并行计算策略。

Method

Notation

作者首先定义 notation 如下表所示

notationdescription
LLnumber of layers
wiw_iweights of a layer
fif_iforward function of a layer
Fk=fjfiF_k=f_j\circ\cdots\circ f_iforward of a partition
BkB_kbackward of a partition
KKnumber of partitions
NNbatch size
MMmicro batch size

GPipe

首先是 naive pipeline parallelism (naive PP), 我们的输入为一个 batch, 然后我们依次计算 F1F_1, 通信传输,计算 F2F_2, 计算完成之后,我们再进行反向传播,更新参数。最后继续下一个 batch 的计算。

总体的过程如下图所示

illustration of pipeline parallelism

下面是一个按照时间轴给出的例子

an example of naive PP with 4 devices

naive PP 的问题在于,每个时刻只有一个 GPU 在工作,GPU 的利用效率很低。因此,GPipe 的做法在于将一个 batch 切分为 MM 个更小的 micro-batch, 下面是一个 M=4M=4 的例子

An example of pipeline parallelism with 4 devices and 4 micro batches

通过切分更小的 batch,我们可以提高 GPU 的利用率

Analysis

Bubble

接下来作者分析了 GPipe 的 bubble 情况,bubble 指的是 PP 过程中的 GPU idle time.

对于 naive PP 来说,一个 GPU 工作时,其余 GPU 都处于空闲状态,因此其 bubble 为

Tbubble=(K1)(F+B)T_{bubble} = (K-1)(F+B)

总的计算时间为

Ttotal=K(F+B)T_{total} = K(F+B)

从而 bubble rate 为

Bubblenaive=TbubbleTtotal=K1KBubble_{naive} = \frac{T_{bubble}}{T_{total}} = \frac{K-1}{K}

K=8K=8 时,我们有 Bubblenaive=87.5%Bubble_{naive}=87.5\%, 也就是说,当前训练的 GPU 空闲率为 87.5%87.5\%.

对于 GPipe 来说,由于我们将一个 batch 拆分为了更小的 batch, 我们可以提高 GPU 的利用率。

此时,我们的 bubble time 仍然是 Tbubble=(K1)(F+B)T_{bubble} = (K-1)(F+B). 但是,现在同一时刻工作的 GPU 变多了,从上面的示意图可以看到,前向过程所需要的时间为第一个 micro batch 运行的时间加上 M1M-1 个 batch 运行所需要的时间,反向同理,因此,GPipe 的总计算时间为

Ttotal=(M+K1)(F+B)T_{total} = (M+K-1)(F+B)

从而 GPipe 的 bubble rate 为

Bubblenaive=TbubbleTtotal=K1M+K1Bubble_{naive} = \frac{T_{bubble}}{T_{total}} = \frac{K-1}{M+K-1}

当我们令 M=8,K=8M=8, K=8 时,我们有 Bubblenaive=46.7%Bubble_{naive}=46.7\%, 可以看到,通过提高 micro batch 数量,我们可以显著降低 bubble rate.

Activation Memory

对于 naive PP 来说,我们需要缓存每一层的输入,因此 activation memory 为 O(N×L/K)\mathcal{O}(N\times L/K), 而使用 activation checkpointing 之后,我们现在的 activation memory 为

O(N+LK×NM)\mathcal{O}(N + \frac{L}{K}\times\frac{N}{M})

其中第一项代表了 boundary activation, 第二项代表了 Internal activation.

Experiments

作者在 image classification, machine translation 任务上进行了实验。

作者还进一步分析了影响 GPipe 性能的因素,结果如下图所示

Time step breakdown

可以看到,activation checkpointing 是 GPipe 的主要开销来源。

Conclusion

作者在本文中提出了 GPipe, 一个针对大规模神经网络训练的并行策略。通过将模型切分部署在不同的设备上以及使用 micro batch, 我们可以显著提高硬件的利用效率以及训练稳定性。

    Introduction

    随着模型参数变大,现有的 GPU 已经很难使用单一 GPU 来训练模型。对于多 GPU 训练场景,目前主要采用了 pipeline parallelism, 比如 GPipe 等,但是,这些策略需要我们对代码进行比较大的改动,这提高了开发成本。

    为了解决多 GPU 训练大规模 LLM 的效率,降低开发成本,目前主要使用了 model parallelism 策略,即对模型进行切分部署在多个 GPU 上。model parallelism 有两种范式:

    1. pipeline parallelism (PP): 将模型按照 layer 进行切分,如 GPipe 等,这种方法的问题是需要额外的逻辑来处理通信以及存在 pipeline bubbles
    2. tensor parallelism (TP): 将模型的按照权重进行切分,部署在不同的 GPU 上。

    作者在本文中基于 TP 策略来对 attention, FFN layer 进行简单改动来实现训练效率的提升。

    作者通过实现验证了 tensor parallelism 的有效性和高效率,结果发现在 512 张 GPU 的场景下,TP 可以达到 76%76\% 的 scaling efficiency (相比于 1 张 GPU 带来的性能提升)

    Method

    作者使用的 transformer 架构如下图所示

    transformer architecture

    本文中,作者探究了 BERT 和 GPT-2 两种架构。

    首先,我们假设 transformer layer 输入为 XRbs×dX\in\mathbb{R}^{bs\times d}, 这里 b,sb, s 分别为 batch size, sequence length, 接下来我们介绍如何针对 FFN, attention 以及 embedding 构建 TP 策略

    FFN

    论文中使用的 FFN 为 Linear-GeLU-Linear 的结构,对应第一层权重为 W1Rd×dffW_1\in\mathbb{R}^{d\times d_{ff}}, 第二层权重为 W2Rdff×dW_2\in\mathbb{R}^{d_{ff}\times d}, 对应数学表达式为

    Y=GeLU(XW1)W2Rbs×dY = \mathrm{GeLU}(XW_1)W_2\in\mathbb{R}^{bs\times d}

    我们首先对 W1W_1 按照 column 进行切分,得到

    W1=[W11,W12]Rd×dff, where W11Rd×d1,W12Rd×d2,d1+d2=dffW_1 = [W_{11}, W_{12}]\in\mathbb{R}^{d\times d_{ff}}, \text{ where } W_{11}\in\mathbb{R}^{d\times d_1}, W_{12}\in\mathbb{R}^{d\times d_2}, d_1+d_2=d_{ff}

    这里 d1,d2d_1, d_2 与我们并行的 GPU 数 (x-way TP) 相关,这样,我们就有

    GeLU(XW1)=GeLU(X[W11,W12])=GeLU([XW11,XW22])=[GeLU(XW11),GeLU(XW12)]\mathrm{GeLU}(XW_1) = \mathrm{GeLU}(X[W_{11}, W_{12}]) = \mathrm{GeLU}([XW_{11}, XW_{22}]) = [\mathrm{GeLU}(XW_{11}), \mathrm{GeLU}(XW_{12})]

    从而我们可以分别将 W11W_{11}W12W_{12} 部署在两个 GPU 上,然后并行计算。

    论文中还介绍如果我们对 W1W_1 按照 row 进行切分,则最终由于 GeLU(A+B)GeLU(A)+GeLU(B)\mathrm{GeLU}(A+B)\neq \mathrm{GeLU}(A)+\mathrm{GeLU}(B) 计算时会产生一次额外的同步。

    接下来,对于 W2W_2, 我们按照 row 进行切分得到

    W2=[W21W22]Rdff×d, where W21Rd1×d,W22Rd2×d,d1+d2=dffW_2 = \begin{bmatrix} W_{21}\\ W_{22} \end{bmatrix}\in\mathbb{R}^{d_{ff}\times d}, \text{ where }W_{21}\in\mathbb{R}^{d_1\times d}, W_{22}\in\mathbb{R}^{d_2\times d}, d_1+d_2=d_{ff}

    计算时,我们有

    GeLU(XW1)W2=[GeLU(XW11),GeLU(XW12)][W21W22]=GeLU(XW11)W21+GeLU(XW12)W22\mathrm{GeLU}(XW_1)W_2 = [\mathrm{GeLU}(XW_{11}), \mathrm{GeLU}(XW_{12})]\begin{bmatrix} W_{21}\\ W_{22} \end{bmatrix} = \mathrm{GeLU}(XW_{11})W_{21} + \mathrm{GeLU}(XW_{12})W_{22}

    可以看到,通过按照 row 进行切分,我们可以将 W11,W21W_{11}, W_{21} 部署在一个 GPU 上,将 W12,W22W_{12}, W_{22} 部署在另一个 GPU 上,分别计算出 GeLU(XW11)W21\mathrm{GeLU}(XW_{11})W_{21}GeLU(XW12)W22\mathrm{GeLU}(XW_{12})W_{22} 之后,再通过一此 all-reduce 操作得到最终的输出结果。计算图如下所示

    Tensor Parallelism for MLP in transformer block

    这里 ffgg 是两个对偶算子,代表了 TP 产生的额外通信开销

    operatorforwardbackward
    ffidentityall-reduce
    ggall-reduceidentity

    如果说我们使用的是 SwiGLU FFN, 即

    Y=(XW3Swish(XW1))W2Y = (XW_3\odot \mathrm{Swish}(XW_1))W_2

    我们按照 column 对 W1,W3W_1, W_3 进行切分,按照 row 对 W2W_2 进行切分(假设我们有 2 个 GPU),得到

    W1=[W11,W12]Rd×dff, where W11Rd×d1,W12Rd×d2,d1+d2=dffW3=[W31,W32]Rd×dff, where W31Rd×d1,W32Rd×d2,d1+d2=dffW2=[W21W22]Rdff×d, where W21Rd1×d,W22Rd2×d,d1+d2=dff\begin{aligned} W_1 &= [W_{11}, W_{12}]\in\mathbb{R}^{d\times d_{ff}}, \text{ where } W_{11}\in\mathbb{R}^{d\times d_1}, W_{12}\in\mathbb{R}^{d\times d_2}, d_1+d_2=d_{ff}\\ W_3 &= [W_{31}, W_{32}]\in\mathbb{R}^{d\times d_{ff}}, \text{ where } W_{31}\in\mathbb{R}^{d\times d_1}, W_{32}\in\mathbb{R}^{d\times d_2}, d_1+d_2=d_{ff}\\ W_2 &= \begin{bmatrix} W_{21}\\ W_{22} \end{bmatrix}\in\mathbb{R}^{d_{ff}\times d}, \text{ where }W_{21}\in\mathbb{R}^{d_1\times d}, W_{22}\in\mathbb{R}^{d_2\times d}, d_1+d_2=d_{ff} \end{aligned}

    然后我们将 W11,W31,W21W_{11}, W_{31}, W_{21} 放在第一个 GPU 上,将 W12,W32,W22W_{12}, W_{32}, W_{22} 放在第二个 GPU 上,此时,

    Swish(XW1)=Swish(X[W11,W12])=Swish([XW11,XW12])=[Swish(XW11,Swish(XW12]XW3Swish(XW1)=[XW31,XW32]Swish(XW1)=[XW31Swish(XW11),XW32Swish(XW12)]Y=(XW3Swish(XW1))W2=(XW3Swish(XW1))[W21W22]=XW31Swish(XW11)W21+XW32Swish(XW12)W22\begin{aligned} \mathrm{Swish}(XW_1) &= \mathrm{Swish}(X[W_{11}, W_{12}]) = \mathrm{Swish}([XW_{11}, XW_{12}])=[\mathrm{Swish}(XW_{11}, \mathrm{Swish}(XW_{12}]\\ XW_3\odot \mathrm{Swish}(XW_1) &= [XW_{31}, XW_{32}]\mathrm{Swish}(XW_1) = [XW_{31}\mathrm{Swish}(XW_{11}), XW_{32}\mathrm{Swish}(XW_{12})]\\ Y = (XW_3\odot \mathrm{Swish}(XW_1))W_2&=(XW_3\odot \mathrm{Swish}(XW_1))\begin{bmatrix} W_{21}\\ W_{22} \end{bmatrix} = XW_{31}\mathrm{Swish}(XW_{11})W_{21}+ XW_{32}\mathrm{Swish}(XW_{12})W_{22} \end{aligned}

    这样我们通过一次 all-reduce 也可以完成 SwiGLU FFN 的 tensor parallelism, 示意图如下所示

    Tensor Parallelism for SwiGLU MLP in transformer block

    Attention

    Attention 的处理与 MLP 非常相似,论文中的做法就是将不同 head 部署到不同 gpu 上分别进行计算,最后在计算 output projection 时再通过一次 all-reduce 来合并输出,这里我们假设有 hh 个 heads, 每个 head 的 dimension 为 dhd_h, 我们先对 query, key, value layer 的 weight WQ,WK,WVRd×hdhW_Q, W_K, W_V\in\mathbb{R}^{d\times hd_h} 进行切分

    WQ=[WQ1,,WQh],WK=[WK1,,Wkh],WV=[WV1,,WVh]W_Q = [W_{Q1}, \dots, W_{Qh}], W_K = [W_{K1}, \dots, W_{kh}], W_V = [W_{V1}, \dots, W_{Vh}]

    其中 WQi,WKi,WViRd×dhW_{Qi}, W_{Ki}, W_{Vi}\in\mathbb{R}^{d\times d_h} 为每个 head 对应的 query, key, value weight. 我们将切分后的 WQi,WKi,WViW_{Qi}, W_{Ki}, W_{Vi} 部署在一个 GPU 上(也可以将若干个 head 部署在一个 GPU 上),然后分别计算出每个 GPU 的 attention 结果,最后再进行汇总,如下所示

    oi=softmax((XWQi)(XWKi)Tdh)XWVi,i=1,,hO=[o1,,oh]WO\begin{aligned} o_i &= \mathrm{softmax}\left(\frac{(XW_{Qi})(XW_{Ki})^T}{d_h}\right) XW_{Vi}, i=1,\dots,h\\ O &= [o_1,\dots,o_h]W_O \end{aligned}

    下面是 multi-head attention 对应的 TP 示意图

    Tensor Parallelism for multi-head attention

    Embedding

    对于 Input embedding, 作者将 embedding matrix EEV×dE\in\mathbb{E}^{V\times d} 按照 row 进行切分(论文中使用了转置,因此是按照 column 进行切分),得到 E=[E1,E2]TE=[E_1,E_2]^T, 这里 EiRd×ViE_i\in\mathbb{R}^{d\times V_i}, V1+V2=VV_1+V_2=V, 接下来我们把切分后的 embedding matrix 部署在不同的 GPU 上,由于每个 GPU 只有部分结果,因此我们还需要进行 all-reduce 来进行汇总。

    而对于 output embedding, 我们也可以使用类似的做法进行切分,每个 GPU 上计算完结果之后我们还需要一个 all-gather 来汇总结果。

    作者在这里还额外介绍了针对 output embedding 的优化方法,由于 embedding 的输出大小为 [bs,V][bs, V], 而 VV 通常比较大,因此,为了降低通信开销,作者将 cross-entropy-loss 与 output embedding kernel 进行融合,这样我们传输的数据量就减少到了 bsbs.

    Experiments

    作者首先对 GPT-2 模型进行了修正,首先将 vocab_size 从 50257 提升到 128 的倍数,即 51200. 对于 model+data parallelism, 作者固定 global batch size 为 512. (64-way DP)

    配置如下表所示(head size 为 96)

    Hidden sizeattention headslayersparameters (B)TPTP+DP
    153616401.2164
    192020542.52128
    230424644.24256
    307232728.38512

    对应的 scaling (使用多卡训练后,每个 GPU 相对于单卡训练的利用率)如下表所示

    parallelismTP-1TP-2TP-4TP-8TP-1+DP-64TP-2+DP-64TP-4+DP-64TP-8+DP-64
    scaling100%95%82%77%96%83%79%74%

    Implementation

    首先是 linear layer 的 TP 版本,如下所示

    import torch
    import torch.nn as nn
    import torch.distributed as dist
    
    
    class ColumnParallelLinear(nn.Module):
        def __init__(self, in_features, out_features, bias=True):
            super().__init__()
            self.rank = dist.get_rank()
            self.world_size = dist.get_world_size()
            
            self.local_out = out_features // self.world_size
            self.weight = nn.Parameter(torch.empty(in_features, self.local_out))
    
    
        def forward(self, x):
            out = x @ self.weight
            
            gather_list = [torch.empty_like(out) for _ in range(self.world_size)]
            dist.all_gather(gather_list, out)
            return torch.cat(gather_list, dim=-1)
    
    
    class RowParallelLinear(nn.Module):
        def __init__(self, in_features, out_features, bias=True):
            super().__init__()
            self.rank = dist.get_rank()
            self.world_size = dist.get_world_size()
            
            self.local_in = in_features // self.world_size
            self.weight = nn.Parameter(torch.empty(self.local_in, out_features))
    
        def forward(self, x):
            x_local = torch.chunk(x, self.world_size, dim=-1)[self.rank]
    
            out = x_local @ self.weight
            
            dist.all_reduce(out, op=dist.ReduceOp.SUM)
            return out
    

    接下来是针对 LLM 中使用的 SwiGLU FFN 进行的优化,基于前面的介绍,我们不需要对基于 column linear 进行 all-reduce, 代码如下所示

    SwiGLU

    import torch
    from torch import nn
    import torch.nn.functional as F
    import torch.distributed as dist
    
    
    world_size = 1
    rank = 0
    
    class ColumnParallelLinear(nn.Module):
        def __init__(self, in_features: int, out_features: int, dtype = None):
            assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})"
            self.part_out_features = out_features // world_size
            self.weight = nn.Parameter(torch.empty(part_out_features, part_in_features, dtype=dtype))
    
        def forward(self, x: torch.Tensor) -> torch.Tensor:
            y = x @ self.weight
            return y
    
    
    class RowParallelLinear(nn.Module):
        def __init__(self, in_features: int, out_features: int, dtype = None):
            assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})"
            self.part_in_features = in_features // world_size
            self.weight = nn.Parameter(torch.empty(out_features, part_in_features, dtype=dtype))
    
        def forward(self, x: torch.Tensor) -> torch.Tensor:
            y = x @ self.weight
            if world_size > 1:
                dist.all_reduce(y)
            return y
    
    
    class MLP(nn.Module):
        def __init__(self, dim: int, inter_dim: int):
            super().__init__()
            self.w1 = ColumnParallelLinear(dim, inter_dim)
            self.w2 = RowParallelLinear(inter_dim, dim)
            self.w3 = ColumnParallelLinear(dim, inter_dim)
    
        def forward(self, x: torch.Tensor) -> torch.Tensor:
            return self.w2(F.silu(self.w1(x)) * self.w3(x))
    

    attention

    class TPMultiHeadAttention(nn.Module):
        def __init__(self, d_model: int, num_heads: int, head_dim: int = None):
            self.d_model = d_model
            self.num_heads = num_heads
            self.head_dim = head_dim if head_dim is not None else d_model // num_heads
            
            assert self.num_heads % world_size == 0, "num_heads must be divisible by world size" 
            assert self.d_model == self.num_heads * self.head_dim, "d_model must equal num_heads * head_dim"
            
            # heads of different GPU
            self.local_num_heads = self.num_heads // world_size
            self.local_qkv_dim = self.local_num_heads * self.head_dim
            
            self.qkv_proj = ColumnParallelLinear(in_features=d_model, out_features=3 * self.local_qkv_dim)
            self.out_proj = RowParallelLinear(in_features=d_model, out_features=d_model)
            
            self.scale = 1.0 / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
            
        def _split_heads(self):
            batch_size, seq_len, _ = x.shape
            # [batch, seq_len, local_num_heads, head_dim]
            x = x.reshape(batch_size, seq_len, self.local_num_heads, self.head_dim)
            # [batch, local_num_heads, seq_len, head_dim]
            return x.transpose(1, 2)
            
        def forward(self, x: torch.Tensor, attn_mask: torch.Tensor = None):
            batch_size, seq_len, _ = x.shape
            # [batch, seq_len, local_qkv_dim * 3]
            qkv = self.qkv_proj(x)
            # [batch, seq_len, local_qkv_dim]
            q, k, v = torch.split(qkv, self.local_qkv_dim, dim=-1)
            # [batch, local_num_heads, seq_len, head_dim]
            q = self._split_heads(q) 
            k = self._split_heads(k) 
            v = self._split_heads(v)
            
            attn_scores = (q @ k.transpose(-2, -1)) * self.scale
            
            if attn_mask is not None: 
                attn_scores = attn_scores + attn_mask
                
            attn_weights = F.softmax(attn_scores, dim=-1)
            # [batch, local_num_heads, seq_len, head_dim]
            attn_output = attn_weights @ v
            # [batch, seq_len, local_num_heads, head_dim]
            attn_output = attn_output.transpose(1, 2)
            # [batch, seq_len, local_qkv_dim]
            attn_output = attn_output.reshape(batch_size, seq_len, self.local_qkv_dim)
            
            gather_list = [torch.empty_like(attn_output) for _ in range(world_size)] 
            dist.all_gather(gather_list, attn_output)
            # [batch, seq_len, d_model]
            full_attn_output = torch.cat(gather_list, dim=-1)
            
            out = self.out_proj(full_attn_output)
            
            return out
    

    Embedding

    class ParallelEmbedding(nn.Module):
        def __init__(self, vocab_size: int, dim: int):
            super().__init__()
            self.vocab_size = vocab_size
            self.dim = dim
            assert vocab_size % world_size == 0, f"Vocabulary size must be divisible by world size (world_size={world_size})"
            self.part_vocab_size = (vocab_size // world_size)
            self.vocab_start_idx = rank * self.part_vocab_size
            self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size
            self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))
    
        def forward(self, x: torch.Tensor) -> torch.Tensor:
            if world_size > 1:
                mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
                x = x - self.vocab_start_idx
                x[mask] = 0
            y = F.embedding(x, self.weight)
            if world_size > 1:
                y[mask] = 0
                dist.all_reduce(y)
            return y
    

    Conclusion

    作者提出了针对 transformer 架构的 tensor parallelism 策略来提高整体的训练效率,通过在训练过程加入四次 all-reduce 通信我们就可以训练更大规模的模型。

      说明:本文参考了 nanotron/ultrascale-playbookColossal-AI Concepts

      什么是分布式系统

      分布式系统允许一个软件的多个组件运行在不同的机器上。与传统集中式系统不一样,分布式系统可以有效提高系统的稳健性。 一个比较比较经典的分布式就是Git,Git允许我们把代码保存在多个remote上。这样当一个remote宕机时,其他remote也能提供服务。

      评估一个分布式系统的重要标准就是规模效益(scalablity),也就是说,我们希望使用8台设备应该要比4台设备快2倍。但是,由于通信带宽等原因,实际上加速比并不是和设备数量成线性关系。因此,我们需要设计分布式算法,来有效提高分布式系统的效率。

      为什么需要分布式训练

      我们需要分布式训练的原因主要是以下几点:

      1. 模型越来越大。当下(2025)领先模型如Qwen,LLaMA系列的最大模型都超过了100B [2][3]。LLaMA系列最大的模型甚至超过了1000B。Scaling law告诉我们模型表现与参数量,算力,数据量成正相关关系。
      2. 数据集越来越大。现在领先的模型需要的数据量基本都需要100M以上,而大语言模型训练需要的token数量也都超过了10T的量级 [2][3].
      3. 算力越来越强。现有最强的GPU H100其显存为80GB,拥有3.35TB/s 的带宽 (PcIe),这让训练大规模模型成为可能。

      超大的模型使得我们很难在一张GPU上进行训练,甚至我们都很难使用单张GPU进行部署。而10T级的数据也也需要几个月的时间才能训练完毕。因此,如何高效利用多张GPU在大规模数据上训练超大模型就是我们需要解决的问题。

      基本概念

      我们先来熟悉一下分布式训练中的一些基本概念:

      我们以下图为例:

      basic concepts

      上图中一共包含2个node (2台机器),每台机器包含4个GPU (device),当我们初始化分布式环境时,我们一共启动了8个进程(每台机器4个进程),每个进程绑定一个GPU。

      在初始化分布式环境之间,我们需要指定host和port。假设我们指定host为node 0和port为 29500,接下来,所有的进程都会基于这个host和port来与其他进程连接。默认的process group(包含所有device)的world size 为8. 其细节展示如下

      process IDrankNode indexGPU index
      0000
      1101
      2202
      3303
      4410
      5511
      6612
      7713

      我们可以创建一个新的process group,使其仅包含ID为偶数的process:

      process IDrankNode indexGPU index
      0000
      2102
      4210
      6312

      Remark: 注意,rank与process group相关,一个process在不同的process group里可能会有不同的rank.

      通信方式

      接下来,我们需要介绍一下设备间的通信方式,这是我们后面分布式训练算法的基础。根据设备数量的不同,我们可以将设备间通信分为:

      1. one-to-one: 两个device之间互相进行通信
      2. one-to-many: 一个device与多个device进行通信
      3. many-to-one: 多个device与一个device之间进行通信
      4. many-to-many: 多个device之间互相进行通信

      One-to-one

      One-to-one的情况很简单,一个process与另一个process进行通信,通信通过 sendrecv 完成。还有对应的 immediate版本,即 isendirecv,示意图如下所示

      point to point communication

      测试代码如下:

      # send_recv.py
      import os
      
      import torch
      import torch.distributed as dist
      
      
      def init_process():
          dist.init_process_group(backend='nccl')
          torch.cuda.set_device(dist.get_rank())
          
      def example_send():
          if dist.get_rank() == 0:
              tensor = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32).cuda()
              dist.send(tensor, dst=1)
          elif dist.get_rank() == 1:
              tensor = torch.zeros(5, dtype=torch.float32).cuda()
              print(f"Before send on rank {dist.get_rank()}: {tensor}")
              dist.recv(tensor, src=0)
              print(f"After send on rank {dist.get_rank()}: {tensor}")
      
      
      init_process()
      example_send()
      
      # run with
      # torchrun --nproc_per_node=2 send_recv.py
      

      结果输出

      Before send on rank 1: tensor([0., 0., 0., 0., 0.], device='cuda:1')
      After send on rank 1: tensor([1., 2., 3., 4., 5.], device='cuda:1')
      

      注:为了方便,后续代码仅定义函数和运行方式,init_process()和import部分省略

      send/recv的特点是在完成通信之前,两个process是锁住的。与之相反,isend/irecv 则不会加锁,代码会继续执行然后返回Work对象,为了让通信顺利进行,我们可以在返回之前加入wait()

      # isend_irecv.py
      def example_isend():
          req = None
          if dist.get_rank() == 0:
              tensor = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32).cuda()
              req = dist.isend(tensor, dst=1)
              print("Rank 0 is sending")
              
          elif dist.get_rank() == 1:
              tensor = torch.zeros(5, dtype=torch.float32).cuda()
              print(f"Before irecv on rank {dist.get_rank()}: {tensor}")
              req = dist.irecv(tensor, src=0)
              print("Rank 1 is receiving")
              req.wait()
              print(f"After isend on rank {dist.get_rank()}: {tensor}")
      
      init_process()
      example_isend()
      
      # run with
      # torchrun --nproc_per_node=2 isend_irecv.py
      

      结果输出

      Before irecv on rank 1: tensor([0., 0., 0., 0., 0.], device='cuda:1')
      Rank 0 is sending
      Rank 1 is receiving
      After isend on rank 1: tensor([1., 2., 3., 4., 5.], device='cuda:1')
      

      由于isend/irecv这种不锁的特性,我们不应该

      1. dist.isend()之前修改发送的内容tensor
      2. dist.irecv()之后读取接受的内容tensor

      req.wait() 可以保证这次通信顺利完成,因此我们可以在req.wait()之后再进行修改和读取。

      One-to-many

      One-to-many 情形下,可以分为两种:scatter 和 broadcast

      scatter的作用是将一个process的数据均分并散布到其他process。broadcast的作用是将一个process的数据广播到其他process。两者不同的地方在于其他process获取到的是全量数据(copy)还是部分数据(slice),其示意图如下所示

      scatter and broadcast

      scatter 测试代码:

      # scatter.py
      def example_scatter():
          if dist.get_rank() == 0:
              scatter_list = [
                  torch.tensor([i+1] * 5, dtype=torch.float32).cuda()
                  for i in range(dist.get_world_size())
              ]
              print(f"Rank 0 scatter list: {scatter_list}")
          else:
              scatter_list = None
          
          tensor = torch.zeros(5, dtype=torch.float32).cuda()
          print(f"Before scatter on rank {dist.get_rank()}: {tensor}")
          dist.scatter(tensor, scatter_list, src=0)
          print(f"After scatter on rank {dist.get_rank()}: {tensor}")
      
      init_process()
      example_broadcast()
      
      # run with
      # torchrun --nproc_per_node=4 broadcast.py
      

      结果输出以下内容(输出内容有优化,后续不再说明):

      Rank 0 scatter list: [
          tensor([1., 1., 1., 1., 1.], device='cuda:0'), 
          tensor([2., 2., 2., 2., 2.], device='cuda:0'), 
          tensor([3., 3., 3., 3., 3.], device='cuda:0'), 
          tensor([4., 4., 4., 4., 4.], device='cuda:0')
      ]
      Before scatter on rank 2: tensor([0., 0., 0., 0., 0.], device='cuda:2')
      Before scatter on rank 1: tensor([0., 0., 0., 0., 0.], device='cuda:1')
      Before scatter on rank 3: tensor([0., 0., 0., 0., 0.], device='cuda:3')
      Before scatter on rank 0: tensor([0., 0., 0., 0., 0.], device='cuda:0')
      
      After scatter on rank 0: tensor([1., 1., 1., 1., 1.], device='cuda:0')
      After scatter on rank 2: tensor([3., 3., 3., 3., 3.], device='cuda:2')
      After scatter on rank 3: tensor([4., 4., 4., 4., 4.], device='cuda:3')
      After scatter on rank 1: tensor([2., 2., 2., 2., 2.], device='cuda:1')
      

      broadcast 测试代码:

      # broadcast.py
      def example_broadcast():
          if dist.get_rank() == 0:
              tensor = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32).cuda()
          else:
              tensor = torch.zeros(5, dtype=torch.float32).cuda()
          print(f"Before broadcast on rank {dist.get_rank()}: {tensor}")
          dist.broadcast(tensor, src=0)
          print(f"After broadcast on rank {dist.get_rank()}: {tensor}")
      
      
      init_process()
      example_broadcast()
      
      # run with
      # torchrun --nproc_per_node=3 broadcast.py
      

      结果输出:

      Before broadcast on rank 1: tensor([0., 0., 0., 0., 0.], device='cuda:1')
      Before broadcast on rank 2: tensor([0., 0., 0., 0., 0.], device='cuda:2')
      Before broadcast on rank 0: tensor([1., 2., 3., 4., 5.], device='cuda:0')
      
      After broadcast on rank 0: tensor([1., 2., 3., 4., 5.], device='cuda:0')
      After broadcast on rank 1: tensor([1., 2., 3., 4., 5.], device='cuda:1')
      After broadcast on rank 2: tensor([1., 2., 3., 4., 5.], device='cuda:2')
      

      Many-to-one

      Many-to-one 情形下,也可以分为两种:gather 和 reduce, Gather对应one-to-many的scatter操作,负责将多个process的内容汇聚到一起,形成一个完整的向量。而reduce的操作则是通过一个函数 f()f(\cdot) 来把数据进行汇总,常见的函数有求和以及求平均,示意图如下所示

      gather and reduce

      gather 测试代码:

      # gather.py
      def example_gather():
          tensor = torch.tensor([dist.get_rank() + 1] * 5, dtype=torch.float32).cuda()
          if dist.get_rank() == 0:
              gather_list = [
                  torch.zeros(5, dtype=torch.float32).cuda()
                  for _ in range(dist.get_world_size())
              ]
              print(f"Rank 0 gather list: {gather_list}")
          else:
              gather_list = None
          
          print(f"Before gather on rank {dist.get_rank()}: {tensor}")
          dist.gather(tensor, gather_list, dst=0)
          if dist.get_rank() == 0:
              print(f"After gather on rank {dist.get_rank()}: {gather_list}")
      
      init_process()
      example_gather()
      
      # run with
      # torchrun --nproc_per_node=4 gather.py
      

      结果输出:

      Before gather on rank 3: tensor([4., 4., 4., 4., 4.], device='cuda:3')
      Before gather on rank 2: tensor([3., 3., 3., 3., 3.], device='cuda:2')
      Before gather on rank 0: tensor([1., 1., 1., 1., 1.], device='cuda:0')
      Before gather on rank 1: tensor([2., 2., 2., 2., 2.], device='cuda:1')
      
      After gather on rank 0: [
          tensor([1., 1., 1., 1., 1.], device='cuda:0'), 
          tensor([2., 2., 2., 2., 2.], device='cuda:0'), 
          tensor([3., 3., 3., 3., 3.], device='cuda:0'), 
          tensor([4., 4., 4., 4., 4.], device='cuda:0')
      ]
      

      reduce 测试代码:

      # example_reduce.py
      def example_reduce():
          tensor = torch.tensor([dist.get_rank() + 1] * 5, dtype=torch.float32).cuda()
          print(f"Before reduce on rank {dist.get_rank()}: {tensor}")
      
          dist.reduce(tensor, dst=0, op=dist.ReduceOp.SUM)
          if dist.get_rank() == 0:
              print(f"After reduce on rank {dist.get_rank()}: {tensor}")
      
      
      init_process()
      example_reduce()
      
      # run with
      # torchrun --nproc_per_node=3 example_reduce.py
      

      这里我们使用求和dist.ReduceOp.SUM作为我们的汇总操作,Pytorch还支持其他的reduce operations. 结果输出以下内容:

      Before reduce on rank 2: tensor([3., 3., 3., 3., 3.], device='cuda:2')
      Before reduce on rank 3: tensor([4., 4., 4., 4., 4.], device='cuda:3')
      Before reduce on rank 0: tensor([1., 1., 1., 1., 1.], device='cuda:0')
      Before reduce on rank 1: tensor([2., 2., 2., 2., 2.], device='cuda:1')
      
      After reduce on rank 0: tensor([10., 10., 10., 10., 10.], device='cuda:0')
      

      Many-to-many

      Many-to-many 情形下的两种通信方式为:All-Reduce 和 All-Gather,分别是reduce和gather的升级版,all-reduce对所有process都执行一次reduce操作,而all-gather则对所有process执行一次gather操作,其示意图如下所示

      all-gather and all-reduce

      all-gather 测试代码:

      # example_all_gather.py
      def example_all_gather():
          tensor = torch.tensor([dist.get_rank() + 1] * 5, dtype=torch.float32).cuda()
          gather_list = [
              torch.zeros(5, dtype=torch.float32).cuda()
              for _ in range(dist.get_world_size())
          ]
          print(f"Before all gather on rank {dist.get_rank()}: {tensor}")
          dist.all_gather(gather_list, tensor)
          print(f"After all gather on rank {dist.get_rank()}: {gather_list}")
      
      
      init_process()
      example_all_gather()
      
      # run with
      # torchrun --nproc_per_node=3 example_all_gather.py
      

      测试输出结果:

      Before all gather on rank 2: tensor([3., 3., 3., 3., 3.], device='cuda:2')
      Before all gather on rank 3: tensor([4., 4., 4., 4., 4.], device='cuda:3')
      Before all gather on rank 0: tensor([1., 1., 1., 1., 1.], device='cuda:0')
      Before all gather on rank 1: tensor([2., 2., 2., 2., 2.], device='cuda:1')
      
      After all gather on rank 0: [
          tensor([1., 1., 1., 1., 1.], device='cuda:0'), 
          tensor([2., 2., 2., 2., 2.], device='cuda:0'), 
          tensor([3., 3., 3., 3., 3.], device='cuda:0'), 
          tensor([4., 4., 4., 4., 4.], device='cuda:0')]
      After all gather on rank 2: [
          tensor([1., 1., 1., 1., 1.], device='cuda:2'), 
          tensor([2., 2., 2., 2., 2.], device='cuda:2'), 
          tensor([3., 3., 3., 3., 3.], device='cuda:2'), 
          tensor([4., 4., 4., 4., 4.], device='cuda:2')
      ]
      After all gather on rank 3: [
          tensor([1., 1., 1., 1., 1.], device='cuda:3'), 
          tensor([2., 2., 2., 2., 2.], device='cuda:3'), 
          tensor([3., 3., 3., 3., 3.], device='cuda:3'), 
          tensor([4., 4., 4., 4., 4.], device='cuda:3')
      ]
      After all gather on rank 1: [
          tensor([1., 1., 1., 1., 1.], device='cuda:1'), 
          tensor([2., 2., 2., 2., 2.], device='cuda:1'), 
          tensor([3., 3., 3., 3., 3.], device='cuda:1'), 
          tensor([4., 4., 4., 4., 4.], device='cuda:1')
      ]
      
      

      all-reduce 测试代码:

      # example_all_reduce.py
      def example_all_reduce():
          tensor = torch.tensor([dist.get_rank() + 1] * 5, dtype=torch.float32).cuda()
          print(f"Before all reduce on rank {dist.get_rank()}: {tensor}")
          dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
          print(f"After all reduce on rank {dist.get_rank()}: {tensor}")
      
      init_process()
      example_all_reduce()
      
      # run with
      # torchrun --nproc_per_node=3 example_all_reduce.py
      

      测试输出结果

      Before all reduce on rank 1: tensor([2., 2., 2., 2., 2.], device='cuda:1')
      Before all reduce on rank 0: tensor([1., 1., 1., 1., 1.], device='cuda:0')
      Before all reduce on rank 2: tensor([3., 3., 3., 3., 3.], device='cuda:2')
      Before all reduce on rank 3: tensor([4., 4., 4., 4., 4.], device='cuda:3')
      After all reduce on rank 0: tensor([10., 10., 10., 10., 10.], device='cuda:0')
      After all reduce on rank 2: tensor([10., 10., 10., 10., 10.], device='cuda:2')
      After all reduce on rank 3: tensor([10., 10., 10., 10., 10.], device='cuda:3')
      After all reduce on rank 1: tensor([10., 10., 10., 10., 10.], device='cuda:1')
      

      Barrier

      除了之前这些传输数据的方式之外,我们还有Barrier,用于在所有process之间进行同步。Barrier会确保所有的process在同一时间点完成某些操作。其流程为,先让每个process完成各自的任务,然后当process到达barrier时,process会通知系统自己已到达。最后当所有process都到达barrier之后,阻塞会解除,所有process继续执行下一步操作。

      barrier 测试代码

      # example_barrier.py
      def example_barrier():
          import time
          rank = dist.get_rank()
          t_start = time.time()
          print(f"Rank {rank} sleeps {rank} seconds")
          time.sleep(rank)
          dist.barrier()
          print(f"Rank {rank} is done at {time.time() - t_start:.4f} seconds")
      
      init_process()
      example_barrier()
      
      # run with
      # torchrun --nproc_per_node=3 example_barrier.py
      

      结果输出

      Rank 2 sleeps 2 seconds
      Rank 0 sleeps 0 seconds
      Rank 1 sleeps 1 seconds
      Rank 3 sleeps 3 seconds
      
      Rank 3 is done at 3.3046 seconds
      Rank 1 is done at 3.3229 seconds
      Rank 2 is done at 3.8437 seconds
      Rank 0 is done at 3.6613 seconds
      

      可以看到,四个process的到达时间都在3s左右,这是因为rank 3需要3s才能完成当前任务

      Advanced

      除了前面的通信方式之外,还有 Reduce-Scatter和Ring All-Reduce,这两个通信方式等我们学习ZeRO的时候再一并讲解。

      Reference

      1. Colossal-AI
      2. LLaMA 4 blog
      3. Qwen3 blog
      4. Pytorch tutorial
      5. nanotron/ultrascale-playbook