Skip to content

Commit 5e58f4b

Browse files
authored
Tiny backend test_ops fix part 3 (tinygrad#9483)
* extract straightforward things from tinygrad#9302 * pass dtype and device for ones_like
1 parent 9fcef4d commit 5e58f4b

File tree

2 files changed

+72
-7
lines changed

2 files changed

+72
-7
lines changed

extra/torch_backend/backend.py

+71-7
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def device_count(self): return getenv("GPUS", 1) # TODO: device count in tiny?
3131
torch.utils.rename_privateuse1_backend("tiny")
3232
torch._register_device_module("tiny", TinyBackend())
3333
torch.utils.generate_methods_for_privateuse1_backend()
34+
aten = torch.ops.aten
3435

3536
# in place operations with views
3637
def is_view(self: torch.Tensor) -> bool: return getattr(self, "_base", None) is not None
@@ -75,9 +76,37 @@ def _index_put_impl_(self, indices, values, accumulate=False, unsafe=False):
7576
def index_tensor(x, y):
7677
return aten.index(x.cpu(), [z.cpu() if isinstance(z, torch.Tensor) else None for z in y]).to(x.device)
7778

79+
@torch.library.impl("aten::index_put", "privateuseone")
80+
def index_put(self, indices, values, accumulate=False):
81+
return aten.index_put(self.cpu(), [z.cpu() if isinstance(z, torch.Tensor) else None for z in indices], values.cpu(), accumulate).tiny()
82+
7883
@torch.library.impl("aten::randperm.generator_out", "privateuseone")
7984
def randperm_generator(n, generator=None, out=None): out.copy_(torch.randperm(n, generator=generator, device="cpu").tiny())
8085

86+
@torch.library.impl("aten::cumprod", "privateuseone")
87+
# TODO: move to tinygrad
88+
def cumprod(self, dim, dtype=None): return aten.cumprod(self.cpu(), dim, dtype=dtype).tiny()
89+
90+
@torch.library.impl("aten::cummax", "privateuseone")
91+
def cummax(self, dim):
92+
# TODO: support cummax with indices to match torch
93+
cummax, indices = aten.cummax(self.cpu(), dim)
94+
return (cummax.tiny(), indices.tiny())
95+
96+
@torch.library.impl("aten::nonzero", "privateuseone")
97+
# TODO: move to tinygrad
98+
def nonzero(self): return aten.nonzero(self.cpu()).tiny()
99+
100+
def upsample_backward(grad_out, output_size, input_size, *args, f=None): return f(grad_out.cpu(), output_size, input_size, *args).tiny()
101+
102+
for i in [
103+
"upsample_linear1d_backward", "upsample_nearest1d_backward", "_upsample_nearest_exact1d_backward",
104+
"upsample_nearest2d_backward", "_upsample_nearest_exact2d_backward",
105+
"upsample_nearest3d_backward", "_upsample_nearest_exact3d_backward",
106+
"upsample_trilinear3d_backward", "upsample_bilinear2d_backward"
107+
]:
108+
torch.library.impl(f"aten::{i}", "privateuseone")(functools.partial(upsample_backward, f=getattr(aten, i)))
109+
81110
# *** end bad functions on CPU ***
82111

83112
@torch.library.impl("aten::zero_", "privateuseone")
@@ -162,24 +191,58 @@ def arange_start_step(start, end, step, dtype=None, device=None, pin_memory=None
162191
def convolution_overrideable(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups):
163192
if TORCH_DEBUG >= 1:
164193
print(f"convolution {input.shape=} {weight.shape=} {stride=} {padding=} {dilation=} {transposed=} {output_padding=} {groups=}")
165-
return wrap(unwrap(input).conv2d(unwrap(weight), unwrap(bias) if bias is not None else None,
166-
groups=groups, stride=stride, dilation=dilation, padding=padding))
194+
input, weight, bias = unwrap(input), unwrap(weight), unwrap(bias) if bias is not None else None
195+
if not transposed: return wrap(input.conv2d(weight, bias, groups=groups, stride=stride, dilation=dilation, padding=padding))
196+
return wrap(input.conv_transpose2d(weight, bias, groups=groups, stride=stride, dilation=dilation, padding=padding, output_padding=output_padding))
167197

168198
@torch.library.impl("aten::convolution_backward_overrideable", "privateuseone")
169199
def convolution_backward_overrideable(grad_out, input, weight, stride, padding, dilation, transposed, output_padding, groups, output_mask):
170200
if TORCH_DEBUG >= 1:
171201
print(f"convolution_backward {input.shape=} {weight.shape=} {stride=} {padding=} {dilation=} {transposed=} {output_padding=} {groups=}")
172202
grad_out, input, weight, bias = unwrap(grad_out), unwrap(input), unwrap(weight), Tensor.zeros(weight.shape[0], device=_from_torch_device(weight.device))
173-
out = Tensor.conv2d(input, weight, bias, groups=groups, stride=stride, dilation=dilation, padding=padding)
203+
if not transposed: out = Tensor.conv2d(input, weight, bias, groups=groups, stride=stride, dilation=dilation, padding=padding)
204+
else:
205+
bias = Tensor.zeros(weight.shape[1] * groups)
206+
out = Tensor.conv_transpose2d(input, weight, bias, groups=groups, stride=stride, dilation=dilation, padding=padding, output_padding=output_padding)
174207
grads = out.gradient(*[t for t,m in zip([input, weight, bias], output_mask) if m], gradient=grad_out)
175208
return tuple([wrap(grads.pop(0)) if m else None for m in output_mask])
176209

210+
def avg_pool(self, kernel_size, stride=[], padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None):
211+
return wrap(unwrap(self).avg_pool2d(kernel_size, stride if stride != [] else None, padding=padding, ceil_mode=ceil_mode, count_include_pad=count_include_pad))
212+
213+
def avg_pool_backward(grad_out, self, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None):
214+
self, grad_out = unwrap(self), unwrap(grad_out)
215+
out = Tensor.avg_pool2d(self, kernel_size, stride if stride != [] else None, dilation=1, padding=padding, ceil_mode=ceil_mode, count_include_pad=count_include_pad)
216+
return wrap(out.gradient(self, gradient=grad_out)[0])
217+
218+
for dim in [2, 3]:
219+
torch.library.impl(f"aten::avg_pool{dim}d", "privateuseone")(avg_pool)
220+
torch.library.impl(f"aten::avg_pool{dim}d_backward", "privateuseone")(avg_pool_backward)
221+
222+
def pad_forward(self, padding, mode=None): return wrap(Tensor.pad(unwrap(self), padding, mode=mode))
223+
224+
def pad_backward(grad_out, self, padding, mode):
225+
self, grad_out = unwrap(self), unwrap(grad_out)
226+
out = Tensor.pad(self, padding, mode=mode)
227+
return wrap(out.gradient(self, gradient=grad_out)[0])
228+
229+
for dim in [1, 2, 3]:
230+
for pad_type, mode in [("replication", "replicate"), ("reflection", "reflect")]:
231+
torch.library.impl(f"aten::{pad_type}_pad{dim}d", "privateuseone")(functools.partial(pad_forward, mode=mode))
232+
torch.library.impl(f"aten::{pad_type}_pad{dim}d_backward", "privateuseone")(functools.partial(pad_backward, mode=mode))
233+
177234
def upsample(self, size, align_corners=False, mode=None): return wrap(Tensor.interpolate(unwrap(self), size, mode=mode, align_corners=align_corners))
178235
for i,pre in enumerate(["", "bi", "tri"]):
179236
torch.library.impl(f"aten::upsample_{pre}linear{i+1}d", "privateuseone")(functools.partial(upsample, mode="linear"))
180237
torch.library.impl(f"aten::upsample_nearest{i+1}d", "privateuseone")(functools.partial(upsample, mode="nearest"))
181238
torch.library.impl(f"aten::_upsample_nearest_exact{i+1}d", "privateuseone")(functools.partial(upsample, mode="nearest-exact"))
182239

240+
@torch.library.impl("aten::scatter_add.out", "privateuseone")
241+
def scatter_add(self, dim, index, src, out):
242+
self, index, src, out = unwrap(self), unwrap(index), unwrap(src), unwrap(out)
243+
if self.shape == (): return wrap(out.assign(src))
244+
return wrap(out.assign(Tensor.scatter_reduce(self, dim, index, src, reduce='sum')))
245+
183246
@torch.library.impl("aten::_copy_from", "privateuseone")
184247
def _copy_from(src: torch.Tensor, dest, non_blocking=False):
185248
realize = dest.is_tiny and maybe_realize_storage(dest)
@@ -222,7 +285,6 @@ def sort_values(input, dim=-1, descending=False, stable=True, values=None, indic
222285

223286
# register some decompositions
224287
from torch._decomp import get_decompositions
225-
aten = torch.ops.aten
226288
decomps = [
227289
aten.native_batch_norm, aten.native_batch_norm_backward,
228290
aten.native_layer_norm_backward,
@@ -344,7 +406,7 @@ def sort_values(input, dim=-1, descending=False, stable=True, values=None, indic
344406
"aten.scatter.value_out": Tensor.scatter,
345407
"aten.where.self_out": Tensor.where,
346408
"aten.prod.int_out": Tensor.prod,
347-
"aten.scatter_add.out": functools.partial(Tensor.scatter_reduce, reduce='sum'),
409+
"aten.scatter.src_out": Tensor.scatter,
348410
# NOTE: axis=[] in torch means all, change tinygrad?
349411
"aten.sum.IntList_out": lambda self,axis,keepdim=False,dtype=None:
350412
self.sum(axis if axis is None or len(axis) else None, keepdim,
@@ -408,9 +470,8 @@ def _wrap_out(*args, **kwargs):
408470
"aten.logical_not": Tensor.logical_not,
409471
"aten.logical_or_": inplace_fn("x")(lambda x, y: x.assign(x | y)),
410472
"aten.multinomial": Tensor.multinomial,
411-
"aten.pad": Tensor.pad,
412-
"aten.reflection_pad2d": functools.partial(Tensor.pad, mode="reflect"),
413473
"aten.masked_fill_.Scalar": inplace_fn("self")(lambda self, mask, value: self.assign(self.masked_fill(mask, value))),
474+
"aten.masked_fill_.Tensor": inplace_fn("self")(lambda self, mask, value: self.assign(self.masked_fill(mask, value))),
414475
"aten.masked_fill.Scalar": Tensor.masked_fill,
415476
"aten.masked_fill.Tensor": Tensor.masked_fill,
416477
"aten.masked_select": Tensor.masked_select,
@@ -441,6 +502,9 @@ def _wrap_out(*args, **kwargs):
441502
"aten.repeat": Tensor.repeat,
442503
"aten.lerp.Tensor": Tensor.lerp,
443504
"aten.expand": Tensor.expand,
505+
"aten.ones_like": lambda self, dtype=None, device=None, **kwargs:
506+
self.ones_like(**{k: v for k, v in {"dtype": _from_torch_dtype(dtype) if dtype else None,
507+
"device": _from_torch_device(device) if device else None}.items() if v is not None}),
444508
"aten.t": Tensor.transpose,
445509
"aten.detach": Tensor.detach,
446510
"aten.max.dim": lambda self, dim, keepdim=False: (self.max(dim, keepdim), self.argmax(dim, keepdim).cast(dtype=dtypes.int64))

test/test_ops.py

+1
Original file line numberDiff line numberDiff line change
@@ -2671,6 +2671,7 @@ def test_gather(self):
26712671
vals=[[1., 2., 3.]])
26722672

26732673
@unittest.expectedFailure
2674+
@unittest.skipIf(torch._C._get_privateuse1_backend_name() == "tiny", 'results in a success instead of a failure')
26742675
def test_gather_failure(self):
26752676
# gather with inf values do not work, other values results in nan
26762677
helper_test_op(None, lambda x: x.gather(dim=0, index=torch.tensor([2, 1, 0, 1, 2], requires_grad=False)),

0 commit comments

Comments
 (0)