@@ -31,6 +31,7 @@ def device_count(self): return getenv("GPUS", 1) # TODO: device count in tiny?
31
31
torch .utils .rename_privateuse1_backend ("tiny" )
32
32
torch ._register_device_module ("tiny" , TinyBackend ())
33
33
torch .utils .generate_methods_for_privateuse1_backend ()
34
+ aten = torch .ops .aten
34
35
35
36
# in place operations with views
36
37
def is_view (self : torch .Tensor ) -> bool : return getattr (self , "_base" , None ) is not None
@@ -75,9 +76,37 @@ def _index_put_impl_(self, indices, values, accumulate=False, unsafe=False):
75
76
def index_tensor (x , y ):
76
77
return aten .index (x .cpu (), [z .cpu () if isinstance (z , torch .Tensor ) else None for z in y ]).to (x .device )
77
78
79
+ @torch .library .impl ("aten::index_put" , "privateuseone" )
80
+ def index_put (self , indices , values , accumulate = False ):
81
+ return aten .index_put (self .cpu (), [z .cpu () if isinstance (z , torch .Tensor ) else None for z in indices ], values .cpu (), accumulate ).tiny ()
82
+
78
83
@torch .library .impl ("aten::randperm.generator_out" , "privateuseone" )
79
84
def randperm_generator (n , generator = None , out = None ): out .copy_ (torch .randperm (n , generator = generator , device = "cpu" ).tiny ())
80
85
86
+ @torch .library .impl ("aten::cumprod" , "privateuseone" )
87
+ # TODO: move to tinygrad
88
+ def cumprod (self , dim , dtype = None ): return aten .cumprod (self .cpu (), dim , dtype = dtype ).tiny ()
89
+
90
+ @torch .library .impl ("aten::cummax" , "privateuseone" )
91
+ def cummax (self , dim ):
92
+ # TODO: support cummax with indices to match torch
93
+ cummax , indices = aten .cummax (self .cpu (), dim )
94
+ return (cummax .tiny (), indices .tiny ())
95
+
96
+ @torch .library .impl ("aten::nonzero" , "privateuseone" )
97
+ # TODO: move to tinygrad
98
+ def nonzero (self ): return aten .nonzero (self .cpu ()).tiny ()
99
+
100
+ def upsample_backward (grad_out , output_size , input_size , * args , f = None ): return f (grad_out .cpu (), output_size , input_size , * args ).tiny ()
101
+
102
+ for i in [
103
+ "upsample_linear1d_backward" , "upsample_nearest1d_backward" , "_upsample_nearest_exact1d_backward" ,
104
+ "upsample_nearest2d_backward" , "_upsample_nearest_exact2d_backward" ,
105
+ "upsample_nearest3d_backward" , "_upsample_nearest_exact3d_backward" ,
106
+ "upsample_trilinear3d_backward" , "upsample_bilinear2d_backward"
107
+ ]:
108
+ torch .library .impl (f"aten::{ i } " , "privateuseone" )(functools .partial (upsample_backward , f = getattr (aten , i )))
109
+
81
110
# *** end bad functions on CPU ***
82
111
83
112
@torch .library .impl ("aten::zero_" , "privateuseone" )
@@ -162,24 +191,58 @@ def arange_start_step(start, end, step, dtype=None, device=None, pin_memory=None
162
191
def convolution_overrideable (input , weight , bias , stride , padding , dilation , transposed , output_padding , groups ):
163
192
if TORCH_DEBUG >= 1 :
164
193
print (f"convolution { input .shape = } { weight .shape = } { stride = } { padding = } { dilation = } { transposed = } { output_padding = } { groups = } " )
165
- return wrap (unwrap (input ).conv2d (unwrap (weight ), unwrap (bias ) if bias is not None else None ,
166
- groups = groups , stride = stride , dilation = dilation , padding = padding ))
194
+ input , weight , bias = unwrap (input ), unwrap (weight ), unwrap (bias ) if bias is not None else None
195
+ if not transposed : return wrap (input .conv2d (weight , bias , groups = groups , stride = stride , dilation = dilation , padding = padding ))
196
+ return wrap (input .conv_transpose2d (weight , bias , groups = groups , stride = stride , dilation = dilation , padding = padding , output_padding = output_padding ))
167
197
168
198
@torch .library .impl ("aten::convolution_backward_overrideable" , "privateuseone" )
169
199
def convolution_backward_overrideable (grad_out , input , weight , stride , padding , dilation , transposed , output_padding , groups , output_mask ):
170
200
if TORCH_DEBUG >= 1 :
171
201
print (f"convolution_backward { input .shape = } { weight .shape = } { stride = } { padding = } { dilation = } { transposed = } { output_padding = } { groups = } " )
172
202
grad_out , input , weight , bias = unwrap (grad_out ), unwrap (input ), unwrap (weight ), Tensor .zeros (weight .shape [0 ], device = _from_torch_device (weight .device ))
173
- out = Tensor .conv2d (input , weight , bias , groups = groups , stride = stride , dilation = dilation , padding = padding )
203
+ if not transposed : out = Tensor .conv2d (input , weight , bias , groups = groups , stride = stride , dilation = dilation , padding = padding )
204
+ else :
205
+ bias = Tensor .zeros (weight .shape [1 ] * groups )
206
+ out = Tensor .conv_transpose2d (input , weight , bias , groups = groups , stride = stride , dilation = dilation , padding = padding , output_padding = output_padding )
174
207
grads = out .gradient (* [t for t ,m in zip ([input , weight , bias ], output_mask ) if m ], gradient = grad_out )
175
208
return tuple ([wrap (grads .pop (0 )) if m else None for m in output_mask ])
176
209
210
+ def avg_pool (self , kernel_size , stride = [], padding = 0 , ceil_mode = False , count_include_pad = True , divisor_override = None ):
211
+ return wrap (unwrap (self ).avg_pool2d (kernel_size , stride if stride != [] else None , padding = padding , ceil_mode = ceil_mode , count_include_pad = count_include_pad ))
212
+
213
+ def avg_pool_backward (grad_out , self , kernel_size , stride = None , padding = 0 , ceil_mode = False , count_include_pad = True , divisor_override = None ):
214
+ self , grad_out = unwrap (self ), unwrap (grad_out )
215
+ out = Tensor .avg_pool2d (self , kernel_size , stride if stride != [] else None , dilation = 1 , padding = padding , ceil_mode = ceil_mode , count_include_pad = count_include_pad )
216
+ return wrap (out .gradient (self , gradient = grad_out )[0 ])
217
+
218
+ for dim in [2 , 3 ]:
219
+ torch .library .impl (f"aten::avg_pool{ dim } d" , "privateuseone" )(avg_pool )
220
+ torch .library .impl (f"aten::avg_pool{ dim } d_backward" , "privateuseone" )(avg_pool_backward )
221
+
222
+ def pad_forward (self , padding , mode = None ): return wrap (Tensor .pad (unwrap (self ), padding , mode = mode ))
223
+
224
+ def pad_backward (grad_out , self , padding , mode ):
225
+ self , grad_out = unwrap (self ), unwrap (grad_out )
226
+ out = Tensor .pad (self , padding , mode = mode )
227
+ return wrap (out .gradient (self , gradient = grad_out )[0 ])
228
+
229
+ for dim in [1 , 2 , 3 ]:
230
+ for pad_type , mode in [("replication" , "replicate" ), ("reflection" , "reflect" )]:
231
+ torch .library .impl (f"aten::{ pad_type } _pad{ dim } d" , "privateuseone" )(functools .partial (pad_forward , mode = mode ))
232
+ torch .library .impl (f"aten::{ pad_type } _pad{ dim } d_backward" , "privateuseone" )(functools .partial (pad_backward , mode = mode ))
233
+
177
234
def upsample (self , size , align_corners = False , mode = None ): return wrap (Tensor .interpolate (unwrap (self ), size , mode = mode , align_corners = align_corners ))
178
235
for i ,pre in enumerate (["" , "bi" , "tri" ]):
179
236
torch .library .impl (f"aten::upsample_{ pre } linear{ i + 1 } d" , "privateuseone" )(functools .partial (upsample , mode = "linear" ))
180
237
torch .library .impl (f"aten::upsample_nearest{ i + 1 } d" , "privateuseone" )(functools .partial (upsample , mode = "nearest" ))
181
238
torch .library .impl (f"aten::_upsample_nearest_exact{ i + 1 } d" , "privateuseone" )(functools .partial (upsample , mode = "nearest-exact" ))
182
239
240
+ @torch .library .impl ("aten::scatter_add.out" , "privateuseone" )
241
+ def scatter_add (self , dim , index , src , out ):
242
+ self , index , src , out = unwrap (self ), unwrap (index ), unwrap (src ), unwrap (out )
243
+ if self .shape == (): return wrap (out .assign (src ))
244
+ return wrap (out .assign (Tensor .scatter_reduce (self , dim , index , src , reduce = 'sum' )))
245
+
183
246
@torch .library .impl ("aten::_copy_from" , "privateuseone" )
184
247
def _copy_from (src : torch .Tensor , dest , non_blocking = False ):
185
248
realize = dest .is_tiny and maybe_realize_storage (dest )
@@ -222,7 +285,6 @@ def sort_values(input, dim=-1, descending=False, stable=True, values=None, indic
222
285
223
286
# register some decompositions
224
287
from torch ._decomp import get_decompositions
225
- aten = torch .ops .aten
226
288
decomps = [
227
289
aten .native_batch_norm , aten .native_batch_norm_backward ,
228
290
aten .native_layer_norm_backward ,
@@ -344,7 +406,7 @@ def sort_values(input, dim=-1, descending=False, stable=True, values=None, indic
344
406
"aten.scatter.value_out" : Tensor .scatter ,
345
407
"aten.where.self_out" : Tensor .where ,
346
408
"aten.prod.int_out" : Tensor .prod ,
347
- "aten.scatter_add.out " : functools . partial ( Tensor .scatter_reduce , reduce = 'sum' ) ,
409
+ "aten.scatter.src_out " : Tensor .scatter ,
348
410
# NOTE: axis=[] in torch means all, change tinygrad?
349
411
"aten.sum.IntList_out" : lambda self ,axis ,keepdim = False ,dtype = None :
350
412
self .sum (axis if axis is None or len (axis ) else None , keepdim ,
@@ -408,9 +470,8 @@ def _wrap_out(*args, **kwargs):
408
470
"aten.logical_not" : Tensor .logical_not ,
409
471
"aten.logical_or_" : inplace_fn ("x" )(lambda x , y : x .assign (x | y )),
410
472
"aten.multinomial" : Tensor .multinomial ,
411
- "aten.pad" : Tensor .pad ,
412
- "aten.reflection_pad2d" : functools .partial (Tensor .pad , mode = "reflect" ),
413
473
"aten.masked_fill_.Scalar" : inplace_fn ("self" )(lambda self , mask , value : self .assign (self .masked_fill (mask , value ))),
474
+ "aten.masked_fill_.Tensor" : inplace_fn ("self" )(lambda self , mask , value : self .assign (self .masked_fill (mask , value ))),
414
475
"aten.masked_fill.Scalar" : Tensor .masked_fill ,
415
476
"aten.masked_fill.Tensor" : Tensor .masked_fill ,
416
477
"aten.masked_select" : Tensor .masked_select ,
@@ -441,6 +502,9 @@ def _wrap_out(*args, **kwargs):
441
502
"aten.repeat" : Tensor .repeat ,
442
503
"aten.lerp.Tensor" : Tensor .lerp ,
443
504
"aten.expand" : Tensor .expand ,
505
+ "aten.ones_like" : lambda self , dtype = None , device = None , ** kwargs :
506
+ self .ones_like (** {k : v for k , v in {"dtype" : _from_torch_dtype (dtype ) if dtype else None ,
507
+ "device" : _from_torch_device (device ) if device else None }.items () if v is not None }),
444
508
"aten.t" : Tensor .transpose ,
445
509
"aten.detach" : Tensor .detach ,
446
510
"aten.max.dim" : lambda self , dim , keepdim = False : (self .max (dim , keepdim ), self .argmax (dim , keepdim ).cast (dtype = dtypes .int64 ))
0 commit comments