Skip to content

Commit 3ed146a

Browse files
authored
Revert "rename Opt amt to arg (tinygrad#8767)" (tinygrad#8769)
This reverts commit bf04165.
1 parent bf04165 commit 3ed146a

15 files changed

+129
-129
lines changed

extra/mcts_search.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def add_node(node:MCTSNode):
162162
if node.n == 0: return
163163
for parent in node.parents: G.add_edge(parent, node)
164164
gopts = node.kernel.applied_opts
165-
edge_lbl = f"{str(gopts[-1].op)[7:]} {gopts[-1].axis} {gopts[-1].arg}" if len(gopts) else "ROOT"
165+
edge_lbl = f"{str(gopts[-1].op)[7:]} {gopts[-1].axis} {gopts[-1].amt}" if len(gopts) else "ROOT"
166166
G.add_node(node, label=f"{node.i+1}\n{node.tm:.2f} us\n{edge_lbl}\nt {node.t:.2f}\nn {node.n}",
167167
fillcolor="#80ff8080" if node.tm == best_tm else "#ffff8080", style='filled' if node.t == best_tm else '')
168168
if node.children is not None:

test/external/external_debug_metal_sd_conv.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
3030
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=4, src=()),
3131
x17,)),)),)),))
32-
opts = [Opt(op=OptOps.UPCAST, axis=3, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UNROLL, axis=2, arg=0), Opt(op=OptOps.UNROLL, axis=1, arg=0), Opt(op=OptOps.LOCAL, axis=1, arg=8), Opt(op=OptOps.LOCAL, axis=2, arg=8), Opt(op=OptOps.LOCAL, axis=2, arg=2)]
32+
opts = [Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UNROLL, axis=2, amt=0), Opt(op=OptOps.UNROLL, axis=1, amt=0), Opt(op=OptOps.LOCAL, axis=1, amt=8), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.LOCAL, axis=2, amt=2)]
3333

3434
k = Kernel(ast)
3535
for opt in opts: k.apply_opt(opt)

test/external/external_test_hcq_fuzz_failures.py

+1-1
Large diffs are not rendered by default.

test/external/external_test_nv.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@ def setUpClass(self):
2626

2727
def test_oor_kernels(self):
2828
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=Ops.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 256, 1, 512, 4, 16, 4, 16), strides=(0, 100352, 0, 196, 0, 14, 0, 1), offset=-15, mask=((0, 1), (0, 256), (0, 1), (0, 512), (0, 4), (1, 15), (0, 4), (1, 15)), contiguous=False), View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(2097152, 0, 0, 128, 2, 4096, 1088, 17), offset=0, mask=None, contiguous=False))))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(25088, 0, 49, 7, 1, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(dtypes.float, False)),), arg=((0, 3, 4), dtypes.float)),), arg=(dtypes.half, False)),), arg=MemBuffer(idx=0, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 512, 1, 1, 512, 3, 3), strides=(0, 0, 4608, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
29-
opts = [Opt(op=OptOps.TC, axis=6, arg=2), Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=3, arg=0), Opt(op=OptOps.LOCAL, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=2, arg=3), Opt(op=OptOps.UPCAST, axis=1, arg=2)] # noqa: E501
29+
opts = [Opt(op=OptOps.TC, axis=6, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.LOCAL, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=3), Opt(op=OptOps.UPCAST, axis=1, amt=2)] # noqa: E501
3030
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=["NV"])
3131

3232
def test_error_on_huge_dims(self):
3333
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=Ops.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 683), strides=(0, 0, 0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 683), strides=(0, 0, 683, 1), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=dtypes.float),), arg=(3,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 1), strides=(0, 0, 1, 0), offset=0, mask=None, contiguous=True),)))) # noqa: E501
34-
opts = [Opt(op=OptOps.GROUP, axis=0, arg=0), Opt(op=OptOps.PADTO, axis=1, arg=32), Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=2), Opt(op=OptOps.LOCAL, axis=0, arg=2)] # noqa: E501
34+
opts = [Opt(op=OptOps.GROUP, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=1, amt=32), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=2), Opt(op=OptOps.LOCAL, axis=0, amt=2)] # noqa: E501
3535
with self.assertRaises(RuntimeError) as cm:
3636
lin = Kernel(ast)
3737
for opt in opts: lin.apply_opt(opt)

test/external/external_test_train_gpt2.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def test_1(self):
2626
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(38633472), arg=2, src=()),
2727
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1024, 50304, 768), strides=(0, 0, 768, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
2828

29-
opts = [Opt(op=OptOps.LOCAL, axis=0, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=3), Opt(op=OptOps.LOCAL, axis=0, arg=2)]
29+
opts = [Opt(op=OptOps.LOCAL, axis=0, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=3), Opt(op=OptOps.LOCAL, axis=0, amt=2)]
3030
kernel = Kernel(ast)
3131
for opt in opts: kernel.apply_opt(opt)
3232
run_linearizer(kernel)
@@ -46,7 +46,7 @@ def test_2(self):
4646
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(205852672), arg=2, src=()),
4747
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1024, 50304, 768), strides=(51463168, 50257, 1, 0), offset=0, mask=((0, 4), (0, 1024), (0, 50257), (0, 768)), contiguous=False),)), src=()),)),)),)),)),))
4848

49-
opts = [Opt(op=OptOps.LOCAL, axis=1, arg=16), Opt(op=OptOps.LOCAL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=2, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=3, arg=4)]
49+
opts = [Opt(op=OptOps.LOCAL, axis=1, amt=16), Opt(op=OptOps.LOCAL, axis=0, amt=8), Opt(op=OptOps.UPCAST, axis=2, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=4)]
5050
kernel = Kernel(ast)
5151
for opt in opts: kernel.apply_opt(opt)
5252
run_linearizer(kernel)

test/external/external_test_valid_remove.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def test_valid_removal(self):
5151
x19,)),
5252
x29,)),)),)),))
5353

54-
opts = [Opt(op=OptOps.UPCAST, axis=3, arg=4), Opt(op=OptOps.UNROLL, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.NOLOCALS, axis=None, arg=None)]
54+
opts = [Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.UNROLL, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.NOLOCALS, axis=None, amt=None)]
5555
kernel = Kernel(ast)
5656

5757
for opt in opts: kernel.apply_opt(opt)
@@ -108,7 +108,7 @@ def test_const_idx(self):
108108
UOp(Ops.DEFINE_GLOBAL, dtypes.imagef((1, 128, 4)), arg=3, src=()),
109109
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=0, mask=((0, 1), (9, 10), (0, 512)), contiguous=False),)), src=()),)),)),)),)),))
110110

111-
opts = [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.NOLOCALS, axis=None, arg=None)]
111+
opts = [Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.NOLOCALS, axis=None, amt=None)]
112112
kernel = Kernel(ast)
113113

114114
for opt in opts: kernel.apply_opt(opt)

test/test_arange.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def test_complexity_w_unroll8(self): return self.test_complexity([Opt(OptOps.UNR
4141
def test_complexity_w_upcast_and_unroll(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], limit=1)
4242

4343
@unittest.skip("doesn't work yet")
44-
def test_complexity_w_local_and_padto(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16), Opt(op=OptOps.PADTO, axis=1, arg=32)])
44+
def test_complexity_w_local_and_padto(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16), Opt(op=OptOps.PADTO, axis=1, amt=32)])
4545

4646
def test_all_opts(self, opts=None, exclude=None):
4747
k = Kernel(Tensor.arange(256).schedule()[-1].ast)
@@ -59,11 +59,11 @@ def test_all_opts(self, opts=None, exclude=None):
5959
self.test_complexity(opts)
6060
def test_all_opts_w_local(self):
6161
with contextlib.suppress(KernelOptError):
62-
return self.test_all_opts([Opt(OptOps.LOCAL, 0, 16)], [Opt(op=OptOps.PADTO, axis=1, arg=32)])
62+
return self.test_all_opts([Opt(OptOps.LOCAL, 0, 16)], [Opt(op=OptOps.PADTO, axis=1, amt=32)])
6363
def test_all_opts_w_upcast(self): return self.test_all_opts([Opt(OptOps.UPCAST, 0, 4)])
64-
def test_all_opts_w_unroll(self): return self.test_all_opts([Opt(OptOps.UNROLL, 0, 4)], [Opt(op=OptOps.GROUP, axis=0, arg=0)])
64+
def test_all_opts_w_unroll(self): return self.test_all_opts([Opt(OptOps.UNROLL, 0, 4)], [Opt(op=OptOps.GROUP, axis=0, amt=0)])
6565
def test_all_opts_w_upcast_and_unroll(self):
66-
return self.test_all_opts([Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], [Opt(op=OptOps.GROUP, axis=0, arg=0)])
66+
return self.test_all_opts([Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], [Opt(op=OptOps.GROUP, axis=0, amt=0)])
6767

6868
class TestIndexing(unittest.TestCase):
6969
# update: passing after CAST_BEFORE_VIEW=1 deletion

test/test_hcq.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def test_exec_update_fuzz(self):
160160
b = a + 1
161161
si = b.schedule()[-1]
162162
k = Kernel(si.ast, opts=TestHCQ.d0.renderer)
163-
for i in range(3): k.apply_opt(Opt(op=OptOps.LOCAL, axis=0, arg=3))
163+
for i in range(3): k.apply_opt(Opt(op=OptOps.LOCAL, axis=0, amt=3))
164164

165165
runner = CompiledRunner(k.to_program())
166166

test/test_linearizer.py

+16-16
Original file line numberDiff line numberDiff line change
@@ -1300,10 +1300,10 @@ def test_arange_opts(self):
13001300
helper_linearizer_opt(a, [
13011301
[Opt(OptOps.GROUP, 0, 32)],
13021302
[Opt(OptOps.GROUPTOP, 0, 32)],
1303-
[Opt(op=OptOps.LOCAL, axis=0, arg=8)],
1304-
[Opt(op=OptOps.LOCAL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=0, arg=0)],
1305-
[Opt(op=OptOps.LOCAL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.GROUP, axis=0, arg=8)],
1306-
[Opt(op=OptOps.LOCAL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.GROUP, axis=0, arg=8), Opt(op=OptOps.UNROLL, axis=1, arg=4)], # noqa: E501
1303+
[Opt(op=OptOps.LOCAL, axis=0, amt=8)],
1304+
[Opt(op=OptOps.LOCAL, axis=0, amt=8), Opt(op=OptOps.UPCAST, axis=0, amt=0)],
1305+
[Opt(op=OptOps.LOCAL, axis=0, amt=8), Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.GROUP, axis=0, amt=8)],
1306+
[Opt(op=OptOps.LOCAL, axis=0, amt=8), Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.GROUP, axis=0, amt=8), Opt(op=OptOps.UNROLL, axis=1, amt=4)], # noqa: E501
13071307
])
13081308

13091309
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
@@ -1363,8 +1363,8 @@ def test_skip_unmatching_upcasts(self):
13631363
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),
13641364
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(1, 240, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)) # noqa: E501
13651365
opt = [
1366-
Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=16),
1367-
Opt(op=OptOps.LOCAL, axis=1, arg=2), Opt(op=OptOps.UPCAST, axis=3, arg=2)
1366+
Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=16),
1367+
Opt(op=OptOps.LOCAL, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=2)
13681368
]
13691369
k = helper_linearizer_ast(ast, [Tensor.randn(240*40).realize()], opts=[opt])[-1]
13701370
out = [u for u in k.uops if u.op is Ops.STORE][0]
@@ -1381,9 +1381,9 @@ def test_skip_unmatching_upcasts_with_gep(self):
13811381
UOp(Ops.LOAD, dtypes.float, src=(
13821382
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),
13831383
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(1, 8, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)) # noqa: E501
1384-
opt = [Opt(op=OptOps.LOCAL, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=2, arg=2), Opt(op=OptOps.LOCAL, axis=1, arg=8),
1385-
Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=8),
1386-
Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=0, arg=2)]
1384+
opt = [Opt(op=OptOps.LOCAL, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=2), Opt(op=OptOps.LOCAL, axis=1, amt=8),
1385+
Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=8),
1386+
Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=0, amt=2)]
13871387
k = helper_linearizer_ast(ast, [Tensor.randn(8*32).realize()], opts=[opt])[-1]
13881388
out = [u for u in k.uops if u.op is Ops.STORE][0]
13891389
assert out.src[-1].op is Ops.VECTORIZE and out.src[-1].dtype.count != 1
@@ -1606,9 +1606,9 @@ def test_half4_load_unrolled(self):
16061606

16071607
# TODO: fix this, expected might change but should be positive
16081608
for expected, opts in [
1609-
((7, 0), [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=3), Opt(op=OptOps.UNROLL, axis=0, arg=4)]),
1610-
((5, 0), [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UNROLL, axis=0, arg=4)]),
1611-
((2, 0), [Opt(op=OptOps.UNROLL, axis=0, arg=4)]),
1609+
((7, 0), [Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=3), Opt(op=OptOps.UNROLL, axis=0, amt=4)]),
1610+
((5, 0), [Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=4)]),
1611+
((2, 0), [Opt(op=OptOps.UNROLL, axis=0, amt=4)]),
16121612
]:
16131613
k = Kernel(ast)
16141614
for opt in opts: k.apply_opt(opt)
@@ -1637,8 +1637,8 @@ def test_float4_acc(self):
16371637
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)) # noqa: E501
16381638

16391639
for expected, opts in [
1640-
(1, [Opt(op=OptOps.UPCAST, axis=2, arg=4)]),
1641-
(4, [Opt(op=OptOps.UPCAST, axis=2, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4)]),
1640+
(1, [Opt(op=OptOps.UPCAST, axis=2, amt=4)]),
1641+
(4, [Opt(op=OptOps.UPCAST, axis=2, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=4)]),
16421642
]:
16431643
k = Kernel(ast)
16441644
for opt in opts: k.apply_opt(opt)
@@ -1660,8 +1660,8 @@ def test_float2_acc(self):
16601660
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1),
16611661
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(256, 64, 3, 56, 2, 3, 56, 2), strides=(1806336, 28224, 3, 504, 0, 1, 9, 0), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 56), (0, 1), (0, 3), (0, 56), (0, 1)), contiguous=False), View(shape=(256, 64, 3, 115, 3, 115), strides=(7225344, 112896, 37632, 336, 112, 1), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 112), (0, 3), (0, 112)), contiguous=False), View(shape=(256, 64, 456, 456), strides=(7617600, 119025, 345, 1), offset=0, mask=((0, 256), (0, 64), (0, 345), (0, 345)), contiguous=False), View(shape=(1, 256, 1, 64, 4, 114, 4, 114), strides=(0, 13307904, 0, 207936, 51984, 456, 114, 1), offset=0, mask=None, contiguous=True)))),)),)),)),)),)),)) # noqa: E501
16621662
for expected, opts in [
1663-
(16, [Opt(op=OptOps.LOCAL, axis=1, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=2, arg=2), Opt(op=OptOps.LOCAL, axis=2, arg=3), Opt(op=OptOps.UPCAST, axis=3, arg=4)]), # noqa: E501
1664-
(4, [Opt(op=OptOps.LOCAL, axis=1, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=2, arg=2)]),
1663+
(16, [Opt(op=OptOps.LOCAL, axis=1, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=2, amt=2), Opt(op=OptOps.LOCAL, axis=2, amt=3), Opt(op=OptOps.UPCAST, axis=3, amt=4)]), # noqa: E501
1664+
(4, [Opt(op=OptOps.LOCAL, axis=1, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=2, amt=2)]),
16651665
]:
16661666
k = Kernel(ast)
16671667
for opt in opts: k.apply_opt(opt)

test/test_linearizer_dumb.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def test_unmerged_ifs(self):
3535
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
3636
ast_const(dtypes.half, 0.0, st_src=(
3737
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),))
38-
opts = [Opt(op=OptOps.TC, axis=2, arg=2), Opt(op=OptOps.UPCAST, axis=2, arg=0), Opt(op=OptOps.UNROLL, axis=1, arg=0)]
38+
opts = [Opt(op=OptOps.TC, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=0), Opt(op=OptOps.UNROLL, axis=1, amt=0)]
3939
k = Kernel(ast, opts=Device["METAL"].renderer)
4040
k.required_optimizations()
4141
for opt in opts: k.apply_opt(opt)
@@ -70,7 +70,7 @@ def test_max_simplify_and_cancel(self):
7070
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1001, 1999), strides=(0, 0), offset=0, mask=((0, 1001), (999, 1999)), contiguous=False), View(shape=(1000, 1000), strides=(1, 2000), offset=0, mask=None, contiguous=False))), src=()),)),)),
7171
ast_const(dtypes.int, 1000, st_src=(
7272
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
73-
opts = [Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=8)]
73+
opts = [Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=8)]
7474
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
7575
k.required_optimizations()
7676
for opt in opts: k.apply_opt(opt)
@@ -88,7 +88,7 @@ def test_expander_new_srcs(self):
8888
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
8989
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
9090
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),))
91-
opts = [Opt(op=OptOps.GROUP, axis=0, arg=0), Opt(op=OptOps.PADTO, axis=0, arg=32), Opt(op=OptOps.LOCAL, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=0)]
91+
opts = [Opt(op=OptOps.GROUP, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.LOCAL, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=0)]
9292
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
9393
k.required_optimizations()
9494
for opt in opts: k.apply_opt(opt)
@@ -155,7 +155,7 @@ def test_unaligns_idxs(self):
155155
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
156156
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()),
157157
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
158-
opts = [Opt(op=OptOps.UNROLL, axis=0, arg=0), Opt(op=OptOps.LOCAL, axis=0, arg=3)]
158+
opts = [Opt(op=OptOps.UNROLL, axis=0, amt=0), Opt(op=OptOps.LOCAL, axis=0, amt=3)]
159159
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
160160
for opt in opts: k.apply_opt(opt)
161161
prg = k.to_program()
@@ -186,7 +186,7 @@ def test_unrolled_float4_align(self):
186186
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
187187
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
188188
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(6, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),))
189-
opts = [Opt(op=OptOps.UNROLL, axis=0, arg=0)]
189+
opts = [Opt(op=OptOps.UNROLL, axis=0, amt=0)]
190190
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
191191
for opt in opts: k.apply_opt(opt)
192192
prg = k.to_program()
@@ -210,7 +210,7 @@ def test_upcasted_stores_out_of_order(self):
210210
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
211211
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
212212
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(260, 13, 1, 0, 0, 0, 65, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
213-
opts = [Opt(op=OptOps.UPCAST, axis=3, arg=0), Opt(op=OptOps.UPCAST, axis=2, arg=0)]
213+
opts = [Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.UPCAST, axis=2, amt=0)]
214214
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
215215
for opt in opts: k.apply_opt(opt)
216216
prg = k.to_program()

0 commit comments

Comments
 (0)