Skip to content

Commit 45354f7

Browse files
authored
fix mplug-owl3 infer (#2175)
1 parent 0c2294a commit 45354f7

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

swift/llm/utils/template.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3560,6 +3560,7 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An
35603560
labels = inputs['labels']
35613561
idx_list = _findall(input_ids, -100)
35623562
processor = self.tokenizer.processor
3563+
inputs = {'_data': {}}
35633564
if images:
35643565
image_inputs = processor.image_processor(images, cut_enable=cut_enable, return_tensors='pt')
35653566
added_tokens_len = 0
@@ -3579,21 +3580,23 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An
35793580
_range = torch.arange(len(input_ids))[:, None]
35803581
matrix = (_range > image_token_idx).sum(dim=1)
35813582
media_offset = torch.stack([torch.zeros(matrix.shape[0], dtype=torch.long), matrix], dim=-1)[None]
3582-
inputs['_data'] = {'pixel_values': image_inputs['pixel_values']}
3583-
inputs['media_offset'] = media_offset
3584-
inputs['num_images'] = image_inputs['pixel_values'].shape[0]
3585-
inputs['input_ids'] = input_ids
3583+
inputs['_data'].update({
3584+
'pixel_values': image_inputs['pixel_values'],
3585+
'media_offset': media_offset,
3586+
})
3587+
inputs['_data']['input_ids'] = input_ids
35863588
inputs['labels'] = labels
35873589
return inputs, {}
35883590

35893591
def _post_encode(self, model, data: Any) -> Dict[str, Any]:
3590-
image_embeds = model.forward_image(data['pixel_values'])
3591-
return {'image_embeds': image_embeds}
3592+
if 'pixel_values' in data:
3593+
pixel_values = data.pop('pixel_values')
3594+
data['image_embeds'] = model.forward_image(pixel_values)
3595+
return data
35923596

35933597
def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]:
35943598
res = super().data_collator(batch, padding_to)
35953599
image_embeds = [b['image_embeds'] for b in batch if 'image_embeds' in b]
3596-
num_images = [b['num_images'] if 'num_images' in b else 0 for b in batch]
35973600
if image_embeds:
35983601
res['image_embeds'] = torch.concat(image_embeds)
35993602
media_offset = []
@@ -3609,7 +3612,7 @@ def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] =
36093612
curr_media_offset.shape[2])
36103613
curr_media_offset = torch.concat([curr_media_offset, padding], dim=1)
36113614
media_offset.append(curr_media_offset + cusum_offset)
3612-
cusum_offset += num_images[bi]
3615+
cusum_offset += image_embeds[bi].shape[0]
36133616

36143617
# media_offset = [b['media_offset'] for b in batch if 'media_offset' in b]
36153618

0 commit comments

Comments
 (0)