diff --git a/docs/source/features/reasoning_outputs.md b/docs/source/features/reasoning_outputs.md index 4759d0c26c3..3c2571298e4 100644 --- a/docs/source/features/reasoning_outputs.md +++ b/docs/source/features/reasoning_outputs.md @@ -141,10 +141,10 @@ Remember to check whether the `reasoning_content` exists in the response before The reasoning content is also available in the structured output. The structured output engine like `xgrammar` will use the reasoning content to generate structured output. It is only supported in v0 engine now. ```bash -VLLM_USE_V1=0 vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --reasoning-parser deepseek_r1 +vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --reasoning-parser deepseek_r1 ``` -Please note that the `VLLM_USE_V1` environment variable must be set to `0` to use the v0 engine. +The following is an example client: ```python from openai import OpenAI diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index 5c116598ff3..25bbcd901d6 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -1,3 +1,4 @@ +# ruff: noqa: E501 # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations @@ -5,17 +6,22 @@ import json import re from enum import Enum -from typing import Any +from typing import TYPE_CHECKING, Any import jsonschema import pytest from pydantic import BaseModel +from tests.reasoning.utils import run_reasoning_extraction from vllm.entrypoints.llm import LLM from vllm.outputs import RequestOutput from vllm.platforms import current_platform +from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager from vllm.sampling_params import GuidedDecodingParams, SamplingParams +if TYPE_CHECKING: + from vllm.config import TokenizerMode + NGRAM_SPEC_CONFIG = { "model": "[ngram]", "num_speculative_tokens": 5, @@ -444,7 +450,7 @@ def test_structured_output( prompt = """ You have access to the following function to retrieve the weather in a city: - + { "name": "get_weather", "parameters": { @@ -455,7 +461,7 @@ def test_structured_output( } } } - + If a you choose to call a function ONLY reply in the following format: <{start_tag}={function_name}>{parameters}{end_tag} where @@ -476,7 +482,7 @@ def test_structured_output( - Always add your sources when using search results to answer the user query You are a helpful assistant. - + Given the previous instructions, what is the weather in New York City? \ Make the response as short as possible. """ @@ -514,6 +520,88 @@ def test_structured_output( f"{generated_text!r}\nError: {str(e)}") +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize( + "model_name, guided_decoding_backend, tokenizer_mode, reasoning_parser, speculative_config", # noqa: E501 + [ + ("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", "xgrammar", "auto", + "deepseek_r1", NGRAM_SPEC_CONFIG), + ("Qwen/Qwen3-1.7B", "xgrammar", "auto", "deepseek_r1", None), + ], +) +def test_structured_output_with_reasoning_matrices( + monkeypatch: pytest.MonkeyPatch, + guided_decoding_backend: str, + tokenizer_mode: TokenizerMode, + reasoning_parser: str, + model_name: str, + speculative_config: dict[str, Any] | None, +): + monkeypatch.setenv("VLLM_USE_V1", "1") + + if current_platform.is_tpu() and speculative_config: + pytest.skip("TPU does not support speculative decoding") + + # Use a single LLM instance for several scenarios to + # speed up the test suite. + llm = LLM( + model=model_name, + # Don't use eager execution on TPUs because we want to test for no + # recompilation at runtime + enforce_eager=bool(not current_platform.is_tpu()), + max_model_len=1024, + max_num_seqs=16, + guided_decoding_backend=guided_decoding_backend, + guided_decoding_disable_any_whitespace=True, + tokenizer_mode=tokenizer_mode, + reasoning_parser=reasoning_parser, + speculative_config=speculative_config, + ) + tokenizer = llm.get_tokenizer(None) + reasoner = ReasoningParserManager.get_reasoning_parser(reasoning_parser)( + tokenizer=tokenizer) + + reasoning_prompt = "Solve the following math problem step-by-step, then provide the final answer as JSON object with a single key 'result'. Make sure to correct your reasoning if there are any issue should it arise.\nProblem: What is 5 * 8 + 2?" # noqa: E501 + reasoning_schema = { + "type": "object", + "properties": { + "result": { + "type": "integer" + } + }, + "required": ["result"], + "additionalProperties": False + } + if "Qwen3" in model_name: + reasoning_prompt += "\n" + + sampling_params = SamplingParams( + temperature=0.1, + max_tokens=8192, + guided_decoding=GuidedDecodingParams(json=reasoning_schema), + ) + outputs = llm.generate( + [reasoning_prompt], + sampling_params=sampling_params, + use_tqdm=True, + ) + + assert outputs is not None + output = outputs[0] + assert output is not None and isinstance(output, RequestOutput) + prompt = output.prompt + generated_text = output.outputs[0].text + reasoning_content, content = run_reasoning_extraction( + reasoner, [generated_text]) + print( + f"Prompt: {prompt!r}\nReasoning: {reasoning_content!r}\nContent: {content!r}" + ) + + assert content is not None and reasoning_content is not None + output_json = json.loads(content) + jsonschema.validate(instance=output_json, schema=reasoning_schema) + + @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("model_name, tokenizer_mode", PARAMS_MODELS_TOKENIZER_MODE) diff --git a/vllm/config.py b/vllm/config.py index d8eabfb2e4f..1ca4fa4f095 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2325,7 +2325,7 @@ class SpeculativeConfig: `TypicalAcceptanceSampler`.""" speculative_token_tree: Optional[str] = None - """Specifies the tree structure for speculative token generation. + """Specifies the tree structure for speculative token generation. """ # required configuration params passed from engine target_model_config: ModelConfig = field(default=None, @@ -4017,7 +4017,7 @@ class VllmConfig: """LoRA configuration.""" speculative_config: Optional[SpeculativeConfig] = None """Speculative decoding configuration.""" - decoding_config: Optional[DecodingConfig] = None + decoding_config: DecodingConfig = field(default_factory=DecodingConfig) """Decoding configuration.""" observability_config: Optional[ObservabilityConfig] = None """Observability configuration.""" diff --git a/vllm/reasoning/abs_reasoning_parsers.py b/vllm/reasoning/abs_reasoning_parsers.py index 454167a0dc9..9dd5191da91 100644 --- a/vllm/reasoning/abs_reasoning_parsers.py +++ b/vllm/reasoning/abs_reasoning_parsers.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import os from abc import abstractmethod from collections.abc import Sequence @@ -33,7 +35,7 @@ def vocab(self) -> dict[str, int]: return self.model_tokenizer.get_vocab() @abstractmethod - def is_reasoning_end(self, input_ids: list[int]) -> bool: + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: """ Check if the reasoning content ends in the input_ids. @@ -106,7 +108,7 @@ class ReasoningParserManager: reasoning_parsers: dict[str, type] = {} @classmethod - def get_reasoning_parser(cls, name) -> type: + def get_reasoning_parser(cls, name: str | None) -> type[ReasoningParser]: """ Get reasoning parser by name which is registered by `register_module`. diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 7773853b096..3aff5542317 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -749,7 +749,8 @@ def update_from_output( # the outer lists can be of length > 1. new_logprobs = logprobs.slice(req_index, req_index + 1) - if new_token_ids and request.use_structured_output: + if new_token_ids and self.structured_output_manager.should_advance( + request): # NOTE: structured_output_request # should not be None if use_structured_output, we have # check above, so safe to ignore type warning @@ -758,11 +759,10 @@ def update_from_output( # Add newly generated spec token ids to the request. if spec_token_ids is not None: - if request.use_structured_output: + if self.structured_output_manager.should_advance(request): metadata = request.structured_output_request - assert metadata is not None and metadata.grammar is not None # Needs to happen after new_token_ids are accepted. - request.spec_token_ids = metadata.grammar.validate_tokens( + request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr] spec_token_ids[req_index]) else: request.spec_token_ids = spec_token_ids[req_index] diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 3183edb7c94..c701ab1d35a 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -7,16 +7,23 @@ from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.reasoning import ReasoningParserManager +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs +from vllm.utils import LazyLoader from vllm.v1.structured_output.backend_guidance import GuidanceBackend from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, StructuredOutputGrammar) +from vllm.v1.structured_output.backend_xgrammar import XgrammarBackend if TYPE_CHECKING: import numpy as np import numpy.typing as npt import torch + from vllm.reasoning import ReasoningParser from vllm.v1.request import Request +else: + torch = LazyLoader("torch", globals(), "torch") logger = init_logger(__name__) @@ -26,9 +33,11 @@ class StructuredOutputManager: def __init__(self, vllm_config: VllmConfig): self.backend: Optional[StructuredOutputBackend] = None + self.reasoner: Optional[ReasoningParser] = None self.vllm_config = vllm_config self._grammar_bitmask: Optional[torch.Tensor] = None + self._full_mask = torch.tensor(-1, dtype=torch.int32) # The default max_workers if not specified is the number of CPUs * 5, # which is way too high since these tasks are CPU-bound, not I/O bound. @@ -36,24 +45,43 @@ def __init__(self, vllm_config: VllmConfig): # compilation, so we set it to half the number of CPUs. max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2) self.executor = ThreadPoolExecutor(max_workers=max_workers) + self.tokenizer = init_tokenizer_from_configs( + model_config=self.vllm_config.model_config, + scheduler_config=self.vllm_config.scheduler_config, + lora_config=self.vllm_config.lora_config, + ).get_lora_tokenizer(None) + reasoning_backend = vllm_config.decoding_config.reasoning_backend + if reasoning_backend: + reasoner_cls = ReasoningParserManager.get_reasoning_parser( + reasoning_backend) + self.reasoner = reasoner_cls(tokenizer=self.tokenizer) def grammar_init(self, request: Request) -> None: if request.structured_output_request is None: return + if TYPE_CHECKING: + assert request.sampling_params.guided_decoding is not None + # Initialize the backend the first time it is needed. # # NOTE: We only support a single backend. We do NOT support different # backends on a per-request basis in V1 (for now, anyway...). if self.backend is None: backend = request.sampling_params.guided_decoding.backend + vocab_size = self.vllm_config.model_config.get_vocab_size() if backend == "xgrammar": - from vllm.v1.structured_output.backend_xgrammar import ( - XgrammarBackend) - - self.backend = XgrammarBackend(self.vllm_config) + self.backend = XgrammarBackend( + self.vllm_config, + tokenizer=self.tokenizer, + vocab_size=vocab_size, + ) elif backend == "guidance": - self.backend = GuidanceBackend(self.vllm_config) + self.backend = GuidanceBackend( + self.vllm_config, + tokenizer=self.tokenizer, + vocab_size=vocab_size, + ) else: raise ValueError( f"Unsupported structured output backend: {backend}") @@ -87,14 +115,14 @@ def grammar_bitmask( if not structured_output_request_ids: return None + max_num_spec_tokens = 0 + if self.vllm_config.speculative_config is not None: + max_num_spec_tokens = \ + self.vllm_config.speculative_config.num_speculative_tokens + if self._grammar_bitmask is None: assert self.backend is not None max_batch_size = self.vllm_config.scheduler_config.max_num_seqs - if self.vllm_config.speculative_config is not None: - max_num_spec_tokens = self.vllm_config.\ - speculative_config.num_speculative_tokens - else: - max_num_spec_tokens = 0 # Allocate a bitmask for each token needing to be checked: # one for each speculative position, and one more for the @@ -103,6 +131,7 @@ def grammar_bitmask( self.backend.allocate_token_bitmask( max_batch_size * (1 + max_num_spec_tokens)) + bitmask_tensor = self._grammar_bitmask # Generate a batched bitmask for all structured output requests. # When speculative decoding is enabled, we need to include multiple # masks for each request, one for each possible bonus token position. @@ -110,16 +139,30 @@ def grammar_bitmask( cumulative_index = 0 ordered_seq = sorted(structured_output_request_ids.items(), key=lambda x: x[1]) + + # Note that for thinking support, we will need to + # reset the relevant part of the bitmask for consequent + # request here. + bitmask_tensor[:(len(ordered_seq) * (1 + max_num_spec_tokens))].fill_( + self._full_mask) + # NOTE: This outer loop can likely be parallelized to improve # performance of bitmask generation for large batches. for req_id, _ in ordered_seq: request = requests[req_id].structured_output_request - assert request is not None and request.grammar is not None + if TYPE_CHECKING: + assert request is not None + assert request.grammar is not None + + apply_bitmask = ( + request.reasoning_ended if self.reasoner is not None else True + ) # noqa: E501 + state_advancements = 0 req_tokens = scheduled_spec_decode_tokens.get(req_id, []) + [None] for i, token in enumerate(req_tokens): - if not request.grammar.is_terminated(): - request.grammar.fill_bitmask(self._grammar_bitmask, + if apply_bitmask and not request.grammar.is_terminated(): + request.grammar.fill_bitmask(bitmask_tensor, cumulative_index) if token is not None: # In order to generate the correct bitmask for each @@ -132,15 +175,41 @@ def grammar_bitmask( if state_advancements > 0: request.grammar.rollback(state_advancements) - bitmask_tensor = self._grammar_bitmask - if cumulative_index < self._grammar_bitmask.shape[0]: - bitmask_tensor = self._grammar_bitmask[:cumulative_index] + if cumulative_index < bitmask_tensor.shape[0]: + bitmask_tensor = bitmask_tensor[:cumulative_index] # After finishing with the xgrammar operations, we convert to # np.ndarray, because that is much more efficient for serialization # and deserialization when sending this to the GPU workers. return bitmask_tensor.numpy() + def should_advance(self, request: Request) -> bool: + if not request.use_structured_output: + return False + + # To determine whether we can advance the FSM. + # Supports thinking usage where we skip the reasoning components. + if TYPE_CHECKING: + assert request.structured_output_request is not None + assert request.structured_output_request.grammar is not None + # by default, we should always advance + # for cases that doesn't uses thinking mode. + if self.reasoner is not None: + structured_req = request.structured_output_request + + if structured_req.reasoning_ended: + return True + + # Check if reasoning ends in *this* step + if self.reasoner.is_reasoning_end(request.all_token_ids): + # Reasoning just ended, so we shouldn't advanced til + # next pass + structured_req.reasoning_ended = True + + return False + else: + return True + def clear_backend(self) -> None: if self.backend is not None: self.backend.destroy() diff --git a/vllm/v1/structured_output/backend_guidance.py b/vllm/v1/structured_output/backend_guidance.py index 0ab175e781e..55c5f609095 100644 --- a/vllm/v1/structured_output/backend_guidance.py +++ b/vllm/v1/structured_output/backend_guidance.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import copy import json import os @@ -8,10 +10,8 @@ import torch -from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.utils import LazyLoader from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, StructuredOutputGrammar, @@ -54,25 +54,17 @@ def process_for_additional_properties( return guide_json_obj +@dataclass class GuidanceBackend(StructuredOutputBackend): - def __init__(self, vllm_config: VllmConfig): - self.vllm_config = vllm_config - tokenizer_group = init_tokenizer_from_configs( - model_config=vllm_config.model_config, - scheduler_config=vllm_config.scheduler_config, - lora_config=vllm_config.lora_config) # type: ignore[arg-type] - self.vllm_config = vllm_config - self.vocab_size = vllm_config.model_config.get_vocab_size() - + def __post_init__(self): self.disable_any_whitespace = \ - vllm_config.decoding_config.disable_any_whitespace + self.vllm_config.decoding_config.disable_any_whitespace self.disable_additional_properties = \ - vllm_config.decoding_config.disable_additional_properties + self.vllm_config.decoding_config.disable_additional_properties - tokenizer = tokenizer_group.get_lora_tokenizer(None) self.ll_tokenizer = llguidance_hf.from_tokenizer( - tokenizer, self.vocab_size) + self.tokenizer, self.vocab_size) def compile_grammar(self, request_type: StructuredOutputOptions, grammar_spec: str) -> StructuredOutputGrammar: diff --git a/vllm/v1/structured_output/backend_types.py b/vllm/v1/structured_output/backend_types.py index 33ca9f8cf48..09f6cdf7333 100644 --- a/vllm/v1/structured_output/backend_types.py +++ b/vllm/v1/structured_output/backend_types.py @@ -1,9 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import enum from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import torch -import torch + from vllm.config import VllmConfig + from vllm.transformers_utils.tokenizer import AnyTokenizer class StructuredOutputOptions(enum.Enum): @@ -85,9 +93,14 @@ def reset(self): """ +@dataclass class StructuredOutputBackend(ABC): """Engine-level backend for structured output requests.""" + vllm_config: VllmConfig + tokenizer: AnyTokenizer + vocab_size: int + @abstractmethod def compile_grammar(self, request_type: StructuredOutputOptions, grammar_spec: str) -> StructuredOutputGrammar: @@ -104,7 +117,7 @@ def compile_grammar(self, request_type: StructuredOutputOptions, """ @abstractmethod - def allocate_token_bitmask(self, max_num_seqs: int): + def allocate_token_bitmask(self, max_num_seqs: int) -> torch.Tensor: """ Allocates a token bitmask for the specified maximum number of sequences. diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py index 2ce2be337ec..f2570221da2 100644 --- a/vllm/v1/structured_output/backend_xgrammar.py +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import json from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any @@ -7,10 +9,8 @@ import torch import vllm.envs -from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.utils import LazyLoader from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, @@ -28,61 +28,49 @@ logger = init_logger(__name__) +@dataclass class XgrammarBackend(StructuredOutputBackend): - def __init__(self, vllm_config: VllmConfig): - self.vllm_config = vllm_config - tokenizer_group = init_tokenizer_from_configs( - model_config=vllm_config.model_config, - scheduler_config=vllm_config.scheduler_config, - lora_config=vllm_config.lora_config) # type: ignore[arg-type] - + def __post_init__(self): self.disable_any_whitespace = \ - vllm_config.decoding_config.disable_any_whitespace + self.vllm_config.decoding_config.disable_any_whitespace - self.num_speculative_tokens = 0 - if self.vllm_config.speculative_config is not None: - self.num_speculative_tokens = \ - self.vllm_config.speculative_config.num_speculative_tokens - - tokenizer = tokenizer_group.get_lora_tokenizer(None) - self.vocab_size = vllm_config.model_config.get_vocab_size() - if isinstance(tokenizer, MistralTokenizer): + if isinstance(self.tokenizer, MistralTokenizer): # NOTE: ideally, xgrammar should handle this accordingly. # refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98 try: - if tokenizer.is_tekken: - encoded_vocab = tokenizer._vocab + if self.tokenizer.is_tekken: + encoded_vocab = self.tokenizer._vocab else: encoded_vocab = [ token for token, _ in sorted( - tokenizer.get_vocab().items(), + self.tokenizer.get_vocab().items(), key=lambda x: x[1], ) ] stop_token_ids = None - if hasattr( - tokenizer, + if (hasattr( + self.tokenizer, "eos_token_id", - ) and tokenizer.eos_token_id is not None: - stop_token_ids = [tokenizer.eos_token_id] + ) and self.tokenizer.eos_token_id is not None): + stop_token_ids = [self.tokenizer.eos_token_id] except AttributeError as e: raise ValueError( f"Cannot get the vocabulary of the tokenizer " - f"{type(tokenizer)}. The tokenizer should have a " + f"{type(self.tokenizer)}. The tokenizer should have a " "get_vocab method.") from e tokenizer_info = xgr.TokenizerInfo( # type: ignore encoded_vocab=encoded_vocab, # NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501 vocab_type=xgr.VocabType.RAW - if tokenizer.is_tekken else xgr.VocabType.BYTE_FALLBACK, + if self.tokenizer.is_tekken else xgr.VocabType.BYTE_FALLBACK, vocab_size=self.vocab_size, stop_token_ids=stop_token_ids, add_prefix_space=True, ) else: tokenizer_info = xgr.TokenizerInfo.from_huggingface( - tokenizer, + self.tokenizer, vocab_size=self.vocab_size, ) self.compiler = xgr.GrammarCompiler( @@ -92,6 +80,11 @@ def __init__(self, vllm_config: VllmConfig): cache_limit_bytes=vllm.envs.VLLM_XGRAMMAR_CACHE_MB * 1024 * 1024, ) + self.num_speculative_tokens = 0 + if self.vllm_config.speculative_config is not None: + self.num_speculative_tokens = \ + self.vllm_config.speculative_config.num_speculative_tokens + def compile_grammar(self, request_type: StructuredOutputOptions, grammar_spec: str) -> StructuredOutputGrammar: if request_type == StructuredOutputOptions.JSON: diff --git a/vllm/v1/structured_output/request.py b/vllm/v1/structured_output/request.py index 6ef472eb896..c16320b9e74 100644 --- a/vllm/v1/structured_output/request.py +++ b/vllm/v1/structured_output/request.py @@ -20,6 +20,7 @@ class StructuredOutputRequest: sampling_params: SamplingParams _grammar: Optional[Union[Future[StructuredOutputGrammar], StructuredOutputGrammar]] = None + reasoning_ended: bool = False def _check_grammar_completion(self) -> bool: # NOTE: We have to lazy import to gate circular imports