Skip to content

Commit c20c173

Browse files
authored
Merge branch 'master' into chualan/fix-19658
2 parents 5be022e + 60289d7 commit c20c173

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

tests/tests_pytorch/models/test_hooks.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def on_before_zero_grad(self, optimizer):
6161

6262
model = CurrentTestModel()
6363

64-
trainer = Trainer(default_root_dir=tmp_path, max_steps=max_steps, max_epochs=2)
64+
trainer = Trainer(devices=1, default_root_dir=tmp_path, max_steps=max_steps, max_epochs=2)
6565
assert model.on_before_zero_grad_called == 0
6666
trainer.fit(model)
6767
assert max_steps == model.on_before_zero_grad_called
@@ -406,7 +406,7 @@ def prepare_data(self): ...
406406
@pytest.mark.parametrize(
407407
"kwargs",
408408
[
409-
{},
409+
{"devices": 1},
410410
# these precision plugins modify the optimization flow, so testing them explicitly
411411
pytest.param({"accelerator": "gpu", "devices": 1, "precision": "16-mixed"}, marks=RunIf(min_cuda_gpus=1)),
412412
pytest.param(
@@ -528,6 +528,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmp_path):
528528
# initial training to get a checkpoint
529529
model = BoringModel()
530530
trainer = Trainer(
531+
devices=1,
531532
default_root_dir=tmp_path,
532533
max_epochs=1,
533534
limit_train_batches=2,
@@ -543,6 +544,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmp_path):
543544
callback = HookedCallback(called)
544545
# already performed 1 step, resume and do 2 more
545546
trainer = Trainer(
547+
devices=1,
546548
default_root_dir=tmp_path,
547549
max_epochs=2,
548550
limit_train_batches=2,
@@ -605,6 +607,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmp_path):
605607
# initial training to get a checkpoint
606608
model = BoringModel()
607609
trainer = Trainer(
610+
devices=1,
608611
default_root_dir=tmp_path,
609612
max_steps=1,
610613
limit_val_batches=0,
@@ -624,6 +627,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmp_path):
624627
train_batches = 2
625628
steps_after_reload = 1 + train_batches
626629
trainer = Trainer(
630+
devices=1,
627631
default_root_dir=tmp_path,
628632
max_steps=steps_after_reload,
629633
limit_val_batches=0,
@@ -690,6 +694,7 @@ def test_trainer_model_hook_system_eval(tmp_path, override_on_x_model_train, bat
690694
assert is_overridden(f"on_{noun}_model_train", model) == override_on_x_model_train
691695
callback = HookedCallback(called)
692696
trainer = Trainer(
697+
devices=1,
693698
default_root_dir=tmp_path,
694699
max_epochs=1,
695700
limit_val_batches=batches,
@@ -731,7 +736,11 @@ def test_trainer_model_hook_system_predict(tmp_path):
731736
callback = HookedCallback(called)
732737
batches = 2
733738
trainer = Trainer(
734-
default_root_dir=tmp_path, limit_predict_batches=batches, enable_progress_bar=False, callbacks=[callback]
739+
devices=1,
740+
default_root_dir=tmp_path,
741+
limit_predict_batches=batches,
742+
enable_progress_bar=False,
743+
callbacks=[callback],
735744
)
736745
trainer.predict(model)
737746
expected = [
@@ -797,7 +806,7 @@ def predict_dataloader(self):
797806

798807
model = CustomBoringModel()
799808

800-
trainer = Trainer(default_root_dir=tmp_path, fast_dev_run=5)
809+
trainer = Trainer(devices=1, default_root_dir=tmp_path, fast_dev_run=5)
801810

802811
trainer.fit(model)
803812
trainer.test(model)
@@ -812,6 +821,7 @@ def test_trainer_datamodule_hook_system(tmp_path):
812821
model = BoringModel()
813822
batches = 2
814823
trainer = Trainer(
824+
devices=1,
815825
default_root_dir=tmp_path,
816826
max_epochs=1,
817827
limit_train_batches=batches,
@@ -887,7 +897,7 @@ class CustomHookedModel(HookedModel):
887897
assert is_overridden("configure_model", model) == override_configure_model
888898

889899
datamodule = CustomHookedDataModule(ldm_called)
890-
trainer = Trainer()
900+
trainer = Trainer(devices=1)
891901
trainer.strategy.connect(model)
892902
trainer._data_connector.attach_data(model, datamodule=datamodule)
893903
ckpt_path = str(tmp_path / "file.ckpt")
@@ -960,6 +970,7 @@ def predict_step(self, *args, **kwargs):
960970

961971
model = MixedTrainModeModule()
962972
trainer = Trainer(
973+
devices=1,
963974
default_root_dir=tmp_path,
964975
max_epochs=1,
965976
val_check_interval=1,

0 commit comments

Comments
 (0)