Skip to content

Refactor pplx init logic to make it modular (prepare for deepep) #18200

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 32 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
cbe41ef
revert enable_expert_parallel everywhere
youkaichao May 15, 2025
5c6ef5e
tmp
youkaichao May 15, 2025
29aebf6
tmp
youkaichao May 15, 2025
058d8e5
fix typing
youkaichao May 15, 2025
b5dce6f
fix typing
youkaichao May 15, 2025
ed9299a
document options
youkaichao May 15, 2025
88cb9d5
fix inductor
youkaichao May 15, 2025
729239f
fix shutdown error
youkaichao May 15, 2025
1bf90d6
fix shutdown error
youkaichao May 15, 2025
4dc2455
comment
youkaichao May 15, 2025
b680ce9
fix max_num_tokens
youkaichao May 15, 2025
f5c6b57
add comments
youkaichao May 15, 2025
ad70c44
fix max_num_tokens
youkaichao May 15, 2025
b20e977
allow per-layer all2all
youkaichao May 15, 2025
1e60d54
merge into init_prepare_finalize
youkaichao May 16, 2025
15d673b
disable inductor
youkaichao May 16, 2025
63f029b
fix
youkaichao May 16, 2025
d75ac1c
rename to manager
youkaichao May 16, 2025
60499df
fix for non-pplx
youkaichao May 16, 2025
cd6858e
fix for non-pplx
youkaichao May 16, 2025
3f6a862
move prepare_communication_buffer_for_model to base
youkaichao May 16, 2025
d419736
fix no ep case
youkaichao May 16, 2025
5e9d2c9
annotate moe and quant_config
youkaichao May 16, 2025
522ea26
fix typing
youkaichao May 16, 2025
f3fc838
fix init
youkaichao May 16, 2025
c3cd65c
fix cross-node init
youkaichao May 16, 2025
259a724
fix typing
youkaichao May 16, 2025
52945ba
fix typing
youkaichao May 16, 2025
9c73776
fix typing
youkaichao May 16, 2025
5b4095b
meaningful comments
youkaichao May 17, 2025
0944f27
Merge branch 'main' into refactor_pplx
youkaichao May 17, 2025
cfa027b
fix error
youkaichao May 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 119 additions & 2 deletions vllm/distributed/device_communicators/all2all.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,44 @@
# SPDX-License-Identifier: Apache-2.0
import importlib.util
import threading
from typing import TYPE_CHECKING
from weakref import WeakValueDictionary

import torch
import torch.distributed as dist

from vllm.forward_context import get_forward_context
from vllm.logger import init_logger

logger = init_logger(__name__)

if TYPE_CHECKING:
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
else:
FusedMoE = None


class Cache:

def __init__(self):
self._cache: WeakValueDictionary = WeakValueDictionary()
self._lock = threading.RLock() # Reentrant lock for thread safety

def get_or_create(self, kwargs, func):
# Create a hashable key from the kwargs
key = tuple(sorted((k, v) for k, v in kwargs.items()))

with self._lock:
instance = self._cache.get(key)
if instance is None:
instance = func(**kwargs)
self._cache[key] = instance
return instance


class All2AllBase:

def __init__(self, cpu_group, model):
def __init__(self, cpu_group, model: torch.nn.Module):
self.cpu_group = cpu_group

# compute some common properties
Expand All @@ -21,6 +53,8 @@ def __init__(self, cpu_group, model):
self.ep_group = get_ep_group()
self.dp_rank = self.dp_group.rank_in_group
self.dp_world_size = self.dp_group.world_size
self.rank = self.ep_group.rank_in_group
self.world_size = self.ep_group.world_size

# all2all communication often has separate implementations for
# intra-node and inter-node communication
Expand All @@ -46,7 +80,7 @@ class NaiveAll2All(All2AllBase):
debugging.
"""

def __init__(self, cpu_group, model):
def __init__(self, cpu_group, model: torch.nn.Module):
super().__init__(cpu_group, model)

def naive_multicast(self, x: torch.Tensor,
Expand Down Expand Up @@ -91,3 +125,86 @@ def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:

def destroy(self):
pass


class PPLXAll2All(All2AllBase):
"""
All2All communication based on PPLX kernels.
"""

def __init__(self, cpu_group, model: torch.nn.Module):
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
assert has_pplx, "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa
import pplx_kernels as pplx
super().__init__(cpu_group, model)

if self.internode:
# inter-node communication needs nvshmem,
# intra-node communication uses p2p mapping directly
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
nvshmem_get_unique_id,
nvshmem_init)
logger.debug(
"Initialize NVSHMEM for pplx_kernels: "
"rank=%d, world size=%d", self.rank, self.world_size)
uid = nvshmem_get_unique_id(
) if self.rank == 0 else nvshmem_alloc_empty_unique_id()
dist.broadcast(uid,
src=self.ep_group.ranks[0],
group=self.cpu_group)
logger.debug("PPLX NVSHMEM UID = %s", uid)
nvshmem_init(uid, self.rank, self.world_size)

self.handle_cache = Cache()

moe_modules = [
module for module in model.modules()
if module.__class__.__name__ == "FusedMoE"
]
for module in moe_modules:
moe_layer = module
moe = moe_layer.moe_config
max_num_tokens = moe.max_num_tokens

all_to_all_args = dict(
max_num_tokens=max_num_tokens,
num_experts=moe.num_experts,
experts_per_token=moe.experts_per_token, # topk
rank=self.rank,
world_size=self.world_size,
# dp_size actually means tp_size, bug in pplx kernels
dp_size=self.tp_group.world_size,
hidden_dim=moe.hidden_dim,
hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize,
# For blocked per token: set to
# ceil_div(hidden_dim, block_size) * sizeof(float32)
# For per-token: set to sizeof(float32)
hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else (
(moe.hidden_dim + moe.block_size - 1) // moe.block_size *
torch.float32.itemsize)),
group_name=self.cpu_group.group_name,
)

pplx_handle = self.handle_cache.get_or_create(
all_to_all_args, pplx.AllToAll.internode
if self.internode else pplx.AllToAll.intranode)

moe_layer.quant_method.fused_experts.prepare_finalize.a2a \
= pplx_handle
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be cleaner to call _construct_prepare_finalize and stash the result in moe_layer.quant_method.fused_experts.prepare_finalize rather than poking the a2a into the already constructed object.

Maybe make the whole process into a method, e.g. move the _construct_prepare_finalize call + set_prepare_finalize into a single init_prepare_finalize method.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also am generally opposed to shoving things into objects like this, but I do wonder sometimes if I'm colored by my C++ background and should embrace Python's "flexibility" more

@bnellnm what would that look like? Maybe we try it in a followup?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

merged them into init_prepare_finalize now.


def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
raise NotImplementedError

def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
raise NotImplementedError

def destroy(self):
with self.handle_cache._lock:
for _, handle in self.handle_cache._cache.items():
handle.destroy()

if self.internode:
from pplx_kernels.nvshmem import nvshmem_finalize
logger.debug("PPLX NVSHMEM finalize")
nvshmem_finalize()
12 changes: 11 additions & 1 deletion vllm/distributed/device_communicators/cuda_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,14 @@ def __init__(self,
use_pynccl = "ep" not in unique_name

self.use_pynccl = use_pynccl
self.use_all2all = "ep" in unique_name

use_ep = False
from vllm.config import get_current_vllm_config
config = get_current_vllm_config()
if config is not None:
use_ep = config.parallel_config.enable_expert_parallel

self.use_all2all = "ep" in unique_name and use_ep
self.all2all_impl: Optional[All2AllBase] = None
self.use_custom_allreduce = use_custom_allreduce

Expand Down Expand Up @@ -151,6 +158,9 @@ def prepare_communication_buffer_for_model(self,
if all2all_backend == "naive":
from .all2all import NaiveAll2All
self.all2all_impl = NaiveAll2All(self.cpu_group, model)
elif all2all_backend == "pplx":
from .all2all import PPLXAll2All
self.all2all_impl = PPLXAll2All(self.cpu_group, model)

def dispatch(
self, hidden_states: torch.Tensor,
Expand Down
56 changes: 6 additions & 50 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
"""
import contextlib
import gc
import importlib.util
import pickle
import weakref
from collections import namedtuple
Expand All @@ -43,7 +42,7 @@
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname,
run_once, supports_custom_op)
supports_custom_op)


@dataclass
Expand Down Expand Up @@ -769,10 +768,14 @@ def dispatch(
if self.device_communicator is not None:
return self.device_communicator.dispatch(hidden_states,
router_logits)
else:
return hidden_states, router_logits

def combine(self, hidden_states) -> torch.Tensor:
if self.device_communicator is not None:
return self.device_communicator.combine(hidden_states)
else:
return hidden_states


_WORLD: Optional[GroupCoordinator] = None
Expand Down Expand Up @@ -937,49 +940,9 @@ def init_distributed_environment(
"world group already initialized with a different world size")


PPLX_DID_INIT: bool = False


@run_once
def pplx_init(rank, world_size):
has_pplx = importlib.util.find_spec("pplx_kernels") is not None

if has_pplx and world_size > 1:
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
nvshmem_get_unique_id, nvshmem_init)
try:
global PPLX_DID_INIT
logger.debug(
"Initialize NVSHMEM for PPLX kernels: rank=%d, "
"world size=%d", rank, world_size)
uid = nvshmem_get_unique_id(
) if rank == 0 else nvshmem_alloc_empty_unique_id()
uid_gpu = uid.cuda()
get_world_group().broadcast(uid_gpu, src=0)
uid = uid_gpu.to(device='cpu')
logger.debug("PPLX NVSHMEM UID = %s", uid)
nvshmem_init(uid, rank, world_size)
PPLX_DID_INIT = True
except Exception as ex:
logger.error("Failed to initialize NVSHMEM for PPLX: %s", ex)


@run_once
def pplx_finalize():
global PPLX_DID_INIT
if PPLX_DID_INIT:
from pplx_kernels.nvshmem import nvshmem_finalize
logger.debug("PPLX NVSHMEM finalize")
from vllm.model_executor.layers.fused_moe.layer import (
_all_to_all_cache)
_all_to_all_cache.destroy()
nvshmem_finalize()


def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
enable_expert_parallel: bool = False,
backend: Optional[str] = None,
) -> None:
"""
Expand Down Expand Up @@ -1082,14 +1045,10 @@ def initialize_model_parallel(
_DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group,
_EP.rank_in_group)

if enable_expert_parallel:
pplx_init(rank, world_size)


def ensure_model_parallel_initialized(
tensor_model_parallel_size: int,
pipeline_model_parallel_size: int,
enable_expert_parallel: bool = False,
backend: Optional[str] = None,
) -> None:
"""Helper to initialize model parallel groups if they are not initialized,
Expand All @@ -1100,8 +1059,7 @@ def ensure_model_parallel_initialized(
get_world_group().device_group)
if not model_parallel_is_initialized():
initialize_model_parallel(tensor_model_parallel_size,
pipeline_model_parallel_size,
enable_expert_parallel, backend)
pipeline_model_parallel_size, backend)
return

assert (
Expand Down Expand Up @@ -1180,8 +1138,6 @@ def destroy_model_parallel():
"""Set the groups to none and destroy them."""
global _TP

pplx_finalize()

if _TP:
_TP.destroy()
_TP = None
Expand Down
3 changes: 3 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,9 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
lambda: int(os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5557")),

# all2all backend for vllm's expert parallel communication
# Available options:
# - "naive": naive all2all implementation using all-reduce
# - "pplx": use pplx kernels
"VLLM_ALL2ALL_BACKEND":
lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"),
}
Expand Down
Loading