DPO的损失函数可以做如下恒等变换,可以发现蓝色框中的结果是一个可以预计算的常数。因此,在使用DPO进行优化时,我们完全可以离线计算常数部分,在线只需要加载katex is not defined进行训练,这样和sft的训练相差无几。

katex is not defined

根据此思路,我们可以在OpenSFT上进行二次开发。 涉及到的点主要是ref_model的离线log_prob的推理和在线损失的计算。
ref_model的离线log_prob的推理部分可以参考如下代码,里面包含两种方式:
1. 给定,输出y和y的概率;
2. 给定<x,y>,输出y的概率;
同模型同y下,两种方法存在精度差异,影响不大。

from openai import OpenAI

base_url = ""
client = OpenAI(base_url=base_url, api_key="EMPTY")
system = '你是一个人工智能助手。'
messages = [
    {"role": "system", "content": system},
    {"role": "user", "content": "如何计算圆周率前10位。请一步一步思考,不少于2000字。"}
]
response1 = client.chat.completions.create(
    model="Qwen2.5-72B-Instruct-2",
    messages=messages,
    logprobs=True,
    n=1,
    extra_body={
        "prompt_logprobs": 1
    }
)
prompt_len = len(response1.prompt_logprobs)

messages.append({
    "role": "assistant",
    "content": response1.choices[0].message.content
})
response2 = client.chat.completions.create(
    model="Qwen2.5-72B-Instruct-2",
    messages=messages,
    logprobs=True,
    n=1,
    max_tokens=1,
    extra_body={
        "prompt_logprobs": 1,
    }
)

token_logprobs1 = [(token.token, token.logprob) for token in response1.choices[0].logprobs.content]
s1 = 0
s2 = 0
res2 = response2.prompt_logprobs[prompt_len:]
for idx, tp in enumerate(token_logprobs1):
    for token in res2[idx]:
        if res2[idx][token]['decoded_token'] == tp[0]:
            print([tp[0]], tp[1], res2[idx][token]['logprob'])
            s1+=tp[1]
            s2+=res2[idx][token]['logprob']
print(s1, s2)



['计算'] -1.9446721076965332 -1.9449156522750854
['圆'] -3.6954811548639555e-06 -3.6954811548639555e-06
['周'] 0.0 0.0
['率'] 0.0 0.0
['π'] -2.1152522563934326 -1.9596004486083984
['('] -1.702558159828186 -1.8511580228805542
['pi'] -0.2962033748626709 -0.29709821939468384
[')'] -0.00023684080224484205 -0.0002369599969824776
['是一个'] -0.999537467956543 -1.0014550685882568
['历史悠久'] -0.31904473900794983 -0.3210349380970001
['且'] -0.48530423641204834 -0.4875105619430542
['富有'] -2.4599339962005615 -2.4824647903442383
['挑战'] -0.077326700091362 -0.0751587375998497
['性的'] -0.07335645705461502 -0.07453189045190811
['数学'] -0.017074842005968094 -0.017076482996344566
['问题'] -0.010011930949985981 -0.0070079006254673
['。'] -0.004675764590501785 -0.006678522098809481
['π'] -1.4160168170928955 -1.3776277303695679
['是一个'] -0.05739554762840271 -0.06244683265686035
['无'] -0.046009816229343414 -0.04602472856640816
['理'] -6.556489552167477e-06 -6.556489552167477e-06
['数'] 0.0 0.0
[','] -2.098061486321967e-05 -2.0861407392658293e-05
['意味着'] -0.15949402749538422 -0.13373138010501862
['它'] -0.1128290593624115 -0.11284811794757843
['的小'] -0.10516367107629776 -0.12227828800678253
['数'] -4.494089080253616e-05 -3.755022044060752e-05
['部分'] -0.00913783349096775 -0.009083377197384834
['无限'] -0.7951511144638062 -0.8952993154525757
['不'] -0.3520168364048004 -0.35117292404174805
['循环'] -3.576214658096433e-05 -3.576214658096433e-05
['。'] -0.8879746198654175 -0.6933096647262573
['圆'] -5.113936901092529 -5.10016393661499
['周'] -3.3378546504536644e-06 -3.933898824470816e-06
['率'] -1.0728830375228426e-06 -1.7881377516459906e-06
['的'] -1.1737570762634277 -1.0424505472183228
['前'] -2.3638522624969482 -2.3200573921203613
['1'] -0.0019491974962875247 -0.001622551935724914
['0'] 0.0 0.0
['位'] -2.3841855067985307e-07 -1.1920928244535389e-07
['是'] -0.21132780611515045 -0.24167922139167786
['3'] -0.001582085620611906 -0.0017353727016597986
['.'] 0.0 0.0
['1'] 0.0 0.0
['4'] 0.0 0.0
['1'] 0.0 0.0
['5'] -1.1920928244535389e-07 0.0
['9'] 0.0 0.0
['2'] -3.135155202471651e-05 -3.755022044060752e-05
['6'] 0.0 0.0
['5'] 0.0 0.0
['3'] 0.0 -1.1920928244535389e-07
['5'] -0.7434393167495728 -0.5722126364707947
[','] -0.11336615681648254 -0.11261051148176193
['但'] -0.060354750603437424 -0.0634424015879631
['如何'] -0.3800504505634308 -0.3280564844608307
['计算'] -1.0152842998504639 -1.0103530883789062
['这些'] -1.3255386352539062 -1.2008512020111084
['值'] -1.01945960521698 -1.1328415870666504
['呢'] -0.004573124460875988 -0.004484952427446842
['?'] -0.0007839705212973058 -0.0007839705212973058
['本文'] -1.124902367591858 -1.1759260892868042
['将'] -0.00046004203613847494 -0.00032181330607272685
['详细介绍'] -0.6663597226142883 -0.7036594152450562
['几种'] -0.016474291682243347 -0.01726561039686203
['计算'] -0.5865393280982971 -0.6067655086517334
['圆'] -0.4640563726425171 -0.4634779691696167
['周'] 0.0 0.0
['率'] 0.0 0.0
['的方法'] -0.4330849349498749 -0.4287530779838562
[',并'] -1.6447341442108154 -1.6447339057922363
['逐步'] -0.1179928183555603 -0.10228589177131653
['展示'] -1.1397227048873901 -1.1486461162567139
['如何'] -0.016492584720253944 -0.014912204816937447
['利用'] -3.152472734451294 -3.00485897064209
['这些'] -0.0997341200709343 -0.09909205138683319
['方法'] -0.00027104519540444016 -0.00038628268521279097
['计算'] -0.2815181016921997 -0.2980819344520569
['出'] -0.9595966339111328 -1.0622118711471558
['π'] -0.19502663612365723 -0.20271897315979004
['的'] -0.00024279984063468874 -0.00022968991834204644
['前'] -1.537788011773955e-05 -1.8596476365928538e-05
['1'] -1.1920928244535389e-07 -1.1920928244535389e-07
['0'] 0.0 0.0
['位'] 0.0 0.0
['。\n\n'] -0.08976593613624573 -0.09146438539028168
...

-517.1214426652566 -509.62573781293065

如何标注chosen、rejected?
有4种方式可以选:
1. 人工标注;
2. 规则;
3. reward model;
4. 业界top模型打GSB或给定参考答案进行打分;

代码库:https://github.com/mlpod/OpenDPO