Transformer 训练与推理细节:损失函数、Mask 与并行
这一篇把训练和推理阶段里容易被略过但又非常关键的细节拆开:损失怎么写、Mask 真正起什么作用,以及为什么 Transformer 特别适合大规模并行。
训练目标:本质就是“预测下一个 token / 填空”
大多数 Transformer 模型的训练目标可以归到两个范式里:
- 自回归(Autoregressive):预测下一个 token(GPT 一类);
- 掩码语言模型(Masked LM):预测被挖掉的 token(BERT 一类)。
自回归最典型的形式是:
$$ \mathcal{L} = -\sum_{t=1}^{T} \log p(x_t \mid x_1, \dots, x_{t-1}) $$实现上,就是把一句话整体右移一位,当成“标签”,对每个位置做一次分类交叉熵:
1# logits: (B, T, V) 模型输出
2# target: (B, T) 下一个 token
3loss = cross_entropy(logits.view(-1, V), target.view(-1), ignore_index=pad_id)
Masked LM 则是只在被标记为 [MASK] 的位置计算损失,其它位置忽略:
1mask = (input_ids == mask_id) # 只在被 mask 的位置算 loss
2loss = cross_entropy(
3 logits[mask], # 这些位置的预测
4 target_ids[mask] # 对应的真实 token
5)
从这两种形式可以看到,本质都是把“理解/生成”的问题统一成多分类问题,只是“预测的目标位置集合”不同。
Mask 的两种核心用途:Padding 与 Causal
训练时常见的 Mask,主要有两种:
- Padding Mask:忽略序列里为了对齐而补的
PAD; - Causal Mask:保证当前 token 看不到未来(用于自回归)。
Padding Mask:不要让补零影响注意力
当 batch 里句子长短不一时,会用 PAD 把它们补到同一长度。如果不做处理,注意力会把这些 padding 位也当成“正常 token”去看。
解决办法是在算注意力分数时对这些位置打上遮罩:
1scores = (Q @ K.transpose(-2, -1)) / math.sqrt(d_k) # (B, h, L_q, L_k)
2
3if padding_mask is not None:
4 # padding_mask: (B, 1, 1, L_k),padding 位置为 0
5 scores = scores.masked_fill(padding_mask == 0, -1e9)
6
7weights = softmax(scores, dim=-1)
softmax 之后,被设为 $-10^9$ 的位置对应权重几乎为 0,相当于“看不见”。
Causal Mask:不能偷看未来
对于 GPT 这类自回归模型,第 $i$ 个位置只能看到 $0 \dots i$ 的 token,需要用一个上三角矩阵 mask 掉未来的信息:
1L = seq_len
2causal_mask = torch.tril(torch.ones(L, L)) # 下三角为 1,上三角为 0
3
4scores = scores.masked_fill(causal_mask == 0, -1e9)
5weights = softmax(scores, dim=-1)
有了 Causal Mask,同一个 batch 里所有位置仍然可以并行计算,但语义上等价于“逐个位置往右生成”。
训练时的并行性:为什么 Transformer 跑得快?
相比 RNN,Transformer 的一个巨大优势是 可以在时间维度上并行:
- 在 RNN 里,每一步都依赖上一时刻的隐藏状态,只能串行;
- 在 Transformer 里,同一层里所有位置的计算只依赖这一层的输入,可以同时算。
训练时通常会有多层并行叠加:
- batch 维度并行:一个 batch 里放很多序列;
- 序列维度并行:同一句话里的所有位置在同一层里可以一起算;
- 头维度并行:多头注意力会把 head 维度也一起展开放到大矩阵乘法里。
这让 Transformer 非常适合 GPU/TPU 上的大矩阵运算,用几次巨大的 GEMM(通用矩阵乘法)把大部分算力吃满。
推理阶段:KV Cache 带来的巨大加速
训练时,为了利用并行性,一般一次性把一整句喂进去;推理时如果也这么做,每生成一个新 token 都要重新算整句的 Attention,成本太高。
解决办法是 缓存之前的 K/V(KV Cache):
- 第一次生成时,按正常流程算出当前所有位置的 Q/K/V,并把每一层的 K/V 存起来;
- 之后每生成一个新 token,只需要:
- 对这个新 token 算它自己的 Q/K/V;
- 把新的 K/V 拼到缓存后面;
- 只对“当前这个位置”算 Attention;
- 整个序列之前的位置不需要重复计算。
伪代码大致是:
1def forward_step(x_t, past_kv):
2 # x_t: 当前这一步输入的 token 向量
3 # past_kv: 每一层历史的 K/V
4 new_past_kv = []
5 h = x_t
6 for layer, (k_cache, v_cache) in zip(layers, past_kv):
7 q, k, v = layer.attn_proj(h) # 只对当前 token 做投影
8 k = concat(k_cache, k, dim=seq_dim) # 拼到历史后面
9 v = concat(v_cache, v, dim=seq_dim)
10 h = layer.attn(q, k, v) # 只算最后一个位置的输出
11 new_past_kv.append((k, v))
12 h = layer.ffn(h)
13 return h, new_past_kv
这就是实际系统里“流式生成”能做到的关键:每个 token 的增量计算成本基本与上下文长度无关。
精度与稳定性:FP16 / BF16、梯度裁剪与规范化
大模型训练里,为了在有限显存里塞进更大的 batch 和模型,通常会用 半精度:
- FP16:早期常见,但对数值范围比较敏感,需要混合精度训练(weights 用 FP16,累加/某些层用 FP32);
- BF16:更适合深度学习的数值分布,很多新卡已经原生支持。
配合半精度时,常见的一些稳定性技巧包括:
- Loss Scaling:把 loss 放大到一定倍数再反传,避免梯度在 FP16 下下溢;
- 梯度裁剪(Gradient Clipping):比如按全局范数裁剪到某个上限,防止梯度爆炸;
- 合适的归一化方式:Pre-LN(LayerNorm 在 Block 前)通常比 Post-LN 更好训。
这些细节都不改变 Transformer 的“表达能力”,但对能不能在大数据上稳定收敛影响很大。
小结:训练/推理中的关键点
- 训练目标可以统一理解为“对一堆 token 做分类交叉熵”,自回归和 Masked LM 只是目标位置集合不同;
- Padding Mask 负责“忽略补零”,Causal Mask 负责“禁止偷看未来”,本质都是在 softmax 前把不合法的位置权重打到接近 0;
- 训练阶段的并行性来自 batch × 序列 × 头三个维度的大矩阵乘法;
- 推理阶段的 KV Cache 让生成成本从“每次重算整句”变成“只算增量”;
- 半精度训练 + 一些数值稳定性技巧,是把 Transformer 规模做大的基础设施。