当前位置: 首页 > news >正文

day49 python 注意力热图

目录

一、注意力热图简介

二、类激活图(CAM)原理

三、基于PyTorch的CAM实现

(一)加载库与模型

(二)输入图片预处理与模型预测

(三)生成注意力热图

(四)可视化与保存结果

四、实验结果与分析


一、注意力热图简介

注意力热图是一种强大的可视化工具,能够直观地展示神经网络在处理输入图像时的关注区域。它可以帮助我们理解模型是如何做出决策的,从而更好地优化和改进模型。在实际应用中,注意力热图广泛应用于图像分类、目标检测等领域,为研究人员提供了宝贵的洞察。

通过阅读大量相关资料,我发现大多数方法都是基于神经网络的输出特征图来生成注意力热图。具体来说,可以使用任意层的特征图,但通常选择最后一个卷积层的输出特征图。将特征图调整到输入图像的大小后,通过特定的函数将其叠加到原图像上,即可得到注意力热图。虽然这个过程看似简单,但在实际操作中,仍有许多细节需要注意。在本次实验中,我主要采用了类激活图(CAM)方法来生成注意力热图。

二、类激活图(CAM)原理

类激活图(CAM)方法是由论文《Learning Deep Features for Discriminative Localization》提出的一种经典方法。其核心思想是利用神经网络的卷积层特征图和分类层权重来生成注意力热图。以下是CAM方法的具体步骤:

  1. 获取输出特征图:从神经网络中提取输出特征图,其形状为[B, C, H, W],其中B为批量大小,C为最后一个卷积层的输出通道数,H和W分别为特征图的宽度和高度。如果输入一张图片,则B=1。

  2. 获取分类层权重:提取训练好的模型的分类头权重classifier.weight,注意分类层的输入通道数必须与输出特征图的通道数匹配。

  3. 加权求和生成注意力热图:将每个通道的特征图与分类层权重进行加权求和,最终得到每一类的注意力热图。

三、基于PyTorch的CAM实现

以下是使用PyTorch实现CAM的完整代码,代码中包含了详细的注释,方便读者理解每个步骤的具体操作。

(一)加载库与模型

import os
import numpy as np
import cv2
import torch
import torch.nn as nn
from torchvision import transforms, datasets
from PIL import Image# 加载自己的网络
from model import modelclass_num = 5model_ft = model(num_classes=class_num)
model_ft.load_state_dict(torch.load('pretrain.pth', map_location=lambda storage, loc: storage))model_features = nn.Sequential(*list(model_ft.children())[:-2])
fc_weights = model_ft.state_dict()['classifier.weight'].cpu().numpy()
class_ = {0: 'car', 1: 'bird', 2: 'tree', 3: 'sky', 4: 'person'}
model_ft.eval()
model_features.eval()

(二)输入图片预处理与模型预测

data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}img_path = '/data/test.jpg'  # 单张测试
_, img_name = os.path.split(img_path)
features_blobs = []
img = Image.open(img_path).convert('RGB')
img_tensor = data_transform['val'](img).unsqueeze(0)  # [1,3,224,224]
features = model_features(img_tensor).detach().cpu().numpy()  # [1,960,7,7]logit = model_ft(img_tensor)  # [1,2] -> [ 3.3207, -2.9495]
h_x = torch.nn.functional.softmax(logit, dim=1).data.squeeze()  # tensor([0.9981, 0.0019])probs, idx = h_x.sort(0, True)  # 按概率从大到小排列
probs = probs.cpu().numpy()
idx = idx.cpu().numpy()for i in range(class_num):print('{:.3f} -> {}'.format(probs[i], class_[idx[i]]))  # 打印预测结果

(三)生成注意力热图

def returnCAM(feature_conv, weight_softmax, class_idx):bz, nc, h, w = feature_conv.shapeoutput_cam = []for idx in class_idx:feature_conv = feature_conv.reshape((nc, h * w))cam = weight_softmax[idx].dot(feature_conv.reshape((nc, h * w)))cam = cam.reshape(h, w)cam_img = (cam - cam.min()) / (cam.max() - cam.min())cam_img = np.uint8(255 * cam_img)output_cam.append(cam_img)return output_camCAMs = returnCAM(features, fc_weights, idx)  # 输出预测概率最大的特征图集对应的CAM
print(img_name + ' output for the top1 prediction: %s' % class_[idx[0]])

(四)可视化与保存结果

img = cv2.imread(img_path)
height, width, _ = img.shape
heatmap = cv2.applyColorMap(cv2.resize(CAMs[0], (width, height)), cv2.COLORMAP_JET)
result = heatmap * 0.3 + img * 0.5text = '%s %.2f%%' % (class_[idx[0]], probs[0] * 100)
cv2.putText(result, text, (210, 40), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.9,color=(123, 222, 238), thickness=2, lineType=cv2.LINE_AA)CAM_RESULT_PATH = r'/data/heatmap/'
if not os.path.exists(CAM_RESULT_PATH):os.mkdir(CAM_RESULT_PATH)
image_name_ = img_name.split(".")[-2]
cv2.imwrite(os.path.join(CAM_RESULT_PATH, image_name_ + '_heatmap.jpg'), result)

四、实验结果与分析

通过上述代码,我成功生成了输入图像的注意力热图,并将其与原图叠加显示。从结果可以看出,注意力热图清晰地标注出了模型在做出预测时关注的区域。例如,在对“car”类别进行预测时,热图主要集中在车辆的轮廓和关键部位,这表明模型能够准确地识别出车辆的特征区域。

@浙大疏锦行

http://www.lqws.cn/news/207577.html

相关文章:

  • 将单体架构项目拆分成微服务时的两种工程结构
  • Spring Cloud Hystrix熔断机制:构建高可用微服务的利器
  • OkHttp 3.0源码解析:从设计理念到核心实现
  • 向日葵远程控制debian无法进入控制画面的解决方法
  • Git开发实战
  • ELK日志管理框架介绍
  • WPS中将在线链接转为图片
  • JAVA实战开源项目:信息技术知识赛系统 (Vue+SpringBoot) 附源码
  • 一.设计模式的基本概念
  • 八、【ESP32开发全栈指南:UDP客户端】
  • CSS 预处理器与工具
  • 1.4 Node.js 的 TCP 和 UDP
  • [HCTF 2018]admin 1
  • n8n + AI Agent:AI 自动化生成测试用例并支持导出 Excel
  • NPOI Excel用OLE对象的形式插入文件附件以及插入图片
  • Model Context Protocol (MCP) 是一个前沿框架
  • 多文化软件团队的协作之道:在认知差异中寻找协同的支点
  • 基于Scala实现Flink的三种基本时间窗口操作
  • 20250607-在Ubuntu中使用Anaconda创建新环境并使用本地的备份文件yaml进行配置
  • 网络协议通俗易懂详解指南
  • 交叉熵损失函数和极大似然估计是什么,区别是什么
  • 【数据结构初阶】--算法复杂度的深度解析
  • Canal环境搭建并实现和ES数据同步
  • Web前端基础:JavaScript
  • Go语言堆内存管理
  • 设计模式-建造者模式
  • 备份还原打印机驱动
  • Linux【4】------RK3568启动和引导顺序
  • grep、wc 与管道符快速上手指南
  • 10.Linux进程信号