@@ -1436,7 +1436,8 @@ def _process_model_outputs(self,
1436
1436
# LLMEngine/AsyncLLMEngine directly
1437
1437
if is_async :
1438
1438
# Log stats.
1439
- self .do_log_stats (scheduler_outputs , outputs , finished_before )
1439
+ self .do_log_stats (scheduler_outputs , outputs , finished_before ,
1440
+ skip )
1440
1441
1441
1442
# Tracing
1442
1443
self .do_tracing (scheduler_outputs )
@@ -1743,25 +1744,31 @@ def remove_logger(self, logger_name: str) -> None:
1743
1744
def do_log_stats (self ,
1744
1745
scheduler_outputs : Optional [SchedulerOutputs ] = None ,
1745
1746
model_output : Optional [List [SamplerOutput ]] = None ,
1746
- finished_before : Optional [List [int ]] = None ) -> None :
1747
+ finished_before : Optional [List [int ]] = None ,
1748
+ skip : Optional [List [int ]] = None ) -> None :
1747
1749
"""Forced log when no requests active."""
1748
1750
if self .log_stats :
1749
1751
stats = self ._get_stats (scheduler_outputs , model_output ,
1750
- finished_before )
1752
+ finished_before , skip )
1751
1753
for logger in self .stat_loggers .values ():
1752
1754
logger .log (stats )
1753
1755
1754
1756
def _get_stats (self ,
1755
1757
scheduler_outputs : Optional [SchedulerOutputs ],
1756
1758
model_output : Optional [List [SamplerOutput ]] = None ,
1757
- finished_before : Optional [List [int ]] = None ) -> Stats :
1759
+ finished_before : Optional [List [int ]] = None ,
1760
+ skip : Optional [List [int ]] = None ) -> Stats :
1758
1761
"""Get Stats to be Logged to Prometheus.
1759
1762
1760
1763
Args:
1761
1764
scheduler_outputs: Optional, used to populate metrics related to
1762
1765
the scheduled batch,
1763
1766
model_output: Optional, used to emit speculative decoding metrics
1764
1767
which are created by the workers.
1768
+ finished_before: Optional, indices of sequences that were finished
1769
+ before. These sequences will be ignored.
1770
+ skip: Optional, indices of sequences that were preempted. These
1771
+ sequences will be ignored.
1765
1772
"""
1766
1773
now = time .time ()
1767
1774
@@ -1836,6 +1843,11 @@ def _get_stats(self,
1836
1843
actual_num_batched_tokens -= 1
1837
1844
continue
1838
1845
1846
+ # Currently, skip == preempted sequences, so we need to skip
1847
+ # their log stats
1848
+ if skip and idx in skip :
1849
+ continue
1850
+
1839
1851
group_was_prefill = idx < scheduler_outputs .num_prefill_groups
1840
1852
seq_group = scheduled_seq_group .seq_group
1841
1853
0 commit comments