Skip to content

Commit d3b444b

Browse files
author
Mu Huai
committed
Merge branch 'main' of github.com:RichardoMrMu/vllm
2 parents 6683d89 + 7ea6cb2 commit d3b444b

File tree

36 files changed

+1232
-243
lines changed

36 files changed

+1232
-243
lines changed

.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,12 @@ function cpu_tests() {
3232
set -e
3333
pip install pytest pytest-asyncio einops peft Pillow soundfile transformers_stream_generator matplotlib
3434
pip install sentence-transformers datamodel_code_generator
35-
pytest -v -s tests/models/embedding/language/test_cls_models.py::test_classification_models[float-jason9693/Qwen2.5-1.5B-apeach]
36-
pytest -v -s tests/models/embedding/language/test_embedding.py::test_models[half-BAAI/bge-base-en-v1.5]
37-
pytest -v -s tests/models/encoder_decoder/language -m cpu_model"
35+
pytest -v -s tests/models/language/generation/test_bart.py -m cpu_model
36+
pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-openai-community/gpt2]
37+
pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-facebook/opt-125m]
38+
pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-google/gemma-1.1-2b-it]
39+
pytest -v -s tests/models/language/pooling/test_classification.py::test_models[float-jason9693/Qwen2.5-1.5B-apeach]
40+
pytest -v -s tests/models/language/pooling/test_embedding.py::test_models[half-BAAI/bge-base-en-v1.5]"
3841
}
3942

4043
# All of CPU tests are expected to be finished less than 40 mins.

benchmarks/benchmark_throughput.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,10 @@ async def run_vllm_async(
146146

147147
async with build_async_engine_client_from_engine_args(
148148
engine_args, disable_frontend_multiprocessing) as llm:
149+
model_config = await llm.get_model_config()
149150
assert all(
150-
llm.model_config.max_model_len >= (request.prompt_len +
151-
request.expected_output_len)
151+
model_config.max_model_len >= (request.prompt_len +
152+
request.expected_output_len)
152153
for request in requests), (
153154
"Please ensure that max_model_len is greater than the sum of"
154155
" prompt_len and expected_output_len for all requests.")
@@ -599,7 +600,7 @@ def validate_args(args):
599600
"--lora-path",
600601
type=str,
601602
default=None,
602-
help="Path to the lora adapters to use. This can be an absolute path, "
603+
help="Path to the LoRA adapters to use. This can be an absolute path, "
603604
"a relative path, or a Hugging Face model identifier.")
604605
parser.add_argument(
605606
"--prefix-len",

docs/source/features/tool_calling.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,13 @@ For Qwen2.5, the chat template in tokenizer_config.json has already included sup
236236

237237
Flags: `--tool-call-parser hermes`
238238

239+
### DeepSeek-V3 Models (`deepseek_v3`)
240+
241+
Supported models:
242+
* `deepseek-ai/DeepSeek-V3-0324`
243+
244+
Flags: `--tool-call-parser deepseek_v3 --chat-template examples/tool_chat_template_deepseekv3.jinja`
245+
239246
### Models with Pythonic Tool Calls (`pythonic`)
240247

241248
A growing number of models output a python list to represent tool calls instead of using JSON. This has the advantage of inherently supporting parallel tool calls and removing ambiguity around the JSON schema required for tool calls. The `pythonic` tool parser can support such models.

docs/source/models/supported_models.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,10 +1045,10 @@ Specified using `--task generate`.
10451045
*
10461046
* ✅︎
10471047
* ✅︎
1048-
- * `Ovis2ForConditionalGeneration`<sup>^</sup>
1049-
* Ovis2
1048+
- * `Ovis`
1049+
* Ovis2, Ovis1.6
10501050
* T + I<sup>+</sup>
1051-
* `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis2-2B`, etc.
1051+
* `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc.
10521052
*
10531053
*
10541054
* ✅︎

examples/offline_inference/vision_language.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -725,8 +725,8 @@ def run_nvlm_d(questions: list[str], modality: str) -> ModelRequestData:
725725
)
726726

727727

728-
# Ovis2
729-
def run_ovis2(questions: list[str], modality: str) -> ModelRequestData:
728+
# Ovis
729+
def run_ovis(questions: list[str], modality: str) -> ModelRequestData:
730730
assert modality == "image"
731731

732732
model_name = "AIDC-AI/Ovis2-1B"
@@ -737,15 +737,18 @@ def run_ovis2(questions: list[str], modality: str) -> ModelRequestData:
737737
max_num_seqs=2,
738738
trust_remote_code=True,
739739
dtype="half",
740-
hf_overrides={"architectures": ["Ovis2ForConditionalGeneration"]},
741740
limit_mm_per_prompt={modality: 1},
742741
)
743742

744-
placeholder = "<image>\n"
745-
prompts = [("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
746-
f"<|im_start|>user\n{placeholder}"
747-
f"{question}<|im_end|>\n"
748-
"<|im_start|>assistant\n") for question in questions]
743+
tokenizer = AutoTokenizer.from_pretrained(model_name,
744+
trust_remote_code=True)
745+
messages = [[{
746+
'role': 'user',
747+
'content': f"<image>\n{question}"
748+
}] for question in questions]
749+
prompts = tokenizer.apply_chat_template(messages,
750+
tokenize=False,
751+
add_generation_prompt=True)
749752

750753
return ModelRequestData(
751754
engine_args=engine_args,
@@ -1069,7 +1072,7 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
10691072
"llama4": run_llama4,
10701073
"molmo": run_molmo,
10711074
"NVLM_D": run_nvlm_d,
1072-
"ovis2": run_ovis2,
1075+
"ovis": run_ovis,
10731076
"paligemma": run_paligemma,
10741077
"paligemma2": run_paligemma2,
10751078
"phi3_v": run_phi3v,

examples/offline_inference/vision_language_multi_image.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -436,8 +436,8 @@ def load_nvlm_d(question: str, image_urls: list[str]) -> ModelRequestData:
436436
)
437437

438438

439-
# Ovis2
440-
def load_ovis2(question: str, image_urls: list[str]) -> ModelRequestData:
439+
# Ovis
440+
def load_ovis(question: str, image_urls: list[str]) -> ModelRequestData:
441441
model_name = "AIDC-AI/Ovis2-1B"
442442

443443
engine_args = EngineArgs(
@@ -447,15 +447,17 @@ def load_ovis2(question: str, image_urls: list[str]) -> ModelRequestData:
447447
trust_remote_code=True,
448448
dtype="half",
449449
limit_mm_per_prompt={"image": len(image_urls)},
450-
hf_overrides={"architectures": ["Ovis2ForConditionalGeneration"]},
451450
)
452451

453-
placeholder = '\n'.join(
454-
[f'Image {i+1}: <image>' for i in range(len(image_urls))]) + '\n'
455-
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
456-
f"<|im_start|>user\n{placeholder}"
457-
f"{question}<|im_end|>\n"
458-
"<|im_start|>assistant\n")
452+
placeholders = "\n".join(f"Image-{i}: <image>\n"
453+
for i, _ in enumerate(image_urls, start=1))
454+
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
455+
456+
tokenizer = AutoTokenizer.from_pretrained(model_name,
457+
trust_remote_code=True)
458+
prompt = tokenizer.apply_chat_template(messages,
459+
tokenize=False,
460+
add_generation_prompt=True)
459461

460462
return ModelRequestData(
461463
engine_args=engine_args,
@@ -713,7 +715,7 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData:
713715
"mistral3": load_mistral3,
714716
"mllama": load_mllama,
715717
"NVLM_D": load_nvlm_d,
716-
"ovis2": load_ovis2,
718+
"ovis": load_ovis,
717719
"phi3_v": load_phi3v,
718720
"phi4_mm": load_phi4mm,
719721
"pixtral_hf": load_pixtral_hf,
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
{% if not add_generation_prompt is defined %}
2+
{% set add_generation_prompt = false %}
3+
{% endif %}
4+
5+
{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='', is_first_sp=true, is_last_user=false) %}
6+
7+
{%- for message in messages %}
8+
{%- if message['role'] == 'system' %}
9+
{%- if ns.is_first_sp %}
10+
{% set ns.system_prompt = ns.system_prompt + message['content'] %}
11+
{% set ns.is_first_sp = false %}
12+
{%- else %}
13+
{% set ns.system_prompt = ns.system_prompt + '\n\n' + message['content'] %}
14+
{%- endif %}
15+
{%- endif %}
16+
{%- endfor %}
17+
18+
{{ bos_token }}
19+
{{ ns.system_prompt }}
20+
{%- if tools %}
21+
{{"\n\n# Tools\n\nYou may call one or more functions to assist with the user query." }}
22+
{%- for tool in tools %}
23+
{{- "\n" }}
24+
{{- tool | tojson }}
25+
{%- endfor %}
26+
{{"\n</tools>\n\n"}}
27+
28+
{{"For function call returns, you should first print <|tool▁calls▁begin|>"}}
29+
30+
{{"For each function call, you should return object like:\n" }}
31+
{{"<|tool▁call▁begin|>function<|tool▁sep|><function_name>\n```json\n<function_arguments_in_json_format>\n```<|tool▁call▁end|>"}}
32+
33+
{{"At the end of function call returns, you should print <|tool▁calls▁end|><|end▁of▁sentence|>"}}
34+
{%- endif %}
35+
36+
{%- for message in messages %}
37+
{%- if message['role'] == 'user' %}
38+
{%- set ns.is_tool = false -%}
39+
{%- set ns.is_first = false -%}
40+
{%- set ns.is_last_user = true -%}
41+
{{'<|User|>' + message['content'] + '<|Assistant|>'}}
42+
{%- endif %}
43+
44+
{%- if message['role'] == 'assistant' and message['tool_calls'] is defined and message['tool_calls'] is not none %}
45+
{%- set ns.is_last_user = false -%}
46+
{%- if ns.is_tool %}
47+
{{'<|tool▁outputs▁end|>'}}
48+
{%- endif %}
49+
{%- set ns.is_first = false %}
50+
{%- set ns.is_tool = false -%}
51+
{%- set ns.is_output_first = true %}
52+
53+
{%- for tool in message['tool_calls'] %}
54+
{%- if not ns.is_first %}
55+
{%- if message['content'] is none %}
56+
{{'<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments']|tojson + '\n' + '```' + '<|tool▁call▁end|>'}}
57+
{%- else %}
58+
{{message['content'] + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments']|tojson + '\n' + '```' + '<|tool▁call▁end|>'}}
59+
{%- endif %}
60+
{%- set ns.is_first = true -%}
61+
{%- else %}
62+
{{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments']|tojson + '\n' + '```' + '<|tool▁call▁end|>'}}
63+
{%- endif %}
64+
{%- endfor %}
65+
{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}
66+
{%- endif %}
67+
{%- if message['role'] == 'assistant' and (message['tool_calls'] is not defined or message['tool_calls'] is none)%}
68+
{%- set ns.is_last_user = false -%}
69+
{%- if ns.is_tool %}
70+
{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}
71+
{%- set ns.is_tool = false -%}
72+
{%- else %}
73+
{% set content = message['content'] %}
74+
{{content + '<|end▁of▁sentence|>'}}
75+
{%- endif %}
76+
{%- endif %}
77+
78+
{%- if message['role'] == 'tool' %}
79+
{%- set ns.is_last_user = false -%}
80+
{%- set ns.is_tool = true -%}
81+
{%- if ns.is_output_first %}
82+
{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}
83+
{%- set ns.is_output_first = false %}
84+
{%- else %}
85+
{{'\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}
86+
{%- endif %}
87+
{%- endif %}
88+
{%- endfor -%}
89+
90+
{% if ns.is_tool %}
91+
{{'<|tool▁outputs▁end|>'}}
92+
{% endif %}
93+
94+
{% if add_generation_prompt and not ns.is_last_user and not ns.is_tool %}
95+
{{'<|Assistant|>'}}
96+
{% endif %}

tests/conftest.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,10 +355,16 @@ def __init__(
355355
**model_kwargs,
356356
)
357357

358+
# in case some unquantized custom models are not in same dtype
359+
if (getattr(model, "quantization_method", None) is None
360+
and any(p.dtype != self.dtype
361+
for p in model.parameters())):
362+
model = model.to(dtype=self.dtype)
363+
358364
if (getattr(model, "quantization_method", None) != "bitsandbytes"
359365
and len({p.device
360366
for p in model.parameters()}) < 2):
361-
model = model.to(self.device)
367+
model = model.to(device=self.device)
362368

363369
self.model = model
364370

tests/models/multimodal/generation/test_common.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,31 @@
476476
max_num_seqs=2,
477477
patch_hf_runner=model_utils.molmo_patch_hf_runner,
478478
),
479+
"ovis1_6-gemma2": VLMTestInfo(
480+
models=["AIDC-AI/Ovis1.6-Gemma2-9B"],
481+
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
482+
prompt_formatter=lambda img_prompt: f"<bos><start_of_turn>user\n{img_prompt}<end_of_turn>\n<start_of_turn>model\n", # noqa: E501
483+
img_idx_to_prompt=lambda idx: "<image>\n", # noqa: E501
484+
max_model_len=4096,
485+
max_num_seqs=2,
486+
dtype="half",
487+
# use sdpa mode for hf runner since ovis2 didn't work with flash_attn
488+
hf_model_kwargs={"llm_attn_implementation": "sdpa"},
489+
patch_hf_runner=model_utils.ovis_patch_hf_runner,
490+
marks=[large_gpu_mark(min_gb=32)],
491+
),
492+
"ovis1_6": VLMTestInfo(
493+
models=["AIDC-AI/Ovis1.6-Llama3.2-3B"],
494+
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
495+
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful and honest multimodal assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
496+
img_idx_to_prompt=lambda idx: "<image>\n", # noqa: E501
497+
max_model_len=4096,
498+
max_num_seqs=2,
499+
dtype="half",
500+
# use sdpa mode for hf runner since ovis2 didn't work with flash_attn
501+
hf_model_kwargs={"llm_attn_implementation": "sdpa"},
502+
patch_hf_runner=model_utils.ovis_patch_hf_runner,
503+
),
479504
"ovis2": VLMTestInfo(
480505
models=["AIDC-AI/Ovis2-1B"],
481506
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
@@ -486,7 +511,7 @@
486511
dtype="half",
487512
# use sdpa mode for hf runner since ovis2 didn't work with flash_attn
488513
hf_model_kwargs={"llm_attn_implementation": "sdpa"},
489-
patch_hf_runner=model_utils.ovis2_patch_hf_runner,
514+
patch_hf_runner=model_utils.ovis_patch_hf_runner,
490515
),
491516
"phi3v": VLMTestInfo(
492517
models=["microsoft/Phi-3.5-vision-instruct"],

tests/models/multimodal/generation/vlm_utils/model_utils.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -678,20 +678,25 @@ def _generate(self, max_new_tokens=None, do_sample=None, **kwargs):
678678
return hf_model
679679

680680

681-
def ovis2_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
681+
def ovis_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
682682
"""Patches and returns an instance of the HfRunner to use for Ovis2."""
683-
hf_model.model.visual_tokenizer.to(hf_model.dtype)
684-
hf_model.model.vte.to(hf_model.dtype)
685-
hf_model.model.llm.to(hf_model.dtype)
686-
687683
hf_model.model.get_output_embeddings = lambda: \
688684
hf_model.model.llm.get_output_embeddings()
689685

690686
def processor(*args, text="", images=None, **kwargs):
691687
text_tokenizer = hf_model.model.get_text_tokenizer()
692688
images = [images] if isinstance(images, Image) else images
693689

694-
text = text.split("<|im_start|>user\n")[1].split("<|im_end|>\n")[0]
690+
prompt_start_and_end = {
691+
"qwen2": ("<|im_start|>user\n", "<|im_end|>\n"),
692+
"llama":
693+
("<|start_header_id|>user<|end_header_id|>\n\n", "<|eot_id|>"),
694+
"gemma2": ("<start_of_turn>user\n", "<end_of_turn>\n"),
695+
}
696+
for start, end in prompt_start_and_end.values():
697+
if start in text and end in text:
698+
text = text.split(start)[1].split(end)[0]
699+
break
695700

696701
prompt, input_ids, pixel_values = hf_model.model.preprocess_inputs(
697702
text_or_conversations=text, images=images)

tests/models/multimodal/processing/test_common.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@ def _test_processing_correctness_hf(
146146
batch_idx: int,
147147
ignore_mm_keys: Optional[set[str]] = None,
148148
):
149-
if model_config.hf_config.model_type in ("mllama", "whisper", "ultravox"):
149+
if model_config.hf_config.model_type in ("mllama", "ovis", "ultravox",
150+
"whisper"):
150151
# For some multimodal models, tokenizer will always add bos_token
151152
# at the beginning of prompt by default, causing hf_processor outputs
152153
# incorrect token ids. So we need use `add_special_tokens=False` here
@@ -274,6 +275,8 @@ def _test_processing_correctness_mistral(
274275
"allenai/Molmo-7B-D-0924",
275276
"allenai/Molmo-7B-O-0924",
276277
"nvidia/NVLM-D-72B",
278+
"AIDC-AI/Ovis1.6-Gemma2-9B",
279+
"AIDC-AI/Ovis1.6-Llama3.2-3B",
277280
"AIDC-AI/Ovis2-1B",
278281
"google/paligemma-3b-mix-224",
279282
"google/paligemma2-3b-ft-docci-448",

tests/models/registry.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -355,9 +355,9 @@ def check_available_online(
355355
max_transformers_version="4.48",
356356
transformers_version_reason="Use of deprecated imports which have been removed.", # noqa: E501
357357
extras={"phi3.5": "microsoft/Phi-3.5-vision-instruct"}), # noqa: E501
358-
"Ovis2ForConditionalGeneration": _HfExamplesInfo("AIDC-AI/Ovis2-1B",
359-
trust_remote_code=True,
360-
hf_overrides={"architectures": ["Ovis2ForConditionalGeneration"]}), # noqa: E501
358+
"Ovis": _HfExamplesInfo("AIDC-AI/Ovis2-1B", trust_remote_code=True,
359+
extras={"1.6-llama": "AIDC-AI/Ovis1.6-Llama3.2-3B",
360+
"1.6-gemma": "AIDC-AI/Ovis1.6-Gemma2-9B"}), # noqa: E501
361361
"Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct",
362362
trust_remote_code=True),
363363
"PixtralForConditionalGeneration": _HfExamplesInfo("mistralai/Pixtral-12B-2409", # noqa: E501

0 commit comments

Comments
 (0)