@@ -63,6 +63,9 @@ def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp):
63
63
# support for using a contiguous permuted view instead of the parent view if one exists
64
64
(UPat (Ops .CONTIGUOUS , name = "contig" , src = (UPat (Ops .VIEW , name = "src" ),)), found_contiguous ),
65
65
(UPat (GroupOp .ALU , name = "alu" ), replace_contiguous ),
66
+ # substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
67
+ (UPat ((Ops .BITCAST , Ops .CONTIGUOUS ), name = "root" ),
68
+ lambda root : root .replace (op = Ops .BUFFER_VIEW ) if isinstance (root .device , str ) and root .device .startswith ("DISK" ) else None ),
66
69
# remove CONST/BIND/BUFFER from SINK
67
70
(UPat (Ops .SINK , name = "root" ),
68
71
lambda root : UOp (Ops .SINK , root .dtype , new_src , root .arg )
@@ -112,6 +115,7 @@ def add_buffers(buf:UOp, buffer_map:dict[UOp, UOp], cache:dict[UOp, UOp]) -> UOp
112
115
op = buf .replace (dtype = dtype , src = tuple (add_buffers (x , buffer_map , cache ) for x in buf .src ))
113
116
# track the buffer uop for the simplified uop
114
117
buffer_map [buf ] = buf_uop
118
+ if op .op is Ops .BUFFER_VIEW : buffers [buf_uop ] = (x := op .src [0 ]).buf_uop .buffer .view (op .size , op .dtype , unwrap (x .st ).views [0 ].offset * x .dtype .itemsize )
115
119
# (early) bufferize
116
120
cache [buf ] = ret = UOp (Ops .VIEW , dtype .base , (buf_uop , op ), buf .st )
117
121
return ret
@@ -132,23 +136,15 @@ def realize_before_view(ctx:ScheduleContext, view:UOp, src:UOp, b:UOp, **kwargs)
132
136
# otherwise safety check pads
133
137
return None if (all (v .mask is None for v in st .views ) or can_pad (src , ctx .realizes , dict ())) else realize (ctx , b , src )
134
138
135
- def create_subbuffer (base :UOp , b :UOp , root :UOp , x :UOp ):
136
- if isinstance (b .device , tuple ) or not b .device .startswith ("DISK" ): return None
137
- buffers [b ] = x .buf_uop .buffer .view (b .size , b .dtype , unwrap (x .st ).views [0 ].offset * x .dtype .itemsize )
138
- return base .replace (src = (b , root .replace (op = Ops .BUFFER_VIEW )))
139
-
140
139
do_realize = PatternMatcher ([
141
140
# always realize SINK parents
142
141
(UPat (Ops .SINK , name = "sink" ), lambda ctx ,sink : ctx .realizes .update ((x .buf_uop , x ) for x in sink .src )),
143
142
# always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW
144
143
(UPatScheduled ({Ops .ASSIGN , Ops .CONTIGUOUS , Ops .COPY , Ops .BUFFER_VIEW }), realize ),
145
144
# realize before expand or unsafe pad ops
146
145
(UPat (Ops .VIEW , name = "view" , src = (UPatScheduled (name = "src" ),)), realize_before_view ),
147
- # realize before COPY or BUFFER_VIEW
148
- (UPat (Ops .COPY , src = (UPat (), UPat .any (UPatScheduled (), UPatScheduled ().view ()),)), realize ),
149
- (UPat (Ops .BUFFER_VIEW , src = (UPat .any (UPatScheduled (), UPatScheduled ().view ()),)), realize ),
150
- # substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
151
- (UPatScheduled ((Ops .BITCAST , Ops .CONTIGUOUS ), name = "root" , src = (UPat .var ("x" ),)), create_subbuffer ),
146
+ # realize before COPY
147
+ (UPat (Ops .COPY , src = (UPat (), UPatScheduled ())), realize ),
152
148
])
153
149
154
150
def append_uop (ctx :ScheduleContext , view :UOp , buf_uop :UOp ) -> None :
0 commit comments