当前位置: 首页 > news >正文

Python Day42

Task: Grad-CAM与Hook函数
1.回调函数
2.lambda函数
3.hook函数的模块钩子和张量钩子
4.Grad-CAM的示例

1. 回调函数

  • 定义:回调函数是作为参数传入到其他函数中的函数,在特定事件发生时被调用。
  • 特点
    • 便于扩展和自定义程序行为。
    • 常用于训练过程中的监控、日志记录、模型保存等场景。
  • 示例
    def callback_function():print("Epoch completed!")# 传递回调函数
    trainer.train(callback=callback_function)
    

2. lambda函数

  • 定义:匿名函数,也叫lambda表达式,用于快速定义简单函数。
  • 语法
    lambda 参数: 表达式
    
  • 示例
    square = lambda x: x ** 2
    print(square(5))  # 输出25
    
  • 用途:简洁地定义短小函数,常用于回调、数据处理等场景。

3. Hook函数的模块钩子和张量钩子

  • Hook函数的作用:用来“钩住”模型内部的中间变量或梯度,从而监控或修改它们。

  • 模块钩子(Module Hooks)

    • 用途:挂载在模型的子模块(如卷积层、线性层)上,捕获或修改输出或梯度。
    • 注册方法
      handle = layer.register_forward_hook(hook_fn)
      
    • 例子
      def hook_fn(module, input, output):print(output.shape)handle = model.layer1[0].register_forward_hook(hook_fn)
      
    • 移除钩子
      handle.remove()
      
  • 张量钩子(Tensor Hooks)

    • 用途:挂载在特定的Tensor上,用于捕获梯度信息。
    • 注册方法
      tensor.register_hook(hook_fn)
      
    • 应用场景:常用于Grad-CAM中,捕获目标层的梯度信息。

4. Grad-CAM的示例

  • 目标:通过反向传播得到特定类别的梯度,结合层的激活值,生成热力图,直观展现模型关注区域。

  • 基本流程

    1. 前向传播,获取目标类别的输出。
    2. 反向传播,计算目标类别对应的梯度。
    3. 使用Hook捕获目标层的激活特征(前向输出)和梯度(反向传播梯度)。
    4. 计算加权和(平均梯度作为权重)以生成热力图。
    5. 叠加热力图到原始图像上。
  • 代码示例(简化版)

    import torch
    import torch.nn.functional as F
    import torchvision.models as models
    import cv2
    import numpy as npmodel = models.resnet50(pretrained=True)
    model.eval()# 定义目标层(例如最后的卷积层)
    target_layer = model.layer4[-1]activations = None
    gradients = None# 定义前向钩子
    def forward_hook(module, input, output):global activationsactivations = output# 定义反向钩子
    def backward_hook(module, grad_input, grad_output):global gradientsgradients = grad_output[0]# 注册钩子
    handle_fw = target_layer.register_forward_hook(forward_hook)
    handle_bw = target_layer.register_backward_hook(backward_hook)# 输入图片(预处理)
    input_img = preprocess_image('path_to_image.jpg')# 前向传播
    output = model(input_img)# 目标类别(例如类别索引)
    class_idx = torch.argmax(output)
    # 反向传播
    model.zero_grad()
    one_hot = torch.zeros_like(output)
    one_hot[0][class_idx] = 1
    output.backward(gradient=one_hot)# 获取权重
    weights = torch.mean(gradients, dim=(2, 3), keepdim=True)
    # 计算Grad-CAM
    cam = torch.sum(weights * activations, dim=1).squeeze()# 后处理,绘制热力图
    heatmap = cam.detach().cpu().numpy()
    heatmap = np.maximum(heatmap, 0)
    heatmap /= np.max(heatmap)# 叠加到原图
    # ...# 移除钩子
    handle_fw.remove()
    handle_bw.remove()
    

备注

  • Hook函数在实现Grad-CAM中非常关键,能够方便地捕获中间层的特征与梯度。
  • 使用正确的钩子注册和移除非常重要,避免内存泄漏。
  • Grad-CAM适用于模型可解释性,帮助理解模型做出决策的依据。
http://www.lqws.cn/news/106075.html

相关文章:

  • xmake的简易学习
  • 一、无参数的函数调用- RSP,EAX寄存器,全局变量,INT类型和MOV,INC,SHL指令
  • Python中os模块详解
  • Spring Boot 自动配置原理:从入门到精通
  • webstrom中git插件勾选提交部分文件时却出现提交全部问题怎么解决
  • UGUI Text/TextMeshPro字体组件
  • Activity
  • 6.3本日总结
  • agent mode 代理模式,整体要求,系统要求, 系统指令
  • ABP-Book Store Application中文讲解 - Part 7: Authors: Database Integration
  • 『uniapp』把接口的内容下载为txt本地保存 / 读取本地保存的txt文件内容(详细图文注释)
  • WPS word 已有多级列表序号
  • 免费批量文件重命名软件
  • AI健康小屋+微高压氧舱:科技如何重构我们的健康防线?
  • KITTI数据集(计算机视觉和自动驾驶领域)
  • mobilnet v4 部署笔记
  • go语言基础|slice入门
  • C语言学习—数据类型20250603
  • 2025.6.3总结
  • Jpom:Java开发者的一站式自动化运维平台详解
  • Java编程之建造者模式
  • 深度学习入门Day2--鱼书学习(1)
  • 【Typst】4.导入、包含和读取
  • Spring AI Alibaba + Nacos 动态 MCP Server 代理方案
  • Playwright定位器详解:自动化测试的核心工具
  • 集合类基础概念
  • 2023年12月四级真题作文的分析总结
  • 704. 二分查找 (力扣)
  • 十五、【测试执行篇】异步与并发:使用 Celery 实现测试任务的后台执行与结果回调
  • GaLore:基于梯度低秩投影的大语言模型高效训练方法详解一