从零实现一个 MiniGPT:用 PyTorch 手写 Transformer 架构全过程

原创 发布日期:
65

一、为什么要从零手写 Transformer

在所有深度学习架构中,Transformer 是当之无愧的基石。从 BERT、GPT 系列到 LLaMA、Qwen,几乎所有前沿大模型都建立在 Transformer(或其变体)之上。PyTorch 虽提供了 nn.Transformer 模块可直接调用,但若不亲手实现一遍,你永远无法真正理解它为何能取代 RNN 成为序列建模的霸主。

MiniGPT 的本质,就是一个最小化的 Encoder-Decoder Transformer。它沿用了《Attention Is All You Need》论文中的经典结构:编码器将输入序列映射为连续表示,解码器自回归地逐 token 生成输出。本文将用纯 PyTorch,从张量操作到完整训练循环,一行一行地构建它。

从零实现一个 MiniGPT:用 PyTorch 手写 Transformer 架构全过程

二、Transformer 架构全景拆解

在动手写代码之前,必须先看清整体结构。Transformer 由 Encoder、Decoder 和连接二者的 Softmax 分类器三部分组成,核心组件如下表所示:

组件 作用 所在位置
Input Embedding + Positional Encoding 将 Token 映射为向量并注入位置信息 Encoder/Decoder 输入端
Multi-Head Self-Attention 计算序列内部每对位置的关系 Encoder 每层 + Decoder 第一子层
Masked Multi-Head Self-Attention 防止解码时"偷看"未来信息 Decoder 第一子层(带因果掩码)
Multi-Head Cross-Attention Decoder 关注 Encoder 输出 Decoder 第二子层
Position-wise Feed-Forward Network 两层线性 + ReLU,逐位置处理 Encoder/Decoder 每层
Add & Norm(残差连接 + LayerNorm) 稳定训练、缓解梯度消失 每个子层之后

Encoder 由 N=6 个相同的层堆叠而成,每层包含两个子层:多头自注意力 + 前馈网络。Decoder 同样由 N=6 层堆叠,每层包含三个子层:带掩码的多头自注意力 + 交叉注意力 + 前馈网络。

三、核心模块的 PyTorch 实现

3.1 位置编码(Positional Encoding)

Transformer 没有循环结构,必须显式注入位置信息。论文采用不同频率的正弦和余弦函数生成位置编码:

PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right), \quad PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right)

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class PositionalEncoding(nn.Module):
  def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
    super().__init__()
    self.dropout = nn.Dropout(p=dropout)
    pe = torch.zeros(max_len, d_model)
    position = torch.arange(0, max_len).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    self.register_buffer('pe', pe.unsqueeze(0)) # [1, max_len, d_model]

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    x = x + self.pe[:, :x.size(1), :]
    return self.dropout(x)

关键点:register_buffer 注册的张量不参与梯度计算,但会随模型保存/加载,这是位置编码的标准做法。

3.2 多头注意力(Multi-Head Attention)

这是 Transformer 的灵魂。多头注意力将输入拆分成多个"头",每个头独立计算缩放点积注意力(Scaled Dot-Product Attention),最后合并输出。 公式如下:

\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h)W^O

\text{where } \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)

论文采用 h=8 个并行注意力头,每个头的维度 d_k = d_v = d_{model}/h = 64,总计算量与单头全维度注意力相当。

class MultiHeadAttention(nn.Module):
  def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
    super().__init__()
    assert d_model % num_heads == 0
    self.d_model = d_model
    self.num_heads = num_heads
    self.d_k = d_model // num_heads

    self.W_q = nn.Linear(d_model, d_model)
    self.W_k = nn.Linear(d_model, d_model)
    self.W_v = nn.Linear(d_model, d_model)
    self.W_o = nn.Linear(d_model, d_model)
    self.dropout = nn.Dropout(dropout)
    self.scale = math.sqrt(self.d_k)

  def forward(self, query, key, value, attn_mask=None):
    batch_size = query.size(0)

    # 线性变换 + 分头: [batch, seq_len, d_model] -> [batch, num_heads, seq_len, d_k]
    Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
    K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
    V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

    # 缩放点积注意力
    scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale # [batch, h, q_len, k_len]
    if attn_mask is not None:
      scores = scores.masked_fill(attn_mask == 0, -1e9)
    attn = F.softmax(scores, dim=-1)
    attn = self.dropout(attn)

    context = torch.matmul(attn, V) # [batch, h, q_len, d_k]
    context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
    output = self.W_o(context)
    return output, attn

三种使用场景

  • Self-Attention:Q、K、V 均来自同一序列(Encoder 内部)

  • Masked Self-Attention:Q、K、V 来自同一序列,但加因果掩码(Decoder 第一子层)

  • Cross-Attention:Q 来自 Decoder,K 和 V 来自 Encoder 输出(Decoder 第二子层)

3.3 位置前馈网络(Position-wise Feed-Forward Network)

由两个全连接层 + ReLU 激活函数组成,等价于两个核大小为 1 的卷积,在不同位置上参数共享,但层与层之间参数不同:

\text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2

论文中 d_{model} = 512,内层维度 d_{ff} = 2048

class PositionWiseFeedForward(nn.Module):
  def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
    super().__init__()
    self.fc1 = nn.Linear(d_model, d_ff)
    self.fc2 = nn.Linear(d_ff, d_model)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    return self.fc2(self.dropout(F.relu(self.fc1(x))))

3.4 残差连接与层归一化(Add & Norm)

为避免深层网络的梯度消失和模型退化,Transformer 在每个子层后使用残差连接 + LayerNorm

\text{Output} = \text{LayerNorm}(x + \text{Sublayer}(x))

class LayerNorm(nn.Module):
  def __init__(self, features: int, eps: float = 1e-6):
    super().__init__()
    self.a_2 = nn.Parameter(torch.ones(features))
    self.b_2 = nn.Parameter(torch.zeros(features))
    self.eps = eps

  def forward(self, x):
    mean = x.mean(-1, keepdim=True)
    std = x.std(-1, keepdim=True)
    return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

四、Encoder 与 Decoder 的完整实现

4.1 编码器层(Encoder Layer)

每个 Encoder Layer 包含两个子层:多头自注意力 + 前馈网络,每个子层后接残差连接和 LayerNorm。

class EncoderLayer(nn.Module):
  def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
    super().__init__()
    self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
    self.feed_forward = PositionWiseFeedForward(d_model, d_ff, dropout)
    self.norm1 = LayerNorm(d_model)
    self.norm2 = LayerNorm(d_model)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, mask):
    # WWW.AIPUZI.CN子层1: 自注意力 + 残差 + 归一化
    attn_out, _ = self.self_attn(x, x, x, mask)
    x = self.norm1(x + self.dropout(attn_out))
    # 子层2: 前馈网络 + 残差 + 归一化
    ff_out = self.feed_forward(x)
    x = self.norm2(x + self.dropout(ff_out))
    return x

class Encoder(nn.Module):
  def __init__(self, layer, N):
    super().__init__()
    self.layers = nn.ModuleList([layer for _ in range(N)])
    self.norm = LayerNorm(layer.d_model if hasattr(layer, 'd_model') else 512)

  def forward(self, x, mask):
    for layer in self.layers:
      x = layer(x, mask)
    return self.norm(x)

4.2 解码器层(Decoder Layer)

Decoder Layer 在 Encoder Layer 两个子层基础上,增加了一个带掩码的多头注意力层(用于防止看到未来信息),共三个子层。

class DecoderLayer(nn.Module):
  def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
    super().__init__()
    self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
    self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)
    self.feed_forward = PositionWiseFeedForward(d_model, d_ff, dropout)
    self.norm1 = LayerNorm(d_model)
    self.norm2 = LayerNorm(d_model)
    self.norm3 = LayerNorm(d_model)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, memory, src_mask, tgt_mask):
    # WWW.AIPUZI.CN子层1: 带掩码的自注意力
    attn_out, _ = self.self_attn(x, x, x, tgt_mask)
    x = self.norm1(x + self.dropout(attn_out))
    # 子层2: 交叉注意力 (Q来自Decoder, K,V来自Encoder的memory)
    attn_out, _ = self.cross_attn(x, memory, memory, src_mask)
    x = self.norm2(x + self.dropout(attn_out))
    # 子层3: 前馈网络
    ff_out = self.feed_forward(x)
    x = self.norm3(x + self.dropout(ff_out))
    return x

class Decoder(nn.Module):
  def __init__(self, layer, N):
    super().__init__()
    self.layers = nn.ModuleList([layer for _ in range(N)])
    self.norm = LayerNorm(layer.d_model if hasattr(layer, 'd_model') else 512)

  def forward(self, x, memory, src_mask, tgt_mask):
    for layer in self.layers:
      x = layer(x, memory, src_mask, tgt_mask)
    return self.norm(x)

五、完整 Transformer 模型封装

将 Encoder、Decoder、Embedding 和生成器组合为完整的序列到序列模型:

参数 说明 默认值
d_model 模型隐藏维度 512
nhead 注意力头数 8
num_encoder_layers Encoder 层数 6
num_decoder_layers Decoder 层数 6
d_ff 前馈网络内层维度 2048
dropout Dropout 比率 0.1
vocab_size 词表大小 任务相关
class Generator(nn.Module):
  """线性层 + Softmax,将解码器输出映射为词表概率"""
  def __init__(self, d_model, vocab):
    super().__init__()
    self.proj = nn.Linear(d_model, vocab)

  def forward(self, x):
    return F.log_softmax(self.proj(x), dim=-1)

class Transformer(nn.Module):
  def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
         num_decoder_layers=6, d_ff=2048, dropout=0.1, vocab_size=32000):
    super().__init__()
    self.d_model = d_model

    encoder_layer = EncoderLayer(d_model, nhead, d_ff, dropout)
    decoder_layer = DecoderLayer(d_model, nhead, d_ff, dropout)

    self.encoder = Encoder(encoder_layer, num_encoder_layers)
    self.decoder = Decoder(decoder_layer, num_decoder_layers)
    self.src_embed = nn.Embedding(vocab_size, d_model)
    self.tgt_embed = nn.Embedding(vocab_size, d_model)
    self.pos_encoder = PositionalEncoding(d_model, dropout=dropout)
    self.generator = Generator(d_model, vocab_size)

    # 权重共享:Embedding 与 Generator 的线性层共享参数(乘以 sqrt(d_model))
    self.generator.proj.weight = self.tgt_embed.weight

  def forward(self, src, tgt, src_mask=None, tgt_mask=None):
    src_embedded = self.pos_encoder(self.src_embed(src) * math.sqrt(self.d_model))
    tgt_embedded = self.pos_encoder(self.tgt_embed(tgt) * math.sqrt(self.d_model))

    memory = self.encoder(src_embedded, src_mask)
    output = self.decoder(tgt_embedded, memory, src_mask, tgt_mask)
    return self.generator(output)

  def generate_square_subsequent_mask(self, sz):
    """生成因果掩码,防止解码时看到未来 token"""
    mask = torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)
    return (mask == 0).unsqueeze(0).unsqueeze(0) # [1, 1, sz, sz]

关键设计:src_embedgenerator.proj 共享权重矩阵,这是 Transformer 的标准做法,可大幅减少参数量并提升训练稳定性。

六、训练全流程

6.1 训练步骤总览

步骤 操作 PyTorch API
1 Tokenization(词元化) tokenizer
2 Embedding + 位置编码nn.Embedding + PositionalEncoding
3 前向传播model(src, tgt, src_mask, tgt_mask)
4 计算损失F.nll_lossF.cross_entropy
5 反向传播loss.backward()
6 梯度裁剪torch.nn.utils.clip_grad_norm_
7 参数更新optimizer.step()

6.2 训练循环核心代码

# 超参数
d_model = 512
nhead = 8
num_layers = 6
d_ff = 2048
dropout = 0.1
vocab_size = 32000
epochs = 10
batch_size = 64

model = Transformer(d_model, nhead, num_layers, num_layers, d_ff, dropout, vocab_size)
criterion = nn.CrossEntropyLoss(ignore_index=0) # 0 为 padding token
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

for epoch in range(epochs):
  model.train()
  total_loss = 0
  for src, tgt in dataloader:
    src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
    tgt_mask = model.generate_square_subsequent_mask(tgt.size(1)).to(src.device)

    optimizer.zero_grad()
    output = model(src, tgt[:, :-1], src_mask, tgt_mask)
    loss = criterion(output.reshape(-1, vocab_size), tgt[:, 1:].reshape(-1))
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    total_loss += loss.item()

  print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}")

训练要点:解码器输入使用 tgt[:, :-1](去掉最后一个 token),目标使用 tgt[:, 1:](去掉第一个 token),这就是 Teacher Forcing 策略。梯度裁剪阈值设为 1.0 是 Transformer 训练的标准配置。

七、MiniGPT 的推理(生成)过程

训练完成后,模型以自回归方式逐 token 生成输出:每一步将已生成的 token 拼接回输入,预测下一个 token,直到生成结束符。

@torch.no_grad()
def greedy_decode(model, src, max_len=50, start_token=1, end_token=2):
  model.eval()
  memory = model.encoder(model.pos_encoder(model.src_embed(src) * math.sqrt(model.d_model)), None)
  tgt = torch.full((1, 1), start_token, dtype=torch.long, device=src.device)

  for _ in range(max_len):
    tgt_mask = model.generate_square_subsequent_mask(tgt.size(1)).to(src.device)
    out = model.decoder(model.pos_encoder(model.tgt_embed(tgt) * math.sqrt(model.d_model)),
              memory, None, tgt_mask)
    prob = model.generator(out[:, -1, :])
    next_token = prob.argmax(dim=-1)
    tgt = torch.cat([tgt, next_token.unsqueeze(0)], dim=1)
    if next_token.item() == end_token:
      break
  return tgt

八、关键参数对照表

参数 含义 MiniGPT 推荐值 论文原始值
d_model 隐藏维度 256~512 512
nhead 注意力头数 4~8 8
num_layers Encoder/Decoder 层数 2~4 6
d_ff 前馈内层维度 1024~2048 2048
dropout 丢弃率 0.1 0.1
max_len 最大序列长度 128~512 5000

MiniGPT 相比原始 Transformer 大幅缩减了层数和维度,这是它能在消费级显卡上运行的关键。

九、总结

从位置编码到多头注意力,从残差连接到自回归生成,本文完整复现了 Transformer 的每一个核心组件。所有代码均基于 PyTorch 原生 API 实现,不依赖任何高级封装,确保你能看清每一层计算的来龙去脉。

Transformer 的强大不在于复杂度,而在于"注意力即一切"这一简洁思想——它用一种统一的机制同时解决了长距离依赖、并行计算和可解释性三大难题。 手写一遍,你就真正拥有了它。

打赏
THE END
作者头像
AI铺子
关注ai行业发展,专注ai工具推荐