AI大模型基础:预训练与微调(迁移学习与微调策略)

猿享天开 发布日期:
11

AI大模型基础:预训练与微调(迁移学习与微调策略)

预训练与微调是现代AI大模型(如BERT、GPT、ViT)的核心技术,基于迁移学习范式,通过在大规模数据集上预训练模型并在特定任务上微调,显著提升性能和效率。本文将深入讲解预训练与微调的原理、实现方法及在实际场景中的应用,适合对AI大模型感兴趣的读者。

AI

一、预训练与微调概述

1.1 定义与目标

  • 预训练:在大规模、无标注或弱标注数据上训练模型,学习通用特征(如语言模式、视觉特征),构建强大的初始参数。

  • 微调:在特定下游任务(如文本分类、图像分割)上调整预训练模型参数,适配任务需求。

  • 目标

    • 泛化能力:预训练捕捉通用知识,微调适配特定场景。

    • 高效性:减少从头训练成本,适合小数据集。

    • 高性能:提升任务精度,尤其在数据有限的领域(如医学影像)。

1.2 迁移学习背景

  • 迁移学习:将从一个任务(源任务)学到的知识应用于另一个任务(目标任务)。预训练与微调是迁移学习的典型实现。

  • 发展历程

    • 早期:特征提取(如SIFT、HOG)+简单分类器(如SVM)。

    • 深度学习时代:CNN(如VGG、ResNet)预训练于ImageNet,微调下游任务。

    • Transformer时代:BERT、GPT、ViT通过自监督预训练,革新NLP和CV。

  • 优势

    • 利用大规模数据(如Wikipedia、ImageNet)学习通用表示。

    • 小数据集也能实现高性能,适合医学影像等场景。

1.3 重要性

  • NLP:预训练(如BERT的MLM)捕获语义,微调适配情感分析、问答等任务。

  • CV:ViT预训练于ImageNet,微调用于肿瘤检测、器官分割。

  • 医学影像:数据稀缺,预训练模型(如ViT)通过迁移学习显著提升分类精度。

1.4 挑战

  • 计算成本:预训练需大量GPU/TPU资源。

  • 过拟合风险:微调时小数据集可能导致过拟合。

  • 任务适配:不同任务需不同微调策略(如全参数 vs. 部分微调)。

  • 可解释性:微调后模型行为难以解释,医学领域需谨慎。

二、预训练原理与实现

2.1 原理

预训练通过自监督或弱监督任务,在大规模数据上学习通用表示,无需特定任务的标注数据。

核心机制
  • 自监督学习(SSL)

    • NLP:掩码语言模型(MLM,如BERT)、自回归语言建模(如GPT)。

    • CV:图像分类(如ViT)、对比学习(如SimCLR)。

  • 数据规模:需大规模语料(如Wikipedia、BooksCorpus)或图像数据集(如ImageNet、JFT-300M)。

  • 模型结构

    • Transformer:基于自注意力机制,捕获全局依赖。

    • 多层设计:如BERT的12/24层,ViT的Patch嵌入+编码器。

预训练任务
  1. 掩码语言模型(MLM)

    • 随机掩盖输入词(15%),预测被掩盖词,学习双向上下文(BERT)。

    • 公式:最大化条件概率 AI大模型基础:预训练与微调(迁移学习与微调策略)

  2. 自回归语言建模

    • 预测下一个词,基于前文(GPT)。

    • 公式:最大化AI大模型基础:预训练与微调(迁移学习与微调策略)

  3. 图像分类

    • ViT在ImageNet上预测类别,学习视觉特征。

    • 公式:最小化交叉熵损失AI大模型基础:预训练与微调(迁移学习与微调策略)

  4. 对比学习

    • 增强图像(如旋转、裁剪),使相同图像的表示靠近(SimCLR)。

数学基础
  • 自注意力

    AI大模型基础:预训练与微调(迁移学习与微调策略)

    其中 Q,V, K 为查询、键、值向量,AI大模型基础:预训练与微调(迁移学习与微调策略)为键维度。

  • 损失函数

    • MLM:交叉熵损失,预测掩盖词。

    • 自回归:最大化序列似然。

    • 分类:交叉熵或对比损失。

优缺点
  • 优点:学习通用表示,减少下游任务数据需求。

  • 缺点:计算成本高,需大规模数据支持。

  • 适用场景:NLP(文本理解/生成)、CV(图像分类/分割)。

2.2 实现示例(Python)

以下以BERT的MLM预训练为例,使用Hugging Face模拟小规模预训练:

from transformers import BertTokenizer, BertForMaskedLM
from transformers import DataCollatorForLanguageModeling, Trainer, TrainingArguments
from datasets import load_dataset
import torch

# 加载数据(示例:WikiText)
dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
texts = [text for text in dataset['text'] if len(text) > 0][:1000]  # 取前1000条

# 加载分词器和模型
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased')

# 预处理数据
def tokenize_function(examples):
    return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=128)

tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=['text'])
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)

# 设置训练参数
training_args = TrainingArguments(
    output_dir='./pretrain_results',
    num_train_epochs=1,
    per_device_train_batch_size=8,
    logging_steps=100,
    save_steps=500
)

# 训练模型
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=data_collator
)
trainer.train()

代码注释

  • load_dataset:加载WikiText数据集,模拟大规模语料。

  • BertTokenizer:将文本转换为输入ID和注意力掩码。

  • BertForMaskedLM:BERT模型,带MLM头。

  • DataCollatorForLanguageModeling:动态掩盖15%输入词,生成MLM任务。

  • Trainer:Hugging Face训练接口,简化预训练流程。

  • 注意:实际预训练需更大数据集(如Wikipedia)和更多计算资源。

三、微调策略

3.1 原理

微调通过调整预训练模型参数,适配特定下游任务,分为以下策略:

1. 全参数微调
  • 调整模型所有参数,适合数据充足场景。

  • 优点:充分利用预训练知识,性能最佳。

  • 缺点:计算成本高,易过拟合(小数据集)。

2. 部分微调
  • 冻结部分层(如低层编码器),仅微调顶层或分类头。

  • 优点:降低计算成本,减少过拟合风险。

  • 缺点:可能损失部分预训练知识。

3. 参数高效微调(PEFT)
  • 方法

    • LoRA(Low-Rank Adaptation):在权重矩阵上添加低秩更新,调整少量参数。

    • Prompt Tuning:添加可训练的提示向量,冻结模型参数。

    • Adapter:在每层插入小型适配器模块。

  • 公式(LoRA)

    AI大模型基础:预训练与微调(迁移学习与微调策略)

    其中 AI大模型基础:预训练与微调(迁移学习与微调策略)为预训练权重,AI大模型基础:预训练与微调(迁移学习与微调策略)为低秩更新,AI大模型基础:预训练与微调(迁移学习与微调策略)为小矩阵。

  • 优点:参数量少(<1%),适合资源受限场景。

  • 缺点:性能略低于全参数微调。

优缺点
  • 优点:灵活适配任务,降低训练成本。

  • 缺点:需根据任务选择策略,调试复杂。

  • 适用场景:小数据集(如医学影像)、资源受限环境。

3.2 实现示例(Python)

以下以BERT全参数微调(文本分类)和LoRA微调为例:

from transformers import BertTokenizer, BertForSequenceClassification
from peft import LoraConfig, get_peft_model
from transformers import Trainer, TrainingArguments
import torch
from torch.utils.data import Dataset

# 自定义数据集
class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        encoding = self.tokenizer(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

# 数据准备(示例:情感分析)
texts = ["I love this movie!", "This movie is terrible."]
labels = [1, 0]  # 1: 正向,0: 负向
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
dataset = TextDataset(texts, labels, tokenizer)

# 全参数微调
model_full = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
training_args_full = TrainingArguments(
    output_dir='./full_finetune_results',
    num_train_epochs=3,
    per_device_train_batch_size=8,
    logging_steps=10,
    save_steps=100
)
trainer_full = Trainer(model=model_full, args=training_args_full, train_dataset=dataset)
trainer_full.train()

# LoRA微调
model_lora = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
lora_config = LoraConfig(r=8, lora_alpha=16, target_modules=["query", "value"])
model_lora = get_peft_model(model_lora, lora_config)
training_args_lora = TrainingArguments(
    output_dir='./lora_finetune_results',
    num_train_epochs=3,
    per_device_train_batch_size=8,
    logging_steps=10,
    save_steps=100
)
trainer_lora = Trainer(model=model_lora, args=training_args_lora, train_dataset=dataset)
trainer_lora.train()

# 推理(以全参数微调为例)
model_full.eval()
text = "This is a great film!"
inputs = tokenizer(text, return_tensors='pt', max_length=128, padding=True, truncation=True)
outputs = model_full(**inputs)
predictions = torch.argmax(outputs.logits, dim=-1)
print("全参数微调预测:", "正向" if predictions.item() == 1 else "负向")

代码注释

  • TextDataset:处理文本和标签,适配BERT输入格式。

  • BertForSequenceClassification:预训练BERT,添加分类头(2类)。

  • LoraConfig:配置LoRA参数,r=8控制低秩矩阵维度,target_modules指定微调自注意力层。

  • get_peft_model:应用LoRA,冻结大部分参数,仅训练低秩更新。

  • Trainer:Hugging Face接口,简化微调流程。

  • 注意:LoRA显著减少参数量(约0.1%),适合小数据集或低资源场景。

四、应用案例

4.1 NLP:文本分类

  • 任务:情感分析(正向/负向),如电影评论分类。

  • 预训练模型:BERT,学习通用语义表示。

  • 微调策略

    • 全参数微调:适配大数据集(如IMDB)。

    • LoRA:适合小数据集(如医疗报告情感分析)。

  • 代码:见3.2实现,微调BERT进行二分类。

4.2 CV:图像分类

  • 任务:肿瘤检测(如乳腺癌X光片分类)。

  • 预训练模型:ViT,预训练于ImageNet。

  • 微调策略

    • 全参数微调:适配大型医学影像数据集。

    • Adapter:添加小型适配器,适合小数据集。

  • 实现示例

from transformers import ViTImageProcessor, ViTForImageClassification
from peft import LoraConfig, get_peft_model
from PIL import Image
import torch

# 加载数据(示例:医学影像)
image = Image.open("tumor_image.jpg").convert('RGB')
labels = [1]  # 1: 恶性,0: 良性

# 预处理
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
inputs = processor(images=image, return_tensors='pt')

# LoRA微调
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224', num_labels=2)
lora_config = LoraConfig(r=8, lora_alpha=16, target_modules=["query", "value"])
model = get_peft_model(model, lora_config)

# 训练(假设数据集)
# 类似TextDataset,省略训练代码
# 推理
model.eval()
outputs = model(**inputs)
predictions = torch.argmax(outputs.logits, dim=-1)
print("预测结果:", "恶性" if predictions.item() == 1 else "良性")

代码注释

  • ViTImageProcessor:处理图像,分块并归一化。

  • ViTForImageClassification:ViT模型,带分类头。

  • LoraConfig:应用LoRA微调,减少参数量。

4.3 医学影像领域

  • 任务:肺癌CT分类、脑部MRI分割。

  • 挑战:样本少、类别不平衡、误诊成本高。

  • 解决方案

    • 预训练:ViT在ImageNet或医学影像数据集(如LIDC-IDRI)上预训练。

    • 微调:使用LoRA或Adapter,适配小数据集,关注召回率(减少漏诊)。

  • 优势:迁移学习利用通用视觉特征,提升小数据集性能。

五、流程图与图表

5.1 预训练与微调流程图

以下是流程图:

流程图

说明

  • A(数据准备):大规模文本(Wikipedia)或图像(ImageNet)。

  • B(预训练):MLM(BERT)、自回归(GPT)、分类(ViT)。

  • C(预训练模型):生成通用表示的模型。

  • D(微调):全参数、LoRA或Adapter适配任务。

  • E(下游任务):分类(情感分析、肿瘤检测)、生成(对话)、分割(器官)。

  • F(输出):任务特定结果,如类别标签或生成文本。

5.2 图表:微调策略性能对比

以下为全参数微调与LoRA在分类任务上的性能对比折线图(假设数据)。

在这里插入图片描述

说明

  • 图表类型:折线图,比较全参数微调与LoRA的准确率。

  • X轴:数据集大小(100、1000、10000样本)。

  • Y轴:准确率,范围0.7-1.0。

  • 数据:假设数据,显示全参数微调在大数据集上更优,LoRA在小数据集上接近。

  • 医学意义:LoRA适合医学影像小数据集,平衡性能与资源。

六、总结与展望

6.1 总结

  • 预训练:通过自监督任务(如MLM、自回归)学习通用表示,降低下游任务数据需求。

  • 微调

    • 全参数微调:适合大数据集,性能最佳。

    • 部分微调/PEFT(如LoRA):适合小数据集或低资源场景。

  • 应用

    • NLP:情感分析、医学报告分类。

    • CV:肿瘤检测、器官分割。

    • capped at 100 samples for demonstration; actual pretraining requires much larger datasets.

6.2 展望

  • 高效预训练:探索更高效的自监督任务(如MAE for ViT),减少数据需求。

  • 自动化微调:开发自动化微调框架,动态选择最佳策略(如AutoML)。

  • 多模态迁移:结合文本和图像预训练,适配医学影像+报告任务。

  • 可解释性:结合SHAP或注意力可视化,解释微调后模型行为。

预训练与微调作为迁移学习的典型实现,已经在NLP和CV领域取得了显著成果。预训练模型通过大规模数据学习通用表示,而微调则通过调整参数适配具体任务,显著提升了模型的性能和泛化能力。无论是全参数微调还是参数高效微调(如LoRA),都能在不同场景下提供高效的解决方案。尤其在医学影像等数据稀缺的领域,预训练与微调技术展现出巨大的潜力。未来,随着更高效的预训练任务和自动化微调策略的发展,这类技术将在多模态、可解释性等方面进一步拓展其应用边界。

打赏
THE END
作者头像
AI铺子
关注ai行业发展,专注ai工具推荐