当前位置: 首页 > news >正文

7.4-Creating data loaders for an instruction dataset

Chapter 7-Fine-tuning to follow instructions

7.4-Creating data loaders for an instruction dataset

  • 我们只需将InstructionDataset对象和custom_collate_fn函数接入 PyTorch 数据加载器

  • 使用以下代码来初始化设备信息

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# Note:
    # Uncommenting the following lines will allow the code to run on Apple Silicon chips, if applicable,
    # which is much faster than on an Apple CPU (as measured on an M3 MacBook Air).
    # However, the resulting loss values may be slightly different.#if torch.cuda.is_available():
    #    device = torch.device("cuda")
    #elif torch.backends.mps.is_available():
    #    device = torch.device("mps")
    #else:
    #    device = torch.device("cpu")print("Device:", device)"""输出"""
    Device: cuda
    

    custom_collate_fn函数中的device参数和allowed_max_length预先设定为变量device1024。这样在后续调用customized_collate_fn时,就不需要再手动传入这两个参数的值了。

    from functools import partialcustomized_collate_fn = partial(custom_collate_fn,device=device,allowed_max_length=1024
    )
    

    接下来,我们设置数据加载器,但是这次,我们将使用我们的自定义排序函数进行批处理过程。

    from torch.utils.data import DataLoadernum_workers = 0
    batch_size = 8torch.manual_seed(123)train_dataset = InstructionDataset(train_data, tokenizer)
    train_loader = DataLoader(train_dataset,batch_size=batch_size,collate_fn=customized_collate_fn,shuffle=True,drop_last=True,num_workers=num_workers
    )val_dataset = InstructionDataset(val_data, tokenizer)
    val_loader = DataLoader(val_dataset,batch_size=batch_size,collate_fn=customized_collate_fn,shuffle=False,drop_last=False,num_workers=num_workers
    )test_dataset = InstructionDataset(test_data, tokenizer)
    test_loader = DataLoader(test_dataset,batch_size=batch_size,collate_fn=customized_collate_fn,shuffle=False,drop_last=False,num_workers=num_workers
    )
    

    让我们看看input 和target批次的维度是什么样的

    print("Train loader:")
    for inputs, targets in train_loader:print(inputs.shape, targets.shape)"""输出"""
    Train loader:
    torch.Size([8, 61]) torch.Size([8, 61])
    torch.Size([8, 76]) torch.Size([8, 76])
    ......
    torch.Size([8, 69]) torch.Size([8, 69])
    

    根据上面的输出,我们可以看到,所有批次的批次大小为8,但长度不同,第一个[8,61]表示,batchsize为8,在当前批次中,每个训练示例中的token数量为61。让我们通过打印“input”批处理中第一个训练示例的内容来仔细检查输入是否包含与tokenID 50256对应的“<|endoftext|>”填充token

    print(inputs[0])"""输出"""
    tensor([21106,   318,   281, 12064,   326,  8477,   257,  4876,    13, 19430,257,  2882,   326, 20431, 32543,   262,  2581,    13,   198,   198,21017, 46486,    25,   198, 30003,  6525,   262,  6827,  1262,   257,985,   576,    13,   198,   198, 21017, 23412,    25,   198,   464,5156,   318,   845, 13779,    13,   198,   198, 21017, 18261,    25,198,   464,  5156,   318,   355, 13779,   355,   257,  4936,    13,50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256],device='cuda:0')
    

    同样,我们仔细检查target是否包含-100占位符标记

    print(target[0])"""输出"""
    tensor([  318,   281, 12064,   326,  8477,   257,  4876,    13, 19430,   257,2882,   326, 20431, 32543,   262,  2581,    13,   198,   198, 21017,46486,    25,   198, 30003,  6525,   262,  6827,  1262,   257,   985,576,    13,   198,   198, 21017, 23412,    25,   198,   464,  5156,318,   845, 13779,    13,   198,   198, 21017, 18261,    25,   198,464,  5156,   318,   355, 13779,   355,   257,  4936,    13, 50256,-100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100],device='cuda:0')
    

http://www.lqws.cn/news/77473.html

相关文章:

  • 【机器学习基础】机器学习入门核心算法:多分类与多标签分类算法
  • 【iOS(swift)笔记-14】App版本不升级时本地数据库sqlite更新逻辑二
  • 如何使用flask做任务调度
  • hot100 -- 6.矩阵系列
  • python打卡day43@浙大疏锦行
  • 3,信号与槽机制
  • Eigen库介绍以及模块划分和相关示例代码
  • NodeJS全栈WEB3面试题——P3Web3.js / Ethers.js 使用
  • Cursor 0.51 全网首歌新功能深度体验:Generate Memories 让 AI 编程助手拥有“记忆“
  • 【DAY37】早停策略和模型权重的保存
  • 微软PowerBI考试 PL-300学习指南
  • 【001】利用github搭建静态网站_essay
  • Go整合Redis2.0发布订阅
  • 6.2本日总结
  • leetcode90.子集II:排序与同层去重的回溯优化策略
  • Python 在金融中的应用- Part 1
  • Pytorch知识点2
  • dify应用探索
  • 【Go语言】Ebiten游戏库开发者文档 (v2.8.8)
  • 字节跳动开源图标库:2000+图标一键换肤的魔法
  • 神经网络中的梯度消失与梯度爆炸
  • 代码随想录60期day54
  • 牛客周赛 Round 94
  • 聚类分析 | MATLAB实现基于SOM自组织特征映射聚类可视化
  • 数据结构之排序
  • 对抗攻击 Adversarial Attack
  • 实现按天更新vintage并热力图可视化
  • 【QT控件】QWidget 常用核心属性介绍 -- 万字详解
  • Python中sys模块详解
  • spring-boot接入websocket教程以及常见问题解决