Skip to content

Commit dadac54

Browse files
committed
Fix unit tests
1 parent 7b9bde5 commit dadac54

File tree

5 files changed

+70
-95
lines changed

5 files changed

+70
-95
lines changed

Diff for: src/distilabel/llms/huggingface/inference_endpoints.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,19 @@ def only_one_of_model_id_endpoint_name_or_base_url_provided(
159159

160160
if self.base_url and (self.model_id or self.endpoint_name):
161161
self._logger.warning( # type: ignore
162-
f"Since the `base_url={self.base_url}` is available and either one of `model_id` or `endpoint_name`"
163-
" is also provided, the `base_url` will either be ignored or overwritten with the one generated"
164-
" from either of those args, for serverless or dedicated inference endpoints, respectively."
162+
f"Since the `base_url={self.base_url}` is available and either one of `model_id`"
163+
" or `endpoint_name` is also provided, the `base_url` will either be ignored"
164+
" or overwritten with the one generated from either of those args, for serverless"
165+
" or dedicated inference endpoints, respectively."
166+
)
167+
168+
if self.model_id and self.tokenizer_id is None:
169+
self.tokenizer_id = self.model_id
170+
171+
if self.use_magpie_template and self.tokenizer_id is None:
172+
raise ValueError(
173+
"`use_magpie_template` cannot be `True` if `tokenizer_id` is `None`. Please,"
174+
" set a `tokenizer_id` and try again."
165175
)
166176

167177
if self.base_url and not (self.model_id or self.endpoint_name):
@@ -174,9 +184,9 @@ def only_one_of_model_id_endpoint_name_or_base_url_provided(
174184
return self
175185

176186
raise ValidationError(
177-
"Only one of `model_id` or `endpoint_name` must be provided. If `base_url` is provided too,"
178-
" it will be overwritten instead. Found `model_id`={self.model_id}, `endpoint_name`={self.endpoint_name},"
179-
f" and `base_url`={self.base_url}."
187+
f"Only one of `model_id` or `endpoint_name` must be provided. If `base_url` is"
188+
f" provided too, it will be overwritten instead. Found `model_id`={self.model_id},"
189+
f" `endpoint_name`={self.endpoint_name}, and `base_url`={self.base_url}."
180190
)
181191

182192
def load(self) -> None: # noqa: C901

Diff for: src/distilabel/llms/mixins/magpie.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class MagpieChatTemplateMixin(BaseModel, validate_assignment=True):
4848
- [Magpie: Alignment Data Synthesis from Scratch by Prompting Aligned LLMs with Nothing](https://arxiv.org/abs/2406.08464)
4949
"""
5050

51-
use_magpie_template: bool = True
51+
use_magpie_template: bool = False
5252
magpie_pre_query_template: Union[MagpieAvailablePreQueryTemplates, str, None] = None
5353

5454
@field_validator("magpie_pre_query_template")

Diff for: src/distilabel/pipeline/step_wrapper.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from queue import Queue
1717
from typing import Any, Dict, List, Optional, Union, cast
1818

19-
from distilabel.llms.mixins import CudaDevicePlacementMixin
19+
from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin
2020
from distilabel.pipeline.batch import _Batch
2121
from distilabel.pipeline.constants import LAST_BATCH_SENT_FLAG
2222
from distilabel.pipeline.typing import StepLoadStatus

Diff for: tests/unit/llms/huggingface/test_inference_endpoints.py

+43-85
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,38 @@
1414

1515
import random
1616
from unittest import mock
17-
from unittest.mock import AsyncMock, MagicMock, Mock, patch
17+
from unittest.mock import AsyncMock, MagicMock, patch
1818

1919
import nest_asyncio
2020
import pytest
2121
from distilabel.llms.huggingface.inference_endpoints import InferenceEndpointsLLM
2222

2323

2424
@patch("huggingface_hub.AsyncInferenceClient")
25-
@patch("openai.AsyncOpenAI")
2625
class TestInferenceEndpointsLLM:
27-
def test_load_no_api_key(
28-
self, mock_inference_client: MagicMock, mock_openai_client: MagicMock
26+
def test_no_tokenizer_magpie_raise_value_error(
27+
self, mock_inference_client: MagicMock
2928
) -> None:
29+
with pytest.raises(
30+
ValueError,
31+
match="`use_magpie_template` cannot be `True` if `tokenizer_id` is `None`",
32+
):
33+
InferenceEndpointsLLM(
34+
base_url="http://localhost:8000",
35+
use_magpie_template=True,
36+
magpie_pre_query_template="llama3",
37+
)
38+
39+
def test_tokenizer_id_set_if_model_id(
40+
self, mock_inference_client: MagicMock
41+
) -> None:
42+
llm = InferenceEndpointsLLM(
43+
model_id="distilabel-internal-testing/tiny-random-mistral"
44+
)
45+
46+
assert llm.tokenizer_id == llm.model_id
47+
48+
def test_load_no_api_key(self, mock_inference_client: MagicMock) -> None:
3049
llm = InferenceEndpointsLLM(
3150
model_id="distilabel-internal-testing/tiny-random-mistral"
3251
)
@@ -40,12 +59,8 @@ def test_load_no_api_key(
4059
):
4160
llm.load()
4261

43-
def test_load_with_cached_token(
44-
self, mock_inference_client: MagicMock, mock_openai_client: MagicMock
45-
) -> None:
46-
llm = InferenceEndpointsLLM(
47-
model_id="distilabel-internal-testing/tiny-random-mistral"
48-
)
62+
def test_load_with_cached_token(self, mock_inference_client: MagicMock) -> None:
63+
llm = InferenceEndpointsLLM(base_url="http://localhost:8000")
4964

5065
# Mock `huggingface_hub.constants.HF_TOKEN_PATH` to exist
5166
with (
@@ -58,7 +73,7 @@ def test_load_with_cached_token(
5873
llm.load()
5974

6075
def test_serverless_inference_endpoints_llm(
61-
self, mock_inference_client: MagicMock, mock_openai_client: MagicMock
76+
self, mock_inference_client: MagicMock
6277
) -> None:
6378
llm = InferenceEndpointsLLM(
6479
model_id="distilabel-internal-testing/tiny-random-mistral"
@@ -68,7 +83,7 @@ def test_serverless_inference_endpoints_llm(
6883
assert llm.model_name == "distilabel-internal-testing/tiny-random-mistral"
6984

7085
def test_dedicated_inference_endpoints_llm(
71-
self, mock_inference_client: MagicMock, mock_openai_client: MagicMock
86+
self, mock_inference_client: MagicMock
7287
) -> None:
7388
llm = InferenceEndpointsLLM(
7489
endpoint_name="tiny-random-mistral",
@@ -79,11 +94,12 @@ def test_dedicated_inference_endpoints_llm(
7994
assert llm.model_name == "tiny-random-mistral"
8095

8196
def test_dedicated_inference_endpoints_llm_via_url(
82-
self, mock_inference_client: MagicMock, mock_openai_client: MagicMock
97+
self, mock_inference_client: MagicMock
8398
) -> None:
8499
llm = InferenceEndpointsLLM(
85100
base_url="https://api-inference.huggingface.co/models/distilabel-internal-testing/tiny-random-mistral"
86101
)
102+
llm.load()
87103

88104
assert isinstance(llm, InferenceEndpointsLLM)
89105
assert (
@@ -93,12 +109,12 @@ def test_dedicated_inference_endpoints_llm_via_url(
93109

94110
@pytest.mark.asyncio
95111
async def test_agenerate_via_inference_client(
96-
self, mock_inference_client: MagicMock, mock_openai_client: MagicMock
112+
self, mock_inference_client: MagicMock
97113
) -> None:
98114
llm = InferenceEndpointsLLM(
99115
model_id="distilabel-internal-testing/tiny-random-mistral"
100116
)
101-
llm._aclient = mock_inference_client
117+
llm.load()
102118

103119
llm._aclient.text_generation = AsyncMock(
104120
return_value=" Aenean hendrerit aliquam velit. ..."
@@ -113,39 +129,14 @@ async def test_agenerate_via_inference_client(
113129
]
114130
) == [" Aenean hendrerit aliquam velit. ..."]
115131

116-
@pytest.mark.asyncio
117-
async def test_agenerate_via_openai_client(
118-
self, mock_inference_client: MagicMock, mock_openai_client: MagicMock
119-
) -> None:
120-
llm = InferenceEndpointsLLM(
121-
model_id="distilabel-internal-testing/tiny-random-mistral",
122-
use_openai_client=True,
123-
)
124-
llm._aclient = mock_openai_client
125-
126-
mocked_completion = Mock(
127-
choices=[Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ..."))]
128-
)
129-
llm._aclient.chat.completions.create = AsyncMock(return_value=mocked_completion)
130-
131-
assert await llm.agenerate(
132-
input=[
133-
{"role": "system", "content": ""},
134-
{
135-
"role": "user",
136-
"content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
137-
},
138-
]
139-
) == [" Aenean hendrerit aliquam velit. ..."]
140-
141132
@pytest.mark.asyncio
142133
async def test_generate_via_inference_client(
143-
self, mock_inference_client: MagicMock, mock_openai_client: MagicMock
134+
self, mock_inference_client: MagicMock
144135
) -> None:
145136
llm = InferenceEndpointsLLM(
146-
model_id="distilabel-internal-testing/tiny-random-mistral"
137+
model_id="distilabel-internal-testing/tiny-random-mistral",
147138
)
148-
llm._aclient = mock_inference_client
139+
llm.load()
149140

150141
llm._aclient.text_generation = AsyncMock(
151142
return_value=" Aenean hendrerit aliquam velit. ..."
@@ -165,45 +156,15 @@ async def test_generate_via_inference_client(
165156
]
166157
) == [(" Aenean hendrerit aliquam velit. ...",)]
167158

168-
@pytest.mark.asyncio
169-
async def test_generate_via_openai_client(
170-
self, mock_inference_client: MagicMock, mock_openai_client: MagicMock
171-
) -> None:
172-
llm = InferenceEndpointsLLM(
173-
model_id="distilabel-internal-testing/tiny-random-mistral",
174-
use_openai_client=True,
175-
)
176-
llm._aclient = mock_openai_client
177-
178-
mocked_completion = Mock(
179-
choices=[Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ..."))]
180-
)
181-
llm._aclient.chat.completions.create = AsyncMock(return_value=mocked_completion)
182-
183-
...
184-
nest_asyncio.apply()
185-
186-
assert llm.generate(
187-
inputs=[
188-
[
189-
{"role": "system", "content": ""},
190-
{
191-
"role": "user",
192-
"content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
193-
},
194-
]
195-
]
196-
) == [(" Aenean hendrerit aliquam velit. ...",)]
197-
198159
@pytest.mark.asyncio
199160
async def test_agenerate_with_structured_output(
200-
self, mock_inference_client: MagicMock, _: MagicMock
161+
self, mock_inference_client: MagicMock
201162
) -> None:
202163
llm = InferenceEndpointsLLM(
203164
model_id="distilabel-internal-testing/tiny-random-mistral",
204165
structured_output={"format": "regex", "schema": r"\b[A-Z][a-z]*\b"},
205166
)
206-
llm._aclient = mock_inference_client
167+
llm.load()
207168

208169
llm._aclient.text_generation = AsyncMock(
209170
return_value=" Aenean hendrerit aliquam velit. ..."
@@ -223,7 +184,7 @@ async def test_agenerate_with_structured_output(
223184
) == [" Aenean hendrerit aliquam velit. ..."]
224185

225186
kwargs = {
226-
"prompt": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
187+
"prompt": "<s>[INST] Lorem ipsum dolor sit amet, consectetur adipiscing elit. [/INST]",
227188
"max_new_tokens": 128,
228189
"do_sample": False,
229190
"typical_p": None,
@@ -235,15 +196,11 @@ async def test_agenerate_with_structured_output(
235196
"return_full_text": False,
236197
"watermark": False,
237198
"grammar": {"type": "regex", "value": "\\b[A-Z][a-z]*\\b"},
238-
"seed": 478163327, # pre-computed random value with `random.seed(42)`
199+
"seed": 2053695854357871005, # pre-computed random value with `random.seed(42)`
239200
}
240-
mock_inference_client.text_generation.assert_called_with(**kwargs)
201+
llm._aclient.text_generation.assert_called_with(**kwargs)
241202

242-
def test_serialization(
243-
self,
244-
mock_inference_client: MagicMock,
245-
mock_openai_client: MagicMock,
246-
) -> None:
203+
def test_serialization(self, mock_inference_client: MagicMock) -> None:
247204
llm = InferenceEndpointsLLM(
248205
model_id="distilabel-internal-testing/tiny-random-mistral",
249206
)
@@ -253,11 +210,12 @@ def test_serialization(
253210
"endpoint_name": None,
254211
"endpoint_namespace": None,
255212
"base_url": None,
256-
"tokenizer_id": None,
213+
"tokenizer_id": "distilabel-internal-testing/tiny-random-mistral",
257214
"generation_kwargs": {},
215+
"magpie_pre_query_template": None,
258216
"structured_output": None,
259217
"model_display_name": None,
260-
"use_openai_client": False,
218+
"use_magpie_template": False,
261219
"type_info": {
262220
"module": "distilabel.llms.huggingface.inference_endpoints",
263221
"name": "InferenceEndpointsLLM",

Diff for: tests/unit/steps/tasks/structured_outputs/test_outlines.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, Dict, Type, Union
15+
from typing import Any, Dict, Literal, Type, Union
1616

1717
import pytest
1818
from distilabel.llms.huggingface.transformers import TransformersLLM
@@ -33,6 +33,7 @@ class DummyUserTest(BaseModel):
3333
DUMP_JSON = {
3434
"cuda_devices": "auto",
3535
"generation_kwargs": {},
36+
"magpie_pre_query_template": None,
3637
"structured_output": {
3738
"format": "json",
3839
"schema": {
@@ -57,6 +58,7 @@ class DummyUserTest(BaseModel):
5758
"device": None,
5859
"device_map": None,
5960
"token": None,
61+
"use_magpie_template": False,
6062
"type_info": {
6163
"module": "distilabel.llms.huggingface.transformers",
6264
"name": "TransformersLLM",
@@ -66,6 +68,7 @@ class DummyUserTest(BaseModel):
6668
DUMP_REGEX = {
6769
"cuda_devices": "auto",
6870
"generation_kwargs": {},
71+
"magpie_pre_query_template": None,
6972
"structured_output": {
7073
"format": "regex",
7174
"schema": "((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)",
@@ -81,6 +84,7 @@ class DummyUserTest(BaseModel):
8184
"device": None,
8285
"device_map": None,
8386
"token": None,
87+
"use_magpie_template": False,
8488
"type_info": {
8589
"module": "distilabel.llms.huggingface.transformers",
8690
"name": "TransformersLLM",
@@ -149,7 +153,10 @@ def test_generation(
149153
],
150154
)
151155
def test_serialization(
152-
self, format: str, schema: Union[str, Type[BaseModel]], dump: Dict[str, Any]
156+
self,
157+
format: Literal["json", "regex"],
158+
schema: Union[str, Type[BaseModel]],
159+
dump: Dict[str, Any],
153160
) -> None:
154161
llm = TransformersLLM(
155162
model="openaccess-ai-collective/tiny-mistral",

0 commit comments

Comments
 (0)