|
7 | 7 | """
|
8 | 8 | import os
|
9 | 9 | from dataclasses import dataclass
|
10 |
| -from typing import List, NamedTuple, Optional |
| 10 | +from typing import List, Literal, NamedTuple, Optional |
11 | 11 |
|
12 | 12 | import pytest
|
13 | 13 |
|
@@ -97,22 +97,23 @@ def iter_params(self, model_name: str):
|
97 | 97 | self.trust_remote_code, self.tokenizer_mode)
|
98 | 98 |
|
99 | 99 |
|
| 100 | +# NOTE: You can adjust tp_base and/or pp_base locally to fit the model in GPU |
| 101 | +# The values displayed here are only a rough indicator of the size of the model |
| 102 | + |
100 | 103 | # yapf: disable
|
101 | 104 | GENERATION_MODEL_SETTINGS = {
|
102 | 105 | # [DETAILED TESTS]
|
103 | 106 | "meta-llama/Meta-Llama-3-8B": PPTestSettings.detailed(),
|
104 | 107 | # [FAST TESTS]
|
105 | 108 | # Uses Llama
|
106 | 109 | # "BAAI/AquilaChat-7B": PPTestSettings.fast(),
|
107 |
| - # TODO: Test on larger GPU |
108 |
| - # "Snowflake/snowflake-arctic-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501 |
| 110 | + "Snowflake/snowflake-arctic-instruct": PPTestSettings.fast(tp_base=8, trust_remote_code=True), # noqa: E501 |
109 | 111 | "baichuan-inc/Baichuan-7B": PPTestSettings.fast(trust_remote_code=True),
|
110 | 112 | "baichuan-inc/Baichuan2-13B-Chat": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
|
111 | 113 | "bigscience/bloomz-1b1": PPTestSettings.fast(),
|
112 | 114 | "THUDM/chatglm3-6b": PPTestSettings.fast(trust_remote_code=True),
|
113 | 115 | "CohereForAI/c4ai-command-r-v01": PPTestSettings.fast(tp_base=2, trust_remote_code=True), # noqa: E501
|
114 |
| - # TODO: Test on larger GPU |
115 |
| - # "databricks/dbrx-instruct": PPTestSettings.fast(), |
| 116 | + "databricks/dbrx-instruct": PPTestSettings.fast(tp_base=8), |
116 | 117 | "Deci/DeciLM-7B-instruct": PPTestSettings.fast(trust_remote_code=True),
|
117 | 118 | "deepseek-ai/deepseek-llm-7b-chat": PPTestSettings.fast(),
|
118 | 119 | "deepseek-ai/DeepSeek-V2-Lite-Chat": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
|
@@ -161,8 +162,9 @@ def iter_params(self, model_name: str):
|
161 | 162 |
|
162 | 163 | EMBEDDING_MODEL_SETTINGS = { # type: ignore[var-annotated]
|
163 | 164 | # [FAST TESTS]
|
164 |
| - # Uses Llama |
165 |
| - # "intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(), |
| 165 | + "intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(), |
| 166 | + "BAAI/bge-multilingual-gemma2": PPTestSettings.fast(), |
| 167 | + "Qwen/Qwen2.5-Math-RM-72B": PPTestSettings.fast(tp_base=4, trust_remote_code=True), # noqa: E501 |
166 | 168 | }
|
167 | 169 |
|
168 | 170 | MULTIMODAL_MODEL_SETTINGS = {
|
@@ -192,40 +194,35 @@ def iter_params(self, model_name: str):
|
192 | 194 | }
|
193 | 195 | # yapf: enable
|
194 | 196 |
|
195 |
| -MODEL_SETTINGS = { |
196 |
| - **GENERATION_MODEL_SETTINGS, |
197 |
| - **EMBEDDING_MODEL_SETTINGS, |
198 |
| - **MULTIMODAL_MODEL_SETTINGS, |
199 |
| -} |
200 |
| - |
201 |
| -# You can update this on your local machine to run specific tests |
| 197 | +# NOTE: You can update this on your local machine to run specific tests |
202 | 198 | TEST_MODELS = [
|
| 199 | + # [LANGUAGE GENERATION] |
203 | 200 | "meta-llama/Meta-Llama-3-8B",
|
204 |
| - "facebook/chameleon-7b", |
| 201 | + "ibm/PowerLM-3b", |
| 202 | + # [LANGUAGE EMBEDDING] |
| 203 | + "intfloat/e5-mistral-7b-instruct", |
| 204 | + "BAAI/bge-multilingual-gemma2", |
| 205 | + # [MULTIMODAL GENERATION] |
205 | 206 | "OpenGVLab/InternVL2-1B",
|
206 | 207 | "microsoft/Phi-3-vision-128k-instruct",
|
207 |
| - "mistralai/Pixtral-12B-2409", |
208 | 208 | "fixie-ai/ultravox-v0_3",
|
209 | 209 | ]
|
210 | 210 |
|
211 | 211 |
|
212 |
| -@pytest.mark.parametrize( |
213 |
| - ("model_name", "parallel_setup", "distributed_backend", |
214 |
| - "trust_remote_code", "tokenizer_mode"), |
215 |
| - [ |
216 |
| - params for model_name, settings in MODEL_SETTINGS.items() |
217 |
| - for params in settings.iter_params(model_name) |
218 |
| - if model_name in TEST_MODELS |
219 |
| - ], |
220 |
| -) |
221 |
| -@fork_new_process_for_each_test |
222 |
| -def test_compare_tp(model_name: str, parallel_setup: ParallelSetup, |
223 |
| - distributed_backend: str, trust_remote_code: bool, |
224 |
| - tokenizer_mode: Optional[str], num_gpus_available): |
| 212 | +def _compare_tp( |
| 213 | + model_name: str, |
| 214 | + parallel_setup: ParallelSetup, |
| 215 | + distributed_backend: str, |
| 216 | + trust_remote_code: bool, |
| 217 | + tokenizer_mode: Optional[str], |
| 218 | + num_gpus_available: int, |
| 219 | + *, |
| 220 | + method: Literal["generate", "encode"] = "encode", |
| 221 | +): |
225 | 222 | tp_size, pp_size, eager_mode, chunked_prefill = parallel_setup
|
226 | 223 |
|
227 |
| - if num_gpus_available < tp_size: |
228 |
| - pytest.skip(f"Need at least {tp_size} GPUs to run the test") |
| 224 | + if num_gpus_available < tp_size * pp_size: |
| 225 | + pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs") |
229 | 226 | if VLLM_MULTI_NODE and distributed_backend == "mp":
|
230 | 227 | pytest.skip("Skipping multi-node pipeline parallel test for "
|
231 | 228 | "multiprocessing distributed backend")
|
@@ -286,10 +283,95 @@ def test_compare_tp(model_name: str, parallel_setup: ParallelSetup,
|
286 | 283 | ]
|
287 | 284 |
|
288 | 285 | try:
|
289 |
| - compare_two_settings(model_name, pp_args, tp_args, pp_env) |
| 286 | + compare_two_settings(model_name, |
| 287 | + pp_args, |
| 288 | + tp_args, |
| 289 | + pp_env, |
| 290 | + method=method) |
290 | 291 | except Exception:
|
291 | 292 | if pp_env is None:
|
292 | 293 | raise
|
293 | 294 | else:
|
294 | 295 | # Ray ADAG tests are flaky, so we don't want to fail the test
|
295 | 296 | logger.exception("Ray ADAG tests failed")
|
| 297 | + |
| 298 | + |
| 299 | +@pytest.mark.parametrize( |
| 300 | + ("model_name", "parallel_setup", "distributed_backend", |
| 301 | + "trust_remote_code", "tokenizer_mode"), |
| 302 | + [ |
| 303 | + params for model_name, settings in GENERATION_MODEL_SETTINGS.items() |
| 304 | + for params in settings.iter_params(model_name) |
| 305 | + if model_name in TEST_MODELS |
| 306 | + ], |
| 307 | +) |
| 308 | +@fork_new_process_for_each_test |
| 309 | +def test_tp_language_generation( |
| 310 | + model_name: str, |
| 311 | + parallel_setup: ParallelSetup, |
| 312 | + distributed_backend: str, |
| 313 | + trust_remote_code: bool, |
| 314 | + tokenizer_mode: Optional[str], |
| 315 | + num_gpus_available, |
| 316 | +): |
| 317 | + _compare_tp(model_name, |
| 318 | + parallel_setup, |
| 319 | + distributed_backend, |
| 320 | + trust_remote_code, |
| 321 | + tokenizer_mode, |
| 322 | + num_gpus_available, |
| 323 | + method="generate") |
| 324 | + |
| 325 | + |
| 326 | +@pytest.mark.parametrize( |
| 327 | + ("model_name", "parallel_setup", "distributed_backend", |
| 328 | + "trust_remote_code", "tokenizer_mode"), |
| 329 | + [ |
| 330 | + params for model_name, settings in EMBEDDING_MODEL_SETTINGS.items() |
| 331 | + for params in settings.iter_params(model_name) |
| 332 | + if model_name in TEST_MODELS |
| 333 | + ], |
| 334 | +) |
| 335 | +@fork_new_process_for_each_test |
| 336 | +def test_tp_language_embedding( |
| 337 | + model_name: str, |
| 338 | + parallel_setup: ParallelSetup, |
| 339 | + distributed_backend: str, |
| 340 | + trust_remote_code: bool, |
| 341 | + tokenizer_mode: Optional[str], |
| 342 | + num_gpus_available, |
| 343 | +): |
| 344 | + _compare_tp(model_name, |
| 345 | + parallel_setup, |
| 346 | + distributed_backend, |
| 347 | + trust_remote_code, |
| 348 | + tokenizer_mode, |
| 349 | + num_gpus_available, |
| 350 | + method="encode") |
| 351 | + |
| 352 | + |
| 353 | +@pytest.mark.parametrize( |
| 354 | + ("model_name", "parallel_setup", "distributed_backend", |
| 355 | + "trust_remote_code", "tokenizer_mode"), |
| 356 | + [ |
| 357 | + params for model_name, settings in MULTIMODAL_MODEL_SETTINGS.items() |
| 358 | + for params in settings.iter_params(model_name) |
| 359 | + if model_name in TEST_MODELS |
| 360 | + ], |
| 361 | +) |
| 362 | +@fork_new_process_for_each_test |
| 363 | +def test_tp_multimodal_generation( |
| 364 | + model_name: str, |
| 365 | + parallel_setup: ParallelSetup, |
| 366 | + distributed_backend: str, |
| 367 | + trust_remote_code: bool, |
| 368 | + tokenizer_mode: Optional[str], |
| 369 | + num_gpus_available, |
| 370 | +): |
| 371 | + _compare_tp(model_name, |
| 372 | + parallel_setup, |
| 373 | + distributed_backend, |
| 374 | + trust_remote_code, |
| 375 | + tokenizer_mode, |
| 376 | + num_gpus_available, |
| 377 | + method="generate") |
0 commit comments