Skip to content

Commit 2147cf6

Browse files
winglianNero10578
andauthored
* add dpo llama3 * fix dpo bos and eos * bos token gets added automatically by the tokenizer * explicit <|end_of_text|> not needed, as eot_id is sufficient --------- Co-authored-by: Nero10578 <owenarliawan@gmail.com>
1 parent 50421c8 commit 2147cf6

File tree

1 file changed

+133
-0
lines changed

1 file changed

+133
-0
lines changed
+133
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
"""
2+
DPO strategies for llama-3 chat template
3+
"""
4+
5+
6+
def argilla(
7+
cfg,
8+
**kwargs,
9+
): # pylint: disable=possibly-unused-variable,unused-argument
10+
def transform_fn(sample):
11+
if "system" in sample and sample["system"]:
12+
sample["prompt"] = (
13+
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
14+
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
15+
)
16+
else:
17+
sample[
18+
"prompt"
19+
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
20+
sample["chosen"] = f"{sample['chosen_response']}<|eot_id|>"
21+
sample["rejected"] = f"{sample['rejected_response']}<|eot_id|>"
22+
return sample
23+
24+
return transform_fn
25+
26+
27+
def argilla_chat(
28+
cfg,
29+
**kwargs,
30+
): # pylint: disable=possibly-unused-variable,unused-argument
31+
"""
32+
for argilla/dpo-mix-7k conversations
33+
"""
34+
35+
def transform_fn(sample):
36+
sample[
37+
"prompt"
38+
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['chosen'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
39+
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
40+
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
41+
return sample
42+
43+
return transform_fn
44+
45+
46+
def icr(
47+
cfg,
48+
**kwargs,
49+
): # pylint: disable=possibly-unused-variable,unused-argument
50+
"""
51+
chatml transforms for datasets with system, input, chosen, rejected
52+
ex. https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs
53+
"""
54+
55+
def transform_fn(sample):
56+
if "system" in sample and sample["system"]:
57+
sample["prompt"] = (
58+
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
59+
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
60+
)
61+
else:
62+
sample[
63+
"prompt"
64+
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
65+
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
66+
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
67+
return sample
68+
69+
return transform_fn
70+
71+
72+
def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
73+
"""
74+
For Intel Orca DPO Pairs
75+
"""
76+
77+
def transform_fn(sample):
78+
if "system" in sample and sample["system"]:
79+
sample["prompt"] = (
80+
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
81+
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
82+
)
83+
else:
84+
sample[
85+
"prompt"
86+
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
87+
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
88+
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
89+
return sample
90+
91+
return transform_fn
92+
93+
94+
def prompt_pairs(
95+
cfg, **kwargs
96+
): # pylint: disable=possibly-unused-variable,unused-argument
97+
def transform_fn(sample):
98+
if "system" in sample and sample["system"]:
99+
sample["prompt"] = (
100+
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
101+
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
102+
)
103+
else:
104+
sample[
105+
"prompt"
106+
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
107+
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
108+
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
109+
return sample
110+
111+
return transform_fn
112+
113+
114+
def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
115+
"""
116+
for ultrafeedback binarized conversations
117+
"""
118+
119+
def transform_fn(sample):
120+
if "system" in sample and sample["system"]:
121+
sample["prompt"] = (
122+
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
123+
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
124+
)
125+
else:
126+
sample[
127+
"prompt"
128+
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
129+
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
130+
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
131+
return sample
132+
133+
return transform_fn

0 commit comments

Comments
 (0)