|
| 1 | +# Copyright 2023-present, Argilla, Inc. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import os |
| 16 | +from typing import TYPE_CHECKING, List, Optional, Union |
| 17 | + |
| 18 | +from pydantic import Field, PrivateAttr, SecretStr, validate_call |
| 19 | + |
| 20 | +from distilabel.llms.base import AsyncLLM |
| 21 | +from distilabel.llms.typing import GenerateOutput |
| 22 | +from distilabel.mixins.runtime_parameters import RuntimeParameter |
| 23 | +from distilabel.steps.tasks.typing import FormattedInput, InstructorStructuredOutputType |
| 24 | + |
| 25 | +if TYPE_CHECKING: |
| 26 | + from openai import AsyncOpenAI |
| 27 | + |
| 28 | + |
| 29 | +_01ai_API_KEY_ENV_VAR_NAME = "01AI_API_KEY" |
| 30 | + |
| 31 | + |
| 32 | +class OneAI(AsyncLLM): |
| 33 | + """WIP""" |
| 34 | + |
| 35 | + model: str |
| 36 | + base_url: Optional[RuntimeParameter[str]] = Field( |
| 37 | + default_factory=lambda: os.getenv( |
| 38 | + "01AI_BASE_URL", "https://api.openai.com/v1" |
| 39 | + ), |
| 40 | + description="The base URL to use for the 01AI API requests.", |
| 41 | + ) |
| 42 | + api_key: Optional[RuntimeParameter[SecretStr]] = Field( |
| 43 | + default_factory=lambda: os.getenv(_01ai_API_KEY_ENV_VAR_NAME), |
| 44 | + description="The API key to authenticate the requests to the 01ai API.", |
| 45 | + ) |
| 46 | + max_retries: RuntimeParameter[int] = Field( |
| 47 | + default=6, |
| 48 | + description="The maximum number of times to retry the request to the API before" |
| 49 | + " failing.", |
| 50 | + ) |
| 51 | + timeout: RuntimeParameter[int] = Field( |
| 52 | + default=120, |
| 53 | + description="The maximum time in seconds to wait for a response from the API.", |
| 54 | + ) |
| 55 | + structured_output: Optional[RuntimeParameter[InstructorStructuredOutputType]] = ( |
| 56 | + Field( |
| 57 | + default=None, |
| 58 | + description="The structured output format to use across all the generations.", |
| 59 | + ) |
| 60 | + ) |
| 61 | + |
| 62 | + _api_key_env_var: str = PrivateAttr(_01ai_API_KEY_ENV_VAR_NAME) |
| 63 | + _aclient: Optional["AsyncOpenAI"] = PrivateAttr(...) |
| 64 | + |
| 65 | + def load(self) -> None: |
| 66 | + """Loads the `AsyncOpenAI` client to benefit from async requests.""" |
| 67 | + super().load() |
| 68 | + |
| 69 | + try: |
| 70 | + from openai import AsyncOpenAI |
| 71 | + except ImportError as ie: |
| 72 | + raise ImportError( |
| 73 | + "OpenAI Python client is not installed. Please install it using" |
| 74 | + " `pip install openai`." |
| 75 | + ) from ie |
| 76 | + |
| 77 | + if self.api_key is None: |
| 78 | + raise ValueError( |
| 79 | + f"To use `{self.__class__.__name__}` an API key must be provided via `api_key`" |
| 80 | + f" attribute or runtime parameter, or set the environment variable `{self._api_key_env_var}`." |
| 81 | + ) |
| 82 | + |
| 83 | + self._aclient = AsyncOpenAI( |
| 84 | + base_url=self.base_url, |
| 85 | + api_key=self.api_key.get_secret_value(), |
| 86 | + max_retries=self.max_retries, # type: ignore |
| 87 | + timeout=self.timeout, |
| 88 | + ) |
| 89 | + |
| 90 | + if self.structured_output: |
| 91 | + result = self._prepare_structured_output( |
| 92 | + structured_output=self.structured_output, |
| 93 | + client=self._aclient, |
| 94 | + framework="openai", |
| 95 | + ) |
| 96 | + self._aclient = result.get("client") # type: ignore |
| 97 | + if structured_output := result.get("structured_output"): |
| 98 | + self.structured_output = structured_output |
| 99 | + |
| 100 | + @property |
| 101 | + def model_name(self) -> str: |
| 102 | + """Returns the model name used for the LLM.""" |
| 103 | + return self.model |
| 104 | + |
| 105 | + @validate_call |
| 106 | + async def agenerate( # type: ignore |
| 107 | + self, |
| 108 | + input: FormattedInput, |
| 109 | + num_generations: int = 1, |
| 110 | + max_new_tokens: int = 128, |
| 111 | + frequency_penalty: float = 0.0, |
| 112 | + presence_penalty: float = 0.0, |
| 113 | + temperature: float = 1.0, |
| 114 | + top_p: float = 1.0, |
| 115 | + stop: Optional[Union[str, List[str]]] = None, |
| 116 | + response_format: Optional[str] = None, |
| 117 | + ) -> GenerateOutput: |
| 118 | + """Generates `num_generations` responses for the given input using the OpenAI async |
| 119 | + client. |
| 120 | +
|
| 121 | + Args: |
| 122 | + input: a single input in chat format to generate responses for. |
| 123 | + num_generations: the number of generations to create per input. Defaults to |
| 124 | + `1`. |
| 125 | + max_new_tokens: the maximum number of new tokens that the model will generate. |
| 126 | + Defaults to `128`. |
| 127 | + frequency_penalty: the repetition penalty to use for the generation. Defaults |
| 128 | + to `0.0`. |
| 129 | + presence_penalty: the presence penalty to use for the generation. Defaults to |
| 130 | + `0.0`. |
| 131 | + temperature: the temperature to use for the generation. Defaults to `0.1`. |
| 132 | + top_p: the top-p value to use for the generation. Defaults to `1.0`. |
| 133 | + stop: a string or a list of strings to use as a stop sequence for the generation. |
| 134 | + Defaults to `None`. |
| 135 | + response_format: the format of the response to return. Must be one of |
| 136 | + "text" or "json". Read the documentation [here](https://platform.openai.com/docs/guides/text-generation/json-mode) |
| 137 | + for more information on how to use the JSON model from OpenAI. Defaults to `text`. |
| 138 | +
|
| 139 | + Note: |
| 140 | + If response_format |
| 141 | +
|
| 142 | + Returns: |
| 143 | + A list of lists of strings containing the generated responses for each input. |
| 144 | + """ |
| 145 | + |
| 146 | + structured_output = None |
| 147 | + if isinstance(input, tuple): |
| 148 | + input, structured_output = input |
| 149 | + result = self._prepare_structured_output( |
| 150 | + structured_output=structured_output, |
| 151 | + client=self._aclient, |
| 152 | + framework="openai", |
| 153 | + ) |
| 154 | + self._aclient = result.get("client") |
| 155 | + |
| 156 | + if structured_output is None and self.structured_output is not None: |
| 157 | + structured_output = self.structured_output |
| 158 | + |
| 159 | + kwargs = { |
| 160 | + "messages": input, # type: ignore |
| 161 | + "model": self.model, |
| 162 | + "max_tokens": max_new_tokens, |
| 163 | + "n": num_generations, |
| 164 | + "frequency_penalty": frequency_penalty, |
| 165 | + "presence_penalty": presence_penalty, |
| 166 | + "temperature": temperature, |
| 167 | + "top_p": top_p, |
| 168 | + "stop": stop, |
| 169 | + "timeout": 50, |
| 170 | + } |
| 171 | + |
| 172 | + if response_format is not None: |
| 173 | + if response_format not in ["text", "json", "json_object"]: |
| 174 | + raise ValueError( |
| 175 | + f"Invalid response format '{response_format}'. Must be either 'text'" |
| 176 | + " or 'json'." |
| 177 | + ) |
| 178 | + |
| 179 | + if response_format == "json": |
| 180 | + response_format = "json_object" |
| 181 | + |
| 182 | + kwargs["response_format"] = response_format |
| 183 | + |
| 184 | + if structured_output: |
| 185 | + kwargs = self._prepare_kwargs(kwargs, structured_output) |
| 186 | + |
| 187 | + generations = [] |
| 188 | + completion = await self._aclient.chat.completions.create(**kwargs) # type: ignore |
| 189 | + |
| 190 | + if structured_output: |
| 191 | + generations.append(completion.model_dump_json()) |
| 192 | + return generations |
| 193 | + |
| 194 | + for choice in completion.choices: |
| 195 | + if (content := choice.message.content) is None: |
| 196 | + self._logger.warning( # type: ignore |
| 197 | + f"Received no response using OpenAI client (model: '{self.model}')." |
| 198 | + f" Finish reason was: {choice.finish_reason}" |
| 199 | + ) |
| 200 | + generations.append(content) |
| 201 | + return generations |
0 commit comments