@@ -1206,18 +1206,13 @@ def test_failure_51(self):
1206
1206
UOp (Ops .VIEW , dtypes .void , arg = ShapeTracker (views = (View (shape = (12 , 1024 , 1 ), strides = (1024 , 1 , 0 ), offset = 0 , mask = None , contiguous = True ),)), src = ()),
1207
1207
UOp (Ops .RECIP , dtypes .half , arg = None , src = (
1208
1208
UOp (Ops .ADD , dtypes .half , arg = None , src = (
1209
- UOp (Ops .WHERE , dtypes .half , arg = None , src = (
1210
- x6 := UOp (Ops .VALID , dtypes .bool , arg = None , src = (
1211
- UOp (Ops .VIEW , dtypes .void , arg = ShapeTracker (views = (View (shape = (12 , 1024 , 1 ), strides = (0 , 0 , 0 ), offset = 0 , mask = None , contiguous = False ),)), src = ()),)),
1212
- UOp (Ops .CONST , dtypes .half , arg = 1.0 , src = ()),
1213
- x9 := UOp (Ops .CONST , dtypes .half , arg = 0.0 , src = ()),)),
1209
+ UOp (Ops .CONST , dtypes .half , arg = 1.0 , src = (
1210
+ x6 := UOp (Ops .VIEW , dtypes .void , arg = ShapeTracker (views = (View (shape = (12 , 1024 , 1 ), strides = (0 , 0 , 0 ), offset = 0 , mask = None , contiguous = False ),)), src = ()),)),
1214
1211
UOp (Ops .EXP2 , dtypes .half , arg = None , src = (
1215
1212
UOp (Ops .MUL , dtypes .half , arg = None , src = (
1216
1213
UOp (Ops .MUL , dtypes .half , arg = None , src = (
1217
- UOp (Ops .WHERE , dtypes .half , arg = None , src = (
1218
- x6 ,
1219
- UOp (Ops .CONST , dtypes .half , arg = 2.0 , src = ()),
1220
- x9 ,)),
1214
+ UOp (Ops .CONST , dtypes .half , arg = 2.0 , src = (
1215
+ x6 ,)),
1221
1216
UOp (Ops .ADD , dtypes .half , arg = None , src = (
1222
1217
UOp (Ops .CAST , dtypes .half , arg = None , src = (
1223
1218
UOp (Ops .REDUCE_AXIS , dtypes .float , arg = (Ops .ADD , (2 ,)), src = (
@@ -1232,10 +1227,8 @@ def test_failure_51(self):
1232
1227
UOp (Ops .LOAD , dtypes .half , arg = None , src = (
1233
1228
UOp (Ops .DEFINE_GLOBAL , dtypes .half .ptr (), arg = 3 , src = ()),
1234
1229
UOp (Ops .VIEW , dtypes .void , arg = ShapeTracker (views = (View (shape = (12 , 1024 , 1 ), strides = (0 , 1 , 0 ), offset = 0 , mask = None , contiguous = False ),)), src = ()),)),)),)),
1235
- UOp (Ops .WHERE , dtypes .half , arg = None , src = (
1236
- x6 ,
1237
- UOp (Ops .CONST , dtypes .half , arg = - 1.4426950408889634 , src = ()),
1238
- x9 ,)),)),)),)),)),)),))
1230
+ UOp (Ops .CONST , dtypes .half , arg = - 1.4426950408889634 , src = (
1231
+ x6 ,)),)),)),)),)),)),))
1239
1232
opts = [Opt (op = OptOps .TC , axis = 0 , arg = 2 )]
1240
1233
helper_test_lin (Kernel (ast , opts = Device [Device .DEFAULT ].renderer ), opts = opts , failed_platforms = [])
1241
1234
@@ -1283,17 +1276,14 @@ def test_failure_53(self):
1283
1276
UOp (Ops .WHERE , dtypes .int , arg = None , src = (
1284
1277
UOp (Ops .VALID , dtypes .bool , arg = None , src = (
1285
1278
UOp (Ops .VIEW , dtypes .void , arg = ShapeTracker (views = (View (shape = (50001 , 99999 ), strides = (0 , 0 ), offset = 0 , mask = ((0 , 50001 ), (49999 , 99999 )), contiguous = False ), View (shape = (1024 , 50000 , 50000 ), strides = (0 , 1 , 100000 ), offset = 0 , mask = None , contiguous = False ))), src = ()),)),
1286
- UOp (Ops .CONST , dtypes .int , arg = 1 , src = ()),
1287
- x20 := UOp (Ops .CONST , dtypes .int , arg = 0 , src = ()),)),)),
1288
- UOp (Ops .WHERE , dtypes .int , arg = None , src = (
1289
- x22 := UOp (Ops .VALID , dtypes .bool , arg = None , src = (
1290
- UOp (Ops .VIEW , dtypes .void , arg = ShapeTracker (views = (View (shape = (1024 , 50000 , 1 ), strides = (0 , 0 , 0 ), offset = 0 , mask = None , contiguous = False ),)), src = ()),)),
1291
- UOp (Ops .CONST , dtypes .int , arg = - 1 , src = ()),
1292
- x20 ,)),)),)),
1293
- UOp (Ops .WHERE , dtypes .bool , arg = None , src = (
1294
- x22 ,
1295
- UOp (Ops .CONST , dtypes .bool , arg = True , src = ()),
1296
- UOp (Ops .CONST , dtypes .bool , arg = False , src = ()),)),)),)),)),)),)),))
1279
+ UOp (Ops .CONST , dtypes .int , arg = 1 , src = (
1280
+ x20 := UOp (Ops .VIEW , dtypes .void , arg = ShapeTracker (views = (View (shape = (1024 , 50000 , 50000 ), strides = (0 , 0 , 0 ), offset = 0 , mask = None , contiguous = False ),)), src = ()),)),
1281
+ UOp (Ops .CONST , dtypes .int , arg = 0 , src = (
1282
+ x20 ,)),)),)),
1283
+ UOp (Ops .CONST , dtypes .int , arg = - 1 , src = (
1284
+ x23 := UOp (Ops .VIEW , dtypes .void , arg = ShapeTracker (views = (View (shape = (1024 , 50000 , 1 ), strides = (0 , 0 , 0 ), offset = 0 , mask = None , contiguous = False ),)), src = ()),)),)),)),
1285
+ UOp (Ops .CONST , dtypes .bool , arg = True , src = (
1286
+ x23 ,)),)),)),)),)),)),))
1297
1287
opts = [Opt (op = OptOps .GROUPTOP , axis = 1 , arg = 16 )]
1298
1288
helper_test_lin (Kernel (ast , opts = Device [Device .DEFAULT ].renderer ), opts = opts , failed_platforms = ["AMD" , "GPU" , "METAL" , "NV" , "CUDA" ])
1299
1289
@@ -1348,11 +1338,8 @@ def test_failure_56(self):
1348
1338
UOp (Ops .MUL , dtypes .float , arg = None , src = (
1349
1339
UOp (Ops .CAST , dtypes .float , arg = None , src = (
1350
1340
UOp (Ops .CMPLT , dtypes .bool , arg = None , src = (
1351
- x7 := UOp (Ops .WHERE , dtypes .float , arg = None , src = (
1352
- x8 := UOp (Ops .VALID , dtypes .bool , arg = None , src = (
1353
- UOp (Ops .VIEW , dtypes .void , arg = ShapeTracker (views = (View (shape = (128 , 16 , 11 , 11 ), strides = (0 , 0 , 0 , 0 ), offset = 0 , mask = None , contiguous = False ),)), src = ()),)),
1354
- x10 := UOp (Ops .CONST , dtypes .float , arg = 0.0 , src = ()),
1355
- x10 ,)),
1341
+ x7 := UOp (Ops .CONST , dtypes .float , arg = 0.0 , src = (
1342
+ x8 := UOp (Ops .VIEW , dtypes .void , arg = ShapeTracker (views = (View (shape = (128 , 16 , 11 , 11 ), strides = (0 , 0 , 0 , 0 ), offset = 0 , mask = None , contiguous = False ),)), src = ()),)),
1356
1343
UOp (Ops .MAX , dtypes .float , arg = None , src = (
1357
1344
UOp (Ops .ADD , dtypes .float , arg = None , src = (
1358
1345
UOp (Ops .MUL , dtypes .float , arg = None , src = (
@@ -1364,20 +1351,18 @@ def test_failure_56(self):
1364
1351
UOp (Ops .MUL , dtypes .float , arg = None , src = (
1365
1352
UOp (Ops .LOAD , dtypes .float , arg = None , src = (
1366
1353
UOp (Ops .DEFINE_GLOBAL , dtypes .float .ptr (), arg = 2 , src = ()),
1367
- x22 := UOp (Ops .VIEW , dtypes .void , arg = ShapeTracker (views = (View (shape = (128 , 16 , 11 , 11 ), strides = (0 , 1 , 0 , 0 ), offset = 0 , mask = None , contiguous = False ),)), src = ()),)),
1368
- UOp (Ops .WHERE , dtypes .float , arg = None , src = (
1369
- x8 ,
1370
- UOp (Ops .CONST , dtypes .float , arg = - 1.0 , src = ()),
1371
- x10 ,)),)),)),
1354
+ x20 := UOp (Ops .VIEW , dtypes .void , arg = ShapeTracker (views = (View (shape = (128 , 16 , 11 , 11 ), strides = (0 , 1 , 0 , 0 ), offset = 0 , mask = None , contiguous = False ),)), src = ()),)),
1355
+ UOp (Ops .CONST , dtypes .float , arg = - 1.0 , src = (
1356
+ x8 ,)),)),)),
1372
1357
UOp (Ops .LOAD , dtypes .float , arg = None , src = (
1373
1358
UOp (Ops .DEFINE_GLOBAL , dtypes .float .ptr (), arg = 3 , src = ()),
1374
- x22 ,)),)),
1359
+ x20 ,)),)),
1375
1360
UOp (Ops .LOAD , dtypes .float , arg = None , src = (
1376
1361
UOp (Ops .DEFINE_GLOBAL , dtypes .float .ptr (), arg = 4 , src = ()),
1377
- x22 ,)),)),
1362
+ x20 ,)),)),
1378
1363
UOp (Ops .LOAD , dtypes .float , arg = None , src = (
1379
1364
UOp (Ops .DEFINE_GLOBAL , dtypes .float .ptr (), arg = 5 , src = ()),
1380
- x22 ,)),)),
1365
+ x20 ,)),)),
1381
1366
x7 ,)),)),)),
1382
1367
UOp (Ops .LOAD , dtypes .float , arg = None , src = (
1383
1368
UOp (Ops .DEFINE_GLOBAL , dtypes .float .ptr (), arg = 6 , src = ()),
@@ -1394,11 +1379,8 @@ def test_failure_57(self):
1394
1379
UOp (Ops .MUL , dtypes .float , arg = None , src = (
1395
1380
UOp (Ops .CAST , dtypes .float , arg = None , src = (
1396
1381
UOp (Ops .CMPLT , dtypes .bool , arg = None , src = (
1397
- x7 := UOp (Ops .WHERE , dtypes .float , arg = None , src = (
1398
- x8 := UOp (Ops .VALID , dtypes .bool , arg = None , src = (
1399
- UOp (Ops .VIEW , dtypes .void , arg = ShapeTracker (views = (View (shape = (128 , 16 , 11 , 11 ), strides = (0 , 0 , 0 , 0 ), offset = 0 , mask = None , contiguous = False ),)), src = ()),)),
1400
- x10 := UOp (Ops .CONST , dtypes .float , arg = 0.0 , src = ()),
1401
- x10 ,)),
1382
+ x7 := UOp (Ops .CONST , dtypes .float , arg = 0.0 , src = (
1383
+ x8 := UOp (Ops .VIEW , dtypes .void , arg = ShapeTracker (views = (View (shape = (128 , 16 , 11 , 11 ), strides = (0 , 0 , 0 , 0 ), offset = 0 , mask = None , contiguous = False ),)), src = ()),)),
1402
1384
UOp (Ops .MAX , dtypes .float , arg = None , src = (
1403
1385
UOp (Ops .ADD , dtypes .float , arg = None , src = (
1404
1386
UOp (Ops .MUL , dtypes .float , arg = None , src = (
@@ -1410,20 +1392,18 @@ def test_failure_57(self):
1410
1392
UOp (Ops .MUL , dtypes .float , arg = None , src = (
1411
1393
UOp (Ops .LOAD , dtypes .float , arg = None , src = (
1412
1394
UOp (Ops .DEFINE_GLOBAL , dtypes .float .ptr (), arg = 2 , src = ()),
1413
- x22 := UOp (Ops .VIEW , dtypes .void , arg = ShapeTracker (views = (View (shape = (128 , 16 , 11 , 11 ), strides = (0 , 1 , 0 , 0 ), offset = 0 , mask = None , contiguous = False ),)), src = ()),)),
1414
- UOp (Ops .WHERE , dtypes .float , arg = None , src = (
1415
- x8 ,
1416
- UOp (Ops .CONST , dtypes .float , arg = - 1.0 , src = ()),
1417
- x10 ,)),)),)),
1395
+ x20 := UOp (Ops .VIEW , dtypes .void , arg = ShapeTracker (views = (View (shape = (128 , 16 , 11 , 11 ), strides = (0 , 1 , 0 , 0 ), offset = 0 , mask = None , contiguous = False ),)), src = ()),)),
1396
+ UOp (Ops .CONST , dtypes .float , arg = - 1.0 , src = (
1397
+ x8 ,)),)),)),
1418
1398
UOp (Ops .LOAD , dtypes .float , arg = None , src = (
1419
1399
UOp (Ops .DEFINE_GLOBAL , dtypes .float .ptr (), arg = 3 , src = ()),
1420
- x22 ,)),)),
1400
+ x20 ,)),)),
1421
1401
UOp (Ops .LOAD , dtypes .float , arg = None , src = (
1422
1402
UOp (Ops .DEFINE_GLOBAL , dtypes .float .ptr (), arg = 4 , src = ()),
1423
- x22 ,)),)),
1403
+ x20 ,)),)),
1424
1404
UOp (Ops .LOAD , dtypes .float , arg = None , src = (
1425
1405
UOp (Ops .DEFINE_GLOBAL , dtypes .float .ptr (), arg = 5 , src = ()),
1426
- x22 ,)),)),
1406
+ x20 ,)),)),
1427
1407
x7 ,)),)),)),
1428
1408
UOp (Ops .LOAD , dtypes .float , arg = None , src = (
1429
1409
UOp (Ops .DEFINE_GLOBAL , dtypes .float .ptr (), arg = 6 , src = ()),
0 commit comments