Skip to content

Commit 3f9dc1d

Browse files
hmellormawong-amd
authored andcommitted
Update some more deprecated type hinting (vllm-project#17998)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
1 parent 0eac19c commit 3f9dc1d

10 files changed

+73
-73
lines changed

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,9 @@ exclude = [
7979
"vllm/engine/**/*.py" = ["UP006", "UP035"]
8080
"vllm/executor/**/*.py" = ["UP006", "UP035"]
8181
"vllm/lora/**/*.py" = ["UP006", "UP035"]
82-
"vllm/model_executor/**/*.py" = ["UP006", "UP035"]
82+
"vllm/model_executor/layers/**/*.py" = ["UP006", "UP035"]
83+
"vllm/model_executor/model_loader/**/*.py" = ["UP006", "UP035"]
84+
"vllm/model_executor/models/**/*.py" = ["UP006", "UP035"]
8385
"vllm/platforms/**/*.py" = ["UP006", "UP035"]
8486
"vllm/plugins/**/*.py" = ["UP006", "UP035"]
8587
"vllm/profiler/**/*.py" = ["UP006", "UP035"]

vllm/model_executor/custom_op.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
from typing import Dict, Type
4-
53
import torch.nn as nn
64

75
from vllm.config import get_current_vllm_config
@@ -138,7 +136,7 @@ def default_on() -> bool:
138136
# Examples:
139137
# - MyOp.enabled()
140138
# - op_registry["my_op"].enabled()
141-
op_registry: Dict[str, Type['CustomOp']] = {}
139+
op_registry: dict[str, type['CustomOp']] = {}
142140

143141
# Decorator to register custom ops.
144142
@classmethod

vllm/model_executor/guided_decoding/guidance_logits_processors.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
import os
3-
from typing import Any, List
3+
from typing import Any
44

55
import llguidance
66
import llguidance.hf
@@ -62,7 +62,7 @@ def _initialize(self):
6262

6363
def __call__(
6464
self,
65-
input_ids: List[int],
65+
input_ids: list[int],
6666
scores: torch.Tensor,
6767
) -> torch.Tensor:
6868
# we initialize the guidance model here

vllm/model_executor/guided_decoding/guided_fields.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
from dataclasses import dataclass
4-
from typing import Dict, List, Optional, TypedDict, Union
4+
from typing import Optional, TypedDict, Union
55

66
from pydantic import BaseModel
77

88

99
# These classes are deprecated, see SamplingParams
1010
class LLMGuidedOptions(TypedDict, total=False):
11-
guided_json: Union[Dict, BaseModel, str]
11+
guided_json: Union[dict, BaseModel, str]
1212
guided_regex: str
13-
guided_choice: List[str]
13+
guided_choice: list[str]
1414
guided_grammar: str
1515
guided_decoding_backend: str
1616
guided_whitespace_pattern: str
@@ -20,9 +20,9 @@ class LLMGuidedOptions(TypedDict, total=False):
2020
@dataclass
2121
class GuidedDecodingRequest:
2222
"""One of the fields will be used to retrieve the logit processor."""
23-
guided_json: Optional[Union[Dict, BaseModel, str]] = None
23+
guided_json: Optional[Union[dict, BaseModel, str]] = None
2424
guided_regex: Optional[str] = None
25-
guided_choice: Optional[List[str]] = None
25+
guided_choice: Optional[list[str]] = None
2626
guided_grammar: Optional[str] = None
2727
guided_decoding_backend: Optional[str] = None
2828
guided_whitespace_pattern: Optional[str] = None

vllm/model_executor/guided_decoding/outlines_decoding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from enum import Enum
77
from json import dumps as json_dumps
88
from re import escape as regex_escape
9-
from typing import Optional, Tuple, Union
9+
from typing import Optional, Union
1010

1111
from transformers import PreTrainedTokenizerBase
1212

@@ -111,7 +111,7 @@ def get_local_outlines_guided_decoding_logits_processor(
111111

112112
def _get_guide_and_mode(
113113
guided_params: GuidedDecodingParams
114-
) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]:
114+
) -> Union[tuple[str, GuidedDecodingMode], tuple[None, None]]:
115115
if guided_params.json:
116116
if isinstance(guided_params.json, dict):
117117
# turn dict into hashable string

vllm/model_executor/guided_decoding/outlines_logits_processors.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import json
2020
from collections import defaultdict
2121
from functools import lru_cache
22-
from typing import Callable, DefaultDict, Dict, List, Optional, Union
22+
from typing import Callable, Optional, Union
2323

2424
import numpy as np
2525
import torch
@@ -53,10 +53,10 @@ def __init__(self, guide: Guide, reasoner: Optional[ReasoningParser]):
5353
self._guide: Guide = guide
5454
self._reasoner: Optional[ReasoningParser] = reasoner
5555
# CFGState is used for the FSM state for CFGGuide
56-
self._fsm_state: DefaultDict[int, Union[int,
56+
self._fsm_state: defaultdict[int, Union[int,
5757
CFGState]] = defaultdict(int)
5858

59-
def __call__(self, input_ids: List[int],
59+
def __call__(self, input_ids: list[int],
6060
scores: torch.Tensor) -> torch.Tensor:
6161
"""Use the FSM to bias the logits before sampling the next token."""
6262

@@ -160,7 +160,7 @@ def __init__(
160160

161161
class JSONLogitsProcessor(RegexLogitsProcessor):
162162

163-
def __init__(self, schema: Union[str, Dict, BaseModel],
163+
def __init__(self, schema: Union[str, dict, BaseModel],
164164
tokenizer: PreTrainedTokenizerBase,
165165
whitespace_pattern: Union[str, None],
166166
reasoner: Optional[ReasoningParser]):
@@ -181,7 +181,7 @@ def __init__(self, schema: Union[str, Dict, BaseModel],
181181
"""
182182
if isinstance(schema, type(BaseModel)):
183183
schema_str = json.dumps(schema.model_json_schema())
184-
elif isinstance(schema, Dict):
184+
elif isinstance(schema, dict):
185185
schema_str = json.dumps(schema)
186186
elif isinstance(schema, str):
187187
schema_str = schema
@@ -252,11 +252,11 @@ def convert_token_to_string(token: str) -> str:
252252
return string
253253

254254
def change_decoder(
255-
decoder: Callable[[List[int]],
256-
str]) -> Callable[[List[int]], List[str]]:
255+
decoder: Callable[[list[int]],
256+
str]) -> Callable[[list[int]], list[str]]:
257257
"""Sync vLLM's decoder with the outlines by returning list."""
258258

259-
def new_decoder(inp_tokens: List[int]) -> List[str]:
259+
def new_decoder(inp_tokens: list[int]) -> list[str]:
260260
if (isinstance(inp_tokens, list) and len(inp_tokens) == 1
261261
and isinstance(inp_tokens[0], list)):
262262
inp_tokens = inp_tokens[0]

vllm/model_executor/guided_decoding/xgrammar_decoding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import json
77
import re
88
from dataclasses import dataclass, field
9-
from typing import TYPE_CHECKING, Any, List
9+
from typing import TYPE_CHECKING, Any
1010

1111
import torch
1212

@@ -273,7 +273,7 @@ def escape_ebnf_string(s: str) -> str:
273273
return re.sub(r'(["\\])', r'\\\1', s)
274274

275275
@staticmethod
276-
def choice_as_grammar(choice: List[str] | None) -> str:
276+
def choice_as_grammar(choice: list[str] | None) -> str:
277277
if choice is None:
278278
raise ValueError("Choice is not set")
279279
escaped_choices = (GrammarConfig.escape_ebnf_string(c) for c in choice)

vllm/model_executor/pooling_metadata.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
from dataclasses import dataclass
4-
from typing import Any, Dict, List, Tuple
4+
from typing import Any
55

66
import torch
77

@@ -23,9 +23,9 @@ class PoolingMetadata:
2323

2424
def __init__(
2525
self,
26-
seq_groups: List[Tuple[List[int], PoolingParams]],
27-
seq_data: Dict[int, Any], # Specific data related to sequences
28-
prompt_lens: List[int],
26+
seq_groups: list[tuple[list[int], PoolingParams]],
27+
seq_data: dict[int, Any], # Specific data related to sequences
28+
prompt_lens: list[int],
2929
) -> None:
3030
self.seq_groups = seq_groups
3131
self.seq_data = seq_data

vllm/model_executor/sampling_metadata.py

Lines changed: 44 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from array import array
44
from dataclasses import dataclass
5-
from typing import Dict, List, Optional, Tuple
5+
from typing import Optional
66

77
import torch
88

@@ -25,10 +25,10 @@ class SequenceGroupToSample:
2525
# |-- query_len ---|
2626

2727
# Sequence ids for the sequence group in a previous step.
28-
seq_ids: List[int]
28+
seq_ids: list[int]
2929
sampling_params: SamplingParams
3030
# seq_id -> sequence data.
31-
seq_data: Dict[int, SequenceData]
31+
seq_data: dict[int, SequenceData]
3232
# The length of the sequence (all tokens seen in the past + new token to
3333
# compute attention) of the sequence group. None if it is in a decode
3434
# stage.
@@ -44,9 +44,9 @@ class SequenceGroupToSample:
4444
is_prompt: bool
4545
# Query token indices from logits. to compute prompt logprob. Empty if
4646
# prompt logprob is not required.
47-
prompt_logprob_indices: List[int]
47+
prompt_logprob_indices: list[int]
4848
# Sample token indices from logits. Empty if sampling is not required.
49-
sample_indices: List[int]
49+
sample_indices: list[int]
5050

5151
@property
5252
def do_sample(self):
@@ -78,7 +78,7 @@ class SamplingMetadataCache:
7878
"""Used to cache SamplingMetadata objects between scheduler iterations"""
7979

8080
def __init__(self):
81-
self._seq_group_to_sample_cache: Dict[int, PyObjectCache] = {}
81+
self._seq_group_to_sample_cache: dict[int, PyObjectCache] = {}
8282

8383
def get_cached_seq_group_to_sample(self, num_seqs):
8484
if num_seqs not in self._seq_group_to_sample_cache:
@@ -130,9 +130,9 @@ def sample(logits):
130130

131131
def __init__(
132132
self,
133-
seq_groups: List[SequenceGroupToSample],
133+
seq_groups: list[SequenceGroupToSample],
134134
selected_token_indices: torch.Tensor,
135-
categorized_sample_indices: Dict[SamplingType, torch.Tensor],
135+
categorized_sample_indices: dict[SamplingType, torch.Tensor],
136136
num_prompts: int,
137137
skip_sampler_cpu_output: bool = False,
138138
reuse_sampling_tensors: bool = False,
@@ -146,12 +146,12 @@ def __init__(
146146

147147
@staticmethod
148148
def prepare(
149-
seq_group_metadata_list: List[SequenceGroupMetadata],
150-
seq_lens: List[int],
151-
query_lens: List[int],
149+
seq_group_metadata_list: list[SequenceGroupMetadata],
150+
seq_lens: list[int],
151+
query_lens: list[int],
152152
device: str,
153153
pin_memory: bool,
154-
generators: Optional[Dict[str, torch.Generator]] = None,
154+
generators: Optional[dict[str, torch.Generator]] = None,
155155
cache: Optional[SamplingMetadataCache] = None,
156156
) -> "SamplingMetadata":
157157
(
@@ -195,16 +195,16 @@ def __repr__(self) -> str:
195195

196196

197197
def _prepare_seq_groups(
198-
seq_group_metadata_list: List[SequenceGroupMetadata],
199-
seq_lens: List[int],
200-
query_lens: List[int],
198+
seq_group_metadata_list: list[SequenceGroupMetadata],
199+
seq_lens: list[int],
200+
query_lens: list[int],
201201
device: str,
202-
generators: Optional[Dict[str, torch.Generator]] = None,
202+
generators: Optional[dict[str, torch.Generator]] = None,
203203
cache: Optional[SamplingMetadataCache] = None,
204-
) -> Tuple[
205-
List[SequenceGroupToSample],
206-
List[int],
207-
Dict[SamplingType, List[int]],
204+
) -> tuple[
205+
list[SequenceGroupToSample],
206+
list[int],
207+
dict[SamplingType, list[int]],
208208
int,
209209
]:
210210
"""Prepare sequence groups and indices for sampling.
@@ -227,17 +227,17 @@ def _prepare_seq_groups(
227227
num_prompts: Total number of prompts from `seq_group_metadata_list`.
228228
"""
229229
# Batched sequence groups for the current model forward stsep.
230-
seq_groups: List[SequenceGroupToSample] = []
230+
seq_groups: list[SequenceGroupToSample] = []
231231
# A list of token indices to sample/compute logprob. It is used to
232232
# prune the outcome logits from the model for the performance.
233-
selected_token_indices: List[int] = []
233+
selected_token_indices: list[int] = []
234234
# Used for selected_token_indices.
235235
model_output_idx = 0
236236

237237
# Sampling type -> (
238238
# indices to sample/prompt logprob within pruned output logits,
239239
# indices to sample within pruned logits)
240-
categorized_sample_indices: Dict[SamplingType, List[int]] = {
240+
categorized_sample_indices: dict[SamplingType, list[int]] = {
241241
t: []
242242
for t in SamplingType
243243
}
@@ -265,9 +265,9 @@ def _prepare_seq_groups(
265265
# If the current seq group is in decode stage, it is None.
266266
seq_len: Optional[int] = None
267267
query_len: Optional[int] = None
268-
prompt_logprob_indices: List[int] = (sample_obj.prompt_logprob_indices
268+
prompt_logprob_indices: list[int] = (sample_obj.prompt_logprob_indices
269269
if cache is not None else [])
270-
sample_indices: List[int] = (sample_obj.sample_indices
270+
sample_indices: list[int] = (sample_obj.sample_indices
271271
if cache is not None else [])
272272
do_sample = seq_group_metadata.do_sample
273273

@@ -389,16 +389,16 @@ def from_sampling_metadata(
389389
vocab_size: int,
390390
device: torch.device,
391391
dtype: torch.dtype,
392-
) -> Tuple["SamplingTensors", bool, bool, bool]:
393-
prompt_tokens: List[array] = []
394-
output_tokens: List[array] = []
395-
top_ks: List[int] = []
396-
temperatures: List[float] = []
397-
top_ps: List[float] = []
398-
min_ps: List[float] = []
399-
presence_penalties: List[float] = []
400-
frequency_penalties: List[float] = []
401-
repetition_penalties: List[float] = []
392+
) -> tuple["SamplingTensors", bool, bool, bool]:
393+
prompt_tokens: list[array] = []
394+
output_tokens: list[array] = []
395+
top_ks: list[int] = []
396+
temperatures: list[float] = []
397+
top_ps: list[float] = []
398+
min_ps: list[float] = []
399+
presence_penalties: list[float] = []
400+
frequency_penalties: list[float] = []
401+
repetition_penalties: list[float] = []
402402
do_penalties = False
403403
do_top_p_top_k = False
404404
do_min_p = False
@@ -496,15 +496,15 @@ def from_sampling_metadata(
496496
@classmethod
497497
def from_lists(
498498
cls,
499-
temperatures: List[float],
500-
top_ps: List[float],
501-
top_ks: List[int],
502-
min_ps: List[float],
503-
presence_penalties: List[float],
504-
frequency_penalties: List[float],
505-
repetition_penalties: List[float],
506-
prompt_tokens: List[array],
507-
output_tokens: List[array],
499+
temperatures: list[float],
500+
top_ps: list[float],
501+
top_ks: list[int],
502+
min_ps: list[float],
503+
presence_penalties: list[float],
504+
frequency_penalties: list[float],
505+
repetition_penalties: list[float],
506+
prompt_tokens: list[array],
507+
output_tokens: list[array],
508508
vocab_size: int,
509509
device: torch.device,
510510
dtype: torch.dtype,

vllm/model_executor/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
"""Utils for model executor."""
3-
from typing import Any, Dict, Optional
3+
from typing import Any, Optional
44

55
import torch
66

@@ -12,7 +12,7 @@ def set_random_seed(seed: int) -> None:
1212

1313
def set_weight_attrs(
1414
weight: torch.Tensor,
15-
weight_attrs: Optional[Dict[str, Any]],
15+
weight_attrs: Optional[dict[str, Any]],
1616
):
1717
"""Set attributes on a weight tensor.
1818

0 commit comments

Comments
 (0)