Skip to content

Commit 6a1a485

Browse files
committed
refactor
1 parent 3c7a2b7 commit 6a1a485

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

src/lightning/pytorch/callbacks/throughput_monitor.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,12 @@ def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) ->
9393
dtype = _plugin_to_compute_dtype(trainer.precision_plugin)
9494
self.available_flops = get_available_flops(trainer.strategy.root_device, dtype)
9595

96-
if stage == TrainerFn.FITTING:
97-
if trainer.enable_validation:
98-
# `fit` includes validation inside
99-
throughput = Throughput(
100-
available_flops=self.available_flops, world_size=trainer.world_size, **self.kwargs
101-
)
102-
self._throughputs[RunningStage.VALIDATING] = throughput
96+
if stage == TrainerFn.FITTING and trainer.enable_validation:
97+
# `fit` includes validation inside
98+
throughput = Throughput(
99+
available_flops=self.available_flops, world_size=trainer.world_size, **self.kwargs
100+
)
101+
self._throughputs[RunningStage.VALIDATING] = throughput
103102

104103
throughput = Throughput(available_flops=self.available_flops, world_size=trainer.world_size, **self.kwargs)
105104
stage = trainer.state.stage

0 commit comments

Comments
 (0)