Skip to content

Add NeuronxDistributedInference support, Speculative Decoding, Dynamic on-device sampling #16357

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 44 commits into from
May 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
6531f92
[NxDI upstream foundation] set up NxDI model runner
Feb 14, 2025
626850c
Support speculation with transformers-neuronx
sssrijan-amazon Oct 15, 2024
7fcf9b2
Add support for eagle speculation using transformers-neuronx
sssrijan-amazon Oct 15, 2024
3a8e7a5
Support speculation with neuronx-distributed-inference for batch 1
sssrijan-amazon Oct 29, 2024
bf9a4c6
[Tnx] Fix streaming flow for speculation
sssrijan-amazon Oct 29, 2024
a2671aa
[NxdI] Support eagle speculation
sssrijan-amazon Oct 31, 2024
0b075b1
[TNx+EAGLE] Use FusedSpeculativeDecoder for EAGLE + Linear token tree.
chongmni-aws Nov 7, 2024
703efd9
Fix the termination check for accepted speculative tokens.
chongmni-aws Nov 21, 2024
6c04502
[TNx][Bug fix] fix the incorrect speculation output check.
chongmni-aws Nov 22, 2024
31203d7
Add continuous batching with eagle
aws-amulyaab Dec 12, 2024
ae29889
[NxDI] Fix masking of padding in speculative output
aws-patlange Dec 19, 2024
466cd01
Remove assertion on bs=1 when using speculation now that we support bs>1
elaineyz Dec 31, 2024
ac90709
Modify NxDI and TNx multi step model runners (used for speculation) t…
elaineyz Feb 18, 2025
f10f17b
Add multi-step NxD model runner
aws-tailinpa Feb 18, 2025
d123747
Add Framework selection logic
aws-satyajith Feb 18, 2025
547709a
Fix no free blocks error
aws-patlange Dec 21, 2024
e91857a
Refactor and add basic docstrings
elaineyz Feb 20, 2025
70ad6a9
Modification to enable Vllm-neuronx instead of Vllm for KTF
aws-navyadhara Feb 21, 2025
2df94fe
Fix global_top_k to be aligned with NxDI default
aws-yishanm Jan 10, 2025
da8d1cf
Add neuron model runner tests for updating sampling param
chongmni-aws Oct 24, 2024
498a918
Updating requirements-neuron.txt
aws-navyadhara Feb 21, 2025
8bc6537
Updating requirements-neuron.txt
aws-navyadhara Feb 21, 2025
4907e3e
multi-node TP support
Feb 21, 2025
248b708
Removing codenames that fail IP Scanning
aws-navyadhara Feb 24, 2025
4d3c6ae
Format auto-check and formatting changes
aws-satyajith Feb 24, 2025
51f403f
Bug fix: Missing NxDI model runner addressed
aws-satyajith Feb 24, 2025
543f55d
set world_size default value as 1
Feb 24, 2025
d15c20b
Formatting changes to satisfy all pre-commit hooks
aws-satyajith Feb 25, 2025
49d7558
Revert "Removing codenames that fail IP Scanning"
aws-navyadhara Feb 26, 2025
87eeb3c
Fix issues identified by mypy checks
aws-satyajith Feb 26, 2025
39ad22c
Fix logging strings
aws-satyajith Feb 26, 2025
53e821a
add example offline script for EAGLE spec
elaineyz Feb 28, 2025
aaa9f17
Merge branch 'upstreaming_main' into upstream-neuron-vllm-04-08
aws-satyajith Apr 9, 2025
5f4eb2f
Add speculative token tree and skip EAGLEConfig creation for Neuron
aws-satyajith Apr 9, 2025
d04c552
Fix imports and satisfy pre-commit hooks
aws-satyajith Apr 9, 2025
ea4e9c7
Remove deprecated files and .gitignore additions
aws-satyajith Apr 9, 2025
7c86b26
Merge branch 'upstreaming_main' into upstream-neuron-vllm-04-08
aws-satyajith Apr 9, 2025
d34883c
Merge branch 'main' into upstream-neuron-vllm-04-08
aws-satyajith Apr 22, 2025
d762abd
Add docstring for speculative_token_tree
aws-satyajith Apr 22, 2025
9d49902
Merge branch 'upstreaming_main' into upstream-neuron-vllm-04-08
aws-satyajith Apr 28, 2025
a31715c
Merge branch 'vllm-project:main' into upstream-neuron-vllm-04-08
mrinalks May 2, 2025
fa055e5
Remove multi-node support. Remove num_lookahead_slots exception for s…
aws-satyajith May 2, 2025
8092f4b
Merge branch 'vllm-project:main' into upstream-neuron-vllm-04-08
aws-satyajith May 3, 2025
e298c41
Modify neuron speculative decoding examples to use latest speculative…
aws-satyajith May 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions examples/offline_inference/neuron_eagle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# SPDX-License-Identifier: Apache-2.0
"""
This example shows how to run offline inference with an EAGLE speculative
decoding model on neuron. To use EAGLE speculative decoding, you must use
a draft model that is specifically fine-tuned for EAGLE speculation.
Additionally, to use EAGLE with NxD Inference, the draft model must include
the LM head weights from the target model. These weights are shared between
the draft and target model.
"""

from vllm import LLM, SamplingParams

# Sample prompts.
prompts = [
"What is annapurna labs?",
]

# Create a sampling params object.
sampling_params = SamplingParams(top_k=1, max_tokens=500, ignore_eos=True)

# Create an LLM.
llm = LLM(
model="/home/ubuntu/model_hf/Meta-Llama-3.1-70B-Instruct",
speculative_config={
"model": "/home/ubuntu/model_hf/Llama-3.1-70B-Instruct-EAGLE-Draft",
"num_speculative_tokens": 5,
"max_model_len": 2048
},
max_num_seqs=4,
# The max_model_len and block_size arguments are required to be same as
# max sequence length when targeting neuron device.
# Currently, this is a known limitation in continuous batching support
# in neuronx-distributed-inference.
max_model_len=2048,
block_size=2048,
# The device can be automatically detected when AWS Neuron SDK is installed.
# The device argument can be either unspecified for automated detection,
# or explicitly assigned.
device="neuron",
tensor_parallel_size=32,
override_neuron_config={
"enable_eagle_speculation": True,
"enable_fused_speculation": True
},
)

# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, \n\n\n\ Generated text: {generated_text!r}")
64 changes: 64 additions & 0 deletions examples/offline_inference/neuron_speculation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# SPDX-License-Identifier: Apache-2.0
"""
This example shows how to run offline inference with a speculative
decoding model on neuron.
"""

import os

from vllm import LLM, SamplingParams

# Sample prompts.
prompts = [
"Hello, I am a language model and I can help",
"The president of the United States is",
"The capital of France is",
]


def config_buckets():
"""Configure context length and token gen buckets."""
# creates XLA hlo graphs for all the context length buckets.
os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048"
# creates XLA hlo graphs for all the token gen buckets.
os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048"


def initialize_model():
"""Create an LLM with speculative decoding."""
return LLM(
model="openlm-research/open_llama_7b",
speculative_config={
"model": "openlm-research/open_llama_3b",
"num_speculative_tokens": 4,
"max_model_len": 2048
},
max_num_seqs=4,
max_model_len=2048,
block_size=2048,
use_v2_block_manager=True,
device="neuron",
tensor_parallel_size=32,
)


def process_requests(model: LLM, sampling_params: SamplingParams):
"""Generate texts from prompts and print them."""
outputs = model.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


def main():
"""Main function that sets up the model and processes prompts."""
config_buckets()
model = initialize_model()
# Create a sampling params object.
sampling_params = SamplingParams(max_tokens=100, top_k=1)
process_requests(model, sampling_params)


if __name__ == '__main__':
main()
3 changes: 2 additions & 1 deletion requirements/neuron.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
packaging>=24.2
setuptools>=77.0.3,<80.0.0
torch-neuronx >= 2.5.0
neuronx-cc
neuronx-cc>=2.0.0a0
torchvision # Required for Llama3.2 multimodal image preprocessing
126 changes: 126 additions & 0 deletions tests/neuron/1_core/test_neuron_model_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# SPDX-License-Identifier: Apache-2.0
import os
from unittest.mock import MagicMock

from vllm.config import VllmConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.platforms.neuron import NeuronFramework
from vllm.sampling_params import SamplingParams
from vllm.sequence import SequenceData, SequenceGroupMetadata
from vllm.worker.neuron_model_runner import NeuronModelRunner

os.environ[
'VLLM_NEURON_FRAMEWORK'] = NeuronFramework.TRANSFORMERS_NEURONX.value


def _create_neuron_model_runner(model: str, *args,
**kwargs) -> NeuronModelRunner:
engine_args = EngineArgs(model, *args, **kwargs)
engine_config = engine_args.create_engine_config()
vllm_config = VllmConfig(
model_config=engine_config.model_config,
parallel_config=engine_config.parallel_config,
scheduler_config=engine_config.scheduler_config,
device_config=engine_config.device_config,
)
neuron_model_runner = NeuronModelRunner(vllm_config=vllm_config)
return neuron_model_runner


def test_update_neuron_sampling_params_not_full_batch():
os.environ["NEURON_ON_DEVICE_SAMPLING_DISABLED"] = "0"
model_runner = _create_neuron_model_runner(
"facebook/opt-125m",
seed=0,
dtype="float16",
max_num_seqs=2,
)
assert not model_runner._on_device_sampling_disabled
# Test sampling param updating only when TNx is framework
# NxDI handles sampling parameter updating inside model
if current_platform.use_transformers_neuronx():
model_mock = MagicMock()
model_runner.model = model_mock

seq_group_metadata_list = [
SequenceGroupMetadata(
request_id="test_0",
is_prompt=True,
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
sampling_params=SamplingParams(temperature=0.5,
top_k=1,
top_p=0.5),
block_tables={0: [1]},
)
]

model_runner.prepare_model_input(seq_group_metadata_list)

# Index neuron sampling parameters based on block_tables indices.
# The first block_id of the sequence 0 is 1, so its parameters are
# placed at index 1. So the sampling parameters will be:
# Index 0: default sampling parameters
# Index 1: sequecne 0's sampling parameters.
neuron_sampling_params = (
model_runner.model_config.neuron_sampling_params)
assert neuron_sampling_params.temperature == [1.0, 0.5]
assert neuron_sampling_params.top_k == [
model_runner._MAX_NEURON_SAMPLING_TOP_K, 1
]
assert neuron_sampling_params.top_p == [1.0, 0.5]
model_mock.model.update_generation_config.assert_called_once_with(
neuron_sampling_params)


def test_update_neuron_sampling_params_full_batch():
os.environ["NEURON_ON_DEVICE_SAMPLING_DISABLED"] = "0"
model_runner = _create_neuron_model_runner(
"facebook/opt-125m",
seed=0,
dtype="float16",
max_num_seqs=2,
)
assert not model_runner._on_device_sampling_disabled

# Test sampling param updating only when TNx is framework
# NxDI handles sampling parameter updating inside model
if current_platform.use_transformers_neuronx():
model_mock = MagicMock()
model_runner.model = model_mock

seq_group_metadata_list = [
SequenceGroupMetadata(
request_id="test_0",
is_prompt=True,
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
sampling_params=SamplingParams(temperature=0.5,
top_k=1,
top_p=0.5),
block_tables={0: [1]},
),
SequenceGroupMetadata(
request_id="test_0",
is_prompt=True,
seq_data={1: SequenceData.from_seqs([4, 5, 6])},
sampling_params=SamplingParams(temperature=0.2,
top_k=2,
top_p=0.2),
block_tables={1: [0]},
)
]

model_runner.prepare_model_input(seq_group_metadata_list)

# Index neuron sampling parameters based on block_tables indices.
# The first block_id of the sequence 0 is 1, so its parameters are
# placed at index 1. So the sampling parameters will be:
# Index 0: sequence 1's sampling parameters
# Index 1: sequecne 0's sampling parameters.
neuron_sampling_params = (
model_runner.model_config.neuron_sampling_params)
assert neuron_sampling_params.temperature == [0.2, 0.5]
assert neuron_sampling_params.top_k == [2, 1]
assert neuron_sampling_params.top_p == [0.2, 0.5]
model_mock.model.update_generation_config.assert_called_once_with(
neuron_sampling_params)
6 changes: 5 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2269,6 +2269,9 @@ class SpeculativeConfig:
"""Scaling factor for entropy-based threshold, applied when using
`TypicalAcceptanceSampler`."""

speculative_token_tree: Optional[str] = None
"""Specifies the tree structure for speculative token generation.
"""
# required configuration params passed from engine
target_model_config: ModelConfig = field(default=None,
init=True) # type: ignore
Expand Down Expand Up @@ -2443,10 +2446,11 @@ def __post_init__(self):
"Chunked prefill and EAGLE are not compatible "
"when using V0.")

from vllm.platforms import current_platform
from vllm.transformers_utils.configs.eagle import (
EAGLEConfig)
if isinstance(self.draft_model_config.hf_config,
EAGLEConfig):
EAGLEConfig) or current_platform.is_neuron():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we eliminate the device-specific changes in vllm/config.py ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 we should move this neuron config to override.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrinalks We currently use the self.draft_model_config.hf_config in multiple places. Moving this to override_neuron_config would mean deviating from the existing flow and would require comprehensive re-testing to ensure we didn't miss any parameters.

@liangfu EAGLEConfig was not present when we implemented EAGLE support in Neuron, hence, I put in the exception. I'll take a look if we can remove this exception for neuron.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a valid change that we'll implement and test internally first. We will address supporting EAGLEConfig as a follow-up commit. I'm adding a comment on RFC #15970 to keep track of this change.

pass
else:
eagle_config = EAGLEConfig(
Expand Down
6 changes: 2 additions & 4 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,10 +399,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
self.scheduler,
self.seq_counter,
get_tokenizer_for_seq,
stop_checker=StopChecker(
self.scheduler_config.max_model_len,
get_tokenizer_for_seq,
),
stop_checker=StopChecker(self.scheduler_config.max_model_len,
get_tokenizer_for_seq),
))

self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}
Expand Down
Loading