From f3306399453fc9a5484eb32ae27a217da61ada0c Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 9 May 2025 17:51:14 -0700 Subject: [PATCH 1/2] [BugFix] Initialize random seed to 0 for V1 Signed-off-by: Woosuk Kwon --- vllm/config.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index ef0163eaff8..af66e8e397d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -260,7 +260,8 @@ class ModelConfig: - "float" is shorthand for FP32 precision.\n - "float32" for FP32 precision.""" seed: Optional[int] = None - """Random seed for reproducibility.""" + """Random seed for reproducibility. Initialized to None in V0, but + initialized to 0 in V1.""" hf_config_path: Optional[str] = None """Name or path of the Hugging Face config to use. If unspecified, model name or path will be used.""" @@ -440,6 +441,18 @@ def compute_hash(self) -> str: return hashlib.sha256(str(factors).encode()).hexdigest() def __post_init__(self) -> None: + # Set the default seed to 0 in V1. + # NOTE(woosuk): In V0, we set the default seed to None because the + # driver worker shares the same process as the user process, and thus + # setting a seed affects the user process as well. + # In V1, we use separate processes for workers (unless + # VLLM_ENABLE_V1_MULTIPROCESSING=0), so setting a seed here + # doesn't affect the user process. However, without a consistent seed, + # different tensor parallel workers would sample different tokens, + # leading to inconsistent results. + if self.seed is None and envs.VLLM_USE_V1: + self.seed = 0 + self.model = maybe_model_redirect(self.model) # The tokenizer is consistent with the model by default. if self.tokenizer is None: From 32f85bdf543b697fe51a6e8df661119cf50589f9 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 12 May 2025 21:02:21 -0700 Subject: [PATCH 2/2] Add warning Signed-off-by: Woosuk Kwon --- vllm/config.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 9ab7f51fde9..770fa9f627e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -451,8 +451,14 @@ def __post_init__(self) -> None: # doesn't affect the user process. However, without a consistent seed, # different tensor parallel workers would sample different tokens, # leading to inconsistent results. - if self.seed is None and envs.VLLM_USE_V1: + if envs.VLLM_USE_V1 and self.seed is None: self.seed = 0 + if not envs.VLLM_ENABLE_V1_MULTIPROCESSING: + logger.warning( + "The global random seed is set to %d. Since " + "VLLM_ENABLE_V1_MULTIPROCESSING is set to False, this may " + "affect the random state of the Python process that " + "launched vLLM.", self.seed) self.model = maybe_model_redirect(self.model) # The tokenizer is consistent with the model by default.