当前位置: 首页 > news >正文

DAY 38 Dataset和Dataloader类

知识点回顾:

  1. Dataset类的__getitem__和__len__方法(本质是python的特殊方法)
  2. Dataloader类
  3. minist手写数据集的了解

作业:了解下cifar数据集,尝试获取其中一张图片

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader , Dataset # DataLoader 是 PyTorch 中用于加载数据的工具
from torchvision import datasets, transforms # torchvision 是一个用于计算机视觉的库,datasets 和 transforms 是其中的模块
import matplotlib.pyplot as plt# 设置随机种子,确保结果可复现
torch.manual_seed(42)
# 1. 数据预处理,该写法非常类似于管道pipeline
# transforms 模块提供了一系列常用的图像预处理操作# 先归一化,再标准化
transform = transforms.Compose([transforms.ToTensor(),  # 转换为张量并归一化到[0,1]transforms.Normalize((0.1307,), (0.3081,))  # MNIST数据集的均值和标准差,这个值很出名,所以直接使用
])
# 2. 加载MNIST数据集,如果没有会自动下载
train_dataset = datasets.MNIST(root='./data',train=True,download=True,transform=transform
)test_dataset = datasets.MNIST(root='./data',train=False,transform=transform
)

一、Dataset类

现在我们想要取出来一个图片,看看长啥样,因为datasets.MNIST本质上集成了torch.utils.data.Dataset,所以自然需要有对应的方法。

import matplotlib.pyplot as plt# 随机选择一张图片,可以重复运行,每次都会随机选择
sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item() # 随机选择一张图片的索引
# len(train_dataset) 表示训练集的图片数量;size=(1,)表示返回一个索引;torch.randint() 函数用于生成一个指定范围内的随机数,item() 方法将张量转换为 Python 数字
image, label = train_dataset[sample_idx] # 获取图片和标签

__getitem__方法

# 示例代码
class MyList:def __init__(self):self.data = [10, 20, 30, 40, 50]def __getitem__(self, idx):return self.data[idx]# 创建类的实例
my_list_obj = MyList()
# 此时可以使用索引访问元素,这会自动调用__getitem__方法
print(my_list_obj[2])  # 输出:30

__len__方法

class MyList:def __init__(self):self.data = [10, 20, 30, 40, 50]def __len__(self):return len(self.data)# 创建类的实例
my_list_obj = MyList()
# 使用len()函数获取元素数量,这会自动调用__len__方法
print(len(my_list_obj))  # 输出:5

# minist数据集的简化版本
class MNIST(Dataset):def __init__(self, root, train=True, transform=None):# 初始化:加载图片路径和标签self.data, self.targets = fetch_mnist_data(root, train) # 这里假设 fetch_mnist_data 是一个函数,用于加载 MNIST 数据集的图片路径和标签self.transform = transform # 预处理操作def __len__(self): return len(self.data)  # 返回样本总数def __getitem__(self, idx): # 获取指定索引的样本# 获取指定索引的图像和标签img, target = self.data[idx], self.targets[idx]# 应用图像预处理(如ToTensor、Normalize)if self.transform is not None: # 如果有预处理操作img = self.transform(img) # 转换图像格式# 这里假设 img 是一个 PIL 图像对象,transform 会将其转换为张量并进行归一化return img, target  # 返回处理后的图像和标签
# 可视化原始图像(需要反归一化)
def imshow(img):img = img * 0.3081 + 0.1307  # 反标准化npimg = img.numpy()plt.imshow(npimg[0], cmap='gray') # 显示灰度图像plt.show()print(f"Label: {label}")
imshow(image)

二、Dataloader类

# 3. 创建数据加载器
train_loader = DataLoader(train_dataset,batch_size=64, # 每个批次64张图片,一般是2的幂次方,这与GPU的计算效率有关shuffle=True # 随机打乱数据
)test_loader = DataLoader(test_dataset,batch_size=1000 # 每个批次1000张图片# shuffle=False # 测试时不需要打乱数据
)

@浙大疏锦行

http://www.lqws.cn/news/464869.html

相关文章:

  • 分布式锁的四种实现方式:从原理到实践
  • 高云GW5AT-LV60 FPGA图像处理板
  • React Native自定义底部弹框
  • Docker高级管理--容器通信技术与数据持久化
  • 华为云Flexus+DeepSeek征文|体验华为云ModelArts快速搭建Dify-LLM应用开发平台并创建b站视频总结大模型
  • Java ArrayList集合和HashSet集合详解
  • 【自动鼠标键盘控制器|支持图像识别】
  • 从代码学习深度学习 - 预训练BERT PyTorch版
  • 文本分类与聚类:让信息“各归其位”的实用方法
  • 最具有实际意义价值的比赛项目
  • CMS与G1的并发安全秘籍:如何在高并发的垃圾回收中保持正确性?
  • 【开源初探】基于 Qwen2.5VL的文档解析工具:docext
  • 【Linux-shell】探索Dialog 工具在 Shell 图形化编程中的高效范式重构
  • synchronized 和 ReentrantLock 的区别
  • 探索 Oracle Database 23ai 中的 SQL 功能
  • 团结引擎 1.5.0 更新 | OpenHarmony 平台开发体验全面升级,突破游戏类应用帧率限制
  • CertiK联创顾荣辉将于港大活动发表演讲,分享Web3安全与发展新视角
  • (LeetCode 面试经典 150 题) 80. 删除有序数组中的重复项 II (双指针、栈)
  • AI与SEO关键词协同进化
  • SQL关键字三分钟入门:INSERT INTO —— 插入数据详解
  • Armbian 开机启动点灯脚本
  • 【C++特殊工具与技术】局部类
  • 从事登高架设作业需要注意哪些安全事项?
  • 57-Oracle SQL Profile(23ai)实操
  • 内容搜索软件AnyTXT.Searcher忘记文件名也能搜,全文检索 1 秒定位文件
  • Java求职者面试指南:微服务技术与源码原理深度解析
  • Rabbitmq集成springboot,手动确认消息basicAck、basicNack、basicReject的使用
  • 在微信小程序wxml文件调用函数实现时间转换---使用wxs模块实现
  • WevServer实现:异步日志写与HTTP连接
  • Zephyr 调试实用指南:日志系统、Shell CLI 与 GDB 全面解析