Skip to content

Commit 1df43de

Browse files
authored
[bug fix] Fix llava next feature size calculation. (#6339)
Signed-off-by: Xiaowei Jiang <xwjiang2010@gmail.com>
1 parent 52b7fcb commit 1df43de

File tree

2 files changed

+23
-9
lines changed

2 files changed

+23
-9
lines changed

tests/models/test_llava_next.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from typing import List, Optional, Tuple
22

33
import pytest
4-
from transformers import AutoTokenizer
4+
from transformers import AutoConfig, AutoTokenizer
55

6+
from vllm.model_executor.models.llava_next import (
7+
get_llava_next_image_feature_size)
68
from vllm.multimodal.utils import rescale_image_size
79
from vllm.sequence import SampleLogprobs
810

@@ -120,3 +122,13 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
120122
name_0="hf",
121123
name_1="vllm",
122124
)
125+
126+
127+
@pytest.mark.parametrize("height_and_width_and_result", [(1669, 2560, 2144),
128+
(183, 488, 776)])
129+
def test_image_feature_size(height_and_width_and_result):
130+
height, width, result = height_and_width_and_result
131+
config = AutoConfig.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
132+
assert get_llava_next_image_feature_size(config,
133+
input_height=height,
134+
input_width=width) == result

vllm/model_executor/models/llava_next.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,19 +74,21 @@ def _get_llava_next_num_unpadded_features(
7474
) -> Tuple[int, int]:
7575
current_height = npatches * num_patch_height
7676
current_width = npatches * num_patch_width
77+
current_height = torch.tensor(current_height).to("cuda")
78+
current_width = torch.tensor(current_width).to("cuda")
7779

7880
aspect_ratio: float = width / height
7981
current_aspect_ratio: float = current_width / current_height
8082
if aspect_ratio > current_aspect_ratio:
81-
new_height = (height * current_width) // width
82-
if new_height % 2 == 1:
83-
new_height += 1
84-
current_height = new_height
83+
scale_factor = current_width / width
84+
new_height = int(height * scale_factor)
85+
padding = (current_height - new_height) // 2
86+
current_height -= padding * 2
8587
else:
86-
new_width = (width * current_height) // height
87-
if new_width % 2 == 1:
88-
new_width += 1
89-
current_width = new_width
88+
scale_factor = current_height / height
89+
new_width = int(width * scale_factor)
90+
padding = (current_width - new_width) // 2
91+
current_width -= padding * 2
9092

9193
unpadded_features = current_height * current_width
9294
newline_features = current_height

0 commit comments

Comments
 (0)