36
36
class Scheduler (SchedulerInterface ):
37
37
38
38
def __init__ (
39
- self ,
40
- vllm_config : VllmConfig ,
41
- kv_cache_config : KVCacheConfig ,
42
- structured_output_manager : StructuredOutputManager ,
43
- mm_registry : MultiModalRegistry = MULTIMODAL_REGISTRY ,
44
- include_finished_set : bool = False ,
45
- log_stats : bool = False ,
39
+ self ,
40
+ vllm_config : VllmConfig ,
41
+ kv_cache_config : KVCacheConfig ,
42
+ structured_output_manager : StructuredOutputManager ,
43
+ mm_registry : MultiModalRegistry = MULTIMODAL_REGISTRY ,
44
+ include_finished_set : bool = False ,
45
+ log_stats : bool = False ,
46
46
) -> None :
47
47
self .vllm_config = vllm_config
48
48
self .scheduler_config = vllm_config .scheduler_config
@@ -65,8 +65,8 @@ def __init__(
65
65
self .scheduler_config .max_num_batched_tokens
66
66
self .max_model_len = self .scheduler_config .max_model_len
67
67
self .enable_kv_cache_events = (
68
- self .kv_events_config is not None
69
- and self .kv_events_config .enable_kv_cache_events )
68
+ self .kv_events_config is not None
69
+ and self .kv_events_config .enable_kv_cache_events )
70
70
71
71
# Create KVConnector for the Scheduler. Note that each Worker
72
72
# will have a corresponding KVConnector with Role=WORKER.
@@ -206,8 +206,8 @@ def schedule(self) -> SchedulerOutput:
206
206
if request .has_encoder_inputs :
207
207
(encoder_inputs_to_schedule , num_new_tokens ,
208
208
new_encoder_budget ) = self ._try_schedule_encoder_inputs (
209
- request , request .num_computed_tokens , num_new_tokens ,
210
- encoder_budget )
209
+ request , request .num_computed_tokens , num_new_tokens ,
210
+ encoder_budget )
211
211
212
212
if num_new_tokens == 0 :
213
213
# The request cannot be scheduled because one of the following
@@ -359,8 +359,8 @@ def schedule(self) -> SchedulerOutput:
359
359
if request .has_encoder_inputs :
360
360
(encoder_inputs_to_schedule , num_new_tokens ,
361
361
new_encoder_budget ) = self ._try_schedule_encoder_inputs (
362
- request , num_computed_tokens , num_new_tokens ,
363
- encoder_budget )
362
+ request , num_computed_tokens , num_new_tokens ,
363
+ encoder_budget )
364
364
if num_new_tokens == 0 :
365
365
# The request cannot be scheduled.
366
366
break
@@ -407,7 +407,7 @@ def schedule(self) -> SchedulerOutput:
407
407
if self .lora_config and request .lora_request :
408
408
scheduled_loras .add (request .lora_request .lora_int_id )
409
409
req_to_new_block_ids [request .request_id ] = (
410
- computed_blocks + new_blocks ).get_block_ids ()
410
+ computed_blocks + new_blocks ).get_block_ids ()
411
411
num_scheduled_tokens [request .request_id ] = num_new_tokens
412
412
token_budget -= num_new_tokens
413
413
request .status = RequestStatus .RUNNING
@@ -522,19 +522,19 @@ def schedule(self) -> SchedulerOutput:
522
522
return scheduler_output
523
523
524
524
def _make_cached_request_data (
525
- self ,
526
- request : Request ,
527
- num_scheduled_tokens : int ,
528
- num_scheduled_spec_tokens : int ,
529
- new_block_ids : list [int ],
530
- resumed_from_preemption : bool ,
525
+ self ,
526
+ request : Request ,
527
+ num_scheduled_tokens : int ,
528
+ num_scheduled_spec_tokens : int ,
529
+ new_block_ids : list [int ],
530
+ resumed_from_preemption : bool ,
531
531
) -> CachedRequestData :
532
532
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
533
533
# them at each scheduling step.
534
534
num_computed_tokens = request .num_computed_tokens
535
535
num_regular_tokens = num_scheduled_tokens - num_scheduled_spec_tokens
536
536
new_token_ids = request .all_token_ids [
537
- num_computed_tokens :num_computed_tokens + num_regular_tokens ]
537
+ num_computed_tokens :num_computed_tokens + num_regular_tokens ]
538
538
539
539
req_data_queue = self ._cached_reqs_data .get (request .request_id )
540
540
if req_data_queue :
@@ -553,11 +553,11 @@ def _make_cached_request_data(
553
553
return req_data
554
554
555
555
def _try_schedule_encoder_inputs (
556
- self ,
557
- request : Request ,
558
- num_computed_tokens : int ,
559
- num_new_tokens : int ,
560
- encoder_budget : int ,
556
+ self ,
557
+ request : Request ,
558
+ num_computed_tokens : int ,
559
+ num_new_tokens : int ,
560
+ encoder_budget : int ,
561
561
) -> tuple [list [int ], int , int ]:
562
562
"""
563
563
Determine which encoder inputs need to be scheduled in the current step,
@@ -636,9 +636,9 @@ def _try_schedule_encoder_inputs(
636
636
return encoder_inputs_to_schedule , num_new_tokens , encoder_budget
637
637
638
638
def update_from_output (
639
- self ,
640
- scheduler_output : SchedulerOutput ,
641
- model_runner_output : ModelRunnerOutput ,
639
+ self ,
640
+ scheduler_output : SchedulerOutput ,
641
+ model_runner_output : ModelRunnerOutput ,
642
642
) -> EngineCoreOutputs :
643
643
sampled_token_ids = model_runner_output .sampled_token_ids
644
644
spec_token_ids = model_runner_output .spec_token_ids
@@ -749,8 +749,9 @@ def update_from_output(
749
749
new_logprobs = new_logprobs ,
750
750
new_prompt_logprobs_tensors = prompt_logprobs_tensors ,
751
751
stop_reason = request .stop_reason ,
752
- events = request .take_events ()) ,
752
+ events = request .take_events (),
753
753
trace_headers = request .trace_headers
754
+ ),
754
755
)
755
756
else :
756
757
# Invariant: EngineCore returns no partial prefill outputs.
@@ -772,9 +773,9 @@ def update_from_output(
772
773
scheduler_stats = self .make_stats (spec_decoding_stats ),
773
774
)
774
775
if self .include_finished_set :
775
- #TODO currently sending duplicates here, improve this
776
+ # TODO currently sending duplicates here, improve this
776
777
engine_core_outputs .finished_requests = (
777
- scheduler_output .finished_req_ids | self .finished_req_ids )
778
+ scheduler_output .finished_req_ids | self .finished_req_ids )
778
779
779
780
return engine_core_outputs
780
781
@@ -785,9 +786,9 @@ def add_request(self, request: Request) -> None:
785
786
request .record_event (EngineCoreEventType .QUEUED )
786
787
787
788
def finish_requests (
788
- self ,
789
- request_ids : Union [str , Iterable [str ]],
790
- finished_status : RequestStatus ,
789
+ self ,
790
+ request_ids : Union [str , Iterable [str ]],
791
+ finished_status : RequestStatus ,
791
792
) -> None :
792
793
"""Handles the finish signal from outside the scheduler.
793
794
@@ -796,7 +797,7 @@ def finish_requests(
796
797
"""
797
798
assert RequestStatus .is_finished (finished_status )
798
799
if isinstance (request_ids , str ):
799
- request_ids = (request_ids , )
800
+ request_ids = (request_ids ,)
800
801
else :
801
802
request_ids = set (request_ids )
802
803
@@ -832,8 +833,8 @@ def reset_prefix_cache(self) -> bool:
832
833
return self .kv_cache_manager .reset_prefix_cache ()
833
834
834
835
def make_stats (
835
- self ,
836
- spec_decoding_stats : Optional [SpecDecodingStats ] = None ,
836
+ self ,
837
+ spec_decoding_stats : Optional [SpecDecodingStats ] = None ,
837
838
) -> Optional [SchedulerStats ]:
838
839
if not self .log_stats :
839
840
return None
@@ -848,10 +849,10 @@ def make_stats(
848
849
)
849
850
850
851
def make_spec_decoding_stats (
851
- self ,
852
- spec_decoding_stats : Optional [SpecDecodingStats ],
853
- num_draft_tokens : int ,
854
- num_accepted_tokens : int ,
852
+ self ,
853
+ spec_decoding_stats : Optional [SpecDecodingStats ],
854
+ num_draft_tokens : int ,
855
+ num_accepted_tokens : int ,
855
856
) -> Optional [SpecDecodingStats ]:
856
857
if not self .log_stats :
857
858
return None
0 commit comments