Skip to content

Commit c18d19f

Browse files
DarkLight1337ywang96Isotr0py
authored andcommitted
[Model] Support NVLM-D and fix QK Norm in InternViT (vllm-project#9045)
Co-authored-by: Roger Wang <ywang@roblox.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: Amit Garg <mitgarg17495@gmail.com>
1 parent 4ce1077 commit c18d19f

File tree

12 files changed

+518
-236
lines changed

12 files changed

+518
-236
lines changed

docs/source/models/supported_models.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,9 @@ Multimodal Language Models
315315

316316
.. _supported_vlms:
317317

318+
Text Generation
319+
---------------
320+
318321
.. list-table::
319322
:widths: 25 25 25 25 5 5
320323
:header-rows: 1
@@ -384,7 +387,13 @@ Multimodal Language Models
384387
- Image
385388
- :code:`meta-llama/Llama-3.2-90B-Vision-Instruct`, :code:`meta-llama/Llama-3.2-11B-Vision`, etc.
386389
-
390+
-
391+
* - :code:`NVLM_D_Model`
392+
- NVLM-D 1.0
393+
- Image\ :sup:`E+`
394+
- :code:`nvidia/NVLM-D-72B`, etc.
387395
-
396+
- ✅︎
388397
* - :code:`PaliGemmaForConditionalGeneration`
389398
- PaliGemma
390399
- Image\ :sup:`E`

examples/offline_inference_vision_language.py

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919

2020
# LLaVA-1.5
21-
def run_llava(question, modality):
21+
def run_llava(question: str, modality: str):
2222
assert modality == "image"
2323

2424
prompt = f"USER: <image>\n{question}\nASSISTANT:"
@@ -29,7 +29,7 @@ def run_llava(question, modality):
2929

3030

3131
# LLaVA-1.6/LLaVA-NeXT
32-
def run_llava_next(question, modality):
32+
def run_llava_next(question: str, modality: str):
3333
assert modality == "image"
3434

3535
prompt = f"[INST] <image>\n{question} [/INST]"
@@ -40,7 +40,7 @@ def run_llava_next(question, modality):
4040

4141
# LlaVA-NeXT-Video
4242
# Currently only support for video input
43-
def run_llava_next_video(question, modality):
43+
def run_llava_next_video(question: str, modality: str):
4444
assert modality == "video"
4545

4646
prompt = f"USER: <video>\n{question} ASSISTANT:"
@@ -50,7 +50,7 @@ def run_llava_next_video(question, modality):
5050

5151

5252
# LLaVA-OneVision
53-
def run_llava_onevision(question, modality):
53+
def run_llava_onevision(question: str, modality: str):
5454

5555
if modality == "video":
5656
prompt = f"<|im_start|>user <video>\n{question}<|im_end|> \
@@ -67,7 +67,7 @@ def run_llava_onevision(question, modality):
6767

6868

6969
# Fuyu
70-
def run_fuyu(question, modality):
70+
def run_fuyu(question: str, modality: str):
7171
assert modality == "image"
7272

7373
prompt = f"{question}\n"
@@ -77,7 +77,7 @@ def run_fuyu(question, modality):
7777

7878

7979
# Phi-3-Vision
80-
def run_phi3v(question, modality):
80+
def run_phi3v(question: str, modality: str):
8181
assert modality == "image"
8282

8383
prompt = f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n" # noqa: E501
@@ -112,7 +112,7 @@ def run_phi3v(question, modality):
112112

113113

114114
# PaliGemma
115-
def run_paligemma(question, modality):
115+
def run_paligemma(question: str, modality: str):
116116
assert modality == "image"
117117

118118
# PaliGemma has special prompt format for VQA
@@ -123,7 +123,7 @@ def run_paligemma(question, modality):
123123

124124

125125
# Chameleon
126-
def run_chameleon(question, modality):
126+
def run_chameleon(question: str, modality: str):
127127
assert modality == "image"
128128

129129
prompt = f"{question}<image>"
@@ -133,7 +133,7 @@ def run_chameleon(question, modality):
133133

134134

135135
# MiniCPM-V
136-
def run_minicpmv(question, modality):
136+
def run_minicpmv(question: str, modality: str):
137137
assert modality == "image"
138138

139139
# 2.0
@@ -176,7 +176,7 @@ def run_minicpmv(question, modality):
176176

177177

178178
# InternVL
179-
def run_internvl(question, modality):
179+
def run_internvl(question: str, modality: str):
180180
assert modality == "image"
181181

182182
model_name = "OpenGVLab/InternVL2-2B"
@@ -203,8 +203,32 @@ def run_internvl(question, modality):
203203
return llm, prompt, stop_token_ids
204204

205205

206+
# NVLM-D
207+
def run_nvlm_d(question: str, modality: str):
208+
assert modality == "image"
209+
210+
model_name = "nvidia/NVLM-D-72B"
211+
212+
# Adjust this as necessary to fit in GPU
213+
llm = LLM(
214+
model=model_name,
215+
trust_remote_code=True,
216+
max_model_len=4096,
217+
tensor_parallel_size=4,
218+
)
219+
220+
tokenizer = AutoTokenizer.from_pretrained(model_name,
221+
trust_remote_code=True)
222+
messages = [{'role': 'user', 'content': f"<image>\n{question}"}]
223+
prompt = tokenizer.apply_chat_template(messages,
224+
tokenize=False,
225+
add_generation_prompt=True)
226+
stop_token_ids = None
227+
return llm, prompt, stop_token_ids
228+
229+
206230
# BLIP-2
207-
def run_blip2(question, modality):
231+
def run_blip2(question: str, modality: str):
208232
assert modality == "image"
209233

210234
# BLIP-2 prompt format is inaccurate on HuggingFace model repository.
@@ -216,7 +240,7 @@ def run_blip2(question, modality):
216240

217241

218242
# Qwen
219-
def run_qwen_vl(question, modality):
243+
def run_qwen_vl(question: str, modality: str):
220244
assert modality == "image"
221245

222246
llm = LLM(
@@ -232,7 +256,7 @@ def run_qwen_vl(question, modality):
232256

233257

234258
# Qwen2-VL
235-
def run_qwen2_vl(question, modality):
259+
def run_qwen2_vl(question: str, modality: str):
236260
assert modality == "image"
237261

238262
model_name = "Qwen/Qwen2-VL-7B-Instruct"
@@ -252,8 +276,8 @@ def run_qwen2_vl(question, modality):
252276
return llm, prompt, stop_token_ids
253277

254278

255-
# LLama
256-
def run_mllama(question, modality):
279+
# LLama 3.2
280+
def run_mllama(question: str, modality: str):
257281
assert modality == "image"
258282

259283
model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
@@ -287,6 +311,7 @@ def run_mllama(question, modality):
287311
"minicpmv": run_minicpmv,
288312
"blip-2": run_blip2,
289313
"internvl_chat": run_internvl,
314+
"NVLM_D": run_nvlm_d,
290315
"qwen_vl": run_qwen_vl,
291316
"qwen2_vl": run_qwen2_vl,
292317
"mllama": run_mllama,

examples/offline_inference_vision_language_multi_image.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,39 @@ def load_internvl(question: str, image_urls: List[str]) -> ModelRequestData:
144144
)
145145

146146

147+
def load_nvlm_d(question: str, image_urls: List[str]):
148+
model_name = "nvidia/NVLM-D-72B"
149+
150+
# Adjust this as necessary to fit in GPU
151+
llm = LLM(
152+
model=model_name,
153+
trust_remote_code=True,
154+
max_model_len=8192,
155+
tensor_parallel_size=4,
156+
limit_mm_per_prompt={"image": len(image_urls)},
157+
mm_processor_kwargs={"max_dynamic_patch": 4},
158+
)
159+
160+
placeholders = "\n".join(f"Image-{i}: <image>\n"
161+
for i, _ in enumerate(image_urls, start=1))
162+
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
163+
164+
tokenizer = AutoTokenizer.from_pretrained(model_name,
165+
trust_remote_code=True)
166+
prompt = tokenizer.apply_chat_template(messages,
167+
tokenize=False,
168+
add_generation_prompt=True)
169+
stop_token_ids = None
170+
171+
return ModelRequestData(
172+
llm=llm,
173+
prompt=prompt,
174+
stop_token_ids=stop_token_ids,
175+
image_data=[fetch_image(url) for url in image_urls],
176+
chat_template=None,
177+
)
178+
179+
147180
def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
148181
try:
149182
from qwen_vl_utils import process_vision_info
@@ -204,6 +237,7 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
204237
model_example_map = {
205238
"phi3_v": load_phi3v,
206239
"internvl_chat": load_internvl,
240+
"NVLM_D": load_nvlm_d,
207241
"qwen2_vl": load_qwen2_vl,
208242
"qwen_vl_chat": load_qwenvl_chat,
209243
}

vllm/entrypoints/chat_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def _placeholder_str(self, modality: ModalityStr,
157157
if model_type.startswith("llava"):
158158
return self._cached_token_str(self._tokenizer,
159159
hf_config.image_token_index)
160-
if model_type in ("chameleon", "internvl_chat"):
160+
if model_type in ("chameleon", "internvl_chat", "NVLM_D"):
161161
return "<image>"
162162
if model_type == "mllama":
163163
return "<|image|>"

vllm/model_executor/layers/layernorm.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,16 @@ def __init__(
1818
self,
1919
hidden_size: int,
2020
eps: float = 1e-6,
21+
var_hidden_size: Optional[int] = None,
2122
) -> None:
2223
super().__init__()
23-
self.weight = nn.Parameter(torch.ones(hidden_size))
24+
25+
self.hidden_size = hidden_size
2426
self.variance_epsilon = eps
27+
self.variance_size_override = (None if var_hidden_size == hidden_size
28+
else var_hidden_size)
29+
30+
self.weight = nn.Parameter(torch.ones(hidden_size))
2531

2632
def forward_native(
2733
self,
@@ -35,7 +41,23 @@ def forward_native(
3541
x = x + residual.to(torch.float32)
3642
residual = x.to(orig_dtype)
3743

38-
variance = x.pow(2).mean(dim=-1, keepdim=True)
44+
hidden_size = x.shape[-1]
45+
if hidden_size != self.hidden_size:
46+
raise ValueError("Expected hidden_size to be "
47+
f"{self.hidden_size}, but found: {hidden_size}")
48+
49+
if self.variance_size_override is None:
50+
x_var = x
51+
else:
52+
if hidden_size < self.variance_size_override:
53+
raise ValueError(
54+
"Expected hidden_size to be at least "
55+
f"{self.variance_size_override}, but found: {hidden_size}")
56+
57+
x_var = x[:, :, :self.variance_size_override]
58+
59+
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
60+
3961
x = x * torch.rsqrt(variance + self.variance_epsilon)
4062
x = x.to(orig_dtype) * self.weight
4163
if residual is None:
@@ -48,6 +70,9 @@ def forward_cuda(
4870
x: torch.Tensor,
4971
residual: Optional[torch.Tensor] = None,
5072
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
73+
if self.variance_size_override is not None:
74+
return self.forward_native(x, residual)
75+
5176
from vllm import _custom_ops as ops
5277

5378
if residual is not None:
@@ -72,6 +97,9 @@ def forward_xpu(
7297
x: torch.Tensor,
7398
residual: Optional[torch.Tensor] = None,
7499
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
100+
if self.variance_size_override is not None:
101+
return self.forward_native(x, residual)
102+
75103
from vllm._ipex_ops import ipex_ops as ops
76104

77105
if residual is not None:

0 commit comments

Comments
 (0)