|
1 |
| -import math |
2 | 1 | from array import array
|
3 | 2 | from dataclasses import dataclass, fields
|
4 | 3 | from itertools import tee
|
|
15 | 14 |
|
16 | 15 | from vllm.attention import AttentionMetadata
|
17 | 16 | from vllm.config import CacheConfig, MultiModalConfig
|
18 |
| -from vllm.inputs import INPUT_REGISTRY, InputContext |
| 17 | +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs |
19 | 18 | from vllm.model_executor.layers.layernorm import RMSNorm
|
20 | 19 | from vllm.model_executor.layers.quantization import QuantizationConfig
|
21 | 20 | from vllm.model_executor.layers.sampler import SamplerOutput
|
22 | 21 | from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
| 22 | +from vllm.model_executor.models.utils import merge_multimodal_embeddings |
23 | 23 | from vllm.model_executor.sampling_metadata import SamplingMetadata
|
24 | 24 | from vllm.multimodal import MULTIMODAL_REGISTRY
|
25 | 25 | from vllm.multimodal.base import MultiModalInputs
|
@@ -48,23 +48,29 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
|
48 | 48 | tokenizer = cached_get_tokenizer(
|
49 | 49 | ctx.model_config.tokenizer,
|
50 | 50 | tokenizer_mode=ctx.model_config.tokenizer_mode)
|
51 |
| - mm_encoder = tokenizer.instruct.mm_encoder |
52 | 51 |
|
53 |
| - mm_config = ctx.model_config.multimodal_config |
54 |
| - max_num_images_per_request = mm_config.limit_per_prompt.get("image", 1) |
| 52 | + mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder |
| 53 | + patch_size = mm_encoder.mm_config.image_patch_size |
| 54 | + image_token_id = mm_encoder.special_ids.img |
55 | 55 |
|
56 |
| - # approximate image size |
57 |
| - size = int(math.sqrt(seq_len) * mm_encoder.mm_config.image_patch_size) |
| 56 | + mm_config = ctx.model_config.multimodal_config |
| 57 | + num_images = mm_config.limit_per_prompt.get("image", 1) |
58 | 58 |
|
| 59 | + # dummy size |
| 60 | + size = 256 |
59 | 61 | image = Image.new("RGB", (size, size), color=0)
|
60 |
| - img_chunk = ImageChunk(image=image) |
61 | 62 |
|
62 |
| - tokens = mm_encoder(img_chunk).tokens |
63 |
| - token_ids = max_num_images_per_request * array(VLLM_TOKEN_ID_ARRAY_TYPE, |
64 |
| - tokens) |
| 63 | + image_feature_size = (size**2) // (patch_size**2) |
| 64 | + |
| 65 | + num_image_tokens = image_feature_size * num_images |
| 66 | + |
| 67 | + token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, |
| 68 | + [image_token_id]) * num_image_tokens |
| 69 | + token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, |
| 70 | + [0]) * (seq_len - num_image_tokens) |
65 | 71 |
|
66 | 72 | seq_data = SequenceData(token_ids)
|
67 |
| - mm_data = {"image": max_num_images_per_request * [image]} |
| 73 | + mm_data = {"image": num_images * [image]} |
68 | 74 | return seq_data, mm_data
|
69 | 75 |
|
70 | 76 |
|
@@ -99,32 +105,31 @@ def input_mapper_for_pixtral(ctx: InputContext,
|
99 | 105 | return MultiModalInputs({"images": images})
|
100 | 106 |
|
101 | 107 |
|
102 |
| -def merge_multimodal_embeddings(input_ids: torch.Tensor, |
103 |
| - inputs_embeds: torch.Tensor, |
104 |
| - image_features: Optional[List[torch.Tensor]], |
105 |
| - image_id: int) -> torch.Tensor: |
106 |
| - text_locations = input_ids != image_id |
107 |
| - image_locations = input_ids == image_id |
108 |
| - |
109 |
| - seq_len = input_ids.shape[0] |
| 108 | +def input_processor_for_pixtral(ctx: InputContext, llm_inputs: LLMInputs): |
| 109 | + multi_modal_data = llm_inputs.get("multi_modal_data") |
| 110 | + if multi_modal_data is not None and "image" in multi_modal_data: |
| 111 | + tokenizer = cached_get_tokenizer( |
| 112 | + ctx.model_config.tokenizer, |
| 113 | + tokenizer_mode=ctx.model_config.tokenizer_mode) |
110 | 114 |
|
111 |
| - N_txt = text_locations.sum().item() |
112 |
| - _, D_txt = inputs_embeds.shape |
113 |
| - N_img, D_img = image_features.shape |
| 115 | + mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder |
| 116 | + image_token_id = mm_encoder.special_ids.img |
114 | 117 |
|
115 |
| - assert (D_txt == D_img), (f"Text features dim {D_txt} should be equal " |
116 |
| - "to image features dim {D_img}") |
117 |
| - assert (seq_len == N_txt + |
118 |
| - N_img), (f"seq_len {seq_len} should be equal to N_txt + N_img " |
119 |
| - f"{(N_txt, N_img, image_locations.sum().item())}") |
| 118 | + if image_token_id not in llm_inputs['prompt_token_ids']: |
| 119 | + raise ValueError( |
| 120 | + (f"You've passed {llm_inputs=} without {image_token_id=}" |
| 121 | + " Make sure to process your input via mistral_common's" |
| 122 | + " tokenizer or pass a chat completion request. For more" |
| 123 | + " For more info, see: " |
| 124 | + "https://github.com/vllm-project/vllm/issues/8411.")) |
120 | 125 |
|
121 |
| - inputs_embeds[image_locations, :] = image_features |
122 |
| - return inputs_embeds |
| 126 | + return llm_inputs |
123 | 127 |
|
124 | 128 |
|
125 | 129 | @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral)
|
126 | 130 | @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_pixtral_image_tokens)
|
127 | 131 | @INPUT_REGISTRY.register_dummy_data(dummy_data_for_pixtral)
|
| 132 | +@INPUT_REGISTRY.register_input_processor(input_processor_for_pixtral) |
128 | 133 | class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal):
|
129 | 134 |
|
130 | 135 | def __init__(self,
|
@@ -201,11 +206,21 @@ def _parse_and_validate_image_input(
|
201 | 206 | return None
|
202 | 207 |
|
203 | 208 | if isinstance(images, torch.Tensor):
|
204 |
| - # always take last images |
205 |
| - images = [images[-1][i] for i in range(images.size(1))] |
| 209 | + # if passed as batch take all images |
| 210 | + N, B, C, W, H = images.shape |
| 211 | + images = images.reshape(N * B, C, W, H) |
| 212 | + images = [images[i] for i in range(images.size(0))] |
206 | 213 | elif isinstance(images, list):
|
207 |
| - # always take last images |
208 |
| - images = [images[-1][i] for i in range(len(images[0]))] |
| 214 | + # if passed as list flatten lists of tensors |
| 215 | + flatten_images = [] |
| 216 | + for imgs_per_req in images: |
| 217 | + imgs_per_req = [ |
| 218 | + imgs_per_req[i] for i in range(imgs_per_req.size(0)) |
| 219 | + ] if isinstance(imgs_per_req, torch.Tensor) else imgs_per_req |
| 220 | + |
| 221 | + flatten_images.extend(imgs_per_req) |
| 222 | + |
| 223 | + images = flatten_images |
209 | 224 |
|
210 | 225 | return images
|
211 | 226 |
|
|
0 commit comments