20
20
# helpers
21
21
22
22
23
- def to_fp8 (tensor : torch .tensor ) -> torch .tensor :
23
+ def to_fp8 (tensor : torch .Tensor ) -> torch .Tensor :
24
24
finfo = torch .finfo (torch .float8_e4m3fn )
25
25
return torch .round (tensor .clamp (
26
26
min = finfo .min , max = finfo .max )).to (dtype = torch .float8_e4m3fn )
27
27
28
28
29
- def to_int8 (tensor : torch .tensor ) -> torch .tensor :
29
+ def to_int8 (tensor : torch .Tensor ) -> torch .Tensor :
30
30
return torch .round (tensor .clamp (min = - 128 , max = 127 )).to (dtype = torch .int8 )
31
31
32
32
33
33
def make_rand_tensors (dtype : torch .dtype , m : int , n : int ,
34
- k : int ) -> Tuple [torch .tensor , torch .tensor ]:
34
+ k : int ) -> Tuple [torch .Tensor , torch .Tensor ]:
35
35
36
36
a = torch .randn ((m , k ), device = 'cuda' ) * 5
37
37
b = torch .randn ((n , k ), device = 'cuda' ).t () * 5
@@ -47,25 +47,25 @@ def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
47
47
# impl
48
48
49
49
50
- def pytorch_mm_impl (a : torch .tensor , b : torch .tensor , scale_a : torch .tensor ,
51
- scale_b : torch .tensor ,
52
- out_dtype : torch .dtype ) -> torch .tensor :
50
+ def pytorch_mm_impl (a : torch .Tensor , b : torch .Tensor , scale_a : torch .Tensor ,
51
+ scale_b : torch .Tensor ,
52
+ out_dtype : torch .dtype ) -> torch .Tensor :
53
53
return torch .mm (a , b )
54
54
55
55
56
- def pytorch_fp8_impl (a : torch .tensor , b : torch .tensor , scale_a : torch .tensor ,
57
- scale_b : torch .tensor ,
58
- out_dtype : torch .dtype ) -> torch .tensor :
56
+ def pytorch_fp8_impl (a : torch .Tensor , b : torch .Tensor , scale_a : torch .Tensor ,
57
+ scale_b : torch .Tensor ,
58
+ out_dtype : torch .dtype ) -> torch .Tensor :
59
59
return torch ._scaled_mm (a ,
60
60
b ,
61
61
scale_a = scale_a ,
62
62
scale_b = scale_b ,
63
63
out_dtype = out_dtype )
64
64
65
65
66
- def pytorch_fp8_impl_fast_accum (a : torch .tensor , b : torch .tensor ,
67
- scale_a : torch .tensor , scale_b : torch .tensor ,
68
- out_dtype : torch .dtype ) -> torch .tensor :
66
+ def pytorch_fp8_impl_fast_accum (a : torch .Tensor , b : torch .Tensor ,
67
+ scale_a : torch .Tensor , scale_b : torch .Tensor ,
68
+ out_dtype : torch .dtype ) -> torch .Tensor :
69
69
return torch ._scaled_mm (a ,
70
70
b ,
71
71
scale_a = scale_a ,
@@ -74,15 +74,15 @@ def pytorch_fp8_impl_fast_accum(a: torch.tensor, b: torch.tensor,
74
74
use_fast_accum = True )
75
75
76
76
77
- def cutlass_impl (a : torch .tensor , b : torch .tensor , scale_a : torch .tensor ,
78
- scale_b : torch .tensor ,
79
- out_dtype : torch .dtype ) -> torch .tensor :
77
+ def cutlass_impl (a : torch .Tensor , b : torch .Tensor , scale_a : torch .Tensor ,
78
+ scale_b : torch .Tensor ,
79
+ out_dtype : torch .dtype ) -> torch .Tensor :
80
80
return ops .cutlass_scaled_mm (a , b , scale_a , scale_b , out_dtype = out_dtype )
81
81
82
82
83
83
# bench
84
- def bench_fn (a : torch .tensor , b : torch .tensor , scale_a : torch .tensor ,
85
- scale_b : torch .tensor , out_dtype : torch .dtype , label : str ,
84
+ def bench_fn (a : torch .Tensor , b : torch .Tensor , scale_a : torch .Tensor ,
85
+ scale_b : torch .Tensor , out_dtype : torch .dtype , label : str ,
86
86
sub_label : str , fn : Callable , description : str ) -> TMeasurement :
87
87
88
88
min_run_time = 1
0 commit comments