14
14
15
15
import random
16
16
from unittest import mock
17
- from unittest .mock import AsyncMock , MagicMock , Mock , patch
17
+ from unittest .mock import AsyncMock , MagicMock , patch
18
18
19
19
import nest_asyncio
20
20
import pytest
21
21
from distilabel .llms .huggingface .inference_endpoints import InferenceEndpointsLLM
22
22
23
23
24
24
@patch ("huggingface_hub.AsyncInferenceClient" )
25
- @patch ("openai.AsyncOpenAI" )
26
25
class TestInferenceEndpointsLLM :
27
- def test_load_no_api_key (
28
- self , mock_inference_client : MagicMock , mock_openai_client : MagicMock
26
+ def test_no_tokenizer_magpie_raise_value_error (
27
+ self , mock_inference_client : MagicMock
29
28
) -> None :
29
+ with pytest .raises (
30
+ ValueError ,
31
+ match = "`use_magpie_template` cannot be `True` if `tokenizer_id` is `None`" ,
32
+ ):
33
+ InferenceEndpointsLLM (
34
+ base_url = "http://localhost:8000" ,
35
+ use_magpie_template = True ,
36
+ magpie_pre_query_template = "llama3" ,
37
+ )
38
+
39
+ def test_tokenizer_id_set_if_model_id (
40
+ self , mock_inference_client : MagicMock
41
+ ) -> None :
42
+ llm = InferenceEndpointsLLM (
43
+ model_id = "distilabel-internal-testing/tiny-random-mistral"
44
+ )
45
+
46
+ assert llm .tokenizer_id == llm .model_id
47
+
48
+ def test_load_no_api_key (self , mock_inference_client : MagicMock ) -> None :
30
49
llm = InferenceEndpointsLLM (
31
50
model_id = "distilabel-internal-testing/tiny-random-mistral"
32
51
)
@@ -40,12 +59,8 @@ def test_load_no_api_key(
40
59
):
41
60
llm .load ()
42
61
43
- def test_load_with_cached_token (
44
- self , mock_inference_client : MagicMock , mock_openai_client : MagicMock
45
- ) -> None :
46
- llm = InferenceEndpointsLLM (
47
- model_id = "distilabel-internal-testing/tiny-random-mistral"
48
- )
62
+ def test_load_with_cached_token (self , mock_inference_client : MagicMock ) -> None :
63
+ llm = InferenceEndpointsLLM (base_url = "http://localhost:8000" )
49
64
50
65
# Mock `huggingface_hub.constants.HF_TOKEN_PATH` to exist
51
66
with (
@@ -58,7 +73,7 @@ def test_load_with_cached_token(
58
73
llm .load ()
59
74
60
75
def test_serverless_inference_endpoints_llm (
61
- self , mock_inference_client : MagicMock , mock_openai_client : MagicMock
76
+ self , mock_inference_client : MagicMock
62
77
) -> None :
63
78
llm = InferenceEndpointsLLM (
64
79
model_id = "distilabel-internal-testing/tiny-random-mistral"
@@ -68,7 +83,7 @@ def test_serverless_inference_endpoints_llm(
68
83
assert llm .model_name == "distilabel-internal-testing/tiny-random-mistral"
69
84
70
85
def test_dedicated_inference_endpoints_llm (
71
- self , mock_inference_client : MagicMock , mock_openai_client : MagicMock
86
+ self , mock_inference_client : MagicMock
72
87
) -> None :
73
88
llm = InferenceEndpointsLLM (
74
89
endpoint_name = "tiny-random-mistral" ,
@@ -79,11 +94,12 @@ def test_dedicated_inference_endpoints_llm(
79
94
assert llm .model_name == "tiny-random-mistral"
80
95
81
96
def test_dedicated_inference_endpoints_llm_via_url (
82
- self , mock_inference_client : MagicMock , mock_openai_client : MagicMock
97
+ self , mock_inference_client : MagicMock
83
98
) -> None :
84
99
llm = InferenceEndpointsLLM (
85
100
base_url = "https://api-inference.huggingface.co/models/distilabel-internal-testing/tiny-random-mistral"
86
101
)
102
+ llm .load ()
87
103
88
104
assert isinstance (llm , InferenceEndpointsLLM )
89
105
assert (
@@ -93,12 +109,12 @@ def test_dedicated_inference_endpoints_llm_via_url(
93
109
94
110
@pytest .mark .asyncio
95
111
async def test_agenerate_via_inference_client (
96
- self , mock_inference_client : MagicMock , mock_openai_client : MagicMock
112
+ self , mock_inference_client : MagicMock
97
113
) -> None :
98
114
llm = InferenceEndpointsLLM (
99
115
model_id = "distilabel-internal-testing/tiny-random-mistral"
100
116
)
101
- llm ._aclient = mock_inference_client
117
+ llm .load ()
102
118
103
119
llm ._aclient .text_generation = AsyncMock (
104
120
return_value = " Aenean hendrerit aliquam velit. ..."
@@ -113,39 +129,14 @@ async def test_agenerate_via_inference_client(
113
129
]
114
130
) == [" Aenean hendrerit aliquam velit. ..." ]
115
131
116
- @pytest .mark .asyncio
117
- async def test_agenerate_via_openai_client (
118
- self , mock_inference_client : MagicMock , mock_openai_client : MagicMock
119
- ) -> None :
120
- llm = InferenceEndpointsLLM (
121
- model_id = "distilabel-internal-testing/tiny-random-mistral" ,
122
- use_openai_client = True ,
123
- )
124
- llm ._aclient = mock_openai_client
125
-
126
- mocked_completion = Mock (
127
- choices = [Mock (message = Mock (content = " Aenean hendrerit aliquam velit. ..." ))]
128
- )
129
- llm ._aclient .chat .completions .create = AsyncMock (return_value = mocked_completion )
130
-
131
- assert await llm .agenerate (
132
- input = [
133
- {"role" : "system" , "content" : "" },
134
- {
135
- "role" : "user" ,
136
- "content" : "Lorem ipsum dolor sit amet, consectetur adipiscing elit." ,
137
- },
138
- ]
139
- ) == [" Aenean hendrerit aliquam velit. ..." ]
140
-
141
132
@pytest .mark .asyncio
142
133
async def test_generate_via_inference_client (
143
- self , mock_inference_client : MagicMock , mock_openai_client : MagicMock
134
+ self , mock_inference_client : MagicMock
144
135
) -> None :
145
136
llm = InferenceEndpointsLLM (
146
- model_id = "distilabel-internal-testing/tiny-random-mistral"
137
+ model_id = "distilabel-internal-testing/tiny-random-mistral" ,
147
138
)
148
- llm ._aclient = mock_inference_client
139
+ llm .load ()
149
140
150
141
llm ._aclient .text_generation = AsyncMock (
151
142
return_value = " Aenean hendrerit aliquam velit. ..."
@@ -165,45 +156,15 @@ async def test_generate_via_inference_client(
165
156
]
166
157
) == [(" Aenean hendrerit aliquam velit. ..." ,)]
167
158
168
- @pytest .mark .asyncio
169
- async def test_generate_via_openai_client (
170
- self , mock_inference_client : MagicMock , mock_openai_client : MagicMock
171
- ) -> None :
172
- llm = InferenceEndpointsLLM (
173
- model_id = "distilabel-internal-testing/tiny-random-mistral" ,
174
- use_openai_client = True ,
175
- )
176
- llm ._aclient = mock_openai_client
177
-
178
- mocked_completion = Mock (
179
- choices = [Mock (message = Mock (content = " Aenean hendrerit aliquam velit. ..." ))]
180
- )
181
- llm ._aclient .chat .completions .create = AsyncMock (return_value = mocked_completion )
182
-
183
- ...
184
- nest_asyncio .apply ()
185
-
186
- assert llm .generate (
187
- inputs = [
188
- [
189
- {"role" : "system" , "content" : "" },
190
- {
191
- "role" : "user" ,
192
- "content" : "Lorem ipsum dolor sit amet, consectetur adipiscing elit." ,
193
- },
194
- ]
195
- ]
196
- ) == [(" Aenean hendrerit aliquam velit. ..." ,)]
197
-
198
159
@pytest .mark .asyncio
199
160
async def test_agenerate_with_structured_output (
200
- self , mock_inference_client : MagicMock , _ : MagicMock
161
+ self , mock_inference_client : MagicMock
201
162
) -> None :
202
163
llm = InferenceEndpointsLLM (
203
164
model_id = "distilabel-internal-testing/tiny-random-mistral" ,
204
165
structured_output = {"format" : "regex" , "schema" : r"\b[A-Z][a-z]*\b" },
205
166
)
206
- llm ._aclient = mock_inference_client
167
+ llm .load ()
207
168
208
169
llm ._aclient .text_generation = AsyncMock (
209
170
return_value = " Aenean hendrerit aliquam velit. ..."
@@ -223,7 +184,7 @@ async def test_agenerate_with_structured_output(
223
184
) == [" Aenean hendrerit aliquam velit. ..." ]
224
185
225
186
kwargs = {
226
- "prompt" : "Lorem ipsum dolor sit amet, consectetur adipiscing elit." ,
187
+ "prompt" : "<s>[INST] Lorem ipsum dolor sit amet, consectetur adipiscing elit. [/INST] " ,
227
188
"max_new_tokens" : 128 ,
228
189
"do_sample" : False ,
229
190
"typical_p" : None ,
@@ -235,15 +196,11 @@ async def test_agenerate_with_structured_output(
235
196
"return_full_text" : False ,
236
197
"watermark" : False ,
237
198
"grammar" : {"type" : "regex" , "value" : "\\ b[A-Z][a-z]*\\ b" },
238
- "seed" : 478163327 , # pre-computed random value with `random.seed(42)`
199
+ "seed" : 2053695854357871005 , # pre-computed random value with `random.seed(42)`
239
200
}
240
- mock_inference_client .text_generation .assert_called_with (** kwargs )
201
+ llm . _aclient .text_generation .assert_called_with (** kwargs )
241
202
242
- def test_serialization (
243
- self ,
244
- mock_inference_client : MagicMock ,
245
- mock_openai_client : MagicMock ,
246
- ) -> None :
203
+ def test_serialization (self , mock_inference_client : MagicMock ) -> None :
247
204
llm = InferenceEndpointsLLM (
248
205
model_id = "distilabel-internal-testing/tiny-random-mistral" ,
249
206
)
@@ -253,11 +210,12 @@ def test_serialization(
253
210
"endpoint_name" : None ,
254
211
"endpoint_namespace" : None ,
255
212
"base_url" : None ,
256
- "tokenizer_id" : None ,
213
+ "tokenizer_id" : "distilabel-internal-testing/tiny-random-mistral" ,
257
214
"generation_kwargs" : {},
215
+ "magpie_pre_query_template" : None ,
258
216
"structured_output" : None ,
259
217
"model_display_name" : None ,
260
- "use_openai_client " : False ,
218
+ "use_magpie_template " : False ,
261
219
"type_info" : {
262
220
"module" : "distilabel.llms.huggingface.inference_endpoints" ,
263
221
"name" : "InferenceEndpointsLLM" ,
0 commit comments