@@ -771,48 +771,50 @@ def repeat_fake(bs):
771
771
if MLLOGGER :
772
772
MLLOGGER .start (key = mllog_constants .EPOCH_START , value = i * BS , metadata = {"epoch_num" : i * BS })
773
773
774
+ # TODO: put copy into jit
774
775
while train_data is not None and i < train_steps and not achieved :
775
- Tensor .training = True
776
- BEAM .value = TRAIN_BEAM
777
- st = time .perf_counter ()
778
- GlobalCounters .reset ()
779
- loss , global_norm = train_step_bert (model , optimizer_group , scheduler_group , loss_scaler ,
780
- train_data ["input_ids" ], train_data ["segment_ids" ], train_data ["input_mask" ], train_data ["masked_lm_positions" ], \
781
- train_data ["masked_lm_ids" ], train_data ["masked_lm_weights" ], train_data ["next_sentence_labels" ], GPUS )
776
+ if getenv ("TRAIN" , 1 ):
777
+ Tensor .training = True
778
+ BEAM .value = TRAIN_BEAM
779
+ st = time .perf_counter ()
780
+ GlobalCounters .reset ()
781
+ loss , global_norm = train_step_bert (model , optimizer_group , scheduler_group , loss_scaler ,
782
+ train_data ["input_ids" ], train_data ["segment_ids" ], train_data ["input_mask" ], train_data ["masked_lm_positions" ], \
783
+ train_data ["masked_lm_ids" ], train_data ["masked_lm_weights" ], train_data ["next_sentence_labels" ], GPUS )
782
784
783
- pt = time .perf_counter ()
785
+ pt = time .perf_counter ()
784
786
785
- try :
786
- next_data = next (train_it )
787
- except StopIteration :
788
- next_data = None
789
-
790
- dt = time .perf_counter ()
791
-
792
- device_str = loss .device if isinstance (loss .device , str ) else f"{ loss .device [0 ]} * { len (loss .device )} "
793
- loss = loss .item ()
794
-
795
- cl = time .perf_counter ()
796
- if BENCHMARK : step_times .append (cl - st )
797
-
798
- tqdm .write (
799
- f"{ i :5} { ((cl - st )) * 1000.0 :7.2f} ms run, { (pt - st ) * 1000.0 :7.2f} ms python, { (dt - pt ) * 1000.0 :6.2f} ms fetch data, "
800
- f"{ (cl - dt ) * 1000.0 :7.2f} ms { device_str } , { loss :5.2f} loss, { optimizer_wd .lr .numpy ()[0 ]:.6f} LR, "
801
- f"{ GlobalCounters .mem_used / 1e9 :.2f} GB used, { GlobalCounters .global_ops * 1e-9 / (cl - st ):9.2f} GFLOPS" )
802
- if WANDB :
803
- wandb .log ({"lr" : optimizer_wd .lr .numpy (), "train/loss" : loss , "train/global_norm" : global_norm .item (), "train/step_time" : cl - st ,
804
- "train/python_time" : pt - st , "train/data_time" : dt - pt , "train/cl_time" : cl - dt ,
805
- "train/GFLOPS" : GlobalCounters .global_ops * 1e-9 / (cl - st ), "epoch" : (i + 1 )* BS })
806
-
807
- train_data , next_data = next_data , None
808
- i += 1
809
-
810
- if i == BENCHMARK :
811
- median_step_time = sorted (step_times )[(BENCHMARK + 1 ) // 2 ] # in seconds
812
- estimated_total_minutes = int (median_step_time * train_steps / 60 )
813
- print (f"Estimated training time: { estimated_total_minutes // 60 } h{ estimated_total_minutes % 60 } m" )
814
- print (f"epoch global_ops: { train_steps * GlobalCounters .global_ops :_} , "
815
- f"epoch global_mem: { train_steps * GlobalCounters .global_mem :_} " )
787
+ try :
788
+ next_data = next (train_it )
789
+ except StopIteration :
790
+ next_data = None
791
+
792
+ dt = time .perf_counter ()
793
+
794
+ device_str = loss .device if isinstance (loss .device , str ) else f"{ loss .device [0 ]} * { len (loss .device )} "
795
+ loss = loss .item ()
796
+
797
+ cl = time .perf_counter ()
798
+ if BENCHMARK : step_times .append (cl - st )
799
+
800
+ tqdm .write (
801
+ f"{ i :5} { ((cl - st )) * 1000.0 :7.2f} ms run, { (pt - st ) * 1000.0 :7.2f} ms python, { (dt - pt ) * 1000.0 :6.2f} ms fetch data, "
802
+ f"{ (cl - dt ) * 1000.0 :7.2f} ms { device_str } , { loss :5.2f} loss, { optimizer_wd .lr .numpy ()[0 ]:.6f} LR, "
803
+ f"{ GlobalCounters .mem_used / 1e9 :.2f} GB used, { GlobalCounters .global_ops * 1e-9 / (cl - st ):9.2f} GFLOPS" )
804
+ if WANDB :
805
+ wandb .log ({"lr" : optimizer_wd .lr .numpy (), "train/loss" : loss , "train/global_norm" : global_norm .item (), "train/step_time" : cl - st ,
806
+ "train/python_time" : pt - st , "train/data_time" : dt - pt , "train/cl_time" : cl - dt ,
807
+ "train/GFLOPS" : GlobalCounters .global_ops * 1e-9 / (cl - st ), "epoch" : (i + 1 )* BS })
808
+
809
+ train_data , next_data = next_data , None
810
+ i += 1
811
+
812
+ if i == BENCHMARK :
813
+ median_step_time = sorted (step_times )[(BENCHMARK + 1 ) // 2 ] # in seconds
814
+ estimated_total_minutes = int (median_step_time * train_steps / 60 )
815
+ print (f"Estimated training time: { estimated_total_minutes // 60 } h{ estimated_total_minutes % 60 } m" )
816
+ print (f"epoch global_ops: { train_steps * GlobalCounters .global_ops :_} , "
817
+ f"epoch global_mem: { train_steps * GlobalCounters .global_mem :_} " )
816
818
817
819
# ** eval loop **
818
820
if i % eval_step_freq == 0 or (BENCHMARK and i == BENCHMARK ) or i == train_steps :
0 commit comments