Skip to content

Commit d87f39e

Browse files
authored
[Bugfix] Add init_cached_hf_modules to RayWorkerWrapper (#4286)
1 parent d3c8180 commit d87f39e

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

vllm/executor/ray_gpu_executor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
100100
)(RayWorkerWrapper).remote(
101101
worker_module_name="vllm.worker.worker",
102102
worker_class_name="Worker",
103+
trust_remote_code=self.model_config.trust_remote_code,
103104
)
104105

105106
worker_ip = ray.get(worker.get_node_ip.remote())
@@ -110,6 +111,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
110111
self.driver_worker = RayWorkerWrapper(
111112
worker_module_name="vllm.worker.worker",
112113
worker_class_name="Worker",
114+
trust_remote_code=self.model_config.trust_remote_code,
113115
)
114116
else:
115117
# Else, added to the list of workers.

vllm/worker/worker_base.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,15 @@ class WorkerWrapperBase:
103103

104104
def __init__(self,
105105
worker_module_name=None,
106-
worker_class_name=None) -> None:
106+
worker_class_name=None,
107+
trust_remote_code: bool = False) -> None:
107108
self.worker_module_name = worker_module_name
108109
self.worker_class_name = worker_class_name
109110
self.worker = None
111+
if trust_remote_code:
112+
# note: lazy import to avoid importing torch before initializing
113+
from vllm.utils import init_cached_hf_modules
114+
init_cached_hf_modules()
110115

111116
@staticmethod
112117
def update_environment_variables(envs: Dict[str, str]) -> None:

0 commit comments

Comments
 (0)