RNN 与序列模型
2283 字约 8 分钟
2026-05-20
语言、时间序列、音频——这类数据的特点是顺序至关重要,且长度可变。RNN 及其变体是专为这类序列数据设计的架构,也是理解 Transformer 的必要前置知识。
1. 为什么需要序列模型
普通神经网络的局限:
- 输入输出长度固定
- 不考虑时序关系("今天天气很好"和"好天气很今天"对它来说一样)
- 参数不共享(第1个位置和第100个位置用的是不同参数)
序列模型解决的问题:
| 任务类型 | 输入 | 输出 | 示例 |
|---|---|---|---|
| 多对一 | 序列 | 单个值 | 情感分析(文本→情感标签) |
| 一对多 | 单个值 | 序列 | 图像描述(图片→文字) |
| 多对多(等长) | 序列 | 序列 | 词性标注 |
| 多对多(变长) | 序列 | 序列 | 机器翻译 |
2. RNN 基本结构
2.1 核心方程
RNN 在每个时间步 t 维护一个隐藏状态 h_t,综合当前输入和上一时刻的隐藏状态:
ht=tanh(Whht−1+Wxxt+bh)
yt=Wyht+by
权重共享:W_h、W_x、W_y 在所有时间步共享,使 RNN 能处理任意长度序列。
2.2 展开图(Unrolled RNN)
x₁ x₂ x₃ x₄
│ │ │ │
↓ ↓ ↓ ↓
h₀ ──→ h₁ ──→ h₂ ──→ h₃ ──→ h₄
│ │ │
↓ ↓ ↓
y₁ y₂ y₃同一组权重在所有时间步被反复使用,这也是为什么梯度在反向传播时会连乘多次权重矩阵。
2.3 RNN 的致命弱点:梯度消失
反向传播穿越时间(BPTT):
∂h1∂L=∂hT∂Lt=2∏T∂ht−1∂ht
每个时间步的梯度贡献:
∂ht−1∂ht=tanh′(zt)⋅Wh
- tanh 的导数最大值为1,通常 < 1
- 经过 T 步相乘后,梯度指数衰减 → 梯度消失(序列长时记住不了早期信息)
- 如果 W_h 的最大奇异值 > 1,则梯度爆炸
import torch
import torch.nn as nn
# 基础 RNN
rnn = nn.RNN(
input_size=128, # 输入维度
hidden_size=256, # 隐藏状态维度
num_layers=2, # 堆叠 2 层
batch_first=True, # 输入形状 (batch, seq_len, features)
dropout=0.3,
bidirectional=False
)
# 前向传播
x = torch.randn(32, 100, 128) # (batch=32, seq_len=100, input=128)
h0 = torch.zeros(2, 32, 256) # (num_layers, batch, hidden)
output, hn = rnn(x, h0)
# output: (32, 100, 256) — 每个时间步的输出
# hn: (2, 32, 256) — 最后一个时间步的隐藏状态3. LSTM(Long Short-Term Memory)
3.1 设计动机
LSTM 通过引入细胞状态(Cell State)C_t和三个门控机制解决长期依赖问题。
细胞状态像一条"信息高速公路",允许信息以几乎无损的方式传递很多时间步。
3.2 四个核心方程
遗忘门(Forget Gate):决定从上一细胞状态遗忘多少信息
ft=σ(Wf[ht−1,xt]+bf)
输入门(Input Gate):决定当前输入有多少写入细胞状态
it=σ(Wi[ht−1,xt]+bi)
候选细胞状态:当前时间步想写入的新信息
C~t=tanh(WC[ht−1,xt]+bC)
细胞状态更新:
Ct=ft⊙Ct−1+it⊙C~t
输出门(Output Gate):决定输出细胞状态的哪些部分
ot=σ(Wo[ht−1,xt]+bo)
ht=ot⊙tanh(Ct)
3.3 门控机制的直觉
细胞状态 C_t-1 ───────────────────────────────────────────→ C_t
↑(×f_t) ↑(+i_t × C̃_t)
遗忘门: 输入门:
不重要的信息清零 新信息写入
输出 h_t = o_t × tanh(C_t)
↑
输出门:控制输出哪些信息示例理解:解析句子中的动词时态
- 遗忘门:当看到新的主语时,遗忘之前的主语信息
- 输入门:将当前时态信息写入细胞状态
- 输出门:预测当前词时,输出相关的语法信息
lstm = nn.LSTM(
input_size=128,
hidden_size=256,
num_layers=2,
batch_first=True,
dropout=0.3,
bidirectional=True # 双向 LSTM:同时从左往右和从右往左处理
)
x = torch.randn(32, 100, 128)
h0 = torch.zeros(4, 32, 256) # (2 layers × 2 directions, batch, hidden)
c0 = torch.zeros(4, 32, 256)
output, (hn, cn) = lstm(x, (h0, c0))
# 双向 LSTM 输出维度是 2×hidden=512
# output: (32, 100, 512)4. GRU(Gated Recurrent Unit)
GRU 是 LSTM 的简化版,将遗忘门和输入门合并为更新门,去掉了独立的细胞状态:
更新门(结合了遗忘门和输入门的功能):
zt=σ(Wz[ht−1,xt]+bz)
重置门(控制前一隐藏状态有多少参与当前计算):
rt=σ(Wr[ht−1,xt]+br)
候选隐藏状态:
h~t=tanh(W[rt⊙ht−1,xt]+b)
隐藏状态更新:
ht=(1−zt)⊙ht−1+zt⊙h~t
GRU vs LSTM:参数更少,在许多任务上效果相当,但在需要精细控制"遗忘"和"更新"的任务上 LSTM 可能更好。
gru = nn.GRU(input_size=128, hidden_size=256, num_layers=2,
batch_first=True, dropout=0.3)
output, hn = gru(x) # 比 LSTM 简单,不需要 c05. 双向 RNN
标准 RNN 只从左往右处理序列,但很多任务需要双向上下文(理解一个词需要看它后面的内容):
句子:"他买了一台苹果___"(填什么?)
→ 需要左侧上下文(他买了...)和右侧上下文(没有)
→ 如果句子是"他___了一台苹果手机"(填什么?)
→ 需要右侧 "苹果手机" 来理解左侧应该填 "买"双向 LSTM 同时从两个方向处理,拼接输出:
bilstm = nn.LSTM(128, 256, bidirectional=True, batch_first=True)
output, _ = bilstm(x)
# output: (batch, seq_len, 2×hidden=512)
# 前 256 维:左→右的信息,后 256 维:右→左的信息BERT 的本质就是双向(bi-directional)——它的全称是 Bidirectional Encoder Representations from Transformers。
6. Seq2Seq 与注意力机制
6.1 Encoder-Decoder 架构
机器翻译的经典架构:
"I love you"
↓ Encoder (LSTM)
[h₁, h₂, h₃] → 最终隐藏状态 c(上下文向量)
↓ Decoder (LSTM)
"我"→"爱"→"你"致命瓶颈:所有信息压缩到单一的固定长度向量 c。句子越长,信息损失越严重,翻译质量下降。
6.2 注意力机制(Attention)
Bahdanau 注意力(2015):解码每个词时,让解码器能"回头看"编码器的所有输出,动态加权:
对齐分数(相似度计算):
etj=vTtanh(Wsst−1+Whhj)
注意力权重(归一化):
αtj=∑kexp(etk)exp(etj)
上下文向量(加权求和编码器输出):
ct=j∑αtjhj
解码器更新:
st=f(st−1,yt−1,ct)
class BahdanauAttention(nn.Module):
def __init__(self, encoder_dim, decoder_dim, attention_dim):
super().__init__()
self.W_s = nn.Linear(decoder_dim, attention_dim)
self.W_h = nn.Linear(encoder_dim, attention_dim)
self.v = nn.Linear(attention_dim, 1)
def forward(self, decoder_hidden, encoder_outputs):
# decoder_hidden: (batch, decoder_dim)
# encoder_outputs: (batch, seq_len, encoder_dim)
# 计算对齐分数
scores = self.v(torch.tanh(
self.W_s(decoder_hidden).unsqueeze(1) + # (batch, 1, attn)
self.W_h(encoder_outputs) # (batch, seq, attn)
)).squeeze(-1) # (batch, seq)
# 注意力权重
weights = torch.softmax(scores, dim=1) # (batch, seq)
# 上下文向量
context = (weights.unsqueeze(-1) * encoder_outputs).sum(dim=1) # (batch, enc_dim)
return context, weights注意力机制让模型的对齐关系变得可解释——可以可视化哪个源语言词对应哪个目标语言词。
6.3 注意力可视化
import matplotlib.pyplot as plt
import seaborn as sns
# 假设 attention_weights: (target_len, source_len)
plt.figure(figsize=(10, 8))
sns.heatmap(
attention_weights,
xticklabels=source_tokens,
yticklabels=target_tokens,
cmap='YlOrRd',
annot=True, fmt='.2f'
)
plt.xlabel('源语言')
plt.ylabel('目标语言')
plt.title('注意力权重可视化')7. 为什么 Transformer 取代了 RNN
7.1 RNN 的根本问题:串行计算
RNN 必须按时间步逐一计算:h₁ → h₂ → h₃ → ... → h_T,无法并行。
在 GPU 时代,并行计算是核心优势。RNN 无论 GPU 有多强,都用不上大部分算力。
7.2 Transformer 的关键优势
Self-Attention 让每个位置能直接关注所有其他位置,全部并行:
| 维度 | RNN/LSTM | Transformer |
|---|---|---|
| 计算方式 | 时间步串行 | 完全并行 |
| 长程依赖 | 路径长度 O(n),容易消失 | 路径长度 O(1),直接相关 |
| 训练速度 | 慢(GPU 利用率低) | 快(GPU 满负荷) |
| 最大序列长度 | 有效长度受限 | 受上下文窗口限制,但可扩展 |
| 参数效率 | 参数共享但表达受限 | 更多参数,更强表达 |
7.3 RNN 还有哪些用途?
RNN 并没有完全消失:
- 轻量级实时应用:参数量小,延迟低
- 时间序列:对结构化时序数据(金融、传感器)效果仍然不错
- 状态空间模型(SSM,如 Mamba):RNN 的新变种,试图在效率和效果上超越 Transformer
# 完整的文本分类 LSTM 示例
class TextClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes, num_layers=2):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers=num_layers,
batch_first=True, dropout=0.3, bidirectional=True)
self.dropout = nn.Dropout(0.5)
self.fc = nn.Linear(hidden_dim * 2, num_classes) # ×2 因为双向
def forward(self, x, lengths):
# x: (batch, seq_len)
embedded = self.dropout(self.embedding(x))
# 打包变长序列(处理 padding)
packed = nn.utils.rnn.pack_padded_sequence(
embedded, lengths.cpu(), batch_first=True, enforce_sorted=False
)
packed_output, (hn, cn) = self.lstm(packed)
output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
# 取最后一个有效时间步的输出(双向:拼接两个方向的最终状态)
hn = torch.cat([hn[-2], hn[-1]], dim=1) # (batch, hidden*2)
return self.fc(self.dropout(hn))8. 时间序列预测实战
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
class TimeSeriesDataset(Dataset):
def __init__(self, data, seq_len, pred_len):
self.data = torch.FloatTensor(data)
self.seq_len = seq_len
self.pred_len = pred_len
def __len__(self):
return len(self.data) - self.seq_len - self.pred_len + 1
def __getitem__(self, idx):
x = self.data[idx:idx + self.seq_len]
y = self.data[idx + self.seq_len:idx + self.seq_len + self.pred_len]
return x.unsqueeze(-1), y # 添加特征维度
class LSTMForecaster(nn.Module):
def __init__(self, input_dim=1, hidden_dim=64, num_layers=2, pred_len=1):
super().__init__()
self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers,
batch_first=True, dropout=0.2)
self.fc = nn.Linear(hidden_dim, pred_len)
def forward(self, x):
output, (hn, _) = self.lstm(x)
return self.fc(hn[-1]) # 用最后一层的最后时间步