大规模 Transformer:参数规模、并行与内存优化
这一篇从“训得起、跑得动”的角度看大模型:参数规模和显存的关系,常见的并行手段,以及几种典型的内存优化思路。
参数规模与显存:大概是个什么量级?
粗略地看,一个 Transformer 模型的参数量主要来自两块:
- 词嵌入(Embedding)
- 各层里的线性层(注意力投影 + FFN)
忽略常数项,可以用一个非常粗糙的估算公式:
$$ \text{Params} \approx 12 \times L \times d_{\text{model}}^2 $$其中:
- $L$ 是层数;
- $d_{\text{model}}$ 是隐藏维度。
显存占用又和两部分有关:
- 参数本身:参数量 × 每个参数的字节数(FP16/FP32 等);
- 激活值(Activation):前向过程中每一层的中间结果,为了反向传播需要暂存。
所以一个 10B 级别的模型,即使用 FP16,参数本身也要二十多 GB 显存,再加上激活值、优化器状态,如果不用并行和优化,是根本放不下的。
并行方式:Data / Model / Pipeline
为了撑起更大的模型和 batch,一般会同时用几种并行方式。
Data Parallel:最直观,也最常用
做法是:
- 把 batch 拆成多份,分到多张卡上;
- 每张卡存一份完整的模型副本,算各自的前向/反向;
- 反向结束后通过 All-Reduce 把梯度同步,再一起更新参数。
好处:
- 容易实现,几乎所有框架都有一键封装;
- 不改变模型本身结构。
代价:
- 模型必须能放进单卡显存;
- 当模型本身太大时就不够用了。
Model Parallel:把模型拆开放到不同卡上
当单卡放不下整套模型时,只能把模型的参数横着切开:
- Tensor Parallel:按列/行把大矩阵拆成几块,分别丢到不同设备上;
- Layer Parallel / Pipeline Parallel:把不同层分布在不同设备上,数据像“流水线”一样依次经过。
Tensor Parallel 适合解决“单个线性层太宽”的问题;
Pipeline Parallel 则更像是“前几层在卡 0,后几层在卡 1”,要处理好流水线上的 bubble(空转)。
混合并行:实际系统里的组合拳
在真实的大模型训练里,很少只用一种并行方式。常见组合是:
- Data Parallel + Tensor Parallel;
- Data Parallel + Tensor Parallel + Pipeline Parallel。
这样既能扩参数规模,又能扩 batch,还能把显存压力摊到多张卡上。
内存优化:激活检查点、ZeRO 等
除了并行,还可以直接从“占多少内存”下手做优化。
激活检查点(Activation Checkpointing)
原始做法是:前向时把每一层的中间结果都存下来,反向时直接用。
激活检查点的思路是:只保存少数关键节点,其它在反向时重新算一遍。
这样做:
- 显存占用大幅下降(因为少存了很多中间张量);
- 计算量略有增加(反向时要多算几次前向)。
在大模型场景里,多算一点前向通常比多占显存要划算得多。
ZeRO:分解优化器状态和梯度
像 Adam 这样的优化器,会为每个参数存两份额外状态(动量、一阶二阶矩),内存几乎翻几倍。
ZeRO 系列方法的核心是:把参数、梯度和优化器状态在多卡之间分片存储,只在需要的时候做聚合。
大致可以理解为:
- 每张卡只保存自己负责的一部分参数/状态;
- 需要做前向/反向/更新时,通过通信把需要的碎片聚在一起;
- 做完再丢回各自的卡。
这对“大模型 + Data Parallel”来说非常关键。
精度和算子融合:再往下挖一点性能
在硬件支持的前提下,还可以从“更小精度”和“更少内存访问”上继续挖。
更小精度:
- 从 FP32 → FP16 → BF16 → INT8 / FP8;
- 训练一般用 FP16/BF16,推理可以考虑 INT8/FP8 量化来进一步省显存和带宽。
算子融合(Kernel Fusion):
- 把一些本来分散的操作(例如线性层 + Bias + 激活)合并成一个自定义 CUDA kernel;
- 目标是减少中间张量的读写次数,把“算力瓶颈”变成“带宽瓶颈”之前榨干设备能力。
很多领先的框架(如各家加速库、推理引擎)都会针对 Transformer 做这类深度优化,比如 fused MLP、fused attention 等。
推理场景:批量、并发与延迟的权衡
在线上推理时,还要考虑:
批量(Batching):
- 多个用户请求合并成一个大 batch,一次性跑完,再拆分结果;
- 提升吞吐,但会带来额外延迟。
并发与 KV Cache 复用:
- 对于多轮对话/流式输出,可以复用上文的 KV Cache,避免重复算;
- 多请求之间的 batch 拼接要兼顾“上下文长度差异”和“padding 成本”。
一个常见做法是:
- 对请求做简单调度,把“长度相近”的请求分在同一批;
- 同时控制“最大 batch 大小”和“最大排队等待时间”,在吞吐和延迟之间找平衡。
总结:规模化背后的一些关键点
- 参数规模和显存开销大致随 $L \times d_{\text{model}}^2$ 增长,不做并行和优化会很快撑爆显存;
- Data / Model / Pipeline 三种并行方式组合使用,是把模型做大的基础;
- 激活检查点、ZeRO 等技术,从“存什么”和“存在哪里”两个方向节省了大量显存;
- 更小精度和算子融合则持续压缩每次前向/反向的算力和内存访存成本;
- 在线推理端,还需要在批量、并发、KV Cache 复用之间做工程上的权衡。