LLM Memory Computation

本文中,我们将介绍如何计算 LLM 在训练和推理过程中的内存需求以及简要介绍对应的优化方法。

Author

Published

2026-01-17 10:04:32+0800

Motivation

我们从一个简单的问题开始

假如我有一张 80GB 的显卡,我想训练/推理一个 4B 的模型,我应该设置多大的 batch size 和 sequence length?

在这个 tutorial 中,我们将基于这个问题来进行思考和分析。我们将考虑更一般的问题形式:

Motivation: 在训练和推理时 LLM 所需要的内存是多少?如何进行优化内存占用?

为了回答以上问题,我们先介绍训练/推理阶段的内存计算,再针对可优化部分进行分析并介绍相应优化算法。

Background

Transformer Architecture

以 Qwen3 为例,现代 LLM 的架构包含多层 Transformer Block,其中具体的模块不同的模型可能有改动。下图是对应的模型架构

Notation

与参数量、FLOPs 计算所用记号一致;参数量 PP 的推导见 LLM parameter analysis

变量含义
PPnumber of parameters
LLlayers
VVvocabulary size
ddhidden size
dffd_{\text{ff}}FFN hidden size
sssequence length
bbbatch size
hhnumber of attention heads
dhd_hattention head dimension

Assumptions

  1. 若无特别说明,使用 BF16/FP16,每个参数 2 byte。
  2. 不使用 dropout(与现代大模型设定一致)。
  3. Attention 基于原始 multi-head attention.
  4. FFN 基于 SwiGLU.

Training Memory Analysis

Training Memory Components

训练部分的内存占用由四部分组成:

training_memory=weight+activation+optimizer+gradient\text{training\_memory} = \text{weight} + \text{activation} + \text{optimizer} + \text{gradient}

Optimizer States

AdamW 优化器需要维护两个动量状态:

AdamW [2] 的更新规则如下:

mtβ1mt1+(1β1)gtvtβ2vt1+(1β2)gt2m^tmt1β1t,v^tvt1β2tθtθt1α(m^tv^t+ϵ+λθt1)\begin{aligned} m_t &\leftarrow \beta_1 \cdot m_{t-1} + (1 - \beta_1) \cdot g_t \\ v_t &\leftarrow \beta_2 \cdot v_{t-1} + (1 - \beta_2) \cdot g_t^2 \\ \hat{m}_t &\leftarrow \frac{m_t}{1 - \beta_1^t}, \quad \hat{v}_t \leftarrow \frac{v_t}{1 - \beta_2^t} \\ \theta_t &\leftarrow \theta_{t-1} - \alpha \left( \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} + \lambda \theta_{t-1} \right) \end{aligned}

Activation

激活值是前向传播过程中计算得到的中间结果,用于在反向传播时计算梯度。

我们仅针对 linear layer 进行推导:

forward:z=Wa1+b,a=ϕ(z)backward:LW=LzzW=Lza1\begin{aligned} \text{forward:} \quad & \mathbf{z}_\ell = W_\ell \mathbf{a}_{\ell-1} + b_\ell, \quad \mathbf{a}_\ell = \phi(\mathbf{z}_\ell) \\ \text{backward:} \quad & \frac{\partial \mathcal{L}}{\partial W_\ell} = \frac{\partial \mathcal{L}}{\partial \mathbf{z}_\ell} \cdot \frac{\partial \mathbf{z}_\ell}{\partial W_\ell} = \frac{\partial \mathcal{L}}{\partial \mathbf{z}_\ell} \cdot \boxed{\mathbf{a}_{\ell-1}} \end{aligned}

可以看到,计算第 \ell 层关于 WW_\ell 的梯度时需要其输入 a1\mathbf{a}_{\ell-1},因此训练时需保存每个模块对应的输入,也就是激活值 (activation)。

Activation — Attention

按计算图(无优化)可得需保存的激活:

Attention 合计: 10bsd+4bhs2\boxed{10bsd + 4bhs^2}

Activation — FFN & LayerNorm

FFN(SwiGLU,assume dff=4dd_{\text{ff}} = 4d):

LayerNorm:保存输入 → 2bsd\boxed{2bsd}

Activation — Output

Output 包含以下组成部分:

合计:4bsd+2bsV\boxed{4bsd + 2bsV}

Activation — Total

将上面的结果汇总在一起,得到:

activation=Ltransformer_block+output=L(Pre_Norm+Attention+Post_Norm+FFN)+output=bs(32dL+4hsL+4d+2V)bsL(32d+4hs)4bs2hL\begin{aligned} \text{activation} &= L \cdot \text{transformer\_block} + \text{output} \\ &= L \cdot (\text{Pre\_Norm} + \textcolor{red}{\text{Attention}} + \text{Post\_Norm} + \text{FFN}) + \text{output} \\ &= \boxed{bs(32dL + \textcolor{red}{4hsL} + 4d + 2V)} \\ &\approx bsL(32d + \textcolor{red}{4hs}) \\ &\approx \textcolor{red}{4bs^2hL} \end{aligned}

注:在 Qwen3 中,2V/32dL6%2V / 32dL \approx 6\%32dL/4hsL2.5%32dL / 4hsL \approx 2.5\%

可以看到,未优化的情况下,activationbs2\text{activation} \propto bs^2。这里 s2s^2 主要由 attention 部分产生,后续 Flash Attention 就针对这一点进行了优化。

Total Training Memory

将上面的结果进行汇总:

training_memory=weight+activation+optimizer+gradient=2P+bs(32dL+4hsL+4d+2V)+4P+2P=8P+bs(32dL+4hsL+4d+2V)(exact)8P+4bs2hL\begin{aligned} \text{training\_memory} &= \text{weight} + \text{activation} + \text{optimizer} + \text{gradient} \\ &= 2P + bs(32dL + 4hsL + 4d + 2V) + 4P + 2P \\ &= 8P + bs(32dL + 4hsL + 4d + 2V) \quad \text{(exact)} \\ &\approx 8P + 4bs^2hL \end{aligned}

可以看到,训练阶段的内存占用分为固定部分 (8P8P) 和动态部分 (4bs2hL4bs^2hL),动态部分主要是 attention 的缓存。

Experiments

我们分别针对 80GB 的显卡计算 Qwen3 系列模型的最高配置:

ModelPPLLhhsspredicted bbactual bb
Qwen3-0.6B0.628165126834
Qwen3-1.7B1.728165124228
Qwen3-4B436325121612
Qwen3-8B8.1363251242

其中 predicted bb 基于前面的准确公式计算得到;actual bb 通过实验验证得到。注意我们这里的 prediction 没有考虑任何优化手段与其他内存开销,因此与实际值有出入。

Case Study

我们分别使用 Qwen3-4B 和 Qwen3-8B 来进行实验(b=1b=1, s=512s=512)。参考 PyTorch 显存可视化与 Snapshot 数据分析

Inference Memory Analysis

Inference Components

Inference 阶段内存占用主要与模型参数、KV cache 两部分相关:

Inference_Memory=weight+activation+KV cache\text{Inference\_Memory} = \text{weight} + \text{activation} + \text{KV cache}

KV Cache Mechanism

LLM 推理中为避免重复计算历史 token 的 key/value 而使用的空间换时间的缓存机制。

自回归时逐 token 生成,每步 attention 形式为(qt\mathbf{q}_t 当前 query,k:,t\mathbf{k}_{:,t} / v:,t\mathbf{v}_{:,t} 历史 K/V):

qt=WQxt,k:,t=WK[x1,,xt],v:,t=WV[x1,,xt]\mathbf{q}_t = W_Q \mathbf{x}_t, \quad \mathbf{k}_{:,t} = W_K[\mathbf{x}_1, \ldots, \mathbf{x}_t], \quad \mathbf{v}_{:,t} = W_V[\mathbf{x}_1, \ldots, \mathbf{x}_t]

处理下一 token xt+1\mathbf{x}_{t+1} 时只需在已有结果后追加当前步:

k:,t+1=[k:,t,WKxt+1],v:,t+1=[v:,t,WVxt+1]\mathbf{k}_{:,t+1} = [\mathbf{k}_{:,t},\, W_K \mathbf{x}_{t+1}], \quad \mathbf{v}_{:,t+1} = [\mathbf{v}_{:,t},\, W_V \mathbf{x}_{t+1}]

KV Cache Memory

对于 multi-head attention,KV cache 的显存占用为:

Memory(KV cache)=s×2×2×L×h×dh=4sLhdh\text{Memory}(\text{KV cache}) = s \times 2 \times 2 \times L \times h \times d_h = \boxed{4sLhd_h}

因子含义: ss 序列长,第一个 22 为 K+V,第二个 22 为 BF16 的 2 bytes,LL 层、hh 头、dhd_h 头维度。

Remark:

Total Inference Memory

综合前面分析,推理阶段的总内存为:

Inference_Memory2.4P+4sLhdh\boxed{\text{Inference\_Memory} \approx 2.4P + 4sLhd_h}

可以看到,推理阶段也由固定部分(参数量,activation)以及动态部分(KV cache)组成。

Dynamic vs. Static

由于 Qwen3 的 KV cache 计算为 4sLhkvdh4sLh_{kv}d_h,而不同模型只有 LL 不一样,因此对于更大的模型,KV cache 显存占用超过模型权重的上下文长度更高。

Optimization

Overview

阶段核心方法典型技术
Training显存与效率提升Activation Checkpointing, Mixed Precision Training, Flash Attention, ZeRO, Pipeline/Model/Data Parallelism
Inference长序列与速度优化KV Cache Optimization, Paged/Radix Attention, Faster Attention, Quantization

Mixed Precision Training

计算量大的部分用低精度,计算量小的部分用高精度。低精度参与运算,高精度避免 Overflow/Underflow。

下表是 DeepSeek-V3 [3] 使用的混合精度训练框架的显存分析:

PrecisionBF16FP32BF16
AMPNoYesYes
WeightsBF16 (2)FP32 (4)BF16 (2)
Master weights--FP32 (4)
GradientsBF16 (2)FP32 (4)BF16 (2)
Adam mBF16 (2)FP32 (4)FP32 (4)
Adam vBF16 (2)FP32 (4)FP32 (4)
Static total (bytes/param)81616

Remark:

  1. BF16 (w/ AMP) 与 FP32 (w/ AMP) 的静态显存占用相同,但 BF16 (w/ AMP) 的动态显存占用更低。
  2. 主流框架基本都使用了 BF16/FP8 (w/ AMP) 的训练方式。

ZeRO

将 optimizer states / gradients / weights 按不同 GPU 切片存储,需要参与计算时再 all-gather 整合成完整参数。这样每张卡只需维护自己负责的一部分,大幅降低单卡显存需求 [4]

ZeRO Stages:

Training_Memory=weight+activation+optimizer#GPUs+gradient\text{Training\_Memory} = \text{weight} + \text{activation} + \frac{\text{optimizer}}{\#\text{GPUs}} + \text{gradient} Training_Memory=weight+activation+optimizer+gradient#GPUs\text{Training\_Memory} = \text{weight} + \text{activation} + \frac{\text{optimizer} + \text{gradient}}{\#\text{GPUs}} Training_Memory=activation+weight+optimizer+gradient#GPUs\text{Training\_Memory} = \text{activation} + \frac{\text{weight} + \text{optimizer} + \text{gradient}}{\#\text{GPUs}}

ZeRO-3 可极大降低单卡显存上限,但通信量也会提高。

Model Parallelism

将模型切分到不同的 GPU 上,计算时,先 dispatch,再执行计算,最后通过 all-gather 等操作得到最终结果 [5]。切分方式包括 PP (Pipeline Parallelism)、TP (Tensor Parallelism)、EP (Expert Parallelism) 等。

Training_Memory=Memory(weight)PP degree×TP degree\text{Training\_Memory} = \frac{\text{Memory}(\text{weight})}{\text{PP degree} \times \text{TP degree}}

结合 ZeRO-1 与 Model Parallelism 时(activation 中与 TP 相关的部分按 TP degree 缩减):

MemorytrainweightPP×TP+activationTP+optimizer#GPUs+gradientPP\text{Memory}_{\text{train}} \approx \frac{\text{weight}}{\text{PP} \times \text{TP}} + \frac{\text{activation}}{\text{TP}} + \frac{\text{optimizer}}{\#\text{GPUs}} + \frac{\text{gradient}}{\text{PP}}

Activation Checkpointing

在反向传播时,重新计算所需的输入,来达到以时间换空间的目的 [6]

No ckptSelective ckptFull ckpt
memory很高中等很低 2bsd\sim 2bsd
extra compute中等很高 2Pbs\sim 2Pbs

实践中常结合 model parallelism 与 selective checkpointing 来实现 trade-off。

Flash Attention

通过将 Attention 的计算进行分块,来提高内存访问效率以及降低反向传播时所需要的 activation 大小 [7]

Flash Attention 通过 tiling 与 online-softmax 降低该部分显存并提升效率(详见 notes on Flash Attention)。这样 attention 部分的显存就由 activationbs2\text{activation} \propto bs^2 降低到了 activationbs\text{activation} \propto bs

Theorem: Flash Attention 输出 O=softmax(QKT)VO = \text{softmax}(QK^T)V(correctness)。其时间复杂度为 O(s2d)\mathcal{O}(s^2 d),空间复杂度为 O(s)\mathcal{O}(s)(memory savings)。

KV Cache Optimization

Memory(KV cache)=s×2×2×L×h×dh\text{Memory}(\text{KV cache}) = s \times 2 \times 2 \times L \times h \times d_h

针对公式中各因子的优化方向 [8]

  1. ss: KV cache compression, eviction, selection
  2. 22 (bytes): KV cache quantization
  3. 22 (K+V): key-value sharing, MLA [9]
  4. h×dhh \times d_h: MQA [10], GQA [11], MLA

Weight Quantization

使用低精度来表示高精度数值的方法,来减少内存占用/提高计算效率。

量化时机代表性工作
训练后量化 (PTQ)GPTQ [12], AWQ [13], SmoothQuant [14], GGUF [15]
量化感知训练 (QAT)LLM-QAT [16], PEQA [17]

Activation Offloading

将一部分参数/优化器状态/激活值等存储到 CPU 上,需要的时候再加载到 GPU 上。

Offloading 场景代表性工作
训练阶段 OffloadingZeRO-Offload [18] / ZeRO-Infinity, FSDP [19] CPU Offload
推理阶段 OffloadingFlexGen [20], vLLM [21] KV Cache Offload
MoE OffloadingKTransformers [22], DeepSpeed-MoE [23]

Real Systems

Training Frameworks

FrameworkMemory Optimizations
Megatron-LM [5]TP, SP, Activation Checkpointing
DeepSpeed [24]ZeRO-1/2/3, CPU/NVMe Offload, Activation Checkpointing
FSDP [19]Full parameter sharding, Gradient sharding, CPU Offload
Colossal-AI [25]ZeRO, TP, PP, Activation Checkpointing

Inference Frameworks

FrameworkKey TechniquesMemory Optimizations
vLLM [21]Paged AttentionKV cache paging, Continuous batching
SGLang [26]Radix AttentionKV cache reuse, Efficient scheduling
TensorRT-LLM [27]Kernel fusionWeight quantization, KV cache optimization

Conclusion

Takeaway

ComponentsTrainingInferenceOptimization
weights2P2P2P2Pquantization
optimizer states4P4P0ZeRO, offloading
gradients2P2P0ZeRO
activations4Lhs2b\sim 4Lhs^2b0.4P\sim 0.4Pckpt, offloading, flash attention
KV cache04sLhdh4sLhd_hKV optimization, attention
TOTAL8P+4Lhs2b8P + 4Lhs^2b2.4P+4sLhdh2.4P + 4sLhd_h

Future Directions

  1. More efficient architecture (attention, MoE).
  2. Scalable training/inference framework.
  3. Software-hardware co-design algorithms.

    1. An Yang et al., “Qwen3 Technical Report,” arXiv:2505.09388, 2025.
    2. Ilya Loshchilov and Frank Hutter, “Decoupled Weight Decay Regularization,” arXiv:1711.05101, 2019.
    3. DeepSeek-AI, “DeepSeek-V3 Technical Report,” arXiv:2412.19437, 2025.
    4. Samyam Rajbhandari et al., “ZeRO: Memory Optimizations Toward Training Trillion Parameter Models,” arXiv:1910.02054, 2020.
    5. Mohammad Shoeybi et al., “Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism,” arXiv:1909.08053, 2020.
    6. Vijay Korthikanti et al., “Reducing Activation Recomputation in Large Transformer Models,” arXiv:2205.05198, 2022.
    7. Tri Dao et al., “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness,” arXiv:2205.14135, 2022.
    8. Haoyang Li et al., “A Survey on Large Language Model Acceleration based on KV Cache Management,” arXiv:2412.19442, 2025.
    9. DeepSeek-AI, “DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model,” arXiv:2405.04434, 2024.
    10. Noam Shazeer, “Fast Transformer Decoding: One Write-Head is All You Need,” arXiv:1911.02150, 2019.
    11. Joshua Ainslie et al., “GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints,” arXiv:2305.13245, 2023.
    12. Elias Frantar et al., “GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers,” arXiv:2210.17323, 2023.
    13. Ji Lin et al., “AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration,” arXiv:2306.00978, 2024.
    14. Guangxuan Xiao et al., “SmoothQuant: Accurate and Efficient Post-Training Quantization for Large Language Models,” arXiv:2211.10438, 2024.
    15. Georgi Gerganov, “ggml: Tensor library for machine learning,” GitHub, 2023.
    16. Zechun Liu et al., “LLM-QAT: Data-Free Quantization Aware Training for Large Language Models,” arXiv:2305.17888, 2023.
    17. Jeonghoon Kim et al., “Memory-Efficient Fine-Tuning of Compressed Large Language Models via sub-4-bit Integer Quantization,” arXiv:2305.14152, 2023.
    18. Jie Ren et al., “ZeRO-Offload: Democratizing Billion-Scale Model Training,” arXiv:2101.06840, 2021.
    19. Yanli Zhao et al., “PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel,” arXiv:2304.11277, 2023.
    20. Ying Sheng et al., “FlexGen: High-Throughput Generative Inference of Large Language Models with a Single GPU,” arXiv:2303.06865, 2023.
    21. Woosuk Kwon et al., “Efficient Memory Management for Large Language Model Serving with PagedAttention,” SOSP, 2023.
    22. Hongtao Chen et al., “KTransformers: Unleashing the Full Potential of CPU/GPU Hybrid Inference for MoE Models,” SOSP, 2025.
    23. Samyam Rajbhandari et al., “DeepSpeed-MoE: Advancing Mixture-of-Experts Inference and Training to Power Next-Generation AI Scale,” arXiv:2201.05596, 2022.
    24. Jeff Rasley et al., “DeepSpeed: System Optimizations Enable Training Deep Learning Models with Over 100 Billion Parameters,” KDD, 2020.
    25. Shenggui Li et al., “Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training,” ICPP, 2023.
    26. Lianmin Zheng et al., “SGLang: Efficient Execution of Structured Language Model Programs,” NeurIPS, 2024.
    27. NVIDIA Corporation, “TensorRT-LLM: A TensorRT Toolset for Optimizing LLM Inference,” GitHub, 2023.