Skip to content

Commit ba6f71f

Browse files
committed
add 01ai client
1 parent 3ad2cac commit ba6f71f

File tree

1 file changed

+201
-0
lines changed

1 file changed

+201
-0
lines changed

src/distilabel/llms/oneai.py

+201
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
# Copyright 2023-present, Argilla, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
from typing import TYPE_CHECKING, List, Optional, Union
17+
18+
from pydantic import Field, PrivateAttr, SecretStr, validate_call
19+
20+
from distilabel.llms.base import AsyncLLM
21+
from distilabel.llms.typing import GenerateOutput
22+
from distilabel.mixins.runtime_parameters import RuntimeParameter
23+
from distilabel.steps.tasks.typing import FormattedInput, InstructorStructuredOutputType
24+
25+
if TYPE_CHECKING:
26+
from openai import AsyncOpenAI
27+
28+
29+
_01ai_API_KEY_ENV_VAR_NAME = "01AI_API_KEY"
30+
31+
32+
class OneAI(AsyncLLM):
33+
"""WIP"""
34+
35+
model: str
36+
base_url: Optional[RuntimeParameter[str]] = Field(
37+
default_factory=lambda: os.getenv(
38+
"01AI_BASE_URL", "https://api.openai.com/v1"
39+
),
40+
description="The base URL to use for the 01AI API requests.",
41+
)
42+
api_key: Optional[RuntimeParameter[SecretStr]] = Field(
43+
default_factory=lambda: os.getenv(_01ai_API_KEY_ENV_VAR_NAME),
44+
description="The API key to authenticate the requests to the 01ai API.",
45+
)
46+
max_retries: RuntimeParameter[int] = Field(
47+
default=6,
48+
description="The maximum number of times to retry the request to the API before"
49+
" failing.",
50+
)
51+
timeout: RuntimeParameter[int] = Field(
52+
default=120,
53+
description="The maximum time in seconds to wait for a response from the API.",
54+
)
55+
structured_output: Optional[RuntimeParameter[InstructorStructuredOutputType]] = (
56+
Field(
57+
default=None,
58+
description="The structured output format to use across all the generations.",
59+
)
60+
)
61+
62+
_api_key_env_var: str = PrivateAttr(_01ai_API_KEY_ENV_VAR_NAME)
63+
_aclient: Optional["AsyncOpenAI"] = PrivateAttr(...)
64+
65+
def load(self) -> None:
66+
"""Loads the `AsyncOpenAI` client to benefit from async requests."""
67+
super().load()
68+
69+
try:
70+
from openai import AsyncOpenAI
71+
except ImportError as ie:
72+
raise ImportError(
73+
"OpenAI Python client is not installed. Please install it using"
74+
" `pip install openai`."
75+
) from ie
76+
77+
if self.api_key is None:
78+
raise ValueError(
79+
f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`"
80+
f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`."
81+
)
82+
83+
self._aclient = AsyncOpenAI(
84+
base_url=self.base_url,
85+
api_key=self.api_key.get_secret_value(),
86+
max_retries=self.max_retries, # type: ignore
87+
timeout=self.timeout,
88+
)
89+
90+
if self.structured_output:
91+
result = self._prepare_structured_output(
92+
structured_output=self.structured_output,
93+
client=self._aclient,
94+
framework="openai",
95+
)
96+
self._aclient = result.get("client") # type: ignore
97+
if structured_output := result.get("structured_output"):
98+
self.structured_output = structured_output
99+
100+
@property
101+
def model_name(self) -> str:
102+
"""Returns the model name used for the LLM."""
103+
return self.model
104+
105+
@validate_call
106+
async def agenerate( # type: ignore
107+
self,
108+
input: FormattedInput,
109+
num_generations: int = 1,
110+
max_new_tokens: int = 128,
111+
frequency_penalty: float = 0.0,
112+
presence_penalty: float = 0.0,
113+
temperature: float = 1.0,
114+
top_p: float = 1.0,
115+
stop: Optional[Union[str, List[str]]] = None,
116+
response_format: Optional[str] = None,
117+
) -> GenerateOutput:
118+
"""Generates `num_generations` responses for the given input using the OpenAI async
119+
client.
120+
121+
Args:
122+
input: a single input in chat format to generate responses for.
123+
num_generations: the number of generations to create per input. Defaults to
124+
`1`.
125+
max_new_tokens: the maximum number of new tokens that the model will generate.
126+
Defaults to `128`.
127+
frequency_penalty: the repetition penalty to use for the generation. Defaults
128+
to `0.0`.
129+
presence_penalty: the presence penalty to use for the generation. Defaults to
130+
`0.0`.
131+
temperature: the temperature to use for the generation. Defaults to `0.1`.
132+
top_p: the top-p value to use for the generation. Defaults to `1.0`.
133+
stop: a string or a list of strings to use as a stop sequence for the generation.
134+
Defaults to `None`.
135+
response_format: the format of the response to return. Must be one of
136+
"text" or "json". Read the documentation [here](https://platform.openai.com/docs/guides/text-generation/json-mode)
137+
for more information on how to use the JSON model from OpenAI. Defaults to `text`.
138+
139+
Note:
140+
If response_format
141+
142+
Returns:
143+
A list of lists of strings containing the generated responses for each input.
144+
"""
145+
146+
structured_output = None
147+
if isinstance(input, tuple):
148+
input, structured_output = input
149+
result = self._prepare_structured_output(
150+
structured_output=structured_output,
151+
client=self._aclient,
152+
framework="openai",
153+
)
154+
self._aclient = result.get("client")
155+
156+
if structured_output is None and self.structured_output is not None:
157+
structured_output = self.structured_output
158+
159+
kwargs = {
160+
"messages": input, # type: ignore
161+
"model": self.model,
162+
"max_tokens": max_new_tokens,
163+
"n": num_generations,
164+
"frequency_penalty": frequency_penalty,
165+
"presence_penalty": presence_penalty,
166+
"temperature": temperature,
167+
"top_p": top_p,
168+
"stop": stop,
169+
"timeout": 50,
170+
}
171+
172+
if response_format is not None:
173+
if response_format not in ["text", "json", "json_object"]:
174+
raise ValueError(
175+
f"Invalid response format '{response_format}'. Must be either 'text'"
176+
" or 'json'."
177+
)
178+
179+
if response_format == "json":
180+
response_format = "json_object"
181+
182+
kwargs["response_format"] = response_format
183+
184+
if structured_output:
185+
kwargs = self._prepare_kwargs(kwargs, structured_output)
186+
187+
generations = []
188+
completion = await self._aclient.chat.completions.create(**kwargs) # type: ignore
189+
190+
if structured_output:
191+
generations.append(completion.model_dump_json())
192+
return generations
193+
194+
for choice in completion.choices:
195+
if (content := choice.message.content) is None:
196+
self._logger.warning( # type: ignore
197+
f"Received no response using OpenAI client (model: '{self.model}')."
198+
f" Finish reason was: {choice.finish_reason}"
199+
)
200+
generations.append(content)
201+
return generations

0 commit comments

Comments
 (0)