|
28 | 28 | import numpy as np
|
29 | 29 | import numpy.typing as npt
|
30 | 30 |
|
| 31 | +import llama_cpp |
31 | 32 | import llama_cpp.llama as llama
|
32 | 33 | import llama_cpp.llama_types as llama_types
|
33 | 34 | import llama_cpp.llama_grammar as llama_grammar
|
| 35 | +import llama_cpp._internals as internals |
34 | 36 |
|
35 | 37 | from ._logger import logger
|
36 | 38 | from ._utils import suppress_stdout_stderr, Singleton
|
@@ -3350,6 +3352,204 @@ class MiniCPMv26ChatHandler(Llava15ChatHandler):
|
3350 | 3352 | )
|
3351 | 3353 |
|
3352 | 3354 |
|
| 3355 | +class PaligemmaChatHandler(Llava15ChatHandler): |
| 3356 | + def __call__( |
| 3357 | + self, |
| 3358 | + *, |
| 3359 | + llama: llama.Llama, |
| 3360 | + messages: List[llama_types.ChatCompletionRequestMessage], |
| 3361 | + functions: Optional[List[llama_types.ChatCompletionFunction]] = None, |
| 3362 | + function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None, |
| 3363 | + tools: Optional[List[llama_types.ChatCompletionTool]] = None, |
| 3364 | + tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None, |
| 3365 | + temperature: float = 0.2, |
| 3366 | + top_p: float = 0.95, |
| 3367 | + top_k: int = 40, |
| 3368 | + min_p: float = 0.05, |
| 3369 | + typical_p: float = 1.0, |
| 3370 | + stream: bool = False, |
| 3371 | + stop: Optional[Union[str, List[str]]] = [], |
| 3372 | + seed: Optional[int] = None, |
| 3373 | + response_format: Optional[ |
| 3374 | + llama_types.ChatCompletionRequestResponseFormat |
| 3375 | + ] = None, |
| 3376 | + max_tokens: Optional[int] = None, |
| 3377 | + presence_penalty: float = 0.0, |
| 3378 | + frequency_penalty: float = 0.0, |
| 3379 | + repeat_penalty: float = 1.1, |
| 3380 | + tfs_z: float = 1.0, |
| 3381 | + mirostat_mode: int = 0, |
| 3382 | + mirostat_tau: float = 5.0, |
| 3383 | + mirostat_eta: float = 0.1, |
| 3384 | + model: Optional[str] = None, |
| 3385 | + logits_processor: Optional[llama.LogitsProcessorList] = None, |
| 3386 | + grammar: Optional[llama.LlamaGrammar] = None, |
| 3387 | + logit_bias: Optional[Dict[str, float]] = None, |
| 3388 | + logprobs: Optional[bool] = None, |
| 3389 | + top_logprobs: Optional[int] = None, |
| 3390 | + **kwargs, # type: ignore |
| 3391 | + ) -> Union[ |
| 3392 | + llama_types.CreateChatCompletionResponse, |
| 3393 | + Iterator[llama_types.CreateChatCompletionStreamResponse], |
| 3394 | + ]: |
| 3395 | + assert self.clip_ctx is not None |
| 3396 | + |
| 3397 | + if len(messages) != 1: |
| 3398 | + raise ValueError("PaligemmaChatHandler only supports single-turn conversations.") |
| 3399 | + |
| 3400 | + image_urls = self.get_image_urls(messages) |
| 3401 | + |
| 3402 | + if len(image_urls) > 1: |
| 3403 | + raise ValueError("PaligemmaChatHandler only supports single image per turn.") |
| 3404 | + |
| 3405 | + text = "<s>answer en " |
| 3406 | + message = messages[0] |
| 3407 | + if isinstance(message["content"], str): |
| 3408 | + text = message["content"] |
| 3409 | + elif isinstance(message["content"], list): |
| 3410 | + for content in message["content"]: |
| 3411 | + if content["type"] == "text": |
| 3412 | + text += content["text"] |
| 3413 | + text += "\n" |
| 3414 | + |
| 3415 | + if self.verbose: |
| 3416 | + print(text, file=sys.stderr) |
| 3417 | + |
| 3418 | + |
| 3419 | + |
| 3420 | + tokens = llama.tokenize(text.encode("utf-8"), special=True) |
| 3421 | + embedding_dim = llama_cpp.llama_n_embd(llama.model) |
| 3422 | + tokens_np = np.array(tokens).astype(np.int32) |
| 3423 | + token_embedding = np.empty((len(tokens), embedding_dim), dtype=np.single) |
| 3424 | + llama_cpp.llama_token_inp_embd( |
| 3425 | + llama.ctx, |
| 3426 | + tokens_np.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)), |
| 3427 | + len(tokens), |
| 3428 | + token_embedding.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), |
| 3429 | + ) |
| 3430 | + |
| 3431 | + if len(image_urls) > 0: |
| 3432 | + image_embedding = self._embed_image_bytes(self.load_image(image_urls[0])) |
| 3433 | + n_image_pos = image_embedding.contents.n_image_pos |
| 3434 | + embeds = np.concatenate([np.ctypeslib.as_array(image_embedding.contents.embed, shape=(n_image_pos, embedding_dim)), token_embedding], axis=0) |
| 3435 | + n_tokens = n_image_pos + len(tokens) |
| 3436 | + llama.input_ids[: n_tokens] = ( |
| 3437 | + llama.tokenize(b"<image>", add_bos=False, special=True) * image_embedding.contents.n_image_pos + tokens |
| 3438 | + ) |
| 3439 | + else: |
| 3440 | + n_tokens = len(tokens) |
| 3441 | + llama.input_ids[: n_tokens] = tokens |
| 3442 | + embeds = token_embedding |
| 3443 | + |
| 3444 | + |
| 3445 | + n_batch = 512 |
| 3446 | + batch = internals.LlamaBatch(n_tokens=n_batch, embd=embedding_dim, n_seq_max=1) |
| 3447 | + |
| 3448 | + batch.batch.n_tokens = n_tokens |
| 3449 | + |
| 3450 | + np.ctypeslib.as_array(batch.batch.embd, shape=(n_batch, embedding_dim))[ |
| 3451 | + :n_tokens, : |
| 3452 | + ] = embeds |
| 3453 | + np.ctypeslib.as_array(batch.batch.pos, shape=(n_batch,))[:n_tokens] = np.arange(n_tokens) |
| 3454 | + np.ctypeslib.as_array(batch.batch.n_seq_id, shape=(n_batch,))[:] = 1 |
| 3455 | + np.ctypeslib.as_array(batch.batch.logits, shape=(n_batch,))[:] = False |
| 3456 | + np.ctypeslib.as_array(batch.batch.logits, shape=(n_batch,))[n_tokens - 1] = True |
| 3457 | + |
| 3458 | + for i in range(n_tokens): |
| 3459 | + batch.batch.seq_id[i][0] = 0 |
| 3460 | + |
| 3461 | + # Evaluate prompt |
| 3462 | + llama.reset() |
| 3463 | + llama._ctx.kv_cache_clear() |
| 3464 | + llama_cpp.llama_set_causal_attn(llama._ctx.ctx, False) |
| 3465 | + llama._ctx.decode(batch) |
| 3466 | + llama.n_tokens += n_tokens |
| 3467 | + llama_cpp.llama_set_causal_attn(llama._ctx.ctx, True) |
| 3468 | + |
| 3469 | + # Get prompt tokens to avoid a cache miss |
| 3470 | + prompt = llama.input_ids[: llama.n_tokens].tolist() |
| 3471 | + |
| 3472 | + if response_format is not None and response_format["type"] == "json_object": |
| 3473 | + grammar = _grammar_for_response_format(response_format) |
| 3474 | + |
| 3475 | + # Convert legacy functions to tools |
| 3476 | + if functions is not None: |
| 3477 | + tools = [ |
| 3478 | + { |
| 3479 | + "type": "function", |
| 3480 | + "function": function, |
| 3481 | + } |
| 3482 | + for function in functions |
| 3483 | + ] |
| 3484 | + |
| 3485 | + # Convert legacy function_call to tool_choice |
| 3486 | + if function_call is not None: |
| 3487 | + if isinstance(function_call, str) and ( |
| 3488 | + function_call == "none" or function_call == "auto" |
| 3489 | + ): |
| 3490 | + tool_choice = function_call |
| 3491 | + if isinstance(function_call, dict) and "name" in function_call: |
| 3492 | + tool_choice = { |
| 3493 | + "type": "function", |
| 3494 | + "function": { |
| 3495 | + "name": function_call["name"], |
| 3496 | + }, |
| 3497 | + } |
| 3498 | + |
| 3499 | + tool = None |
| 3500 | + if ( |
| 3501 | + tool_choice is not None |
| 3502 | + and isinstance(tool_choice, dict) |
| 3503 | + and tools is not None |
| 3504 | + ): |
| 3505 | + name = tool_choice["function"]["name"] |
| 3506 | + tool = next((t for t in tools if t["function"]["name"] == name), None) |
| 3507 | + if tool is None: |
| 3508 | + raise ValueError(f"Tool choice '{name}' not found in tools.") |
| 3509 | + schema = tool["function"]["parameters"] |
| 3510 | + try: |
| 3511 | + # create grammar from json schema |
| 3512 | + grammar = llama_grammar.LlamaGrammar.from_json_schema( |
| 3513 | + json.dumps(schema), verbose=llama.verbose |
| 3514 | + ) |
| 3515 | + except Exception as e: |
| 3516 | + if llama.verbose: |
| 3517 | + print(str(e), file=sys.stderr) |
| 3518 | + grammar = llama_grammar.LlamaGrammar.from_string( |
| 3519 | + llama_grammar.JSON_GBNF, verbose=llama.verbose |
| 3520 | + ) |
| 3521 | + |
| 3522 | + completion_or_chunks = llama.create_completion( |
| 3523 | + prompt=prompt, |
| 3524 | + temperature=temperature, |
| 3525 | + top_p=top_p, |
| 3526 | + top_k=top_k, |
| 3527 | + min_p=min_p, |
| 3528 | + typical_p=typical_p, |
| 3529 | + logprobs=top_logprobs if logprobs else None, |
| 3530 | + stream=stream, |
| 3531 | + stop=stop, |
| 3532 | + seed=seed, |
| 3533 | + max_tokens=max_tokens, |
| 3534 | + presence_penalty=presence_penalty, |
| 3535 | + frequency_penalty=frequency_penalty, |
| 3536 | + repeat_penalty=repeat_penalty, |
| 3537 | + tfs_z=tfs_z, |
| 3538 | + mirostat_mode=mirostat_mode, |
| 3539 | + mirostat_tau=mirostat_tau, |
| 3540 | + mirostat_eta=mirostat_eta, |
| 3541 | + model=model, |
| 3542 | + logits_processor=logits_processor, |
| 3543 | + grammar=grammar, |
| 3544 | + logit_bias=logit_bias, |
| 3545 | + ) |
| 3546 | + if tool is not None: |
| 3547 | + tool_name = tool["function"]["name"] |
| 3548 | + return _convert_completion_to_chat_function( |
| 3549 | + tool_name, completion_or_chunks, stream |
| 3550 | + ) |
| 3551 | + return _convert_completion_to_chat(completion_or_chunks, stream=stream) |
| 3552 | + |
3353 | 3553 | @register_chat_completion_handler("chatml-function-calling")
|
3354 | 3554 | def chatml_function_calling(
|
3355 | 3555 | llama: llama.Llama,
|
|
0 commit comments