当前位置: 首页 > news >正文

打卡Day45

使用PyTorch在CIFAR10数据集上微调ResNet18,并用TensorBoard监控训练过程

1. 环境准备

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
import numpy as np
import os

2. 数据预处理与加载

# 数据增强和归一化(使用ImageNet统计量)
train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])# 加载数据集
train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_set, batch_size=100, shuffle=False, num_workers=2)

3. 模型准备(ResNet18微调)

# 加载预训练模型并修改
model = torchvision.models.resnet18(pretrained=True)# 修改第一层适配32x32输入(原始为224x224)
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
model.maxpool = nn.Identity()  # 移除初始maxpool# 修改最后的全连接层(CIFAR10有10类)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)# 移动到GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

4. 训练配置

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)# 创建TensorBoard writer
writer = SummaryWriter('runs/resnet18_cifar10_finetune')

5. 训练循环(集成TensorBoard日志)

def train(epoch):model.train()train_loss = 0correct = 0total = 0for batch_idx, (inputs, targets) in enumerate(train_loader):inputs, targets = inputs.to(device), targets.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, targets)loss.backward()optimizer.step()train_loss += loss.item()_, predicted = outputs.max(1)total += targets.size(0)correct += predicted.eq(targets).sum().item()# 记录batch级数据if batch_idx % 100 == 0:writer.add_scalar('Training/Loss (batch)', loss.item(), epoch * len(train_loader) + batch_idx)writer.add_scalar('Training/Accuracy (batch)', 100. * correct / total, epoch * len(train_loader) + batch_idx)# 记录epoch级数据avg_loss = train_loss / len(train_loader)acc = 100. * correct / totalwriter.add_scalar('Training/Loss (epoch)', avg_loss, epoch)writer.add_scalar('Training/Accuracy (epoch)', acc, epoch)print(f'Epoch: {epoch} | Train Loss: {avg_loss:.3f} | Acc: {acc:.2f}%')return acc, avg_lossdef test(epoch):model.eval()test_loss = 0correct = 0total = 0with torch.no_grad():for batch_idx, (inputs, targets) in enumerate(test_loader):inputs, targets = inputs.to(device), targets.to(device)outputs = model(inputs)loss = criterion(outputs, targets)test_loss += loss.item()_, predicted = outputs.max(1)total += targets.size(0)correct += predicted.eq(targets).sum().item()# 记录验证结果avg_loss = test_loss / len(test_loader)acc = 100. * correct / totalwriter.add_scalar('Validation/Loss', avg_loss, epoch)writer.add_scalar('Validation/Accuracy', acc, epoch)# 记录学习率writer.add_scalar('Learning Rate', optimizer.param_groups[0]['lr'], epoch)print(f'Test Loss: {avg_loss:.3f} | Acc: {acc:.2f}%')return acc, avg_loss# 主训练循环
for epoch in range(100):train_acc, train_loss = train(epoch)test_acc, test_loss = test(epoch)scheduler.step()# 保存最佳模型if test_acc > best_acc:best_acc = test_acctorch.save(model.state_dict(), 'best_model.pth')writer.close()
http://www.lqws.cn/news/168715.html

相关文章:

  • 从0开始学习R语言--Day17--Cox回归
  • ES集群磁盘空间超水位线不可写的应急处理
  • 【AI News | 20250605】每日AI进展
  • K8S认证|CKS题库+答案| 2. Pod 指定 ServiceAccount
  • 七彩喜智慧养老平台:科技赋能下的市场蓝海,满足多样化养老服务需求
  • OpenStack组件:放置服务(Placement)安装
  • 数据可视化大屏案例落地实战指南:捷码平台7天交付方法论
  • 看板中“进行中”任务过多如何优化
  • 单精度浮点数值 和 双精度浮点数值
  • 基于51单片机的车内防窒息检测报警系统
  • 【运维心得】内存占用虚标真相
  • vue-19(Vuex异步操作和变更)
  • 使用ArcPy进行栅格数据分析(2)
  • JAVA之 Lambda
  • 【赵渝强老师】Docker的图形化管理工具
  • 【JavaEE】万字详解HTTP协议
  • 残月个人拟态主页
  • RADIUS 协议 (Remote Authentication Dial-In User Service)
  • 华为交换机vlan配置步骤
  • 《最长公共子序列》题集
  • 8086寻址解剖图:7种武器解锁x86内存访问的基因密码
  • Linux --环境变量,虚拟地址空间
  • 直线导轨微型化技术难点在哪里?
  • Python基于方差-协方差方法实现投资组合风险管理的VaR与ES模型项目实战
  • Java并发编程实战 Day 10:原子操作类详解
  • 边缘计算应用实践心得
  • P10909 [蓝桥杯 2024 国 B] 立定跳远
  • Python Einops库:深度学习中的张量操作革命
  • 使用 uv 工具快速部署并管理 vLLM 推理环境
  • 前端面试四之Fetch API同步和异步