当前位置: 首页 > news >正文

BERT 模型准备与转换详细操作流程

在尝试复现极客专栏《PyTorch 深度学习实战|24 | 文本分类:如何使用BERT构建文本分类模型?》时候,构建模型这一步骤专栏老师一笔带过,对于新手有些不友好,经过一阵摸索,终于调通了,现在总结一下整体流程。

1. 获取必要脚本文件

首先,我们需要从 Transformers 的 GitHub 仓库中找到相关文件:

# 克隆 Transformers 仓库
git clone https://github.com/huggingface/transformers.git
cd transformers

在仓库中,我们需要找到以下关键文件:

  • src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py(用于 TF1.x 模型)
  • src/transformers/models/bert/convert_bert_original_tf2_checkpoint_to_pytorch.py(用于 TF2.x 模型)
  • src/transformers/models/bert/modeling_bert.py(BERT 的 PyTorch 实现)

2. 下载预训练模型

接下来,我们需要下载 Google 提供的预训练 BERT 模型。根据你的需求,我们选择"BERT-Base, Multilingual Cased"版本,它支持104种语言。

访问 Google 的 BERT GitHub 页面:https://github.com/google-research/bert

在该页面中找到"BERT-Base, Multilingual Cased"的下载链接,或直接使用以下命令下载:

mkdir bert-base-multilingual-cased
cd bert-base-multilingual-cased# 下载模型文件
wget https://storage.googleapis.com/bert_models/2018_11_23/multi_cased_L-12_H-768_A-12.zip
unzip multi_cased_L-12_H-768_A-12.zip

解压后,你会得到以下文件:

  • bert_model.ckpt.data-00000-of-00001
  • bert_model.ckpt.index
  • bert_model.ckpt.meta
  • bert_config.json
  • vocab.txt

3. 模型转换

现在,我们使用之前找到的转换脚本将 TensorFlow 模型转换为 PyTorch 格式:

# 回到 transformers 目录
cd ../transformers# 执行转换脚本(针对 TF2.x 模型)
python src/transformers/models/bert/convert_bert_original_tf2_checkpoint_to_pytorch.py \--tf_checkpoint_path ../bert-base-multilingual-cased/bert_model.ckpt \--bert_config_file ../bert-base-multilingual-cased/bert_config.json \--pytorch_dump_path ../bert-base-multilingual-cased/pytorch_model.bin

如果你下载的是 TF1.x 模型,则使用:

python src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py \--tf_checkpoint_path ../bert-base-multilingual-cased/multi_cased_L-12_H-768_A-12/bert_model.ckpt \--bert_config_file ../bert-base-multilingual-cased/multi_cased_L-12_H-768_A-12/bert_config.json \--pytorch_dump_path ../bert-base-multilingual-cased/multi_cased_L-12_H-768_A-12/pytorch_model.bin

注意,此处需要安装tensorflow。

4. 准备完整的 PyTorch 模型目录

转换完成后,我们需要确保模型目录包含所有必要文件:

cd ../bert-base-multilingual-cased# 复制 bert_config.json 为 config.json(Transformers 库需要)
cp bert_config.json config.json

现在,你的模型目录应该包含以下三个关键文件:

  1. config.json:模型配置文件,包含了所有用于训练的参数设置
  2. pytorch_model.bin:转换后的 PyTorch 模型权重文件
  3. vocab.txt:词表文件,用于识别模型支持的各种语言的字符

5. 验证模型转换成功

为了验证模型转换是否成功,我们可以编写一个简单的脚本来加载模型并进行测试:

from transformers import BertTokenizer, BertModel# 加载模型和分词器
model_path = "path/to/bert-base-multilingual-cased"
tokenizer = BertTokenizer.from_pretrained(model_path)
model = BertModel.from_pretrained(model_path)# 测试多语言能力
texts = ["Hello, how are you?",  # 英语"你好,最近怎么样?",    # 中文"Hola, ¿cómo estás?"   # 西班牙语
]for text in texts:inputs = tokenizer(text, return_tensors="pt")outputs = model(**inputs)print(f"Text: {text}")print(f"Shape of last hidden states: {outputs.last_hidden_state.shape}")print("---")

6. 使用模型进行下游任务

现在你可以使用这个转换好的模型进行各种下游任务,如文本分类、命名实体识别等:

from transformers import BertTokenizer, BertForSequenceClassification
import torch# 加载模型和分词器
model_path = "path/to/bert-base-multilingual-cased"
tokenizer = BertTokenizer.from_pretrained(model_path)# 初始化分类模型(假设有2个类别)
model = BertForSequenceClassification.from_pretrained(model_path, num_labels=2)# 准备输入
text = "这是一个测试文本"
inputs = tokenizer(text, return_tensors="pt")# 前向传播
outputs = model(**inputs)
logits = outputs.logits# 获取预测结果
predicted_class = torch.argmax(logits, dim=1).item()
print(f"预测类别: {predicted_class}")

注意事项

  1. 模型文件大小:BERT-Base 模型文件通常较大(约400MB+),请确保有足够的磁盘空间和内存。

  2. 路径问题:在执行转换脚本时,确保正确指定了所有文件的路径。

  3. 命名约定:Transformers 库期望配置文件名为 config.json,而不是 bert_config.json,所以需要进行复制或重命名。

  4. TensorFlow 版本:根据你下载的模型版本(TF1.x 或 TF2.x),选择正确的转换脚本。

  5. checkpoint 文件:转换脚本中的 --tf_checkpoint_path 参数应该指向不带后缀的 checkpoint 文件名(如 bert_model.ckpt),而不是具体的 .index.data 文件。

通过以上步骤,你就可以成功地将 Google 预训练的 BERT 模型转换为 PyTorch 格式,并在你的项目中使用它了。这个多语言版本的 BERT 模型支持 104 种语言,非常适合多语言自然语言处理任务。

http://www.lqws.cn/news/482509.html

相关文章:

  • Bytemd@Bytemd/react详解(编辑器实现基础AST、插件、跨框架)
  • Macbook M4芯片 MUMU模拟器安装使用burpsuit抓包教程APP
  • WEB3合约开发以太坊中货币单位科普
  • 应急推进器和辅助推进器诊断函数封装
  • 媒体AI关键技术研究
  • linux----------------进程VS线程
  • 零基础学习Redis(14) -- Spring中使用Redis
  • RA4M2开发IOT(9)----动态显示MEMS数据
  • 深入理解Spring MVC:构建灵活Web应用的基石
  • 【SQL语法汇总】
  • Python 商务数据分析—— NumPy 学习笔记Ⅰ
  • 由浅入深详解前缀树-Trie树
  • 数智管理学(二十四)
  • Flink Connector Kafka深度剖析与进阶实践指南
  • ELMo 说明解析及用法
  • Netty Channel 详解
  • 【递归,搜索与回溯算法】记忆化搜索(二)
  • 【CSS】CSS3媒体查询全攻略
  • 基于Vue.js的图书管理系统前端界面设计
  • 【分布式技术】Bearer Token以及MAC Token深入理解
  • 大模型应用:如何使用Langchain+Qwen部署一套Rag检索系统
  • 制造业B端登录页案例:生产数据安全入口的权限分级设计
  • AMAT P5000 CVDFDT CVDMAINT Precision 5000 Mark 操作 电气原理 PCB图 电路图等
  • 【Datawhale组队学习202506】YOLO-Master task03 IOU总结
  • 防御悬垂指针:C++的多维度安全实践指南
  • 【前后前】导入Excel文件闭环模型:Vue3前端上传Excel文件,【Java后端接收、解析、返回数据】,Vue3前端接收展示数据
  • hot100 -- 16.多维动态规划
  • 分布式ID生成方式及优缺点详解
  • Azure Devops
  • 时序数据库IoTDB的架构、安装启动方法与数据模式总结