21
21
import nest_asyncio
22
22
import pytest
23
23
from distilabel .llms .huggingface .inference_endpoints import InferenceEndpointsLLM
24
+ from huggingface_hub import (
25
+ ChatCompletionOutput ,
26
+ ChatCompletionOutputComplete ,
27
+ ChatCompletionOutputMessage ,
28
+ ChatCompletionOutputUsage ,
29
+ )
24
30
25
31
26
32
@pytest .fixture (autouse = True )
@@ -49,25 +55,7 @@ def test_tokenizer_id_set_if_model_id_and_structured_output(
49
55
) -> None :
50
56
llm = InferenceEndpointsLLM (
51
57
model_id = "distilabel-internal-testing/tiny-random-mistral" ,
52
- structured_output = { # type: ignore
53
- "title" : "MMORPG Character" ,
54
- "type" : "object" ,
55
- "properties" : {
56
- "name" : {"type" : "string" , "description" : "Character's name" },
57
- "level" : {
58
- "type" : "integer" ,
59
- "minimum" : 1 ,
60
- "maximum" : 100 ,
61
- "description" : "Character's level" ,
62
- },
63
- "health" : {
64
- "type" : "integer" ,
65
- "minimum" : 1 ,
66
- "description" : "Character's current health" ,
67
- },
68
- },
69
- "required" : ["name" , "level" , "health" ],
70
- },
58
+ structured_output = {"format" : "regex" , "schema" : r"\b[A-Z][a-z]*\b" },
71
59
)
72
60
73
61
assert llm .tokenizer_id == llm .model_id
@@ -160,9 +148,89 @@ async def test_agenerate_with_text_generation(
160
148
) == [" Aenean hendrerit aliquam velit. ..." ]
161
149
162
150
@pytest .mark .asyncio
163
- async def test_generate_with_text_generation (
151
+ async def test_agenerate_with_chat_completion (
152
+ self , mock_inference_client : MagicMock
153
+ ) -> None :
154
+ llm = InferenceEndpointsLLM (
155
+ model_id = "distilabel-internal-testing/tiny-random-mistral" ,
156
+ )
157
+ llm .load ()
158
+
159
+ llm ._aclient .chat_completion = AsyncMock ( # type: ignore
160
+ return_value = ChatCompletionOutput ( # type: ignore
161
+ choices = [
162
+ ChatCompletionOutputComplete (
163
+ finish_reason = "length" ,
164
+ index = 0 ,
165
+ message = ChatCompletionOutputMessage (
166
+ role = "assistant" ,
167
+ content = " Aenean hendrerit aliquam velit. ..." ,
168
+ ),
169
+ )
170
+ ],
171
+ created = 1721045246 ,
172
+ id = "" ,
173
+ model = "meta-llama/Meta-Llama-3-70B-Instruct" ,
174
+ object = "chat.completion" ,
175
+ system_fingerprint = "2.1.1-dev0-sha-4327210" ,
176
+ usage = ChatCompletionOutputUsage (
177
+ completion_tokens = 66 , prompt_tokens = 18 , total_tokens = 84
178
+ ),
179
+ )
180
+ )
181
+
182
+ assert await llm .agenerate (
183
+ input = [
184
+ {
185
+ "role" : "user" ,
186
+ "content" : "Lorem ipsum dolor sit amet, consectetur adipiscing elit." ,
187
+ },
188
+ ]
189
+ ) == [" Aenean hendrerit aliquam velit. ..." ]
190
+
191
+ @pytest .mark .asyncio
192
+ async def test_agenerate_with_chat_completion_fails (
164
193
self , mock_inference_client : MagicMock
165
194
) -> None :
195
+ llm = InferenceEndpointsLLM (
196
+ model_id = "distilabel-internal-testing/tiny-random-mistral" ,
197
+ )
198
+ llm .load ()
199
+
200
+ llm ._aclient .chat_completion = AsyncMock ( # type: ignore
201
+ return_value = ChatCompletionOutput ( # type: ignore
202
+ choices = [
203
+ ChatCompletionOutputComplete (
204
+ finish_reason = "eos_token" ,
205
+ index = 0 ,
206
+ message = ChatCompletionOutputMessage (
207
+ role = "assistant" ,
208
+ content = None ,
209
+ ),
210
+ )
211
+ ],
212
+ created = 1721045246 ,
213
+ id = "" ,
214
+ model = "meta-llama/Meta-Llama-3-70B-Instruct" ,
215
+ object = "chat.completion" ,
216
+ system_fingerprint = "2.1.1-dev0-sha-4327210" ,
217
+ usage = ChatCompletionOutputUsage (
218
+ completion_tokens = 66 , prompt_tokens = 18 , total_tokens = 84
219
+ ),
220
+ )
221
+ )
222
+
223
+ assert await llm .agenerate (
224
+ input = [
225
+ {
226
+ "role" : "user" ,
227
+ "content" : "Lorem ipsum dolor sit amet, consectetur adipiscing elit." ,
228
+ },
229
+ ]
230
+ ) == [None ]
231
+
232
+ @pytest .mark .asyncio
233
+ async def test_generate (self , mock_inference_client : MagicMock ) -> None :
166
234
llm = InferenceEndpointsLLM (
167
235
model_id = "distilabel-internal-testing/tiny-random-mistral" ,
168
236
tokenizer_id = "distilabel-internal-testing/tiny-random-mistral" ,
@@ -185,7 +253,7 @@ async def test_generate_with_text_generation(
185
253
},
186
254
]
187
255
]
188
- ) == [( " Aenean hendrerit aliquam velit. ..." ,) ]
256
+ ) == [[ " Aenean hendrerit aliquam velit. ..." ] ]
189
257
190
258
@pytest .mark .asyncio
191
259
async def test_agenerate_with_structured_output (
0 commit comments