@@ -130,13 +130,14 @@ def _access_resources(self, rawbufs:list[Buffer], write:list[int], new_dependenc
130
130
# a marker for your graph supporting multiple devices of the same type
131
131
class MultiGraphRunner (GraphRunner ): pass
132
132
133
+ def get_out_buffers_for_ei (ei :ExecItem ) -> list [Buffer ]:
134
+ if isinstance (ei .prg , CompiledRunner ): return [cast (Buffer , ei .bufs [out ]) for out in ei .prg .p .outs if out not in ei .prg .p .ins ]
135
+ if isinstance (ei .prg , (BufferCopy , BufferXfer )): return [cast (Buffer , ei .bufs [0 ])]
136
+ return []
137
+
133
138
def update_depends (depends :set [Buffer | None ], jit_cache :list [ExecItem ]):
134
139
for ei in jit_cache :
135
- if any (b in depends for b in ei .bufs ):
136
- if isinstance (ei .prg , CompiledRunner ):
137
- depends .update (cast (Buffer , ei .bufs [out ]) for out in ei .prg .p .outs if out not in ei .prg .p .ins )
138
- if isinstance (ei .prg , (BufferCopy , BufferXfer )):
139
- depends .add (cast (Buffer , ei .bufs [0 ]))
140
+ if any (b in depends for b in ei .bufs ): depends .update (get_out_buffers_for_ei (ei ))
140
141
141
142
ReturnType = TypeVar ('ReturnType' )
142
143
@dataclass
@@ -294,8 +295,7 @@ def __call__(self, *args, **kwargs) -> ReturnType:
294
295
if self .prune :
295
296
depends = set (input_buffers )
296
297
update_depends (depends , jit_cache )
297
- pruned , onetime = partition (jit_cache ,
298
- lambda ei : not isinstance (ei .prg , CompiledRunner ) or any (ei .bufs [out ] in depends for out in ei .prg .p .outs ))
298
+ pruned , onetime = partition (jit_cache , lambda ei : any (b in depends for b in get_out_buffers_for_ei (ei )))
299
299
if DEBUG >= 1 : print (f"pruned from { len (jit_cache )} -> { len (pruned )} kernels" )
300
300
# run the onetime kernels here
301
301
for ei in onetime :
0 commit comments