@@ -160,22 +160,22 @@ def test_throughput_monitor_fit_no_length_fn(tmp_path):
160
160
]
161
161
162
162
163
- def test_throughput_monitor_fit_gradient_accumulation (tmp_path ):
163
+ @pytest .mark .parametrize ("log_every_n_steps" , [1 , 3 ])
164
+ def test_throughput_monitor_fit_gradient_accumulation (log_every_n_steps , tmp_path ):
164
165
logger_mock = Mock ()
165
166
logger_mock .save_dir = tmp_path
166
167
monitor = ThroughputMonitor (length_fn = lambda x : 3 * 2 , batch_size_fn = lambda x : 3 , window_size = 4 , separator = "|" )
167
168
model = BoringModel ()
168
169
model .flops_per_batch = 10
169
170
170
- # accumulate_grad_batches=2, log_every_n_steps=3
171
171
trainer = Trainer (
172
172
devices = 1 ,
173
173
logger = logger_mock ,
174
174
callbacks = monitor ,
175
175
limit_train_batches = 5 ,
176
176
limit_val_batches = 0 ,
177
177
max_epochs = 2 ,
178
- log_every_n_steps = 3 ,
178
+ log_every_n_steps = log_every_n_steps ,
179
179
accumulate_grad_batches = 2 ,
180
180
enable_checkpointing = False ,
181
181
enable_model_summary = False ,
@@ -194,61 +194,19 @@ def test_throughput_monitor_fit_gradient_accumulation(tmp_path):
194
194
"train|device|flops_per_sec" : 10.0 ,
195
195
"train|device|mfu" : 0.1 ,
196
196
}
197
- assert logger_mock .log_metrics .mock_calls == [
197
+
198
+ all_log_calls = [
198
199
call (
199
200
metrics = {
200
- ** expected ,
201
- "train|time" : 5.5 ,
202
- "train|batches" : 5 ,
203
- "train|samples" : 15 ,
204
- "train|lengths" : 30 ,
201
+ # The very first batch doesn't have the *_per_sec metrics yet
202
+ ** (expected if log_every_n_steps > 1 else {}),
203
+ "train|time" : 2.5 ,
204
+ "train|batches" : 2 ,
205
+ "train|samples" : 6 ,
206
+ "train|lengths" : 12 ,
205
207
"epoch" : 0 ,
206
208
},
207
- step = 2 ,
208
- ),
209
- call (
210
- metrics = {
211
- ** expected ,
212
- "train|time" : 10.5 ,
213
- "train|batches" : 10 ,
214
- "train|samples" : 30 ,
215
- "train|lengths" : 60 ,
216
- "epoch" : 1 ,
217
- },
218
- step = 5 ,
219
- ),
220
- ]
221
-
222
- # accumulate_grad_batches=2, log_every_n_steps=1
223
- trainer = Trainer (
224
- devices = 1 ,
225
- logger = logger_mock ,
226
- callbacks = monitor ,
227
- limit_train_batches = 5 ,
228
- limit_val_batches = 0 ,
229
- max_epochs = 2 ,
230
- log_every_n_steps = 1 ,
231
- accumulate_grad_batches = 2 ,
232
- enable_checkpointing = False ,
233
- enable_model_summary = False ,
234
- enable_progress_bar = False ,
235
- )
236
- timings = [0.0 ] + [0.5 + i for i in range (1 , 11 )]
237
- with mock .patch ("lightning.pytorch.callbacks.throughput_monitor.get_available_flops" , return_value = 100 ), mock .patch (
238
- "time.perf_counter" , side_effect = timings
239
- ):
240
- trainer .fit (model )
241
-
242
- expected = {
243
- "train|device|batches_per_sec" : 1.0 ,
244
- "train|device|samples_per_sec" : 3.0 ,
245
- "train|device|items_per_sec" : 6.0 ,
246
- "train|device|flops_per_sec" : 10.0 ,
247
- "train|device|mfu" : 0.1 ,
248
- }
249
- assert logger_mock .log_metrics .mock_calls == [
250
- call (
251
- metrics = {"train|time" : 2.5 , "train|batches" : 2 , "train|samples" : 6 , "train|lengths" : 12 , "epoch" : 0 }, step = 0
209
+ step = 0 ,
252
210
),
253
211
call (
254
212
metrics = {
@@ -306,6 +264,8 @@ def test_throughput_monitor_fit_gradient_accumulation(tmp_path):
306
264
step = 5 ,
307
265
),
308
266
]
267
+ expected_log_calls = all_log_calls [(log_every_n_steps - 1 ) :: log_every_n_steps ]
268
+ assert logger_mock .log_metrics .mock_calls == expected_log_calls
309
269
310
270
311
271
@pytest .mark .parametrize ("fn" , ["validate" , "test" , "predict" ])
0 commit comments