从零实现一个 MiniGPT:用 PyTorch 手写 Transformer 架构全过程
一、为什么要从零手写 Transformer
在所有深度学习架构中,Transformer 是当之无愧的基石。从 BERT、GPT 系列到 LLaMA、Qwen,几乎所有前沿大模型都建立在 Transformer(或其变体)之上。PyTorch 虽提供了 nn.Transformer 模块可直接调用,但若不亲手实现一遍,你永远无法真正理解它为何能取代 RNN 成为序列建模的霸主。
MiniGPT 的本质,就是一个最小化的 Encoder-Decoder Transformer。它沿用了《Attention Is All You Need》论文中的经典结构:编码器将输入序列映射为连续表示,解码器自回归地逐 token 生成输出。本文将用纯 PyTorch,从张量操作到完整训练循环,一行一行地构建它。

二、Transformer 架构全景拆解
在动手写代码之前,必须先看清整体结构。Transformer 由 Encoder、Decoder 和连接二者的 Softmax 分类器三部分组成,核心组件如下表所示:
| 组件 | 作用 | 所在位置 |
|---|---|---|
| Input Embedding + Positional Encoding | 将 Token 映射为向量并注入位置信息 | Encoder/Decoder 输入端 |
| Multi-Head Self-Attention | 计算序列内部每对位置的关系 | Encoder 每层 + Decoder 第一子层 |
| Masked Multi-Head Self-Attention | 防止解码时"偷看"未来信息 | Decoder 第一子层(带因果掩码) |
| Multi-Head Cross-Attention | Decoder 关注 Encoder 输出 | Decoder 第二子层 |
| Position-wise Feed-Forward Network | 两层线性 + ReLU,逐位置处理 | Encoder/Decoder 每层 |
| Add & Norm(残差连接 + LayerNorm) | 稳定训练、缓解梯度消失 | 每个子层之后 |
Encoder 由 N=6 个相同的层堆叠而成,每层包含两个子层:多头自注意力 + 前馈网络。Decoder 同样由 N=6 层堆叠,每层包含三个子层:带掩码的多头自注意力 + 交叉注意力 + 前馈网络。
三、核心模块的 PyTorch 实现
3.1 位置编码(Positional Encoding)
Transformer 没有循环结构,必须显式注入位置信息。论文采用不同频率的正弦和余弦函数生成位置编码:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe.unsqueeze(0)) # [1, max_len, d_model]
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.pe[:, :x.size(1), :]
return self.dropout(x)关键点:register_buffer 注册的张量不参与梯度计算,但会随模型保存/加载,这是位置编码的标准做法。
3.2 多头注意力(Multi-Head Attention)
这是 Transformer 的灵魂。多头注意力将输入拆分成多个"头",每个头独立计算缩放点积注意力(Scaled Dot-Product Attention),最后合并输出。 公式如下:
论文采用 h=8 个并行注意力头,每个头的维度 ,总计算量与单头全维度注意力相当。
class MultiHeadAttention(nn.Module): def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1): super().__init__() assert d_model % num_heads == 0 self.d_model = d_model self.num_heads = num_heads self.d_k = d_model // num_heads self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model) self.dropout = nn.Dropout(dropout) self.scale = math.sqrt(self.d_k) def forward(self, query, key, value, attn_mask=None): batch_size = query.size(0) # 线性变换 + 分头: [batch, seq_len, d_model] -> [batch, num_heads, seq_len, d_k] Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) # 缩放点积注意力 scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale # [batch, h, q_len, k_len] if attn_mask is not None: scores = scores.masked_fill(attn_mask == 0, -1e9) attn = F.softmax(scores, dim=-1) attn = self.dropout(attn) context = torch.matmul(attn, V) # [batch, h, q_len, d_k] context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model) output = self.W_o(context) return output, attn
三种使用场景:
Self-Attention:Q、K、V 均来自同一序列(Encoder 内部)
Masked Self-Attention:Q、K、V 来自同一序列,但加因果掩码(Decoder 第一子层)
Cross-Attention:Q 来自 Decoder,K 和 V 来自 Encoder 输出(Decoder 第二子层)
3.3 位置前馈网络(Position-wise Feed-Forward Network)
由两个全连接层 + ReLU 激活函数组成,等价于两个核大小为 1 的卷积,在不同位置上参数共享,但层与层之间参数不同:
论文中 ,内层维度
。
class PositionWiseFeedForward(nn.Module): def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1): super().__init__() self.fc1 = nn.Linear(d_model, d_ff) self.fc2 = nn.Linear(d_ff, d_model) self.dropout = nn.Dropout(dropout) def forward(self, x): return self.fc2(self.dropout(F.relu(self.fc1(x))))
3.4 残差连接与层归一化(Add & Norm)
为避免深层网络的梯度消失和模型退化,Transformer 在每个子层后使用残差连接 + LayerNorm:
class LayerNorm(nn.Module): def __init__(self, features: int, eps: float = 1e-6): super().__init__() self.a_2 = nn.Parameter(torch.ones(features)) self.b_2 = nn.Parameter(torch.zeros(features)) self.eps = eps def forward(self, x): mean = x.mean(-1, keepdim=True) std = x.std(-1, keepdim=True) return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
四、Encoder 与 Decoder 的完整实现
4.1 编码器层(Encoder Layer)
每个 Encoder Layer 包含两个子层:多头自注意力 + 前馈网络,每个子层后接残差连接和 LayerNorm。
class EncoderLayer(nn.Module): def __init__(self, d_model, num_heads, d_ff, dropout=0.1): super().__init__() self.self_attn = MultiHeadAttention(d_model, num_heads, dropout) self.feed_forward = PositionWiseFeedForward(d_model, d_ff, dropout) self.norm1 = LayerNorm(d_model) self.norm2 = LayerNorm(d_model) self.dropout = nn.Dropout(dropout) def forward(self, x, mask): # WWW.AIPUZI.CN子层1: 自注意力 + 残差 + 归一化 attn_out, _ = self.self_attn(x, x, x, mask) x = self.norm1(x + self.dropout(attn_out)) # 子层2: 前馈网络 + 残差 + 归一化 ff_out = self.feed_forward(x) x = self.norm2(x + self.dropout(ff_out)) return x class Encoder(nn.Module): def __init__(self, layer, N): super().__init__() self.layers = nn.ModuleList([layer for _ in range(N)]) self.norm = LayerNorm(layer.d_model if hasattr(layer, 'd_model') else 512) def forward(self, x, mask): for layer in self.layers: x = layer(x, mask) return self.norm(x)
4.2 解码器层(Decoder Layer)
Decoder Layer 在 Encoder Layer 两个子层基础上,增加了一个带掩码的多头注意力层(用于防止看到未来信息),共三个子层。
class DecoderLayer(nn.Module): def __init__(self, d_model, num_heads, d_ff, dropout=0.1): super().__init__() self.self_attn = MultiHeadAttention(d_model, num_heads, dropout) self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout) self.feed_forward = PositionWiseFeedForward(d_model, d_ff, dropout) self.norm1 = LayerNorm(d_model) self.norm2 = LayerNorm(d_model) self.norm3 = LayerNorm(d_model) self.dropout = nn.Dropout(dropout) def forward(self, x, memory, src_mask, tgt_mask): # WWW.AIPUZI.CN子层1: 带掩码的自注意力 attn_out, _ = self.self_attn(x, x, x, tgt_mask) x = self.norm1(x + self.dropout(attn_out)) # 子层2: 交叉注意力 (Q来自Decoder, K,V来自Encoder的memory) attn_out, _ = self.cross_attn(x, memory, memory, src_mask) x = self.norm2(x + self.dropout(attn_out)) # 子层3: 前馈网络 ff_out = self.feed_forward(x) x = self.norm3(x + self.dropout(ff_out)) return x class Decoder(nn.Module): def __init__(self, layer, N): super().__init__() self.layers = nn.ModuleList([layer for _ in range(N)]) self.norm = LayerNorm(layer.d_model if hasattr(layer, 'd_model') else 512) def forward(self, x, memory, src_mask, tgt_mask): for layer in self.layers: x = layer(x, memory, src_mask, tgt_mask) return self.norm(x)
五、完整 Transformer 模型封装
将 Encoder、Decoder、Embedding 和生成器组合为完整的序列到序列模型:
| 参数 | 说明 | 默认值 |
|---|---|---|
| d_model | 模型隐藏维度 | 512 |
| nhead | 注意力头数 | 8 |
| num_encoder_layers | Encoder 层数 | 6 |
| num_decoder_layers | Decoder 层数 | 6 |
| d_ff | 前馈网络内层维度 | 2048 |
| dropout | Dropout 比率 | 0.1 |
| vocab_size | 词表大小 | 任务相关 |
class Generator(nn.Module):
"""线性层 + Softmax,将解码器输出映射为词表概率"""
def __init__(self, d_model, vocab):
super().__init__()
self.proj = nn.Linear(d_model, vocab)
def forward(self, x):
return F.log_softmax(self.proj(x), dim=-1)
class Transformer(nn.Module):
def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
num_decoder_layers=6, d_ff=2048, dropout=0.1, vocab_size=32000):
super().__init__()
self.d_model = d_model
encoder_layer = EncoderLayer(d_model, nhead, d_ff, dropout)
decoder_layer = DecoderLayer(d_model, nhead, d_ff, dropout)
self.encoder = Encoder(encoder_layer, num_encoder_layers)
self.decoder = Decoder(decoder_layer, num_decoder_layers)
self.src_embed = nn.Embedding(vocab_size, d_model)
self.tgt_embed = nn.Embedding(vocab_size, d_model)
self.pos_encoder = PositionalEncoding(d_model, dropout=dropout)
self.generator = Generator(d_model, vocab_size)
# 权重共享:Embedding 与 Generator 的线性层共享参数(乘以 sqrt(d_model))
self.generator.proj.weight = self.tgt_embed.weight
def forward(self, src, tgt, src_mask=None, tgt_mask=None):
src_embedded = self.pos_encoder(self.src_embed(src) * math.sqrt(self.d_model))
tgt_embedded = self.pos_encoder(self.tgt_embed(tgt) * math.sqrt(self.d_model))
memory = self.encoder(src_embedded, src_mask)
output = self.decoder(tgt_embedded, memory, src_mask, tgt_mask)
return self.generator(output)
def generate_square_subsequent_mask(self, sz):
"""生成因果掩码,防止解码时看到未来 token"""
mask = torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)
return (mask == 0).unsqueeze(0).unsqueeze(0) # [1, 1, sz, sz]关键设计:src_embed 与 generator.proj 共享权重矩阵,这是 Transformer 的标准做法,可大幅减少参数量并提升训练稳定性。
六、训练全流程
6.1 训练步骤总览
| 步骤 | 操作 | PyTorch API |
|---|---|---|
| 1 | Tokenization(词元化) | tokenizer |
| 2 | Embedding + 位置编码 | nn.Embedding + PositionalEncoding |
| 3 | 前向传播 | model(src, tgt, src_mask, tgt_mask) |
| 4 | 计算损失 | F.nll_loss 或 F.cross_entropy |
| 5 | 反向传播 | loss.backward() |
| 6 | 梯度裁剪 | torch.nn.utils.clip_grad_norm_ |
| 7 | 参数更新 | optimizer.step() |
6.2 训练循环核心代码
# 超参数
d_model = 512
nhead = 8
num_layers = 6
d_ff = 2048
dropout = 0.1
vocab_size = 32000
epochs = 10
batch_size = 64
model = Transformer(d_model, nhead, num_layers, num_layers, d_ff, dropout, vocab_size)
criterion = nn.CrossEntropyLoss(ignore_index=0) # 0 为 padding token
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
for epoch in range(epochs):
model.train()
total_loss = 0
for src, tgt in dataloader:
src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
tgt_mask = model.generate_square_subsequent_mask(tgt.size(1)).to(src.device)
optimizer.zero_grad()
output = model(src, tgt[:, :-1], src_mask, tgt_mask)
loss = criterion(output.reshape(-1, vocab_size), tgt[:, 1:].reshape(-1))
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}")训练要点:解码器输入使用 tgt[:, :-1](去掉最后一个 token),目标使用 tgt[:, 1:](去掉第一个 token),这就是 Teacher Forcing 策略。梯度裁剪阈值设为 1.0 是 Transformer 训练的标准配置。
七、MiniGPT 的推理(生成)过程
训练完成后,模型以自回归方式逐 token 生成输出:每一步将已生成的 token 拼接回输入,预测下一个 token,直到生成结束符。
@torch.no_grad() def greedy_decode(model, src, max_len=50, start_token=1, end_token=2): model.eval() memory = model.encoder(model.pos_encoder(model.src_embed(src) * math.sqrt(model.d_model)), None) tgt = torch.full((1, 1), start_token, dtype=torch.long, device=src.device) for _ in range(max_len): tgt_mask = model.generate_square_subsequent_mask(tgt.size(1)).to(src.device) out = model.decoder(model.pos_encoder(model.tgt_embed(tgt) * math.sqrt(model.d_model)), memory, None, tgt_mask) prob = model.generator(out[:, -1, :]) next_token = prob.argmax(dim=-1) tgt = torch.cat([tgt, next_token.unsqueeze(0)], dim=1) if next_token.item() == end_token: break return tgt
八、关键参数对照表
| 参数 | 含义 | MiniGPT 推荐值 | 论文原始值 |
|---|---|---|---|
| d_model | 隐藏维度 | 256~512 | 512 |
| nhead | 注意力头数 | 4~8 | 8 |
| num_layers | Encoder/Decoder 层数 | 2~4 | 6 |
| d_ff | 前馈内层维度 | 1024~2048 | 2048 |
| dropout | 丢弃率 | 0.1 | 0.1 |
| max_len | 最大序列长度 | 128~512 | 5000 |
MiniGPT 相比原始 Transformer 大幅缩减了层数和维度,这是它能在消费级显卡上运行的关键。
九、总结
从位置编码到多头注意力,从残差连接到自回归生成,本文完整复现了 Transformer 的每一个核心组件。所有代码均基于 PyTorch 原生 API 实现,不依赖任何高级封装,确保你能看清每一层计算的来龙去脉。
Transformer 的强大不在于复杂度,而在于"注意力即一切"这一简洁思想——它用一种统一的机制同时解决了长距离依赖、并行计算和可解释性三大难题。 手写一遍,你就真正拥有了它。
版权及免责申明:本文由@AI铺子原创发布。该文章观点仅代表作者本人,不代表本站立场。本站不承担任何相关法律责任。
如若转载,请注明出处:https://www.aipuzi.cn/ai-tutorial/from-scratch-minigpt-pytorch-transformer.html

