File tree 1 file changed +2
-2
lines changed
1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -658,9 +658,9 @@ def train_bert():
658
658
# ** hyperparameters **
659
659
BS = config ["GLOBAL_BATCH_SIZE" ] = getenv ("BS" , 11 * len (GPUS ) if dtypes .default_float in (dtypes .float16 , dtypes .bfloat16 ) else 8 * len (GPUS ))
660
660
EVAL_BS = config ["EVAL_BS" ] = getenv ("EVAL_BS" , 1 * len (GPUS ))
661
- max_lr = config ["OPT_BASE_LEARNING_RATE" ] = getenv ("OPT_BASE_LEARNING_RATE" , 0.0002 * math .sqrt (BS / 96 ))
661
+ max_lr = config ["OPT_BASE_LEARNING_RATE" ] = getenv ("OPT_BASE_LEARNING_RATE" , 0.00018 * math .sqrt (BS / 96 ))
662
662
663
- train_steps = config ["TRAIN_STEPS" ] = getenv ("TRAIN_STEPS" , 3630000 // BS )
663
+ train_steps = config ["TRAIN_STEPS" ] = getenv ("TRAIN_STEPS" , 3300000 // BS )
664
664
warmup_steps = config ["NUM_WARMUP_STEPS" ] = getenv ("NUM_WARMUP_STEPS" , 1 )
665
665
max_eval_steps = config ["MAX_EVAL_STEPS" ] = getenv ("MAX_EVAL_STEPS" , (10000 + EVAL_BS - 1 ) // EVAL_BS ) # EVAL_BS * MAX_EVAL_STEPS >= 10000
666
666
eval_step_freq = config ["EVAL_STEP_FREQ" ] = getenv ("EVAL_STEP_FREQ" , int ((math .floor (0.05 * (230.23 * BS + 3000000 ) / 25000 ) * 25000 ) / BS )) # Round down
You can’t perform that action at this time.
0 commit comments