Skip to content

[V0][V1][Core] Add outlines integration for V1, and update V0 integration. #15975

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion requirements/common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ prometheus-fastapi-instrumentator >= 7.0.0
tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer >= 0.10.11, < 0.11
llguidance >= 0.7.11, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64"
outlines == 0.1.11
outlines_core == 0.2.10
# required for outlines backend disk cache
diskcache == 5.6.3
lark == 1.2.2
xgrammar == 0.1.19; platform_machine == "x86_64" or platform_machine == "aarch64"
typing_extensions >= 4.10
Expand Down
33 changes: 20 additions & 13 deletions tests/entrypoints/llm/test_guided_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,18 @@
from vllm.sampling_params import GuidedDecodingParams, SamplingParams

MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
GUIDED_DECODING_BACKENDS = [

# Separate backends which support grammars vs ones
# which only support regex based constraints in tests.
GRAMMAR_DECODING_BACKENDS = [
# (backend, disable_any_whitespace),
("outlines", False),
("lm-format-enforcer", False),
("xgrammar", True),
("guidance", True),
]

ALL_DECODING_BACKENDS = ([("outlines", False)] + GRAMMAR_DECODING_BACKENDS)


@pytest.fixture(scope="module")
def llm():
Expand All @@ -38,7 +42,7 @@ def llm():

@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
ALL_DECODING_BACKENDS)
def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
disable_any_whitespace: bool):
sampling_params = SamplingParams(
Expand All @@ -48,6 +52,7 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
regex=sample_regex,
backend=guided_decoding_backend,
disable_any_whitespace=disable_any_whitespace))

outputs = llm.generate(prompts=[
f"Give an example IPv4 address with this regex: {sample_regex}"
] * 2,
Expand All @@ -68,7 +73,7 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,

@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
ALL_DECODING_BACKENDS)
def test_guided_json_completion(sample_json_schema, llm,
guided_decoding_backend: str,
disable_any_whitespace: bool):
Expand Down Expand Up @@ -102,7 +107,7 @@ def test_guided_json_completion(sample_json_schema, llm,

@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
ALL_DECODING_BACKENDS)
def test_guided_complex_json_completion(sample_complex_json_schema, llm,
guided_decoding_backend: str,
disable_any_whitespace: bool):
Expand Down Expand Up @@ -137,7 +142,7 @@ def test_guided_complex_json_completion(sample_complex_json_schema, llm,

@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
ALL_DECODING_BACKENDS)
def test_guided_definition_json_completion(sample_definition_json_schema, llm,
guided_decoding_backend: str,
disable_any_whitespace: bool):
Expand Down Expand Up @@ -172,7 +177,7 @@ def test_guided_definition_json_completion(sample_definition_json_schema, llm,

@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
ALL_DECODING_BACKENDS)
def test_guided_enum_json_completion(sample_enum_json_schema, llm,
guided_decoding_backend: str,
disable_any_whitespace: bool):
Expand Down Expand Up @@ -217,7 +222,7 @@ def test_guided_enum_json_completion(sample_enum_json_schema, llm,

@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
ALL_DECODING_BACKENDS)
def test_guided_choice_completion(sample_guided_choice, llm,
guided_decoding_backend: str,
disable_any_whitespace: bool):
Expand Down Expand Up @@ -247,7 +252,7 @@ def test_guided_choice_completion(sample_guided_choice, llm,

@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
GRAMMAR_DECODING_BACKENDS)
def test_guided_grammar(sample_sql_statements, llm,
guided_decoding_backend: str,
disable_any_whitespace: bool):
Expand Down Expand Up @@ -343,7 +348,7 @@ def test_disable_guided_decoding_fallback(sample_regex, llm):

@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
GRAMMAR_DECODING_BACKENDS)
def test_guided_json_object(llm, guided_decoding_backend: str,
disable_any_whitespace: bool):
sampling_params = SamplingParams(
Expand Down Expand Up @@ -376,7 +381,9 @@ def test_guided_json_object(llm, guided_decoding_backend: str,

# Parse to verify it is valid JSON
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)
# A list is not what was intended, but is still valid
# json.
assert isinstance(parsed_json, (dict, list))


class CarType(str, Enum):
Expand All @@ -394,7 +401,7 @@ class CarDescription(BaseModel):

@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
ALL_DECODING_BACKENDS)
def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str,
disable_any_whitespace: bool):
json_schema = CarDescription.model_json_schema()
Expand Down Expand Up @@ -426,7 +433,7 @@ def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str,

@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace",
GUIDED_DECODING_BACKENDS)
ALL_DECODING_BACKENDS)
def test_guided_number_range_json_completion(llm, guided_decoding_backend: str,
disable_any_whitespace: bool):
sample_output_schema = {
Expand Down
34 changes: 12 additions & 22 deletions tests/model_executor/test_guided_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,26 +39,23 @@ def test_guided_logits_processors(zephyr_7B_tokenzer, sample_regex,
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
regex_LP = RegexLogitsProcessor(sample_regex,
zephyr_7B_tokenzer,
reasoner=None)
reasoner=None,
vocab_size=32000)
json_LP = JSONLogitsProcessor(sample_json_schema,
zephyr_7B_tokenzer,
whitespace_pattern=None,
reasoner=None)
reasoner=None,
vocab_size=32000)

token_ids = zephyr_7B_tokenzer.encode(
f"Give an example IPv4 address with this regex: {sample_regex}")
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
regex_LP(token_ids, tensor)
tensor = regex_LP([], tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)

token_ids = zephyr_7B_tokenzer.encode(
f"Give an employee profile that fits this schema: {sample_json_schema}"
)
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
json_LP(token_ids, tensor)
tensor = json_LP([], tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)

Expand All @@ -80,8 +77,6 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
seed=0,
dtype="bfloat16",
)
token_ids = zephyr_7B_tokenzer.encode(
f"Give an example IPv4 address with this regex: {sample_regex}")
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)

regex_lp = get_local_guided_decoding_logits_processor(
Expand All @@ -91,21 +86,19 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool,
assert regex_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
tensor = regex_lp(token_ids, tensor)
# allowed tokens at state 0
tensor = regex_lp([], tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)

token_ids = zephyr_7B_tokenzer.encode(
f"Give an employee profile that fits this schema: {sample_json_schema}"
)
json_request = GuidedDecodingParams(json=sample_json_schema,
backend=backend)
json_lp = await get_guided_decoding_logits_processor(
json_request, zephyr_7B_tokenzer, config)
assert json_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
tensor = json_lp(token_ids, tensor)
tensor = json_lp([], tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)

Expand All @@ -129,7 +122,6 @@ async def test_guided_logits_processor_with_reasoning(
dtype="bfloat16",
)
token_ids = deepseek_r1_qwen_tokenizer.encode(
f"Give an example IPv4 address with this regex: {sample_regex}."
"<think>here is the thinking process")
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)

Expand All @@ -142,12 +134,11 @@ async def test_guided_logits_processor_with_reasoning(
assert regex_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
tensor = regex_lp(token_ids, tensor)
regex_lp(token_ids, tensor)
assert tensor.shape == original_tensor.shape
assert torch.allclose(tensor, original_tensor)

token_ids = deepseek_r1_qwen_tokenizer.encode(
f"Give an employee profile that fits this schema: {sample_json_schema}."
"<think>here is the thinking process")
json_request = GuidedDecodingParams(json=sample_json_schema,
backend=backend)
Expand All @@ -159,14 +150,13 @@ async def test_guided_logits_processor_with_reasoning(
assert json_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
tensor = json_lp(token_ids, tensor)
json_lp(token_ids, tensor)
assert tensor.shape == original_tensor.shape
assert torch.allclose(tensor, original_tensor)

# Thinking is over, so the tensor should change.
token_ids = deepseek_r1_qwen_tokenizer.encode(
f"Give an employee profile that fits this schema: {sample_json_schema}."
"<think>here is the thinking process</think> Then")
"<think>here is the thinking process</think>")
json_request = GuidedDecodingParams(json=sample_json_schema,
backend=backend)
json_lp = get_local_guided_decoding_logits_processor(
Expand Down
2 changes: 1 addition & 1 deletion tests/tool_use/test_tool_choice_required.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _compile_and_check(tools: list[ChatCompletionToolsParam], sample_output,
assert isinstance(schema, dict)

# use build_regex_from_schema used in JSONLogitsProcessor to create Guide
from outlines_core.fsm.json_schema import build_regex_from_schema
from outlines_core.json_schema import build_regex_from_schema
regex = build_regex_from_schema(json.dumps(schema))
compiled = re.compile(regex)
matches = compiled.fullmatch(json.dumps(sample_output)) is not None
Expand Down
3 changes: 2 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3289,7 +3289,8 @@ def get_served_model_name(model: str,

GuidedDecodingBackendV0 = Literal["auto", "outlines", "lm-format-enforcer",
"xgrammar", "guidance"]
GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance"]

GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance", "outlines"]
GuidedDecodingBackend = Literal[GuidedDecodingBackendV0,
GuidedDecodingBackendV1]

Expand Down
7 changes: 7 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@
VLLM_DP_MASTER_PORT: int = 0
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
VLLM_V0_USE_OUTLINES_CACHE: bool = False
VLLM_V1_USE_OUTLINES_CACHE: bool = False
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
VLLM_USE_DEEP_GEMM: bool = False
VLLM_XGRAMMAR_CACHE_MB: int = 0
Expand Down Expand Up @@ -737,6 +738,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
"VLLM_V0_USE_OUTLINES_CACHE":
lambda: os.environ.get("VLLM_V0_USE_OUTLINES_CACHE", "0") == "1",

# Whether to turn on the outlines cache for V0
# This cache is unbounded and on disk, so it's not safe to use in
# an environment with potentially malicious users.
"VLLM_V1_USE_OUTLINES_CACHE":
lambda: os.environ.get("VLLM_V1_USE_OUTLINES_CACHE", "0") == "1",

# Gap between padding buckets for the forward pass. So we have
# 8, we will run forward pass with [16, 24, 32, ...].
"VLLM_TPU_BUCKET_PADDING_GAP":
Expand Down
6 changes: 2 additions & 4 deletions vllm/model_executor/guided_decoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,12 @@ async def get_guided_decoding_logits_processor(

guided_params = maybe_backend_fallback(guided_params)

# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_outlines_guided_decoding_logits_processor)
return await get_outlines_guided_decoding_logits_processor(
guided_params, tokenizer, reasoner)
guided_params, tokenizer, reasoner, model_config)
if guided_params.backend == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_local_lm_format_enforcer_guided_decoding_logits_processor)
Expand Down Expand Up @@ -151,13 +150,12 @@ def get_local_guided_decoding_logits_processor(
reasoning_backend)
reasoner = reasoner_class(tokenizer)

# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_local_outlines_guided_decoding_logits_processor)
return get_local_outlines_guided_decoding_logits_processor(
guided_params, tokenizer, reasoner)
guided_params, tokenizer, reasoner, model_config)
if guided_params.backend == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_local_lm_format_enforcer_guided_decoding_logits_processor)
Expand Down
Loading