@@ -1300,10 +1300,10 @@ def test_arange_opts(self):
1300
1300
helper_linearizer_opt (a , [
1301
1301
[Opt (OptOps .GROUP , 0 , 32 )],
1302
1302
[Opt (OptOps .GROUPTOP , 0 , 32 )],
1303
- [Opt (op = OptOps .LOCAL , axis = 0 , arg = 8 )],
1304
- [Opt (op = OptOps .LOCAL , axis = 0 , arg = 8 ), Opt (op = OptOps .UPCAST , axis = 0 , arg = 0 )],
1305
- [Opt (op = OptOps .LOCAL , axis = 0 , arg = 8 ), Opt (op = OptOps .UPCAST , axis = 0 , arg = 0 ), Opt (op = OptOps .GROUP , axis = 0 , arg = 8 )],
1306
- [Opt (op = OptOps .LOCAL , axis = 0 , arg = 8 ), Opt (op = OptOps .UPCAST , axis = 0 , arg = 0 ), Opt (op = OptOps .GROUP , axis = 0 , arg = 8 ), Opt (op = OptOps .UNROLL , axis = 1 , arg = 4 )], # noqa: E501
1303
+ [Opt (op = OptOps .LOCAL , axis = 0 , amt = 8 )],
1304
+ [Opt (op = OptOps .LOCAL , axis = 0 , amt = 8 ), Opt (op = OptOps .UPCAST , axis = 0 , amt = 0 )],
1305
+ [Opt (op = OptOps .LOCAL , axis = 0 , amt = 8 ), Opt (op = OptOps .UPCAST , axis = 0 , amt = 0 ), Opt (op = OptOps .GROUP , axis = 0 , amt = 8 )],
1306
+ [Opt (op = OptOps .LOCAL , axis = 0 , amt = 8 ), Opt (op = OptOps .UPCAST , axis = 0 , amt = 0 ), Opt (op = OptOps .GROUP , axis = 0 , amt = 8 ), Opt (op = OptOps .UNROLL , axis = 1 , amt = 4 )], # noqa: E501
1307
1307
])
1308
1308
1309
1309
@unittest .skipUnless (Device [Device .DEFAULT ].renderer .supports_float4 , "test requires float4" )
@@ -1363,8 +1363,8 @@ def test_skip_unmatching_upcasts(self):
1363
1363
UOp (Ops .DEFINE_GLOBAL , dtypes .float .ptr (), arg = 1 ),
1364
1364
UOp (Ops .VIEW , arg = ShapeTracker (views = (View (shape = (240 , 40 , 1 , 1 ), strides = (1 , 240 , 0 , 0 ), offset = 0 , mask = None , contiguous = False ),))),)),)),)) # noqa: E501
1365
1365
opt = [
1366
- Opt (op = OptOps .UPCAST , axis = 1 , arg = 4 ), Opt (op = OptOps .LOCAL , axis = 0 , arg = 16 ),
1367
- Opt (op = OptOps .LOCAL , axis = 1 , arg = 2 ), Opt (op = OptOps .UPCAST , axis = 3 , arg = 2 )
1366
+ Opt (op = OptOps .UPCAST , axis = 1 , amt = 4 ), Opt (op = OptOps .LOCAL , axis = 0 , amt = 16 ),
1367
+ Opt (op = OptOps .LOCAL , axis = 1 , amt = 2 ), Opt (op = OptOps .UPCAST , axis = 3 , amt = 2 )
1368
1368
]
1369
1369
k = helper_linearizer_ast (ast , [Tensor .randn (240 * 40 ).realize ()], opts = [opt ])[- 1 ]
1370
1370
out = [u for u in k .uops if u .op is Ops .STORE ][0 ]
@@ -1381,9 +1381,9 @@ def test_skip_unmatching_upcasts_with_gep(self):
1381
1381
UOp (Ops .LOAD , dtypes .float , src = (
1382
1382
UOp (Ops .DEFINE_GLOBAL , dtypes .float .ptr (), arg = 1 ),
1383
1383
UOp (Ops .VIEW , arg = ShapeTracker (views = (View (shape = (8 , 32 , 1 , 1 ), strides = (1 , 8 , 0 , 0 ), offset = 0 , mask = None , contiguous = False ),))),)),)),)) # noqa: E501
1384
- opt = [Opt (op = OptOps .LOCAL , axis = 1 , arg = 4 ), Opt (op = OptOps .UPCAST , axis = 2 , arg = 2 ), Opt (op = OptOps .LOCAL , axis = 1 , arg = 8 ),
1385
- Opt (op = OptOps .UPCAST , axis = 1 , arg = 0 ), Opt (op = OptOps .UPCAST , axis = 1 , arg = 4 ), Opt (op = OptOps .LOCAL , axis = 0 , arg = 8 ),
1386
- Opt (op = OptOps .UPCAST , axis = 1 , arg = 0 ), Opt (op = OptOps .UPCAST , axis = 0 , arg = 2 )]
1384
+ opt = [Opt (op = OptOps .LOCAL , axis = 1 , amt = 4 ), Opt (op = OptOps .UPCAST , axis = 2 , amt = 2 ), Opt (op = OptOps .LOCAL , axis = 1 , amt = 8 ),
1385
+ Opt (op = OptOps .UPCAST , axis = 1 , amt = 0 ), Opt (op = OptOps .UPCAST , axis = 1 , amt = 4 ), Opt (op = OptOps .LOCAL , axis = 0 , amt = 8 ),
1386
+ Opt (op = OptOps .UPCAST , axis = 1 , amt = 0 ), Opt (op = OptOps .UPCAST , axis = 0 , amt = 2 )]
1387
1387
k = helper_linearizer_ast (ast , [Tensor .randn (8 * 32 ).realize ()], opts = [opt ])[- 1 ]
1388
1388
out = [u for u in k .uops if u .op is Ops .STORE ][0 ]
1389
1389
assert out .src [- 1 ].op is Ops .VECTORIZE and out .src [- 1 ].dtype .count != 1
@@ -1606,9 +1606,9 @@ def test_half4_load_unrolled(self):
1606
1606
1607
1607
# TODO: fix this, expected might change but should be positive
1608
1608
for expected , opts in [
1609
- ((7 , 0 ), [Opt (op = OptOps .UPCAST , axis = 1 , arg = 4 ), Opt (op = OptOps .UPCAST , axis = 0 , arg = 3 ), Opt (op = OptOps .UNROLL , axis = 0 , arg = 4 )]),
1610
- ((5 , 0 ), [Opt (op = OptOps .UPCAST , axis = 1 , arg = 4 ), Opt (op = OptOps .UNROLL , axis = 0 , arg = 4 )]),
1611
- ((2 , 0 ), [Opt (op = OptOps .UNROLL , axis = 0 , arg = 4 )]),
1609
+ ((7 , 0 ), [Opt (op = OptOps .UPCAST , axis = 1 , amt = 4 ), Opt (op = OptOps .UPCAST , axis = 0 , amt = 3 ), Opt (op = OptOps .UNROLL , axis = 0 , amt = 4 )]),
1610
+ ((5 , 0 ), [Opt (op = OptOps .UPCAST , axis = 1 , amt = 4 ), Opt (op = OptOps .UNROLL , axis = 0 , amt = 4 )]),
1611
+ ((2 , 0 ), [Opt (op = OptOps .UNROLL , axis = 0 , amt = 4 )]),
1612
1612
]:
1613
1613
k = Kernel (ast )
1614
1614
for opt in opts : k .apply_opt (opt )
@@ -1637,8 +1637,8 @@ def test_float4_acc(self):
1637
1637
UOp (Ops .VIEW , arg = ShapeTracker (views = (View (shape = (1 , 1 , 128 , 512 , 512 , 1 , 1 , 1 ), strides = (0 , 0 , 1 , 0 , 0 , 0 , 0 , 0 ), offset = 0 , mask = None , contiguous = False ),))),)),)),)),)) # noqa: E501
1638
1638
1639
1639
for expected , opts in [
1640
- (1 , [Opt (op = OptOps .UPCAST , axis = 2 , arg = 4 )]),
1641
- (4 , [Opt (op = OptOps .UPCAST , axis = 2 , arg = 4 ), Opt (op = OptOps .UPCAST , axis = 0 , arg = 4 )]),
1640
+ (1 , [Opt (op = OptOps .UPCAST , axis = 2 , amt = 4 )]),
1641
+ (4 , [Opt (op = OptOps .UPCAST , axis = 2 , amt = 4 ), Opt (op = OptOps .UPCAST , axis = 0 , amt = 4 )]),
1642
1642
]:
1643
1643
k = Kernel (ast )
1644
1644
for opt in opts : k .apply_opt (opt )
@@ -1660,8 +1660,8 @@ def test_float2_acc(self):
1660
1660
UOp (Ops .DEFINE_GLOBAL , dtypes .half .ptr (), arg = 1 ),
1661
1661
UOp (Ops .VIEW , arg = ShapeTracker (views = (View (shape = (256 , 64 , 3 , 56 , 2 , 3 , 56 , 2 ), strides = (1806336 , 28224 , 3 , 504 , 0 , 1 , 9 , 0 ), offset = 0 , mask = ((0 , 256 ), (0 , 64 ), (0 , 3 ), (0 , 56 ), (0 , 1 ), (0 , 3 ), (0 , 56 ), (0 , 1 )), contiguous = False ), View (shape = (256 , 64 , 3 , 115 , 3 , 115 ), strides = (7225344 , 112896 , 37632 , 336 , 112 , 1 ), offset = 0 , mask = ((0 , 256 ), (0 , 64 ), (0 , 3 ), (0 , 112 ), (0 , 3 ), (0 , 112 )), contiguous = False ), View (shape = (256 , 64 , 456 , 456 ), strides = (7617600 , 119025 , 345 , 1 ), offset = 0 , mask = ((0 , 256 ), (0 , 64 ), (0 , 345 ), (0 , 345 )), contiguous = False ), View (shape = (1 , 256 , 1 , 64 , 4 , 114 , 4 , 114 ), strides = (0 , 13307904 , 0 , 207936 , 51984 , 456 , 114 , 1 ), offset = 0 , mask = None , contiguous = True )))),)),)),)),)),)),)) # noqa: E501
1662
1662
for expected , opts in [
1663
- (16 , [Opt (op = OptOps .LOCAL , axis = 1 , arg = 16 ), Opt (op = OptOps .UPCAST , axis = 1 , arg = 0 ), Opt (op = OptOps .UPCAST , axis = 2 , arg = 2 ), Opt (op = OptOps .LOCAL , axis = 2 , arg = 3 ), Opt (op = OptOps .UPCAST , axis = 3 , arg = 4 )]), # noqa: E501
1664
- (4 , [Opt (op = OptOps .LOCAL , axis = 1 , arg = 16 ), Opt (op = OptOps .UPCAST , axis = 1 , arg = 0 ), Opt (op = OptOps .UPCAST , axis = 2 , arg = 2 )]),
1663
+ (16 , [Opt (op = OptOps .LOCAL , axis = 1 , amt = 16 ), Opt (op = OptOps .UPCAST , axis = 1 , amt = 0 ), Opt (op = OptOps .UPCAST , axis = 2 , amt = 2 ), Opt (op = OptOps .LOCAL , axis = 2 , amt = 3 ), Opt (op = OptOps .UPCAST , axis = 3 , amt = 4 )]), # noqa: E501
1664
+ (4 , [Opt (op = OptOps .LOCAL , axis = 1 , amt = 16 ), Opt (op = OptOps .UPCAST , axis = 1 , amt = 0 ), Opt (op = OptOps .UPCAST , axis = 2 , amt = 2 )]),
1665
1665
]:
1666
1666
k = Kernel (ast )
1667
1667
for opt in opts : k .apply_opt (opt )
0 commit comments