Skip to content

Commit b2032fd

Browse files
committed
Add tests
1 parent 6196c15 commit b2032fd

File tree

4 files changed

+120
-8
lines changed

4 files changed

+120
-8
lines changed

tests/conftest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,10 @@ def get_tokenizer_pool_config(tokenizer_group_type):
564564
return TokenizerPoolConfig(pool_size=1,
565565
pool_type="ray",
566566
extra_config={})
567+
if isinstance(tokenizer_group_type, type):
568+
return TokenizerPoolConfig(pool_size=1,
569+
pool_type=tokenizer_group_type,
570+
extra_config={})
567571
raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")
568572

569573

tests/engine/test_custom_executor.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import asyncio
2+
import os
3+
4+
import pytest
5+
6+
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
7+
from vllm.engine.async_llm_engine import AsyncLLMEngine
8+
from vllm.engine.llm_engine import LLMEngine
9+
from vllm.executor.gpu_executor import GPUExecutor, GPUExecutorAsync
10+
from vllm.sampling_params import SamplingParams
11+
12+
13+
class Mock:
14+
...
15+
16+
17+
class CustomGPUExecutor(GPUExecutor):
18+
19+
def execute_model(self, *args, **kwargs):
20+
# Drop marker to show that this was ran
21+
with open(".marker", "w"):
22+
...
23+
return super().execute_model(*args, **kwargs)
24+
25+
26+
class CustomGPUExecutorAsync(GPUExecutorAsync):
27+
28+
async def execute_model_async(self, *args, **kwargs):
29+
with open(".marker", "w"):
30+
...
31+
return await super().execute_model_async(*args, **kwargs)
32+
33+
34+
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
35+
def test_custom_executor_type_checking(model):
36+
with pytest.raises(ValueError):
37+
engine_args = EngineArgs(model=model,
38+
distributed_executor_backend=Mock)
39+
LLMEngine.from_engine_args(engine_args)
40+
with pytest.raises(ValueError):
41+
engine_args = AsyncEngineArgs(model=model,
42+
distributed_executor_backend=Mock)
43+
AsyncLLMEngine.from_engine_args(engine_args)
44+
with pytest.raises(TypeError):
45+
engine_args = AsyncEngineArgs(
46+
model=model, distributed_executor_backend=CustomGPUExecutor)
47+
AsyncLLMEngine.from_engine_args(engine_args)
48+
49+
50+
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
51+
def test_custom_executor(model, tmpdir):
52+
cwd = os.path.abspath(".")
53+
os.chdir(tmpdir)
54+
try:
55+
assert not os.path.exists(".marker")
56+
57+
engine_args = EngineArgs(
58+
model=model, distributed_executor_backend=CustomGPUExecutor)
59+
engine = LLMEngine.from_engine_args(engine_args)
60+
sampling_params = SamplingParams(max_tokens=1)
61+
62+
engine.add_request("0", "foo", sampling_params)
63+
engine.step()
64+
65+
assert os.path.exists(".marker")
66+
finally:
67+
os.chdir(cwd)
68+
69+
70+
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
71+
def test_custom_executor_async(model, tmpdir):
72+
cwd = os.path.abspath(".")
73+
os.chdir(tmpdir)
74+
try:
75+
assert not os.path.exists(".marker")
76+
77+
engine_args = AsyncEngineArgs(
78+
model=model, distributed_executor_backend=CustomGPUExecutorAsync)
79+
engine = AsyncLLMEngine.from_engine_args(engine_args)
80+
sampling_params = SamplingParams(max_tokens=1)
81+
82+
async def t():
83+
stream = await engine.add_request("0", "foo", sampling_params)
84+
async for x in stream:
85+
...
86+
87+
asyncio.run(t())
88+
89+
assert os.path.exists(".marker")
90+
finally:
91+
os.chdir(cwd)

tests/tokenization/test_tokenizer_group.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,28 @@
77
import pytest
88
from transformers import AutoTokenizer, PreTrainedTokenizerBase
99

10-
from vllm.transformers_utils.tokenizer_group import get_tokenizer_group
10+
from vllm.transformers_utils.tokenizer_group import (TokenizerGroup,
11+
get_tokenizer_group)
1112
from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import (
1213
RayTokenizerGroupPool)
13-
from vllm.transformers_utils.tokenizer_group.tokenizer_group import (
14-
TokenizerGroup)
1514

1615
from ..conftest import get_tokenizer_pool_config
1716

1817

18+
class CustomTokenizerGroup(TokenizerGroup):
19+
20+
def __init__(self, *args, **kwargs):
21+
super().__init__(*args, **kwargs)
22+
self._i = 0
23+
24+
def encode(self, *args, **kwargs):
25+
self._i += 1
26+
return super().encode(*args, **kwargs)
27+
28+
1929
@pytest.mark.asyncio
20-
@pytest.mark.parametrize("tokenizer_group_type", [None, "ray"])
30+
@pytest.mark.parametrize("tokenizer_group_type",
31+
[None, "ray", CustomTokenizerGroup])
2132
async def test_tokenizer_group(tokenizer_group_type):
2233
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
2334
tokenizer_group = get_tokenizer_group(
@@ -36,6 +47,8 @@ async def test_tokenizer_group(tokenizer_group_type):
3647
PreTrainedTokenizerBase)
3748
assert tokenizer_group.get_lora_tokenizer(
3849
None) == await tokenizer_group.get_lora_tokenizer_async(None)
50+
if tokenizer_group_type is CustomTokenizerGroup:
51+
assert tokenizer_group._i > 0
3952

4053

4154
@pytest.mark.asyncio

vllm/config.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import torch
77
from transformers import PretrainedConfig
88

9-
from vllm.executor.executor_base import ExecutorBase
109
from vllm.logger import init_logger
1110
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
1211
from vllm.model_executor.models import ModelRegistry
@@ -19,6 +18,7 @@
1918
if TYPE_CHECKING:
2019
from ray.util.placement_group import PlacementGroup
2120

21+
from vllm.executor.executor_base import ExecutorBase
2222
from vllm.model_executor.model_loader.loader import BaseModelLoader
2323
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
2424
BaseTokenizerGroup)
@@ -657,7 +657,7 @@ def __init__(
657657
ray_workers_use_nsight: bool = False,
658658
placement_group: Optional["PlacementGroup"] = None,
659659
distributed_executor_backend: Optional[Union[
660-
str, Type[ExecutorBase]]] = None,
660+
str, Type["ExecutorBase"]]] = None,
661661
) -> None:
662662
self.pipeline_parallel_size = pipeline_parallel_size
663663
self.tensor_parallel_size = tensor_parallel_size
@@ -714,6 +714,9 @@ def use_ray(self) -> bool:
714714
and self.distributed_executor_backend.uses_ray)
715715

716716
def _verify_args(self) -> None:
717+
# Lazy import to avoid circular import
718+
from vllm.executor.executor_base import ExecutorBase
719+
717720
if (self.pipeline_parallel_size > 1
718721
and self.distributed_executor_backend == "mp"):
719722
raise NotImplementedError("Pipeline parallelism is not supported "
@@ -723,8 +726,9 @@ def _verify_args(self) -> None:
723726
self.distributed_executor_backend, type) and issubclass(
724727
self.distributed_executor_backend, ExecutorBase)):
725728
raise ValueError(
726-
"Unrecognized distributed executor backend. Supported values "
727-
"are 'ray' or 'mp'.")
729+
"Unrecognized distributed executor backend "
730+
f"{self.distributed_executor_backend}. Supported "
731+
"values are 'ray', 'mp' or custom ExecutorBase subclass.")
728732
if self.use_ray:
729733
from vllm.executor import ray_utils
730734
ray_utils.assert_ray_available()

0 commit comments

Comments
 (0)