@@ -164,6 +164,24 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
164
164
raise NotImplementedError (
165
165
"Non-Attention backend is not supported by V1 NPUModelRunner." )
166
166
167
+ self .attn_backend = get_attn_backend (
168
+ self .head_size ,
169
+ self .dtype ,
170
+ self .kv_cache_dtype ,
171
+ self .block_size ,
172
+ self .model_config .is_attention_free ,
173
+ use_mla = self .model_config .use_mla ,
174
+ )
175
+ if self .attn_backend is None :
176
+ error_msg = (
177
+ f"Error with get_att_backend: { self .head_size = } , "
178
+ f"{ self .dtype = } , { self .kv_cache_dtype = } , { self .block_size = } , "
179
+ f"{ self .model_config .is_attention_free = } , "
180
+ f"{ self .model_config .use_mla = } " )
181
+ logger .error (error_msg )
182
+ raise NotImplementedError (
183
+ "Non-Attention backend is not supported by V1 GPUModelRunner." )
184
+
167
185
self .attn_metadata_builder = self .attn_backend .get_builder_cls ()(
168
186
weakref .proxy (self ))
169
187
@@ -196,6 +214,17 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
196
214
pin_memory = True ,
197
215
vocab_size = self .model_config .get_vocab_size (),
198
216
)
217
+ else :
218
+ self .input_batch = InputBatch (
219
+ max_num_reqs = self .max_num_reqs ,
220
+ max_model_len = self .model_config .max_model_len ,
221
+ max_num_blocks_per_req = self .max_num_blocks_per_req ,
222
+ max_num_batched_tokens = self .max_num_tokens ,
223
+ device = self .device ,
224
+ pin_memory = True ,
225
+ vocab_size = self .model_config .get_vocab_size (),
226
+ )
227
+
199
228
self .input_ids = torch .zeros (self .max_num_tokens ,
200
229
dtype = torch .int32 ,
201
230
device = self .device )
@@ -542,10 +571,7 @@ def _process_reqs(
542
571
543
572
block_table_indices = (req_indices * self .max_num_blocks_per_req +
544
573
positions_np // self .block_size )
545
- if vllm_version_is ("0.8.5" ) or vllm_version_is ("0.8.5.post1" ):
546
- block_table_cpu = self .input_batch .block_table .get_cpu_tensor ()
547
- else :
548
- block_table_cpu = self .input_batch .block_table [0 ].get_cpu_tensor ()
574
+ block_table_cpu = self .input_batch .block_table .get_cpu_tensor ()
549
575
block_numbers = block_table_cpu .flatten ()[block_table_indices ].numpy ()
550
576
block_offsets = positions_np % self .block_size
551
577
np .add (block_numbers * self .block_size ,
@@ -960,16 +986,6 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
960
986
"""
961
987
import torch_npu
962
988
kv_caches : Dict [str , torch .Tensor ] = {}
963
- if not (vllm_version_is ("0.8.5" ) or vllm_version_is ("0.8.5.post1" )):
964
- self .input_batch = InputBatch (
965
- max_num_reqs = self .max_num_reqs ,
966
- max_model_len = self .model_config .max_model_len ,
967
- max_num_batched_tokens = self .max_num_tokens ,
968
- device = self .device ,
969
- pin_memory = True ,
970
- vocab_size = self .model_config .get_vocab_size (),
971
- kv_cache_config = kv_cache_config ,
972
- )
973
989
974
990
for kv_cache_group in kv_cache_config .kv_cache_groups :
975
991
kv_cache_spec = kv_cache_group .kv_cache_spec
0 commit comments