Skip to content

Add NeuronxDistributedInference support, Speculative Decoding, Dynamic on-device sampling #16357

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 44 commits into from
May 7, 2025

Conversation

aws-satyajith
Copy link
Contributor

@aws-satyajith aws-satyajith commented Apr 9, 2025

Add the following Neuron features as a part of RFC #15970 :

  1. NeuronX Distributed (NxD) Inference Support

    1. Allow customers to select a framework based on preference or availability. Default to neuronx-distributed-inference (NxD); if unavailable, fall back to transformers-neuronx (TNx).
    2. Support inference using NxD by adding a worker/neuronx_distributed_model_runner.py
    3. Add framework detection utility that returns the current framework in use.
  2. Speculative Decoding

    1. To enable speculative decoding with NxD, we added worker/multi_step_neuronx_distributed_model_runner.py.
    2. To enable speculative decoding with TnX, we added worker/multi_step_neuron_model_runner.py. This model runner is chosen in neuron_worker.py if speculation is enabled.
  3. Dynamic On-device Sampling

    1. Extract the sampling params (top_k, top_p, temperature) and add them to execute_model().
  4. Multi-node Tensor Parallelism Inference

    1. The communication between master and worker nodes happens at two layers,
      the control plane layer for metadata communication (i.e., the input prompts) from master node to work nodes. Specifically, enable do_metadata_broadcast, while supplying conversion methods from ModelInputForNeuron to broadcast-able dictionary and vice versa in neuron_model_runner.py.
      the Neuron backend (i.e., NxD, TNx) is doing all the collectives operations in model forward.
      Examples of usage can be found in [examples/neuron/multi_node](https://github.com/aws-neuron/upstreaming-to-
    2. vllm/tree/neuron-2.22-vllm-v0.7.2/examples/neuron/multi_node)

RFC: #15970

Note: This is a fixed version of #16043. The following issues have been fixed in this revision:

  1. Eliminated unnecessary commits.
  2. Signed off each commit.
  3. Resolved merge conflicts that arose in the past few days.

Aaron Dou and others added 30 commits April 9, 2025 16:49
Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
…o conform to the new VllmConfig construct

Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
This reverts commit b5140f5.

Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
Copy link

github-actions bot commented Apr 9, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added documentation Improvements or additions to documentation ci/build labels Apr 9, 2025
# Conflicts:
#	vllm/platforms/neuron.py

Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
@aws-satyajith
Copy link
Contributor Author

@liangfu @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill

Could you please review this PR? There is an RFC related to this as well for additional context and discussion: #15970

@robertgshaw2-redhat robertgshaw2-redhat self-assigned this Apr 18, 2025
Copy link
Contributor

@liangfu liangfu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @aws-satyajith for contributing. I left a few comment in the PR.

These comments are specifically concerning:

  • isolate num_lookahead_slots related change to a separate PR
  • move get_neuron_framework_to_use function from vllm/worker/utils.py to vllm/platform/neuron.py
  • isolate multi-node example to a separate PR, since the behavior does not seem to be consistent with other hardware backends.
  • to avoid environment variable type errors, it would be better to define environment variable types in vllm/envs.py with VLLM_NEURON_ prefix

In addition, i feel there are a set of features/components that are bundled in this PR.
I propose to break down the bundled PR into a few individual components/features.

1/ introduce neuronx-distributed-inference as a dependency for the neuron backend, and replace the existing transformers-neuronx based implementation (for simplicity), with a basic test to ensure the integration does not break in the future.
2/ add on-device sampling support, with test script
3/ add speculative decoding support, with test script

If we do not remove transformers-neuronx based implementation, there would be:
a/ 4 model_runner scripts in the worker directory.
b/ the behavior of the two packages (transformers-neuronx and neuronx-distributed-inference) may or may not be consistent across different features and configurations.


# Use mpirun to trigger inference on head/worker nodes

/opt/amazon/openmpi/bin/mpirun \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar to #8692

i would propose to be consistent with eco-system and leverage interfaces described in https://docs.vllm.ai/en/latest/serving/distributed_serving.html

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for linking the PR and docs. I checked them out and read the associated comments. In summary:

  1. You would propose that we should be able to perform multi-node inference using
vllm serve /path/to/the/model/in/the/container \
--tensor-parallel-size 128
  1. We should document the -x fields below and explain the reason for the setup.

Both these require re-working the feature significantly. The previous PR was closed out as the above pending items were not addressed. Is this a good understanding of the situation?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, that's a good understanding.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I discussed with @mrinalks and we decided to not include multi-node support in this PR as this needs to be heavily reworked. I'll remove all the multi-node specific code in the following revision.


# Create an LLM.
llm = LLM(
model=TARGET_MODEL_PATH,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stay consistent with other offline script?

e.g.

def main():
    # ... some details ...

    llm = LLM(
        model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
        ...
    )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just confirming: Do you mean, we should change the model to TinyLlama? Or do you mean we should remove the constants TARGET_MODEL_PATH and directly use the string in-line?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's the later, since demonstrate eagle isn't going to be feasible with TinyLlama.

from vllm.transformers_utils.configs.eagle import (
EAGLEConfig)
if isinstance(self.draft_model_config.hf_config,
EAGLEConfig):
EAGLEConfig) or current_platform.is_neuron():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we eliminate the device-specific changes in vllm/config.py ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 we should move this neuron config to override.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrinalks We currently use the self.draft_model_config.hf_config in multiple places. Moving this to override_neuron_config would mean deviating from the existing flow and would require comprehensive re-testing to ensure we didn't miss any parameters.

@liangfu EAGLEConfig was not present when we implemented EAGLE support in Neuron, hence, I put in the exception. I'll take a look if we can remove this exception for neuron.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a valid change that we'll implement and test internally first. We will address supporting EAGLEConfig as a follow-up commit. I'm adding a comment on RFC #15970 to keep track of this change.

Comment on lines 393 to 396
if self.device_config.device_type == "neuron":
num_lookahead_slots = self.scheduler_config.num_lookahead_slots
else:
num_lookahead_slots = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we eliminate the device-specific changes in vllm/engine/llm_engine.py ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can probably move this into override_neuron_config. I'll look into making that change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for bringing this up @liangfu. I looked into this a little deeper and this has potential implications for non-neuron workflows. Some notes:

num_lookahead_slots is a part of self.scheduler_config and we’re setting it to 0 for all non-neuron cases irrespective of what the actual value is.
This is problematic for two reasons:

  1. If someone uses the num_lookahead_slots in the StopChecker in the future for non-neuron workflow, they will not see the value from scheduler_config.num_lookahead_slots . Instead they’ll see 0.
  2. We want to avoid hardware dependent exceptions for StopChecker as pointed out by you in the parent comment.

I'll check further and see if we can remove this dependency altogether or come up with a non-impacting solution.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed this further with @elaineyz and identified steps forward to address both the above points. I'll need some additional time to address this but I'll include changes to this in the next revision.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed with Liangfu to remove num_lookahead_slots from llm_engine and stop_checker. This will impact customers using ignore_eos = True. We will address this issue separately as a part of RFC #15970

Leaving a comment on the RFC as well for tracking.

Comment on lines 52 to 53
self.rank = int(os.getenv("NEURON_RANK_ID",
DEFAULT_NEURON_RANK_ID))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we would like to keep these environment variables, i think it's better to be consistent with environment_variables in vllm/envs.py

similar to

    VLLM_ROCM_USE_AITER: bool = False
    VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False
    VLLM_ROCM_USE_AITER_LINEAR: bool = True
    VLLM_ROCM_USE_AITER_MOE: bool = True
    VLLM_ROCM_USE_AITER_RMSNORM: bool = True
    VLLM_ROCM_USE_AITER_MLA: bool = True
    VLLM_ROCM_USE_SKINNY_GEMM: bool = True
    VLLM_ROCM_FP8_PADDING: bool = True
    VLLM_ROCM_MOE_PADDING: bool = True
    VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True

we may introduce the prefix VLLM_NEURON_ for neuron-specific environment variables.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Comment on lines 45 to 47
self.enable_neuron_multi_node = (os.getenv(
"ENABLE_NEURON_MULTI_NODE",
DEFAULT_ENABLE_NEURON_MULTI_NODE).lower() == "true")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's error-prone to lower string and compare with "true".

i propose to define environment-variable types in https://github.com/vllm-project/vllm/blob/main/vllm/envs.py

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Contributor

@liangfu liangfu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's hard to move forward as is, since there are quite a lot of code-changes bundled in this PR.
In addition, the community intend to "fully remove V0" as part of the Q2 plan (#15735).

@robertgshaw2-redhat @simon-mo @comaniac what do you think?

@comaniac
Copy link
Collaborator

I think it's hard to move forward as is, since there are quite a lot of code-changes bundled in this PR. In addition, the community intend to "fully remove V0" as part of the Q2 plan (#15735).

@robertgshaw2-redhat @simon-mo @comaniac what do you think?

Sorry for over looking and I second @liangfu on this. Specifically we are trying to freeze and deprecate v0 in Q2 if possible, so we are not in favor of taking large changes other than bug fixes in v0. We however would be happy to discuss an RFC for Neuron integration in v1. Also cc @WoosukKwon

@mrinalks
Copy link
Contributor

I think it's hard to move forward as is, since there are quite a lot of code-changes bundled in this PR. In addition, the community intend to "fully remove V0" as part of the Q2 plan (#15735).

@robertgshaw2-redhat @simon-mo @comaniac what do you think?
Sorry for over looking and I second @liangfu on this. Specifically we are trying to freeze and deprecate v0 in Q2 if possible, > so we are not in favor of taking large changes other than bug fixes in v0. We however would be happy to discuss an RFC for Neuron integration in v1. Also cc @WoosukKwon

We have agreement from @robertgshaw2-redhat (+simon-mo and +woosuk I believe as well..) to merge NxDI V0 changes as a last RFC to vLLM and then grabbing a snapshot of the final changes for our customers for longer-term support of V0 for those customers that need more time to migrate to V1 due to performance reasons on Neuron.

Meanwhile we are also prepping a Q2 Roadmap to align with vLLM Maintainers (I believe @simon-mo) asked for it a few weeks back which focuses entirely on V1 Efforts. All in all we are aligned V0 is getting deprecated.. we just need our changes merged/functional and archived so we can cleanly switch over to V1 as we work towards V1 Architecture + Neuron (performant drops). More on this in the next couple of weeks!

mrinalks and others added 2 commits May 2, 2025 14:12
…top checker. Other PR specific changes.

Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
Copy link
Contributor

@liangfu liangfu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @aws-satyajith for the update. I'm okay with the proposed change. I look forward to working together on the follow-up RFC for NxDI V1 support.

@mrinalks
Copy link
Contributor

mrinalks commented May 7, 2025

@simon-mo @WoosukKwon @robertgshaw2-redhat could you please review or approve this pull request so we can merge our final V0 RFC.
per your earlier request, @liangfu has also reviewed and blessed the pull request.

@simon-mo simon-mo merged commit 043e4c4 into vllm-project:main May 7, 2025
25 checks passed
@mrinalks
Copy link
Contributor

mrinalks commented May 7, 2025

Thanks @simon-mo 🥳

@aws-satyajith aws-satyajith deleted the upstream-neuron-vllm-04-08 branch May 7, 2025 17:31
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
…c on-device sampling (vllm-project#16357)

Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
Co-authored-by: Aaron Dou <yzdou@amazon.com>
Co-authored-by: Shashwat Srijan <sssrijan@amazon.com>
Co-authored-by: Chongming Ni <chongmni@amazon.com>
Co-authored-by: Amulya Ballakur <amulyaab@amazon.com>
Co-authored-by: Patrick Lange <patlange@amazon.com>
Co-authored-by: Elaine Zhao <elaineyz@amazon.com>
Co-authored-by: Lin Lin Pan <tailinpa@amazon.com>
Co-authored-by: Navyadhara Gogineni <navyadha@amazon.com>
Co-authored-by: Yishan McNabb <yishanm@amazon.com>
Co-authored-by: Mrinal Shukla <181322398+mrinalks@users.noreply.github.com>
Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
mawong-amd pushed a commit to ROCm/vllm that referenced this pull request May 14, 2025
…c on-device sampling (vllm-project#16357)

Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
Co-authored-by: Aaron Dou <yzdou@amazon.com>
Co-authored-by: Shashwat Srijan <sssrijan@amazon.com>
Co-authored-by: Chongming Ni <chongmni@amazon.com>
Co-authored-by: Amulya Ballakur <amulyaab@amazon.com>
Co-authored-by: Patrick Lange <patlange@amazon.com>
Co-authored-by: Elaine Zhao <elaineyz@amazon.com>
Co-authored-by: Lin Lin Pan <tailinpa@amazon.com>
Co-authored-by: Navyadhara Gogineni <navyadha@amazon.com>
Co-authored-by: Yishan McNabb <yishanm@amazon.com>
Co-authored-by: Mrinal Shukla <181322398+mrinalks@users.noreply.github.com>
zzzyq pushed a commit to zzzyq/vllm that referenced this pull request May 24, 2025
…c on-device sampling (vllm-project#16357)

Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
Co-authored-by: Aaron Dou <yzdou@amazon.com>
Co-authored-by: Shashwat Srijan <sssrijan@amazon.com>
Co-authored-by: Chongming Ni <chongmni@amazon.com>
Co-authored-by: Amulya Ballakur <amulyaab@amazon.com>
Co-authored-by: Patrick Lange <patlange@amazon.com>
Co-authored-by: Elaine Zhao <elaineyz@amazon.com>
Co-authored-by: Lin Lin Pan <tailinpa@amazon.com>
Co-authored-by: Navyadhara Gogineni <navyadha@amazon.com>
Co-authored-by: Yishan McNabb <yishanm@amazon.com>
Co-authored-by: Mrinal Shukla <181322398+mrinalks@users.noreply.github.com>
Signed-off-by: Yuqi Zhang <yuqizhang@google.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build documentation Improvements or additions to documentation
Projects
None yet
Development

Successfully merging this pull request may close these issues.