当前位置: 首页 > news >正文

DAY 39 图像数据与显存

@浙大疏锦行https://blog.csdn.net/weixin_45655710
知识点回顾
  1. 图像数据的格式:灰度和彩色数据
  2. 模型的定义
  3. 显存占用的4种地方
    1. 模型参数+梯度参数
    2. 优化器参数
    3. 数据批量所占显存
    4. 神经元输出中间状态
  4. batchisize和训练的关系

作业:今日代码较少,理解内容即可

黑白图像模型的定义

# 先归一化,再标准化
transform = transforms.Compose([transforms.ToTensor(),  # 转换为张量并归一化到[0,1]transforms.Normalize((0.1307,), (0.3081,))  # MNIST数据集的均值和标准差,这个值很出名,所以直接使用
])
import matplotlib.pyplot as plt# 2. 加载MNIST数据集,如果没有会自动下载
train_dataset = datasets.MNIST(root='./data',train=True,download=True,transform=transform
)test_dataset = datasets.MNIST(root='./data',train=False,transform=transform
)# %%
# 定义两层MLP神经网络
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.flatten = nn.Flatten()  # 将28x28的图像展平为784维向量self.layer1 = nn.Linear(784, 128)  # 第一层:784个输入,128个神经元self.relu = nn.ReLU()  # 激活函数self.layer2 = nn.Linear(128, 10)  # 第二层:128个输入,10个输出(对应10个数字类别)def forward(self, x):x = self.flatten(x)  # 展平图像x = self.layer1(x)   # 第一层线性变换x = self.relu(x)     # 应用ReLU激活函数x = self.layer2(x)   # 第二层线性变换,输出logitsreturn x# 初始化模型
model = MLP()device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)  # 将模型移至GPU(如果可用)from torchsummary import summary  # 导入torchsummary库
print("\n模型结构信息:")
summary(model, input_size=(1, 28, 28))  # 输入尺寸为MNIST图像尺寸

彩色图像模型的定义

class MLP(nn.Module):def __init__(self, input_size=3072, hidden_size=128, num_classes=10):super(MLP, self).__init__()# 展平层:将3×32×32的彩色图像转为一维向量# 输入尺寸计算:3通道 × 32高 × 32宽 = 3072self.flatten = nn.Flatten()# 全连接层self.fc1 = nn.Linear(input_size, hidden_size)  # 第一层self.relu = nn.ReLU()self.fc2 = nn.Linear(hidden_size, num_classes)  # 输出层def forward(self, x):x = self.flatten(x)  # 展平:[batch, 3, 32, 32] → [batch, 3072]x = self.fc1(x)      # 线性变换:[batch, 3072] → [batch, 128]x = self.relu(x)     # 激活函数x = self.fc2(x)      # 输出层:[batch, 128] → [batch, 10]return x# 初始化模型
model = MLP()device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)  # 将模型移至GPU(如果可用)from torchsummary import summary  # 导入torchsummary库
print("\n模型结构信息:")
summary(model, input_size=(3, 32, 32))  # CIFAR-10 彩色图像(3×32×32)

核心逻辑可以总结为以下这条线:

图像数据的特殊性 → 如何用代码表示图像 → 如何构建网络处理图像 → 如何高效训练网络

1. 图像数据的特殊性:从一维到多维
  • 之前:我们处理的表格数据,每个样本是一行,可以看作一个一维特征向量,例如 (特征1, 特征2, ..., 特征N)

  • 现在:图像数据具有空间结构。一张图片不仅有像素值,还有这些像素在二维空间中的位置关系。因此,它不能被简单地看作一维向量,而是用一个三维或更高维的张量 (Tensor) 来表示。

    • 形状约定 (PyTorch)(通道数, 高度, 宽度)
      • 灰度图 (如MNIST):只有一个颜色通道,形状为 (1, 28, 28)
      • 彩色图 (如CIFAR-10):有红(R)、绿(G)、蓝(B)三个颜色通道,形状为 (3, 32, 32)
2. 如何用代码表示图像:Datasettransforms

为了让模型能“读懂”这些图像,我们需要进行预处理。PyTorch通过torchvision库提供了标准化的流程。

  • transforms (预处理管道):像一个流水线工厂,对每一张原始图片进行加工。
    1. transforms.ToTensor():将图片从常规格式(如PIL Image)转换为PyTorch张量,并顺便将像素值从0-255的范围归一化0-1
    2. transforms.Normalize(均值, 标准差):对归一化后的数据进行标准化,使其数据分布的均值为0,标准差为1。这能帮助神经网络更快、更稳定地学习。
  • Dataset (数据管理员):这个类负责管理整个数据集。它的核心是两个“协议”:
    1. __len__:告诉我们数据集中一共有多少张图片。
    2. __getitem__(i):当我们向它索要第 i 张图片时,它负责从硬盘读取这张图片,用上面的transform流水线进行加工,然后返回加工好的图片数据和它的标签。
3. 如何构建网络处理图像:nn.Module 的适配

我们之前用于处理表格数据的MLP(多层感知机)模型需要做一些调整才能处理图像。

  • nn.Flatten() (展平层):MLP的输入层(nn.Linear)只能接收一维向量。但我们的图像数据是三维的 (通道, 高, 宽)。因此,在送入第一个全连接层之前,必须用nn.Flatten()将图像“拉平”成一个长长的一维向量。
    • MNIST (1x28x28) → 展平为 784 维向量。
    • CIFAR-10 (3x32x32) → 展平为 3072 维向量。
  • 输入/输出层定义
    • 输入层:神经元数量必须等于展平后向量的维度(如784或3072)。
    • 输出层:神经元数量必须等于任务的类别总数(如MNIST和CIFAR-10都是10类)。
4. 如何高效训练网络:DataLoader 与显存管理

图像数据通常很大,不可能一次性全部放入GPU的显存 (VRAM) 中进行训练,否则会立刻导致内存溢出 (Out of Memory, OOM),也就是我们之前遇到的“内核死亡”的最终原因之一。

  • DataLoader (自动上菜服务员)

    • 它从 Dataset(数据仓库)中取出数据。
    • batch_size (批量大小)DataLoader最重要的参数。它决定了每次从仓库中取出多少张图片,打包成一“批”(一个batch),然后送入GPU进行训练。
    • shuffle=True (打乱数据):在每个训练周期(epoch)开始前,它会像洗牌一样把数据顺序完全打乱,这能有效避免模型学到数据顺序的巧合,增强泛化能力。
  • 显存占用与 batch_size 的权衡

    • 显存占用主要来自:模型参数梯度(与参数一样大)、优化器状态(Adam等会额外占用2倍参数大小)以及当前批次的数据和中间计算结果
    • batch_size 越大
      • 优点:能更充分地利用GPU的并行计算能力,训练更快;每个批次的梯度是对更多样本的平均,方向更稳定,有助于模型收敛。
      • 缺点:占用的显存越多。
    • batch_size 越小
      • 优点:占用显存少,不易OOM。
      • 缺点:训练速度慢,梯度更新频繁且不稳定。
  • 最佳实践:通过 nvidia-smi 命令监控显存使用情况,从小到大尝试不同的batch_size,找到一个既能充分利用显存、又不会导致OOM的最佳值。

总结下来的一条线就是: 我们认识到图像数据是多维的,所以我们用Datasettransforms定义和预处理它;接着我们调整了神经网络,用nn.Flatten适配这种多维输入;最后,因为数据量和模型变大,我们引入了DataLoaderbatch_size的概念,来高效、安全地分批训练,从而解决了显存占用的核心问题。

http://www.lqws.cn/news/477145.html

相关文章:

  • ProtoBuf:通讯录4.0实现 序列化能⼒对⽐验证
  • Rust 引用与借用
  • 47.第二阶段x64游戏实战-封包-分析打怪call
  • winform mvvm
  • 关于存储与网络基础的详细讲解(从属GESP二级内容)
  • 【机器学习四大核心任务类型详解】分类、回归、聚类、降维都是什么?
  • 人工智能、机器人最容易取哪些体力劳动和脑力劳动
  • AWS 使用图形化界面创建 EKS 集群(零基础教程)
  • Spring AI 项目实战(十):Spring Boot + AI + DeepSeek 构建智能合同分析技术实践(附完整源码)
  • java中HashMap和ConcurrentHashMap的共性以及区别
  • 《高等数学》(同济大学·第7版)第五章 定积分 第四节反常积分
  • 用可观测工具高效定位和查找设计中深度隐藏的bug
  • 网络安全智能体:重塑重大赛事安全保障新范式
  • 啥是 SaaS
  • [xiaozhi-esp32] 构建智能AI设备 | 开发板抽象层 | 通信协议层
  • 【ELK(Elasticsearch+Logstash+Kibana) 从零搭建实战记录:日志采集与可视化】
  • Elasticsearch Kibana (一)
  • spring碎片
  • 针对数据仓库方向的大数据算法工程师面试经验总结
  • 点点(小红书AI搜索):生活场景的智能搜索助手
  • Typecho博客3D彩色标签云插件(Handsome主题优化版)
  • 2.jupyter切换使用conda虚拟环境的最佳方法
  • 【DataWhale组队学习】AI办公实践与应用
  • Mysql—锁相关面试题(全局锁,表级锁,行级锁)
  • SpringCloudGateway(spel)漏洞复现 Spring + Swagger 接口泄露问题
  • 大零售生态下开源链动2+1模式、AI智能名片与S2B2C商城小程序的协同创新研究
  • Python 前端框架/工具合集
  • python实战项目77:足球运动员数据分析
  • 《高等数学》(同济大学·第7版)第五章 定积分 第三节积分的换元法和分部积分法
  • 在windows上使用file命令