Skip to content

Commit f60e62d

Browse files
committed
feat: support parsing thinking tokens
Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
1 parent 7a77454 commit f60e62d

File tree

5 files changed

+109
-24
lines changed

5 files changed

+109
-24
lines changed

tests/v1/entrypoints/llm/test_struct_output_generate.py

Lines changed: 72 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,37 @@
1515
from vllm.outputs import RequestOutput
1616
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
1717

18-
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
19-
("mistralai/Ministral-8B-Instruct-2410", "xgrammar:disable-any-whitespace",
20-
"auto"),
21-
("mistralai/Ministral-8B-Instruct-2410", "guidance:disable-any-whitespace",
22-
"auto"),
23-
("mistralai/Ministral-8B-Instruct-2410", "xgrammar:disable-any-whitespace",
24-
"mistral"),
25-
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar:disable-any-whitespace", "auto"),
18+
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE_REASONING_PARSER = [
19+
(
20+
"mistralai/Ministral-8B-Instruct-2410",
21+
"xgrammar:disable-any-whitespace",
22+
"auto",
23+
None,
24+
),
25+
(
26+
"mistralai/Ministral-8B-Instruct-2410",
27+
"guidance:disable-any-whitespace",
28+
"auto",
29+
None,
30+
),
31+
(
32+
"mistralai/Ministral-8B-Instruct-2410",
33+
"xgrammar:disable-any-whitespace",
34+
"mistral",
35+
None,
36+
),
37+
(
38+
"Qwen/Qwen2.5-1.5B-Instruct",
39+
"xgrammar:disable-any-whitespace",
40+
"auto",
41+
None,
42+
),
43+
(
44+
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
45+
"xgrammar:disable-any-whitespace",
46+
"auto",
47+
"deepseek_r1",
48+
),
2649
#FIXME: This test is flaky on CI thus disabled
2750
#("Qwen/Qwen2.5-1.5B-Instruct", "guidance:disable-any-whitespace", "auto"),
2851
]
@@ -47,8 +70,9 @@ class CarDescription(BaseModel):
4770

4871

4972
@pytest.mark.skip_global_cleanup
50-
@pytest.mark.parametrize("model_name, guided_decoding_backend, tokenizer_mode",
51-
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE)
73+
@pytest.mark.parametrize(
74+
"model_name, guided_decoding_backend, tokenizer_mode, reasoning_parser",
75+
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE_REASONING_PARSER)
5276
def test_structured_output(
5377
monkeypatch: pytest.MonkeyPatch,
5478
sample_json_schema: dict[str, Any],
@@ -59,6 +83,7 @@ def test_structured_output(
5983
sample_guided_choice: str,
6084
guided_decoding_backend: str,
6185
tokenizer_mode: str,
86+
reasoning_parser: str | None,
6287
model_name: str,
6388
):
6489
monkeypatch.setenv("VLLM_USE_V1", "1")
@@ -69,7 +94,9 @@ def test_structured_output(
6994
enforce_eager=True,
7095
max_model_len=1024,
7196
guided_decoding_backend=guided_decoding_backend,
72-
tokenizer_mode=tokenizer_mode)
97+
tokenizer_mode=tokenizer_mode,
98+
enable_reasoning=reasoning_parser is not None,
99+
reasoning_parser=reasoning_parser)
73100

74101
#
75102
# Test 1: Generate JSON output based on a provided schema
@@ -364,6 +391,40 @@ def test_structured_output(
364391
output_json = json.loads(generated_text)
365392
jsonschema.validate(instance=output_json, schema=json_schema)
366393

394+
#
395+
# Test 11: Generate structured output with reasoning step
396+
#
397+
if reasoning_parser is not None:
398+
reasoning_prompt = "Solve the following math problem step-by-step, then provide the final answer as JSON object with a single key 'result'. Problem: What is 5 * 8 + 2?" # noqa: E501
399+
reasoning_schema = {
400+
"type": "object",
401+
"properties": {
402+
"result": {
403+
"type": "integer"
404+
}
405+
},
406+
"required": ["result"]
407+
}
408+
409+
sampling_params = SamplingParams(
410+
temperature=0.1, # Low temp for deterministic reasoning
411+
max_tokens=200,
412+
guided_decoding=GuidedDecodingParams(json=reasoning_schema))
413+
outputs = llm.generate(prompts=[reasoning_prompt],
414+
sampling_params=sampling_params,
415+
use_tqdm=True)
416+
417+
assert outputs is not None
418+
output = outputs[0]
419+
assert output is not None
420+
assert isinstance(output, RequestOutput)
421+
prompt = output.prompt
422+
generated_text = output.outputs[0].text
423+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
424+
425+
output_json = json.loads(generated_text)
426+
jsonschema.validate(instance=output_json, schema=reasoning_schema)
427+
367428

368429
@pytest.mark.skip_global_cleanup
369430
@pytest.mark.parametrize("model_name, tokenizer_mode",

vllm/reasoning/abs_reasoning_parsers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ class ReasoningParserManager:
106106
reasoning_parsers: dict[str, type] = {}
107107

108108
@classmethod
109-
def get_reasoning_parser(cls, name: str) -> type[ReasoningParser]:
109+
def get_reasoning_parser(cls, name: str | None) -> type[ReasoningParser]:
110110
"""
111111
Get reasoning parser by name which is registered by `register_module`.
112112

vllm/v1/core/sched/scheduler.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -653,11 +653,20 @@ def update_from_output(
653653
new_logprobs = logprobs.slice(req_index, req_index + 1)
654654

655655
if new_token_ids and request.use_structured_output:
656-
# NOTE: structured_output_request
657-
# should not be None if use_structured_output, we have
658-
# check above, so safe to ignore type warning
659-
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
660-
req_id, new_token_ids)
656+
advance_fsm = False
657+
reasoner = self.structured_output_manager.reasoner
658+
if reasoner is None or request.reasoning_ended:
659+
advance_fsm = True
660+
elif reasoner.is_reasoning_end(request.all_token_ids):
661+
request.reasoning_ended = True
662+
advance_fsm = True
663+
664+
if advance_fsm:
665+
# NOTE: structured_output_request
666+
# should not be None if use_structured_output, we have
667+
# check above, so safe to ignore type warning
668+
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
669+
req_id, new_token_ids)
661670

662671
# Get prompt logprobs for this request.
663672
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)

vllm/v1/request.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
self.eos_token_id = eos_token_id
3838
self.lora_request = lora_request
3939
self.structured_output_request = structured_output_request
40+
self.reasoning_ended: bool = False
4041

4142
self.status = (RequestStatus.WAITING_FOR_FSM
4243
if sampling_params.guided_decoding is not None else

vllm/v1/structured_output/__init__.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -119,15 +119,29 @@ def grammar_bitmask(
119119
# position in the batch. Resize the bitmask down to the size of
120120
# the batch.
121121
bitmask_tensor = self._grammar_bitmask
122+
# Reset the relevant part of the bitmask before filling
123+
if batch_len > 0:
124+
bitmask_tensor[:batch_len].fill_(-1)
125+
122126
for req_id, batch_index in structured_output_request_ids.items():
123-
request = requests[req_id].structured_output_request
124-
assert request is not None and request.grammar is not None
125-
if not request.grammar.is_terminated():
126-
request.grammar.fill_bitmask(bitmask_tensor, batch_index)
127-
if batch_len < self._grammar_bitmask.shape[0]:
128-
bitmask_tensor = self._grammar_bitmask[:batch_len]
127+
full_request = requests[req_id]
128+
so_request = full_request.structured_output_request
129+
assert so_request is not None and so_request.grammar is not None
130+
131+
apply_bitmask = (self.reasoner is None
132+
or full_request.reasoning_ended
133+
or self.reasoner.is_reasoning_end(
134+
full_request.all_token_ids))
135+
136+
if apply_bitmask and not so_request.grammar.is_terminated():
137+
so_request.grammar.fill_bitmask(bitmask_tensor, batch_index)
138+
139+
if batch_len < bitmask_tensor.shape[0]:
140+
final_bitmask_tensor = bitmask_tensor[:batch_len]
141+
else:
142+
final_bitmask_tensor = bitmask_tensor
129143

130144
# After finishing with the xgrammar operations, we convert to
131145
# np.ndarray, because that is much more efficient for serialization
132146
# and deserialization when sending this to the GPU workers.
133-
return bitmask_tensor.numpy()
147+
return final_bitmask_tensor.numpy()

0 commit comments

Comments
 (0)