Skip to content

Commit 5f7c796

Browse files
authored
jit: prune independent copies (tinygrad#9749)
* jit: prune independent copies * linter * check kernel cnt
1 parent c2573b2 commit 5f7c796

File tree

2 files changed

+25
-7
lines changed

2 files changed

+25
-7
lines changed

test/test_jit.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,24 @@ def w2(x) -> Tensor: return (weights*2).contiguous() + x.to(Device.DEFAULT)
570570
out = w2_prune(a)
571571
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
572572

573+
def test_prune_w_independent_copy_correct(self):
574+
weights = Tensor.rand(16, device="CPU").realize()
575+
def w2(x) -> Tensor: return (weights*2).contiguous().to(Device.DEFAULT) + x
576+
w2_noprune = TinyJit(w2)
577+
w2_prune = TinyJit(w2, prune=True)
578+
579+
for _ in range(3):
580+
a = Tensor.rand(16).realize()
581+
out = w2_noprune(a)
582+
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
583+
584+
for _ in range(3):
585+
a = Tensor.rand(16).realize()
586+
out = w2_prune(a)
587+
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
588+
589+
assert len(w2_prune.captured.jit_cache) == 1, "prune should have removed the copy"
590+
573591
class TestJitFree(unittest.TestCase):
574592
def test_free_intermediates(self):
575593
ext_tensor = Tensor([1,24,23,45,1])

tinygrad/engine/jit.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -130,13 +130,14 @@ def _access_resources(self, rawbufs:list[Buffer], write:list[int], new_dependenc
130130
# a marker for your graph supporting multiple devices of the same type
131131
class MultiGraphRunner(GraphRunner): pass
132132

133+
def get_out_buffers_for_ei(ei:ExecItem) -> list[Buffer]:
134+
if isinstance(ei.prg, CompiledRunner): return [cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs if out not in ei.prg.p.ins]
135+
if isinstance(ei.prg, (BufferCopy, BufferXfer)): return [cast(Buffer, ei.bufs[0])]
136+
return []
137+
133138
def update_depends(depends:set[Buffer|None], jit_cache:list[ExecItem]):
134139
for ei in jit_cache:
135-
if any(b in depends for b in ei.bufs):
136-
if isinstance(ei.prg, CompiledRunner):
137-
depends.update(cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs if out not in ei.prg.p.ins)
138-
if isinstance(ei.prg, (BufferCopy, BufferXfer)):
139-
depends.add(cast(Buffer, ei.bufs[0]))
140+
if any(b in depends for b in ei.bufs): depends.update(get_out_buffers_for_ei(ei))
140141

141142
ReturnType = TypeVar('ReturnType')
142143
@dataclass
@@ -294,8 +295,7 @@ def __call__(self, *args, **kwargs) -> ReturnType:
294295
if self.prune:
295296
depends = set(input_buffers)
296297
update_depends(depends, jit_cache)
297-
pruned, onetime = partition(jit_cache,
298-
lambda ei: not isinstance(ei.prg, CompiledRunner) or any(ei.bufs[out] in depends for out in ei.prg.p.outs))
298+
pruned, onetime = partition(jit_cache, lambda ei: any(b in depends for b in get_out_buffers_for_ei(ei)))
299299
if DEBUG >= 1: print(f"pruned from {len(jit_cache)} -> {len(pruned)} kernels")
300300
# run the onetime kernels here
301301
for ei in onetime:

0 commit comments

Comments
 (0)