|
| 1 | +import pickle |
| 2 | +from unittest.mock import patch |
| 3 | + |
1 | 4 | import pytest
|
| 5 | +from litellm import Router |
2 | 6 | from pydantic import BaseModel
|
3 | 7 |
|
4 | 8 | from ragbits.core.llms.exceptions import LLMNotSupportingImagesError
|
@@ -176,3 +180,85 @@ async def test_generation_without_image_support():
|
176 | 180 | prompt = MockPromptWithImage("Hello, what is on this image?")
|
177 | 181 | with pytest.raises(LLMNotSupportingImagesError):
|
178 | 182 | await llm.generate(prompt)
|
| 183 | + |
| 184 | + |
| 185 | +async def test_pickling(): |
| 186 | + """Test pickling of the LiteLLM class.""" |
| 187 | + llm = LiteLLM( |
| 188 | + model_name="gpt-3.5-turbo", |
| 189 | + default_options=LiteLLMOptions(mock_response="I'm fine, thank you."), |
| 190 | + custom_model_cost_config={ |
| 191 | + "gpt-3.5-turbo": { |
| 192 | + "support_vision": True, |
| 193 | + } |
| 194 | + }, |
| 195 | + use_structured_output=True, |
| 196 | + router=Router(), |
| 197 | + base_url="https://api.litellm.ai", |
| 198 | + api_key="test_key", |
| 199 | + api_version="v1", |
| 200 | + ) |
| 201 | + llm_pickled = pickle.loads(pickle.dumps(llm)) # noqa: S301 |
| 202 | + assert llm_pickled.model_name == "gpt-3.5-turbo" |
| 203 | + assert llm_pickled.default_options.mock_response == "I'm fine, thank you." |
| 204 | + assert llm_pickled.custom_model_cost_config == { |
| 205 | + "gpt-3.5-turbo": { |
| 206 | + "support_vision": True, |
| 207 | + } |
| 208 | + } |
| 209 | + assert llm_pickled.use_structured_output |
| 210 | + assert llm_pickled.router.model_list == [] |
| 211 | + assert llm_pickled.base_url == "https://api.litellm.ai" |
| 212 | + assert llm_pickled.api_key == "test_key" |
| 213 | + assert llm_pickled.api_version == "v1" |
| 214 | + |
| 215 | + |
| 216 | +async def test_init_registers_model_with_custom_cost_config(): |
| 217 | + """Test that custom model cost config properly registers the model with LiteLLM.""" |
| 218 | + custom_config = { |
| 219 | + "some_model": { |
| 220 | + "support_vision": True, |
| 221 | + "input_cost_per_token": 0.0015, |
| 222 | + "output_cost_per_token": 0.002, |
| 223 | + "max_tokens": 4096, |
| 224 | + } |
| 225 | + } |
| 226 | + |
| 227 | + with patch("litellm.register_model") as mock_register: |
| 228 | + # Create LLM instance with custom config |
| 229 | + LiteLLM( |
| 230 | + model_name="some_model", |
| 231 | + custom_model_cost_config=custom_config, |
| 232 | + ) |
| 233 | + |
| 234 | + # Verify register_model was called with the correct config |
| 235 | + mock_register.assert_called_once_with(custom_config) |
| 236 | + |
| 237 | + |
| 238 | +async def test_init_does_not_register_model_if_no_cost_config_is_provided(): |
| 239 | + """Test that the model is not registered if no cost config is provided.""" |
| 240 | + with patch("litellm.register_model") as mock_register: |
| 241 | + LiteLLM( |
| 242 | + model_name="some_model", |
| 243 | + ) |
| 244 | + mock_register.assert_not_called() |
| 245 | + |
| 246 | + |
| 247 | +async def test_pickling_registers_model_with_custom_cost_config(): |
| 248 | + """Test that the model is registered with LiteLLM when unpickled.""" |
| 249 | + custom_config = { |
| 250 | + "some_model": { |
| 251 | + "support_vision": True, |
| 252 | + "input_cost_per_token": 0.0015, |
| 253 | + "output_cost_per_token": 0.002, |
| 254 | + "max_tokens": 4096, |
| 255 | + } |
| 256 | + } |
| 257 | + llm = LiteLLM( |
| 258 | + model_name="some_model", |
| 259 | + custom_model_cost_config=custom_config, |
| 260 | + ) |
| 261 | + with patch("litellm.register_model") as mock_register: |
| 262 | + llm_pickled = pickle.loads(pickle.dumps(llm)) # noqa: S301 |
| 263 | + assert llm_pickled.custom_model_cost_config == custom_config |
| 264 | + mock_register.assert_called_once_with(custom_config) |
0 commit comments