@@ -199,6 +199,8 @@ def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
199
199
(UPat (Ops .PRELOAD , name = "root" ), lambda root :root .replace (op = Ops .LOAD )),
200
200
# once images are loaded they become the base dtype
201
201
(UPat (set (Ops )- {Ops .DEFINE_GLOBAL }, name = "x" ), lambda x : x .replace (dtype = x .dtype .base ) if isinstance (x .dtype , ImageDType ) else None ),
202
+ # CONST(VIEW) becomes VALID too, TODO: doesn't have to
203
+ (UPat (Ops .CONST , name = "x" , src = (UPat (Ops .VIEW , name = "st" ),)), lambda x ,st : UOp .const (x .dtype , x .const_arg ).valid (st .st )),
202
204
])
203
205
204
206
# LOAD(BUFFER) -> the STORE value if it's we're doing the STORE in the same kernel
@@ -438,11 +440,11 @@ def create_subbuffer(base:UOp, b:UOp, root:UOp, x:UOp):
438
440
(UPatScheduled ((Ops .BITCAST , Ops .CONTIGUOUS ), name = "root" , src = (UPat .var ("x" ),)), create_subbuffer ),
439
441
])
440
442
441
- # **** rewrite VIEW into LOAD/STORE/VALID or fuse the underlying UOp
443
+ # **** rewrite VIEW into LOAD/STORE or fuse the underlying UOp
442
444
443
445
def unbind_variable (ctx :ScheduleContext , bind :UOp , var :UOp , val :UOp ):
444
- assert isinstance (val .src [ 1 ]. const_arg , int ), f"expected BIND value to be int { val } "
445
- ctx .var_vals [ret := var .replace (src = ())] = val .src [ 1 ]. const_arg
446
+ assert isinstance (val .const_arg , int ), f"expected BIND value to be int { val } "
447
+ ctx .var_vals [ret := var .replace (src = ())] = val .const_arg
446
448
return ret .valid (unwrap (bind .st ))
447
449
448
450
def load_realized (ctx :ScheduleContext , b :UOp , st :UOp ):
@@ -456,8 +458,6 @@ def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp):
456
458
return UOp (Ops .LOAD , x .dtype , (b , unwrap (st .st ).to_uop ()))
457
459
458
460
break_sched = PatternMatcher ([
459
- # CONST is always fused and generated
460
- (UPat (Ops .CONST , name = "x" , src = (UPat (Ops .VIEW , name = "st" ),)), lambda x ,st : UOp .const (x .dtype , x .const_arg ).valid (st .st )),
461
461
(UPat (Ops .BIND , name = "bind" , src = (UPat .var ("var" ), UPat .var ("val" ))), unbind_variable ),
462
462
# VIEW of BUFFER either becomes a LOAD/STORE or we fuse it
463
463
(UPat (Ops .VIEW , name = "st" , src = (UPat (Ops .BUFFER , name = "b" ),)), load_realized ),
@@ -481,9 +481,6 @@ def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None:
481
481
# some masked views can collapse to 0, VIEW(x) -> CONST(VIEW)
482
482
(UPat (Ops .VIEW , name = "view" ),
483
483
lambda view : view .const_like (0 ) if (vm := view .st .views [- 1 ].mask ) is not None and any ((x [1 ]- x [0 ]) == 0 for x in vm ) else None ),
484
- # merge unmasked const views
485
- (UPat (Ops .VIEW , name = "view" , src = (UPat (Ops .CONST , name = "const" , src = (UPat (Ops .VIEW , name = "st" ),) ),)),
486
- lambda st ,const ,view : const .replace (src = (st .replace (arg = st .st + view .st ),)) if all (v .mask is None for v in (st .st + view .st ).views ) else None ),
487
484
])
488
485
489
486
@track_rewrites (named = True )
0 commit comments