Skip to content

Commit 0e182be

Browse files
committed
Cache last image embed
1 parent 94fe4bc commit 0e182be

File tree

1 file changed

+27
-14
lines changed

1 file changed

+27
-14
lines changed

llama_cpp/llama_chat_format.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2193,6 +2193,8 @@ def __init__(self, clip_model_path: str, verbose: bool = False):
21932193

21942194
self._llava_cpp = llava_cpp # TODO: Fix
21952195
self._exit_stack = ExitStack()
2196+
self._last_image_embed: Optional[llava_cpp.CtypesPointer[llava_cpp.llava_image_embed]] = None
2197+
self._last_image_hash: Optional[int] = None
21962198

21972199
if not os.path.exists(clip_model_path):
21982200
raise ValueError(f"Clip model path does not exist: {clip_model_path}")
@@ -2212,6 +2214,14 @@ def clip_free():
22122214
self._llava_cpp.clip_free(self.clip_ctx)
22132215

22142216
self._exit_stack.callback(clip_free)
2217+
2218+
def last_image_embed_free():
2219+
with suppress_stdout_stderr(disable=self.verbose):
2220+
if self._last_image_embed is not None:
2221+
self._llava_cpp.llava_image_embed_free(self._last_image_embed)
2222+
self._last_image_embed = None
2223+
2224+
self._exit_stack.callback(last_image_embed_free)
22152225

22162226
def load_image(self, image_url: str) -> bytes:
22172227
return self._load_image(image_url)
@@ -2270,6 +2280,22 @@ def __call__(
22702280
text = template.render(messages=messages, add_generation_prompt=True)
22712281
split_text = self.split_text_on_image_urls(text, image_urls)
22722282

2283+
def embed_image_bytes(image_bytes: bytes):
2284+
if self._last_image_embed is not None and self._last_image_hash is not None and hash(image_bytes) == self._last_image_hash:
2285+
return self._last_image_embed
2286+
with suppress_stdout_stderr(disable=self.verbose):
2287+
embed = (
2288+
self._llava_cpp.llava_image_embed_make_with_bytes(
2289+
self.clip_ctx,
2290+
llama.context_params.n_threads_batch,
2291+
(ctypes.c_uint8 * len(image_bytes)).from_buffer(bytearray(image_bytes)),
2292+
len(image_bytes),
2293+
)
2294+
)
2295+
self._last_image_embed = embed
2296+
self._last_image_hash = hash(image_bytes)
2297+
return embed
2298+
22732299
# Evaluate prompt
22742300
llama.reset()
22752301
for i, (type_, value) in enumerate(split_text):
@@ -2280,20 +2306,7 @@ def __call__(
22802306
llama.eval(tokens)
22812307
else:
22822308
image_bytes = self.load_image(value)
2283-
exit_stack = ExitStack()
2284-
with suppress_stdout_stderr(disable=self.verbose):
2285-
embed = (
2286-
self._llava_cpp.llava_image_embed_make_with_bytes(
2287-
self.clip_ctx,
2288-
llama.context_params.n_threads_batch,
2289-
(ctypes.c_uint8 * len(image_bytes)).from_buffer(bytearray(image_bytes)),
2290-
len(image_bytes),
2291-
)
2292-
)
2293-
def free_embed():
2294-
with suppress_stdout_stderr(disable=self.verbose):
2295-
self._llava_cpp.llava_image_embed_free(embed)
2296-
exit_stack.callback(free_embed)
2309+
embed = embed_image_bytes(image_bytes)
22972310
if llama.n_tokens + embed.contents.n_image_pos > llama.n_ctx():
22982311
raise ValueError("Prompt exceeds n_ctx") # TODO: Fix
22992312
n_past = ctypes.c_int(llama.n_tokens)

0 commit comments

Comments
 (0)