Skip to content

Commit ed04071

Browse files
aarnphmrussellb
authored andcommitted
[V1] Structured Outputs + Thinking compatibility (vllm-project#16577)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz> Co-authored-by: Russell Bryant <rbryant@redhat.com> Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
1 parent 96f13e7 commit ed04071

File tree

10 files changed

+233
-75
lines changed

10 files changed

+233
-75
lines changed

docs/source/features/reasoning_outputs.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,10 @@ Remember to check whether the `reasoning_content` exists in the response before
141141
The reasoning content is also available in the structured output. The structured output engine like `xgrammar` will use the reasoning content to generate structured output. It is only supported in v0 engine now.
142142

143143
```bash
144-
VLLM_USE_V1=0 vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --reasoning-parser deepseek_r1
144+
vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --reasoning-parser deepseek_r1
145145
```
146146

147-
Please note that the `VLLM_USE_V1` environment variable must be set to `0` to use the v0 engine.
147+
The following is an example client:
148148

149149
```python
150150
from openai import OpenAI

tests/v1/entrypoints/llm/test_struct_output_generate.py

Lines changed: 92 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,27 @@
1+
# ruff: noqa: E501
12
# SPDX-License-Identifier: Apache-2.0
23

34
from __future__ import annotations
45

56
import json
67
import re
78
from enum import Enum
8-
from typing import Any
9+
from typing import TYPE_CHECKING, Any
910

1011
import jsonschema
1112
import pytest
1213
from pydantic import BaseModel
1314

15+
from tests.reasoning.utils import run_reasoning_extraction
1416
from vllm.entrypoints.llm import LLM
1517
from vllm.outputs import RequestOutput
1618
from vllm.platforms import current_platform
19+
from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager
1720
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
1821

22+
if TYPE_CHECKING:
23+
from vllm.config import TokenizerMode
24+
1925
NGRAM_SPEC_CONFIG = {
2026
"model": "[ngram]",
2127
"num_speculative_tokens": 5,
@@ -444,7 +450,7 @@ def test_structured_output(
444450

445451
prompt = """
446452
You have access to the following function to retrieve the weather in a city:
447-
453+
448454
{
449455
"name": "get_weather",
450456
"parameters": {
@@ -455,7 +461,7 @@ def test_structured_output(
455461
}
456462
}
457463
}
458-
464+
459465
If a you choose to call a function ONLY reply in the following format:
460466
<{start_tag}={function_name}>{parameters}{end_tag}
461467
where
@@ -476,7 +482,7 @@ def test_structured_output(
476482
- Always add your sources when using search results to answer the user query
477483
478484
You are a helpful assistant.
479-
485+
480486
Given the previous instructions, what is the weather in New York City? \
481487
Make the response as short as possible.
482488
"""
@@ -514,6 +520,88 @@ def test_structured_output(
514520
f"{generated_text!r}\nError: {str(e)}")
515521

516522

523+
@pytest.mark.skip_global_cleanup
524+
@pytest.mark.parametrize(
525+
"model_name, guided_decoding_backend, tokenizer_mode, reasoning_parser, speculative_config", # noqa: E501
526+
[
527+
("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", "xgrammar", "auto",
528+
"deepseek_r1", NGRAM_SPEC_CONFIG),
529+
("Qwen/Qwen3-1.7B", "xgrammar", "auto", "deepseek_r1", None),
530+
],
531+
)
532+
def test_structured_output_with_reasoning_matrices(
533+
monkeypatch: pytest.MonkeyPatch,
534+
guided_decoding_backend: str,
535+
tokenizer_mode: TokenizerMode,
536+
reasoning_parser: str,
537+
model_name: str,
538+
speculative_config: dict[str, Any] | None,
539+
):
540+
monkeypatch.setenv("VLLM_USE_V1", "1")
541+
542+
if current_platform.is_tpu() and speculative_config:
543+
pytest.skip("TPU does not support speculative decoding")
544+
545+
# Use a single LLM instance for several scenarios to
546+
# speed up the test suite.
547+
llm = LLM(
548+
model=model_name,
549+
# Don't use eager execution on TPUs because we want to test for no
550+
# recompilation at runtime
551+
enforce_eager=bool(not current_platform.is_tpu()),
552+
max_model_len=1024,
553+
max_num_seqs=16,
554+
guided_decoding_backend=guided_decoding_backend,
555+
guided_decoding_disable_any_whitespace=True,
556+
tokenizer_mode=tokenizer_mode,
557+
reasoning_parser=reasoning_parser,
558+
speculative_config=speculative_config,
559+
)
560+
tokenizer = llm.get_tokenizer(None)
561+
reasoner = ReasoningParserManager.get_reasoning_parser(reasoning_parser)(
562+
tokenizer=tokenizer)
563+
564+
reasoning_prompt = "Solve the following math problem step-by-step, then provide the final answer as JSON object with a single key 'result'. Make sure to correct your reasoning if there are any issue should it arise.\nProblem: What is 5 * 8 + 2?" # noqa: E501
565+
reasoning_schema = {
566+
"type": "object",
567+
"properties": {
568+
"result": {
569+
"type": "integer"
570+
}
571+
},
572+
"required": ["result"],
573+
"additionalProperties": False
574+
}
575+
if "Qwen3" in model_name:
576+
reasoning_prompt += "<think>\n"
577+
578+
sampling_params = SamplingParams(
579+
temperature=0.1,
580+
max_tokens=8192,
581+
guided_decoding=GuidedDecodingParams(json=reasoning_schema),
582+
)
583+
outputs = llm.generate(
584+
[reasoning_prompt],
585+
sampling_params=sampling_params,
586+
use_tqdm=True,
587+
)
588+
589+
assert outputs is not None
590+
output = outputs[0]
591+
assert output is not None and isinstance(output, RequestOutput)
592+
prompt = output.prompt
593+
generated_text = output.outputs[0].text
594+
reasoning_content, content = run_reasoning_extraction(
595+
reasoner, [generated_text])
596+
print(
597+
f"Prompt: {prompt!r}\nReasoning: {reasoning_content!r}\nContent: {content!r}"
598+
)
599+
600+
assert content is not None and reasoning_content is not None
601+
output_json = json.loads(content)
602+
jsonschema.validate(instance=output_json, schema=reasoning_schema)
603+
604+
517605
@pytest.mark.skip_global_cleanup
518606
@pytest.mark.parametrize("model_name, tokenizer_mode",
519607
PARAMS_MODELS_TOKENIZER_MODE)

vllm/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2332,7 +2332,7 @@ class SpeculativeConfig:
23322332
`TypicalAcceptanceSampler`."""
23332333

23342334
speculative_token_tree: Optional[str] = None
2335-
"""Specifies the tree structure for speculative token generation.
2335+
"""Specifies the tree structure for speculative token generation.
23362336
"""
23372337
# required configuration params passed from engine
23382338
target_model_config: ModelConfig = field(default=None,
@@ -4024,7 +4024,7 @@ class VllmConfig:
40244024
"""LoRA configuration."""
40254025
speculative_config: Optional[SpeculativeConfig] = None
40264026
"""Speculative decoding configuration."""
4027-
decoding_config: Optional[DecodingConfig] = None
4027+
decoding_config: DecodingConfig = field(default_factory=DecodingConfig)
40284028
"""Decoding configuration."""
40294029
observability_config: Optional[ObservabilityConfig] = None
40304030
"""Observability configuration."""

vllm/reasoning/abs_reasoning_parsers.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
from __future__ import annotations
4+
35
import os
46
from abc import abstractmethod
57
from collections.abc import Sequence
@@ -33,7 +35,7 @@ def vocab(self) -> dict[str, int]:
3335
return self.model_tokenizer.get_vocab()
3436

3537
@abstractmethod
36-
def is_reasoning_end(self, input_ids: list[int]) -> bool:
38+
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
3739
"""
3840
Check if the reasoning content ends in the input_ids.
3941
@@ -106,7 +108,7 @@ class ReasoningParserManager:
106108
reasoning_parsers: dict[str, type] = {}
107109

108110
@classmethod
109-
def get_reasoning_parser(cls, name) -> type:
111+
def get_reasoning_parser(cls, name: str | None) -> type[ReasoningParser]:
110112
"""
111113
Get reasoning parser by name which is registered by `register_module`.
112114

vllm/v1/core/sched/scheduler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -758,7 +758,8 @@ def update_from_output(
758758
# the outer lists can be of length > 1.
759759
new_logprobs = logprobs.slice(req_index, req_index + 1)
760760

761-
if new_token_ids and request.use_structured_output:
761+
if new_token_ids and self.structured_output_manager.should_advance(
762+
request):
762763
# NOTE: structured_output_request
763764
# should not be None if use_structured_output, we have
764765
# check above, so safe to ignore type warning
@@ -767,11 +768,10 @@ def update_from_output(
767768

768769
# Add newly generated spec token ids to the request.
769770
if spec_token_ids is not None:
770-
if request.use_structured_output:
771+
if self.structured_output_manager.should_advance(request):
771772
metadata = request.structured_output_request
772-
assert metadata is not None and metadata.grammar is not None
773773
# Needs to happen after new_token_ids are accepted.
774-
request.spec_token_ids = metadata.grammar.validate_tokens(
774+
request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr]
775775
spec_token_ids[req_index])
776776
else:
777777
request.spec_token_ids = spec_token_ids[req_index]

0 commit comments

Comments
 (0)