Skip to content

Commit 0ef3f70

Browse files
authored
Add end_with_user and include_system_prompt flags to Magpie tasks and handle Nones. (#784)
* Add `end_with_user` flag * Add `include_system_prompt` attribute to `Magpie` * Update docstrings * Update `MagpieBase` to handle `None`s * Fix `InferenceEndpointsLLM` unit tests after release of `huggingface_hub==0.24.0`
1 parent b22a494 commit 0ef3f70

File tree

5 files changed

+251
-36
lines changed

5 files changed

+251
-36
lines changed

Diff for: src/distilabel/steps/tasks/magpie/base.py

+90-29
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
15+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
1616

1717
from pydantic import Field, PositiveInt
1818

@@ -26,7 +26,7 @@
2626
from distilabel.steps.tasks.base import Task
2727

2828
if TYPE_CHECKING:
29-
from distilabel.steps.tasks.typing import ChatType, FormattedInput
29+
from distilabel.steps.tasks.typing import ChatType
3030
from distilabel.steps.typing import StepOutput
3131

3232
MAGPIE_MULTI_TURN_SYSTEM_PROMPT = (
@@ -50,6 +50,14 @@ class MagpieBase(RuntimeParametersMixin):
5050
default=1,
5151
description="The number of turns to generate for the conversation.",
5252
)
53+
end_with_user: RuntimeParameter[bool] = Field(
54+
default=False,
55+
description="Whether the conversation should end with a user message.",
56+
)
57+
include_system_prompt: RuntimeParameter[bool] = Field(
58+
default=False,
59+
description="Whether to include the system prompt used in the generated conversation.",
60+
)
5361
only_instruction: RuntimeParameter[bool] = Field(
5462
default=False,
5563
description="Whether to generate only the instruction. If this argument"
@@ -63,7 +71,7 @@ class MagpieBase(RuntimeParametersMixin):
6371

6472
def _prepare_inputs_for_instruction_generation(
6573
self, inputs: List[Dict[str, Any]]
66-
) -> List["FormattedInput"]:
74+
) -> List["ChatType"]:
6775
"""Prepares the inputs adding the system (if required) prompt provided in each row,
6876
or if the conversations to generate have more than one turn, then adding the system
6977
prompt for multi-turn conversation from the paper.
@@ -106,7 +114,8 @@ def _append_messages_to_conversations(
106114
The updated conversations.
107115
"""
108116
for instruction, conversation in zip(messages, conversations):
109-
conversation.append({"role": role, "content": instruction})
117+
if instruction is not None:
118+
conversation.append({"role": role, "content": instruction})
110119
return conversations
111120

112121
def _generate_instruction(
@@ -120,41 +129,83 @@ def _generate_instruction(
120129
)
121130
return [{"instruction": output[0]} for output in outputs]
122131

132+
def _prepare_conversation_outputs(
133+
self, conversations: List["ChatType"]
134+
) -> List[Dict[str, Any]]:
135+
"""Prepare the output conversation removing the system prompt if necessary.
136+
137+
Args:
138+
conversations: the list of generated conversations.
139+
140+
Returns:
141+
A list of dictionaries containing a "conversation" key.
142+
"""
143+
outputs = []
144+
for conversation in conversations:
145+
if not self.include_system_prompt and conversation[0]["role"] == "system":
146+
conversation.pop(0)
147+
outputs.append({"conversation": conversation})
148+
return outputs
149+
150+
def _generate_conversation_turn(
151+
self, role: str, conversations: List["ChatType"], active_indices: List[int]
152+
) -> Tuple[List["ChatType"], List[int]]:
153+
# Generate an output for the conversations that are still active (no previous `None`s)
154+
outputs = self.llm.generate(
155+
inputs=[conversations[idx] for idx in active_indices],
156+
num_generations=1,
157+
**self.llm.generation_kwargs, # type: ignore
158+
)
159+
160+
active_conversations = [conversations[idx] for idx in active_indices]
161+
updated_conversations = self._append_messages_to_conversations(
162+
role=role,
163+
messages=[output[0] for output in outputs],
164+
conversations=active_conversations,
165+
)
166+
167+
for idx, conv in zip(active_indices, updated_conversations):
168+
conversations[idx] = conv
169+
170+
new_active_indices = [
171+
idx for idx, output in zip(active_indices, outputs) if output[0] is not None
172+
]
173+
174+
return conversations, new_active_indices
175+
123176
def _generate_multi_turn_conversation(
124177
self, inputs: List[Dict[str, Any]]
125178
) -> List[Dict[str, Any]]:
126-
conversations = self._prepare_inputs_for_instruction_generation(inputs)
127-
128-
for _ in range(self.n_turns): # type: ignore
129-
# Generate instruction or user message
130-
outputs = self.llm.generate(
131-
inputs=conversations,
132-
num_generations=1,
133-
**self.llm.generation_kwargs, # type: ignore
134-
)
179+
conversations: List["ChatType"] = (
180+
self._prepare_inputs_for_instruction_generation(inputs)
181+
)
182+
# Keep track of the active conversations, as it could happen that for some conversation
183+
# we can't generate the next turn because the `LLM` returned `None`.
184+
active_indices = list(range(len(conversations)))
185+
186+
for i in range(self.n_turns): # type: ignore
187+
if not active_indices:
188+
break
135189

136-
conversations = self._append_messages_to_conversations(
137-
role="user",
138-
messages=[output[0] for output in outputs],
139-
conversations=conversations, # type: ignore
190+
# Generate user message
191+
conversations, active_indices = self._generate_conversation_turn(
192+
role="user", conversations=conversations, active_indices=active_indices
140193
)
141194

142-
# TODO: handle potential previous `None`s
195+
if i == self.n_turns - 1 and self.end_with_user: # type: ignore
196+
break
143197

144-
# Generate response
145-
outputs = self.llm.generate(
146-
inputs=conversations,
147-
num_generations=1,
148-
**self.llm.generation_kwargs, # type: ignore
149-
)
198+
if not active_indices:
199+
break
150200

151-
conversations = self._append_messages_to_conversations(
201+
# Generate assistant message
202+
conversations, active_indices = self._generate_conversation_turn(
152203
role="assistant",
153-
messages=[output[0] for output in outputs],
154-
conversations=conversations, # type: ignore
204+
conversations=conversations,
205+
active_indices=active_indices,
155206
)
156207

157-
return [{"conversation": conversation} for conversation in conversations]
208+
return self._prepare_conversation_outputs(conversations)
158209

159210
def _generate_with_pre_query_template(
160211
self, inputs: List[Dict[str, Any]]
@@ -196,6 +247,11 @@ class Magpie(Task, MagpieBase):
196247
197248
Attributes:
198249
n_turns: the number of turns that the generated conversation will have.
250+
Defaults to `1`.
251+
end_with_user: whether the conversation should end with a user message.
252+
Defaults to `False`.
253+
include_system_prompt: whether to include the system prompt used in the generated
254+
conversation. Defaults to `False`.
199255
only_instruction: whether to generate only the instruction. If this argument is
200256
`True`, then `n_turns` will be ignored. Defaults to `False`.
201257
system_prompt: an optional system prompt that can be used to steer the LLM to generate
@@ -204,7 +260,12 @@ class Magpie(Task, MagpieBase):
204260
one from the column will be used. Defaults to `None`.
205261
206262
Runtime parameters:
207-
- `n_turns`: the number of turns that the generated conversation will have.
263+
- `n_turns`: the number of turns that the generated conversation will have. Defaults
264+
to `1`.
265+
- `end_with_user`: whether the conversation should end with a user message.
266+
Defaults to `False`.
267+
- `include_system_prompt`: whether to include the system prompt used in the generated
268+
conversation. Defaults to `False`.
208269
- `only_instruction`: whether to generate only the instruction. If this argument is
209270
`True`, then `n_turns` will be ignored. Defaults to `False`.
210271
- `system_prompt`: an optional system prompt that can be used to steer the LLM to

Diff for: src/distilabel/steps/tasks/magpie/generator.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@ class MagpieGenerator(GeneratorTask, MagpieBase):
4242
4343
Attributes:
4444
n_turns: the number of turns that the generated conversation will have.
45+
Defaults to `1`.
46+
end_with_user: whether the conversation should end with a user message.
47+
Defaults to `False`.
48+
include_system_prompt: whether to include the system prompt used in the generated
49+
conversation. Defaults to `False`.
4550
only_instruction: whether to generate only the instruction. If this argument is
4651
`True`, then `n_turns` will be ignored. Defaults to `False`.
4752
system_prompt: an optional system prompt that can be used to steer the LLM to generate
@@ -51,11 +56,18 @@ class MagpieGenerator(GeneratorTask, MagpieBase):
5156
num_rows: the number of rows to be generated.
5257
5358
Runtime parameters:
54-
- `n_turns`: the number of turns that the generated conversation will have.
55-
- `only_instruction`: whether to generate only the instruction. If this argument
56-
is `True`, then `n_turns` will be ignored. Defaults to `False`.
59+
- `n_turns`: the number of turns that the generated conversation will have. Defaults
60+
to `1`.
61+
- `end_with_user`: whether the conversation should end with a user message.
62+
Defaults to `False`.
63+
- `include_system_prompt`: whether to include the system prompt used in the generated
64+
conversation. Defaults to `False`.
65+
- `only_instruction`: whether to generate only the instruction. If this argument is
66+
`True`, then `n_turns` will be ignored. Defaults to `False`.
5767
- `system_prompt`: an optional system prompt that can be used to steer the LLM to
58-
generate content of certain topic, guide the style, etc. Defaults to `None`.
68+
generate content of certain topic, guide the style, etc. If the provided inputs
69+
contains a `system_prompt` column, then this runtime parameter will be ignored
70+
and the one from the column will be used. Defaults to `None`.
5971
- `num_rows`: the number of rows to be generated.
6072
6173
Output columns:

Diff for: tests/unit/llms/huggingface/test_inference_endpoints.py

-2
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,6 @@ async def test_agenerate_with_chat_completion(
171171
created=1721045246,
172172
id="",
173173
model="meta-llama/Meta-Llama-3-70B-Instruct",
174-
object="chat.completion",
175174
system_fingerprint="2.1.1-dev0-sha-4327210",
176175
usage=ChatCompletionOutputUsage(
177176
completion_tokens=66, prompt_tokens=18, total_tokens=84
@@ -212,7 +211,6 @@ async def test_agenerate_with_chat_completion_fails(
212211
created=1721045246,
213212
id="",
214213
model="meta-llama/Meta-Llama-3-70B-Instruct",
215-
object="chat.completion",
216214
system_fingerprint="2.1.1-dev0-sha-4327210",
217215
usage=ChatCompletionOutputUsage(
218216
completion_tokens=66, prompt_tokens=18, total_tokens=84

0 commit comments

Comments
 (0)