Self-Attention 细节:从矩阵形状到多头实现
这一篇把 Self-Attention 拆细一点:从输入输出的张量形状,到 Q/K/V 是怎么算出来的,多头注意力又是怎么在工程里实现的。
输入输出长什么样?
先约定一个最常见的形状记法:
- batch 大小:$B$
- 序列长度:$L$
- embedding 维度:$d_{\text{model}}$
对于一层 Self-Attention,它的典型输入是一个张量:
- 形状:$(B, L, d_{\text{model}})$
- 含义:一共有 $B$ 个样本,每个样本是一串长度为 $L$ 的 token,每个 token 用一个 $d_{\text{model}}$ 维的向量表示。
这一层的输出,形状通常和输入是一样的:
- 形状:$(B, L, d_{\text{model}})$
- 含义:对每个位置,重新融合了一遍“整句的信息”之后的表示。
中间过程会经历 Q/K/V 的线性变换、注意力权重计算、加权求和等步骤,但从外部看,就是一个“输入一串向量,输出同形状的一串向量”的算子。
Q / K / V 是怎么算出来的?
对于输入张量 $X \in \mathbb{R}^{B \times L \times d_{\text{model}}}$,Self-Attention 先通过三组线性变换,得到:
- $Q = X W_Q$
- $K = X W_K$
- $V = X W_V$
其中:
- $W_Q, W_K, W_V \in \mathbb{R}^{d_{\text{model}} \times d_k}$ 或 $\mathbb{R}^{d_{\text{model}} \times d_v}$
- $d_k, d_v$ 通常与 $d_{\text{model}}$ 同量级,但可以不同。
在代码里,这通常就是三层 Linear 或一次 Linear 再切分三段。例如(伪代码):
1# X: (B, L, d_model)
2Q = X @ W_Q # (B, L, d_k)
3K = X @ W_K # (B, L, d_k)
4V = X @ W_V # (B, L, d_v)
很多实现会用一个大矩阵一次性做完:
1W = concat(W_Q, W_K, W_V, dim=-1) # (d_model, d_q + d_k + d_v)
2QKV = X @ W # (B, L, d_q + d_k + d_v)
3Q, K, V = split(QKV) # 按最后一个维度切三份
这样可以更好地利用底层矩阵乘法库的性能。
Scaled Dot-Product Attention 的矩阵形状
注意力的核心公式是:
$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V $$以单头的情况为例,忽略 batch 维度时,可以这样理解:
- $Q \in \mathbb{R}^{L_q \times d_k}$
- $K \in \mathbb{R}^{L_k \times d_k}$
- $V \in \mathbb{R}^{L_k \times d_v}$
这里 $L_q$ 是 query 的长度,$L_k$ 是 key/value 的长度:
- 自注意力(Self-Attention):$L_q = L_k = L$;
- 编码器-解码器注意力(Encoder-Decoder Attention):query 来自 decoder,key/value 来自 encoder,长度可以不同。
计算步骤:
相似度矩阵:
$$ S = \frac{QK^\top}{\sqrt{d_k}} \in \mathbb{R}^{L_q \times L_k} $$做 softmax 得到权重:
$$ A = \text{softmax}(S) \in \mathbb{R}^{L_q \times L_k} $$用权重对 V 做加权求和:
$$ O = AV \in \mathbb{R}^{L_q \times d_v} $$
对于自注意力,当 (L_q = L_k = L) 时,上面三个矩阵的形状就是:
- (S \in \mathbb{R}^{L \times L}):每一行是一个位置对所有位置的“注意力分布”;
- (A \in \mathbb{R}^{L \times L}):softmax 后的权重;
- (O \in \mathbb{R}^{L \times d_v}):每个位置融合整句信息后的输出。
在真正的实现里,这一切都会带上 batch 维度 (B),一般写成:
- (Q \in \mathbb{R}^{B \times L_q \times d_k})
- (K \in \mathbb{R}^{B \times L_k \times d_k})
- (V \in \mathbb{R}^{B \times L_k \times d_v})
多头注意力:为什么要多一个 Head 维度?
单头注意力可以理解为:在一个子空间里,学会“谁该看谁”。
多头注意力的想法是:同时在多个子空间里,各自学习一种关系模式。
做法上是:
- 把 (d_{\text{model}}) 划分成 (h) 个头,每头的维度是 (d_h = d_{\text{model}} / h);
- 对每个头都有自己的 (W_Q^{(i)}, W_K^{(i)}, W_V^{(i)});
- 每个头各自算一次 Attention,最后把所有头的输出在最后一个维度上拼接,再做一次线性变换。
如果加上 head 这一维,典型的形状是:
输入:(X \in \mathbb{R}^{B \times L \times d_{\text{model}}})
线性投影后 reshape:
- (Q, K, V \in \mathbb{R}^{B \times h \times L \times d_h})
对每个头分别做 Attention,得到:
- (O \in \mathbb{R}^{B \times h \times L \times d_h})
再把 head 维度和最后一维拼回去:
- reshape 成 (\mathbb{R}^{B \times L \times (h \cdot d_h)} = \mathbb{R}^{B \times L \times d_{\text{model}}})
- 通过一个线性层 (W_O) 映射回输出空间。
在很多库里(比如 PyTorch / TensorFlow),这一段实现大致是:
1# X: (B, L, d_model)
2Q = X @ W_Q # (B, L, h * d_h)
3K = X @ W_K # (B, L, h * d_h)
4V = X @ W_V # (B, L, h * d_h)
5
6# 变换成带 head 维度的形状
7Q = Q.view(B, L, h, d_h).transpose(1, 2) # (B, h, L, d_h)
8K = K.view(B, L, h, d_h).transpose(1, 2) # (B, h, L, d_h)
9V = V.view(B, L, h, d_h).transpose(1, 2) # (B, h, L, d_h)
10
11# 按 head 维做 Attention
12scores = (Q @ K.transpose(-2, -1)) / math.sqrt(d_h) # (B, h, L, L)
13weights = softmax(scores, dim=-1) # (B, h, L, L)
14O = weights @ V # (B, h, L, d_h)
15
16# 拼回去
17O = O.transpose(1, 2).contiguous().view(B, L, h * d_h) # (B, L, d_model)
18O = O @ W_O # (B, L, d_model)
Mask 是怎么加进去的?
无论是 padding mask(忽略补零位置),还是因果 mask(不能看未来),本质上都是在 softmax 前对分数矩阵 (S) 做处理。
最常见的做法是:
- 准备一个 mask 张量
mask,与 (S) 的形状兼容(例如(B, 1, 1, L)或(B, 1, L, L)); - 对不允许关注的位置,把对应的分数减去一个很大的数(近似 (-\infty));
- softmax 之后,这些位置的权重就接近 0。
伪代码大致是:
1scores = (Q @ K.transpose(-2, -1)) / math.sqrt(d_h) # (B, h, L_q, L_k)
2
3if mask is not None:
4 scores = scores.masked_fill(mask == 0, -1e9)
5
6weights = softmax(scores, dim=-1)
对于 GPT 这类自回归模型,因果 mask 一般是一个上三角矩阵,保证第 (i) 个位置只能看到 (0 \dots i) 的位置。
工程上常见的几个实现细节
缩放因子 (\sqrt{d_k}):
- 没有缩放时,(QK^\top) 的数值范围会随着 (d_k) 增大而变大;
- 过大的分数会让 softmax 非常“尖锐”,梯度容易消失;
- 除以 (\sqrt{d_k}) 是个简单有效的数值稳定技巧。
头数和维度的折中:
- 头数太少,表达能力有限;
- 头数太多,每头的维度 (d_h) 太小,容易“学不到什么内容”;
- 常见配置:(d_{\text{model}} = 768, h = 12) 或 (d_{\text{model}} = 1024, h = 16) 等。
实现里大量用到 reshape / transpose:
- 逻辑上是“加一个 head 维度”,实际就是在 batch、序列、head、通道维度之间来回重排;
- 真正的计算集中在几个大矩阵乘法上,其它都是张量变换。
这篇把 Self-Attention 在实现上的“形状问题”理了一遍:
从 Q/K/V 的线性变换,到多头注意力的 reshape,再到 mask 的介入,基本可以对照着任何一个框架里的 MultiHeadAttention 源码去读,不容易迷路。