Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pull] master from tinygrad:master #191

Merged
merged 1 commit into from
Mar 23, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/tensor/ops.md
Original file line number Diff line number Diff line change
@@ -22,6 +22,7 @@

::: tinygrad.Tensor.avg_pool2d
::: tinygrad.Tensor.max_pool2d
::: tinygrad.Tensor.max_unpool2d
::: tinygrad.Tensor.conv2d
::: tinygrad.Tensor.conv_transpose2d
::: tinygrad.Tensor.dot
6 changes: 1 addition & 5 deletions extra/onnx.py
Original file line number Diff line number Diff line change
@@ -434,11 +434,7 @@ def ConvTranspose(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OP
return X.conv_transpose2d(W, B, stride=strides, groups=group, dilation=dilations, padding=pads, output_padding=output_padding)

def MaxUnpool(xT: Tensor, xI: Tensor, outshape: list[int]|None=None, kernel_shape:list[int]=None, pads:list[int]|int=0, strides:list[int]|int=1):
pads, strides = (make_tuple(x, len(xI.shape)) for x in (pads, strides))
out_sh = [(ks//2)*2 + st * inps for inps, st, ks in zip(xI.shape, strides, kernel_shape)]
ret = (xI.reshape(-1, 1)._one_hot_along_dim(prod(out_sh)) * xT.reshape(-1, 1)).sum(0).reshape(1, 1, *out_sh)
if outshape is not None and outshape != ret.shape: pads = _auto_pad([outshape[-2] - ret.shape[-2], outshape[-1] - ret.shape[-1]], "SAME_UPPER")
return ret.pad(_onnx_pads_to_tiny_pads(pads))
return Tensor.max_unpool2d(xT, xI, kernel_shape, strides, 1, pads, outshape if outshape is None else tuple(outshape))

def GlobalAveragePool(X:Tensor): return X.mean(axis=tuple(range(2, X.ndim)), keepdim=True)
def GlobalMaxPool(X:Tensor): return X.max(axis=tuple(range(2, X.ndim)), keepdim=True)
11 changes: 5 additions & 6 deletions extra/torch_backend/backend.py
Original file line number Diff line number Diff line change
@@ -167,12 +167,11 @@ def max_pool2d_with_indices(self:torch.Tensor, kernel_size:tuple[int, ...], stri

@torch.library.impl("aten::max_pool2d_with_indices_backward", "privateuseone")
def max_pool2d_with_indices_backward(grad_out:torch.Tensor, self:torch.Tensor, kernel_size:tuple[int, ...], stride=None, padding=0, dilation=1, ceil_mode=False, indices=None):
if stride is not None and len(stride) == 0: stride = None
# TODO: utilize input indices once they are correct
# TODO: implement maxunpool
self_ = unwrap(self)
out = Tensor.max_pool2d(self_, kernel_size, stride, dilation, padding, ceil_mode)
return wrap(out.gradient(self_, gradient=unwrap(grad_out))[0])
return wrap(Tensor.max_unpool2d(unwrap(grad_out), unwrap(indices), output_size=unwrap(self).shape))

@torch.library.impl("aten::max_unpool2d", "privateuseone")
def max_unpool2d(self:torch.Tensor, indices:torch.Tensor, output_size):
return wrap(unwrap(self).max_unpool2d(unwrap(indices), output_size=output_size))

@torch.library.impl("aten::arange", "privateuseone")
def arange(end, dtype=None, device=None, pin_memory=None):
3 changes: 3 additions & 0 deletions test/external/external_test_onnx_backend.py
Original file line number Diff line number Diff line change
@@ -54,6 +54,9 @@ def supports_device(cls, device: str) -> bool:
backend_test.exclude('test_dynamicquantizelinear_cpu')
backend_test.exclude('test_dynamicquantizelinear_expanded_cpu')

# BUG: we match ORT, tested in TestMainOnnxOps.test_maxunpool
backend_test.exclude('test_maxunpool_export_with_output_shape_cpu')

# about different dtypes
if not is_dtype_supported(dtypes.float64):
backend_test.exclude('float64')
12 changes: 11 additions & 1 deletion test/external/external_test_onnx_ops.py
Original file line number Diff line number Diff line change
@@ -48,6 +48,16 @@ def test_gather(self):
outputs = ["y"]
self.helper_test_single_op("Gather", inputs, attributes, outputs)

def test_maxunpool(self):
# test_maxunpool_export_with_output_shape_cpu
xT = np.array([[[[5, 6], [7, 8]]]], dtype=np.float32)
xI = np.array([[[[5, 7], [13, 15]]]], dtype=np.int64)
output_shape = np.array((1, 1, 5, 5), dtype=np.int64)
inputs = {"x": xT, "indices": xI, "output_shape": output_shape}
attributes = {"kernel_shape": [2, 2], "strides": [2, 2]}
outputs = ["y"]
self.helper_test_single_op("MaxUnpool", inputs, attributes, outputs)

def test_quantize_linear(self):
test_cases = [
{"test_case": "round_half_to_even", "qdtype": np.int8, "qzero_point": 0, "x": [-1.5, -0.5, 0.5, 1.5], "scale": 1.0},
@@ -221,7 +231,7 @@ def test_qlinear_add(self):
}
attributes = {}
outputs = ["C"]
self.helper_test_single_op("QLinearAdd", inputs, attributes, outputs)
self.helper_test_single_op("QLinearAdd", inputs, attributes, outputs, atol=1) # TODO: look into why this is inaccurate

with self.subTest(test_case="round_half_to_even"):
inputs = {
21 changes: 21 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
@@ -2355,6 +2355,27 @@ def test_max_pool2d_return_indices(self):
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(3,3), stride=1, return_indices=True)[1].type(torch.int32),
lambda x: Tensor.max_pool2d(x, kernel_size=(3,3), stride=1, return_indices=True)[1],
vals=[[[[[1]*6]*6]]], forward_only=True) # Tensor.ones(1,1,6,6)
# overlapping max indices
helper_test_op(None,
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(2,2), stride=1, return_indices=True)[1].type(torch.int32),
lambda x: Tensor.max_pool2d(x, kernel_size=(2,2), stride=1, return_indices=True)[1],
vals=[[[[[1,2]*3]*6]]], forward_only=True) # Tensor([1,2,1,2,1,2]).expand(1,1,6,6)

def test_max_unpool2d(self):
args = {"kernel_size":(5,5), "stride":(6,5)}
helper_test_op([(8,3,50,50)],
lambda x: torch.nn.functional.max_unpool2d(*torch.nn.functional.max_pool2d(x, return_indices=True, **args), **args),
lambda x: Tensor.max_unpool2d(*Tensor.max_pool2d(x, return_indices=True, **args), **args), forward_only=True)
args = {"kernel_size":(3,3), "stride":(6,7), "padding":1}
helper_test_op([(8,3,30,30)],
lambda x: torch.nn.functional.max_unpool2d(*torch.nn.functional.max_pool2d(x, return_indices=True, **args), **args, output_size=(30,30)),
lambda x: Tensor.max_unpool2d(*Tensor.max_pool2d(x, return_indices=True, **args), **args, output_size=(30,30)), forward_only=True)
# batch_size and channel_size of output_size are ignored
helper_test_op([(1,3,7,6)],
lambda x: torch.nn.functional.max_unpool2d(*torch.nn.functional.max_pool2d(x, kernel_size=(2,2), return_indices=True),
kernel_size=(2,2), output_size=(99,99,7,6)),
lambda x: Tensor.max_unpool2d(*Tensor.max_pool2d(x, kernel_size=(2,2), return_indices=True),
kernel_size=(2,2), output_size=(99,99,7,6)), forward_only=True)

def test_avg_pool2d(self):
shape = (32,2,111,28)
31 changes: 31 additions & 0 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
@@ -2203,6 +2203,37 @@ def max_pool2d(self, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1,
idx = m * idx.pad(pads, value=dtypes.min(idx.dtype))._pool(k_, stride if stride is not None else k_, dilation)
return pooled.max(axis), spatial_sz - idx.max(axis)

def max_unpool2d(self, indices:Tensor, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]=0, output_size=None):
"""
Performs a partial inverse of `max_pool2d` using the indices from the argmax.

When `output_size` is provided, the output shape disambiguates to the provided shape.

NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.

```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(1, 17).reshape(1, 1, 4, 4)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
output, indices = Tensor.max_pool2d(t, return_indices=True)
print(output.numpy())
print(indices.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.max_unpool2d(output, indices).numpy())
```
"""
bs,c,*spatial_shape = self.shape
if output_size is None:
k_,d_,s_ = (make_tuple(x, len(spatial_shape)) for x in (kernel_size, dilation, stride if stride is not None else kernel_size))
p_ = _flat_to_grouped(self._resolve_pool_pads(padding, len(spatial_shape)))
# https://arxiv.org/pdf/1603.07285 inverse of relationship 15 in section 5.1.
output_size = tuple((i-1)*s - (pB+pA) + (d*(k-1)+1) for i,k,d,s,(pA,pB) in zip(spatial_shape,k_,d_,s_,p_))
else: output_size = output_size[-len(spatial_shape):]
ret = (indices.reshape(bs,c,1,-1)._one_hot_along_dim(prod(output_size), 2) * self.reshape(bs,c,1,-1)).sum(3)
return ret.reshape(bs,c,*output_size)

def conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding:int|tuple[int, ...]=0,
dtype:DTypeLike|None=None) -> Tensor:
"""
Loading
Oops, something went wrong.