diff --git a/requirements/common.txt b/requirements/common.txt index 80f90e60007..f9eb016c9e7 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -20,7 +20,9 @@ prometheus-fastapi-instrumentator >= 7.0.0 tiktoken >= 0.6.0 # Required for DBRX tokenizer lm-format-enforcer >= 0.10.11, < 0.11 llguidance >= 0.7.11, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64" -outlines == 0.1.11 +outlines_core == 0.2.10 +# required for outlines backend disk cache +diskcache == 5.6.3 lark == 1.2.2 xgrammar == 0.1.19; platform_machine == "x86_64" or platform_machine == "aarch64" typing_extensions >= 4.10 diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index fdbdccd4654..d8b0a107f8e 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -15,14 +15,18 @@ from vllm.sampling_params import GuidedDecodingParams, SamplingParams MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" -GUIDED_DECODING_BACKENDS = [ + +# Separate backends which support grammars vs ones +# which only support regex based constraints in tests. +GRAMMAR_DECODING_BACKENDS = [ # (backend, disable_any_whitespace), - ("outlines", False), ("lm-format-enforcer", False), ("xgrammar", True), ("guidance", True), ] +ALL_DECODING_BACKENDS = ([("outlines", False)] + GRAMMAR_DECODING_BACKENDS) + @pytest.fixture(scope="module") def llm(): @@ -38,7 +42,7 @@ def llm(): @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", - GUIDED_DECODING_BACKENDS) + ALL_DECODING_BACKENDS) def test_guided_regex(sample_regex, llm, guided_decoding_backend: str, disable_any_whitespace: bool): sampling_params = SamplingParams( @@ -48,6 +52,7 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str, regex=sample_regex, backend=guided_decoding_backend, disable_any_whitespace=disable_any_whitespace)) + outputs = llm.generate(prompts=[ f"Give an example IPv4 address with this regex: {sample_regex}" ] * 2, @@ -68,7 +73,7 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str, @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", - GUIDED_DECODING_BACKENDS) + ALL_DECODING_BACKENDS) def test_guided_json_completion(sample_json_schema, llm, guided_decoding_backend: str, disable_any_whitespace: bool): @@ -102,7 +107,7 @@ def test_guided_json_completion(sample_json_schema, llm, @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", - GUIDED_DECODING_BACKENDS) + ALL_DECODING_BACKENDS) def test_guided_complex_json_completion(sample_complex_json_schema, llm, guided_decoding_backend: str, disable_any_whitespace: bool): @@ -137,7 +142,7 @@ def test_guided_complex_json_completion(sample_complex_json_schema, llm, @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", - GUIDED_DECODING_BACKENDS) + ALL_DECODING_BACKENDS) def test_guided_definition_json_completion(sample_definition_json_schema, llm, guided_decoding_backend: str, disable_any_whitespace: bool): @@ -172,7 +177,7 @@ def test_guided_definition_json_completion(sample_definition_json_schema, llm, @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", - GUIDED_DECODING_BACKENDS) + ALL_DECODING_BACKENDS) def test_guided_enum_json_completion(sample_enum_json_schema, llm, guided_decoding_backend: str, disable_any_whitespace: bool): @@ -217,7 +222,7 @@ def test_guided_enum_json_completion(sample_enum_json_schema, llm, @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", - GUIDED_DECODING_BACKENDS) + ALL_DECODING_BACKENDS) def test_guided_choice_completion(sample_guided_choice, llm, guided_decoding_backend: str, disable_any_whitespace: bool): @@ -247,7 +252,7 @@ def test_guided_choice_completion(sample_guided_choice, llm, @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", - GUIDED_DECODING_BACKENDS) + GRAMMAR_DECODING_BACKENDS) def test_guided_grammar(sample_sql_statements, llm, guided_decoding_backend: str, disable_any_whitespace: bool): @@ -343,7 +348,7 @@ def test_disable_guided_decoding_fallback(sample_regex, llm): @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", - GUIDED_DECODING_BACKENDS) + GRAMMAR_DECODING_BACKENDS) def test_guided_json_object(llm, guided_decoding_backend: str, disable_any_whitespace: bool): sampling_params = SamplingParams( @@ -376,7 +381,9 @@ def test_guided_json_object(llm, guided_decoding_backend: str, # Parse to verify it is valid JSON parsed_json = json.loads(generated_text) - assert isinstance(parsed_json, dict) + # A list is not what was intended, but is still valid + # json. + assert isinstance(parsed_json, (dict, list)) class CarType(str, Enum): @@ -394,7 +401,7 @@ class CarDescription(BaseModel): @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", - GUIDED_DECODING_BACKENDS) + ALL_DECODING_BACKENDS) def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str, disable_any_whitespace: bool): json_schema = CarDescription.model_json_schema() @@ -426,7 +433,7 @@ def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str, @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", - GUIDED_DECODING_BACKENDS) + ALL_DECODING_BACKENDS) def test_guided_number_range_json_completion(llm, guided_decoding_backend: str, disable_any_whitespace: bool): sample_output_schema = { diff --git a/tests/model_executor/test_guided_processors.py b/tests/model_executor/test_guided_processors.py index 6cd966f8480..66980bffa88 100644 --- a/tests/model_executor/test_guided_processors.py +++ b/tests/model_executor/test_guided_processors.py @@ -39,26 +39,23 @@ def test_guided_logits_processors(zephyr_7B_tokenzer, sample_regex, """Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor.""" regex_LP = RegexLogitsProcessor(sample_regex, zephyr_7B_tokenzer, - reasoner=None) + reasoner=None, + vocab_size=32000) json_LP = JSONLogitsProcessor(sample_json_schema, zephyr_7B_tokenzer, whitespace_pattern=None, - reasoner=None) + reasoner=None, + vocab_size=32000) - token_ids = zephyr_7B_tokenzer.encode( - f"Give an example IPv4 address with this regex: {sample_regex}") tensor = torch.rand(32000) original_tensor = torch.clone(tensor) - regex_LP(token_ids, tensor) + tensor = regex_LP([], tensor) assert tensor.shape == original_tensor.shape assert not torch.allclose(tensor, original_tensor) - token_ids = zephyr_7B_tokenzer.encode( - f"Give an employee profile that fits this schema: {sample_json_schema}" - ) tensor = torch.rand(32000) original_tensor = torch.clone(tensor) - json_LP(token_ids, tensor) + tensor = json_LP([], tensor) assert tensor.shape == original_tensor.shape assert not torch.allclose(tensor, original_tensor) @@ -80,8 +77,6 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool, seed=0, dtype="bfloat16", ) - token_ids = zephyr_7B_tokenzer.encode( - f"Give an example IPv4 address with this regex: {sample_regex}") regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend) regex_lp = get_local_guided_decoding_logits_processor( @@ -91,13 +86,11 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool, assert regex_lp is not None tensor = torch.rand(32000) original_tensor = torch.clone(tensor) - tensor = regex_lp(token_ids, tensor) + # allowed tokens at state 0 + tensor = regex_lp([], tensor) assert tensor.shape == original_tensor.shape assert not torch.allclose(tensor, original_tensor) - token_ids = zephyr_7B_tokenzer.encode( - f"Give an employee profile that fits this schema: {sample_json_schema}" - ) json_request = GuidedDecodingParams(json=sample_json_schema, backend=backend) json_lp = await get_guided_decoding_logits_processor( @@ -105,7 +98,7 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool, assert json_lp is not None tensor = torch.rand(32000) original_tensor = torch.clone(tensor) - tensor = json_lp(token_ids, tensor) + tensor = json_lp([], tensor) assert tensor.shape == original_tensor.shape assert not torch.allclose(tensor, original_tensor) @@ -129,7 +122,6 @@ async def test_guided_logits_processor_with_reasoning( dtype="bfloat16", ) token_ids = deepseek_r1_qwen_tokenizer.encode( - f"Give an example IPv4 address with this regex: {sample_regex}." "here is the thinking process") regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend) @@ -142,12 +134,11 @@ async def test_guided_logits_processor_with_reasoning( assert regex_lp is not None tensor = torch.rand(32000) original_tensor = torch.clone(tensor) - tensor = regex_lp(token_ids, tensor) + regex_lp(token_ids, tensor) assert tensor.shape == original_tensor.shape assert torch.allclose(tensor, original_tensor) token_ids = deepseek_r1_qwen_tokenizer.encode( - f"Give an employee profile that fits this schema: {sample_json_schema}." "here is the thinking process") json_request = GuidedDecodingParams(json=sample_json_schema, backend=backend) @@ -159,14 +150,13 @@ async def test_guided_logits_processor_with_reasoning( assert json_lp is not None tensor = torch.rand(32000) original_tensor = torch.clone(tensor) - tensor = json_lp(token_ids, tensor) + json_lp(token_ids, tensor) assert tensor.shape == original_tensor.shape assert torch.allclose(tensor, original_tensor) # Thinking is over, so the tensor should change. token_ids = deepseek_r1_qwen_tokenizer.encode( - f"Give an employee profile that fits this schema: {sample_json_schema}." - "here is the thinking process Then") + "here is the thinking process") json_request = GuidedDecodingParams(json=sample_json_schema, backend=backend) json_lp = get_local_guided_decoding_logits_processor( diff --git a/tests/tool_use/test_tool_choice_required.py b/tests/tool_use/test_tool_choice_required.py index 2ab87a0ef41..a2c6a66fef9 100644 --- a/tests/tool_use/test_tool_choice_required.py +++ b/tests/tool_use/test_tool_choice_required.py @@ -71,7 +71,7 @@ def _compile_and_check(tools: list[ChatCompletionToolsParam], sample_output, assert isinstance(schema, dict) # use build_regex_from_schema used in JSONLogitsProcessor to create Guide - from outlines_core.fsm.json_schema import build_regex_from_schema + from outlines_core.json_schema import build_regex_from_schema regex = build_regex_from_schema(json.dumps(schema)) compiled = re.compile(regex) matches = compiled.fullmatch(json.dumps(sample_output)) is not None diff --git a/vllm/config.py b/vllm/config.py index dddfdabd126..88a45e782fb 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3289,7 +3289,8 @@ def get_served_model_name(model: str, GuidedDecodingBackendV0 = Literal["auto", "outlines", "lm-format-enforcer", "xgrammar", "guidance"] -GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance"] + +GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance", "outlines"] GuidedDecodingBackend = Literal[GuidedDecodingBackendV0, GuidedDecodingBackendV1] diff --git a/vllm/envs.py b/vllm/envs.py index fe3fa91fbe3..dbbb52e772b 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -109,6 +109,7 @@ VLLM_DP_MASTER_PORT: int = 0 VLLM_MARLIN_USE_ATOMIC_ADD: bool = False VLLM_V0_USE_OUTLINES_CACHE: bool = False + VLLM_V1_USE_OUTLINES_CACHE: bool = False VLLM_TPU_BUCKET_PADDING_GAP: int = 0 VLLM_USE_DEEP_GEMM: bool = False VLLM_XGRAMMAR_CACHE_MB: int = 0 @@ -737,6 +738,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: "VLLM_V0_USE_OUTLINES_CACHE": lambda: os.environ.get("VLLM_V0_USE_OUTLINES_CACHE", "0") == "1", + # Whether to turn on the outlines cache for V0 + # This cache is unbounded and on disk, so it's not safe to use in + # an environment with potentially malicious users. + "VLLM_V1_USE_OUTLINES_CACHE": + lambda: os.environ.get("VLLM_V1_USE_OUTLINES_CACHE", "0") == "1", + # Gap between padding buckets for the forward pass. So we have # 8, we will run forward pass with [16, 24, 32, ...]. "VLLM_TPU_BUCKET_PADDING_GAP": diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index a2b61a1b19e..743b5a8cfed 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -110,13 +110,12 @@ async def get_guided_decoding_logits_processor( guided_params = maybe_backend_fallback(guided_params) - # CFG grammar not supported by LMFE, so we use outlines instead if guided_params.backend == 'outlines': # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa get_outlines_guided_decoding_logits_processor) return await get_outlines_guided_decoding_logits_processor( - guided_params, tokenizer, reasoner) + guided_params, tokenizer, reasoner, model_config) if guided_params.backend == 'lm-format-enforcer': from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa get_local_lm_format_enforcer_guided_decoding_logits_processor) @@ -151,13 +150,12 @@ def get_local_guided_decoding_logits_processor( reasoning_backend) reasoner = reasoner_class(tokenizer) - # CFG grammar not supported by LMFE, so we use outlines instead if guided_params.backend == 'outlines': # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa get_local_outlines_guided_decoding_logits_processor) return get_local_outlines_guided_decoding_logits_processor( - guided_params, tokenizer, reasoner) + guided_params, tokenizer, reasoner, model_config) if guided_params.backend == 'lm-format-enforcer': from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa get_local_lm_format_enforcer_guided_decoding_logits_processor) diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py index bcd7494e6ce..dfa5527583b 100644 --- a/vllm/model_executor/guided_decoding/outlines_decoding.py +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -10,8 +10,9 @@ from transformers import PreTrainedTokenizerBase +from vllm.config import ModelConfig from vllm.model_executor.guided_decoding.outlines_logits_processors import ( - CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor) + JSONLogitsProcessor, RegexLogitsProcessor) from vllm.reasoning import ReasoningParser from vllm.sampling_params import GuidedDecodingParams @@ -20,36 +21,8 @@ class GuidedDecodingMode(Enum): JSON = "json" REGEX = "regex" CHOICE = "choice" - GRAMMAR = "grammar" -# https://github.com/outlines-dev/outlines/blob/main/outlines/grammars/json.lark -# the main difference is that we changed the start: value to -# start: object | array, so we are denying scalar values as the root of the -# JSON. Starting with scalars as the root seems to cause llama to generate -# without stop. -JSON_GRAMMAR = r""" -?start: object | array - -?value: object -| array -| UNESCAPED_STRING -| SIGNED_NUMBER -> number -| "true" -> true -| "false" -> false -| "null" -> null - -array : "[" [value ("," value)*] "]" -object : "{" [pair ("," pair)*] "}" -pair : UNESCAPED_STRING ":" value - -%import common.UNESCAPED_STRING -%import common.SIGNED_NUMBER -%import common.WS - -%ignore WS -""" - global_thread_pool = None # used for generating logits processor fsm # It's not yet clear that using more provides a benefit, and it could @@ -59,16 +32,12 @@ class GuidedDecodingMode(Enum): async def get_outlines_guided_decoding_logits_processor( - guided_params: GuidedDecodingParams, - tokenizer: PreTrainedTokenizerBase, - reasoner: Optional[ReasoningParser], -) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, - None]: + guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase, + reasoner: Optional[ReasoningParser], model_config: ModelConfig +) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, None]: """ Given an OpenAI-compatible request, check for guided decoding parameters and get the necessary logits processor for the given guide. - We cache logit processors by (guide, tokenizer), and on cache hit - we make a shallow copy to reuse the same underlying FSM. """ global global_thread_pool guide, mode = _get_guide_and_mode(guided_params) @@ -82,31 +51,28 @@ async def get_outlines_guided_decoding_logits_processor( global_thread_pool = concurrent.futures.ThreadPoolExecutor( max_workers=max_workers) loop = asyncio.get_running_loop() - + vocab_size = model_config.get_vocab_size() return await loop.run_in_executor(global_thread_pool, _get_logits_processor, guide, tokenizer, mode, guided_params.whitespace_pattern, - reasoner) + reasoner, vocab_size) def get_local_outlines_guided_decoding_logits_processor( - guided_params: GuidedDecodingParams, - tokenizer: PreTrainedTokenizerBase, - reasoner: Optional[ReasoningParser], -) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, - None]: + guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase, + reasoner: Optional[ReasoningParser], model_config: ModelConfig +) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, None]: """ Given an OpenAI-compatible request, check for guided decoding parameters and get the necessary logits processor for the given guide. - We cache logit processors by (guide, tokenizer), and on cache hit - we make a shallow copy to reuse the same underlying FSM. """ guide, mode = _get_guide_and_mode(guided_params) if not guide or not mode: return None return _get_logits_processor(guide, tokenizer, mode, - guided_params.whitespace_pattern, reasoner) + guided_params.whitespace_pattern, reasoner, + model_config.get_vocab_size()) def _get_guide_and_mode( @@ -129,9 +95,10 @@ def _get_guide_and_mode( choices_regex = "(" + "|".join(choices) + ")" return choices_regex, GuidedDecodingMode.CHOICE elif guided_params.grammar: - return guided_params.grammar, GuidedDecodingMode.GRAMMAR - elif guided_params.json_object: - return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR + raise ValueError( + "The `outlines` guided decoding backend no longer supports grammar " + "guided generation. Please use either the `xgrammar` or `guidance` " + "backend") else: return None, None @@ -142,13 +109,12 @@ def _get_logits_processor( mode: GuidedDecodingMode, whitespace_pattern: Union[str, None], reasoner: Optional[ReasoningParser], -) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]: + vocab_size: int, +) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]: if mode == GuidedDecodingMode.JSON: return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern, - reasoner) + reasoner, vocab_size) elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE: - return RegexLogitsProcessor(guide, tokenizer, reasoner) - elif mode == GuidedDecodingMode.GRAMMAR: - return CFGLogitsProcessor(guide, tokenizer, reasoner) + return RegexLogitsProcessor(guide, tokenizer, reasoner, vocab_size) else: raise ValueError(f"Unknown guided decoding mode {mode}") diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index 8ae7c7b6b2c..054a465bed6 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -7,155 +7,124 @@ # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at - +# # http://www.apache.org/licenses/LICENSE-2.0 - +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy + +from __future__ import annotations + +import hashlib +import importlib.metadata import json -from collections import defaultdict -from functools import lru_cache -from typing import Callable, Optional, Union +import os +import re +from typing import Optional, Union -import numpy as np import torch -from outlines import grammars -from outlines.caching import cache, disable_cache -from outlines.fsm.guide import (CFGGuide, CFGState, Generate, Guide, - RegexGuide, Write) -from outlines.fsm.parsing import PartialLark -from outlines_core.fsm.json_schema import build_regex_from_schema +from cachetools import LRUCache +from diskcache import Cache +from outlines_core import Guide, Index, Vocabulary +from outlines_core.json_schema import build_regex_from_schema +from outlines_core.kernels.torch import (_apply_token_bitmask_inplace_kernel, + allocate_token_bitmask) from pydantic import BaseModel from transformers import PreTrainedTokenizerBase +from transformers.file_utils import SPIECE_UNDERLINE +from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode import vllm.envs as envs from vllm.logger import init_logger -from vllm.platforms import current_platform from vllm.reasoning import ReasoningParser +from vllm.transformers_utils.tokenizer import AnyTokenizer logger = init_logger(__name__) -if envs.VLLM_V0_USE_OUTLINES_CACHE: - logger.warning("Enabling outlines cache. This is an unbounded on-disk " - "cache. It may consume a lot of disk space and should " - "not be used with untrusted clients.") -else: - disable_cache() +CACHE = None class BaseLogitsProcessor: - def __init__(self, guide: Guide, reasoner: Optional[ReasoningParser]): + def __init__(self, guide: Guide, vocab_size: int, eos_token_id: int, + reasoner: Optional[ReasoningParser]): self._guide: Guide = guide + self._eos_token_id = eos_token_id self._reasoner: Optional[ReasoningParser] = reasoner - # CFGState is used for the FSM state for CFGGuide - self._fsm_state: defaultdict[int, Union[int, - CFGState]] = defaultdict(int) + self._mask = allocate_token_bitmask(vocab_size) def __call__(self, input_ids: list[int], scores: torch.Tensor) -> torch.Tensor: - """Use the FSM to bias the logits before sampling the next token.""" # Skip the structured logits processing if reasoning is not finished. # reasoner is not None only when `--reasoning-parser` is set. - if self._reasoner is not None: - if not self._reasoner.is_reasoning_end(input_ids): - return scores - else: - # Remove the reasoning tokens from the input_ids - # We need this because our implementation relies on the - # hash of the input_ids to store the FSM state. - input_ids = self._reasoner.extract_content_ids(input_ids) - - seq_id = hash(tuple(input_ids)) - - if len(input_ids) > 0: - last_token = input_ids[-1] - last_seq_id = hash(tuple(input_ids[:-1])) - self._fsm_state[seq_id] = self._guide.get_next_state( - state=self._fsm_state[last_seq_id], token_id=last_token) - else: - # Note: this is a hack. - # Lark pickling does not work properly (silent failure), - # which breaks the RPC (which uses python pickleing). - # We need to find a better solution. - # On the first time this is called, we simply re-create - # the Lark object. - if isinstance(self._guide, CFGGuide): - self._guide.parser = PartialLark( - self._guide.cfg_string, - parser="lalr", - import_paths=[grammars.GRAMMAR_PATH], - ) - self._fsm_state[seq_id] = CFGState( - parser_state=self._guide.parser.parse(""), prev_token=None) - - instruction = self._guide.get_next_instruction( - state=self._fsm_state[seq_id]) - - if type(instruction) == Generate: # noqa: E721 - allowed_tokens = instruction.tokens - elif type(instruction) == Write: # noqa: E721 - # TODO: support fast forward tokens - allowed_tokens = [instruction.tokens[0]] - else: - raise TypeError( - f"Unsupported instruction type {type(instruction)}") - - mask = torch.full((scores.shape[-1], ), - -torch.inf, - device=scores.device) - # The tokenizer may support more token ids than the model can generate, - # eg. Llama 3.2 Vision models have an `<|image|>` token with id 128256 - # but scores.shape == torch.Size([128256]) - # Using NumPy is faster for filtering token ids - allowed_tokens = np.array(allowed_tokens, dtype=np.int64) - allowed_tokens = torch.tensor(allowed_tokens, device=scores.device) - allowed_tokens = allowed_tokens.masked_select( - allowed_tokens < scores.shape[-1]) - mask.index_fill_(0, allowed_tokens, 0) - if current_platform.is_hpu(): - # Workaround for HPU bug where add_() raise RuntimeError: - # synNodeCreateWithId failed for node: strided_insert - # with synStatus 1 [Invalid argument], hopefully it will - # be fixed in the future releases of the HPU runtime. - scores = scores.add(mask) - else: - scores.add_(mask) + if self._reasoner is not None and not self._reasoner.is_reasoning_end( + input_ids): + return scores + + # Remove the reasoning tokens from the input_ids + # We need this because our implementation relies on the + # input_ids sequence to store the FSM state. + input_ids = (self._reasoner.extract_content_ids(input_ids) + if self._reasoner is not None else input_ids) + + # Vllm V0 engine has a weird bug where we have to repeat + # the eos token id twice for generation to stop, or at least + # that is what we have to do from here in any case. + # This is a patch until a better solution can be pushed + # to outlines_core + if input_ids and input_ids[-1] != self._eos_token_id: + self._guide.advance(token_id=input_ids[-1], return_tokens=False) + + self._guide.write_mask_into( + data_ptr=self._mask.data_ptr(), + numel=self._mask.numel(), + element_size=self._mask.element_size(), + ) + + # Any allowed tokens beyond the length of the scores will + # be ignored by the kernel, taking care of the issue with + # models such as Llama 3.2 Vision with an `<|image|>` token + # with id 128256, but scores.shape == torch.Size([128256]) + _apply_token_bitmask_inplace_kernel( + logits=scores.unsqueeze(dim=0), + # mask must be on same device + mask=self._mask.to(scores.device)) + self._mask.to("cpu") + return scores class RegexLogitsProcessor(BaseLogitsProcessor): @classmethod - @cache() def _get_guide(cls, regex_string: str, tokenizer: PreTrainedTokenizerBase) -> Guide: - tokenizer = _adapt_tokenizer(tokenizer) - return RegexGuide.from_regex(regex_string, tokenizer) - - def __init__( - self, - regex_string: str, - tokenizer: PreTrainedTokenizerBase, - reasoner: Optional[ReasoningParser], - ): - """Compile the FSM that drives the regex-structured generation. - - Parameters - ---------- - regex_string - A string that represents a regular expression - tokenizer - The model's tokenizer - - """ + global CACHE + if CACHE is None: + CACHE = get_cache() + vocabulary = get_vocabulary(tokenizer) # type: ignore[arg-type] + cache_key = f"{vocabulary._hash}_{regex_string}" + if CACHE is not None and cache_key in CACHE: + return Guide(CACHE[cache_key]) + + index = Index(regex_string, vocabulary.inner) + + if CACHE is not None: + CACHE[cache_key] = index + + return Guide(index) + + def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase, + reasoner: Optional[ReasoningParser], vocab_size: int) -> None: super().__init__( - RegexLogitsProcessor._get_guide(regex_string, tokenizer), reasoner) + guide=RegexLogitsProcessor._get_guide(regex_string, tokenizer), + vocab_size=vocab_size, + eos_token_id=tokenizer.eos_token_id, # type: ignore + reasoner=reasoner) class JSONLogitsProcessor(RegexLogitsProcessor): @@ -163,22 +132,8 @@ class JSONLogitsProcessor(RegexLogitsProcessor): def __init__(self, schema: Union[str, dict, BaseModel], tokenizer: PreTrainedTokenizerBase, whitespace_pattern: Union[str, None], - reasoner: Optional[ReasoningParser]): - """Compile the FSM that drives the JSON-guided generation. - - Parameters - ---------- - schema - A JSON schema that encodes the structure we want the model to - generate - tokenizer - The model's tokenizer - whitespace_pattern - Pattern to use for JSON syntactic whitespace (doesn't impact - string literals) - Example: allow only a single space or newline with - `whitespace_pattern=r"[\n ]?"` - """ + reasoner: Optional[ReasoningParser], vocab_size: int) -> None: + if isinstance(schema, type(BaseModel)): schema_str = json.dumps(schema.model_json_schema()) elif isinstance(schema, dict): @@ -190,57 +145,42 @@ def __init__(self, schema: Union[str, dict, BaseModel], f"Cannot parse schema {schema}. The schema must be either " f"a Pydantic object, a dictionary or a string that contains " f"the JSON Schema specification") - regex_string = build_regex_from_schema(schema_str, whitespace_pattern) - super().__init__(regex_string, tokenizer, reasoner) - - -class CFGLogitsProcessor(BaseLogitsProcessor): - @classmethod - @cache() - def _get_guide(cls, cfg: str, tokenizer: PreTrainedTokenizerBase) -> Guide: - tokenizer = _adapt_tokenizer(tokenizer) - return CFGGuide(cfg, tokenizer) + regex_string = build_regex_from_schema(schema_str, whitespace_pattern) + super().__init__(regex_string, tokenizer, reasoner, vocab_size) - def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase, - reasoner: Optional[ReasoningParser]): - """Compile the FSM that drives the context free grammar generation. - Parameters - ---------- - cfg - A string that represents a context-free grammar - tokenizer - The model's tokenizer +class OutlinesVocabulary: + """ + Wrapper class for `outlines_core.Vocabulary`, + which allows us to store a hash with the vocabulary + """ - """ - super().__init__(CFGLogitsProcessor._get_guide(cfg, tokenizer), - reasoner) - self._guide = self._guide.copy() + def __init__(self, vocabulary: Vocabulary): + # Actual vocabulary object + self.inner = vocabulary + # Have to do abs(hash()) because python hashes can + # be negative, and we are using hash as a cache key. + hex_str = hashlib.sha256( + vocabulary.__repr__().encode('utf-8')).hexdigest() + hash_int = int(hex_str, 16) + self._hash = hash_int -@lru_cache(maxsize=32) -def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase): - """Adapt vLLM's tokenizer to use to compile the FSM. +re_llama_byte_token = re.compile(r"^<0x[0-9A-F]{2}>$") +re_replacement_seq = re.compile(r"^.?�+.?$") - The API of Outlines tokenizers is slightly different to that of - `transformers`. The decoder of outlines, returns a list whereas - the decode of vLLM returns an str. To sync the vLLM decoder with - outlines internal api, the decoder should be adapted. In addition - we need to handle the missing spaces to Llama's tokenizer to be - able to compile FSMs for this model. +def _reduced_vocabulary(tokenizer: AnyTokenizer, + eos_token_id: int) -> dict[bytes, list[int]]: + """Create a map from vocabulary tokens to lists of equivalent token ids. + + Returns: + A Dict of token string -> equivalent token ids """ - if getattr(tokenizer, "_outlines_adapted", False): - return tokenizer - - tokenizer = copy.deepcopy(tokenizer) - - tokenizer.vocabulary = tokenizer.get_vocab() - tokenizer.special_tokens = set(tokenizer.all_special_tokens) + unicode_to_bytes = {v: k for k, v in bytes_to_unicode().items()} def convert_token_to_string(token: str) -> str: - from transformers.file_utils import SPIECE_UNDERLINE string = tokenizer.convert_tokens_to_string([token]) @@ -251,21 +191,122 @@ def convert_token_to_string(token: str) -> str: return string - def change_decoder( - decoder: Callable[[list[int]], - str]) -> Callable[[list[int]], list[str]]: - """Sync vLLM's decoder with the outlines by returning list.""" + vocabulary: dict[bytes, list[int]] = {} + empty_token_ids: list[int] = [] + for token, token_idx in tokenizer.get_vocab().items(): + if token in tokenizer.all_special_tokens: # type: ignore + continue + + token_str = convert_token_to_string(token) + if token_str: + if isinstance(token, (bytes, bytearray)): + # For BPE tokenizers where tokens are stored as bytes. + + # safe to ignore since token_str is of type (bytearray, bytes) + # by this point. + token_bytes = bytes(token_str) # type: ignore[arg-type] + + elif "\ufffd" in token_str and not re_replacement_seq.match(token): + # Handle tokens with invalid UTF-8 sequences. + if re_llama_byte_token.match(token): + # Llama-like tokenizers use <0xXX> for incomplete sequences. + token_bytes = bytes([int(token[3:5], 16)]) + else: + # GPT2 tokenizers: map each byte back using unicode_to_bytes + byte_vals = [unicode_to_bytes.get(c) for c in token] + if None in byte_vals: + raise RuntimeError( + f"Cannot convert token `{token}`" + f" ({token_idx}) to bytes: {token_str}") + # safe to ignore, since if None in byte_vals, + # an error is thrown. + token_bytes = bytes(byte_vals) # type: ignore[arg-type] + else: + token_bytes = token_str.encode('utf-8') - def new_decoder(inp_tokens: list[int]) -> list[str]: - if (isinstance(inp_tokens, list) and len(inp_tokens) == 1 - and isinstance(inp_tokens[0], list)): - inp_tokens = inp_tokens[0] - return [decoder(inp_tokens)] + if token_idx != eos_token_id: + vocabulary.setdefault(token_bytes, []).append(token_idx) + else: + empty_token_ids.append(token_idx) - return new_decoder + return vocabulary - tokenizer.convert_token_to_string = convert_token_to_string - tokenizer.decode = change_decoder(tokenizer.decode) - setattr(tokenizer, "_outlines_adapted", True) # noqa: B010 - return tokenizer +def get_vocabulary(tokenizer: AnyTokenizer) -> Vocabulary: + """Get the `Vocabulary` object for a given tokenizer. + """ + if hasattr(tokenizer, "_outlines_vocabulary"): + return tokenizer._outlines_vocabulary # type: ignore + + try: + if hasattr( + tokenizer, + "eos_token_id", + ) and tokenizer.eos_token_id is not None: + eos_token_id = tokenizer.eos_token_id + else: + raise ValueError( + f"Error during guided decoding setup: Tokenizer" + f" ({type(tokenizer)}) has no `eos_token_id` property, " + "but `eos_token_id` is required for guided decoding" + " to work properly.") + + reduced_vocab = _reduced_vocabulary( + tokenizer, + eos_token_id #type: ignore + ) + vocabulary = OutlinesVocabulary(Vocabulary(eos_token_id, + reduced_vocab)) + tokenizer._outlines_vocabulary = vocabulary # type: ignore + + return vocabulary + except AttributeError as e: + raise ValueError(f"Cannot get the vocabulary of the tokenizer " + f"({type(tokenizer)}). The tokenizer should have a " + "get_vocab method.") from e + + +def get_cache_path() -> str: + """Get the context object that contains previously-computed return values""" + outlines_cache_dir = os.getenv("OUTLINES_CACHE_DIR") + xdg_cache_home = os.getenv("XDG_CACHE_HOME") + home_dir = os.path.expanduser("~") + + if outlines_cache_dir: + # OUTLINES_CACHE_DIR takes precedence + return outlines_cache_dir + elif xdg_cache_home: + return os.path.join(xdg_cache_home, ".cache", "outlines") + # If homedir is "/", we may be inside a container, and thus writing to + # root would be problematic, so we fallback to using a tempfile. + # Also validate the path exists, since os.path.expanduser does + # not garuntee existence. + elif os.path.isdir(home_dir) and home_dir != "/": + # Default Unix fallback: ~/.cache/outlines + return os.path.join(home_dir, ".cache", "outlines") + else: + import tempfile + + # home_dir may be / inside a docker container without existing user + tempdir = tempfile.gettempdir() + return os.path.join(tempdir, ".cache", "outlines") + + +def get_cache(): + """Get the Cache instance to be used for index caching""" + + cache_dir = get_cache_path() + if envs.VLLM_V0_USE_OUTLINES_CACHE: + logger.warning("Enabling outlines cache. This is an unbounded on-disk " + "cache. It may consume a lot of disk space and should " + "not be used with untrusted clients.") + cache = Cache(cache_dir, eviction_policy="none", cull_limit=0) + outlines_version = importlib.metadata.version("outlines_core") + + cached_version = cache.get('__version__', None) + if cached_version != outlines_version: + cache.clear() + cache.set('__version__', outlines_version) + return cache + else: + return LRUCache(maxsize=128) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 64a75614878..28c7b3320b4 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -22,6 +22,8 @@ from vllm.v1.engine.mm_input_cache import MirroredProcessingCache from vllm.v1.structured_output.backend_guidance import ( validate_guidance_grammar) +from vllm.v1.structured_output.backend_outlines import ( + validate_structured_output_request_outlines) from vllm.v1.structured_output.backend_xgrammar import ( validate_xgrammar_grammar) @@ -181,6 +183,9 @@ def _validate_structured_output(self, params: SamplingParams) -> None: # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens # Without tokenizer these are disallowed in grammars. validate_guidance_grammar(params, tokenizer=None) + elif engine_level_backend == "outlines": + # outlines backend + validate_structured_output_request_outlines(params) else: # NOTE: engine_level_backend must be "auto" here, because we have # checked supported_backends above. diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index c701ab1d35a..bddccbcfc30 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -82,6 +82,15 @@ def grammar_init(self, request: Request) -> None: tokenizer=self.tokenizer, vocab_size=vocab_size, ) + elif backend == "outlines": + from vllm.v1.structured_output.backend_outlines import ( + OutlinesBackend) + + self.backend = OutlinesBackend( + self.vllm_config, + tokenizer=self.tokenizer, + vocab_size=vocab_size, + ) else: raise ValueError( f"Unsupported structured output backend: {backend}") diff --git a/vllm/v1/structured_output/backend_outlines.py b/vllm/v1/structured_output/backend_outlines.py new file mode 100644 index 00000000000..9256ed71db2 --- /dev/null +++ b/vllm/v1/structured_output/backend_outlines.py @@ -0,0 +1,300 @@ +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import ast +import json +import re +import sys +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +import torch + +from vllm.model_executor.guided_decoding.outlines_logits_processors import ( + OutlinesVocabulary, get_cache, get_vocabulary) +from vllm.sampling_params import SamplingParams +from vllm.utils import LazyLoader +from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, + StructuredOutputGrammar, + StructuredOutputOptions) + +if TYPE_CHECKING: + import outlines_core as oc +else: + oc = LazyLoader("oc", globals(), "outlines_core") + +# Python 3.11+ sre_parse and sre_constants +# are deprecated, so we must import them from re +if sys.version_info >= (3, 11): + from re import _constants as sre_constants # type: ignore[attr-defined] + from re import _parser as sre_parse # type: ignore[attr-defined] +else: + import sre_constants + import sre_parse + + +@dataclass +class OutlinesBackend(StructuredOutputBackend): + + def __post_init__(self): + self.vocabulary = get_vocabulary(self.tokenizer) + self.cache = get_cache() + + def _compile_index(self, regex_string: str, + vocabulary: OutlinesVocabulary) -> oc.Index: + cache_key = f"{vocabulary._hash}_{regex_string}" + if cache_key in self.cache: + return self.cache[cache_key] + + index = oc.Index(regex_string, vocabulary.inner) + self.cache[cache_key] = index + + return index + + def compile_grammar(self, request_type: StructuredOutputOptions, + grammar_spec: str) -> StructuredOutputGrammar: + if request_type == StructuredOutputOptions.JSON: + regex = oc.json_schema.build_regex_from_schema(grammar_spec) + elif request_type == StructuredOutputOptions.REGEX: + regex = grammar_spec + elif request_type == StructuredOutputOptions.CHOICE: + choices = ast.literal_eval(grammar_spec) + choices = [re.escape(c) for c in choices] + regex = "(" + "|".join(choices) + ")" + else: + raise ValueError( + f"Invalid request type for Outlines backend ({request_type!s})" + ) + index = self._compile_index(regex, self.vocabulary) + return OutlinesGrammar( + vocab_size=self.vocab_size, + guide=oc.Guide(index, + max_rollback=self.vllm_config.speculative_config. + num_speculative_tokens)) + + def allocate_token_bitmask(self, max_num_seqs: int) -> torch.Tensor: + return torch.full( + (max_num_seqs, (self.vocab_size + 31) // 32), + -1, + dtype=torch.int32, + pin_memory=torch.cuda.is_available(), + ) + + def destroy(self): + pass + + +@dataclass +class OutlinesGrammar(StructuredOutputGrammar): + + vocab_size: int + guide: oc.Guide = field(hash=False) + num_processed_tokens: int = field(default_factory=lambda: 0, + repr=False, + hash=False, + init=False) + + # outlines_core signals done on DFA accept; vLLM expects done after EOS. + # We delay the finished flag by one step so EOS can still be emitted. + _prev_finished: bool = field(default=False, + init=False, + repr=False, + hash=False) + + def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: + """Accepts a list of tokens and advances the FSM. + + Returns True if the FSM was advanced successfully. + Returns False if the FSM failed to advance. + """ + if self.guide.accepts_tokens(tokens): + # Advance cannot fail because we checked Guide.accepts_tokens() + for t in tokens: + self.guide.advance(t) + return True + return False + + def rollback(self, num_tokens: int) -> None: + self.guide.rollback_state(num_tokens) + self.num_processed_tokens -= num_tokens + + def validate_tokens(self, tokens: list[int]) -> list[int]: + accepted: list[int] = [] + for tok in tokens: + accepted.append(tok) + if not self.guide.accepts_tokens(accepted): + accepted.pop() + break + return accepted + + def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None: + mask = bitmask[idx] + self.guide.write_mask_into(mask.data_ptr(), mask.numel(), + mask.element_size()) + + def is_terminated(self) -> bool: + curr = self.guide.is_finished() + prev = self._prev_finished + self._prev_finished = curr + return prev + + def reset(self): + self.num_processed_tokens = 0 + self._prev_finished = False + self.guide.reset() + + +def validate_structured_output_request_outlines(params: SamplingParams): + if params.guided_decoding is None: + return + + gd_params = params.guided_decoding + + if gd_params.regex: + validate_regex_is_buildable(gd_params.regex) + elif gd_params.json: + if isinstance(gd_params.json, str): + try: + schema = json.loads(gd_params.json) + except json.JSONDecodeError as e: + raise ValueError("Invalid JSON grammar specification.") from e + else: + schema = gd_params.json + pattern = oc.json_schema.build_regex_from_schema(schema) + validate_regex_is_buildable(pattern) + elif gd_params.choice: + choices = [re.escape(str(choice)) for choice in gd_params.choice] + regex = "(" + "|".join(choices) + ")" + validate_regex_is_buildable(regex) + elif gd_params.grammar: + raise ValueError("Outlines guided decoding backend " + "does not support grammar specifications") + + +def _prefix_needs_context(parsed) -> bool: + """Return True if there's a look-around/anchor before any consumer.""" + + def subpattern_consumes(parsed) -> bool: + """Return True if subpattern can consume at least one character.""" + tokens = parsed.data if hasattr(parsed, 'data') else parsed + for ttype, tval in tokens: + # literal, character class, or dot always consumes + if ttype in (sre_parse.LITERAL, sre_parse.IN, sre_parse.ANY): + return True + # quantified subpattern: check inner pattern + elif ttype == sre_parse.MAX_REPEAT: + _, mx, sub = tval + if mx != 0 and subpattern_consumes(sub): + return True + # alternation: if any branch consumes, the whole does + elif ttype == sre_parse.BRANCH: + _, branches = tval + if any(subpattern_consumes(br) for br in branches): + return True + # grouped subpattern: recurse into its contents + elif ttype == sre_parse.SUBPATTERN and subpattern_consumes( + tval[3]): + return True + # No consumers, return False + return False + + tokens = parsed.data if hasattr(parsed, 'data') else parsed + for ttype, tval in tokens: + # Direct anchors or look-around + if ttype == sre_parse.AT or ttype in (sre_constants.ASSERT, + sre_constants.ASSERT_NOT): + return True + + # Nested subpattern: check + if ttype == sre_parse.SUBPATTERN: + # tval: (group, add_flags, del_flags, subpattern) + if _prefix_needs_context(tval[3]): + return True + if subpattern_consumes(tval[3]): + return False + + # if any branch has a prefix anchor => True, + # else if at least one branch consumes => prefix ends => False + elif ttype == sre_parse.BRANCH: + saw_consumer = False + for br in tval[1]: + if _prefix_needs_context(br): + return True + if subpattern_consumes(br): + saw_consumer = True + if saw_consumer: + return False + + # Immediate consumer tokens + elif ttype in (sre_parse.LITERAL, sre_parse.IN, sre_parse.ANY): + return False + + # if subpattern has anchor => True, if it can consume => stop + elif ttype == sre_parse.MAX_REPEAT: + if _prefix_needs_context(tval[2]): + return True + if subpattern_consumes(tval[2]): + return False + + return False + + +def _check_unsupported(parsed) -> None: + """Check for regex features unsupported by regex-automata""" + tokens = parsed.data if hasattr(parsed, 'data') else parsed + for ttype, tval in tokens: + + # backreference + if ttype in (sre_parse.GROUPREF, sre_parse.GROUPREF_EXISTS): + raise ValueError("Backreferences are unsupported.") + + # look-around assertion + elif ttype in (sre_constants.ASSERT, sre_constants.ASSERT_NOT): + raise ValueError("Look-Around assertion are unsupported.") + + # unicode word boundaries + elif ttype == sre_parse.AT: + if tval in (sre_constants.AT_BOUNDARY, + sre_constants.AT_NON_BOUNDARY): + raise ValueError("Unicode word boundaries are unsupported.") + + elif ttype == sre_parse.BRANCH: + # tval is (None, branches) + for branch in tval[1]: + _check_unsupported(branch) + + # tval is (min, max, subpattern) + elif ttype == sre_parse.MAX_REPEAT: + _check_unsupported(tval[2]) + + +def validate_regex_is_buildable(pattern: str) -> None: + """ + Validates that the input regex is not using unsupported features + of the `regex-automata` crate (outlines_core regex engine), and has a + universal start state. + definition of universal start state used can be found at: + https://docs.rs/regex-automata/latest/regex_automata/dfa/trait.Automaton.html#method.universal_start_state + """ + try: + parsed = sre_parse.parse(pattern) + + except sre_constants.error as e: + raise ValueError(f"Error parsing regex: {e}") from e + + try: + _check_unsupported(parsed) + except ValueError as e: + raise ValueError( + f"Regex uses unsupported feature for guided decoding: {e}. " + "Only basic matching constructs are supported—lookarounds, " + "backreferences, and unicode boundaries are not.") from e + + if _prefix_needs_context(parsed): + raise ValueError( + "Regex does not have a anchored universal start state" + "This means that the Regex uses anchors (^) or look-arounds " + "in a way which requires context before any token is matched." + "Guided decoding needs regexes that can match without needing " + "that context. Try rewriting the pattern without using these " + f"constructs. Pattern:\n{pattern}")