@@ -61,7 +61,7 @@ def on_before_zero_grad(self, optimizer):
61
61
62
62
model = CurrentTestModel ()
63
63
64
- trainer = Trainer (default_root_dir = tmp_path , max_steps = max_steps , max_epochs = 2 )
64
+ trainer = Trainer (devices = 1 , default_root_dir = tmp_path , max_steps = max_steps , max_epochs = 2 )
65
65
assert model .on_before_zero_grad_called == 0
66
66
trainer .fit (model )
67
67
assert max_steps == model .on_before_zero_grad_called
@@ -406,7 +406,7 @@ def prepare_data(self): ...
406
406
@pytest .mark .parametrize (
407
407
"kwargs" ,
408
408
[
409
- {},
409
+ {"devices" : 1 },
410
410
# these precision plugins modify the optimization flow, so testing them explicitly
411
411
pytest .param ({"accelerator" : "gpu" , "devices" : 1 , "precision" : "16-mixed" }, marks = RunIf (min_cuda_gpus = 1 )),
412
412
pytest .param (
@@ -528,6 +528,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmp_path):
528
528
# initial training to get a checkpoint
529
529
model = BoringModel ()
530
530
trainer = Trainer (
531
+ devices = 1 ,
531
532
default_root_dir = tmp_path ,
532
533
max_epochs = 1 ,
533
534
limit_train_batches = 2 ,
@@ -543,6 +544,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmp_path):
543
544
callback = HookedCallback (called )
544
545
# already performed 1 step, resume and do 2 more
545
546
trainer = Trainer (
547
+ devices = 1 ,
546
548
default_root_dir = tmp_path ,
547
549
max_epochs = 2 ,
548
550
limit_train_batches = 2 ,
@@ -605,6 +607,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmp_path):
605
607
# initial training to get a checkpoint
606
608
model = BoringModel ()
607
609
trainer = Trainer (
610
+ devices = 1 ,
608
611
default_root_dir = tmp_path ,
609
612
max_steps = 1 ,
610
613
limit_val_batches = 0 ,
@@ -624,6 +627,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmp_path):
624
627
train_batches = 2
625
628
steps_after_reload = 1 + train_batches
626
629
trainer = Trainer (
630
+ devices = 1 ,
627
631
default_root_dir = tmp_path ,
628
632
max_steps = steps_after_reload ,
629
633
limit_val_batches = 0 ,
@@ -690,6 +694,7 @@ def test_trainer_model_hook_system_eval(tmp_path, override_on_x_model_train, bat
690
694
assert is_overridden (f"on_{ noun } _model_train" , model ) == override_on_x_model_train
691
695
callback = HookedCallback (called )
692
696
trainer = Trainer (
697
+ devices = 1 ,
693
698
default_root_dir = tmp_path ,
694
699
max_epochs = 1 ,
695
700
limit_val_batches = batches ,
@@ -731,7 +736,11 @@ def test_trainer_model_hook_system_predict(tmp_path):
731
736
callback = HookedCallback (called )
732
737
batches = 2
733
738
trainer = Trainer (
734
- default_root_dir = tmp_path , limit_predict_batches = batches , enable_progress_bar = False , callbacks = [callback ]
739
+ devices = 1 ,
740
+ default_root_dir = tmp_path ,
741
+ limit_predict_batches = batches ,
742
+ enable_progress_bar = False ,
743
+ callbacks = [callback ],
735
744
)
736
745
trainer .predict (model )
737
746
expected = [
@@ -797,7 +806,7 @@ def predict_dataloader(self):
797
806
798
807
model = CustomBoringModel ()
799
808
800
- trainer = Trainer (default_root_dir = tmp_path , fast_dev_run = 5 )
809
+ trainer = Trainer (devices = 1 , default_root_dir = tmp_path , fast_dev_run = 5 )
801
810
802
811
trainer .fit (model )
803
812
trainer .test (model )
@@ -812,6 +821,7 @@ def test_trainer_datamodule_hook_system(tmp_path):
812
821
model = BoringModel ()
813
822
batches = 2
814
823
trainer = Trainer (
824
+ devices = 1 ,
815
825
default_root_dir = tmp_path ,
816
826
max_epochs = 1 ,
817
827
limit_train_batches = batches ,
@@ -887,7 +897,7 @@ class CustomHookedModel(HookedModel):
887
897
assert is_overridden ("configure_model" , model ) == override_configure_model
888
898
889
899
datamodule = CustomHookedDataModule (ldm_called )
890
- trainer = Trainer ()
900
+ trainer = Trainer (devices = 1 )
891
901
trainer .strategy .connect (model )
892
902
trainer ._data_connector .attach_data (model , datamodule = datamodule )
893
903
ckpt_path = str (tmp_path / "file.ckpt" )
@@ -960,6 +970,7 @@ def predict_step(self, *args, **kwargs):
960
970
961
971
model = MixedTrainModeModule ()
962
972
trainer = Trainer (
973
+ devices = 1 ,
963
974
default_root_dir = tmp_path ,
964
975
max_epochs = 1 ,
965
976
val_check_interval = 1 ,
0 commit comments