Skip to content

Commit dcafa35

Browse files
iamlemecabetlen
authored andcommitted
feat: Allow for possibly non-pooled embeddings (abetlen#1380)
* allow for possibly non-pooled embeddings * add more to embeddings section in README.md --------- Co-authored-by: Andrei <abetlen@gmail.com>
1 parent a410adf commit dcafa35

File tree

4 files changed

+60
-19
lines changed

4 files changed

+60
-19
lines changed

llama_cpp/_internals.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,10 @@ def n_ctx(self) -> int:
273273
assert self.ctx is not None
274274
return llama_cpp.llama_n_ctx(self.ctx)
275275

276+
def pooling_type(self) -> int:
277+
assert self.ctx is not None
278+
return llama_cpp.llama_pooling_type(self.ctx)
279+
276280
def kv_cache_clear(self):
277281
assert self.ctx is not None
278282
llama_cpp.llama_kv_cache_clear(self.ctx)
@@ -641,6 +645,16 @@ def _should_add_bos(model: _LlamaModel) -> bool:
641645
return llama_cpp.llama_vocab_type(model.model) == llama_cpp.LLAMA_VOCAB_TYPE_SPM
642646

643647

648+
# Embedding functions
649+
650+
651+
def _normalize_embedding(embedding):
652+
norm = float(np.linalg.norm(embedding))
653+
if norm == 0.0:
654+
return embedding
655+
return [v / norm for v in embedding]
656+
657+
644658
# Python wrappers over common/sampling structs
645659

646660

llama_cpp/llama.py

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
_LlamaTokenDataArray, # type: ignore
5151
_LlamaSamplingParams, # type: ignore
5252
_LlamaSamplingContext, # type: ignore
53+
_normalize_embedding, # type: ignore
5354
)
5455
from ._logger import set_verbose
5556
from ._utils import suppress_stdout_stderr
@@ -760,7 +761,7 @@ def create_embedding(
760761
input = input if isinstance(input, list) else [input]
761762

762763
# get numeric embeddings
763-
embeds: List[List[float]]
764+
embeds: Union[List[List[float]], List[List[List[float]]]]
764765
total_tokens: int
765766
embeds, total_tokens = self.embed(input, return_count=True) # type: ignore
766767

@@ -787,7 +788,7 @@ def create_embedding(
787788
def embed(
788789
self,
789790
input: Union[str, List[str]],
790-
normalize: bool = True,
791+
normalize: bool = False,
791792
truncate: bool = True,
792793
return_count: bool = False,
793794
):
@@ -803,6 +804,10 @@ def embed(
803804
n_embd = self.n_embd()
804805
n_batch = self.n_batch
805806

807+
# get pooling information
808+
pooling_type = self.pooling_type()
809+
logits_all = pooling_type == llama_cpp.LLAMA_POOLING_TYPE_NONE
810+
806811
if self.context_params.embeddings == False:
807812
raise RuntimeError(
808813
"Llama model must be created with embedding=True to call this method"
@@ -820,29 +825,37 @@ def embed(
820825
self._batch.reset()
821826

822827
# decode and fetch embeddings
823-
data: List[List[float]] = []
828+
data: Union[List[List[float]], List[List[List[float]]]] = []
824829

825-
def decode_batch(n_seq: int):
830+
def decode_batch(seq_sizes: List[int]):
826831
assert self._ctx.ctx is not None
827832
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
828833
self._ctx.decode(self._batch)
829834
self._batch.reset()
830835

831836
# store embeddings
832-
for i in range(n_seq):
833-
ptr = llama_cpp.llama_get_embeddings_seq(
834-
self._ctx.ctx, i
835-
)
836-
if not ptr:
837-
raise RuntimeError("Failed to get embeddings from sequence pooling type is not set")
838-
embedding: List[float] = ptr[:n_embd]
839-
if normalize:
840-
norm = float(np.linalg.norm(embedding))
841-
embedding = [v / norm for v in embedding]
842-
data.append(embedding)
837+
if pooling_type == llama_cpp.LLAMA_POOLING_TYPE_NONE:
838+
pos: int = 0
839+
for i, size in enumerate(seq_sizes):
840+
ptr = llama_cpp.llama_get_embeddings(self._ctx.ctx)
841+
embedding: List[List[float]] = [
842+
ptr[pos + j * n_embd : pos + (j + 1) * n_embd] for j in range(size)
843+
]
844+
if normalize:
845+
embedding = [_normalize_embedding(e) for e in embedding]
846+
data.append(embedding)
847+
pos += size
848+
else:
849+
for i in range(len(seq_sizes)):
850+
ptr = llama_cpp.llama_get_embeddings_seq(self._ctx.ctx, i)
851+
embedding: List[float] = ptr[:n_embd]
852+
if normalize:
853+
embedding = _normalize_embedding(embedding)
854+
data.append(embedding)
843855

844856
# init state
845857
total_tokens = 0
858+
s_batch = []
846859
t_batch = 0
847860
p_batch = 0
848861

@@ -863,17 +876,21 @@ def decode_batch(n_seq: int):
863876

864877
# time to eval batch
865878
if t_batch + n_tokens > n_batch:
866-
decode_batch(p_batch)
879+
decode_batch(s_batch)
880+
s_batch = []
867881
t_batch = 0
868882
p_batch = 0
869883

870884
# add to batch
871-
self._batch.add_sequence(tokens, p_batch, False)
885+
self._batch.add_sequence(tokens, p_batch, logits_all)
886+
887+
# update batch stats
888+
s_batch.append(n_tokens)
872889
t_batch += n_tokens
873890
p_batch += 1
874891

875892
# hanlde last batch
876-
decode_batch(p_batch)
893+
decode_batch(s_batch)
877894

878895
if self.verbose:
879896
llama_cpp.llama_print_timings(self._ctx.ctx)
@@ -1845,6 +1862,10 @@ def token_nl(self) -> int:
18451862
"""Return the newline token."""
18461863
return self._model.token_nl()
18471864

1865+
def pooling_type(self) -> str:
1866+
"""Return the pooling type."""
1867+
return self._ctx.pooling_type()
1868+
18481869
@staticmethod
18491870
def logits_to_logprobs(
18501871
logits: Union[npt.NDArray[np.single], List], axis: int = -1

llama_cpp/llama_cpp.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1189,6 +1189,12 @@ def llama_rope_type(model: llama_model_p, /) -> int:
11891189
...
11901190

11911191

1192+
# LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_model * model);
1193+
@ctypes_function("llama_pooling_type", [llama_model_p_ctypes], ctypes.c_int)
1194+
def llama_pooling_type(model: llama_model_p, /) -> int:
1195+
...
1196+
1197+
11921198
# LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
11931199
@ctypes_function("llama_n_vocab", [llama_model_p_ctypes], ctypes.c_int32)
11941200
def llama_n_vocab(model: llama_model_p, /) -> int:

llama_cpp/llama_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class EmbeddingUsage(TypedDict):
2424
class Embedding(TypedDict):
2525
index: int
2626
object: str
27-
embedding: List[float]
27+
embedding: Union[List[float], List[List[float]]]
2828

2929

3030
class CreateEmbeddingResponse(TypedDict):

0 commit comments

Comments
 (0)