14
14
15
15
import json
16
16
import os
17
- import time
18
17
from dataclasses import dataclass , field , is_dataclass
19
18
from typing import List , Optional , Union
20
19
21
20
import lightning .pytorch as pl
21
+ import numpy as np
22
22
import torch
23
23
from omegaconf import OmegaConf , open_dict
24
24
39
39
)
40
40
from nemo .core .config import hydra_runner
41
41
from nemo .utils import logging
42
+ from nemo .utils .timers import SimpleTimer
42
43
43
44
"""
44
45
Transcribe audio file on a single CPU/GPU. Useful for transcription of moderate amounts of audio data.
@@ -203,6 +204,8 @@ class TranscriptionConfig:
203
204
extract_nbest : bool = False # Extract n-best hypotheses from the model
204
205
205
206
calculate_rtfx : bool = False
207
+ warmup_steps : int = 0 # by default - no warmup
208
+ run_steps : int = 1 # by default - single run
206
209
207
210
208
211
@hydra_runner (config_name = "TranscriptionConfig" , schema = TranscriptionConfig )
@@ -378,11 +381,16 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis
378
381
)
379
382
total_duration += item ["duration" ]
380
383
384
+ if cfg .warmup_steps == 0 :
385
+ logging .warning (
386
+ "RTFx measurement enabled, but warmup_steps=0. "
387
+ "At least one warmup step is recommended to measure RTFx"
388
+ )
389
+
390
+ timer = SimpleTimer ()
391
+ model_measurements = []
381
392
with torch .amp .autocast ('cuda' if torch .cuda .is_available () else 'cpu' , dtype = amp_dtype , enabled = cfg .amp ):
382
393
with torch .no_grad ():
383
- if cfg .calculate_rtfx :
384
- start_time = time .time ()
385
-
386
394
override_cfg = asr_model .get_transcribe_config ()
387
395
override_cfg .batch_size = cfg .batch_size
388
396
override_cfg .num_workers = cfg .num_workers
@@ -395,12 +403,29 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis
395
403
if hasattr (override_cfg , "prompt" ):
396
404
override_cfg .prompt = parse_multitask_prompt (OmegaConf .to_container (cfg .prompt ))
397
405
398
- transcriptions = asr_model .transcribe (
399
- audio = filepaths ,
400
- override_config = override_cfg ,
401
- )
402
- if cfg .calculate_rtfx :
403
- transcribe_time = time .time () - start_time
406
+ device = next (asr_model .parameters ()).device
407
+ for run_step in range (cfg .warmup_steps + cfg .run_steps ):
408
+ if run_step < cfg .warmup_steps :
409
+ logging .info (f"Running warmup step { run_step } " )
410
+ # reset timer
411
+ timer .reset ()
412
+ timer .start (device = device )
413
+ # call transcribe
414
+ transcriptions = asr_model .transcribe (
415
+ audio = filepaths ,
416
+ override_config = override_cfg ,
417
+ )
418
+ # stop timer, log time
419
+ timer .stop (device = device )
420
+ logging .info (f"Model time for iteration { run_step } : { timer .total_sec ():.3f} " )
421
+ if run_step >= cfg .warmup_steps :
422
+ model_measurements .append (timer .total_sec ())
423
+
424
+ model_measurements_np = np .asarray (model_measurements )
425
+ logging .info (
426
+ f"Model time avg: { model_measurements_np .mean ():.3f} "
427
+ + (f" (std: { model_measurements_np .std ():.3f} )" if cfg .run_steps > 1 else "" )
428
+ )
404
429
405
430
if cfg .dataset_manifest is not None :
406
431
logging .info (f"Finished transcribing from manifest file: { cfg .dataset_manifest } " )
@@ -453,7 +478,11 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis
453
478
logging .info (f"{ total_res } " )
454
479
455
480
if cfg .calculate_rtfx :
456
- logging .info (f"Dataset RTFx { (total_duration / transcribe_time )} " )
481
+ rtfx_measurements = total_duration / model_measurements_np
482
+ logging .info (
483
+ f"Model RTFx on the dataset: { rtfx_measurements .mean ():.3f} "
484
+ + (f" (std: { rtfx_measurements .std ():.3f} )" if cfg .run_steps > 1 else "" )
485
+ )
457
486
458
487
return cfg
459
488
0 commit comments