当前位置: 首页 > news >正文

python打卡day43@浙大疏锦行

作业:

kaggle找到一个图像数据集,用cnn网络进行训练并且用grad-cam做可视化

进阶:并拆分成多个文件

一、配置文件 (config.py)

import torchclass Config:# 数据集配置DATASET_PATH = "/path/to/kaggle/dataset"IMAGE_SIZE = 224BATCH_SIZE = 32# 模型配置NUM_CLASSES = 10PRETRAINED = True# 训练配置EPOCHS = 10LR = 0.001DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

二、数据加载 (dataset.py)

from torchvision import transforms, datasets
from config import Configdef get_dataloaders():train_transform = transforms.Compose([transforms.Resize((Config.IMAGE_SIZE, Config.IMAGE_SIZE)),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])test_transform = transforms.Compose([transforms.Resize((Config.IMAGE_SIZE, Config.IMAGE_SIZE)),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])train_set = datasets.ImageFolder(f"{Config.DATASET_PATH}/train",transform=train_transform)test_set = datasets.ImageFolder(f"{Config.DATASET_PATH}/test",transform=test_transform)return train_set, test_set

三、CNN模型定义 (model.py)

import torch.nn as nn
from torchvision import models
from config import Configclass CNNModel(nn.Module):def __init__(self):super().__init__()base_model = models.resnet18(pretrained=Config.PRETRAINED)num_features = base_model.fc.in_featuresbase_model.fc = nn.Linear(num_features, Config.NUM_CLASSES)self.model = base_modeldef forward(self, x):return self.model(x)def get_feature_maps(self):return self.model.layer4[-1].conv2

四、训练脚本 (train.py)

import torch
from torch.utils.data import DataLoader
from dataset import get_dataloaders
from model import CNNModel
from config import Config
from utils import save_checkpointdef train():train_set, test_set = get_dataloaders()train_loader = DataLoader(train_set, batch_size=Config.BATCH_SIZE, shuffle=True)test_loader = DataLoader(test_set, batch_size=Config.BATCH_SIZE)model = CNNModel().to(Config.DEVICE)criterion = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=Config.LR)for epoch in range(Config.EPOCHS):model.train()for inputs, labels in train_loader:inputs, labels = inputs.to(Config.DEVICE), labels.to(Config.DEVICE)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# 验证代码...save_checkpoint(model, epoch)if __name__ == "__main__":train()

五、Grad-CAM实现 (gradcam.py)

import torch
import numpy as np
import cv2
import matplotlib.pyplot as pltclass GradCAM:def __init__(self, model, target_layer):self.model = modelself.gradients = Noneself.activations = Nonetarget_layer.register_forward_hook(self.save_activations)target_layer.register_backward_hook(self.save_gradients)def save_activations(self, module, input, output):self.activations = output.detach()def save_gradients(self, module, grad_input, grad_output):self.gradients = grad_output[0].detach()def __call__(self, x, class_idx=None):# 前向传播output = self.model(x)if class_idx is None:class_idx = torch.argmax(output, dim=1)# 反向传播self.model.zero_grad()one_hot = torch.zeros_like(output)one_hot[0][class_idx] = 1output.backward(gradient=one_hot)# 计算权重weights = torch.mean(self.gradients, dim=(2, 3), keepdim=True)cam = torch.sum(self.activations * weights, dim=1)cam = torch.relu(cam)# 后处理cam = cam.squeeze().cpu().numpy()cam = cv2.resize(cam, (x.shape[3], x.shape[2]))cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam))return camdef visualize_gradcam(model, image_tensor, original_image):target_layer = model.get_feature_maps()gradcam = GradCAM(model, target_layer)cam = gradcam(image_tensor.unsqueeze(0).to(Config.DEVICE))heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)heatmap = np.float32(heatmap) / 255superimposed_img = heatmap + np.float32(original_image)superimposed_img = superimposed_img / np.max(superimposed_img)plt.imshow(superimposed_img)plt.axis('off')plt.show()

六、工具函数 (utils.py)

import torch
import osdef save_checkpoint(model, epoch, path="checkpoints"):if not os.path.exists(path):os.makedirs(path)torch.save({'epoch': epoch,'model_state_dict': model.state_dict(),}, f"{path}/checkpoint_{epoch}.pth")def load_checkpoint(model, path):checkpoint = torch.load(path)model.load_state_dict(checkpoint['model_state_dict'])return model

http://www.lqws.cn/news/77383.html

相关文章:

  • 3,信号与槽机制
  • Eigen库介绍以及模块划分和相关示例代码
  • NodeJS全栈WEB3面试题——P3Web3.js / Ethers.js 使用
  • Cursor 0.51 全网首歌新功能深度体验:Generate Memories 让 AI 编程助手拥有“记忆“
  • 【DAY37】早停策略和模型权重的保存
  • 微软PowerBI考试 PL-300学习指南
  • 【001】利用github搭建静态网站_essay
  • Go整合Redis2.0发布订阅
  • 6.2本日总结
  • leetcode90.子集II:排序与同层去重的回溯优化策略
  • Python 在金融中的应用- Part 1
  • Pytorch知识点2
  • dify应用探索
  • 【Go语言】Ebiten游戏库开发者文档 (v2.8.8)
  • 字节跳动开源图标库:2000+图标一键换肤的魔法
  • 神经网络中的梯度消失与梯度爆炸
  • 代码随想录60期day54
  • 牛客周赛 Round 94
  • 聚类分析 | MATLAB实现基于SOM自组织特征映射聚类可视化
  • 数据结构之排序
  • 对抗攻击 Adversarial Attack
  • 实现按天更新vintage并热力图可视化
  • 【QT控件】QWidget 常用核心属性介绍 -- 万字详解
  • Python中sys模块详解
  • spring-boot接入websocket教程以及常见问题解决
  • 基于 51 单片机的智能饮水机控制系统设计与实现
  • 模块二:C++核心能力进阶(5篇) 篇一:《STL源码剖析:vector扩容策略与迭代器失效》
  • 达芬奇(DaVinci Resolve)下载安装教程
  • B树和B+树
  • MySQL DDL操作全解析:从入门到精通,包含索引视图分区表等全操作解析