Skip to content

Commit ada155a

Browse files
committed
improve tests
1 parent 0e2f15b commit ada155a

File tree

3 files changed

+165
-173
lines changed

3 files changed

+165
-173
lines changed

docs/api/llm/01ai.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# 01ai API
22

3-
::: distilabel.llms.oneai
3+
::: distilabel.llms.oneai

docs/sections/how_to_guides/basic/llm/index.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -142,4 +142,4 @@ To create custom LLMs, subclass either [`LLM`][distilabel.llms.LLM] for synchron
142142

143143
## Available LLMs
144144

145-
[Our LLM gallery](/distilabel/components-gallery/llms/) shows a list of the available LLMs that can be used within the `distilabel` library.
145+
[Our LLM gallery](/distilabel/components-gallery/llms/) shows a list of the available LLMs that can be used within the `distilabel` library.

tests/unit/llms/test_01ai.py

+163-171
Original file line numberDiff line numberDiff line change
@@ -1,171 +1,163 @@
1-
from distilabel.llms.oneai import OneAI
2-
from typing import Any, Dict
3-
from unittest import mock
4-
from unittest.mock import AsyncMock, MagicMock, Mock, patch
5-
6-
import nest_asyncio
7-
import pytest
8-
from .utils import DummyUserDetail
9-
10-
11-
@patch("openai.AsyncOpenAI")
12-
class TestOneAI:
13-
model_id: str = "gpt-4"
14-
15-
def test_oneai_llm(self, _: MagicMock) -> None:
16-
llm = OneAI(model=self.model_id, api_key="api.key") # type: ignore
17-
18-
assert isinstance(llm, OneAI)
19-
assert llm.model_name == self.model_id
20-
21-
def test_oneai_llm_env_vars(self, _: MagicMock) -> None:
22-
with mock.patch.dict(os.environ, clear=True):
23-
os.environ["01AI_API_KEY"] = "another.api.key"
24-
os.environ["01AI_BASE_URL"] = "https://example.com"
25-
26-
llm = OneAI(model=self.model_id)
27-
28-
assert isinstance(llm, OneAI)
29-
assert llm.model_name == self.model_id
30-
assert llm.base_url == "https://example.com"
31-
assert llm.api_key.get_secret_value() == "another.api.key" # type: ignore
32-
33-
@pytest.mark.asyncio
34-
async def test_agenerate(self, mock_openai: MagicMock) -> None:
35-
llm = OneAI(model=self.model_id, api_key="api.key") # type: ignore
36-
llm._aclient = mock_openai
37-
38-
mocked_completion = Mock(
39-
choices=[Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ..."))]
40-
)
41-
llm._aclient.chat.completions.create = AsyncMock(return_value=mocked_completion)
42-
43-
await llm.agenerate(
44-
input=[
45-
{"role": "system", "content": ""},
46-
{
47-
"role": "user",
48-
"content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
49-
},
50-
]
51-
)
52-
53-
@pytest.mark.asyncio
54-
async def test_agenerate_structured(self, mock_openai: MagicMock) -> None:
55-
llm = OneAI(
56-
model=self.model_id,
57-
api_key="api.key",
58-
structured_output={
59-
"schema": DummyUserDetail,
60-
"mode": "tool_call",
61-
"max_retries": 1,
62-
},
63-
) # type: ignore
64-
llm._aclient = mock_openai
65-
66-
sample_user = DummyUserDetail(name="John Doe", age=30)
67-
68-
llm._aclient.chat.completions.create = AsyncMock(return_value=sample_user)
69-
70-
generation = await llm.agenerate(
71-
input=[
72-
{"role": "system", "content": ""},
73-
{
74-
"role": "user",
75-
"content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
76-
},
77-
]
78-
)
79-
assert generation[0] == sample_user.model_dump_json()
80-
81-
@pytest.mark.skipif(
82-
sys.version_info < (3, 9), reason="`mistralai` requires Python 3.9 or higher"
83-
)
84-
@pytest.mark.asyncio
85-
async def test_generate(self, mock_openai: MagicMock) -> None:
86-
llm = OneAI(model=self.model_id, api_key="api.key") # type: ignore
87-
llm._aclient = mock_openai
88-
89-
mocked_completion = Mock(
90-
choices=[Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ..."))]
91-
)
92-
llm._aclient.chat.completions.create = AsyncMock(return_value=mocked_completion)
93-
94-
nest_asyncio.apply()
95-
96-
llm.generate(
97-
inputs=[
98-
[
99-
{"role": "system", "content": ""},
100-
{
101-
"role": "user",
102-
"content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
103-
},
104-
]
105-
]
106-
)
107-
108-
with pytest.raises(ValueError):
109-
llm.generate(
110-
inputs=[
111-
[
112-
{"role": "system", "content": ""},
113-
{
114-
"role": "user",
115-
"content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
116-
},
117-
]
118-
],
119-
response_format="unkown_format",
120-
)
121-
122-
@pytest.mark.parametrize(
123-
"structured_output, dump",
124-
[
125-
(
126-
None,
127-
{
128-
"model": "gpt-4",
129-
"generation_kwargs": {},
130-
"max_retries": 6,
131-
"base_url": "https://api.openai.com/v1",
132-
"timeout": 120,
133-
"structured_output": None,
134-
"type_info": {
135-
"module": "distilabel.llms.oneai",
136-
"name": "OneAI",
137-
},
138-
},
139-
),
140-
(
141-
{
142-
"schema": DummyUserDetail.model_json_schema(),
143-
"mode": "tool_call",
144-
"max_retries": 1,
145-
},
146-
{
147-
"model": "gpt-4",
148-
"generation_kwargs": {},
149-
"max_retries": 6,
150-
"base_url": "https://api.openai.com/v1",
151-
"timeout": 120,
152-
"structured_output": {
153-
"schema": DummyUserDetail.model_json_schema(),
154-
"mode": "tool_call",
155-
"max_retries": 1,
156-
},
157-
"type_info": {
158-
"module": "distilabel.llms.oneai",
159-
"name": "OneAI",
160-
},
161-
},
162-
),
163-
],
164-
)
165-
def test_serialization(
166-
self, _: MagicMock, structured_output: Dict[str, Any], dump: Dict[str, Any]
167-
) -> None:
168-
llm = OneAI(model=self.model_id, structured_output=structured_output)
169-
170-
assert llm.dump() == dump
171-
assert isinstance(OneAI.from_dict(dump), OneAI)
1+
# tests/test_oneai.py
2+
3+
import os
4+
import sys
5+
from typing import Any, Dict
6+
from unittest import mock
7+
from unittest.mock import AsyncMock, MagicMock, Mock, patch
8+
9+
import nest_asyncio
10+
import pytest
11+
12+
from distilabel.llms.oneai import OneAI
13+
from .utils import DummyUserDetail
14+
15+
16+
@patch("openai.AsyncOpenAI")
17+
class TestOneAI:
18+
model_id: str = "yi-large"
19+
20+
def test_oneai_llm(self, _: MagicMock) -> None:
21+
llm = OneAI(model=self.model_id, api_key="api.key") # type: ignore
22+
assert isinstance(llm, OneAI)
23+
assert llm.model_name == self.model_id
24+
25+
def test_oneai_llm_env_vars(self, _: MagicMock) -> None:
26+
with mock.patch.dict(os.environ, clear=True):
27+
os.environ["01AI_API_KEY"] = "another.api.key"
28+
os.environ["01AI_BASE_URL"] = "https://api.01.ai/v1/chat/completions"
29+
llm = OneAI(model=self.model_id)
30+
assert isinstance(llm, OneAI)
31+
assert llm.model_name == self.model_id
32+
assert llm.base_url == "https://api.01.ai/v1/chat/completions"
33+
assert llm.api_key.get_secret_value() == "another.api.key" # type: ignore
34+
35+
@pytest.mark.asyncio
36+
async def test_agenerate(self, mock_openai: MagicMock) -> None:
37+
llm = OneAI(model=self.model_id, api_key="api.key") # type: ignore
38+
llm._aclient = mock_openai
39+
mocked_completion = Mock(
40+
choices=[Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ..."))]
41+
)
42+
llm._aclient.chat.completions.create = AsyncMock(return_value=mocked_completion)
43+
await llm.agenerate(
44+
input=[
45+
{"role": "system", "content": ""},
46+
{
47+
"role": "user",
48+
"content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
49+
},
50+
]
51+
)
52+
53+
@pytest.mark.asyncio
54+
async def test_agenerate_structured(self, mock_openai: MagicMock) -> None:
55+
llm = OneAI(
56+
model=self.model_id,
57+
api_key="api.key",
58+
structured_output={
59+
"schema": DummyUserDetail,
60+
"mode": "tool_call",
61+
"max_retries": 1,
62+
},
63+
) # type: ignore
64+
llm._aclient = mock_openai
65+
sample_user = DummyUserDetail(name="John Doe", age=30)
66+
llm._aclient.chat.completions.create = AsyncMock(return_value=sample_user)
67+
generation = await llm.agenerate(
68+
input=[
69+
{"role": "system", "content": ""},
70+
{
71+
"role": "user",
72+
"content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
73+
},
74+
]
75+
)
76+
assert generation[0] == sample_user.model_dump_json()
77+
78+
@pytest.mark.skipif(
79+
sys.version_info < (3, 9), reason="`01ai` requires Python 3.9 or higher"
80+
)
81+
@pytest.mark.asyncio
82+
async def test_generate(self, mock_openai: MagicMock) -> None:
83+
llm = OneAI(model=self.model_id, api_key="api.key") # type: ignore
84+
llm._aclient = mock_openai
85+
mocked_completion = Mock(
86+
choices=[Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ..."))]
87+
)
88+
llm._aclient.chat.completions.create = AsyncMock(return_value=mocked_completion)
89+
nest_asyncio.apply()
90+
llm.generate(
91+
inputs=[
92+
[
93+
{"role": "system", "content": ""},
94+
{
95+
"role": "user",
96+
"content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
97+
},
98+
]
99+
]
100+
)
101+
with pytest.raises(ValueError):
102+
llm.generate(
103+
inputs=[
104+
[
105+
{"role": "system", "content": ""},
106+
{
107+
"role": "user",
108+
"content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
109+
},
110+
]
111+
],
112+
response_format="unknown_format",
113+
)
114+
115+
@pytest.mark.parametrize(
116+
"structured_output, dump",
117+
[
118+
(
119+
None,
120+
{
121+
"model": "yi-large",
122+
"generation_kwargs": {},
123+
"max_retries": 6,
124+
"base_url": "https://api.01.ai/v1/chat/completions",
125+
"timeout": 120,
126+
"structured_output": None,
127+
"type_info": {
128+
"module": "distilabel.llms.oneai",
129+
"name": "OneAI",
130+
},
131+
},
132+
),
133+
(
134+
{
135+
"schema": DummyUserDetail.model_json_schema(),
136+
"mode": "tool_call",
137+
"max_retries": 1,
138+
},
139+
{
140+
"model": "gpt-4",
141+
"generation_kwargs": {},
142+
"max_retries": 6,
143+
"base_url": "https://api.01.ai/v1/chat/completions",
144+
"timeout": 120,
145+
"structured_output": {
146+
"schema": DummyUserDetail.model_json_schema(),
147+
"mode": "tool_call",
148+
"max_retries": 1,
149+
},
150+
"type_info": {
151+
"module": "distilabel.llms.oneai",
152+
"name": "OneAI",
153+
},
154+
},
155+
),
156+
],
157+
)
158+
def test_serialization(
159+
self, _: MagicMock, structured_output: Dict[str, Any], dump: Dict[str, Any]
160+
) -> None:
161+
llm = OneAI(model=self.model_id, structured_output=structured_output)
162+
assert llm.dump() == dump
163+
assert isinstance(OneAI.from_dict(dump), OneAI)

0 commit comments

Comments
 (0)