@@ -440,24 +440,33 @@ def initialize_dummy_weights(
440
440
model : torch .nn .Module ,
441
441
low : float = - 1e-3 ,
442
442
high : float = 1e-3 ,
443
+ seed : int = 1234 ,
443
444
) -> None :
444
445
"""Initialize model weights with random values.
445
446
446
447
The model weights must be randomly initialized for accurate performance
447
448
measurements. Additionally, the model weights should not cause NaNs in the
448
449
forward pass. We empirically found that initializing the weights with
449
450
values between -1e-3 and 1e-3 works well for most models.
451
+
452
+ We use per-parameter random seed, so that dummy weights are consistent,
453
+ even if the model is partitioned across multiple devices. When the seed
454
+ is fixed, the random values generated by this function only depends on
455
+ the parameter's number of elements and its data type.
450
456
"""
451
457
for param in model .state_dict ().values ():
452
458
if torch .is_floating_point (param ):
459
+ generator = torch .Generator (device = param .data .device )
460
+ generator .manual_seed (seed )
453
461
if torch .finfo (param .data .dtype ).bits < 16 :
454
462
# uniform_ doesn't support < 16-bit datatypes (FP8)
455
463
dtype = param .data .dtype
456
464
tmp_param = param .data .to (torch .float16 )
457
- tmp_param = tmp_param .uniform_ (low , high ).to (dtype )
465
+ tmp_param = tmp_param .uniform_ (low , high ,
466
+ generator = generator ).to (dtype )
458
467
param .data .copy_ (tmp_param )
459
468
else :
460
- param .uniform_ (low , high )
469
+ param .uniform_ (low , high , generator = generator )
461
470
462
471
463
472
def maybe_remap_kv_scale_name (name : str , params_dict : dict ) -> Optional [str ]:
0 commit comments