From cbe41ef69e399407907809cea9e210dd1f1538f6 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 15 May 2025 15:15:11 +0800 Subject: [PATCH 01/31] revert enable_expert_parallel everywhere Signed-off-by: youkaichao --- vllm/distributed/parallel_state.py | 7 +++---- vllm/v1/worker/gpu_worker.py | 3 +-- vllm/v1/worker/tpu_worker.py | 3 +-- vllm/worker/cpu_worker.py | 3 +-- vllm/worker/hpu_worker.py | 6 ++---- vllm/worker/tpu_worker.py | 3 +-- vllm/worker/worker.py | 3 +-- vllm/worker/xpu_worker.py | 3 +-- 8 files changed, 11 insertions(+), 20 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 51c519d8f86..9b293f6459f 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -979,7 +979,6 @@ def pplx_finalize(): def initialize_model_parallel( tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, - enable_expert_parallel: bool = False, backend: Optional[str] = None, ) -> None: """ @@ -1012,10 +1011,12 @@ def initialize_model_parallel( get_world_group().device_group) data_parallel_size = 1 + enable_expert_parallel = False from vllm.config import get_current_vllm_config config = get_current_vllm_config() if config is not None: data_parallel_size = config.parallel_config.data_parallel_size + enable_expert_parallel = config.parallel_config.enable_expert_parallel # the layout order is: ExternalDP x DP x PP x TP # ExternalDP is the data parallel group that is not part of the model, @@ -1089,7 +1090,6 @@ def initialize_model_parallel( def ensure_model_parallel_initialized( tensor_model_parallel_size: int, pipeline_model_parallel_size: int, - enable_expert_parallel: bool = False, backend: Optional[str] = None, ) -> None: """Helper to initialize model parallel groups if they are not initialized, @@ -1100,8 +1100,7 @@ def ensure_model_parallel_initialized( get_world_group().device_group) if not model_parallel_is_initialized(): initialize_model_parallel(tensor_model_parallel_size, - pipeline_model_parallel_size, - enable_expert_parallel, backend) + pipeline_model_parallel_size, backend) return assert ( diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index d85701fa93d..5352b1c5a37 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -341,8 +341,7 @@ def init_worker_distributed_environment( distributed_init_method, local_rank) ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size, - parallel_config.enable_expert_parallel) + parallel_config.pipeline_parallel_size) ensure_kv_transfer_initialized(vllm_config) diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 25715407cee..9eea26d8524 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -265,5 +265,4 @@ def init_tpu_worker_distributed_environment( backend="gloo", ) ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size, - parallel_config.enable_expert_parallel) + parallel_config.pipeline_parallel_size) diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index a92cf1e5a3b..1436a404335 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -390,8 +390,7 @@ def init_distributed_environment(self) -> None: ensure_model_parallel_initialized( parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size, - parallel_config.enable_expert_parallel) + parallel_config.pipeline_parallel_size) def get_cache_block_size_bytes(self) -> int: """Return the size in bytes of a single KV cache block. diff --git a/vllm/worker/hpu_worker.py b/vllm/worker/hpu_worker.py index 42882992f2d..7898c645d66 100644 --- a/vllm/worker/hpu_worker.py +++ b/vllm/worker/hpu_worker.py @@ -416,8 +416,7 @@ def init_worker_distributed_environment( backend='hccl') ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size, - parallel_config.enable_expert_parallel) + parallel_config.pipeline_parallel_size) if torch.distributed.is_initialized(): torch_world_size = torch.distributed.get_world_size() @@ -443,8 +442,7 @@ def init_worker_distributed_environment( torch.distributed.all_reduce(dummy_tensor_hpu) assert dummy_tensor_hpu.item() == parallel_config.world_size ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size, - parallel_config.enable_expert_parallel) + parallel_config.pipeline_parallel_size) def raise_if_cache_size_invalid(num_gpu_blocks, block_size, max_model_len, diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 891ed66599d..4bb9bea022f 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -76,8 +76,7 @@ def init_device(self) -> None: ) ensure_model_parallel_initialized( self.parallel_config.tensor_parallel_size, - self.parallel_config.pipeline_parallel_size, - self.parallel_config.enable_expert_parallel) + self.parallel_config.pipeline_parallel_size) # Device initialization should happen after initializing the distributed # runtime. diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 41546462e5c..17f636765ff 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -530,8 +530,7 @@ def init_worker_distributed_environment( init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank) ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size, - parallel_config.enable_expert_parallel) + parallel_config.pipeline_parallel_size) ensure_kv_transfer_initialized(vllm_config) diff --git a/vllm/worker/xpu_worker.py b/vllm/worker/xpu_worker.py index 65085f80f97..17f53352517 100644 --- a/vllm/worker/xpu_worker.py +++ b/vllm/worker/xpu_worker.py @@ -176,8 +176,7 @@ def init_worker_distributed_environment(self) -> None: ensure_model_parallel_initialized( parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size, - parallel_config.enable_expert_parallel) + parallel_config.pipeline_parallel_size) # global all_reduce needed for overall oneccl warm up torch.distributed.all_reduce(torch.zeros(1).xpu()) From 5c6ef5ec0e8d2076bfbd383f54ec2a26a7994fe4 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 15 May 2025 19:01:59 +0800 Subject: [PATCH 02/31] tmp Signed-off-by: youkaichao --- .../device_communicators/all2all.py | 91 ++++++++++++++++++- .../device_communicators/cuda_communicator.py | 12 ++- vllm/distributed/parallel_state.py | 53 +---------- vllm/model_executor/layers/fused_moe/layer.py | 66 +------------- .../layers/fused_moe/pplx_prepare_finalize.py | 1 - 5 files changed, 109 insertions(+), 114 deletions(-) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index b69647b0058..d12dc67e588 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -1,12 +1,22 @@ # SPDX-License-Identifier: Apache-2.0 +import importlib.util +from typing import TYPE_CHECKING + import torch +import torch.distributed as dist from vllm.forward_context import get_forward_context +from vllm.logger import init_logger + +logger = init_logger() + +if TYPE_CHECKING: + from vllm.model_executor.layers.fused_moe.layer import FusedMoE class All2AllBase: - def __init__(self, cpu_group, model): + def __init__(self, cpu_group, model: torch.nn.Module): self.cpu_group = cpu_group # compute some common properties @@ -21,6 +31,8 @@ def __init__(self, cpu_group, model): self.ep_group = get_ep_group() self.dp_rank = self.dp_group.rank_in_group self.dp_world_size = self.dp_group.world_size + self.rank = self.ep_group.rank_in_group + self.world_size = self.ep_group.world_size # all2all communication often has separate implementations for # intra-node and inter-node communication @@ -46,7 +58,7 @@ class NaiveAll2All(All2AllBase): debugging. """ - def __init__(self, cpu_group, model): + def __init__(self, cpu_group, model: torch.nn.Module): super().__init__(cpu_group, model) def naive_multicast(self, x: torch.Tensor, @@ -91,3 +103,78 @@ def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: def destroy(self): pass + + +class PPLXAll2All(All2AllBase): + """ + All2All communication based on PPLX kernels. + """ + + def __init__(self, cpu_group, model: torch.nn.Module): + has_pplx = importlib.util.find_spec("pplx_kernels") is not None + assert has_pplx, "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa + import pplx_kernels as pplx + super().__init__(cpu_group, model) + moe_layer: FusedMoE = None + for module in model.modules(): + if module.__class__.__name__ == "FusedMoE": + moe_layer = module + break + # assume all MoE layers have the same config + moe = moe_layer.moe_config + MOE_DP_CHUNK_SIZE = 256 + max_num_tokens = MOE_DP_CHUNK_SIZE + + all_to_all_args = dict( + max_num_tokens=max_num_tokens, + num_experts=moe.num_experts, + experts_per_token=moe.experts_per_token, # topk + rank=self.rank, + world_size=self.world_size, + dp_size=self.tp_group. + world_size, # dp_size actually means tp_size, bug in pplx kernels + hidden_dim=moe.hidden_dim, + hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, + # For blocked per token: set to + # ceil_div(hidden_dim, block_size) * sizeof(float32) + # For per-token: set to sizeof(float32) + hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else + ((moe.hidden_dim + moe.block_size - 1) // + moe.block_size * torch.float32.itemsize))) + + if self.internode: + from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, + nvshmem_get_unique_id, + nvshmem_init) + logger.debug( + "Initialize NVSHMEM for pplx_kernels: " + "rank=%d, world size=%d", self.rank, self.world_size) + uid = nvshmem_get_unique_id( + ) if self.rank == 0 else nvshmem_alloc_empty_unique_id() + dist.broadcast(uid, + src=self.ep_group.ranks[0], + group=self.cpu_group) + logger.debug("PPLX NVSHMEM UID = %s", uid) + nvshmem_init(uid, self.rank, self.world_size) + self.pplx_handle = pplx.AllToAll.internode(**all_to_all_args) + else: + self.pplx_handle = pplx.AllToAll.intranode(**all_to_all_args) + + # TODO: refactor the initialization logic + for module in model.modules(): + if module.__class__.__name__ == "FusedMoE": + module.quant_method.fused_experts.prepare_finalize.a2a \ + = self.pplx_handle + + def dispatch(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): + raise NotImplementedError + + def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + def destroy(self): + self.pplx_handle.destroy() + from pplx_kernels.nvshmem import nvshmem_finalize + logger.debug("PPLX NVSHMEM finalize") + nvshmem_finalize() diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 13303f94b8e..cdf60cfece8 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -31,7 +31,14 @@ def __init__(self, use_pynccl = "ep" not in unique_name self.use_pynccl = use_pynccl - self.use_all2all = "ep" in unique_name + + use_ep = False + from vllm.config import get_current_vllm_config + config = get_current_vllm_config() + if config is not None: + use_ep = config.parallel_config.enable_expert_parallel + + self.use_all2all = "ep" in unique_name and use_ep self.all2all_impl: Optional[All2AllBase] = None self.use_custom_allreduce = use_custom_allreduce @@ -151,6 +158,9 @@ def prepare_communication_buffer_for_model(self, if all2all_backend == "naive": from .all2all import NaiveAll2All self.all2all_impl = NaiveAll2All(self.cpu_group, model) + elif all2all_backend == "pplx": + from .all2all import PPLXAll2All + self.all2all_impl = PPLXAll2All(self.cpu_group, model) def dispatch( self, hidden_states: torch.Tensor, diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 9b293f6459f..e1c4e2e26ba 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -23,7 +23,6 @@ """ import contextlib import gc -import importlib.util import pickle import weakref from collections import namedtuple @@ -43,7 +42,7 @@ from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname, - run_once, supports_custom_op) + supports_custom_op) @dataclass @@ -769,10 +768,14 @@ def dispatch( if self.device_communicator is not None: return self.device_communicator.dispatch(hidden_states, router_logits) + else: + return hidden_states, router_logits def combine(self, hidden_states) -> torch.Tensor: if self.device_communicator is not None: return self.device_communicator.combine(hidden_states) + else: + return hidden_states _WORLD: Optional[GroupCoordinator] = None @@ -937,45 +940,6 @@ def init_distributed_environment( "world group already initialized with a different world size") -PPLX_DID_INIT: bool = False - - -@run_once -def pplx_init(rank, world_size): - has_pplx = importlib.util.find_spec("pplx_kernels") is not None - - if has_pplx and world_size > 1: - from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, - nvshmem_get_unique_id, nvshmem_init) - try: - global PPLX_DID_INIT - logger.debug( - "Initialize NVSHMEM for PPLX kernels: rank=%d, " - "world size=%d", rank, world_size) - uid = nvshmem_get_unique_id( - ) if rank == 0 else nvshmem_alloc_empty_unique_id() - uid_gpu = uid.cuda() - get_world_group().broadcast(uid_gpu, src=0) - uid = uid_gpu.to(device='cpu') - logger.debug("PPLX NVSHMEM UID = %s", uid) - nvshmem_init(uid, rank, world_size) - PPLX_DID_INIT = True - except Exception as ex: - logger.error("Failed to initialize NVSHMEM for PPLX: %s", ex) - - -@run_once -def pplx_finalize(): - global PPLX_DID_INIT - if PPLX_DID_INIT: - from pplx_kernels.nvshmem import nvshmem_finalize - logger.debug("PPLX NVSHMEM finalize") - from vllm.model_executor.layers.fused_moe.layer import ( - _all_to_all_cache) - _all_to_all_cache.destroy() - nvshmem_finalize() - - def initialize_model_parallel( tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, @@ -1011,12 +975,10 @@ def initialize_model_parallel( get_world_group().device_group) data_parallel_size = 1 - enable_expert_parallel = False from vllm.config import get_current_vllm_config config = get_current_vllm_config() if config is not None: data_parallel_size = config.parallel_config.data_parallel_size - enable_expert_parallel = config.parallel_config.enable_expert_parallel # the layout order is: ExternalDP x DP x PP x TP # ExternalDP is the data parallel group that is not part of the model, @@ -1083,9 +1045,6 @@ def initialize_model_parallel( _DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group, _EP.rank_in_group) - if enable_expert_parallel: - pplx_init(rank, world_size) - def ensure_model_parallel_initialized( tensor_model_parallel_size: int, @@ -1179,8 +1138,6 @@ def destroy_model_parallel(): """Set the groups to none and destroy them.""" global _TP - pplx_finalize() - if _TP: _TP.destroy() _TP = None diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 0b3c02d1ba2..4ea2e80b80a 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1,12 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 import importlib -import threading from abc import abstractmethod from dataclasses import dataclass from enum import Enum from typing import Callable, Optional -from weakref import WeakValueDictionary import torch import torch.nn.functional as F @@ -74,7 +72,8 @@ class FusedMoEParallelConfig: @property def use_pplx_kernels(self): - return self.dp_size > 1 and self.use_ep and has_pplx + return self.dp_size > 1 and self.use_ep and \ + envs.VLLM_ALL2ALL_BACKEND == "pplx" @staticmethod def make(tp_size_: int, dp_size_: int, @@ -275,46 +274,6 @@ def apply( raise NotImplementedError -class AllToAllCache: - - def __init__(self): - self._cache: WeakValueDictionary = WeakValueDictionary() - self._lock = threading.RLock() # Reentrant lock for thread safety - - def destroy(self): - with self._lock: - # TODO: can we do del self._cache? - for _, a2a in self._cache.items(): - a2a.destroy() - - def get_or_create(self, **kwargs): - assert has_pplx - import pplx_kernels as pplx - - # Create a hashable key from the kwargs - key = tuple(sorted((k, v) for k, v in kwargs.items())) - - with self._lock: - instance = self._cache.get(key) - if instance is None: - # TODO (varun): Add support to switch to intranode - # when all communications are within the same - # node. - logger.debug("Create AllToAll %s", kwargs) - instance = pplx.AllToAll.internode(**kwargs) - self._cache[key] = instance - return instance - - -# Global singleton -_all_to_all_cache = AllToAllCache() - - -# Factory function as a cleaner interface -def get_all_to_all(**kwargs): - return _all_to_all_cache.get_or_create(**kwargs) - - @CustomOp.register("unquantized_fused_moe") class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" @@ -690,26 +649,8 @@ def _construct_prepare_finalize( rank = moe.ep_rank if moe.use_pplx_kernels: - logger.debug("using PplxPrepareAndFinalize") - - all_to_all = get_all_to_all( - max_num_tokens=max_num_tokens, - num_experts=moe.num_experts, - experts_per_token=moe.experts_per_token, # topk - rank=rank, - world_size=world_size, - dp_size=dp_size, - hidden_dim=moe.hidden_dim, - hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, - # For blocked per token: set to - # ceil_div(hidden_dim, block_size) * sizeof(float32) - # For per-token: set to sizeof(float32) - hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else - ((moe.hidden_dim + moe.block_size - 1) // - moe.block_size * torch.float32.itemsize))) - return PplxPrepareAndFinalize( - all_to_all, + None, max_num_tokens=max_num_tokens, world_size=world_size, rank=rank, @@ -834,6 +775,7 @@ def __init__( # TODO (bnell): this needs to be fixed for quantized types. in_dtype=params_dtype, ) + self.moe_config = moe # Note: get_quant_method will look at the layer's local_num_experts # for heuristic purposes, so it must be initialized first. diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index b1126b94e45..783ebebbfec 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -9,7 +9,6 @@ moe_kernel_quantize_input) -# Note use: layer.get_all_to_all() to get an AllToAll instance # The max_num_tokens, world_size and dp_size must be the same # as the ones used to create the AllToAll. class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): From 29aebf6709fdfbae06bcf6782f8754d77c4f396d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 15 May 2025 19:03:45 +0800 Subject: [PATCH 03/31] tmp Signed-off-by: youkaichao --- vllm/distributed/device_communicators/all2all.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index d12dc67e588..92bcfff9aad 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -140,7 +140,9 @@ def __init__(self, cpu_group, model: torch.nn.Module): # For per-token: set to sizeof(float32) hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else ((moe.hidden_dim + moe.block_size - 1) // - moe.block_size * torch.float32.itemsize))) + moe.block_size * torch.float32.itemsize)), + group_name=self.cpu_group.group_name, + ) if self.internode: from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, From 058d8e50dae22befcbcfe367b9c383cbd6d0a0a1 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 15 May 2025 19:09:18 +0800 Subject: [PATCH 04/31] fix typing Signed-off-by: youkaichao --- vllm/distributed/device_communicators/all2all.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 92bcfff9aad..81aceb6895d 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -12,6 +12,8 @@ if TYPE_CHECKING: from vllm.model_executor.layers.fused_moe.layer import FusedMoE +else: + FusedMoE = None class All2AllBase: From b5dce6f24db917fecde7bc61a567690997ab460f Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 15 May 2025 19:10:41 +0800 Subject: [PATCH 05/31] fix typing Signed-off-by: youkaichao --- vllm/distributed/device_communicators/all2all.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 81aceb6895d..f2d53f232a1 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -8,7 +8,7 @@ from vllm.forward_context import get_forward_context from vllm.logger import init_logger -logger = init_logger() +logger = init_logger(__name__) if TYPE_CHECKING: from vllm.model_executor.layers.fused_moe.layer import FusedMoE From ed9299a18366bf1e2d000b93b0b1d14571fd68f6 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 15 May 2025 19:11:45 +0800 Subject: [PATCH 06/31] document options Signed-off-by: youkaichao --- vllm/envs.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/envs.py b/vllm/envs.py index fe3fa91fbe3..b122f577778 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -778,6 +778,9 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: lambda: int(os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5557")), # all2all backend for vllm's expert parallel communication + # Available options: + # - "naive": naive all2all implementation using all-reduce + # - "pplx": use pplx kernels "VLLM_ALL2ALL_BACKEND": lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"), } From 88cb9d5382cd4d6d0a609012531df4619ae02d17 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 15 May 2025 19:22:26 +0800 Subject: [PATCH 07/31] fix inductor Signed-off-by: youkaichao --- vllm/platforms/cuda.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index bdee8b2f821..9163b97c51a 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -158,7 +158,6 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: "currently not supported with CUDA Graphs.") vllm_config.model_config.enforce_eager = True compilation_config.use_cudagraph = False - compilation_config.use_inductor = False @classmethod def get_current_memory_usage(cls, From 729239f17023d47abfba55285e258ef68a98e05f Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 15 May 2025 19:24:20 +0800 Subject: [PATCH 08/31] fix shutdown error Signed-off-by: youkaichao --- vllm/distributed/device_communicators/all2all.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index f2d53f232a1..86ff8dc5302 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -179,6 +179,8 @@ def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: def destroy(self): self.pplx_handle.destroy() + torch.cuda.synchronize() + torch.distributed.barrier(self.cpu_group) from pplx_kernels.nvshmem import nvshmem_finalize logger.debug("PPLX NVSHMEM finalize") nvshmem_finalize() From 1bf90d65c618cf81caf54ec23d5ee52851dfb470 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 15 May 2025 19:25:07 +0800 Subject: [PATCH 09/31] fix shutdown error Signed-off-by: youkaichao --- vllm/distributed/device_communicators/all2all.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 86ff8dc5302..0b4451ee836 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -179,8 +179,7 @@ def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: def destroy(self): self.pplx_handle.destroy() - torch.cuda.synchronize() - torch.distributed.barrier(self.cpu_group) - from pplx_kernels.nvshmem import nvshmem_finalize - logger.debug("PPLX NVSHMEM finalize") - nvshmem_finalize() + if self.internode: + from pplx_kernels.nvshmem import nvshmem_finalize + logger.debug("PPLX NVSHMEM finalize") + nvshmem_finalize() From 4dc24559364eb85183848ecc2990b32d1529fdcf Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 15 May 2025 19:32:15 +0800 Subject: [PATCH 10/31] comment Signed-off-by: youkaichao --- vllm/model_executor/layers/fused_moe/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 4ea2e80b80a..9e0aa174d70 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -650,7 +650,7 @@ def _construct_prepare_finalize( if moe.use_pplx_kernels: return PplxPrepareAndFinalize( - None, + None, # will be set later in prepare_communication_buffer_for_model max_num_tokens=max_num_tokens, world_size=world_size, rank=rank, From b680ce94379ee511339dc082132ee850f26da8f0 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 15 May 2025 19:38:43 +0800 Subject: [PATCH 11/31] fix max_num_tokens Signed-off-by: youkaichao --- vllm/distributed/device_communicators/all2all.py | 3 +-- vllm/model_executor/layers/fused_moe/layer.py | 6 ++++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 0b4451ee836..a2de64c29ef 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -124,8 +124,7 @@ def __init__(self, cpu_group, model: torch.nn.Module): break # assume all MoE layers have the same config moe = moe_layer.moe_config - MOE_DP_CHUNK_SIZE = 256 - max_num_tokens = MOE_DP_CHUNK_SIZE + max_num_tokens = moe.max_num_tokens all_to_all_args = dict( max_num_tokens=max_num_tokens, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 9e0aa174d70..faf31c89d61 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -196,6 +196,8 @@ class MoEConfig: # TODO: add more quantization params, blocked, per-token, etc. block_size: int = 128 + max_num_tokens: int + @property def tp_size(self): return self.moe_parallel_config.tp_size @@ -643,7 +645,6 @@ def determine_expert_map( def _construct_prepare_finalize( moe: MoEConfig, quant_config: Optional[QuantizationConfig] ) -> Optional[FusedMoEPrepareAndFinalize]: - max_num_tokens = MOE_DP_CHUNK_SIZE world_size = moe.ep_size dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP. rank = moe.ep_rank @@ -651,7 +652,7 @@ def _construct_prepare_finalize( if moe.use_pplx_kernels: return PplxPrepareAndFinalize( None, # will be set later in prepare_communication_buffer_for_model - max_num_tokens=max_num_tokens, + max_num_tokens=moe.max_num_tokens, world_size=world_size, rank=rank, dp_size=dp_size, @@ -774,6 +775,7 @@ def __init__( moe_parallel_config=self.moe_parallel_config, # TODO (bnell): this needs to be fixed for quantized types. in_dtype=params_dtype, + max_num_tokens=MOE_DP_CHUNK_SIZE, ) self.moe_config = moe From f5c6b57641f6b06bad45c1440151146988ac73c6 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 15 May 2025 19:40:56 +0800 Subject: [PATCH 12/31] add comments Signed-off-by: youkaichao --- vllm/distributed/device_communicators/all2all.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index a2de64c29ef..1e9a0ffc2f3 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -146,6 +146,7 @@ def __init__(self, cpu_group, model: torch.nn.Module): ) if self.internode: + # inter-node communication needs nvshmem from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, nvshmem_get_unique_id, nvshmem_init) @@ -161,6 +162,7 @@ def __init__(self, cpu_group, model: torch.nn.Module): nvshmem_init(uid, self.rank, self.world_size) self.pplx_handle = pplx.AllToAll.internode(**all_to_all_args) else: + # intra-node communication uses p2p mapping directly self.pplx_handle = pplx.AllToAll.intranode(**all_to_all_args) # TODO: refactor the initialization logic From ad70c44c4b37521c9f4b2f8569e8489ad830f37e Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 15 May 2025 19:42:21 +0800 Subject: [PATCH 13/31] fix max_num_tokens Signed-off-by: youkaichao --- vllm/model_executor/layers/fused_moe/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index faf31c89d61..bffd87702e9 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -196,7 +196,7 @@ class MoEConfig: # TODO: add more quantization params, blocked, per-token, etc. block_size: int = 128 - max_num_tokens: int + max_num_tokens: int = MOE_DP_CHUNK_SIZE @property def tp_size(self): From b20e977e6f7765d1db71f86430613aad6a1cd414 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 15 May 2025 21:10:56 +0800 Subject: [PATCH 14/31] allow per-layer all2all Signed-off-by: youkaichao --- .../device_communicators/all2all.py | 100 +++++++++++------- 1 file changed, 62 insertions(+), 38 deletions(-) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 1e9a0ffc2f3..7b6b8fc2f8a 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 import importlib.util +import threading from typing import TYPE_CHECKING +from weakref import WeakValueDictionary import torch import torch.distributed as dist @@ -16,6 +18,24 @@ FusedMoE = None +class Cache: + + def __init__(self): + self._cache: WeakValueDictionary = WeakValueDictionary() + self._lock = threading.RLock() # Reentrant lock for thread safety + + def get_or_create(self, kwargs, func): + # Create a hashable key from the kwargs + key = tuple(sorted((k, v) for k, v in kwargs.items())) + + with self._lock: + instance = self._cache.get(key) + if instance is None: + instance = func(**kwargs) + self._cache[key] = instance + return instance + + class All2AllBase: def __init__(self, cpu_group, model: torch.nn.Module): @@ -117,36 +137,10 @@ def __init__(self, cpu_group, model: torch.nn.Module): assert has_pplx, "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa import pplx_kernels as pplx super().__init__(cpu_group, model) - moe_layer: FusedMoE = None - for module in model.modules(): - if module.__class__.__name__ == "FusedMoE": - moe_layer = module - break - # assume all MoE layers have the same config - moe = moe_layer.moe_config - max_num_tokens = moe.max_num_tokens - - all_to_all_args = dict( - max_num_tokens=max_num_tokens, - num_experts=moe.num_experts, - experts_per_token=moe.experts_per_token, # topk - rank=self.rank, - world_size=self.world_size, - dp_size=self.tp_group. - world_size, # dp_size actually means tp_size, bug in pplx kernels - hidden_dim=moe.hidden_dim, - hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, - # For blocked per token: set to - # ceil_div(hidden_dim, block_size) * sizeof(float32) - # For per-token: set to sizeof(float32) - hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else - ((moe.hidden_dim + moe.block_size - 1) // - moe.block_size * torch.float32.itemsize)), - group_name=self.cpu_group.group_name, - ) if self.internode: - # inter-node communication needs nvshmem + # inter-node communication needs nvshmem, + # intra-node communication uses p2p mapping directly from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, nvshmem_get_unique_id, nvshmem_init) @@ -160,16 +154,43 @@ def __init__(self, cpu_group, model: torch.nn.Module): group=self.cpu_group) logger.debug("PPLX NVSHMEM UID = %s", uid) nvshmem_init(uid, self.rank, self.world_size) - self.pplx_handle = pplx.AllToAll.internode(**all_to_all_args) - else: - # intra-node communication uses p2p mapping directly - self.pplx_handle = pplx.AllToAll.intranode(**all_to_all_args) - # TODO: refactor the initialization logic - for module in model.modules(): - if module.__class__.__name__ == "FusedMoE": - module.quant_method.fused_experts.prepare_finalize.a2a \ - = self.pplx_handle + self.handle_cache = Cache() + + moe_modules = [ + module for module in model.modules() + if module.__class__.__name__ == "FusedMoE" + ] + for module in moe_modules: + moe_layer = module + moe = moe_layer.moe_config + max_num_tokens = moe.max_num_tokens + + all_to_all_args = dict( + max_num_tokens=max_num_tokens, + num_experts=moe.num_experts, + experts_per_token=moe.experts_per_token, # topk + rank=self.rank, + world_size=self.world_size, + # dp_size actually means tp_size, bug in pplx kernels + dp_size=self.tp_group.world_size, + hidden_dim=moe.hidden_dim, + hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, + # For blocked per token: set to + # ceil_div(hidden_dim, block_size) * sizeof(float32) + # For per-token: set to sizeof(float32) + hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else ( + (moe.hidden_dim + moe.block_size - 1) // moe.block_size * + torch.float32.itemsize)), + group_name=self.cpu_group.group_name, + ) + + pplx_handle = self.handle_cache.get_or_create( + all_to_all_args, pplx.AllToAll.internode + if self.internode else pplx.AllToAll.intranode) + + moe_layer.quant_method.fused_experts.prepare_finalize.a2a \ + = pplx_handle def dispatch(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): @@ -179,7 +200,10 @@ def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: raise NotImplementedError def destroy(self): - self.pplx_handle.destroy() + with self.handle_cache._lock: + for _, handle in self.handle_cache._cache.items(): + handle.destroy() + if self.internode: from pplx_kernels.nvshmem import nvshmem_finalize logger.debug("PPLX NVSHMEM finalize") From 1e60d54ad4ca8236004b39778833b840c62b68c7 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 16 May 2025 10:32:50 +0800 Subject: [PATCH 15/31] merge into init_prepare_finalize Signed-off-by: youkaichao --- .../device_communicators/all2all.py | 95 +--------- .../base_device_communicator.py | 74 ++++++++ .../device_communicators/cuda_communicator.py | 10 -- vllm/model_executor/layers/fused_moe/layer.py | 168 +++++++++--------- .../model_executor/layers/quantization/fp8.py | 19 +- 5 files changed, 175 insertions(+), 191 deletions(-) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 7b6b8fc2f8a..7347e091eac 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import importlib.util -import threading from typing import TYPE_CHECKING -from weakref import WeakValueDictionary import torch import torch.distributed as dist @@ -10,6 +8,8 @@ from vllm.forward_context import get_forward_context from vllm.logger import init_logger +from .base_device_communicator import All2AllBase, Cache + logger = init_logger(__name__) if TYPE_CHECKING: @@ -18,60 +18,6 @@ FusedMoE = None -class Cache: - - def __init__(self): - self._cache: WeakValueDictionary = WeakValueDictionary() - self._lock = threading.RLock() # Reentrant lock for thread safety - - def get_or_create(self, kwargs, func): - # Create a hashable key from the kwargs - key = tuple(sorted((k, v) for k, v in kwargs.items())) - - with self._lock: - instance = self._cache.get(key) - if instance is None: - instance = func(**kwargs) - self._cache[key] = instance - return instance - - -class All2AllBase: - - def __init__(self, cpu_group, model: torch.nn.Module): - self.cpu_group = cpu_group - - # compute some common properties - from vllm.distributed.parallel_state import (get_dp_group, - get_ep_group, - get_tp_group, - in_the_same_node_as) - - # all2all lives in ep group, which is merged from dp and tp group - self.dp_group = get_dp_group() - self.tp_group = get_tp_group() - self.ep_group = get_ep_group() - self.dp_rank = self.dp_group.rank_in_group - self.dp_world_size = self.dp_group.world_size - self.rank = self.ep_group.rank_in_group - self.world_size = self.ep_group.world_size - - # all2all communication often has separate implementations for - # intra-node and inter-node communication - self.intranode = in_the_same_node_as(cpu_group, source_rank=0) - self.internode = not self.intranode - - def dispatch(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor): - raise NotImplementedError - - def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: - raise NotImplementedError - - def destroy(self): - pass - - class NaiveAll2All(All2AllBase): """ A naive implementation of all2all communication. @@ -135,7 +81,6 @@ class PPLXAll2All(All2AllBase): def __init__(self, cpu_group, model: torch.nn.Module): has_pplx = importlib.util.find_spec("pplx_kernels") is not None assert has_pplx, "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa - import pplx_kernels as pplx super().__init__(cpu_group, model) if self.internode: @@ -162,35 +107,13 @@ def __init__(self, cpu_group, model: torch.nn.Module): if module.__class__.__name__ == "FusedMoE" ] for module in moe_modules: - moe_layer = module - moe = moe_layer.moe_config - max_num_tokens = moe.max_num_tokens - - all_to_all_args = dict( - max_num_tokens=max_num_tokens, - num_experts=moe.num_experts, - experts_per_token=moe.experts_per_token, # topk - rank=self.rank, - world_size=self.world_size, - # dp_size actually means tp_size, bug in pplx kernels - dp_size=self.tp_group.world_size, - hidden_dim=moe.hidden_dim, - hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, - # For blocked per token: set to - # ceil_div(hidden_dim, block_size) * sizeof(float32) - # For per-token: set to sizeof(float32) - hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else ( - (moe.hidden_dim + moe.block_size - 1) // moe.block_size * - torch.float32.itemsize)), - group_name=self.cpu_group.group_name, - ) - - pplx_handle = self.handle_cache.get_or_create( - all_to_all_args, pplx.AllToAll.internode - if self.internode else pplx.AllToAll.intranode) - - moe_layer.quant_method.fused_experts.prepare_finalize.a2a \ - = pplx_handle + module.quant_method.init_prepare_finalize() + + def get_handle(self, kwargs): + import pplx_kernels as pplx + return self.handle_cache.get_or_create( + kwargs, pplx.AllToAll.internode + if self.internode else pplx.AllToAll.intranode) def dispatch(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index ead79872bd4..f17f6f16b40 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -1,11 +1,76 @@ # SPDX-License-Identifier: Apache-2.0 +import threading from typing import Optional +from weakref import WeakValueDictionary import torch import torch.distributed as dist from torch.distributed import ProcessGroup +class Cache: + + def __init__(self): + self._cache: WeakValueDictionary = WeakValueDictionary() + self._lock = threading.RLock() # Reentrant lock for thread safety + + def get_or_create(self, kwargs, func): + # Create a hashable key from the kwargs + key = tuple(sorted((k, v) for k, v in kwargs.items())) + + with self._lock: + instance = self._cache.get(key) + if instance is None: + instance = func(**kwargs) + self._cache[key] = instance + return instance + + +class All2AllBase: + + def __init__(self, cpu_group, model: torch.nn.Module): + self.cpu_group = cpu_group + + # compute some common properties + from vllm.distributed.parallel_state import (get_dp_group, + get_ep_group, + get_tp_group, + in_the_same_node_as) + + # all2all lives in ep group, which is merged from dp and tp group + self.dp_group = get_dp_group() + self.tp_group = get_tp_group() + self.ep_group = get_ep_group() + self.dp_rank = self.dp_group.rank_in_group + self.dp_world_size = self.dp_group.world_size + self.rank = self.ep_group.rank_in_group + self.world_size = self.ep_group.world_size + + # all2all communication often has separate implementations for + # intra-node and inter-node communication + self.intranode = in_the_same_node_as(cpu_group, source_rank=0) + self.internode = not self.intranode + + def get_handle(self, kwargs): + # get a handle for the all2all communication, + # based on the kwargs. + # different layers can have different configs, + # e.g. one layer has hidden size 1024, another has 2048. + # usually the underlying implementation caches the handle + # and reuse it for the same config. + raise NotImplementedError + + def dispatch(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): + raise NotImplementedError + + def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + def destroy(self): + pass + + class DeviceCommunicatorBase: """ Base class for device-specific communicator. @@ -31,6 +96,15 @@ def __init__(self, self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank) + use_ep = False + from vllm.config import get_current_vllm_config + config = get_current_vllm_config() + if config is not None: + use_ep = config.parallel_config.enable_expert_parallel + + self.use_all2all = "ep" in unique_name and use_ep + self.all2all_impl: Optional[All2AllBase] = None + def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: dist.all_reduce(input_, group=self.device_group) return input_ diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index cdf60cfece8..b69d91965ad 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -7,7 +7,6 @@ import vllm.envs as envs -from .all2all import All2AllBase from .base_device_communicator import DeviceCommunicatorBase @@ -31,15 +30,6 @@ def __init__(self, use_pynccl = "ep" not in unique_name self.use_pynccl = use_pynccl - - use_ep = False - from vllm.config import get_current_vllm_config - config = get_current_vllm_config() - if config is not None: - use_ep = config.parallel_config.enable_expert_parallel - - self.use_all2all = "ep" in unique_name and use_ep - self.all2all_impl: Optional[All2AllBase] = None self.use_custom_allreduce = use_custom_allreduce # lazy import to avoid documentation build error diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index bffd87702e9..245417ecf31 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -246,13 +246,58 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, params_dtype: torch.dtype, **extra_weight_attrs): raise NotImplementedError - def set_prepare_finalize( - self, - dp_size: int, - world_size: int, - prepare_finalize: FusedMoEPrepareAndFinalize, - ) -> bool: - return False + def init_prepare_finalize(self): + all2all_impl = get_ep_group().device_communicator.all2all_impl + assert all2all_impl is not None + + moe: MoEConfig = self.moe + + all_to_all_args = dict( + max_num_tokens=moe.max_num_tokens, + num_experts=moe.num_experts, + experts_per_token=moe.experts_per_token, # topk + rank=all2all_impl.rank, + world_size=all2all_impl.world_size, + # dp_size actually means tp_size, bug in pplx kernels + dp_size=all2all_impl.tp_group.world_size, + hidden_dim=moe.hidden_dim, + hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, + # For blocked per token: set to + # ceil_div(hidden_dim, block_size) * sizeof(float32) + # For per-token: set to sizeof(float32) + hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else + ((moe.hidden_dim + moe.block_size - 1) // + moe.block_size * torch.float32.itemsize)), + group_name=all2all_impl.cpu_group.group_name, + ) + + handle = all2all_impl.get_handle(all_to_all_args) + + prepare_finalize = None + if moe.use_pplx_kernels: + prepare_finalize = PplxPrepareAndFinalize( + handle, + max_num_tokens=moe.max_num_tokens, + world_size=all2all_impl.world_size, + rank=all2all_impl.rank, + # dp_size actually means tp_size, bug in pplx kernels + dp_size=all2all_impl.tp_group.world_size, + quant_dtype=moe.in_dtype, + ) + + experts = self.select_gemm_impl(prepare_finalize) + + self.fused_experts = FusedMoEModularKernel( + prepare_finalize, + experts, + ) + + def select_gemm_impl( + self, prepare_finalize: Optional[FusedMoEPrepareAndFinalize] + ) -> FusedMoEPermuteExpertsUnpermute: + # based on the all2all implementation, select the appropriate + # gemm implementation + pass @abstractmethod def apply( @@ -292,6 +337,42 @@ def __init__(self, moe: MoEConfig): else: self.rocm_aiter_fused_experts = None # type: ignore + def select_gemm_impl( + self, prepare_finalize: Optional[FusedMoEPrepareAndFinalize]): + + assert self.fused_experts == fused_experts + + all2all_impl = get_ep_group().device_communicator.all2all_impl + assert all2all_impl is not None + + experts: Optional[FusedMoEPermuteExpertsUnpermute] = None + + if isinstance(prepare_finalize, + (BatchedPrepareAndFinalize, PplxPrepareAndFinalize)): + logger.debug("BatchedTritonExperts %s", self.moe) + experts = BatchedTritonExperts( + max_num_tokens=MOE_DP_CHUNK_SIZE, + world_size=all2all_impl.world_size, + # dp_size actually means tp_size, bug in pplx kernels + dp_size=all2all_impl.tp_group.world_size, + use_fp8_w8a8=False, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + block_shape=None, + ) + else: + logger.debug("TritonExperts %s", self.moe) + experts = TritonExperts( + use_fp8_w8a8=False, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + block_shape=None, + per_channel_quant=False, + ) + return experts + def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): @@ -391,47 +472,6 @@ def apply( activation=activation, apply_router_weight_on_input=apply_router_weight_on_input) - def set_prepare_finalize( - self, - dp_size: int, - world_size: int, - prepare_finalize: FusedMoEPrepareAndFinalize, - ) -> bool: - assert self.fused_experts == fused_experts - - experts: Optional[FusedMoEPermuteExpertsUnpermute] = None - - if isinstance(prepare_finalize, - (BatchedPrepareAndFinalize, PplxPrepareAndFinalize)): - logger.debug("BatchedTritonExperts %s", self.moe) - experts = BatchedTritonExperts( - max_num_tokens=MOE_DP_CHUNK_SIZE, - world_size=world_size, - dp_size=dp_size, - use_fp8_w8a8=False, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - block_shape=None, - ) - else: - logger.debug("TritonExperts %s", self.moe) - experts = TritonExperts( - use_fp8_w8a8=False, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - block_shape=None, - per_channel_quant=False, - ) - - self.fused_experts = FusedMoEModularKernel( - prepare_finalize, - experts, - ) - - return True - def forward_cuda( self, layer: torch.nn.Module, @@ -642,26 +682,6 @@ def determine_expert_map( return (local_num_experts, expert_map) -def _construct_prepare_finalize( - moe: MoEConfig, quant_config: Optional[QuantizationConfig] -) -> Optional[FusedMoEPrepareAndFinalize]: - world_size = moe.ep_size - dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP. - rank = moe.ep_rank - - if moe.use_pplx_kernels: - return PplxPrepareAndFinalize( - None, # will be set later in prepare_communication_buffer_for_model - max_num_tokens=moe.max_num_tokens, - world_size=world_size, - rank=rank, - dp_size=dp_size, - quant_dtype=moe.in_dtype, - ) - - return None - - class FusedMoE(torch.nn.Module): """FusedMoE layer for MoE models. @@ -785,25 +805,13 @@ def __init__( if quant_config is None: quant_method = UnquantizedFusedMoEMethod(moe) - prepare_finalize = _construct_prepare_finalize(moe, quant_config) else: quant_method = quant_config.get_quant_method(self, prefix) - # No pplx for quantized types yet. - prepare_finalize = None assert quant_method is not None assert isinstance(quant_method, FusedMoEMethodBase) self.quant_method = quant_method - if prepare_finalize is not None: - world_size = moe.ep_size - dp_size = int(moe.ep_size // moe.dp_size) - success = self.quant_method.set_prepare_finalize( - dp_size, world_size, prepare_finalize) - if not success: - logger.warning("DP+EP not supported for %s.", - type(self.quant_method)) - moe_quant_params = { "num_experts": self.local_num_experts, "hidden_size": hidden_size, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index f4cdc3db1a0..487177c6b63 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -10,7 +10,6 @@ from torch.nn.parameter import Parameter import vllm.envs as envs -import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger @@ -791,17 +790,12 @@ def process_weights_after_loading(self, layer: Module) -> None: del layer.w13_input_scale del layer.w2_input_scale - def set_prepare_finalize( - self, - dp_size: int, - world_size: int, - prepare_finalize: mk.FusedMoEPrepareAndFinalize, - ) -> bool: + def select_gemm_impl(self, prepare_finalize): from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( TritonOrDeepGemmExperts) - if self.use_marlin or self.rocm_aiter_moe_enabled: - return False + assert not self.use_marlin and not self.rocm_aiter_moe_enabled, ( + "Marlin and ROCm AITER are not supported with all2all yet.") experts = TritonOrDeepGemmExperts( use_fp8_w8a8=True, @@ -809,12 +803,7 @@ def set_prepare_finalize( allow_deep_gemm=self.allow_deep_gemm, ) - self.fused_experts = mk.FusedMoEModularKernel( - prepare_finalize, - experts, - ) - - return True + return experts def apply( self, From 15d673b286a91034e6cbf788cd27eab824a8882a Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 16 May 2025 10:38:58 +0800 Subject: [PATCH 16/31] disable inductor Signed-off-by: youkaichao --- vllm/platforms/cuda.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 9163b97c51a..3b0e77f6e47 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -158,6 +158,8 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: "currently not supported with CUDA Graphs.") vllm_config.model_config.enforce_eager = True compilation_config.use_cudagraph = False + # FIXME: inductor breaks cudagraph (from @bnell) + compilation_config.use_inductor = False @classmethod def get_current_memory_usage(cls, From 63f029b11e1bb058e3754e21dc713e4fb63406e8 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 16 May 2025 10:48:29 +0800 Subject: [PATCH 17/31] fix Signed-off-by: youkaichao --- vllm/distributed/device_communicators/all2all.py | 7 ------- .../device_communicators/cuda_communicator.py | 9 +++++++++ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 7347e091eac..8b641dd890d 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -102,13 +102,6 @@ def __init__(self, cpu_group, model: torch.nn.Module): self.handle_cache = Cache() - moe_modules = [ - module for module in model.modules() - if module.__class__.__name__ == "FusedMoE" - ] - for module in moe_modules: - module.quant_method.init_prepare_finalize() - def get_handle(self, kwargs): import pplx_kernels as pplx return self.handle_cache.get_or_create( diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index b69d91965ad..7cca11c1ef1 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -151,6 +151,15 @@ def prepare_communication_buffer_for_model(self, elif all2all_backend == "pplx": from .all2all import PPLXAll2All self.all2all_impl = PPLXAll2All(self.cpu_group, model) + else: + raise ValueError(f"Unknown all2all backend: {all2all_backend}") + + moe_modules = [ + module for module in model.modules() + if module.__class__.__name__ == "FusedMoE" + ] + for module in moe_modules: + module.quant_method.init_prepare_finalize() def dispatch( self, hidden_states: torch.Tensor, From d75ac1c27f5267ac4986e9c3e979ff48145aef8e Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 16 May 2025 10:51:14 +0800 Subject: [PATCH 18/31] rename to manager Signed-off-by: youkaichao --- .../device_communicators/all2all.py | 6 ++-- .../base_device_communicator.py | 4 +-- .../device_communicators/cuda_communicator.py | 22 +++++++-------- vllm/model_executor/layers/fused_moe/layer.py | 28 +++++++++---------- 4 files changed, 30 insertions(+), 30 deletions(-) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 8b641dd890d..b061d7215d9 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -8,7 +8,7 @@ from vllm.forward_context import get_forward_context from vllm.logger import init_logger -from .base_device_communicator import All2AllBase, Cache +from .base_device_communicator import All2AllManagerBase, Cache logger = init_logger(__name__) @@ -18,7 +18,7 @@ FusedMoE = None -class NaiveAll2All(All2AllBase): +class NaiveAll2AllManager(All2AllManagerBase): """ A naive implementation of all2all communication. It uses all-reduce under the hood, which is not @@ -73,7 +73,7 @@ def destroy(self): pass -class PPLXAll2All(All2AllBase): +class PPLXAll2AllManager(All2AllManagerBase): """ All2All communication based on PPLX kernels. """ diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index f17f6f16b40..a8846fc8320 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -26,7 +26,7 @@ def get_or_create(self, kwargs, func): return instance -class All2AllBase: +class All2AllManagerBase: def __init__(self, cpu_group, model: torch.nn.Module): self.cpu_group = cpu_group @@ -103,7 +103,7 @@ def __init__(self, use_ep = config.parallel_config.enable_expert_parallel self.use_all2all = "ep" in unique_name and use_ep - self.all2all_impl: Optional[All2AllBase] = None + self.all2all_manager: Optional[All2AllManagerBase] = None def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: dist.all_reduce(input_, group=self.device_group) diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 7cca11c1ef1..2a8f59c28c7 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -133,9 +133,9 @@ def destroy(self): self.pynccl_comm = None if self.ca_comm is not None: self.ca_comm = None - if self.all2all_impl is not None: - self.all2all_impl.destroy() - self.all2all_impl = None + if self.all2all_manager is not None: + self.all2all_manager.destroy() + self.all2all_manager = None def prepare_communication_buffer_for_model(self, model: torch.nn.Module) -> None: @@ -146,11 +146,11 @@ def prepare_communication_buffer_for_model(self, return all2all_backend = envs.VLLM_ALL2ALL_BACKEND if all2all_backend == "naive": - from .all2all import NaiveAll2All - self.all2all_impl = NaiveAll2All(self.cpu_group, model) + from .all2all import NaiveAll2AllManager + self.all2all_manager = NaiveAll2AllManager(self.cpu_group, model) elif all2all_backend == "pplx": - from .all2all import PPLXAll2All - self.all2all_impl = PPLXAll2All(self.cpu_group, model) + from .all2all import PPLXAll2AllManager + self.all2all_manager = PPLXAll2AllManager(self.cpu_group, model) else: raise ValueError(f"Unknown all2all backend: {all2all_backend}") @@ -164,12 +164,12 @@ def prepare_communication_buffer_for_model(self, def dispatch( self, hidden_states: torch.Tensor, router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - assert self.all2all_impl is not None - hidden_states, router_logits = self.all2all_impl.dispatch( + assert self.all2all_manager is not None + hidden_states, router_logits = self.all2all_manager.dispatch( hidden_states, router_logits) return hidden_states, router_logits def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: - assert self.all2all_impl is not None - hidden_states = self.all2all_impl.combine(hidden_states) + assert self.all2all_manager is not None + hidden_states = self.all2all_manager.combine(hidden_states) return hidden_states diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 245417ecf31..632b5e6d285 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -247,8 +247,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, raise NotImplementedError def init_prepare_finalize(self): - all2all_impl = get_ep_group().device_communicator.all2all_impl - assert all2all_impl is not None + all2all_manager = get_ep_group().device_communicator.all2all_manager + assert all2all_manager is not None moe: MoEConfig = self.moe @@ -256,10 +256,10 @@ def init_prepare_finalize(self): max_num_tokens=moe.max_num_tokens, num_experts=moe.num_experts, experts_per_token=moe.experts_per_token, # topk - rank=all2all_impl.rank, - world_size=all2all_impl.world_size, + rank=all2all_manager.rank, + world_size=all2all_manager.world_size, # dp_size actually means tp_size, bug in pplx kernels - dp_size=all2all_impl.tp_group.world_size, + dp_size=all2all_manager.tp_group.world_size, hidden_dim=moe.hidden_dim, hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, # For blocked per token: set to @@ -268,20 +268,20 @@ def init_prepare_finalize(self): hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else ((moe.hidden_dim + moe.block_size - 1) // moe.block_size * torch.float32.itemsize)), - group_name=all2all_impl.cpu_group.group_name, + group_name=all2all_manager.cpu_group.group_name, ) - handle = all2all_impl.get_handle(all_to_all_args) + handle = all2all_manager.get_handle(all_to_all_args) prepare_finalize = None if moe.use_pplx_kernels: prepare_finalize = PplxPrepareAndFinalize( handle, max_num_tokens=moe.max_num_tokens, - world_size=all2all_impl.world_size, - rank=all2all_impl.rank, + world_size=all2all_manager.world_size, + rank=all2all_manager.rank, # dp_size actually means tp_size, bug in pplx kernels - dp_size=all2all_impl.tp_group.world_size, + dp_size=all2all_manager.tp_group.world_size, quant_dtype=moe.in_dtype, ) @@ -342,8 +342,8 @@ def select_gemm_impl( assert self.fused_experts == fused_experts - all2all_impl = get_ep_group().device_communicator.all2all_impl - assert all2all_impl is not None + all2all_manager = get_ep_group().device_communicator.all2all_manager + assert all2all_manager is not None experts: Optional[FusedMoEPermuteExpertsUnpermute] = None @@ -352,9 +352,9 @@ def select_gemm_impl( logger.debug("BatchedTritonExperts %s", self.moe) experts = BatchedTritonExperts( max_num_tokens=MOE_DP_CHUNK_SIZE, - world_size=all2all_impl.world_size, + world_size=all2all_manager.world_size, # dp_size actually means tp_size, bug in pplx kernels - dp_size=all2all_impl.tp_group.world_size, + dp_size=all2all_manager.tp_group.world_size, use_fp8_w8a8=False, use_int8_w8a8=False, use_int8_w8a16=False, From 60499dfe419cb46feed581cce83a98fded96e7c9 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 16 May 2025 10:59:26 +0800 Subject: [PATCH 19/31] fix for non-pplx Signed-off-by: youkaichao --- vllm/model_executor/layers/fused_moe/layer.py | 42 +++++++++---------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 632b5e6d285..a6e68bf26ed 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -252,29 +252,29 @@ def init_prepare_finalize(self): moe: MoEConfig = self.moe - all_to_all_args = dict( - max_num_tokens=moe.max_num_tokens, - num_experts=moe.num_experts, - experts_per_token=moe.experts_per_token, # topk - rank=all2all_manager.rank, - world_size=all2all_manager.world_size, - # dp_size actually means tp_size, bug in pplx kernels - dp_size=all2all_manager.tp_group.world_size, - hidden_dim=moe.hidden_dim, - hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, - # For blocked per token: set to - # ceil_div(hidden_dim, block_size) * sizeof(float32) - # For per-token: set to sizeof(float32) - hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else - ((moe.hidden_dim + moe.block_size - 1) // - moe.block_size * torch.float32.itemsize)), - group_name=all2all_manager.cpu_group.group_name, - ) - - handle = all2all_manager.get_handle(all_to_all_args) - prepare_finalize = None if moe.use_pplx_kernels: + all_to_all_args = dict( + max_num_tokens=moe.max_num_tokens, + num_experts=moe.num_experts, + experts_per_token=moe.experts_per_token, # topk + rank=all2all_manager.rank, + world_size=all2all_manager.world_size, + # dp_size actually means tp_size, bug in pplx kernels + dp_size=all2all_manager.tp_group.world_size, + hidden_dim=moe.hidden_dim, + hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, + # For blocked per token: set to + # ceil_div(hidden_dim, block_size) * sizeof(float32) + # For per-token: set to sizeof(float32) + hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else ( + (moe.hidden_dim + moe.block_size - 1) // moe.block_size * + torch.float32.itemsize)), + group_name=all2all_manager.cpu_group.group_name, + ) + + handle = all2all_manager.get_handle(all_to_all_args) + prepare_finalize = PplxPrepareAndFinalize( handle, max_num_tokens=moe.max_num_tokens, From cd6858e896f59d16bb3312d896907ceec43295e6 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 16 May 2025 11:03:56 +0800 Subject: [PATCH 20/31] fix for non-pplx Signed-off-by: youkaichao --- vllm/model_executor/layers/fused_moe/layer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index a6e68bf26ed..bf6a31793d2 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -285,12 +285,12 @@ def init_prepare_finalize(self): quant_dtype=moe.in_dtype, ) - experts = self.select_gemm_impl(prepare_finalize) - - self.fused_experts = FusedMoEModularKernel( - prepare_finalize, - experts, - ) + if prepare_finalize is not None: + experts = self.select_gemm_impl(prepare_finalize) + self.fused_experts = FusedMoEModularKernel( + prepare_finalize, + experts, + ) def select_gemm_impl( self, prepare_finalize: Optional[FusedMoEPrepareAndFinalize] From 3f6a862b58a651c7e15b520e1bf7c36a50446b92 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 16 May 2025 11:13:59 +0800 Subject: [PATCH 21/31] move prepare_communication_buffer_for_model to base Signed-off-by: youkaichao --- .../device_communicators/all2all.py | 8 ++-- .../base_device_communicator.py | 10 +++-- .../device_communicators/cuda_communicator.py | 40 ++++++++----------- 3 files changed, 27 insertions(+), 31 deletions(-) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index b061d7215d9..333ea5536ec 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -26,8 +26,8 @@ class NaiveAll2AllManager(All2AllManagerBase): debugging. """ - def __init__(self, cpu_group, model: torch.nn.Module): - super().__init__(cpu_group, model) + def __init__(self, cpu_group): + super().__init__(cpu_group) def naive_multicast(self, x: torch.Tensor, cu_tokens_across_dp_cpu: torch.Tensor): @@ -78,10 +78,10 @@ class PPLXAll2AllManager(All2AllManagerBase): All2All communication based on PPLX kernels. """ - def __init__(self, cpu_group, model: torch.nn.Module): + def __init__(self, cpu_group): has_pplx = importlib.util.find_spec("pplx_kernels") is not None assert has_pplx, "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa - super().__init__(cpu_group, model) + super().__init__(cpu_group) if self.internode: # inter-node communication needs nvshmem, diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index a8846fc8320..f32ddd3c000 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -28,7 +28,7 @@ def get_or_create(self, kwargs, func): class All2AllManagerBase: - def __init__(self, cpu_group, model: torch.nn.Module): + def __init__(self, cpu_group): self.cpu_group = cpu_group # compute some common properties @@ -228,9 +228,13 @@ def prepare_communication_buffer_for_model(self, model: torch.nn.Module) -> None: """ Prepare the communication buffer for the model. - This is a no-op in the base class. """ - pass + moe_modules = [ + module for module in model.modules() + if module.__class__.__name__ == "FusedMoE" + ] + for module in moe_modules: + module.quant_method.init_prepare_finalize() def dispatch( self, hidden_states: torch.Tensor, diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 2a8f59c28c7..a05a13f51d4 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -6,9 +6,12 @@ from torch.distributed import ProcessGroup import vllm.envs as envs +from vllm.logger import init_logger from .base_device_communicator import DeviceCommunicatorBase +logger = init_logger(__name__) + class CudaCommunicator(DeviceCommunicatorBase): @@ -53,6 +56,19 @@ def __init__(self, device=self.device, ) + if self.use_all2all: + all2all_backend = envs.VLLM_ALL2ALL_BACKEND + if all2all_backend == "naive": + from .all2all import NaiveAll2AllManager + self.all2all_manager = NaiveAll2AllManager(self.cpu_group) + logger.info("Using naive all2all manager.") + elif all2all_backend == "pplx": + from .all2all import PPLXAll2AllManager + self.all2all_manager = PPLXAll2AllManager(self.cpu_group) + logger.info("Using PPLX all2all manager.") + else: + raise ValueError(f"Unknown all2all backend: {all2all_backend}") + def all_reduce(self, input_): # always try custom allreduce first, # and then pynccl. @@ -137,30 +153,6 @@ def destroy(self): self.all2all_manager.destroy() self.all2all_manager = None - def prepare_communication_buffer_for_model(self, - model: torch.nn.Module) -> None: - """ - Prepare the communication buffer for the model. - """ - if not self.use_all2all: - return - all2all_backend = envs.VLLM_ALL2ALL_BACKEND - if all2all_backend == "naive": - from .all2all import NaiveAll2AllManager - self.all2all_manager = NaiveAll2AllManager(self.cpu_group, model) - elif all2all_backend == "pplx": - from .all2all import PPLXAll2AllManager - self.all2all_manager = PPLXAll2AllManager(self.cpu_group, model) - else: - raise ValueError(f"Unknown all2all backend: {all2all_backend}") - - moe_modules = [ - module for module in model.modules() - if module.__class__.__name__ == "FusedMoE" - ] - for module in moe_modules: - module.quant_method.init_prepare_finalize() - def dispatch( self, hidden_states: torch.Tensor, router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: From d4197366b95cd35b17d6aa3b80e35e32c2e07a2f Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 16 May 2025 11:16:31 +0800 Subject: [PATCH 22/31] fix no ep case Signed-off-by: youkaichao --- .../device_communicators/base_device_communicator.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index f32ddd3c000..76a8735b783 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -229,6 +229,9 @@ def prepare_communication_buffer_for_model(self, """ Prepare the communication buffer for the model. """ + if not self.use_all2all: + return + moe_modules = [ module for module in model.modules() if module.__class__.__name__ == "FusedMoE" From 5e9d2c90e321e2f3ad437977588facd55980fcf2 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 16 May 2025 11:31:26 +0800 Subject: [PATCH 23/31] annotate moe and quant_config Signed-off-by: youkaichao --- vllm/model_executor/layers/fused_moe/layer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index bf6a31793d2..0a4faee1f6a 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -239,6 +239,8 @@ class FusedMoeWeightScaleSupported(Enum): class FusedMoEMethodBase(QuantizeMethodBase): + moe: Optional[MoEConfig] = None + quant_config: Optional[QuantizationConfig] = None @abstractmethod def create_weights(self, layer: torch.nn.Module, num_experts: int, From 522ea26eb0acaa446e686250743661300a812b1f Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 16 May 2025 11:33:32 +0800 Subject: [PATCH 24/31] fix typing Signed-off-by: youkaichao --- vllm/model_executor/layers/fused_moe/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 0a4faee1f6a..d6812d81f71 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -299,7 +299,7 @@ def select_gemm_impl( ) -> FusedMoEPermuteExpertsUnpermute: # based on the all2all implementation, select the appropriate # gemm implementation - pass + raise NotImplementedError @abstractmethod def apply( From f3fc838ba19723e007d182bf40432b769dbe6940 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 16 May 2025 11:40:48 +0800 Subject: [PATCH 25/31] fix init Signed-off-by: youkaichao --- .../device_communicators/base_device_communicator.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index 76a8735b783..af71fc7c334 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -33,18 +33,18 @@ def __init__(self, cpu_group): # compute some common properties from vllm.distributed.parallel_state import (get_dp_group, - get_ep_group, get_tp_group, in_the_same_node_as) # all2all lives in ep group, which is merged from dp and tp group self.dp_group = get_dp_group() self.tp_group = get_tp_group() - self.ep_group = get_ep_group() + # no self.ep_group since self.ep_group is still in construction + # when we create this object self.dp_rank = self.dp_group.rank_in_group self.dp_world_size = self.dp_group.world_size - self.rank = self.ep_group.rank_in_group - self.world_size = self.ep_group.world_size + self.rank = dist.get_rank(cpu_group) + self.world_size = dist.get_world_size(cpu_group) # all2all communication often has separate implementations for # intra-node and inter-node communication From c3cd65cf44e2580a65534ec6cc5619e0f47236d5 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 16 May 2025 11:47:44 +0800 Subject: [PATCH 26/31] fix cross-node init Signed-off-by: youkaichao --- vllm/distributed/device_communicators/all2all.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 333ea5536ec..a250ec89cd5 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -95,7 +95,7 @@ def __init__(self, cpu_group): uid = nvshmem_get_unique_id( ) if self.rank == 0 else nvshmem_alloc_empty_unique_id() dist.broadcast(uid, - src=self.ep_group.ranks[0], + src=dist.get_process_group_ranks(self.cpu_group)[0], group=self.cpu_group) logger.debug("PPLX NVSHMEM UID = %s", uid) nvshmem_init(uid, self.rank, self.world_size) From 259a724ecf91875806466e2a2cdd4f82c662eaf3 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 16 May 2025 12:09:47 +0800 Subject: [PATCH 27/31] fix typing Signed-off-by: youkaichao --- .../device_communicators/base_device_communicator.py | 3 ++- vllm/model_executor/layers/fused_moe/layer.py | 8 +++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index af71fc7c334..1c26ec88aa9 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -237,7 +237,8 @@ def prepare_communication_buffer_for_model(self, if module.__class__.__name__ == "FusedMoE" ] for module in moe_modules: - module.quant_method.init_prepare_finalize() + module.quant_method.init_prepare_finalize(module.moe_config, + module.quant_config) def dispatch( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index d6812d81f71..c2ffa8387f5 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -239,8 +239,6 @@ class FusedMoeWeightScaleSupported(Enum): class FusedMoEMethodBase(QuantizeMethodBase): - moe: Optional[MoEConfig] = None - quant_config: Optional[QuantizationConfig] = None @abstractmethod def create_weights(self, layer: torch.nn.Module, num_experts: int, @@ -248,12 +246,11 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, params_dtype: torch.dtype, **extra_weight_attrs): raise NotImplementedError - def init_prepare_finalize(self): + def init_prepare_finalize(self, moe: MoEConfig, + quant_config: Optional[QuantizationConfig]): all2all_manager = get_ep_group().device_communicator.all2all_manager assert all2all_manager is not None - moe: MoEConfig = self.moe - prepare_finalize = None if moe.use_pplx_kernels: all_to_all_args = dict( @@ -800,6 +797,7 @@ def __init__( max_num_tokens=MOE_DP_CHUNK_SIZE, ) self.moe_config = moe + self.quant_config = quant_config # Note: get_quant_method will look at the layer's local_num_experts # for heuristic purposes, so it must be initialized first. From 52945ba3e562f351a3cf3e69de9b645dbadaa544 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 16 May 2025 14:48:10 +0800 Subject: [PATCH 28/31] fix typing Signed-off-by: youkaichao --- vllm/model_executor/layers/fused_moe/layer.py | 2 +- vllm/model_executor/layers/quantization/fp8.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index c2ffa8387f5..79beb619e98 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -326,7 +326,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): def __init__(self, moe: MoEConfig): super().__init__() - self.fused_experts = fused_experts + self.fused_experts = fused_experts # noqa self.moe = moe self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 487177c6b63..e95de9382ca 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -460,7 +460,7 @@ def __init__(self, quant_config: Fp8Config): logger.warning_once( "DeepGemm not supported on the current platform.") - self.fused_experts = functools.partial( + self.fused_experts = functools.partial( # noqa fused_experts, block_shape=self.quant_config.weight_block_size, allow_deep_gemm=self.allow_deep_gemm) From 9c737768463eb6075e302d3b2b08c6928ae6bf83 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 16 May 2025 14:56:20 +0800 Subject: [PATCH 29/31] fix typing Signed-off-by: youkaichao --- vllm/model_executor/layers/fused_moe/layer.py | 2 +- vllm/model_executor/layers/quantization/fp8.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 79beb619e98..5543a45720e 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -326,7 +326,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): def __init__(self, moe: MoEConfig): super().__init__() - self.fused_experts = fused_experts # noqa + self.fused_experts = fused_experts # type: ignore self.moe = moe self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index e95de9382ca..652bf76673c 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -460,7 +460,7 @@ def __init__(self, quant_config: Fp8Config): logger.warning_once( "DeepGemm not supported on the current platform.") - self.fused_experts = functools.partial( # noqa + self.fused_experts = functools.partial( # type: ignore fused_experts, block_shape=self.quant_config.weight_block_size, allow_deep_gemm=self.allow_deep_gemm) From 5b4095bce14b250b82adfda8bdae352479f22611 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 18 May 2025 00:05:04 +0800 Subject: [PATCH 30/31] meaningful comments Signed-off-by: youkaichao --- vllm/model_executor/layers/fused_moe/layer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 5543a45720e..823313d4c5d 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -296,7 +296,9 @@ def select_gemm_impl( ) -> FusedMoEPermuteExpertsUnpermute: # based on the all2all implementation, select the appropriate # gemm implementation - raise NotImplementedError + raise NotImplementedError( + "Subclass must select appropriate gemm implementation" + " based on the prepare_finalize") @abstractmethod def apply( From cfa027bff7084e40b156bce25f6b6360c1c21b94 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 19 May 2025 13:24:07 +0800 Subject: [PATCH 31/31] fix error Signed-off-by: youkaichao --- .../device_communicators/base_device_communicator.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index 1c26ec88aa9..52b97094914 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -100,7 +100,10 @@ def __init__(self, from vllm.config import get_current_vllm_config config = get_current_vllm_config() if config is not None: - use_ep = config.parallel_config.enable_expert_parallel + # as long as we use data parallel (coupled data parallel + # where all data parallel ranks execute forward together), + # we initialize the all2all manager used in expert parallel. + use_ep = config.parallel_config.data_parallel_size > 1 self.use_all2all = "ep" in unique_name and use_ep self.all2all_manager: Optional[All2AllManagerBase] = None