@@ -54,8 +54,8 @@ def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp):
54
54
(UPat (Ops .COPY , src = (UPat (), UPat .var ("copyin" )), name = "copy" ),
55
55
lambda copyin ,copy : copyin if copyin .device == copy .device and copy .arg is not True else None ),
56
56
# remove cast to image when it's already a contiguous image
57
- (UPat (Ops .VIEW , name = "vm1" , src = ( UPat ( Ops . CAST , name = "cast" , src = (UPat (Ops .VIEW , name = "vm2 " , src = (UPat (Ops .CONTIGUOUS , name = "base" )) ))),)),
58
- lambda cast ,base ,vm1 , vm2 : base .view (vm2 . st + vm1 .st ) if isinstance (cast .dtype , ImageDType ) and isinstance (base .dtype , ImageDType ) else None ),
57
+ (UPat (Ops .CAST , name = "cast" , src = (UPat (Ops .VIEW , name = "vm " , src = (UPat (Ops .CONTIGUOUS , name = "base" ))),)),
58
+ lambda cast ,base ,vm : base .view (vm .st ) if isinstance (cast .dtype , ImageDType ) and isinstance (base .dtype , ImageDType ) else None ),
59
59
# remove contiguous if we can just view the buffer
60
60
(UPat (Ops .CONTIGUOUS , name = "root" , src = (UPat (Ops .VIEW , name = "view" , src = (UPat (Ops .BUFFER , name = "buf" ),)),)),
61
61
lambda root ,view ,buf : view if view .st .contiguous and view .size == buf .size else None ),
0 commit comments