18
18
#
19
19
20
20
import gc
21
- import math
22
21
import os
23
22
import time
24
23
import weakref
@@ -293,9 +292,9 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
293
292
device = "cpu" )
294
293
self .attn_mask = None
295
294
self .attn_state = None
296
- self .use_npu_graph = (self .vllm_config .compilation_config .level
297
- == CompilationLevel .PIECEWISE
298
- and not self .model_config .enforce_eager )
295
+ self .use_aclgraph = (self .vllm_config .compilation_config .level
296
+ == CompilationLevel .PIECEWISE
297
+ and not self .model_config .enforce_eager )
299
298
self .aclgraph_batch_sizes = list (
300
299
reversed (
301
300
self .vllm_config .compilation_config .cudagraph_capture_sizes ))
@@ -508,6 +507,13 @@ def _process_reqs(
508
507
assert total_num_scheduled_tokens > 0
509
508
num_reqs = self .input_batch .num_reqs
510
509
assert num_reqs > 0
510
+ if (self .use_aclgraph and
511
+ total_num_scheduled_tokens <= self .aclgraph_batch_sizes [- 1 ]):
512
+ # Add padding to the batch size.
513
+ num_input_tokens = self .vllm_config .pad_for_cudagraph (
514
+ total_num_scheduled_tokens )
515
+ else :
516
+ num_input_tokens = total_num_scheduled_tokens
511
517
512
518
modified_batch = self .attn_metadata_builder .reorder_batch (
513
519
self .input_batch , scheduler_output )
@@ -546,7 +552,7 @@ def _process_reqs(
546
552
547
553
self .positions [:total_num_scheduled_tokens ].copy_ (
548
554
self .positions_cpu [:total_num_scheduled_tokens ], non_blocking = True )
549
- positions = self .positions [:total_num_scheduled_tokens ]
555
+ positions = self .positions [:num_input_tokens ]
550
556
self .query_lens = torch .from_numpy (num_scheduled_tokens )
551
557
552
558
self .seq_lens_np [:num_reqs ] = (
@@ -605,7 +611,7 @@ def _process_reqs(
605
611
# Copy the tensors to the NPU.
606
612
self .input_ids [:total_num_scheduled_tokens ].copy_ (
607
613
self .input_ids_cpu [:total_num_scheduled_tokens ], non_blocking = True )
608
- input_ids = self .input_ids [:total_num_scheduled_tokens ]
614
+ input_ids = self .input_ids [:num_input_tokens ]
609
615
610
616
if self .enable_torchair_graph_mode and attn_metadata .attn_state == AscendAttentionState .DecodeOnly :
611
617
padding = torch .zeros (graph_pad_size ,
@@ -615,7 +621,9 @@ def _process_reqs(
615
621
positions = torch .cat ([positions , padding ])
616
622
617
623
# Run forward pass
618
- with set_forward_context (attn_metadata , self .vllm_config ):
624
+ with set_forward_context (attn_metadata ,
625
+ self .vllm_config ,
626
+ num_tokens = num_input_tokens ):
619
627
model_kwargs = {}
620
628
if self .enable_torchair_graph_mode :
621
629
model_kwargs ["kv_caches" ] = self .kv_caches
@@ -1062,17 +1070,14 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
1062
1070
return kv_cache_spec
1063
1071
1064
1072
def capture_model (self ) -> None :
1065
- if not self .use_npu_graph :
1073
+ if not self .use_aclgraph :
1066
1074
logger .warning (
1067
1075
"Skipping NPU graph capture. Please add "
1068
1076
"-O %s to use NPU graphs." , CompilationLevel .PIECEWISE )
1069
1077
return
1070
1078
1071
1079
start_time = time .perf_counter ()
1072
1080
start_free_npu_memory = torch .npu .mem_get_info ()[0 ]
1073
- # Since vllm aclgraph_batch_sizes is too large,
1074
- # we need to adjust its length to proper size.
1075
- self .verify_adjust_aclgraph_batch_sizes ()
1076
1081
1077
1082
# Trigger ACL graph capture for specific shapes.
1078
1083
# Capture the large shapes first so that the smaller shapes
@@ -1091,63 +1096,3 @@ def capture_model(self) -> None:
1091
1096
# This usually takes 5~20 seconds.
1092
1097
logger .info ("Graph capturing finished in %.0f secs, took %.2f GiB" ,
1093
1098
elapsed_time , npu_graph_size / (1 << 30 ))
1094
-
1095
- def verify_adjust_aclgraph_batch_sizes (self ) -> None :
1096
- # Now, vllm-ascend support max capture size is 1920
1097
- max_capture_size = 1920
1098
- original_aclgraph_batch_sizes = self .aclgraph_batch_sizes
1099
- num_hidden_layers = self .vllm_config .model_config .hf_config .num_hidden_layers
1100
- max_support_len_aclgraph = self .get_max_support_len (
1101
- max_capture_size , num_hidden_layers )
1102
-
1103
- if max_support_len_aclgraph < len (original_aclgraph_batch_sizes ):
1104
- self .aclgraph_batch_sizes = self .sample_from_list (
1105
- max_support_len_aclgraph )
1106
-
1107
- logger .info (
1108
- "Model:%s-num_hidden_layers:%d will adjust aclgraph_batch_sizes, pre-adjust-len: %s, post-adjust-len: %s" ,
1109
- self .vllm_config .model_config .architectures [0 ],
1110
- num_hidden_layers , len (original_aclgraph_batch_sizes ),
1111
- len (self .aclgraph_batch_sizes ))
1112
- else :
1113
- logger .info (
1114
- "Model:%s-num_hidden_layers:%d no need adjust aclgraph_batch_sizes, list_len: %s" ,
1115
- self .vllm_config .model_config .architectures [0 ],
1116
- num_hidden_layers , len (original_aclgraph_batch_sizes ))
1117
-
1118
- def get_max_support_len (self , max_capture_size , num_hidden_layers ) -> int :
1119
- parallel_type_cnt = 0
1120
- dp_size = self .vllm_config .parallel_config .data_parallel_size
1121
- tp_size = self .vllm_config .parallel_config .tensor_parallel_size
1122
- if dp_size > 1 :
1123
- parallel_type_cnt += 1
1124
- if tp_size > 1 :
1125
- parallel_type_cnt += 1
1126
- max_support_len_aclgraph = math .floor (max_capture_size /
1127
- (num_hidden_layers + 1 ) /
1128
- (parallel_type_cnt + 1 ))
1129
- logger .info (
1130
- "max_capture_size:%s, dp_size:%s, tp_size:%s, parallel_type_cnt:%s, max_support_len_aclgraph: %s:" ,
1131
- max_capture_size ,
1132
- dp_size ,
1133
- tp_size ,
1134
- parallel_type_cnt ,
1135
- max_support_len_aclgraph ,
1136
- )
1137
-
1138
- return max_support_len_aclgraph
1139
-
1140
- def sample_from_list (self , sample_len ) -> list [int ]:
1141
- # we use this function to sample a new list from old list by given length, and maintain uniformity, for example:
1142
- # original: [1 8 16 24 32 40 48 56 64]
1143
- # --> sample length = 3: [1 32 64]
1144
- # --> sample length = 5: [1 16 32 48 64]
1145
- original_len = len (self .aclgraph_batch_sizes )
1146
- step = (original_len - 1 ) / (sample_len - 1 )
1147
- indices = [round (i * step ) for i in range (sample_len )]
1148
- # Align first and last element of the original list and sub-list
1149
- indices [0 ] = 0
1150
- indices [- 1 ] = original_len - 1
1151
- # Sample new list
1152
- new_list = [self .aclgraph_batch_sizes [i ] for i in indices ]
1153
- return new_list
0 commit comments