Skip to content

Commit 17bd581

Browse files
WoosukKwonjimpang
authored and
jimpang
committed
[V1][Spec decode] Move drafter to model runner (vllm-project#13363)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
1 parent 6155ee9 commit 17bd581

File tree

9 files changed

+84
-57
lines changed

9 files changed

+84
-57
lines changed

tests/v1/core/test_scheduler.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ def test_schedule_partial_requests():
203203
req_ids=[request.request_id for request in requests],
204204
req_id_to_index=req_to_index,
205205
sampled_token_ids=[[0] for _ in range(len(requests))],
206+
spec_token_ids=None,
206207
logprobs=None,
207208
prompt_logprobs_dict={},
208209
)
@@ -259,6 +260,7 @@ def test_stop_via_update_from_output():
259260
sampled_token_ids=[[EOS_TOKEN_ID],
260261
[10,
261262
11]], # First request hits EOS, second continues
263+
spec_token_ids=None,
262264
logprobs=None,
263265
prompt_logprobs_dict={})
264266

@@ -307,6 +309,7 @@ def test_stop_via_update_from_output():
307309
},
308310
sampled_token_ids=[[10, 42, 12],
309311
[13, 14]], # First request hits stop token
312+
spec_token_ids=None,
310313
logprobs=None,
311314
prompt_logprobs_dict={})
312315

@@ -354,6 +357,7 @@ def test_stop_via_update_from_output():
354357
},
355358
sampled_token_ids=[[10, 11, 12],
356359
[13]], # First request exceeds max_tokens
360+
spec_token_ids=None,
357361
logprobs=None,
358362
prompt_logprobs_dict={})
359363

@@ -394,6 +398,7 @@ def test_stop_via_update_from_output():
394398
req_ids=[requests[0].request_id],
395399
req_id_to_index={requests[0].request_id: 0},
396400
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
401+
spec_token_ids=None,
397402
logprobs=None,
398403
prompt_logprobs_dict={})
399404

@@ -434,6 +439,7 @@ def test_schedule_concurrent_batches():
434439
req_ids=[requests[0].request_id],
435440
req_id_to_index={requests[0].request_id: 0},
436441
sampled_token_ids=[[0]],
442+
spec_token_ids=None,
437443
logprobs=None,
438444
prompt_logprobs_dict={},
439445
)
@@ -450,6 +456,7 @@ def test_schedule_concurrent_batches():
450456
req_ids=[requests[1].request_id],
451457
req_id_to_index={requests[1].request_id: 0},
452458
sampled_token_ids=[[0]],
459+
spec_token_ids=None,
453460
logprobs=None,
454461
prompt_logprobs_dict={},
455462
)

vllm/v1/core/scheduler.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,7 @@ def update_from_output(
474474
model_runner_output: "ModelRunnerOutput",
475475
) -> EngineCoreOutputs:
476476
sampled_token_ids = model_runner_output.sampled_token_ids
477+
spec_token_ids = model_runner_output.spec_token_ids
477478
logprobs = model_runner_output.logprobs
478479
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
479480
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
@@ -530,13 +531,9 @@ def update_from_output(
530531
self.encoder_cache_manager.free_encoder_input(
531532
request, input_id)
532533

533-
if request.num_computed_tokens >= request.num_tokens:
534-
# Clear the spec tokens as the request has generated
535-
# a new token. Here, We assume all spec tokens are verified
536-
# if we perform speculative decoding for this request.
537-
# Therefore, we can clear all spec tokens after
538-
# the generation step.
539-
request.clear_spec_tokens()
534+
# Add newly generated spec token ids to the request.
535+
if spec_token_ids is not None:
536+
request.spec_token_ids = spec_token_ids[req_index]
540537

541538
# Get prompt logprobs for this request.
542539
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)

vllm/v1/engine/core.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from vllm.v1.outputs import ModelRunnerOutput
2828
from vllm.v1.request import Request, RequestStatus
2929
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
30-
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
3130
from vllm.version import __version__ as VLLM_VERSION
3231

3332
logger = init_logger(__name__)
@@ -86,15 +85,6 @@ def __init__(
8685
self.batch_queue_size)
8786
self.batch_queue = queue.Queue(self.batch_queue_size)
8887

89-
# Setup speculative decode.
90-
# TODO: find a better way to check if we are using ngram.
91-
self.use_spec_decode = False
92-
if self.scheduler.speculative_config:
93-
assert self.scheduler.speculative_config.ngram_prompt_lookup_min \
94-
, "Only ngram spec decode is supported in V1."
95-
self.proposer = NgramProposer()
96-
self.use_spec_decode = True
97-
9888
def _initialize_kv_caches(self,
9989
vllm_config: VllmConfig) -> Tuple[int, int]:
10090
start = time.time()
@@ -158,9 +148,6 @@ def step(self) -> EngineCoreOutputs:
158148
return EngineCoreOutputs(
159149
outputs=[], scheduler_stats=self.scheduler.make_stats())
160150

161-
if self.use_spec_decode:
162-
self.propose_tokens()
163-
164151
scheduler_output = self.scheduler.schedule()
165152
output = self.model_executor.execute_model(scheduler_output)
166153
engine_core_outputs = self.scheduler.update_from_output(
@@ -221,23 +208,6 @@ def shutdown(self):
221208
def profile(self, is_start: bool = True):
222209
self.model_executor.profile(is_start)
223210

224-
def propose_tokens(self):
225-
assert self.scheduler.speculative_config is not None
226-
for req in self.scheduler.running:
227-
# Ignore requests that are doing chunked prefill.
228-
if req.num_computed_tokens < req.num_tokens - 1:
229-
continue
230-
# Ignore requests that already have spec tokens.
231-
if req.spec_token_ids:
232-
continue
233-
spec_tokens = self.proposer.propose(
234-
req.all_token_ids,
235-
self.scheduler.speculative_config.ngram_prompt_lookup_min,
236-
self.scheduler.speculative_config.num_speculative_tokens,
237-
)
238-
if spec_tokens:
239-
req.append_spec_token_ids(spec_tokens)
240-
241211
def reset_prefix_cache(self):
242212
self.scheduler.reset_prefix_cache()
243213

vllm/v1/outputs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ class ModelRunnerOutput:
6767
# each request due to speculative/jump decoding.
6868
sampled_token_ids: List[List[int]]
6969

70+
# num_reqs x num_spec_tokens
71+
spec_token_ids: Optional[List[List[int]]]
72+
7073
# [num_reqs, max_num_logprobs + 1]
7174
# [num_reqs, max_num_logprobs + 1]
7275
# [num_reqs]

vllm/v1/request.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -104,18 +104,6 @@ def append_output_token_ids(
104104
self._output_token_ids.extend(token_ids)
105105
self._all_token_ids.extend(token_ids)
106106

107-
def append_spec_token_ids(
108-
self,
109-
token_ids: Union[int, List[int]],
110-
) -> None:
111-
if isinstance(token_ids, int):
112-
self.spec_token_ids.append(token_ids)
113-
else:
114-
self.spec_token_ids.extend(token_ids)
115-
116-
def clear_spec_tokens(self) -> None:
117-
self.spec_token_ids.clear()
118-
119107
@property
120108
def num_tokens(self) -> int:
121109
return len(self._all_token_ids)

vllm/v1/spec_decode/ngram_proposer.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
# SPDX-License-Identifier: Apache-2.0
22
from typing import List, Optional
33

4-
from vllm.v1.utils import ConstantList
4+
import numpy as np
55

66

77
class NgramProposer:
88

99
def __init__(self):
1010
pass
1111

12-
def propose(self, context_token_ids: ConstantList[int], n: int,
13-
k: int) -> Optional[List[int]]:
12+
def propose(
13+
self,
14+
context_token_ids: np.ndarray,
15+
n: int,
16+
k: int,
17+
) -> Optional[np.ndarray]:
1418
"""Proposes the next sequence of tokens based on n-gram pattern
1519
matching in the context. The function finds matches of the last n
1620
tokens in the previous context, and returns k tokens that followed
@@ -25,8 +29,8 @@ def propose(self, context_token_ids: ConstantList[int], n: int,
2529
the maximum amount of tokens until the end.
2630
2731
Returns:
28-
List[int]: The sequence of tokens that followed
29-
the matched n-gram in the context.
32+
np.ndarray: The sequence of tokens that followed
33+
the matched n-gram in the context.
3034
None: If no matching n-gram pattern is found.
3135
3236
Example:
@@ -66,9 +70,12 @@ def _kmp_lps_array(pattern: List[int]) -> List[int]:
6670
return lps
6771

6872
@staticmethod
69-
def _find_subarray_kmp(context_token_ids: ConstantList[int], n: int,
70-
k: int) -> Optional[List[int]]:
71-
context_len = len(context_token_ids)
73+
def _find_subarray_kmp(
74+
context_token_ids: np.ndarray,
75+
n: int,
76+
k: int,
77+
) -> Optional[np.ndarray]:
78+
context_len = context_token_ids.shape[0]
7279
assert n > 0
7380

7481
pattern = context_token_ids[-n:]

vllm/v1/worker/gpu_input_batch.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def __init__(
7878
)
7979
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
8080
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
81+
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
8182
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
8283
self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32)
8384

@@ -217,7 +218,11 @@ def add_request(
217218
end_idx = start_idx + len(request.output_token_ids)
218219
self.token_ids_cpu[req_index,
219220
start_idx:end_idx] = request.output_token_ids
221+
# Number of token ids in token_ids_cpu.
222+
# NOTE(woosuk): This may include spec decode tokens.
220223
self.num_tokens[req_index] = request.num_tokens
224+
# Number of tokens without spec decode tokens.
225+
self.num_tokens_no_spec[req_index] = request.num_tokens
221226

222227
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
223228
self.block_table.add_row(req_index, request.block_ids)
@@ -356,6 +361,8 @@ def condense(self, empty_req_indices: List[int]) -> None:
356361
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
357362
last_req_index, :num_tokens]
358363
self.num_tokens[empty_index] = num_tokens
364+
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
365+
last_req_index]
359366
self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[
360367
last_req_index]
361368
self.num_computed_tokens_cpu[

vllm/v1/worker/gpu_model_runner.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
3434
from vllm.v1.sample.metadata import SamplingMetadata
3535
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID
36+
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
3637
from vllm.v1.utils import bind_kv_cache
3738
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
3839
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
@@ -117,6 +118,15 @@ def __init__(
117118
# req_id -> (input_id -> encoder_output)
118119
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
119120

121+
# Set up speculative decoding.
122+
self.use_spec_decode = False
123+
if self.speculative_config:
124+
# TODO: find a better way to check if we are using ngram.
125+
assert self.speculative_config.ngram_prompt_lookup_min, \
126+
"Currently, only ngram spec decode is supported in V1."
127+
self.drafter = NgramProposer()
128+
self.use_spec_decode = True
129+
120130
# Request states.
121131
self.requests: Dict[str, CachedRequestState] = {}
122132
# Persistent batch.
@@ -367,6 +377,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
367377
self.input_batch.token_ids_cpu[
368378
req_index,
369379
start_token_index:end_token_index] = req_data.new_token_ids
380+
self.input_batch.num_tokens_no_spec[req_index] = end_token_index
370381
# Add spec_token_ids to token_ids_cpu.
371382
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
372383
req_id, [])
@@ -1009,15 +1020,51 @@ def execute_model(
10091020
for seq in sampled_token_ids[valid_mask].split(gen_lens)
10101021
]
10111022

1023+
if not self.use_spec_decode:
1024+
spec_token_ids = None
1025+
else:
1026+
spec_token_ids = self.generate_draft_token_ids(
1027+
valid_sampled_token_ids)
1028+
10121029
model_runner_output = ModelRunnerOutput(
10131030
req_ids=req_ids,
10141031
req_id_to_index=self.input_batch.req_id_to_index,
10151032
sampled_token_ids=valid_sampled_token_ids,
1033+
spec_token_ids=spec_token_ids,
10161034
logprobs=logprobs_lists,
10171035
prompt_logprobs_dict=prompt_logprobs_dict,
10181036
)
10191037
return model_runner_output
10201038

1039+
def generate_draft_token_ids(
1040+
self,
1041+
sampled_token_ids: List[List[int]],
1042+
) -> List[List[int]]:
1043+
# TODO(woosuk): Optimize.
1044+
num_reqs = len(sampled_token_ids)
1045+
draft_token_ids: List[List[int]] = []
1046+
for i in range(num_reqs):
1047+
if len(sampled_token_ids[i]) == 0:
1048+
# Skip speculative decoding.
1049+
draft_token_ids.append([])
1050+
continue
1051+
1052+
# Add sampled_token_ids to token_ids_cpu.
1053+
start_idx = self.input_batch.num_tokens_no_spec[i]
1054+
end_idx = start_idx + len(sampled_token_ids[i])
1055+
self.input_batch.token_ids_cpu[
1056+
i, start_idx:end_idx] = sampled_token_ids[i]
1057+
drafter_output = self.drafter.propose(
1058+
self.input_batch.token_ids_cpu[i, :end_idx],
1059+
self.speculative_config.ngram_prompt_lookup_min,
1060+
self.speculative_config.num_speculative_tokens,
1061+
)
1062+
if drafter_output is None or len(drafter_output) == 0:
1063+
draft_token_ids.append([])
1064+
else:
1065+
draft_token_ids.append(drafter_output.tolist())
1066+
return draft_token_ids
1067+
10211068
def load_model(self) -> None:
10221069
logger.info("Starting to load model %s...", self.model_config.model)
10231070
with DeviceMemoryProfiler() as m: # noqa: SIM117

vllm/v1/worker/tpu_model_runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,7 @@ def execute_model(
696696
req_ids=all_req_ids,
697697
req_id_to_index=self.input_batch.req_id_to_index,
698698
sampled_token_ids=[[token_id] for token_id in sampled_token_ids],
699+
spec_token_ids=None,
699700
logprobs=None,
700701
prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore[arg-type]
701702
)

0 commit comments

Comments
 (0)