Skip to content

Commit a8f5aec

Browse files
authored
[V1] Update zmq socket creation in nixl connector (#18148)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
1 parent de71fec commit a8f5aec

File tree

3 files changed

+34
-15
lines changed

3 files changed

+34
-15
lines changed

tests/test_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache,
1818
MemorySnapshot, PlaceholderModule, StoreBoolean,
1919
bind_kv_cache, deprecate_kwargs, get_open_port,
20-
make_zmq_socket, memory_profiling,
20+
make_zmq_path, make_zmq_socket, memory_profiling,
2121
merge_async_iterators, sha256, split_zmq_path,
2222
supports_kw, swap_dict_values)
2323

@@ -714,3 +714,8 @@ def test_make_zmq_socket_ipv6():
714714
# Clean up
715715
zsock.close()
716716
ctx.term()
717+
718+
719+
def test_make_zmq_path():
720+
assert make_zmq_path("tcp", "127.0.0.1", "5555") == "tcp://127.0.0.1:5555"
721+
assert make_zmq_path("tcp", "::1", "5555") == "tcp://[::1]:5555"

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
2222
get_tp_group)
2323
from vllm.logger import init_logger
24-
from vllm.utils import round_down
24+
from vllm.utils import make_zmq_path, make_zmq_socket, round_down
2525
from vllm.v1.core.sched.output import SchedulerOutput
2626
from vllm.v1.request import RequestStatus
2727

@@ -379,7 +379,7 @@ def _nixl_handshake_listener(metadata: NixlAgentMetadata,
379379
# hack to keeps us moving. We will switch when moving to etcd
380380
# or where we have a single ZMQ socket in the scheduler.
381381
port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + rank
382-
path = f"tcp://{host}:{port}"
382+
path = make_zmq_path("tcp", host, port)
383383
logger.debug("Starting listening on path: %s", path)
384384
with zmq_ctx(zmq.ROUTER, path) as sock:
385385
ready_event.set()
@@ -397,7 +397,7 @@ def _nixl_handshake(self, host: str, port: int):
397397
# NOTE(rob): we need each rank to have a unique port. This is
398398
# a hack to keep us moving. We will switch when moving to etcd
399399
# or where we have a single ZMQ socket in the scheduler.
400-
path = f"tcp://{host}:{port + self.rank}"
400+
path = make_zmq_path("tcp", host, port + self.rank)
401401
logger.debug("Querying metadata on path: %s", path)
402402
with zmq_ctx(zmq.REQ, path) as sock:
403403
# Send query for the request.
@@ -741,20 +741,16 @@ def _get_block_descs_ids(self, engine_id: str,
741741
def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]:
742742
"""Context manager for a ZMQ socket"""
743743

744+
if socket_type not in (zmq.ROUTER, zmq.REQ):
745+
raise ValueError(f"Unexpected socket type: {socket_type}")
746+
744747
ctx: Optional[zmq.Context] = None
745748
try:
746749
ctx = zmq.Context() # type: ignore[attr-defined]
747-
748-
if socket_type == zmq.ROUTER:
749-
socket = ctx.socket(zmq.ROUTER)
750-
socket.bind(addr)
751-
elif socket_type == zmq.REQ:
752-
socket = ctx.socket(zmq.REQ)
753-
socket.connect(addr)
754-
else:
755-
raise ValueError(f"Unexpected socket type: {socket_type}")
756-
757-
yield socket
750+
yield make_zmq_socket(ctx=ctx,
751+
path=addr,
752+
socket_type=socket_type,
753+
bind=socket_type == zmq.ROUTER)
758754
finally:
759755
if ctx is not None:
760756
ctx.destroy(linger=0)

vllm/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2350,6 +2350,24 @@ def split_zmq_path(path: str) -> Tuple[str, str, str]:
23502350
return scheme, host, port
23512351

23522352

2353+
def make_zmq_path(scheme: str, host: str, port: Optional[int] = None) -> str:
2354+
"""Make a ZMQ path from its parts.
2355+
2356+
Args:
2357+
scheme: The ZMQ transport scheme (e.g. tcp, ipc, inproc).
2358+
host: The host - can be an IPv4 address, IPv6 address, or hostname.
2359+
port: Optional port number, only used for TCP sockets.
2360+
2361+
Returns:
2362+
A properly formatted ZMQ path string.
2363+
"""
2364+
if not port:
2365+
return f"{scheme}://{host}"
2366+
if is_valid_ipv6_address(host):
2367+
return f"{scheme}://[{host}]:{port}"
2368+
return f"{scheme}://{host}:{port}"
2369+
2370+
23532371
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L783 # noqa: E501
23542372
def make_zmq_socket(
23552373
ctx: Union[zmq.asyncio.Context, zmq.Context], # type: ignore[name-defined]

0 commit comments

Comments
 (0)