From 5c9c8de86f4a20a32ec5c1808400e1ebb017ce82 Mon Sep 17 00:00:00 2001 From: Shukant Pal Date: Sun, 21 Jan 2024 11:53:16 -0800 Subject: [PATCH] Use default value as initial_scale_power if FP16 scaling params not provided The dynamic_loss_scale_args is None if some scaling param is not specified in the config: https://github.com/microsoft/DeepSpeed/blob/9d2660d2a3fac767972f01ac96858b2605ffc0e4/deepspeed/runtime/config.py#L215 In that case, it seems like DeepSpeed is using 2**32 as the initial_scale instead of the 2**16 as specified in the docs here: https://github.com/microsoft/DeepSpeed/blob/9d2660d2a3fac767972f01ac96858b2605ffc0e4/deepspeed/runtime/config.py#L215 --- deepspeed/runtime/fp16/loss_scaler.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/fp16/loss_scaler.py b/deepspeed/runtime/fp16/loss_scaler.py index 451451c51a32..dc82c0adf24f 100755 --- a/deepspeed/runtime/fp16/loss_scaler.py +++ b/deepspeed/runtime/fp16/loss_scaler.py @@ -23,6 +23,7 @@ import torch from deepspeed import comm as dist +from deepspeed.runtime.constants import FP16_INITIAL_SCALE_POWER_DEFAULT from deepspeed.utils import logger INITIAL_LOSS_SCALE = 'init_scale' @@ -109,14 +110,14 @@ class DynamicLossScaler(LossScalerBase): always using the highest loss scale possible without incurring overflow. Args: - init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.` + init_scale (float, optional, default=2**16): Initial loss scale attempted by :class:`DynamicLossScaler.` scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``. scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale. consecutive_hysteresis (bool, optional, default=False): Whether to refill hysteresis if we reach an iteration that doesn't overflow """ def __init__(self, - init_scale=2**32, + init_scale=2**FP16_INITIAL_SCALE_POWER_DEFAULT, scale_factor=2., scale_window=1000, min_scale=1,