|
55 | 55 | from vllm_ascend.attention.attention import AttentionMaskBuilder
|
56 | 56 | from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
57 | 57 | from vllm_ascend.platform import NPUPlatform
|
| 58 | +from vllm_ascend.utils import vllm_version_is |
58 | 59 |
|
59 | 60 | if TYPE_CHECKING:
|
60 | 61 | import xgrammar as xgr # type: ignore[import-untyped]
|
@@ -187,14 +188,26 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
187 | 188 | # Request states.
|
188 | 189 | self.requests: Dict[str, CachedRequestState] = {}
|
189 | 190 | # Persistent batch.
|
190 |
| - self.input_batch = InputBatch( |
191 |
| - max_num_reqs=self.max_num_reqs, |
192 |
| - max_model_len=self.model_config.max_model_len, |
193 |
| - max_num_blocks_per_req=self.max_num_blocks_per_req, |
194 |
| - device=self.device, |
195 |
| - pin_memory=True, |
196 |
| - vocab_size=self.model_config.get_vocab_size(), |
197 |
| - ) |
| 191 | + # Remove this after we drop 0.8.5 support |
| 192 | + if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"): |
| 193 | + self.input_batch = InputBatch( |
| 194 | + max_num_reqs=self.max_num_reqs, |
| 195 | + max_model_len=self.model_config.max_model_len, |
| 196 | + max_num_blocks_per_req=self.max_num_blocks_per_req, |
| 197 | + device=self.device, |
| 198 | + pin_memory=True, |
| 199 | + vocab_size=self.model_config.get_vocab_size(), |
| 200 | + ) |
| 201 | + else: |
| 202 | + self.input_batch = InputBatch( |
| 203 | + max_num_reqs=self.max_num_reqs, |
| 204 | + max_model_len=self.model_config.max_model_len, |
| 205 | + max_num_blocks_per_req=self.max_num_blocks_per_req, |
| 206 | + max_num_batched_tokens=self.max_num_tokens, |
| 207 | + device=self.device, |
| 208 | + pin_memory=True, |
| 209 | + vocab_size=self.model_config.get_vocab_size(), |
| 210 | + ) |
198 | 211 |
|
199 | 212 | self.input_ids = torch.zeros(self.max_num_tokens,
|
200 | 213 | dtype=torch.int32,
|
|
0 commit comments