当前位置: 首页 > news >正文

Day.46

通道注意力机制:

class ChannelAttention(nn.Module):

    def __init__(self, in_channels, reduction_ratio=16):

            in_channels: 输入特征图的通道数

            reduction_ratio: 降维比例,用于减少参数量

        super(ChannelAttention, self).__init__()

        self.avg_pool = nn.AdaptiveAvgPool2d(1)

            self.fc = nn.Sequential(

            nn.Linear(in_channels, in_channels // reduction_ratio, bias=False),

            nn.ReLU(inplace=True),

            nn.Linear(in_channels // reduction_ratio, in_channels, bias=False),

            nn.Sigmoid()

        )

    def forward(self, x):

        """

        参数:

            x: 输入特征图,形状为 [batch_size, channels, height, width]

       

        返回:

            加权后的特征图,形状不变

        """

        batch_size, channels, height, width = x.size()

        avg_pool_output = self.avg_pool(x)

        avg_pool_output = avg_pool_output.view(batch_size, channels)

        channel_weights = self.fc(avg_pool_output)

        channel_weights = channel_weights.view(batch_size, channels, 1, 1)

        return x * channel_weights

模型定义:

class CNN(nn.Module):

    def __init__(self):

        super(CNN, self).__init__()  

       

        # ---------------------- 第一个卷积块 ----------------------

        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)

        self.bn1 = nn.BatchNorm2d(32)

        self.relu1 = nn.ReLU()

        self.ca1 = ChannelAttention(in_channels=32, reduction_ratio=16)  

        self.pool1 = nn.MaxPool2d(2, 2)  

       

        # ---------------------- 第二个卷积块 ----------------------

        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)

        self.bn2 = nn.BatchNorm2d(64)

        self.relu2 = nn.ReLU()

        self.ca2 = ChannelAttention(in_channels=64, reduction_ratio=16)  

        self.pool2 = nn.MaxPool2d(2)  

       

        # ---------------------- 第三个卷积块 ----------------------

        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)

        self.bn3 = nn.BatchNorm2d(128)

        self.relu3 = nn.ReLU()

        self.ca3 = ChannelAttention(in_channels=128, reduction_ratio=16)  

        self.pool3 = nn.MaxPool2d(2)  

       

        # ---------------------- 全连接层(分类器) ----------------------

        self.fc1 = nn.Linear(128 * 4 * 4, 512)

        self.dropout = nn.Dropout(p=0.5)

        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):

        # ---------- 卷积块1处理 ----------

        x = self.conv1(x)      

        x = self.bn1(x)        

        x = self.relu1(x)      

        x = self.ca1(x)  

        x = self.pool1(x)      

       

        # ---------- 卷积块2处理 ----------

        x = self.conv2(x)      

        x = self.bn2(x)        

        x = self.relu2(x)      

        x = self.ca2(x)  

        x = self.pool2(x)      

       

        # ---------- 卷积块3处理 ----------

        x = self.conv3(x)      

        x = self.bn3(x)        

        x = self.relu3(x)      

        x = self.ca3(x)  

        x = self.pool3(x)      

       

        # ---------- 展平与全连接层 ----------

        x = x.view(-1, 128 * 4 * 4)  

        x = self.fc1(x)          

        x = self.relu3(x)        

        x = self.dropout(x)      

        x = self.fc2(x)          

       

        return x  

# 重新初始化模型,包含通道注意力模块

model = CNN()

model = model.to(device)  

criterion = nn.CrossEntropyLoss()  

optimizer = optim.Adam(model.parameters(), lr=0.001)  

scheduler.step()

scheduler = optim.lr_scheduler.ReduceLROnPlateau(

    optimizer,        

    mode='min',     

    patience=3,      

    factor=0.5        @浙大疏锦行

)

http://www.lqws.cn/news/571105.html

相关文章:

  • 水果维生素含量排名详表
  • 【硬核数学】9. 驯服“梯度下降”:深度学习中的优化艺术与正则化技巧《从零构建机器学习、深度学习到LLM的数学认知》
  • 【JavaSE】反射学习笔记
  • 中州养老:学会设计数据库表
  • WebRTC(十三):信令服务器
  • Spring事件驱动模型核心:ApplicationEventMulticaster初始化全解析
  • 图书管理系统练习项目源码-前后端分离-使用node.js来做后端开发
  • NV064NV065美光固态闪存NV067NV076
  • 申论审题训练
  • DEPTHPRO:一秒内实现清晰的单目度量深度估计
  • 云端可视化耦合电磁场:麦克斯韦方程组的应用-AI云计算数值分析和代码验证
  • Leetcode百题斩-双指针
  • 电容屏触摸不灵敏及跳点问题分析
  • PyEcharts教程(010):天猫订单数据可视化项目
  • ISP Pipeline(9):Noise Filter for Chroma 色度去噪
  • H3C-路由器DHCPV6V4配置标准
  • 如何通过自动化减少重复性工作
  • GitHub vs GitLab 全面对比报告(2025版)
  • Java面试宝典:基础三
  • Vue中keep-alive结合router实现部分页面缓存
  • Spring生态创新应用
  • 【Redis#4】Redis 数据结构 -- String类型
  • 用户行为序列建模(篇七)-【阿里】DIN
  • AlphaFold3安装报错
  • 【系统分析师】2021年真题:论文及解题思路
  • GitLab详细分析
  • ​19.自动补全功能
  • 机器学习7——神经网络上
  • SpringCloud系列(40)--SpringCloud Gateway的Filter的简介及使用
  • 基于YOLO的目标检测图形界面应用(适配于YOLOv5、YOLOv6、YOLOv8、YOLOv9、YOLOv10、YOLOv11、YOLOv12)