@@ -1056,7 +1056,8 @@ def _process_model_outputs(self,
1056
1056
# LLMEngine/AsyncLLMEngine directly
1057
1057
if is_async :
1058
1058
# Log stats.
1059
- self .do_log_stats (scheduler_outputs , outputs , finished_before )
1059
+ self .do_log_stats (scheduler_outputs , outputs , finished_before ,
1060
+ skip )
1060
1061
1061
1062
# Tracing
1062
1063
self .do_tracing (scheduler_outputs )
@@ -1363,25 +1364,31 @@ def remove_logger(self, logger_name: str) -> None:
1363
1364
def do_log_stats (self ,
1364
1365
scheduler_outputs : Optional [SchedulerOutputs ] = None ,
1365
1366
model_output : Optional [List [SamplerOutput ]] = None ,
1366
- finished_before : Optional [List [int ]] = None ) -> None :
1367
+ finished_before : Optional [List [int ]] = None ,
1368
+ skip : Optional [List [int ]] = None ) -> None :
1367
1369
"""Forced log when no requests active."""
1368
1370
if self .log_stats :
1369
1371
stats = self ._get_stats (scheduler_outputs , model_output ,
1370
- finished_before )
1372
+ finished_before , skip )
1371
1373
for logger in self .stat_loggers .values ():
1372
1374
logger .log (stats )
1373
1375
1374
1376
def _get_stats (self ,
1375
1377
scheduler_outputs : Optional [SchedulerOutputs ],
1376
1378
model_output : Optional [List [SamplerOutput ]] = None ,
1377
- finished_before : Optional [List [int ]] = None ) -> Stats :
1379
+ finished_before : Optional [List [int ]] = None ,
1380
+ skip : Optional [List [int ]] = None ) -> Stats :
1378
1381
"""Get Stats to be Logged to Prometheus.
1379
1382
1380
1383
Args:
1381
1384
scheduler_outputs: Optional, used to populate metrics related to
1382
1385
the scheduled batch,
1383
1386
model_output: Optional, used to emit speculative decoding metrics
1384
1387
which are created by the workers.
1388
+ finished_before: Optional, indices of sequences that were finished
1389
+ before. These sequences will be ignored.
1390
+ skip: Optional, indices of sequences that were preempted. These
1391
+ sequences will be ignored.
1385
1392
"""
1386
1393
now = time .time ()
1387
1394
@@ -1456,6 +1463,11 @@ def _get_stats(self,
1456
1463
actual_num_batched_tokens -= 1
1457
1464
continue
1458
1465
1466
+ # Currently, skip == preempted sequences, so we need to skip
1467
+ # their log stats
1468
+ if skip and idx in skip :
1469
+ continue
1470
+
1459
1471
group_was_prefill = idx < scheduler_outputs .num_prefill_groups
1460
1472
seq_group = scheduled_seq_group .seq_group
1461
1473
0 commit comments