Skip to content

Commit db6a884

Browse files
committed
Add include_system_prompt attribute to Magpie
1 parent a630dda commit db6a884

File tree

3 files changed

+84
-8
lines changed

3 files changed

+84
-8
lines changed

Diff for: src/distilabel/steps/tasks/magpie/base.py

+28-4
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from distilabel.steps.tasks.base import Task
2727

2828
if TYPE_CHECKING:
29-
from distilabel.steps.tasks.typing import ChatType, FormattedInput
29+
from distilabel.steps.tasks.typing import ChatType
3030
from distilabel.steps.typing import StepOutput
3131

3232
MAGPIE_MULTI_TURN_SYSTEM_PROMPT = (
@@ -54,6 +54,10 @@ class MagpieBase(RuntimeParametersMixin):
5454
default=False,
5555
description="Whether the conversation should end with a user message.",
5656
)
57+
include_system_prompt: RuntimeParameter[bool] = Field(
58+
default=False,
59+
description="Whether to include the system prompt used in the generated conversation.",
60+
)
5761
only_instruction: RuntimeParameter[bool] = Field(
5862
default=False,
5963
description="Whether to generate only the instruction. If this argument"
@@ -67,7 +71,7 @@ class MagpieBase(RuntimeParametersMixin):
6771

6872
def _prepare_inputs_for_instruction_generation(
6973
self, inputs: List[Dict[str, Any]]
70-
) -> List["FormattedInput"]:
74+
) -> List["ChatType"]:
7175
"""Prepares the inputs adding the system (if required) prompt provided in each row,
7276
or if the conversations to generate have more than one turn, then adding the system
7377
prompt for multi-turn conversation from the paper.
@@ -124,10 +128,30 @@ def _generate_instruction(
124128
)
125129
return [{"instruction": output[0]} for output in outputs]
126130

131+
def _prepare_conversation_outputs(
132+
self, conversations: List["ChatType"]
133+
) -> List[Dict[str, Any]]:
134+
"""Prepare the output conversation removing the system prompt if necessary.
135+
136+
Args:
137+
conversations: the list of generated conversations.
138+
139+
Returns:
140+
A list of dictionaries containing a "conversation" key.
141+
"""
142+
outputs = []
143+
for conversation in conversations:
144+
if not self.include_system_prompt and conversation[0]["role"] == "system":
145+
conversation.pop(0)
146+
outputs.append({"conversation": conversation})
147+
return outputs
148+
127149
def _generate_multi_turn_conversation(
128150
self, inputs: List[Dict[str, Any]]
129151
) -> List[Dict[str, Any]]:
130-
conversations = self._prepare_inputs_for_instruction_generation(inputs)
152+
conversations: List["ChatType"] = (
153+
self._prepare_inputs_for_instruction_generation(inputs)
154+
)
131155

132156
for i in range(self.n_turns): # type: ignore
133157
# Generate instruction or user message
@@ -161,7 +185,7 @@ def _generate_multi_turn_conversation(
161185
conversations=conversations, # type: ignore
162186
)
163187

164-
return [{"conversation": conversation} for conversation in conversations]
188+
return self._prepare_conversation_outputs(conversations)
165189

166190
def _generate_with_pre_query_template(
167191
self, inputs: List[Dict[str, Any]]

Diff for: tests/unit/steps/tasks/magpie/test_base.py

+50-4
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ def test_process_with_n_turns(self) -> None:
7676
assert next(task.process(inputs=[{}, {}, {}])) == [
7777
{
7878
"conversation": [
79-
{"role": "system", "content": MAGPIE_MULTI_TURN_SYSTEM_PROMPT},
8079
{"role": "user", "content": "Hello Magpie"},
8180
{"role": "assistant", "content": "Hello Magpie"},
8281
{"role": "user", "content": "Hello Magpie"},
@@ -86,7 +85,6 @@ def test_process_with_n_turns(self) -> None:
8685
},
8786
{
8887
"conversation": [
89-
{"role": "system", "content": MAGPIE_MULTI_TURN_SYSTEM_PROMPT},
9088
{"role": "user", "content": "Hello Magpie"},
9189
{"role": "assistant", "content": "Hello Magpie"},
9290
{"role": "user", "content": "Hello Magpie"},
@@ -96,7 +94,6 @@ def test_process_with_n_turns(self) -> None:
9694
},
9795
{
9896
"conversation": [
99-
{"role": "system", "content": MAGPIE_MULTI_TURN_SYSTEM_PROMPT},
10097
{"role": "user", "content": "Hello Magpie"},
10198
{"role": "assistant", "content": "Hello Magpie"},
10299
{"role": "user", "content": "Hello Magpie"},
@@ -115,13 +112,50 @@ def test_process_with_end_with_user(self) -> None:
115112

116113
task.load()
117114

115+
assert next(task.process(inputs=[{}, {}, {}])) == [
116+
{
117+
"conversation": [
118+
{"role": "user", "content": "Hello Magpie"},
119+
{"role": "assistant", "content": "Hello Magpie"},
120+
{"role": "user", "content": "Hello Magpie"},
121+
],
122+
"model_name": "test",
123+
},
124+
{
125+
"conversation": [
126+
{"role": "user", "content": "Hello Magpie"},
127+
{"role": "assistant", "content": "Hello Magpie"},
128+
{"role": "user", "content": "Hello Magpie"},
129+
],
130+
"model_name": "test",
131+
},
132+
{
133+
"conversation": [
134+
{"role": "user", "content": "Hello Magpie"},
135+
{"role": "assistant", "content": "Hello Magpie"},
136+
{"role": "user", "content": "Hello Magpie"},
137+
],
138+
"model_name": "test",
139+
},
140+
]
141+
142+
def test_process_with_include_system_prompt(self) -> None:
143+
task = Magpie(
144+
llm=DummyMagpieLLM(magpie_pre_query_template="llama3"),
145+
n_turns=2,
146+
include_system_prompt=True,
147+
)
148+
149+
task.load()
150+
118151
assert next(task.process(inputs=[{}, {}, {}])) == [
119152
{
120153
"conversation": [
121154
{"role": "system", "content": MAGPIE_MULTI_TURN_SYSTEM_PROMPT},
122155
{"role": "user", "content": "Hello Magpie"},
123156
{"role": "assistant", "content": "Hello Magpie"},
124157
{"role": "user", "content": "Hello Magpie"},
158+
{"role": "assistant", "content": "Hello Magpie"},
125159
],
126160
"model_name": "test",
127161
},
@@ -131,6 +165,7 @@ def test_process_with_end_with_user(self) -> None:
131165
{"role": "user", "content": "Hello Magpie"},
132166
{"role": "assistant", "content": "Hello Magpie"},
133167
{"role": "user", "content": "Hello Magpie"},
168+
{"role": "assistant", "content": "Hello Magpie"},
134169
],
135170
"model_name": "test",
136171
},
@@ -140,13 +175,18 @@ def test_process_with_end_with_user(self) -> None:
140175
{"role": "user", "content": "Hello Magpie"},
141176
{"role": "assistant", "content": "Hello Magpie"},
142177
{"role": "user", "content": "Hello Magpie"},
178+
{"role": "assistant", "content": "Hello Magpie"},
143179
],
144180
"model_name": "test",
145181
},
146182
]
147183

148184
def test_process_with_system_prompt_per_row(self) -> None:
149-
task = Magpie(llm=DummyMagpieLLM(magpie_pre_query_template="llama3"), n_turns=2)
185+
task = Magpie(
186+
llm=DummyMagpieLLM(magpie_pre_query_template="llama3"),
187+
n_turns=2,
188+
include_system_prompt=True,
189+
)
150190

151191
task.load()
152192

@@ -235,6 +275,7 @@ def test_serialization(self) -> None:
235275
},
236276
"n_turns": 1,
237277
"end_with_user": False,
278+
"include_system_prompt": False,
238279
"only_instruction": True,
239280
"system_prompt": None,
240281
"name": "magpie_0",
@@ -272,6 +313,11 @@ def test_serialization(self) -> None:
272313
"optional": True,
273314
"description": "Whether the conversation should end with a user message.",
274315
},
316+
{
317+
"name": "include_system_prompt",
318+
"optional": True,
319+
"description": "Whether to include the system prompt used in the generated conversation.",
320+
},
275321
{
276322
"name": "only_instruction",
277323
"optional": True,

Diff for: tests/unit/steps/tasks/magpie/test_generator.py

+6
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def test_serialization(self) -> None:
5454
},
5555
"n_turns": 1,
5656
"end_with_user": False,
57+
"include_system_prompt": False,
5758
"only_instruction": False,
5859
"system_prompt": None,
5960
"name": "magpie_generator_0",
@@ -92,6 +93,11 @@ def test_serialization(self) -> None:
9293
"optional": True,
9394
"description": "Whether the conversation should end with a user message.",
9495
},
96+
{
97+
"name": "include_system_prompt",
98+
"optional": True,
99+
"description": "Whether to include the system prompt used in the generated conversation.",
100+
},
95101
{
96102
"name": "only_instruction",
97103
"optional": True,

0 commit comments

Comments
 (0)