Skip to content

[pull] master from tinygrad:master #101

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,10 @@ jobs:
key: windows-minimal
deps: testing_minimal
- name: Set env
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'clang' && 'CLANG=1'}}"
shell: bash
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'clang' && 'CLANG=1'}}" >> $GITHUB_ENV
- name: Run pytest (${{ matrix.backend }})
shell: bash
run: python -m pytest -n=auto test/test_tiny.py test/test_ops.py --durations=20
run: |
python -c "from tinygrad import Device; assert Device.DEFAULT == '${{ matrix.backend }}'.upper(), Device.DEFAULT"
python -m pytest -n=auto test/test_tiny.py test/test_ops.py --durations=20
3 changes: 2 additions & 1 deletion examples/handcode_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from tinygrad.codegen.kernel import Kernel
from tinygrad.ops import Ops, sym_infer
from tinygrad.device import Compiled
from tinygrad.engine.search import time_linearizer, beam_search, bufs_from_lin
from tinygrad.engine.search import beam_search, bufs_from_lin
from tinygrad.helpers import DEBUG, ansilen, getenv, colored, TRACEMETA
from extra.optimization.helpers import time_linearizer

def get_sched_resnet():
mdl = ResNet50()
Expand Down
21 changes: 21 additions & 0 deletions extra/optimization/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,24 @@ def lin_to_feats(lin:Kernel, use_sts=True):
else:
assert len(ret) == 274, f"wrong len {len(ret)}"
return ret

from tinygrad.device import Device, Buffer
from tinygrad.engine.search import _ensure_buffer_alloc, _time_program
from tinygrad.helpers import to_function_name, CACHELEVEL, diskcache_put

def time_linearizer(lin:Kernel, rawbufs:list[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501
key = {"ast": lin.ast.key, "opts": str(lin.applied_opts), "allow_test_size": allow_test_size,
"max_global_size": max_global_size, "clear_l2": clear_l2, "device": lin.opts.device, "suffix": lin.opts.suffix}
if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)

dev = Device[lin.opts.device]
assert dev.compiler is not None

rawbufs = _ensure_buffer_alloc(rawbufs)
var_vals: dict[Variable, int] = {k:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
p = lin.to_program()
tms = _time_program(p, dev.compiler.compile(p.src), var_vals, rawbufs,
max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name))

if CACHELEVEL >= 2: diskcache_put("time_linearizer", key, tms)
return min(tms)
4 changes: 2 additions & 2 deletions extra/optimization/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import math, random
from tinygrad.tensor import Tensor
from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict
from tinygrad.engine.search import actions, bufs_from_lin, time_linearizer, get_kernel_actions
from tinygrad.engine.search import actions, bufs_from_lin, get_kernel_actions
from tinygrad.nn.optim import Adam
from extra.optimization.extract_policynet import PolicyNet
from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats
from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats, time_linearizer

if __name__ == "__main__":
net = PolicyNet()
Expand Down
4 changes: 2 additions & 2 deletions extra/optimization/search.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import argparse
from extra.optimization.helpers import ast_str_to_lin
from extra.optimization.helpers import ast_str_to_lin, time_linearizer

from tinygrad import dtypes
from tinygrad.helpers import BEAM, getenv
from tinygrad.device import Device, Compiled
from tinygrad.codegen.kernel import Kernel
from tinygrad.engine.search import time_linearizer, beam_search, bufs_from_lin
from tinygrad.engine.search import beam_search, bufs_from_lin


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions extra/optimization/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from tinygrad.helpers import getenv, colored
from tinygrad.tensor import Tensor
from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict
from tinygrad.engine.search import bufs_from_lin, time_linearizer, actions, get_kernel_actions
from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats
from tinygrad.engine.search import bufs_from_lin, actions, get_kernel_actions
from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats, time_linearizer
from extra.optimization.extract_policynet import PolicyNet
from extra.optimization.pretrain_valuenet import ValueNet

Expand Down
4 changes: 2 additions & 2 deletions extra/optimization/test_time_linearizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from extra.optimization.helpers import load_worlds, ast_str_to_lin
from tinygrad.engine.search import bufs_from_lin, time_linearizer, get_kernel_actions
from extra.optimization.helpers import load_worlds, ast_str_to_lin, time_linearizer
from tinygrad.engine.search import bufs_from_lin, get_kernel_actions

if __name__ == "__main__":
ast_strs = load_worlds()
Expand Down
4 changes: 2 additions & 2 deletions test/external/external_benchmark_hcopt.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import random
from tinygrad.helpers import getenv
from tinygrad.engine.search import time_linearizer, beam_search, bufs_from_lin
from extra.optimization.helpers import load_worlds, ast_str_to_lin
from tinygrad.engine.search import beam_search, bufs_from_lin
from extra.optimization.helpers import load_worlds, ast_str_to_lin, time_linearizer

def optimize_kernel(k):
# TODO: update this
Expand Down
4 changes: 2 additions & 2 deletions test/external/speed_beam_v_hcopt.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from tinygrad import Device
from tinygrad.helpers import getenv, DEBUG, BEAM
from tinygrad.engine.search import beam_search, time_linearizer, bufs_from_lin
from extra.optimization.helpers import load_worlds, ast_str_to_lin
from tinygrad.engine.search import beam_search, bufs_from_lin
from extra.optimization.helpers import load_worlds, ast_str_to_lin, time_linearizer

if __name__ == "__main__":
filter_reduce = bool(getenv("FILTER_REDUCE"))
Expand Down
3 changes: 1 addition & 2 deletions test/external/verify_kernel.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import argparse
from collections import defaultdict
from extra.optimization.helpers import kern_str_to_lin
from extra.optimization.helpers import kern_str_to_lin, time_linearizer
from test.external.fuzz_linearizer import compare_linearizer
from tinygrad.helpers import colored
from tinygrad.codegen.kernel import Kernel
from tinygrad.engine.search import time_linearizer

# Use this with the LOGKERNS options to verify that all executed kernels are valid and evaluate to the same ground truth results

Expand Down
20 changes: 10 additions & 10 deletions test/test_linearizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def test_var_multireduce(self):
first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((3, 27, 1, 32)).expand((3, 27, 32, 32)).to_uop()))
first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (3,)))
neg_mean = first_reduce * ast_const(dtypes.float, -0.03125, (3, 27, 32, 1))
# store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((3, 27, 32, 1)).to_uop(), mean))
# store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((3, 27, 32, 1)).to_uop(), mean))
# verify_lazyop(store)
second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((3, 27, 32, 1)).to_uop()))
squares = (second_x+neg_mean)*(second_x+neg_mean)
Expand Down Expand Up @@ -854,7 +854,7 @@ def test_two_nested_range(self):
ranges = [i for i,u in enumerate(lin.uops) if u.op is Ops.RANGE]
assert len(ranges) == 1 # NOTE: it collapses now
# RANGE -> LOAD -> RANGE -> ASSIGN
#assert any(x.op is UOps.LOAD for x in lin.uops[ranges[0]:ranges[1]])
#assert any(x.op is Ops.LOAD for x in lin.uops[ranges[0]:ranges[1]])

def test_three_nested_range(self):
a = Tensor.randn(2, ).realize()
Expand All @@ -865,7 +865,7 @@ def test_three_nested_range(self):
# RANGE -> RANGE -> LOAD -> RANGE -> ASSIGN
# NOTE: nothing should toposort between the first two ranges
#assert ranges[0]+1 == ranges[1]
#assert any(x.op is UOps.LOAD for x in lin.uops[ranges[1]:ranges[2]])
#assert any(x.op is Ops.LOAD for x in lin.uops[ranges[1]:ranges[2]])

def test_two_nested_range_alt_indexing(self):
a = Tensor([2, 2]).realize()
Expand Down Expand Up @@ -895,14 +895,14 @@ def test_range_outer_op_before_phi_nested_range(self):
assert len(ranges) == 1 # NOTE: it collapses now
#if getenv("PTX"):
# LOAD -> RANGE -> CAST -> ALU -> ALU -> LOAD -> ALU -> RANGE -> ALU -> ASSIGN
# assert lin.uops[ranges[0]-2].op is UOps.LOAD
# assert lin.uops[ranges[0]-2].op is Ops.LOAD
# assert ranges[1] == ranges[0]+6
# assert [x.op for x in lin.uops[ranges[1]-2:ranges[1]]] == [UOps.LOAD, UOps.ALU]
# assert [x.op for x in lin.uops[ranges[1]-2:ranges[1]]] == [Ops.LOAD, Ops.ALU]
# LOAD -> RANGE -> LOAD -> ALU -> RANGE -> ASSIGN
#else:
# assert lin.uops[ranges[0]-2].op is UOps.LOAD
# assert lin.uops[ranges[0]-2].op is Ops.LOAD
# assert ranges[1] == ranges[0]+3
# assert [x.op for x in lin.uops[ranges[1]-2:ranges[1]]] == [UOps.LOAD, UOps.ALU]
# assert [x.op for x in lin.uops[ranges[1]-2:ranges[1]]] == [Ops.LOAD, Ops.ALU]

def test_range_outer_op_after_phi(self):
a = Tensor.randn(4, 1).realize()
Expand Down Expand Up @@ -1306,7 +1306,7 @@ def test_grouped_store_phis(self):
# check that the float4 cast collapses
store_vals = [u.src[-1] for u in k.uops if u.op is Ops.STORE]
for val in store_vals:
assert val.dtype == dtypes.float.vec(4) # and val.op is not UOps.VECTORIZE
assert val.dtype == dtypes.float.vec(4) # and val.op is not Ops.VECTORIZE

@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
Expand Down Expand Up @@ -1345,7 +1345,7 @@ def get_recursive(uop): return set.union(set(uop.src), [uop], *[get_recursive(v)
barrier = [u for u in k.uops if u.op is Ops.BARRIER][0]
# check that the float4 cast collapses for all stores
for store in local_stores+global_stores:
assert store.src[-1].dtype.count > 1 # and store.src[2].op is not UOps.VECTORIZE
assert store.src[-1].dtype.count > 1 # and store.src[2].op is not Ops.VECTORIZE
# # check the children's vins
# TODO: src ALU are not the same, should it?
# assert barrier.src == tuple(local_stores)
Expand All @@ -1362,7 +1362,7 @@ def test_grouped_store_local_only(self):

# the float4 value stores directly in lds and we skip upcast
self.assertEqual(stores[0].src[-1].dtype, dtypes.float.vec(4))
#assert stores[0].src[-1].op is not UOps.VECTORIZE
#assert stores[0].src[-1].op is not Ops.VECTORIZE

# the global store doesn't change
assert stores[1].src[-1].dtype == dtypes.float
Expand Down
6 changes: 3 additions & 3 deletions test/test_linearizer_failures.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def test_failure_6(self):
ast_const(dtypes.int, 10, st_src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),))
opts = [Opt(op=OptOps.UPCAST, axis=0, arg=2), Opt(op=OptOps.UPCAST, axis=0, arg=0)]
# COMPILE FAILED, KeyError: UOps.CONST
# COMPILE FAILED, KeyError: Ops.CONST
helper_test_lin(Kernel(ast), opts, failed_platforms=[])

def test_failure_7(self):
Expand Down Expand Up @@ -804,7 +804,7 @@ def test_failure_32(self):
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[], atol=0.1, rtol=0.05)

def test_failure_33(self):
# UOps.UNMUL left after linearize
# Ops.UNMUL left after linearize
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
Expand Down Expand Up @@ -868,7 +868,7 @@ def test_failure_35(self): self.test_failure_34(True)

# from world fuzz_linearizer: PYTHONPATH=. METAL=1 FUZZ_ALL_ACTIONS=1 DEPTH=1 FUZZ_N=100 FUZZ_NTH=84 python3 ./test/external/fuzz_linearizer.py
def test_failure_36(self):
# UOps.UNMUL left after linearize
# Ops.UNMUL left after linearize
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(), arg=0, src=()),
Expand Down
4 changes: 2 additions & 2 deletions test/test_linearizer_overflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from tinygrad import dtypes, Device
from tinygrad.helpers import CI
from tinygrad.codegen.kernel import Kernel
from tinygrad.engine.search import Opt, OptOps
from tinygrad.engine.search import time_linearizer, bufs_from_lin
from tinygrad.engine.search import Opt, OptOps, bufs_from_lin
from extra.optimization.helpers import time_linearizer

# stuff needed to unpack a kernel
from tinygrad.ops import UOp, Ops
Expand Down
3 changes: 2 additions & 1 deletion test/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from tinygrad.codegen.kernel import Opt, OptOps
from tinygrad.codegen.kernel import Kernel
from tinygrad.ops import UOp, Ops
from tinygrad.engine.search import time_linearizer, bufs_from_lin, actions, beam_search
from tinygrad.engine.search import bufs_from_lin, actions, beam_search
from tinygrad.device import Device, Buffer
from tinygrad.tensor import Tensor
from tinygrad.dtype import dtypes
from tinygrad.helpers import Context, GlobalCounters
from tinygrad.engine.realize import capturing
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from extra.optimization.helpers import time_linearizer

class TestTimeLinearizer(unittest.TestCase):
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "WebGPU timestamps are low precision, tm is 0")
Expand Down
3 changes: 1 addition & 2 deletions test/test_subbuffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ def test_subbuffer_len(self):

def test_subbuffer_used(self):
t = Tensor.arange(0, 10, dtype=dtypes.uint8).realize()
# TODO: why does it needs contiguous
vt = t[2:4].contiguous().realize()
vt = t[2:4].realize()
out = (vt + 100).tolist()
assert out == [102, 103]

Expand Down
4 changes: 2 additions & 2 deletions tinygrad/codegen/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,8 @@ def apply_tensor_cores(self, use_tensor_cores=1, extra_opts:Optional[list[Opt]]=
-1: iterates through all available tensor cores in order and uses the first one that matches the requirements (dims and dtypes)
[0-N]: uses only the n'th tensor core available; useful for search
tc_opt -- controls which kinds of kernels may be eligible for tensor cores application (default 2 during BEAM, 0 otherwise)
0: applies to only kernels with a single reduce axis and direct UOps.LOAD into Ops.MUL
1: allows kernels with multiple reduce axes and also multiplication of UOps.CAST'd buffers
0: applies to only kernels with a single reduce axis and direct Ops.LOAD into Ops.MUL
1: allows kernels with multiple reduce axes and also multiplication of Ops.CAST'd buffers
2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed
"""
if tc_select is None: tc_select = TC_SELECT.value
Expand Down
19 changes: 1 addition & 18 deletions tinygrad/engine/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import replace
from tinygrad.ops import UOp, Ops, Variable, sym_infer
from tinygrad.device import Device, Buffer, Compiler
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored
from tinygrad.helpers import IGNORE_BEAM_CACHE, TC_SEARCH_OVER_SHAPE
from tinygrad.dtype import ImageDType, PtrDType
from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError
Expand Down Expand Up @@ -197,20 +197,3 @@ def try_exec(local_size):
ret = min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])
assert not math.isinf(ret[0]), "all optimize_local_size exec failed"
return ret[1]

def time_linearizer(lin:Kernel, rawbufs:list[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501
key = {"ast": lin.ast.key, "opts": str(lin.applied_opts), "allow_test_size": allow_test_size,
"max_global_size": max_global_size, "clear_l2": clear_l2, "device": lin.opts.device, "suffix": lin.opts.suffix}
if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)

dev = Device[lin.opts.device]
assert dev.compiler is not None

rawbufs = _ensure_buffer_alloc(rawbufs)
var_vals: dict[Variable, int] = {k:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
p = lin.to_program()
tms = _time_program(p, dev.compiler.compile(p.src), var_vals, rawbufs,
max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name))

if CACHELEVEL >= 2: diskcache_put("time_linearizer", key, tms)
return min(tms)
2 changes: 1 addition & 1 deletion tinygrad/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ def _min_max(self) -> tuple[ConstType, ConstType]:
if self.op is Ops.RANGE: return self.src[0].vmin, (self.src[1]-1).vmax
if self.op is Ops.BIND: return self.src[0]._min_max # ignore the bound value
if self.op in {Ops.UNROLL, Ops.VECTORIZE}: return min(x.vmin for x in self.src), max(x.vmax for x in self.src)
# TODO: UOps.SPECIAL is UOps.DEFINE_VAR
# TODO: Ops.SPECIAL is Ops.DEFINE_VAR
if self.op is Ops.SPECIAL: return 0, self.arg[1]-1 if isinstance(self.arg[1], int) else self.arg[1].vmax
if self.op is Ops.CONST: return self.arg, self.arg
if self.op is Ops.VCONST: return (min(self.arg), max(self.arg))
Expand Down
4 changes: 2 additions & 2 deletions tinygrad/renderer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ class Renderer:
has_local: bool = True
has_shared: bool = True
# NOTE: these two should be in (x,y,z) order to match the max_sizes argument in get_grouped_dims
global_max: Optional[tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: UOps.SPECIAL int32 indexes right now
local_max: Optional[tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: UOps.SPECIAL int32 indexes right now
global_max: Optional[tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: Ops.SPECIAL int32 indexes right now
local_max: Optional[tuple[int, ...]] = (0x8FFFFFFF,) * (3) # TODO: Ops.SPECIAL int32 indexes right now
shared_max: int = 32768
tensor_cores: list[TensorCore] = []
extra_matcher: Optional[PatternMatcher] = None
Expand Down
7 changes: 1 addition & 6 deletions tinygrad/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,6 @@
# NOTE: the arg here specifies clone=True, which prevents folding same device copy
(UPat(Ops.COPY, name="copy", src=(UPat(Ops.DEVICE), UPat.var("x"))), lambda copy,x: isinstance(copy.arg, bool) and copy.dtype == x.dtype),

# VIEW(BUFFER) applies a ShapeTracker on top of the underlying device buffer
# NOTE: VIEW size exactly matches the underlying BUFFER, tensor doesn't apply movement ops to the VIEW
(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),
lambda view,buf: view.dtype == buf.dtype and view.size == buf.size and view.st.contiguous),

# ASSIGN changes the value of a realized buffer
(UPat(Ops.ASSIGN, name="assign", src=(UPat.var("target"), UPat.var("new_val"))),
lambda assign,target,new_val: target.is_realized and (assign.dtype == target.dtype == new_val.dtype)),
Expand Down Expand Up @@ -113,7 +108,7 @@
(UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, allow_any_len=True)), lambda: True), # NOTE: all pointers must be local

# NOTE: for testing, we let sinks be anything
#(UPat(UOps.SINK, src=UPat(UOps.STORE)), lambda: True),
#(UPat(Ops.SINK, src=UPat(Ops.STORE)), lambda: True),
(UPat(Ops.SINK, dtypes.void), lambda: True),
(UPat((Ops.NOOP, Ops.CUSTOM)), lambda: True),

Expand Down
Loading