@@ -149,7 +149,7 @@ class CapturedJit(Generic[ReturnType]):
149
149
expected_st_vars_dtype_device : list [tuple [ShapeTracker , tuple [Variable , ...], DType , str ]]
150
150
151
151
def __reduce__ (self ):
152
- # TODO: free_intermediates here? optimize_weights here?
152
+ # TODO: free_intermediates here? replan_buffers_memory_layout here?
153
153
return self .__class__ , (self .ret , self .jit_cache , self .input_replace , self .extra_view_inputs ,
154
154
self .expected_names , self .expected_st_vars_dtype_device )
155
155
@@ -171,7 +171,7 @@ def free_intermediates(self):
171
171
if b ._base is not None and b ._base .allocated_views == 0 and b ._base .is_allocated (): b ._base .deallocate ()
172
172
self .__post_init__ () # reset the graph state
173
173
174
- def optimize_weights (self ):
174
+ def replan_buffers_memory_layout (self ):
175
175
blacklist = [t .lazydata .buffer for t in get_parameters (self .ret )]
176
176
asgn = _internal_memory_planner ([[b for item in self .jit_cache for b in item .bufs if b is not None and b not in blacklist ]], ignore_checks = True )
177
177
self .jit_cache = [ExecItem (item .prg , [asgn .get (b ,b ) if b is not None else None for b in item .bufs ]) for item in self .jit_cache ]
@@ -314,7 +314,7 @@ def __call__(self, *args, **kwargs) -> ReturnType:
314
314
315
315
# set this for next run
316
316
self .captured = CapturedJit (ret , jit_cache , input_replace , extra_view_inputs , names , st_vars_dtype_device )
317
- if self .optimize : self .captured .optimize_weights ()
317
+ if self .optimize : self .captured .replan_buffers_memory_layout ()
318
318
elif self .cnt >= 2 :
319
319
# jit exec
320
320
assert self .captured is not None
0 commit comments