1
1
"""All builtin functions."""
2
- from typing import Annotated
2
+ from typing import Annotated , Any
3
3
4
4
import json
5
5
import sentence_transformers
@@ -13,8 +13,16 @@ class SplitRecursively(op.FunctionSpec):
13
13
language : str | None = None
14
14
15
15
class SentenceTransformerEmbed (op .FunctionSpec ):
16
- """Run the sentence transformer"""
16
+ """
17
+ `SentenceTransformerEmbed` embeds a text into a vector space using the [SentenceTransformer](https://huggingface.co/sentence-transformers) library.
18
+
19
+ Args:
20
+
21
+ model: The name of the SentenceTransformer model to use.
22
+ args: Additional arguments to pass to the SentenceTransformer constructor. e.g. {"trust_remote_code": True}
23
+ """
17
24
model : str
25
+ args : dict [str , Any ] | None = None
18
26
19
27
@op .executor_class (gpu = True , cache = True , behavior_version = 1 )
20
28
class SentenceTransformerEmbedExecutor :
@@ -24,7 +32,8 @@ class SentenceTransformerEmbedExecutor:
24
32
_model : sentence_transformers .SentenceTransformer
25
33
26
34
def analyze (self , text = None ):
27
- self ._model = sentence_transformers .SentenceTransformer (self .spec .model , 3 )
35
+ args = self .spec .args or {}
36
+ self ._model = sentence_transformers .SentenceTransformer (self .spec .model , ** args )
28
37
dim = self ._model .get_sentence_embedding_dimension ()
29
38
return Annotated [list [Float32 ], Vector (dim = dim ), TypeAttr ("cocoindex.io/vector_origin_text" , json .loads (text .analyzed_value ))]
30
39
0 commit comments