Skip to content

Commit 42af40e

Browse files
committed
parameterize
1 parent 9e4f10b commit 42af40e

File tree

1 file changed

+14
-54
lines changed

1 file changed

+14
-54
lines changed

tests/tests_pytorch/callbacks/test_throughput_monitor.py

+14-54
Original file line numberDiff line numberDiff line change
@@ -160,22 +160,22 @@ def test_throughput_monitor_fit_no_length_fn(tmp_path):
160160
]
161161

162162

163-
def test_throughput_monitor_fit_gradient_accumulation(tmp_path):
163+
@pytest.mark.parametrize("log_every_n_steps", [1, 3])
164+
def test_throughput_monitor_fit_gradient_accumulation(log_every_n_steps, tmp_path):
164165
logger_mock = Mock()
165166
logger_mock.save_dir = tmp_path
166167
monitor = ThroughputMonitor(length_fn=lambda x: 3 * 2, batch_size_fn=lambda x: 3, window_size=4, separator="|")
167168
model = BoringModel()
168169
model.flops_per_batch = 10
169170

170-
# accumulate_grad_batches=2, log_every_n_steps=3
171171
trainer = Trainer(
172172
devices=1,
173173
logger=logger_mock,
174174
callbacks=monitor,
175175
limit_train_batches=5,
176176
limit_val_batches=0,
177177
max_epochs=2,
178-
log_every_n_steps=3,
178+
log_every_n_steps=log_every_n_steps,
179179
accumulate_grad_batches=2,
180180
enable_checkpointing=False,
181181
enable_model_summary=False,
@@ -194,61 +194,19 @@ def test_throughput_monitor_fit_gradient_accumulation(tmp_path):
194194
"train|device|flops_per_sec": 10.0,
195195
"train|device|mfu": 0.1,
196196
}
197-
assert logger_mock.log_metrics.mock_calls == [
197+
198+
all_log_calls = [
198199
call(
199200
metrics={
200-
**expected,
201-
"train|time": 5.5,
202-
"train|batches": 5,
203-
"train|samples": 15,
204-
"train|lengths": 30,
201+
# The very first batch doesn't have the *_per_sec metrics yet
202+
**(expected if log_every_n_steps > 1 else {}),
203+
"train|time": 2.5,
204+
"train|batches": 2,
205+
"train|samples": 6,
206+
"train|lengths": 12,
205207
"epoch": 0,
206208
},
207-
step=2,
208-
),
209-
call(
210-
metrics={
211-
**expected,
212-
"train|time": 10.5,
213-
"train|batches": 10,
214-
"train|samples": 30,
215-
"train|lengths": 60,
216-
"epoch": 1,
217-
},
218-
step=5,
219-
),
220-
]
221-
222-
# accumulate_grad_batches=2, log_every_n_steps=1
223-
trainer = Trainer(
224-
devices=1,
225-
logger=logger_mock,
226-
callbacks=monitor,
227-
limit_train_batches=5,
228-
limit_val_batches=0,
229-
max_epochs=2,
230-
log_every_n_steps=1,
231-
accumulate_grad_batches=2,
232-
enable_checkpointing=False,
233-
enable_model_summary=False,
234-
enable_progress_bar=False,
235-
)
236-
timings = [0.0] + [0.5 + i for i in range(1, 11)]
237-
with mock.patch("lightning.pytorch.callbacks.throughput_monitor.get_available_flops", return_value=100), mock.patch(
238-
"time.perf_counter", side_effect=timings
239-
):
240-
trainer.fit(model)
241-
242-
expected = {
243-
"train|device|batches_per_sec": 1.0,
244-
"train|device|samples_per_sec": 3.0,
245-
"train|device|items_per_sec": 6.0,
246-
"train|device|flops_per_sec": 10.0,
247-
"train|device|mfu": 0.1,
248-
}
249-
assert logger_mock.log_metrics.mock_calls == [
250-
call(
251-
metrics={"train|time": 2.5, "train|batches": 2, "train|samples": 6, "train|lengths": 12, "epoch": 0}, step=0
209+
step=0,
252210
),
253211
call(
254212
metrics={
@@ -306,6 +264,8 @@ def test_throughput_monitor_fit_gradient_accumulation(tmp_path):
306264
step=5,
307265
),
308266
]
267+
expected_log_calls = all_log_calls[(log_every_n_steps - 1) :: log_every_n_steps]
268+
assert logger_mock.log_metrics.mock_calls == expected_log_calls
309269

310270

311271
@pytest.mark.parametrize("fn", ["validate", "test", "predict"])

0 commit comments

Comments
 (0)