@@ -2110,7 +2110,7 @@ def _apply_ceil_mode(self, pads:Sequence[int], k_:tuple[sint, ...], s_:int|tuple
2110
2110
2111
2111
# NOTE: these work for more than 2D
2112
2112
def avg_pool2d (self , kernel_size :tuple [int , ...]= (2 ,2 ), stride = None , dilation = 1 , padding :int | tuple [int , ...]= 0 ,
2113
- ceil_mode = False , count_include_pad = True ):
2113
+ ceil_mode = False , count_include_pad = True ) -> Tensor :
2114
2114
"""
2115
2115
Applies average pooling over a tensor.
2116
2116
@@ -2158,7 +2158,7 @@ def pool(x:Tensor, padding_:Sequence[int]) -> Tensor: return x.pad(padding_)._po
2158
2158
return pool (self , ceil_pads ).sum (axis ) / pool (self .pad (reg_pads ).ones_like (), tuple (cp - rp for cp ,rp in zip (ceil_pads , reg_pads ))).sum (axis )
2159
2159
2160
2160
def max_pool2d (self , kernel_size :tuple [int , ...]= (2 ,2 ), stride = None , dilation = 1 , padding :int | tuple [int , ...]= 0 ,
2161
- ceil_mode = False ) :
2161
+ ceil_mode = False , return_indices = False ) -> Tensor | tuple [ Tensor , Tensor ] :
2162
2162
"""
2163
2163
Applies max pooling over a tensor.
2164
2164
@@ -2175,6 +2175,7 @@ def max_pool2d(self, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1,
2175
2175
`(padding_left, padding_right, padding_top, padding_bottom, ...)`.
2176
2176
2177
2177
When `ceil_mode` is set to `True`, output shape will be determined using ceil division.
2178
+ When `return_indices` is set to `True`, the argmax will be returned along with the max values.
2178
2179
2179
2180
NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
2180
2181
@@ -2191,9 +2192,16 @@ def max_pool2d(self, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1,
2191
2192
print(t.max_pool2d(padding=1).numpy())
2192
2193
```
2193
2194
"""
2194
- pads = self ._resolve_pool_pads (padding , len (k_ := make_tuple (kernel_size , 2 )))
2195
+ axis = tuple (range (- len (k_ := make_tuple (kernel_size , 2 )), 0 ))
2196
+ pads = self ._resolve_pool_pads (padding , len (k_ ))
2195
2197
if ceil_mode : pads = self ._apply_ceil_mode (pads , k_ , stride if stride is not None else k_ , dilation )
2196
- return self .pad (pads , value = dtypes .min (self .dtype ))._pool (k_ , stride if stride is not None else k_ , dilation ).max (tuple (range (- len (k_ ), 0 )))
2198
+ pooled = self .pad (pads , value = dtypes .min (self .dtype ))._pool (k_ , stride if stride is not None else k_ , dilation )
2199
+ if not return_indices : return pooled .max (axis )
2200
+ spatial_sz = math .prod (spatial_shape := self .shape [- len (k_ ):])
2201
+ idx = Tensor .arange (spatial_sz ,0 ,- 1 , requires_grad = False , device = self .device ).reshape (spatial_shape )
2202
+ m = pooled == pooled .max (axis , keepdim = True )
2203
+ idx = m * idx .pad (pads , value = dtypes .min (idx .dtype ))._pool (k_ , stride if stride is not None else k_ , dilation )
2204
+ return pooled .max (axis ), spatial_sz - idx .max (axis )
2197
2205
2198
2206
def conv2d (self , weight :Tensor , bias :Tensor | None = None , groups = 1 , stride = 1 , dilation = 1 , padding :int | tuple [int , ...]= 0 ,
2199
2207
dtype :DTypeLike | None = None ) -> Tensor :
@@ -2577,7 +2585,7 @@ def _inv_mask(a:Tensor|ConstType, b:Tensor|ConstType) -> Tensor: return mask.any
2577
2585
return mask .where (src , 0 ).sum (- 1 , dtype = self .dtype ).add (self if include_self else _inv_mask (self , 0 )).div (count )
2578
2586
raise RuntimeError (f"{ reduce = } must be one of 'sum', 'prod', 'mean', 'amax', 'amin'" )
2579
2587
2580
- def sort (self , dim :int = - 1 , descending :bool = False ):
2588
+ def sort (self , dim :int = - 1 , descending :bool = False ) -> tuple [ Tensor , Tensor ] :
2581
2589
"""
2582
2590
Performs a bitonic sort on the tensor along the specified dimension.
2583
2591
@@ -2621,14 +2629,15 @@ def sort(self, dim:int=-1, descending:bool=False):
2621
2629
x = blue_box .cat (flipped_green_box .flip (flip_dims ), dim = crossover_dim )
2622
2630
x = x .flatten (dim , dim + n_stages - 1 ).shrink (tuple ((0 , orig_len ) if i == dim else None for i in range (x .ndim )))
2623
2631
# compute indices for sorted values
2624
- idx = Tensor .arange (orig_len , device = self .device ).reshape (tuple (orig_len if i == dim else 1 for i in range (x .ndim ))).expand (x .shape )
2632
+ idx = Tensor .arange (orig_len , requires_grad = False , device = self .device ).reshape (tuple (orig_len if i == dim else 1 for i in range (x .ndim )))
2633
+ idx = idx .expand (x .shape )
2625
2634
def compute_counts (t :Tensor ): return ((idx .unsqueeze (dim ) <= idx .unsqueeze (dim + 1 )) & (t .unsqueeze (dim ) == t .unsqueeze (dim + 1 ))).sum (dim + 1 )
2626
2635
count_orig , count_sorted = compute_counts (self ), compute_counts (x )
2627
2636
cond = (self .unsqueeze (dim + 1 ) == x .unsqueeze (dim )) & (count_orig .unsqueeze (dim + 1 ) == count_sorted .unsqueeze (dim ))
2628
2637
idx = (cond * idx .unsqueeze (dim + 1 )).sum (dim )
2629
2638
return x , idx
2630
2639
2631
- def topk (self , k :int , dim :int = - 1 , largest :bool = True , sorted_ :bool = True ):
2640
+ def topk (self , k :int , dim :int = - 1 , largest :bool = True , sorted_ :bool = True ) -> tuple [ Tensor , Tensor ] :
2632
2641
"""
2633
2642
Computes the top-k elements of the tensor along the specified `dim`.
2634
2643
0 commit comments