当前位置: 首页 > news >正文

从代码学习深度学习 - 情感分析:使用循环神经网络 PyTorch版

文章目录

  • 前言
  • 1. 加载与预处理数据集
    • 数据读取与词元化
    • 构建词汇表
    • 截断、填充与数据迭代器
  • 2. 构建循环神经网络模型
    • 双向RNN模型(BiRNN)详解
    • 权重初始化
  • 3. 加载预训练词向量
    • 构建词向量加载器
    • 将预训练向量注入模型
  • 4. 训练与评估模型
    • 定义训练函数
    • 可视化训练过程
  • 5. 模型预测
    • 编写预测函数
    • 实例测试
  • 6. 总结


前言

在信息爆炸的时代,从海量的文本数据中提取有价值的信息变得至关重要。无论是电商网站的商品评论、社交媒体上的用户反馈,还是新闻文章中的观点倾向,理解文本背后的情感色彩——即情感分析——都有着广泛的应用。

循环神经网络(RNN)由于其对序列数据的强大建模能力,天然地适用于处理文本这类具有时序特征的数据。在本篇博客中,我们将从零开始,使用PyTorch框架构建一个基于双向循环神经网络(Bi-RNN)的情感分析模型。我们不仅会详细讲解数据预处理、模型构建、训练评估的全过程,还将引入预训练的GloVe词向量来提升模型的性能。

这篇博客的目标是“从代码学习深度学习”。因此,我们将完整地展示每一个模块的代码,并配以详尽的解释,力求让读者不仅能看懂代码,更能理解每一行代码背后的原理和设计思想。无论您是深度学习初学者,还是希望系统学习PyTorch在自然语言处理中应用的开发者,相信都能从中获益。

让我们一起踏上这场代码与思想的探索之旅吧!

完整代码:下载链接


1. 加载与预处理数据集

任何成功的NLP项目都始于坚实的数据处理。我们的任务是分析IMDb电影评论的情感,这是一个经典的二分类问题(正面/负面)。在这一步,我们将完成从原始文本文件到PyTorch数据迭代器的全部转换过程。

主逻辑由load_data_imdb函数驱动,它调用了一系列辅助函数来完成任务。

# 情感分析:使用循环神经网络.ipynbimport torch
import utils_for_data
from torch import nnbatch_size = 64
train_iter, test_iter, vocab = utils_for_data.load_data_imdb(batch_size)

上面的代码是我们的入口,它调用utils_for_data.load_data_imdb来获取训练/测试数据迭代器和词汇表。现在,让我们深入utils_for_data.pyutils_for_vocab.py,看看这一切是如何实现的。

数据读取与词元化

首先,我们需要从压缩包中读取IMDb数据集的文本和标签。read_imdb函数负责遍历指定目录,读取每个评论文件并为其打上正面(1)或负面(0)的标签。

# utils_for_data.pyimport os
import zipfile
import tarfile
import utils_for_vocab
import torch.utils.data as data
import torchdef extract(name, folder=None):"""下载并解压zip/tar文件参数:name (str): 要解压的文件名/路径,维度: [字符串]folder (str, optional): 指定的文件夹名称,维度: [字符串] 或 None返回:str: 解压后的目录路径,维度: [字符串]"""base_dir = os.path.dirname(name)data_dir, ext = os.path.splitext(name)if ext == '.zip':fp = zipfile.ZipFile(name, 'r')elif ext in ('.tar', '.gz'):fp = tarfile.open(name, 'r')else:assert False, '只有zip/tar文件可以被解压缩'fp.extractall(base_dir)fp.close()return os.path.join(base_dir, folder) if folder else data_dirdef read_imdb(data_dir, is_train):"""读取IMDb评论数据集文本序列和标签参数:data_dir (str): 数据集根目录路径is_train (bool): 是否读取训练集,True为训练集,False为测试集返回:tuple: (data, labels)data (list): 评论文本列表,维度为 [样本数量]labels (list): 标签列表,维度为 [样本数量],1表示正面评价,0表示负面评价"""data = []labels = []for label in ('pos', 'neg'):folder_name = os.path.join(data_dir, 'train' if is_train else 'test', label)for file in os.listdir(folder_name):file_path = os.path.join(folder_name, file)with open(file_path, 'rb') as f:review = f.read().decode('utf-8').replace('\n', '')data.append(review)labels.append(1 if label == 'pos' else 0)return data, labels

拿到原始文本后,我们需要将其分解为模型可以理解的基本单元——词元(Token)。这个过程称为词元化(Tokenization)。tokenize函数可以按单词或字符进行分割。

# utils_for_vocab.pyimport torch
import torch.utils.data
from collections import Counterdef tokenize(lines, token='word'):"""将文本行拆分为单词或字符词元参数:lines (list): 文本行列表,维度: [行数],每个元素为字符串token (str): 词元化类型,维度: [标量],'word'表示按单词分割,'char'表示按字符分割返回:tokenized_lines (list): 词元化后的文本,维度: [行数 × 词元数],嵌套列表结构"""if token == 'word':return [line.split() for line in lines]elif token == 'char':return [list(line) for line in lines]else:print('错误:未知词元类型:' + token)

构建词汇表

计算机无法直接处理文本,我们需要将词元映射为数字索引。Vocab类就是为此设计的。它会统计所有词元的频率,并只保留那些出现频率高于min_freq的词元,其余的都归为未知词元<unk>。这不仅能减小词汇表的大小,还能过滤掉噪音。

# utils_for_vocab.pydef count_corpus(tokens):"""统计词元出现频率参数:tokens (list): 词元列表,维度: [词元数] 或 [序列数 × 词元数](嵌套列表)返回:counter (Counter): 词元频率统计对象,键为词元,值为出现次数"""if len(tokens) == 0 or isinstance(tokens[0], list):tokens = [token for line in tokens for token in line]return Counter(tokens)class Vocab:"""文本词汇表类,用于管理词元到索引的映射关系"""def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):"""初始化词汇表参数:tokens (list): 词元列表,维度: [词元数] 或 [序列数 × 词元数]min_freq (int): 最小词频阈值,维度: [标量],低于此频率的词元将被忽略reserved_tokens (list): 保留词元列表,维度: [保留词元数],如特殊标记"""if tokens is None:tokens = []if reserved_tokens is None:reserved_tokens = []counter = count_corpus(tokens)self._token_freqs = sorted(counter.items()<
http://www.lqws.cn/news/504073.html

相关文章:

  • 国产安路FPGA纯verilog视频图像去雾,基于暗通道先验算法实现,提供5套TD工程源码和技术支持
  • 帮助装修公司拓展客户资源的微信装修小程序怎么做?
  • 开篇-认识Gin——Go语言Web框架的性能王者
  • 接口自动化测试之 pytest 接口关联框架封装
  • Qt 中使用 gtest 做单元测试
  • 如何一次性将 iPhone 中的联系人转移到 PC
  • Learning to See in the Dark 论文阅读
  • 安卓android com.google.android.material.tabs.TabLayout 设置下拉图标无法正常显示
  • ubuntu虚拟机扩容
  • 【计算机网络】期末复习
  • centos 7 mysql 8 离线部署
  • (3)ROS2:6-dof前馈+PD / 阻抗控制器
  • 【Vue】 keep-alive缓存组件实战指南
  • C# VB.NET中Tuple轻量级数据结构和固定长度数组
  • 第五课:大白话教你用K邻近算法做分类和回归
  • 从零学习linux(2)——管理
  • 战地2042(战地风云)因安全启动(Secure Boot)无法启动的解决方案以及其他常见的启动或闪退问题
  • iOS 抓包实战:时间戳偏差导致的数据同步异常排查记录
  • spring-ai 1.0.0 学习(十四)——向量数据库
  • 【机器学习深度学习】反向传播机制
  • 使用argparse封装python程序为命令行工具
  • C++ 第二阶段:模板编程 - 第一节:函数模板与类模板
  • Linux线程概念及常用接口(1)
  • 数据分箱:科学分类的简单指南
  • 轻量级小程序自定义tabbar组件封装的实现与使用
  • MediaMarktSaturn EDI 对接指南:欧洲零售卖场的数字化协同范例
  • 火山引擎向量数据库 Milvus 版正式开放
  • 竹云受邀出席华为开发者大会,与华为联合发布海外政务数字化解决方案
  • 【MATLAB代码】基于MVC的EKF和经典EKF对三维非线性状态的滤波,提供滤波值对比、误差对比,应对跳变的观测噪声进行优化
  • 安全报告:LLM 模型在无显性攻击提示下的越狱行为分析