大规模 Transformer:参数规模、并行与内存优化

这一篇从“训得起、跑得动”的角度看大模型:参数规模和显存的关系,常见的并行手段,以及几种典型的内存优化思路。

粗略地看,一个 Transformer 模型的参数量主要来自两块:

  • 词嵌入(Embedding)
  • 各层里的线性层(注意力投影 + FFN)

忽略常数项,可以用一个非常粗糙的估算公式:

$$ \text{Params} \approx 12 \times L \times d_{\text{model}}^2 $$

其中:

  • $L$ 是层数;
  • $d_{\text{model}}$ 是隐藏维度。

显存占用又和两部分有关:

  • 参数本身:参数量 × 每个参数的字节数(FP16/FP32 等);
  • 激活值(Activation):前向过程中每一层的中间结果,为了反向传播需要暂存。

所以一个 10B 级别的模型,即使用 FP16,参数本身也要二十多 GB 显存,再加上激活值、优化器状态,如果不用并行和优化,是根本放不下的。

为了撑起更大的模型和 batch,一般会同时用几种并行方式。

做法是:

  • 把 batch 拆成多份,分到多张卡上;
  • 每张卡存一份完整的模型副本,算各自的前向/反向;
  • 反向结束后通过 All-Reduce 把梯度同步,再一起更新参数。

好处:

  • 容易实现,几乎所有框架都有一键封装;
  • 不改变模型本身结构。

代价:

  • 模型必须能放进单卡显存;
  • 当模型本身太大时就不够用了。

当单卡放不下整套模型时,只能把模型的参数横着切开:

  • Tensor Parallel:按列/行把大矩阵拆成几块,分别丢到不同设备上;
  • Layer Parallel / Pipeline Parallel:把不同层分布在不同设备上,数据像“流水线”一样依次经过。

Tensor Parallel 适合解决“单个线性层太宽”的问题;
Pipeline Parallel 则更像是“前几层在卡 0,后几层在卡 1”,要处理好流水线上的 bubble(空转)。

在真实的大模型训练里,很少只用一种并行方式。常见组合是:

  • Data Parallel + Tensor Parallel;
  • Data Parallel + Tensor Parallel + Pipeline Parallel。

这样既能扩参数规模,又能扩 batch,还能把显存压力摊到多张卡上。

除了并行,还可以直接从“占多少内存”下手做优化。

原始做法是:前向时把每一层的中间结果都存下来,反向时直接用。
激活检查点的思路是:只保存少数关键节点,其它在反向时重新算一遍

这样做:

  • 显存占用大幅下降(因为少存了很多中间张量);
  • 计算量略有增加(反向时要多算几次前向)。

在大模型场景里,多算一点前向通常比多占显存要划算得多。

像 Adam 这样的优化器,会为每个参数存两份额外状态(动量、一阶二阶矩),内存几乎翻几倍。
ZeRO 系列方法的核心是:把参数、梯度和优化器状态在多卡之间分片存储,只在需要的时候做聚合。

大致可以理解为:

  • 每张卡只保存自己负责的一部分参数/状态;
  • 需要做前向/反向/更新时,通过通信把需要的碎片聚在一起;
  • 做完再丢回各自的卡。

这对“大模型 + Data Parallel”来说非常关键。

在硬件支持的前提下,还可以从“更小精度”和“更少内存访问”上继续挖。

  • 更小精度

    • 从 FP32 → FP16 → BF16 → INT8 / FP8;
    • 训练一般用 FP16/BF16,推理可以考虑 INT8/FP8 量化来进一步省显存和带宽。
  • 算子融合(Kernel Fusion)

    • 把一些本来分散的操作(例如线性层 + Bias + 激活)合并成一个自定义 CUDA kernel;
    • 目标是减少中间张量的读写次数,把“算力瓶颈”变成“带宽瓶颈”之前榨干设备能力。

很多领先的框架(如各家加速库、推理引擎)都会针对 Transformer 做这类深度优化,比如 fused MLP、fused attention 等。

在线上推理时,还要考虑:

  • 批量(Batching)

    • 多个用户请求合并成一个大 batch,一次性跑完,再拆分结果;
    • 提升吞吐,但会带来额外延迟。
  • 并发与 KV Cache 复用

    • 对于多轮对话/流式输出,可以复用上文的 KV Cache,避免重复算;
    • 多请求之间的 batch 拼接要兼顾“上下文长度差异”和“padding 成本”。

一个常见做法是:

  • 对请求做简单调度,把“长度相近”的请求分在同一批;
  • 同时控制“最大 batch 大小”和“最大排队等待时间”,在吞吐和延迟之间找平衡。
  • 参数规模和显存开销大致随 $L \times d_{\text{model}}^2$ 增长,不做并行和优化会很快撑爆显存;
  • Data / Model / Pipeline 三种并行方式组合使用,是把模型做大的基础;
  • 激活检查点、ZeRO 等技术,从“存什么”和“存在哪里”两个方向节省了大量显存;
  • 更小精度和算子融合则持续压缩每次前向/反向的算力和内存访存成本;
  • 在线推理端,还需要在批量、并发、KV Cache 复用之间做工程上的权衡。