From 7f9b1741c86f851b4c5d477707e914e9b2972757 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Mon, 14 Apr 2025 06:33:05 +0000 Subject: [PATCH 01/40] chore: migrate tokenizer init to manager only Signed-off-by: Aaron Pham --- vllm/v1/structured_output/__init__.py | 25 ++++++++++--- vllm/v1/structured_output/backend_guidance.py | 20 +++------- vllm/v1/structured_output/backend_types.py | 13 +++++++ vllm/v1/structured_output/backend_xgrammar.py | 37 ++++++++----------- 4 files changed, 54 insertions(+), 41 deletions(-) diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 0fd66c07296..08e20537dee 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -7,9 +7,11 @@ from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.v1.structured_output.backend_guidance import GuidanceBackend from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, StructuredOutputGrammar) +from vllm.v1.structured_output.backend_xgrammar import XgrammarBackend if TYPE_CHECKING: import numpy as np @@ -46,13 +48,26 @@ def grammar_init(self, request: Request) -> None: # backends on a per-request basis in V1 (for now, anyway...). if self.backend is None: backend_name = request.sampling_params.guided_decoding.backend_name + tokenizer_group = init_tokenizer_from_configs( + model_config=self.vllm_config.model_config, + scheduler_config=self.vllm_config.scheduler_config, + parallel_config=self.vllm_config.parallel_config, + lora_config=self.vllm_config.lora_config) + tokenizer_group.ping() + tokenizer = tokenizer_group.get_lora_tokenizer(None) + vocab_size = self.vllm_config.model_config.get_vocab_size() if backend_name == "xgrammar": - from vllm.v1.structured_output.backend_xgrammar import ( - XgrammarBackend) - - self.backend = XgrammarBackend(self.vllm_config) + self.backend = XgrammarBackend( + self.vllm_config, + tokenizer=tokenizer, + vocab_size=vocab_size, + ) elif backend_name == "guidance": - self.backend = GuidanceBackend(self.vllm_config) + self.backend = GuidanceBackend( + self.vllm_config, + tokenizer=tokenizer, + vocab_size=vocab_size, + ) else: raise ValueError( f"Unsupported structured output backend: {backend_name}") diff --git a/vllm/v1/structured_output/backend_guidance.py b/vllm/v1/structured_output/backend_guidance.py index 6d2ccd4019d..603083f65b2 100644 --- a/vllm/v1/structured_output/backend_guidance.py +++ b/vllm/v1/structured_output/backend_guidance.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import copy import json import os @@ -8,10 +10,8 @@ import torch -from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.sampling_params import GuidedDecodingParams, SamplingParams -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.utils import LazyLoader from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, StructuredOutputGrammar, @@ -54,21 +54,14 @@ def process_for_additional_properties( return guide_json_obj +@dataclass class GuidanceBackend(StructuredOutputBackend): - def __init__(self, vllm_config: VllmConfig): - self.vllm_config = vllm_config - tokenizer_group = init_tokenizer_from_configs( - model_config=vllm_config.model_config, - scheduler_config=vllm_config.scheduler_config, - lora_config=vllm_config.lora_config) # type: ignore[arg-type] - self.vllm_config = vllm_config - self.vocab_size = vllm_config.model_config.get_vocab_size() - + def __post_init__(self): self.disable_any_whitespace = False self.no_additional_properties = False backend_options = GuidedDecodingParams( - backend=vllm_config.decoding_config.guided_decoding_backend + backend=self.vllm_config.decoding_config.guided_decoding_backend ).backend_options() for option in backend_options: if option == "disable-any-whitespace": @@ -79,9 +72,8 @@ def __init__(self, vllm_config: VllmConfig): raise ValueError( f"Unsupported option for the guidance backend: {option}") - tokenizer = tokenizer_group.get_lora_tokenizer(None) self.ll_tokenizer = llguidance_hf.from_tokenizer( - tokenizer, self.vocab_size) + self.tokenizer, self.vocab_size) def compile_grammar(self, request_type: StructuredOutputOptions, grammar_spec: str) -> StructuredOutputGrammar: diff --git a/vllm/v1/structured_output/backend_types.py b/vllm/v1/structured_output/backend_types.py index 306e4aa0196..7ac8e40f93e 100644 --- a/vllm/v1/structured_output/backend_types.py +++ b/vllm/v1/structured_output/backend_types.py @@ -1,10 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import enum from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING import torch +if TYPE_CHECKING: + from vllm.config import VllmConfig + from vllm.transformers_utils.tokenizer import AnyTokenizer + class StructuredOutputOptions(enum.Enum): JSON = enum.auto() @@ -60,9 +68,14 @@ def reset(self): """ +@dataclass class StructuredOutputBackend(ABC): """Engine-level backend for structured output requests.""" + vllm_config: VllmConfig + tokenizer: AnyTokenizer + vocab_size: int + @abstractmethod def compile_grammar(self, request_type: StructuredOutputOptions, grammar_spec: str) -> StructuredOutputGrammar: diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index bb7c7edc278..957283ee316 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import json from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any @@ -7,10 +9,8 @@ import torch import vllm.envs -from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.sampling_params import GuidedDecodingParams, SamplingParams -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.utils import LazyLoader from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, @@ -28,18 +28,13 @@ logger = init_logger(__name__) +@dataclass class XgrammarBackend(StructuredOutputBackend): - def __init__(self, vllm_config: VllmConfig): - self.vllm_config = vllm_config - tokenizer_group = init_tokenizer_from_configs( - model_config=vllm_config.model_config, - scheduler_config=vllm_config.scheduler_config, - lora_config=vllm_config.lora_config) # type: ignore[arg-type] - + def __post_init__(self): self.disable_any_whitespace = False backend_options = GuidedDecodingParams( - backend=vllm_config.decoding_config.guided_decoding_backend + backend=self.vllm_config.decoding_config.guided_decoding_backend ).backend_options() for option in backend_options: if option == "disable-any-whitespace": @@ -48,44 +43,42 @@ def __init__(self, vllm_config: VllmConfig): raise ValueError( f"Unsupported option for the xgrammar backend: {option}") - tokenizer = tokenizer_group.get_lora_tokenizer(None) - self.vocab_size = vllm_config.model_config.get_vocab_size() - if isinstance(tokenizer, MistralTokenizer): + if isinstance(self.tokenizer, MistralTokenizer): # NOTE: ideally, xgrammar should handle this accordingly. # refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98 try: - if tokenizer.is_tekken: - encoded_vocab = tokenizer._vocab + if self.tokenizer.is_tekken: + encoded_vocab = self.tokenizer._vocab else: encoded_vocab = [ token for token, _ in sorted( - tokenizer.get_vocab().items(), + self.tokenizer.get_vocab().items(), key=lambda x: x[1], ) ] stop_token_ids = None if hasattr( - tokenizer, + self.tokenizer, "eos_token_id", - ) and tokenizer.eos_token_id is not None: - stop_token_ids = [tokenizer.eos_token_id] + ) and self.tokenizer.eos_token_id is not None: + stop_token_ids = [self.tokenizer.eos_token_id] except AttributeError as e: raise ValueError( f"Cannot get the vocabulary of the tokenizer " - f"{type(tokenizer)}. The tokenizer should have a " + f"{type(self.tokenizer)}. The tokenizer should have a " "get_vocab method.") from e tokenizer_info = xgr.TokenizerInfo( # type: ignore encoded_vocab=encoded_vocab, # NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501 vocab_type=xgr.VocabType.RAW - if tokenizer.is_tekken else xgr.VocabType.BYTE_FALLBACK, + if self.tokenizer.is_tekken else xgr.VocabType.BYTE_FALLBACK, vocab_size=self.vocab_size, stop_token_ids=stop_token_ids, add_prefix_space=True, ) else: tokenizer_info = xgr.TokenizerInfo.from_huggingface( - tokenizer, + self.tokenizer, vocab_size=self.vocab_size, ) self.compiler = xgr.GrammarCompiler( From 023807da8e0989f931217708c3de2b889306d069 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Mon, 14 Apr 2025 06:44:59 +0000 Subject: [PATCH 02/40] chore: init reasoning_parser on manager Signed-off-by: Aaron Pham --- vllm/reasoning/abs_reasoning_parsers.py | 2 +- vllm/v1/structured_output/__init__.py | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/vllm/reasoning/abs_reasoning_parsers.py b/vllm/reasoning/abs_reasoning_parsers.py index 454167a0dc9..ca672561574 100644 --- a/vllm/reasoning/abs_reasoning_parsers.py +++ b/vllm/reasoning/abs_reasoning_parsers.py @@ -106,7 +106,7 @@ class ReasoningParserManager: reasoning_parsers: dict[str, type] = {} @classmethod - def get_reasoning_parser(cls, name) -> type: + def get_reasoning_parser(cls, name: str) -> type[ReasoningParser]: """ Get reasoning parser by name which is registered by `register_module`. diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 08e20537dee..3770596d9de 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -7,6 +7,7 @@ from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.reasoning import ReasoningParserManager from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.v1.structured_output.backend_guidance import GuidanceBackend from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, @@ -18,6 +19,7 @@ import numpy.typing as npt import torch + from vllm.reasoning import ReasoningParser from vllm.v1.request import Request logger = init_logger(__name__) @@ -28,6 +30,7 @@ class StructuredOutputManager: def __init__(self, vllm_config: VllmConfig): self.backend: Optional[StructuredOutputBackend] = None + self.reasoner: Optional[ReasoningParser] = None self.vllm_config = vllm_config self._grammar_bitmask: Optional[torch.Tensor] = None @@ -72,6 +75,12 @@ def grammar_init(self, request: Request) -> None: raise ValueError( f"Unsupported structured output backend: {backend_name}") + if (reasoning_backend := + self.vllm_config.decoding_config.reasoning_backend + ) is not None and self.reasoner is None: + self.reasoner = ReasoningParserManager.get_reasoning_parser( + reasoning_backend)(tokenizer=tokenizer) + grammar = self.executor.submit(self._async_create_grammar, request) request.structured_output_request.grammar = grammar # type: ignore[assignment] From 92527f6561c820c08bea11b3d4743c28d63d4b31 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Mon, 14 Apr 2025 06:59:33 +0000 Subject: [PATCH 03/40] feat: support parsing thinking tokens Signed-off-by: Aaron Pham --- .../llm/test_struct_output_generate.py | 83 ++++++++++++++++--- vllm/reasoning/abs_reasoning_parsers.py | 2 +- vllm/v1/core/sched/scheduler.py | 19 +++-- vllm/v1/request.py | 1 + vllm/v1/structured_output/__init__.py | 28 +++++-- 5 files changed, 109 insertions(+), 24 deletions(-) diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index 1e4a8053997..755e9884df5 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -16,14 +16,37 @@ from vllm.platforms import current_platform from vllm.sampling_params import GuidedDecodingParams, SamplingParams -PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [ - ("mistralai/Ministral-8B-Instruct-2410", "xgrammar:disable-any-whitespace", - "auto"), - ("mistralai/Ministral-8B-Instruct-2410", "guidance:disable-any-whitespace", - "auto"), - ("mistralai/Ministral-8B-Instruct-2410", "xgrammar:disable-any-whitespace", - "mistral"), - ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar:disable-any-whitespace", "auto"), +PARAMS_MODELS_BACKENDS_TOKENIZER_MODE_REASONING_PARSER = [ + ( + "mistralai/Ministral-8B-Instruct-2410", + "xgrammar:disable-any-whitespace", + "auto", + None, + ), + ( + "mistralai/Ministral-8B-Instruct-2410", + "guidance:disable-any-whitespace", + "auto", + None, + ), + ( + "mistralai/Ministral-8B-Instruct-2410", + "xgrammar:disable-any-whitespace", + "mistral", + None, + ), + ( + "Qwen/Qwen2.5-1.5B-Instruct", + "xgrammar:disable-any-whitespace", + "auto", + None, + ), + ( + "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", + "xgrammar:disable-any-whitespace", + "auto", + "deepseek_r1", + ), #FIXME: This test is flaky on CI thus disabled #("Qwen/Qwen2.5-1.5B-Instruct", "guidance:disable-any-whitespace", "auto"), ] @@ -48,8 +71,9 @@ class CarDescription(BaseModel): @pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("model_name, guided_decoding_backend, tokenizer_mode", - PARAMS_MODELS_BACKENDS_TOKENIZER_MODE) +@pytest.mark.parametrize( + "model_name, guided_decoding_backend, tokenizer_mode, reasoning_parser", + PARAMS_MODELS_BACKENDS_TOKENIZER_MODE_REASONING_PARSER) def test_structured_output( monkeypatch: pytest.MonkeyPatch, sample_json_schema: dict[str, Any], @@ -60,6 +84,7 @@ def test_structured_output( sample_guided_choice: str, guided_decoding_backend: str, tokenizer_mode: str, + reasoning_parser: str | None, model_name: str, ): monkeypatch.setenv("VLLM_USE_V1", "1") @@ -73,7 +98,9 @@ def test_structured_output( enforce_eager=enforce_eager, max_model_len=1024, guided_decoding_backend=guided_decoding_backend, - tokenizer_mode=tokenizer_mode) + tokenizer_mode=tokenizer_mode, + enable_reasoning=reasoning_parser is not None, + reasoning_parser=reasoning_parser) # # Test 1: Generate JSON output based on a provided schema @@ -368,6 +395,40 @@ def test_structured_output( output_json = json.loads(generated_text) jsonschema.validate(instance=output_json, schema=json_schema) + # + # Test 11: Generate structured output with reasoning step + # + if reasoning_parser is not None: + reasoning_prompt = "Solve the following math problem step-by-step, then provide the final answer as JSON object with a single key 'result'. Problem: What is 5 * 8 + 2?" # noqa: E501 + reasoning_schema = { + "type": "object", + "properties": { + "result": { + "type": "integer" + } + }, + "required": ["result"] + } + + sampling_params = SamplingParams( + temperature=0.1, # Low temp for deterministic reasoning + max_tokens=200, + guided_decoding=GuidedDecodingParams(json=reasoning_schema)) + outputs = llm.generate(prompts=[reasoning_prompt], + sampling_params=sampling_params, + use_tqdm=True) + + assert outputs is not None + output = outputs[0] + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + output_json = json.loads(generated_text) + jsonschema.validate(instance=output_json, schema=reasoning_schema) + @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("model_name, tokenizer_mode", diff --git a/vllm/reasoning/abs_reasoning_parsers.py b/vllm/reasoning/abs_reasoning_parsers.py index ca672561574..1f742354c24 100644 --- a/vllm/reasoning/abs_reasoning_parsers.py +++ b/vllm/reasoning/abs_reasoning_parsers.py @@ -106,7 +106,7 @@ class ReasoningParserManager: reasoning_parsers: dict[str, type] = {} @classmethod - def get_reasoning_parser(cls, name: str) -> type[ReasoningParser]: + def get_reasoning_parser(cls, name: str | None) -> type[ReasoningParser]: """ Get reasoning parser by name which is registered by `register_module`. diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index adec4462963..0196ad1de09 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -705,11 +705,20 @@ def update_from_output( new_logprobs = logprobs.slice(req_index, req_index + 1) if new_token_ids and request.use_structured_output: - # NOTE: structured_output_request - # should not be None if use_structured_output, we have - # check above, so safe to ignore type warning - request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] - req_id, new_token_ids) + advance_fsm = False + reasoner = self.structured_output_manager.reasoner + if reasoner is None or request.reasoning_ended: + advance_fsm = True + elif reasoner.is_reasoning_end(request.all_token_ids): + request.reasoning_ended = True + advance_fsm = True + + if advance_fsm: + # NOTE: structured_output_request + # should not be None if use_structured_output, we have + # check above, so safe to ignore type warning + request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] + req_id, new_token_ids) # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 6be72431dde..941672f79d8 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -37,6 +37,7 @@ def __init__( self.eos_token_id = eos_token_id self.lora_request = lora_request self.structured_output_request = structured_output_request + self.reasoning_ended: bool = False self.status = (RequestStatus.WAITING_FOR_FSM if sampling_params.guided_decoding is not None else diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 3770596d9de..b8f80d6c8a0 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -119,18 +119,32 @@ def grammar_bitmask( # position in the batch. Resize the bitmask down to the size of # the batch. bitmask_tensor = self._grammar_bitmask + # Reset the relevant part of the bitmask before filling + if batch_len > 0: + bitmask_tensor[:batch_len].fill_(-1) + for req_id, batch_index in structured_output_request_ids.items(): - request = requests[req_id].structured_output_request - assert request is not None and request.grammar is not None - if not request.grammar.is_terminated(): - request.grammar.fill_bitmask(bitmask_tensor, batch_index) - if batch_len < self._grammar_bitmask.shape[0]: - bitmask_tensor = self._grammar_bitmask[:batch_len] + full_request = requests[req_id] + so_request = full_request.structured_output_request + assert so_request is not None and so_request.grammar is not None + + apply_bitmask = (self.reasoner is None + or full_request.reasoning_ended + or self.reasoner.is_reasoning_end( + full_request.all_token_ids)) + + if apply_bitmask and not so_request.grammar.is_terminated(): + so_request.grammar.fill_bitmask(bitmask_tensor, batch_index) + + if batch_len < bitmask_tensor.shape[0]: + final_bitmask_tensor = bitmask_tensor[:batch_len] + else: + final_bitmask_tensor = bitmask_tensor # After finishing with the xgrammar operations, we convert to # np.ndarray, because that is much more efficient for serialization # and deserialization when sending this to the GPU workers. - return bitmask_tensor.numpy() + return final_bitmask_tensor.numpy() def clear_backend(self) -> None: if self.backend is not None: From e50ea4044ec908ca12412c511f77c4ca6ba5274a Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Mon, 14 Apr 2025 09:33:42 +0000 Subject: [PATCH 04/40] chore: add a check to make sure that the reasoning token is not being advanced Signed-off-by: Aaron Pham --- vllm/v1/core/sched/scheduler.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 0196ad1de09..28a214c0109 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -707,13 +707,26 @@ def update_from_output( if new_token_ids and request.use_structured_output: advance_fsm = False reasoner = self.structured_output_manager.reasoner + is_reasoning_end_this_step = False # Flag the transition + if reasoner is None or request.reasoning_ended: + # Reasoning was already off or never active advance_fsm = True - elif reasoner.is_reasoning_end(request.all_token_ids): - request.reasoning_ended = True - advance_fsm = True + else: + # Reasoning is active, check if it ends now + if reasoner.is_reasoning_end(request.all_token_ids): + request.reasoning_ended = True + is_reasoning_end_this_step = True + # Don't advance FSM in the step the transition occurs, + # as new_token_ids might contain the end marker. + advance_fsm = False + else: + # Reasoning continues, don't advance FSM + advance_fsm = False - if advance_fsm: + # Only advance FSM if reasoning was already off OR + # if we are not in the specific step where reasoning just ended. + if advance_fsm and not is_reasoning_end_this_step: # NOTE: structured_output_request # should not be None if use_structured_output, we have # check above, so safe to ignore type warning From 7542582f0bdc73f3ea861d640ff9938f527f4449 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Mon, 14 Apr 2025 06:45:05 -0400 Subject: [PATCH 05/40] chore: update docs Signed-off-by: Aaron Pham --- docs/source/features/reasoning_outputs.md | 46 +++++++++++------------ 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/docs/source/features/reasoning_outputs.md b/docs/source/features/reasoning_outputs.md index 3a0be69f8e1..75bb3f5f862 100644 --- a/docs/source/features/reasoning_outputs.md +++ b/docs/source/features/reasoning_outputs.md @@ -10,11 +10,11 @@ Reasoning models return an additional `reasoning_content` field in their outputs vLLM currently supports the following reasoning models: -| Model Series | Parser Name | Structured Output Support | Tool Calling | -|--------------|-------------|------------------|-------------| -| [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `deepseek_r1` | `guided_json`, `guided_regex` | ❌ | -| [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `guided_json`, `guided_regex` | ✅ | -| [IBM Granite 3.2 language models](https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a) | `granite` | ❌ | ❌ | +| Model Series | Parser Name | Structured Output Support | Tool Calling | +| ------------------------------------------------------------------------------------------------------------------------------------- | ------------- | ----------------------------- | ------------ | +| [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `deepseek_r1` | `guided_json`, `guided_regex` | ❌ | +| [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `guided_json`, `guided_regex` | ✅ | +| [IBM Granite 3.2 language models](https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a) | `granite` | ❌ | ❌ | - IBM Granite 3.2 reasoning is disabled by default; to enable it, you must also pass `thinking=True` in your `chat_template_kwargs`. @@ -64,22 +64,22 @@ Streaming chat completions are also supported for reasoning models. The `reasoni ```json { - "id": "chatcmpl-123", - "object": "chat.completion.chunk", - "created": 1694268190, - "model": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", - "system_fingerprint": "fp_44709d6fcb", - "choices": [ - { - "index": 0, - "delta": { - "role": "assistant", - "reasoning_content": "is", - }, - "logprobs": null, - "finish_reason": null - } - ] + "id": "chatcmpl-123", + "object": "chat.completion.chunk", + "created": 1694268190, + "model": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", + "system_fingerprint": "fp_44709d6fcb", + "choices": [ + { + "index": 0, + "delta": { + "role": "assistant", + "reasoning_content": "is" + }, + "logprobs": null, + "finish_reason": null + } + ] } ``` @@ -139,12 +139,10 @@ Remember to check whether the `reasoning_content` exists in the response before The reasoning content is also available in the structured output. The structured output engine like `xgrammar` will use the reasoning content to generate structured output. It is only supported in v0 engine now. ```bash -VLLM_USE_V1=0 vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \ +vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \ --enable-reasoning --reasoning-parser deepseek_r1 ``` -Please note that the `VLLM_USE_V1` environment variable must be set to `0` to use the v0 engine. - ```python from openai import OpenAI from pydantic import BaseModel From fa6da3f823fb93d72d970a9aca61500df3e18316 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Thu, 17 Apr 2025 15:02:53 -0400 Subject: [PATCH 06/40] chore: move reasoning_ended to so_request Signed-off-by: Aaron Pham --- vllm/reasoning/abs_reasoning_parsers.py | 4 +++- vllm/v1/core/sched/scheduler.py | 12 +++++------- vllm/v1/request.py | 1 - vllm/v1/structured_output/__init__.py | 2 +- vllm/v1/structured_output/request.py | 1 + 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/vllm/reasoning/abs_reasoning_parsers.py b/vllm/reasoning/abs_reasoning_parsers.py index 1f742354c24..9dd5191da91 100644 --- a/vllm/reasoning/abs_reasoning_parsers.py +++ b/vllm/reasoning/abs_reasoning_parsers.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import os from abc import abstractmethod from collections.abc import Sequence @@ -33,7 +35,7 @@ def vocab(self) -> dict[str, int]: return self.model_tokenizer.get_vocab() @abstractmethod - def is_reasoning_end(self, input_ids: list[int]) -> bool: + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: """ Check if the reasoning content ends in the input_ids. diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 28a214c0109..0f96dff0815 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -707,21 +707,19 @@ def update_from_output( if new_token_ids and request.use_structured_output: advance_fsm = False reasoner = self.structured_output_manager.reasoner - is_reasoning_end_this_step = False # Flag the transition + so_request = request.structured_output_request + is_reasoning_end_this_step = False - if reasoner is None or request.reasoning_ended: - # Reasoning was already off or never active + if reasoner is None or so_request.reasoning_ended: # type: ignore[union-attr] advance_fsm = True - else: - # Reasoning is active, check if it ends now + else: # type: ignore[union-attr] if reasoner.is_reasoning_end(request.all_token_ids): - request.reasoning_ended = True + so_request.reasoning_ended = True # type: ignore[union-attr] is_reasoning_end_this_step = True # Don't advance FSM in the step the transition occurs, # as new_token_ids might contain the end marker. advance_fsm = False else: - # Reasoning continues, don't advance FSM advance_fsm = False # Only advance FSM if reasoning was already off OR diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 941672f79d8..6be72431dde 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -37,7 +37,6 @@ def __init__( self.eos_token_id = eos_token_id self.lora_request = lora_request self.structured_output_request = structured_output_request - self.reasoning_ended: bool = False self.status = (RequestStatus.WAITING_FOR_FSM if sampling_params.guided_decoding is not None else diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index b8f80d6c8a0..85d5c0597f7 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -129,7 +129,7 @@ def grammar_bitmask( assert so_request is not None and so_request.grammar is not None apply_bitmask = (self.reasoner is None - or full_request.reasoning_ended + or so_request.reasoning_ended or self.reasoner.is_reasoning_end( full_request.all_token_ids)) diff --git a/vllm/v1/structured_output/request.py b/vllm/v1/structured_output/request.py index 9e54b8bf028..edcea60888e 100644 --- a/vllm/v1/structured_output/request.py +++ b/vllm/v1/structured_output/request.py @@ -20,6 +20,7 @@ class StructuredOutputRequest: sampling_params: SamplingParams _grammar: Optional[Union[Future[StructuredOutputGrammar], StructuredOutputGrammar]] = None + reasoning_ended: bool = False def _check_grammar_completion(self) -> bool: # NOTE: We have to lazy import to gate circular imports From 061ee09305ea5bb8a4fd48d58a1004972c180fd0 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Thu, 17 Apr 2025 19:28:38 +0000 Subject: [PATCH 07/40] chore: reduce diff Signed-off-by: Aaron Pham --- vllm/v1/structured_output/__init__.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 85d5c0597f7..6aa6f7e63f5 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -119,9 +119,6 @@ def grammar_bitmask( # position in the batch. Resize the bitmask down to the size of # the batch. bitmask_tensor = self._grammar_bitmask - # Reset the relevant part of the bitmask before filling - if batch_len > 0: - bitmask_tensor[:batch_len].fill_(-1) for req_id, batch_index in structured_output_request_ids.items(): full_request = requests[req_id] @@ -137,14 +134,12 @@ def grammar_bitmask( so_request.grammar.fill_bitmask(bitmask_tensor, batch_index) if batch_len < bitmask_tensor.shape[0]: - final_bitmask_tensor = bitmask_tensor[:batch_len] - else: - final_bitmask_tensor = bitmask_tensor + bitmask_tensor = self._grammar_bitmask[:batch_len] # After finishing with the xgrammar operations, we convert to # np.ndarray, because that is much more efficient for serialization # and deserialization when sending this to the GPU workers. - return final_bitmask_tensor.numpy() + return bitmask_tensor.numpy() def clear_backend(self) -> None: if self.backend is not None: From 5eecdbbf1da3437cf56e365ffce2b075e55aa666 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Thu, 17 Apr 2025 19:45:15 +0000 Subject: [PATCH 08/40] chore: move up checker logics Signed-off-by: Aaron Pham --- vllm/v1/core/sched/scheduler.py | 28 ++++++++++++++------------- vllm/v1/structured_output/__init__.py | 13 ++++++------- 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 0f96dff0815..5be7219f497 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -5,7 +5,7 @@ import time from collections import defaultdict, deque from collections.abc import Iterable -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.factory import ( @@ -707,20 +707,22 @@ def update_from_output( if new_token_ids and request.use_structured_output: advance_fsm = False reasoner = self.structured_output_manager.reasoner - so_request = request.structured_output_request is_reasoning_end_this_step = False - if reasoner is None or so_request.reasoning_ended: # type: ignore[union-attr] + # NOTE: use_structured_output implies + # structured_output_request is not None, + # but type checker isn't smart enough to know this. + # This only affect type runtime, not actual runtime. + # assert is also not recommended on perf-sensitive runtime path. + if TYPE_CHECKING: + assert request.structured_output_request is not None + assert request.structured_output_request.grammar is not None + + if reasoner is None or request.structured_output_request.reasoning_ended: # noqa: E501 advance_fsm = True - else: # type: ignore[union-attr] - if reasoner.is_reasoning_end(request.all_token_ids): - so_request.reasoning_ended = True # type: ignore[union-attr] - is_reasoning_end_this_step = True - # Don't advance FSM in the step the transition occurs, - # as new_token_ids might contain the end marker. - advance_fsm = False - else: - advance_fsm = False + elif reasoner.is_reasoning_end(request.all_token_ids): + request.structured_output_request.reasoning_ended = True + is_reasoning_end_this_step = True # Only advance FSM if reasoning was already off OR # if we are not in the specific step where reasoning just ended. @@ -728,7 +730,7 @@ def update_from_output( # NOTE: structured_output_request # should not be None if use_structured_output, we have # check above, so safe to ignore type warning - request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] + request.structured_output_request.grammar.accept_tokens( req_id, new_token_ids) # Get prompt logprobs for this request. diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 6aa6f7e63f5..26f3dc5722c 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -122,18 +122,17 @@ def grammar_bitmask( for req_id, batch_index in structured_output_request_ids.items(): full_request = requests[req_id] - so_request = full_request.structured_output_request - assert so_request is not None and so_request.grammar is not None + request = full_request.structured_output_request + assert request is not None and request.grammar is not None - apply_bitmask = (self.reasoner is None - or so_request.reasoning_ended + apply_bitmask = (self.reasoner is None or request.reasoning_ended or self.reasoner.is_reasoning_end( full_request.all_token_ids)) - if apply_bitmask and not so_request.grammar.is_terminated(): - so_request.grammar.fill_bitmask(bitmask_tensor, batch_index) + if apply_bitmask and not request.grammar.is_terminated(): + request.grammar.fill_bitmask(bitmask_tensor, batch_index) - if batch_len < bitmask_tensor.shape[0]: + if batch_len < self._grammar_bitmask.shape[0]: bitmask_tensor = self._grammar_bitmask[:batch_len] # After finishing with the xgrammar operations, we convert to From 873b08bb3dfb56123be2bd339963726861c46de3 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Sat, 26 Apr 2025 06:37:42 +0000 Subject: [PATCH 09/40] chore: update correct function imports Signed-off-by: Aaron Pham --- vllm/v1/structured_output/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 26f3dc5722c..f9e2631b4bc 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -54,8 +54,8 @@ def grammar_init(self, request: Request) -> None: tokenizer_group = init_tokenizer_from_configs( model_config=self.vllm_config.model_config, scheduler_config=self.vllm_config.scheduler_config, - parallel_config=self.vllm_config.parallel_config, - lora_config=self.vllm_config.lora_config) + lora_config=self.vllm_config.lora_config, + ) tokenizer_group.ping() tokenizer = tokenizer_group.get_lora_tokenizer(None) vocab_size = self.vllm_config.model_config.get_vocab_size() From 218ad9c44fcc8f6c687e1b82e5fafdff7546e1f7 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Sat, 26 Apr 2025 06:43:01 +0000 Subject: [PATCH 10/40] chore: remove incorrect function Signed-off-by: Aaron Pham --- vllm/v1/structured_output/__init__.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index f9e2631b4bc..89624eba673 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -51,13 +51,11 @@ def grammar_init(self, request: Request) -> None: # backends on a per-request basis in V1 (for now, anyway...). if self.backend is None: backend_name = request.sampling_params.guided_decoding.backend_name - tokenizer_group = init_tokenizer_from_configs( + tokenizer = init_tokenizer_from_configs( model_config=self.vllm_config.model_config, scheduler_config=self.vllm_config.scheduler_config, lora_config=self.vllm_config.lora_config, - ) - tokenizer_group.ping() - tokenizer = tokenizer_group.get_lora_tokenizer(None) + ).get_lora_tokenizer(None) vocab_size = self.vllm_config.model_config.get_vocab_size() if backend_name == "xgrammar": self.backend = XgrammarBackend( From 1ec89289d0cf9e0f5a5fcea1a7a1e1249e0d0cda Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Sat, 26 Apr 2025 10:16:26 +0000 Subject: [PATCH 11/40] fix: make sure to reset the bitmask before update Signed-off-by: Aaron Pham --- vllm/v1/core/sched/scheduler.py | 3 +++ vllm/v1/structured_output/__init__.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 5be7219f497..c3b411b80c7 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -723,6 +723,9 @@ def update_from_output( elif reasoner.is_reasoning_end(request.all_token_ids): request.structured_output_request.reasoning_ended = True is_reasoning_end_this_step = True + advance_fsm = False + else: + advance_fsm = False # Only advance FSM if reasoning was already off OR # if we are not in the specific step where reasoning just ended. diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 89624eba673..964208fbf6d 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -117,6 +117,9 @@ def grammar_bitmask( # position in the batch. Resize the bitmask down to the size of # the batch. bitmask_tensor = self._grammar_bitmask + # Reset the relevant part of the bitmask before filling + if batch_len > 0: + bitmask_tensor[:batch_len].fill_(-1) for req_id, batch_index in structured_output_request_ids.items(): full_request = requests[req_id] From 9b6f4e8391469eecbe395ab53093c1f8c4c4921a Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Sat, 26 Apr 2025 12:13:02 +0000 Subject: [PATCH 12/40] chore: make sure non reasoning case works Signed-off-by: Aaron Pham --- vllm/engine/arg_utils.py | 6 ++++-- vllm/v1/core/sched/scheduler.py | 29 +++++++++++++-------------- vllm/v1/structured_output/__init__.py | 4 +--- 3 files changed, 19 insertions(+), 20 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 31971a51aed..7c8d92eef0b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -73,7 +73,7 @@ def optional_float(val: str) -> Optional[float]: def nullable_kvs(val: str) -> Optional[dict[str, int]]: """NOTE: This function is deprecated, args should be passed as JSON strings instead. - + Parses a string containing comma separate key [str] to value [int] pairs into a dictionary. @@ -264,6 +264,8 @@ def __post_init__(self): if not self.tokenizer: self.tokenizer = self.model + self.enable_reasoning = self.reasoning_parser is not None + # support `EngineArgs(compilation_config={...})` # without having to manually construct a # CompilationConfig object @@ -1706,7 +1708,7 @@ def _warn_or_fallback(feature_name: str) -> bool: def human_readable_int(value): """Parse human-readable integers like '1k', '2M', etc. Including decimal values with decimal multipliers. - + Examples: - '1k' -> 1,000 - '1K' -> 1,024 diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index c3b411b80c7..196736f2f2d 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -705,8 +705,8 @@ def update_from_output( new_logprobs = logprobs.slice(req_index, req_index + 1) if new_token_ids and request.use_structured_output: - advance_fsm = False reasoner = self.structured_output_manager.reasoner + advance_fsm = reasoner is None is_reasoning_end_this_step = False # NOTE: use_structured_output implies @@ -718,23 +718,22 @@ def update_from_output( assert request.structured_output_request is not None assert request.structured_output_request.grammar is not None - if reasoner is None or request.structured_output_request.reasoning_ended: # noqa: E501 - advance_fsm = True - elif reasoner.is_reasoning_end(request.all_token_ids): - request.structured_output_request.reasoning_ended = True - is_reasoning_end_this_step = True - advance_fsm = False - else: - advance_fsm = False + if reasoner is not None: + if request.structured_output_request.reasoning_ended: # noqa: E501 + advance_fsm = True + elif reasoner.is_reasoning_end(request.all_token_ids): + request.structured_output_request.reasoning_ended = True + is_reasoning_end_this_step = True + advance_fsm = False + else: + advance_fsm = False # Only advance FSM if reasoning was already off OR # if we are not in the specific step where reasoning just ended. - if advance_fsm and not is_reasoning_end_this_step: - # NOTE: structured_output_request - # should not be None if use_structured_output, we have - # check above, so safe to ignore type warning - request.structured_output_request.grammar.accept_tokens( - req_id, new_token_ids) + # yapf: off + if advance_fsm and (not is_reasoning_end_this_step if reasoner is not None else True): # noqa: E501 + request.structured_output_request.grammar.accept_tokens(req_id, new_token_ids) # noqa: E501 + # yapf: on # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 964208fbf6d..d1b5081685c 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -126,9 +126,7 @@ def grammar_bitmask( request = full_request.structured_output_request assert request is not None and request.grammar is not None - apply_bitmask = (self.reasoner is None or request.reasoning_ended - or self.reasoner.is_reasoning_end( - full_request.all_token_ids)) + apply_bitmask = request.reasoning_ended if self.reasoner is not None else True if apply_bitmask and not request.grammar.is_terminated(): request.grammar.fill_bitmask(bitmask_tensor, batch_index) From 63eecbf61a8c15346518d0ec4a442512b87c350d Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Sat, 26 Apr 2025 12:34:50 +0000 Subject: [PATCH 13/40] fix: remove unused check Signed-off-by: Aaron Pham --- vllm/v1/core/sched/scheduler.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 196736f2f2d..01411c3720a 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -707,7 +707,6 @@ def update_from_output( if new_token_ids and request.use_structured_output: reasoner = self.structured_output_manager.reasoner advance_fsm = reasoner is None - is_reasoning_end_this_step = False # NOTE: use_structured_output implies # structured_output_request is not None, @@ -723,17 +722,13 @@ def update_from_output( advance_fsm = True elif reasoner.is_reasoning_end(request.all_token_ids): request.structured_output_request.reasoning_ended = True - is_reasoning_end_this_step = True advance_fsm = False else: advance_fsm = False - # Only advance FSM if reasoning was already off OR - # if we are not in the specific step where reasoning just ended. - # yapf: off - if advance_fsm and (not is_reasoning_end_this_step if reasoner is not None else True): # noqa: E501 - request.structured_output_request.grammar.accept_tokens(req_id, new_token_ids) # noqa: E501 - # yapf: on + if advance_fsm: + request.structured_output_request.grammar.accept_tokens( + req_id, new_token_ids) # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) From 910ee0c76f4766a72e285d77d24839f89235afe8 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Tue, 29 Apr 2025 03:08:18 +0000 Subject: [PATCH 14/40] chore: fix pre-comimt Signed-off-by: Aaron Pham --- vllm/v1/structured_output/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index d1b5081685c..dfc38c8a860 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -126,7 +126,7 @@ def grammar_bitmask( request = full_request.structured_output_request assert request is not None and request.grammar is not None - apply_bitmask = request.reasoning_ended if self.reasoner is not None else True + apply_bitmask = request.reasoning_ended if self.reasoner is not None else True # noqa: E501 if apply_bitmask and not request.grammar.is_terminated(): request.grammar.fill_bitmask(bitmask_tensor, batch_index) From 327a0d0b889ce810bfc25f686240d1e80e19debb Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Tue, 29 Apr 2025 17:24:25 -0400 Subject: [PATCH 15/40] revert: bad merge and remove inlines Signed-off-by: Aaron Pham --- .../llm/test_struct_output_generate.py | 6 +- vllm/config.py | 476 +++++++----------- vllm/v1/structured_output/backend_xgrammar.py | 29 +- 3 files changed, 198 insertions(+), 313 deletions(-) diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index a67250e98c5..b9219be6962 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -72,7 +72,8 @@ class CarDescription(BaseModel): @pytest.mark.skip_global_cleanup @pytest.mark.parametrize( "model_name, guided_decoding_backend, tokenizer_mode, reasoning_parser", - PARAMS_MODELS_BACKENDS_TOKENIZER_MODE_REASONING_PARSER) + PARAMS_MODELS_BACKENDS_TOKENIZER_MODE_REASONING_PARSER, +) def test_structured_output( monkeypatch: pytest.MonkeyPatch, sample_json_schema: dict[str, Any], @@ -100,7 +101,6 @@ def test_structured_output( guided_decoding_backend=guided_decoding_backend, guided_decoding_disable_any_whitespace=True, tokenizer_mode=tokenizer_mode, - enable_reasoning=reasoning_parser is not None, reasoning_parser=reasoning_parser, ) @@ -510,7 +510,7 @@ def test_structured_output( } sampling_params = SamplingParams( - temperature=0.1, # Low temp for deterministic reasoning + temperature=0.1, max_tokens=200, guided_decoding=GuidedDecodingParams(json=reasoning_schema)) outputs = llm.generate(prompts=[reasoning_prompt], diff --git a/vllm/config.py b/vllm/config.py index 4ba9afb9962..abe59734e2d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -70,16 +70,8 @@ _POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120 -TaskOption = Literal[ - "auto", - "generate", - "embedding", - "embed", - "classify", - "score", - "reward", - "transcription", -] +TaskOption = Literal["auto", "generate", "embedding", "embed", "classify", + "score", "reward", "transcription"] _ResolvedTask = Literal["generate", "embed", "classify", "score", "reward", "draft", "transcription"] @@ -437,9 +429,8 @@ def __init__( self.maybe_pull_model_tokenizer_for_s3(model, tokenizer) - if ((backend := envs.VLLM_ATTENTION_BACKEND) - and backend == "FLASHINFER" - and find_spec("flashinfer") is None): + if (backend := envs.VLLM_ATTENTION_BACKEND + ) and backend == "FLASHINFER" and find_spec("flashinfer") is None: raise ValueError( "VLLM_ATTENTION_BACKEND is set to FLASHINFER, but flashinfer " "module was not found. See " @@ -467,13 +458,9 @@ def __init__( raise ValueError( "Sleep mode is not supported on current platform.") - hf_config = get_config( - self.hf_config_path or self.model, - trust_remote_code, - revision, - code_revision, - config_format, - ) + hf_config = get_config(self.hf_config_path or self.model, + trust_remote_code, revision, code_revision, + config_format) if hf_overrides_kw: logger.info("Overriding HF config with %s", hf_overrides_kw) @@ -503,11 +490,9 @@ def __init__( isinstance(sliding_window, list) or (self.hf_text_config.model_type in interleaved_attn_models)) - if not self.disable_sliding_window and has_interleaved_attention: - if (backend := envs.VLLM_ATTENTION_BACKEND) in ( - "XFORMERS", - "FLASHINFER", - ): + if (not self.disable_sliding_window and has_interleaved_attention): + if (backend := + envs.VLLM_ATTENTION_BACKEND) in ("XFORMERS", "FLASHINFER"): sliding_window_len_min = get_min_sliding_window( self.hf_text_config.sliding_window) @@ -534,8 +519,7 @@ def __init__( disable_sliding_window=self.disable_sliding_window, sliding_window_len=self.get_hf_config_sliding_window(), spec_target_max_model_len=spec_target_max_model_len, - encoder_config=self.encoder_config, - ) + encoder_config=self.encoder_config) self.served_model_name = get_served_model_name(model, served_model_name) self.multimodal_config = self._init_multimodal_config( @@ -621,13 +605,11 @@ def _init_multimodal_config( ) if limit_mm_per_prompt: - raise ValueError( - "`limit_mm_per_prompt` is only supported for multimodal models." - ) + raise ValueError("`limit_mm_per_prompt` is only supported for " + "multimodal models.") if mm_processor_kwargs: - raise ValueError( - "`mm_processor_kwargs` is only supported for multimodal models." - ) + raise ValueError("`mm_processor_kwargs` is only supported for " + "multimodal models.") if disable_mm_preprocessor_cache: raise ValueError("`disable_mm_preprocessor_cache` is only " "supported for multimodal models.") @@ -642,6 +624,7 @@ def _init_pooler_config( self, override_pooler_config: Optional["PoolerConfig"], ) -> Optional["PoolerConfig"]: + if self.runner_type == "pooling": user_config = override_pooler_config or PoolerConfig() @@ -757,10 +740,7 @@ def _resolve_task( logger.info( "This model supports multiple tasks: %s. " - "Defaulting to '%s'.", - supported_tasks, - selected_task, - ) + "Defaulting to '%s'.", supported_tasks, selected_task) else: # Aliases if task_option == "embedding": @@ -796,19 +776,9 @@ def _parse_quant_hf_config(self): def _verify_quantization(self) -> None: supported_quantization = QUANTIZATION_METHODS optimized_quantization_methods = [ - "fp8", - "marlin", - "modelopt", - "gptq_marlin_24", - "gptq_marlin", - "awq_marlin", - "fbgemm_fp8", - "compressed-tensors", - "experts_int8", - "quark", - "nvfp4", - "bitblas", - "gptq_bitblas", + "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin", + "awq_marlin", "fbgemm_fp8", "compressed-tensors", "experts_int8", + "quark", "nvfp4", "bitblas", "gptq_bitblas" ] if self.quantization is not None: self.quantization = self.quantization.lower() @@ -879,29 +849,24 @@ def _verify_quantization(self) -> None: f"Unknown quantization method: {self.quantization}. Must " f"be one of {supported_quantization}.") from vllm.platforms import current_platform - current_platform.verify_quantization(self.quantization) if self.quantization not in optimized_quantization_methods: logger.warning( "%s quantization is not fully " "optimized yet. The speed can be slower than " - "non-quantized models.", - self.quantization, - ) + "non-quantized models.", self.quantization) def _verify_cuda_graph(self) -> None: if self.max_seq_len_to_capture is None: self.max_seq_len_to_capture = self.max_model_len self.max_seq_len_to_capture = min(self.max_seq_len_to_capture, self.max_model_len) - ROCM_UNSUPPORTED_MODELS = ["mllama"] + ROCM_UNSUPPORTED_MODELS = ['mllama'] if (self.hf_config.model_type in ROCM_UNSUPPORTED_MODELS and not self.enforce_eager and current_platform.is_rocm()): logger.warning( "CUDA graph is not supported for %s on ROCm yet, fallback " - "to the eager mode.", - self.hf_config.model_type, - ) + "to the eager mode.", self.hf_config.model_type) self.enforce_eager = True def _verify_bnb_config(self) -> None: @@ -958,7 +923,6 @@ def verify_async_output_proc(self, parallel_config, speculative_config, # Reminder: Please update docs/source/features/compatibility_matrix.md # If the feature combo become valid from vllm.platforms import current_platform - if not current_platform.is_async_output_supported(self.enforce_eager): self.use_async_output_proc = False return @@ -981,6 +945,7 @@ def verify_with_parallel_config( self, parallel_config: "ParallelConfig", ) -> None: + if parallel_config.distributed_executor_backend == "external_launcher": assert self.seed is not None, ( "Seed must be set when using external launcher backend to " @@ -1009,7 +974,7 @@ def verify_with_parallel_config( self.use_async_output_proc = False def get_hf_config_sliding_window( - self, ) -> Union[Optional[int], list[Optional[int]]]: + self) -> Union[Optional[int], list[Optional[int]]]: """Get the sliding window size, or None if disabled.""" # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in @@ -1021,7 +986,8 @@ def get_hf_config_sliding_window( return getattr(self.hf_text_config, "sliding_window", None) def get_sliding_window(self) -> Optional[Union[int, list[Optional[int]]]]: - """Get the sliding window size, or None if disabled.""" + """Get the sliding window size, or None if disabled. + """ # If user disables sliding window, return None. if self.disable_sliding_window: return None @@ -1038,18 +1004,15 @@ def get_hidden_size(self) -> int: def is_deepseek_mla(self) -> bool: if not hasattr(self.hf_text_config, "model_type"): return False - elif self.hf_text_config.model_type in ( - "deepseek_v2", - "deepseek_v3", - "deepseek_mtp", - ): + elif self.hf_text_config.model_type in \ + ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'): return self.hf_text_config.kv_lora_rank is not None - elif self.hf_text_config.model_type == "eagle": + elif self.hf_text_config.model_type == 'eagle': # if the model is an EAGLE module, check for the # underlying architecture - return (self.hf_text_config.model.model_type - in ("deepseek_v2", "deepseek_v3") - and self.hf_text_config.kv_lora_rank is not None) + return self.hf_text_config.model.model_type in \ + ('deepseek_v2', 'deepseek_v3') \ + and self.hf_text_config.kv_lora_rank is not None return False def get_head_size(self) -> int: @@ -1101,17 +1064,14 @@ def get_total_num_kv_heads(self) -> int: return self.hf_config.attn_config["kv_n_heads"] return self.hf_config.num_attention_heads if self.hf_config.model_type == "dbrx": - return getattr( - self.hf_config.attn_config, - "kv_n_heads", - self.hf_config.num_attention_heads, - ) + return getattr(self.hf_config.attn_config, "kv_n_heads", + self.hf_config.num_attention_heads) if self.hf_config.model_type == "nemotron-nas": for block in self.hf_config.block_configs: if not block.attention.no_op: - return (self.hf_config.num_attention_heads // - block.attention.n_heads_in_group) + return self.hf_config.num_attention_heads \ + // block.attention.n_heads_in_group raise RuntimeError("Couldn't determine number of kv heads") @@ -1158,7 +1118,6 @@ def get_num_attention_heads(self, def get_layers_start_end_indices( self, parallel_config: "ParallelConfig") -> tuple[int, int]: from vllm.distributed.utils import get_pp_indices - if self.hf_text_config.model_type == "deepseek_mtp": total_num_hidden_layers = getattr(self.hf_text_config, "num_nextn_predict_layers", 0) @@ -1184,8 +1143,9 @@ def get_num_layers_by_block_type( # This function relies on 'layers_block_type' in hf_config, # for w/o this attribute, we will need to have workarounds like so attn_block_type = block_type == LayerBlockType.attention - is_transformer = (not self.is_hybrid and not self.has_noops - and not self.is_attention_free) + is_transformer = not self.is_hybrid and \ + not self.has_noops and \ + not self.is_attention_free start, end = self.get_layers_start_end_indices(parallel_config) if is_transformer: @@ -1343,8 +1303,8 @@ def is_v1_compatible(self) -> bool: @property def is_matryoshka(self) -> bool: - return hasattr(self.hf_config, "matryoshka_dimensions") or getattr( - self.hf_config, "is_matryoshka", False) + return (hasattr(self.hf_config, "matryoshka_dimensions") + or getattr(self.hf_config, "is_matryoshka", False)) @property def matryoshka_dimensions(self): @@ -1558,7 +1518,7 @@ class LoadConfig: """Configuration for loading the model weights.""" load_format: Union[str, LoadFormat, - "BaseModelLoader"] = (LoadFormat.AUTO.value) + "BaseModelLoader"] = LoadFormat.AUTO.value """The format of the model weights to load:\n - "auto" will try to load the weights in the safetensors format and fall back to the pytorch bin format if safetensors format is not available.\n @@ -1621,8 +1581,7 @@ def __post_init__(self): if self.ignore_patterns is not None and len(self.ignore_patterns) > 0: logger.info( "Ignoring the following patterns when downloading weights: %s", - self.ignore_patterns, - ) + self.ignore_patterns) else: self.ignore_patterns = ["original/**/*"] @@ -1738,8 +1697,7 @@ def stateless_init_dp_group(self) -> "ProcessGroup": self.get_next_dp_init_port(), self.data_parallel_rank, self.data_parallel_size, - backend="gloo", - ) + backend="gloo") return dp_group @@ -1772,8 +1730,8 @@ def compute_hash(self): return hashlib.sha256(str(factors).encode()).hexdigest() def __post_init__(self) -> None: - self.world_size = (self.pipeline_parallel_size * - self.tensor_parallel_size) + self.world_size = self.pipeline_parallel_size * \ + self.tensor_parallel_size if self.data_parallel_size > 1: # Data parallel was specified in the engine args. @@ -1791,13 +1749,11 @@ def __post_init__(self) -> None: if self.distributed_executor_backend == "external_launcher": import os - os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" logger.info("Disabling V1 multiprocessing for external launcher.") ray_only_devices: list[str] = [] from vllm.platforms import current_platform - if (current_platform.device_type in ray_only_devices and self.world_size > 1): if self.distributed_executor_backend is None: @@ -1812,7 +1768,6 @@ def __post_init__(self) -> None: # current node and we aren't in a ray placement group. from vllm.executor import ray_utils - backend: DistributedExecutorBackend = "mp" ray_found = ray_utils.ray_is_available() if current_platform.is_neuron(): @@ -1831,10 +1786,8 @@ def __post_init__(self) -> None: backend = "ray" else: from ray import is_initialized as ray_is_initialized - if ray_is_initialized(): from ray.util import get_current_placement_group - if get_current_placement_group(): backend = "ray" self.distributed_executor_backend = backend @@ -1856,16 +1809,11 @@ def _verify_args(self) -> None: # Lazy import to avoid circular import from vllm.executor.executor_base import ExecutorBase from vllm.platforms import current_platform - if self.distributed_executor_backend not in ( - "ray", - "mp", - "uni", - "external_launcher", - None, - ) and not (isinstance(self.distributed_executor_backend, type) - and issubclass(self.distributed_executor_backend, - ExecutorBase)): + "ray", "mp", "uni", + "external_launcher", None) and not (isinstance( + self.distributed_executor_backend, type) and issubclass( + self.distributed_executor_backend, ExecutorBase)): raise ValueError( "Unrecognized distributed executor backend " f"{self.distributed_executor_backend}. Supported " @@ -1873,7 +1821,6 @@ def _verify_args(self) -> None: " custom ExecutorBase subclass.") if self.use_ray: from vllm.executor import ray_utils - ray_utils.assert_ray_available() if not current_platform.use_custom_allreduce(): @@ -1882,8 +1829,8 @@ def _verify_args(self) -> None: "Disabled the custom all-reduce kernel because it is not " "supported on current platform.") if self.ray_workers_use_nsight and not self.use_ray: - raise ValueError( - "Unable to use nsight profiling unless workers run with Ray.") + raise ValueError("Unable to use nsight profiling unless workers " + "run with Ray.") assert isinstance(self.worker_extension_cls, str), ( "worker_extension_cls must be a string (qualified class name).") @@ -2033,17 +1980,13 @@ def __post_init__(self) -> None: self.max_model_len = 8192 logger.warning( "max_model_len was is not set. Defaulting to arbitrary value " - "of %d.", - self.max_model_len, - ) + "of %d.", self.max_model_len) if self.max_num_seqs is None: self.max_num_seqs = 128 logger.warning( "max_num_seqs was is not set. Defaulting to arbitrary value " - "of %d.", - self.max_num_seqs, - ) + "of %d.", self.max_num_seqs) if self.max_num_batched_tokens is None: if self.enable_chunked_prefill: @@ -2083,8 +2026,7 @@ def __post_init__(self) -> None: if self.enable_chunked_prefill: logger.info( "Chunked prefill is enabled with max_num_batched_tokens=%d.", - self.max_num_batched_tokens, - ) + self.max_num_batched_tokens) self.chunked_prefill_enabled = self.enable_chunked_prefill if self.max_num_partial_prefills > 1: @@ -2096,10 +2038,8 @@ def __post_init__(self) -> None: "Concurrent partial prefills enabled with " "max_num_partial_prefills=%d, max_long_partial_prefills=%d, " "long_prefill_token_threshold=%d", - self.max_num_partial_prefills, - self.max_long_partial_prefills, - self.long_prefill_token_threshold, - ) + self.max_num_partial_prefills, self.max_long_partial_prefills, + self.long_prefill_token_threshold) self._verify_args() @@ -2198,7 +2138,6 @@ def __post_init__(self): if self.device == "auto": # Automated device type detection from vllm.platforms import current_platform - self.device_type = current_platform.device_type if not self.device_type: raise RuntimeError( @@ -2365,6 +2304,7 @@ def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: return hf_config def __post_init__(self): + # Note: "method" is a new parameter that helps to extend the # configuration of non-model-based proposers, and the "model" parameter # will be used to set the draft model, eagle head, or additional weight @@ -2376,9 +2316,9 @@ def __post_init__(self): if self.model is None and self.num_speculative_tokens is not None: # TODO(Shangming): Refactor mtp configuration logic when supporting # mtp acceleration for more models besides deepseek_v3 - if (self.target_model_config - and self.target_model_config.hf_text_config.model_type - == "deepseek_v3"): + if self.target_model_config and \ + self.target_model_config.hf_text_config.model_type \ + == "deepseek_v3": # use the draft model from the same model: self.model = self.target_model_config.model elif self.method in ("ngram", "[ngram]"): @@ -2458,10 +2398,10 @@ def __post_init__(self): ) # Automatically detect the method - if self.method in ("eagle", "eagle3"): + if self.method in ('eagle', 'eagle3'): pass - elif ("eagle-" in self.draft_model_config.model.lower() - or "eagle3-" in self.draft_model_config.model.lower()): + elif "eagle-" in self.draft_model_config.model.lower() or \ + "eagle3-" in self.draft_model_config.model.lower(): self.method = "eagle" elif self.draft_model_config.hf_config.model_type == "medusa": self.method = "medusa" @@ -2480,22 +2420,20 @@ def __post_init__(self): from vllm.transformers_utils.configs.eagle import ( EAGLEConfig) - if isinstance(self.draft_model_config.hf_config, EAGLEConfig): pass else: eagle_config = EAGLEConfig( self.draft_model_config.hf_config, - method=self.method, - ) + method=self.method) self.draft_model_config.hf_config = eagle_config - if self.num_speculative_tokens is not None and hasattr( - self.draft_model_config.hf_config, - "num_lookahead_tokens"): - self.draft_model_config.hf_config.num_lookahead_tokens = ( - self.num_speculative_tokens) + if (self.num_speculative_tokens is not None + and hasattr(self.draft_model_config.hf_config, + "num_lookahead_tokens")): + self.draft_model_config.hf_config.num_lookahead_tokens = \ + self.num_speculative_tokens n_predict = getattr(self.draft_model_config.hf_config, "n_predict", None) @@ -2503,19 +2441,19 @@ def __post_init__(self): if self.num_speculative_tokens is None: # Default to max value defined in draft model config. self.num_speculative_tokens = n_predict - elif (self.num_speculative_tokens > n_predict - and self.num_speculative_tokens % n_predict != 0): + elif self.num_speculative_tokens > n_predict and \ + self.num_speculative_tokens % n_predict != 0: # Ensure divisibility for MTP module reuse. raise ValueError( f"num_speculative_tokens:{self.num_speculative_tokens}" f" must be divisible by {n_predict=}") - self.draft_tensor_parallel_size = ( + self.draft_tensor_parallel_size = \ SpeculativeConfig._verify_and_get_draft_tp( self.target_parallel_config, self.draft_tensor_parallel_size, - self.draft_model_config.hf_config, - )) + self.draft_model_config.hf_config + ) self.draft_model_config.max_model_len = ( SpeculativeConfig._maybe_override_draft_max_model_len( @@ -2527,8 +2465,7 @@ def __post_init__(self): self.draft_parallel_config = ( SpeculativeConfig.create_draft_parallel_config( self.target_parallel_config, - self.draft_tensor_parallel_size, - )) + self.draft_tensor_parallel_size)) if self.acceptance_method == "typical_acceptance_sampler": if self.posterior_threshold is None: @@ -2557,6 +2494,7 @@ def _maybe_override_draft_max_model_len( """ if speculative_max_model_len is not None: + if speculative_max_model_len > draft_max_model_len: raise ValueError(f"{speculative_max_model_len=} cannot be " f"larger than {draft_max_model_len=}") @@ -2574,10 +2512,9 @@ def _maybe_override_draft_max_model_len( @staticmethod def _verify_and_get_draft_tp( - target_parallel_config: ParallelConfig, - speculative_draft_tensor_parallel_size: Optional[int], - draft_hf_config: PretrainedConfig, - ) -> int: + target_parallel_config: ParallelConfig, + speculative_draft_tensor_parallel_size: Optional[int], + draft_hf_config: PretrainedConfig) -> int: """ Verifies and adjusts the tensor parallel size for a draft model specified using speculative_draft_tensor_parallel_size. @@ -2591,15 +2528,12 @@ def _verify_and_get_draft_tp( logger.warning( "%s cannot currently be run with tp>1; " "setting speculative_draft_tensor_parallel_size=1", - draft_hf_config.model_type, - ) + draft_hf_config.model_type) else: - speculative_draft_tensor_parallel_size = ( - target_parallel_config.tensor_parallel_size) + speculative_draft_tensor_parallel_size = \ + target_parallel_config.tensor_parallel_size elif speculative_draft_tensor_parallel_size not in ( - 1, - target_parallel_config.tensor_parallel_size, - ): + 1, target_parallel_config.tensor_parallel_size): raise ValueError( f"{speculative_draft_tensor_parallel_size=} cannot be " f"other value than 1 or target model tensor_parallel_size") @@ -2652,8 +2586,8 @@ def _verify_args(self) -> None: "Expected values are rejection_sampler or " "typical_acceptance_sampler.") - if (self.acceptance_method != "rejection_sampler" - and self.acceptance_method != "typical_acceptance_sampler"): + if (self.acceptance_method != 'rejection_sampler' + and self.acceptance_method != 'typical_acceptance_sampler'): raise ValueError( "Expected acceptance_method to be either " "rejection_sampler or typical_acceptance_sampler. Instead it " @@ -2676,8 +2610,8 @@ def _verify_args(self) -> None: "speculative decoding is > 1, but got " f"{self.disable_by_batch_size=}") - if (self.method == "eagle3" and self.target_model_config and "llama" - not in self.target_model_config.hf_text_config.model_type): + if self.method == "eagle3" and self.target_model_config and \ + "llama" not in self.target_model_config.hf_text_config.model_type: raise ValueError( "Eagle3 is only supported for Llama models. " f"Got {self.target_model_config.hf_text_config.model_type=}") @@ -2837,6 +2771,7 @@ def compute_hash(self) -> str: return hash_str def __post_init__(self): + if self.max_prompt_adapters < 1: raise ValueError(f"max_prompt_adapters " f"({self.max_prompt_adapters}) must be >= 1.") @@ -3021,7 +2956,6 @@ def _get_and_verify_dtype( torch_dtype = torch.bfloat16 from vllm.platforms import current_platform - if (current_platform.is_cpu() and current_platform.get_cpu_architecture() == CpuArchEnum.POWERPC @@ -3114,8 +3048,8 @@ def _get_and_verify_max_len( for key in possible_keys: max_len = getattr(hf_config, key, None) if max_len is not None: - max_len_key = (key - if max_len < derived_max_model_len else max_len_key) + max_len_key = key if max_len < derived_max_model_len \ + else max_len_key derived_max_model_len = min(derived_max_model_len, max_len) # For Command-R / Cohere, Cohere2 / Aya Vision models if tmp_max_len := getattr(hf_config, "model_max_length", None): @@ -3125,9 +3059,10 @@ def _get_and_verify_max_len( # If sliding window is manually disabled, max_length should be less # than the sliding window length in the model config. if disable_sliding_window and sliding_window_len is not None: + sliding_window_len_min = get_min_sliding_window(sliding_window_len) - max_len_key = ("sliding_window" if sliding_window_len_min - < derived_max_model_len else max_len_key) + max_len_key = "sliding_window" \ + if sliding_window_len_min < derived_max_model_len else max_len_key derived_max_model_len = min(derived_max_model_len, sliding_window_len_min) @@ -3147,10 +3082,8 @@ def _get_and_verify_max_len( logger.warning( "The model's config.json does not contain any of the following " "keys to determine the original maximum length of the model: " - "%s. Assuming the model's maximum length is %d.", - possible_keys, - default_max_len, - ) + "%s. Assuming the model's maximum length is %d.", possible_keys, + default_max_len) derived_max_model_len = default_max_len rope_scaling = getattr(hf_config, "rope_scaling", None) @@ -3209,9 +3142,7 @@ def _get_and_verify_max_len( if envs.VLLM_ALLOW_LONG_MAX_MODEL_LEN: logger.warning( "%s Make sure the value is correct and within the " - "model context size.", - msg, - ) + "model context size.", msg) else: raise ValueError( f"{msg} To allow overriding this maximum, set " @@ -3220,7 +3151,7 @@ def _get_and_verify_max_len( def get_min_sliding_window( - sliding_window: Union[int, list[Optional[int]]], ) -> int: + sliding_window: Union[int, list[Optional[int]]]) -> int: if isinstance(sliding_window, list): return min(s for s in sliding_window if s is not None) @@ -3320,13 +3251,11 @@ def __post_init__(self): if self.backend not in valid_guided_backends: raise ValueError(f"Invalid backend '{self.backend}'," f" must be one of {valid_guided_backends}") - if self.disable_any_whitespace and self.backend not in ( - "xgrammar", - "guidance", - ): + if (self.disable_any_whitespace + and self.backend not in ("xgrammar", "guidance")): raise ValueError("disable_any_whitespace is only supported for " "xgrammar and guidance backends.") - if self.disable_additional_properties and self.backend != "guidance": + if (self.disable_additional_properties and self.backend != "guidance"): raise ValueError("disable_additional_properties is only supported " "for the guidance backend.") @@ -3352,7 +3281,6 @@ def _extract_backend_options(self): @dataclass class ObservabilityConfig: """Configuration for observability - metrics and tracing.""" - show_hidden_metrics: bool = False otlp_traces_endpoint: Optional[str] = None @@ -3453,10 +3381,9 @@ def from_cli(cls, cli_value: str) -> "KVTransferConfig": return KVTransferConfig.model_validate_json(cli_value) def model_post_init(self, __context: Any) -> None: + if self.kv_role is not None and self.kv_role not in [ - "kv_producer", - "kv_consumer", - "kv_both", + "kv_producer", "kv_consumer", "kv_both" ]: raise ValueError( f"Unsupported kv_role: {self.kv_role}. " @@ -3470,25 +3397,18 @@ def model_post_init(self, __context: Any) -> None: @property def is_kv_transfer_instance(self) -> bool: - return self.kv_connector is not None and self.kv_role in [ - "kv_producer", - "kv_consumer", - "kv_both", - ] + return self.kv_connector is not None and \ + self.kv_role in ["kv_producer", "kv_consumer", "kv_both"] @property def is_kv_producer(self) -> bool: - return self.kv_connector is not None and self.kv_role in [ - "kv_producer", - "kv_both", - ] + return self.kv_connector is not None and \ + self.kv_role in ["kv_producer", "kv_both"] @property def is_kv_consumer(self) -> bool: - return self.kv_connector is not None and self.kv_role in [ - "kv_consumer", - "kv_both", - ] + return self.kv_connector is not None and \ + self.kv_role in ["kv_consumer", "kv_both"] def get_from_extra_config(self, key, default) -> Any: return self.kv_connector_extra_config.get(key, default) @@ -3582,8 +3502,7 @@ class CompilationConfig(BaseModel): static shapes. However, we find the general shape compilation is sufficient for most cases. It might be beneficial to compile for certain small batchsizes, where inductor is good at optimizing. - """ # noqa - + """ # noqa level: int = 0 debug_dump_path: str = "" cache_dir: str = "" @@ -3615,7 +3534,6 @@ class PassConfig(BaseModel): TODO(luka) better pass enabling system. - enable_sequence_parallelism: whether to enable sequence parallelism. """ - dump_graph_stages: list[str] = Field(default_factory=list) dump_graph_dir: Path = Field(default=Path(".")) enable_fusion: bool = True @@ -3629,11 +3547,8 @@ def uuid(self): Do not include dump_graph_* in the hash - they don't affect compilation. """ - dict_ = self.model_dump(include={ - "enable_fusion", - "enable_noop", - "enable_sequence_parallelism", - }) + dict_ = self.model_dump(include={"enable_fusion", "enable_noop", \ + "enable_sequence_parallelism"}) return InductorPass.hash_dict(dict_) def model_post_init(self, __context: Any) -> None: @@ -3711,6 +3626,7 @@ def from_cli(cls, cli_value: str) -> "CompilationConfig": return CompilationConfig.model_validate(dict_value) def model_post_init(self, __context: Any) -> None: + count_none = self.custom_ops.count("none") count_all = self.custom_ops.count("all") assert count_none + count_all <= 1, "Can only specify 'none' or 'all'" @@ -3724,7 +3640,7 @@ def model_post_init(self, __context: Any) -> None: # https://github.com/vllm-project/vllm/issues/14703 if is_torch_equal_or_newer("2.6"): - KEY = "enable_auto_functionalized_v2" + KEY = 'enable_auto_functionalized_v2' if KEY not in self.inductor_compile_config: self.inductor_compile_config[KEY] = False @@ -3735,8 +3651,8 @@ def model_post_init(self, __context: Any) -> None: if not isinstance(v, str): assert callable(v), ( f"pass {k} should be callable or a qualified name") - self.inductor_compile_config[k] = (v if isinstance( - v, InductorPass) else CallableInductorPass(v)) + self.inductor_compile_config[k] = v if isinstance( + v, InductorPass) else CallableInductorPass(v) continue # resolve function from qualified name @@ -3744,8 +3660,8 @@ def model_post_init(self, __context: Any) -> None: module = ".".join(names[:-1]) func_name = names[-1] func = __import__(module).__dict__[func_name] - self.inductor_compile_config[k] = (func if isinstance( - func, InductorPass) else CallableInductorPass(func)) + self.inductor_compile_config[k] = func if isinstance( + func, InductorPass) else CallableInductorPass(func) self.enabled_custom_ops = Counter() self.disabled_custom_ops = Counter() @@ -3758,11 +3674,9 @@ def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: raise ValueError("No compilation level is set.") from torch._dynamo.backends.registry import list_backends - torch_backends = list_backends(exclude_tags=tuple()) if self.level in [ - CompilationLevel.DYNAMO_AS_IS, - CompilationLevel.DYNAMO_ONCE, + CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE ]: if self.backend == "": return "eager" @@ -3775,7 +3689,6 @@ def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: assert self.level == CompilationLevel.PIECEWISE from vllm.compilation.backends import VllmBackend - return VllmBackend(vllm_config) def init_with_cudagraph_sizes(self, @@ -3789,12 +3702,9 @@ def init_with_cudagraph_sizes(self, # de-duplicate the sizes provided by the config self.cudagraph_capture_sizes = list( set(self.cudagraph_capture_sizes)) - logger.info( - ("cudagraph sizes specified by model runner" - " %s is overridden by config %s"), - cudagraph_capture_sizes, - self.cudagraph_capture_sizes, - ) + logger.info(("cudagraph sizes specified by model runner" + " %s is overridden by config %s"), + cudagraph_capture_sizes, self.cudagraph_capture_sizes) computed_compile_sizes = [] if self.compile_sizes is not None: @@ -3802,9 +3712,9 @@ def init_with_cudagraph_sizes(self, self.compile_sizes = list(set(self.compile_sizes)) for x in self.compile_sizes: if isinstance(x, str): - assert x == "cudagraph_capture_sizes", ( - "Unrecognized size type in compile_sizes, " - f"expect 'cudagraph_capture_sizes', got {x}") + assert x == "cudagraph_capture_sizes", \ + "Unrecognized size type in compile_sizes, " \ + f"expect 'cudagraph_capture_sizes', got {x}" computed_compile_sizes.extend(self.cudagraph_capture_sizes) else: assert isinstance(x, int) @@ -3813,8 +3723,8 @@ def init_with_cudagraph_sizes(self, # sort to make sure cudagraph capture sizes are in descending order self.cudagraph_capture_sizes.sort(reverse=True) - self.max_capture_size = (self.cudagraph_capture_sizes[0] - if self.cudagraph_capture_sizes else 0) + self.max_capture_size = self.cudagraph_capture_sizes[ + 0] if self.cudagraph_capture_sizes else 0 # pre-compute the mapping from batch size to padded graph size self.bs_to_padded_graph_size = [ @@ -3827,8 +3737,8 @@ def init_with_cudagraph_sizes(self, self.bs_to_padded_graph_size[bs] = start else: self.bs_to_padded_graph_size[bs] = end - self.bs_to_padded_graph_size[self.max_capture_size] = ( - self.max_capture_size) + self.bs_to_padded_graph_size[ + self.max_capture_size] = self.max_capture_size def set_splitting_ops_for_v1(self): # If default, override splitting ops for piecewise cudagraph on V1. @@ -3858,8 +3768,7 @@ class VllmConfig: lora_config: Optional[LoRAConfig] = None speculative_config: SpeculativeConfig = field(default=None, init=True) # type: ignore - decoding_config: DecodingConfig = field(default_factory=DecodingConfig, - init=True) + decoding_config: Optional[DecodingConfig] = None observability_config: Optional[ObservabilityConfig] = None prompt_adapter_config: Optional[PromptAdapterConfig] = None quant_config: Optional[QuantizationConfig] = None @@ -3891,7 +3800,6 @@ def compute_hash(self) -> str: # summarize vllm config vllm_factors: list[Any] = [] from vllm import __version__ - vllm_factors.append(__version__) vllm_factors.append(envs.VLLM_USE_V1) if self.model_config: @@ -3976,11 +3884,9 @@ def _get_quantization_config( load_config: LoadConfig) -> Optional[QuantizationConfig]: """Get the quantization config.""" from vllm.platforms import current_platform - if model_config.quantization is not None: from vllm.model_executor.model_loader.weight_utils import ( get_quant_config) - quant_config = get_quant_config(model_config, load_config) capability_tuple = current_platform.get_device_capability() @@ -4027,13 +3933,12 @@ def with_hf_config( return replace(self, model_config=model_config) def __post_init__(self): - """Verify configs are valid & consistent with each other.""" + """Verify configs are valid & consistent with each other. + """ if self.model_config is not None: - self.model_config.verify_async_output_proc( - self.parallel_config, - self.speculative_config, - self.device_config, - ) + self.model_config.verify_async_output_proc(self.parallel_config, + self.speculative_config, + self.device_config) self.model_config.verify_with_parallel_config(self.parallel_config) if self.cache_config is not None: @@ -4047,17 +3952,17 @@ def __post_init__(self): self.prompt_adapter_config.verify_with_model_config( self.model_config) - if (self.quant_config is None and self.model_config is not None - and self.load_config is not None): + if self.quant_config is None and \ + self.model_config is not None and self.load_config is not None: self.quant_config = VllmConfig._get_quantization_config( self.model_config, self.load_config) from vllm.platforms import current_platform - - if (self.scheduler_config is not None and self.model_config is not None - and self.scheduler_config.chunked_prefill_enabled - and self.model_config.dtype == torch.float32 - and current_platform.get_device_capability() == (7, 5)): + if self.scheduler_config is not None and \ + self.model_config is not None and \ + self.scheduler_config.chunked_prefill_enabled and \ + self.model_config.dtype == torch.float32 and \ + current_platform.get_device_capability() == (7, 5): logger.warning_once( "Turing devices tensor cores do not support float32 matmul. " "To workaround this limitation, vLLM will set 'ieee' input " @@ -4067,8 +3972,8 @@ def __post_init__(self): self.compilation_config = CompilationConfig() if self.compilation_config.pass_config.enable_sequence_parallelism: self.compilation_config.custom_ops.append("+rms_norm") - if (envs.VLLM_USE_V1 and self.model_config is not None - and not self.model_config.enforce_eager): + if envs.VLLM_USE_V1 and self.model_config is not None and \ + not self.model_config.enforce_eager: # NOTE(woosuk): Currently, we use inductor because the piecewise # CUDA graphs do not work properly with the custom CUDA kernels. # FIXME(woosuk): Disable inductor to reduce the compilation time @@ -4084,25 +3989,24 @@ def __post_init__(self): self.compilation_config.level = CompilationLevel.PIECEWISE self.compilation_config.set_splitting_ops_for_v1() - if (self.parallel_config is not None - and self.parallel_config.tensor_parallel_size > 1 - and self.parallel_config.pipeline_parallel_size > 1 - and self.compilation_config is not None - and self.compilation_config.pass_config is not None and - self.compilation_config.pass_config.enable_sequence_parallelism - ): + if self.parallel_config is not None and \ + self.parallel_config.tensor_parallel_size > 1 and \ + self.parallel_config.pipeline_parallel_size > 1 and \ + self.compilation_config is not None and \ + self.compilation_config.pass_config is not None and \ + self.compilation_config.pass_config.enable_sequence_parallelism: logger.warning_once( "Sequence parallelism is not supported with pipeline " "parallelism. Disabling sequence parallelism.") - self.compilation_config.pass_config.enable_sequence_parallelism = ( - False) + self.compilation_config.pass_config.\ + enable_sequence_parallelism = False self._set_cudagraph_sizes() - if (self.cache_config is not None - and self.cache_config.cpu_offload_gb > 0 - and self.compilation_config.level - != CompilationLevel.NO_COMPILATION and not envs.VLLM_USE_V1): + if self.cache_config is not None and \ + self.cache_config.cpu_offload_gb > 0 and \ + self.compilation_config.level != CompilationLevel.NO_COMPILATION \ + and not envs.VLLM_USE_V1: logger.warning( "CPU offload is not supported with `torch.compile` in v0 yet." " Disabling `torch.compile`.") @@ -4116,9 +4020,9 @@ def __post_init__(self): "Disabling `torch.compile`.") self.compilation_config.level = CompilationLevel.NO_COMPILATION - if (self.model_config and self.model_config.use_mla - and not (current_platform.is_cuda() - or current_platform.is_rocm())): + + if self.model_config and self.model_config.use_mla and \ + not (current_platform.is_cuda() or current_platform.is_rocm()): logger.info( "MLA is enabled on a non-GPU platform; forcing chunked " "prefill and prefix caching to be disabled.") @@ -4126,8 +4030,7 @@ def __post_init__(self): self.scheduler_config.chunked_prefill_enabled = False self.scheduler_config.max_num_batched_tokens = max( self.scheduler_config.max_model_len, - _DEFAULT_MAX_NUM_BATCHED_TOKENS, - ) + _DEFAULT_MAX_NUM_BATCHED_TOKENS) if self.cache_config is not None: self.cache_config.enable_prefix_caching = False @@ -4149,10 +4052,8 @@ def update_sizes_for_sequence_parallelism(self, logger.warning( "Batch sizes %s are removed because they are not " "multiple of tp_size %d when " - "sequence parallelism is enabled", - removed_sizes, - self.parallel_config.tensor_parallel_size, - ) + "sequence parallelism is enabled", removed_sizes, + self.parallel_config.tensor_parallel_size) return [ size for size in possible_sizes @@ -4191,13 +4092,13 @@ def _set_cudagraph_sizes(self): if not envs.VLLM_USE_V1: batch_size_capture_list = [] max_batchsize_to_capture = 0 - if (self.scheduler_config is not None - and self.model_config is not None - and not self.model_config.enforce_eager): + if self.scheduler_config is not None and \ + self.model_config is not None and \ + not self.model_config.enforce_eager: + possible_sizes = [1, 2, 4] + [8 * i for i in range(1, 1025)] - if (self.parallel_config.tensor_parallel_size > 1 - and self.compilation_config.pass_config. - enable_sequence_parallelism): + if self.parallel_config.tensor_parallel_size > 1 and \ + self.compilation_config.pass_config.enable_sequence_parallelism: possible_sizes = self.update_sizes_for_sequence_parallelism( possible_sizes) @@ -4220,16 +4121,14 @@ def _set_cudagraph_sizes(self): ] else: batch_size_capture_list = [] - if (self.model_config is not None - and not self.model_config.enforce_eager): + if self.model_config is not None and \ + not self.model_config.enforce_eager: batch_size_capture_list = [1, 2, 4 ] + [i for i in range(8, 513, 8)] - if (self.parallel_config.tensor_parallel_size > 1 - and self.compilation_config.pass_config. - enable_sequence_parallelism): - batch_size_capture_list = ( - self.update_sizes_for_sequence_parallelism( - batch_size_capture_list)) + if self.parallel_config.tensor_parallel_size > 1 and \ + self.compilation_config.pass_config.enable_sequence_parallelism: + batch_size_capture_list = \ + self.update_sizes_for_sequence_parallelism(batch_size_capture_list) max_num_tokens = self.scheduler_config.max_num_batched_tokens batch_size_capture_list = [ @@ -4290,7 +4189,6 @@ def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False): global _current_vllm_config old_vllm_config = _current_vllm_config from vllm.compilation.counter import compilation_counter - num_models_seen = compilation_counter.num_models_seen try: _current_vllm_config = vllm_config @@ -4298,17 +4196,13 @@ def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False): except Exception: raise else: - logger.debug( - "enabled custom ops: %s", - vllm_config.compilation_config.enabled_custom_ops, - ) - logger.debug( - "disabled custom ops: %s", - vllm_config.compilation_config.disabled_custom_ops, - ) - if (check_compile and vllm_config.compilation_config.level - == CompilationLevel.PIECEWISE - and compilation_counter.num_models_seen == num_models_seen): + logger.debug("enabled custom ops: %s", + vllm_config.compilation_config.enabled_custom_ops) + logger.debug("disabled custom ops: %s", + vllm_config.compilation_config.disabled_custom_ops) + if check_compile and \ + vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \ + and compilation_counter.num_models_seen == num_models_seen: # If the model supports compilation, # compilation_counter.num_models_seen should be increased # by at least 1. @@ -4318,8 +4212,7 @@ def set_current_vllm_config(vllm_config: VllmConfig, check_compile=False): "`torch.compile` is turned on, but the model %s" " does not support it. Please open an issue on GitHub" " if you want it to be supported.", - vllm_config.model_config.model, - ) + vllm_config.model_config.model) finally: _current_vllm_config = old_vllm_config @@ -4331,7 +4224,6 @@ def get_current_vllm_config() -> VllmConfig: # config. logger.warning("Current vLLM config is not set.") from vllm.config import VllmConfig - return VllmConfig() return _current_vllm_config @@ -4349,7 +4241,7 @@ def contains_object_print(text): Returns: bool: True if a match is found, False otherwise """ - pattern = r"at 0x[a-fA-F0-9]{2,16}>" + pattern = r'at 0x[a-fA-F0-9]{2,16}>' match = re.search(pattern, text) return match is not None diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index 303d04a566f..3680469f821 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -183,14 +183,10 @@ def check_object(obj: dict[str, Any]) -> bool: return True # Check for array unsupported keywords - if obj.get("type") == "array" and any(key in obj for key in ( - "uniqueItems", - "contains", - "minContains", - "maxContains", - "minItems", - "maxItems", - )): + if obj.get("type") == "array" and any( + key in obj + for key in ("uniqueItems", "contains", "minContains", + "maxContains", "minItems", "maxItems")): return True # Unsupported keywords for strings @@ -198,12 +194,9 @@ def check_object(obj: dict[str, Any]) -> bool: return True # Unsupported keywords for objects - if obj.get("type") == "object" and any(key in obj for key in ( - "minProperties", - "maxProperties", - "propertyNames", - "patternProperties", - )): + if obj.get("type") == "object" and any( + key in obj for key in ("minProperties", "maxProperties", + "propertyNames", "patternProperties")): return True # Recursively check all nested objects and arrays @@ -235,16 +228,16 @@ def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None: try: xgr.Grammar.from_regex(gd_params.regex) except Exception as err: - raise ValueError( - f"Failed to transform regex into a grammar: {err}") from err + raise ValueError("Failed to transform regex into a grammar: " + f"{err}") from err if gd_params.choice: choice_grammar = choice_as_grammar(gd_params.choice) try: xgr.Grammar.from_ebnf(choice_grammar) except Exception as err: - raise ValueError( - "Failed to transform choices into a grammar: {err}") from err + raise ValueError("Failed to transform choices into a grammar: " + "{err}") from err gd_params.choice = None gd_params.grammar = choice_grammar return From 6d26942081173dce1b1da7ead0ef187bf5536c72 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Tue, 29 Apr 2025 17:32:22 -0400 Subject: [PATCH 16/40] fix: make sure to initialize DecodingConfig by default, and fix types Signed-off-by: Aaron Pham --- vllm/config.py | 3 ++- vllm/v1/structured_output/backend_guidance.py | 5 +++-- vllm/v1/structured_output/backend_xgrammar.py | 12 ++++-------- 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index abe59734e2d..12e358fe521 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3768,7 +3768,8 @@ class VllmConfig: lora_config: Optional[LoRAConfig] = None speculative_config: SpeculativeConfig = field(default=None, init=True) # type: ignore - decoding_config: Optional[DecodingConfig] = None + decoding_config: DecodingConfig = field(default_factory=DecodingConfig, + init=True) observability_config: Optional[ObservabilityConfig] = None prompt_adapter_config: Optional[PromptAdapterConfig] = None quant_config: Optional[QuantizationConfig] = None diff --git a/vllm/v1/structured_output/backend_guidance.py b/vllm/v1/structured_output/backend_guidance.py index 2003b4112f9..0bf89641b0c 100644 --- a/vllm/v1/structured_output/backend_guidance.py +++ b/vllm/v1/structured_output/backend_guidance.py @@ -136,10 +136,11 @@ def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: return r - def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None: + def fill_bitmask(self, bitmask: torch.Tensor, batch_index: int) -> None: # this will automatically return [EOS] mask if the matcher is stopped # or otherwise in an error state - llguidance_torch.fill_next_token_bitmask(self.ll_matcher, bitmask, idx) + llguidance_torch.fill_next_token_bitmask(self.ll_matcher, bitmask, + batch_index) self.check_error() def is_terminated(self) -> bool: diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index 3680469f821..e6373ac891a 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -88,8 +88,7 @@ def compile_grammar(self, request_type: StructuredOutputOptions, elif request_type == StructuredOutputOptions.JSON_OBJECT: ctx = self.compiler.compile_json_schema( '{"type": "object"}', - any_whitespace=not self.disable_any_whitespace, - ) + any_whitespace=not self.disable_any_whitespace) elif request_type == StructuredOutputOptions.GRAMMAR: ctx = self.compiler.compile_grammar(grammar_spec) elif request_type == StructuredOutputOptions.REGEX: @@ -152,16 +151,13 @@ def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: if not self.matcher.accept_token(token): logger.error( "Failed to advance FSM for request %s " - "for tokens %s. Please file an issue.", - request_id, - token, - ) + "for tokens %s. Please file an issue.", request_id, token) return False self.num_processed_tokens += 1 return True - def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None: - self.matcher.fill_next_token_bitmask(bitmask, idx) + def fill_bitmask(self, bitmask: torch.Tensor, batch_index: int) -> None: + self.matcher.fill_next_token_bitmask(bitmask, batch_index) def is_terminated(self) -> bool: return self.matcher.is_terminated() From fcfef12d9e41470a29071294cd9f34f343838615 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Wed, 30 Apr 2025 12:20:44 +0000 Subject: [PATCH 17/40] --wip-- Signed-off-by: Aaron Pham --- .../llm/test_struct_output_generate.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index d908fa320d3..b8b3db4d685 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -1,3 +1,4 @@ +# ruff: noqa: E501 # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations @@ -29,22 +30,20 @@ "num_speculative_tokens": 5, } +# yapf: disable PARAMS_MODELS_BACKENDS_TOKENIZER_MODE_REASONING_PARSER_SPEC_CONFIG = [ ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto", None, None), ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None, None), - ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None, - None), + ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None, None), ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None, None), - ("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", "xgrammar", "auto", - "deepseek_r1", None), - ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None, - NGRAM_SPEC_CONFIG), + ("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", "xgrammar", "auto", "deepseek_r1", None), + ("Qwen/Qwen3-0.6B", "xgrammar", "auto", "qwen3", None, NGRAM_SPEC_CONFIG), + ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None, NGRAM_SPEC_CONFIG), #FIXME: This test is flaky on CI thus disabled - ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None, NGRAM_SPEC_CONFIG - ), - ("meta-llama/Meta-Llama-3.1-8B-Instruct", "xgrammar", "auto", None, - EAGLE_SPEC_CONFIG) + ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None, NGRAM_SPEC_CONFIG), + ("meta-llama/Meta-Llama-3.1-8B-Instruct", "xgrammar", "auto", None, EAGLE_SPEC_CONFIG) ] +# yapf: enable PARAMS_MODELS_TOKENIZER_MODE = [ ("mistralai/Ministral-8B-Instruct-2410", "auto"), From 1e828bd2cd499fabcd7f492ce27425b6c592f72d Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Wed, 30 Apr 2025 15:42:13 +0000 Subject: [PATCH 18/40] chore: move logic to manager Signed-off-by: Aaron Pham --- vllm/v1/core/sched/scheduler.py | 38 +++++---------- vllm/v1/structured_output/__init__.py | 66 +++++++++++++++++---------- 2 files changed, 53 insertions(+), 51 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index d1c1ae98a68..d73a46f958e 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -706,39 +706,25 @@ def update_from_output( # the outer lists can be of length > 1. new_logprobs = logprobs.slice(req_index, req_index + 1) - if new_token_ids and request.use_structured_output: - reasoner = self.structured_output_manager.reasoner - advance_fsm = reasoner is None - - # NOTE: use_structured_output implies - # structured_output_request is not None, - # but type checker isn't smart enough to know this. - # This only affect type runtime, not actual runtime. - # assert is also not recommended on perf-sensitive runtime path. + if new_token_ids and request.use_structured_output and self.structured_output_manager.should_advance( # noqa: E501 + request): if TYPE_CHECKING: assert request.structured_output_request is not None assert request.structured_output_request.grammar is not None - - if reasoner is not None: - if request.structured_output_request.reasoning_ended: # noqa: E501 - advance_fsm = True - elif reasoner.is_reasoning_end(request.all_token_ids): - request.structured_output_request.reasoning_ended = True - advance_fsm = False - else: - advance_fsm = False - - if advance_fsm: - request.structured_output_request.grammar.accept_tokens( - req_id, new_token_ids) + request.structured_output_request.grammar.accept_tokens( + req_id, + new_token_ids, + ) # Add newly generated spec token ids to the request. if spec_token_ids is not None: - if request.use_structured_output: - metadata = request.structured_output_request - assert metadata is not None and metadata.grammar is not None + if request.use_structured_output and self.structured_output_manager.should_advance( # noqa: E501 + request): + if TYPE_CHECKING: + assert request.structured_output_request is not None + assert request.structured_output_request.grammar is not None # noqa: E501 # Needs to happen after new_token_ids are accepted. - request.spec_token_ids = metadata.grammar.validate_tokens( + request.spec_token_ids = request.structured_output_request.grammar.validate_tokens( # noqa: E501 spec_token_ids[req_index]) else: request.spec_token_ids = spec_token_ids[req_index] diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 5765f33ea8f..e4b93df7b7c 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -46,6 +46,11 @@ def __init__(self, vllm_config: VllmConfig): scheduler_config=self.vllm_config.scheduler_config, lora_config=self.vllm_config.lora_config, ).get_lora_tokenizer(None) + # yapf: disable + reasoning_backend = vllm_config.decoding_config.reasoning_backend + if reasoning_backend is not None: + self.reasoner = ReasoningParserManager.get_reasoning_parser(reasoning_backend)(tokenizer=self.tokenizer) # noqa: E501 + # yapf: enable def grammar_init(self, request: Request) -> None: if request.structured_output_request is None: @@ -77,12 +82,6 @@ def grammar_init(self, request: Request) -> None: raise ValueError( f"Unsupported structured output backend: {backend}") - if ((reasoning_backend := - self.vllm_config.decoding_config.reasoning_backend) - is not None and self.reasoner is None): - self.reasoner = ReasoningParserManager.get_reasoning_parser( - reasoning_backend)(tokenizer=self.tokenizer) - grammar = self.executor.submit(self._async_create_grammar, request) request.structured_output_request.grammar = grammar # type: ignore[assignment] @@ -115,11 +114,10 @@ def grammar_bitmask( if self._grammar_bitmask is None: assert self.backend is not None max_batch_size = self.vllm_config.scheduler_config.max_num_seqs + max_num_spec_tokens = 0 if self.vllm_config.speculative_config is not None: - max_num_spec_tokens = self.vllm_config.\ - speculative_config.num_speculative_tokens - else: - max_num_spec_tokens = 0 + max_num_spec_tokens = \ + self.vllm_config.speculative_config.num_speculative_tokens # Allocate a bitmask for each token needing to be checked: # one for each speculative position, and one more for the @@ -128,15 +126,7 @@ def grammar_bitmask( self.backend.allocate_token_bitmask( max_batch_size * (1 + max_num_spec_tokens)) - # Fill the bitmask using the index of each request equal to its - # position in the batch. Resize the bitmask down to the size of - # the batch. bitmask_tensor = self._grammar_bitmask - batch_len = len(scheduled_spec_decode_tokens) - # Reset the relevant part of the bitmask before filling - if batch_len > 0: - bitmask_tensor[:batch_len].fill_(-1) - # Generate a batched bitmask for all structured output requests. # When speculative decoding is enabled, we need to include multiple # masks for each request, one for each possible bonus token position. @@ -144,13 +134,17 @@ def grammar_bitmask( cumulative_index = 0 ordered_seq = sorted(structured_output_request_ids.items(), key=lambda x: x[1]) + + # Reset the relevant part of the bitmask before filling + if self.reasoner is not None: + bitmask_tensor[:len(ordered_seq)].fill_(-1) # NOTE: This outer loop can likely be parallelized to improve # performance of bitmask generation for large batches. for req_id, _ in ordered_seq: - full_request = requests[req_id] - request = full_request.structured_output_request + request = requests[req_id].structured_output_request if TYPE_CHECKING: - assert request is not None and request.grammar is not None + assert request is not None + assert request.grammar is not None apply_bitmask = ( request.reasoning_ended if self.reasoner is not None else True @@ -160,7 +154,7 @@ def grammar_bitmask( req_tokens = scheduled_spec_decode_tokens.get(req_id, []) + [None] for token in req_tokens: if apply_bitmask and not request.grammar.is_terminated(): - request.grammar.fill_bitmask(self._grammar_bitmask, + request.grammar.fill_bitmask(bitmask_tensor, cumulative_index) if token is not None: # In order to generate the correct bitmask for each @@ -173,15 +167,37 @@ def grammar_bitmask( if state_advancements > 0: request.grammar.rollback(state_advancements) - bitmask_tensor = self._grammar_bitmask - if cumulative_index < self._grammar_bitmask.shape[0]: - bitmask_tensor = self._grammar_bitmask[:cumulative_index] + if cumulative_index < bitmask_tensor.shape[0]: + bitmask_tensor = bitmask_tensor[:cumulative_index] # After finishing with the xgrammar operations, we convert to # np.ndarray, because that is much more efficient for serialization # and deserialization when sending this to the GPU workers. return bitmask_tensor.numpy() + def should_advance(self, request: Request) -> bool: + """To determine whether we can advance the FSM. + Supports thinking usage where we skip the reasoning components. + """ + if TYPE_CHECKING: + assert request.structured_output_request is not None + assert request.structured_output_request.grammar is not None + # by default, we should always advance + should_advance = self.reasoner is None + + # if there is a reasoning parser, then will + # we will need to determine when to start advancing + # the state machine. + if self.reasoner is not None: + if request.structured_output_request.reasoning_ended: + should_advance = True + elif self.reasoner.is_reasoning_end(request.all_token_ids): + request.structured_output_request.reasoning_ended = True + should_advance = False + else: + should_advance = False + return should_advance + def clear_backend(self) -> None: if self.backend is not None: self.backend.destroy() From ce1fddc8f021eb183277f30fe53505ebd34e2b79 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Wed, 30 Apr 2025 21:17:57 +0000 Subject: [PATCH 19/40] chore: update notes Signed-off-by: Aaron Pham --- vllm/v1/structured_output/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index e4b93df7b7c..2ef6fe5d0bb 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -176,13 +176,13 @@ def grammar_bitmask( return bitmask_tensor.numpy() def should_advance(self, request: Request) -> bool: - """To determine whether we can advance the FSM. - Supports thinking usage where we skip the reasoning components. - """ + # To determine whether we can advance the FSM. + # Supports thinking usage where we skip the reasoning components. if TYPE_CHECKING: assert request.structured_output_request is not None assert request.structured_output_request.grammar is not None # by default, we should always advance + # for cases that doesn't uses thinking mode. should_advance = self.reasoner is None # if there is a reasoning parser, then will From c211110566dbfd4c530fd884b748b928cdbb77f6 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Thu, 1 May 2025 02:22:01 +0000 Subject: [PATCH 20/40] fix: make sure works with both thinking, spec and struct matrixes Signed-off-by: Aaron Pham --- vllm/v1/structured_output/__init__.py | 45 ++++++++++++++++----------- 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 2ef6fe5d0bb..173d4034d44 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -9,6 +9,7 @@ from vllm.logger import init_logger from vllm.reasoning import ReasoningParserManager from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs +from vllm.utils import LazyLoader from vllm.v1.structured_output.backend_guidance import GuidanceBackend from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, StructuredOutputGrammar) @@ -21,6 +22,8 @@ from vllm.reasoning import ReasoningParser from vllm.v1.request import Request +else: + torch = LazyLoader("torch", globals(), "torch") logger = init_logger(__name__) @@ -34,6 +37,7 @@ def __init__(self, vllm_config: VllmConfig): self.vllm_config = vllm_config self._grammar_bitmask: Optional[torch.Tensor] = None + self._full_mask = torch.tensor(-1, dtype=torch.int32) # The default max_workers if not specified is the number of CPUs * 5, # which is way too high since these tasks are CPU-bound, not I/O bound. @@ -111,13 +115,14 @@ def grammar_bitmask( if not structured_output_request_ids: return None + max_num_spec_tokens = 0 + if self.vllm_config.speculative_config is not None: + max_num_spec_tokens = \ + self.vllm_config.speculative_config.num_speculative_tokens + if self._grammar_bitmask is None: assert self.backend is not None max_batch_size = self.vllm_config.scheduler_config.max_num_seqs - max_num_spec_tokens = 0 - if self.vllm_config.speculative_config is not None: - max_num_spec_tokens = \ - self.vllm_config.speculative_config.num_speculative_tokens # Allocate a bitmask for each token needing to be checked: # one for each speculative position, and one more for the @@ -136,8 +141,9 @@ def grammar_bitmask( key=lambda x: x[1]) # Reset the relevant part of the bitmask before filling - if self.reasoner is not None: - bitmask_tensor[:len(ordered_seq)].fill_(-1) + bitmask_tensor[:(len(ordered_seq) * (1 + max_num_spec_tokens))].fill_( + self._full_mask) + # NOTE: This outer loop can likely be parallelized to improve # performance of bitmask generation for large batches. for req_id, _ in ordered_seq: @@ -183,20 +189,21 @@ def should_advance(self, request: Request) -> bool: assert request.structured_output_request.grammar is not None # by default, we should always advance # for cases that doesn't uses thinking mode. - should_advance = self.reasoner is None - - # if there is a reasoning parser, then will - # we will need to determine when to start advancing - # the state machine. if self.reasoner is not None: - if request.structured_output_request.reasoning_ended: - should_advance = True - elif self.reasoner.is_reasoning_end(request.all_token_ids): - request.structured_output_request.reasoning_ended = True - should_advance = False - else: - should_advance = False - return should_advance + structured_req = request.structured_output_request + + if structured_req.reasoning_ended: + return True + + # Check if reasoning ends in *this* step + if self.reasoner.is_reasoning_end(request.all_token_ids): + # Reasoning just ended, so we shouldn't advanced til + # next pass + structured_req.reasoning_ended = True + + return False + else: + return True def clear_backend(self) -> None: if self.backend is not None: From b89662a471d05134d7dbcddb2d30395b75b2fd7a Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Thu, 1 May 2025 02:47:40 +0000 Subject: [PATCH 21/40] chore: cleanup logics Signed-off-by: Aaron Pham --- vllm/v1/core/sched/scheduler.py | 10 +++++----- vllm/v1/structured_output/__init__.py | 3 +++ 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 7c6e522b055..f67410a274b 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -1,3 +1,4 @@ +# ruff: noqa: E501 # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations @@ -721,7 +722,7 @@ def update_from_output( # the outer lists can be of length > 1. new_logprobs = logprobs.slice(req_index, req_index + 1) - if new_token_ids and request.use_structured_output and self.structured_output_manager.should_advance( # noqa: E501 + if new_token_ids and self.structured_output_manager.should_advance( request): if TYPE_CHECKING: assert request.structured_output_request is not None @@ -733,13 +734,12 @@ def update_from_output( # Add newly generated spec token ids to the request. if spec_token_ids is not None: - if request.use_structured_output and self.structured_output_manager.should_advance( # noqa: E501 - request): + if self.structured_output_manager.should_advance(request): if TYPE_CHECKING: assert request.structured_output_request is not None - assert request.structured_output_request.grammar is not None # noqa: E501 + assert request.structured_output_request.grammar is not None # Needs to happen after new_token_ids are accepted. - request.spec_token_ids = request.structured_output_request.grammar.validate_tokens( # noqa: E501 + request.spec_token_ids = request.structured_output_request.grammar.validate_tokens( spec_token_ids[req_index]) else: request.spec_token_ids = spec_token_ids[req_index] diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 173d4034d44..a0d83d1deb7 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -182,6 +182,9 @@ def grammar_bitmask( return bitmask_tensor.numpy() def should_advance(self, request: Request) -> bool: + if not request.use_structured_output: + return False + # To determine whether we can advance the FSM. # Supports thinking usage where we skip the reasoning components. if TYPE_CHECKING: From 591da8e114ea58b02567195284aeceda494b7660 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Mon, 5 May 2025 18:33:21 +0000 Subject: [PATCH 22/40] fix: update to newer logics Signed-off-by: Aaron Pham --- vllm/engine/arg_utils.py | 1 - vllm/v1/structured_output/__init__.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index c6a9e543e61..08dbb4c4503 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -388,7 +388,6 @@ class EngineArgs: pt_load_map_location: str = LoadConfig.pt_load_map_location def __post_init__(self): - self.enable_reasoning = self.reasoning_parser is not None # support `EngineArgs(compilation_config={...})` # without having to manually construct a # CompilationConfig object diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index a0d83d1deb7..cd7b5ef4be1 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -52,7 +52,7 @@ def __init__(self, vllm_config: VllmConfig): ).get_lora_tokenizer(None) # yapf: disable reasoning_backend = vllm_config.decoding_config.reasoning_backend - if reasoning_backend is not None: + if reasoning_backend: self.reasoner = ReasoningParserManager.get_reasoning_parser(reasoning_backend)(tokenizer=self.tokenizer) # noqa: E501 # yapf: enable From a807beede55a8fb619e0d6c81c87ed82fc379cd3 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Fri, 9 May 2025 17:40:02 +0000 Subject: [PATCH 23/40] chore: revert whitespace changes Signed-off-by: Aaron Pham --- docs/source/features/reasoning_outputs.md | 46 ++++++++++--------- .../llm/test_struct_output_generate.py | 2 - 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/docs/source/features/reasoning_outputs.md b/docs/source/features/reasoning_outputs.md index b6bff178cad..305e196b9cc 100644 --- a/docs/source/features/reasoning_outputs.md +++ b/docs/source/features/reasoning_outputs.md @@ -10,12 +10,12 @@ Reasoning models return an additional `reasoning_content` field in their outputs vLLM currently supports the following reasoning models: -| Model Series | Parser Name | Structured Output Support | Tool Calling | -| ------------------------------------------------------------------------------------------------------------------------------------- | ------------- | ----------------------------- | ------------ | -| [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `deepseek_r1` | `guided_json`, `guided_regex` | ❌ | -| [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `guided_json`, `guided_regex` | ✅ | -| [IBM Granite 3.2 language models](https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a) | `granite` | ❌ | ❌ | -| [Qwen3 series](https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f) | `qwen3` | `guided_json`, `guided_regex` | ✅ | +| Model Series | Parser Name | Structured Output Support | Tool Calling | +|--------------|-------------|------------------|-------------| +| [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `deepseek_r1` | `guided_json`, `guided_regex` | ❌ | +| [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `guided_json`, `guided_regex` | ✅ | +| [IBM Granite 3.2 language models](https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a) | `granite` | ❌ | ❌ | +| [Qwen3 series](https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f) | `qwen3` | `guided_json`, `guided_regex` | ✅ | - IBM Granite 3.2 reasoning is disabled by default; to enable it, you must also pass `thinking=True` in your `chat_template_kwargs`. @@ -64,22 +64,22 @@ Streaming chat completions are also supported for reasoning models. The `reasoni ```json { - "id": "chatcmpl-123", - "object": "chat.completion.chunk", - "created": 1694268190, - "model": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", - "system_fingerprint": "fp_44709d6fcb", - "choices": [ - { - "index": 0, - "delta": { - "role": "assistant", - "reasoning_content": "is" - }, - "logprobs": null, - "finish_reason": null - } - ] + "id": "chatcmpl-123", + "object": "chat.completion.chunk", + "created": 1694268190, + "model": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", + "system_fingerprint": "fp_44709d6fcb", + "choices": [ + { + "index": 0, + "delta": { + "role": "assistant", + "reasoning_content": "is", + }, + "logprobs": null, + "finish_reason": null + } + ] } ``` @@ -142,6 +142,8 @@ The reasoning content is also available in the structured output. The structured vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --reasoning-parser deepseek_r1 ``` +The following is an example client: + ```python from openai import OpenAI from pydantic import BaseModel diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index 098793a74d7..9c2d8c4a760 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -30,7 +30,6 @@ "num_speculative_tokens": 5, } -# yapf: disable PARAMS_MODELS_BACKENDS_TOKENIZER_MODE_REASONING_PARSER_SPEC_CONFIG = [ ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto", None, None), ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None, None), @@ -43,7 +42,6 @@ ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None, NGRAM_SPEC_CONFIG), ("meta-llama/Meta-Llama-3.1-8B-Instruct", "xgrammar", "auto", None, EAGLE_SPEC_CONFIG) ] -# yapf: enable PARAMS_MODELS_TOKENIZER_MODE = [ ("mistralai/Ministral-8B-Instruct-2410", "auto"), From 6c2b9df814ab8f006f3d452d257c797fcbdea84a Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Fri, 9 May 2025 17:44:15 +0000 Subject: [PATCH 24/40] fix(tests): ignore runaway properties Co-authored-by: Russell Bryant Signed-off-by: Aaron Pham --- .../llm/test_struct_output_generate.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index 9c2d8c4a760..fd13a422861 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -33,14 +33,19 @@ PARAMS_MODELS_BACKENDS_TOKENIZER_MODE_REASONING_PARSER_SPEC_CONFIG = [ ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto", None, None), ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None, None), - ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None, None), + ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None, + None), ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None, None), - ("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", "xgrammar", "auto", "deepseek_r1", None), + ("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", "xgrammar", "auto", + "deepseek_r1", None), ("Qwen/Qwen3-0.6B", "xgrammar", "auto", "qwen3", None, NGRAM_SPEC_CONFIG), - ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None, NGRAM_SPEC_CONFIG), + ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None, + NGRAM_SPEC_CONFIG), #FIXME: This test is flaky on CI thus disabled - ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None, NGRAM_SPEC_CONFIG), - ("meta-llama/Meta-Llama-3.1-8B-Instruct", "xgrammar", "auto", None, EAGLE_SPEC_CONFIG) + ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None, NGRAM_SPEC_CONFIG + ), + ("meta-llama/Meta-Llama-3.1-8B-Instruct", "xgrammar", "auto", None, + EAGLE_SPEC_CONFIG) ] PARAMS_MODELS_TOKENIZER_MODE = [ @@ -468,7 +473,8 @@ def test_structured_output( You are a helpful assistant. -Given the previous instructions, what is the weather in New York City? +Given the previous instructions, what is the weather in New York City? \ +Make the response as short as possible. """ # Change this once other backends support structural_tag From fb92d8a9380ab0cb4195854b9a317e762a676663 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Fri, 9 May 2025 17:47:14 +0000 Subject: [PATCH 25/40] fix: broken tests Signed-off-by: Aaron Pham --- .../entrypoints/llm/test_struct_output_generate.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index fd13a422861..37bbad823d5 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -34,16 +34,14 @@ ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto", None, None), ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None, None), ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None, - None), - ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None, None), + None), ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None, None), ("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", "xgrammar", "auto", "deepseek_r1", None), - ("Qwen/Qwen3-0.6B", "xgrammar", "auto", "qwen3", None, NGRAM_SPEC_CONFIG), + ("Qwen/Qwen3-0.6B", "xgrammar", "auto", "qwen3", NGRAM_SPEC_CONFIG), ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None, NGRAM_SPEC_CONFIG), - #FIXME: This test is flaky on CI thus disabled - ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None, NGRAM_SPEC_CONFIG - ), + ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None, + NGRAM_SPEC_CONFIG), ("meta-llama/Meta-Llama-3.1-8B-Instruct", "xgrammar", "auto", None, EAGLE_SPEC_CONFIG) ] @@ -521,7 +519,8 @@ def test_structured_output( "type": "integer" } }, - "required": ["result"] + "required": ["result"], + "additionalProperties": False } sampling_params = SamplingParams( From 174e7e8cd402b9c6cbe7fcd580cf3d34e92cd34c Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Fri, 9 May 2025 14:30:03 -0400 Subject: [PATCH 26/40] Update tests/v1/entrypoints/llm/test_struct_output_generate.py Co-authored-by: Russell Bryant Signed-off-by: Aaron Pham --- tests/v1/entrypoints/llm/test_struct_output_generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index 37bbad823d5..741d72dce03 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -525,7 +525,7 @@ def test_structured_output( sampling_params = SamplingParams( temperature=0.1, - max_tokens=200, + max_tokens=4096, guided_decoding=GuidedDecodingParams(json=reasoning_schema)) outputs = llm.generate(prompts=[reasoning_prompt], sampling_params=sampling_params, From 42671cf4ee7752437b811e4c4d22b602d3831687 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Fri, 9 May 2025 18:39:01 +0000 Subject: [PATCH 27/40] revert: update noqa changes Signed-off-by: Aaron Pham --- vllm/config.py | 2 +- vllm/v1/core/sched/scheduler.py | 21 ++++++++------------- vllm/v1/structured_output/__init__.py | 5 ++--- 3 files changed, 11 insertions(+), 17 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index d4bf39891fd..fca2865f85d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3907,7 +3907,7 @@ class VllmConfig: lora_config: Optional[LoRAConfig] = None speculative_config: SpeculativeConfig = field(default=None, init=True) # type: ignore - decoding_config: DecodingConfig = field(default_factory=DecodingConfig) + decoding_config: Optional[DecodingConfig] = None observability_config: Optional[ObservabilityConfig] = None prompt_adapter_config: Optional[PromptAdapterConfig] = None quant_config: Optional[QuantizationConfig] = None diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 551fd04dc0f..fcd6db16e67 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -1,4 +1,3 @@ -# ruff: noqa: E501 # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations @@ -6,7 +5,7 @@ import time from collections import defaultdict, deque from collections.abc import Iterable -from typing import TYPE_CHECKING, Optional, Union +from typing import Optional, Union from vllm.config import VllmConfig from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch @@ -722,22 +721,18 @@ def update_from_output( if new_token_ids and self.structured_output_manager.should_advance( request): - if TYPE_CHECKING: - assert request.structured_output_request is not None - assert request.structured_output_request.grammar is not None - request.structured_output_request.grammar.accept_tokens( - req_id, - new_token_ids, - ) + # NOTE: structured_output_request + # should not be None if use_structured_output, we have + # check above, so safe to ignore type warning + request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] + req_id, new_token_ids) # Add newly generated spec token ids to the request. if spec_token_ids is not None: if self.structured_output_manager.should_advance(request): - if TYPE_CHECKING: - assert request.structured_output_request is not None - assert request.structured_output_request.grammar is not None + metadata = request.structured_output_request # Needs to happen after new_token_ids are accepted. - request.spec_token_ids = request.structured_output_request.grammar.validate_tokens( + request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr] spec_token_ids[req_index]) else: request.spec_token_ids = spec_token_ids[req_index] diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index cd7b5ef4be1..4fe413956f4 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -50,11 +50,10 @@ def __init__(self, vllm_config: VllmConfig): scheduler_config=self.vllm_config.scheduler_config, lora_config=self.vllm_config.lora_config, ).get_lora_tokenizer(None) - # yapf: disable reasoning_backend = vllm_config.decoding_config.reasoning_backend if reasoning_backend: - self.reasoner = ReasoningParserManager.get_reasoning_parser(reasoning_backend)(tokenizer=self.tokenizer) # noqa: E501 - # yapf: enable + self.reasoner = ReasoningParserManager.get_reasoning_parser( + reasoning_backend)(tokenizer=self.tokenizer) # noqa: E501 def grammar_init(self, request: Request) -> None: if request.structured_output_request is None: From 9c364d052a0c33c68d882c62f7274877010057ee Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Fri, 9 May 2025 23:40:31 +0000 Subject: [PATCH 28/40] chore: add a notes about bitmask reset Signed-off-by: Aaron Pham --- vllm/v1/structured_output/__init__.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 4fe413956f4..5fe1b30be16 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -139,7 +139,9 @@ def grammar_bitmask( ordered_seq = sorted(structured_output_request_ids.items(), key=lambda x: x[1]) - # Reset the relevant part of the bitmask before filling + # Note that for thinking support, we will need to + # reset the relevant part of the bitmask for consequent + # request here. bitmask_tensor[:(len(ordered_seq) * (1 + max_num_spec_tokens))].fill_( self._full_mask) @@ -157,7 +159,7 @@ def grammar_bitmask( state_advancements = 0 req_tokens = scheduled_spec_decode_tokens.get(req_id, []) + [None] - for token in req_tokens: + for i, token in enumerate(req_tokens): if apply_bitmask and not request.grammar.is_terminated(): request.grammar.fill_bitmask(bitmask_tensor, cumulative_index) From ffd3fa143ee3d717be602cbeea97a43e7b0ac57a Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Tue, 13 May 2025 00:52:30 +0000 Subject: [PATCH 29/40] fix: initialize default decoding_config Signed-off-by: Aaron Pham --- vllm/config.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index fca2865f85d..5ddeabf4300 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2274,7 +2274,7 @@ class SpeculativeConfig: `TypicalAcceptanceSampler`.""" speculative_token_tree: Optional[str] = None - """Specifies the tree structure for speculative token generation. + """Specifies the tree structure for speculative token generation. """ # required configuration params passed from engine target_model_config: ModelConfig = field(default=None, @@ -3605,9 +3605,9 @@ class CompilationConfig(BaseModel): are always used, it can set this to False. Otherwise, it should set this to True, and the compiler will copy the input to an internally managed buffer. Default is False. - - full_cuda_graph: whether to use a full cuda graph for the entire forward - pass rather than splitting certain operations such as attention into subgraphs. - Thus this flag cannot be used together with splitting_ops. This may provide + - full_cuda_graph: whether to use a full cuda graph for the entire forward + pass rather than splitting certain operations such as attention into subgraphs. + Thus this flag cannot be used together with splitting_ops. This may provide performance benefits for smaller models. - Inductor compilation: - use_inductor: whether to use inductor compilation. @@ -3907,7 +3907,7 @@ class VllmConfig: lora_config: Optional[LoRAConfig] = None speculative_config: SpeculativeConfig = field(default=None, init=True) # type: ignore - decoding_config: Optional[DecodingConfig] = None + decoding_config: DecodingConfig = field(default_factory=DecodingConfig) observability_config: Optional[ObservabilityConfig] = None prompt_adapter_config: Optional[PromptAdapterConfig] = None quant_config: Optional[QuantizationConfig] = None From edd235b7c52746e897d4dd663e1051f5c1a4cea2 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Tue, 13 May 2025 07:22:02 +0000 Subject: [PATCH 30/40] chore(test): use deepseek_r1 parser for qwen3 Signed-off-by: Aaron Pham --- tests/v1/entrypoints/llm/test_struct_output_generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index 14059772a4d..ce6c09745cc 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -37,7 +37,7 @@ None), ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None, None), ("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", "xgrammar", "auto", "deepseek_r1", None), - ("Qwen/Qwen3-0.6B", "xgrammar", "auto", "qwen3", NGRAM_SPEC_CONFIG), + ("Qwen/Qwen3-0.6B", "xgrammar", "auto", "deepseek_r1", NGRAM_SPEC_CONFIG), ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None, NGRAM_SPEC_CONFIG), ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None, From 3cbbd8cab3dab230c17d2e3a87a4677c556dab91 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Tue, 13 May 2025 07:37:05 +0000 Subject: [PATCH 31/40] chore: separate out reasoning tests Signed-off-by: Aaron Pham --- .../llm/test_struct_output_generate.py | 131 +++++++++++------- 1 file changed, 83 insertions(+), 48 deletions(-) diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index ce6c09745cc..935e1ba2adf 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -6,7 +6,7 @@ import json import re from enum import Enum -from typing import Any +from typing import TYPE_CHECKING, Any import jsonschema import pytest @@ -17,6 +17,9 @@ from vllm.platforms import current_platform from vllm.sampling_params import GuidedDecodingParams, SamplingParams +if TYPE_CHECKING: + from vllm.config import TokenizerMode + NGRAM_SPEC_CONFIG = { "model": "[ngram]", "num_speculative_tokens": 5, @@ -30,19 +33,15 @@ "num_speculative_tokens": 5, } -PARAMS_MODELS_BACKENDS_TOKENIZER_MODE_REASONING_PARSER_SPEC_CONFIG = [ - ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto", None, None), - ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None, None), - ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None, - None), ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None, None), - ("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", "xgrammar", "auto", - "deepseek_r1", None), - ("Qwen/Qwen3-0.6B", "xgrammar", "auto", "deepseek_r1", NGRAM_SPEC_CONFIG), - ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None, - NGRAM_SPEC_CONFIG), - ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None, +PARAMS_MODELS_BACKENDS_TOKENIZER_MODE_SPEC_CONFIG = [ + ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto", None), + ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None), + ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None), + ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None), + ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", NGRAM_SPEC_CONFIG), - ("meta-llama/Meta-Llama-3.1-8B-Instruct", "xgrammar", "auto", None, + ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", NGRAM_SPEC_CONFIG), + ("meta-llama/Meta-Llama-3.1-8B-Instruct", "xgrammar", "auto", EAGLE_SPEC_CONFIG) ] @@ -77,8 +76,8 @@ def _load_json(s: str, backend: str) -> str: @pytest.mark.skip_global_cleanup @pytest.mark.parametrize( - "model_name, guided_decoding_backend, tokenizer_mode, reasoning_parser, speculative_config", # noqa: E501 - PARAMS_MODELS_BACKENDS_TOKENIZER_MODE_REASONING_PARSER_SPEC_CONFIG, + "model_name, guided_decoding_backend, tokenizer_mode, speculative_config", # noqa: E501 + PARAMS_MODELS_BACKENDS_TOKENIZER_MODE_SPEC_CONFIG, ) def test_structured_output( monkeypatch: pytest.MonkeyPatch, @@ -90,7 +89,6 @@ def test_structured_output( sample_guided_choice: str, guided_decoding_backend: str, tokenizer_mode: str, - reasoning_parser: str | None, model_name: str, speculative_config: dict[str, Any], ): @@ -110,7 +108,6 @@ def test_structured_output( guided_decoding_backend=guided_decoding_backend, guided_decoding_disable_any_whitespace=True, tokenizer_mode=tokenizer_mode, - reasoning_parser=reasoning_parser, speculative_config=speculative_config) # @@ -519,40 +516,78 @@ def test_structured_output( pytest.fail("Invalid function call format: " f"{generated_text!r}\nError: {str(e)}") - # - # Test 12: Generate structured output with reasoning step - # - if reasoning_parser is not None: - reasoning_prompt = "Solve the following math problem step-by-step, then provide the final answer as JSON object with a single key 'result'. Problem: What is 5 * 8 + 2?" # noqa: E501 - reasoning_schema = { - "type": "object", - "properties": { - "result": { - "type": "integer" - } - }, - "required": ["result"], - "additionalProperties": False - } - sampling_params = SamplingParams( - temperature=0.1, - max_tokens=4096, - guided_decoding=GuidedDecodingParams(json=reasoning_schema)) - outputs = llm.generate(prompts=[reasoning_prompt], - sampling_params=sampling_params, - use_tqdm=True) +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize( + "model_name, guided_decoding_backend, tokenizer_mode, reasoning_parser, speculative_config", # noqa: E501 + [ + ("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", "xgrammar", "auto", + "deepseek_r1", None), + ("Qwen/Qwen3-0.6B", "xgrammar", "auto", "deepseek_r1", + NGRAM_SPEC_CONFIG), + ], +) +def test_structured_output_with_reasoning_matrices( + monkeypatch: pytest.MonkeyPatch, + guided_decoding_backend: str, + tokenizer_mode: TokenizerMode, + reasoning_parser: str, + model_name: str, + speculative_config: dict[str, Any], +): + monkeypatch.setenv("VLLM_USE_V1", "1") - assert outputs is not None - output = outputs[0] - assert output is not None - assert isinstance(output, RequestOutput) - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + if current_platform.is_tpu() and speculative_config: + pytest.skip("TPU does not support speculative decoding") - output_json = json.loads(generated_text) - jsonschema.validate(instance=output_json, schema=reasoning_schema) + # Use a single LLM instance for several scenarios to + # speed up the test suite. + llm = LLM( + model=model_name, + # Don't use eager execution on TPUs because we want to test for no + # recompilation at runtime + enforce_eager=bool(not current_platform.is_tpu()), + max_model_len=1024, + guided_decoding_backend=guided_decoding_backend, + guided_decoding_disable_any_whitespace=True, + tokenizer_mode=tokenizer_mode, + reasoning_parser=reasoning_parser, + speculative_config=speculative_config, + ) + + reasoning_prompt = "Solve the following math problem step-by-step, then provide the final answer as JSON object with a single key 'result'. Problem: What is 5 * 8 + 2?" # noqa: E501 + reasoning_schema = { + "type": "object", + "properties": { + "result": { + "type": "integer" + } + }, + "required": ["result"], + "additionalProperties": False + } + + sampling_params = SamplingParams( + temperature=0.1, + max_tokens=4096, + guided_decoding=GuidedDecodingParams(json=reasoning_schema), + ) + outputs = llm.generate( + prompts=[reasoning_prompt], + sampling_params=sampling_params, + use_tqdm=True, + ) + + assert outputs is not None + output = outputs[0] + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + output_json = json.loads(generated_text) + jsonschema.validate(instance=output_json, schema=reasoning_schema) @pytest.mark.skip_global_cleanup From a559b72a64ad1d3bdd8b53d9baa945af950e8a04 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Tue, 13 May 2025 07:54:34 +0000 Subject: [PATCH 32/40] fix: reasoning tests to parse it Signed-off-by: Aaron Pham --- .../llm/test_struct_output_generate.py | 31 ++++++++++++++----- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index 935e1ba2adf..582eb4fc094 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -13,8 +13,10 @@ from pydantic import BaseModel from vllm.entrypoints.llm import LLM +from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.outputs import RequestOutput from vllm.platforms import current_platform +from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager from vllm.sampling_params import GuidedDecodingParams, SamplingParams if TYPE_CHECKING: @@ -523,8 +525,7 @@ def test_structured_output( [ ("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", "xgrammar", "auto", "deepseek_r1", None), - ("Qwen/Qwen3-0.6B", "xgrammar", "auto", "deepseek_r1", - NGRAM_SPEC_CONFIG), + ("Qwen/Qwen3-0.6B", "xgrammar", "auto", "qwen3", NGRAM_SPEC_CONFIG), ], ) def test_structured_output_with_reasoning_matrices( @@ -554,6 +555,9 @@ def test_structured_output_with_reasoning_matrices( reasoning_parser=reasoning_parser, speculative_config=speculative_config, ) + tokenizer = llm.get_tokenizer(None) + reasoner = ReasoningParserManager.get_reasoning_parser(reasoning_parser)( + tokenizer=tokenizer) reasoning_prompt = "Solve the following math problem step-by-step, then provide the final answer as JSON object with a single key 'result'. Problem: What is 5 * 8 + 2?" # noqa: E501 reasoning_schema = { @@ -569,24 +573,35 @@ def test_structured_output_with_reasoning_matrices( sampling_params = SamplingParams( temperature=0.1, - max_tokens=4096, + max_tokens=8192, guided_decoding=GuidedDecodingParams(json=reasoning_schema), ) outputs = llm.generate( - prompts=[reasoning_prompt], + [reasoning_prompt], sampling_params=sampling_params, use_tqdm=True, ) assert outputs is not None output = outputs[0] - assert output is not None - assert isinstance(output, RequestOutput) + assert output is not None and isinstance(output, RequestOutput) prompt = output.prompt generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + print(generated_text) + reasoning_content, content = reasoner.extract_reasoning_content( + generated_text, + request=ChatCompletionRequest( + messages=[], + model="test-model", + seed=123, + ), + ) + assert content is not None + print( + f"Prompt: {prompt!r}\nReasoning: {reasoning_content!r}\nContent: {content!r}" + ) - output_json = json.loads(generated_text) + output_json = json.loads(content) jsonschema.validate(instance=output_json, schema=reasoning_schema) From 1f3c3695a40cc728e4e33b1c4af45f235ee39c05 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Tue, 13 May 2025 07:56:52 +0000 Subject: [PATCH 33/40] chore: replicate duplicate thinking budget Signed-off-by: Aaron Pham --- .../llm/test_struct_output_generate.py | 53 ++++++++++--------- 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index 582eb4fc094..fbe95fd8cf1 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -525,7 +525,8 @@ def test_structured_output( [ ("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", "xgrammar", "auto", "deepseek_r1", None), - ("Qwen/Qwen3-0.6B", "xgrammar", "auto", "qwen3", NGRAM_SPEC_CONFIG), + ("Qwen/Qwen3-0.6B", "xgrammar", "auto", "deepseek_r1", + NGRAM_SPEC_CONFIG), ], ) def test_structured_output_with_reasoning_matrices( @@ -576,33 +577,33 @@ def test_structured_output_with_reasoning_matrices( max_tokens=8192, guided_decoding=GuidedDecodingParams(json=reasoning_schema), ) - outputs = llm.generate( - [reasoning_prompt], - sampling_params=sampling_params, - use_tqdm=True, - ) + for _ in range(2): + outputs = llm.generate( + [reasoning_prompt], + sampling_params=sampling_params, + use_tqdm=True, + ) - assert outputs is not None - output = outputs[0] - assert output is not None and isinstance(output, RequestOutput) - prompt = output.prompt - generated_text = output.outputs[0].text - print(generated_text) - reasoning_content, content = reasoner.extract_reasoning_content( - generated_text, - request=ChatCompletionRequest( - messages=[], - model="test-model", - seed=123, - ), - ) - assert content is not None - print( - f"Prompt: {prompt!r}\nReasoning: {reasoning_content!r}\nContent: {content!r}" - ) + assert outputs is not None + output = outputs[0] + assert output is not None and isinstance(output, RequestOutput) + prompt = output.prompt + generated_text = output.outputs[0].text + reasoning_content, content = reasoner.extract_reasoning_content( + generated_text, + request=ChatCompletionRequest( + messages=[], + model="test-model", + seed=123, + ), + ) + assert content is not None + print( + f"Prompt: {prompt!r}\nReasoning: {reasoning_content!r}\nContent: {content!r}" + ) - output_json = json.loads(content) - jsonschema.validate(instance=output_json, schema=reasoning_schema) + output_json = json.loads(content) + jsonschema.validate(instance=output_json, schema=reasoning_schema) @pytest.mark.skip_global_cleanup From d5574be2627c6bef33019ceaecec108ae75f36bc Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Tue, 13 May 2025 07:58:38 +0000 Subject: [PATCH 34/40] revert: remove duplications Signed-off-by: Aaron Pham --- .../llm/test_struct_output_generate.py | 49 +++++++++---------- 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index fbe95fd8cf1..aa55b330fce 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -577,33 +577,32 @@ def test_structured_output_with_reasoning_matrices( max_tokens=8192, guided_decoding=GuidedDecodingParams(json=reasoning_schema), ) - for _ in range(2): - outputs = llm.generate( - [reasoning_prompt], - sampling_params=sampling_params, - use_tqdm=True, - ) + outputs = llm.generate( + [reasoning_prompt], + sampling_params=sampling_params, + use_tqdm=True, + ) - assert outputs is not None - output = outputs[0] - assert output is not None and isinstance(output, RequestOutput) - prompt = output.prompt - generated_text = output.outputs[0].text - reasoning_content, content = reasoner.extract_reasoning_content( - generated_text, - request=ChatCompletionRequest( - messages=[], - model="test-model", - seed=123, - ), - ) - assert content is not None - print( - f"Prompt: {prompt!r}\nReasoning: {reasoning_content!r}\nContent: {content!r}" - ) + assert outputs is not None + output = outputs[0] + assert output is not None and isinstance(output, RequestOutput) + prompt = output.prompt + generated_text = output.outputs[0].text + reasoning_content, content = reasoner.extract_reasoning_content( + generated_text, + request=ChatCompletionRequest( + messages=[], + model="test-model", + seed=123, + ), + ) + assert content is not None + print( + f"Prompt: {prompt!r}\nReasoning: {reasoning_content!r}\nContent: {content!r}" + ) - output_json = json.loads(content) - jsonschema.validate(instance=output_json, schema=reasoning_schema) + output_json = json.loads(content) + jsonschema.validate(instance=output_json, schema=reasoning_schema) @pytest.mark.skip_global_cleanup From 59f2aa7b59b7861458abbb91133067a4974ee321 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Tue, 13 May 2025 08:00:21 +0000 Subject: [PATCH 35/40] chore: reorder test logs Signed-off-by: Aaron Pham --- tests/v1/entrypoints/llm/test_struct_output_generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index aa55b330fce..2d1a6fb0e5a 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -596,11 +596,11 @@ def test_structured_output_with_reasoning_matrices( seed=123, ), ) - assert content is not None print( f"Prompt: {prompt!r}\nReasoning: {reasoning_content!r}\nContent: {content!r}" ) + assert content is not None and reasoning_content is not None output_json = json.loads(content) jsonschema.validate(instance=output_json, schema=reasoning_schema) From ded38906c1f4b9cf58540221ef9f46ec07e1fca1 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Tue, 13 May 2025 08:13:25 +0000 Subject: [PATCH 36/40] chore: keep main change to reduce diff Signed-off-by: Aaron Pham --- .../llm/test_struct_output_generate.py | 29 ++++++++----------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index 2d1a6fb0e5a..f56003e7097 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -12,8 +12,8 @@ import pytest from pydantic import BaseModel +from tests.reasoning.utils import run_reasoning_extraction from vllm.entrypoints.llm import LLM -from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.outputs import RequestOutput from vllm.platforms import current_platform from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager @@ -35,11 +35,13 @@ "num_speculative_tokens": 5, } -PARAMS_MODELS_BACKENDS_TOKENIZER_MODE_SPEC_CONFIG = [ +PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [ ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto", None), ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None), ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None), ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None), + #FIXME: This test is flaky on CI thus disabled + #("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"), ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", NGRAM_SPEC_CONFIG), ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", NGRAM_SPEC_CONFIG), @@ -78,9 +80,8 @@ def _load_json(s: str, backend: str) -> str: @pytest.mark.skip_global_cleanup @pytest.mark.parametrize( - "model_name, guided_decoding_backend, tokenizer_mode, speculative_config", # noqa: E501 - PARAMS_MODELS_BACKENDS_TOKENIZER_MODE_SPEC_CONFIG, -) + "model_name, guided_decoding_backend, tokenizer_mode, speculative_config", + PARAMS_MODELS_BACKENDS_TOKENIZER_MODE) def test_structured_output( monkeypatch: pytest.MonkeyPatch, sample_json_schema: dict[str, Any], @@ -524,9 +525,8 @@ def test_structured_output( "model_name, guided_decoding_backend, tokenizer_mode, reasoning_parser, speculative_config", # noqa: E501 [ ("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", "xgrammar", "auto", - "deepseek_r1", None), - ("Qwen/Qwen3-0.6B", "xgrammar", "auto", "deepseek_r1", - NGRAM_SPEC_CONFIG), + "deepseek_r1", NGRAM_SPEC_CONFIG), + ("Qwen/Qwen3-0.6B", "xgrammar", "auto", "qwen3", None), ], ) def test_structured_output_with_reasoning_matrices( @@ -535,7 +535,7 @@ def test_structured_output_with_reasoning_matrices( tokenizer_mode: TokenizerMode, reasoning_parser: str, model_name: str, - speculative_config: dict[str, Any], + speculative_config: dict[str, Any] | None, ): monkeypatch.setenv("VLLM_USE_V1", "1") @@ -550,6 +550,7 @@ def test_structured_output_with_reasoning_matrices( # recompilation at runtime enforce_eager=bool(not current_platform.is_tpu()), max_model_len=1024, + max_num_seqs=16, guided_decoding_backend=guided_decoding_backend, guided_decoding_disable_any_whitespace=True, tokenizer_mode=tokenizer_mode, @@ -588,14 +589,8 @@ def test_structured_output_with_reasoning_matrices( assert output is not None and isinstance(output, RequestOutput) prompt = output.prompt generated_text = output.outputs[0].text - reasoning_content, content = reasoner.extract_reasoning_content( - generated_text, - request=ChatCompletionRequest( - messages=[], - model="test-model", - seed=123, - ), - ) + reasoning_content, content = run_reasoning_extraction( + reasoner, [generated_text]) print( f"Prompt: {prompt!r}\nReasoning: {reasoning_content!r}\nContent: {content!r}" ) From 0fb92a52768ac1f005ac92e23be81d61502e3aef Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Tue, 13 May 2025 08:15:22 +0000 Subject: [PATCH 37/40] fix: use deepseek_r1 parser for tests Signed-off-by: Aaron Pham --- tests/v1/entrypoints/llm/test_struct_output_generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index f56003e7097..79324f442c1 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -526,7 +526,7 @@ def test_structured_output( [ ("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", "xgrammar", "auto", "deepseek_r1", NGRAM_SPEC_CONFIG), - ("Qwen/Qwen3-0.6B", "xgrammar", "auto", "qwen3", None), + ("Qwen/Qwen3-0.6B", "xgrammar", "auto", "deepseek_r1", None), ], ) def test_structured_output_with_reasoning_matrices( From 7ace2cbffde169cc69b2a412420696cc23d55ff3 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Tue, 13 May 2025 08:23:21 +0000 Subject: [PATCH 38/40] chore: use a slightly larger models for smarter cot Signed-off-by: Aaron Pham --- tests/v1/entrypoints/llm/test_struct_output_generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index 79324f442c1..2034449cd24 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -526,7 +526,7 @@ def test_structured_output( [ ("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", "xgrammar", "auto", "deepseek_r1", NGRAM_SPEC_CONFIG), - ("Qwen/Qwen3-0.6B", "xgrammar", "auto", "deepseek_r1", None), + ("Qwen/Qwen3-1.7B", "xgrammar", "auto", "deepseek_r1", None), ], ) def test_structured_output_with_reasoning_matrices( From 1816b3b24ff1a7cce0c63c5ca276f6b1ded3d3a1 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Tue, 13 May 2025 08:31:46 +0000 Subject: [PATCH 39/40] fix: support for qwen3 prompts Signed-off-by: Aaron Pham --- tests/v1/entrypoints/llm/test_struct_output_generate.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index 2034449cd24..25bbcd901d6 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -561,7 +561,7 @@ def test_structured_output_with_reasoning_matrices( reasoner = ReasoningParserManager.get_reasoning_parser(reasoning_parser)( tokenizer=tokenizer) - reasoning_prompt = "Solve the following math problem step-by-step, then provide the final answer as JSON object with a single key 'result'. Problem: What is 5 * 8 + 2?" # noqa: E501 + reasoning_prompt = "Solve the following math problem step-by-step, then provide the final answer as JSON object with a single key 'result'. Make sure to correct your reasoning if there are any issue should it arise.\nProblem: What is 5 * 8 + 2?" # noqa: E501 reasoning_schema = { "type": "object", "properties": { @@ -572,6 +572,8 @@ def test_structured_output_with_reasoning_matrices( "required": ["result"], "additionalProperties": False } + if "Qwen3" in model_name: + reasoning_prompt += "\n" sampling_params = SamplingParams( temperature=0.1, From d96fa456f5a64dcd86050e8138e82058d3ad6faa Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Tue, 13 May 2025 21:44:33 -0400 Subject: [PATCH 40/40] chore: make it more clear Signed-off-by: Aaron Pham --- vllm/v1/structured_output/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 5fe1b30be16..c701ab1d35a 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -52,8 +52,9 @@ def __init__(self, vllm_config: VllmConfig): ).get_lora_tokenizer(None) reasoning_backend = vllm_config.decoding_config.reasoning_backend if reasoning_backend: - self.reasoner = ReasoningParserManager.get_reasoning_parser( - reasoning_backend)(tokenizer=self.tokenizer) # noqa: E501 + reasoner_cls = ReasoningParserManager.get_reasoning_parser( + reasoning_backend) + self.reasoner = reasoner_cls(tokenizer=self.tokenizer) def grammar_init(self, request: Request) -> None: if request.structured_output_request is None: