当前位置: 首页 > news >正文

PyTorch——优化器(9)

优化器根据梯度调整参数,以达到降低误差

import torch.optim
import torchvision
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.data import DataLoader# 加载CIFAR10测试数据集,设置transform将图像转换为Tensor
dataset = torchvision.datasets.CIFAR10("./data", train=False, transform=torchvision.transforms.ToTensor(),download=True)
# 创建数据加载器,设置批量大小为64
dataloader = DataLoader(dataset, batch_size=64)# 定义卷积神经网络模型
class TY(nn.Module):def __init__(self):super(TY, self).__init__()# 构建网络结构:3个卷积层+池化层组合,2个全连接层self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),    # 输入3通道,输出32通道,卷积核5x5MaxPool2d(2),                   # 最大池化,步长2Conv2d(32, 32, 5, padding=2),   # 第二层卷积MaxPool2d(2),                   # 第二次池化Conv2d(32, 64, 5, padding=2),   # 第三层卷积MaxPool2d(2),                   # 第三次池化Flatten(),                      # 将多维张量展平为向量Linear(1024, 64),               # 全连接层,输入1024维,输出64维Linear(64, 10),                 # 输出层,10个类别对应10个输出)def forward(self, x):# 定义前向传播路径x = self.model1(x)return x# 定义损失函数(交叉熵损失适用于多分类问题)
loss = nn.CrossEntropyLoss()
# 实例化模型
ty = TY()
# 定义优化器(随机梯度下降),设置学习率为0.01
optim = torch.optim.SGD(ty.parameters(), lr=0.01)# 训练20个完整轮次
for epoch in range(20):running_loss = 0.0  # 初始化本轮累计损失# 遍历数据加载器中的每个批次for data in dataloader:imgs, targets = data  # 获取图像和标签outputs = ty(imgs)    # 前向传播result_loss = loss(outputs, targets)  # 计算损失optim.zero_grad()     # 梯度清零,防止累积result_loss.backward()  # 反向传播计算梯度optim.step()          # 更新模型参数running_loss += result_loss  # 累加损失值# 打印本轮训练的累计损失print(f"Epoch {epoch+1}, Loss: {running_loss}")

http://www.lqws.cn/news/135253.html

相关文章:

  • Web开发主流前后端框架总结
  • [特殊字符] Spring Boot底层原理深度解析与高级面试题精析
  • PDF处理控件Aspose.PDF教程:在 C# 中更改 PDF 页面大小
  • (LeetCode 每日一题)3403. 从盒子中找出字典序最大的字符串 I (贪心+枚举)
  • [Java 基础]面向对象-封装
  • STM32上部署AI的两个实用软件——Nanoedge AI Studio和STM32Cube AI
  • C++11 中 auto 和 decltype 的深入解析
  • 服务器中僵尸网络攻击是指什么?
  • 前端css外边距塌陷(Margin Collapse)现象原因和解决方法
  • 编程技能:格式化打印04,sprintf
  • 虚拟斯德哥尔摩症候群:用户为何为缺陷AI辩护?
  • 在CSDN发布AWS Proton解决方案:实现云原生应用的标准化部署
  • AWS DocumentDB vs MongoDB:数据库的技术抉择
  • AWS 成本异常检测IAM策略
  • 【知识点】第6章:组合数据类型
  • idea相关功能
  • sylar--线程模块
  • Java面试题及答案整理( 2025年最新版,持续更新...)
  • 从OCR到Document Parsing,AI时代的非结构化数据处理发生了什么改变?
  • Edge Databases:赋能分布式计算环境
  • Elasticsearch的写入性能优化
  • 旅游微信小程序制作指南
  • 【2025】通过idea把项目到私有仓库(3)
  • OD 算法题 B卷【DNA序列】
  • SQL 中 IN 和 EXISTS 的区别
  • 李飞飞World Labs开源革命性Web端3D渲染器Forge!3D高斯溅射技术首次实现全平台流畅运行
  • 【DeepSeek】【Dify】:用 Dify 对话流+标题关键词注入,让 RAG 准确率飞跃
  • 计算机I/O系统:数据交互的核心桥梁
  • Manus AI 现在可以生成短片了
  • 数据结构期末PTA选择汇总