50
50
_LlamaTokenDataArray , # type: ignore
51
51
_LlamaSamplingParams , # type: ignore
52
52
_LlamaSamplingContext , # type: ignore
53
+ _normalize_embedding , # type: ignore
53
54
)
54
55
from ._logger import set_verbose
55
56
from ._utils import suppress_stdout_stderr
@@ -760,7 +761,7 @@ def create_embedding(
760
761
input = input if isinstance (input , list ) else [input ]
761
762
762
763
# get numeric embeddings
763
- embeds : List [List [float ]]
764
+ embeds : Union [ List [List [float ]], List [ List [ List [ float ]] ]]
764
765
total_tokens : int
765
766
embeds , total_tokens = self .embed (input , return_count = True ) # type: ignore
766
767
@@ -787,7 +788,7 @@ def create_embedding(
787
788
def embed (
788
789
self ,
789
790
input : Union [str , List [str ]],
790
- normalize : bool = True ,
791
+ normalize : bool = False ,
791
792
truncate : bool = True ,
792
793
return_count : bool = False ,
793
794
):
@@ -803,6 +804,10 @@ def embed(
803
804
n_embd = self .n_embd ()
804
805
n_batch = self .n_batch
805
806
807
+ # get pooling information
808
+ pooling_type = self .pooling_type ()
809
+ logits_all = pooling_type == llama_cpp .LLAMA_POOLING_TYPE_NONE
810
+
806
811
if self .context_params .embeddings == False :
807
812
raise RuntimeError (
808
813
"Llama model must be created with embedding=True to call this method"
@@ -820,29 +825,37 @@ def embed(
820
825
self ._batch .reset ()
821
826
822
827
# decode and fetch embeddings
823
- data : List [List [float ]] = []
828
+ data : Union [ List [List [float ]], List [ List [ List [ float ]] ]] = []
824
829
825
- def decode_batch (n_seq : int ):
830
+ def decode_batch (seq_sizes : List [ int ] ):
826
831
assert self ._ctx .ctx is not None
827
832
llama_cpp .llama_kv_cache_clear (self ._ctx .ctx )
828
833
self ._ctx .decode (self ._batch )
829
834
self ._batch .reset ()
830
835
831
836
# store embeddings
832
- for i in range (n_seq ):
833
- ptr = llama_cpp .llama_get_embeddings_seq (
834
- self ._ctx .ctx , i
835
- )
836
- if not ptr :
837
- raise RuntimeError ("Failed to get embeddings from sequence pooling type is not set" )
838
- embedding : List [float ] = ptr [:n_embd ]
839
- if normalize :
840
- norm = float (np .linalg .norm (embedding ))
841
- embedding = [v / norm for v in embedding ]
842
- data .append (embedding )
837
+ if pooling_type == llama_cpp .LLAMA_POOLING_TYPE_NONE :
838
+ pos : int = 0
839
+ for i , size in enumerate (seq_sizes ):
840
+ ptr = llama_cpp .llama_get_embeddings (self ._ctx .ctx )
841
+ embedding : List [List [float ]] = [
842
+ ptr [pos + j * n_embd : pos + (j + 1 ) * n_embd ] for j in range (size )
843
+ ]
844
+ if normalize :
845
+ embedding = [_normalize_embedding (e ) for e in embedding ]
846
+ data .append (embedding )
847
+ pos += size
848
+ else :
849
+ for i in range (len (seq_sizes )):
850
+ ptr = llama_cpp .llama_get_embeddings_seq (self ._ctx .ctx , i )
851
+ embedding : List [float ] = ptr [:n_embd ]
852
+ if normalize :
853
+ embedding = _normalize_embedding (embedding )
854
+ data .append (embedding )
843
855
844
856
# init state
845
857
total_tokens = 0
858
+ s_batch = []
846
859
t_batch = 0
847
860
p_batch = 0
848
861
@@ -863,17 +876,21 @@ def decode_batch(n_seq: int):
863
876
864
877
# time to eval batch
865
878
if t_batch + n_tokens > n_batch :
866
- decode_batch (p_batch )
879
+ decode_batch (s_batch )
880
+ s_batch = []
867
881
t_batch = 0
868
882
p_batch = 0
869
883
870
884
# add to batch
871
- self ._batch .add_sequence (tokens , p_batch , False )
885
+ self ._batch .add_sequence (tokens , p_batch , logits_all )
886
+
887
+ # update batch stats
888
+ s_batch .append (n_tokens )
872
889
t_batch += n_tokens
873
890
p_batch += 1
874
891
875
892
# hanlde last batch
876
- decode_batch (p_batch )
893
+ decode_batch (s_batch )
877
894
878
895
if self .verbose :
879
896
llama_cpp .llama_print_timings (self ._ctx .ctx )
@@ -1845,6 +1862,10 @@ def token_nl(self) -> int:
1845
1862
"""Return the newline token."""
1846
1863
return self ._model .token_nl ()
1847
1864
1865
+ def pooling_type (self ) -> str :
1866
+ """Return the pooling type."""
1867
+ return self ._ctx .pooling_type ()
1868
+
1848
1869
@staticmethod
1849
1870
def logits_to_logprobs (
1850
1871
logits : Union [npt .NDArray [np .single ], List ], axis : int = - 1
0 commit comments