当前位置: 首页 > news >正文

信号处理学习——文献精读与code复现之TFN——嵌入时频变换的可解释神经网络(下)

书接上文:

信号处理学习——文献精读与code复现之TFN——嵌入时频变换的可解释神经网络(上)-CSDN博客

接下来是重要的代码复现!!!GitHub - ChenQian0618/TFN: this is the open code of paper entitled "TFN: An Interpretable Neural Network With Time Frequency Transform Embedded for Intelligent Fault Diagnosis".



一. 准备工作

因为我的论文中所使用的数据的样本量是2048,而不是TFN文献中的1024,所以有些地方需要调整一下。

先看TFN-main\Models\BackboneCNN.py中的代码,查看是否需要调整。(不需要)

再看TFN-main\Models\TFconvlayer.py中的代码,也同样的(不需要)。

具体可去github上查看作者大大们的完整代码。

模块是否依赖输入长度?原因
TFconv_* 类中的 forward❌ 不依赖它接受的输入是 [B, C, L],不限制 L(你的2048没问题)
weightforward()❌ 不依赖只与 kernel_sizesuperparams 有关,与你输入的信号长度无关
AdaptiveMaxPool1d in CNN❌ 不依赖输入长度自动调整为固定输出维度,兼容任意长度
T = torch.arange(...)❌ 但与 kernel_size 相关这个是内部卷积核构造,与输入数据 2048 无关

二. 设置

参考文献中的部分设置,分类损失函数采用交叉熵,训练优化器选用Adam,动量参数设为0.9,初始学习率为0.001,总训练周期为50次。总共重复10次平均实验。

三.Code

1.数据划分

这里用到的数据集是比CWRU要稍微难识别一些的变速轴承信号数据集,加拿大渥太华数据集。

import scipy.io as sio
import numpy as np
import random
import os# 定义基础路径
base_path = 'D:/0A_Gotoyourdream/00_BOSS_WHQ/A_Code/A_Data/'# 定义各类别对应的mat文件
file_mapping = {'H': 'H-B-1.mat','I': 'I-B-1.mat','B': 'B-B-1.mat','O': 'O-B-2.mat','C': 'C-B-2.mat'
}# 定义每个类别需要抽取的数量
sample_limit = {'H': 200,'I': 200,'B': 200,'O': 200,'C': 200
}# 保存最终数据
X_list = []
y_list = []# 固定参数
fs = 200000
window_size = 2048
step_size = int(fs * 0.015)  # 步长 0.015秒# 类别编码
#label_mapping = {'H': 0, 'I': 1, 'B': 3, 'O': 2, 'C': 4}  # 注意和你之前保持一致label_mapping = {'H': 0, 'I': 1, 'O': 2, 'B': 3, 'C': 4}
inverse_mapping = {v: k for k, v in label_mapping.items()}
labels = [inverse_mapping[i] for i in range(len(inverse_mapping))]
# 再替换这些缩写为全名
label_fullnames = {'H': 'Health','I': 'F_Inner','O': 'F_Outer','B': 'F_Ball','C': 'F_Combined'
}
labels = [label_fullnames[c] for c in labels]# 创建保存目录(可选)
output_dir = os.path.join(base_path, "ClassBD-Processed_Samples")
os.makedirs(output_dir, exist_ok=True)# 遍历每一类数据
for label_name, file_name in file_mapping.items():print(f"正在处理类别 {label_name}...")mat_path = os.path.join(base_path, file_name)dataset = sio.loadmat(mat_path)# 提取振动信号并去直流分量vib_data = np.array(dataset["Channel_1"].flatten().tolist()[:fs * 10])vib_data = vib_data - np.mean(vib_data)# 滑窗切分样本vib_samples = []start = 0while start + window_size <= len(vib_data):sample = vib_data[start:start + window_size].astype(np.float32)  # 降低内存占用vib_samples.append(sample)start += step_sizevib_samples = np.array(vib_samples)print(f"共切分得到 {vib_samples.shape[0]} 个样本")# 抽样if vib_samples.shape[0] < sample_limit[label_name]:raise ValueError(f"类别 {label_name} 样本不足(仅 {vib_samples.shape[0]}),无法抽取 {sample_limit[label_name]} 个")selected_indices = random.sample(range(vib_samples.shape[0]), sample_limit[label_name])selected_X = vib_samples[selected_indices]selected_y = np.full(sample_limit[label_name], label_mapping[label_name], dtype=np.int64)# 保存save_path_X = os.path.join(output_dir, f"X_{label_name}.mat")save_path_y = os.path.join(output_dir, f"y_{label_name}.mat")sio.savemat(save_path_X, {'X': selected_X})sio.savemat(save_path_y, {'y': selected_y})print(f"已保存类别 {label_name} 的数据:{save_path_X}, {save_path_y}")

2. 存储为dataloder

import os
import scipy.io as sio
import numpy as np
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader# ========== 1. 读取四类数据 ==========
base_path = "D:/0A_Gotoyourdream/00_BOSS_WHQ/A_Code/A_Data/ClassBD-Processed_Samples"def load_data(label):X = sio.loadmat(os.path.join(base_path, f"X_{label}.mat"))["X"]y = sio.loadmat(os.path.join(base_path, f"y_{label}.mat"))["y"].flatten()return X.astype(np.float32), y.astype(np.int64)X_H, y_H = load_data("H")
X_I, y_I = load_data("I")
X_B, y_B = load_data("B")
X_O, y_O = load_data("O")
X_C, y_C = load_data("C")# ========== 2. 合并数据 + reshape ==========
X_all = np.concatenate([X_H, X_I, X_B, X_O, X_C], axis=0)
y_all = np.concatenate([y_H, y_I, y_B, y_O, y_C], axis=0)
X_all = X_all[:, np.newaxis, :]  # (N, 1, 200000)# ========== 3. 划分训练/测试集 ==========
X_train, X_test, y_train, y_test = train_test_split(X_all, y_all, test_size=0.4, stratify=y_all, random_state=42)# ========== 4. DataLoader ==========
train_dataset = TensorDataset(torch.tensor(X_train), torch.tensor(y_train))
test_dataset = TensorDataset(torch.tensor(X_test), torch.tensor(y_test))train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

3. 定义模型及一些设置

需要注意的部分在代码的注释中有写

from Models.TFN import TFN_STTF  # 你也可以换成 TFN_Chirplet、TFN_Morlet
model = TFN_STTF(in_channels=1, out_channels=5, kernel_size=15)  # out_channels = 类别数device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

4. 训练及测试

# 开始训练
for epoch in range(1, 51):model.train()running_loss = 0.0correct = 0total = 0for inputs, labels in train_loader:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()# TFN模型支持返回多个输出(output, _, _)outputs = model(inputs)if isinstance(outputs, tuple):outputs = outputs[0]loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()#scheduler.step()train_acc = correct / total * 100# 测试集评估model.eval()correct_test = 0total_test = 0with torch.no_grad():for inputs, labels in test_loader:inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)if isinstance(outputs, tuple):outputs = outputs[0]_, predicted = outputs.max(1)total_test += labels.size(0)correct_test += predicted.eq(labels).sum().item()test_acc = correct_test / total_test * 100print(f"Epoch {epoch:03d}: Loss={running_loss:.4f}, Train Acc={train_acc:.2f}%, Test Acc={test_acc:.2f}%")

四. 结果

代码复现成功!!!

接着后面就是拿来做对比实验啦~~~

(感恩大佬们提供github代码!!!)

 

 

http://www.lqws.cn/news/558325.html

相关文章:

  • 给定一个整型矩阵map,求最大的矩形区域为1的数量
  • Insar 相位展开真实的数据集的生成与下载(随机矩阵放大,zernike 仿真包裹相位)
  • Launcher3中的CellLayout 和ShortcutAndWidgetContainer 的联系和各自职责
  • 剑指offer50_0到n-1中缺失的数字
  • python -日期与天数的转换
  • autoas/as 工程的RTE静态消息总线实现与端口数据交换机制详解
  • 解决flash-attn安装报错的问题
  • 【C】陷波滤波器
  • 鸿蒙开发:资讯项目实战之底部导航封装
  • MySQL之MVCC实现原理深度解析
  • 类和对象(中)
  • springboot+Vue驾校管理系统
  • 开疆智能ModbusTCP转CClinkIE网关连接台达DVP-ES3 PLC配置案例
  • Java-正则表达式
  • 测量 Linux 中进程上下文切换需要的时间
  • cocos creator 3.8 - 精品源码 - 挪车超人(挪车消消乐)
  • 同步日志系统深度解析【链式调用】【宏定义】【固定缓冲区】【线程局部存储】【RAII】
  • 蚂蚁百宝箱体验:如何快速创建“旅游小助手”AI智能体
  • LINUX628 NFS 多web;主从dns;ntp;samba
  • AlphaGenome:基因组学领域的人工智能革命
  • Linux离线搭建Redis (centos7)详细操作步骤
  • 深入解析 Electron 核心模块:构建跨平台桌面应用的关键
  • 《Go语言高级编程》玩转RPC
  • Vue.js 中的 v-model 和 :value:理解父子组件的数据绑定
  • 网络 : 传输层【UDP协议】
  • (线性代数)矩阵的奇异值Singular Value
  • WPS之PPT镂空效果实现
  • 笔记07:网表的输出与导入
  • spring中maven缺少包如何重新加载,报错java: 程序包org.springframework.web.reactive.function不存在
  • FPGA产品