当前位置: 首页 > news >正文

P27:RNN实现阿尔茨海默病诊断

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

一、过程解读

PyTorch 实战:阿尔茨海默病数据预测模型

今天,我将带大家一起探索一个基于 PyTorch 的深度学习小项目——利用 RNN 模型对阿尔茨海默病数据进行预测。这个实例不仅涵盖了数据预处理、模型构建、训练与保存,还包含了如何在其他代码中调用保存的模型进行预测。通过这个完整的流程,我们可以学到很多实用技能,下面让我们一一进行详细解读。

数据准备与预处理:奠定模型基础

数据是深度学习模型的基石,因此数据的准备和预处理步骤至关重要。在这个实例中,我们使用 pandas 读取 CSV 格式的阿尔茨海默病数据,并进行了一些简单的预处理操作,比如删除第一列和最后一列,以及将特征数据标准化为标准正态分布。通过这些操作,我们确保了输入模型的数据质量和一致性,为模型的训练打下了良好的基础。

模型构建:搭建 RNN 网络

接下来是构建模型的核心环节。在这个实例中,我们设计了一个简单的循环神经网络(RNN)模型,它由一个 RNN 层和两个全连接层组成。RNN 层能够处理序列数据,这对于时间序列预测等任务非常有用。全连接层则用于将 RNN 的输出映射到最终的分类结果。通过定义模型的结构,我们不仅学习了如何使用 PyTorch 构建自定义神经网络,还了解了 RNN 和全连接层的基本原理和应用场景。

模型训练:优化模型参数

模型构建完成后,我们进入训练阶段。在这个实例中,我们定义了训练函数和测试函数,分别用于计算模型在训练集和测试集上的损失和准确率,并更新模型参数。通过设置损失函数(如交叉熵损失)和优化器(如 Adam 优化器),我们能够有效地优化模型的性能。此外,我们还学习了如何使用数据加载器来批量加载数据,以及如何在训练过程中记录和输出模型的训练进度和指标,这些技巧对于监控和调整模型的训练过程非常有帮助。

模型保存与加载:实现模型的持久化和复用

当模型训练完成后,我们将其保存到本地文件中。这样做的好处是可以避免每次使用模型时都要重新训练,节省了大量的时间和计算资源。在实例中,我们使用 PyTorch 提供的 torch.savetorch.load 函数来保存和加载模型的参数。通过这种方式,我们可以在其他代码中轻松地加载模型,并直接使用它进行预测,而无需关心模型的训练过程。

模型调用与预测:将模型付诸实践

最后,我们在其他代码中加载了之前保存的模型,并使用它对新的数据进行预测。在这个过程中,我们学习了如何将输入数据转换为模型所需的格式,以及如何调用模型进行预测并获取结果。此外,我们还通过打印模型的预测结果来验证模型的性能,并使用混淆矩阵来评估模型的分类效果。这些步骤展示了如何将深度学习模型应用到实际问题中,并为后续的模型优化和改进提供了依据。
当然!循环神经网络(Recurrent Neural Network, RNN)是一种常用于处理序列数据的神经网络架构。与传统的前馈神经网络(Feedforward Neural Network)不同,RNN 具有“记忆”功能,能够处理长度可变的输入序列,并通过内部状态捕获序列中的时间依赖关系。下面将对 RNN 的基本原理、结构和应用场景进行详细讲解。

RNN网络的基本原理

  1. 序列数据的处理

    • RNN 适用于处理序列数据,如时间序列、文本序列、语音序列等。序列数据具有时间上的依赖关系,后续数据点与前面的数据点相关。RNN 通过循环结构,将前面时间步的隐藏状态传递到当前时间步,实现对序列信息的累积和记忆。
  2. 循环结构

    • RNN 的核心是循环(recurrence)结构,它允许信息在神经网络中沿时间步进行传递。基本的 RNN 单元在每个时间步接收两个输入:当前时间步的输入 ( x_t ) 和前一时间步的隐藏状态 ( h_{t-1} ),然后计算当前时间步的隐藏状态 ( h_t ) 和输出 ( o_t )。
  3. 状态更新公式

    • RNN 的状态更新通常使用以下公式:

      其中:

      • ( x_t ) 是当前时间步的输入。
      • ( h_{t-1} ) 是前一时间步的隐藏状态。
      • ( W_{xh} ) 和 ( W_{hh} ) 分别是输入到隐藏层和隐藏层到隐藏层的权重矩阵。
      • ( b_h ) 是隐藏层的偏置项。
      • ( \sigma ) 是激活函数,如 tanh 或 ReLU。
      • ( o_t ) 是当前时间步的输出,( W_{ho} ) 和 ( b_o ) 是隐藏层到输出层的权重矩阵和偏置项。

RNN网络的基本结构

  1. 单向 RNN

    • 单向 RNN 只能利用过去的上下文信息来预测当前时间步的输出。它的信息流动是单向的,从过去到未来。
  2. 双向 RNN

    • 双向 RNN 可以同时利用过去的上下文信息和未来的上下文信息来预测当前时间步的输出。它包含两个隐藏层,一个处理正向的时间序列,另一个处理反向的时间序列。
  3. 多层 RNN

    • 多层 RNN 是将多个 RNN 层堆叠在一起,每个 RNN 层的输出作为下一层的输入。这种结构可以学习到更复杂的特征表示。

RNN网络的训练

  1. 反向传播

    • RNN 的训练通常使用反向传播算法,称为“随时间反向传播”(Backpropagation Through Time, BPTT)。它将 RNN 展开为一个沿时间步的计算图,然后对每个时间步的损失进行反向传播,以更新网络的参数。
  2. 梯度消失与梯度爆炸问题

    • 在训练 RNN 时,可能会遇到梯度消失或梯度爆炸的问题。这是因为长序列中的梯度在反向传播时会不断乘以权重矩阵的导数,导致梯度变得非常小或非常大。为了解决这些问题,可以使用梯度裁剪(Gradient Clipping)、LSTM(Long Short-Term Memory)或 GRU(Gated Recurrent Unit)等改进的 RNN 变体。

RNN网络的应用场景

  1. 序列预测

    • 预测时间序列的未来值,如股票价格预测、天气预测等。
  2. 自然语言处理

    • 在文本生成、机器翻译、情感分析等任务中,RNN 能够捕获文本中的上下文信息。
  3. 语音识别

    • 将语音信号转换为文字,RNN 可以处理语音信号的时间序列特征。
  4. 生物信息学

    • 分析 DNA 序列、蛋白质序列等生物数据。

RNN网络的变体

  1. 长短期记忆网络(LSTM)

    • LSTM 是一种改进的 RNN 变体,能够更好地处理长序列数据。它引入了“记忆单元”和多个“门”(输入门、遗忘门和输出门),可以控制信息的流动,从而缓解梯度消失问题。
  2. 门控循环单元(GRU)

    • GRU 是另一种改进的 RNN 变体,结构相对 LSTM 更简单。它将遗忘门和输入门合并为一个“更新门”,并移除了记忆单元,直接使用隐藏状态来存储长期信息。
  3. 深度 RNN

    • 深度 RNN 是将多个 RNN 层堆叠在一起,形成一个多层的 RNN 架构。每一层的输出作为下一层的输入,可以学习到更复杂的特征表示。

二、代码实现

1.导入库函数

import torch
from torch import nn
import torch.nn.functional as F
import seaborn as sns
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import warnings
from datetime import datetime

2.导入数据

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
df = pd.read_csv("./data/alzheimers_disease_data.csv")
# 删除第一列和最后一列
df = df.iloc[:, 1:-1]

3.标准化

X = df.iloc[:, :-1]
y = df.iloc[:, -1]# 将每一列特征标准化为标准正态分布,注意,标准化是针对每一列而言的
sc = StandardScaler()
X = sc.fit_transform(X)

4.数据集构建

X = torch.tensor(np.array(X), dtype=torch.float32)
y = torch.tensor(np.array(y), dtype=torch.int64)X_train, X_test, y_train, y_test = train_test_split(X, y,test_size=0.1,random_state=1)train_dl = DataLoader(TensorDataset(X_train, y_train),batch_size=64,shuffle=False)test_dl = DataLoader(TensorDataset(X_test, y_test),batch_size=64,shuffle=False)

5.模型构建

class model_rnn(nn.Module):def __init__(self):super(model_rnn, self).__init__()self.rnn0 = nn.RNN(input_size=32, hidden_size=200,num_layers=1, batch_first=True)self.fc0 = nn.Linear(200, 50)self.fc1 = nn.Linear(50, 2)def forward(self, x):out, hidden1 = self.rnn0(x)out = self.fc0(out)out = self.fc1(out)return outmodel = model_rnn().to(device)
model_rnn((rnn0): RNN(32, 200, batch_first=True)(fc0): Linear(in_features=200, out_features=50, bias=True)(fc1): Linear(in_features=50, out_features=2, bias=True)
)

6.构建测试训练

def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)num_batches = len(dataloader)train_loss, train_acc = 0, 0for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc /= sizetrain_loss /= num_batchesreturn train_acc, train_loss

7. 构建训练函数

def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)test_loss, test_acc = 0, 0with torch.no_grad():for imgs, target in dataloader:imgs, target = imgs.to(device), target.to(device)target_pred = model(imgs)loss = loss_fn(target_pred, target)test_loss += loss.item()test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()test_acc /= sizetest_loss /= num_batchesreturn test_acc, test_loss

8.训练模型并保存

loss_fn = nn.CrossEntropyLoss()
learn_rate = 5e-5
opt = torch.optim.Adam(model.parameters(), lr=learn_rate)
epochs = 50train_loss = []
train_acc = []
test_loss = []
test_acc = []for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)lr = opt.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss,epoch_test_acc*100, epoch_test_loss, lr))print("="*20, 'Done', "="*20)
#保存模型
torch.save(model.state_dict(), "./model_rnn.pth")  # 保存模型参数
print("模型已保存到 ./model_rnn.pth")
Epoch: 1, Train_acc:62.9%, Train_loss:0.673, Test_acc:71.6%, Test_loss:0.655, Lr:5.00E-05
Epoch: 2, Train_acc:70.1%, Train_loss:0.644, Test_acc:71.2%, Test_loss:0.629, Lr:5.00E-05
Epoch: 3, Train_acc:69.7%, Train_loss:0.617, Test_acc:67.9%, Test_loss:0.603, Lr:5.00E-05
Epoch: 4, Train_acc:67.6%, Train_loss:0.593, Test_acc:66.5%, Test_loss:0.584, Lr:5.00E-05
Epoch: 5, Train_acc:67.6%, Train_loss:0.574, Test_acc:67.9%, Test_loss:0.570, Lr:5.00E-05
Epoch: 6, Train_acc:69.9%, Train_loss:0.555, Test_acc:68.8%, Test_loss:0.556, Lr:5.00E-05
Epoch: 7, Train_acc:73.0%, Train_loss:0.537, Test_acc:70.7%, Test_loss:0.542, Lr:5.00E-05
Epoch: 8, Train_acc:75.3%, Train_loss:0.518, Test_acc:73.0%, Test_loss:0.527, Lr:5.00E-05
Epoch: 9, Train_acc:77.7%, Train_loss:0.498, Test_acc:74.9%, Test_loss:0.513, Lr:5.00E-05
Epoch:10, Train_acc:79.7%, Train_loss:0.479, Test_acc:77.2%, Test_loss:0.499, Lr:5.00E-05
Epoch:11, Train_acc:80.9%, Train_loss:0.461, Test_acc:77.7%, Test_loss:0.486, Lr:5.00E-05
Epoch:12, Train_acc:81.8%, Train_loss:0.444, Test_acc:78.6%, Test_loss:0.473, Lr:5.00E-05
Epoch:13, Train_acc:82.6%, Train_loss:0.428, Test_acc:79.1%, Test_loss:0.462, Lr:5.00E-05
Epoch:14, Train_acc:82.9%, Train_loss:0.414, Test_acc:78.1%, Test_loss:0.452, Lr:5.00E-05
Epoch:15, Train_acc:83.4%, Train_loss:0.401, Test_acc:79.1%, Test_loss:0.444, Lr:5.00E-05
Epoch:16, Train_acc:83.7%, Train_loss:0.390, Test_acc:78.6%, Test_loss:0.436, Lr:5.00E-05
Epoch:17, Train_acc:84.1%, Train_loss:0.380, Test_acc:79.5%, Test_loss:0.430, Lr:5.00E-05
Epoch:18, Train_acc:84.9%, Train_loss:0.372, Test_acc:80.0%, Test_loss:0.425, Lr:5.00E-05
Epoch:19, Train_acc:85.3%, Train_loss:0.364, Test_acc:80.0%, Test_loss:0.420, Lr:5.00E-05
Epoch:20, Train_acc:85.6%, Train_loss:0.358, Test_acc:79.1%, Test_loss:0.417, Lr:5.00E-05
Epoch:21, Train_acc:85.9%, Train_loss:0.352, Test_acc:79.1%, Test_loss:0.414, Lr:5.00E-05
Epoch:22, Train_acc:85.8%, Train_loss:0.347, Test_acc:79.5%, Test_loss:0.412, Lr:5.00E-05
Epoch:23, Train_acc:86.0%, Train_loss:0.343, Test_acc:78.6%, Test_loss:0.410, Lr:5.00E-05
Epoch:24, Train_acc:86.3%, Train_loss:0.339, Test_acc:78.1%, Test_loss:0.409, Lr:5.00E-05
Epoch:25, Train_acc:86.7%, Train_loss:0.335, Test_acc:78.6%, Test_loss:0.408, Lr:5.00E-05
Epoch:26, Train_acc:86.7%, Train_loss:0.332, Test_acc:77.7%, Test_loss:0.408, Lr:5.00E-05
Epoch:27, Train_acc:86.8%, Train_loss:0.329, Test_acc:77.2%, Test_loss:0.408, Lr:5.00E-05
Epoch:28, Train_acc:86.8%, Train_loss:0.327, Test_acc:77.2%, Test_loss:0.408, Lr:5.00E-05
Epoch:29, Train_acc:86.8%, Train_loss:0.324, Test_acc:77.2%, Test_loss:0.408, Lr:5.00E-05
Epoch:30, Train_acc:87.0%, Train_loss:0.322, Test_acc:76.7%, Test_loss:0.409, Lr:5.00E-05
Epoch:31, Train_acc:87.2%, Train_loss:0.320, Test_acc:76.3%, Test_loss:0.409, Lr:5.00E-05
Epoch:32, Train_acc:87.3%, Train_loss:0.318, Test_acc:75.8%, Test_loss:0.410, Lr:5.00E-05
Epoch:33, Train_acc:87.7%, Train_loss:0.316, Test_acc:75.8%, Test_loss:0.411, Lr:5.00E-05
Epoch:34, Train_acc:87.7%, Train_loss:0.314, Test_acc:75.8%, Test_loss:0.412, Lr:5.00E-05
Epoch:35, Train_acc:88.0%, Train_loss:0.312, Test_acc:75.8%, Test_loss:0.413, Lr:5.00E-05
Epoch:36, Train_acc:88.1%, Train_loss:0.310, Test_acc:75.3%, Test_loss:0.414, Lr:5.00E-05
Epoch:37, Train_acc:88.3%, Train_loss:0.309, Test_acc:76.3%, Test_loss:0.416, Lr:5.00E-05
Epoch:38, Train_acc:88.4%, Train_loss:0.307, Test_acc:76.3%, Test_loss:0.417, Lr:5.00E-05
Epoch:39, Train_acc:88.3%, Train_loss:0.305, Test_acc:76.3%, Test_loss:0.418, Lr:5.00E-05
Epoch:40, Train_acc:88.3%, Train_loss:0.304, Test_acc:76.3%, Test_loss:0.420, Lr:5.00E-05
Epoch:41, Train_acc:88.4%, Train_loss:0.302, Test_acc:76.7%, Test_loss:0.421, Lr:5.00E-05
Epoch:42, Train_acc:88.4%, Train_loss:0.301, Test_acc:77.7%, Test_loss:0.423, Lr:5.00E-05
Epoch:43, Train_acc:88.4%, Train_loss:0.299, Test_acc:77.7%, Test_loss:0.425, Lr:5.00E-05
Epoch:44, Train_acc:88.5%, Train_loss:0.297, Test_acc:77.7%, Test_loss:0.426, Lr:5.00E-05
Epoch:45, Train_acc:88.6%, Train_loss:0.296, Test_acc:77.7%, Test_loss:0.428, Lr:5.00E-05
Epoch:46, Train_acc:88.8%, Train_loss:0.294, Test_acc:78.1%, Test_loss:0.430, Lr:5.00E-05
Epoch:47, Train_acc:88.9%, Train_loss:0.293, Test_acc:78.6%, Test_loss:0.432, Lr:5.00E-05
Epoch:48, Train_acc:88.9%, Train_loss:0.291, Test_acc:78.6%, Test_loss:0.435, Lr:5.00E-05
Epoch:49, Train_acc:89.0%, Train_loss:0.290, Test_acc:78.1%, Test_loss:0.437, Lr:5.00E-05
Epoch:50, Train_acc:89.0%, Train_loss:0.288, Test_acc:78.6%, Test_loss:0.439, Lr:5.00E-05
==================== Done ====================
模型已保存到 ./model_rnn.pth

9.模型评估

warnings.filterwarnings("ignore")
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 200current_time = datetime.now()epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.xlabel(current_time)plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

在这里插入图片描述

10.混淆矩阵

print("===============输入数据Shape为===============")
print("X_test.shape: ", X_test.shape)
print("y_test.shape: ", y_test.shape)pred = model(X_test.to(device)).argmax(1).cpu().numpy()print("\n===============输出数据Shape为===============")
print("pred.shape: ", pred.shape)
===============输入数据Shape为===============
X_test.shape:  torch.Size([215, 32])
y_test.shape:  torch.Size([215])===============输出数据Shape为===============
pred.shape:  (215,)
cm = confusion_matrix(y_test, pred)
plt.figure(figsize=(6,5))
plt.suptitle('')
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
plt.title("Confusion Matrix", fontsize=12)
plt.xlabel("Predicted Label", fontsize=10)
plt.ylabel("True Label", fontsize=10)plt.tight_layout()
plt.show()

在这里插入图片描述

11.调用模型进行预测

import torch
from torch import nn
import numpy as np
import pandas as pd
from datetime import datetime
import P27_阿尔兹海默症 as ar# 设置设备(GPU或CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 定义模型结构
class model_rnn(nn.Module):def __init__(self):super(model_rnn, self).__init__()self.rnn0 = nn.RNN(input_size=32, hidden_size=200,num_layers=1, batch_first=True)self.fc0 = nn.Linear(200, 50)self.fc1 = nn.Linear(50, 2)def forward(self, x):out, hidden1 = self.rnn0(x)out = self.fc0(out)out = self.fc1(out)return out# 加载模型
model = model_rnn().to(device)
model.load_state_dict(torch.load("./model_rnn.pth"))  # 加载保存的模型参数
model.eval()  # 设置为评估模式# 假设X_test是测试数据,已经转换为torch.Tensor
# X_test = ...  # 实际数据加载代码# 对单个样本进行预测
test_X = ar.X_test[0].reshape(1, -1).to(device)  # X_test[0]即我们的输入数据
pred = model(test_X).argmax(1).item()
print("模型预测结果为:", pred)
print("=="*20)
print("0: 未患病")
print("1: 已患病")
模型预测结果为: 0
========================================
0: 未患病
1: 已患病
http://www.lqws.cn/news/550207.html

相关文章:

  • 华为云Flexus+DeepSeek征文|基于Dify+ModelArts开发AI智能会议助手
  • 本地部署 WordPress 博客完整指南(基于 XAMPP)
  • nt!MiFlushSectionInternal函数分析从nt!IoSynchronousPageWrite函数到Ntfs!NtfsFsdWrite函数
  • 三阶落地:腾讯云Serverless+Spring Cloud的微服务实战架构
  • React中的ErrorBoundary
  • 【经验】新版Chrome中Proxy SwitchyOmega2已实效,改为ZeroOmega
  • 车载诊断架构 --- 诊断与ECU平台工作说明书
  • SQL Server for Linux 如何实现高可用架构
  • 【策划所需编程知识】
  • 中国双非高校经费TOP榜数据分析
  • 【记录】Ubuntu|Ubuntu服务器挂载新的硬盘的流程(开机自动挂载)
  • SQL学习笔记4
  • MFC获取本机所有IP、局域网所有IP、本机和局域网可连接IP
  • 一起endpoint迷路的问题排查总结
  • 浅谈Apache HttpClient的相关配置和使用
  • git add 报错UnicodeDecodeError: ‘gbk‘ codec can‘t decode byte 0xaf in position 42
  • SOCKS 协议版本 5 (RFC 1928)
  • 【stm32】HAL库开发——CubeMX配置串口通讯(中断方式)
  • VUE使用过程中的碰到问题记录
  • 自动对焦技术助力TGV检测 半导体检测精度大突破
  • 工作台-02.代码开发
  • Linux信号机制:从入门到精通
  • [Python]-基础篇1- 从零开始的Python入门指南
  • 微调大语言模型(生成任务),怎么评估它到底“变好”了?
  • Python网安-zip文件暴力破解
  • Java:链接mysql数据库报错:CommunicationsException: Communications link failure
  • Coze API如何上传文件能得到文件的file_url
  • 缓解停车难的城市密码:4G地磁检测器如何重构车位资源分配
  • Discrete Audio Tokens: More Than a Survey
  • TensorRT-LLM的深度剖析:关键问题与核心局限性