Skip to content

Commit 3d4984b

Browse files
authored
update prompts for open orca to match the paper (axolotl-ai-cloud#317)
fix the test for the updated system tokenizer
1 parent ff7f18d commit 3d4984b

File tree

3 files changed

+26
-6
lines changed

3 files changed

+26
-6
lines changed

src/axolotl/prompt_strategies/alpaca_w_system.py

+22-3
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,34 @@ def build_prompt_w_system(
6666
) -> Generator[str, None, None]:
6767
# returns the full prompt from instruction and optional input
6868
# if a label (=response, =output) is provided, it's also appended.
69+
formatted_sys_prompt = f"### System:\n{system}\n\n" if system else ""
6970
if input:
70-
res = system + self.turn_format.format(instruction=instruction, input=input)
71+
res = formatted_sys_prompt + self.turn_format.format(
72+
instruction=instruction, input=input
73+
)
7174
else:
72-
res = system + self.turn_no_input_format.format(instruction=instruction)
75+
res = formatted_sys_prompt + self.turn_no_input_format.format(
76+
instruction=instruction
77+
)
7378
if output:
7479
res = f"{res}{output}"
7580
yield res
7681

7782

83+
class OpenOrcaSystemDataPrompter(SystemDataPrompter):
84+
"""
85+
Alpaca Style Prompter that uses system prompts from the dataset, with OpenOrca prompts
86+
"""
87+
88+
def match_prompt_style(self):
89+
if self.prompt_style == PromptStyle.INSTRUCT.value:
90+
self.turn_format = "### User:\n{instruction}\n\n### Additional Context:\n{input}\n\n### Assistant:\n"
91+
self.turn_no_input_format = "### User:\n{instruction}\n\n### Assistant:\n"
92+
if self.prompt_style == PromptStyle.CHAT.value:
93+
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
94+
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
95+
96+
7897
class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy):
7998
"""
8099
Tokenizing strategy for OpenOrca datasets
@@ -113,7 +132,7 @@ def load_chat(tokenizer, cfg):
113132

114133
def load_open_orca(tokenizer, cfg):
115134
return OpenOrcaPromptTokenizingStrategy(
116-
SystemDataPrompter(PromptStyle.INSTRUCT.value),
135+
OpenOrcaSystemDataPrompter(PromptStyle.INSTRUCT.value),
117136
tokenizer,
118137
cfg.train_on_inputs,
119138
cfg.sequence_len,

tests/test_prompt_tokenizers.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,9 @@ def test_system_alpaca(self):
130130
"output": "Hi! How can I help?",
131131
}
132132
example = strat.tokenize_prompt(sample)
133-
assert example["input_ids"][0:3] == [1, 671, 20118] # <s>use cot
134-
assert example["input_ids"][3] == 11889 # USER
133+
assert example["input_ids"][0:4] == [1, 835, 2184, 29901] # "<s>### System:"
134+
assert example["input_ids"][5:7] == [1509, 20118] # "use cot"
135+
assert example["input_ids"][9] == 11889 # USER
135136

136137

137138
if __name__ == "__main__":

tests/test_prompters.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def test_system_prompt(self):
7070
)
7171
)
7272
assert "use cot" in res
73-
assert res.startswith("use cot")
73+
assert res.startswith("### System:")
7474
assert "### Instruction:" not in res
7575
assert "### Input:" not in res
7676
assert "alpacas" in res

0 commit comments

Comments
 (0)