diff --git a/docs/tensor/ops.md b/docs/tensor/ops.md index dcfa7e53a6440..7db33a930458e 100644 --- a/docs/tensor/ops.md +++ b/docs/tensor/ops.md @@ -32,6 +32,7 @@ ::: tinygrad.Tensor.tril ::: tinygrad.Tensor.interpolate ::: tinygrad.Tensor.scatter +::: tinygrad.Tensor.scatter_reduce ## Neural Network (functional) diff --git a/extra/onnx.py b/extra/onnx.py index 81bb57199ab3b..8c3938f899e1b 100644 --- a/extra/onnx.py +++ b/extra/onnx.py @@ -691,9 +691,10 @@ def ScatterND(x:Tensor, indices:Tensor, updates:Tensor, reduction:Literal["none" else: raise NotImplementedError("reduction doesn't support max or min") return x - def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Literal["none", "add", "mul"]="none"): + def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Literal["none", "add", "mul", "min", "max"]="none"): indices = (indices < 0).where(x.shape[axis], 0) + indices - return x.scatter(axis, indices, updates, {"none":None, "mul": "multiply"}.get(reduction, reduction)) + if reduction == "none": return x.scatter(axis, indices, updates) + return x.scatter_reduce(axis, indices, updates, {"add": "sum", "mul": "prod", "min": "amin", "max": "amax"}.get(reduction)) def GatherElements(x:Tensor, indices:Tensor, axis:int): indices = (indices < 0).where(x.shape[axis], 0) + indices return x.gather(axis, indices) diff --git a/test/external/external_test_onnx_backend.py b/test/external/external_test_onnx_backend.py index 1c227d22f8dcd..dece6a21dc037 100644 --- a/test/external/external_test_onnx_backend.py +++ b/test/external/external_test_onnx_backend.py @@ -175,9 +175,7 @@ def supports_device(cls, device: str) -> bool: backend_test.exclude('test_ai_onnx_ml_label_encoder_tensor_mapping_cpu') # bad data type string backend_test.exclude('test_group_normalization_*') # numerical inaccuracy problem. Current Group Normalization OP fails test -backend_test.exclude('test_scatter_elements_with_reduction_min_cpu') # min not yet supported backend_test.exclude('test_scatternd_min_cpu') # min not yet supported -backend_test.exclude('test_scatter_elements_with_reduction_max_cpu') # max not yet supported backend_test.exclude('test_scatternd_max_cpu') # max not yet supported if Device.DEFAULT in ['GPU', 'METAL']: diff --git a/test/test_ops.py b/test/test_ops.py index 03b89dadc740c..8bee773ae860e 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2572,12 +2572,6 @@ def test_scatter(self): vals=[[1.,2.,3.,4.], [1.,0.]]) def test_scatter_add(self): - b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False) - a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False) - for dim in (0,1,2,-1,-2,-3): - helper_test_op([(4,5,6), (4,5,6)], lambda x,src: x.scatter(dim=dim, index=b, src=src, reduce="add"), - lambda x,src: x.scatter(dim=dim, index=a, src=src, reduce="add"), forward_only=True) - b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False) a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False) helper_test_op([(4,5,6)], lambda x: x.scatter(dim=1, index=b, value=float("inf"), reduce="add"), @@ -2592,10 +2586,6 @@ def test_scatter_add(self): def test_scatter_mul(self): b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False) a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False) - for dim in (0,1,2,-1,-2,-3): - helper_test_op([(4,5,6), (4,5,6)], lambda x,src: x.scatter(dim=dim, index=b, src=src, reduce="multiply"), - lambda x,src: x.scatter(dim=dim, index=a, src=src, reduce="multiply"), forward_only=True) - helper_test_op([(4,5,6)], lambda x: x.scatter(dim=1, index=b, value=float("inf"), reduce="multiply"), lambda x: x.scatter(dim=1, index=a, src=float("inf"), reduce="multiply"), forward_only=True) @@ -2605,10 +2595,38 @@ def test_scatter_mul(self): lambda x: x.scatter(1, b, float("nan"), reduce="multiply"), lambda x: x.scatter(1, a, float("nan"), reduce="multiply"), forward_only=True,) + def test_scatter_no_reduce_tensor_src(self): + with self.assertRaises(TypeError): + Tensor.ones(4).scatter(dim=1, index=Tensor([0]), src=Tensor.ones(4), reduce="add") + + def test_scatter_reduce(self): + b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False) + a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False) + for reduce in ("sum", "prod", "mean", "amin", "amax"): + for dim in (0,1,2,-1,-2,-3): + helper_test_op([(4,5,6), (4,5,6)], + lambda x,src: x.scatter_reduce(dim=dim, index=b, src=src, reduce=reduce), + lambda x,src: x.scatter_reduce(dim=dim, index=a, src=src, reduce=reduce), forward_only=True) + helper_test_op([(4,5,6), (4,5,6)], + lambda x,src: x.scatter_reduce(dim=dim, index=b, src=src, reduce=reduce, include_self=False), + lambda x,src: x.scatter_reduce(dim=dim, index=a, src=src, reduce=reduce, include_self=False), forward_only=True) + + def test_scatter_reduce_prod_zeros(self): + b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False) + a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False) x = Tensor.zeros([4,5,6]).float() y = torch.zeros([4,5,6]).float() - helper_test_op([(4,5,6)], lambda src: y.scatter(dim=1, index=b, src=src, reduce="multiply"), - lambda src: x.scatter(dim=1, index=a, src=src, reduce="multiply"), forward_only=True) + helper_test_op([(4,5,6)], + lambda src: y.scatter_reduce(dim=1, index=b, src=src, reduce="prod"), + lambda src: x.scatter_reduce(dim=1, index=a, src=src, reduce="prod"), forward_only=True) + + def test_scatter_reduce_invalid_reduce_op(self): + b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False) + a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False) + self.helper_test_exception([(4,5,6), (4,5,6)], + lambda x,src: x.scatter_reduce(dim=0, index=b, src=src, reduce="INVALID"), + lambda x,src: x.scatter_reduce(dim=0, index=a, src=src, reduce="INVALID"), + RuntimeError) def test_scaled_dot_product_attention(self): helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64)], torch.nn.functional.scaled_dot_product_attention, Tensor.scaled_dot_product_attention) diff --git a/test/test_schedule.py b/test/test_schedule.py index 56bbbbf3c92d4..04ced7b697f9f 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -2490,7 +2490,7 @@ def test_contiguous_view_realizes(self): self.assertEqual(b.lazydata.base.buffer.size, 16) class TestUOpBecome(unittest.TestCase): - # the simplest case, if we create a new BUFFER for this UOp + # the simplest case, if we create a new BUFFER for this tensor UOp def test_new_buffer(self): a = Tensor.empty(4, 4) b = Tensor.empty(4, 4) @@ -2498,8 +2498,10 @@ def test_new_buffer(self): check_schedule(add, 1) # NOTE: realized base is always a flat buffer assert UPat(Ops.BUFFER).match(add.lazydata.base, {}) - # the Tensor UOp can optionally stack a VIEW on top of BUFFER + # the Tensor UOp can optionally stack movement ops on top of BUFFER, in this case to preserve the (4, 4) shape of the tensor assert UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),)).match(add.lazydata, {}) + self.assertEqual(add.lazydata.size, 16) + self.assertEqual(add.lazydata.shape, (4, 4)) def test_new_buffer_view(self): a = Tensor.empty(4, 4) @@ -2507,10 +2509,28 @@ def test_new_buffer_view(self): add = (a+b).reshape(8, 2) check_schedule(add, 1) assert UPat(Ops.BUFFER).match(add.lazydata.base, {}) - # VIEW is preserverd after the becomes rewrite. + # the shape is preserverd in the becomes_map. self.assertEqual(add.lazydata.shape, (8, 2)) assert add.lazydata is not add.lazydata.base + def test_new_flat_buffer(self): + a = Tensor.empty(4,) + b = Tensor.empty(4,) + add = a+b + check_schedule(add, 1) + # BUFFER already has a shape (4,), this tensor just becomes a contiguous BUFFER + assert UPat(Ops.BUFFER).match(add.lazydata, {}) + + # sometimes we prefer to perform an op before movement ops, in this case we should stack the mops on top of the new buffer + + @unittest.expectedFailure + def test_new_buffer_mops(self): + a = Tensor.empty(4, 1) + b = a.expand(4, 4).reciprocal() + check_schedule(b, 1) + self.assertEqual(b.lazydata.base.realized.size, 4) + assert UPat(Ops.EXPAND, src=(UPat(Ops.RESHAPE),)).match(b.lazydata, {}), f"{b.lazydata}" + def test_become_existing_buffer(self): a = Tensor.empty(4, 4) b = a*1 diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index d61809e388ca9..c18deae24019e 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -66,10 +66,10 @@ def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp): # substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK (UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="root"), lambda root: root.replace(op=Ops.BUFFER_VIEW) if isinstance(root.device, str) and root.device.startswith("DISK") else None), - # remove CONST/BIND/BUFFER/VIEW from SINK + # remove CONST/BIND/BUFFER from SINK (UPat(Ops.SINK, name="root"), lambda root: UOp(Ops.SINK, root.dtype, new_src, root.arg) - if (new_src:=tuple(x.base for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None), + if (new_src:=tuple(x for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None), ]) remove_movement_ops = merge_views+PatternMatcher([ diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index c32a53e10255f..f180090e91fc3 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2430,11 +2430,26 @@ def interpolate(self, size:tuple[int, ...], mode:str="linear", align_corners:boo x = x.gather(i, index) return x.cast(self.dtype) + def _pre_scatter(self, dim:int, index:Tensor, src:Tensor) -> tuple[Tensor, Tensor]: + assert index.ndim == self.ndim == src.ndim, f"self.ndim, index.ndim and src.dim must all equal, {self.ndim=} {index.ndim=} {src.ndim=}" + assert all((d == dim or self_ >= index_) and src_ >= index_ for d,(self_,index_,src_) in enumerate(zip(self.shape, index.shape, src.shape))), \ + f"All dimensions of {index.shape=} should be <= to all dimensions of {src.shape=} and all dimensions except dimension {dim} of {self.shape=}" + # shrink src to index shape to shrink away the unused values + src = src.shrink(tuple((0,s) for s in index.shape)) + # prepare src and mask for reduce with respect to dim + src = src.unsqueeze(-1).expand(*src.shape, self.shape[dim]).transpose(-1, dim) + mask = index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim]).transpose(-1, dim) + # pad src and mask to self.shape so that reduce can be done with padded values as no-ops + src, mask = (x.pad(tuple((0, self.shape[i] - x.shape[i]) if i != dim else None for i in range(self.ndim)) + (None,)) for x in (src, mask)) + return src, mask + def scatter(self, dim:int, index:Tensor, src:Union[Tensor, ConstType], reduce:Union[None, Literal['multiply'], Literal['add']]=None) -> Tensor: """ Scatters `src` values along an axis specified by `dim`. Apply `add` or `multiply` reduction operation with `reduce`. + NOTE: To use the `reduce` argument with a Tensor `src`, see `Tensor.scatter_reduce`. + ```python exec="true" source="above" session="tensor" result="python" src = Tensor.arange(1, 11).reshape(2, 5) print(src.numpy()) @@ -2455,22 +2470,59 @@ def scatter(self, dim:int, index:Tensor, src:Union[Tensor, ConstType], reduce:Un ``` """ if reduce not in {None, "add", "multiply"}: raise TypeError(f"{reduce=} must be one of None, 'multiply', or 'add'") - index, dim = index.to(self.device), self._resolve_dim(dim) + if reduce and isinstance(src, Tensor): raise TypeError("Tensor src is not supported with reduce arg. see scatter_reduce") src = src.cast(self.dtype) if isinstance(src, Tensor) else Tensor(src, device=self.device, dtype=self.dtype)._broadcast_to(index.shape) - assert index.ndim == self.ndim == src.ndim, f"self.ndim, index.ndim and src.dim must all equal, {self.ndim=} {index.ndim=} {src.ndim=}" - assert all((d == dim or self_ >= index_) and src_ >= index_ for d,(self_,index_,src_) in enumerate(zip(self.shape, index.shape, src.shape))), \ - f"All dimensions of {index.shape=} should be <= to all dimensions of {src.shape=} and all dimensions except dimension {dim} of {self.shape=}" - # shrink src to index shape to shrink away the unused values - src = src.shrink(tuple((0,s) for s in index.shape)) - # prepare src and mask for reduce with respect to dim - src = src.unsqueeze(-1).expand(*src.shape, self.shape[dim]).transpose(-1, dim) - mask = index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim]).transpose(-1, dim) - # pad src and mask to self.shape so that reduce can be done with padded values as no-ops - src, mask = (x.pad(tuple((0, self.shape[i] - x.shape[i]) if i != dim else None for i in range(self.ndim)) + (None,)) for x in (src, mask)) + index, dim = index.to(self.device), self._resolve_dim(dim) + src, mask = self._pre_scatter(dim, index, src) + # TODO: should not overwrite acc_dtype here? if reduce == "add": return mask.where(src, 0).sum(-1, acc_dtype=self.dtype) + self if reduce == "multiply": return mask.where(src, 1).prod(-1, acc_dtype=self.dtype) * self return _masked_setitem(self, src, mask, (-1,)) + def scatter_reduce(self, dim:int, index:Tensor, src:Tensor, reduce:Literal["sum", "prod", "mean", "amax", "amin"], + include_self:bool=True) -> Tensor: + """ + Scatters `src` values along an axis specified by `dim`. + Apply `"sum"`, `"prod"`, `"mean"`, `"amax"`, or `"amin"` reduction operations with `reduce`. + + Set `include_self=False` to exclude values in the `self` Tensor from the reduction. + + ```python exec="true" source="above" session="tensor" result="python" + src = Tensor.arange(1, 11).cast(dtypes.float).reshape(2, 5) + print(src.numpy()) + index = Tensor([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]) + print(index.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor.ones(1, 5, dtype=src.dtype).scatter_reduce(0, index, src, reduce='sum').numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor.ones(1, 5, dtype=src.dtype).scatter_reduce(0, index, src, reduce='prod').numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor.ones(1, 5, dtype=src.dtype).scatter_reduce(0, index, src, reduce='mean', include_self=False).numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([[-10, 20, 0, 5, 10]], dtype=src.dtype).scatter_reduce(0, index, src, reduce='amax').numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([[-10, 20, 0, 5, 10]], dtype=src.dtype).scatter_reduce(0, index, src, reduce='amin').numpy()) + ``` + """ + src = src.cast(self.dtype) + index, dim = index.to(self.device), self._resolve_dim(dim) + src, mask = self._pre_scatter(dim, index, src) + def _inv_mask(a:Union[Tensor, ConstType], b:Union[Tensor, ConstType]) -> Tensor: return mask.any(-1).logical_not().where(a, b) + # TODO: should not overwrite acc_dtype here? + if reduce == "sum": return mask.where(src, 0).sum(-1, acc_dtype=self.dtype).add(self if include_self else _inv_mask(self, 0)) + if reduce == "prod": return mask.where(src, 1).prod(-1, acc_dtype=self.dtype).mul(self if include_self else _inv_mask(self, 1)) + if reduce == "amax": return mask.where(src, m := dtypes.min(src.dtype)).max(-1).maximum(self if include_self else _inv_mask(self, m)) + if reduce == "amin": return mask.where(src, m := dtypes.max(src.dtype)).min(-1).minimum(self if include_self else _inv_mask(self, m)) + if reduce == "mean": + count = mask.where(1, 0).sum(-1, acc_dtype=self.dtype).add(1 if include_self else _inv_mask(1, 0)) + return mask.where(src, 0).sum(-1, acc_dtype=self.dtype).add(self if include_self else _inv_mask(self, 0)).div(count) + raise RuntimeError(f"{reduce=} must be one of 'sum', 'prod', 'mean', 'amax', 'amin'") + # ***** unary ops ***** def logical_not(self):