@@ -110,7 +110,6 @@ def add_buffers(buf:UOp, buffer_map:dict[UOp, UOp], cache:dict[UOp, UOp]) -> UOp
110
110
op = buf .replace (dtype = dtype , src = tuple (add_buffers (x , buffer_map , cache ) for x in buf .src ))
111
111
# track the buffer uop for the simplified uop
112
112
buffer_map [buf ] = buf_uop
113
- if op .op is Ops .BUFFER_VIEW : buffers [buf_uop ] = (x := op .src [0 ]).buf_uop .buffer .view (op .size , op .dtype , unwrap (x .st ).views [0 ].offset * x .dtype .itemsize )
114
113
# (early) bufferize
115
114
cache [buf ] = ret = UOp (Ops .VIEW , dtype .base , (buf_uop , op ), buf .st )
116
115
return ret
@@ -391,6 +390,8 @@ def schedule_uop(sink:UOp, var_vals:dict[Variable, int]) -> ScheduleItem:
391
390
ast = graph_rewrite (graph_rewrite (ast , unbind_vars + view_left , ctx = var_vals ), view_right )
392
391
# fix_kernel_ops
393
392
ast = graph_rewrite (ast , fix_kernel_ops , var_vals )
393
+ # create subbuffer
394
+ if ast .op is Ops .BUFFER_VIEW : buffers [bufs [0 ]] = bufs [1 ].buffer .view (ast .size , ast .dtype , (x := ast .src [0 ]).st_arg .views [0 ].offset * x .dtype .itemsize )
394
395
return ScheduleItem (ast , tuple (dedup ([x .buffer for x in bufs ])), sink .src [1 ].arg .metadata )
395
396
396
397
PROCESS_REPLAY_CAPTURE :dict [str , bytes ] = {}
@@ -454,8 +455,6 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
454
455
if any (x .op is Ops .ASSIGN and x .buf_uop is s for x in u .toposort ):
455
456
raise RuntimeError (f"cycle detected in graph, kernel must either depend on ASSIGN or BUFFER for { k } " )
456
457
assign_rep [a ] = kernel_assign [s ] = a .replace (src = a .src + (u ,))
457
- # increment the refcount of the target buf (this is required by the JIT and memory planner)
458
- u .buf_uop .buffer .ref (1 )
459
458
if assign_rep : sched_sink = sched_sink .substitute (assign_rep )
460
459
461
460
# display the final graph
@@ -478,6 +477,8 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
478
477
while queue :
479
478
u = queue .popleft ()
480
479
schedule .append (schedule_uop (u , var_vals ))
480
+ # increment the refcount of the target buf (this is required by the JIT and memory planner)
481
+ u .buf_uop .buffer .ref (1 )
481
482
for x in children .get (u , []):
482
483
in_degree [x ] -= 1
483
484
if in_degree [x ] == 0 : queue .append (x )
0 commit comments