16
16
from typing import TYPE_CHECKING , Any , Dict , Generator , List , Optional , Tuple , Union
17
17
18
18
import orjson
19
- from pydantic import PositiveInt , validate_call
19
+ from pydantic import NonNegativeInt , PositiveInt , validate_call
20
20
21
21
from distilabel import envs
22
22
from distilabel .exceptions import DistilabelOfflineBatchGenerationNotFinishedException
29
29
from openai .types import Batch as OpenAIBatch
30
30
from openai .types import FileObject as OpenAIFileObject
31
31
from openai .types .chat import ChatCompletion as OpenAIChatCompletion
32
- from openai .types .chat .chat_completion import Choice as OpenAIChoice
32
+ from openai .types .chat .chat_completion import Choice as OpenAIChatCompletionChoice
33
33
from openai .types .completion import Completion as OpenAICompletion
34
+ from openai .types .completion_choice import (
35
+ CompletionChoice as OpenAICompletionChoice ,
36
+ )
34
37
35
- from distilabel .typing import LLMStatistics , Logprob
38
+ from distilabel .typing .models import (
39
+ LLMStatistics ,
40
+ Logprob ,
41
+ StandardInput ,
42
+ StructuredInput ,
43
+ )
36
44
37
45
38
46
_OPENAI_BATCH_API_MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB
@@ -148,15 +156,17 @@ async def agenerate( # type: ignore
148
156
self ,
149
157
input : FormattedInput ,
150
158
num_generations : int = 1 ,
151
- max_new_tokens : int = 128 ,
159
+ max_new_tokens : NonNegativeInt = 128 ,
152
160
logprobs : bool = False ,
153
161
top_logprobs : Optional [PositiveInt ] = None ,
162
+ echo : bool = False ,
154
163
frequency_penalty : float = 0.0 ,
155
164
presence_penalty : float = 0.0 ,
156
165
temperature : float = 1.0 ,
157
166
top_p : float = 1.0 ,
158
167
stop : Optional [Union [str , List [str ]]] = None ,
159
168
response_format : Optional [Dict [str , str ]] = None ,
169
+ extra_body : Optional [Dict [str , Any ]] = None ,
160
170
) -> GenerateOutput :
161
171
"""Generates `num_generations` responses for the given input using the OpenAI async
162
172
client.
@@ -170,6 +180,8 @@ async def agenerate( # type: ignore
170
180
logprobs: whether to return the log probabilities or not. Defaults to `False`.
171
181
top_logprobs: the number of top log probabilities to return per output token
172
182
generated. Defaults to `None`.
183
+ echo: whether to echo the input in the response or not. It's only used if the
184
+ `input` argument is an `str`. Defaults to `False`.
173
185
frequency_penalty: the repetition penalty to use for the generation. Defaults
174
186
to `0.0`.
175
187
presence_penalty: the presence penalty to use for the generation. Defaults to
@@ -182,14 +194,115 @@ async def agenerate( # type: ignore
182
194
"text" or "json". Read the documentation [here](https://platform.openai.com/docs/guides/text-generation/json-mode)
183
195
for more information on how to use the JSON model from OpenAI. Defaults to None
184
196
which returns text. To return JSON, use {"type": "json_object"}.
185
-
186
- Note:
187
- If response_format
197
+ extra_body: an optional dictionary containing extra body parameters that will
198
+ be sent to the OpenAI API endpoint. Defaults to `None`.
188
199
189
200
Returns:
190
201
A list of lists of strings containing the generated responses for each input.
191
202
"""
192
203
204
+ if isinstance (input , str ):
205
+ return await self ._generate_completion (
206
+ input = input ,
207
+ num_generations = num_generations ,
208
+ max_new_tokens = max_new_tokens ,
209
+ echo = echo ,
210
+ top_logprobs = top_logprobs ,
211
+ frequency_penalty = frequency_penalty ,
212
+ presence_penalty = presence_penalty ,
213
+ temperature = temperature ,
214
+ top_p = top_p ,
215
+ extra_body = extra_body ,
216
+ )
217
+
218
+ return await self ._generate_chat_completion (
219
+ input = input ,
220
+ num_generations = num_generations ,
221
+ max_new_tokens = max_new_tokens ,
222
+ logprobs = logprobs ,
223
+ top_logprobs = top_logprobs ,
224
+ frequency_penalty = frequency_penalty ,
225
+ presence_penalty = presence_penalty ,
226
+ temperature = temperature ,
227
+ top_p = top_p ,
228
+ stop = stop ,
229
+ response_format = response_format ,
230
+ extra_body = extra_body ,
231
+ )
232
+
233
+ async def _generate_completion (
234
+ self ,
235
+ input : str ,
236
+ num_generations : int = 1 ,
237
+ max_new_tokens : int = 128 ,
238
+ echo : bool = False ,
239
+ top_logprobs : Optional [PositiveInt ] = None ,
240
+ frequency_penalty : float = 0.0 ,
241
+ presence_penalty : float = 0.0 ,
242
+ temperature : float = 1.0 ,
243
+ top_p : float = 1.0 ,
244
+ extra_body : Optional [Dict [str , Any ]] = None ,
245
+ ) -> GenerateOutput :
246
+ completion = await self ._aclient .completions .create (
247
+ prompt = input ,
248
+ echo = echo ,
249
+ model = self .model ,
250
+ n = num_generations ,
251
+ max_tokens = max_new_tokens ,
252
+ logprobs = top_logprobs ,
253
+ frequency_penalty = frequency_penalty ,
254
+ presence_penalty = presence_penalty ,
255
+ temperature = temperature ,
256
+ top_p = top_p ,
257
+ extra_body = extra_body ,
258
+ )
259
+
260
+ generations = []
261
+ logprobs = []
262
+ for choice in completion .choices :
263
+ generations .append (choice .text )
264
+ if choice_logprobs := self ._get_logprobs_from_completion_choice (choice ):
265
+ logprobs .append (choice_logprobs )
266
+
267
+ statistics = self ._get_llm_statistics (completion )
268
+ return prepare_output (
269
+ generations = generations ,
270
+ input_tokens = statistics ["input_tokens" ],
271
+ output_tokens = statistics ["output_tokens" ],
272
+ logprobs = logprobs ,
273
+ )
274
+
275
+ def _get_logprobs_from_completion_choice (
276
+ self , choice : "OpenAICompletionChoice"
277
+ ) -> Union [List [Union [List ["Logprob" ], None ]], None ]:
278
+ if choice .logprobs is None or choice .logprobs .top_logprobs is None :
279
+ return None
280
+
281
+ return [
282
+ [
283
+ {"token" : token , "logprob" : token_logprob }
284
+ for token , token_logprob in logprobs .items ()
285
+ ]
286
+ if logprobs is not None
287
+ else None
288
+ for logprobs in choice .logprobs .top_logprobs
289
+ ]
290
+
291
+ async def _generate_chat_completion (
292
+ self ,
293
+ input : Union ["StandardInput" , "StructuredInput" ],
294
+ num_generations : int = 1 ,
295
+ max_new_tokens : int = 128 ,
296
+ logprobs : bool = False ,
297
+ top_logprobs : Optional [PositiveInt ] = None ,
298
+ frequency_penalty : float = 0.0 ,
299
+ presence_penalty : float = 0.0 ,
300
+ temperature : float = 1.0 ,
301
+ top_p : float = 1.0 ,
302
+ stop : Optional [Union [str , List [str ]]] = None ,
303
+ response_format : Optional [Dict [str , str ]] = None ,
304
+ extra_body : Optional [Dict [str , Any ]] = None ,
305
+ ) -> GenerateOutput :
193
306
structured_output = None
194
307
if isinstance (input , tuple ):
195
308
input , structured_output = input
@@ -215,9 +328,11 @@ async def agenerate( # type: ignore
215
328
"temperature" : temperature ,
216
329
"top_p" : top_p ,
217
330
"stop" : stop ,
331
+ "extra_body" : extra_body ,
218
332
}
219
- # Check if it's a vision generation task, in that case "stop" cannot be used or raises
220
- # an error in the API.
333
+
334
+ # Checks if any message contains an image, in that case "stop" cannot be used or
335
+ # raises an error in the API.
221
336
if isinstance (
222
337
[row for row in input if row ["role" ] == "user" ][0 ]["content" ], list
223
338
):
@@ -235,7 +350,7 @@ async def agenerate( # type: ignore
235
350
# NOTE: `instructor` doesn't work with `n` parameter, so it will always return
236
351
# only 1 choice.
237
352
statistics = self ._get_llm_statistics (completion ._raw_response )
238
- if choice_logprobs := self ._get_logprobs_from_choice (
353
+ if choice_logprobs := self ._get_logprobs_from_chat_completion_choice (
239
354
completion ._raw_response .choices [0 ]
240
355
):
241
356
output_logprobs = [choice_logprobs ]
@@ -270,7 +385,9 @@ def _generations_from_openai_completion(
270
385
f" Finish reason was: { choice .finish_reason } "
271
386
)
272
387
generations .append (content )
273
- if choice_logprobs := self ._get_logprobs_from_choice (choice ):
388
+ if choice_logprobs := self ._get_logprobs_from_chat_completion_choice (
389
+ choice
390
+ ):
274
391
logprobs .append (choice_logprobs )
275
392
276
393
statistics = self ._get_llm_statistics (completion )
@@ -281,8 +398,8 @@ def _generations_from_openai_completion(
281
398
logprobs = logprobs ,
282
399
)
283
400
284
- def _get_logprobs_from_choice (
285
- self , choice : "OpenAIChoice "
401
+ def _get_logprobs_from_chat_completion_choice (
402
+ self , choice : "OpenAIChatCompletionChoice "
286
403
) -> Union [List [List ["Logprob" ]], None ]:
287
404
if choice .logprobs is None or choice .logprobs .content is None :
288
405
return None
0 commit comments