Skip to content

Allow for possibly non-pooled embeddings #1380

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions llama_cpp/_internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,10 @@ def n_ctx(self) -> int:
assert self.ctx is not None
return llama_cpp.llama_n_ctx(self.ctx)

def pooling_type(self) -> int:
assert self.ctx is not None
return llama_cpp.llama_pooling_type(self.ctx)

def kv_cache_clear(self):
assert self.ctx is not None
llama_cpp.llama_kv_cache_clear(self.ctx)
Expand Down Expand Up @@ -641,6 +645,16 @@ def _should_add_bos(model: _LlamaModel) -> bool:
return llama_cpp.llama_vocab_type(model.model) == llama_cpp.LLAMA_VOCAB_TYPE_SPM


# Embedding functions


def _normalize_embedding(embedding):
norm = float(np.linalg.norm(embedding))
if norm == 0.0:
return embedding
return [v / norm for v in embedding]


# Python wrappers over common/sampling structs


Expand Down
57 changes: 39 additions & 18 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
_LlamaTokenDataArray, # type: ignore
_LlamaSamplingParams, # type: ignore
_LlamaSamplingContext, # type: ignore
_normalize_embedding, # type: ignore
)
from ._logger import set_verbose
from ._utils import suppress_stdout_stderr
Expand Down Expand Up @@ -760,7 +761,7 @@ def create_embedding(
input = input if isinstance(input, list) else [input]

# get numeric embeddings
embeds: List[List[float]]
embeds: Union[List[List[float]], List[List[List[float]]]]
total_tokens: int
embeds, total_tokens = self.embed(input, return_count=True) # type: ignore

Expand All @@ -787,7 +788,7 @@ def create_embedding(
def embed(
self,
input: Union[str, List[str]],
normalize: bool = True,
normalize: bool = False,
truncate: bool = True,
return_count: bool = False,
):
Expand All @@ -803,6 +804,10 @@ def embed(
n_embd = self.n_embd()
n_batch = self.n_batch

# get pooling information
pooling_type = self.pooling_type()
logits_all = pooling_type == llama_cpp.LLAMA_POOLING_TYPE_NONE

if self.context_params.embeddings == False:
raise RuntimeError(
"Llama model must be created with embedding=True to call this method"
Expand All @@ -820,29 +825,37 @@ def embed(
self._batch.reset()

# decode and fetch embeddings
data: List[List[float]] = []
data: Union[List[List[float]], List[List[List[float]]]] = []

def decode_batch(n_seq: int):
def decode_batch(seq_sizes: List[int]):
assert self._ctx.ctx is not None
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
self._ctx.decode(self._batch)
self._batch.reset()

# store embeddings
for i in range(n_seq):
ptr = llama_cpp.llama_get_embeddings_seq(
self._ctx.ctx, i
)
if not ptr:
raise RuntimeError("Failed to get embeddings from sequence pooling type is not set")
embedding: List[float] = ptr[:n_embd]
if normalize:
norm = float(np.linalg.norm(embedding))
embedding = [v / norm for v in embedding]
data.append(embedding)
if pooling_type == llama_cpp.LLAMA_POOLING_TYPE_NONE:
pos: int = 0
for i, size in enumerate(seq_sizes):
ptr = llama_cpp.llama_get_embeddings(self._ctx.ctx)
embedding: List[List[float]] = [
ptr[pos + j * n_embd : pos + (j + 1) * n_embd] for j in range(size)
]
if normalize:
embedding = [_normalize_embedding(e) for e in embedding]
data.append(embedding)
pos += size
else:
for i in range(len(seq_sizes)):
ptr = llama_cpp.llama_get_embeddings_seq(self._ctx.ctx, i)
embedding: List[float] = ptr[:n_embd]
if normalize:
embedding = _normalize_embedding(embedding)
data.append(embedding)

# init state
total_tokens = 0
s_batch = []
t_batch = 0
p_batch = 0

Expand All @@ -863,17 +876,21 @@ def decode_batch(n_seq: int):

# time to eval batch
if t_batch + n_tokens > n_batch:
decode_batch(p_batch)
decode_batch(s_batch)
s_batch = []
t_batch = 0
p_batch = 0

# add to batch
self._batch.add_sequence(tokens, p_batch, False)
self._batch.add_sequence(tokens, p_batch, logits_all)

# update batch stats
s_batch.append(n_tokens)
t_batch += n_tokens
p_batch += 1

# hanlde last batch
decode_batch(p_batch)
decode_batch(s_batch)

if self.verbose:
llama_cpp.llama_print_timings(self._ctx.ctx)
Expand Down Expand Up @@ -1845,6 +1862,10 @@ def token_nl(self) -> int:
"""Return the newline token."""
return self._model.token_nl()

def pooling_type(self) -> str:
"""Return the pooling type."""
return self._ctx.pooling_type()

@staticmethod
def logits_to_logprobs(
logits: Union[npt.NDArray[np.single], List], axis: int = -1
Expand Down
6 changes: 6 additions & 0 deletions llama_cpp/llama_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1185,6 +1185,12 @@ def llama_rope_type(model: llama_model_p, /) -> int:
...


# LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_model * model);
@ctypes_function("llama_pooling_type", [llama_model_p_ctypes], ctypes.c_int)
def llama_pooling_type(model: llama_model_p, /) -> int:
...


Comment on lines +1192 to +1197
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_model * model);
@ctypes_function("llama_pooling_type", [llama_model_p_ctypes], ctypes.c_int)
def llama_pooling_type(model: llama_model_p, /) -> int:
...

I added this function elsewhere in the file to match the llama.h order so we can delete this.

# LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
@ctypes_function("llama_n_vocab", [llama_model_p_ctypes], ctypes.c_int32)
def llama_n_vocab(model: llama_model_p, /) -> int:
Expand Down
2 changes: 1 addition & 1 deletion llama_cpp/llama_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class EmbeddingUsage(TypedDict):
class Embedding(TypedDict):
index: int
object: str
embedding: List[float]
embedding: Union[List[float], List[List[float]]]


class CreateEmbeddingResponse(TypedDict):
Expand Down
Loading