Skip to content

Refactor pplx init logic to make it modular (prepare for deepep) #18200

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 32 commits into
base: main
Choose a base branch
from

Conversation

youkaichao
Copy link
Member

@youkaichao youkaichao commented May 15, 2025

follow-up after #15956 , refactor pplx-related logic to make it modular.

Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added v1 tpu Related to Google TPUs labels May 15, 2025
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
@youkaichao youkaichao changed the title Refactor pplx init logic Refactor pplx init logic to make it more moduler and prepare for deepep May 15, 2025
@youkaichao youkaichao changed the title Refactor pplx init logic to make it more moduler and prepare for deepep Refactor pplx init logic to make it modular (prepare for deepep) May 15, 2025
@youkaichao youkaichao marked this pull request as ready for review May 15, 2025 11:35
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
@simon-mo simon-mo requested a review from tlrmchlsmth May 15, 2025 16:15
@simon-mo
Copy link
Collaborator

@varun-sundar-rabindranath @bnellnm please help review, thx!

@@ -158,7 +158,6 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
"currently not supported with CUDA Graphs.")
vllm_config.model_config.enforce_eager = True
compilation_config.use_cudagraph = False
compilation_config.use_inductor = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing this now will break things if eager mode is not used.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although, vllm_config.model_config.enforce_eager = True can be removed. I didn't want to land the PR with that just in case there were other issues.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does vllm_config.model_config.enforce_eager = True cover us?

Actually do things break if we set --enforce-eager but also make torch.compile happen with the compilation config?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Afaict, if the inductor is on, then it'll break no matter what other options are set. But, cudagraphs + eager backend work just fine.

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's be sure that removing the compilation_config.use_inductor = False doesn't break anything - otherwise lgtm

Signed-off-by: youkaichao <youkaichao@gmail.com>
@youkaichao youkaichao requested a review from mgoin as a code owner May 16, 2025 02:33
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
@youkaichao youkaichao marked this pull request as draft May 16, 2025 03:05
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
@youkaichao
Copy link
Member Author

youkaichao commented May 16, 2025

The procedure now:

during distributed environment initialization

only for EP group with expert parallel enabled, cuda communicator creates the manager based on VLLM_ALL2ALL_BACKEND, initializes nvshmem if necessary.

after model is created

  1. model runner calls prepare_communication_buffer_for_model
  2. EP group's prepare_communication_buffer_for_model calls init_prepare_finalize on every MoE layer's quant_method
  3. init_prepare_finalize accepts moe_config and quant_config, and it can call get_ep_group().device_communicator.all2all_manager.get_handle to get all2all handle, create prepare_finalize object, and call select_gemm_impl to select the gemm implementation, finally assemble the FusedMoEModularKernel.

Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Comment on lines 156 to 167
def dispatch(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
assert self.all2all_impl is not None
hidden_states, router_logits = self.all2all_impl.dispatch(
assert self.all2all_manager is not None
hidden_states, router_logits = self.all2all_manager.dispatch(
hidden_states, router_logits)
return hidden_states, router_logits

def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
assert self.all2all_impl is not None
hidden_states = self.all2all_impl.combine(hidden_states)
assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine(hidden_states)
return hidden_states
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would these methods work if we weren't using the naive manager? e.g. the pplx all2all object might have a different instance for each layer.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes so the dispatch/combine in the DeviceCommunicatorBase is not used for pplx kernel, and I agree with your prepare/finalize call inside every layer now. I will try to remove dispatch/combine in the DeviceCommunicatorBase.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will try to remove dispatch/combine in the DeviceCommunicatorBase

this would be in a future PR, and we need to have prepare_finalize for the naive all2all implementation, then we can remove these functions in DeviceCommunicatorBase

Comment on lines 294 to 299
def select_gemm_impl(
self, prepare_finalize: Optional[FusedMoEPrepareAndFinalize]
) -> FusedMoEPermuteExpertsUnpermute:
# based on the all2all implementation, select the appropriate
# gemm implementation
raise NotImplementedError
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there should be some sort of error logging here so it's obvious that the combination of pplx + particular MoE implementation is not supported. Or maybe in init_prepare_finalize?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why here? this is the base class, and it just shows the interface. selection of gemm kernels for pplx is at UnquantizedFusedMoEMethod.select_gemm_impl

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't necessarily have to be here but it would be nice to get more than a NotImplementedError exception.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added comments in 5b4095b

Signed-off-by: youkaichao <youkaichao@gmail.com>
Copy link

mergify bot commented May 17, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @youkaichao.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label May 17, 2025
@mergify mergify bot removed the needs-rebase label May 17, 2025
@youkaichao youkaichao marked this pull request as ready for review May 18, 2025 17:39
@youkaichao youkaichao enabled auto-merge (squash) May 19, 2025 03:05
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label May 19, 2025
Signed-off-by: youkaichao <youkaichao@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed tpu Related to Google TPUs v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants