Skip to content

Commit 7328512

Browse files
patrickvonplatengarg-amit
authored andcommitted
[Hotfix][Pixtral] Fix multiple images bugs (vllm-project#8415)
Signed-off-by: Amit Garg <mitgarg17495@gmail.com>
1 parent 883ae2a commit 7328512

File tree

4 files changed

+195
-76
lines changed

4 files changed

+195
-76
lines changed
20.4 KB
Binary file not shown.
20.4 KB
Binary file not shown.

tests/models/test_pixtral.py

Lines changed: 146 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2,63 +2,167 @@
22
33
Run `pytest tests/models/test_mistral.py`.
44
"""
5+
import pickle
6+
import uuid
7+
from typing import Any, Dict, List
8+
59
import pytest
10+
from mistral_common.protocol.instruct.messages import ImageURLChunk
11+
from mistral_common.protocol.instruct.request import ChatCompletionRequest
12+
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
13+
from mistral_common.tokens.tokenizers.multimodal import image_from_chunk
14+
15+
from vllm import EngineArgs, LLMEngine, SamplingParams, TokensPrompt
16+
from vllm.multimodal import MultiModalDataBuiltins
617

7-
from vllm.sampling_params import SamplingParams
18+
from .utils import check_logprobs_close
819

920
pytestmark = pytest.mark.vlm
1021

1122
MODELS = ["mistralai/Pixtral-12B-2409"]
23+
IMG_URLS = [
24+
"https://picsum.photos/id/237/400/300",
25+
"https://picsum.photos/id/231/200/300",
26+
"https://picsum.photos/id/27/500/500",
27+
"https://picsum.photos/id/17/150/600",
28+
]
29+
PROMPT = "Describe each image in one short sentence."
30+
31+
32+
def _create_msg_format(urls: List[str]) -> List[Dict[str, Any]]:
33+
return [{
34+
"role":
35+
"user",
36+
"content": [{
37+
"type": "text",
38+
"text": PROMPT,
39+
}] + [{
40+
"type": "image_url",
41+
"image_url": {
42+
"url": url
43+
}
44+
} for url in urls],
45+
}]
46+
47+
48+
def _create_engine_inputs(urls: List[str]) -> TokensPrompt:
49+
msg = _create_msg_format(urls)
50+
51+
tokenizer = MistralTokenizer.from_model("pixtral")
52+
53+
request = ChatCompletionRequest(messages=msg) # type: ignore[type-var]
54+
tokenized = tokenizer.encode_chat_completion(request)
55+
56+
engine_inputs = TokensPrompt(prompt_token_ids=tokenized.tokens)
57+
58+
images = []
59+
for chunk in request.messages[0].content:
60+
if isinstance(chunk, ImageURLChunk):
61+
images.append(image_from_chunk(chunk))
62+
63+
mm_data = MultiModalDataBuiltins(image=images)
64+
engine_inputs["multi_modal_data"] = mm_data
65+
66+
return engine_inputs
67+
68+
69+
MSGS = [
70+
_create_msg_format(IMG_URLS[:1]),
71+
_create_msg_format(IMG_URLS[:2]),
72+
_create_msg_format(IMG_URLS),
73+
]
74+
ENGINE_INPUTS = [
75+
_create_engine_inputs(IMG_URLS[:1]),
76+
_create_engine_inputs(IMG_URLS[:2]),
77+
_create_engine_inputs(IMG_URLS),
78+
]
79+
80+
SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5)
81+
LIMIT_MM_PER_PROMPT = dict(image=4)
82+
83+
MAX_MODEL_LEN = [8192, 65536]
84+
FIXTURE_LOGPROBS_CHAT = "tests/models/fixtures/pixtral_chat.pickle"
85+
FIXTURE_LOGPROBS_ENGINE = "tests/models/fixtures/pixtral_chat_engine.pickle"
86+
87+
88+
def load_logprobs(filename: str) -> Any:
89+
with open(filename, 'rb') as f:
90+
return pickle.load(f)
1291

1392

1493
@pytest.mark.skip(
1594
reason=
1695
"Model is too big, test passed on A100 locally but will OOM on CI machine."
1796
)
1897
@pytest.mark.parametrize("model", MODELS)
98+
@pytest.mark.parametrize("max_model_len", MAX_MODEL_LEN)
1999
@pytest.mark.parametrize("dtype", ["bfloat16"])
20-
@pytest.mark.parametrize("max_tokens", [64])
21-
@pytest.mark.parametrize("num_logprobs", [5])
22-
def test_models(
100+
def test_chat(
23101
vllm_runner,
24-
example_prompts,
102+
max_model_len: int,
25103
model: str,
26104
dtype: str,
27-
max_tokens: int,
28-
num_logprobs: int,
29105
) -> None:
30-
image_urls = [
31-
"https://picsum.photos/id/237/200/300",
32-
"https://picsum.photos/seed/picsum/200/300"
33-
]
34-
expected = [
35-
"The image depicts a black dog lying on a wooden surface, looking directly at the camera with a calm expression.", # noqa
36-
"The image depicts a serene landscape with a snow-covered mountain under a pastel-colored sky during sunset." # noqa
37-
]
38-
prompt = "Describe the image in one short sentence."
39-
40-
sampling_params = SamplingParams(max_tokens=512, temperature=0.0)
41-
42-
with vllm_runner(model, dtype=dtype,
43-
tokenizer_mode="mistral") as vllm_model:
44-
45-
for i, image_url in enumerate(image_urls):
46-
messages = [
47-
{
48-
"role":
49-
"user",
50-
"content": [{
51-
"type": "text",
52-
"text": prompt
53-
}, {
54-
"type": "image_url",
55-
"image_url": {
56-
"url": image_url
57-
}
58-
}]
59-
},
60-
]
61-
62-
outputs = vllm_model.model.chat(messages,
63-
sampling_params=sampling_params)
64-
assert outputs[0].outputs[0].text == expected[i]
106+
EXPECTED_CHAT_LOGPROBS = load_logprobs(FIXTURE_LOGPROBS_CHAT)
107+
with vllm_runner(
108+
model,
109+
dtype=dtype,
110+
tokenizer_mode="mistral",
111+
enable_chunked_prefill=False,
112+
max_model_len=max_model_len,
113+
limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
114+
) as vllm_model:
115+
outputs = []
116+
for msg in MSGS:
117+
output = vllm_model.model.chat(msg,
118+
sampling_params=SAMPLING_PARAMS)
119+
120+
outputs.extend(output)
121+
122+
logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
123+
check_logprobs_close(outputs_0_lst=logprobs,
124+
outputs_1_lst=EXPECTED_CHAT_LOGPROBS,
125+
name_0="output",
126+
name_1="h100_ref")
127+
128+
129+
@pytest.mark.skip(
130+
reason=
131+
"Model is too big, test passed on A100 locally but will OOM on CI machine."
132+
)
133+
@pytest.mark.parametrize("model", MODELS)
134+
@pytest.mark.parametrize("dtype", ["bfloat16"])
135+
def test_model_engine(vllm_runner, model: str, dtype: str) -> None:
136+
EXPECTED_ENGINE_LOGPROBS = load_logprobs(FIXTURE_LOGPROBS_ENGINE)
137+
args = EngineArgs(
138+
model=model,
139+
tokenizer_mode="mistral",
140+
enable_chunked_prefill=False,
141+
limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
142+
dtype=dtype,
143+
)
144+
engine = LLMEngine.from_engine_args(args)
145+
146+
engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[0], SAMPLING_PARAMS)
147+
engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[1], SAMPLING_PARAMS)
148+
149+
outputs = []
150+
count = 0
151+
while True:
152+
out = engine.step()
153+
count += 1
154+
for request_output in out:
155+
if request_output.finished:
156+
outputs.append(request_output)
157+
158+
if count == 2:
159+
engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[2],
160+
SAMPLING_PARAMS)
161+
if not engine.has_unfinished_requests():
162+
break
163+
164+
logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
165+
check_logprobs_close(outputs_0_lst=logprobs,
166+
outputs_1_lst=EXPECTED_ENGINE_LOGPROBS,
167+
name_0="output",
168+
name_1="h100_ref")

vllm/model_executor/models/pixtral.py

Lines changed: 49 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import math
21
from array import array
32
from dataclasses import dataclass, fields
43
from itertools import tee
@@ -15,11 +14,12 @@
1514

1615
from vllm.attention import AttentionMetadata
1716
from vllm.config import CacheConfig, MultiModalConfig
18-
from vllm.inputs import INPUT_REGISTRY, InputContext
17+
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
1918
from vllm.model_executor.layers.layernorm import RMSNorm
2019
from vllm.model_executor.layers.quantization import QuantizationConfig
2120
from vllm.model_executor.layers.sampler import SamplerOutput
2221
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
22+
from vllm.model_executor.models.utils import merge_multimodal_embeddings
2323
from vllm.model_executor.sampling_metadata import SamplingMetadata
2424
from vllm.multimodal import MULTIMODAL_REGISTRY
2525
from vllm.multimodal.base import MultiModalInputs
@@ -48,23 +48,29 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
4848
tokenizer = cached_get_tokenizer(
4949
ctx.model_config.tokenizer,
5050
tokenizer_mode=ctx.model_config.tokenizer_mode)
51-
mm_encoder = tokenizer.instruct.mm_encoder
5251

53-
mm_config = ctx.model_config.multimodal_config
54-
max_num_images_per_request = mm_config.limit_per_prompt.get("image", 1)
52+
mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
53+
patch_size = mm_encoder.mm_config.image_patch_size
54+
image_token_id = mm_encoder.special_ids.img
5555

56-
# approximate image size
57-
size = int(math.sqrt(seq_len) * mm_encoder.mm_config.image_patch_size)
56+
mm_config = ctx.model_config.multimodal_config
57+
num_images = mm_config.limit_per_prompt.get("image", 1)
5858

59+
# dummy size
60+
size = 256
5961
image = Image.new("RGB", (size, size), color=0)
60-
img_chunk = ImageChunk(image=image)
6162

62-
tokens = mm_encoder(img_chunk).tokens
63-
token_ids = max_num_images_per_request * array(VLLM_TOKEN_ID_ARRAY_TYPE,
64-
tokens)
63+
image_feature_size = (size**2) // (patch_size**2)
64+
65+
num_image_tokens = image_feature_size * num_images
66+
67+
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
68+
[image_token_id]) * num_image_tokens
69+
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
70+
[0]) * (seq_len - num_image_tokens)
6571

6672
seq_data = SequenceData(token_ids)
67-
mm_data = {"image": max_num_images_per_request * [image]}
73+
mm_data = {"image": num_images * [image]}
6874
return seq_data, mm_data
6975

7076

@@ -99,32 +105,31 @@ def input_mapper_for_pixtral(ctx: InputContext,
99105
return MultiModalInputs({"images": images})
100106

101107

102-
def merge_multimodal_embeddings(input_ids: torch.Tensor,
103-
inputs_embeds: torch.Tensor,
104-
image_features: Optional[List[torch.Tensor]],
105-
image_id: int) -> torch.Tensor:
106-
text_locations = input_ids != image_id
107-
image_locations = input_ids == image_id
108-
109-
seq_len = input_ids.shape[0]
108+
def input_processor_for_pixtral(ctx: InputContext, llm_inputs: LLMInputs):
109+
multi_modal_data = llm_inputs.get("multi_modal_data")
110+
if multi_modal_data is not None and "image" in multi_modal_data:
111+
tokenizer = cached_get_tokenizer(
112+
ctx.model_config.tokenizer,
113+
tokenizer_mode=ctx.model_config.tokenizer_mode)
110114

111-
N_txt = text_locations.sum().item()
112-
_, D_txt = inputs_embeds.shape
113-
N_img, D_img = image_features.shape
115+
mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
116+
image_token_id = mm_encoder.special_ids.img
114117

115-
assert (D_txt == D_img), (f"Text features dim {D_txt} should be equal "
116-
"to image features dim {D_img}")
117-
assert (seq_len == N_txt +
118-
N_img), (f"seq_len {seq_len} should be equal to N_txt + N_img "
119-
f"{(N_txt, N_img, image_locations.sum().item())}")
118+
if image_token_id not in llm_inputs['prompt_token_ids']:
119+
raise ValueError(
120+
(f"You've passed {llm_inputs=} without {image_token_id=}"
121+
" Make sure to process your input via mistral_common's"
122+
" tokenizer or pass a chat completion request. For more"
123+
" For more info, see: "
124+
"https://github.com/vllm-project/vllm/issues/8411."))
120125

121-
inputs_embeds[image_locations, :] = image_features
122-
return inputs_embeds
126+
return llm_inputs
123127

124128

125129
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral)
126130
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_pixtral_image_tokens)
127131
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_pixtral)
132+
@INPUT_REGISTRY.register_input_processor(input_processor_for_pixtral)
128133
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal):
129134

130135
def __init__(self,
@@ -201,11 +206,21 @@ def _parse_and_validate_image_input(
201206
return None
202207

203208
if isinstance(images, torch.Tensor):
204-
# always take last images
205-
images = [images[-1][i] for i in range(images.size(1))]
209+
# if passed as batch take all images
210+
N, B, C, W, H = images.shape
211+
images = images.reshape(N * B, C, W, H)
212+
images = [images[i] for i in range(images.size(0))]
206213
elif isinstance(images, list):
207-
# always take last images
208-
images = [images[-1][i] for i in range(len(images[0]))]
214+
# if passed as list flatten lists of tensors
215+
flatten_images = []
216+
for imgs_per_req in images:
217+
imgs_per_req = [
218+
imgs_per_req[i] for i in range(imgs_per_req.size(0))
219+
] if isinstance(imgs_per_req, torch.Tensor) else imgs_per_req
220+
221+
flatten_images.extend(imgs_per_req)
222+
223+
images = flatten_images
209224

210225
return images
211226

0 commit comments

Comments
 (0)