Skip to content

Commit a427548

Browse files
committed
DS fix, continued (#3145)
1 parent 42be235 commit a427548

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

tests/deepspeed/test_deepspeed_multiple_model.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,10 @@ def test_config_reference_update(self):
118118
accelerator = Accelerator(deepspeed_plugin=ds_plugins)
119119
from transformers.integrations.deepspeed import deepspeed_config
120120

121+
# Note that these have `auto` values being set so we need to adjust
121122
assert accelerator.deepspeed_plugin is zero2
123+
zero2.deepspeed_config["train_micro_batch_size_per_gpu"] = 1
124+
zero2.deepspeed_config.pop("train_batch_size")
122125
assert deepspeed_config() == accelerator.deepspeed_plugin.hf_ds_config.config
123126

124127
accelerator.state.select_deepspeed_plugin("zero3")
@@ -173,6 +176,6 @@ def test_prepare_multiple_models_zero3_inference(self):
173176
@slow
174177
def test_train_multiple_models(self):
175178
self.test_file_path = self.test_scripts_folder / "test_ds_multiple_model.py"
176-
args = ["--num_processes=2", "--num_machines=1", "--main_process_port=10999", str(self.test_file_path)]
179+
args = ["--num_processes=2", "--num_machines=1", "--main_process_port=0", str(self.test_file_path)]
177180
args = self.parser.parse_args(args)
178181
launch_command(args)

0 commit comments

Comments
 (0)