Skip to content

Commit 4da6d71

Browse files
committed
fix test
1 parent ef5957c commit 4da6d71

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tests/tests_pytorch/utilities/test_deepspeed_model_summary.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323

2424
@RunIf(min_cuda_gpus=2, deepspeed=True, standalone=True)
25-
def test_deepspeed_summary(tmpdir):
25+
def test_deepspeed_summary(tmp_path):
2626
"""Test to ensure that the summary contains the correct values when stage 3 is enabled and that the trainer enables
2727
the `DeepSpeedSummary` when DeepSpeed is used."""
2828

@@ -37,12 +37,12 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
3737

3838
# check the additional params per device
3939
summary_data = model_summary._get_summary_data()
40-
params_per_device = summary_data[-1][-1]
40+
params_per_device = summary_data[4][-1]
4141
assert int(params_per_device[0]) == (model_summary.total_parameters // 2)
4242

4343
trainer = Trainer(
4444
strategy=DeepSpeedStrategy(stage=3),
45-
default_root_dir=tmpdir,
45+
default_root_dir=tmp_path,
4646
accelerator="gpu",
4747
fast_dev_run=True,
4848
devices=2,

0 commit comments

Comments
 (0)