当前位置: 首页 > news >正文

【PyTorch】保存和加载模型

目录

■state_dict

■用于推理的保存和加载模型

保存/加载state_dict

保存/加载整个模型

以 TorchScript 格式导出/加载模型

■保存和加载用于推断和/或恢复训练的一般检查点(Checkpoint)

■将多个模型保存在一个文件中

■使用来自不同模型的参数进行暖启动(Warmstarting)模型

■跨设备保存和加载模型

保存在GPU,加载到CPU

保存在GPU,加载到GPU

保存在CPU,加载到GPU

■保存torch.nn.DataParallel模型



state_dict

在 PyTorch 中,torch.nn.Module模型可学习的参数(即权重和偏差) 包含在模型的参数中 (model.parameters())。state_dict 只是一个 Python 字典对象,将每层映射到其参数张量。 请注意,只有具有可学习参数的图层(卷积层, 线性层等)和注册缓冲区(batchnorm的 running_mean) 在模型的 state_dict 中有条目。优化器物体(torch.optim) 也有 state_dict,其中包含 有关优化器状态以及超参数的信息使用。因为 state_dict 对象是 Python 字典,它们可以很容易地保存,更新,更改和恢复,增加了大量的模块化 PyTorch 模型和优化器。

示例

从 simple module中的state_dict使用的 state_dict训练分类器的教程。

# Define model
class TheModelClass(nn.Module):def __init__(self):super(TheModelClass, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# Initialize model
model = TheModelClass()# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():print(param_tensor, "\t", model.state_dict()[param_tensor].size())# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():print(var_name, "\t", optimizer.state_dict()[var_name])

输出

Model's state_dict:
conv1.weight     torch.Size([6, 3, 5, 5])
conv1.bias   torch.Size([6])
conv2.weight     torch.Size([16, 6, 5, 5])
conv2.bias   torch.Size([16])
fc1.weight   torch.Size([120, 400])
fc1.bias     torch.Size([120])
fc2.weight   torch.Size([84, 120])
fc2.bias     torch.Size([84])
fc3.weight   torch.Size([10, 84])
fc3.bias     torch.Size([10])Optimizer's state_dict:
state    {}
param_groups     [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4675713712, 4675713784, 4675714000, 4675714072, 4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]}]

用于推理的保存和加载模型

保存/加载state_dict

保存:

torch.save(model.state_dict(), PATH)

加载:

model = TheModelClass(*args, **kwargs)model.load_state_dict(torch.load(PATH, weights_only=True))model.eval()

PyTorch的1.6版本切换torch.save使用一个新的 zip file-based格式。torch.load仍然保留能力, 以旧格式加载文件。如果出于任何原因想要 torch.save使用旧格式,通过kwarg参数_use_new_zipfile_serialization=False

一个常见的PyTorch惯例是使用任一方法保存模型。.pt.pth文件扩展。

记住,在运行推理之前,必须使用model.eval()设置 dropout 和 batch 正则化层到评估模式。不这样做,会产生不一致的推理结果。

注意,load_state_dict()function 需要字典对象,不是保存对象的路径。这意味着必须在将 state_dict 传递给 state_dict 之前,反序列化 load_state_dict()功能。例如,无法使用 model.load_state_dict(PATH)

如果只打算保持最好的模型(根据获得验证损失),不要忘记 best_model_state = model.state_dict()返回对状态的引用,而不是其副本!必须序列化 best_model_state或使用best_model_state = deepcopy(model.state_dict())否则best_model_state将通过后续训练不断更新迭代。因此,最终的模型状态将是超拟合模型的状态。

保存/加载整个模型

保存:

torch.save(model, PATH)

加载:

# Model class must be defined somewheremodel = torch.load(PATH, weights_only=False)model.eval()

以 TorchScript 格式导出/加载模型

使用经过训练的模型进行推理的一种常见方法是使用 TorchScript,一个中间体 PyTorch模型的表示,该模型可以在Python,高性能环境,如C ++中运行。TorchScript实际上是推荐的模型格式 用于缩放的推理和部署。

使用 TorchScript 格式,将能够加载导出的模型和运行推论而不定义模型类

导出:

model_scripted = torch.jit.script(model) # Export to TorchScriptmodel_scripted.save('model_scripted.pt') # Save

加载:

model = torch.jit.load('model_scripted.pt')model.eval()

有关 TorchScript 的更多信息,可访问专用 tutorials教程。将熟悉跟踪转换并学习如何 在 C++ 环境中运行 TorchScript 模块。

保存和加载用于推断和/或恢复训练的一般检查点(Checkpoint)

保存:

torch.save({'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss,...}, PATH)

加载:

model = TheModelClass(*args, **kwargs)optimizer = TheOptimizerClass(*args, **kwargs)checkpoint = torch.load(PATH, weights_only=True)model.load_state_dict(checkpoint['model_state_dict'])optimizer.load_state_dict(checkpoint['optimizer_state_dict'])epoch = checkpoint['epoch']loss = checkpoint['loss']model.eval()# - or -model.train()

保存一般检查点时,用于推理或 恢复训练,必须保存的不仅仅是模型的 state_dict 的相关内容,保存优化器的 state_dict 也很重要, 因为它包含作为模型更新的缓冲区和记录的训练更新参数。其他可能想要保存的项目是epoch信息,最新记录的训练损失,外部 torch.nn.Embedding层等。因此,这样的检查点通常比单模型大2~3倍。

要保存多个组件,请将它们组织在字典中并使用 torch.save()序列化字典。常见的PyTorch 惯例是使用.tar文件扩展保存这些检查点。

要加载项目,首先初始化模型和优化器,然后加载本地使用的字典torch.load()。从这里,只需查询字典,可以很容易地访问保存的项目。

记住,在运行推理之前,model.eval()设置 dropout 和 batch 正则化层到评估模式。 不这样做会产生不一致的推理结果。如果希望恢复训练,调用model.train()确保这些层处于训练模式。

将多个模型保存在一个文件中

保存:

torch.save({'modelA_state_dict': modelA.state_dict(),'modelB_state_dict': modelB.state_dict(),'optimizerA_state_dict': optimizerA.state_dict(),'optimizerB_state_dict': optimizerB.state_dict(),...}, PATH)

加载:

modelA = TheModelAClass(*args, **kwargs)modelB = TheModelBClass(*args, **kwargs)optimizerA = TheOptimizerAClass(*args, **kwargs)optimizerB = TheOptimizerBClass(*args, **kwargs)checkpoint = torch.load(PATH, weights_only=True)modelA.load_state_dict(checkpoint['modelA_state_dict'])modelB.load_state_dict(checkpoint['modelB_state_dict'])optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])modelA.eval()modelB.eval()# - or -modelA.train()modelB.train()

保存由多个torch.nn.Modules组成的模型时,如GAN,一个序列到序列模型,或一个模型的集合体,遵循与保存一般检查点时相同的方法。换句话说,保存每个模型state_dict 的字典和相应的优化器。如前所述,通过简单地附加它们到字典,可以在恢复训练中保存任何其他的条目。

使用来自不同模型的参数进行暖启动(Warmstarting)模型

保存:

torch.save(modelA.state_dict(), PATH)

加载:

modelB = TheModelBClass(*args, **kwargs)modelB.load_state_dict(torch.load(PATH, weights_only=True), strict=False)

部分加载模型或加载部分模型是常见的迁移学习或训练新复杂模型时的场景。 利用经过训练的参数,即使只有少数可用,也会有所帮助暖启动训练过程,并希望帮助模型收敛,比从零开始训练快得多。

是否从部分state_dict加载这是缺少一些键,或者用比正在加载的模型更多的键,在load_state_dict()函数忽略非匹配键,可以设置strict为 False 。

如果要将参数从一层加载到另一层,但需要一些键不匹配。只需更改正在加载的state_dict中参数键的名称,使其与正在加载的模型中的键相匹配。

跨设备保存和加载模型

保存在GPU,加载到CPU

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device('cpu')model = TheModelClass(*args, **kwargs)model.load_state_dict(torch.load(PATH, map_location=device, weights_only=True))

在使用GPU训练的模型在CPU上加载时,通过将 torch.device('cpu')传递给torch.load()函数中的map_location参数。在这种情况下,使用map_location参数将张量底层的存储动态地重新映射到CPU设备。

保存在GPU,加载到GPU

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device("cuda")model = TheModelClass(*args, **kwargs)model.load_state_dict(torch.load(PATH, weights_only=True))model.to(device)# Make sure to call input = input.to(device) on any input tensors that you feed to the model

在GPU上加载经过训练并保存在GPU上的模型时,只需使用 model.to(torch.device('cuda'))转换初始化model为 CUDA 优化模型。一定要在所有模型输入上准备模型的数据使用 .to(torch.device('cuda'))函数。请注意,调用 my_tensor.to(device)返回一个my_tensor新的副本在GPU上。不覆盖 my_tensor。因此,记得手动覆盖张量: my_tensor=my_tensor.to(torch.device('cuda'))

保存在CPU,加载到GPU

保存:

torch.save(model.state_dict(), PATH)

加载:

device = torch.device("cuda")model = TheModelClass(*args, **kwargs)model.load_state_dict(torch.load(PATH, weights_only=True, map_location="cuda:0"))  # Choose whatever GPU device number you wantmodel.to(device)# Make sure to call input = input.to(device) on any input tensors that you feed to the model

当在GPU上加载经过训练并保存在CPU上的模型时,将torch.load()函数中的map_location参数设置为cuda:device_id。这将模型加载到给定的GPU设备上。接下来,确保调用model.to(torch.device('cuda'))将模型的参数张量转换为cuda张量。最后,确保在所有模型输入上使用.to(torch.device('cuda'))函数,为cuda优化模型准备数据。注意,调用my_tensor.to(device)会在GPU上返回一个my_tensor的新副本。它不会覆盖my_tensor。因此,请记住手动覆盖张量:my_tensor = my_tensor.to(torch.device('cuda')))

保存torch.nn.DataParallel模型

保存:

torch.save(model.module.state_dict(), PATH)

加载:

# Load to whatever device you want

torch.nn.DataParallel是模型包装器,它支持并行GPU的使用。要通用地保存DataParallel模型,请保存model.module.state_dict()。这样,就可以灵活地以任何方式将模型加载到任何设备上。

至此,本文分享的内容就结束了。

http://www.lqws.cn/news/521785.html

相关文章:

  • 【cursor实战】分析python下并行、串行计算性能
  • <六> k8s + promtail + loki + grafana初探
  • 深度学习入门--(二)感知机
  • 利用代理IP爬取Shopee网页数据
  • C/C++中调用Java实现
  • keil5 cannot copy license file to “Download“ folder
  • 阿里云Web应用防火墙3.0使用CNAME接入传统负载均衡CLB
  • 量学云讲堂王岩江宇龙2025年第58期视频 主课正课系统课+收评
  • 【EDA软件】【应用功能子模块网表提供和加载编译方法】
  • Web层注解
  • 浙大/浙工大合作iMeta(1区 | IF 33.2):单微生物RNA-seq + 聚类解析肠道关键种代谢功能
  • MySQL常用函数性能优化及索引影响分析
  • ES和 Kafka 集群搭建过程中的典型问题、配置规范及最佳实践
  • C++11原子操作:从入门到精通
  • Fisco Bcos学习 - 搭建第一个区块链网络
  • selenium UI自动化元素定位中classname和CSS区别
  • Spring Boot中日志管理与异常处理
  • 【评估指标】MAP@k (目标检测)
  • docker start mysql失败,解决方案
  • 深入理解Redis整数集合(intset)的升级策略:内存优化的核心魔法
  • FPGA笔记——ZYNQ-7020运行PS端的USB 2.0端口作为硬盘
  • 基于大数据的社会治理与决策支持方案PPT(66页)
  • IE浏览器使用
  • 系统思考:预防重于治疗
  • 如何搭建CDN服务器?
  • 将 Docker的存储目录迁移到空间更大的磁盘
  • 搭建自己的WEB应用防火墙
  • mbedtls ssl handshake error,res:-0x2700
  • 数据库数据恢复—SQL Server数据库被加密如何恢复?
  • Fisco Bcos学习 - 搭建星形拓扑组网