1
1
from tinygrad import Tensor , dtypes
2
- from tinygrad .helpers import DEBUG
2
+ from tinygrad .helpers import DEBUG , getenv
3
3
import torch , pathlib
4
+ torch .autograd .grad_mode .set_multithreading_enabled (False )
5
+
6
+ # https://pytorch.org/docs/stable/torch.compiler_ir.html
4
7
5
8
# TODO: don't replicate this in cpp
6
9
torch_to_tiny_dtype = {
7
10
torch .float32 : dtypes .float32 ,
8
11
torch .float64 : dtypes .float64 ,
12
+ torch .uint8 : dtypes .uint8 ,
13
+ torch .int8 : dtypes .int8 ,
9
14
torch .int32 : dtypes .int32 ,
10
15
torch .int64 : dtypes .int64 ,
11
16
torch .bool : dtypes .bool ,
12
17
}
13
18
14
19
import torch .utils .cpp_extension
15
20
mod = torch .utils .cpp_extension .load (name = "custom_device_extension" , sources = [pathlib .Path (__file__ ).parent / "wrapped_tensor.cpp" ])
16
- wrap , unwrap = mod .wrap , mod .unwrap
21
+ def wrap (x :Tensor ) -> torch .Tensor : return mod .wrap (x )
22
+ def unwrap (x :torch .Tensor ) -> Tensor :
23
+ assert isinstance (x , torch .Tensor ), f"x isn't { type (x )} "
24
+ return mod .unwrap (x )
17
25
class TinyBackend : pass
18
26
torch .utils .rename_privateuse1_backend ("tiny" )
19
27
torch ._register_device_module ("tiny" , TinyBackend )
20
28
torch .utils .generate_methods_for_privateuse1_backend ()
21
29
22
- @torch .library .impl ("aten::view" , "privateuseone" )
23
- def view (x , sz ): return mod .wrap (mod .unwrap (x ).reshape (sz ))
24
-
25
- @torch .library .impl ("aten::min" , "privateuseone" )
26
- def min (x ): return mod .wrap (mod .unwrap (x ).min ())
27
-
28
- @torch .library .impl ("aten::max" , "privateuseone" )
29
- def max (x ): return mod .wrap (mod .unwrap (x ).max ())
30
-
31
30
@torch .library .impl ("aten::zero_" , "privateuseone" )
32
31
def zero_ (x ):
33
- tt = mod . unwrap (x )
32
+ tt = unwrap (x )
34
33
tt .replace (tt .zeros_like ())
35
34
36
35
@torch .library .impl ("aten::fill_.Scalar" , "privateuseone" )
@@ -51,11 +50,14 @@ def as_strided(tensor, size, stride, storage_offset=None):
51
50
if size == [] and storage_offset is not None :
52
51
# TODO: is this right?
53
52
return wrap (unwrap (tensor ).flatten ()[storage_offset :storage_offset + 1 ].reshape (()))
54
- print (tensor .shape , size , stride , storage_offset )
53
+ # broadcast
54
+ if len (tensor .shape ) == 0 : return wrap (unwrap (tensor ).reshape ((1 ,)* len (size )).expand (size ))
55
+ print ("******* NOTE: this as_strided is wrong ***********\n " , tensor .shape , size , stride , storage_offset )
56
+ return wrap (Tensor .zeros (* size ))
55
57
raise NotImplementedError ("fix as_strided" )
56
58
57
59
@torch .library .impl ("aten::empty_strided" , "privateuseone" )
58
- def empty_strided (size , stride , dtype , layout , device , pin_memory ):
60
+ def empty_strided (size , stride , dtype , layout , device , pin_memory = False ):
59
61
if DEBUG >= 2 : print (f"empty_strided { size = } { stride = } { dtype = } { layout = } { device = } { pin_memory = } " )
60
62
ret = Tensor .empty (* size , dtype = torch_to_tiny_dtype [dtype ])
61
63
return wrap (ret )
@@ -68,49 +70,73 @@ def empty_memory_format(size, dtype=None, layout=None, device=None, pin_memory=F
68
70
69
71
@torch .library .impl ("aten::convolution_overrideable" , "privateuseone" )
70
72
def convolution_overrideable (input , weight , bias , stride , padding , dilation , transposed , output_padding , groups ):
71
- print (input , weight , bias )
72
- raise NotImplementedError
73
+ #print(f"{input.shape=} {weight.shape=} {bias.shape=} {stride=} {padding=} {dilation=} {transposed=} {output_padding=} {groups=}")
74
+ return wrap (unwrap (input ).conv2d (unwrap (weight ), unwrap (bias ) if bias is not None else None ,
75
+ groups = groups , stride = stride , dilation = dilation , padding = padding ))
76
+ #raise NotImplementedError("need convolution")
73
77
74
78
@torch .library .impl ("aten::_copy_from" , "privateuseone" )
75
79
def _copy_from (src , dest ):
76
80
if str (src .device ) == "tiny" and str (dest .device ) == "tiny" :
77
81
unwrap (dest ).replace (unwrap (src ), allow_shape_mismatch = True )
78
82
elif str (src .device ) == "tiny" and str (dest .device ) == "cpu" :
79
- dest [:] = torch .from_numpy (unwrap (src ).numpy ())
83
+ # TODO: is there a better way?
84
+ dest .resize_ (src .numel ()).resize_ (src .shape )
85
+ dest .copy_ (torch .from_numpy (unwrap (src ).numpy ()))
80
86
elif str (src .device ) == "cpu" and str (dest .device ) == "tiny" :
81
87
unwrap (dest ).assign (Tensor (src .numpy ()))
82
88
else :
83
89
raise NotImplementedError (f"can't copy from { src .device } -> { dest .device } " )
84
90
85
- @torch .library .impl ("aten::exp2.out" , "privateuseone" )
86
- def exp2_out (x , out ): unwrap (out ).replace (unwrap (x ).exp2 (), allow_shape_mismatch = True )
87
-
88
- @torch .library .impl ("aten::ceil.out" , "privateuseone" )
89
- def ceil_out (x , out ): unwrap (out ).replace (unwrap (x ).ceil (), allow_shape_mismatch = True )
90
-
91
- @torch .library .impl ("aten::abs.out" , "privateuseone" )
92
- def abs_out (x , out ): unwrap (out ).replace (unwrap (x ).abs (), allow_shape_mismatch = True )
93
-
94
- @torch .library .impl ("aten::bitwise_and.Tensor" , "privateuseone" )
95
- def bitwise_and_tensor (x , y ): return wrap (unwrap (x ) & unwrap (y ))
96
-
97
- @torch .library .impl ("aten::add.Tensor" , "privateuseone" )
98
- def add_tensor (x , y ): return wrap (unwrap (x ) + unwrap (y ))
99
-
100
- @torch .library .impl ("aten::mul.Tensor" , "privateuseone" )
101
- def mul_tensor (x , y ): return wrap (unwrap (x ) * unwrap (y ))
102
-
103
- @torch .library .impl ("aten::div.Tensor" , "privateuseone" )
104
- def div_tensor (x , y ): return wrap (unwrap (x ) / unwrap (y ))
105
-
106
- @torch .library .impl ("aten::eq.Tensor" , "privateuseone" )
107
- def eq_tensor (x , y ): return wrap (unwrap (x ).eq (unwrap (y )))
108
-
109
- @torch .library .impl ("aten::ne.Tensor" , "privateuseone" )
110
- def ne_tensor (x , y ): return wrap (unwrap (x ).ne (unwrap (y )))
111
-
112
- @torch .library .impl ("aten::ne.Scalar" , "privateuseone" )
113
- def ne_scalar (x , y ): return wrap (unwrap (x ).ne (y ))
91
+ @torch .library .impl ("aten::cat.out" , "privateuseone" )
92
+ def cat_out (tensors , out , dim = 0 ): unwrap (out ).replace (Tensor .cat (* [unwrap (x ) for x in tensors ], dim = dim ), allow_shape_mismatch = True )
93
+
94
+ @torch .library .impl ("aten::index.Tensor" , "privateuseone" )
95
+ def index_tensor (x , y ): return wrap (unwrap (x )[y [0 ].tolist ()])
96
+
97
+ tiny_backend = {
98
+ "aten.view" : Tensor .reshape ,
99
+ "aten.add.Tensor" : Tensor .add ,
100
+ "aten.sub.Tensor" : Tensor .sub ,
101
+ "aten.mul.Tensor" : Tensor .mul ,
102
+ "aten.div.Tensor" : Tensor .div ,
103
+ "aten.add_.Tensor" : lambda x ,y : x .assign (x .add (y )),
104
+ "aten.pow.Tensor_Scalar" : Tensor .pow ,
105
+ "aten.bitwise_and.Tensor" : Tensor .bitwise_and ,
106
+ "aten.eq.Tensor" : Tensor .eq , "aten.eq.Scalar" : Tensor .eq ,
107
+ "aten.ne.Tensor" : Tensor .ne , "aten.ne.Scalar" : Tensor .ne ,
108
+ "aten.gt.Tensor" : Tensor .__gt__ , "aten.gt.Scalar" : Tensor .__gt__ ,
109
+ "aten.lt.Tensor" : Tensor .__lt__ , "aten.lt.Scalar" : Tensor .__lt__ ,
110
+ "aten.exp2" : Tensor .exp2 ,
111
+ "aten.min" : Tensor .min ,
112
+ "aten.max" : Tensor .max ,
113
+ "aten.relu" : Tensor .relu ,
114
+ "aten.mean" : Tensor .mean ,
115
+ "aten.neg" : Tensor .neg ,
116
+ "aten.mm" : Tensor .matmul ,
117
+ }
114
118
115
- @torch .library .impl ("aten::gt.Scalar" , "privateuseone" )
116
- def gt_scalar (x , y ): return wrap (unwrap (x ) > y )
119
+ # there's earlier things to hook here
120
+ #"aten.add.out": lambda x,y,out: out.replace(x+y, allow_shape_mismatch=True),
121
+ #"aten.abs.out": lambda x,out: out.replace(x.abs(), allow_shape_mismatch=True),
122
+ #"aten.ceil.out": lambda x,out: out.replace(x.ceil(), allow_shape_mismatch=True),
123
+ #"aten.exp2.out": lambda x,out: out.replace(x.exp2(), allow_shape_mismatch=True),
124
+
125
+ def wrap_fxn (k ,f ):
126
+ def nf (* args , ** kwargs ):
127
+ #print(k, len(args), kwargs.keys())
128
+ args = [unwrap (x ) if isinstance (x , torch .Tensor ) else x for x in args ]
129
+ kwargs = {k :unwrap (v ) if isinstance (v , torch .Tensor ) else v for k ,v in kwargs .items ()}
130
+ return wrap (f (* args , ** kwargs ))
131
+ return nf
132
+
133
+ for k ,v in tiny_backend .items (): torch .library .impl (k .replace ("aten." , "aten::" ), "privateuseone" )(wrap_fxn (k ,v ))
134
+
135
+ if getenv ("TORCH_DEBUG" ):
136
+ from torch .utils ._python_dispatch import TorchDispatchMode
137
+ class DispatchLog (TorchDispatchMode ):
138
+ def __torch_dispatch__ (self , func , types , args , kwargs = None ):
139
+ #print(f"Dispatch Log: {func}(*{args}, **{kwargs})")
140
+ print (f"Dispatch Log: { func } " )
141
+ return func (* args , ** (kwargs or {}))
142
+ DispatchLog ().__enter__ ()
0 commit comments