当前位置: 首页 > news >正文

强化学习系列--dpo损失函数

DPO 概要

  1. DPO(Direct Preference Optimization,直接偏好优化)是由斯坦福大学等研究团队于2023年提出的一种偏好优化算法,可用于LLM、VLM与MLLM的对齐训练。

  2. 算法基于PPO的RLHF基础上进行了大幅简化。DPO算法跳过了训练奖励模型这一中间过程,直接(Direct)优化策略模型 ——这正是DPO命名中“D(Direct)”的含义所在。

主要流程

  1. 数据收集: 基于SFT训练的模型作为推理模型,用户输入prompt,模型多次推理,找到好的答案和不好的答案。如果都是不好(rejected)的答案,则人工修改把不好的答案变为好的答案。

    标数据收集
  2. 主要包含两个基础模型,策略模型&参考模型(不需要Reward模型)。 在trl强化学习框架中,只需要传入策略模型,参考模型会复制一份策略模型。

    1. 策略模型是DPO需要训练的模型,后用在项目中的模型。策略模型的权重直接复制SFT阶段微调模型的权重

    2. 参考模型是策略模型的帮衬,其权重参数冻结不变。主要两个作用,其一协助其计算reward loss,其二计算kl正则项,防止其训练偏移初始SFT模型太远,由一个β参数控制。

  3. β参数控制含义

    1. 较大 beta(如 1.0):放大 reward 或 logp 的差异,使模型更“自信”地倾向于较优样本,但容易过拟合或 reward 震荡。

    2. 较小 beta(如 0.1):差异被压缩,模型训练更稳定,但收敛较慢、辨别力较弱。

    3. 极小 beta(趋近于 0):差异几乎无效,模型无法区分好坏样本,退化为随机训练

  4.  整体流程如下:

  5. 具体流程

    DPO训练流程细节

九个损失函数解析

"loss": 1.8678"rewards/chosen": 42.519317626953125"rewards/rejected": -33.865535736083984"rewards/accuracies": 0.865429699420929"rewards/margins": 76.38734436035156"logps/chosen": -948.4149780273438"logps/rejected": -1285.1175537109375"logits/chosen": 5.363300800323486"logits/rejected": 4.879658222198486
  1. logps/chosen和logps/rejected: logps 是模型生成 token 概率,在归一化后(softmax)取 log 后的值(log prob)。

    #1 把 prompt 和 response 拼接起来作为输入
    input = prompt + response
    from transformers import AutoTokenizer, AutoModelForCausalLM
    import torch# 加载 tokenizer 和模型
    tokenizer = AutoTokenizer.from_pretrained("your-model-name")
    model = AutoModelForCausalLM.from_pretrained("your-model-name").cuda()# 设置 prompt 和 response
    prompt = "你今天心情怎么样?"
    response = "我今天很开心,太阳出来了,我们一起去玩吧!"# 拼接输入
    full_input = prompt + response
    encodings = tokenizer(full_input, return_tensors="pt").to("cuda")
    input_ids = encodings["input_ids"]# 找到 response 的起始位置
    prompt_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to("cuda")
    response_start = prompt_ids.shape[-1]# 前向推理,获取 logits
    with torch.no_grad():outputs = model(**encodings)logits = outputs.logits# 计算 log probabilities
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)# 获取 response 部分 token 的 log probability
    response_token_ids = input_ids[:, response_start:]
    response_logits = log_probs[:, response_start - 1:-1, :]  # 对应 shift
    response_logp = torch.gather(response_logits, 2, response_token_ids.unsqueeze(-1)).squeeze(-1)# 平均 log probability(整个 response)
    logp_response = response_logp.mean()logps_chosen = compute_logp(prompt, chosen, actor_model)
    logps_rejected = compute_logp(prompt, rejected, actor_model)
    logps_ref_chosen = compute_logp(prompt, chosen, ref_model)
    logps_ref_rejected = compute_logp(prompt, rejected, ref_model)
  2. logits/chosen和logits/rejected: 模型输出的raw score(未进行归一化)求平均

    # 模型输出:logits = [batch_size, seq_len, vocab_size]
    # 获取 chosen 的最后一个 token 的 logit:
    logit_chosen = logits[:, -1, :]  # 通常是这个位置
    logits/chosen = logit_chosen.mean().item()
    # 拿出 chosen response 部分的 token 对应的 logit 向量
    logits_response = logits[:, prompt_len:, :]  # mask 掉 prompt 部分
    logits/chosen = logits_response.mean().item()
  3. reward 计算方法

    chosen_rewards = self.beta * (chosen_logps.to(device) - ref_chosen_logps.to(device)).detach()
    rejected_rewards = self.beta * (rejected_logps.to(device) - ref_rejected_logps.to(device)).detach()
    reward_accuracies = (chosen_rewards > rejected_rewards).float()
    metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
    metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
    metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
    metrics[f"{prefix}rewards/margins"] = (
    self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item()
  4. Loss 计算方法

    本次默认使用sigmoidlogratios = chosen_logps - rejected_logpsref_logratios = ref_chosen_logps - ref_rejected_logps                logratios = logratios.to(self.accelerator.device)ref_logratios = ref_logratios.to(self.accelerator.device)logits = logratios - ref_logratios losses = (-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)- F.logsigmoid(-self.beta * logits) * self.label_smoothing )
    其他计算方法如下(后续介绍):"hinge","ipo",
    "exo_pair","nca_pair","robust","bco_pair",
    "sppo_hard","aot","apo_down""aot_pair","apo_zero","discopop",
  5. 关系理解

    指标

    含义

    关系

    logits

    每个 token 的原始输出分数(未归一化)

    模型输出的raw score(未进行归一化)求平均

    logps

    所有 token 的 log 概率之和(对 logit softmax 后求 log,token-wise 累加)

    来自 logits → softmax → log(prob) → sum over tokens

    rewards

    在 logp-based reward 情况下,reward 就是 sum(logps)/len(tokens)

    eval_rewards/chosen == eval_logps/chosen/len(tokens)

  6. 主要关注指标

    指标名

    含义

    影响

    loss

    当前 batch 的 DPO/IPO 损失值

    反映训练是否有效收敛,是否有发散/震荡

    rewards/margins

    reward_chosen - reward_rejected 的平均值

    反映模型区分正负样本的能力是否提升

    rewards/accuracies

    reward_chosen > reward_rejected 的比例

    反映偏好判断正确率是否提高

    logs/chosen& logs/rejected

    每个 sample 的对数似然总和

    趋势变化判断 token-level 拟合趋势

其他思考

1.  logps/chosen是负的合理吗

logps(y_{chosen}|x})logps(y_{chosen}|x}) 是模型对生成chosen回复时,每个token的概率取对数后加总, 由于每一个token的概率 ,所以。p(yt,y<t)∈(0,1),所以logp(yt)<0。 所以累加一段文本后,整个logp通常是一个比较大的负值。

2. reward为负值

因为是 rchosen=logπθ(ychosen|x) ,如果没有额外reward打分模型,则 r=sum(logps)/len(logps)

http://www.lqws.cn/news/590221.html

相关文章:

  • 齿轮的齿厚极限偏差如何确定?一起学习一下
  • C++基础
  • 目前最火的agent方向-A2A快速实战构建(二): AutoGen模型集成指南:从OpenAI到本地部署的全场景LLM解决方案
  • 《Python 架构之美:三大设计模式实战指南》
  • 【FR801xH】富芮坤FR801xH之UART
  • 【javaAI】SpringAI快速入门
  • 【C#】如果有一个数值如 168.0000100,如何去除末尾的无效零,只显示有效的小数位数,让DeepSeek给我们解答
  • 半加器和全加器
  • Disruptor架构哲学
  • 【机器学习2】正则化regularizaiton(降低模型过拟合)
  • 设备管理的11个指标、七大误区、六大特征
  • muduo
  • 数据结构——线性表的链式存储
  • QT笔记---环境和编译出现的问题
  • Golang的代码结构设计原则与实践与模式应用
  • helm安装配置jenkins
  • 百度轮岗:任命新CFO,崔珊珊退居业务二线
  • Redis-7.4.3-Windows-x64下载安装使用
  • 时空数据挖掘五大革新方向详解篇!
  • 我认知的AI宇宙系列第三期
  • 强化学习概述及学习流程
  • 3D词云图
  • 虚拟机配置过程中的知识点
  • shardingsphere5.2.1与SpringBoot3.X的版本冲突问题
  • 华为云Flexus+DeepSeek征文 | ​​华为云ModelArts Studio大模型与企业AI会议纪要场景的对接方案
  • 具身智能环境的构建和工作(具身智能入门四)
  • Oracle 进阶语法实战:从多维分析到数据清洗的深度应用​(第四课)
  • 贪心算法在C++中的应用与实践
  • Monorepo+Pnpm+Turborepo
  • 数据结构:链表