Skip to content

Commit 08cc0b8

Browse files
committed
Add Paligemma support
1 parent 7c4aead commit 08cc0b8

File tree

3 files changed

+242
-1
lines changed

3 files changed

+242
-1
lines changed

llama_cpp/llama_chat_format.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,11 @@
2828
import numpy as np
2929
import numpy.typing as npt
3030

31+
import llama_cpp
3132
import llama_cpp.llama as llama
3233
import llama_cpp.llama_types as llama_types
3334
import llama_cpp.llama_grammar as llama_grammar
35+
import llama_cpp._internals as internals
3436

3537
from ._logger import logger
3638
from ._utils import suppress_stdout_stderr, Singleton
@@ -3350,6 +3352,204 @@ class MiniCPMv26ChatHandler(Llava15ChatHandler):
33503352
)
33513353

33523354

3355+
class PaligemmaChatHandler(Llava15ChatHandler):
3356+
def __call__(
3357+
self,
3358+
*,
3359+
llama: llama.Llama,
3360+
messages: List[llama_types.ChatCompletionRequestMessage],
3361+
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
3362+
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
3363+
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
3364+
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
3365+
temperature: float = 0.2,
3366+
top_p: float = 0.95,
3367+
top_k: int = 40,
3368+
min_p: float = 0.05,
3369+
typical_p: float = 1.0,
3370+
stream: bool = False,
3371+
stop: Optional[Union[str, List[str]]] = [],
3372+
seed: Optional[int] = None,
3373+
response_format: Optional[
3374+
llama_types.ChatCompletionRequestResponseFormat
3375+
] = None,
3376+
max_tokens: Optional[int] = None,
3377+
presence_penalty: float = 0.0,
3378+
frequency_penalty: float = 0.0,
3379+
repeat_penalty: float = 1.1,
3380+
tfs_z: float = 1.0,
3381+
mirostat_mode: int = 0,
3382+
mirostat_tau: float = 5.0,
3383+
mirostat_eta: float = 0.1,
3384+
model: Optional[str] = None,
3385+
logits_processor: Optional[llama.LogitsProcessorList] = None,
3386+
grammar: Optional[llama.LlamaGrammar] = None,
3387+
logit_bias: Optional[Dict[str, float]] = None,
3388+
logprobs: Optional[bool] = None,
3389+
top_logprobs: Optional[int] = None,
3390+
**kwargs, # type: ignore
3391+
) -> Union[
3392+
llama_types.CreateChatCompletionResponse,
3393+
Iterator[llama_types.CreateChatCompletionStreamResponse],
3394+
]:
3395+
assert self.clip_ctx is not None
3396+
3397+
if len(messages) != 1:
3398+
raise ValueError("PaligemmaChatHandler only supports single-turn conversations.")
3399+
3400+
image_urls = self.get_image_urls(messages)
3401+
3402+
if len(image_urls) > 1:
3403+
raise ValueError("PaligemmaChatHandler only supports single image per turn.")
3404+
3405+
text = "<s>answer en "
3406+
message = messages[0]
3407+
if isinstance(message["content"], str):
3408+
text = message["content"]
3409+
elif isinstance(message["content"], list):
3410+
for content in message["content"]:
3411+
if content["type"] == "text":
3412+
text += content["text"]
3413+
text += "\n"
3414+
3415+
if self.verbose:
3416+
print(text, file=sys.stderr)
3417+
3418+
3419+
3420+
tokens = llama.tokenize(text.encode("utf-8"), special=True)
3421+
embedding_dim = llama_cpp.llama_n_embd(llama.model)
3422+
tokens_np = np.array(tokens).astype(np.int32)
3423+
token_embedding = np.empty((len(tokens), embedding_dim), dtype=np.single)
3424+
llama_cpp.llama_token_inp_embd(
3425+
llama.ctx,
3426+
tokens_np.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
3427+
len(tokens),
3428+
token_embedding.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
3429+
)
3430+
3431+
if len(image_urls) > 0:
3432+
image_embedding = self._embed_image_bytes(self.load_image(image_urls[0]))
3433+
n_image_pos = image_embedding.contents.n_image_pos
3434+
embeds = np.concatenate([np.ctypeslib.as_array(image_embedding.contents.embed, shape=(n_image_pos, embedding_dim)), token_embedding], axis=0)
3435+
n_tokens = n_image_pos + len(tokens)
3436+
llama.input_ids[: n_tokens] = (
3437+
llama.tokenize(b"<image>", add_bos=False, special=True) * image_embedding.contents.n_image_pos + tokens
3438+
)
3439+
else:
3440+
n_tokens = len(tokens)
3441+
llama.input_ids[: n_tokens] = tokens
3442+
embeds = token_embedding
3443+
3444+
3445+
n_batch = 512
3446+
batch = internals.LlamaBatch(n_tokens=n_batch, embd=embedding_dim, n_seq_max=1)
3447+
3448+
batch.batch.n_tokens = n_tokens
3449+
3450+
np.ctypeslib.as_array(batch.batch.embd, shape=(n_batch, embedding_dim))[
3451+
:n_tokens, :
3452+
] = embeds
3453+
np.ctypeslib.as_array(batch.batch.pos, shape=(n_batch,))[:n_tokens] = np.arange(n_tokens)
3454+
np.ctypeslib.as_array(batch.batch.n_seq_id, shape=(n_batch,))[:] = 1
3455+
np.ctypeslib.as_array(batch.batch.logits, shape=(n_batch,))[:] = False
3456+
np.ctypeslib.as_array(batch.batch.logits, shape=(n_batch,))[n_tokens - 1] = True
3457+
3458+
for i in range(n_tokens):
3459+
batch.batch.seq_id[i][0] = 0
3460+
3461+
# Evaluate prompt
3462+
llama.reset()
3463+
llama._ctx.kv_cache_clear()
3464+
llama_cpp.llama_set_causal_attn(llama._ctx.ctx, False)
3465+
llama._ctx.decode(batch)
3466+
llama.n_tokens += n_tokens
3467+
llama_cpp.llama_set_causal_attn(llama._ctx.ctx, True)
3468+
3469+
# Get prompt tokens to avoid a cache miss
3470+
prompt = llama.input_ids[: llama.n_tokens].tolist()
3471+
3472+
if response_format is not None and response_format["type"] == "json_object":
3473+
grammar = _grammar_for_response_format(response_format)
3474+
3475+
# Convert legacy functions to tools
3476+
if functions is not None:
3477+
tools = [
3478+
{
3479+
"type": "function",
3480+
"function": function,
3481+
}
3482+
for function in functions
3483+
]
3484+
3485+
# Convert legacy function_call to tool_choice
3486+
if function_call is not None:
3487+
if isinstance(function_call, str) and (
3488+
function_call == "none" or function_call == "auto"
3489+
):
3490+
tool_choice = function_call
3491+
if isinstance(function_call, dict) and "name" in function_call:
3492+
tool_choice = {
3493+
"type": "function",
3494+
"function": {
3495+
"name": function_call["name"],
3496+
},
3497+
}
3498+
3499+
tool = None
3500+
if (
3501+
tool_choice is not None
3502+
and isinstance(tool_choice, dict)
3503+
and tools is not None
3504+
):
3505+
name = tool_choice["function"]["name"]
3506+
tool = next((t for t in tools if t["function"]["name"] == name), None)
3507+
if tool is None:
3508+
raise ValueError(f"Tool choice '{name}' not found in tools.")
3509+
schema = tool["function"]["parameters"]
3510+
try:
3511+
# create grammar from json schema
3512+
grammar = llama_grammar.LlamaGrammar.from_json_schema(
3513+
json.dumps(schema), verbose=llama.verbose
3514+
)
3515+
except Exception as e:
3516+
if llama.verbose:
3517+
print(str(e), file=sys.stderr)
3518+
grammar = llama_grammar.LlamaGrammar.from_string(
3519+
llama_grammar.JSON_GBNF, verbose=llama.verbose
3520+
)
3521+
3522+
completion_or_chunks = llama.create_completion(
3523+
prompt=prompt,
3524+
temperature=temperature,
3525+
top_p=top_p,
3526+
top_k=top_k,
3527+
min_p=min_p,
3528+
typical_p=typical_p,
3529+
logprobs=top_logprobs if logprobs else None,
3530+
stream=stream,
3531+
stop=stop,
3532+
seed=seed,
3533+
max_tokens=max_tokens,
3534+
presence_penalty=presence_penalty,
3535+
frequency_penalty=frequency_penalty,
3536+
repeat_penalty=repeat_penalty,
3537+
tfs_z=tfs_z,
3538+
mirostat_mode=mirostat_mode,
3539+
mirostat_tau=mirostat_tau,
3540+
mirostat_eta=mirostat_eta,
3541+
model=model,
3542+
logits_processor=logits_processor,
3543+
grammar=grammar,
3544+
logit_bias=logit_bias,
3545+
)
3546+
if tool is not None:
3547+
tool_name = tool["function"]["name"]
3548+
return _convert_completion_to_chat_function(
3549+
tool_name, completion_or_chunks, stream
3550+
)
3551+
return _convert_completion_to_chat(completion_or_chunks, stream=stream)
3552+
33533553
@register_chat_completion_handler("chatml-function-calling")
33543554
def chatml_function_calling(
33553555
llama: llama.Llama,

llama_cpp/llama_cpp.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2836,6 +2836,47 @@ def llama_detokenize(
28362836
...
28372837

28382838

2839+
# // @details Get the input embeddings for a sequence of tokens
2840+
# // @param tokens The tokens to embed
2841+
# // @param n_tokens The number of tokens
2842+
# // @param embeddings The embeddings pointer must be large enough to hold the resulting embeddings.
2843+
# // @param n_embd The number of embeddings per token
2844+
# // @return Returns a negative number on failure
2845+
# LLAMA_API int32_t llama_token_inp_embd(
2846+
# struct llama_context * ctx,
2847+
# llama_token * tokens,
2848+
# int32_t n_tokens,
2849+
# float * embeddings);
2850+
@ctypes_function(
2851+
"llama_token_inp_embd",
2852+
[
2853+
llama_context_p_ctypes,
2854+
llama_token_p,
2855+
ctypes.c_int32,
2856+
ctypes.POINTER(ctypes.c_float),
2857+
],
2858+
ctypes.c_int32,
2859+
)
2860+
def llama_token_inp_embd(
2861+
ctx: llama_context_p,
2862+
tokens: CtypesArray[llama_token],
2863+
n_tokens: Union[ctypes.c_int32, int],
2864+
embeddings: CtypesArray[ctypes.c_float],
2865+
/,
2866+
) -> int:
2867+
"""Get the input embeddings for a sequence of tokens
2868+
2869+
Args:
2870+
ctx: The model context.
2871+
tokens: The tokens to embed.
2872+
n_tokens: The number of tokens.
2873+
embeddings: The embeddings pointer must be large enough to hold the resulting embeddings.
2874+
2875+
Returns:
2876+
Returns a negative number on failure"""
2877+
...
2878+
2879+
28392880
# //
28402881
# // Chat templates
28412882
# //

vendor/llama.cpp

0 commit comments

Comments
 (0)