当前位置: 首页 > news >正文

组相对策略优化(GRPO):原理及源码解析

文章目录

    • PPO vs GRPO
    • PPO的目标函数
    • GRPO的目标函数
      • KL散度约束与估计
      • ORM监督RL的结果
      • PRM监督RL的过程
      • 迭代RL
      • 算法流程
    • GRPO损失的不同版本
    • GRPO源码解析

  • DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models

PPO vs GRPO

在这里插入图片描述

PPO的目标函数

J P P O ( θ ) = E [ q ∼ P ( Q ) , o ∼ π θ old  ( O ∣ q ) ] 1 ∣ o ∣ ∑ t = 1 ∣ o ∣ min ⁡ [ π θ ( o t ∣ q , o < t ) π θ old  ( o t ∣ q , o < t ) A t , clip ⁡ ( π θ ( o t ∣ q , o < t ) π θ old  ( o t ∣ q , o < t ) , 1 − ε , 1 + ε ) A t ] \begin{align*} \mathcal{J}_{P P O}(\theta) &=\mathbb{E}\left[q \sim P(Q), o \sim \pi_{\theta_{\text {old }}}(O \mid q)\right]\\ &\frac{1}{|o|}\sum_{t=1}^{|o|} \min \left[\frac{\pi_\theta\left(o_t \mid q, o_{<t}\right)}{\pi_{\theta_{\text {old }}}\left(o_t \mid q, o_{<t}\right)} A_t, \operatorname{clip}\left(\frac{\pi_\theta\left(o_t \mid q, o_{<t}\right)}{\pi_{\theta_{\text {old }}}\left(o_t \mid q, o_{<t}\right)}, 1-\varepsilon, 1+\varepsilon\right) A_t\right] \end{align*} JPPO(θ)=E[qP(Q),oπθold (Oq)]o1t=1omin[πθold (otq,o<t)πθ(otq,o<t)At,clip(πθold (otq,o<t)πθ(otq,o<t),1ε,1+ε)At]

A t A_t At是使用广义优势估计(GAE)基于奖励 { r ≥ t } \{r_{\ge t}\} {rt}和状态价值 V ψ V_{\psi} Vψ计算的优势值,需联合训练策略模型和状态价值模型。通常为避免奖励模型被过度拟合而产生异常输出,标准做法为每一个token的奖励添加策略模型和参考模型的KL惩罚。
r t = r φ ( q , o ≤ t ) − β log ⁡ π θ ( o t ∣ q , o < t ) π r e f ( o t ∣ q , o < t ) r_t=r_{\varphi}\left(q, o_{\leq t}\right)-\beta \log \frac{\pi_\theta\left(o_t \mid q, o_{<t}\right)}{\pi_{r e f}\left(o_t \mid q, o_{<t}\right)} rt=rφ(q,ot)βlogπref(otq,o<t)πθ(otq,o<t)

GRPO的目标函数

PPO算法使用价值模型输出作为优势的baseline,指导策略模型更新。价值模型一般与策略模型同尺寸,训练时占显存、耗算力。在LLM生成场景下,奖励函数给出整个response的分数,再加到最后一个token的奖励上,价值模型要预测token-level的奖励,比较困难。

GRPO通过对单个query采样多个response,取平均奖励作为baseline不需要使用价值模型(foregoes critic model),目标函数为:

J G R P O ( θ ) = E [ q ∼ P ( Q ) , { o i } i = 1 G ∼ π θ o l d ( O ∣ q ) ] 1 G ∑ i = 1 G 1 ∣ o i ∣ ∑ t = 1 ∣ o i ∣ { min ⁡ [ π θ ( o i , t ∣ q , o i , < t ) π θ o l d ( o i , t ∣ q , o i , < t ) A ^ i , t , clip ⁡ ( π θ ( o i , t ∣ q , o i , < t ) π θ o l d ( o i , t ∣ q , o i , < t ) , 1 − ε , 1 + ε ) A ^ i , t ] − β D K L [ π θ ∣ ∣ π r e f ] } \begin{align*} \mathcal{J}_{G R P O}(\theta) & =\mathbb{E}\left[q \sim P(Q),\left\{o_i\right\}_{i=1}^G \sim \pi_{\theta_{o l d}}(O \mid q)\right] \\ & \frac{1}{G} \sum_{i=1}^G \frac{1}{\left|o_i\right|} \sum_{t=1}^{\left|o_i\right|}\left\{\min \left[\frac{\pi_\theta\left(o_{i, t} \mid q, o_{i,<t}\right)}{\pi_{\theta_{o l d}}\left(o_{i, t} \mid q, o_{i,<t}\right)} \hat{A}_{i, t}, \operatorname{clip}\left(\frac{\pi_\theta\left(o_{i, t} \mid q, o_{i,<t}\right)}{\pi_{\theta_{o l d}}\left(o_{i, t} \mid q, o_{i,<t}\right)}, 1-\varepsilon, 1+\varepsilon\right) \hat{A}_{i, t}\right]-\beta \mathbb{D}_{K L}\left[\pi_\theta| | \pi_{r e f}\right]\right\} \end{align*} JGRPO(θ)=E[qP(Q),{oi}i=1Gπθold(Oq)]G1i=1Goi1t=1oi{min[πθold(oi,tq,oi,<t)πθ(oi,tq,oi,<t)A^i,t,clip(πθold(oi,tq,oi,<t)πθ(oi,tq,oi,<t),1ε,1+ε)A^i,t]βDKL[πθ∣∣πref]}

建立组内竞争机制,不需要外部独立的Critic。比组内平均分高的响应获得正分数,低的获得负分数,鼓励模型生成比平均水平更好的响应,使得平均得分越来越高。

KL散度约束与估计

KL散度项用于约束策略更新幅度,我们使用k3型的KL散度估计:
D K L [ π θ ∣ ∣ π r e f ] = π r e f ( o i , t ∣ q , o i , < t ) π θ ( o i , t ∣ q , o i , < t ) − log ⁡ π r e f ( o i , t ∣ q , o i , < t ) π θ ( o i , t ∣ q , o i , < t ) − 1 \mathbb{D}_{K L}\left[\pi_\theta| | \pi_{r e f}\right]=\frac{\pi_{r e f}\left(o_{i, t} \mid q, o_{i,<t}\right)}{\pi_\theta\left(o_{i, t} \mid q, o_{i,<t}\right)}-\log \frac{\pi_{r e f}\left(o_{i, t} \mid q, o_{i,<t}\right)}{\pi_\theta\left(o_{i, t} \mid q, o_{i,<t}\right)}-1 DKL[πθ∣∣πref]=πθ(oi,tq,oi,<t)πref(oi,tq,oi,<t)logπθ(oi,tq,oi,<t)πref(oi,tq,oi,<t)1

解释: 奖励模型经比较/偏好数据集训练,使用相对优势的RL方法与奖励模型也比较匹配。PPO方法将策略模型和参考模型的KL散度作为奖励的惩罚,GRPO不惩罚奖励,而是将KL惩罚直接放在策略损失里面,避免在 A i , t A_{i,t} Ai,t中引入复杂的计算。

通常 x x x无法穷举,一般通过多次采样求平均方式估计期望,即无偏估计,KL散度的定义及无偏估计为
K L [ p ∣ ∣ q ] = ∑ x p ( x ) log ⁡ ( p ( x ) q ( x ) ) = E x ∼ p [ p ( x ) q ( x ) ] ≈ 1 N log ⁡ ( p ( x ) q ( x ) ) KL[p||q]=\sum_x p(x)\log\left(\dfrac{p(x)}{q(x)}\right)=\mathbb E_{x\sim p}\left[\frac{p(x)}{q(x)}\right]\approx\frac{1}{N}\log\left(\frac{p(x)}{q(x)}\right) KL[p∣∣q]=xp(x)log(q(x)p(x))=Exp[q(x)p(x)]N1log(q(x)p(x))

采样与期望: 如果p中有n个不同的x,从中随机采样m个x,m>>n,则重复x的个数除以m就近似为概率p(x)。

r = q ( x ) / p ( x ) r=q(x)/p(x) r=q(x)/p(x),几种KL散度采样估计:

  • k1 − log ⁡ r -\log r logr无偏、高方差,半数样本为负(KL为正),偏差比较高。
  • k2 1 2 ( log ⁡ r ) 2 \dfrac{1}{2}(\log r)^2 21(logr)2有偏、低方差,始终为正,明确反映出分布之间的偏离程度。
  • k3 − log ⁡ r + ( r − 1 ) -\log r + (r - 1) logr+(r1)无偏、低方差,始终为正。启发式设计,k1加上期望为0,并且与其负相关的项。
    • p ( x ) p(x) p(x) q ( x ) q(x) q(x)分步接近时, r r r的期望为1,新增项 r − 1 r-1 r1为0;
    • r r r增大,k1 − log ⁡ ( r ) -\log(r) log(r)减小,新增项 ( r − 1 ) (r-1) (r1)增加;
    • 直观表达, l o g ( p / q ) + ( q / p − 1 ) log(p/q)+(q/p-1) log(p/q)+(q/p1) p ( x ) p(x) p(x)大于 q ( x ) q(x) q(x)时,k1大于0,新增修正项小于1;

ORM监督RL的结果

对于每个query q q q,从 π θ o l d \pi_{\theta_{old}} πθold中采样一组输出 G = { o 1 , o 2 , ⋯ , o G } G=\{o_1,o_2,\cdots,o_{G}\} G={o1,o2,,oG},奖励模型对这些输出(或者说结果Outcome)打分 r = { r 1 , r 2 , ⋯ , r G } {\bf r}=\{r_1,r_2,\cdots,r_{G}\} r={r1,r2,,rG},将这些奖励标准化可作为每个输出 o i o_i oi在结束位置的组内相对优势
A ^ i , t = r ~ i = r i − mean ⁡ ( r ) std ⁡ ( r ) \hat{A}_{i, t}=\widetilde{r}_i=\frac{r_i-\operatorname{mean}(\mathbf{r})}{\operatorname{std}(\mathbf{r})} A^i,t=r i=std(r)rimean(r)

PRM监督RL的过程

结果监督仅提供了每个输出在结束位置的奖励,不足以监督复杂的数学推理任务。

为进行过程监督,对每个推理步骤打分:
R = { { r 1 i n d e x ( 1 ) , ⋯ , r 1 i n d e x ( K 1 ) } , ⋯ , { r G i n d e x ( 1 ) , ⋯ , r G i n d e x ( K G ) } } \mathbf{R}=\left\{\left\{r_1^{{index}(1)}, \cdots, r_1^{{index}\left(K_1\right)}\right\}, \cdots,\left\{r_G^{{index}(1)}, \cdots, r_G^{{index}\left(K_G\right)}\right\}\right\} R={{r1index(1),,r1index(K1)},,{rGindex(1),,rGindex(KG)}}

其中 i n d e x ( j ) index(j) index(j)表示第 j j j步的结束token,标准化的步骤奖励为
r ~ i i n d e x ( j ) = r i i n d e x ( j ) − mean ⁡ ( R ) std ⁡ ( R ) \tilde{r}_i^{{index}(j)}=\frac{r_i^{{index}(j)}-\operatorname{mean}(\mathbf{R})}{\operatorname{std}(\mathbf{R})} r~iindex(j)=std(R)riindex(j)mean(R)

每一个token的优势等于之后所有步骤的标准化奖励和:
A ^ i , t = ∑ i n d e x ( j ) ≥ t r ~ i i n d e x ( j ) \hat A_{i,t}=\sum_{index(j)\ge t}\tilde r_i^{index(j)} A^i,t=index(j)tr~iindex(j)

迭代RL

随着策略模型更新,奖励模型可能不足以监督策略模型。GRPO使用迭代的方式,从新的策略模型中采样数据,加上10%的历史数据,以继续训练方式更新奖励模型。之后,将最新的策略模型设置为参考模型,继续训练策略模型,重复上述过程。

算法流程

在这里插入图片描述

奖励模型使用base模型初始化
奖励模型训练数据

GRPO损失的不同版本

GRPO目标可以定义为
L G R P O ( θ ) = − 1 G ∑ i = 1 G 1 ∣ o i ∣ ∑ t = 1 ∣ o i ∣ l i , t , w . t . l i , t = π θ ( o i , t ∣ q , o i , < t ) [ π θ ( o i , t ∣ q , o i , < t ) ] n o g r a d A ^ i , t − β D K L [ π θ ∥ π r e f ] \mathcal{L}_{\mathrm{GRPO}}(\theta)=-\frac{1}{G} \sum_{i=1}^G \frac{1}{\left|o_i\right|} \sum_{t=1}^{\left|o_i\right|} l_{i, t}, \quad w.t.\ l_{i, t}=\frac{\pi_\theta\left(o_{i, t} \mid q, o_{i,<t}\right)}{\left[\pi_\theta\left(o_{i, t} \mid q, o_{i,<t}\right)\right]_{\mathrm{no} \mathrm{grad}}} \hat{A}_{i, t}-\beta \mathbb{D}_{\mathrm{KL}}\left[\pi_\theta \| \pi_{\mathrm{ref}}\right] LGRPO(θ)=G1i=1Goi1t=1oili,t,w.t. li,t=[πθ(oi,tq,oi,<t)]nogradπθ(oi,tq,oi,<t)A^i,tβDKL[πθπref]

DAPO指出,GRPO使用sample-level损失,在long-COT场景下,long-response惩罚不足,导致其输出质量比较低。DAPO使用token-level损失,所有response中的每个token的奖励更加平衡,不受response长度的影响。
L D A P O ( θ ) = − 1 ∑ i = 1 G ∣ o i ∣ ∑ i = 1 G ∑ t = 1 ∣ o i ∣ l i , t \mathcal{L}_{\mathrm{DAPO}}(\theta)=-\frac{1}{\sum_{i=1}^G\left|o_i\right|} \sum_{i=1}^G \sum_{t=1}^{\left|o_i\right|} l_{i, t} LDAPO(θ)=i=1Goi1i=1Gt=1oili,t

Dr. GRPO指出,DAPO没有完全消除不同response长度偏差的影响,为了更彻底的消除,其使用常数替代序列长度:
L Dr. GRPO ( θ ) = − 1 L G ∑ i = 1 G ∑ t = 1 ∣ o i ∣ l i , t \mathcal{L}_{\text{Dr. GRPO}}(\theta) = -\frac{1}{LG} \sum_{i=1}^{G} \sum_{t=1}^{|o_i|} l_{i, t} LDr. GRPO(θ)=LG1i=1Gt=1oili,t

GRPO源码解析

代码库trl中GRPOTrainer的实现,继承于Transformers Trainer,重载_prepare_inputscompute_loss方法

源码在这里:https://github.com/huggingface/trl/blob/v0.18.1/trl/trainer/grpo_trainer.py

算法过程

  1. 构造批次输入prompts
    • 使用自定义的RepeatSampler采样批次,保证每个prompt能重复采样多次,并且能跨进程同步分组;
    • 风格为generatechat_completions,执行左padding,左truncate;
  2. 采样completions_prepare_inputs中调用_generate_and_score_completions,参数为temperature=0.9top_p=1.0max_new_tokens=256
    • 若使用vllm server:
      • 权重同步:确保policy model和vllm model的参数同步;
      • 数据并行采样:主进程上gather所有进程上的prompts,为每个不重复的prompt生成num_generations个completions;
      • 广播分配:主进程上broadcast所有completions到其它进程,所有进程截取自己prompts的completions;
    • 若使用transformers标准的model.generate:
      • 独立生成每个prompt的completion,包含重复的prompt(同一prompt多次prefill),计算低效;
  3. 处理completion padding
    • 根据completion中EOS的位置计算completion长度,并mask首个EOS后的token,只保留有效的completion token;
    • mask所有没有EOS的completion,避免异常completion对loss影响过大(可选);
  4. 计算old_logprobs:若使用相同completion多次迭代优化,计算当前policy model的logprobs作为old_logprobs,用于后续epoch中计算概率比率;
  5. 计算scores:每个reward model/reward func计算每条prompt+completion的score并加权,得到每条sentence的score;
  6. 计算advantages:gather所有进程上的scores,分组标准化,即奖励 - 奖励均值 / 奖励标准差(可选)
  7. 计算loss
    • 计算policy model的logprobs;
    • 计算reference model的ref_logprobs;
    • 计算policy model和reference model之间在每个completion token的kl散度,使用k3无偏估计:kl=log(p/q)+(q/p-1),如果p和q都是对数概率,则kl=p-q+exp(q-p)-1,即kl损失
    • 使用logprobs和old_logprobs计算概率比率并裁剪,限制参数更新幅度(重要性采样,PPO算法的核心),利用裁剪后概率比率clamped_ratio、advantage和completion mask,计算每个token的策略损失
    • 损失加权求和:加权求和token-level的策略损失和kl损失,kl损失权重小,非主导;
    • 损失均值化:loss有多种求和/平均方式,bnpo loss不考虑每条样本的completion长度的影响,取所有token的平均loss。grpo_loss对每条completion依次在token-level、sample-level上求和平均,对长completion的惩罚不足;
    • 使用梯度下降更新policy model;
http://www.lqws.cn/news/99487.html

相关文章:

  • 从测试角度看待CI/CD,敏捷开发
  • tauri项目绕开plugin-shell直接调用可执行文件并携带任意参数
  • OpenCV C++ 学习笔记(五):颜色空间转换、数值类型转换、图像混合、图像缩放
  • redis数据过期策略
  • 垂起固定翼无人机应用及技术分析
  • [特殊字符] Unity UI 性能优化终极指南 — ScrollRect篇
  • 如何提高工作效率
  • 日语学习-日语知识点小记-构建基础-JLPT-N4阶段(31):そう
  • 第十三章 Java基础-特殊处理
  • 【鸿蒙】HarmonyOS NEXT之如何正常加载地图组件
  • HTTP连接管理——短连接,长连接,HTTP 流水线
  • 常见的七种排序算法 ——直接插入排序
  • Vue-ref 与 props
  • 数据的评估与清洗篇---评估数据
  • TSN 中的 CBS(Credit-Based Shaper)功能详解
  • 低谷才是出成绩
  • C#对象扩展方法:提升对象操作的灵活性与效率
  • 【Web应用】若依框架:基础篇13 源码阅读-前端代码分析
  • 物联网数据归档方案选择分析
  • 24.【.NET8 实战--孢子记账--从单体到微服务--转向微服务】--单体转微服务--认证微服务
  • 华为盘古 Ultra MoE 模型:国产 AI 的技术突破与行业影响
  • 更新已打包好的 Spring Boot JAR 文件中的 class 文件
  • Vue.js教学第十八章:Vue 与后端交互(二):Axios 拦截器与高级应用
  • 从汇编的角度揭秘C++引用,豁然开朗
  • 硬件工程师笔记——555定时器应用Multisim电路仿真实验汇总
  • CRM管理软件的数据可视化功能使用技巧:让数据驱动决策
  • SpringBoot 之 JWT
  • 8.RV1126-OPENCV 视频中添加LOGO
  • Web后端快速入门(Maven)
  • OSCP备战-BSides-Vancouver-2018-Workshop靶机详细步骤