@@ -2193,6 +2193,8 @@ def __init__(self, clip_model_path: str, verbose: bool = False):
2193
2193
2194
2194
self ._llava_cpp = llava_cpp # TODO: Fix
2195
2195
self ._exit_stack = ExitStack ()
2196
+ self ._last_image_embed : Optional [llava_cpp .CtypesPointer [llava_cpp .llava_image_embed ]] = None
2197
+ self ._last_image_hash : Optional [int ] = None
2196
2198
2197
2199
if not os .path .exists (clip_model_path ):
2198
2200
raise ValueError (f"Clip model path does not exist: { clip_model_path } " )
@@ -2212,6 +2214,14 @@ def clip_free():
2212
2214
self ._llava_cpp .clip_free (self .clip_ctx )
2213
2215
2214
2216
self ._exit_stack .callback (clip_free )
2217
+
2218
+ def last_image_embed_free ():
2219
+ with suppress_stdout_stderr (disable = self .verbose ):
2220
+ if self ._last_image_embed is not None :
2221
+ self ._llava_cpp .llava_image_embed_free (self ._last_image_embed )
2222
+ self ._last_image_embed = None
2223
+
2224
+ self ._exit_stack .callback (last_image_embed_free )
2215
2225
2216
2226
def load_image (self , image_url : str ) -> bytes :
2217
2227
return self ._load_image (image_url )
@@ -2270,6 +2280,22 @@ def __call__(
2270
2280
text = template .render (messages = messages , add_generation_prompt = True )
2271
2281
split_text = self .split_text_on_image_urls (text , image_urls )
2272
2282
2283
+ def embed_image_bytes (image_bytes : bytes ):
2284
+ if self ._last_image_embed is not None and self ._last_image_hash is not None and hash (image_bytes ) == self ._last_image_hash :
2285
+ return self ._last_image_embed
2286
+ with suppress_stdout_stderr (disable = self .verbose ):
2287
+ embed = (
2288
+ self ._llava_cpp .llava_image_embed_make_with_bytes (
2289
+ self .clip_ctx ,
2290
+ llama .context_params .n_threads_batch ,
2291
+ (ctypes .c_uint8 * len (image_bytes )).from_buffer (bytearray (image_bytes )),
2292
+ len (image_bytes ),
2293
+ )
2294
+ )
2295
+ self ._last_image_embed = embed
2296
+ self ._last_image_hash = hash (image_bytes )
2297
+ return embed
2298
+
2273
2299
# Evaluate prompt
2274
2300
llama .reset ()
2275
2301
for i , (type_ , value ) in enumerate (split_text ):
@@ -2280,20 +2306,7 @@ def __call__(
2280
2306
llama .eval (tokens )
2281
2307
else :
2282
2308
image_bytes = self .load_image (value )
2283
- exit_stack = ExitStack ()
2284
- with suppress_stdout_stderr (disable = self .verbose ):
2285
- embed = (
2286
- self ._llava_cpp .llava_image_embed_make_with_bytes (
2287
- self .clip_ctx ,
2288
- llama .context_params .n_threads_batch ,
2289
- (ctypes .c_uint8 * len (image_bytes )).from_buffer (bytearray (image_bytes )),
2290
- len (image_bytes ),
2291
- )
2292
- )
2293
- def free_embed ():
2294
- with suppress_stdout_stderr (disable = self .verbose ):
2295
- self ._llava_cpp .llava_image_embed_free (embed )
2296
- exit_stack .callback (free_embed )
2309
+ embed = embed_image_bytes (image_bytes )
2297
2310
if llama .n_tokens + embed .contents .n_image_pos > llama .n_ctx ():
2298
2311
raise ValueError ("Prompt exceeds n_ctx" ) # TODO: Fix
2299
2312
n_past = ctypes .c_int (llama .n_tokens )
0 commit comments