|
1 |
| -# Copyright 2023-present, Argilla, Inc. |
2 |
| -# |
3 |
| -# Licensed under the Apache License, Version 2.0 (the "License"); |
4 |
| -# you may not use this file except in compliance with the License. |
5 |
| -# You may obtain a copy of the License at |
6 |
| -# |
7 |
| -# http://www.apache.org/licenses/LICENSE-2.0 |
8 |
| -# |
9 |
| -# Unless required by applicable law or agreed to in writing, software |
10 |
| -# distributed under the License is distributed on an "AS IS" BASIS, |
11 |
| -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 |
| -# See the License for the specific language governing permissions and |
13 |
| -# limitations under the License. |
14 |
| - |
15 |
| -import os |
16 |
| -from typing import TYPE_CHECKING, List, Optional, Union |
17 |
| - |
18 |
| -from pydantic import Field, PrivateAttr, SecretStr, validate_call |
19 |
| - |
20 |
| -from distilabel.llms.base import AsyncLLM |
21 |
| -from distilabel.llms.typing import GenerateOutput |
22 |
| -from distilabel.mixins.runtime_parameters import RuntimeParameter |
23 |
| -from distilabel.steps.tasks.typing import FormattedInput, InstructorStructuredOutputType |
24 |
| - |
25 |
| -if TYPE_CHECKING: |
26 |
| - from openai import AsyncOpenAI |
27 |
| - |
28 |
| - |
29 |
| -_01ai_API_KEY_ENV_VAR_NAME = "01AI_API_KEY" |
30 |
| - |
31 |
| - |
32 |
| -class OneAI(AsyncLLM): |
33 |
| - """WIP""" |
34 |
| - |
35 |
| - model: str |
36 |
| - base_url: Optional[RuntimeParameter[str]] = Field( |
37 |
| - default_factory=lambda: os.getenv( |
38 |
| - "01AI_BASE_URL", "https://api.openai.com/v1" |
39 |
| - ), |
40 |
| - description="The base URL to use for the 01AI API requests.", |
41 |
| - ) |
42 |
| - api_key: Optional[RuntimeParameter[SecretStr]] = Field( |
43 |
| - default_factory=lambda: os.getenv(_01ai_API_KEY_ENV_VAR_NAME), |
44 |
| - description="The API key to authenticate the requests to the 01ai API.", |
45 |
| - ) |
46 |
| - max_retries: RuntimeParameter[int] = Field( |
47 |
| - default=6, |
48 |
| - description="The maximum number of times to retry the request to the API before" |
49 |
| - " failing.", |
50 |
| - ) |
51 |
| - timeout: RuntimeParameter[int] = Field( |
52 |
| - default=120, |
53 |
| - description="The maximum time in seconds to wait for a response from the API.", |
54 |
| - ) |
55 |
| - structured_output: Optional[RuntimeParameter[InstructorStructuredOutputType]] = ( |
56 |
| - Field( |
57 |
| - default=None, |
58 |
| - description="The structured output format to use across all the generations.", |
59 |
| - ) |
60 |
| - ) |
61 |
| - |
62 |
| - _api_key_env_var: str = PrivateAttr(_01ai_API_KEY_ENV_VAR_NAME) |
63 |
| - _aclient: Optional["AsyncOpenAI"] = PrivateAttr(...) |
64 |
| - |
65 |
| - def load(self) -> None: |
66 |
| - """Loads the `AsyncOpenAI` client to benefit from async requests.""" |
67 |
| - super().load() |
68 |
| - |
69 |
| - try: |
70 |
| - from openai import AsyncOpenAI |
71 |
| - except ImportError as ie: |
72 |
| - raise ImportError( |
73 |
| - "OpenAI Python client is not installed. Please install it using" |
74 |
| - " `pip install openai`." |
75 |
| - ) from ie |
76 |
| - |
77 |
| - if self.api_key is None: |
78 |
| - raise ValueError( |
79 |
| - f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`" |
80 |
| - f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`." |
81 |
| - ) |
82 |
| - |
83 |
| - self._aclient = AsyncOpenAI( |
84 |
| - base_url=self.base_url, |
85 |
| - api_key=self.api_key.get_secret_value(), |
86 |
| - max_retries=self.max_retries, # type: ignore |
87 |
| - timeout=self.timeout, |
88 |
| - ) |
89 |
| - |
90 |
| - if self.structured_output: |
91 |
| - result = self._prepare_structured_output( |
92 |
| - structured_output=self.structured_output, |
93 |
| - client=self._aclient, |
94 |
| - framework="openai", |
95 |
| - ) |
96 |
| - self._aclient = result.get("client") # type: ignore |
97 |
| - if structured_output := result.get("structured_output"): |
98 |
| - self.structured_output = structured_output |
99 |
| - |
100 |
| - @property |
101 |
| - def model_name(self) -> str: |
102 |
| - """Returns the model name used for the LLM.""" |
103 |
| - return self.model |
104 |
| - |
105 |
| - @validate_call |
106 |
| - async def agenerate( # type: ignore |
107 |
| - self, |
108 |
| - input: FormattedInput, |
109 |
| - num_generations: int = 1, |
110 |
| - max_new_tokens: int = 128, |
111 |
| - frequency_penalty: float = 0.0, |
112 |
| - presence_penalty: float = 0.0, |
113 |
| - temperature: float = 1.0, |
114 |
| - top_p: float = 1.0, |
115 |
| - stop: Optional[Union[str, List[str]]] = None, |
116 |
| - response_format: Optional[str] = None, |
117 |
| - ) -> GenerateOutput: |
118 |
| - """Generates `num_generations` responses for the given input using the OpenAI async |
119 |
| - client. |
120 |
| -
|
121 |
| - Args: |
122 |
| - input: a single input in chat format to generate responses for. |
123 |
| - num_generations: the number of generations to create per input. Defaults to |
124 |
| - `1`. |
125 |
| - max_new_tokens: the maximum number of new tokens that the model will generate. |
126 |
| - Defaults to `128`. |
127 |
| - frequency_penalty: the repetition penalty to use for the generation. Defaults |
128 |
| - to `0.0`. |
129 |
| - presence_penalty: the presence penalty to use for the generation. Defaults to |
130 |
| - `0.0`. |
131 |
| - temperature: the temperature to use for the generation. Defaults to `0.1`. |
132 |
| - top_p: the top-p value to use for the generation. Defaults to `1.0`. |
133 |
| - stop: a string or a list of strings to use as a stop sequence for the generation. |
134 |
| - Defaults to `None`. |
135 |
| - response_format: the format of the response to return. Must be one of |
136 |
| - "text" or "json". Read the documentation [here](https://platform.openai.com/docs/guides/text-generation/json-mode) |
137 |
| - for more information on how to use the JSON model from OpenAI. Defaults to `text`. |
138 |
| -
|
139 |
| - Note: |
140 |
| - If response_format |
141 |
| -
|
142 |
| - Returns: |
143 |
| - A list of lists of strings containing the generated responses for each input. |
144 |
| - """ |
145 |
| - |
146 |
| - structured_output = None |
147 |
| - if isinstance(input, tuple): |
148 |
| - input, structured_output = input |
149 |
| - result = self._prepare_structured_output( |
150 |
| - structured_output=structured_output, |
151 |
| - client=self._aclient, |
152 |
| - framework="openai", |
153 |
| - ) |
154 |
| - self._aclient = result.get("client") |
155 |
| - |
156 |
| - if structured_output is None and self.structured_output is not None: |
157 |
| - structured_output = self.structured_output |
158 |
| - |
159 |
| - kwargs = { |
160 |
| - "messages": input, # type: ignore |
161 |
| - "model": self.model, |
162 |
| - "max_tokens": max_new_tokens, |
163 |
| - "n": num_generations, |
164 |
| - "frequency_penalty": frequency_penalty, |
165 |
| - "presence_penalty": presence_penalty, |
166 |
| - "temperature": temperature, |
167 |
| - "top_p": top_p, |
168 |
| - "stop": stop, |
169 |
| - "timeout": 50, |
170 |
| - } |
171 |
| - |
172 |
| - if response_format is not None: |
173 |
| - if response_format not in ["text", "json", "json_object"]: |
174 |
| - raise ValueError( |
175 |
| - f"Invalid response format '{response_format}'. Must be either 'text'" |
176 |
| - " or 'json'." |
177 |
| - ) |
178 |
| - |
179 |
| - if response_format == "json": |
180 |
| - response_format = "json_object" |
181 |
| - |
182 |
| - kwargs["response_format"] = response_format |
183 |
| - |
184 |
| - if structured_output: |
185 |
| - kwargs = self._prepare_kwargs(kwargs, structured_output) |
186 |
| - |
187 |
| - generations = [] |
188 |
| - completion = await self._aclient.chat.completions.create(**kwargs) # type: ignore |
189 |
| - |
190 |
| - if structured_output: |
191 |
| - generations.append(completion.model_dump_json()) |
192 |
| - return generations |
193 |
| - |
194 |
| - for choice in completion.choices: |
195 |
| - if (content := choice.message.content) is None: |
196 |
| - self._logger.warning( # type: ignore |
197 |
| - f"Received no response using OpenAI client (model: '{self.model}')." |
198 |
| - f" Finish reason was: {choice.finish_reason}" |
199 |
| - ) |
200 |
| - generations.append(content) |
201 |
| - return generations |
| 1 | +# Copyright 2023-present, Argilla, Inc. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import os |
| 16 | +from typing import TYPE_CHECKING, List, Optional, Union |
| 17 | + |
| 18 | +from pydantic import Field, PrivateAttr, SecretStr, validate_call |
| 19 | + |
| 20 | +from distilabel.llms.base import AsyncLLM |
| 21 | +from distilabel.llms.typing import GenerateOutput |
| 22 | +from distilabel.mixins.runtime_parameters import RuntimeParameter |
| 23 | +from distilabel.steps.tasks.typing import FormattedInput, InstructorStructuredOutputType |
| 24 | + |
| 25 | +if TYPE_CHECKING: |
| 26 | + from openai import AsyncOpenAI |
| 27 | + |
| 28 | +_ONEAI_API_KEY_ENV_VAR_NAME = "01AI_API_KEY" |
| 29 | + |
| 30 | + |
| 31 | +class OneAI(AsyncLLM): |
| 32 | + """OneAI LLM implementation running the async API client of OpenAI. |
| 33 | + |
| 34 | + Attributes: |
| 35 | + model: the model name to use for the LLM, e.g., `google/gemma-7b-it`. |
| 36 | + base_url: the base URL to use for the OneAI API requests. Defaults to `None`, which |
| 37 | + means that the value set for the environment variable `01AI_BASE_URL` will be used, or |
| 38 | + "https://api.01.ai/v1/chat/completions" if not set. |
| 39 | + api_key: the API key to authenticate the requests to the OneAI API. Defaults to `None` which |
| 40 | + means that the value set for the environment variable `01AI_API_KEY` will be used, or |
| 41 | + `None` if not set. |
| 42 | + max_retries: the maximum number of times to retry the request to the API before failing. |
| 43 | + timeout: the maximum time in seconds to wait for a response from the API. |
| 44 | + structured_output: the structured output format to use across all the generations. |
| 45 | + _api_key_env_var: the name of the environment variable to use for the API key. |
| 46 | + It is meant to be used internally. |
| 47 | + |
| 48 | + Examples: |
| 49 | + |
| 50 | + Generate text: |
| 51 | + |
| 52 | + ```python |
| 53 | + from distilabel.llms import OneAI |
| 54 | + |
| 55 | + llm = OneAI(model="google/gemma-7b-it", api_key="api.key") |
| 56 | + |
| 57 | + llm.load() |
| 58 | + |
| 59 | + output = llm.generate(inputs=[[{"role": "user", "content": "Hello world!"}]]) |
| 60 | + ``` |
| 61 | + """ |
| 62 | + |
| 63 | + model: str |
| 64 | + base_url: Optional[RuntimeParameter[str]] = Field( |
| 65 | + default_factory=lambda: os.getenv( |
| 66 | + "01AI_BASE_URL", "https://api.01.ai/v1/chat/completions" |
| 67 | + ), |
| 68 | + description="The base URL to use for the OneAI API requests.", |
| 69 | + ) |
| 70 | + api_key: Optional[RuntimeParameter[SecretStr]] = Field( |
| 71 | + default_factory=lambda: os.getenv(_ONEAI_API_KEY_ENV_VAR_NAME), |
| 72 | + description="The API key to authenticate the requests to the OneAI API.", |
| 73 | + ) |
| 74 | + max_retries: RuntimeParameter[int] = Field( |
| 75 | + default=6, |
| 76 | + description="The maximum number of times to retry the request to the API before failing.", |
| 77 | + ) |
| 78 | + timeout: RuntimeParameter[int] = Field( |
| 79 | + default=120, |
| 80 | + description="The maximum time in seconds to wait for a response from the API.", |
| 81 | + ) |
| 82 | + structured_output: Optional[RuntimeParameter[InstructorStructuredOutputType]] = Field( |
| 83 | + default=None, |
| 84 | + description="The structured output format to use across all the generations.", |
| 85 | + ) |
| 86 | + |
| 87 | + _api_key_env_var: str = PrivateAttr(_ONEAI_API_KEY_ENV_VAR_NAME) |
| 88 | + _aclient: Optional["AsyncOpenAI"] = PrivateAttr(...) |
| 89 | + |
| 90 | + def load(self) -> None: |
| 91 | + """Loads the `AsyncOpenAI` client to benefit from async requests.""" |
| 92 | + super().load() |
| 93 | + try: |
| 94 | + from openai import AsyncOpenAI |
| 95 | + except ImportError as ie: |
| 96 | + raise ImportError( |
| 97 | + "OpenAI Python client is not installed. Please install it using `pip install openai`." |
| 98 | + ) from ie |
| 99 | + |
| 100 | + if self.api_key is None: |
| 101 | + raise ValueError( |
| 102 | + f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`" |
| 103 | + f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`." |
| 104 | + ) |
| 105 | + |
| 106 | + self._aclient = AsyncOpenAI( |
| 107 | + base_url=self.base_url, |
| 108 | + api_key=self.api_key.get_secret_value(), |
| 109 | + max_retries=self.max_retries, |
| 110 | + timeout=self.timeout, |
| 111 | + ) |
| 112 | + |
| 113 | + if self.structured_output: |
| 114 | + result = self._prepare_structured_output( |
| 115 | + structured_output=self.structured_output, |
| 116 | + client=self._aclient, |
| 117 | + framework="openai", |
| 118 | + ) |
| 119 | + self._aclient = result.get("client") |
| 120 | + if structured_output := result.get("structured_output"): |
| 121 | + self.structured_output = structured_output |
| 122 | + |
| 123 | + @property |
| 124 | + def model_name(self) -> str: |
| 125 | + """Returns the model name used for the LLM.""" |
| 126 | + return self.model |
| 127 | + |
| 128 | + @validate_call |
| 129 | + async def agenerate( |
| 130 | + self, |
| 131 | + input: FormattedInput, |
| 132 | + num_generations: int = 1, |
| 133 | + max_new_tokens: int = 128, |
| 134 | + frequency_penalty: float = 0.0, |
| 135 | + presence_penalty: float = 0.0, |
| 136 | + temperature: float = 1.0, |
| 137 | + top_p: float = 1.0, |
| 138 | + stop: Optional[Union[str, List[str]]] = None, |
| 139 | + response_format: Optional[str] = None, |
| 140 | + ) -> GenerateOutput: |
| 141 | + """Generates text using the OneAI LLM.""" |
| 142 | + |
| 143 | + structured_output = None |
| 144 | + if isinstance(input, tuple): |
| 145 | + input, structured_output = input |
| 146 | + result = self._prepare_structured_output( |
| 147 | + structured_output=structured_output, |
| 148 | + client=self._aclient, |
| 149 | + framework="openai", |
| 150 | + ) |
| 151 | + self._aclient = result.get("client") |
| 152 | + if structured_output is None and self.structured_output is not None: |
| 153 | + structured_output = self.structured_output |
| 154 | + |
| 155 | + kwargs = { |
| 156 | + "messages": input, |
| 157 | + "model": self.model, |
| 158 | + "max_tokens": max_new_tokens, |
| 159 | + "n": num_generations, |
| 160 | + "frequency_penalty": frequency_penalty, |
| 161 | + "presence_penalty": presence_penalty, |
| 162 | + "temperature": temperature, |
| 163 | + "top_p": top_p, |
| 164 | + "stop": stop, |
| 165 | + "timeout": 50, |
| 166 | + } |
| 167 | + |
| 168 | + if response_format is not None: |
| 169 | + if response_format not in ["text", "json", "json_object"]: |
| 170 | + raise ValueError( |
| 171 | + f"Invalid response format '{response_format}'. Must be either 'text' or 'json'." |
| 172 | + ) |
| 173 | + if response_format == "json": |
| 174 | + response_format = "json_object" |
| 175 | + kwargs["response_format"] = response_format |
| 176 | + |
| 177 | + if structured_output: |
| 178 | + kwargs = self._prepare_kwargs(kwargs, structured_output) |
| 179 | + |
| 180 | + generations = [] |
| 181 | + completion = await self._aclient.chat.completions.create(**kwargs) |
| 182 | + if structured_output: |
| 183 | + generations.append(completion.model_dump_json()) |
| 184 | + return generations |
| 185 | + |
| 186 | + for choice in completion.choices: |
| 187 | + if (content := choice.message.content) is None: |
| 188 | + self._logger.warning( |
| 189 | + f"Received no response using OpenAI client (model: '{self.model}')." |
| 190 | + f" Finish reason was: {choice.finish_reason}" |
| 191 | + ) |
| 192 | + generations.append(content) |
| 193 | + return generations |
0 commit comments