Skip to content

[pull] master from tinygrad:master #103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/tensor/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
::: tinygrad.Tensor.tril
::: tinygrad.Tensor.interpolate
::: tinygrad.Tensor.scatter
::: tinygrad.Tensor.scatter_reduce

## Neural Network (functional)

Expand Down
5 changes: 3 additions & 2 deletions extra/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,9 +691,10 @@ def ScatterND(x:Tensor, indices:Tensor, updates:Tensor, reduction:Literal["none"
else: raise NotImplementedError("reduction doesn't support max or min")
return x

def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Literal["none", "add", "mul"]="none"):
def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Literal["none", "add", "mul", "min", "max"]="none"):
indices = (indices < 0).where(x.shape[axis], 0) + indices
return x.scatter(axis, indices, updates, {"none":None, "mul": "multiply"}.get(reduction, reduction))
if reduction == "none": return x.scatter(axis, indices, updates)
return x.scatter_reduce(axis, indices, updates, {"add": "sum", "mul": "prod", "min": "amin", "max": "amax"}.get(reduction))
def GatherElements(x:Tensor, indices:Tensor, axis:int):
indices = (indices < 0).where(x.shape[axis], 0) + indices
return x.gather(axis, indices)
Expand Down
2 changes: 0 additions & 2 deletions test/external/external_test_onnx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,7 @@ def supports_device(cls, device: str) -> bool:
backend_test.exclude('test_ai_onnx_ml_label_encoder_tensor_mapping_cpu') # bad data type string
backend_test.exclude('test_group_normalization_*') # numerical inaccuracy problem. Current Group Normalization OP fails test

backend_test.exclude('test_scatter_elements_with_reduction_min_cpu') # min not yet supported
backend_test.exclude('test_scatternd_min_cpu') # min not yet supported
backend_test.exclude('test_scatter_elements_with_reduction_max_cpu') # max not yet supported
backend_test.exclude('test_scatternd_max_cpu') # max not yet supported

if Device.DEFAULT in ['GPU', 'METAL']:
Expand Down
42 changes: 30 additions & 12 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2572,12 +2572,6 @@ def test_scatter(self):
vals=[[1.,2.,3.,4.], [1.,0.]])

def test_scatter_add(self):
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
for dim in (0,1,2,-1,-2,-3):
helper_test_op([(4,5,6), (4,5,6)], lambda x,src: x.scatter(dim=dim, index=b, src=src, reduce="add"),
lambda x,src: x.scatter(dim=dim, index=a, src=src, reduce="add"), forward_only=True)

b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
helper_test_op([(4,5,6)], lambda x: x.scatter(dim=1, index=b, value=float("inf"), reduce="add"),
Expand All @@ -2592,10 +2586,6 @@ def test_scatter_add(self):
def test_scatter_mul(self):
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
for dim in (0,1,2,-1,-2,-3):
helper_test_op([(4,5,6), (4,5,6)], lambda x,src: x.scatter(dim=dim, index=b, src=src, reduce="multiply"),
lambda x,src: x.scatter(dim=dim, index=a, src=src, reduce="multiply"), forward_only=True)

helper_test_op([(4,5,6)], lambda x: x.scatter(dim=1, index=b, value=float("inf"), reduce="multiply"),
lambda x: x.scatter(dim=1, index=a, src=float("inf"), reduce="multiply"), forward_only=True)

Expand All @@ -2605,10 +2595,38 @@ def test_scatter_mul(self):
lambda x: x.scatter(1, b, float("nan"), reduce="multiply"),
lambda x: x.scatter(1, a, float("nan"), reduce="multiply"), forward_only=True,)

def test_scatter_no_reduce_tensor_src(self):
with self.assertRaises(TypeError):
Tensor.ones(4).scatter(dim=1, index=Tensor([0]), src=Tensor.ones(4), reduce="add")

def test_scatter_reduce(self):
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
for reduce in ("sum", "prod", "mean", "amin", "amax"):
for dim in (0,1,2,-1,-2,-3):
helper_test_op([(4,5,6), (4,5,6)],
lambda x,src: x.scatter_reduce(dim=dim, index=b, src=src, reduce=reduce),
lambda x,src: x.scatter_reduce(dim=dim, index=a, src=src, reduce=reduce), forward_only=True)
helper_test_op([(4,5,6), (4,5,6)],
lambda x,src: x.scatter_reduce(dim=dim, index=b, src=src, reduce=reduce, include_self=False),
lambda x,src: x.scatter_reduce(dim=dim, index=a, src=src, reduce=reduce, include_self=False), forward_only=True)

def test_scatter_reduce_prod_zeros(self):
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
x = Tensor.zeros([4,5,6]).float()
y = torch.zeros([4,5,6]).float()
helper_test_op([(4,5,6)], lambda src: y.scatter(dim=1, index=b, src=src, reduce="multiply"),
lambda src: x.scatter(dim=1, index=a, src=src, reduce="multiply"), forward_only=True)
helper_test_op([(4,5,6)],
lambda src: y.scatter_reduce(dim=1, index=b, src=src, reduce="prod"),
lambda src: x.scatter_reduce(dim=1, index=a, src=src, reduce="prod"), forward_only=True)

def test_scatter_reduce_invalid_reduce_op(self):
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
self.helper_test_exception([(4,5,6), (4,5,6)],
lambda x,src: x.scatter_reduce(dim=0, index=b, src=src, reduce="INVALID"),
lambda x,src: x.scatter_reduce(dim=0, index=a, src=src, reduce="INVALID"),
RuntimeError)

def test_scaled_dot_product_attention(self):
helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64)], torch.nn.functional.scaled_dot_product_attention, Tensor.scaled_dot_product_attention)
Expand Down
26 changes: 23 additions & 3 deletions test/test_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2490,27 +2490,47 @@ def test_contiguous_view_realizes(self):
self.assertEqual(b.lazydata.base.buffer.size, 16)

class TestUOpBecome(unittest.TestCase):
# the simplest case, if we create a new BUFFER for this UOp
# the simplest case, if we create a new BUFFER for this tensor UOp
def test_new_buffer(self):
a = Tensor.empty(4, 4)
b = Tensor.empty(4, 4)
add = a+b
check_schedule(add, 1)
# NOTE: realized base is always a flat buffer
assert UPat(Ops.BUFFER).match(add.lazydata.base, {})
# the Tensor UOp can optionally stack a VIEW on top of BUFFER
# the Tensor UOp can optionally stack movement ops on top of BUFFER, in this case to preserve the (4, 4) shape of the tensor
assert UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),)).match(add.lazydata, {})
self.assertEqual(add.lazydata.size, 16)
self.assertEqual(add.lazydata.shape, (4, 4))

def test_new_buffer_view(self):
a = Tensor.empty(4, 4)
b = Tensor.empty(4, 4)
add = (a+b).reshape(8, 2)
check_schedule(add, 1)
assert UPat(Ops.BUFFER).match(add.lazydata.base, {})
# VIEW is preserverd after the becomes rewrite.
# the shape is preserverd in the becomes_map.
self.assertEqual(add.lazydata.shape, (8, 2))
assert add.lazydata is not add.lazydata.base

def test_new_flat_buffer(self):
a = Tensor.empty(4,)
b = Tensor.empty(4,)
add = a+b
check_schedule(add, 1)
# BUFFER already has a shape (4,), this tensor just becomes a contiguous BUFFER
assert UPat(Ops.BUFFER).match(add.lazydata, {})

# sometimes we prefer to perform an op before movement ops, in this case we should stack the mops on top of the new buffer

@unittest.expectedFailure
def test_new_buffer_mops(self):
a = Tensor.empty(4, 1)
b = a.expand(4, 4).reciprocal()
check_schedule(b, 1)
self.assertEqual(b.lazydata.base.realized.size, 4)
assert UPat(Ops.EXPAND, src=(UPat(Ops.RESHAPE),)).match(b.lazydata, {}), f"{b.lazydata}"

def test_become_existing_buffer(self):
a = Tensor.empty(4, 4)
b = a*1
Expand Down
4 changes: 2 additions & 2 deletions tinygrad/engine/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp):
# substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
(UPat((Ops.BITCAST, Ops.CONTIGUOUS), name="root"),
lambda root: root.replace(op=Ops.BUFFER_VIEW) if isinstance(root.device, str) and root.device.startswith("DISK") else None),
# remove CONST/BIND/BUFFER/VIEW from SINK
# remove CONST/BIND/BUFFER from SINK
(UPat(Ops.SINK, name="root"),
lambda root: UOp(Ops.SINK, root.dtype, new_src, root.arg)
if (new_src:=tuple(x.base for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None),
if (new_src:=tuple(x for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None),
])

remove_movement_ops = merge_views+PatternMatcher([
Expand Down
74 changes: 63 additions & 11 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2430,11 +2430,26 @@ def interpolate(self, size:tuple[int, ...], mode:str="linear", align_corners:boo
x = x.gather(i, index)
return x.cast(self.dtype)

def _pre_scatter(self, dim:int, index:Tensor, src:Tensor) -> tuple[Tensor, Tensor]:
assert index.ndim == self.ndim == src.ndim, f"self.ndim, index.ndim and src.dim must all equal, {self.ndim=} {index.ndim=} {src.ndim=}"
assert all((d == dim or self_ >= index_) and src_ >= index_ for d,(self_,index_,src_) in enumerate(zip(self.shape, index.shape, src.shape))), \
f"All dimensions of {index.shape=} should be <= to all dimensions of {src.shape=} and all dimensions except dimension {dim} of {self.shape=}"
# shrink src to index shape to shrink away the unused values
src = src.shrink(tuple((0,s) for s in index.shape))
# prepare src and mask for reduce with respect to dim
src = src.unsqueeze(-1).expand(*src.shape, self.shape[dim]).transpose(-1, dim)
mask = index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim]).transpose(-1, dim)
# pad src and mask to self.shape so that reduce can be done with padded values as no-ops
src, mask = (x.pad(tuple((0, self.shape[i] - x.shape[i]) if i != dim else None for i in range(self.ndim)) + (None,)) for x in (src, mask))
return src, mask

def scatter(self, dim:int, index:Tensor, src:Union[Tensor, ConstType], reduce:Union[None, Literal['multiply'], Literal['add']]=None) -> Tensor:
"""
Scatters `src` values along an axis specified by `dim`.
Apply `add` or `multiply` reduction operation with `reduce`.

NOTE: To use the `reduce` argument with a Tensor `src`, see `Tensor.scatter_reduce`.

```python exec="true" source="above" session="tensor" result="python"
src = Tensor.arange(1, 11).reshape(2, 5)
print(src.numpy())
Expand All @@ -2455,22 +2470,59 @@ def scatter(self, dim:int, index:Tensor, src:Union[Tensor, ConstType], reduce:Un
```
"""
if reduce not in {None, "add", "multiply"}: raise TypeError(f"{reduce=} must be one of None, 'multiply', or 'add'")
index, dim = index.to(self.device), self._resolve_dim(dim)
if reduce and isinstance(src, Tensor): raise TypeError("Tensor src is not supported with reduce arg. see scatter_reduce")
src = src.cast(self.dtype) if isinstance(src, Tensor) else Tensor(src, device=self.device, dtype=self.dtype)._broadcast_to(index.shape)
assert index.ndim == self.ndim == src.ndim, f"self.ndim, index.ndim and src.dim must all equal, {self.ndim=} {index.ndim=} {src.ndim=}"
assert all((d == dim or self_ >= index_) and src_ >= index_ for d,(self_,index_,src_) in enumerate(zip(self.shape, index.shape, src.shape))), \
f"All dimensions of {index.shape=} should be <= to all dimensions of {src.shape=} and all dimensions except dimension {dim} of {self.shape=}"
# shrink src to index shape to shrink away the unused values
src = src.shrink(tuple((0,s) for s in index.shape))
# prepare src and mask for reduce with respect to dim
src = src.unsqueeze(-1).expand(*src.shape, self.shape[dim]).transpose(-1, dim)
mask = index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim]).transpose(-1, dim)
# pad src and mask to self.shape so that reduce can be done with padded values as no-ops
src, mask = (x.pad(tuple((0, self.shape[i] - x.shape[i]) if i != dim else None for i in range(self.ndim)) + (None,)) for x in (src, mask))
index, dim = index.to(self.device), self._resolve_dim(dim)
src, mask = self._pre_scatter(dim, index, src)
# TODO: should not overwrite acc_dtype here?
if reduce == "add": return mask.where(src, 0).sum(-1, acc_dtype=self.dtype) + self
if reduce == "multiply": return mask.where(src, 1).prod(-1, acc_dtype=self.dtype) * self
return _masked_setitem(self, src, mask, (-1,))

def scatter_reduce(self, dim:int, index:Tensor, src:Tensor, reduce:Literal["sum", "prod", "mean", "amax", "amin"],
include_self:bool=True) -> Tensor:
"""
Scatters `src` values along an axis specified by `dim`.
Apply `"sum"`, `"prod"`, `"mean"`, `"amax"`, or `"amin"` reduction operations with `reduce`.

Set `include_self=False` to exclude values in the `self` Tensor from the reduction.

```python exec="true" source="above" session="tensor" result="python"
src = Tensor.arange(1, 11).cast(dtypes.float).reshape(2, 5)
print(src.numpy())
index = Tensor([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]])
print(index.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(1, 5, dtype=src.dtype).scatter_reduce(0, index, src, reduce='sum').numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(1, 5, dtype=src.dtype).scatter_reduce(0, index, src, reduce='prod').numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(1, 5, dtype=src.dtype).scatter_reduce(0, index, src, reduce='mean', include_self=False).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([[-10, 20, 0, 5, 10]], dtype=src.dtype).scatter_reduce(0, index, src, reduce='amax').numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([[-10, 20, 0, 5, 10]], dtype=src.dtype).scatter_reduce(0, index, src, reduce='amin').numpy())
```
"""
src = src.cast(self.dtype)
index, dim = index.to(self.device), self._resolve_dim(dim)
src, mask = self._pre_scatter(dim, index, src)
def _inv_mask(a:Union[Tensor, ConstType], b:Union[Tensor, ConstType]) -> Tensor: return mask.any(-1).logical_not().where(a, b)
# TODO: should not overwrite acc_dtype here?
if reduce == "sum": return mask.where(src, 0).sum(-1, acc_dtype=self.dtype).add(self if include_self else _inv_mask(self, 0))
if reduce == "prod": return mask.where(src, 1).prod(-1, acc_dtype=self.dtype).mul(self if include_self else _inv_mask(self, 1))
if reduce == "amax": return mask.where(src, m := dtypes.min(src.dtype)).max(-1).maximum(self if include_self else _inv_mask(self, m))
if reduce == "amin": return mask.where(src, m := dtypes.max(src.dtype)).min(-1).minimum(self if include_self else _inv_mask(self, m))
if reduce == "mean":
count = mask.where(1, 0).sum(-1, acc_dtype=self.dtype).add(1 if include_self else _inv_mask(1, 0))
return mask.where(src, 0).sum(-1, acc_dtype=self.dtype).add(self if include_self else _inv_mask(self, 0)).div(count)
raise RuntimeError(f"{reduce=} must be one of 'sum', 'prod', 'mean', 'amax', 'amin'")

# ***** unary ops *****

def logical_not(self):
Expand Down
Loading