20
20
from transformers import AutoProcessor
21
21
from safetensors import safe_open
22
22
from transformers .utils import TensorType
23
+ from lightllm .server .multimodal_params import MultimodalParams , ImageItem
23
24
from lightllm .models .qwen2_vl .qwen2_visual import PatchEmbed , VisionRotaryEmbedding
24
25
25
26
# adapted from
@@ -509,17 +510,17 @@ def load_model(self, weight_dir):
509
510
510
511
self .load_state_dict (weight_dict )
511
512
512
- def encode (self , image_uuids : List ):
513
+ def encode (self , images : List [ ImageItem ] ):
513
514
img_tensors = []
514
515
valid_ids = []
515
516
valid_id = 0
516
517
img_grids = []
517
518
uuids = []
518
519
519
- for i , url in enumerate (image_uuids ):
520
- if isinstance (url , int ):
521
- uuids .append (url )
522
- image_data = read_shm (get_shm_name_data (url ))
520
+ for i , img in enumerate (images ):
521
+ if isinstance (img , ImageItem ):
522
+ uuids .append (img . uuid )
523
+ image_data = read_shm (get_shm_name_data (img . uuid ))
523
524
image_data = Image .open (BytesIO (image_data ))
524
525
image_data = get_image (image_data )
525
526
image_inputs = self .processor .preprocess (images = image_data , return_tensors = "pt" )
@@ -528,7 +529,7 @@ def encode(self, image_uuids: List):
528
529
img_tensors .append (pixel_values )
529
530
img_grids .append (image_grid_thw )
530
531
else :
531
- raise Exception ("Unsupport input types: {} for {}" .format (type (url ), url ))
532
+ raise Exception ("Unsupport input types: {} for {}" .format (type (img ), img ))
532
533
533
534
# must devide merge_length
534
535
cur_num = img_tensors [- 1 ].shape [0 ] // (self .spatial_merge_size ** 2 )
0 commit comments