Skip to content

Commit 9e4f10b

Browse files
committed
update
1 parent 6a1a485 commit 9e4f10b

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed

tests/tests_pytorch/callbacks/test_throughput_monitor.py

+53
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,59 @@ def test_throughput_monitor_fit_gradient_accumulation(tmp_path):
167167
model = BoringModel()
168168
model.flops_per_batch = 10
169169

170+
# accumulate_grad_batches=2, log_every_n_steps=3
171+
trainer = Trainer(
172+
devices=1,
173+
logger=logger_mock,
174+
callbacks=monitor,
175+
limit_train_batches=5,
176+
limit_val_batches=0,
177+
max_epochs=2,
178+
log_every_n_steps=3,
179+
accumulate_grad_batches=2,
180+
enable_checkpointing=False,
181+
enable_model_summary=False,
182+
enable_progress_bar=False,
183+
)
184+
timings = [0.0] + [0.5 + i for i in range(1, 11)]
185+
with mock.patch("lightning.pytorch.callbacks.throughput_monitor.get_available_flops", return_value=100), mock.patch(
186+
"time.perf_counter", side_effect=timings
187+
):
188+
trainer.fit(model)
189+
190+
expected = {
191+
"train|device|batches_per_sec": 1.0,
192+
"train|device|samples_per_sec": 3.0,
193+
"train|device|items_per_sec": 6.0,
194+
"train|device|flops_per_sec": 10.0,
195+
"train|device|mfu": 0.1,
196+
}
197+
assert logger_mock.log_metrics.mock_calls == [
198+
call(
199+
metrics={
200+
**expected,
201+
"train|time": 5.5,
202+
"train|batches": 5,
203+
"train|samples": 15,
204+
"train|lengths": 30,
205+
"epoch": 0,
206+
},
207+
step=2,
208+
),
209+
call(
210+
metrics={
211+
**expected,
212+
"train|time": 10.5,
213+
"train|batches": 10,
214+
"train|samples": 30,
215+
"train|lengths": 60,
216+
"epoch": 1,
217+
},
218+
step=5,
219+
),
220+
]
221+
222+
# accumulate_grad_batches=2, log_every_n_steps=1
170223
trainer = Trainer(
171224
devices=1,
172225
logger=logger_mock,

0 commit comments

Comments
 (0)