Skip to content

Commit 10cc835

Browse files
authored
refactor: move BagOfTokens model_name / encoding_name parameters to init (#592)
1 parent 3056374 commit 10cc835

File tree

4 files changed

+80
-54
lines changed

4 files changed

+80
-54
lines changed

packages/ragbits-core/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
## Unreleased
44

5+
- Refacor: move BagOfTokens model_name / encoding_name parameters to init (#592)
56
- Update utils (#590)
67
- Resolve vector_size by PgVectorStore automatically (#588)
78
- Add get_vector_size method to all Embedders (#587)

packages/ragbits-core/src/ragbits/core/embeddings/sparse/bag_of_tokens.py

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
class BagOfTokensOptions(Options):
1313
"""A dataclass with definition of BOT options"""
1414

15-
model_name: str | None | NotGiven = "gpt-4o"
16-
encoding_name: str | None | NotGiven = NOT_GIVEN
1715
min_token_count: int | None | NotGiven = NOT_GIVEN
1816

1917

@@ -22,31 +20,48 @@ class BagOfTokens(SparseEmbedder[BagOfTokensOptions]):
2220

2321
options_cls = BagOfTokensOptions
2422

25-
async def get_vector_size(self) -> VectorSize:
23+
def __init__(
24+
self,
25+
model_name: str | None = None,
26+
encoding_name: str | None = None,
27+
default_options: BagOfTokensOptions | None = None,
28+
) -> None:
2629
"""
27-
Get the vector size for this BagOfTokens model.
30+
Initialize the BagOfTokens embedder.
2831
29-
For BagOfTokens, this returns the tokenizer vocabulary size.
32+
Args:
33+
model_name: Name of the model to use for tokenization (e.g., "gpt-4o").
34+
encoding_name: Name of the encoding to use for tokenization.
35+
default_options: Default options for the embedder.
3036
31-
Returns:
32-
VectorSize object with is_sparse=True and the vocabulary size.
37+
Raises:
38+
ValueError: If both model_name and encoding_name are provided, or if neither is provided.
3339
"""
34-
merged_options = self.default_options
40+
super().__init__(default_options=default_options)
3541

36-
if merged_options.encoding_name and merged_options.model_name:
42+
if encoding_name and model_name:
3743
raise ValueError("Please specify only one of encoding_name or model_name")
38-
if not (merged_options.encoding_name or merged_options.model_name):
39-
raise ValueError("Either encoding_name or model_name needs to be specified")
40-
41-
if merged_options.encoding_name:
42-
encoder = tiktoken.get_encoding(encoding_name=merged_options.encoding_name)
43-
elif merged_options.model_name:
44-
encoder = tiktoken.encoding_for_model(model_name=merged_options.model_name)
44+
if not (encoding_name or model_name):
45+
# Default to gpt-4o if neither is specified
46+
model_name = "gpt-4o"
47+
48+
if encoding_name:
49+
self._encoder = tiktoken.get_encoding(encoding_name=encoding_name)
50+
elif model_name:
51+
self._encoder = tiktoken.encoding_for_model(model_name=model_name)
4552
else:
4653
raise ValueError("Either encoding_name or model_name needs to be specified")
4754

48-
# Get the vocabulary size from the encoder
49-
vocab_size = encoder.n_vocab
55+
async def get_vector_size(self) -> VectorSize:
56+
"""
57+
Get the vector size for this BagOfTokens model.
58+
59+
For BagOfTokens, this returns the tokenizer vocabulary size.
60+
61+
Returns:
62+
VectorSize object with is_sparse=True and the vocabulary size.
63+
"""
64+
vocab_size = self._encoder.n_vocab
5065
return VectorSize(size=vocab_size, is_sparse=True)
5166

5267
async def embed_text(self, texts: list[str], options: BagOfTokensOptions | None = None) -> list[SparseVector]:
@@ -63,21 +78,9 @@ async def embed_text(self, texts: list[str], options: BagOfTokensOptions | None
6378
vectors = []
6479
merged_options = self.default_options | options if options else self.default_options
6580
with trace(data=texts, options=merged_options.dict()) as outputs:
66-
if merged_options.encoding_name and merged_options.model_name:
67-
raise ValueError("Please specify only one of encoding_name or model_name")
68-
if not (merged_options.encoding_name or merged_options.model_name):
69-
raise ValueError("Either encoding_name or model_name needs to be specified")
70-
71-
if merged_options.encoding_name:
72-
encoder = tiktoken.get_encoding(encoding_name=merged_options.encoding_name)
73-
elif merged_options.model_name:
74-
encoder = tiktoken.encoding_for_model(model_name=merged_options.model_name)
75-
else:
76-
raise ValueError("Either encoding_name or model_name needs to be specified")
77-
7881
min_token_count = merged_options.min_token_count or float("-inf")
7982
for text in texts:
80-
tokens = encoder.encode(text)
83+
tokens = self._encoder.encode(text)
8184
token_counts = Counter(tokens)
8285
non_zero_dims = []
8386
non_zero_vals = []

packages/ragbits-core/tests/unit/embeddings/test_bag_of_tokens.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,11 @@
22

33
from ragbits.core.embeddings.base import VectorSize
44
from ragbits.core.embeddings.sparse.bag_of_tokens import BagOfTokens, BagOfTokensOptions
5-
from ragbits.core.types import NOT_GIVEN
65

76

87
async def test_bag_of_tokens_get_vector_size_with_encoding():
98
"""Test BagOfTokens get_vector_size method with encoding_name."""
10-
options = BagOfTokensOptions(encoding_name="cl100k_base", model_name=NOT_GIVEN)
11-
embedder = BagOfTokens(default_options=options)
9+
embedder = BagOfTokens(encoding_name="cl100k_base")
1210

1311
vector_size = await embedder.get_vector_size()
1412

@@ -20,8 +18,7 @@ async def test_bag_of_tokens_get_vector_size_with_encoding():
2018

2119
async def test_bag_of_tokens_get_vector_size_with_model():
2220
"""Test BagOfTokens get_vector_size method with model_name."""
23-
options = BagOfTokensOptions(model_name="gpt-3.5-turbo")
24-
embedder = BagOfTokens(default_options=options)
21+
embedder = BagOfTokens(model_name="gpt-3.5-turbo")
2522

2623
vector_size = await embedder.get_vector_size()
2724

@@ -44,26 +41,22 @@ async def test_bag_of_tokens_get_vector_size_default():
4441

4542
async def test_bag_of_tokens_get_vector_size_error_both_specified():
4643
"""Test BagOfTokens get_vector_size raises error when both encoding_name and model_name are specified."""
47-
options = BagOfTokensOptions(encoding_name="cl100k_base", model_name="gpt-3.5-turbo")
48-
embedder = BagOfTokens(default_options=options)
49-
5044
with pytest.raises(ValueError, match="Please specify only one of encoding_name or model_name"):
51-
await embedder.get_vector_size()
45+
BagOfTokens(encoding_name="cl100k_base", model_name="gpt-3.5-turbo")
5246

5347

5448
async def test_bag_of_tokens_get_vector_size_error_none_specified():
5549
"""Test BagOfTokens get_vector_size raises error when neither encoding_name nor model_name are specified."""
56-
options = BagOfTokensOptions(encoding_name=NOT_GIVEN, model_name=NOT_GIVEN)
57-
embedder = BagOfTokens(default_options=options)
58-
59-
with pytest.raises(ValueError, match="Either encoding_name or model_name needs to be specified"):
60-
await embedder.get_vector_size()
50+
# This test is no longer valid since we now default to gpt-4o when nothing is specified
51+
# The constructor will automatically use gpt-4o as default
52+
embedder = BagOfTokens()
53+
vector_size = await embedder.get_vector_size()
54+
assert vector_size.size > 0 # Should succeed with default gpt-4o
6155

6256

6357
async def test_bag_of_tokens_embed_text_consistency():
6458
"""Test that BagOfTokens embeddings are consistent with vector size."""
65-
options = BagOfTokensOptions(encoding_name="cl100k_base", model_name=NOT_GIVEN)
66-
embedder = BagOfTokens(default_options=options)
59+
embedder = BagOfTokens(encoding_name="cl100k_base")
6760

6861
# Get vector size
6962
vector_size = await embedder.get_vector_size()
@@ -79,15 +72,26 @@ async def test_bag_of_tokens_embed_text_consistency():
7972

8073
async def test_bag_of_tokens_different_encodings():
8174
"""Test BagOfTokens with different encodings have different vocabulary sizes."""
82-
options1 = BagOfTokensOptions(encoding_name="cl100k_base", model_name=NOT_GIVEN)
83-
embedder1 = BagOfTokens(default_options=options1)
84-
85-
options2 = BagOfTokensOptions(encoding_name="p50k_base", model_name=NOT_GIVEN)
86-
embedder2 = BagOfTokens(default_options=options2)
75+
embedder1 = BagOfTokens(encoding_name="cl100k_base")
76+
embedder2 = BagOfTokens(encoding_name="p50k_base")
8777

8878
vector_size1 = await embedder1.get_vector_size()
8979
vector_size2 = await embedder2.get_vector_size()
9080

9181
assert vector_size1.size != vector_size2.size
9282
assert vector_size1.is_sparse is True
9383
assert vector_size2.is_sparse is True
84+
85+
86+
async def test_bag_of_tokens_min_token_count_option():
87+
"""Test BagOfTokens with min_token_count option."""
88+
embedder = BagOfTokens(encoding_name="cl100k_base")
89+
options = BagOfTokensOptions(min_token_count=2)
90+
91+
# Test with text that has some repeated tokens
92+
embeddings = await embedder.embed_text(["test test test"], options=options)
93+
94+
# Should have embeddings (non-empty vectors)
95+
assert len(embeddings) == 1
96+
assert len(embeddings[0].indices) > 0
97+
assert len(embeddings[0].values) > 0

packages/ragbits-core/tests/unit/embeddings/test_from_config.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import litellm
2+
import pytest
23

34
from ragbits.core.embeddings import DenseEmbedder, NoopEmbedder
45
from ragbits.core.embeddings.dense import LiteLLMEmbedder, LiteLLMEmbedderOptions
@@ -44,6 +45,7 @@ def test_subclass_from_config_bag_of_tokens():
4445
{
4546
"type": "ragbits.core.embeddings.sparse:BagOfTokens",
4647
"config": {
48+
"model_name": "gpt-4o",
4749
"default_options": {
4850
"option1": "value1",
4951
"option2": "value2",
@@ -54,14 +56,30 @@ def test_subclass_from_config_bag_of_tokens():
5456
embedder: SparseEmbedder = SparseEmbedder.subclass_from_config(config)
5557
assert isinstance(embedder, BagOfTokens)
5658
assert embedder.default_options == BagOfTokensOptions(
57-
model_name="gpt-4o",
58-
encoding_name=NOT_GIVEN,
5959
min_token_count=NOT_GIVEN,
6060
option1="value1",
6161
option2="value2",
6262
) # type: ignore
6363

6464

65+
def test_subclass_from_config_bag_of_tokens_both_specified():
66+
config = ObjectConstructionConfig.model_validate(
67+
{
68+
"type": "ragbits.core.embeddings.sparse:BagOfTokens",
69+
"config": {
70+
"model_name": "gpt-4o",
71+
"encoding_name": "cl100k_base",
72+
"default_options": {
73+
"option1": "value1",
74+
"option2": "value2",
75+
},
76+
},
77+
}
78+
)
79+
with pytest.raises(ValueError, match="Please specify only one of encoding_name or model_name"):
80+
SparseEmbedder.subclass_from_config(config)
81+
82+
6583
def test_from_config_with_router():
6684
config = ObjectConstructionConfig(
6785
type="ragbits.core.embeddings.dense:LiteLLMEmbedder",

0 commit comments

Comments
 (0)