Skip to content

Commit 08bf784

Browse files
Jason-CKYrussellb
andauthored
[Bugfix] validate grammar and throw 400 error instead of crashing the engine when xgrammar validation fails (#17623)
Signed-off-by: Jason Cheng <jasoncky96@gmail.com> Co-authored-by: Russell Bryant <rbryant@redhat.com>
1 parent d45fe33 commit 08bf784

File tree

4 files changed

+240
-1
lines changed

4 files changed

+240
-1
lines changed
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import openai # use the official client for correctness check
4+
import pytest
5+
import pytest_asyncio
6+
7+
from tests.utils import RemoteOpenAIServer
8+
9+
# any model with a chat template defined in tokenizer_config should work here
10+
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
11+
12+
13+
@pytest.fixture(scope="module")
14+
def default_server_args():
15+
return [
16+
# use half precision for speed and memory savings in CI environment
17+
"--max-model-len",
18+
"2048",
19+
"--max-num-seqs",
20+
"128",
21+
"--enforce-eager",
22+
]
23+
24+
25+
@pytest.fixture(scope="module")
26+
def server(default_server_args):
27+
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
28+
yield remote_server
29+
30+
31+
@pytest_asyncio.fixture
32+
async def client(server):
33+
async with server.get_async_client() as async_client:
34+
yield async_client
35+
36+
37+
@pytest.mark.asyncio
38+
@pytest.mark.parametrize(
39+
"model_name",
40+
[MODEL_NAME],
41+
)
42+
async def test_invalid_json_schema(client: openai.AsyncOpenAI,
43+
model_name: str) -> None:
44+
invalid_json_schema = {
45+
"$defs": {
46+
"CarType": {
47+
"enum": ["sedan", "SUV", "Truck", "Coupe"],
48+
"title": "CarType",
49+
"type": "string",
50+
}
51+
},
52+
"properties": {
53+
"brand": {
54+
"title": "Brand",
55+
"type": "string"
56+
},
57+
"model": {
58+
"title": "Model",
59+
"type": "string"
60+
},
61+
"car_type": {
62+
"$ref": "#/$defs/CarType"
63+
},
64+
"foo": "bar",
65+
},
66+
"required": ["brand", "model", "car_type"],
67+
"title": "CarDescription",
68+
"type": "object",
69+
}
70+
prompt = ("Generate a JSON with the brand, model and car_type of"
71+
"the most iconic car from the 90's")
72+
with pytest.raises((openai.BadRequestError, openai.APIError)):
73+
await client.chat.completions.create(
74+
model=model_name,
75+
messages=[{
76+
"role": "user",
77+
"content": prompt,
78+
}],
79+
extra_body={"guided_json": invalid_json_schema},
80+
)
81+
82+
83+
@pytest.mark.asyncio
84+
@pytest.mark.parametrize(
85+
"model_name",
86+
[MODEL_NAME],
87+
)
88+
async def test_invalid_regex(client: openai.AsyncOpenAI, model_name: str):
89+
prompt = ("Generate an email address for Alan Turing, who works in Enigma."
90+
"End in .com and new line. Example result:"
91+
"alan.turing@enigma.com\n")
92+
93+
with pytest.raises((openai.BadRequestError, openai.APIError)):
94+
await client.chat.completions.create(
95+
model=model_name,
96+
messages=[{
97+
"role": "user",
98+
"content": prompt,
99+
}],
100+
extra_body={
101+
"guided_regex": r"[.*",
102+
"stop": ["\n"]
103+
},
104+
)
105+
106+
107+
@pytest.mark.asyncio
108+
@pytest.mark.parametrize(
109+
"model_name",
110+
[MODEL_NAME],
111+
)
112+
async def test_invalid_grammar(client: openai.AsyncOpenAI, model_name: str):
113+
invalid_simplified_sql_grammar = """
114+
root ::= select_statementinvalidsyntax
115+
116+
select_statement ::= "SELECT " column " from " table " where " condition
117+
118+
column ::= "col_1 " | "col_2 "
119+
120+
table ::= "table_1 " | "table_2 "
121+
122+
condition ::= column "= " number
123+
124+
number ::= "1 " | "2 "
125+
"""
126+
127+
prompt = ("Generate an SQL query to show the 'username' and 'email'"
128+
"from the 'users' table.")
129+
with pytest.raises((openai.BadRequestError, openai.APIError)):
130+
await client.chat.completions.create(
131+
model=model_name,
132+
messages=[{
133+
"role": "user",
134+
"content": prompt,
135+
}],
136+
extra_body={"guided_grammar": invalid_simplified_sql_grammar},
137+
)

tests/v1/entrypoints/openai/test_completion.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,3 +584,97 @@ async def test_echo_logprob_completion(client: openai.AsyncOpenAI,
584584
assert max(logprobs_arg,
585585
1) <= len(top_logprobs) <= logprobs_arg + 1
586586
assert len(logprobs.tokens) > 5
587+
588+
589+
@pytest.mark.asyncio
590+
@pytest.mark.parametrize(
591+
"model_name",
592+
[MODEL_NAME],
593+
)
594+
async def test_invalid_json_schema(client: openai.AsyncOpenAI,
595+
model_name: str) -> None:
596+
invalid_json_schema = {
597+
"$defs": {
598+
"CarType": {
599+
"enum": ["sedan", "SUV", "Truck", "Coupe"],
600+
"title": "CarType",
601+
"type": "string",
602+
}
603+
},
604+
"properties": {
605+
"brand": {
606+
"title": "Brand",
607+
"type": "string"
608+
},
609+
"model": {
610+
"title": "Model",
611+
"type": "string"
612+
},
613+
"car_type": {
614+
"$ref": "#/$defs/CarType"
615+
},
616+
"foo": "bar",
617+
},
618+
"required": ["brand", "model", "car_type"],
619+
"title": "CarDescription",
620+
"type": "object",
621+
}
622+
prompt = ("Generate a JSON with the brand, model and car_type of"
623+
"the most iconic car from the 90's")
624+
with pytest.raises((openai.BadRequestError, openai.APIError)):
625+
await client.completions.create(
626+
model=model_name,
627+
prompt=prompt,
628+
extra_body={"guided_json": invalid_json_schema},
629+
)
630+
631+
632+
@pytest.mark.asyncio
633+
@pytest.mark.parametrize(
634+
"model_name",
635+
[MODEL_NAME],
636+
)
637+
async def test_invalid_regex(client: openai.AsyncOpenAI, model_name: str):
638+
prompt = ("Generate an email address for Alan Turing, who works in Enigma."
639+
"End in .com and new line. Example result:"
640+
"alan.turing@enigma.com\n")
641+
642+
with pytest.raises((openai.BadRequestError, openai.APIError)):
643+
await client.completions.create(
644+
model=model_name,
645+
prompt=prompt,
646+
extra_body={
647+
"guided_regex": r"[.*",
648+
"stop": ["\n"]
649+
},
650+
)
651+
652+
653+
@pytest.mark.asyncio
654+
@pytest.mark.parametrize(
655+
"model_name",
656+
[MODEL_NAME],
657+
)
658+
async def test_invalid_grammar(client: openai.AsyncOpenAI, model_name: str):
659+
invalid_simplified_sql_grammar = """
660+
root ::= select_statementinvalidsyntax
661+
662+
select_statement ::= "SELECT " column " from " table " where " condition
663+
664+
column ::= "col_1 " | "col_2 "
665+
666+
table ::= "table_1 " | "table_2 "
667+
668+
condition ::= column "= " number
669+
670+
number ::= "1 " | "2 "
671+
"""
672+
673+
prompt = ("Generate an SQL query to show the 'username' and 'email'"
674+
"from the 'users' table.")
675+
with pytest.raises((openai.BadRequestError, openai.APIError)):
676+
await client.completions.create(
677+
model=model_name,
678+
prompt=prompt,
679+
extra_body={"guided_grammar": invalid_simplified_sql_grammar},
680+
)

vllm/v1/engine/processor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,10 @@ def _validate_structured_output(self, params: SamplingParams) -> None:
188188
validate_xgrammar_grammar(params)
189189
params.guided_decoding.backend = "xgrammar"
190190
except ValueError:
191-
# The request includes some jsonschema feature(s) that
191+
# The request either failed validation
192+
# or includes some jsonschema feature(s) that
192193
# are not supported in xgrammar. Fall back to guidance.
194+
validate_guidance_grammar(params, tokenizer=None)
193195
params.guided_decoding.backend = "guidance"
194196
# Remember that this backend was set automatically
195197
params.guided_decoding.backend_was_auto = True

vllm/v1/structured_output/backend_xgrammar.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,12 @@ def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None:
282282
else:
283283
schema = gd_params.json
284284

285+
try:
286+
xgr.Grammar.from_json_schema(schema)
287+
except Exception as err:
288+
raise ValueError("Failed to transform json schema into a grammar: "
289+
f"{err}") from err
290+
285291
if has_xgrammar_unsupported_json_features(schema):
286292
raise ValueError("The provided JSON schema contains features not "
287293
"supported by xgrammar.")

0 commit comments

Comments
 (0)