16
16
17
17
def _test_processing_correctness (
18
18
model_id : str ,
19
- modalities : dict [str , bool ],
20
19
hit_rate : float ,
21
20
num_batches : int ,
22
21
simplify_rate : float ,
@@ -25,11 +24,6 @@ def _test_processing_correctness(
25
24
model_info .check_available_online (on_fail = "skip" )
26
25
model_info .check_transformers_version (on_fail = "skip" )
27
26
28
- limit_mm_per_prompt = {
29
- modality : 3 if supports_multi else 1
30
- for modality , supports_multi in modalities .items ()
31
- }
32
-
33
27
model_config = ModelConfig (
34
28
model_id ,
35
29
task = "auto" ,
@@ -40,18 +34,29 @@ def _test_processing_correctness(
40
34
dtype = "float16" ,
41
35
revision = None ,
42
36
hf_overrides = model_info .hf_overrides ,
43
- limit_mm_per_prompt = limit_mm_per_prompt ,
44
37
)
45
38
46
39
model_cls = MULTIMODAL_REGISTRY ._get_model_cls (model_config )
47
40
factories = MULTIMODAL_REGISTRY ._processor_factories [model_cls ]
48
41
ctx = InputProcessingContext (
49
42
model_config ,
50
- tokenizer = cached_get_tokenizer (model_config .tokenizer ),
43
+ tokenizer = cached_get_tokenizer (
44
+ model_config .tokenizer ,
45
+ trust_remote_code = model_info .trust_remote_code ,
46
+ ),
51
47
)
52
48
# Ensure that it can fit all of the data
53
49
cache = ProcessingCache (capacity = 1 << 30 )
54
50
51
+ processing_info = factories .info (ctx )
52
+ supported_mm_limits = processing_info .get_supported_mm_limits ()
53
+ limit_mm_per_prompt = {
54
+ modality : 3 if limit is None else limit
55
+ for modality , limit in supported_mm_limits .items ()
56
+ }
57
+
58
+ model_config .get_multimodal_config ().limit_per_prompt = limit_mm_per_prompt
59
+
55
60
baseline_processor = factories .build_processor (ctx , cache = None )
56
61
cached_processor = factories .build_processor (ctx , cache = cache )
57
62
dummy_inputs = baseline_processor .dummy_inputs
@@ -82,8 +87,8 @@ def _test_processing_correctness(
82
87
mm_data = {
83
88
k :
84
89
[(input_to_hit [k ] if rng .rand () < hit_rate else input_factory [k ]())
85
- for _ in range (rng .randint (limit_mm_per_prompt [ k ] ))]
86
- for k in modalities
90
+ for _ in range (rng .randint (limit ))]
91
+ for k , limit in limit_mm_per_prompt . items ()
87
92
}
88
93
89
94
mm_counts = {k : len (vs ) for k , vs in mm_data .items ()}
@@ -135,53 +140,49 @@ def _test_processing_correctness(
135
140
136
141
# yapf: disable
137
142
# True if the model supports multiple data items of the modality per request
138
- @pytest .mark .parametrize (("model_id" , "modalities" ), [
139
- ("rhymes-ai/Aria" , {"image" : True }),
140
- ("Salesforce/blip2-opt-2.7b" , {"image" : False }),
141
- ("facebook/chameleon-7b" , {"image" : False }),
142
- ("deepseek-ai/deepseek-vl2-tiny" , {"image" : True }),
143
- ("adept/fuyu-8b" , {"image" : False }),
144
- ("llava-hf/llava-1.5-7b-hf" , {"image" : True }),
145
- ("llava-hf/llava-v1.6-mistral-7b-hf" , {"image" : True }),
146
- ("llava-hf/LLaVA-NeXT-Video-7B-hf" , {"video" : False }),
147
- ("llava-hf/llava-onevision-qwen2-0.5b-ov-hf" , {"image" : True , "video" : True }), # noqa: E501
148
- ("TIGER-Lab/Mantis-8B-siglip-llama3" , {"image" : True }),
149
- ("mistral-community/pixtral-12b" , {"image" : True }),
150
- ("Qwen/Qwen2-VL-2B-Instruct" , {"image" : True , "video" : True }),
151
- ("Qwen/Qwen2-Audio-7B-Instruct" , {"audio" : True }),
152
- ("fixie-ai/ultravox-v0_3" , {"audio" : True }),
143
+ @pytest .mark .parametrize ("model_id" , [
144
+ "rhymes-ai/Aria" ,
145
+ "Salesforce/blip2-opt-2.7b" ,
146
+ "facebook/chameleon-7b" ,
147
+ "deepseek-ai/deepseek-vl2-tiny" ,
148
+ "adept/fuyu-8b" ,
149
+ "llava-hf/llava-1.5-7b-hf" ,
150
+ "llava-hf/llava-v1.6-mistral-7b-hf" ,
151
+ "llava-hf/LLaVA-NeXT-Video-7B-hf" ,
152
+ "llava-hf/llava-onevision-qwen2-0.5b-ov-hf" ,
153
+ "TIGER-Lab/Mantis-8B-siglip-llama3" ,
154
+ "mistral-community/pixtral-12b" ,
155
+ "Qwen/Qwen-VL-Chat" ,
156
+ "Qwen/Qwen2-VL-2B-Instruct" ,
157
+ "Qwen/Qwen2-Audio-7B-Instruct" ,
158
+ "fixie-ai/ultravox-v0_3" ,
153
159
])
154
160
@pytest .mark .parametrize ("hit_rate" , [0.3 , 0.5 , 1.0 ])
155
161
@pytest .mark .parametrize ("num_batches" , [32 ])
156
162
@pytest .mark .parametrize ("simplify_rate" , [1.0 ])
157
163
# yapf: enable
158
164
def test_processing_correctness (
159
165
model_id : str ,
160
- modalities : dict [str , bool ],
161
166
hit_rate : float ,
162
167
num_batches : int ,
163
168
simplify_rate : float ,
164
169
):
165
170
_test_processing_correctness (
166
171
model_id ,
167
- modalities ,
168
172
hit_rate = hit_rate ,
169
173
num_batches = num_batches ,
170
174
simplify_rate = simplify_rate ,
171
175
)
172
176
173
177
174
178
# yapf: disable
175
- @pytest .mark .parametrize (("model_id" , "modalities" ), [
176
- ("microsoft/Phi-3-vision-128k-instruct" , {"image" : True }),
177
- ])
179
+ @pytest .mark .parametrize ("model_id" , ["microsoft/Phi-3-vision-128k-instruct" ])
178
180
@pytest .mark .parametrize ("hit_rate" , [0.3 , 0.5 , 1.0 ])
179
181
@pytest .mark .parametrize ("num_batches" , [32 ])
180
182
@pytest .mark .parametrize ("simplify_rate" , [1.0 ])
181
183
# yapf: enable
182
184
def test_processing_correctness_phi3v (
183
185
model_id : str ,
184
- modalities : dict [str , bool ],
185
186
hit_rate : float ,
186
187
num_batches : int ,
187
188
simplify_rate : float ,
@@ -195,7 +196,6 @@ def test_processing_correctness_phi3v(
195
196
196
197
_test_processing_correctness (
197
198
model_id ,
198
- modalities ,
199
199
hit_rate = hit_rate ,
200
200
num_batches = num_batches ,
201
201
simplify_rate = simplify_rate ,
0 commit comments