预训练与微调
2330 字约 8 分钟
2026-05-20
"预训练-微调"范式是现代 AI 的核心框架。理解大模型从何而来、如何被训练成对话助手,是理解 AI 能力边界的关键。
1. 预训练(Pre-training)
1.1 核心任务:自监督学习
大模型的预训练不需要人工标注,用的是自监督学习。数据本身提供监督信号。
语言模型(Causal LM,GPT 风格):
预测下一个 token——输入序列,预测每个位置的下一个词:
L=−t=1∑TlogP(xt∣x1,x2,...,xt−1)
本质上是极大似然估计:最大化在给定前缀条件下,真实下一词出现的概率。
# 训练时的目标:每个位置预测下一个 token
# 输入:[BOS, token1, token2, ..., tokenN]
# 标签:[token1, token2, ..., tokenN, EOS]
logits = model(input_ids) # (batch, seq_len, vocab_size)
labels = input_ids[:, 1:].clone() # 标签向左移一位
loss = F.cross_entropy(
logits[:, :-1].reshape(-1, vocab_size),
labels.reshape(-1),
ignore_index=pad_token_id # 忽略 padding 位置
)Masked LM(BERT 风格):
随机遮盖15%的 token,预测被遮盖的词:
- 80% 替换为 [MASK]
- 10% 替换为随机词
- 10% 保持原词不变(让模型不确定哪些词需要修正)
1.2 预训练数据
| 数据来源 | 说明 |
|---|---|
| 网页文本(Common Crawl) | 互联网爬取,最大量,需大量清洗 |
| 书籍(Books1/2, Gutenberg) | 高质量,长文本,逻辑连贯 |
| 学术论文(ArXiv) | 专业知识,数学推理 |
| 代码(GitHub) | 提升代码能力,也提升逻辑推理 |
| 百科全书(Wikipedia) | 高质量事实性知识 |
| 对话数据 | 帮助模型理解对话格式 |
数据质量 >> 数据数量。数据清洗(去重、去低质量、去有害内容)是预训练中最重要的工程工作之一。
1.3 Scaling Law
模型能力和三个要素之间存在幂律关系(log-log 线性):
L∝N−αN,L∝D−αD,L∝C−αC
Chinchilla 定律(2022):给定计算预算,最优的参数量和数据量关系:
- 模型参数量 ≈ 训练 tokens 数量 / 20
- GPT-3 (175B 参数) 应配套 3.5T tokens 才是最优,但实际只用了 300B
这推动了 LLaMA 的出现:用更小的参数量(7B、13B、70B)训练更多数据,获得更好的推理时效率。
1.4 训练基础设施
训练 GPT-4 级别的模型需要:
- 数千张 A100/H100 GPU
- 数据并行 + 模型并行 + 流水线并行
- 几个月的训练时间
- 数千万美元的算力成本
数据并行(DDP):不同 GPU 各自有完整模型副本,处理不同数据,同步梯度
模型并行:模型太大放不下一张 GPU,将不同层或不同参数切分到不同 GPU
混合精度训练:用 BF16 存储权重和激活,但保持 FP32 的主权重用于梯度累积
2. 指令微调(Instruction Fine-tuning / SFT)
2.1 基础模型 vs 指令模型
预训练得到的是基础模型(Base Model),它只会"续写文本",不会"对话":
输入: "法国的首都是"
基础模型输出: "巴黎,位于法国北部,塞纳河畔..."(纯续写)指令微调让模型学会遵循指令,按照人类期望的方式响应:
输入: "请问法国的首都是哪里?"
指令模型输出: "法国的首都是巴黎。"2.2 训练数据格式
指令微调用有格式的对话数据:
{
"messages": [
{"role": "system", "content": "你是一个有帮助、无害、诚实的AI助手。"},
{"role": "user", "content": "请解释什么是梯度下降?"},
{"role": "assistant", "content": "梯度下降是一种优化算法...(高质量回答)"}
]
}数据来源:
- 人工撰写的高质量对话
- 强模型生成 + 人工筛选(Self-Instruct)
- 从已有任务改造(Alpaca 等)
2.3 训练细节
SFT 只在 assistant 的回复部分计算损失,system 和 user 的部分只作为上下文(loss = -1 mask掉):
def compute_sft_loss(logits, labels, response_mask):
"""
只在 assistant 回复位置计算 loss
response_mask: 1 表示是 assistant 回复,0 表示其他
"""
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
shift_mask = response_mask[..., 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, logits.size(-1)),
shift_labels.view(-1),
reduction='none'
)
loss = (loss * shift_mask.view(-1)).sum() / shift_mask.sum()
return loss3. RLHF(人类反馈强化学习)
SFT 之后,模型遵循指令的能力有了,但还需要和人类价值观对齐:有帮助、无害、诚实。RLHF 是解决方案。
3.1 三阶段流程
阶段1:SFT(已完成)
阶段2:训练奖励模型(Reward Model)
让人类对同一问题的不同回答排序(A比B好),训练一个能预测人类偏好的打分模型:
LRM=−E(x,yw,yl)[logσ(r(x,yw)−r(x,yl))]
- y_w:更受人类偏好的回答(winner)
- y_l:较差的回答(loser)
- r(x, y):奖励模型给(问题, 回答)对打的分数
class RewardModel(nn.Module):
def __init__(self, base_model):
super().__init__()
self.base = base_model
# 在语言模型头部加一个标量输出层
self.value_head = nn.Linear(base_model.config.hidden_size, 1)
def forward(self, input_ids, attention_mask):
outputs = self.base(input_ids, attention_mask=attention_mask)
# 取最后一个 token 的 hidden state 作为序列表示
last_hidden = outputs.last_hidden_state[:, -1, :]
reward = self.value_head(last_hidden).squeeze(-1)
return reward阶段3:PPO 强化学习
用奖励模型作为信号,用 PPO(Proximal Policy Optimization)算法微调语言模型:
目标:最大化 E[r(x, y)] - β · KL(π_θ || π_ref)- 奖励部分:让回答获得更高的人类偏好分数
- KL 惩罚项:防止模型偏离 SFT 参考模型太远(避免奖励攻击)
PPO 训练中涉及四个模型同时运行(Actor、Critic、Reward、Reference),工程复杂度极高。
3.2 RLHF 的挑战
奖励作弊(Reward Hacking):模型学会骗过奖励模型,而不是真正提升回答质量。
例如:发现奖励模型偏爱长回答 → 生成冗长废话;发现奖励模型偏爱自信语气 → 不确定也说得很肯定。
4. DPO(直接偏好优化)
RLHF 的简化替代,不需要单独训练奖励模型,直接用偏好数据优化语言模型:
LDPO=−E[logσ(βlogπref(yw∣x)πθ(yw∣x)−βlogπref(yl∣x)πθ(yl∣x))]
优势:
- 工程简单(只需一个训练循环)
- 训练稳定(无需 PPO 的复杂超参数调整)
- 效果和 PPO 相当甚至更好
目前大量开源模型使用 DPO(Zephyr、Mistral Instruct 等)。
5. 参数高效微调(PEFT)
全量微调 70B 模型需要 280GB+ 显存,绝大多数人用不起。PEFT 用极少参数实现接近全量微调的效果。
5.1 LoRA(Low-Rank Adaptation)⭐ 最常用
核心思想:大模型的权重更新矩阵是低秩的。冻结原始权重 W,旁路加入低秩分解 BA:
W′=W+ΔW=W+rαBA
- B:(d × r) 矩阵,初始为0
- A:(r × d_in) 矩阵,初始高斯随机
- r:低秩的秩,通常 4-64
- α:缩放因子(通常 = r 或 2r)
参数减少量:原始矩阵 (d × d_in) 参数量为 d·d_in,LoRA 只需 r·(d + d_in),当 r << d 时节省极大。
合并权重:推理时将 BA 加回 W,推理速度不受影响!
from peft import LoraConfig, get_peft_model, TaskType
# LoRA 配置
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=16, # 低秩的秩
lora_alpha=32, # 缩放因子
target_modules=["q_proj", "v_proj", # 应用LoRA的目标层
"k_proj", "o_proj"],
lora_dropout=0.05,
bias="none",
)
# 将 LoRA 注入基础模型
model = get_peft_model(base_model, lora_config)
model.print_trainable_parameters()
# 输出: trainable params: 4,194,304 || all params: 6,742,609,920 || 0.062%
# 训练(只有 LoRA 参数有梯度)
for batch in dataloader:
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
# 保存 LoRA 权重(只保存几十MB,而不是整个模型)
model.save_pretrained("lora_weights/")
# 合并权重(推理时)
merged_model = model.merge_and_unload()5.2 QLoRA
在量化(4-bit NF4 格式)的基础模型上应用 LoRA,将显存需求降低 ~4倍:
- LLaMA-3 70B 全量微调:约 140GB 显存
- QLoRA 微调:约 48GB 显存(两张 A6000 或单张 A100 80G)
from transformers import BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4", # NormalFloat4 量化
bnb_4bit_use_double_quant=True, # 嵌套量化进一步节省
)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3-70b",
quantization_config=bnb_config,
device_map="auto"
)
# 然后正常使用 LoRA
model = get_peft_model(model, lora_config)5.3 其他 PEFT 方法
| 方法 | 思路 | 特点 |
|---|---|---|
| LoRA | 低秩旁路 | 最通用,效果好 |
| Prefix Tuning | 在 KV 前添加可学习前缀 | 适合 NLU 任务 |
| Prompt Tuning | 学习 soft prompt 向量 | 极少参数,效果略差 |
| IA³ | 学习缩放向量 | 比 LoRA 参数更少 |
6. 模型量化
将模型权重从高精度(FP32/BF16)压缩到低精度,减少内存和加速推理。
6.1 量化格式对比
| 精度 | 每参数字节 | 70B 模型大小 | 精度损失 |
|---|---|---|---|
| FP32 | 4B | 280GB | 基准 |
| BF16/FP16 | 2B | 140GB | 极小 |
| INT8 | 1B | 70GB | 小 |
| INT4/NF4 | 0.5B | 35GB | 较小 |
6.2 GGUF(llama.cpp)
在 CPU 上运行量化模型的事实标准格式:
# 使用 Ollama 本地运行大模型
# ollama pull llama3:8b
# ollama run llama3:8b
from ollama import Client
client = Client()
response = client.chat(
model='llama3:8b',
messages=[{'role': 'user', 'content': '解释一下量子纠缠'}]
)
print(response['message']['content'])7. 完整的微调流程示例
from transformers import (
AutoModelForCausalLM, AutoTokenizer,
TrainingArguments, DataCollatorForSeq2Seq
)
from peft import LoraConfig, get_peft_model
from trl import SFTTrainer
import torch
# 1. 加载模型
model_name = "meta-llama/Llama-3-8b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto"
)
# 2. 配置 LoRA
lora_config = LoraConfig(
r=8, lora_alpha=16,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05, bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
# 3. 训练参数
training_args = TrainingArguments(
output_dir="./fine_tuned_model",
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=4, # 等效 batch_size=16
learning_rate=2e-4,
lr_scheduler_type="cosine",
warmup_ratio=0.05,
bf16=True,
logging_steps=10,
save_steps=100,
evaluation_strategy="steps",
eval_steps=100,
load_best_model_at_end=True,
)
# 4. 使用 TRL 的 SFTTrainer(自动处理格式)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
dataset_text_field="text",
max_seq_length=2048,
args=training_args,
)
trainer.train()
trainer.save_model()