当前位置: 首页 > news >正文

Embedding模型微调实战(ms-swift框架)

目录

简介

1. 创建虚拟环境

2 安装ms-swift

3安装其他依赖库

4. 下载数据集

5.开始embedding模型训练

6. 自定义数据格式和对应的Loss类型

(1) infoNCE损失     

(2)余弦相似度损失

(3)对比学习损失

(4).在线对比学习损失

(5)损失函数总结


简介

ms-swift是魔搭社区提供的大模型与多模态大模型微调部署框架,现已支持500+大模型与200+多模态大模型的训练(预训练、微调、人类对齐)、推理、评测、量化与部署。其中大模型包括:Qwen3、Qwen3-MoE、Qwen2.5、InternLM3、GLM4、Mistral、DeepSeek-R1、Yi1.5、TeleChat2、Baichuan2、Gemma2等模型,多模态大模型包括:Qwen2.5-VL、Qwen2-Audio、Llama4、Llava、InternVL3、MiniCPM-V-2.6、GLM4v、Xcomposer2.5、Yi-VL、DeepSeek-VL2、Phi3.5-Vision、GOT-OCR2等模型。

🍔 除此之外,ms-swift汇集了最新的训练技术,包括LoRA、QLoRA、Llama-Pro、LongLoRA、GaLore、Q-GaLore、LoRA+、LISA、DoRA、FourierFt、ReFT、UnSloth、和Liger等轻量化训练技术,以及DPO、GRPO、RM、PPO、GKD、KTO、CPO、SimPO、ORPO等人类对齐训练方法。ms-swift支持使用vLLM、SGLang和LMDeploy对推理、评测和部署模块进行加速,并支持使用GPTQ、AWQ、BNB等技术对大模型进行量化。ms-swift还提供了基于Gradio的Web-UI界面及丰富的最佳实践。

https://github.com/modelscope/ms-swift?tab=readme-ov-file

1. 创建虚拟环境

conda create -n swift_venv python=3.12 -y
conda init bash && source /root/.bashrc
conda activate swift_venv
conda install ipykernel
ipython kernel install --user --name=swift_venv

2 安装ms-swift

#1使用pip安装,把ms-swift作为一个库,安装到anaconda的虚拟环境中
pip install ms-swift -U#2从源码克隆安装(推荐)pip install git+https://github.com/modelscope/ms-swift.git#3从源码安装
git clone https://github.com/modelscope/ms-swift.git
cd ms-swift
pip install -e .

3安装其他依赖库

pip install deepspeed liger-kernel 
pip install scikit-learn
pip install -U sentence-transformers
# 建议科学上⽹后,再执⾏下⾯的命令
pip install flash-attn --no-build-isolation

4. 下载数据集

# 设置下⾯命令,⽆需科学上⽹
export HF_ENDPOINT=https://hf-mirror.com
pip install --upgrade huggingface_hub
huggingface-cli download --repo-type dataset --resume-download microsoft/ms_marco --local-dir 下载保存路径

数据格式:

5.开始embedding模型训练

# 把下面命令,保存为train.sh格式,   运行bash train.sh命令,启动训练CUDA_VISIBLE_DEVICES=0 \
swift sft \--model /data/qwen3_embedding/Qwen3-Embedding-0.6B \--task_type embedding \--train_type full \     #训练模式full lora  --torch_dtype bfloat16 \--num_train_epochs 100 \               #训练轮数--per_device_train_batch_size 16 \   #训练批次--per_device_eval_batch_size 16\--learning_rate 1e-4 \--lora_rank 8 \--lora_alpha 8 \--target_modules all-linear \     #目标模块    all-attention--gradient_accumulation_steps 1 \--eval_steps 1 \--save_steps 1 \--save_total_limit 2 \--logging_steps 5 \--max_length 512 \--output_dir output \        #输出路径--warmup_ratio 0.05 \--dataloader_num_workers 8 \--model_author swift \--model_name swift-robot \--split_dataset_ratio 0.3 \   #train 和val分割比例--dataset /home/dataset \    #数据集路径--loss_type infonce       # 损失函数3种类型   contrastive  cosine_similarity infonce

 使用swift sft --help命令查询有哪些训练设置参数。

6. 自定义数据格式和对应的Loss类型

(1) infoNCE损失     

  --loss_type  infonce

对⽐学习损失函数,最⼤化正样

本对相似度,最⼩化负样本对相似度 .

使⽤批内对⽐学习策略,将同批次内其他样本作为负样本.

数据格式:

[{"query": "如何学习编程?","response": "可以从Python语言开始入门,它语法简单适合初学者。","rejected_response": ["随便看看书就会了", "编程很难学不会的"]},{"query": "推荐一款性价比高的手机","response": "Redmi Note系列在2000元价位段表现均衡,值得考虑。","rejected_response": ["越贵的手机越好", "苹果手机永远是最好的"]}
]

 

(2)余弦相似度损失

 --loss_type cosine_similarity

直接优化预测相似度与真实相似度标签的差异 ,使⽤ MSE 损失计算 ||input_label - cosine_sim(u,v)||_2

数据格式:

[{"query": "A dog is barking loudly.","response": "The canine is making loud barking noises.","label": 0.8},{"query": "Children are playing in the park.","response": "Kids are playing in the playground.","label": 1.0},{"query": "The sun is shining brightly.","response": "Bright sunlight is visible.","label": 0.7}
]

(3)对比学习损失

 --loss_type contrastive

经典的对⽐学习损失,正样本拉近,负样本推远 需要设置 margin 参数。

[{"query": "A dog is barking loudly.","response": "The canine is making loud barking noises.","label": 1},{"query": "Children are playing in the park.","response": "Kids are playing in the playground.","label": 1}]

(4).在线对比学习损失

--loss_type online_contrastive

对⽐学习的改进版本,选择困难正样本和困难负样本 通常⽐标准对⽐学习效果更好。

(5)损失函数总结

 

http://www.lqws.cn/news/565759.html

相关文章:

  • 2025年IOTJ SCI2区TOP,动态协同鲸鱼优化算法DCWOA+多车车联网路径规划,深度解析+性能实测
  • 从RDS MySQL到Aurora:能否实现真正的无缝迁移?
  • OpenCV学习3
  • 设计模式之装饰者模式
  • 企业级路由器技术全解析:从基础原理到实战开发
  • promise深入理解和使用
  • 线性相关和线性无关
  • 【数据挖掘】聚类算法学习—K-Means
  • Windows 4625日志类别解析:未成功的账户登录事件
  • 节点小宝:告别公网IP,重塑你的远程连接体验
  • 数据库 DML 语句详解:语法与注意事项
  • Android大图加载优化:BitmapRegionDecoder深度解析与实战
  • 【分布式 ID】生成唯一 ID 的几种方式
  • 面试150 螺旋矩阵
  • 模拟工作队列 - 华为OD机试真题(JavaScript卷)
  • llama.cpp学习笔记:后端加载
  • Windows系统安装鸿蒙模拟器
  • 接口自动化测试(Python+pytest+PyMySQL+Jenkins)
  • OpenLayers 全屏控件介绍
  • Wpf布局之StackPanel!
  • Mac电脑手动安装原版Stable Diffusion,开启本地API调用生成图片
  • 在Mac上查找并删除Java 21.0.5
  • 【Canvas与标志】圆规脚足球俱乐部标志
  • Spring Cloud Gateway 实战:从网关搭建到过滤器与跨域解决方案
  • 浮油 - 3 相分层和自由表面流 CFX 模拟
  • 医疗AI智能基础设施构建:向量数据库矩阵化建设流程分析
  • js 基础
  • PCB工艺学习与总结-20250628
  • JVM——垃圾回收
  • Kafka4.0初体验