Skip to content

Commit 4fc569d

Browse files
gabrielmbmbplaguss
andauthored
Add Magpie and MagpieGenerator tasks (#778)
* Move `CudaDevicePlacementMixin` to new module * Initial work for implementing Magpie * Simplify magpie implementation * Remove `use_open_ai` and add `MagpieChatTemplateMixin` to `InferenceEndpointsLLM` * Add `MagpieChatTemplateMixin` to `vLLM` * Add `MagpieGenerator` task * Fix unit tests * Fix docstrings * Mock `HF_TOKEN` environment variable * Fix list index out of range * Fix `MagpieGenerator` last batch * Add `only_instruction` attribute * Update categories * testing * Worth trying * Add examples * Add magpie unit tests * Fix docstring * Update docstrings * Apply suggestions from code review Co-authored-by: Agus <agustin@argilla.io> * Update to `huggingface_hub >= 0.22.0` * Add generation with `chat_completion` * Update `agenerate` arguments * Update unit tests * Fix `tools` were not being used * Update unit tests * Fix list of tuples instead of list of list * Add missing docstring * Add `chat_completion` unit tests * Fix `GroqLLM.generate` unit test after updating `_agenerate` --------- Co-authored-by: Agus <agustin@argilla.io>
1 parent 86d4e80 commit 4fc569d

28 files changed

+1785
-253
lines changed

Diff for: pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ anthropic = ["anthropic >= 0.20.0"]
7474
argilla = ["argilla >= 1.29.0"]
7575
cohere = ["cohere >= 5.2.0"]
7676
groq = ["groq >= 0.4.1"]
77-
hf-inference-endpoints = ["huggingface_hub >= 0.19.0"]
77+
hf-inference-endpoints = ["huggingface_hub >= 0.22.0"]
7878
hf-transformers = ["transformers >= 4.34.1", "torch >= 2.0.0"]
7979
instructor = ["instructor >= 1.2.3"]
8080
litellm = ["litellm >= 1.30.0"]

Diff for: src/distilabel/llms/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from distilabel.llms.litellm import LiteLLM
2323
from distilabel.llms.llamacpp import LlamaCppLLM
2424
from distilabel.llms.mistral import MistralLLM
25-
from distilabel.llms.mixins import CudaDevicePlacementMixin
25+
from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin
2626
from distilabel.llms.moa import MixtureOfAgentsLLM
2727
from distilabel.llms.ollama import OllamaLLM
2828
from distilabel.llms.openai import OpenAILLM

Diff for: src/distilabel/llms/azure.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class AzureOpenAILLM(OpenAILLM):
4545
`None` if not set.
4646
4747
Icon:
48-
`:simple-microsoftazure:`
48+
`:material-microsoft-azure:`
4949
5050
Examples:
5151

Diff for: src/distilabel/llms/base.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,10 @@ async def _agenerate(
329329
for _ in range(num_generations)
330330
]
331331
outputs = [outputs[0] for outputs in await asyncio.gather(*tasks)]
332-
return list(grouper(outputs, n=num_generations, incomplete="ignore"))
332+
return [
333+
list(group)
334+
for group in grouper(outputs, n=num_generations, incomplete="ignore")
335+
]
333336

334337
def generate(
335338
self,

Diff for: src/distilabel/llms/huggingface/inference_endpoints.py

+311-152
Large diffs are not rendered by default.

Diff for: src/distilabel/llms/huggingface/transformers.py

+27-8
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919

2020
from distilabel.llms.base import LLM
2121
from distilabel.llms.chat_templates import CHATML_TEMPLATE
22-
from distilabel.llms.mixins import CudaDevicePlacementMixin
22+
from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin
23+
from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin
2324
from distilabel.llms.typing import GenerateOutput
2425
from distilabel.mixins.runtime_parameters import RuntimeParameter
2526
from distilabel.steps.tasks.typing import OutlinesStructuredOutputType, StandardInput
@@ -32,7 +33,7 @@
3233
from distilabel.llms.typing import HiddenState
3334

3435

35-
class TransformersLLM(LLM, CudaDevicePlacementMixin):
36+
class TransformersLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin):
3637
"""Hugging Face `transformers` library LLM implementation using the text generation
3738
pipeline.
3839
@@ -64,6 +65,12 @@ class TransformersLLM(LLM, CudaDevicePlacementMixin):
6465
local configuration will be used. Defaults to `None`.
6566
structured_output: a dictionary containing the structured output configuration or if more
6667
fine-grained control is needed, an instance of `OutlinesStructuredOutput`. Defaults to None.
68+
use_magpie_template: a flag used to enable/disable applying the Magpie pre-query
69+
template. Defaults to `False`.
70+
magpie_pre_query_template: the pre-query template to be applied to the prompt or
71+
sent to the LLM to generate an instruction or a follow up user message. Valid
72+
values are "llama3", "qwen2" or another pre-query template provided. Defaults
73+
to `None`.
6774
6875
Icon:
6976
`:hugging:`
@@ -157,14 +164,25 @@ def model_name(self) -> str:
157164
return self.model
158165

159166
def prepare_input(self, input: "StandardInput") -> str:
160-
"""Prepares the input by applying the chat template to the input, which is formatted
161-
as an OpenAI conversation, and adding the generation prompt.
167+
"""Prepares the input (applying the chat template and tokenization) for the provided
168+
input.
169+
170+
Args:
171+
input: the input list containing chat items.
172+
173+
Returns:
174+
The prompt to send to the LLM.
162175
"""
163-
return self._pipeline.tokenizer.apply_chat_template( # type: ignore
164-
input, # type: ignore
165-
tokenize=False,
166-
add_generation_prompt=True,
176+
prompt: str = (
177+
self._pipeline.tokenizer.apply_chat_template( # type: ignore
178+
input, # type: ignore
179+
tokenize=False,
180+
add_generation_prompt=True,
181+
)
182+
if input
183+
else ""
167184
)
185+
return super().apply_magpie_pre_query_template(prompt, input)
168186

169187
@validate_call
170188
def generate( # type: ignore
@@ -209,6 +227,7 @@ def generate( # type: ignore
209227
do_sample=do_sample,
210228
num_return_sequences=num_generations,
211229
prefix_allowed_tokens_fn=self._prefix_allowed_tokens_fn,
230+
pad_token_id=self._pipeline.tokenizer.eos_token_id, # type: ignore
212231
)
213232
return [
214233
[generation["generated_text"] for generation in output]

Diff for: src/distilabel/llms/mixins/__init__.py

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2023-present, Argilla, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
File renamed without changes.

Diff for: src/distilabel/llms/mixins/magpie.py

+89
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright 2023-present, Argilla, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import TYPE_CHECKING, Dict, Literal, Union
16+
17+
from pydantic import BaseModel, field_validator, model_validator
18+
from typing_extensions import Self
19+
20+
if TYPE_CHECKING:
21+
from distilabel.steps.tasks.typing import StandardInput
22+
23+
MagpieAvailablePreQueryTemplates = Literal["llama3", "qwen2"]
24+
"""The available predefined pre-query templates."""
25+
26+
MAGPIE_PRE_QUERY_TEMPLATES: Dict[MagpieAvailablePreQueryTemplates, str] = {
27+
"llama3": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n",
28+
"qwen2": "<|im_start|>user\n",
29+
}
30+
31+
32+
class MagpieChatTemplateMixin(BaseModel, validate_assignment=True):
33+
"""A simple mixin that adds the required logic to apply the pre-query template that
34+
allows to an instruct fine-tuned LLM to generate user instructions as described in
35+
the paper 'Magpie: Alignment Data Synthesis from Scratch by Prompting Aligned LLMs with Nothing'.
36+
37+
This mixin is meant to be used in combination with the [Magpie][distilabel.steps.tasks.magpie.base.Magpie]
38+
task.
39+
40+
Attributes:
41+
use_magpie_template: a flag used to enable/disable applying the Magpie pre-query
42+
template. Defaults to `False`.
43+
magpie_pre_query_template: the pre-query template to be applied to the prompt or
44+
sent to the LLM to generate an instruction or a follow up user message. Valid
45+
values are "llama3", "qwen2" or another pre-query template provided. Defaults
46+
to `None`.
47+
48+
References:
49+
- [Magpie: Alignment Data Synthesis from Scratch by Prompting Aligned LLMs with Nothing](https://arxiv.org/abs/2406.08464)
50+
"""
51+
52+
use_magpie_template: bool = False
53+
magpie_pre_query_template: Union[MagpieAvailablePreQueryTemplates, str, None] = None
54+
55+
@field_validator("magpie_pre_query_template")
56+
@classmethod
57+
def magpie_pre_query_template_validator(cls, value: str) -> str:
58+
"""Resolves the pre-query template alias if it exists, otherwise, returns the
59+
value with no modification."""
60+
if value in MAGPIE_PRE_QUERY_TEMPLATES:
61+
return MAGPIE_PRE_QUERY_TEMPLATES[value]
62+
return value
63+
64+
@model_validator(mode="after")
65+
def use_magpie_template_validation(self) -> Self:
66+
"""Checks that there is a pre-query template set if Magpie is going to be used."""
67+
if self.use_magpie_template and self.magpie_pre_query_template is None:
68+
raise ValueError(
69+
f"Cannot set `use_magpie_template=True` if `magpie_pre_query_template` is"
70+
f" `None`. To use Magpie with `{self.__class__.__name__}` you need to set"
71+
f" the `magpie_pre_query_template` attribute."
72+
)
73+
return self
74+
75+
def apply_magpie_pre_query_template(
76+
self, prompt: str, input: "StandardInput"
77+
) -> str:
78+
"""Applies the pre-query template to the prompt if Magpie is going to be used.
79+
80+
Args:
81+
prompt: the prompt to which the pre-query template has to be applied.
82+
input: the list with the chat items that were used to generate the prompt.
83+
84+
Returns:
85+
The prompt with the pre-query template applied if needed.
86+
"""
87+
if not self.use_magpie_template or (input and input[-1]["role"] == "user"):
88+
return prompt
89+
return prompt + self.magpie_pre_query_template # type: ignore

Diff for: src/distilabel/llms/vllm.py

+31-12
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030

3131
from distilabel.llms.base import LLM
3232
from distilabel.llms.chat_templates import CHATML_TEMPLATE
33-
from distilabel.llms.mixins import CudaDevicePlacementMixin
33+
from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin
34+
from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin
3435
from distilabel.llms.typing import GenerateOutput
3536
from distilabel.mixins.runtime_parameters import RuntimeParameter
3637
from distilabel.steps.tasks.typing import FormattedInput, OutlinesStructuredOutputType
@@ -39,11 +40,13 @@
3940
from transformers import PreTrainedTokenizer
4041
from vllm import LLM as _vLLM
4142

43+
from distilabel.steps.tasks.typing import StandardInput
44+
4245

4346
SamplingParams = None
4447

4548

46-
class vLLM(LLM, CudaDevicePlacementMixin):
49+
class vLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin):
4750
"""`vLLM` library LLM implementation.
4851
4952
Attributes:
@@ -75,6 +78,12 @@ class vLLM(LLM, CudaDevicePlacementMixin):
7578
_tokenizer: the tokenizer instance used to format the prompt before passing it to
7679
the `LLM`. This attribute is meant to be used internally and should not be
7780
accessed directly. It will be set in the `load` method.
81+
use_magpie_template: a flag used to enable/disable applying the Magpie pre-query
82+
template. Defaults to `False`.
83+
magpie_pre_query_template: the pre-query template to be applied to the prompt or
84+
sent to the LLM to generate an instruction or a follow up user message. Valid
85+
values are "llama3", "qwen2" or another pre-query template provided. Defaults
86+
to `None`.
7887
7988
References:
8089
- https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py
@@ -213,15 +222,26 @@ def model_name(self) -> str:
213222
"""Returns the model name used for the LLM."""
214223
return self.model
215224

216-
def prepare_input(self, input: "FormattedInput") -> str:
217-
"""Prepares the input by applying the chat template to the input, which is formatted
218-
as an OpenAI conversation, and adding the generation prompt.
225+
def prepare_input(self, input: "StandardInput") -> str:
226+
"""Prepares the input (applying the chat template and tokenization) for the provided
227+
input.
228+
229+
Args:
230+
input: the input list containing chat items.
231+
232+
Returns:
233+
The prompt to send to the LLM.
219234
"""
220-
return self._tokenizer.apply_chat_template( # type: ignore
221-
input, # type: ignore
222-
tokenize=False,
223-
add_generation_prompt=True, # type: ignore
235+
prompt: str = (
236+
self._tokenizer.apply_chat_template( # type: ignore
237+
input, # type: ignore
238+
tokenize=False,
239+
add_generation_prompt=True, # type: ignore
240+
)
241+
if input
242+
else ""
224243
)
244+
return super().apply_magpie_pre_query_template(prompt, input)
225245

226246
def _prepare_batches(
227247
self, inputs: List[FormattedInput]
@@ -304,14 +324,13 @@ def generate( # type: ignore
304324
if extra_sampling_params is None:
305325
extra_sampling_params = {}
306326
structured_output = None
307-
needs_sorting = False
308327

309328
if isinstance(inputs[0], tuple):
310329
prepared_batches, sorted_indices = self._prepare_batches(inputs)
311-
needs_sorting = True
312330
else:
313331
# Simulate a batch without the structured output content
314332
prepared_batches = [([self.prepare_input(input) for input in inputs], None)]
333+
sorted_indices = None
315334

316335
# In case we have a single structured output for the dataset, we can
317336
logits_processors = None
@@ -348,7 +367,7 @@ def generate( # type: ignore
348367

349368
# If logits_processor is set, we need to sort the outputs back to the original order
350369
# (would be needed only if we have multiple structured outputs in the dataset)
351-
if needs_sorting:
370+
if sorted_indices is not None:
352371
batched_outputs = _sort_batches(
353372
batched_outputs, sorted_indices, num_generations=num_generations
354373
)

Diff for: src/distilabel/pipeline/step_wrapper.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from queue import Queue
1717
from typing import Any, Dict, List, Optional, Union, cast
1818

19-
from distilabel.llms.mixins import CudaDevicePlacementMixin
19+
from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin
2020
from distilabel.pipeline.batch import _Batch
2121
from distilabel.pipeline.constants import LAST_BATCH_SENT_FLAG
2222
from distilabel.pipeline.typing import StepLoadStatus

Diff for: src/distilabel/steps/tasks/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
from distilabel.steps.tasks.instruction_backtranslation import (
3636
InstructionBacktranslation,
3737
)
38+
from distilabel.steps.tasks.magpie.base import Magpie
39+
from distilabel.steps.tasks.magpie.generator import MagpieGenerator
3840
from distilabel.steps.tasks.pair_rm import PairRM
3941
from distilabel.steps.tasks.prometheus_eval import PrometheusEval
4042
from distilabel.steps.tasks.quality_scorer import QualityScorer
@@ -64,6 +66,8 @@
6466
"GenerateTextRetrievalData",
6567
"MonolingualTripletGenerator",
6668
"InstructionBacktranslation",
69+
"Magpie",
70+
"MagpieGenerator",
6771
"PairRM",
6872
"PrometheusEval",
6973
"QualityScorer",

Diff for: src/distilabel/steps/tasks/evol_instruct/generator.py

+1
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ def process(self, offset: int = 0) -> "GeneratorStepOutput": # type: ignore
280280
instructions = []
281281
mutation_no = 0
282282

283+
# TODO: update to take into account `offset`
283284
iter_no = 0
284285
while len(instructions) < self.num_instructions:
285286
prompts = self._apply_random_mutation(iter_no=iter_no)

Diff for: src/distilabel/steps/tasks/magpie/__init__.py

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2023-present, Argilla, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+

0 commit comments

Comments
 (0)