|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import os |
| 4 | +import ctypes |
| 5 | +import typing |
| 6 | +import contextlib |
| 7 | + |
| 8 | +import numpy as np |
| 9 | + |
| 10 | +import llama_cpp |
| 11 | +import llama_cpp.llava_cpp as llava_cpp |
| 12 | + |
| 13 | + |
| 14 | +class LlavaEmbedding: |
| 15 | + def __init__(self, embedding: ctypes._Pointer[llava_cpp.llava_image_embed]): |
| 16 | + self._embedding = embedding |
| 17 | + self._exit_stack = contextlib.ExitStack() |
| 18 | + |
| 19 | + def llava_image_embed_free(): |
| 20 | + llava_cpp.llava_image_embed_free(self._embedding) |
| 21 | + |
| 22 | + self._exit_stack.callback(llava_image_embed_free) |
| 23 | + |
| 24 | + @property |
| 25 | + def n_image_pos(self) -> int: |
| 26 | + return self._embedding.contents.n_image_pos |
| 27 | + |
| 28 | + def embed( |
| 29 | + self, llama_ctx: llama_cpp.llama_context_p, n_tokens: int, n_batch: int |
| 30 | + ) -> int: |
| 31 | + n_past = ctypes.c_int(n_tokens) |
| 32 | + n_past_p = ctypes.pointer(n_past) |
| 33 | + llava_cpp.llava_eval_image_embed( |
| 34 | + llama_ctx, |
| 35 | + self._embedding, |
| 36 | + n_batch, |
| 37 | + n_past_p, |
| 38 | + ) |
| 39 | + return n_past.value |
| 40 | + |
| 41 | + def numpy_view(self, shape: typing.Tuple[int, int]) -> np.ndarray: |
| 42 | + return np.ctypeslib.as_array( |
| 43 | + self._embedding.contents.embed, shape=shape |
| 44 | + ) |
| 45 | + |
| 46 | + |
| 47 | +class LlavaModel: |
| 48 | + def __init__(self, path: str, n_threads: int = 1): |
| 49 | + self._path = path |
| 50 | + self._n_threads = n_threads |
| 51 | + self._exit_stack = contextlib.ExitStack() |
| 52 | + |
| 53 | + if not os.path.exists(self._path): |
| 54 | + raise ValueError(f"Clip model path does not exist: {self._path}") |
| 55 | + |
| 56 | + clip_ctx = llava_cpp.clip_model_load(self._path.encode(), 0) |
| 57 | + |
| 58 | + if clip_ctx is None: |
| 59 | + raise ValueError(f"Failed to load clip model: {self._path}") |
| 60 | + |
| 61 | + self._clip_ctx = clip_ctx |
| 62 | + |
| 63 | + def clip_free(): |
| 64 | + llava_cpp.clip_free(self._clip_ctx) |
| 65 | + print("Clip model freed") |
| 66 | + |
| 67 | + self._exit_stack.callback(clip_free) |
| 68 | + |
| 69 | + def embed_bytes(self, image_bytes: bytes): |
| 70 | + embed = llava_cpp.llava_image_embed_make_with_bytes( |
| 71 | + self._clip_ctx, |
| 72 | + self._n_threads, |
| 73 | + (ctypes.c_uint8 * len(image_bytes)).from_buffer(bytearray(image_bytes)), |
| 74 | + len(image_bytes), |
| 75 | + ) |
| 76 | + return LlavaEmbedding(embed) |
| 77 | + |
0 commit comments