Transformer 架构
2219 字约 7 分钟
2026-05-20
2017年,Google 发表《Attention is All You Need》,提出 Transformer 架构。这是过去十年最重要的 AI 论文,它彻底替代了 RNN,并成为所有现代大语言模型的基础。
核心思想只有一句话:完全基于注意力机制,不使用循环或卷积。
1. 整体架构
1.1 原始 Encoder-Decoder 结构
原始 Transformer 用于机器翻译,包含:
输入序列(源语言)
↓
[Tokenization + Embedding + Positional Encoding]
↓
Encoder(N=6层)
├── Multi-Head Self-Attention
├── Add & Norm
├── Feed-Forward Network
└── Add & Norm
↓
Encoder 输出(Memory)
↓
Decoder(N=6层)
├── Masked Multi-Head Self-Attention(自回归掩码)
├── Add & Norm
├── Multi-Head Cross-Attention(关注 Encoder 输出)
├── Add & Norm
├── Feed-Forward Network
└── Add & Norm
↓
Linear + Softmax → 输出概率分布1.2 现代大模型:Decoder-only
GPT 系列、Claude、LLaMA 等都只用 Decoder,去掉了 Encoder 和 Cross-Attention:
输入 tokens
↓
Embedding + Positional Encoding
↓
N × Transformer Block:
├── LayerNorm(Pre-Norm,放在前面)
├── Multi-Head Self-Attention(因果掩码)
├── Residual Connection
├── LayerNorm
├── Feed-Forward Network(SwiGLU)
└── Residual Connection
↓
LayerNorm
↓
Linear → Softmax → 下一个 token 的概率2. Tokenization(分词)
2.1 为什么不直接用字符或单词?
- 字符级:序列太长,难以学习长程依赖
- 词级:词汇表爆炸(处理不了未登录词,如新词、代码、数学符号)
- 子词级:在两者之间,现在的标准
2.2 BPE(Byte Pair Encoding)
- 初始:每个字符是一个 token
- 统计相邻 token 对的频率
- 合并频率最高的对
- 重复直到达到预设词汇表大小
"lower" → "low" + "er";"newer" → "new" + "er"(共享"er"子词)
GPT 系列用 BPE,Claude 也用类似方法。中文通常每个字是一个 token。
import tiktoken # OpenAI 的 tokenizer
enc = tiktoken.get_encoding("cl100k_base") # GPT-4 的 tokenizer
tokens = enc.encode("Hello, 你好世界!")
print(tokens) # token IDs
print(len(tokens)) # token 数量
text = enc.decode(tokens) # 还原文本3. 嵌入(Embedding)
将每个 token 映射到连续的高维向量空间:
ei=WE[tokeni]
W_E 是 Embedding 矩阵,形状为 (vocab_size, d_model),是可学习参数。
embedding = nn.Embedding(vocab_size=50257, embedding_dim=768)
# 输入: token IDs (batch, seq_len)
# 输出: (batch, seq_len, 768) — 每个 token 变成768维向量4. 位置编码(Positional Encoding)
Self-Attention 本身是置换不变的(对"猫吃鱼"和"鱼吃猫"的处理相同),必须显式注入位置信息。
4.1 原始 Transformer:正弦位置编码
PE(pos,2i)=sin(100002i/dmodelpos)
PE(pos,2i+1)=cos(100002i/dmodelpos)
固定的、不可学习,但实验证明效果很好,且对超出训练长度的位置有一定泛化能力。
import torch
import numpy as np
import matplotlib.pyplot as plt
def positional_encoding(max_len, d_model):
PE = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
(-np.log(10000.0) / d_model))
PE[:, 0::2] = torch.sin(position * div_term)
PE[:, 1::2] = torch.cos(position * div_term)
return PE # (max_len, d_model)
# 可视化:每一行是一个位置的编码,颜色表示值的大小
pe = positional_encoding(50, 64)
plt.figure(figsize=(12, 4))
plt.pcolormesh(pe.T, cmap='RdBu')
plt.xlabel('Position')
plt.ylabel('Dimension')
plt.colorbar()4.2 RoPE(旋转位置编码)— 现代大模型标准
LLaMA、GPT-NeoX、Claude 等现代大模型使用 RoPE。
核心思想:不是把位置编码加到 Embedding 上,而是在计算注意力时,将位置信息以旋转的方式编码到 Q 和 K 向量中:
qm=Rmq,kn=Rnk
score(qm,kn)=(Rmq)T(Rnk)=qTRn−mk
注意力分数只依赖于相对位置 (n-m),而不是绝对位置,这让模型对位置外推有更好的能力。
5. 自注意力机制(Self-Attention)
5.1 核心方程
Attention(Q,K,V)=softmax(dkQKT)V
- Q(Query):当前 token "想要什么信息"
- K(Key):每个 token "能提供什么信息"
- V(Value):每个 token 的"实际内容"
Q = X W_Q,K = X W_K,V = X W_V,三个投影矩阵是可学习参数。
5.2 为什么除以 √d_k?
当 d_k 较大时,Q·K 的点积方差也较大,经过 softmax 后会导致梯度消失(softmax 在极端值处梯度趋近0)。除以 √d_k 将方差归一化。
5.3 因果掩码(Causal Mask)
Decoder-only 模型在训练时用因果掩码,确保位置 i 只能看到位置 ≤ i 的信息(不能"偷看未来"):
maskij={0−∞i≥ji<j
加到注意力分数上后,softmax 会让未来位置的权重变为0。
import torch
import torch.nn.functional as F
import math
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Q, K, V: (batch, heads, seq_len, d_k)
"""
d_k = Q.size(-1)
scores = Q @ K.transpose(-2, -1) / math.sqrt(d_k) # (batch, heads, seq, seq)
if mask is not None:
# 因果掩码:将未来位置填充 -inf
scores = scores.masked_fill(mask == 0, float('-inf'))
weights = F.softmax(scores, dim=-1) # 沿序列维度归一化
return weights @ V, weights
# 生成因果掩码
def causal_mask(seq_len, device):
return torch.tril(torch.ones(seq_len, seq_len, device=device))5.4 注意力复杂度
- 时间复杂度:O(n²d),n 是序列长度,d 是维度
- 空间复杂度:O(n²),需要存储注意力矩阵
这是 Transformer 处理长序列的瓶颈——64K tokens 的注意力矩阵内存需求极大。
6. 多头注意力(Multi-Head Attention)
6.1 思路
用 h 组不同的(Q, K, V)参数并行计算注意力,每个"头"学习不同类型的关联关系:
- 头1:可能关注句法关系(主谓关系)
- 头2:可能关注语义关系(同义词、反义词)
- 头3:可能关注指代关系("它"指代哪个名词)
- ...
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# Q, K, V, 输出的投影矩阵
self.W_Q = nn.Linear(d_model, d_model)
self.W_K = nn.Linear(d_model, d_model)
self.W_V = nn.Linear(d_model, d_model)
self.W_O = nn.Linear(d_model, d_model)
def split_heads(self, x):
# (batch, seq, d_model) → (batch, heads, seq, d_k)
batch, seq, _ = x.size()
return x.view(batch, seq, self.num_heads, self.d_k).transpose(1, 2)
def forward(self, x, mask=None):
batch, seq, _ = x.size()
Q = self.split_heads(self.W_Q(x))
K = self.split_heads(self.W_K(x))
V = self.split_heads(self.W_V(x))
attn_output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)
# 合并各头的输出
attn_output = attn_output.transpose(1, 2).contiguous().view(batch, seq, self.d_model)
return self.W_O(attn_output)7. 前馈网络(Feed-Forward Network)
每个 Transformer Block 的第二个子层,对每个位置独立处理:
原始:
FFN(x)=ReLU(xW1+b1)W2+b2
中间维度通常是 d_model 的 4 倍。
现代(SwiGLU,LLaMA 等使用):
SwiGLU(x)=(SiLU(xW1)⊙xW3)W2
其中 SiLU(x) = x × sigmoid(x)(也叫 Swish)。三个投影矩阵,带门控机制,效果优于 ReLU FFN。
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.1):
super().__init__()
# SwiGLU 风格
self.w1 = nn.Linear(d_model, d_ff, bias=False)
self.w2 = nn.Linear(d_ff, d_model, bias=False)
self.w3 = nn.Linear(d_model, d_ff, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# SwiGLU: silu(xW1) * xW3 → W2
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))8. 残差连接与归一化
8.1 Pre-Norm vs Post-Norm
Post-Norm(原始论文):
x=LayerNorm(x+SubLayer(x))
Pre-Norm(现代大模型标准):
x=x+SubLayer(LayerNorm(x))
Pre-Norm 训练更稳定,不需要精细的学习率 warmup,深层网络更容易优化。
8.2 LayerNorm vs RMSNorm
LayerNorm:对每个样本的所有特征归一化(减均值,除标准差)
RMSNorm(LLaMA 使用):只除以均方根,不减均值,更简单更快:
RMSNorm(x)=d1∑xi2+ϵx⋅γ
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
rms = x.pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
return x * rms * self.weight9. 完整 Transformer Block
class TransformerBlock(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.attention = MultiHeadAttention(d_model, num_heads)
self.feed_forward = FeedForward(d_model, d_ff, dropout)
self.norm1 = RMSNorm(d_model)
self.norm2 = RMSNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Pre-Norm + 残差连接
x = x + self.dropout(self.attention(self.norm1(x), mask))
x = x + self.dropout(self.feed_forward(self.norm2(x)))
return x
class GPT(nn.Module):
def __init__(self, vocab_size, d_model, num_heads, d_ff, num_layers, max_seq_len, dropout=0.1):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
# 学习的位置嵌入(简单版,实际用 RoPE)
self.pos_embedding = nn.Embedding(max_seq_len, d_model)
self.dropout = nn.Dropout(dropout)
self.blocks = nn.ModuleList([
TransformerBlock(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
self.norm = RMSNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
# 权重绑定:输入 Embedding 和输出 Linear 共享权重
self.lm_head.weight = self.embedding.weight
def forward(self, input_ids):
batch, seq_len = input_ids.shape
device = input_ids.device
# 因果掩码
mask = causal_mask(seq_len, device)
# Token Embedding + Positional Embedding
positions = torch.arange(seq_len, device=device).unsqueeze(0)
x = self.dropout(self.embedding(input_ids) + self.pos_embedding(positions))
# N 层 Transformer Block
for block in self.blocks:
x = block(x, mask)
x = self.norm(x)
logits = self.lm_head(x) # (batch, seq_len, vocab_size)
return logits10. Flash Attention
标准注意力的内存瓶颈:需要存储 O(n²) 的注意力矩阵。
Flash Attention(Dao et al. 2022):重写注意力计算的内存访问模式,分块计算,避免将完整注意力矩阵写入 HBM(显存主存储器):
- 速度提升:2-4 倍
- 内存节省:O(n²) → O(n)
- 数值结果完全一致(非近似)
# PyTorch 2.0+ 已内置 Flash Attention
with torch.backends.cuda.sdp_kernel(enable_flash=True):
output = F.scaled_dot_product_attention(Q, K, V, attn_mask=mask)11. KV Cache(推理加速)
自回归生成时,每次生成新 token 都需要重新计算所有之前 token 的 K 和 V。KV Cache 将已计算的 K、V 缓存起来,每次只计算新 token 的 K、V:
- 推理速度提升:显著,生成 1000 token 时快约 1000 倍
- 代价:内存占用增大(长上下文时 KV Cache 很大)
GQA(分组查询注意力):多个 Q 头共享一组 K、V,减小 KV Cache:
MHA:8个Q头 → 8组独立的K, V
MQA:8个Q头 → 共享1组K, V(极度节省内存,但效果略差)
GQA:8个Q头 → 共享2组K, V(折中,LLaMA 3等使用)