12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- from typing import TYPE_CHECKING , Any , Dict , List , Optional , Union
15
+ from typing import TYPE_CHECKING , Any , Dict , List , Optional , Tuple , Union
16
16
17
17
from pydantic import Field , PositiveInt
18
18
26
26
from distilabel .steps .tasks .base import Task
27
27
28
28
if TYPE_CHECKING :
29
- from distilabel .steps .tasks .typing import ChatType , FormattedInput
29
+ from distilabel .steps .tasks .typing import ChatType
30
30
from distilabel .steps .typing import StepOutput
31
31
32
32
MAGPIE_MULTI_TURN_SYSTEM_PROMPT = (
@@ -50,6 +50,14 @@ class MagpieBase(RuntimeParametersMixin):
50
50
default = 1 ,
51
51
description = "The number of turns to generate for the conversation." ,
52
52
)
53
+ end_with_user : RuntimeParameter [bool ] = Field (
54
+ default = False ,
55
+ description = "Whether the conversation should end with a user message." ,
56
+ )
57
+ include_system_prompt : RuntimeParameter [bool ] = Field (
58
+ default = False ,
59
+ description = "Whether to include the system prompt used in the generated conversation." ,
60
+ )
53
61
only_instruction : RuntimeParameter [bool ] = Field (
54
62
default = False ,
55
63
description = "Whether to generate only the instruction. If this argument"
@@ -63,7 +71,7 @@ class MagpieBase(RuntimeParametersMixin):
63
71
64
72
def _prepare_inputs_for_instruction_generation (
65
73
self , inputs : List [Dict [str , Any ]]
66
- ) -> List ["FormattedInput " ]:
74
+ ) -> List ["ChatType " ]:
67
75
"""Prepares the inputs adding the system (if required) prompt provided in each row,
68
76
or if the conversations to generate have more than one turn, then adding the system
69
77
prompt for multi-turn conversation from the paper.
@@ -106,7 +114,8 @@ def _append_messages_to_conversations(
106
114
The updated conversations.
107
115
"""
108
116
for instruction , conversation in zip (messages , conversations ):
109
- conversation .append ({"role" : role , "content" : instruction })
117
+ if instruction is not None :
118
+ conversation .append ({"role" : role , "content" : instruction })
110
119
return conversations
111
120
112
121
def _generate_instruction (
@@ -120,41 +129,83 @@ def _generate_instruction(
120
129
)
121
130
return [{"instruction" : output [0 ]} for output in outputs ]
122
131
132
+ def _prepare_conversation_outputs (
133
+ self , conversations : List ["ChatType" ]
134
+ ) -> List [Dict [str , Any ]]:
135
+ """Prepare the output conversation removing the system prompt if necessary.
136
+
137
+ Args:
138
+ conversations: the list of generated conversations.
139
+
140
+ Returns:
141
+ A list of dictionaries containing a "conversation" key.
142
+ """
143
+ outputs = []
144
+ for conversation in conversations :
145
+ if not self .include_system_prompt and conversation [0 ]["role" ] == "system" :
146
+ conversation .pop (0 )
147
+ outputs .append ({"conversation" : conversation })
148
+ return outputs
149
+
150
+ def _generate_conversation_turn (
151
+ self , role : str , conversations : List ["ChatType" ], active_indices : List [int ]
152
+ ) -> Tuple [List ["ChatType" ], List [int ]]:
153
+ # Generate an output for the conversations that are still active (no previous `None`s)
154
+ outputs = self .llm .generate (
155
+ inputs = [conversations [idx ] for idx in active_indices ],
156
+ num_generations = 1 ,
157
+ ** self .llm .generation_kwargs , # type: ignore
158
+ )
159
+
160
+ active_conversations = [conversations [idx ] for idx in active_indices ]
161
+ updated_conversations = self ._append_messages_to_conversations (
162
+ role = role ,
163
+ messages = [output [0 ] for output in outputs ],
164
+ conversations = active_conversations ,
165
+ )
166
+
167
+ for idx , conv in zip (active_indices , updated_conversations ):
168
+ conversations [idx ] = conv
169
+
170
+ new_active_indices = [
171
+ idx for idx , output in zip (active_indices , outputs ) if output [0 ] is not None
172
+ ]
173
+
174
+ return conversations , new_active_indices
175
+
123
176
def _generate_multi_turn_conversation (
124
177
self , inputs : List [Dict [str , Any ]]
125
178
) -> List [Dict [str , Any ]]:
126
- conversations = self ._prepare_inputs_for_instruction_generation (inputs )
127
-
128
- for _ in range (self .n_turns ): # type: ignore
129
- # Generate instruction or user message
130
- outputs = self .llm .generate (
131
- inputs = conversations ,
132
- num_generations = 1 ,
133
- ** self .llm .generation_kwargs , # type: ignore
134
- )
179
+ conversations : List ["ChatType" ] = (
180
+ self ._prepare_inputs_for_instruction_generation (inputs )
181
+ )
182
+ # Keep track of the active conversations, as it could happen that for some conversation
183
+ # we can't generate the next turn because the `LLM` returned `None`.
184
+ active_indices = list (range (len (conversations )))
185
+
186
+ for i in range (self .n_turns ): # type: ignore
187
+ if not active_indices :
188
+ break
135
189
136
- conversations = self ._append_messages_to_conversations (
137
- role = "user" ,
138
- messages = [output [0 ] for output in outputs ],
139
- conversations = conversations , # type: ignore
190
+ # Generate user message
191
+ conversations , active_indices = self ._generate_conversation_turn (
192
+ role = "user" , conversations = conversations , active_indices = active_indices
140
193
)
141
194
142
- # TODO: handle potential previous `None`s
195
+ if i == self .n_turns - 1 and self .end_with_user : # type: ignore
196
+ break
143
197
144
- # Generate response
145
- outputs = self .llm .generate (
146
- inputs = conversations ,
147
- num_generations = 1 ,
148
- ** self .llm .generation_kwargs , # type: ignore
149
- )
198
+ if not active_indices :
199
+ break
150
200
151
- conversations = self ._append_messages_to_conversations (
201
+ # Generate assistant message
202
+ conversations , active_indices = self ._generate_conversation_turn (
152
203
role = "assistant" ,
153
- messages = [ output [ 0 ] for output in outputs ] ,
154
- conversations = conversations , # type: ignore
204
+ conversations = conversations ,
205
+ active_indices = active_indices ,
155
206
)
156
207
157
- return [{ "conversation" : conversation } for conversation in conversations ]
208
+ return self . _prepare_conversation_outputs ( conversations )
158
209
159
210
def _generate_with_pre_query_template (
160
211
self , inputs : List [Dict [str , Any ]]
@@ -196,6 +247,11 @@ class Magpie(Task, MagpieBase):
196
247
197
248
Attributes:
198
249
n_turns: the number of turns that the generated conversation will have.
250
+ Defaults to `1`.
251
+ end_with_user: whether the conversation should end with a user message.
252
+ Defaults to `False`.
253
+ include_system_prompt: whether to include the system prompt used in the generated
254
+ conversation. Defaults to `False`.
199
255
only_instruction: whether to generate only the instruction. If this argument is
200
256
`True`, then `n_turns` will be ignored. Defaults to `False`.
201
257
system_prompt: an optional system prompt that can be used to steer the LLM to generate
@@ -204,7 +260,12 @@ class Magpie(Task, MagpieBase):
204
260
one from the column will be used. Defaults to `None`.
205
261
206
262
Runtime parameters:
207
- - `n_turns`: the number of turns that the generated conversation will have.
263
+ - `n_turns`: the number of turns that the generated conversation will have. Defaults
264
+ to `1`.
265
+ - `end_with_user`: whether the conversation should end with a user message.
266
+ Defaults to `False`.
267
+ - `include_system_prompt`: whether to include the system prompt used in the generated
268
+ conversation. Defaults to `False`.
208
269
- `only_instruction`: whether to generate only the instruction. If this argument is
209
270
`True`, then `n_turns` will be ignored. Defaults to `False`.
210
271
- `system_prompt`: an optional system prompt that can be used to steer the LLM to
0 commit comments