当前位置: 首页 > news >正文

Python打卡训练营学习记录Day46

作业:

  1. 今日代码较多,理解逻辑即可
  2. 对比不同卷积层特征图可视化的结果(可选)

一、CNN特征图可视化实现

import torch
import matplotlib.pyplot as pltdef visualize_feature_maps(model, input_tensor):# 注册钩子获取中间层输出features = []def hook(module, input, output):features.append(output.detach().cpu())# 选择不同卷积层观察target_layers = [model.layer1[0].conv1,model.layer2[0].conv1,model.layer3[0].conv1]handles = []for layer in target_layers:handles.append(layer.register_forward_hook(hook))# 前向传播with torch.no_grad():_ = model(input_tensor.unsqueeze(0))# 移除钩子for handle in handles:handle.remove()# 可视化不同层特征图fig, axes = plt.subplots(len(target_layers), 5, figsize=(20, 10))for i, feat in enumerate(features):for j in range(5):  # 显示前5个通道axes[i,j].imshow(feat[0, j].numpy(), cmap='viridis')axes[i,j].axis('off')plt.show()

二、通道注意力模块示例

class ChannelAttention(nn.Module):def __init__(self, in_channels, reduction=16):super().__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.fc = nn.Sequential(nn.Linear(in_channels, in_channels // reduction),nn.ReLU(),nn.Linear(in_channels // reduction, in_channels),nn.Sigmoid())def forward(self, x):# ... existing code ...return x * attention_weights  # 应用注意力权重

三、热力图生成方法

def generate_heatmap(model, input_img):# 前向传播获取梯度model.eval()input_img.requires_grad = Trueoutput = model(input_img)pred_class = output.argmax(dim=1).item()# 反向传播计算梯度model.zero_grad()output[0, pred_class].backward()# 获取最后一个卷积层的梯度gradients = model.layer4[1].conv2.weight.gradpooled_gradients = torch.mean(gradients, dim=[0,2,3])# 生成热力图activations = model.layer4[1].conv2.activations.detach()for i in range(activations.shape[1]):activations[:,i,:,:] *= pooled_gradients[i]heatmap = torch.mean(activations, dim=1).squeeze()return heatmap

@浙大疏锦行

http://www.lqws.cn/news/192889.html

相关文章:

  • 第7篇:中间件全链路监控与 SQL 性能分析实践
  • 微软推出SQL Server 2025技术预览版,深化人工智能应用集成
  • VBA清空数据
  • Python训练营---Day46
  • [大A量化专栏] QMT常见问题QA
  • 5G网络中频段的分配
  • DAY45 可视化
  • 每日算法 -【Swift 算法】电话号码字母组合
  • gvim比较两个文件不同并合并差异
  • 和芯 SL6341 (内置FLASH) 国产USB 3.0HUB芯片 替代 GL3510 VL817
  • Spring Boot + Prometheus 实现应用监控(基于 Actuator 和 Micrometer)
  • Langgraph实战--在Agent中加入人工反馈
  • 13.MySQL用户管理
  • 力扣100-移动0
  • Android Test3 获取的ANDROID_ID值不同
  • ​​TPS3808​​低静态电流、可编程延迟电压监控电路,应用笔记
  • 初识AI Agent
  • Rust 开发环境搭建
  • 精益数据分析(95/126):Socialight的定价转型启示——B2B商业模式的价格策略与利润优化
  • 超声波清洗设备的清洗效果如何?
  • CMA软件产品测试报告在哪申请?
  • AI对测试行业的应用
  • 中医的十问歌和脉象分类
  • 基于KNN算法的入侵检测模型设计与实现【源码+文档】
  • 【深度学习新浪潮】RoPE对大模型的外推性有什么影响?
  • yolov8自训练模型作为预训练权重【增加新类别】注意事项
  • 事件监听 ——CAD C#二次开发
  • react 常见的闭包陷阱深入解析
  • 几何引擎对比:OpenCasCade、ACIS、Parasolid和CGM
  • n皇后问题的 C++ 回溯算法教学攻略