From 309afa20b7459583caebf96a2d770fde622e836f Mon Sep 17 00:00:00 2001 From: geohotstan <135171913+geohotstan@users.noreply.github.com> Date: Sun, 23 Mar 2025 00:11:33 +0800 Subject: [PATCH] add Tensor.max_unpool2d (#9518) * why does max_unpool2d feel slower than out.gradient ... * slightly cleaner * what happened to ruff * need to think about this some more * slightly faster now? * clean up, 1 more failing edge case * ok good * working TINY_BACKEND * nit doc wording * retry CI --- docs/tensor/ops.md | 1 + extra/onnx.py | 6 +--- extra/torch_backend/backend.py | 11 ++++---- test/external/external_test_onnx_backend.py | 3 ++ test/external/external_test_onnx_ops.py | 12 +++++++- test/test_ops.py | 21 ++++++++++++++ tinygrad/tensor.py | 31 +++++++++++++++++++++ 7 files changed, 73 insertions(+), 12 deletions(-) diff --git a/docs/tensor/ops.md b/docs/tensor/ops.md index 2a2a872ae4c30..ef902b68a14dc 100644 --- a/docs/tensor/ops.md +++ b/docs/tensor/ops.md @@ -22,6 +22,7 @@ ::: tinygrad.Tensor.avg_pool2d ::: tinygrad.Tensor.max_pool2d +::: tinygrad.Tensor.max_unpool2d ::: tinygrad.Tensor.conv2d ::: tinygrad.Tensor.conv_transpose2d ::: tinygrad.Tensor.dot diff --git a/extra/onnx.py b/extra/onnx.py index d340b00bf4ac3..726e9099c92d3 100644 --- a/extra/onnx.py +++ b/extra/onnx.py @@ -434,11 +434,7 @@ def ConvTranspose(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OP return X.conv_transpose2d(W, B, stride=strides, groups=group, dilation=dilations, padding=pads, output_padding=output_padding) def MaxUnpool(xT: Tensor, xI: Tensor, outshape: list[int]|None=None, kernel_shape:list[int]=None, pads:list[int]|int=0, strides:list[int]|int=1): - pads, strides = (make_tuple(x, len(xI.shape)) for x in (pads, strides)) - out_sh = [(ks//2)*2 + st * inps for inps, st, ks in zip(xI.shape, strides, kernel_shape)] - ret = (xI.reshape(-1, 1)._one_hot_along_dim(prod(out_sh)) * xT.reshape(-1, 1)).sum(0).reshape(1, 1, *out_sh) - if outshape is not None and outshape != ret.shape: pads = _auto_pad([outshape[-2] - ret.shape[-2], outshape[-1] - ret.shape[-1]], "SAME_UPPER") - return ret.pad(_onnx_pads_to_tiny_pads(pads)) + return Tensor.max_unpool2d(xT, xI, kernel_shape, strides, 1, pads, outshape if outshape is None else tuple(outshape)) def GlobalAveragePool(X:Tensor): return X.mean(axis=tuple(range(2, X.ndim)), keepdim=True) def GlobalMaxPool(X:Tensor): return X.max(axis=tuple(range(2, X.ndim)), keepdim=True) diff --git a/extra/torch_backend/backend.py b/extra/torch_backend/backend.py index 9dab77f3450a3..c338af886aa07 100644 --- a/extra/torch_backend/backend.py +++ b/extra/torch_backend/backend.py @@ -167,12 +167,11 @@ def max_pool2d_with_indices(self:torch.Tensor, kernel_size:tuple[int, ...], stri @torch.library.impl("aten::max_pool2d_with_indices_backward", "privateuseone") def max_pool2d_with_indices_backward(grad_out:torch.Tensor, self:torch.Tensor, kernel_size:tuple[int, ...], stride=None, padding=0, dilation=1, ceil_mode=False, indices=None): - if stride is not None and len(stride) == 0: stride = None - # TODO: utilize input indices once they are correct - # TODO: implement maxunpool - self_ = unwrap(self) - out = Tensor.max_pool2d(self_, kernel_size, stride, dilation, padding, ceil_mode) - return wrap(out.gradient(self_, gradient=unwrap(grad_out))[0]) + return wrap(Tensor.max_unpool2d(unwrap(grad_out), unwrap(indices), output_size=unwrap(self).shape)) + +@torch.library.impl("aten::max_unpool2d", "privateuseone") +def max_unpool2d(self:torch.Tensor, indices:torch.Tensor, output_size): + return wrap(unwrap(self).max_unpool2d(unwrap(indices), output_size=output_size)) @torch.library.impl("aten::arange", "privateuseone") def arange(end, dtype=None, device=None, pin_memory=None): diff --git a/test/external/external_test_onnx_backend.py b/test/external/external_test_onnx_backend.py index d12fcd9f88675..fd9a5896eead0 100644 --- a/test/external/external_test_onnx_backend.py +++ b/test/external/external_test_onnx_backend.py @@ -54,6 +54,9 @@ def supports_device(cls, device: str) -> bool: backend_test.exclude('test_dynamicquantizelinear_cpu') backend_test.exclude('test_dynamicquantizelinear_expanded_cpu') +# BUG: we match ORT, tested in TestMainOnnxOps.test_maxunpool +backend_test.exclude('test_maxunpool_export_with_output_shape_cpu') + # about different dtypes if not is_dtype_supported(dtypes.float64): backend_test.exclude('float64') diff --git a/test/external/external_test_onnx_ops.py b/test/external/external_test_onnx_ops.py index 7e846695f943c..9225018ce1236 100644 --- a/test/external/external_test_onnx_ops.py +++ b/test/external/external_test_onnx_ops.py @@ -48,6 +48,16 @@ def test_gather(self): outputs = ["y"] self.helper_test_single_op("Gather", inputs, attributes, outputs) + def test_maxunpool(self): + # test_maxunpool_export_with_output_shape_cpu + xT = np.array([[[[5, 6], [7, 8]]]], dtype=np.float32) + xI = np.array([[[[5, 7], [13, 15]]]], dtype=np.int64) + output_shape = np.array((1, 1, 5, 5), dtype=np.int64) + inputs = {"x": xT, "indices": xI, "output_shape": output_shape} + attributes = {"kernel_shape": [2, 2], "strides": [2, 2]} + outputs = ["y"] + self.helper_test_single_op("MaxUnpool", inputs, attributes, outputs) + def test_quantize_linear(self): test_cases = [ {"test_case": "round_half_to_even", "qdtype": np.int8, "qzero_point": 0, "x": [-1.5, -0.5, 0.5, 1.5], "scale": 1.0}, @@ -221,7 +231,7 @@ def test_qlinear_add(self): } attributes = {} outputs = ["C"] - self.helper_test_single_op("QLinearAdd", inputs, attributes, outputs) + self.helper_test_single_op("QLinearAdd", inputs, attributes, outputs, atol=1) # TODO: look into why this is inaccurate with self.subTest(test_case="round_half_to_even"): inputs = { diff --git a/test/test_ops.py b/test/test_ops.py index 34700276a4f77..a343fd4c10faa 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2355,6 +2355,27 @@ def test_max_pool2d_return_indices(self): lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(3,3), stride=1, return_indices=True)[1].type(torch.int32), lambda x: Tensor.max_pool2d(x, kernel_size=(3,3), stride=1, return_indices=True)[1], vals=[[[[[1]*6]*6]]], forward_only=True) # Tensor.ones(1,1,6,6) + # overlapping max indices + helper_test_op(None, + lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(2,2), stride=1, return_indices=True)[1].type(torch.int32), + lambda x: Tensor.max_pool2d(x, kernel_size=(2,2), stride=1, return_indices=True)[1], + vals=[[[[[1,2]*3]*6]]], forward_only=True) # Tensor([1,2,1,2,1,2]).expand(1,1,6,6) + + def test_max_unpool2d(self): + args = {"kernel_size":(5,5), "stride":(6,5)} + helper_test_op([(8,3,50,50)], + lambda x: torch.nn.functional.max_unpool2d(*torch.nn.functional.max_pool2d(x, return_indices=True, **args), **args), + lambda x: Tensor.max_unpool2d(*Tensor.max_pool2d(x, return_indices=True, **args), **args), forward_only=True) + args = {"kernel_size":(3,3), "stride":(6,7), "padding":1} + helper_test_op([(8,3,30,30)], + lambda x: torch.nn.functional.max_unpool2d(*torch.nn.functional.max_pool2d(x, return_indices=True, **args), **args, output_size=(30,30)), + lambda x: Tensor.max_unpool2d(*Tensor.max_pool2d(x, return_indices=True, **args), **args, output_size=(30,30)), forward_only=True) + # batch_size and channel_size of output_size are ignored + helper_test_op([(1,3,7,6)], + lambda x: torch.nn.functional.max_unpool2d(*torch.nn.functional.max_pool2d(x, kernel_size=(2,2), return_indices=True), + kernel_size=(2,2), output_size=(99,99,7,6)), + lambda x: Tensor.max_unpool2d(*Tensor.max_pool2d(x, kernel_size=(2,2), return_indices=True), + kernel_size=(2,2), output_size=(99,99,7,6)), forward_only=True) def test_avg_pool2d(self): shape = (32,2,111,28) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 6b8ddab448b81..821a75c02f2a6 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2203,6 +2203,37 @@ def max_pool2d(self, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1, idx = m * idx.pad(pads, value=dtypes.min(idx.dtype))._pool(k_, stride if stride is not None else k_, dilation) return pooled.max(axis), spatial_sz - idx.max(axis) + def max_unpool2d(self, indices:Tensor, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]=0, output_size=None): + """ + Performs a partial inverse of `max_pool2d` using the indices from the argmax. + + When `output_size` is provided, the output shape disambiguates to the provided shape. + + NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions. + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor.arange(1, 17).reshape(1, 1, 4, 4) + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + output, indices = Tensor.max_pool2d(t, return_indices=True) + print(output.numpy()) + print(indices.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor.max_unpool2d(output, indices).numpy()) + ``` + """ + bs,c,*spatial_shape = self.shape + if output_size is None: + k_,d_,s_ = (make_tuple(x, len(spatial_shape)) for x in (kernel_size, dilation, stride if stride is not None else kernel_size)) + p_ = _flat_to_grouped(self._resolve_pool_pads(padding, len(spatial_shape))) + # https://arxiv.org/pdf/1603.07285 inverse of relationship 15 in section 5.1. + output_size = tuple((i-1)*s - (pB+pA) + (d*(k-1)+1) for i,k,d,s,(pA,pB) in zip(spatial_shape,k_,d_,s_,p_)) + else: output_size = output_size[-len(spatial_shape):] + ret = (indices.reshape(bs,c,1,-1)._one_hot_along_dim(prod(output_size), 2) * self.reshape(bs,c,1,-1)).sum(3) + return ret.reshape(bs,c,*output_size) + def conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding:int|tuple[int, ...]=0, dtype:DTypeLike|None=None) -> Tensor: """