Skip to content

Commit b4c064b

Browse files
authored
feat: add router option to LiteLLMEmbedder (#457)
1 parent bd025d0 commit b4c064b

File tree

3 files changed

+67
-1
lines changed

3 files changed

+67
-1
lines changed

packages/ragbits-core/CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
## Unreleased
44

5+
- Add router option to LiteLLMEmbedder (#440)
6+
57
## 0.12.0 (2025-03-25)
68
- Allow Prompt class to accept the asynchronous response_parser. Change the signature of parse_response method.
79
- Fix from_config for LiteLLM class (#441)

packages/ragbits-core/src/ragbits/core/embeddings/litellm.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
from typing import Any
2+
13
import litellm
4+
from typing_extensions import Self
25

36
from ragbits.core.audit import trace
47
from ragbits.core.embeddings import Embedder
@@ -38,6 +41,7 @@ def __init__(
3841
api_base: str | None = None,
3942
api_key: str | None = None,
4043
api_version: str | None = None,
44+
router: litellm.Router | None = None,
4145
) -> None:
4246
"""
4347
Constructs the LiteLLMEmbeddingClient.
@@ -51,12 +55,14 @@ def __init__(
5155
for more information, follow the instructions for your specific vendor in the\
5256
[LiteLLM documentation](https://docs.litellm.ai/docs/embedding/supported_embedding).
5357
api_version: The API version for the call.
58+
router: Router to be used to [route requests](https://docs.litellm.ai/docs/routing) to different models.
5459
"""
5560
super().__init__(default_options=default_options)
5661
self.model = model
5762
self.api_base = api_base
5863
self.api_key = api_key
5964
self.api_version = api_version
65+
self.router = router
6066

6167
async def embed_text(self, data: list[str], options: LiteLLMEmbedderOptions | None = None) -> list[list[float]]:
6268
"""
@@ -85,7 +91,8 @@ async def embed_text(self, data: list[str], options: LiteLLMEmbedderOptions | No
8591
options=merged_options.dict(),
8692
) as outputs:
8793
try:
88-
response = await litellm.aembedding(
94+
entrypoint = self.router or litellm
95+
response = await entrypoint.aembedding(
8996
input=data,
9097
model=self.model,
9198
api_base=self.api_base,
@@ -110,3 +117,19 @@ async def embed_text(self, data: list[str], options: LiteLLMEmbedderOptions | No
110117
outputs.total_tokens = response.usage.total_tokens
111118

112119
return outputs.embeddings
120+
121+
@classmethod
122+
def from_config(cls, config: dict[str, Any]) -> Self:
123+
"""
124+
Creates and returns a LiteLLMEmbedder instance.
125+
126+
Args:
127+
config: A configuration object containing the configuration for initializing the LiteLLMEmbedder instance.
128+
129+
Returns:
130+
LiteLLMEmbedder: An initialized LiteLLMEmbedder instance.
131+
"""
132+
if "router" in config:
133+
router = litellm.router.Router(model_list=config["router"])
134+
config["router"] = router
135+
return super().from_config(config)

packages/ragbits-core/tests/unit/embeddings/test_from_config.py

+41
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import litellm
2+
13
from ragbits.core.embeddings import Embedder, NoopEmbedder
24
from ragbits.core.embeddings.litellm import LiteLLMEmbedder, LiteLLMEmbedderOptions
35
from ragbits.core.embeddings.sparse import BagOfTokens, BagOfTokensOptions, SparseEmbedder
@@ -58,3 +60,42 @@ def test_subclass_from_config_bag_of_tokens():
5860
option1="value1",
5961
option2="value2",
6062
) # type: ignore
63+
64+
65+
def test_from_config_with_router():
66+
config = ObjectConstructionConfig(
67+
type="ragbits.core.embeddings.litellm:LiteLLMEmbedder",
68+
config={
69+
"model": "text-embedding-3-small",
70+
"api_key": "test_api_key",
71+
"router": [
72+
{
73+
"model_name": "small",
74+
"litellm_params": {
75+
"model": "text-embedding-3-small",
76+
"dimensions": 3000,
77+
"api_key": "test_api_key",
78+
},
79+
},
80+
{
81+
"model_name": "large",
82+
"litellm_params": {
83+
"model": "text-embedding-3-large",
84+
"api_key": "test_api_key",
85+
},
86+
},
87+
],
88+
},
89+
)
90+
91+
embedder: Embedder = Embedder.subclass_from_config(config)
92+
assert isinstance(embedder, LiteLLMEmbedder)
93+
assert embedder.api_base is None
94+
assert embedder.model == "text-embedding-3-small"
95+
assert embedder.api_key == "test_api_key"
96+
assert isinstance(embedder.router, litellm.router.Router)
97+
assert len(embedder.router.model_list) == 2
98+
assert embedder.router.model_list[0]["model_name"] == "small"
99+
assert embedder.router.model_list[0]["litellm_params"]["dimensions"] == 3000
100+
assert embedder.router.model_list[1]["model_name"] == "large"
101+
assert embedder.router.model_list[1]["litellm_params"]["model"] == "text-embedding-3-large"

0 commit comments

Comments
 (0)