当前位置: 首页 > news >正文

DAY 40 超大力王爱学Python

知识点回顾:

  1. 彩色和灰度图片测试和训练的规范写法:封装在函数中
  2. 展平操作:除第一个维度batchsize外全部展平
  3. dropout操作:训练阶段随机丢弃神经元,测试阶段eval模式关闭dropout

作业:仔细学习下测试和训练代码的逻辑,这是基础,这个代码框架后续会一直沿用,后续的重点慢慢就是转向模型定义阶段了。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from typing import Tuple, Callable, Optional# 1. 数据预处理与加载函数
def get_data_loaders(dataset_name: str = 'MNIST',  # 可选: 'MNIST'或'CIFAR10'batch_size: int = 64,data_dir: str = './data',num_workers: int = 2
) -> Tuple[DataLoader, DataLoader]:"""获取训练和测试数据加载器,支持灰度(MNIST)和彩色(CIFAR10)数据集"""# 根据数据集类型设置不同的转换if dataset_name == 'MNIST':# 灰度图像转换transform = transforms.Compose([transforms.ToTensor(),  # 转换为Tensor并归一化到[0,1]transforms.Normalize((0.1307,), (0.3081,))  # MNIST数据集的均值和标准差])train_dataset = datasets.MNIST(data_dir, train=True, download=True, transform=transform)test_dataset = datasets.MNIST(data_dir, train=False, transform=transform)elif dataset_name == 'CIFAR10':# 彩色图像转换transform = transforms.Compose([transforms.ToTensor(),  # 转换为Tensor并归一化到[0,1]transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 标准化到[-1,1]])train_dataset = datasets.CIFAR10(data_dir, train=True, download=True, transform=transform)test_dataset = datasets.CIFAR10(data_dir, train=False, transform=transform)else:raise ValueError(f"Unsupported dataset: {dataset_name}")# 创建数据加载器train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=num_workers)test_loader = DataLoader(test_dataset,batch_size=batch_size,shuffle=False,num_workers=num_workers)return train_loader, test_loader# 2. 通用模型定义(支持灰度和彩色图像)
class Flatten(nn.Module):"""自定义展平层,保留batch维度"""def forward(self, x):return x.view(x.size(0), -1)  # 保留batch维度,展平其余维度class ImageClassifier(nn.Module):"""通用图像分类器,支持灰度和彩色图像"""def __init__(self,input_channels: int = 1,  # MNIST:1, CIFAR10:3input_size: int = 28,     # MNIST:28, CIFAR10:32hidden_size: int = 128,num_classes: int = 10,dropout_rate: float = 0.5):super().__init__()self.model = nn.Sequential(Flatten(),  # 展平除batch外的所有维度nn.Linear(input_channels * input_size * input_size, hidden_size),nn.ReLU(),nn.Dropout(dropout_rate),  # 训练时随机丢弃神经元nn.Linear(hidden_size, num_classes))def forward(self, x):return self.model(x)# 3. 训练函数
def train(model: nn.Module,train_loader: DataLoader,criterion: nn.Module,optimizer: optim.Optimizer,device: torch.device,epoch: int,log_interval: int = 100
) -> None:"""训练模型一个epoch"""model.train()  # 启用训练模式(激活dropout等)running_loss = 0.0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)# 前向传播optimizer.zero_grad()output = model(data)loss = criterion(output, target)# 反向传播loss.backward()optimizer.step()running_loss += loss.item()# 打印训练进度if batch_idx % log_interval == 0:print(f'Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} 'f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')# 打印平均损失avg_loss = running_loss / len(train_loader)print(f'Epoch {epoch} average loss: {avg_loss:.4f}')# 4. 测试函数
def test(model: nn.Module,test_loader: DataLoader,criterion: nn.Module,device: torch.device
) -> Tuple[float, float]:"""评估模型在测试集上的性能"""model.eval()  # 启用评估模式(关闭dropout等)test_loss = 0correct = 0with torch.no_grad():  # 不计算梯度,节省内存和计算资源for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item()  # 累加批次损失pred = output.argmax(dim=1, keepdim=True)  # 获取最大概率的类别correct += pred.eq(target.view_as(pred)).sum().item()  # 统计正确预测数# 计算平均损失和准确率test_loss /= len(test_loader)accuracy = 100. * correct / len(test_loader.dataset)print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} 'f'({accuracy:.2f}%)\n')return test_loss, accuracy# 5. 主函数:训练和测试流程
def main(dataset_name: str = 'MNIST',batch_size: int = 64,epochs: int = 5,lr: float = 0.001,dropout_rate: float = 0.5,use_cuda: bool = True
) -> None:"""主函数:整合数据加载、模型训练和测试流程"""# 设置设备device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu")print(f"Using device: {device}")# 获取数据加载器train_loader, test_loader = get_data_loaders(dataset_name=dataset_name,batch_size=batch_size)# 确定输入参数if dataset_name == 'MNIST':input_channels = 1input_size = 28num_classes = 10elif dataset_name == 'CIFAR10':input_channels = 3input_size = 32num_classes = 10else:raise ValueError(f"Unsupported dataset: {dataset_name}")# 初始化模型model = ImageClassifier(input_channels=input_channels,input_size=input_size,hidden_size=128,num_classes=num_classes,dropout_rate=dropout_rate).to(device)# 定义损失函数和优化器criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=lr)# 训练和测试循环for epoch in range(1, epochs + 1):train(model, train_loader, criterion, optimizer, device, epoch)test(model, test_loader, criterion, device)# 保存模型torch.save(model.state_dict(), f"{dataset_name}_mlp_model.pth")print(f"Model saved as: {dataset_name}_mlp_model.pth")if __name__ == "__main__":# 训练MNIST模型main(dataset_name='MNIST', batch_size=64, epochs=5)# 训练CIFAR10模型(取消注释下面一行)# main(dataset_name='CIFAR10', batch_size=64, epochs=10)    

@浙大疏锦行

http://www.lqws.cn/news/81145.html

相关文章:

  • 【多线程初阶】内存可见性问题 volatile
  • Java线程生命周期详解
  • Promise与Async/Await:现代JavaScript异步编程的利器
  • 高效使用Map的“新”方法
  • 模块二:C++核心能力进阶(5篇)篇二:《多线程编程:C++线程池与原子操作实战》(14万字深度指南)
  • openai-java
  • Java详解LeetCode 热题 100(23):LeetCode 206. 反转链表(Reverse Linked List)详解
  • 做好 4个基本动作,拦住性能优化改坏原功能的bug
  • ps色彩平衡调整
  • github 提交失败,连接不上
  • 数值与字典解决方案二十七讲:两列数据相互去掉重复值后合并
  • 【Java Web】速通Tomcat
  • 【性能调优系列】深入解析火焰图:从基础阅读到性能优化实战
  • 导入典籍数据
  • Docker 镜像原理
  • React 核心概念与生态系统
  • js的时间循环的讲解
  • sqlite-vec:谁说SQLite不是向量数据库?
  • 题目 3225: 蓝桥杯2024年第十五届省赛真题-回文字符串
  • 光伏功率预测 | LSTM多变量单步光伏功率预测(Matlab完整源码和数据)
  • 机器视觉图像处理之图像滤波
  • 从多巴胺的诱惑到内啡肽的力量 | 个体成长代际教育的成瘾困局与破局之道
  • Python----目标检测(《YOLO9000: Better, Faster, Stronger》和YOLO-V2的原理与网络结构)
  • 蓝云APP:云端存储,便捷管理
  • Linux入门(十三)动态监控系统监控网络状态
  • (Python网络爬虫);抓取B站404页面小漫画
  • 探秘 Minimax:AI 领域的创新先锋
  • C# 异常处理进阶:精准获取错误行号的通用方案
  • JS中的 WeakSet 和 WeakMap
  • Y1——链式前向星