|
| 1 | +import types |
| 2 | +from typing import List, Optional, Type |
| 3 | + |
| 4 | +import pytest |
| 5 | +import torch |
| 6 | +from huggingface_hub import snapshot_download |
| 7 | +from PIL.Image import Image |
| 8 | + |
| 9 | +from vllm.model_executor.models.internvl import (IMG_CONTEXT, IMG_END, |
| 10 | + IMG_START, |
| 11 | + image_to_pixel_values) |
| 12 | +from vllm.multimodal.utils import rescale_image_size |
| 13 | +from vllm.utils import is_cpu |
| 14 | + |
| 15 | +from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets |
| 16 | +from .utils import check_logprobs_close |
| 17 | + |
| 18 | +pytestmark = pytest.mark.vlm |
| 19 | + |
| 20 | +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ |
| 21 | + "stop_sign": |
| 22 | + "<|im_start|>User\n<image>\nWhat's the content in the center of the image?<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 |
| 23 | + "cherry_blossom": |
| 24 | + "<|im_start|>User\n<image>\nWhat is the season?<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 |
| 25 | +}) |
| 26 | + |
| 27 | +# we use snapshot_download to prevent conflicts between |
| 28 | +# dynamic_module and trust_remote_code for hf_runner |
| 29 | +models = [ |
| 30 | + snapshot_download("OpenGVLab/InternVL2-1B"), |
| 31 | + snapshot_download("OpenGVLab/InternVL2-2B"), |
| 32 | + # snapshot_download("OpenGVLab/InternVL2-4B"), # broken |
| 33 | +] |
| 34 | + |
| 35 | + |
| 36 | +class InternVLProcessor: |
| 37 | + """A simple processor for InternVL2 HF model which misses a processor.""" |
| 38 | + |
| 39 | + def __init__(self, hf_runner: HfRunner): |
| 40 | + self.num_image_token = hf_runner.model.num_image_token |
| 41 | + self.tokenizer = hf_runner.tokenizer |
| 42 | + self.dtype = hf_runner.model.dtype |
| 43 | + |
| 44 | + def __call__(self, text: str, images: Image, **kwargs): |
| 45 | + pixel_values = image_to_pixel_values(images).to(self.dtype) |
| 46 | + num_patches_list = [pixel_values.shape[0]] |
| 47 | + for num_patches in num_patches_list: |
| 48 | + context_tokens = IMG_CONTEXT * self.num_image_token * num_patches |
| 49 | + image_tokens = IMG_START + context_tokens + IMG_END |
| 50 | + text = text.replace('<image>', image_tokens, 1) |
| 51 | + prompt = self.tokenizer(text, return_tensors="pt") |
| 52 | + prompt.update({"pixel_values": pixel_values}) |
| 53 | + return prompt |
| 54 | + |
| 55 | + |
| 56 | +# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py |
| 57 | +def generate( |
| 58 | + self, |
| 59 | + pixel_values: torch.FloatTensor, |
| 60 | + input_ids: torch.FloatTensor, |
| 61 | + attention_mask: Optional[torch.LongTensor] = None, |
| 62 | + **generate_kwargs, |
| 63 | +) -> torch.LongTensor: |
| 64 | + """Generate method for InternVL2 model without fixed use_cache.""" |
| 65 | + assert self.img_context_token_id is not None |
| 66 | + vit_embeds = self.extract_feature(pixel_values) |
| 67 | + input_embeds = self.language_model.get_input_embeddings()(input_ids) |
| 68 | + B, N, C = input_embeds.shape |
| 69 | + input_embeds = input_embeds.reshape(B * N, C) |
| 70 | + |
| 71 | + input_ids = input_ids.reshape(B * N) |
| 72 | + selected = (input_ids == self.img_context_token_id) |
| 73 | + assert selected.sum() != 0 |
| 74 | + input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device) |
| 75 | + |
| 76 | + input_embeds = input_embeds.reshape(B, N, C) |
| 77 | + |
| 78 | + outputs = self.language_model.generate( |
| 79 | + inputs_embeds=input_embeds, |
| 80 | + attention_mask=attention_mask, |
| 81 | + **generate_kwargs, |
| 82 | + ) |
| 83 | + |
| 84 | + return outputs |
| 85 | + |
| 86 | + |
| 87 | +def run_test( |
| 88 | + hf_runner: Type[HfRunner], |
| 89 | + vllm_runner: Type[VllmRunner], |
| 90 | + image_assets: _ImageAssets, |
| 91 | + model: str, |
| 92 | + *, |
| 93 | + size_factors: List[float], |
| 94 | + dtype: str, |
| 95 | + max_tokens: int, |
| 96 | + num_logprobs: int, |
| 97 | + tensor_parallel_size: int, |
| 98 | + distributed_executor_backend: Optional[str] = None, |
| 99 | +): |
| 100 | + """Inference result should be the same between hf and vllm. |
| 101 | +
|
| 102 | + All the image fixtures for the test is under tests/images. |
| 103 | + For huggingface runner, we provide the PIL images as input. |
| 104 | + For vllm runner, we provide MultiModalDataDict objects |
| 105 | + and corresponding vision language config as input. |
| 106 | + Note, the text input is also adjusted to abide by vllm contract. |
| 107 | + The text output is sanitized to be able to compare with hf. |
| 108 | + """ |
| 109 | + images = [asset.pil_image for asset in image_assets] |
| 110 | + |
| 111 | + inputs_per_image = [( |
| 112 | + [prompt for _ in size_factors], |
| 113 | + [rescale_image_size(image, factor) for factor in size_factors], |
| 114 | + ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] |
| 115 | + |
| 116 | + # NOTE: take care of the order. run vLLM first, and then run HF. |
| 117 | + # vLLM needs a fresh new process without cuda initialization. |
| 118 | + # if we run HF first, the cuda initialization will be done and it |
| 119 | + # will hurt multiprocessing backend with fork method (the default method). |
| 120 | + |
| 121 | + # max_model_len should be greater than image_feature_size |
| 122 | + with vllm_runner(model, |
| 123 | + max_model_len=4096, |
| 124 | + dtype=dtype, |
| 125 | + tensor_parallel_size=tensor_parallel_size, |
| 126 | + distributed_executor_backend=distributed_executor_backend, |
| 127 | + enforce_eager=True) as vllm_model: |
| 128 | + vllm_outputs_per_image = [ |
| 129 | + vllm_model.generate_greedy_logprobs(prompts, |
| 130 | + max_tokens, |
| 131 | + num_logprobs=num_logprobs, |
| 132 | + images=images) |
| 133 | + for prompts, images in inputs_per_image |
| 134 | + ] |
| 135 | + |
| 136 | + with hf_runner(model, dtype=dtype) as hf_model: |
| 137 | + img_context_token_id = hf_model.tokenizer.convert_tokens_to_ids( |
| 138 | + "<IMG_CONTEXT>") |
| 139 | + hf_model.model.img_context_token_id = img_context_token_id |
| 140 | + hf_model.processor = InternVLProcessor(hf_model) |
| 141 | + hf_model.model.get_output_embeddings = lambda: \ |
| 142 | + hf_model.model.language_model.get_output_embeddings() |
| 143 | + hf_model.model.generate = types.MethodType(generate, hf_model.model) |
| 144 | + eos_token_id = hf_model.tokenizer.eos_token_id |
| 145 | + hf_outputs_per_image = [ |
| 146 | + hf_model.generate_greedy_logprobs_limit(prompts, |
| 147 | + max_tokens, |
| 148 | + num_logprobs=num_logprobs, |
| 149 | + images=hf_images, |
| 150 | + eos_token_id=eos_token_id) |
| 151 | + for prompts, hf_images in inputs_per_image |
| 152 | + ] |
| 153 | + |
| 154 | + for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, |
| 155 | + vllm_outputs_per_image): |
| 156 | + # TODO: Check whether using original CLIPVisionModel can improve |
| 157 | + # consistency against HF |
| 158 | + check_logprobs_close( |
| 159 | + outputs_0_lst=hf_outputs, |
| 160 | + outputs_1_lst=vllm_outputs, |
| 161 | + name_0="hf", |
| 162 | + name_1="vllm", |
| 163 | + ) |
| 164 | + |
| 165 | + |
| 166 | +target_dtype = "half" |
| 167 | +if is_cpu(): |
| 168 | + target_dtype = "bfloat16" |
| 169 | + |
| 170 | + |
| 171 | +@pytest.mark.parametrize("model", models) |
| 172 | +@pytest.mark.parametrize( |
| 173 | + "size_factors", |
| 174 | + [ |
| 175 | + # No image |
| 176 | + [], |
| 177 | + # Single-scale |
| 178 | + [1.0], |
| 179 | + # Single-scale, batched |
| 180 | + [1.0, 1.0, 1.0], |
| 181 | + # Multi-scale |
| 182 | + [0.25, 0.5, 1.0], |
| 183 | + ], |
| 184 | +) |
| 185 | +@pytest.mark.parametrize("dtype", [target_dtype]) |
| 186 | +@pytest.mark.parametrize("max_tokens", [128]) |
| 187 | +@pytest.mark.parametrize("num_logprobs", [5]) |
| 188 | +@torch.inference_mode() |
| 189 | +def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, |
| 190 | + dtype: str, max_tokens: int, num_logprobs: int) -> None: |
| 191 | + run_test( |
| 192 | + hf_runner, |
| 193 | + vllm_runner, |
| 194 | + image_assets, |
| 195 | + model, |
| 196 | + size_factors=size_factors, |
| 197 | + dtype=dtype, |
| 198 | + max_tokens=max_tokens, |
| 199 | + num_logprobs=num_logprobs, |
| 200 | + tensor_parallel_size=1, |
| 201 | + ) |
0 commit comments