Skip to content

Commit 390ec88

Browse files
authored
[Misc] Consolidate Audio tests into multimodal common generation tests (#18214)
Signed-off-by: Isotr0py <2037008807@qq.com>
1 parent 5418176 commit 390ec88

File tree

9 files changed

+282
-215
lines changed

9 files changed

+282
-215
lines changed

tests/models/multimodal/generation/test_common.py

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@
88
from pathlib import PosixPath
99

1010
import pytest
11-
from transformers import (AutoModelForImageTextToText,
11+
from transformers import (AutoModel, AutoModelForImageTextToText,
1212
AutoModelForTextToWaveform, AutoModelForVision2Seq)
1313

1414
from vllm.platforms import current_platform
1515
from vllm.utils import identity
1616

17-
from ....conftest import (IMAGE_ASSETS, HfRunner, ImageTestAssets,
18-
VideoTestAssets, VllmRunner)
17+
from ....conftest import (IMAGE_ASSETS, AudioTestAssets, HfRunner,
18+
ImageTestAssets, VideoTestAssets, VllmRunner)
1919
from ....utils import (create_new_process_for_each_test, large_gpu_mark,
2020
multi_gpu_marks)
2121
from ...utils import check_outputs_equal
@@ -158,6 +158,17 @@
158158
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
159159
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
160160
),
161+
"ultravox": VLMTestInfo(
162+
models = ["fixie-ai/ultravox-v0_5-llama-3_2-1b"],
163+
test_type=VLMTestType.AUDIO,
164+
prompt_formatter=lambda audio_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{audio_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
165+
audio_idx_to_prompt=lambda idx: "<|audio|>",
166+
max_model_len=4096,
167+
max_num_seqs=2,
168+
auto_cls=AutoModel,
169+
hf_output_post_proc=model_utils.ultravox_trunc_hf_output,
170+
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
171+
),
161172
#### Extended model tests
162173
"aria": VLMTestInfo(
163174
models=["rhymes-ai/Aria"],
@@ -393,7 +404,6 @@
393404
formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
394405
),
395406
limit_mm_per_prompt={"video": 4},
396-
runner_mm_key="videos",
397407
)],
398408
),
399409
"llava_next_video": VLMTestInfo(
@@ -706,6 +716,7 @@ def _mark_splits(
706716
# - multi-image
707717
# - image embeddings
708718
# - video
719+
# - audio
709720
# - custom inputs
710721
@pytest.mark.parametrize(
711722
"model_type,test_case",
@@ -803,6 +814,28 @@ def test_video_models(model_type: str, test_case: ExpandableVLMTestArgs,
803814
)
804815

805816

817+
@pytest.mark.parametrize(
818+
"model_type,test_case",
819+
get_parametrized_options(
820+
VLM_TEST_SETTINGS,
821+
test_type=VLMTestType.AUDIO,
822+
create_new_process_for_each_test=False,
823+
))
824+
def test_audio_models(model_type: str, test_case: ExpandableVLMTestArgs,
825+
hf_runner: type[HfRunner], vllm_runner: type[VllmRunner],
826+
audio_assets: AudioTestAssets, monkeypatch):
827+
if model_type in REQUIRES_V0_MODELS:
828+
monkeypatch.setenv("VLLM_USE_V1", "0")
829+
model_test_info = VLM_TEST_SETTINGS[model_type]
830+
runners.run_audio_test(
831+
model_test_info=model_test_info,
832+
test_case=test_case,
833+
hf_runner=hf_runner,
834+
vllm_runner=vllm_runner,
835+
audio_assets=audio_assets,
836+
)
837+
838+
806839
@pytest.mark.parametrize(
807840
"model_type,test_case",
808841
get_parametrized_options(
@@ -930,6 +963,29 @@ def test_video_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs,
930963
)
931964

932965

966+
@pytest.mark.parametrize(
967+
"model_type,test_case",
968+
get_parametrized_options(
969+
VLM_TEST_SETTINGS,
970+
test_type=VLMTestType.AUDIO,
971+
create_new_process_for_each_test=True,
972+
))
973+
def test_audio_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs,
974+
hf_runner: type[HfRunner],
975+
vllm_runner: type[VllmRunner],
976+
audio_assets: AudioTestAssets, monkeypatch):
977+
if model_type in REQUIRES_V0_MODELS:
978+
monkeypatch.setenv("VLLM_USE_V1", "0")
979+
model_test_info = VLM_TEST_SETTINGS[model_type]
980+
runners.run_audio_test(
981+
model_test_info=model_test_info,
982+
test_case=test_case,
983+
hf_runner=hf_runner,
984+
vllm_runner=vllm_runner,
985+
audio_assets=audio_assets,
986+
)
987+
988+
933989
@pytest.mark.parametrize(
934990
"model_type,test_case",
935991
get_parametrized_options(

tests/models/multimodal/generation/test_ultravox.py

Lines changed: 3 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,16 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import json
4-
from typing import Any, Optional
4+
from typing import Any
55

66
import numpy as np
77
import pytest
88
import pytest_asyncio
9-
from transformers import AutoModel, AutoTokenizer
9+
from transformers import AutoTokenizer
1010

11-
from vllm.multimodal.audio import resample_audio_librosa
12-
from vllm.sequence import SampleLogprobs
13-
14-
from ....conftest import AUDIO_ASSETS, AudioTestAssets, HfRunner, VllmRunner
11+
from ....conftest import AUDIO_ASSETS, AudioTestAssets, VllmRunner
1512
from ....utils import RemoteOpenAIServer
1613
from ...registry import HF_EXAMPLE_MODELS
17-
from ...utils import check_logprobs_close
1814

1915
MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
2016

@@ -88,79 +84,6 @@ def _get_prompt(audio_count, question, placeholder):
8884
add_generation_prompt=True)
8985

9086

91-
def vllm_to_hf_output(vllm_output: tuple[list[int], str,
92-
Optional[SampleLogprobs]],
93-
model: str):
94-
"""Sanitize vllm output to be comparable with hf output."""
95-
output_ids, output_str, out_logprobs = vllm_output
96-
97-
tokenizer = AutoTokenizer.from_pretrained(model)
98-
eos_token_id = tokenizer.eos_token_id
99-
100-
hf_output_ids = output_ids[:]
101-
hf_output_str = output_str
102-
if hf_output_ids[-1] == eos_token_id:
103-
hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
104-
105-
return hf_output_ids, hf_output_str, out_logprobs
106-
107-
108-
def run_test(
109-
hf_runner: type[HfRunner],
110-
vllm_runner: type[VllmRunner],
111-
prompts_and_audios: list[tuple[str, str, AudioTuple]],
112-
model: str,
113-
*,
114-
dtype: str,
115-
max_tokens: int,
116-
num_logprobs: int,
117-
**kwargs,
118-
):
119-
"""Inference result should be the same between hf and vllm."""
120-
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
121-
model_info.check_available_online(on_fail="skip")
122-
model_info.check_transformers_version(on_fail="skip")
123-
124-
# NOTE: take care of the order. run vLLM first, and then run HF.
125-
# vLLM needs a fresh new process without cuda initialization.
126-
# if we run HF first, the cuda initialization will be done and it
127-
# will hurt multiprocessing backend with fork method (the default method).
128-
129-
with vllm_runner(model, dtype=dtype, enforce_eager=True,
130-
**kwargs) as vllm_model:
131-
vllm_outputs_per_audio = [
132-
vllm_model.generate_greedy_logprobs([vllm_prompt],
133-
max_tokens,
134-
num_logprobs=num_logprobs,
135-
audios=[audio])
136-
for vllm_prompt, _, audio in prompts_and_audios
137-
]
138-
139-
with hf_runner(model, dtype=dtype, auto_cls=AutoModel) as hf_model:
140-
hf_outputs_per_audio = [
141-
hf_model.generate_greedy_logprobs_limit(
142-
[hf_prompt],
143-
max_tokens,
144-
num_logprobs=num_logprobs,
145-
audios=[(resample_audio_librosa(audio[0],
146-
orig_sr=audio[1],
147-
target_sr=16000), 16000)])
148-
for _, hf_prompt, audio in prompts_and_audios
149-
]
150-
151-
for hf_outputs, vllm_outputs in zip(hf_outputs_per_audio,
152-
vllm_outputs_per_audio):
153-
check_logprobs_close(
154-
outputs_0_lst=hf_outputs,
155-
outputs_1_lst=[
156-
vllm_to_hf_output(vllm_output, model)
157-
for vllm_output in vllm_outputs
158-
],
159-
name_0="hf",
160-
name_1="vllm",
161-
)
162-
163-
16487
def run_multi_audio_test(
16588
vllm_runner: type[VllmRunner],
16689
prompts_and_audios: list[tuple[str, list[AudioTuple]]],
@@ -194,35 +117,6 @@ def run_multi_audio_test(
194117
assert all(tokens for tokens, *_ in vllm_outputs)
195118

196119

197-
@pytest.mark.core_model
198-
@pytest.mark.parametrize("dtype", ["bfloat16"])
199-
@pytest.mark.parametrize("max_tokens", [128])
200-
@pytest.mark.parametrize("num_logprobs", [5])
201-
@pytest.mark.parametrize("vllm_kwargs", [
202-
pytest.param({}, marks=pytest.mark.cpu_model),
203-
pytest.param(CHUNKED_PREFILL_KWARGS),
204-
])
205-
def test_models(hf_runner, vllm_runner, audio_assets: AudioTestAssets,
206-
dtype: str, max_tokens: int, num_logprobs: int,
207-
vllm_kwargs: dict) -> None:
208-
audio_inputs = [(
209-
_get_prompt(1, audio, VLLM_PLACEHOLDER),
210-
_get_prompt(1, audio, HF_PLACEHOLDER),
211-
audio.audio_and_sample_rate,
212-
) for audio in audio_assets]
213-
214-
run_test(
215-
hf_runner,
216-
vllm_runner,
217-
audio_inputs,
218-
MODEL_NAME,
219-
dtype=dtype,
220-
max_tokens=max_tokens,
221-
num_logprobs=num_logprobs,
222-
**vllm_kwargs,
223-
)
224-
225-
226120
@pytest.mark.core_model
227121
@pytest.mark.parametrize("dtype", ["half"])
228122
@pytest.mark.parametrize("max_tokens", [128])

0 commit comments

Comments
 (0)