Skip to content

Commit 649a8c3

Browse files
mgoinYuqi Zhang
authored andcommitted
Fix PixtralHF missing spatial_merge_size (vllm-project#17571)
Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Yuqi Zhang <yuqizhang@google.com>
1 parent 61150cf commit 649a8c3

File tree

4 files changed

+18
-25
lines changed

4 files changed

+18
-25
lines changed

vllm/model_executor/models/llava.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -354,9 +354,8 @@ def _get_prompt_updates(
354354
image_token_id = hf_config.image_token_index
355355
image_end_id = vocab[processor.image_end_token]
356356

357-
vision_config = hf_config.vision_config
358-
assert isinstance(vision_config, PixtralVisionConfig)
359-
encoder_info = PixtralHFEncoderInfo(vision_config)
357+
assert isinstance(hf_config.vision_config, PixtralVisionConfig)
358+
encoder_info = PixtralHFEncoderInfo(hf_config)
360359

361360
def get_replacement(item_idx: int):
362361
images = mm_items.get_items("image", ImageProcessorItems)

vllm/model_executor/models/mistral3.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -272,12 +272,8 @@ def _get_prompt_updates(
272272
image_token_id = hf_config.image_token_index
273273
image_end_id = vocab[processor.image_end_token]
274274

275-
vision_config = hf_config.vision_config
276-
assert isinstance(vision_config, PixtralVisionConfig)
277-
# Need to sneak in spatial_merge_size for Mistral3
278-
vision_config.spatial_merge_size = getattr(hf_config,
279-
"spatial_merge_size", 1)
280-
encoder_info = PixtralHFEncoderInfo(vision_config)
275+
assert isinstance(hf_config.vision_config, PixtralVisionConfig)
276+
encoder_info = PixtralHFEncoderInfo(hf_config)
281277

282278
def get_replacement(item_idx: int):
283279
images = mm_items.get_items("image", ImageProcessorItems)

vllm/model_executor/models/pixtral.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -916,8 +916,9 @@ def get_image_size(self) -> int:
916916
return self.vision_config.image_size
917917

918918
def get_patch_size(self) -> int:
919-
return (self.vision_config.patch_size *
920-
self.vision_config.spatial_merge_size)
919+
# spatial_merge_size is needed for Mistral3
920+
spatial_merge_size = getattr(self.hf_config, "spatial_merge_size", 1)
921+
return self.vision_config.patch_size * spatial_merge_size
921922

922923
def get_patch_grid_length(self) -> int:
923924
image_size, patch_size = self.get_image_size(), self.get_patch_size()

vllm/model_executor/models/vision.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@
1919

2020
class VisionEncoderInfo(ABC, Generic[_C]):
2121

22-
def __init__(self, vision_config: _C) -> None:
22+
def __init__(self, hf_config: _C) -> None:
2323
super().__init__()
2424

25-
self.vision_config = vision_config
25+
self.hf_config = hf_config
26+
self.vision_config = hf_config.vision_config
2627

2728
@abstractmethod
2829
def get_num_image_tokens(
@@ -57,18 +58,14 @@ def get_vision_encoder_info(
5758
from .pixtral import PixtralHFEncoderInfo, PixtralVisionConfig
5859
from .siglip import SiglipEncoderInfo, SiglipVisionConfig
5960

60-
vision_config = hf_config.vision_config
61-
if isinstance(vision_config, CLIPVisionConfig):
62-
return CLIPEncoderInfo(vision_config)
63-
if isinstance(vision_config, PixtralVisionConfig):
64-
# Need to sneak in spatial_merge_size for Mistral3
65-
vision_config.spatial_merge_size = getattr(hf_config,
66-
"spatial_merge_size", 1)
67-
return PixtralHFEncoderInfo(vision_config)
68-
if isinstance(vision_config, SiglipVisionConfig):
69-
return SiglipEncoderInfo(vision_config)
70-
71-
msg = f"Unsupported vision config: {type(vision_config)}"
61+
if isinstance(hf_config.vision_config, CLIPVisionConfig):
62+
return CLIPEncoderInfo(hf_config)
63+
if isinstance(hf_config.vision_config, PixtralVisionConfig):
64+
return PixtralHFEncoderInfo(hf_config)
65+
if isinstance(hf_config.vision_config, SiglipVisionConfig):
66+
return SiglipEncoderInfo(hf_config)
67+
68+
msg = f"Unsupported vision config: {type(hf_config.vision_config)}"
7269
raise NotImplementedError(msg)
7370

7471

0 commit comments

Comments
 (0)