2
2
3
3
import time
4
4
from dataclasses import dataclass , field
5
- from typing import TYPE_CHECKING , List
5
+ from typing import TYPE_CHECKING , Dict , List , Optional , Set
6
6
7
7
if TYPE_CHECKING :
8
8
from vllm .outputs import RequestOutput
9
9
from vllm .v1 .engine import EngineCoreEvent , EngineCoreOutput , FinishReason
10
+ from vllm .v1 .output_processor import RequestState
10
11
11
12
12
13
@dataclass
@@ -36,6 +37,12 @@ class SchedulerStats:
36
37
default_factory = PrefixCacheStats )
37
38
38
39
40
+ @dataclass
41
+ class LoRAStats :
42
+ waiting_requests : Set [str ] = field (default_factory = set )
43
+ running_requests : Set [str ] = field (default_factory = set )
44
+
45
+
39
46
@dataclass
40
47
class RequestStateStats :
41
48
"""Stats that need to be tracked across delta updates."""
@@ -76,14 +83,17 @@ def __init__(self):
76
83
self .time_per_output_tokens_iter : List [float ] = []
77
84
self .queue_times_iter : List [float ] = []
78
85
self .prefill_times_iter : List [float ] = []
86
+ self .waiting_lora_adapters : Dict [str , int ] = {}
87
+ self .running_lora_adapters : Dict [str , int ] = {}
79
88
80
89
def _time_since (self , start : float ) -> float :
81
90
"""Calculate an interval relative to this iteration's timestamp."""
82
91
return self .iteration_timestamp - start
83
92
84
93
def update_from_output (self , output : "EngineCoreOutput" ,
85
94
engine_core_timestamp : float , is_prefilling : bool ,
86
- prompt_len : int , req_stats : RequestStateStats ):
95
+ prompt_len : int , req_stats : RequestStateStats ,
96
+ lora_stats : Optional [LoRAStats ]):
87
97
num_new_generation_tokens = len (output .new_token_ids )
88
98
89
99
self .num_generation_tokens += num_new_generation_tokens
@@ -105,7 +115,8 @@ def update_from_output(self, output: "EngineCoreOutput",
105
115
106
116
# Process request-level engine core events
107
117
if output .events is not None :
108
- self .update_from_events (output .events , is_prefilling , req_stats )
118
+ self .update_from_events (output .request_id , output .events ,
119
+ is_prefilling , req_stats , lora_stats )
109
120
110
121
# Process the batch-level "new tokens" engine core event
111
122
if is_prefilling :
@@ -123,17 +134,21 @@ def update_from_output(self, output: "EngineCoreOutput",
123
134
if num_new_generation_tokens > 0 :
124
135
req_stats .last_token_ts = engine_core_timestamp
125
136
126
- def update_from_events (self , events : List ["EngineCoreEvent" ],
127
- is_prefilling : bool , req_stats : RequestStateStats ):
137
+ def update_from_events (self , req_id : str , events : List ["EngineCoreEvent" ],
138
+ is_prefilling : bool , req_stats : RequestStateStats ,
139
+ lora_stats : Optional [LoRAStats ]):
128
140
# Avoid circular dependency
129
141
from vllm .v1 .engine import EngineCoreEventType
130
142
for event in events :
131
143
if event .type == EngineCoreEventType .QUEUED :
132
144
req_stats .queued_ts = event .timestamp
145
+ if lora_stats is not None :
146
+ lora_stats .waiting_requests .add (req_id )
133
147
elif event .type == EngineCoreEventType .SCHEDULED :
134
148
queued_interval = event .timestamp - req_stats .queued_ts
135
149
self .queue_times_iter .append (queued_interval )
136
150
req_stats .scheduled_ts = event .timestamp
151
+ LoRARequestStates .scheduled_request (lora_stats , req_id )
137
152
138
153
def update_from_finished_request (self , finish_reason : "FinishReason" ,
139
154
request_output : "RequestOutput" ,
@@ -151,3 +166,55 @@ def update_from_finished_request(self, finish_reason: "FinishReason",
151
166
inference_time = inference_time ,
152
167
decode_time = decode_time )
153
168
self .finished_requests .append (finished_req )
169
+
170
+
171
+ class LoRARequestStates :
172
+ """Per-LoRA request state stats."""
173
+
174
+ def __init__ (self ):
175
+ self .lora_name_to_stats : Dict [str , LoRAStats ] = {}
176
+
177
+ def get_stats (self , req_state : 'RequestState' ) -> Optional [LoRAStats ]:
178
+ if req_state .lora_name is None :
179
+ return None
180
+ if req_state .lora_name not in self .lora_name_to_stats :
181
+ self .lora_name_to_stats [req_state .lora_name ] = LoRAStats ()
182
+ return self .lora_name_to_stats [req_state .lora_name ]
183
+
184
+ def add_request (self , req_state : 'RequestState' ):
185
+ if (lora_stats := self .get_stats (req_state )) is not None :
186
+ lora_stats .waiting_requests .add (req_state .request_id )
187
+
188
+ def finish_request (self , req_state : 'RequestState' ):
189
+ if req_state .lora_name is None :
190
+ return
191
+ lora_stats = self .lora_name_to_stats [req_state .lora_name ]
192
+ lora_stats .running_requests .remove (req_state .request_id )
193
+
194
+ def abort_request (self , req_state : 'RequestState' ):
195
+ if req_state .lora_name is None :
196
+ return
197
+ lora_stats = self .lora_name_to_stats [req_state .lora_name ]
198
+ lora_stats .waiting_requests .discard (req_state .request_id )
199
+ lora_stats .running_requests .discard (req_state .request_id )
200
+
201
+ # Break the pattern for this lifecycle methods so we can
202
+ # call this from IterationStats.update_from_events()
203
+ @staticmethod
204
+ def scheduled_request (lora_stats : Optional [LoRAStats ], request_id : str ):
205
+ if lora_stats is None :
206
+ return
207
+ lora_stats .waiting_requests .remove (request_id )
208
+ lora_stats .running_requests .add (request_id )
209
+
210
+ def update_iteration_stats (self ,
211
+ iteration_stats : Optional [IterationStats ]):
212
+ if iteration_stats is None :
213
+ return
214
+ for lora_name , stats in self .lora_name_to_stats .items ():
215
+ if stats .waiting_requests :
216
+ iteration_stats .waiting_lora_adapters [lora_name ] = \
217
+ len (stats .waiting_requests )
218
+ if stats .running_requests :
219
+ iteration_stats .running_lora_adapters [lora_name ] = \
220
+ len (stats .running_requests )
0 commit comments