Self-Attention 细节:从矩阵形状到多头实现

这一篇把 Self-Attention 拆细一点:从输入输出的张量形状,到 Q/K/V 是怎么算出来的,多头注意力又是怎么在工程里实现的。

先约定一个最常见的形状记法:

  • batch 大小:$B$
  • 序列长度:$L$
  • embedding 维度:$d_{\text{model}}$

对于一层 Self-Attention,它的典型输入是一个张量:

  • 形状:$(B, L, d_{\text{model}})$
  • 含义:一共有 $B$ 个样本,每个样本是一串长度为 $L$ 的 token,每个 token 用一个 $d_{\text{model}}$ 维的向量表示。

这一层的输出,形状通常和输入是一样的:

  • 形状:$(B, L, d_{\text{model}})$
  • 含义:对每个位置,重新融合了一遍“整句的信息”之后的表示。

中间过程会经历 Q/K/V 的线性变换、注意力权重计算、加权求和等步骤,但从外部看,就是一个“输入一串向量,输出同形状的一串向量”的算子。

对于输入张量 $X \in \mathbb{R}^{B \times L \times d_{\text{model}}}$,Self-Attention 先通过三组线性变换,得到:

  • $Q = X W_Q$
  • $K = X W_K$
  • $V = X W_V$

其中:

  • $W_Q, W_K, W_V \in \mathbb{R}^{d_{\text{model}} \times d_k}$ 或 $\mathbb{R}^{d_{\text{model}} \times d_v}$
  • $d_k, d_v$ 通常与 $d_{\text{model}}$ 同量级,但可以不同。

在代码里,这通常就是三层 Linear 或一次 Linear 再切分三段。例如(伪代码):

1# X: (B, L, d_model)
2Q = X @ W_Q  # (B, L, d_k)
3K = X @ W_K  # (B, L, d_k)
4V = X @ W_V  # (B, L, d_v)

很多实现会用一个大矩阵一次性做完:

1W = concat(W_Q, W_K, W_V, dim=-1)   # (d_model, d_q + d_k + d_v)
2QKV = X @ W                         # (B, L, d_q + d_k + d_v)
3Q, K, V = split(QKV)                # 按最后一个维度切三份

这样可以更好地利用底层矩阵乘法库的性能。

注意力的核心公式是:

$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V $$

以单头的情况为例,忽略 batch 维度时,可以这样理解:

  • $Q \in \mathbb{R}^{L_q \times d_k}$
  • $K \in \mathbb{R}^{L_k \times d_k}$
  • $V \in \mathbb{R}^{L_k \times d_v}$

这里 $L_q$ 是 query 的长度,$L_k$ 是 key/value 的长度:

  • 自注意力(Self-Attention):$L_q = L_k = L$;
  • 编码器-解码器注意力(Encoder-Decoder Attention):query 来自 decoder,key/value 来自 encoder,长度可以不同。

计算步骤:

  1. 相似度矩阵:

    $$ S = \frac{QK^\top}{\sqrt{d_k}} \in \mathbb{R}^{L_q \times L_k} $$
  2. 做 softmax 得到权重:

    $$ A = \text{softmax}(S) \in \mathbb{R}^{L_q \times L_k} $$
  3. 用权重对 V 做加权求和:

    $$ O = AV \in \mathbb{R}^{L_q \times d_v} $$

对于自注意力,当 (L_q = L_k = L) 时,上面三个矩阵的形状就是:

  • (S \in \mathbb{R}^{L \times L}):每一行是一个位置对所有位置的“注意力分布”;
  • (A \in \mathbb{R}^{L \times L}):softmax 后的权重;
  • (O \in \mathbb{R}^{L \times d_v}):每个位置融合整句信息后的输出。

在真正的实现里,这一切都会带上 batch 维度 (B),一般写成:

  • (Q \in \mathbb{R}^{B \times L_q \times d_k})
  • (K \in \mathbb{R}^{B \times L_k \times d_k})
  • (V \in \mathbb{R}^{B \times L_k \times d_v})

单头注意力可以理解为:在一个子空间里,学会“谁该看谁”。
多头注意力的想法是:同时在多个子空间里,各自学习一种关系模式

做法上是:

  • 把 (d_{\text{model}}) 划分成 (h) 个头,每头的维度是 (d_h = d_{\text{model}} / h);
  • 对每个头都有自己的 (W_Q^{(i)}, W_K^{(i)}, W_V^{(i)});
  • 每个头各自算一次 Attention,最后把所有头的输出在最后一个维度上拼接,再做一次线性变换。

如果加上 head 这一维,典型的形状是:

  • 输入:(X \in \mathbb{R}^{B \times L \times d_{\text{model}}})

  • 线性投影后 reshape:

    • (Q, K, V \in \mathbb{R}^{B \times h \times L \times d_h})
  • 对每个头分别做 Attention,得到:

    • (O \in \mathbb{R}^{B \times h \times L \times d_h})
  • 再把 head 维度和最后一维拼回去:

    • reshape 成 (\mathbb{R}^{B \times L \times (h \cdot d_h)} = \mathbb{R}^{B \times L \times d_{\text{model}}})
    • 通过一个线性层 (W_O) 映射回输出空间。

在很多库里(比如 PyTorch / TensorFlow),这一段实现大致是:

 1# X: (B, L, d_model)
 2Q = X @ W_Q  # (B, L, h * d_h)
 3K = X @ W_K  # (B, L, h * d_h)
 4V = X @ W_V  # (B, L, h * d_h)
 5
 6# 变换成带 head 维度的形状
 7Q = Q.view(B, L, h, d_h).transpose(1, 2)  # (B, h, L, d_h)
 8K = K.view(B, L, h, d_h).transpose(1, 2)  # (B, h, L, d_h)
 9V = V.view(B, L, h, d_h).transpose(1, 2)  # (B, h, L, d_h)
10
11# 按 head 维做 Attention
12scores = (Q @ K.transpose(-2, -1)) / math.sqrt(d_h)  # (B, h, L, L)
13weights = softmax(scores, dim=-1)                    # (B, h, L, L)
14O = weights @ V                                      # (B, h, L, d_h)
15
16# 拼回去
17O = O.transpose(1, 2).contiguous().view(B, L, h * d_h)  # (B, L, d_model)
18O = O @ W_O                                             # (B, L, d_model)

无论是 padding mask(忽略补零位置),还是因果 mask(不能看未来),本质上都是在 softmax 前对分数矩阵 (S) 做处理。

最常见的做法是:

  • 准备一个 mask 张量 mask,与 (S) 的形状兼容(例如 (B, 1, 1, L)(B, 1, L, L));
  • 对不允许关注的位置,把对应的分数减去一个很大的数(近似 (-\infty));
  • softmax 之后,这些位置的权重就接近 0。

伪代码大致是:

1scores = (Q @ K.transpose(-2, -1)) / math.sqrt(d_h)  # (B, h, L_q, L_k)
2
3if mask is not None:
4  scores = scores.masked_fill(mask == 0, -1e9)
5
6weights = softmax(scores, dim=-1)

对于 GPT 这类自回归模型,因果 mask 一般是一个上三角矩阵,保证第 (i) 个位置只能看到 (0 \dots i) 的位置。

  • 缩放因子 (\sqrt{d_k})

    • 没有缩放时,(QK^\top) 的数值范围会随着 (d_k) 增大而变大;
    • 过大的分数会让 softmax 非常“尖锐”,梯度容易消失;
    • 除以 (\sqrt{d_k}) 是个简单有效的数值稳定技巧。
  • 头数和维度的折中

    • 头数太少,表达能力有限;
    • 头数太多,每头的维度 (d_h) 太小,容易“学不到什么内容”;
    • 常见配置:(d_{\text{model}} = 768, h = 12) 或 (d_{\text{model}} = 1024, h = 16) 等。
  • 实现里大量用到 reshape / transpose

    • 逻辑上是“加一个 head 维度”,实际就是在 batch、序列、head、通道维度之间来回重排;
    • 真正的计算集中在几个大矩阵乘法上,其它都是张量变换。

这篇把 Self-Attention 在实现上的“形状问题”理了一遍:
从 Q/K/V 的线性变换,到多头注意力的 reshape,再到 mask 的介入,基本可以对照着任何一个框架里的 MultiHeadAttention 源码去读,不容易迷路。