Skip to content

Refactor pplx init logic to make it modular (prepare for deepep) #18200

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 32 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
cbe41ef
revert enable_expert_parallel everywhere
youkaichao May 15, 2025
5c6ef5e
tmp
youkaichao May 15, 2025
29aebf6
tmp
youkaichao May 15, 2025
058d8e5
fix typing
youkaichao May 15, 2025
b5dce6f
fix typing
youkaichao May 15, 2025
ed9299a
document options
youkaichao May 15, 2025
88cb9d5
fix inductor
youkaichao May 15, 2025
729239f
fix shutdown error
youkaichao May 15, 2025
1bf90d6
fix shutdown error
youkaichao May 15, 2025
4dc2455
comment
youkaichao May 15, 2025
b680ce9
fix max_num_tokens
youkaichao May 15, 2025
f5c6b57
add comments
youkaichao May 15, 2025
ad70c44
fix max_num_tokens
youkaichao May 15, 2025
b20e977
allow per-layer all2all
youkaichao May 15, 2025
1e60d54
merge into init_prepare_finalize
youkaichao May 16, 2025
15d673b
disable inductor
youkaichao May 16, 2025
63f029b
fix
youkaichao May 16, 2025
d75ac1c
rename to manager
youkaichao May 16, 2025
60499df
fix for non-pplx
youkaichao May 16, 2025
cd6858e
fix for non-pplx
youkaichao May 16, 2025
3f6a862
move prepare_communication_buffer_for_model to base
youkaichao May 16, 2025
d419736
fix no ep case
youkaichao May 16, 2025
5e9d2c9
annotate moe and quant_config
youkaichao May 16, 2025
522ea26
fix typing
youkaichao May 16, 2025
f3fc838
fix init
youkaichao May 16, 2025
c3cd65c
fix cross-node init
youkaichao May 16, 2025
259a724
fix typing
youkaichao May 16, 2025
52945ba
fix typing
youkaichao May 16, 2025
9c73776
fix typing
youkaichao May 16, 2025
5b4095b
meaningful comments
youkaichao May 17, 2025
0944f27
Merge branch 'main' into refactor_pplx
youkaichao May 17, 2025
cfa027b
fix error
youkaichao May 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 67 additions & 34 deletions vllm/distributed/device_communicators/all2all.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,33 @@
# SPDX-License-Identifier: Apache-2.0
import importlib.util
from typing import TYPE_CHECKING

import torch
import torch.distributed as dist

from vllm.forward_context import get_forward_context
from vllm.logger import init_logger

from .base_device_communicator import All2AllManagerBase, Cache

class All2AllBase:

def __init__(self, cpu_group, model):
self.cpu_group = cpu_group

# compute some common properties
from vllm.distributed.parallel_state import (get_dp_group,
get_ep_group,
get_tp_group,
in_the_same_node_as)

# all2all lives in ep group, which is merged from dp and tp group
self.dp_group = get_dp_group()
self.tp_group = get_tp_group()
self.ep_group = get_ep_group()
self.dp_rank = self.dp_group.rank_in_group
self.dp_world_size = self.dp_group.world_size

# all2all communication often has separate implementations for
# intra-node and inter-node communication
self.intranode = in_the_same_node_as(cpu_group, source_rank=0)
self.internode = not self.intranode

def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
raise NotImplementedError

def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
logger = init_logger(__name__)

def destroy(self):
pass
if TYPE_CHECKING:
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
else:
FusedMoE = None


class NaiveAll2All(All2AllBase):
class NaiveAll2AllManager(All2AllManagerBase):
"""
A naive implementation of all2all communication.
It uses all-reduce under the hood, which is not
efficient at all. The main purpose is for testing and
debugging.
"""

def __init__(self, cpu_group, model):
super().__init__(cpu_group, model)
def __init__(self, cpu_group):
super().__init__(cpu_group)

def naive_multicast(self, x: torch.Tensor,
cu_tokens_across_dp_cpu: torch.Tensor):
Expand Down Expand Up @@ -91,3 +71,56 @@ def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:

def destroy(self):
pass


class PPLXAll2AllManager(All2AllManagerBase):
"""
All2All communication based on PPLX kernels.
"""

def __init__(self, cpu_group):
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
assert has_pplx, "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa
super().__init__(cpu_group)

if self.internode:
# inter-node communication needs nvshmem,
# intra-node communication uses p2p mapping directly
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
nvshmem_get_unique_id,
nvshmem_init)
logger.debug(
"Initialize NVSHMEM for pplx_kernels: "
"rank=%d, world size=%d", self.rank, self.world_size)
uid = nvshmem_get_unique_id(
) if self.rank == 0 else nvshmem_alloc_empty_unique_id()
dist.broadcast(uid,
src=dist.get_process_group_ranks(self.cpu_group)[0],
group=self.cpu_group)
logger.debug("PPLX NVSHMEM UID = %s", uid)
nvshmem_init(uid, self.rank, self.world_size)

self.handle_cache = Cache()

def get_handle(self, kwargs):
import pplx_kernels as pplx
return self.handle_cache.get_or_create(
kwargs, pplx.AllToAll.internode
if self.internode else pplx.AllToAll.intranode)

def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
raise NotImplementedError

def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
raise NotImplementedError

def destroy(self):
with self.handle_cache._lock:
for _, handle in self.handle_cache._cache.items():
handle.destroy()

if self.internode:
from pplx_kernels.nvshmem import nvshmem_finalize
logger.debug("PPLX NVSHMEM finalize")
nvshmem_finalize()
86 changes: 84 additions & 2 deletions vllm/distributed/device_communicators/base_device_communicator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,76 @@
# SPDX-License-Identifier: Apache-2.0
import threading
from typing import Optional
from weakref import WeakValueDictionary

import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup


class Cache:

def __init__(self):
self._cache: WeakValueDictionary = WeakValueDictionary()
self._lock = threading.RLock() # Reentrant lock for thread safety

def get_or_create(self, kwargs, func):
# Create a hashable key from the kwargs
key = tuple(sorted((k, v) for k, v in kwargs.items()))

with self._lock:
instance = self._cache.get(key)
if instance is None:
instance = func(**kwargs)
self._cache[key] = instance
return instance


class All2AllManagerBase:

def __init__(self, cpu_group):
self.cpu_group = cpu_group

# compute some common properties
from vllm.distributed.parallel_state import (get_dp_group,
get_tp_group,
in_the_same_node_as)

# all2all lives in ep group, which is merged from dp and tp group
self.dp_group = get_dp_group()
self.tp_group = get_tp_group()
# no self.ep_group since self.ep_group is still in construction
# when we create this object
self.dp_rank = self.dp_group.rank_in_group
self.dp_world_size = self.dp_group.world_size
self.rank = dist.get_rank(cpu_group)
self.world_size = dist.get_world_size(cpu_group)

# all2all communication often has separate implementations for
# intra-node and inter-node communication
self.intranode = in_the_same_node_as(cpu_group, source_rank=0)
self.internode = not self.intranode

def get_handle(self, kwargs):
# get a handle for the all2all communication,
# based on the kwargs.
# different layers can have different configs,
# e.g. one layer has hidden size 1024, another has 2048.
# usually the underlying implementation caches the handle
# and reuse it for the same config.
raise NotImplementedError

def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
raise NotImplementedError

def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
raise NotImplementedError

def destroy(self):
pass


class DeviceCommunicatorBase:
"""
Base class for device-specific communicator.
Expand All @@ -31,6 +96,15 @@ def __init__(self,
self.rank_in_group = dist.get_group_rank(self.cpu_group,
self.global_rank)

use_ep = False
from vllm.config import get_current_vllm_config
config = get_current_vllm_config()
if config is not None:
use_ep = config.parallel_config.enable_expert_parallel

self.use_all2all = "ep" in unique_name and use_ep
self.all2all_manager: Optional[All2AllManagerBase] = None

def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
dist.all_reduce(input_, group=self.device_group)
return input_
Expand Down Expand Up @@ -154,9 +228,17 @@ def prepare_communication_buffer_for_model(self,
model: torch.nn.Module) -> None:
"""
Prepare the communication buffer for the model.
This is a no-op in the base class.
"""
pass
if not self.use_all2all:
return

moe_modules = [
module for module in model.modules()
if module.__class__.__name__ == "FusedMoE"
]
for module in moe_modules:
module.quant_method.init_prepare_finalize(module.moe_config,
module.quant_config)

def dispatch(
self, hidden_states: torch.Tensor,
Expand Down
45 changes: 23 additions & 22 deletions vllm/distributed/device_communicators/cuda_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
from torch.distributed import ProcessGroup

import vllm.envs as envs
from vllm.logger import init_logger

from .all2all import All2AllBase
from .base_device_communicator import DeviceCommunicatorBase

logger = init_logger(__name__)


class CudaCommunicator(DeviceCommunicatorBase):

Expand All @@ -31,8 +33,6 @@ def __init__(self,
use_pynccl = "ep" not in unique_name

self.use_pynccl = use_pynccl
self.use_all2all = "ep" in unique_name
self.all2all_impl: Optional[All2AllBase] = None
self.use_custom_allreduce = use_custom_allreduce

# lazy import to avoid documentation build error
Expand All @@ -56,6 +56,19 @@ def __init__(self,
device=self.device,
)

if self.use_all2all:
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
if all2all_backend == "naive":
from .all2all import NaiveAll2AllManager
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
logger.info("Using naive all2all manager.")
elif all2all_backend == "pplx":
from .all2all import PPLXAll2AllManager
self.all2all_manager = PPLXAll2AllManager(self.cpu_group)
logger.info("Using PPLX all2all manager.")
else:
raise ValueError(f"Unknown all2all backend: {all2all_backend}")

def all_reduce(self, input_):
# always try custom allreduce first,
# and then pynccl.
Expand Down Expand Up @@ -136,31 +149,19 @@ def destroy(self):
self.pynccl_comm = None
if self.ca_comm is not None:
self.ca_comm = None
if self.all2all_impl is not None:
self.all2all_impl.destroy()
self.all2all_impl = None

def prepare_communication_buffer_for_model(self,
model: torch.nn.Module) -> None:
"""
Prepare the communication buffer for the model.
"""
if not self.use_all2all:
return
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
if all2all_backend == "naive":
from .all2all import NaiveAll2All
self.all2all_impl = NaiveAll2All(self.cpu_group, model)
if self.all2all_manager is not None:
self.all2all_manager.destroy()
self.all2all_manager = None

def dispatch(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
assert self.all2all_impl is not None
hidden_states, router_logits = self.all2all_impl.dispatch(
assert self.all2all_manager is not None
hidden_states, router_logits = self.all2all_manager.dispatch(
hidden_states, router_logits)
return hidden_states, router_logits

def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
assert self.all2all_impl is not None
hidden_states = self.all2all_impl.combine(hidden_states)
assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine(hidden_states)
return hidden_states
Comment on lines 156 to 167
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would these methods work if we weren't using the naive manager? e.g. the pplx all2all object might have a different instance for each layer.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes so the dispatch/combine in the DeviceCommunicatorBase is not used for pplx kernel, and I agree with your prepare/finalize call inside every layer now. I will try to remove dispatch/combine in the DeviceCommunicatorBase.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will try to remove dispatch/combine in the DeviceCommunicatorBase

this would be in a future PR, and we need to have prepare_finalize for the naive all2all implementation, then we can remove these functions in DeviceCommunicatorBase

Loading