diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index b69647b0058..a250ec89cd5 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -1,44 +1,24 @@ # SPDX-License-Identifier: Apache-2.0 +import importlib.util +from typing import TYPE_CHECKING + import torch +import torch.distributed as dist from vllm.forward_context import get_forward_context +from vllm.logger import init_logger +from .base_device_communicator import All2AllManagerBase, Cache -class All2AllBase: - - def __init__(self, cpu_group, model): - self.cpu_group = cpu_group - - # compute some common properties - from vllm.distributed.parallel_state import (get_dp_group, - get_ep_group, - get_tp_group, - in_the_same_node_as) - - # all2all lives in ep group, which is merged from dp and tp group - self.dp_group = get_dp_group() - self.tp_group = get_tp_group() - self.ep_group = get_ep_group() - self.dp_rank = self.dp_group.rank_in_group - self.dp_world_size = self.dp_group.world_size - - # all2all communication often has separate implementations for - # intra-node and inter-node communication - self.intranode = in_the_same_node_as(cpu_group, source_rank=0) - self.internode = not self.intranode - - def dispatch(self, hidden_states: torch.Tensor, - router_logits: torch.Tensor): - raise NotImplementedError - - def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: - raise NotImplementedError +logger = init_logger(__name__) - def destroy(self): - pass +if TYPE_CHECKING: + from vllm.model_executor.layers.fused_moe.layer import FusedMoE +else: + FusedMoE = None -class NaiveAll2All(All2AllBase): +class NaiveAll2AllManager(All2AllManagerBase): """ A naive implementation of all2all communication. It uses all-reduce under the hood, which is not @@ -46,8 +26,8 @@ class NaiveAll2All(All2AllBase): debugging. """ - def __init__(self, cpu_group, model): - super().__init__(cpu_group, model) + def __init__(self, cpu_group): + super().__init__(cpu_group) def naive_multicast(self, x: torch.Tensor, cu_tokens_across_dp_cpu: torch.Tensor): @@ -91,3 +71,56 @@ def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: def destroy(self): pass + + +class PPLXAll2AllManager(All2AllManagerBase): + """ + All2All communication based on PPLX kernels. + """ + + def __init__(self, cpu_group): + has_pplx = importlib.util.find_spec("pplx_kernels") is not None + assert has_pplx, "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa + super().__init__(cpu_group) + + if self.internode: + # inter-node communication needs nvshmem, + # intra-node communication uses p2p mapping directly + from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, + nvshmem_get_unique_id, + nvshmem_init) + logger.debug( + "Initialize NVSHMEM for pplx_kernels: " + "rank=%d, world size=%d", self.rank, self.world_size) + uid = nvshmem_get_unique_id( + ) if self.rank == 0 else nvshmem_alloc_empty_unique_id() + dist.broadcast(uid, + src=dist.get_process_group_ranks(self.cpu_group)[0], + group=self.cpu_group) + logger.debug("PPLX NVSHMEM UID = %s", uid) + nvshmem_init(uid, self.rank, self.world_size) + + self.handle_cache = Cache() + + def get_handle(self, kwargs): + import pplx_kernels as pplx + return self.handle_cache.get_or_create( + kwargs, pplx.AllToAll.internode + if self.internode else pplx.AllToAll.intranode) + + def dispatch(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): + raise NotImplementedError + + def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + def destroy(self): + with self.handle_cache._lock: + for _, handle in self.handle_cache._cache.items(): + handle.destroy() + + if self.internode: + from pplx_kernels.nvshmem import nvshmem_finalize + logger.debug("PPLX NVSHMEM finalize") + nvshmem_finalize() diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index ead79872bd4..52b97094914 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -1,11 +1,76 @@ # SPDX-License-Identifier: Apache-2.0 +import threading from typing import Optional +from weakref import WeakValueDictionary import torch import torch.distributed as dist from torch.distributed import ProcessGroup +class Cache: + + def __init__(self): + self._cache: WeakValueDictionary = WeakValueDictionary() + self._lock = threading.RLock() # Reentrant lock for thread safety + + def get_or_create(self, kwargs, func): + # Create a hashable key from the kwargs + key = tuple(sorted((k, v) for k, v in kwargs.items())) + + with self._lock: + instance = self._cache.get(key) + if instance is None: + instance = func(**kwargs) + self._cache[key] = instance + return instance + + +class All2AllManagerBase: + + def __init__(self, cpu_group): + self.cpu_group = cpu_group + + # compute some common properties + from vllm.distributed.parallel_state import (get_dp_group, + get_tp_group, + in_the_same_node_as) + + # all2all lives in ep group, which is merged from dp and tp group + self.dp_group = get_dp_group() + self.tp_group = get_tp_group() + # no self.ep_group since self.ep_group is still in construction + # when we create this object + self.dp_rank = self.dp_group.rank_in_group + self.dp_world_size = self.dp_group.world_size + self.rank = dist.get_rank(cpu_group) + self.world_size = dist.get_world_size(cpu_group) + + # all2all communication often has separate implementations for + # intra-node and inter-node communication + self.intranode = in_the_same_node_as(cpu_group, source_rank=0) + self.internode = not self.intranode + + def get_handle(self, kwargs): + # get a handle for the all2all communication, + # based on the kwargs. + # different layers can have different configs, + # e.g. one layer has hidden size 1024, another has 2048. + # usually the underlying implementation caches the handle + # and reuse it for the same config. + raise NotImplementedError + + def dispatch(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): + raise NotImplementedError + + def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + def destroy(self): + pass + + class DeviceCommunicatorBase: """ Base class for device-specific communicator. @@ -31,6 +96,18 @@ def __init__(self, self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank) + use_ep = False + from vllm.config import get_current_vllm_config + config = get_current_vllm_config() + if config is not None: + # as long as we use data parallel (coupled data parallel + # where all data parallel ranks execute forward together), + # we initialize the all2all manager used in expert parallel. + use_ep = config.parallel_config.data_parallel_size > 1 + + self.use_all2all = "ep" in unique_name and use_ep + self.all2all_manager: Optional[All2AllManagerBase] = None + def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: dist.all_reduce(input_, group=self.device_group) return input_ @@ -154,9 +231,17 @@ def prepare_communication_buffer_for_model(self, model: torch.nn.Module) -> None: """ Prepare the communication buffer for the model. - This is a no-op in the base class. """ - pass + if not self.use_all2all: + return + + moe_modules = [ + module for module in model.modules() + if module.__class__.__name__ == "FusedMoE" + ] + for module in moe_modules: + module.quant_method.init_prepare_finalize(module.moe_config, + module.quant_config) def dispatch( self, hidden_states: torch.Tensor, diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 13303f94b8e..a05a13f51d4 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -6,10 +6,12 @@ from torch.distributed import ProcessGroup import vllm.envs as envs +from vllm.logger import init_logger -from .all2all import All2AllBase from .base_device_communicator import DeviceCommunicatorBase +logger = init_logger(__name__) + class CudaCommunicator(DeviceCommunicatorBase): @@ -31,8 +33,6 @@ def __init__(self, use_pynccl = "ep" not in unique_name self.use_pynccl = use_pynccl - self.use_all2all = "ep" in unique_name - self.all2all_impl: Optional[All2AllBase] = None self.use_custom_allreduce = use_custom_allreduce # lazy import to avoid documentation build error @@ -56,6 +56,19 @@ def __init__(self, device=self.device, ) + if self.use_all2all: + all2all_backend = envs.VLLM_ALL2ALL_BACKEND + if all2all_backend == "naive": + from .all2all import NaiveAll2AllManager + self.all2all_manager = NaiveAll2AllManager(self.cpu_group) + logger.info("Using naive all2all manager.") + elif all2all_backend == "pplx": + from .all2all import PPLXAll2AllManager + self.all2all_manager = PPLXAll2AllManager(self.cpu_group) + logger.info("Using PPLX all2all manager.") + else: + raise ValueError(f"Unknown all2all backend: {all2all_backend}") + def all_reduce(self, input_): # always try custom allreduce first, # and then pynccl. @@ -136,31 +149,19 @@ def destroy(self): self.pynccl_comm = None if self.ca_comm is not None: self.ca_comm = None - if self.all2all_impl is not None: - self.all2all_impl.destroy() - self.all2all_impl = None - - def prepare_communication_buffer_for_model(self, - model: torch.nn.Module) -> None: - """ - Prepare the communication buffer for the model. - """ - if not self.use_all2all: - return - all2all_backend = envs.VLLM_ALL2ALL_BACKEND - if all2all_backend == "naive": - from .all2all import NaiveAll2All - self.all2all_impl = NaiveAll2All(self.cpu_group, model) + if self.all2all_manager is not None: + self.all2all_manager.destroy() + self.all2all_manager = None def dispatch( self, hidden_states: torch.Tensor, router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - assert self.all2all_impl is not None - hidden_states, router_logits = self.all2all_impl.dispatch( + assert self.all2all_manager is not None + hidden_states, router_logits = self.all2all_manager.dispatch( hidden_states, router_logits) return hidden_states, router_logits def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: - assert self.all2all_impl is not None - hidden_states = self.all2all_impl.combine(hidden_states) + assert self.all2all_manager is not None + hidden_states = self.all2all_manager.combine(hidden_states) return hidden_states diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 51c519d8f86..e1c4e2e26ba 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -23,7 +23,6 @@ """ import contextlib import gc -import importlib.util import pickle import weakref from collections import namedtuple @@ -43,7 +42,7 @@ from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname, - run_once, supports_custom_op) + supports_custom_op) @dataclass @@ -769,10 +768,14 @@ def dispatch( if self.device_communicator is not None: return self.device_communicator.dispatch(hidden_states, router_logits) + else: + return hidden_states, router_logits def combine(self, hidden_states) -> torch.Tensor: if self.device_communicator is not None: return self.device_communicator.combine(hidden_states) + else: + return hidden_states _WORLD: Optional[GroupCoordinator] = None @@ -937,49 +940,9 @@ def init_distributed_environment( "world group already initialized with a different world size") -PPLX_DID_INIT: bool = False - - -@run_once -def pplx_init(rank, world_size): - has_pplx = importlib.util.find_spec("pplx_kernels") is not None - - if has_pplx and world_size > 1: - from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, - nvshmem_get_unique_id, nvshmem_init) - try: - global PPLX_DID_INIT - logger.debug( - "Initialize NVSHMEM for PPLX kernels: rank=%d, " - "world size=%d", rank, world_size) - uid = nvshmem_get_unique_id( - ) if rank == 0 else nvshmem_alloc_empty_unique_id() - uid_gpu = uid.cuda() - get_world_group().broadcast(uid_gpu, src=0) - uid = uid_gpu.to(device='cpu') - logger.debug("PPLX NVSHMEM UID = %s", uid) - nvshmem_init(uid, rank, world_size) - PPLX_DID_INIT = True - except Exception as ex: - logger.error("Failed to initialize NVSHMEM for PPLX: %s", ex) - - -@run_once -def pplx_finalize(): - global PPLX_DID_INIT - if PPLX_DID_INIT: - from pplx_kernels.nvshmem import nvshmem_finalize - logger.debug("PPLX NVSHMEM finalize") - from vllm.model_executor.layers.fused_moe.layer import ( - _all_to_all_cache) - _all_to_all_cache.destroy() - nvshmem_finalize() - - def initialize_model_parallel( tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, - enable_expert_parallel: bool = False, backend: Optional[str] = None, ) -> None: """ @@ -1082,14 +1045,10 @@ def initialize_model_parallel( _DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group, _EP.rank_in_group) - if enable_expert_parallel: - pplx_init(rank, world_size) - def ensure_model_parallel_initialized( tensor_model_parallel_size: int, pipeline_model_parallel_size: int, - enable_expert_parallel: bool = False, backend: Optional[str] = None, ) -> None: """Helper to initialize model parallel groups if they are not initialized, @@ -1100,8 +1059,7 @@ def ensure_model_parallel_initialized( get_world_group().device_group) if not model_parallel_is_initialized(): initialize_model_parallel(tensor_model_parallel_size, - pipeline_model_parallel_size, - enable_expert_parallel, backend) + pipeline_model_parallel_size, backend) return assert ( @@ -1180,8 +1138,6 @@ def destroy_model_parallel(): """Set the groups to none and destroy them.""" global _TP - pplx_finalize() - if _TP: _TP.destroy() _TP = None diff --git a/vllm/envs.py b/vllm/envs.py index dc23c8ea531..7828473c5e1 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -809,6 +809,9 @@ def get_vllm_port() -> Optional[int]: lambda: int(os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5557")), # all2all backend for vllm's expert parallel communication + # Available options: + # - "naive": naive all2all implementation using all-reduce + # - "pplx": use pplx kernels "VLLM_ALL2ALL_BACKEND": lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"), } diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index f1cb77f64ea..14bf37fffe3 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1,12 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 import importlib -import threading from abc import abstractmethod from dataclasses import dataclass from enum import Enum from typing import Callable, Optional -from weakref import WeakValueDictionary import torch import torch.nn.functional as F @@ -74,7 +72,8 @@ class FusedMoEParallelConfig: @property def use_pplx_kernels(self): - return self.dp_size > 1 and self.use_ep and has_pplx + return self.dp_size > 1 and self.use_ep and \ + envs.VLLM_ALL2ALL_BACKEND == "pplx" @staticmethod def make(tp_size_: int, dp_size_: int, @@ -197,6 +196,8 @@ class MoEConfig: # TODO: add more quantization params, blocked, per-token, etc. block_size: int = 128 + max_num_tokens: int = MOE_DP_CHUNK_SIZE + @property def tp_size(self): return self.moe_parallel_config.tp_size @@ -245,13 +246,59 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, params_dtype: torch.dtype, **extra_weight_attrs): raise NotImplementedError - def set_prepare_finalize( - self, - dp_size: int, - world_size: int, - prepare_finalize: FusedMoEPrepareAndFinalize, - ) -> bool: - return False + def init_prepare_finalize(self, moe: MoEConfig, + quant_config: Optional[QuantizationConfig]): + all2all_manager = get_ep_group().device_communicator.all2all_manager + assert all2all_manager is not None + + prepare_finalize = None + if moe.use_pplx_kernels: + all_to_all_args = dict( + max_num_tokens=moe.max_num_tokens, + num_experts=moe.num_experts, + experts_per_token=moe.experts_per_token, # topk + rank=all2all_manager.rank, + world_size=all2all_manager.world_size, + # dp_size actually means tp_size, bug in pplx kernels + dp_size=all2all_manager.tp_group.world_size, + hidden_dim=moe.hidden_dim, + hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, + # For blocked per token: set to + # ceil_div(hidden_dim, block_size) * sizeof(float32) + # For per-token: set to sizeof(float32) + hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else ( + (moe.hidden_dim + moe.block_size - 1) // moe.block_size * + torch.float32.itemsize)), + group_name=all2all_manager.cpu_group.group_name, + ) + + handle = all2all_manager.get_handle(all_to_all_args) + + prepare_finalize = PplxPrepareAndFinalize( + handle, + max_num_tokens=moe.max_num_tokens, + world_size=all2all_manager.world_size, + rank=all2all_manager.rank, + # dp_size actually means tp_size, bug in pplx kernels + dp_size=all2all_manager.tp_group.world_size, + quant_dtype=moe.in_dtype, + ) + + if prepare_finalize is not None: + experts = self.select_gemm_impl(prepare_finalize) + self.fused_experts = FusedMoEModularKernel( + prepare_finalize, + experts, + ) + + def select_gemm_impl( + self, prepare_finalize: Optional[FusedMoEPrepareAndFinalize] + ) -> FusedMoEPermuteExpertsUnpermute: + # based on the all2all implementation, select the appropriate + # gemm implementation + raise NotImplementedError( + "Subclass must select appropriate gemm implementation" + " based on the prepare_finalize") @abstractmethod def apply( @@ -275,53 +322,13 @@ def apply( raise NotImplementedError -class AllToAllCache: - - def __init__(self): - self._cache: WeakValueDictionary = WeakValueDictionary() - self._lock = threading.RLock() # Reentrant lock for thread safety - - def destroy(self): - with self._lock: - # TODO: can we do del self._cache? - for _, a2a in self._cache.items(): - a2a.destroy() - - def get_or_create(self, **kwargs): - assert has_pplx - import pplx_kernels as pplx - - # Create a hashable key from the kwargs - key = tuple(sorted((k, v) for k, v in kwargs.items())) - - with self._lock: - instance = self._cache.get(key) - if instance is None: - # TODO (varun): Add support to switch to intranode - # when all communications are within the same - # node. - logger.debug("Create AllToAll %s", kwargs) - instance = pplx.AllToAll.internode(**kwargs) - self._cache[key] = instance - return instance - - -# Global singleton -_all_to_all_cache = AllToAllCache() - - -# Factory function as a cleaner interface -def get_all_to_all(**kwargs): - return _all_to_all_cache.get_or_create(**kwargs) - - @CustomOp.register("unquantized_fused_moe") class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" def __init__(self, moe: MoEConfig): super().__init__() - self.fused_experts = fused_experts + self.fused_experts = fused_experts # type: ignore self.moe = moe self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() @@ -331,6 +338,42 @@ def __init__(self, moe: MoEConfig): else: self.rocm_aiter_fused_experts = None # type: ignore + def select_gemm_impl( + self, prepare_finalize: Optional[FusedMoEPrepareAndFinalize]): + + assert self.fused_experts == fused_experts + + all2all_manager = get_ep_group().device_communicator.all2all_manager + assert all2all_manager is not None + + experts: Optional[FusedMoEPermuteExpertsUnpermute] = None + + if isinstance(prepare_finalize, + (BatchedPrepareAndFinalize, PplxPrepareAndFinalize)): + logger.debug("BatchedTritonExperts %s", self.moe) + experts = BatchedTritonExperts( + max_num_tokens=MOE_DP_CHUNK_SIZE, + world_size=all2all_manager.world_size, + # dp_size actually means tp_size, bug in pplx kernels + dp_size=all2all_manager.tp_group.world_size, + use_fp8_w8a8=False, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + block_shape=None, + ) + else: + logger.debug("TritonExperts %s", self.moe) + experts = TritonExperts( + use_fp8_w8a8=False, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + block_shape=None, + per_channel_quant=False, + ) + return experts + def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): @@ -430,47 +473,6 @@ def apply( activation=activation, apply_router_weight_on_input=apply_router_weight_on_input) - def set_prepare_finalize( - self, - dp_size: int, - world_size: int, - prepare_finalize: FusedMoEPrepareAndFinalize, - ) -> bool: - assert self.fused_experts == fused_experts - - experts: Optional[FusedMoEPermuteExpertsUnpermute] = None - - if isinstance(prepare_finalize, - (BatchedPrepareAndFinalize, PplxPrepareAndFinalize)): - logger.debug("BatchedTritonExperts %s", self.moe) - experts = BatchedTritonExperts( - max_num_tokens=MOE_DP_CHUNK_SIZE, - world_size=world_size, - dp_size=dp_size, - use_fp8_w8a8=False, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - block_shape=None, - ) - else: - logger.debug("TritonExperts %s", self.moe) - experts = TritonExperts( - use_fp8_w8a8=False, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - block_shape=None, - per_channel_quant=False, - ) - - self.fused_experts = FusedMoEModularKernel( - prepare_finalize, - experts, - ) - - return True - def forward_cuda( self, layer: torch.nn.Module, @@ -680,45 +682,6 @@ def determine_expert_map( return (local_num_experts, expert_map) -def _construct_prepare_finalize( - moe: MoEConfig, quant_config: Optional[QuantizationConfig] -) -> Optional[FusedMoEPrepareAndFinalize]: - max_num_tokens = MOE_DP_CHUNK_SIZE - world_size = moe.ep_size - dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP. - rank = moe.ep_rank - - if moe.use_pplx_kernels: - logger.debug("using PplxPrepareAndFinalize") - - all_to_all = get_all_to_all( - max_num_tokens=max_num_tokens, - num_experts=moe.num_experts, - experts_per_token=moe.experts_per_token, # topk - rank=rank, - world_size=world_size, - dp_size=dp_size, - hidden_dim=moe.hidden_dim, - hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, - # For blocked per token: set to - # ceil_div(hidden_dim, block_size) * sizeof(float32) - # For per-token: set to sizeof(float32) - hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else - ((moe.hidden_dim + moe.block_size - 1) // - moe.block_size * torch.float32.itemsize))) - - return PplxPrepareAndFinalize( - all_to_all, - max_num_tokens=max_num_tokens, - world_size=world_size, - rank=rank, - dp_size=dp_size, - quant_dtype=moe.in_dtype, - ) - - return None - - class FusedMoE(torch.nn.Module): """FusedMoE layer for MoE models. @@ -832,7 +795,10 @@ def __init__( moe_parallel_config=self.moe_parallel_config, # TODO (bnell): this needs to be fixed for quantized types. in_dtype=params_dtype, + max_num_tokens=MOE_DP_CHUNK_SIZE, ) + self.moe_config = moe + self.quant_config = quant_config # Note: get_quant_method will look at the layer's local_num_experts # for heuristic purposes, so it must be initialized first. @@ -840,25 +806,13 @@ def __init__( if quant_config is None: quant_method = UnquantizedFusedMoEMethod(moe) - prepare_finalize = _construct_prepare_finalize(moe, quant_config) else: quant_method = quant_config.get_quant_method(self, prefix) - # No pplx for quantized types yet. - prepare_finalize = None assert quant_method is not None assert isinstance(quant_method, FusedMoEMethodBase) self.quant_method = quant_method - if prepare_finalize is not None: - world_size = moe.ep_size - dp_size = int(moe.ep_size // moe.dp_size) - success = self.quant_method.set_prepare_finalize( - dp_size, world_size, prepare_finalize) - if not success: - logger.warning("DP+EP not supported for %s.", - type(self.quant_method)) - moe_quant_params = { "num_experts": self.local_num_experts, "hidden_size": hidden_size, diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index b1126b94e45..783ebebbfec 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -9,7 +9,6 @@ moe_kernel_quantize_input) -# Note use: layer.get_all_to_all() to get an AllToAll instance # The max_num_tokens, world_size and dp_size must be the same # as the ones used to create the AllToAll. class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index f4cdc3db1a0..652bf76673c 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -10,7 +10,6 @@ from torch.nn.parameter import Parameter import vllm.envs as envs -import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger @@ -461,7 +460,7 @@ def __init__(self, quant_config: Fp8Config): logger.warning_once( "DeepGemm not supported on the current platform.") - self.fused_experts = functools.partial( + self.fused_experts = functools.partial( # type: ignore fused_experts, block_shape=self.quant_config.weight_block_size, allow_deep_gemm=self.allow_deep_gemm) @@ -791,17 +790,12 @@ def process_weights_after_loading(self, layer: Module) -> None: del layer.w13_input_scale del layer.w2_input_scale - def set_prepare_finalize( - self, - dp_size: int, - world_size: int, - prepare_finalize: mk.FusedMoEPrepareAndFinalize, - ) -> bool: + def select_gemm_impl(self, prepare_finalize): from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( TritonOrDeepGemmExperts) - if self.use_marlin or self.rocm_aiter_moe_enabled: - return False + assert not self.use_marlin and not self.rocm_aiter_moe_enabled, ( + "Marlin and ROCm AITER are not supported with all2all yet.") experts = TritonOrDeepGemmExperts( use_fp8_w8a8=True, @@ -809,12 +803,7 @@ def set_prepare_finalize( allow_deep_gemm=self.allow_deep_gemm, ) - self.fused_experts = mk.FusedMoEModularKernel( - prepare_finalize, - experts, - ) - - return True + return experts def apply( self, diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index bdee8b2f821..3b0e77f6e47 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -158,6 +158,7 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: "currently not supported with CUDA Graphs.") vllm_config.model_config.enforce_eager = True compilation_config.use_cudagraph = False + # FIXME: inductor breaks cudagraph (from @bnell) compilation_config.use_inductor = False @classmethod diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 93129d98794..2746eabec0c 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -341,8 +341,7 @@ def init_worker_distributed_environment( distributed_init_method, local_rank) ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size, - parallel_config.enable_expert_parallel) + parallel_config.pipeline_parallel_size) ensure_kv_transfer_initialized(vllm_config) diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index ae3735ab025..fa4eb30ccd9 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -265,8 +265,7 @@ def init_tpu_worker_distributed_environment( backend="gloo", ) ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size, - parallel_config.enable_expert_parallel) + parallel_config.pipeline_parallel_size) try: diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index a92cf1e5a3b..1436a404335 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -390,8 +390,7 @@ def init_distributed_environment(self) -> None: ensure_model_parallel_initialized( parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size, - parallel_config.enable_expert_parallel) + parallel_config.pipeline_parallel_size) def get_cache_block_size_bytes(self) -> int: """Return the size in bytes of a single KV cache block. diff --git a/vllm/worker/hpu_worker.py b/vllm/worker/hpu_worker.py index 42882992f2d..7898c645d66 100644 --- a/vllm/worker/hpu_worker.py +++ b/vllm/worker/hpu_worker.py @@ -416,8 +416,7 @@ def init_worker_distributed_environment( backend='hccl') ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size, - parallel_config.enable_expert_parallel) + parallel_config.pipeline_parallel_size) if torch.distributed.is_initialized(): torch_world_size = torch.distributed.get_world_size() @@ -443,8 +442,7 @@ def init_worker_distributed_environment( torch.distributed.all_reduce(dummy_tensor_hpu) assert dummy_tensor_hpu.item() == parallel_config.world_size ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size, - parallel_config.enable_expert_parallel) + parallel_config.pipeline_parallel_size) def raise_if_cache_size_invalid(num_gpu_blocks, block_size, max_model_len, diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 891ed66599d..4bb9bea022f 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -76,8 +76,7 @@ def init_device(self) -> None: ) ensure_model_parallel_initialized( self.parallel_config.tensor_parallel_size, - self.parallel_config.pipeline_parallel_size, - self.parallel_config.enable_expert_parallel) + self.parallel_config.pipeline_parallel_size) # Device initialization should happen after initializing the distributed # runtime. diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 41546462e5c..17f636765ff 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -530,8 +530,7 @@ def init_worker_distributed_environment( init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank) ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size, - parallel_config.enable_expert_parallel) + parallel_config.pipeline_parallel_size) ensure_kv_transfer_initialized(vllm_config) diff --git a/vllm/worker/xpu_worker.py b/vllm/worker/xpu_worker.py index 65085f80f97..17f53352517 100644 --- a/vllm/worker/xpu_worker.py +++ b/vllm/worker/xpu_worker.py @@ -176,8 +176,7 @@ def init_worker_distributed_environment(self) -> None: ensure_model_parallel_initialized( parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size, - parallel_config.enable_expert_parallel) + parallel_config.pipeline_parallel_size) # global all_reduce needed for overall oneccl warm up torch.distributed.all_reduce(torch.zeros(1).xpu())