1
1
from tinygrad import Tensor , dtypes
2
- from tinygrad .helpers import DEBUG , getenv
2
+ from tinygrad .helpers import DEBUG , getenv , prod
3
+ TORCH_DEBUG = getenv ("TORCH_DEBUG" )
3
4
import torch , pathlib
4
5
torch .autograd .grad_mode .set_multithreading_enabled (False )
5
6
@@ -46,31 +47,44 @@ def masked_select(self, mask):
46
47
return wrap (Tensor (self .cpu ().numpy ()[mask .cpu ().numpy ()]))
47
48
48
49
@torch .library .impl ("aten::as_strided" , "privateuseone" )
49
- def as_strided (tensor , size , stride , storage_offset = None ):
50
- if size == [] and storage_offset is not None :
51
- # TODO: is this right?
52
- return wrap (unwrap (tensor ).flatten ()[storage_offset :storage_offset + 1 ].reshape (()))
53
- # broadcast
54
- if len (tensor .shape ) == 0 : return wrap (unwrap (tensor ).reshape ((1 ,)* len (size )).expand (size ))
55
- print ("******* NOTE: this as_strided is wrong ***********\n " , tensor .shape , size , stride , storage_offset )
56
- return wrap (Tensor .zeros (* size ))
50
+ def as_strided (tensor :torch .Tensor , size , stride , storage_offset = None ):
51
+ #return tensor.cpu().as_strided(size, stride).tiny()
52
+ if TORCH_DEBUG >= 1 : print ("** NOTE: this as_strided is wrong" , tensor .shape , size , stride , storage_offset )
53
+
54
+ if tuple (x for x in tensor .shape if x != 1 ) == tuple (x for x in size if x != 1 ):
55
+ # this is squeeze/unsqueeze
56
+ return tensor .reshape (size )
57
+
58
+ # TODO: how do i know this is permute?
59
+ if tensor .shape == (1000 , 512 ) and size == [512 , 1000 ] and stride == [0 , 1 ]:
60
+ return wrap (unwrap (tensor ).permute (1 ,0 ))
61
+
62
+ #print(tensor.cpu().numpy())
57
63
raise NotImplementedError ("fix as_strided" )
58
64
59
65
@torch .library .impl ("aten::empty_strided" , "privateuseone" )
60
- def empty_strided (size , stride , dtype , layout , device , pin_memory = False ):
61
- if DEBUG >= 2 : print (f"empty_strided { size = } { stride = } { dtype = } { layout = } { device = } { pin_memory = } " )
66
+ def empty_strided (size , stride , dtype , layout = None , device = None , pin_memory = False ):
67
+ if TORCH_DEBUG : print (f"empty_strided { size = } { stride = } { dtype = } { layout = } { device = } { pin_memory = } " )
62
68
ret = Tensor .empty (* size , dtype = torch_to_tiny_dtype [dtype ])
63
69
return wrap (ret )
64
70
65
71
@torch .library .impl ("aten::empty.memory_format" , "privateuseone" )
66
72
def empty_memory_format (size , dtype = None , layout = None , device = None , pin_memory = False , memory_format = None ):
67
- if DEBUG >= 2 : print (f"empty.memory_format { size = } { dtype = } { layout = } { device = } { pin_memory = } { memory_format = } " )
68
- ret = Tensor .empty (* size , dtype = torch_to_tiny_dtype [dtype ])
73
+ if TORCH_DEBUG : print (f"empty.memory_format { size = } { dtype = } { layout = } { device = } { pin_memory = } { memory_format = } " )
74
+ ret = Tensor .empty (* size , dtype = torch_to_tiny_dtype [dtype or torch . get_default_dtype () ])
69
75
return wrap (ret )
70
76
77
+ @torch .library .impl ("aten::max_pool2d_with_indices" , "privateuseone" )
78
+ def max_pool2d_with_indices (self :Tensor , kernel_size , stride = None , padding = 0 , dilation = 1 , ceil_mode = False ):
79
+ # TODO: support return_indices in tinygrad
80
+ ret = unwrap (self ).max_pool2d (kernel_size , stride , dilation , padding , ceil_mode )
81
+ # TODO: this is wrong
82
+ return (wrap (ret ), wrap (Tensor .zeros_like (ret , dtype = dtypes .int64 )))
83
+
71
84
@torch .library .impl ("aten::convolution_overrideable" , "privateuseone" )
72
85
def convolution_overrideable (input , weight , bias , stride , padding , dilation , transposed , output_padding , groups ):
73
- #print(f"{input.shape=} {weight.shape=} {bias.shape=} {stride=} {padding=} {dilation=} {transposed=} {output_padding=} {groups=}")
86
+ if TORCH_DEBUG >= 1 :
87
+ print (f"convolution { input .shape = } { weight .shape = } { stride = } { padding = } { dilation = } { transposed = } { output_padding = } { groups = } " )
74
88
return wrap (unwrap (input ).conv2d (unwrap (weight ), unwrap (bias ) if bias is not None else None ,
75
89
groups = groups , stride = stride , dilation = dilation , padding = padding ))
76
90
#raise NotImplementedError("need convolution")
@@ -94,45 +108,89 @@ def cat_out(tensors, out, dim=0): unwrap(out).replace(Tensor.cat(*[unwrap(x) for
94
108
@torch .library .impl ("aten::index.Tensor" , "privateuseone" )
95
109
def index_tensor (x , y ): return wrap (unwrap (x )[y [0 ].tolist ()])
96
110
111
+ # register some decompositions
112
+ from torch ._decomp import get_decompositions
113
+ aten = torch .ops .aten
114
+ decomps = [
115
+ aten .native_batch_norm ,
116
+ aten .addmm ,
117
+ # NOTE: many of these don't work or cause infinite loops
118
+ #aten.var_mean,
119
+ #aten.var,
120
+ #aten.rsqrt,
121
+ #aten.max_pool2d_with_indices,
122
+ ]
123
+ for k ,v in get_decompositions (decomps ).items ():
124
+ key = str (k ._schema ).split ("(" )[0 ]
125
+ if TORCH_DEBUG >= 2 : print ("register decomp for" , k )
126
+ torch .library .impl (key , "privateuseone" )(v )
127
+
97
128
tiny_backend = {
98
129
"aten.view" : Tensor .reshape ,
99
130
"aten.add.Tensor" : Tensor .add ,
100
131
"aten.sub.Tensor" : Tensor .sub ,
101
132
"aten.mul.Tensor" : Tensor .mul ,
102
133
"aten.div.Tensor" : Tensor .div ,
103
- "aten.add_.Tensor" : lambda x ,y : x .assign (x .add (y )),
134
+ "aten.add_.Tensor" : lambda x ,y , alpha = 1 : x .assign (x .add (y )* alpha ),
104
135
"aten.pow.Tensor_Scalar" : Tensor .pow ,
105
136
"aten.bitwise_and.Tensor" : Tensor .bitwise_and ,
106
137
"aten.eq.Tensor" : Tensor .eq , "aten.eq.Scalar" : Tensor .eq ,
107
138
"aten.ne.Tensor" : Tensor .ne , "aten.ne.Scalar" : Tensor .ne ,
108
139
"aten.gt.Tensor" : Tensor .__gt__ , "aten.gt.Scalar" : Tensor .__gt__ ,
109
140
"aten.lt.Tensor" : Tensor .__lt__ , "aten.lt.Scalar" : Tensor .__lt__ ,
141
+ "aten.le.Tensor" : Tensor .__le__ , "aten.le.Scalar" : Tensor .__le__ ,
142
+ "aten.abs" : Tensor .abs ,
143
+ "aten.exp" : Tensor .exp ,
110
144
"aten.exp2" : Tensor .exp2 ,
111
145
"aten.min" : Tensor .min ,
112
146
"aten.max" : Tensor .max ,
113
147
"aten.relu" : Tensor .relu ,
148
+ "aten.relu_" : lambda x : x .assign (x .relu ()),
114
149
"aten.mean" : Tensor .mean ,
150
+ "aten.mean.dim" : Tensor .mean ,
115
151
"aten.neg" : Tensor .neg ,
152
+ "aten.reciprocal" : Tensor .reciprocal ,
153
+ "aten.sqrt" : Tensor .sqrt ,
154
+ "aten.rsqrt" : Tensor .rsqrt ,
116
155
"aten.mm" : Tensor .matmul ,
156
+ "aten.var.correction" : Tensor .var ,
157
+ # TODO: support var_mean in tinygrad
158
+ "aten.var_mean.correction" : lambda self , dims , keepdim = False , correction = 1 : (self .var (dims , keepdim , correction ), self .mean (dims , keepdim )),
159
+ # NOTE: axis=[] in torch means all, change tinygrad?
160
+ "aten.sum.IntList_out" : lambda self ,axis ,keepdim = False ,out = None :
161
+ out .replace (Tensor .sum (self , axis if len (axis ) else None , keepdim ), allow_shape_mismatch = True ),
162
+ "aten.argmax" : Tensor .argmax ,
163
+ "aten.scatter.value" : Tensor .scatter ,
164
+ "aten.gather" : Tensor .gather ,
165
+ "aten.where.self" : Tensor .where ,
166
+ "aten._log_softmax" : lambda self ,dim ,half_to_float : self .softmax (dim ),
167
+ "aten.random_" : lambda self :
168
+ self .assign (Tensor .randint (* self .shape , low = dtypes .min (self .dtype ), high = dtypes .max (self .dtype ), device = self .device , dtype = self .dtype )),
169
+ "aten.uniform_" : lambda self , low = 0 , high = 1 : self .assign (Tensor .uniform (* self .shape , low = low , high = high )),
170
+ "aten.normal_" : lambda self , low = 0 , high = 1 : self .assign (Tensor .normal (* self .shape , low = low , high = high )),
117
171
}
118
172
119
- # there's earlier things to hook here
173
+ # NOTE: there's earlier things to hook these, so the .out form isn't needed
120
174
#"aten.add.out": lambda x,y,out: out.replace(x+y, allow_shape_mismatch=True),
121
175
#"aten.abs.out": lambda x,out: out.replace(x.abs(), allow_shape_mismatch=True),
122
176
#"aten.ceil.out": lambda x,out: out.replace(x.ceil(), allow_shape_mismatch=True),
123
177
#"aten.exp2.out": lambda x,out: out.replace(x.exp2(), allow_shape_mismatch=True),
124
178
125
179
def wrap_fxn (k ,f ):
126
180
def nf (* args , ** kwargs ):
127
- #print(k, len(args), kwargs.keys())
181
+ if TORCH_DEBUG : print (k , len (args ), [x .shape if isinstance (x , torch .Tensor ) else x for x in args ],
182
+ {k :v .shape if isinstance (v , torch .Tensor ) else v for k ,v in kwargs .items ()})
128
183
args = [unwrap (x ) if isinstance (x , torch .Tensor ) else x for x in args ]
129
184
kwargs = {k :unwrap (v ) if isinstance (v , torch .Tensor ) else v for k ,v in kwargs .items ()}
130
- return wrap (f (* args , ** kwargs ))
185
+ out = f (* args , ** kwargs )
186
+ if isinstance (out , Tensor ): return wrap (out )
187
+ elif isinstance (out , tuple ): return tuple (wrap (x ) for x in out )
188
+ else : raise RuntimeError (f"unknown output type { type (out )} " )
131
189
return nf
132
190
133
191
for k ,v in tiny_backend .items (): torch .library .impl (k .replace ("aten." , "aten::" ), "privateuseone" )(wrap_fxn (k ,v ))
134
192
135
- if getenv ( " TORCH_DEBUG" ) :
193
+ if TORCH_DEBUG :
136
194
from torch .utils ._python_dispatch import TorchDispatchMode
137
195
class DispatchLog (TorchDispatchMode ):
138
196
def __torch_dispatch__ (self , func , types , args , kwargs = None ):
0 commit comments