Skip to content

Commit 1eafdfb

Browse files
feat: cost and capabilities config for custom litellm models (#481)
1 parent 01b1859 commit 1eafdfb

File tree

3 files changed

+110
-1
lines changed

3 files changed

+110
-1
lines changed

packages/ragbits-core/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
- Make the score in VectorStoreResult consistent (always bigger is better)
55
- Add router option to LiteLLMEmbedder (#440)
66
- Fix: make unflatten_dict symmetric to flatten_dict (#461)
7+
- Cost and capabilities config for custom litellm models (#481)
78

89
## 0.12.0 (2025-03-25)
910
- Allow Prompt class to accept the asynchronous response_parser. Change the signature of parse_response method.

packages/ragbits-core/src/ragbits/core/llms/litellm.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import AsyncGenerator
1+
from collections.abc import AsyncGenerator, Callable
22
from typing import Any
33

44
import litellm
@@ -57,6 +57,7 @@ def __init__(
5757
api_version: str | None = None,
5858
use_structured_output: bool = False,
5959
router: litellm.Router | None = None,
60+
custom_model_cost_config: dict | None = None,
6061
) -> None:
6162
"""
6263
Constructs a new LiteLLM instance.
@@ -74,13 +75,20 @@ def __init__(
7475
[structured output](https://docs.litellm.ai/docs/completion/json_mode#pass-in-json_schema)
7576
from the model. Default is False. Can only be combined with models that support structured output.
7677
router: Router to be used to [route requests](https://docs.litellm.ai/docs/routing) to different models.
78+
custom_model_cost_config: Custom cost and capabilities configuration for the model.
79+
Necessary for custom model cost and capabilities tracking in LiteLLM.
80+
See the [LiteLLM documentation](https://docs.litellm.ai/docs/completion/token_usage#9-register_model)
81+
for more information.
7782
"""
7883
super().__init__(model_name, default_options)
7984
self.base_url = base_url
8085
self.api_key = api_key
8186
self.api_version = api_version
8287
self.use_structured_output = use_structured_output
8388
self.router = router
89+
self.custom_model_cost_config = custom_model_cost_config
90+
if custom_model_cost_config:
91+
litellm.register_model(custom_model_cost_config)
8492

8593
def count_tokens(self, prompt: BasePrompt) -> int:
8694
"""
@@ -257,3 +265,17 @@ def from_config(cls, config: dict[str, Any]) -> Self:
257265
router = litellm.router.Router(model_list=config["router"])
258266
config["router"] = router
259267
return super().from_config(config)
268+
269+
def __reduce__(self) -> tuple[Callable, tuple]:
270+
config = {
271+
"model_name": self.model_name,
272+
"default_options": self.default_options.dict(),
273+
"base_url": self.base_url,
274+
"api_key": self.api_key,
275+
"api_version": self.api_version,
276+
"use_structured_output": self.use_structured_output,
277+
"custom_model_cost_config": self.custom_model_cost_config,
278+
}
279+
if self.router:
280+
config["router"] = self.router.model_list
281+
return self.from_config, (config,)

packages/ragbits-core/tests/unit/llms/test_litellm.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1+
import pickle
2+
from unittest.mock import patch
3+
14
import pytest
5+
from litellm import Router
26
from pydantic import BaseModel
37

48
from ragbits.core.llms.exceptions import LLMNotSupportingImagesError
@@ -176,3 +180,85 @@ async def test_generation_without_image_support():
176180
prompt = MockPromptWithImage("Hello, what is on this image?")
177181
with pytest.raises(LLMNotSupportingImagesError):
178182
await llm.generate(prompt)
183+
184+
185+
async def test_pickling():
186+
"""Test pickling of the LiteLLM class."""
187+
llm = LiteLLM(
188+
model_name="gpt-3.5-turbo",
189+
default_options=LiteLLMOptions(mock_response="I'm fine, thank you."),
190+
custom_model_cost_config={
191+
"gpt-3.5-turbo": {
192+
"support_vision": True,
193+
}
194+
},
195+
use_structured_output=True,
196+
router=Router(),
197+
base_url="https://api.litellm.ai",
198+
api_key="test_key",
199+
api_version="v1",
200+
)
201+
llm_pickled = pickle.loads(pickle.dumps(llm)) # noqa: S301
202+
assert llm_pickled.model_name == "gpt-3.5-turbo"
203+
assert llm_pickled.default_options.mock_response == "I'm fine, thank you."
204+
assert llm_pickled.custom_model_cost_config == {
205+
"gpt-3.5-turbo": {
206+
"support_vision": True,
207+
}
208+
}
209+
assert llm_pickled.use_structured_output
210+
assert llm_pickled.router.model_list == []
211+
assert llm_pickled.base_url == "https://api.litellm.ai"
212+
assert llm_pickled.api_key == "test_key"
213+
assert llm_pickled.api_version == "v1"
214+
215+
216+
async def test_init_registers_model_with_custom_cost_config():
217+
"""Test that custom model cost config properly registers the model with LiteLLM."""
218+
custom_config = {
219+
"some_model": {
220+
"support_vision": True,
221+
"input_cost_per_token": 0.0015,
222+
"output_cost_per_token": 0.002,
223+
"max_tokens": 4096,
224+
}
225+
}
226+
227+
with patch("litellm.register_model") as mock_register:
228+
# Create LLM instance with custom config
229+
LiteLLM(
230+
model_name="some_model",
231+
custom_model_cost_config=custom_config,
232+
)
233+
234+
# Verify register_model was called with the correct config
235+
mock_register.assert_called_once_with(custom_config)
236+
237+
238+
async def test_init_does_not_register_model_if_no_cost_config_is_provided():
239+
"""Test that the model is not registered if no cost config is provided."""
240+
with patch("litellm.register_model") as mock_register:
241+
LiteLLM(
242+
model_name="some_model",
243+
)
244+
mock_register.assert_not_called()
245+
246+
247+
async def test_pickling_registers_model_with_custom_cost_config():
248+
"""Test that the model is registered with LiteLLM when unpickled."""
249+
custom_config = {
250+
"some_model": {
251+
"support_vision": True,
252+
"input_cost_per_token": 0.0015,
253+
"output_cost_per_token": 0.002,
254+
"max_tokens": 4096,
255+
}
256+
}
257+
llm = LiteLLM(
258+
model_name="some_model",
259+
custom_model_cost_config=custom_config,
260+
)
261+
with patch("litellm.register_model") as mock_register:
262+
llm_pickled = pickle.loads(pickle.dumps(llm)) # noqa: S301
263+
assert llm_pickled.custom_model_cost_config == custom_config
264+
mock_register.assert_called_once_with(custom_config)

0 commit comments

Comments
 (0)