Skip to content

Commit ce37be7

Browse files
authored
[misc][distributed] add seed to dummy weights (#6491)
1 parent 7f62077 commit ce37be7

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

vllm/model_executor/model_loader/weight_utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -440,24 +440,33 @@ def initialize_dummy_weights(
440440
model: torch.nn.Module,
441441
low: float = -1e-3,
442442
high: float = 1e-3,
443+
seed: int = 1234,
443444
) -> None:
444445
"""Initialize model weights with random values.
445446
446447
The model weights must be randomly initialized for accurate performance
447448
measurements. Additionally, the model weights should not cause NaNs in the
448449
forward pass. We empirically found that initializing the weights with
449450
values between -1e-3 and 1e-3 works well for most models.
451+
452+
We use per-parameter random seed, so that dummy weights are consistent,
453+
even if the model is partitioned across multiple devices. When the seed
454+
is fixed, the random values generated by this function only depends on
455+
the parameter's number of elements and its data type.
450456
"""
451457
for param in model.state_dict().values():
452458
if torch.is_floating_point(param):
459+
generator = torch.Generator(device=param.data.device)
460+
generator.manual_seed(seed)
453461
if torch.finfo(param.data.dtype).bits < 16:
454462
# uniform_ doesn't support < 16-bit datatypes (FP8)
455463
dtype = param.data.dtype
456464
tmp_param = param.data.to(torch.float16)
457-
tmp_param = tmp_param.uniform_(low, high).to(dtype)
465+
tmp_param = tmp_param.uniform_(low, high,
466+
generator=generator).to(dtype)
458467
param.data.copy_(tmp_param)
459468
else:
460-
param.uniform_(low, high)
469+
param.uniform_(low, high, generator=generator)
461470

462471

463472
def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:

0 commit comments

Comments
 (0)