@@ -962,6 +962,45 @@ def _process_sequence_group_outputs(
962
962
963
963
return
964
964
965
+ def _update_num_computed_tokens_for_multi_step_prefill (
966
+ self , seq_group : SequenceGroup ,
967
+ seq_group_meta : SequenceGroupMetadata ,
968
+ is_first_step_output : Optional [bool ]):
969
+ """
970
+ This function updates num_computed_tokens for prompt sequences
971
+ when Multi-Step is enabled.
972
+
973
+ seq_group: SequenceGroup to update the num_computed_tokens for.
974
+ seq_group_meta: Metadata of the given SequenceGroup.
975
+ is_first_step_output: Optional[bool] -
976
+ When available, is_first_step_output indicates if the appended
977
+ output token is the output of the first-step in multi-step.
978
+ A value of None indicates that outputs from all steps in
979
+ in multi-step are submitted in a single burst.
980
+ """
981
+
982
+ assert self .scheduler_config .is_multi_step
983
+
984
+ if not seq_group_meta .is_prompt :
985
+ # num_computed_token updates for multi-step decodes happen after
986
+ # the tokens are appended to the sequence.
987
+ return
988
+
989
+ do_update : bool = False
990
+ if self .scheduler_config .chunked_prefill_enabled :
991
+ # In multi-step + chunked-prefill case, the prompt sequences
992
+ # that are scheduled are fully processed in the first step.
993
+ do_update = is_first_step_output is None or is_first_step_output
994
+ else :
995
+ # Normal multi-step decoding case. In this case prompt-sequences
996
+ # are actually single-stepped. Always update in this case.
997
+ assert seq_group .state .num_steps == 1
998
+ do_update = True
999
+
1000
+ if do_update :
1001
+ seq_group .update_num_computed_tokens (
1002
+ seq_group_meta .token_chunk_size )
1003
+
965
1004
def _process_model_outputs (self ,
966
1005
ctx : SchedulerContext ,
967
1006
request_id : Optional [str ] = None ) -> None :
@@ -972,64 +1011,6 @@ def _process_model_outputs(self,
972
1011
request_id: If provided, then only this request is going to be processed
973
1012
"""
974
1013
975
- def update_prefill_num_computed_tokens (
976
- seq_group : SequenceGroup ,
977
- seq_group_meta : SequenceGroupMetadata , num_outputs : int ,
978
- is_first_step_output : Optional [bool ]) -> None :
979
- """
980
- When multi-step and chunked-prefill are enabled together, the
981
- prefill sequence scheduled for multi-step execution turn into
982
- decodes in the first step itself. This function accounts
983
- for that conversion.
984
-
985
- seq_group: SequenceGroup - A prefill seq_group
986
- seq_group_meta: SequenceGroupMetadata - Metadata of the given
987
- prefill seq_group
988
- num_outputs: int - number of output tokens being processed for the
989
- given seq_group
990
- is_first_step_output: Optional[bool] -
991
- If multi-step is enabled and num_outputs is 1, this value
992
- indicates if this outputs belongs to the first step in the
993
- multi-step.
994
- If multi-step is enabled and num_outputs > 1, this value
995
- must be None, as num_outputs > 1 indicates that outputs from
996
- all the steps in multi-step are submitted in a single burst.
997
- When multi-step is disabled, this value is always True.
998
- """
999
-
1000
- assert seq_group_meta .is_prompt
1001
-
1002
- token_chunk_size = seq_group_meta .token_chunk_size
1003
-
1004
- if num_outputs == 1 :
1005
- assert is_first_step_output is not None
1006
-
1007
- if seq_group_meta .state .num_steps == 1 :
1008
- assert is_first_step_output is True
1009
- seq_group .update_num_computed_tokens (token_chunk_size )
1010
- return
1011
-
1012
- # multi-step prefill is only supported when multi-step is
1013
- # enabled with chunked prefill
1014
- assert self .scheduler_config .is_multi_step and \
1015
- self .scheduler_config .chunked_prefill_enabled
1016
- if is_first_step_output is True :
1017
- # This sequence is a prompt during the first step only.
1018
- seq_group .update_num_computed_tokens (token_chunk_size )
1019
- return
1020
-
1021
- assert is_first_step_output is None
1022
-
1023
- # multi-step prefill is only supported when multi-step is
1024
- # enabled with chunked prefill. Outputs from all the steps are
1025
- # submitted in a single burst.
1026
- assert self .scheduler_config .is_multi_step and \
1027
- self .scheduler_config .chunked_prefill_enabled
1028
- assert num_outputs == seq_group_meta .state .num_steps , \
1029
- f"#outputs { len (outputs )} - num steps { seq_group_meta .state .num_steps } " #noqa
1030
- # This sequence is a prompt during the first step only.
1031
- seq_group .update_num_computed_tokens (token_chunk_size )
1032
-
1033
1014
now = time .time ()
1034
1015
1035
1016
if len (ctx .output_queue ) == 0 :
@@ -1090,7 +1071,7 @@ def update_prefill_num_computed_tokens(
1090
1071
seq_group_meta = seq_group_metadata_list [i ]
1091
1072
scheduled_seq_group = scheduler_outputs .scheduled_seq_groups [i ]
1092
1073
1093
- seq_group = scheduled_seq_group .seq_group
1074
+ seq_group : SequenceGroup = scheduled_seq_group .seq_group
1094
1075
1095
1076
if seq_group .is_finished ():
1096
1077
finished_before .append (i )
@@ -1101,14 +1082,14 @@ def update_prefill_num_computed_tokens(
1101
1082
else :
1102
1083
output = [outputs_by_sequence_group [0 ][i ]]
1103
1084
1104
- if not is_async and seq_group_meta . is_prompt :
1105
- # Updates for all decodes happen when we actually append the
1106
- # token ids to the seq in process_outputs.
1107
- update_prefill_num_computed_tokens ( seq_group , seq_group_meta ,
1108
- len ( output ),
1109
- is_first_step_output )
1110
- elif not is_async :
1111
- seq_group . update_num_computed_tokens ( 1 )
1085
+ if not is_async :
1086
+ if self . scheduler_config . is_multi_step :
1087
+ # Updates happen only if the sequence is prefill
1088
+ self . _update_num_computed_tokens_for_multi_step_prefill (
1089
+ seq_group , seq_group_meta , is_first_step_output )
1090
+ else :
1091
+ seq_group . update_num_computed_tokens (
1092
+ seq_group_meta . token_chunk_size )
1112
1093
1113
1094
if outputs :
1114
1095
for o in outputs :
@@ -1132,16 +1113,8 @@ def update_prefill_num_computed_tokens(
1132
1113
else :
1133
1114
self .output_processor .process_prompt_logprob (seq_group , output )
1134
1115
if seq_group_meta .do_sample :
1135
- output_token_num = self .output_processor .process_outputs (
1116
+ self .output_processor .process_outputs (
1136
1117
seq_group , output , is_async )
1137
- if self .speculative_config :
1138
- # We -1 here because we always
1139
- # (w/o speculative decoding) add the number of
1140
- # computed tokens by one in the decoding phase.
1141
- # Therefore, we remove that one token that
1142
- # is already added.
1143
- seq_group .update_num_computed_tokens (output_token_num -
1144
- 1 )
1145
1118
1146
1119
if seq_group .is_finished ():
1147
1120
finished_now .append (i )
@@ -1250,20 +1223,15 @@ def _advance_to_next_step(
1250
1223
if seq_group .is_finished ():
1251
1224
continue
1252
1225
1253
- if seq_group_metadata .is_prompt :
1254
- if self .scheduler_config .is_multi_step and \
1255
- self .scheduler_config .chunked_prefill_enabled :
1256
- # Prompts are scheduled in multi-step only when
1257
- # chunking is enabled. These prompts turn into
1258
- # decodes after the very first step. Therefore,
1259
- # we skip the update to the num_computed_tokens
1260
- # here.
1261
- seq_group .update_num_computed_tokens (1 )
1262
- else :
1263
- seq_group .update_num_computed_tokens (
1264
- seq_group_metadata .token_chunk_size )
1226
+ if self .scheduler_config .is_multi_step :
1227
+ # Updates happen only if the sequence is prefill
1228
+ self ._update_num_computed_tokens_for_multi_step_prefill (
1229
+ seq_group , seq_group_metadata ,
1230
+ seq_group .state .num_steps == 1 )
1265
1231
else :
1266
- seq_group .update_num_computed_tokens (1 )
1232
+ seq_group .update_num_computed_tokens (
1233
+ seq_group_metadata .token_chunk_size )
1234
+
1267
1235
if seq_group_metadata .do_sample :
1268
1236
assert len (sequence_group_outputs .samples ) == 1 , (
1269
1237
"Async output processor expects a single sample"
@@ -1273,7 +1241,15 @@ def _advance_to_next_step(
1273
1241
1274
1242
assert len (seq_group .seqs ) == 1
1275
1243
seq = seq_group .seqs [0 ]
1276
- seq .append_token_id (sample .output_token , sample .logprobs )
1244
+
1245
+ if self .scheduler_config .is_multi_step :
1246
+ is_prefill_append = seq .data .get_num_uncomputed_tokens (
1247
+ ) == 0
1248
+ seq .append_token_id (sample .output_token , sample .logprobs )
1249
+ if not is_prefill_append :
1250
+ seq_group .update_num_computed_tokens (1 )
1251
+ else :
1252
+ seq .append_token_id (sample .output_token , sample .logprobs )
1277
1253
1278
1254
def step (self ) -> List [Union [RequestOutput , EmbeddingRequestOutput ]]:
1279
1255
"""Performs one decoding iteration and returns newly generated results.
0 commit comments