从代码学习深度学习 - 自然语言推断与数据集 PyTorch版
文章目录
- 前言
- 什么是自然语言推断 (NLI)?
- SNLI 数据集详解
- 数据格式
- 样本分析
- 第一步:读取与解析原始数据
- 代码实现
- 功能验证
- 第二步:构建通用词汇表工具
- 代码实现 (`utils_for_vocab.py`)
- 第三步:创建 PyTorch 数据集类
- 序列填充与截断
- 自定义 `SNLIDataset` 类
- 第四步:整合流程,加载数据
- 代码实现
- 总结
前言
欢迎来到“从代码学习深度学习”系列!在自然语言处理(NLP)的宏伟蓝图中,理解句子之间的逻辑关系是一项核心挑战。这项任务被称为自然语言推断(Natural Language Inference, NLI),它要求模型判断一个“前提”(Premise)句子能否推断出另一个“假设”(Hypothesis)句子。
为了训练能够胜任此任务的深度学习模型,我们需要一个大规模、高质量的标注数据集。斯坦福自然语言推断(SNLI)数据集正是为此而生,它已成为该领域的基石。
本篇博客的目标,就是带领大家通过 PyTorch,从零开始,一步步处理 SNLI 数据集。我们将深入每一行代码,从读取原始文本文件,到构建词汇表,再到封装成 PyTorch 的 DataLoader
,最终得到可以直接送入模型的、整齐划一的数据批次。无论你是 NLP 新手还是希望夯实基础的实践者,相信都能从中获益。
完整代码:下载链接
什么是自然语言推断 (NLI)?
自然语言推断(NLI),有时也称为文本蕴含(Textual Entailment),是判断两个句子之间逻辑关系的任务。具体来说,给定一个前提句 (Premise) 和一个假设句 (Hypothesis),模型需要确定它们之间的关系属于以下三种之一:
- 蕴含 (Entailment):前提句可以明确地推导出假设句。如果前提为真,那么假设也必然为真。
- 矛盾 (Contradiction):前提句与假设句的含义相互冲突。如果前提为真,那么假设必然为假。
- 中性 (Neutral):两个句子之间没有明确的逻辑关系。前提句的真假无法判断假设句的真假。
为了更直观地理解,我们来看几个例子:
Premise(前提) | Hypothesis(假设) | Label(标签) |
---|---|---|
A man and a woman are walking through a park. | A couple is enjoying a walk in the park. | entailment |
A man inspects the uniform of a figure. | The man is sleeping. | contradiction |
A woman is wearing a pink dress. | The woman is carrying a briefcase. | neutral |
通过训练模型来解决 NLI 任务,我们实际上是在教机器“理解”语言深层的语义和逻辑,这是实现更高级别人工智能的关键一步。
SNLI 数据集详解
SNLI 是由斯坦福大学于2015年发布的、第一个规模达到数十万级别的人工高质量标注 NLI 数据集。它包含约 57 万个带有人工标注的句对,为深度学习模型在 NLI 任务上的发展提供了强大的动力。
数据格式
数据集中的每一条样本都以 JSON 格式存储,但通常为了方便处理,会被转换成制表符分隔的 .txt
文件。一条典型的样本包含以下核心信息:
sentence1
: 前提句 (Premise)。sentence2
: 假设句 (Hypothesis)。gold_label
: 最终确定的标签 (entailment
,contradiction
,neutral
)。
样本分析
让我们分析一下 notebook 中提到的两个例子,来感受一下数据的特点:
样本 1 (中性):
- 前提: “A person on a horse jumps over a broken down airplane.” (一个人骑在马上跳过了一架损坏的飞机。)
- 假设: “A person is training his horse for a competition.” (一个人在训练他的马,为比赛做准备。)
- 分析: “跳飞机”和“为比赛训练”是两个不同的场景,它们不矛盾,也无法相互推断。因此,标签为
neutral
是非常合理的。
样本 2 (矛盾):
- 前提: “A person on a horse jumps over a broken down airplane.” (一个人骑在马上跳过了一架损坏的飞机。)
- 假设: “A person is at a diner, ordering an omelette.” (一个人在餐馆点煎蛋。)
- 分析: 一个人不可能同时在户外骑马跳飞机,又在餐厅里点餐。这两个场景在时空上是完全冲突的。因此,SNLI 将这种情况标注为
contradiction
。
理解了数据集的基本情况后,让我们开始动手编写代码来处理它。
第一步:读取与解析原始数据
我们的第一步是从 .txt
文件中读取数据,并将其解析为三个独立的列表:前提、假设和标签。
代码实现
我们定义一个 read_snli
函数来完成这项工作。
import re # 用于正则表达式文本处理
import os # 用于操作系统路径操作def read_snli(data_dir, is_train):"""将SNLI数据集解析为前提、假设和标签参数:data_dir: str类型,SNLI数据集所在的目录路径is_train: bool类型,True表示读取训练集,False表示读取测试集返回:premises: list类型,维度为[N],包含N个前提句子的字符串列表hypotheses: list类型,维度为[N],包含N个假设句子的字符串列表 labels: list类型,维度为[N],包含N个标签的整数列表(0:蕴含, 1:矛盾, 2:中性)"""def extract_text(s):"""文本预处理函数,清理和标准化输入文本参数:s: str类型,原始文本字符串返回:str类型,处理后的清洁文本"""# 删除所有左括号,这些符号在NLI任务中通常不重要s = re.sub('\\(', '', s)# 删除所有右括号,保持文本简洁s = re.sub('\\)', '', s)# 将两个或多个连续的空格替换为单个空格,标准化空白字符s = re.sub('\\s{2,}', ' ', s)# 去除字符串首尾的空白字符并返回return s.strip()# 标签映射字典:将字符串标签转换为数值标签# entailment(蕴含): 前提能够推出假设# contradiction(矛盾): 前提与假设相矛盾 # neutral(中性): 前提与假设没有明确的逻辑关系label_set = {'entailment': 0, 'contradiction': 1, 'neutral': 2}# 根据is_train参数选择对应的数据文件# 训练集文件名: snli_1.0_train.txt# 测试集文件名: snli_1.0_test.txtfile_name = os.path.join(data_dir, 'snli_1.0_train.txt'if is_train else 'snli_1.0_test.txt')# 读取数据文件with open(file_name, 'r', encoding='utf-8') as f:# 读取所有行并按制表符分割,跳过第一行(表头)# rows: list类型,维度为[M, K],M为数据行数,K为每行的字段数rows = [row.split('\t') for row in f.readlines()[1:]]# 提取前提句子(第2列,索引为1)# 只保留标签在label_set中的有效数据行# premises: list类型,维度为[N],N为有效样本数量premises = [extract_text(row[1]) for row in rows if row[0] in label_set]# 提取假设句子(第3列,索引为2) # 同样只保留标签有效的数据行# hypotheses: list类型,维度为[N],与premises长度相同hypotheses = [extract_text(row[2]) for row in rows if row[0] in label_set]# 提取标签(第1列,索引为0)并转换为数值# labels: list类型,维度为[N],包含0,1,2三种整数值labels = [label_set