当前位置: 首页 > news >正文

Python Day44

Task:
1.预训练的概念
2.常见的分类预训练模型
3.图像预训练模型的发展史
4.预训练的策略
5.预训练代码实战:resnet18


1. 预训练的概念

预训练(Pre-training)是指在大规模数据集上,先训练模型以学习通用的特征表示,然后将其用于特定任务的微调。这种方法可以显著提高模型在目标任务上的性能,减少训练时间和所需数据量。

核心思想:

  • 在大规模、通用的数据(如ImageNet)上训练模型,学习丰富的特征表示。
  • 将预训练模型应用于任务特定的细调(fine-tuning),使模型适应目标任务。

优势:

  • 提升模型性能
  • 缩短训练时间
  • 需要较少的标注数据
  • 提供良好的特征初始化

2. 常见的分类预训练模型

常见的分类预训练模型主要包括:

模型名称提出年份特色与应用
AlexNet2012标志深度学习重返计算机视觉的起点
VGG(VGG16/19)2014简洁结构,深层网络,广泛用于特征提取
ResNet(Residual Network)2015引入残差连接,解决深层网络退化问题
Inception(GoogLeNet)2014多尺度特征提取,复杂模块设计
DenseNet2017密集连接,加深网络而不增加参数
MobileNet2017轻量级模型,适合移动端应用
EfficientNet2019根据模型宽度、深度和分辨率优化设计

这些模型在ImageNet等大规模数据集上预训练,成为计算机视觉各种任务的基础。


3. 图像预训练模型的发展史

  1. AlexNet (2012)
    首次使用深度卷积神经网络大规模应用于ImageNet,显著提升分类效果。

  2. VGG系列 (2014)
    简单堆叠卷积和池化层,深度逐步增加,提高表现。

  3. GoogLeNet/Inception (2014)
    引入Inception模块,进行多尺度特征提取,有效提升效率。

  4. ResNet (2015)
    通过残差连接解决深层网络的退化问题,使网络深度大幅提升(如ResNet-50,ResNet-101等)。

  5. DenseNet (2017)
    特色是密集连接,增强特征传播,改善梯度流。

  6. MobileNet, EfficientNet (2017-2019)
    追求轻量级和高效率,适应移动端和资源有限场景。

总的趋势:

  • 从浅层逐步向深层网络发展
  • 引入残差、密集连接等结构解决深层网络训练难题
  • 注重模型效率与性能平衡

4. 预训练的策略

常用的预训练策略包括:

1. 直接使用预训练模型进行微调(Fine-tuning)

  • 加载预训练权重
  • 替换最后的分类层以适应新任务(如类别数不同)
  • 选择性冻结部分层(如只训练最后几层)或全部训练

2. 特征提取(Feature Extraction)

  • 使用预训练模型的固定特征提取器,从中提取特征
  • 在这些特征基础上训练简单的分类器(如SVM或线性层)

3. 逐层逐步微调(Layer-wise Fine-tuning)

  • 先冻结底层特征层,只训练高层
  • 再逐步解冻低层,进行全层微调

4. 迁移学习(Transfer Learning)

  • 利用预训练模型迁移到相似领域任务中
  • 通过微调适应不同数据分布和任务需求

5. 预训练代码实战:ResNet18

以下是基于PyTorch框架的ResNet18预训练模型加载和微调的示例代码:

import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader# 1. 加载预训练ResNet18模型
model = models.resnet18(pretrained=True)# 2. 替换分类层以适应新任务(比如有10个类别)
num_classes = 10
model.fc = nn.Linear(model.fc.in_features, num_classes)# 3. 冻结前面层,只训练最后的全连接层(可选)
for param in model.parameters():param.requires_grad = False  # 冻结所有参数# 只训练最后一层参数
for param in model.fc.parameters():param.requires_grad = True# 4. 定义数据变换
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
])# 5. 加载数据集
train_dataset = ImageFolder('path_to_train_data', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)val_dataset = ImageFolder('path_to_val_data', transform=transform)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)# 6. 设置优化器(只优化可训练参数)
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)
criterion = nn.CrossEntropyLoss()# 7. 训练环节
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)for epoch in range(10):model.train()total_loss = 0for images, labels in train_loader:images = images.to(device)labels = labels.to(device)optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()total_loss += loss.item()print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}")# 8. 评估
model.eval()
correct = 0
total = 0
with torch.no_grad():for images, labels in val_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs, 1)correct += (predicted == labels).sum().item()total += labels.size(0)
print(f'Validation Accuracy: {100 * correct / total:.2f}%')

总结

  • 预训练是一种利用大规模数据学习通用特征,从而在目标任务中快速获得优秀表现的技术。
  • 常用的分类预训练模型包括ResNet、VGG、Inception等,发展经历了从浅层到深层、从视觉到效率的不断演变。
  • 预训练策略多样,适应不同场景,微调与特征提取是常用手段。
  • 实战中,可以利用PyTorch提供的模型接口快速加载预训练模型,并进行微调以满足具体需求。
http://www.lqws.cn/news/166969.html

相关文章:

  • 设计模式(代理设计模式)
  • NLP学习路线图(二十六):自注意力机制
  • Wireshark使用教程(含安装包和安装教程)
  • JS深入学习 — 循环、函数、数组、字符串、Date对象,Math对象
  • 哈希算法实战全景:安全加密到分布式系统的“核心引擎”
  • 深入理解Java多态性:原理、实现与应用实例
  • 【Linux手册】冯诺依曼体系结构
  • day34- 系统编程之 网络编程(TCP)
  • ObjectMapper 在 Spring 统一响应处理中的作用详解
  • AI Agent 项目 SUNA 部署环境搭建 - 基于 MSYS2 的 Poetry+Python3.11 虚拟环境
  • 【操作系统】死锁
  • JSON Web Token (JWT) 详解:由来、原理与应用实践
  • 在 Ubuntu 24.04 LTS 上安装 Jenkins 并配置全局工具(Git、JDK、Maven)
  • LeetCode-70. 爬楼梯
  • 八、Python模块、包
  • QT中使用libcurl库实现到ftp服务器的上传和下载
  • C语言 — 编译和链接
  • 体制内 AI写作:推荐材料星 AI文章修改润色
  • 11. vue pinia 和react redux、jotai对比
  • 互联网大厂Java求职面试:AI与大模型技术在企业知识库中的深度应用
  • minimatch 详解:功能、语法与应用场景
  • uniapp+vue3实现CK通信协议(基于jjc-tcpTools)
  • IDEA 包分层显示设置
  • BT Panel密码修改
  • 【Redis】类型补充
  • ROS2--导航仿真
  • sumatraPDF设置深色界面
  • YOLOv11 | 注意力机制篇 | 可变形大核注意力Deformable-LKA与C2PSA机制
  • JTAG与SWD的功能辩解有和相关
  • Mysql主从复制原理分析