Skip to content

Commit be5606f

Browse files
committed
Merge branch 'develop' of https://github.com/argilla-io/distilabel into update-pre-commit
2 parents 5720979 + f7f9e26 commit be5606f

File tree

18 files changed

+1341
-12
lines changed

18 files changed

+1341
-12
lines changed

Diff for: docs/sections/how_to_guides/advanced/checkpointing.md

+30
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,33 @@ The final datasets can be found in the following links:
5757
- Checkpoint dataset: [distilabel-internal-testing/streaming_test_1](https://huggingface.co/datasets/distilabel-internal-testing/streaming_test_1)
5858

5959
- Final distiset: [distilabel-internal-testing/streaming_test](https://huggingface.co/datasets/distilabel-internal-testing/streaming_test)
60+
61+
### Read back the data
62+
63+
In case we want to take a look at a given filename we can take advantage of the `huggingface_hub` library. We will use the `HfFileSystem` to list all the `jsonl` files in the dataset repository, and download onle of them to show how it works:
64+
65+
```python
66+
from huggingface_hub import HfFileSystem, hf_hub_download
67+
68+
dataset_name = "distilabel-internal-testing/streaming_test_1"
69+
fs = HfFileSystem()
70+
filenames = fs.glob(f"datasets/{dataset_name}/**/*.jsonl")
71+
72+
filename = hf_hub_download(repo_id="distilabel-internal-testing/streaming_test_1", filename="config-0/train-00000.jsonl", repo_type="dataset")
73+
```
74+
75+
The filename will be downloaded to the default cache, and to read the data we can just proceed as with any other jsonlines file:
76+
77+
```python
78+
import json
79+
data = []
80+
81+
with open(filename, "r") as f:
82+
data = [json.loads(line) for line in f.readlines()]
83+
84+
# [{'a': 1, 'b': 5},
85+
# {'a': 2, 'b': 6},
86+
# {'a': 3, 'b': 7},
87+
# ...
88+
```
89+

Diff for: examples/exam_questions.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,15 @@ class ExamQuestions(BaseModel):
5959
name="load_instructions",
6060
data=[
6161
{
62-
"page": page.content,
62+
"instruction": page.content,
6363
}
6464
],
6565
)
6666

6767
text_generation = TextGeneration(
6868
name="exam_generation",
6969
system_prompt=SYSTEM_PROMPT,
70-
template="Generate a list of answers and questions about the document. Document:\n\n{{ page }}",
70+
template="Generate a list of answers and questions about the document. Document:\n\n{{ instruction }}",
7171
llm=InferenceEndpointsLLM(
7272
model_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
7373
tokenizer_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
@@ -95,4 +95,4 @@ class ExamQuestions(BaseModel):
9595
},
9696
use_cache=False,
9797
)
98-
distiset.push_to_hub("USERNAME/exam_questions")
98+
# distiset.push_to_hub("USERNAME/exam_questions")

Diff for: pyproject.toml

+11-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ argilla = ["argilla >= 2.0.0", "ipython"]
7979
cohere = ["cohere >= 5.2.0"]
8080
groq = ["groq >= 0.4.1"]
8181
hf-inference-endpoints = ["huggingface_hub >= 0.22.0"]
82-
hf-transformers = ["transformers >= 4.34.1", "torch >= 2.0.0"]
82+
hf-transformers = ["transformers == 4.48.3", "torch >= 2.0.0"]
8383
instructor = ["instructor >= 1.2.3"]
8484
litellm = ["litellm >= 1.30.0"]
8585
llama-cpp = ["llama-cpp-python >= 0.2.0"]
@@ -107,6 +107,16 @@ vision = ["Pillow >= 10.3.0"] # To work with images.
107107
# minhash
108108
minhash = ["datasketch >= 1.6.5", "nltk>3.8.1"]
109109

110+
sglang = ["sglang[all]>=0.4.3.post2", "transformers == 4.48.3"]
111+
112+
[tool.hatch.envs.default]
113+
dependencies = [
114+
"sglang[all]>=0.4.3.post2",
115+
"transformers == 4.48.3",
116+
]
117+
installer = "pip"
118+
pip-args = ["--find-links", "https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python"]
119+
110120
[project.urls]
111121
Documentation = "https://distilabel.argilla.io/"
112122
Issues = "https://github.com/argilla/distilabel/issues"

Diff for: src/distilabel/embeddings.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,11 @@
2727
from distilabel.models.embeddings.sentence_transformers import (
2828
SentenceTransformerEmbeddings,
2929
)
30-
from distilabel.models.embeddings.vllm import vLLMEmbeddings
30+
from distilabel.models.embeddings.vllm import SGLangEmbeddings, vLLMEmbeddings
3131

3232
__all__ = [
3333
"Embeddings",
34+
"SGLangEmbeddings",
3435
"SentenceTransformerEmbeddings",
3536
"vLLMEmbeddings",
3637
]

Diff for: src/distilabel/llms.py

+3
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from distilabel.models.llms.moa import MixtureOfAgentsLLM
3838
from distilabel.models.llms.ollama import OllamaLLM
3939
from distilabel.models.llms.openai import OpenAILLM
40+
from distilabel.models.llms.sglang import ClientSGLang, SGLang
4041
from distilabel.models.llms.together import TogetherLLM
4142
from distilabel.models.llms.vertexai import VertexAILLM
4243
from distilabel.models.llms.vllm import ClientvLLM, vLLM
@@ -49,6 +50,7 @@
4950
"AnyscaleLLM",
5051
"AsyncLLM",
5152
"AzureOpenAILLM",
53+
"ClientSGLang",
5254
"ClientvLLM",
5355
"CohereLLM",
5456
"CudaDevicePlacementMixin",
@@ -63,6 +65,7 @@
6365
"MlxLLM",
6466
"OllamaLLM",
6567
"OpenAILLM",
68+
"SGLang",
6669
"TogetherLLM",
6770
"TransformersLLM",
6871
"VertexAILLM",

Diff for: src/distilabel/models/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from distilabel.models.embeddings.sentence_transformers import (
1919
SentenceTransformerEmbeddings,
2020
)
21+
from distilabel.models.embeddings.sglang import SGLangEmbeddings
2122
from distilabel.models.embeddings.vllm import vLLMEmbeddings
2223
from distilabel.models.image_generation.base import (
2324
AsyncImageGenerationModel,
@@ -41,6 +42,7 @@
4142
from distilabel.models.llms.moa import MixtureOfAgentsLLM
4243
from distilabel.models.llms.ollama import OllamaLLM
4344
from distilabel.models.llms.openai import OpenAILLM
45+
from distilabel.models.llms.sglang import ClientSGLang, SGLang
4446
from distilabel.models.llms.together import TogetherLLM
4547
from distilabel.models.llms.vertexai import VertexAILLM
4648
from distilabel.models.llms.vllm import ClientvLLM, vLLM
@@ -54,6 +56,7 @@
5456
"AsyncImageGenerationModel",
5557
"AsyncLLM",
5658
"AzureOpenAILLM",
59+
"ClientSGLang",
5760
"ClientvLLM",
5861
"CohereLLM",
5962
"CudaDevicePlacementMixin",
@@ -73,6 +76,8 @@
7376
"OllamaLLM",
7477
"OpenAIImageGeneration",
7578
"OpenAILLM",
79+
"SGLang",
80+
"SGLangEmbeddings",
7681
"SentenceTransformerEmbeddings",
7782
"TogetherLLM",
7883
"TransformersLLM",

Diff for: src/distilabel/models/base_clients/inference_endpoints.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,7 @@ def load(self) -> None: # noqa: C901
108108
f"Model {self.model_id} is not currently deployed or is not running the TGI framework"
109109
)
110110

111-
self.base_url = client._resolve_url(
112-
model=self.model_id, task="text-generation"
113-
)
111+
self._base_url = client.base_url
114112

115113
if self.endpoint_name is not None:
116114
client = get_inference_endpoint(

Diff for: src/distilabel/models/embeddings/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717
from distilabel.models.embeddings.sentence_transformers import (
1818
SentenceTransformerEmbeddings,
1919
)
20+
from distilabel.models.embeddings.sglang import SGLangEmbeddings
2021
from distilabel.models.embeddings.vllm import vLLMEmbeddings
2122

2223
__all__ = [
2324
"Embeddings",
2425
"LlamaCppEmbeddings",
26+
"SGLangEmbeddings",
2527
"SentenceTransformerEmbeddings",
2628
"vLLMEmbeddings",
2729
]

Diff for: src/distilabel/models/embeddings/sglang.py

+125
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Copyright 2023-present, Argilla, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
16+
17+
from pydantic import Field, PrivateAttr
18+
19+
from distilabel.mixins.runtime_parameters import RuntimeParameter
20+
from distilabel.models.embeddings.base import Embeddings
21+
from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin
22+
23+
if TYPE_CHECKING:
24+
from sglang import Engine
25+
26+
27+
class SGLangEmbeddings(Embeddings, CudaDevicePlacementMixin):
28+
"""`sglang` library implementation for embedding generation.
29+
30+
Attributes:
31+
model: the model Hugging Face Hub repo id or a path to a directory containing the
32+
model weights and configuration files.
33+
dtype: the data type to use for the model. Defaults to `auto`.
34+
trust_remote_code: whether to trust the remote code when loading the model. Defaults
35+
to `False`.
36+
quantization: the quantization mode to use for the model. Defaults to `None`.
37+
revision: the revision of the model to load. Defaults to `None`.
38+
seed: the seed to use for the random number generator. Defaults to `0`.
39+
extra_kwargs: additional dictionary of keyword arguments that will be passed to the
40+
`Engine` class of `sglang` library. Defaults to `{}`.
41+
_model: the `SGLang` model instance. This attribute is meant to be used internally
42+
and should not be accessed directly. It will be set in the `load` method.
43+
44+
References:
45+
- https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/server_args.py
46+
47+
Examples:
48+
Generating sentence embeddings:
49+
50+
```python
51+
if __name__ == "__main__":
52+
53+
from distilabel.models import SGLangEmbeddings
54+
embeddings = SGLangEmbeddings(model="intfloat/e5-mistral-7b-instruct")
55+
embeddings.load()
56+
results = embeddings.encode(inputs=["distilabel is awesome!", "and Argilla!"])
57+
print(results)
58+
# [
59+
# [0.0203704833984375, -0.0060882568359375, ...],
60+
# [0.02398681640625, 0.0177001953125 ...],
61+
# ]
62+
```
63+
"""
64+
65+
model: str
66+
dtype: str = "auto"
67+
trust_remote_code: bool = False
68+
quantization: Optional[str] = None
69+
revision: Optional[str] = None
70+
71+
seed: int = 0
72+
73+
extra_kwargs: Optional[RuntimeParameter[Dict[str, Any]]] = Field(
74+
default_factory=dict,
75+
description="Additional dictionary of keyword arguments that will be passed to the"
76+
" `Engine` class of `sglang` library. See all the supported arguments at: "
77+
"https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/entrypoints/engine.py",
78+
)
79+
80+
_model: "Engine" = PrivateAttr(None)
81+
82+
def load(self) -> None:
83+
"""Loads the `sglang` model using either the path or the Hugging Face Hub repository id."""
84+
super().load()
85+
86+
CudaDevicePlacementMixin.load(self)
87+
88+
try:
89+
from sglang import Engine
90+
except ImportError as err:
91+
raise ImportError(
92+
"sglang is not installed. Please install it with sglang document https://docs.sglang.ai/start/install.html."
93+
) from err
94+
95+
self._model = Engine(
96+
model_path=self.model,
97+
dtype=self.dtype,
98+
trust_remote_code=self.trust_remote_code,
99+
quantization=self.quantization,
100+
revision=self.revision,
101+
random_seed=self.seed,
102+
**self.extra_kwargs, # type: ignore
103+
)
104+
105+
def unload(self) -> None:
106+
"""Unloads the `SGLang` model."""
107+
self._model = None
108+
CudaDevicePlacementMixin.unload(self)
109+
super().unload()
110+
111+
@property
112+
def model_name(self) -> str:
113+
"""Returns the name of the model."""
114+
return self.model
115+
116+
def encode(self, inputs: List[str]) -> List[List[Union[int, float]]]:
117+
"""Generates embeddings for the provided inputs.
118+
119+
Args:
120+
inputs: a list of texts for which an embedding has to be generated.
121+
122+
Returns:
123+
The generated embeddings.
124+
"""
125+
return [output["embedding"] for output in self._model.encode(inputs)]

Diff for: src/distilabel/models/llms/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from distilabel.models.llms.moa import MixtureOfAgentsLLM
2727
from distilabel.models.llms.ollama import OllamaLLM
2828
from distilabel.models.llms.openai import OpenAILLM
29+
from distilabel.models.llms.sglang import ClientSGLang, SGLang
2930
from distilabel.models.llms.together import TogetherLLM
3031
from distilabel.models.llms.vertexai import VertexAILLM
3132
from distilabel.models.llms.vllm import ClientvLLM, vLLM
@@ -38,6 +39,7 @@
3839
"AnyscaleLLM",
3940
"AsyncLLM",
4041
"AzureOpenAILLM",
42+
"ClientSGLang",
4143
"ClientvLLM",
4244
"CohereLLM",
4345
"CudaDevicePlacementMixin",
@@ -52,6 +54,7 @@
5254
"MlxLLM",
5355
"OllamaLLM",
5456
"OpenAILLM",
57+
"SGLang",
5558
"TogetherLLM",
5659
"TransformersLLM",
5760
"VertexAILLM",

Diff for: src/distilabel/models/llms/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,7 @@ def _prepare_kwargs(
476476
Args:
477477
arguments: The arguments that would be passed to the LLM as **kwargs.
478478
to update with the structured output configuration.
479-
structured_outputs: The structured output configuration to update the arguments.
479+
structured_output: The structured output configuration to update the arguments.
480480
481481
Returns:
482482
kwargs updated with the special arguments used by `instructor`.

0 commit comments

Comments
 (0)