@@ -415,15 +415,6 @@ class Phi4MMImagePixelInputs(TypedDict):
415
415
"""Shape: `(batch_size * num_images, H_mask, W_mask)`"""
416
416
417
417
418
- class Phi4MMImageEmbeddingInputs (TypedDict ):
419
- type : Literal ["image_embeds" ]
420
- data : Union [torch .Tensor , list [torch .Tensor ]]
421
- """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
422
-
423
- `hidden_size` must match the hidden size of language model backbone.
424
- """
425
-
426
-
427
418
class Phi4MMAudioFeatureInputs (TypedDict ):
428
419
type : Literal ["audio_features" ]
429
420
data : Union [torch .Tensor , list [torch .Tensor ]]
@@ -436,7 +427,6 @@ class Phi4MMAudioEmbeddingInputs(TypedDict):
436
427
"""Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)"""
437
428
438
429
439
- Phi4MMImageInput = Union [Phi4MMImagePixelInputs , Phi4MMImageEmbeddingInputs ]
440
430
Phi4MMAudioInputs = Union [Phi4MMAudioFeatureInputs , Phi4MMAudioEmbeddingInputs ]
441
431
442
432
@@ -1112,15 +1102,13 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
1112
1102
1113
1103
def _process_image_input (
1114
1104
self , image_input : Phi4MMImagePixelInputs ) -> list [torch .Tensor ]:
1115
- if image_input ["type" ] == "image_embeds" :
1116
- image_embeds = image_input ["image_embeds" ].type (self .visual .dtype )
1117
- else :
1118
- dtype = next (self .vision_encoder .parameters ()).dtype
1119
- pixel_values = image_input ['data' ].to (dtype )
1120
- image_sizes = image_input ['image_sizes' ]
1121
- image_attention_mask = image_input ['image_attention_mask' ]
1122
- image_embeds = self .vision_encoder (pixel_values , image_sizes ,
1123
- image_attention_mask )
1105
+
1106
+ dtype = next (self .vision_encoder .parameters ()).dtype
1107
+ pixel_values = image_input ['data' ].to (dtype )
1108
+ image_sizes = image_input ['image_sizes' ]
1109
+ image_attention_mask = image_input ['image_attention_mask' ]
1110
+ image_embeds = self .vision_encoder (pixel_values , image_sizes ,
1111
+ image_attention_mask )
1124
1112
return image_embeds
1125
1113
1126
1114
def get_multimodal_embeddings (
0 commit comments