Skip to content

Commit cff8b1c

Browse files
author
Mu Huai
committed
fix bug
Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
1 parent cc97560 commit cff8b1c

File tree

2 files changed

+44
-43
lines changed

2 files changed

+44
-43
lines changed

vllm/v1/core/sched/scheduler.py

Lines changed: 43 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,13 @@
3636
class Scheduler(SchedulerInterface):
3737

3838
def __init__(
39-
self,
40-
vllm_config: VllmConfig,
41-
kv_cache_config: KVCacheConfig,
42-
structured_output_manager: StructuredOutputManager,
43-
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
44-
include_finished_set: bool = False,
45-
log_stats: bool = False,
39+
self,
40+
vllm_config: VllmConfig,
41+
kv_cache_config: KVCacheConfig,
42+
structured_output_manager: StructuredOutputManager,
43+
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
44+
include_finished_set: bool = False,
45+
log_stats: bool = False,
4646
) -> None:
4747
self.vllm_config = vllm_config
4848
self.scheduler_config = vllm_config.scheduler_config
@@ -65,8 +65,8 @@ def __init__(
6565
self.scheduler_config.max_num_batched_tokens
6666
self.max_model_len = self.scheduler_config.max_model_len
6767
self.enable_kv_cache_events = (
68-
self.kv_events_config is not None
69-
and self.kv_events_config.enable_kv_cache_events)
68+
self.kv_events_config is not None
69+
and self.kv_events_config.enable_kv_cache_events)
7070

7171
# Create KVConnector for the Scheduler. Note that each Worker
7272
# will have a corresponding KVConnector with Role=WORKER.
@@ -206,8 +206,8 @@ def schedule(self) -> SchedulerOutput:
206206
if request.has_encoder_inputs:
207207
(encoder_inputs_to_schedule, num_new_tokens,
208208
new_encoder_budget) = self._try_schedule_encoder_inputs(
209-
request, request.num_computed_tokens, num_new_tokens,
210-
encoder_budget)
209+
request, request.num_computed_tokens, num_new_tokens,
210+
encoder_budget)
211211

212212
if num_new_tokens == 0:
213213
# The request cannot be scheduled because one of the following
@@ -359,8 +359,8 @@ def schedule(self) -> SchedulerOutput:
359359
if request.has_encoder_inputs:
360360
(encoder_inputs_to_schedule, num_new_tokens,
361361
new_encoder_budget) = self._try_schedule_encoder_inputs(
362-
request, num_computed_tokens, num_new_tokens,
363-
encoder_budget)
362+
request, num_computed_tokens, num_new_tokens,
363+
encoder_budget)
364364
if num_new_tokens == 0:
365365
# The request cannot be scheduled.
366366
break
@@ -407,7 +407,7 @@ def schedule(self) -> SchedulerOutput:
407407
if self.lora_config and request.lora_request:
408408
scheduled_loras.add(request.lora_request.lora_int_id)
409409
req_to_new_block_ids[request.request_id] = (
410-
computed_blocks + new_blocks).get_block_ids()
410+
computed_blocks + new_blocks).get_block_ids()
411411
num_scheduled_tokens[request.request_id] = num_new_tokens
412412
token_budget -= num_new_tokens
413413
request.status = RequestStatus.RUNNING
@@ -522,19 +522,19 @@ def schedule(self) -> SchedulerOutput:
522522
return scheduler_output
523523

524524
def _make_cached_request_data(
525-
self,
526-
request: Request,
527-
num_scheduled_tokens: int,
528-
num_scheduled_spec_tokens: int,
529-
new_block_ids: list[int],
530-
resumed_from_preemption: bool,
525+
self,
526+
request: Request,
527+
num_scheduled_tokens: int,
528+
num_scheduled_spec_tokens: int,
529+
new_block_ids: list[int],
530+
resumed_from_preemption: bool,
531531
) -> CachedRequestData:
532532
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
533533
# them at each scheduling step.
534534
num_computed_tokens = request.num_computed_tokens
535535
num_regular_tokens = num_scheduled_tokens - num_scheduled_spec_tokens
536536
new_token_ids = request.all_token_ids[
537-
num_computed_tokens:num_computed_tokens + num_regular_tokens]
537+
num_computed_tokens:num_computed_tokens + num_regular_tokens]
538538

539539
req_data_queue = self._cached_reqs_data.get(request.request_id)
540540
if req_data_queue:
@@ -553,11 +553,11 @@ def _make_cached_request_data(
553553
return req_data
554554

555555
def _try_schedule_encoder_inputs(
556-
self,
557-
request: Request,
558-
num_computed_tokens: int,
559-
num_new_tokens: int,
560-
encoder_budget: int,
556+
self,
557+
request: Request,
558+
num_computed_tokens: int,
559+
num_new_tokens: int,
560+
encoder_budget: int,
561561
) -> tuple[list[int], int, int]:
562562
"""
563563
Determine which encoder inputs need to be scheduled in the current step,
@@ -636,9 +636,9 @@ def _try_schedule_encoder_inputs(
636636
return encoder_inputs_to_schedule, num_new_tokens, encoder_budget
637637

638638
def update_from_output(
639-
self,
640-
scheduler_output: SchedulerOutput,
641-
model_runner_output: ModelRunnerOutput,
639+
self,
640+
scheduler_output: SchedulerOutput,
641+
model_runner_output: ModelRunnerOutput,
642642
) -> EngineCoreOutputs:
643643
sampled_token_ids = model_runner_output.sampled_token_ids
644644
spec_token_ids = model_runner_output.spec_token_ids
@@ -749,8 +749,9 @@ def update_from_output(
749749
new_logprobs=new_logprobs,
750750
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
751751
stop_reason=request.stop_reason,
752-
events=request.take_events()),
752+
events=request.take_events(),
753753
trace_headers=request.trace_headers
754+
),
754755
)
755756
else:
756757
# Invariant: EngineCore returns no partial prefill outputs.
@@ -772,9 +773,9 @@ def update_from_output(
772773
scheduler_stats=self.make_stats(spec_decoding_stats),
773774
)
774775
if self.include_finished_set:
775-
#TODO currently sending duplicates here, improve this
776+
# TODO currently sending duplicates here, improve this
776777
engine_core_outputs.finished_requests = (
777-
scheduler_output.finished_req_ids | self.finished_req_ids)
778+
scheduler_output.finished_req_ids | self.finished_req_ids)
778779

779780
return engine_core_outputs
780781

@@ -785,9 +786,9 @@ def add_request(self, request: Request) -> None:
785786
request.record_event(EngineCoreEventType.QUEUED)
786787

787788
def finish_requests(
788-
self,
789-
request_ids: Union[str, Iterable[str]],
790-
finished_status: RequestStatus,
789+
self,
790+
request_ids: Union[str, Iterable[str]],
791+
finished_status: RequestStatus,
791792
) -> None:
792793
"""Handles the finish signal from outside the scheduler.
793794
@@ -796,7 +797,7 @@ def finish_requests(
796797
"""
797798
assert RequestStatus.is_finished(finished_status)
798799
if isinstance(request_ids, str):
799-
request_ids = (request_ids, )
800+
request_ids = (request_ids,)
800801
else:
801802
request_ids = set(request_ids)
802803

@@ -832,8 +833,8 @@ def reset_prefix_cache(self) -> bool:
832833
return self.kv_cache_manager.reset_prefix_cache()
833834

834835
def make_stats(
835-
self,
836-
spec_decoding_stats: Optional[SpecDecodingStats] = None,
836+
self,
837+
spec_decoding_stats: Optional[SpecDecodingStats] = None,
837838
) -> Optional[SchedulerStats]:
838839
if not self.log_stats:
839840
return None
@@ -848,10 +849,10 @@ def make_stats(
848849
)
849850

850851
def make_spec_decoding_stats(
851-
self,
852-
spec_decoding_stats: Optional[SpecDecodingStats],
853-
num_draft_tokens: int,
854-
num_accepted_tokens: int,
852+
self,
853+
spec_decoding_stats: Optional[SpecDecodingStats],
854+
num_draft_tokens: int,
855+
num_accepted_tokens: int,
855856
) -> Optional[SpecDecodingStats]:
856857
if not self.log_stats:
857858
return None

vllm/v1/engine/output_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ def process_outputs(
397397
def do_tracing(self, engine_core_output: EngineCoreOutput,
398398
req_state: RequestState,
399399
iteration_stats: Optional[IterationStats]):
400-
if engine_core_output.finish_reason is None or iteration_stats is None:
400+
if engine_core_output.finish_reason is None or iteration_stats is None or req_state is None or req_state.stats is None:
401401
return
402402
arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9)
403403

0 commit comments

Comments
 (0)