Skip to content

Commit da1cccb

Browse files
committed
fix initial pass for z3
Signed-off-by: Masahiro Tanaka <mtanaka@microsoft.com>
1 parent 7c78179 commit da1cccb

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

deepspeed/compile/passes/zero3_compile.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,12 @@ def add_z3_gather_release_fw(gm: GraphModule,
9999
debug_log=False) -> GraphModule:
100100

101101
nz3 = get_deepcompile_handle()
102-
graph = gm.graph
103102

104103
real_inputs = create_inputs_fn()
105104
param_indices = profiling_results[graph_id].param_indices
106105

107-
graph = add_gather_and_release(graph_id, graph, param_manager[graph_id], get_param_nodes(graph, param_indices))
106+
gm.graph = add_gather_and_release(graph_id, gm.graph, param_manager[graph_id],
107+
get_param_nodes(gm.graph, param_indices))
108108

109109
nz3.register_graph_z3(graph_id, [v[1] for v in param_indices]) # Need this before profiling
110110

@@ -163,7 +163,7 @@ def add_z3_gather_release_bw(gm: GraphModule,
163163
if rank == 0 and debug_log:
164164
print(f"Bwd before scheduling graph {graph_index} graph_id={graph_id} {gm.graph}")
165165

166-
gm.graph = fast_free_schedule(gm.graph, get_accelerator().available_memory(), 0, debug_log=debug_log)
166+
# gm.graph = fast_free_schedule(gm.graph, get_accelerator().available_memory(), 0, debug_log=debug_log)
167167
return gm
168168

169169

0 commit comments

Comments
 (0)