Skip to content

Commit 7cbd9ec

Browse files
Isotr0pyywang96
andauthored
[Model] Initialize support for InternVL2 series models (#6514)
Co-authored-by: Roger Wang <ywang@roblox.com>
1 parent 3eeb148 commit 7cbd9ec

File tree

14 files changed

+1042
-6
lines changed

14 files changed

+1042
-6
lines changed

docs/source/models/supported_models.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,10 @@ Vision Language Models
200200
- Fuyu
201201
- :code:`adept/fuyu-8b` etc.
202202
-
203+
* - :code:`InternVLChatModel`
204+
- InternVL2
205+
- :code:`OpenGVLab/InternVL2-4B`, :code:`OpenGVLab/InternVL2-8B`, etc.
206+
-
203207
* - :code:`LlavaForConditionalGeneration`
204208
- LLaVA-1.5
205209
- :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc.

examples/offline_inference_vision_language.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,20 @@ def run_minicpmv(question):
106106
return llm, prompt
107107

108108

109+
# InternVL
110+
def run_internvl(question):
111+
# Generally, InternVL can use chatml template for conversation
112+
TEMPLATE = "<|im_start|>User\n{prompt}<|im_end|>\n<|im_start|>Assistant\n"
113+
prompt = f"<image>\n{question}\n"
114+
prompt = TEMPLATE.format(prompt=prompt)
115+
llm = LLM(
116+
model="OpenGVLab/InternVL2-4B",
117+
trust_remote_code=True,
118+
max_num_seqs=5,
119+
)
120+
return llm, prompt
121+
122+
109123
# BLIP-2
110124
def run_blip2(question):
111125

@@ -125,6 +139,7 @@ def run_blip2(question):
125139
"chameleon": run_chameleon,
126140
"minicpmv": run_minicpmv,
127141
"blip-2": run_blip2,
142+
"internvl_chat": run_internvl,
128143
}
129144

130145

examples/openai_vision_api_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
],
4343
}],
4444
model=model,
45+
max_tokens=64,
4546
)
4647

4748
result = chat_completion_from_url.choices[0].message.content
@@ -78,6 +79,7 @@ def encode_image_base64_from_url(image_url: str) -> str:
7879
],
7980
}],
8081
model=model,
82+
max_tokens=64,
8183
)
8284

8385
result = chat_completion_from_base64.choices[0].message.content

requirements-test.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ ray
1616
sentence-transformers # required for embedding
1717
sparseml==1.8.0 # required for compressed-tensors
1818
compressed-tensors==0.4.0 # required for compressed-tensors
19+
timm # required for internvl test
1920

2021
# Benchmarking
2122
aiohttp

tests/models/test_internvl.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
import types
2+
from typing import List, Optional, Type
3+
4+
import pytest
5+
import torch
6+
from huggingface_hub import snapshot_download
7+
from PIL.Image import Image
8+
9+
from vllm.model_executor.models.internvl import (IMG_CONTEXT, IMG_END,
10+
IMG_START,
11+
image_to_pixel_values)
12+
from vllm.multimodal.utils import rescale_image_size
13+
from vllm.utils import is_cpu
14+
15+
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
16+
from .utils import check_logprobs_close
17+
18+
pytestmark = pytest.mark.vlm
19+
20+
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
21+
"stop_sign":
22+
"<|im_start|>User\n<image>\nWhat's the content in the center of the image?<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501
23+
"cherry_blossom":
24+
"<|im_start|>User\n<image>\nWhat is the season?<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501
25+
})
26+
27+
# we use snapshot_download to prevent conflicts between
28+
# dynamic_module and trust_remote_code for hf_runner
29+
models = [
30+
snapshot_download("OpenGVLab/InternVL2-1B"),
31+
snapshot_download("OpenGVLab/InternVL2-2B"),
32+
# snapshot_download("OpenGVLab/InternVL2-4B"), # broken
33+
]
34+
35+
36+
class InternVLProcessor:
37+
"""A simple processor for InternVL2 HF model which misses a processor."""
38+
39+
def __init__(self, hf_runner: HfRunner):
40+
self.num_image_token = hf_runner.model.num_image_token
41+
self.tokenizer = hf_runner.tokenizer
42+
self.dtype = hf_runner.model.dtype
43+
44+
def __call__(self, text: str, images: Image, **kwargs):
45+
pixel_values = image_to_pixel_values(images).to(self.dtype)
46+
num_patches_list = [pixel_values.shape[0]]
47+
for num_patches in num_patches_list:
48+
context_tokens = IMG_CONTEXT * self.num_image_token * num_patches
49+
image_tokens = IMG_START + context_tokens + IMG_END
50+
text = text.replace('<image>', image_tokens, 1)
51+
prompt = self.tokenizer(text, return_tensors="pt")
52+
prompt.update({"pixel_values": pixel_values})
53+
return prompt
54+
55+
56+
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py
57+
def generate(
58+
self,
59+
pixel_values: torch.FloatTensor,
60+
input_ids: torch.FloatTensor,
61+
attention_mask: Optional[torch.LongTensor] = None,
62+
**generate_kwargs,
63+
) -> torch.LongTensor:
64+
"""Generate method for InternVL2 model without fixed use_cache."""
65+
assert self.img_context_token_id is not None
66+
vit_embeds = self.extract_feature(pixel_values)
67+
input_embeds = self.language_model.get_input_embeddings()(input_ids)
68+
B, N, C = input_embeds.shape
69+
input_embeds = input_embeds.reshape(B * N, C)
70+
71+
input_ids = input_ids.reshape(B * N)
72+
selected = (input_ids == self.img_context_token_id)
73+
assert selected.sum() != 0
74+
input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
75+
76+
input_embeds = input_embeds.reshape(B, N, C)
77+
78+
outputs = self.language_model.generate(
79+
inputs_embeds=input_embeds,
80+
attention_mask=attention_mask,
81+
**generate_kwargs,
82+
)
83+
84+
return outputs
85+
86+
87+
def run_test(
88+
hf_runner: Type[HfRunner],
89+
vllm_runner: Type[VllmRunner],
90+
image_assets: _ImageAssets,
91+
model: str,
92+
*,
93+
size_factors: List[float],
94+
dtype: str,
95+
max_tokens: int,
96+
num_logprobs: int,
97+
tensor_parallel_size: int,
98+
distributed_executor_backend: Optional[str] = None,
99+
):
100+
"""Inference result should be the same between hf and vllm.
101+
102+
All the image fixtures for the test is under tests/images.
103+
For huggingface runner, we provide the PIL images as input.
104+
For vllm runner, we provide MultiModalDataDict objects
105+
and corresponding vision language config as input.
106+
Note, the text input is also adjusted to abide by vllm contract.
107+
The text output is sanitized to be able to compare with hf.
108+
"""
109+
images = [asset.pil_image for asset in image_assets]
110+
111+
inputs_per_image = [(
112+
[prompt for _ in size_factors],
113+
[rescale_image_size(image, factor) for factor in size_factors],
114+
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
115+
116+
# NOTE: take care of the order. run vLLM first, and then run HF.
117+
# vLLM needs a fresh new process without cuda initialization.
118+
# if we run HF first, the cuda initialization will be done and it
119+
# will hurt multiprocessing backend with fork method (the default method).
120+
121+
# max_model_len should be greater than image_feature_size
122+
with vllm_runner(model,
123+
max_model_len=4096,
124+
dtype=dtype,
125+
tensor_parallel_size=tensor_parallel_size,
126+
distributed_executor_backend=distributed_executor_backend,
127+
enforce_eager=True) as vllm_model:
128+
vllm_outputs_per_image = [
129+
vllm_model.generate_greedy_logprobs(prompts,
130+
max_tokens,
131+
num_logprobs=num_logprobs,
132+
images=images)
133+
for prompts, images in inputs_per_image
134+
]
135+
136+
with hf_runner(model, dtype=dtype) as hf_model:
137+
img_context_token_id = hf_model.tokenizer.convert_tokens_to_ids(
138+
"<IMG_CONTEXT>")
139+
hf_model.model.img_context_token_id = img_context_token_id
140+
hf_model.processor = InternVLProcessor(hf_model)
141+
hf_model.model.get_output_embeddings = lambda: \
142+
hf_model.model.language_model.get_output_embeddings()
143+
hf_model.model.generate = types.MethodType(generate, hf_model.model)
144+
eos_token_id = hf_model.tokenizer.eos_token_id
145+
hf_outputs_per_image = [
146+
hf_model.generate_greedy_logprobs_limit(prompts,
147+
max_tokens,
148+
num_logprobs=num_logprobs,
149+
images=hf_images,
150+
eos_token_id=eos_token_id)
151+
for prompts, hf_images in inputs_per_image
152+
]
153+
154+
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
155+
vllm_outputs_per_image):
156+
# TODO: Check whether using original CLIPVisionModel can improve
157+
# consistency against HF
158+
check_logprobs_close(
159+
outputs_0_lst=hf_outputs,
160+
outputs_1_lst=vllm_outputs,
161+
name_0="hf",
162+
name_1="vllm",
163+
)
164+
165+
166+
target_dtype = "half"
167+
if is_cpu():
168+
target_dtype = "bfloat16"
169+
170+
171+
@pytest.mark.parametrize("model", models)
172+
@pytest.mark.parametrize(
173+
"size_factors",
174+
[
175+
# No image
176+
[],
177+
# Single-scale
178+
[1.0],
179+
# Single-scale, batched
180+
[1.0, 1.0, 1.0],
181+
# Multi-scale
182+
[0.25, 0.5, 1.0],
183+
],
184+
)
185+
@pytest.mark.parametrize("dtype", [target_dtype])
186+
@pytest.mark.parametrize("max_tokens", [128])
187+
@pytest.mark.parametrize("num_logprobs", [5])
188+
@torch.inference_mode()
189+
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
190+
dtype: str, max_tokens: int, num_logprobs: int) -> None:
191+
run_test(
192+
hf_runner,
193+
vllm_runner,
194+
image_assets,
195+
model,
196+
size_factors=size_factors,
197+
dtype=dtype,
198+
max_tokens=max_tokens,
199+
num_logprobs=num_logprobs,
200+
tensor_parallel_size=1,
201+
)

vllm/entrypoints/chat_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def _image_token_str(model_config: ModelConfig,
107107
return None
108108
if model_type.startswith("llava"):
109109
return tokenizer.decode(model_config.hf_config.image_token_index)
110-
if model_type == "chameleon":
110+
if model_type in ("chameleon", "internvl_chat"):
111111
return "<image>"
112112
raise TypeError(f"Unknown model type: {model_type}")
113113

vllm/model_executor/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
"GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
3838
"InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
3939
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
40+
"InternVLChatModel": ("internvl", "InternVLChatModel"),
4041
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
4142
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
4243
"LlavaForConditionalGeneration":

0 commit comments

Comments
 (0)