@@ -1323,14 +1323,14 @@ def sint_to_uop(x:sint, dtype:DType=dtypes.int) -> UOp: return UOp.const(dtype,
1323
1323
(UPat (Ops .VIEW , name = "vm1" , src = (UPat (Ops .VIEW , name = "vm2" ),)), lambda vm1 ,vm2 : vm2 .replace (arg = vm2 .st + vm1 .st )),
1324
1324
(UPat (Ops .VIEW , name = "vm" , src = (UPat .var ("x" ),)), lambda vm ,x : x if vm .st .contiguous and x .st is not None and x .shape == vm .shape else None ),
1325
1325
# merge unmasked const views
1326
- (UPat (Ops .VIEW , name = "view" , src = (UPat (Ops .CONST , name = "const" , src = (UPat (Ops .VIEW , name = "st" ),) ),)),
1326
+ (UPat (Ops .VIEW , name = "view" , src = (UPat (( Ops .CONST , Ops . DEFINE_VAR ) , name = "const" , src = (UPat (Ops .VIEW , name = "st" ),) ),)),
1327
1327
lambda st ,const ,view : const .replace (src = (st .replace (arg = st .st + view .st ),)) if all (v .mask is None for v in (st .st + view .st ).views ) else None ),
1328
1328
])
1329
1329
1330
1330
# push VIEW to parents
1331
1331
view_left = merge_views + PatternMatcher ([
1332
1332
# VIEW(CONST) becomes VALID
1333
- (UPat (Ops .VIEW , name = "vm" , src = (UPat . cvar ( "x" ),)), lambda vm ,x : UOp . const ( x . dtype , x . const_arg ).valid (vm .st )),
1333
+ (UPat (Ops .VIEW , name = "vm" , src = (UPat (( Ops . CONST , Ops . DEFINE_VAR ), name = "x" ),)), lambda vm ,x : x . replace ( src = () ).valid (vm .st )),
1334
1334
# VIEW before elementwise/buffer ops
1335
1335
(UPat (Ops .VIEW , name = "vm" , src = (UPat ({* GroupOp .ALU , Ops .CAST , Ops .BITCAST , Ops .ASSIGN }, name = "e" ),)),
1336
1336
lambda e ,vm : e .replace (src = tuple (s if s .st is None else s .view (vm .st ) if s is s .base else s .base .view (s .st + vm .st ) for s in e .src ))),
0 commit comments