Skip to content

Commit 2fd4404

Browse files
authored
refactor: make litellm and litellmembedder consistent (#463)
1 parent 51c7781 commit 2fd4404

File tree

11 files changed

+66
-35
lines changed

11 files changed

+66
-35
lines changed

docs/api_reference/core/embeddings.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,7 @@
55
::: ragbits.core.embeddings.local.LocalEmbedder
66

77
::: ragbits.core.embeddings.litellm.LiteLLMEmbedder
8+
9+
::: ragbits.core.embeddings.fastembed.FastEmbedEmbedder
10+
11+
::: ragbits.core.embeddings.fastembed.FastEmbedSparseEmbedder

docs/how-to/llms/use_local_llms.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ from ragbits.core.prompt.base import SimplePrompt
7171

7272

7373
async def main() -> None:
74-
llm = LiteLLM(model_name="openai/local", api_key="<api_key>", base_url="http://127.0.0.1:8080")
74+
llm = LiteLLM(model_name="openai/local", api_key="<api_key>", api_base="http://127.0.0.1:8080")
7575
prompt = SimplePrompt("Tell me a joke about software developers.")
7676
response = await llm.generate(prompt)
7777
print(response)
@@ -99,7 +99,7 @@ from ragbits.core.prompt.base import SimplePrompt
9999

100100

101101
async def main() -> None:
102-
llm = LiteLLM(model_name="hosted_vllm/<model_name>", base_url="http://127.0.0.1:8000/v1")
102+
llm = LiteLLM(model_name="hosted_vllm/<model_name>", api_base="http://127.0.0.1:8000/v1")
103103
prompt = SimplePrompt("Tell me a joke about software developers.")
104104
response = await llm.generate(prompt)
105105
print(response)
@@ -123,7 +123,7 @@ from ragbits.core.embeddings.litellm import LiteLLMEmbedder
123123

124124

125125
async def main() -> None:
126-
embedder = LiteLLMEmbedder(model="hosted_vllm/<model_name>", api_base="http://127.0.0.1:8000/v1")
126+
embedder = LiteLLMEmbedder(model_name="hosted_vllm/<model_name>", api_base="http://127.0.0.1:8000/v1")
127127
embeddings = await embedder.embed_text(["Hello"])
128128
print(len(embeddings[0]))
129129

examples/document-search/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ async def main() -> None:
6969
Run the example.
7070
"""
7171
embedder = LiteLLMEmbedder(
72-
model="text-embedding-3-small",
72+
model_name="text-embedding-3-small",
7373
)
7474
vector_store = InMemoryVectorStore(embedder=embedder)
7575
document_search = DocumentSearch(

packages/ragbits-core/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
## Unreleased
44
- Make the score in VectorStoreResult consistent (always bigger is better)
55
- Add router option to LiteLLMEmbedder (#440)
6+
- Make LLM / Embedder APIs consistent (#463)
67
- New methods in Prompt class for appending conversation history (#480)
78
- Fix: make unflatten_dict symmetric to flatten_dict (#461)
89
- Cost and capabilities config for custom litellm models (#481)

packages/ragbits-core/src/ragbits/core/embeddings/fastembed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ async def embed_text(self, data: list[str], options: EmbedderOptionsT | None = N
5757
"""
5858
merged_options = (self.default_options | options) if options else self.default_options
5959
with trace(
60-
data=data, model_name=self.model_name, model=repr(self._model), options=merged_options.dict()
60+
data=data, model_name=self.model_name, model_obj=repr(self._model), options=merged_options.dict()
6161
) as outputs:
6262
embeddings = [[float(x) for x in result] for result in self._model.embed(data, **merged_options.dict())]
6363
outputs.embeddings = embeddings
@@ -104,7 +104,7 @@ async def embed_text(self, data: list[str], options: EmbedderOptionsT | None = N
104104
"""
105105
merged_options = (self.default_options | options) if options else self.default_options
106106
with trace(
107-
data=data, model_name=self.model_name, model=repr(self._model), options=merged_options.dict()
107+
data=data, model_name=self.model_name, model_obj=repr(self._model), options=merged_options.dict()
108108
) as outputs:
109109
outputs.embeddings = [
110110
SparseVector(values=[float(x) for x in result.values], indices=[int(x) for x in result.indices])

packages/ragbits-core/src/ragbits/core/embeddings/litellm.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,11 @@ class LiteLLMEmbedder(Embedder[LiteLLMEmbedderOptions]):
3636

3737
def __init__(
3838
self,
39-
model: str = "text-embedding-3-small",
39+
model_name: str = "text-embedding-3-small",
4040
default_options: LiteLLMEmbedderOptions | None = None,
41+
*,
4142
api_base: str | None = None,
43+
base_url: str | None = None, # Alias for api_base
4244
api_key: str | None = None,
4345
api_version: str | None = None,
4446
router: litellm.Router | None = None,
@@ -47,19 +49,21 @@ def __init__(
4749
Constructs the LiteLLMEmbeddingClient.
4850
4951
Args:
50-
model: Name of the [LiteLLM supported model](https://docs.litellm.ai/docs/embedding/supported_embedding)\
52+
model_name: Name of the [LiteLLM supported model](https://docs.litellm.ai/docs/embedding/supported_embedding)\
5153
to be used. Default is "text-embedding-3-small".
5254
default_options: Default options to pass to the LiteLLM API.
5355
api_base: The API endpoint you want to call the model with.
54-
api_key: API key to be used. API key to be used. If not specified, an environment variable will be used,
56+
base_url: Alias for api_base. If both are provided, api_base takes precedence.
57+
api_key: API key to be used. If not specified, an environment variable will be used,
5558
for more information, follow the instructions for your specific vendor in the\
5659
[LiteLLM documentation](https://docs.litellm.ai/docs/embedding/supported_embedding).
5760
api_version: The API version for the call.
5861
router: Router to be used to [route requests](https://docs.litellm.ai/docs/routing) to different models.
5962
"""
6063
super().__init__(default_options=default_options)
61-
self.model = model
62-
self.api_base = api_base
64+
65+
self.model_name = model_name
66+
self.api_base = api_base or base_url
6367
self.api_key = api_key
6468
self.api_version = api_version
6569
self.router = router
@@ -85,7 +89,7 @@ async def embed_text(self, data: list[str], options: LiteLLMEmbedderOptions | No
8589

8690
with trace(
8791
data=data,
88-
model=self.model,
92+
model=self.model_name,
8993
api_base=self.api_base,
9094
api_version=self.api_version,
9195
options=merged_options.dict(),
@@ -94,7 +98,7 @@ async def embed_text(self, data: list[str], options: LiteLLMEmbedderOptions | No
9498
entrypoint = self.router or litellm
9599
response = await entrypoint.aembedding(
96100
input=data,
97-
model=self.model,
101+
model=self.model_name,
98102
api_base=self.api_base,
99103
api_key=self.api_key,
100104
api_version=self.api_version,
@@ -132,4 +136,9 @@ def from_config(cls, config: dict[str, Any]) -> Self:
132136
if "router" in config:
133137
router = litellm.router.Router(model_list=config["router"])
134138
config["router"] = router
139+
140+
# Map base_url to api_base if present
141+
if "base_url" in config and "api_base" not in config:
142+
config["api_base"] = config.pop("base_url")
143+
135144
return super().from_config(config)

packages/ragbits-core/src/ragbits/core/embeddings/local.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ async def embed_text(self, data: list[str], options: LocalEmbedderOptions | None
7373
with trace(
7474
data=data,
7575
model_name=self.model_name,
76-
model=repr(self.model),
76+
model_obj=repr(self.model),
7777
tokenizer=repr(self.tokenizer),
7878
device=self.device,
7979
options=merged_options.dict(),

packages/ragbits-core/src/ragbits/core/embeddings/vertex_multimodal.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,9 @@ class VertexAIMultimodelEmbedder(Embedder[LiteLLMEmbedderOptions]):
3030

3131
def __init__(
3232
self,
33-
model: str = "multimodalembedding",
33+
model_name: str = "multimodalembedding",
3434
api_base: str | None = None,
35+
base_url: str | None = None, # Alias for api_base
3536
api_key: str | None = None,
3637
concurency: int = 10,
3738
default_options: LiteLLMEmbedderOptions | None = None,
@@ -40,8 +41,9 @@ def __init__(
4041
Constructs the embedding client for multimodal VertexAI models.
4142
4243
Args:
43-
model: One of the VertexAI multimodal models to be used. Default is "multimodalembedding".
44+
model_name: One of the VertexAI multimodal models to be used. Default is "multimodalembedding".
4445
api_base: The API endpoint you want to call the model with.
46+
base_url: Alias for api_base. If both are provided, api_base takes precedence.
4547
api_key: API key to be used. If not specified, an environment variable will be used.
4648
concurency: The number of concurrent requests to make to the API.
4749
default_options: Additional options to pass to the API.
@@ -54,17 +56,18 @@ def __init__(
5456
raise ImportError("You need to install the 'litellm' extra requirements to use LiteLLM embeddings models")
5557

5658
super().__init__(default_options=default_options)
57-
if model.startswith(self.VERTEX_AI_PREFIX):
58-
model = model[len(self.VERTEX_AI_PREFIX) :]
5959

60-
self.model = model
61-
self.api_base = api_base
60+
if model_name.startswith(self.VERTEX_AI_PREFIX):
61+
model_name = model_name[len(self.VERTEX_AI_PREFIX) :]
62+
63+
self.model_name = model_name
64+
self.api_base = api_base or base_url
6265
self.api_key = api_key
6366
self.concurency = concurency
6467

6568
supported_models = VertexMultimodalEmbedding().SUPPORTED_MULTIMODAL_EMBEDDING_MODELS
66-
if model not in supported_models:
67-
raise ValueError(f"Model {model} is not supported by VertexAI multimodal embeddings")
69+
if model_name not in supported_models:
70+
raise ValueError(f"Model {model_name} is not supported by VertexAI multimodal embeddings")
6871

6972
async def _embed(self, data: list[dict], options: LiteLLMEmbedderOptions | None = None) -> list[dict]:
7073
"""
@@ -86,7 +89,7 @@ async def _embed(self, data: list[dict], options: LiteLLMEmbedderOptions | None
8689
merged_options = (self.default_options | options) if options else self.default_options
8790
with trace(
8891
data=data,
89-
model=self.model,
92+
model=self.model_name,
9093
api_base=self.api_base,
9194
options=merged_options.dict(),
9295
) as outputs:
@@ -123,7 +126,7 @@ async def _call_litellm(
123126
async with semaphore:
124127
response = await litellm.aembedding(
125128
input=[instance],
126-
model=f"{self.VERTEX_AI_PREFIX}{self.model}",
129+
model=f"{self.VERTEX_AI_PREFIX}{self.model_name}",
127130
api_base=self.api_base,
128131
api_key=self.api_key,
129132
**options.dict(),

packages/ragbits-core/src/ragbits/core/llms/litellm.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ def __init__(
5252
model_name: str = "gpt-3.5-turbo",
5353
default_options: LiteLLMOptions | None = None,
5454
*,
55-
base_url: str | None = None,
55+
api_base: str | None = None,
56+
base_url: str | None = None, # Alias for api_base
5657
api_key: str | None = None,
5758
api_version: str | None = None,
5859
use_structured_output: bool = False,
@@ -66,7 +67,8 @@ def __init__(
6667
model_name: Name of the [LiteLLM supported model](https://docs.litellm.ai/docs/providers) to be used.\
6768
Default is "gpt-3.5-turbo".
6869
default_options: Default options to be used.
69-
base_url: Base URL of the LLM API.
70+
api_base: Base URL of the LLM API.
71+
base_url: Alias for api_base. If both are provided, api_base takes precedence.
7072
api_key: API key to be used. API key to be used. If not specified, an environment variable will be used,
7173
for more information, follow the instructions for your specific vendor in the\
7274
[LiteLLM documentation](https://docs.litellm.ai/docs/providers).
@@ -81,7 +83,7 @@ def __init__(
8183
for more information.
8284
"""
8385
super().__init__(model_name, default_options)
84-
self.base_url = base_url
86+
self.api_base = api_base or base_url
8587
self.api_key = api_key
8688
self.api_version = api_version
8789
self.use_structured_output = use_structured_output
@@ -187,7 +189,7 @@ async def _call_streaming(
187189
with trace(
188190
messages=prompt.chat,
189191
model=self.model_name,
190-
base_url=self.base_url,
192+
base_url=self.api_base,
191193
api_version=self.api_version,
192194
response_format=response_format,
193195
options=options.dict(),
@@ -222,7 +224,7 @@ async def _get_litellm_response(
222224
response = await entrypoint.acompletion(
223225
messages=conversation,
224226
model=self.model_name,
225-
base_url=self.base_url,
227+
base_url=self.api_base,
226228
api_key=self.api_key,
227229
api_version=self.api_version,
228230
response_format=response_format,
@@ -250,6 +252,13 @@ def _get_response_format(
250252
response_format = {"type": "json_object"}
251253
return response_format
252254

255+
@property
256+
def base_url(self) -> str | None:
257+
"""
258+
Returns the base URL of the LLM API. Alias for `api_base`.
259+
"""
260+
return self.api_base
261+
253262
@classmethod
254263
def from_config(cls, config: dict[str, Any]) -> Self:
255264
"""
@@ -264,13 +273,18 @@ def from_config(cls, config: dict[str, Any]) -> Self:
264273
if "router" in config:
265274
router = litellm.router.Router(model_list=config["router"])
266275
config["router"] = router
276+
277+
# Map base_url to api_base if present
278+
if "base_url" in config and "api_base" not in config:
279+
config["api_base"] = config.pop("base_url")
280+
267281
return super().from_config(config)
268282

269283
def __reduce__(self) -> tuple[Callable, tuple]:
270284
config = {
271285
"model_name": self.model_name,
272286
"default_options": self.default_options.dict(),
273-
"base_url": self.base_url,
287+
"api_base": self.api_base,
274288
"api_key": self.api_key,
275289
"api_version": self.api_version,
276290
"use_structured_output": self.use_structured_output,

packages/ragbits-core/tests/unit/embeddings/test_from_config.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def test_subclass_from_config_litellm():
1212
{
1313
"type": "ragbits.core.embeddings.litellm:LiteLLMEmbedder",
1414
"config": {
15-
"model": "some_model",
15+
"model_name": "some_model",
1616
"default_options": {
1717
"option1": "value1",
1818
"option2": "value2",
@@ -22,7 +22,7 @@ def test_subclass_from_config_litellm():
2222
)
2323
embedder: Embedder = Embedder.subclass_from_config(config)
2424
assert isinstance(embedder, LiteLLMEmbedder)
25-
assert embedder.model == "some_model"
25+
assert embedder.model_name == "some_model"
2626
assert embedder.default_options == LiteLLMEmbedderOptions(
2727
dimensions=NOT_GIVEN,
2828
timeout=NOT_GIVEN,
@@ -66,7 +66,7 @@ def test_from_config_with_router():
6666
config = ObjectConstructionConfig(
6767
type="ragbits.core.embeddings.litellm:LiteLLMEmbedder",
6868
config={
69-
"model": "text-embedding-3-small",
69+
"model_name": "text-embedding-3-small",
7070
"api_key": "test_api_key",
7171
"router": [
7272
{
@@ -91,7 +91,7 @@ def test_from_config_with_router():
9191
embedder: Embedder = Embedder.subclass_from_config(config)
9292
assert isinstance(embedder, LiteLLMEmbedder)
9393
assert embedder.api_base is None
94-
assert embedder.model == "text-embedding-3-small"
94+
assert embedder.model_name == "text-embedding-3-small"
9595
assert embedder.api_key == "test_api_key"
9696
assert isinstance(embedder.router, litellm.router.Router)
9797
assert len(embedder.router.model_list) == 2

packages/ragbits-core/tests/unit/llms/test_from_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def test_from_config_with_router():
6565

6666
llm: LLM = LLM.subclass_from_config(config)
6767
assert isinstance(llm, LiteLLM)
68-
assert llm.base_url is None
68+
assert llm.api_base is None
6969
assert llm.model_name == "gpt-4-turbo"
7070
assert llm.api_key == "test_api_key"
7171
assert isinstance(llm.router, litellm.router.Router)

0 commit comments

Comments
 (0)