Skip to content

Commit 364dfb1

Browse files
authored
fix: from config for LiteLLM (#454)
1 parent a3b4ee4 commit 364dfb1

File tree

25 files changed

+141
-80
lines changed

25 files changed

+141
-80
lines changed

packages/ragbits-conversations/src/ragbits/conversations/history/stores/sql.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from ragbits.conversations.history.stores.base import HistoryStore
1111
from ragbits.core.options import Options
1212
from ragbits.core.prompt import ChatFormat
13-
from ragbits.core.utils.config_handling import ObjectContructionConfig
13+
from ragbits.core.utils.config_handling import ObjectConstructionConfig
1414

1515

1616
class _Base(DeclarativeBase):
@@ -190,6 +190,6 @@ def from_config(cls, config: dict) -> Self:
190190
Returns:
191191
An instance of the class initialized with the provided configuration.
192192
"""
193-
engine_options = ObjectContructionConfig.model_validate(config["sqlalchemy_engine"])
193+
engine_options = ObjectConstructionConfig.model_validate(config["sqlalchemy_engine"])
194194
config["sqlalchemy_engine"] = create_async_engine(engine_options.config["url"])
195195
return cls(**config)

packages/ragbits-core/CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
## 0.12.0 (2025-03-25)
66
- Allow Prompt class to accept the asynchronous response_parser. Change the signature of parse_response method.
7-
7+
- Fix from_config for LiteLLM class (#441)
88
- Fix Qdrant vector store serialization (#419)
99

1010
## 0.11.0 (2025-03-25)

packages/ragbits-core/src/ragbits/core/llms/litellm.py

+18
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from collections.abc import AsyncGenerator
2+
from typing import Any
23

34
import litellm
45
from litellm.utils import CustomStreamWrapper, ModelResponse
56
from pydantic import BaseModel
7+
from typing_extensions import Self
68

79
from ragbits.core.audit import trace
810
from ragbits.core.llms.base import LLM
@@ -239,3 +241,19 @@ def _get_response_format(
239241
elif json_mode:
240242
response_format = {"type": "json_object"}
241243
return response_format
244+
245+
@classmethod
246+
def from_config(cls, config: dict[str, Any]) -> Self:
247+
"""
248+
Creates and returns a LiteLLM instance.
249+
250+
Args:
251+
config: A configuration object containing the configuration for initializing the LiteLLM instance.
252+
253+
Returns:
254+
LiteLLM: An initialized LiteLLM instance.
255+
"""
256+
if "router" in config:
257+
router = litellm.router.Router(model_list=config["router"])
258+
config["router"] = router
259+
return super().from_config(config)

packages/ragbits-core/src/ragbits/core/utils/config_handling.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class NoPreferredConfigError(InvalidConfigError):
3131
def import_by_path(path: str, default_module: ModuleType | None = None) -> Any: # noqa: ANN401
3232
"""
3333
Retrieves and returns an object based on the string in the format of "module.submodule:object_name".
34-
If the first part is ommited, the default module is used.
34+
If the first part is omitted, the default module is used.
3535
3636
Args:
3737
path: A string representing the path to the object. This can either be a
@@ -63,7 +63,7 @@ def import_by_path(path: str, default_module: ModuleType | None = None) -> Any:
6363
raise InvalidConfigError(f"{path} not found in module {default_module}") from err
6464

6565

66-
class ObjectContructionConfig(BaseModel):
66+
class ObjectConstructionConfig(BaseModel):
6767
"""
6868
A model for object construction configuration.
6969
"""
@@ -87,7 +87,7 @@ class WithConstructionConfig(abc.ABC):
8787
configuration_key: ClassVar[str]
8888

8989
@classmethod
90-
def subclass_from_config(cls, config: ObjectContructionConfig) -> Self:
90+
def subclass_from_config(cls, config: ObjectConstructionConfig) -> Self:
9191
"""
9292
Initializes the class with the provided configuration. May return a subclass of the class,
9393
if requested by the configuration.
@@ -151,7 +151,7 @@ def preferred_subclass(
151151
if yaml_path_override:
152152
preferences = get_config_from_yaml(yaml_path_override)
153153
if type_config := preferences.get(cls.configuration_key):
154-
return cls.subclass_from_config(ObjectContructionConfig.model_validate(type_config))
154+
return cls.subclass_from_config(ObjectConstructionConfig.model_validate(type_config))
155155

156156
if factory_path_override:
157157
return cls.subclass_from_factory(factory_path_override)
@@ -160,7 +160,7 @@ def preferred_subclass(
160160
return cls.subclass_from_factory(preferred_factory)
161161

162162
if preferred_config := config.preferred_instances_config.get(cls.configuration_key):
163-
return cls.subclass_from_config(ObjectContructionConfig.model_validate(preferred_config))
163+
return cls.subclass_from_config(ObjectConstructionConfig.model_validate(preferred_config))
164164

165165
raise NoPreferredConfigError(f"Could not find preferred factory or configuration for {cls.configuration_key}")
166166

@@ -195,7 +195,7 @@ def __init__(self, default_options: OptionsT | None = None) -> None:
195195
self.default_options: OptionsT = default_options or self.options_cls()
196196

197197
@classmethod
198-
def from_config(cls, config: dict[str, Any]) -> ConfigurableComponent:
198+
def from_config(cls, config: dict[str, Any]) -> Self:
199199
"""
200200
Initializes the class with the provided configuration.
201201

packages/ragbits-core/src/ragbits/core/vector_stores/base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from ragbits.core import vector_stores
1111
from ragbits.core.embeddings.base import Embedder
1212
from ragbits.core.options import Options
13-
from ragbits.core.utils.config_handling import ConfigurableComponent, ObjectContructionConfig
13+
from ragbits.core.utils.config_handling import ConfigurableComponent, ObjectConstructionConfig
1414
from ragbits.core.utils.pydantic import SerializableBytes
1515

1616
WhereQuery = dict[str, str | int | float | bool]
@@ -196,6 +196,6 @@ def from_config(cls, config: dict) -> Self:
196196
options = cls.options_cls(**default_options) if default_options else None
197197

198198
embedder_config = config.pop("embedder")
199-
embedder: Embedder = Embedder.subclass_from_config(ObjectContructionConfig.model_validate(embedder_config))
199+
embedder: Embedder = Embedder.subclass_from_config(ObjectConstructionConfig.model_validate(embedder_config))
200200

201201
return cls(**config, default_options=options, embedder=embedder)

packages/ragbits-core/src/ragbits/core/vector_stores/chroma.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from ragbits.core.audit import trace
1010
from ragbits.core.embeddings.base import Embedder
11-
from ragbits.core.utils.config_handling import ObjectContructionConfig, import_by_path
11+
from ragbits.core.utils.config_handling import ObjectConstructionConfig, import_by_path
1212
from ragbits.core.utils.dict_transformations import flatten_dict, unflatten_dict
1313
from ragbits.core.vector_stores.base import (
1414
EmbeddingType,
@@ -78,7 +78,7 @@ def from_config(cls, config: dict) -> Self:
7878
Returns:
7979
An instance of the class initialized with the provided configuration.
8080
"""
81-
client_options = ObjectContructionConfig.model_validate(config["client"])
81+
client_options = ObjectConstructionConfig.model_validate(config["client"])
8282
client_cls = import_by_path(client_options.type, chromadb)
8383
config["client"] = client_cls(**client_options.config)
8484
return super().from_config(config)

packages/ragbits-core/src/ragbits/core/vector_stores/qdrant.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from ragbits.core.audit import trace
1313
from ragbits.core.embeddings.base import Embedder
14-
from ragbits.core.utils.config_handling import ObjectContructionConfig, import_by_path
14+
from ragbits.core.utils.config_handling import ObjectConstructionConfig, import_by_path
1515
from ragbits.core.utils.dict_transformations import flatten_dict
1616
from ragbits.core.vector_stores.base import (
1717
EmbeddingType,
@@ -105,7 +105,7 @@ def from_config(cls, config: dict) -> Self:
105105
Returns:
106106
An instance of the class initialized with the provided configuration.
107107
"""
108-
client_options = ObjectContructionConfig.model_validate(config["client"])
108+
client_options = ObjectConstructionConfig.model_validate(config["client"])
109109
client_cls = import_by_path(client_options.type, qdrant_client)
110110
if "limits" in client_options.config:
111111
limits = httpx.Limits(**client_options.config["limits"])

packages/ragbits-core/tests/unit/embeddings/test_from_config.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
from ragbits.core.embeddings.litellm import LiteLLMEmbedder, LiteLLMEmbedderOptions
33
from ragbits.core.embeddings.sparse import BagOfTokens, BagOfTokensOptions, SparseEmbedder
44
from ragbits.core.types import NOT_GIVEN
5-
from ragbits.core.utils.config_handling import ObjectContructionConfig
5+
from ragbits.core.utils.config_handling import ObjectConstructionConfig
66

77

88
def test_subclass_from_config_litellm():
9-
config = ObjectContructionConfig.model_validate(
9+
config = ObjectConstructionConfig.model_validate(
1010
{
1111
"type": "ragbits.core.embeddings.litellm:LiteLLMEmbedder",
1212
"config": {
@@ -32,13 +32,13 @@ def test_subclass_from_config_litellm():
3232

3333

3434
def test_subclass_from_config_default_path_litellm():
35-
config = ObjectContructionConfig.model_validate({"type": "NoopEmbedder"})
35+
config = ObjectConstructionConfig.model_validate({"type": "NoopEmbedder"})
3636
embedder: Embedder = Embedder.subclass_from_config(config)
3737
assert isinstance(embedder, NoopEmbedder)
3838

3939

4040
def test_subclass_from_config_bag_of_tokens():
41-
config = ObjectContructionConfig.model_validate(
41+
config = ObjectConstructionConfig.model_validate(
4242
{
4343
"type": "ragbits.core.embeddings.sparse:BagOfTokens",
4444
"config": {

packages/ragbits-core/tests/unit/llms/test_from_config.py

+44-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
import litellm
2+
13
from ragbits.core.llms import LLM
24
from ragbits.core.llms.litellm import LiteLLM, LiteLLMOptions
3-
from ragbits.core.utils.config_handling import ObjectContructionConfig
5+
from ragbits.core.utils.config_handling import ObjectConstructionConfig
46

57

68
def test_subclass_from_config():
7-
config = ObjectContructionConfig.model_validate(
9+
config = ObjectConstructionConfig.model_validate(
810
{
911
"type": "ragbits.core.llms.litellm:LiteLLM",
1012
"config": {
@@ -27,6 +29,45 @@ def test_subclass_from_config():
2729

2830

2931
def test_subclass_from_config_default_path():
30-
config = ObjectContructionConfig.model_validate({"type": "LiteLLM"})
32+
config = ObjectConstructionConfig.model_validate({"type": "LiteLLM"})
33+
llm: LLM = LLM.subclass_from_config(config)
34+
assert isinstance(llm, LiteLLM)
35+
36+
37+
def test_from_config_with_router():
38+
config = ObjectConstructionConfig(
39+
type="ragbits.core.llms.litellm:LiteLLM",
40+
config={
41+
"model_name": "gpt-4-turbo",
42+
"api_key": "test_api_key",
43+
"router": [
44+
{
45+
"model_name": "gpt-4o",
46+
"litellm_params": {
47+
"model": "azure/gpt-4o-eval-1",
48+
"api_key": "test_api_key",
49+
"api_version": "2024-07-19-test",
50+
"api_base": "https://test-api.openai.azure.com",
51+
},
52+
},
53+
{
54+
"model_name": "gpt-4o",
55+
"litellm_params": {
56+
"model": "azure/gpt-4o-eval-2",
57+
"api_key": "test_api_key",
58+
"api_version": "2024-07-19-test",
59+
"api_base": "https://test-api.openai.azure.com",
60+
},
61+
},
62+
],
63+
},
64+
)
65+
3166
llm: LLM = LLM.subclass_from_config(config)
3267
assert isinstance(llm, LiteLLM)
68+
assert llm.base_url is None
69+
assert llm.model_name == "gpt-4-turbo"
70+
assert llm.api_key == "test_api_key"
71+
assert isinstance(llm.router, litellm.router.Router)
72+
assert len(llm.router.model_list) == 2
73+
assert llm.router.model_list[0]["model_name"] == "gpt-4o"

packages/ragbits-core/tests/unit/utils/test_config_handling.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from ragbits.core.config import CoreConfig, core_config
77
from ragbits.core.utils._pyproject import get_config_instance
8-
from ragbits.core.utils.config_handling import InvalidConfigError, ObjectContructionConfig, WithConstructionConfig
8+
from ragbits.core.utils.config_handling import InvalidConfigError, ObjectConstructionConfig, WithConstructionConfig
99

1010
projects_dir = Path(__file__).parent / "testprojects"
1111

@@ -40,7 +40,7 @@ def test_default_from_config():
4040

4141

4242
def test_subclass_from_config():
43-
config = ObjectContructionConfig.model_validate(
43+
config = ObjectConstructionConfig.model_validate(
4444
{
4545
"type": "ExampleSubclass",
4646
"config": {"foo": "foo", "bar": 1},
@@ -53,7 +53,7 @@ def test_subclass_from_config():
5353

5454

5555
def test_incorrect_subclass_from_config():
56-
config = ObjectContructionConfig.model_validate(
56+
config = ObjectConstructionConfig.model_validate(
5757
{
5858
"type": "ExampleWithNoDefaultModule", # Not a subclass of ExampleClassWithConfigMixin
5959
"config": {"foo": "foo", "bar": 1},
@@ -64,7 +64,7 @@ def test_incorrect_subclass_from_config():
6464

6565

6666
def test_no_default_module():
67-
config = ObjectContructionConfig.model_validate(
67+
config = ObjectConstructionConfig.model_validate(
6868
{
6969
"type": "ExampleWithNoDefaultModule",
7070
"config": {"foo": "foo", "bar": 1},

packages/ragbits-core/tests/unit/vector_stores/test_from_config.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
from qdrant_client import AsyncQdrantClient
44
from qdrant_client.local.async_qdrant_local import AsyncQdrantLocal
55

6-
from ragbits.core.utils.config_handling import ObjectContructionConfig
6+
from ragbits.core.utils.config_handling import ObjectConstructionConfig
77
from ragbits.core.vector_stores.base import VectorStore, VectorStoreOptions
88
from ragbits.core.vector_stores.chroma import ChromaVectorStore
99
from ragbits.core.vector_stores.in_memory import InMemoryVectorStore
1010
from ragbits.core.vector_stores.qdrant import QdrantVectorStore
1111

1212

1313
def test_subclass_from_config():
14-
config = ObjectContructionConfig.model_validate(
14+
config = ObjectConstructionConfig.model_validate(
1515
{
1616
"type": "ragbits.core.vector_stores:InMemoryVectorStore",
1717
"config": {
@@ -33,7 +33,7 @@ def test_subclass_from_config():
3333

3434

3535
def test_subclass_from_config_default_path():
36-
config = ObjectContructionConfig.model_validate(
36+
config = ObjectConstructionConfig.model_validate(
3737
{
3838
"type": "InMemoryVectorStore",
3939
"config": {
@@ -46,7 +46,7 @@ def test_subclass_from_config_default_path():
4646

4747

4848
def test_subclass_from_config_chroma_client():
49-
config = ObjectContructionConfig.model_validate(
49+
config = ObjectConstructionConfig.model_validate(
5050
{
5151
"type": "ragbits.core.vector_stores.chroma:ChromaVectorStore",
5252
"config": {
@@ -69,7 +69,7 @@ def test_subclass_from_config_chroma_client():
6969

7070

7171
def test_subclass_from_config_qdrant_client():
72-
config = ObjectContructionConfig.model_validate(
72+
config = ObjectConstructionConfig.model_validate(
7373
{
7474
"type": "ragbits.core.vector_stores.qdrant:QdrantVectorStore",
7575
"config": {

packages/ragbits-document-search/src/ragbits/document_search/_main.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from ragbits.core.utils._pyproject import get_config_from_yaml
1313
from ragbits.core.utils.config_handling import (
1414
NoPreferredConfigError,
15-
ObjectContructionConfig,
15+
ObjectConstructionConfig,
1616
WithConstructionConfig,
1717
)
1818
from ragbits.core.vector_stores import VectorStore
@@ -49,12 +49,12 @@ class DocumentSearchConfig(BaseModel):
4949
Schema for the dict taken by DocumentSearch.from_config method.
5050
"""
5151

52-
vector_store: ObjectContructionConfig
53-
rephraser: ObjectContructionConfig = ObjectContructionConfig(type="NoopQueryRephraser")
54-
reranker: ObjectContructionConfig = ObjectContructionConfig(type="NoopReranker")
55-
ingest_strategy: ObjectContructionConfig = ObjectContructionConfig(type="SequentialIngestStrategy")
56-
parser_router: dict[str, ObjectContructionConfig] = {}
57-
enricher_router: dict[str, ObjectContructionConfig] = {}
52+
vector_store: ObjectConstructionConfig
53+
rephraser: ObjectConstructionConfig = ObjectConstructionConfig(type="NoopQueryRephraser")
54+
reranker: ObjectConstructionConfig = ObjectConstructionConfig(type="NoopReranker")
55+
ingest_strategy: ObjectConstructionConfig = ObjectConstructionConfig(type="SequentialIngestStrategy")
56+
parser_router: dict[str, ObjectConstructionConfig] = {}
57+
enricher_router: dict[str, ObjectConstructionConfig] = {}
5858

5959

6060
class DocumentSearch(WithConstructionConfig):
@@ -158,7 +158,7 @@ def preferred_subclass(
158158

159159
# Look for explicit document search configuration
160160
if type_config := preferences.get(cls.configuration_key):
161-
return cls.subclass_from_config(ObjectContructionConfig.model_validate(type_config))
161+
return cls.subclass_from_config(ObjectConstructionConfig.model_validate(type_config))
162162

163163
# Instantiate the class with the preferred configuration for each component
164164
return cls.from_config(preferences)
@@ -172,7 +172,7 @@ def preferred_subclass(
172172
if config.component_preference_config_path is not None:
173173
# Look for explicit document search configuration
174174
if preferred_config := config.preferred_instances_config.get(cls.configuration_key):
175-
return cls.subclass_from_config(ObjectContructionConfig.model_validate(preferred_config))
175+
return cls.subclass_from_config(ObjectConstructionConfig.model_validate(preferred_config))
176176

177177
# Instantiate the class with the preferred configuration for each component
178178
return cls.from_config(config.preferred_instances_config)

0 commit comments

Comments
 (0)