当前位置: 首页 > news >正文

第N5周:Pytorch文本分类入门

  •         🍨 本文为🔗365天深度学习训练营中的学习记录博客
  •         🍖 原作者:K同学啊

一、前期准备

1.加载数据
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms,datasets
import os,PIL,pathlib,warningswarnings.filterwarnings("ignore")
#忽略警告信息#win10系统,调用GPU运行
#device = torch.device("cuda" if torch.cuda.is_available()else "cpu")
#devicedevice = torch.device("cpu")
device
device(type='cpu')
from torchtext.datasets import AG_NEWStrain_iter = list(AG_NEWS(split='train')) #加载 AG_News 数据集
num_class = len(set([label for (label, text) in train_iter]))
2.构建词典
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iteratortokenizer = get_tokenizer('basic_english') # 返回分词器函数,训练营内“get tokenizer函数详解def yield_tokens(data_iter):for _, text in data_iter:yield tokenizer(text)vocab = build_vocab_from_iterator(yield_tokens(train_iter),specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"]) #设置默认索引,如果找不到单词,则会选择默认索引
vocab(['here','is','an','example'])

 [475, 21, 30, 5297]

text_pipeline= lambda x:vocab(tokenizer(x))
label_pipeline=lambda x:int(x)-1text_pipeline('here is the an example')

 [475, 21, 2, 30, 5297]

label_pipeline('10')

3.生成数据批次和迭代器
from torch.utils.data import DataLoaderdef collate_batch(batch):label_list,text_list,offsets =[],[],[0]for(_label, _text) in batch:#标签列表label_list.append(label_pipeline(_label))#文本列表processed_text = torch.tensor(text_pipeline(_text),dtype=torch.int64)text_list.append(processed_text)#偏移量,即语句的总词汇量offsets.append(processed_text.size(0))label_list = torch.tensor(label_list, dtype=torch.int64)text_list = torch.cat(text_list)offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)  #返回维度dim中输入元素的累计和return label_list.to(device),text_list.to(device),offsets.to(device)#数据加载器
dataloader =DataLoader(train_iter,batch_size=8,shuffle =False,collate_fn=collate_batch)

二、准备模型

1.定义模型
from torch import nnclass TextclassificationModel(nn.Module):def __init__(self, vocab_size, embed_dim, num_class):super(TextclassificationModel,self).__init__()self.embedding =nn.EmbeddingBag(vocab_size, #词典大小embed_dim,  #嵌入的维度sparse=False) #self.fc =nn.Linear(embed_dim,num_class)self.init_weights()def init_weights(self):initrange =0.5self.embedding.weight.data.uniform_(-initrange, initrange)self.fc.weight.data.uniform_(-initrange, initrange)self.fc.bias.data.zero_()def forward(self,text, offsets):embedded =self.embedding(text,offsets)return self.fc(embedded)
2.定义实例
num_class = len(set([label for(label,text)in train_iter]))
vocab_size = len(vocab)
em_size = 64
model = TextclassificationModel(vocab_size,em_size,num_class).to(device)
3.定义训练函数和评估函数
import timedef train(dataloader, model, optimizer, criterion, epoch):model.train()total_acc, train_loss, total_count = 0, 0, 0log_interval = 500start_time = time.time()for idx, (label, text, offsets) in enumerate(dataloader):predicted_label = model(text, offsets)optimizer.zero_grad()loss = criterion(predicted_label, label)loss.backward()optimizer.step()total_acc += (predicted_label.argmax(1) == label).sum().item()train_loss += loss.item()total_count += label.size(0)if idx % log_interval == 0 and idx > 0:elapsed = time.time() - start_timeprint('| epoch {:1d} | {:4d}/{:4d} batches ''| train_acc {:4.3f} train_loss {:4.5f}'.format(epoch, idx, len(dataloader),total_acc / total_count, train_loss / total_count))total_acc, train_loss, total_count = 0, 0, 0start_time = time.time()def evaluate(dataloader, model, criterion):model.eval()  # 切换为测试模式total_acc, train_loss, total_count = 0, 0, 0with torch.no_grad():for idx, (label, text, offsets) in enumerate(dataloader):predicted_label = model(text, offsets)loss = criterion(predicted_label, label)  # 计算loss值# 记录测试数据total_acc   += (predicted_label.argmax(1) == label).sum().item()train_loss  += loss.item()total_count += label.size(0)return total_acc/total_count, train_loss/total_count

三、训练模型

1.拆分数据集并运行模型
import timedef train(dataloader, model, optimizer, criterion, epoch):model.train()total_acc, train_loss, total_count = 0, 0, 0log_interval = 500start_time = time.time()for idx, (label, text, offsets) in enumerate(dataloader):predicted_label = model(text, offsets)optimizer.zero_grad()loss = criterion(predicted_label, label)loss.backward()optimizer.step()total_acc += (predicted_label.argmax(1) == label).sum().item()train_loss += loss.item()total_count += label.size(0)if idx % log_interval == 0 and idx > 0:elapsed = time.time() - start_timeprint('| epoch {:1d} | {:4d}/{:4d} batches ''| train_acc {:4.3f} train_loss {:4.5f}'.format(epoch, idx, len(dataloader),total_acc / total_count, train_loss / total_count))total_acc, train_loss, total_count = 0, 0, 0start_time = time.time()def evaluate(dataloader, model, criterion):model.eval()  # 切换为测试模式total_acc, train_loss, total_count = 0, 0, 0with torch.no_grad():for idx, (label, text, offsets) in enumerate(dataloader):predicted_label = model(text, offsets)loss = criterion(predicted_label, label)  # 计算loss值# 记录测试数据total_acc   += (predicted_label.argmax(1) == label).sum().item()train_loss  += loss.item()total_count += label.size(0)return total_acc/total_count, train_loss/total_count
| epoch 1 |  500/1782 batches | train_acc 0.904 train_loss 0.00450
| epoch 1 | 1000/1782 batches | train_acc 0.903 train_loss 0.00455
| epoch 1 | 1500/1782 batches | train_acc 0.904 train_loss 0.00443
---------------------------------------------------------------------
| epoch 1 | time:11.72s | valid_acc 0.901 valid_loss 0.005
---------------------------------------------------------------------
| epoch 2 |  500/1782 batches | train_acc 0.918 train_loss 0.00379
| epoch 2 | 1000/1782 batches | train_acc 0.920 train_loss 0.00377
| epoch 2 | 1500/1782 batches | train_acc 0.913 train_loss 0.00399
---------------------------------------------------------------------
| epoch 2 | time:11.52s | valid_acc 0.907 valid_loss 0.005
---------------------------------------------------------------------
| epoch 3 |  500/1782 batches | train_acc 0.930 train_loss 0.00323
| epoch 3 | 1000/1782 batches | train_acc 0.925 train_loss 0.00345
| epoch 3 | 1500/1782 batches | train_acc 0.925 train_loss 0.00350
---------------------------------------------------------------------
| epoch 3 | time:11.77s | valid_acc 0.915 valid_loss 0.004
---------------------------------------------------------------------
| epoch 4 |  500/1782 batches | train_acc 0.937 train_loss 0.00294
| epoch 4 | 1000/1782 batches | train_acc 0.931 train_loss 0.00317
| epoch 4 | 1500/1782 batches | train_acc 0.927 train_loss 0.00332
---------------------------------------------------------------------
| epoch 4 | time:11.81s | valid_acc 0.914 valid_loss 0.004
---------------------------------------------------------------------
| epoch 5 |  500/1782 batches | train_acc 0.951 train_loss 0.00243
| epoch 5 | 1000/1782 batches | train_acc 0.950 train_loss 0.00243
| epoch 5 | 1500/1782 batches | train_acc 0.949 train_loss 0.00245
---------------------------------------------------------------------
| epoch 5 | time:11.94s | valid_acc 0.917 valid_loss 0.004
---------------------------------------------------------------------
| epoch 6 |  500/1782 batches | train_acc 0.951 train_loss 0.00236
| epoch 6 | 1000/1782 batches | train_acc 0.951 train_loss 0.00241
| epoch 6 | 1500/1782 batches | train_acc 0.951 train_loss 0.00241
---------------------------------------------------------------------
| epoch 6 | time:11.69s | valid_acc 0.918 valid_loss 0.004
---------------------------------------------------------------------
| epoch 7 |  500/1782 batches | train_acc 0.952 train_loss 0.00233
| epoch 7 | 1000/1782 batches | train_acc 0.952 train_loss 0.00236
| epoch 7 | 1500/1782 batches | train_acc 0.952 train_loss 0.00235
---------------------------------------------------------------------
| epoch 7 | time:11.88s | valid_acc 0.920 valid_loss 0.004
---------------------------------------------------------------------
| epoch 8 |  500/1782 batches | train_acc 0.953 train_loss 0.00233
| epoch 8 | 1000/1782 batches | train_acc 0.954 train_loss 0.00226
| epoch 8 | 1500/1782 batches | train_acc 0.953 train_loss 0.00229
---------------------------------------------------------------------
| epoch 8 | time:11.92s | valid_acc 0.917 valid_loss 0.004
---------------------------------------------------------------------
| epoch 9 |  500/1782 batches | train_acc 0.956 train_loss 0.00223
| epoch 9 | 1000/1782 batches | train_acc 0.955 train_loss 0.00219
| epoch 9 | 1500/1782 batches | train_acc 0.955 train_loss 0.00223
---------------------------------------------------------------------
| epoch 9 | time:11.78s | valid_acc 0.919 valid_loss 0.004
---------------------------------------------------------------------
| epoch 10 |  500/1782 batches | train_acc 0.955 train_loss 0.00226
| epoch 10 | 1000/1782 batches | train_acc 0.954 train_loss 0.00223
| epoch 10 | 1500/1782 batches | train_acc 0.955 train_loss 0.00221
---------------------------------------------------------------------
| epoch 10 | time:11.82s | valid_acc 0.919 valid_loss 0.004
---------------------------------------------------------------------
2.使用测试数据集评估模型 
print('checking the results of test dataset.')
test_acc,test_loss = evaluate(test_dataloader,model, criterion)
print('test accuracy{:8.3f}'.format(test_acc))

四、学习心得

       本周额外安装了 portalocker 库,并且下载了AG_News数据集,并TextClassificationModel模型,首先对文本进行嵌入,然后对句子嵌入之后的结果进行均值聚合,从而最终实现了文本分类的任务。在训练过程出现一些问题得到有效解决。

http://www.lqws.cn/news/546661.html

相关文章:

  • 使用GDAL库统计不同分区内的灾害点分布情况,计算灾害相对密度等统计指标
  • Spring Boot 3.2.11 Swagger版本推荐
  • Python 数据分析与可视化 Day 9 - 缺失值与异常值处理技巧
  • 从0到100:房产中介小程序开发笔记(中)
  • css去掉换行小工具 去掉css换行 style样式去掉换行
  • flink同步kafka到paimon,doris加速查询
  • 大数据赋能智能家居:打造你贴心的“数字管家”
  • 飞往大厂梦之算法提升-day09
  • ssh -T git@github.com失败后解决方案
  • Google机器学习实践指南(逻辑回归损失函数)
  • RabitQ 量化:既省内存又提性能
  • 华为云Flexus+DeepSeek征文 | 基于华为云ModelArts Studio平台搭建AI Markdown编辑器
  • 【iSAQB软件架构】四大架构视图利益相关者
  • 【开源项目】「安卓原生3D开源渲染引擎」:Sceneform‑EQR
  • 机器学习6——线性分类函数
  • PHP「Not enough Memory」实战排错笔记
  • 小程序 API 开发手册:从入门到高级应用一网打尽
  • 基于[coze][dify]搭建一个智能体工作流,抓取热门视频数据,自动存入在线表格
  • Python打卡:Day38
  • 华为数通认证:适合谁的技术进阶之路?
  • 基于MySQL的分布式锁实现(Spring Boot + MyBatis)
  • 【数据分析,相关性分析】Matlab代码#数学建模#创新算法
  • 【C语言】知识总结·指针篇
  • 关于SAP产品名称变更通知 SAP云认证实施商工博科技
  • 动态控制click事件绑定
  • H.264中片数据分割(Slice Data Partitioning)介绍
  • Decoder-only PLM GPT1
  • c++异常
  • LINUX625 DNS反向解析
  • gemini-cli 踩坑实录