Skip to content

Commit 74cc09e

Browse files
authored
Update LLMs to support prompt logprobs use-case (#1099)
1 parent 5257600 commit 74cc09e

File tree

7 files changed

+288
-70
lines changed

7 files changed

+288
-70
lines changed

src/distilabel/models/base_clients/inference_endpoints.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from typing import (
1717
TYPE_CHECKING,
1818
Optional,
19-
Union,
2019
)
2120

2221
from pydantic import (
@@ -143,9 +142,9 @@ def load(self) -> None: # noqa: C901
143142
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id)
144143

145144
@property
146-
def model_name(self) -> Union[str, None]: # type: ignore
145+
def model_name(self) -> str:
147146
"""Returns the model name used for the model."""
148-
return (
147+
return ( # type: ignore
149148
self.model_display_name
150149
or self._model_name
151150
or self.model_id

src/distilabel/models/llms/huggingface/inference_endpoints.py

+35-28
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def _get_structured_output(
273273

274274
async def _generate_with_text_generation(
275275
self,
276-
input: FormattedInput,
276+
input: str,
277277
max_new_tokens: int = 128,
278278
repetition_penalty: Optional[float] = None,
279279
frequency_penalty: Optional[float] = None,
@@ -287,13 +287,12 @@ async def _generate_with_text_generation(
287287
return_full_text: bool = False,
288288
seed: Optional[int] = None,
289289
watermark: bool = False,
290+
structured_output: Union[Dict[str, Any], None] = None,
290291
) -> GenerateOutput:
291-
input, structured_output = self._get_structured_output(input)
292-
prompt = self.prepare_input(input)
293292
generation: Union["TextGenerationOutput", None] = None
294293
try:
295294
generation = await self._aclient.text_generation( # type: ignore
296-
prompt=prompt,
295+
prompt=input,
297296
max_new_tokens=max_new_tokens,
298297
do_sample=do_sample,
299298
typical_p=typical_p,
@@ -319,7 +318,9 @@ async def _generate_with_text_generation(
319318
)
320319
return prepare_output(
321320
generations=[generation.generated_text] if generation else [None],
322-
input_tokens=[compute_tokens(prompt, self._tokenizer.encode)], # type: ignore
321+
input_tokens=[
322+
compute_tokens(input, self._tokenizer.encode) if self._tokenizer else -1
323+
],
323324
output_tokens=[
324325
generation.details.generated_tokens
325326
if generation and generation.details
@@ -544,37 +545,43 @@ async def agenerate( # type: ignore
544545
"""
545546
stop_sequences = self._check_stop_sequences(stop_sequences)
546547

547-
if self.tokenizer_id is None:
548-
return await self._generate_with_chat_completion(
549-
input=input, # type: ignore
548+
if isinstance(input, str) or self.tokenizer_id is not None:
549+
structured_output = None
550+
if not isinstance(input, str):
551+
input, structured_output = self._get_structured_output(input)
552+
input = self.prepare_input(input)
553+
554+
return await self._generate_with_text_generation(
555+
input=input,
550556
max_new_tokens=max_new_tokens,
557+
do_sample=do_sample,
558+
typical_p=typical_p,
559+
repetition_penalty=repetition_penalty,
551560
frequency_penalty=frequency_penalty,
552-
logit_bias=logit_bias,
553-
logprobs=logprobs,
554-
presence_penalty=presence_penalty,
555-
seed=seed,
556-
stop_sequences=stop_sequences,
557561
temperature=temperature,
558-
tool_choice=tool_choice,
559-
tool_prompt=tool_prompt,
560-
tools=tools,
561-
top_logprobs=top_logprobs,
562+
top_n_tokens=top_n_tokens,
562563
top_p=top_p,
564+
top_k=top_k,
565+
stop_sequences=stop_sequences,
566+
return_full_text=return_full_text,
567+
seed=seed,
568+
watermark=watermark,
569+
structured_output=structured_output,
563570
)
564571

565-
return await self._generate_with_text_generation(
566-
input=input,
572+
return await self._generate_with_chat_completion(
573+
input=input, # type: ignore
567574
max_new_tokens=max_new_tokens,
568-
do_sample=do_sample,
569-
typical_p=typical_p,
570-
repetition_penalty=repetition_penalty,
571575
frequency_penalty=frequency_penalty,
576+
logit_bias=logit_bias,
577+
logprobs=logprobs,
578+
presence_penalty=presence_penalty,
579+
seed=seed,
580+
stop_sequences=stop_sequences,
572581
temperature=temperature,
573-
top_n_tokens=top_n_tokens,
582+
tool_choice=tool_choice,
583+
tool_prompt=tool_prompt,
584+
tools=tools,
585+
top_logprobs=top_logprobs,
574586
top_p=top_p,
575-
top_k=top_k,
576-
stop_sequences=stop_sequences,
577-
return_full_text=return_full_text,
578-
seed=seed,
579-
watermark=watermark,
580587
)

src/distilabel/models/llms/openai.py

+130-13
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Union
1717

1818
import orjson
19-
from pydantic import PositiveInt, validate_call
19+
from pydantic import NonNegativeInt, PositiveInt, validate_call
2020

2121
from distilabel import envs
2222
from distilabel.exceptions import DistilabelOfflineBatchGenerationNotFinishedException
@@ -29,10 +29,18 @@
2929
from openai.types import Batch as OpenAIBatch
3030
from openai.types import FileObject as OpenAIFileObject
3131
from openai.types.chat import ChatCompletion as OpenAIChatCompletion
32-
from openai.types.chat.chat_completion import Choice as OpenAIChoice
32+
from openai.types.chat.chat_completion import Choice as OpenAIChatCompletionChoice
3333
from openai.types.completion import Completion as OpenAICompletion
34+
from openai.types.completion_choice import (
35+
CompletionChoice as OpenAICompletionChoice,
36+
)
3437

35-
from distilabel.typing import LLMStatistics, Logprob
38+
from distilabel.typing.models import (
39+
LLMStatistics,
40+
Logprob,
41+
StandardInput,
42+
StructuredInput,
43+
)
3644

3745

3846
_OPENAI_BATCH_API_MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB
@@ -148,15 +156,17 @@ async def agenerate( # type: ignore
148156
self,
149157
input: FormattedInput,
150158
num_generations: int = 1,
151-
max_new_tokens: int = 128,
159+
max_new_tokens: NonNegativeInt = 128,
152160
logprobs: bool = False,
153161
top_logprobs: Optional[PositiveInt] = None,
162+
echo: bool = False,
154163
frequency_penalty: float = 0.0,
155164
presence_penalty: float = 0.0,
156165
temperature: float = 1.0,
157166
top_p: float = 1.0,
158167
stop: Optional[Union[str, List[str]]] = None,
159168
response_format: Optional[Dict[str, str]] = None,
169+
extra_body: Optional[Dict[str, Any]] = None,
160170
) -> GenerateOutput:
161171
"""Generates `num_generations` responses for the given input using the OpenAI async
162172
client.
@@ -170,6 +180,8 @@ async def agenerate( # type: ignore
170180
logprobs: whether to return the log probabilities or not. Defaults to `False`.
171181
top_logprobs: the number of top log probabilities to return per output token
172182
generated. Defaults to `None`.
183+
echo: whether to echo the input in the response or not. It's only used if the
184+
`input` argument is an `str`. Defaults to `False`.
173185
frequency_penalty: the repetition penalty to use for the generation. Defaults
174186
to `0.0`.
175187
presence_penalty: the presence penalty to use for the generation. Defaults to
@@ -182,14 +194,115 @@ async def agenerate( # type: ignore
182194
"text" or "json". Read the documentation [here](https://platform.openai.com/docs/guides/text-generation/json-mode)
183195
for more information on how to use the JSON model from OpenAI. Defaults to None
184196
which returns text. To return JSON, use {"type": "json_object"}.
185-
186-
Note:
187-
If response_format
197+
extra_body: an optional dictionary containing extra body parameters that will
198+
be sent to the OpenAI API endpoint. Defaults to `None`.
188199
189200
Returns:
190201
A list of lists of strings containing the generated responses for each input.
191202
"""
192203

204+
if isinstance(input, str):
205+
return await self._generate_completion(
206+
input=input,
207+
num_generations=num_generations,
208+
max_new_tokens=max_new_tokens,
209+
echo=echo,
210+
top_logprobs=top_logprobs,
211+
frequency_penalty=frequency_penalty,
212+
presence_penalty=presence_penalty,
213+
temperature=temperature,
214+
top_p=top_p,
215+
extra_body=extra_body,
216+
)
217+
218+
return await self._generate_chat_completion(
219+
input=input,
220+
num_generations=num_generations,
221+
max_new_tokens=max_new_tokens,
222+
logprobs=logprobs,
223+
top_logprobs=top_logprobs,
224+
frequency_penalty=frequency_penalty,
225+
presence_penalty=presence_penalty,
226+
temperature=temperature,
227+
top_p=top_p,
228+
stop=stop,
229+
response_format=response_format,
230+
extra_body=extra_body,
231+
)
232+
233+
async def _generate_completion(
234+
self,
235+
input: str,
236+
num_generations: int = 1,
237+
max_new_tokens: int = 128,
238+
echo: bool = False,
239+
top_logprobs: Optional[PositiveInt] = None,
240+
frequency_penalty: float = 0.0,
241+
presence_penalty: float = 0.0,
242+
temperature: float = 1.0,
243+
top_p: float = 1.0,
244+
extra_body: Optional[Dict[str, Any]] = None,
245+
) -> GenerateOutput:
246+
completion = await self._aclient.completions.create(
247+
prompt=input,
248+
echo=echo,
249+
model=self.model,
250+
n=num_generations,
251+
max_tokens=max_new_tokens,
252+
logprobs=top_logprobs,
253+
frequency_penalty=frequency_penalty,
254+
presence_penalty=presence_penalty,
255+
temperature=temperature,
256+
top_p=top_p,
257+
extra_body=extra_body,
258+
)
259+
260+
generations = []
261+
logprobs = []
262+
for choice in completion.choices:
263+
generations.append(choice.text)
264+
if choice_logprobs := self._get_logprobs_from_completion_choice(choice):
265+
logprobs.append(choice_logprobs)
266+
267+
statistics = self._get_llm_statistics(completion)
268+
return prepare_output(
269+
generations=generations,
270+
input_tokens=statistics["input_tokens"],
271+
output_tokens=statistics["output_tokens"],
272+
logprobs=logprobs,
273+
)
274+
275+
def _get_logprobs_from_completion_choice(
276+
self, choice: "OpenAICompletionChoice"
277+
) -> Union[List[Union[List["Logprob"], None]], None]:
278+
if choice.logprobs is None or choice.logprobs.top_logprobs is None:
279+
return None
280+
281+
return [
282+
[
283+
{"token": token, "logprob": token_logprob}
284+
for token, token_logprob in logprobs.items()
285+
]
286+
if logprobs is not None
287+
else None
288+
for logprobs in choice.logprobs.top_logprobs
289+
]
290+
291+
async def _generate_chat_completion(
292+
self,
293+
input: Union["StandardInput", "StructuredInput"],
294+
num_generations: int = 1,
295+
max_new_tokens: int = 128,
296+
logprobs: bool = False,
297+
top_logprobs: Optional[PositiveInt] = None,
298+
frequency_penalty: float = 0.0,
299+
presence_penalty: float = 0.0,
300+
temperature: float = 1.0,
301+
top_p: float = 1.0,
302+
stop: Optional[Union[str, List[str]]] = None,
303+
response_format: Optional[Dict[str, str]] = None,
304+
extra_body: Optional[Dict[str, Any]] = None,
305+
) -> GenerateOutput:
193306
structured_output = None
194307
if isinstance(input, tuple):
195308
input, structured_output = input
@@ -215,9 +328,11 @@ async def agenerate( # type: ignore
215328
"temperature": temperature,
216329
"top_p": top_p,
217330
"stop": stop,
331+
"extra_body": extra_body,
218332
}
219-
# Check if it's a vision generation task, in that case "stop" cannot be used or raises
220-
# an error in the API.
333+
334+
# Checks if any message contains an image, in that case "stop" cannot be used or
335+
# raises an error in the API.
221336
if isinstance(
222337
[row for row in input if row["role"] == "user"][0]["content"], list
223338
):
@@ -235,7 +350,7 @@ async def agenerate( # type: ignore
235350
# NOTE: `instructor` doesn't work with `n` parameter, so it will always return
236351
# only 1 choice.
237352
statistics = self._get_llm_statistics(completion._raw_response)
238-
if choice_logprobs := self._get_logprobs_from_choice(
353+
if choice_logprobs := self._get_logprobs_from_chat_completion_choice(
239354
completion._raw_response.choices[0]
240355
):
241356
output_logprobs = [choice_logprobs]
@@ -270,7 +385,9 @@ def _generations_from_openai_completion(
270385
f" Finish reason was: {choice.finish_reason}"
271386
)
272387
generations.append(content)
273-
if choice_logprobs := self._get_logprobs_from_choice(choice):
388+
if choice_logprobs := self._get_logprobs_from_chat_completion_choice(
389+
choice
390+
):
274391
logprobs.append(choice_logprobs)
275392

276393
statistics = self._get_llm_statistics(completion)
@@ -281,8 +398,8 @@ def _generations_from_openai_completion(
281398
logprobs=logprobs,
282399
)
283400

284-
def _get_logprobs_from_choice(
285-
self, choice: "OpenAIChoice"
401+
def _get_logprobs_from_chat_completion_choice(
402+
self, choice: "OpenAIChatCompletionChoice"
286403
) -> Union[List[List["Logprob"]], None]:
287404
if choice.logprobs is None or choice.logprobs.content is None:
288405
return None

0 commit comments

Comments
 (0)