Skip to content

Commit 7960342

Browse files
committed
add 01ai client using extant openai client
1 parent a058fe3 commit 7960342

File tree

1 file changed

+193
-201
lines changed

1 file changed

+193
-201
lines changed

src/distilabel/llms/oneai.py

+193-201
Original file line numberDiff line numberDiff line change
@@ -1,201 +1,193 @@
1-
# Copyright 2023-present, Argilla, Inc.
2-
#
3-
# Licensed under the Apache License, Version 2.0 (the "License");
4-
# you may not use this file except in compliance with the License.
5-
# You may obtain a copy of the License at
6-
#
7-
# http://www.apache.org/licenses/LICENSE-2.0
8-
#
9-
# Unless required by applicable law or agreed to in writing, software
10-
# distributed under the License is distributed on an "AS IS" BASIS,
11-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
# See the License for the specific language governing permissions and
13-
# limitations under the License.
14-
15-
import os
16-
from typing import TYPE_CHECKING, List, Optional, Union
17-
18-
from pydantic import Field, PrivateAttr, SecretStr, validate_call
19-
20-
from distilabel.llms.base import AsyncLLM
21-
from distilabel.llms.typing import GenerateOutput
22-
from distilabel.mixins.runtime_parameters import RuntimeParameter
23-
from distilabel.steps.tasks.typing import FormattedInput, InstructorStructuredOutputType
24-
25-
if TYPE_CHECKING:
26-
from openai import AsyncOpenAI
27-
28-
29-
_01ai_API_KEY_ENV_VAR_NAME = "01AI_API_KEY"
30-
31-
32-
class OneAI(AsyncLLM):
33-
"""WIP"""
34-
35-
model: str
36-
base_url: Optional[RuntimeParameter[str]] = Field(
37-
default_factory=lambda: os.getenv(
38-
"01AI_BASE_URL", "https://api.openai.com/v1"
39-
),
40-
description="The base URL to use for the 01AI API requests.",
41-
)
42-
api_key: Optional[RuntimeParameter[SecretStr]] = Field(
43-
default_factory=lambda: os.getenv(_01ai_API_KEY_ENV_VAR_NAME),
44-
description="The API key to authenticate the requests to the 01ai API.",
45-
)
46-
max_retries: RuntimeParameter[int] = Field(
47-
default=6,
48-
description="The maximum number of times to retry the request to the API before"
49-
" failing.",
50-
)
51-
timeout: RuntimeParameter[int] = Field(
52-
default=120,
53-
description="The maximum time in seconds to wait for a response from the API.",
54-
)
55-
structured_output: Optional[RuntimeParameter[InstructorStructuredOutputType]] = (
56-
Field(
57-
default=None,
58-
description="The structured output format to use across all the generations.",
59-
)
60-
)
61-
62-
_api_key_env_var: str = PrivateAttr(_01ai_API_KEY_ENV_VAR_NAME)
63-
_aclient: Optional["AsyncOpenAI"] = PrivateAttr(...)
64-
65-
def load(self) -> None:
66-
"""Loads the `AsyncOpenAI` client to benefit from async requests."""
67-
super().load()
68-
69-
try:
70-
from openai import AsyncOpenAI
71-
except ImportError as ie:
72-
raise ImportError(
73-
"OpenAI Python client is not installed. Please install it using"
74-
" `pip install openai`."
75-
) from ie
76-
77-
if self.api_key is None:
78-
raise ValueError(
79-
f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
80-
f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
81-
)
82-
83-
self._aclient = AsyncOpenAI(
84-
base_url=self.base_url,
85-
api_key=self.api_key.get_secret_value(),
86-
max_retries=self.max_retries, # type: ignore
87-
timeout=self.timeout,
88-
)
89-
90-
if self.structured_output:
91-
result = self._prepare_structured_output(
92-
structured_output=self.structured_output,
93-
client=self._aclient,
94-
framework="openai",
95-
)
96-
self._aclient = result.get("client") # type: ignore
97-
if structured_output := result.get("structured_output"):
98-
self.structured_output = structured_output
99-
100-
@property
101-
def model_name(self) -> str:
102-
"""Returns the model name used for the LLM."""
103-
return self.model
104-
105-
@validate_call
106-
async def agenerate( # type: ignore
107-
self,
108-
input: FormattedInput,
109-
num_generations: int = 1,
110-
max_new_tokens: int = 128,
111-
frequency_penalty: float = 0.0,
112-
presence_penalty: float = 0.0,
113-
temperature: float = 1.0,
114-
top_p: float = 1.0,
115-
stop: Optional[Union[str, List[str]]] = None,
116-
response_format: Optional[str] = None,
117-
) -> GenerateOutput:
118-
"""Generates `num_generations` responses for the given input using the OpenAI async
119-
client.
120-
121-
Args:
122-
input: a single input in chat format to generate responses for.
123-
num_generations: the number of generations to create per input. Defaults to
124-
`1`.
125-
max_new_tokens: the maximum number of new tokens that the model will generate.
126-
Defaults to `128`.
127-
frequency_penalty: the repetition penalty to use for the generation. Defaults
128-
to `0.0`.
129-
presence_penalty: the presence penalty to use for the generation. Defaults to
130-
`0.0`.
131-
temperature: the temperature to use for the generation. Defaults to `0.1`.
132-
top_p: the top-p value to use for the generation. Defaults to `1.0`.
133-
stop: a string or a list of strings to use as a stop sequence for the generation.
134-
Defaults to `None`.
135-
response_format: the format of the response to return. Must be one of
136-
"text" or "json". Read the documentation [here](https://platform.openai.com/docs/guides/text-generation/json-mode)
137-
for more information on how to use the JSON model from OpenAI. Defaults to `text`.
138-
139-
Note:
140-
If response_format
141-
142-
Returns:
143-
A list of lists of strings containing the generated responses for each input.
144-
"""
145-
146-
structured_output = None
147-
if isinstance(input, tuple):
148-
input, structured_output = input
149-
result = self._prepare_structured_output(
150-
structured_output=structured_output,
151-
client=self._aclient,
152-
framework="openai",
153-
)
154-
self._aclient = result.get("client")
155-
156-
if structured_output is None and self.structured_output is not None:
157-
structured_output = self.structured_output
158-
159-
kwargs = {
160-
"messages": input, # type: ignore
161-
"model": self.model,
162-
"max_tokens": max_new_tokens,
163-
"n": num_generations,
164-
"frequency_penalty": frequency_penalty,
165-
"presence_penalty": presence_penalty,
166-
"temperature": temperature,
167-
"top_p": top_p,
168-
"stop": stop,
169-
"timeout": 50,
170-
}
171-
172-
if response_format is not None:
173-
if response_format not in ["text", "json", "json_object"]:
174-
raise ValueError(
175-
f"Invalid response format '{response_format}'. Must be either 'text'"
176-
" or 'json'."
177-
)
178-
179-
if response_format == "json":
180-
response_format = "json_object"
181-
182-
kwargs["response_format"] = response_format
183-
184-
if structured_output:
185-
kwargs = self._prepare_kwargs(kwargs, structured_output)
186-
187-
generations = []
188-
completion = await self._aclient.chat.completions.create(**kwargs) # type: ignore
189-
190-
if structured_output:
191-
generations.append(completion.model_dump_json())
192-
return generations
193-
194-
for choice in completion.choices:
195-
if (content := choice.message.content) is None:
196-
self._logger.warning( # type: ignore
197-
f"Received no response using OpenAI client (model: '{self.model}')."
198-
f" Finish reason was: {choice.finish_reason}"
199-
)
200-
generations.append(content)
201-
return generations
1+
# Copyright 2023-present, Argilla, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
from typing import TYPE_CHECKING, List, Optional, Union
17+
18+
from pydantic import Field, PrivateAttr, SecretStr, validate_call
19+
20+
from distilabel.llms.base import AsyncLLM
21+
from distilabel.llms.typing import GenerateOutput
22+
from distilabel.mixins.runtime_parameters import RuntimeParameter
23+
from distilabel.steps.tasks.typing import FormattedInput, InstructorStructuredOutputType
24+
25+
if TYPE_CHECKING:
26+
from openai import AsyncOpenAI
27+
28+
_ONEAI_API_KEY_ENV_VAR_NAME = "01AI_API_KEY"
29+
30+
31+
class OneAI(AsyncLLM):
32+
"""OneAI LLM implementation running the async API client of OpenAI.
33+
34+
Attributes:
35+
model: the model name to use for the LLM, e.g., `google/gemma-7b-it`.
36+
base_url: the base URL to use for the OneAI API requests. Defaults to `None`, which
37+
means that the value set for the environment variable `01AI_BASE_URL` will be used, or
38+
"https://api.01.ai/v1/chat/completions" if not set.
39+
api_key: the API key to authenticate the requests to the OneAI API. Defaults to `None` which
40+
means that the value set for the environment variable `01AI_API_KEY` will be used, or
41+
`None` if not set.
42+
max_retries: the maximum number of times to retry the request to the API before failing.
43+
timeout: the maximum time in seconds to wait for a response from the API.
44+
structured_output: the structured output format to use across all the generations.
45+
_api_key_env_var: the name of the environment variable to use for the API key.
46+
It is meant to be used internally.
47+
48+
Examples:
49+
50+
Generate text:
51+
52+
```python
53+
from distilabel.llms import OneAI
54+
55+
llm = OneAI(model="google/gemma-7b-it", api_key="api.key")
56+
57+
llm.load()
58+
59+
output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]])
60+
```
61+
"""
62+
63+
model: str
64+
base_url: Optional[RuntimeParameter[str]] = Field(
65+
default_factory=lambda: os.getenv(
66+
"01AI_BASE_URL", "https://api.01.ai/v1/chat/completions"
67+
),
68+
description="The base URL to use for the OneAI API requests.",
69+
)
70+
api_key: Optional[RuntimeParameter[SecretStr]] = Field(
71+
default_factory=lambda: os.getenv(_ONEAI_API_KEY_ENV_VAR_NAME),
72+
description="The API key to authenticate the requests to the OneAI API.",
73+
)
74+
max_retries: RuntimeParameter[int] = Field(
75+
default=6,
76+
description="The maximum number of times to retry the request to the API before failing.",
77+
)
78+
timeout: RuntimeParameter[int] = Field(
79+
default=120,
80+
description="The maximum time in seconds to wait for a response from the API.",
81+
)
82+
structured_output: Optional[RuntimeParameter[InstructorStructuredOutputType]] = Field(
83+
default=None,
84+
description="The structured output format to use across all the generations.",
85+
)
86+
87+
_api_key_env_var: str = PrivateAttr(_ONEAI_API_KEY_ENV_VAR_NAME)
88+
_aclient: Optional["AsyncOpenAI"] = PrivateAttr(...)
89+
90+
def load(self) -> None:
91+
"""Loads the `AsyncOpenAI` client to benefit from async requests."""
92+
super().load()
93+
try:
94+
from openai import AsyncOpenAI
95+
except ImportError as ie:
96+
raise ImportError(
97+
"OpenAI Python client is not installed. Please install it using `pip install openai`."
98+
) from ie
99+
100+
if self.api_key is None:
101+
raise ValueError(
102+
f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
103+
f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
104+
)
105+
106+
self._aclient = AsyncOpenAI(
107+
base_url=self.base_url,
108+
api_key=self.api_key.get_secret_value(),
109+
max_retries=self.max_retries,
110+
timeout=self.timeout,
111+
)
112+
113+
if self.structured_output:
114+
result = self._prepare_structured_output(
115+
structured_output=self.structured_output,
116+
client=self._aclient,
117+
framework="openai",
118+
)
119+
self._aclient = result.get("client")
120+
if structured_output := result.get("structured_output"):
121+
self.structured_output = structured_output
122+
123+
@property
124+
def model_name(self) -> str:
125+
"""Returns the model name used for the LLM."""
126+
return self.model
127+
128+
@validate_call
129+
async def agenerate(
130+
self,
131+
input: FormattedInput,
132+
num_generations: int = 1,
133+
max_new_tokens: int = 128,
134+
frequency_penalty: float = 0.0,
135+
presence_penalty: float = 0.0,
136+
temperature: float = 1.0,
137+
top_p: float = 1.0,
138+
stop: Optional[Union[str, List[str]]] = None,
139+
response_format: Optional[str] = None,
140+
) -> GenerateOutput:
141+
"""Generates text using the OneAI LLM."""
142+
143+
structured_output = None
144+
if isinstance(input, tuple):
145+
input, structured_output = input
146+
result = self._prepare_structured_output(
147+
structured_output=structured_output,
148+
client=self._aclient,
149+
framework="openai",
150+
)
151+
self._aclient = result.get("client")
152+
if structured_output is None and self.structured_output is not None:
153+
structured_output = self.structured_output
154+
155+
kwargs = {
156+
"messages": input,
157+
"model": self.model,
158+
"max_tokens": max_new_tokens,
159+
"n": num_generations,
160+
"frequency_penalty": frequency_penalty,
161+
"presence_penalty": presence_penalty,
162+
"temperature": temperature,
163+
"top_p": top_p,
164+
"stop": stop,
165+
"timeout": 50,
166+
}
167+
168+
if response_format is not None:
169+
if response_format not in ["text", "json", "json_object"]:
170+
raise ValueError(
171+
f"Invalid response format '{response_format}'. Must be either 'text' or 'json'."
172+
)
173+
if response_format == "json":
174+
response_format = "json_object"
175+
kwargs["response_format"] = response_format
176+
177+
if structured_output:
178+
kwargs = self._prepare_kwargs(kwargs, structured_output)
179+
180+
generations = []
181+
completion = await self._aclient.chat.completions.create(**kwargs)
182+
if structured_output:
183+
generations.append(completion.model_dump_json())
184+
return generations
185+
186+
for choice in completion.choices:
187+
if (content := choice.message.content) is None:
188+
self._logger.warning(
189+
f"Received no response using OpenAI client (model: '{self.model}')."
190+
f" Finish reason was: {choice.finish_reason}"
191+
)
192+
generations.append(content)
193+
return generations

0 commit comments

Comments
 (0)