Skip to content

Commit 87c11cc

Browse files
committed
Add chat_completion unit tests
1 parent 3863eb3 commit 87c11cc

File tree

1 file changed

+89
-21
lines changed

1 file changed

+89
-21
lines changed

Diff for: tests/unit/llms/huggingface/test_inference_endpoints.py

+89-21
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@
2121
import nest_asyncio
2222
import pytest
2323
from distilabel.llms.huggingface.inference_endpoints import InferenceEndpointsLLM
24+
from huggingface_hub import (
25+
ChatCompletionOutput,
26+
ChatCompletionOutputComplete,
27+
ChatCompletionOutputMessage,
28+
ChatCompletionOutputUsage,
29+
)
2430

2531

2632
@pytest.fixture(autouse=True)
@@ -49,25 +55,7 @@ def test_tokenizer_id_set_if_model_id_and_structured_output(
4955
) -> None:
5056
llm = InferenceEndpointsLLM(
5157
model_id="distilabel-internal-testing/tiny-random-mistral",
52-
structured_output={ # type: ignore
53-
"title": "MMORPG Character",
54-
"type": "object",
55-
"properties": {
56-
"name": {"type": "string", "description": "Character's name"},
57-
"level": {
58-
"type": "integer",
59-
"minimum": 1,
60-
"maximum": 100,
61-
"description": "Character's level",
62-
},
63-
"health": {
64-
"type": "integer",
65-
"minimum": 1,
66-
"description": "Character's current health",
67-
},
68-
},
69-
"required": ["name", "level", "health"],
70-
},
58+
structured_output={"format": "regex", "schema": r"\b[A-Z][a-z]*\b"},
7159
)
7260

7361
assert llm.tokenizer_id == llm.model_id
@@ -160,9 +148,89 @@ async def test_agenerate_with_text_generation(
160148
) == [" Aenean hendrerit aliquam velit. ..."]
161149

162150
@pytest.mark.asyncio
163-
async def test_generate_with_text_generation(
151+
async def test_agenerate_with_chat_completion(
152+
self, mock_inference_client: MagicMock
153+
) -> None:
154+
llm = InferenceEndpointsLLM(
155+
model_id="distilabel-internal-testing/tiny-random-mistral",
156+
)
157+
llm.load()
158+
159+
llm._aclient.chat_completion = AsyncMock( # type: ignore
160+
return_value=ChatCompletionOutput( # type: ignore
161+
choices=[
162+
ChatCompletionOutputComplete(
163+
finish_reason="length",
164+
index=0,
165+
message=ChatCompletionOutputMessage(
166+
role="assistant",
167+
content=" Aenean hendrerit aliquam velit. ...",
168+
),
169+
)
170+
],
171+
created=1721045246,
172+
id="",
173+
model="meta-llama/Meta-Llama-3-70B-Instruct",
174+
object="chat.completion",
175+
system_fingerprint="2.1.1-dev0-sha-4327210",
176+
usage=ChatCompletionOutputUsage(
177+
completion_tokens=66, prompt_tokens=18, total_tokens=84
178+
),
179+
)
180+
)
181+
182+
assert await llm.agenerate(
183+
input=[
184+
{
185+
"role": "user",
186+
"content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
187+
},
188+
]
189+
) == [" Aenean hendrerit aliquam velit. ..."]
190+
191+
@pytest.mark.asyncio
192+
async def test_agenerate_with_chat_completion_fails(
164193
self, mock_inference_client: MagicMock
165194
) -> None:
195+
llm = InferenceEndpointsLLM(
196+
model_id="distilabel-internal-testing/tiny-random-mistral",
197+
)
198+
llm.load()
199+
200+
llm._aclient.chat_completion = AsyncMock( # type: ignore
201+
return_value=ChatCompletionOutput( # type: ignore
202+
choices=[
203+
ChatCompletionOutputComplete(
204+
finish_reason="eos_token",
205+
index=0,
206+
message=ChatCompletionOutputMessage(
207+
role="assistant",
208+
content=None,
209+
),
210+
)
211+
],
212+
created=1721045246,
213+
id="",
214+
model="meta-llama/Meta-Llama-3-70B-Instruct",
215+
object="chat.completion",
216+
system_fingerprint="2.1.1-dev0-sha-4327210",
217+
usage=ChatCompletionOutputUsage(
218+
completion_tokens=66, prompt_tokens=18, total_tokens=84
219+
),
220+
)
221+
)
222+
223+
assert await llm.agenerate(
224+
input=[
225+
{
226+
"role": "user",
227+
"content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
228+
},
229+
]
230+
) == [None]
231+
232+
@pytest.mark.asyncio
233+
async def test_generate(self, mock_inference_client: MagicMock) -> None:
166234
llm = InferenceEndpointsLLM(
167235
model_id="distilabel-internal-testing/tiny-random-mistral",
168236
tokenizer_id="distilabel-internal-testing/tiny-random-mistral",
@@ -185,7 +253,7 @@ async def test_generate_with_text_generation(
185253
},
186254
]
187255
]
188-
) == [(" Aenean hendrerit aliquam velit. ...",)]
256+
) == [[" Aenean hendrerit aliquam velit. ..."]]
189257

190258
@pytest.mark.asyncio
191259
async def test_agenerate_with_structured_output(

0 commit comments

Comments
 (0)