Skip to content

Commit 8c0d0a1

Browse files
authored
Add return_indices to max_pool (tinygrad#9506)
* wow argmax is so good * 1 less line * clean up and better variable names * is this torch thing right...? * add more tests * slap a TODO on it * clean ups * prettier looking code and fix ceil mode test * add return types and some docs * ok that was a bad example since indices == value, just no example
1 parent 189f62d commit 8c0d0a1

File tree

4 files changed

+51
-16
lines changed

4 files changed

+51
-16
lines changed

extra/onnx.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -409,11 +409,9 @@ def AveragePool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NO
409409

410410
def MaxPool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, dilations:list[int]|int=1, pads:list[int]|int=0,
411411
storage_order:int=0, strides:list[int]|int=1):
412-
ret = X.max_pool2d(kernel_shape, strides, dilations, _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad), ceil_mode=ceil_mode)
413-
# tests expect indices with int64 dtype
414-
# TODO: if there are repeated values, this is wrong
415-
indices = ((ret.reshape(-1, 1) == X.reshape(1, -1)) * Tensor.arange(X.numel(), dtype=dtypes.int64).unsqueeze(0)).sum(1).reshape(ret.shape)
416-
return ret.cast(X.dtype), indices.transpose(-2, -1) if storage_order else indices
412+
pads = _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad)
413+
ret, idx = X.max_pool2d(kernel_shape, strides, dilations, pads, ceil_mode=ceil_mode, return_indices=True)
414+
return ret, idx.transpose(-2, -1).cast(dtypes.int64) if storage_order else idx.cast(dtypes.int64)
417415

418416
def Conv(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1,
419417
kernel_shape:list[int]|None=None, pads:list[int]|int=0, strides:list[int]|int=1):

extra/torch_backend/backend.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,15 +162,14 @@ def empty_memory_format(size, dtype=None, layout=None, device=None, pin_memory=F
162162
def max_pool2d_with_indices(self:torch.Tensor, kernel_size:tuple[int, ...], stride=None, padding=0, dilation=1, ceil_mode=False):
163163
# TODO: supprt stride [] in tinygrad?
164164
if stride is not None and len(stride) == 0: stride = None
165-
# TODO: support return_indices in tinygrad
166-
ret = unwrap(self).max_pool2d(kernel_size, stride, dilation, padding, ceil_mode)
167-
# TODO: this is wrong
168-
return (wrap(ret), wrap(Tensor.zeros_like(ret, dtype=dtypes.int64)))
165+
ret, idx = unwrap(self).max_pool2d(kernel_size, stride, dilation, padding, ceil_mode, return_indices=True)
166+
return (wrap(ret), wrap(idx.cast(dtypes.int64)))
169167

170168
@torch.library.impl("aten::max_pool2d_with_indices_backward", "privateuseone")
171169
def max_pool2d_with_indices_backward(grad_out:torch.Tensor, self:torch.Tensor, kernel_size:tuple[int, ...], stride=None, padding=0, dilation=1, ceil_mode=False, indices=None):
172170
if stride is not None and len(stride) == 0: stride = None
173171
# TODO: utilize input indices once they are correct
172+
# TODO: implement maxunpool
174173
self_ = unwrap(self)
175174
out = Tensor.max_pool2d(self_, kernel_size, stride, dilation, padding, ceil_mode)
176175
return wrap(out.gradient(self_, gradient=unwrap(grad_out))[0])

test/test_ops.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2327,6 +2327,35 @@ def test_max_pool2d_ceil_mode_output_size_reduce_by_one(self):
23272327
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True),
23282328
lambda x: Tensor.max_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True))
23292329

2330+
def test_max_pool2d_return_indices(self):
2331+
# batch and multi-channel
2332+
helper_test_op([(2,3,6,6)],
2333+
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(2,2), return_indices=True)[1].type(torch.int32),
2334+
lambda x: Tensor.max_pool2d(x, kernel_size=(2,2), return_indices=True)[1], forward_only=True)
2335+
# dilation
2336+
helper_test_op([(1,1,10,10)],
2337+
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(3,2), dilation=(2,3), return_indices=True)[1].type(torch.int32),
2338+
lambda x: Tensor.max_pool2d(x, kernel_size=(3,2), dilation=(2,3), return_indices=True)[1], forward_only=True)
2339+
# padding
2340+
helper_test_op([(1,1,5,5)],
2341+
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(3,3), padding=1, return_indices=True)[1].type(torch.int32),
2342+
lambda x: Tensor.max_pool2d(x, kernel_size=(3,3), padding=1, return_indices=True)[1], forward_only=True)
2343+
# ceil mode padding
2344+
helper_test_op([(1, 1, 7, 7)],
2345+
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(2, 2), stride=(2, 2), ceil_mode=True, return_indices=True)[1].type(torch.int32),
2346+
lambda x: Tensor.max_pool2d(x, kernel_size=(2, 2), stride=(2, 2), ceil_mode=True, return_indices=True)[1],
2347+
forward_only=True)
2348+
# global maxpool
2349+
helper_test_op([(1,1,12,13)],
2350+
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(12, 13), return_indices=True)[1].type(torch.int32),
2351+
lambda x: Tensor.max_pool2d(x, kernel_size=(12, 13), return_indices=True)[1],
2352+
forward_only=True)
2353+
# multiple identical values in same window and overlapping windows
2354+
helper_test_op(None,
2355+
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(3,3), stride=1, return_indices=True)[1].type(torch.int32),
2356+
lambda x: Tensor.max_pool2d(x, kernel_size=(3,3), stride=1, return_indices=True)[1],
2357+
vals=[[[[[1]*6]*6]]], forward_only=True) # Tensor.ones(1,1,6,6)
2358+
23302359
def test_avg_pool2d(self):
23312360
shape = (32,2,111,28)
23322361
for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]:

tinygrad/tensor.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2110,7 +2110,7 @@ def _apply_ceil_mode(self, pads:Sequence[int], k_:tuple[sint, ...], s_:int|tuple
21102110

21112111
# NOTE: these work for more than 2D
21122112
def avg_pool2d(self, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]=0,
2113-
ceil_mode=False, count_include_pad=True):
2113+
ceil_mode=False, count_include_pad=True) -> Tensor:
21142114
"""
21152115
Applies average pooling over a tensor.
21162116
@@ -2158,7 +2158,7 @@ def pool(x:Tensor, padding_:Sequence[int]) -> Tensor: return x.pad(padding_)._po
21582158
return pool(self, ceil_pads).sum(axis) / pool(self.pad(reg_pads).ones_like(), tuple(cp-rp for cp,rp in zip(ceil_pads, reg_pads))).sum(axis)
21592159

21602160
def max_pool2d(self, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]=0,
2161-
ceil_mode=False):
2161+
ceil_mode=False, return_indices=False) -> Tensor | tuple[Tensor, Tensor]:
21622162
"""
21632163
Applies max pooling over a tensor.
21642164
@@ -2175,6 +2175,7 @@ def max_pool2d(self, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1,
21752175
`(padding_left, padding_right, padding_top, padding_bottom, ...)`.
21762176
21772177
When `ceil_mode` is set to `True`, output shape will be determined using ceil division.
2178+
When `return_indices` is set to `True`, the argmax will be returned along with the max values.
21782179
21792180
NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
21802181
@@ -2191,9 +2192,16 @@ def max_pool2d(self, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1,
21912192
print(t.max_pool2d(padding=1).numpy())
21922193
```
21932194
"""
2194-
pads = self._resolve_pool_pads(padding, len(k_ := make_tuple(kernel_size, 2)))
2195+
axis = tuple(range(-len(k_ := make_tuple(kernel_size, 2)), 0))
2196+
pads = self._resolve_pool_pads(padding, len(k_))
21952197
if ceil_mode: pads = self._apply_ceil_mode(pads, k_, stride if stride is not None else k_, dilation)
2196-
return self.pad(pads, value=dtypes.min(self.dtype))._pool(k_, stride if stride is not None else k_, dilation).max(tuple(range(-len(k_), 0)))
2198+
pooled = self.pad(pads, value=dtypes.min(self.dtype))._pool(k_, stride if stride is not None else k_, dilation)
2199+
if not return_indices: return pooled.max(axis)
2200+
spatial_sz = math.prod(spatial_shape := self.shape[-len(k_):])
2201+
idx = Tensor.arange(spatial_sz,0,-1, requires_grad=False, device=self.device).reshape(spatial_shape)
2202+
m = pooled == pooled.max(axis, keepdim=True)
2203+
idx = m * idx.pad(pads, value=dtypes.min(idx.dtype))._pool(k_, stride if stride is not None else k_, dilation)
2204+
return pooled.max(axis), spatial_sz - idx.max(axis)
21972205

21982206
def conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding:int|tuple[int, ...]=0,
21992207
dtype:DTypeLike|None=None) -> Tensor:
@@ -2577,7 +2585,7 @@ def _inv_mask(a:Tensor|ConstType, b:Tensor|ConstType) -> Tensor: return mask.any
25772585
return mask.where(src, 0).sum(-1, dtype=self.dtype).add(self if include_self else _inv_mask(self, 0)).div(count)
25782586
raise RuntimeError(f"{reduce=} must be one of 'sum', 'prod', 'mean', 'amax', 'amin'")
25792587

2580-
def sort(self, dim:int=-1, descending:bool=False):
2588+
def sort(self, dim:int=-1, descending:bool=False) -> tuple[Tensor, Tensor]:
25812589
"""
25822590
Performs a bitonic sort on the tensor along the specified dimension.
25832591
@@ -2621,14 +2629,15 @@ def sort(self, dim:int=-1, descending:bool=False):
26212629
x = blue_box.cat(flipped_green_box.flip(flip_dims), dim=crossover_dim)
26222630
x = x.flatten(dim, dim+n_stages-1).shrink(tuple((0, orig_len) if i == dim else None for i in range(x.ndim)))
26232631
# compute indices for sorted values
2624-
idx = Tensor.arange(orig_len, device=self.device).reshape(tuple(orig_len if i == dim else 1 for i in range(x.ndim))).expand(x.shape)
2632+
idx = Tensor.arange(orig_len, requires_grad=False, device=self.device).reshape(tuple(orig_len if i == dim else 1 for i in range(x.ndim)))
2633+
idx = idx.expand(x.shape)
26252634
def compute_counts(t:Tensor): return ((idx.unsqueeze(dim) <= idx.unsqueeze(dim+1)) & (t.unsqueeze(dim) == t.unsqueeze(dim+1))).sum(dim+1)
26262635
count_orig, count_sorted = compute_counts(self), compute_counts(x)
26272636
cond = (self.unsqueeze(dim+1) == x.unsqueeze(dim)) & (count_orig.unsqueeze(dim+1) == count_sorted.unsqueeze(dim))
26282637
idx = (cond * idx.unsqueeze(dim+1)).sum(dim)
26292638
return x, idx
26302639

2631-
def topk(self, k:int, dim:int=-1, largest:bool=True, sorted_:bool=True):
2640+
def topk(self, k:int, dim:int=-1, largest:bool=True, sorted_:bool=True) -> tuple[Tensor, Tensor]:
26322641
"""
26332642
Computes the top-k elements of the tensor along the specified `dim`.
26342643

0 commit comments

Comments
 (0)