Skip to content

Commit 63b9c10

Browse files
WoosukKwonjimpang
authored andcommitted
[Misc] Use torch.Tensor for type annotation (vllm-project#6505)
1 parent a6500d2 commit 63b9c10

File tree

2 files changed

+18
-18
lines changed

2 files changed

+18
-18
lines changed

benchmarks/cutlass_benchmarks/w8a8_benchmarks.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,18 @@
2020
# helpers
2121

2222

23-
def to_fp8(tensor: torch.tensor) -> torch.tensor:
23+
def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
2424
finfo = torch.finfo(torch.float8_e4m3fn)
2525
return torch.round(tensor.clamp(
2626
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
2727

2828

29-
def to_int8(tensor: torch.tensor) -> torch.tensor:
29+
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
3030
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
3131

3232

3333
def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
34-
k: int) -> Tuple[torch.tensor, torch.tensor]:
34+
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
3535

3636
a = torch.randn((m, k), device='cuda') * 5
3737
b = torch.randn((n, k), device='cuda').t() * 5
@@ -47,25 +47,25 @@ def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
4747
# impl
4848

4949

50-
def pytorch_mm_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
51-
scale_b: torch.tensor,
52-
out_dtype: torch.dtype) -> torch.tensor:
50+
def pytorch_mm_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
51+
scale_b: torch.Tensor,
52+
out_dtype: torch.dtype) -> torch.Tensor:
5353
return torch.mm(a, b)
5454

5555

56-
def pytorch_fp8_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
57-
scale_b: torch.tensor,
58-
out_dtype: torch.dtype) -> torch.tensor:
56+
def pytorch_fp8_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
57+
scale_b: torch.Tensor,
58+
out_dtype: torch.dtype) -> torch.Tensor:
5959
return torch._scaled_mm(a,
6060
b,
6161
scale_a=scale_a,
6262
scale_b=scale_b,
6363
out_dtype=out_dtype)
6464

6565

66-
def pytorch_fp8_impl_fast_accum(a: torch.tensor, b: torch.tensor,
67-
scale_a: torch.tensor, scale_b: torch.tensor,
68-
out_dtype: torch.dtype) -> torch.tensor:
66+
def pytorch_fp8_impl_fast_accum(a: torch.Tensor, b: torch.Tensor,
67+
scale_a: torch.Tensor, scale_b: torch.Tensor,
68+
out_dtype: torch.dtype) -> torch.Tensor:
6969
return torch._scaled_mm(a,
7070
b,
7171
scale_a=scale_a,
@@ -74,15 +74,15 @@ def pytorch_fp8_impl_fast_accum(a: torch.tensor, b: torch.tensor,
7474
use_fast_accum=True)
7575

7676

77-
def cutlass_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
78-
scale_b: torch.tensor,
79-
out_dtype: torch.dtype) -> torch.tensor:
77+
def cutlass_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
78+
scale_b: torch.Tensor,
79+
out_dtype: torch.dtype) -> torch.Tensor:
8080
return ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype=out_dtype)
8181

8282

8383
# bench
84-
def bench_fn(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
85-
scale_b: torch.tensor, out_dtype: torch.dtype, label: str,
84+
def bench_fn(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
85+
scale_b: torch.Tensor, out_dtype: torch.dtype, label: str,
8686
sub_label: str, fn: Callable, description: str) -> TMeasurement:
8787

8888
min_run_time = 1

vllm/worker/worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def __init__(
105105
# initialize_cache.
106106
self.cache_engine: List[CacheEngine]
107107
# Initialize gpu_cache as embedding models don't initialize kv_caches
108-
self.gpu_cache: Optional[List[List[torch.tensor]]] = None
108+
self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
109109

110110
def init_device(self) -> None:
111111
if self.device_config.device.type == "cuda":

0 commit comments

Comments
 (0)