当前位置: 首页 > news >正文

Python训练营---Day44

DAY 44 预训练模型

知识点回顾:

  1. 预训练的概念
  2. 常见的分类预训练模型
  3. 图像预训练模型的发展史
  4. 预训练的策略
  5. 预训练代码实战:resnet18

作业:

  1. 尝试在cifar10对比如下其他的预训练模型,观察差异,尽可能和他人选择的不同
  2. 尝试通过ctrl进入resnet的内部,观察残差究竟是什么

选用 DenseNet121预训练模型,注意DenseNet121 模型的最后分类层名为classifier,而不是 ResNet 中的fc

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os
from torchvision.models import resnet18, densenet121, vgg16# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")# 1. 数据预处理(训练集增强,测试集标准化)
train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),transforms.RandomRotation(15),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])# 2. 加载CIFAR-10数据集
train_dataset = datasets.CIFAR10(root='./cifar_data',train=True,download=True,transform=train_transform
)test_dataset = datasets.CIFAR10(root='./cifar_data',train=False,transform=test_transform
)# 3. 创建数据加载器
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)# 4. 定义DenseNet121模型
def create_densenet121(pretrained=True, num_classes=10):model = models.densenet121(pretrained=pretrained)# 修改最后一层全连接层in_features = model.classifier.in_featuresmodel.classifier = nn.Linear(in_features, num_classes) # DenseNet121 的最后一层分类器名称是classifierreturn model.to(device)# 5. 冻结/解冻模型层的函数
# 这种设计允许我们在迁移学习中保留预训练模型的特征提取部分(卷积层),只训练新添加的分类层(全连接层)。
def freeze_model(model, freeze=True):"""冻结或解冻模型的卷积层参数"""# 冻结/解冻除fc层外的所有参数for name, param in model.named_parameters():if 'classifier' not in name:    #排除名称中包含 "fc" 的参数,这些通常是全连接层的参数param.requires_grad = not freeze    #param.requires_grad是 PyTorch 中控制参数是否参与反向传播和梯度更新的标志# 打印冻结状态frozen_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)   #统计所有requires_grad=False的参数数量total_params = sum(p.numel() for p in model.parameters())if freeze:print(f"已冻结模型卷积层参数 ({frozen_params}/{total_params} 参数)")else:print(f"已解冻模型所有参数 ({total_params}/{total_params} 参数可训练)")return model# 6. 训练函数(支持阶段式训练)
def train_with_freeze_schedule(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs, freeze_epochs=5):"""前freeze_epochs轮冻结卷积层,之后解冻所有层进行训练"""train_loss_history = []test_loss_history = []train_acc_history = []test_acc_history = []all_iter_losses = []iter_indices = []# 初始冻结卷积层if freeze_epochs > 0:model = freeze_model(model, freeze=True)for epoch in range(epochs):# 解冻控制:在指定轮次后解冻所有层if epoch == freeze_epochs:model = freeze_model(model, freeze=False)# 解冻后调整优化器(可选)optimizer.param_groups[0]['lr'] = 1e-4  # 降低学习率防止过拟合model.train()  # 设置为训练模式running_loss = 0.0correct_train = 0total_train = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()# 记录Iteration损失iter_loss = loss.item()all_iter_losses.append(iter_loss)iter_indices.append(epoch * len(train_loader) + batch_idx + 1)# 统计训练指标running_loss += iter_loss_, predicted = output.max(1)total_train += target.size(0)correct_train += predicted.eq(target).sum().item()# 每100批次打印进度if (batch_idx + 1) % 100 == 0:print(f"Epoch {epoch+1}/{epochs} | Batch {batch_idx+1}/{len(train_loader)} "f"| 单Batch损失: {iter_loss:.4f}")# 计算 epoch 级指标epoch_train_loss = running_loss / len(train_loader)epoch_train_acc = 100. * correct_train / total_train# 测试阶段model.eval()correct_test = 0total_test = 0test_loss = 0.0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item()_, predicted = output.max(1)total_test += target.size(0)correct_test += predicted.eq(target).sum().item()epoch_test_loss = test_loss / len(test_loader)epoch_test_acc = 100. * correct_test / total_test# 记录历史数据train_loss_history.append(epoch_train_loss)test_loss_history.append(epoch_test_loss)train_acc_history.append(epoch_train_acc)test_acc_history.append(epoch_test_acc)# 更新学习率调度器if scheduler is not None:scheduler.step(epoch_test_loss)# 打印 epoch 结果print(f"Epoch {epoch+1} 完成 | 训练损失: {epoch_train_loss:.4f} "f"| 训练准确率: {epoch_train_acc:.2f}% | 测试准确率: {epoch_test_acc:.2f}%")# 绘制损失和准确率曲线plot_iter_losses(all_iter_losses, iter_indices)plot_epoch_metrics(train_acc_history, test_acc_history, train_loss_history, test_loss_history)return epoch_test_acc  # 返回最终测试准确率# 7. 绘制Iteration损失曲线
def plot_iter_losses(losses, indices):plt.figure(figsize=(10, 4))plt.plot(indices, losses, 'b-', alpha=0.7)plt.xlabel('Iteration(Batch序号)')plt.ylabel('损失值')plt.title('训练过程中的Iteration损失变化')plt.grid(True)plt.show()# 8. 绘制Epoch级指标曲线
def plot_epoch_metrics(train_acc, test_acc, train_loss, test_loss):epochs = range(1, len(train_acc) + 1)plt.figure(figsize=(12, 5))# 准确率曲线plt.subplot(1, 2, 1)plt.plot(epochs, train_acc, 'b-', label='训练准确率')plt.plot(epochs, test_acc, 'r-', label='测试准确率')plt.xlabel('Epoch')plt.ylabel('准确率 (%)')plt.title('准确率随Epoch变化')plt.legend()plt.grid(True)# 损失曲线plt.subplot(1, 2, 2)plt.plot(epochs, train_loss, 'b-', label='训练损失')plt.plot(epochs, test_loss, 'r-', label='测试损失')plt.xlabel('Epoch')plt.ylabel('损失值')plt.title('损失值随Epoch变化')plt.legend()plt.grid(True)plt.tight_layout()plt.show()# 主函数:训练模型
def main():# 参数设置epochs = 40  # 总训练轮次freeze_epochs = 5  # 冻结卷积层的轮次learning_rate = 1e-3  # 初始学习率weight_decay = 1e-4  # 权重衰减# 创建DenseNet121模型(加载预训练权重)model = create_densenet121(pretrained=True, num_classes=10)# 定义优化器和损失函数optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)criterion = nn.CrossEntropyLoss()# 定义学习率调度器scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)# 开始训练(前5轮冻结卷积层,之后解冻)final_accuracy = train_with_freeze_schedule(model=model,train_loader=train_loader,test_loader=test_loader,criterion=criterion,optimizer=optimizer,scheduler=scheduler,device=device,epochs=epochs,freeze_epochs=freeze_epochs)print(f"训练完成!最终测试准确率: {final_accuracy:.2f}%")# # 保存模型# torch.save(model.state_dict(), 'resnet18_cifar10_finetuned.pth')# print("模型已保存至: resnet18_cifar10_finetuned.pth")if __name__ == "__main__":main()

http://www.lqws.cn/news/153757.html

相关文章:

  • 捍卫低空安全!-中科固源发现无人机MavLink协议远程内存泄漏漏洞
  • VisDrone无人机视觉挑战赛观察解析2025.6.5
  • [Zynq] Zynq Linux 环境下 AXI UART Lite 使用方法详解(代码示例)
  • 免费wordpress模板下载
  • ES 学习总结一 基础内容
  • MPNet:旋转机械轻量化故障诊断模型详解python代码复现
  • electron主进程和渲染进程之间的通信
  • mysql跨库关联查询及视图创建
  • IDEA 开发PHP配置调试插件XDebug
  • 人脸识别技术应用备案材料详细解析
  • 【数据集】MODIS 8日合成1公里地表温度LST产品
  • 虎扑正式易主,迅雷完成收购会带来什么变化?
  • 理解电池的极化:极化内阻与欧姆内阻解析
  • 第一章:数据结构概述
  • uniapp运行在微信开发者工具中流程
  • 云服务器Xshell登录拒绝访问排查
  • std::conditional_t一个用法
  • HikariCP数据库连接池原理解析
  • 智能照明系统:具备认知能力的“光神经网络”
  • Python-内置函数
  • 【SSM】SpringBoot笔记2:整合Junit、MyBatis
  • 「Java教案」选择结构
  • 解决 Git 访问 GitHub 时的 SSL 错误
  • 软考 系统架构设计师系列知识点之杂项集萃(81)
  • 大陆4D毫米波雷达ARS548调试
  • 线程的基础知识
  • 基于eclipse进行Birt报表开发
  • MySQL间隙锁入手,拿下间隙锁面试与实操
  • python变量
  • Java-IO流之转换流详解