Skip to content

Commit c7e511f

Browse files
authored
SentenceTransformerEmbed support additional args passed to the library. (#33)
* `SentenceTransformerEmbed` support additional args for the library. * Update documentation for `SentenceTransformerEmbed`.
1 parent ccf9d1b commit c7e511f

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

docs/docs/ops/functions.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ Return type: `Table`, each row represents a chunk, with the following sub fields
3333
The spec takes the following fields:
3434

3535
* `model` (type: `str`, required): The name of the SentenceTransformer model to use.
36+
* `args` (type: `dict[str, Any]`, optional): Additional arguments to pass to the SentenceTransformer constructor. e.g. `{"trust_remote_code": True}`
3637

3738
Input data:
3839

python/cocoindex/functions.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""All builtin functions."""
2-
from typing import Annotated
2+
from typing import Annotated, Any
33

44
import json
55
import sentence_transformers
@@ -13,8 +13,16 @@ class SplitRecursively(op.FunctionSpec):
1313
language: str | None = None
1414

1515
class SentenceTransformerEmbed(op.FunctionSpec):
16-
"""Run the sentence transformer"""
16+
"""
17+
`SentenceTransformerEmbed` embeds a text into a vector space using the [SentenceTransformer](https://huggingface.co/sentence-transformers) library.
18+
19+
Args:
20+
21+
model: The name of the SentenceTransformer model to use.
22+
args: Additional arguments to pass to the SentenceTransformer constructor. e.g. {"trust_remote_code": True}
23+
"""
1724
model: str
25+
args: dict[str, Any] | None = None
1826

1927
@op.executor_class(gpu=True, cache=True, behavior_version=1)
2028
class SentenceTransformerEmbedExecutor:
@@ -24,7 +32,8 @@ class SentenceTransformerEmbedExecutor:
2432
_model: sentence_transformers.SentenceTransformer
2533

2634
def analyze(self, text = None):
27-
self._model = sentence_transformers.SentenceTransformer(self.spec.model, 3)
35+
args = self.spec.args or {}
36+
self._model = sentence_transformers.SentenceTransformer(self.spec.model, **args)
2837
dim = self._model.get_sentence_embedding_dimension()
2938
return Annotated[list[Float32], Vector(dim=dim), TypeAttr("cocoindex.io/vector_origin_text", json.loads(text.analyzed_value))]
3039

0 commit comments

Comments
 (0)