Skip to content

Commit c2573b2

Browse files
authored
jit: rename optimize_weights -> replan_buffers_memory_layout (tinygrad#9751)
1 parent 493fb31 commit c2573b2

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

test/test_jit.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -617,8 +617,8 @@ def fxn(y):
617617
fxn(Tensor([2]))
618618
self.assertEqual(x.item(), 8)
619619

620-
def test_optimize_weights(self):
621-
if not hasattr(Device[Device.DEFAULT].allocator, '_offset'): raise unittest.SkipTest("optimize_weights useless")
620+
def test_replan_buffers_memory_layout(self):
621+
if not hasattr(Device[Device.DEFAULT].allocator, '_offset'): raise unittest.SkipTest("replan_buffers_memory_layout useless")
622622

623623
ext_tensor = Tensor([1,24,23,45,1])
624624
ext_tensor_2 = Tensor([2,2,2,2,2])
@@ -630,7 +630,7 @@ def fxn(x:Tensor):
630630
out = fxn(Tensor([i,1,2,3,4]))
631631
self.assertEqual(out.item(), 11400+200*i)
632632
assert len(set([b.base for item in fxn.captured.jit_cache for b in item.bufs if b is not None])) == 4
633-
fxn.captured.optimize_weights()
633+
fxn.captured.replan_buffers_memory_layout()
634634
assert len(set([b.base for item in fxn.captured.jit_cache for b in item.bufs if b is not None])) == 2
635635

636636
out = fxn(Tensor([11,1,2,3,4]))

tinygrad/engine/jit.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ class CapturedJit(Generic[ReturnType]):
149149
expected_st_vars_dtype_device: list[tuple[ShapeTracker, tuple[Variable, ...], DType, str]]
150150

151151
def __reduce__(self):
152-
# TODO: free_intermediates here? optimize_weights here?
152+
# TODO: free_intermediates here? replan_buffers_memory_layout here?
153153
return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs,
154154
self.expected_names, self.expected_st_vars_dtype_device)
155155

@@ -171,7 +171,7 @@ def free_intermediates(self):
171171
if b._base is not None and b._base.allocated_views == 0 and b._base.is_allocated(): b._base.deallocate()
172172
self.__post_init__() # reset the graph state
173173

174-
def optimize_weights(self):
174+
def replan_buffers_memory_layout(self):
175175
blacklist = [t.lazydata.buffer for t in get_parameters(self.ret)]
176176
asgn = _internal_memory_planner([[b for item in self.jit_cache for b in item.bufs if b is not None and b not in blacklist]], ignore_checks=True)
177177
self.jit_cache = [ExecItem(item.prg, [asgn.get(b,b) if b is not None else None for b in item.bufs]) for item in self.jit_cache]
@@ -314,7 +314,7 @@ def __call__(self, *args, **kwargs) -> ReturnType:
314314

315315
# set this for next run
316316
self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, st_vars_dtype_device)
317-
if self.optimize: self.captured.optimize_weights()
317+
if self.optimize: self.captured.replan_buffers_memory_layout()
318318
elif self.cnt >= 2:
319319
# jit exec
320320
assert self.captured is not None

0 commit comments

Comments
 (0)