21
21
get_tensor_model_parallel_rank , get_tensor_model_parallel_world_size ,
22
22
get_tp_group )
23
23
from vllm .logger import init_logger
24
- from vllm .utils import round_down
24
+ from vllm .utils import make_zmq_path , make_zmq_socket , round_down
25
25
from vllm .v1 .core .sched .output import SchedulerOutput
26
26
from vllm .v1 .request import RequestStatus
27
27
@@ -379,7 +379,7 @@ def _nixl_handshake_listener(metadata: NixlAgentMetadata,
379
379
# hack to keeps us moving. We will switch when moving to etcd
380
380
# or where we have a single ZMQ socket in the scheduler.
381
381
port = envs .VLLM_NIXL_SIDE_CHANNEL_PORT + rank
382
- path = f "tcp:// { host } : { port } "
382
+ path = make_zmq_path ( "tcp" , host , port )
383
383
logger .debug ("Starting listening on path: %s" , path )
384
384
with zmq_ctx (zmq .ROUTER , path ) as sock :
385
385
ready_event .set ()
@@ -397,7 +397,7 @@ def _nixl_handshake(self, host: str, port: int):
397
397
# NOTE(rob): we need each rank to have a unique port. This is
398
398
# a hack to keep us moving. We will switch when moving to etcd
399
399
# or where we have a single ZMQ socket in the scheduler.
400
- path = f "tcp:// { host } : { port + self .rank } "
400
+ path = make_zmq_path ( "tcp" , host , port + self .rank )
401
401
logger .debug ("Querying metadata on path: %s" , path )
402
402
with zmq_ctx (zmq .REQ , path ) as sock :
403
403
# Send query for the request.
@@ -741,20 +741,16 @@ def _get_block_descs_ids(self, engine_id: str,
741
741
def zmq_ctx (socket_type : Any , addr : str ) -> Iterator [zmq .Socket ]:
742
742
"""Context manager for a ZMQ socket"""
743
743
744
+ if socket_type not in (zmq .ROUTER , zmq .REQ ):
745
+ raise ValueError (f"Unexpected socket type: { socket_type } " )
746
+
744
747
ctx : Optional [zmq .Context ] = None
745
748
try :
746
749
ctx = zmq .Context () # type: ignore[attr-defined]
747
-
748
- if socket_type == zmq .ROUTER :
749
- socket = ctx .socket (zmq .ROUTER )
750
- socket .bind (addr )
751
- elif socket_type == zmq .REQ :
752
- socket = ctx .socket (zmq .REQ )
753
- socket .connect (addr )
754
- else :
755
- raise ValueError (f"Unexpected socket type: { socket_type } " )
756
-
757
- yield socket
750
+ yield make_zmq_socket (ctx = ctx ,
751
+ path = addr ,
752
+ socket_type = socket_type ,
753
+ bind = socket_type == zmq .ROUTER )
758
754
finally :
759
755
if ctx is not None :
760
756
ctx .destroy (linger = 0 )
0 commit comments