@@ -3560,6 +3560,7 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An
3560
3560
labels = inputs ['labels' ]
3561
3561
idx_list = _findall (input_ids , - 100 )
3562
3562
processor = self .tokenizer .processor
3563
+ inputs = {'_data' : {}}
3563
3564
if images :
3564
3565
image_inputs = processor .image_processor (images , cut_enable = cut_enable , return_tensors = 'pt' )
3565
3566
added_tokens_len = 0
@@ -3579,21 +3580,23 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An
3579
3580
_range = torch .arange (len (input_ids ))[:, None ]
3580
3581
matrix = (_range > image_token_idx ).sum (dim = 1 )
3581
3582
media_offset = torch .stack ([torch .zeros (matrix .shape [0 ], dtype = torch .long ), matrix ], dim = - 1 )[None ]
3582
- inputs ['_data' ] = {'pixel_values' : image_inputs ['pixel_values' ]}
3583
- inputs ['media_offset' ] = media_offset
3584
- inputs ['num_images' ] = image_inputs ['pixel_values' ].shape [0 ]
3585
- inputs ['input_ids' ] = input_ids
3583
+ inputs ['_data' ].update ({
3584
+ 'pixel_values' : image_inputs ['pixel_values' ],
3585
+ 'media_offset' : media_offset ,
3586
+ })
3587
+ inputs ['_data' ]['input_ids' ] = input_ids
3586
3588
inputs ['labels' ] = labels
3587
3589
return inputs , {}
3588
3590
3589
3591
def _post_encode (self , model , data : Any ) -> Dict [str , Any ]:
3590
- image_embeds = model .forward_image (data ['pixel_values' ])
3591
- return {'image_embeds' : image_embeds }
3592
+ if 'pixel_values' in data :
3593
+ pixel_values = data .pop ('pixel_values' )
3594
+ data ['image_embeds' ] = model .forward_image (pixel_values )
3595
+ return data
3592
3596
3593
3597
def data_collator (self , batch : List [Dict [str , Any ]], padding_to : Optional [int ] = None ) -> Dict [str , Any ]:
3594
3598
res = super ().data_collator (batch , padding_to )
3595
3599
image_embeds = [b ['image_embeds' ] for b in batch if 'image_embeds' in b ]
3596
- num_images = [b ['num_images' ] if 'num_images' in b else 0 for b in batch ]
3597
3600
if image_embeds :
3598
3601
res ['image_embeds' ] = torch .concat (image_embeds )
3599
3602
media_offset = []
@@ -3609,7 +3612,7 @@ def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] =
3609
3612
curr_media_offset .shape [2 ])
3610
3613
curr_media_offset = torch .concat ([curr_media_offset , padding ], dim = 1 )
3611
3614
media_offset .append (curr_media_offset + cusum_offset )
3612
- cusum_offset += num_images [bi ]
3615
+ cusum_offset += image_embeds [bi ]. shape [ 0 ]
3613
3616
3614
3617
# media_offset = [b['media_offset'] for b in batch if 'media_offset' in b]
3615
3618
0 commit comments