Skip to content

Commit 459ad86

Browse files
aws-satyajithAaron Dousssrijan-amazonchongmni-awsaws-amulyaab
authored and
Yuqi Zhang
committed
Add NeuronxDistributedInference support, Speculative Decoding, Dynamic on-device sampling (vllm-project#16357)
Signed-off-by: Satyajith Chilappagari <satchill@amazon.com> Co-authored-by: Aaron Dou <yzdou@amazon.com> Co-authored-by: Shashwat Srijan <sssrijan@amazon.com> Co-authored-by: Chongming Ni <chongmni@amazon.com> Co-authored-by: Amulya Ballakur <amulyaab@amazon.com> Co-authored-by: Patrick Lange <patlange@amazon.com> Co-authored-by: Elaine Zhao <elaineyz@amazon.com> Co-authored-by: Lin Lin Pan <tailinpa@amazon.com> Co-authored-by: Navyadhara Gogineni <navyadha@amazon.com> Co-authored-by: Yishan McNabb <yishanm@amazon.com> Co-authored-by: Mrinal Shukla <181322398+mrinalks@users.noreply.github.com> Signed-off-by: Yuqi Zhang <yuqizhang@google.com>
1 parent c6c10e6 commit 459ad86

15 files changed

+1623
-102
lines changed
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
This example shows how to run offline inference with an EAGLE speculative
4+
decoding model on neuron. To use EAGLE speculative decoding, you must use
5+
a draft model that is specifically fine-tuned for EAGLE speculation.
6+
Additionally, to use EAGLE with NxD Inference, the draft model must include
7+
the LM head weights from the target model. These weights are shared between
8+
the draft and target model.
9+
"""
10+
11+
from vllm import LLM, SamplingParams
12+
13+
# Sample prompts.
14+
prompts = [
15+
"What is annapurna labs?",
16+
]
17+
18+
# Create a sampling params object.
19+
sampling_params = SamplingParams(top_k=1, max_tokens=500, ignore_eos=True)
20+
21+
# Create an LLM.
22+
llm = LLM(
23+
model="/home/ubuntu/model_hf/Meta-Llama-3.1-70B-Instruct",
24+
speculative_config={
25+
"model": "/home/ubuntu/model_hf/Llama-3.1-70B-Instruct-EAGLE-Draft",
26+
"num_speculative_tokens": 5,
27+
"max_model_len": 2048
28+
},
29+
max_num_seqs=4,
30+
# The max_model_len and block_size arguments are required to be same as
31+
# max sequence length when targeting neuron device.
32+
# Currently, this is a known limitation in continuous batching support
33+
# in neuronx-distributed-inference.
34+
max_model_len=2048,
35+
block_size=2048,
36+
# The device can be automatically detected when AWS Neuron SDK is installed.
37+
# The device argument can be either unspecified for automated detection,
38+
# or explicitly assigned.
39+
device="neuron",
40+
tensor_parallel_size=32,
41+
override_neuron_config={
42+
"enable_eagle_speculation": True,
43+
"enable_fused_speculation": True
44+
},
45+
)
46+
47+
# Generate texts from the prompts. The output is a list of RequestOutput objects
48+
# that contain the prompt, generated text, and other information.
49+
outputs = llm.generate(prompts, sampling_params)
50+
# Print the outputs.
51+
for output in outputs:
52+
prompt = output.prompt
53+
generated_text = output.outputs[0].text
54+
print(f"Prompt: {prompt!r}, \n\n\n\ Generated text: {generated_text!r}")
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
This example shows how to run offline inference with a speculative
4+
decoding model on neuron.
5+
"""
6+
7+
import os
8+
9+
from vllm import LLM, SamplingParams
10+
11+
# Sample prompts.
12+
prompts = [
13+
"Hello, I am a language model and I can help",
14+
"The president of the United States is",
15+
"The capital of France is",
16+
]
17+
18+
19+
def config_buckets():
20+
"""Configure context length and token gen buckets."""
21+
# creates XLA hlo graphs for all the context length buckets.
22+
os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048"
23+
# creates XLA hlo graphs for all the token gen buckets.
24+
os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048"
25+
26+
27+
def initialize_model():
28+
"""Create an LLM with speculative decoding."""
29+
return LLM(
30+
model="openlm-research/open_llama_7b",
31+
speculative_config={
32+
"model": "openlm-research/open_llama_3b",
33+
"num_speculative_tokens": 4,
34+
"max_model_len": 2048
35+
},
36+
max_num_seqs=4,
37+
max_model_len=2048,
38+
block_size=2048,
39+
use_v2_block_manager=True,
40+
device="neuron",
41+
tensor_parallel_size=32,
42+
)
43+
44+
45+
def process_requests(model: LLM, sampling_params: SamplingParams):
46+
"""Generate texts from prompts and print them."""
47+
outputs = model.generate(prompts, sampling_params)
48+
for output in outputs:
49+
prompt = output.prompt
50+
generated_text = output.outputs[0].text
51+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
52+
53+
54+
def main():
55+
"""Main function that sets up the model and processes prompts."""
56+
config_buckets()
57+
model = initialize_model()
58+
# Create a sampling params object.
59+
sampling_params = SamplingParams(max_tokens=100, top_k=1)
60+
process_requests(model, sampling_params)
61+
62+
63+
if __name__ == '__main__':
64+
main()

requirements/neuron.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@
55
packaging>=24.2
66
setuptools>=77.0.3,<80.0.0
77
torch-neuronx >= 2.5.0
8-
neuronx-cc
8+
neuronx-cc>=2.0.0a0
9+
torchvision # Required for Llama3.2 multimodal image preprocessing
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import os
3+
from unittest.mock import MagicMock
4+
5+
from vllm.config import VllmConfig
6+
from vllm.engine.arg_utils import EngineArgs
7+
from vllm.platforms import current_platform
8+
from vllm.platforms.neuron import NeuronFramework
9+
from vllm.sampling_params import SamplingParams
10+
from vllm.sequence import SequenceData, SequenceGroupMetadata
11+
from vllm.worker.neuron_model_runner import NeuronModelRunner
12+
13+
os.environ[
14+
'VLLM_NEURON_FRAMEWORK'] = NeuronFramework.TRANSFORMERS_NEURONX.value
15+
16+
17+
def _create_neuron_model_runner(model: str, *args,
18+
**kwargs) -> NeuronModelRunner:
19+
engine_args = EngineArgs(model, *args, **kwargs)
20+
engine_config = engine_args.create_engine_config()
21+
vllm_config = VllmConfig(
22+
model_config=engine_config.model_config,
23+
parallel_config=engine_config.parallel_config,
24+
scheduler_config=engine_config.scheduler_config,
25+
device_config=engine_config.device_config,
26+
)
27+
neuron_model_runner = NeuronModelRunner(vllm_config=vllm_config)
28+
return neuron_model_runner
29+
30+
31+
def test_update_neuron_sampling_params_not_full_batch():
32+
os.environ["NEURON_ON_DEVICE_SAMPLING_DISABLED"] = "0"
33+
model_runner = _create_neuron_model_runner(
34+
"facebook/opt-125m",
35+
seed=0,
36+
dtype="float16",
37+
max_num_seqs=2,
38+
)
39+
assert not model_runner._on_device_sampling_disabled
40+
# Test sampling param updating only when TNx is framework
41+
# NxDI handles sampling parameter updating inside model
42+
if current_platform.use_transformers_neuronx():
43+
model_mock = MagicMock()
44+
model_runner.model = model_mock
45+
46+
seq_group_metadata_list = [
47+
SequenceGroupMetadata(
48+
request_id="test_0",
49+
is_prompt=True,
50+
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
51+
sampling_params=SamplingParams(temperature=0.5,
52+
top_k=1,
53+
top_p=0.5),
54+
block_tables={0: [1]},
55+
)
56+
]
57+
58+
model_runner.prepare_model_input(seq_group_metadata_list)
59+
60+
# Index neuron sampling parameters based on block_tables indices.
61+
# The first block_id of the sequence 0 is 1, so its parameters are
62+
# placed at index 1. So the sampling parameters will be:
63+
# Index 0: default sampling parameters
64+
# Index 1: sequecne 0's sampling parameters.
65+
neuron_sampling_params = (
66+
model_runner.model_config.neuron_sampling_params)
67+
assert neuron_sampling_params.temperature == [1.0, 0.5]
68+
assert neuron_sampling_params.top_k == [
69+
model_runner._MAX_NEURON_SAMPLING_TOP_K, 1
70+
]
71+
assert neuron_sampling_params.top_p == [1.0, 0.5]
72+
model_mock.model.update_generation_config.assert_called_once_with(
73+
neuron_sampling_params)
74+
75+
76+
def test_update_neuron_sampling_params_full_batch():
77+
os.environ["NEURON_ON_DEVICE_SAMPLING_DISABLED"] = "0"
78+
model_runner = _create_neuron_model_runner(
79+
"facebook/opt-125m",
80+
seed=0,
81+
dtype="float16",
82+
max_num_seqs=2,
83+
)
84+
assert not model_runner._on_device_sampling_disabled
85+
86+
# Test sampling param updating only when TNx is framework
87+
# NxDI handles sampling parameter updating inside model
88+
if current_platform.use_transformers_neuronx():
89+
model_mock = MagicMock()
90+
model_runner.model = model_mock
91+
92+
seq_group_metadata_list = [
93+
SequenceGroupMetadata(
94+
request_id="test_0",
95+
is_prompt=True,
96+
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
97+
sampling_params=SamplingParams(temperature=0.5,
98+
top_k=1,
99+
top_p=0.5),
100+
block_tables={0: [1]},
101+
),
102+
SequenceGroupMetadata(
103+
request_id="test_0",
104+
is_prompt=True,
105+
seq_data={1: SequenceData.from_seqs([4, 5, 6])},
106+
sampling_params=SamplingParams(temperature=0.2,
107+
top_k=2,
108+
top_p=0.2),
109+
block_tables={1: [0]},
110+
)
111+
]
112+
113+
model_runner.prepare_model_input(seq_group_metadata_list)
114+
115+
# Index neuron sampling parameters based on block_tables indices.
116+
# The first block_id of the sequence 0 is 1, so its parameters are
117+
# placed at index 1. So the sampling parameters will be:
118+
# Index 0: sequence 1's sampling parameters
119+
# Index 1: sequecne 0's sampling parameters.
120+
neuron_sampling_params = (
121+
model_runner.model_config.neuron_sampling_params)
122+
assert neuron_sampling_params.temperature == [0.2, 0.5]
123+
assert neuron_sampling_params.top_k == [2, 1]
124+
assert neuron_sampling_params.top_p == [0.2, 0.5]
125+
model_mock.model.update_generation_config.assert_called_once_with(
126+
neuron_sampling_params)

vllm/config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2273,6 +2273,9 @@ class SpeculativeConfig:
22732273
"""Scaling factor for entropy-based threshold, applied when using
22742274
`TypicalAcceptanceSampler`."""
22752275

2276+
speculative_token_tree: Optional[str] = None
2277+
"""Specifies the tree structure for speculative token generation.
2278+
"""
22762279
# required configuration params passed from engine
22772280
target_model_config: ModelConfig = field(default=None,
22782281
init=True) # type: ignore
@@ -2447,10 +2450,11 @@ def __post_init__(self):
24472450
"Chunked prefill and EAGLE are not compatible "
24482451
"when using V0.")
24492452

2453+
from vllm.platforms import current_platform
24502454
from vllm.transformers_utils.configs.eagle import (
24512455
EAGLEConfig)
24522456
if isinstance(self.draft_model_config.hf_config,
2453-
EAGLEConfig):
2457+
EAGLEConfig) or current_platform.is_neuron():
24542458
pass
24552459
else:
24562460
eagle_config = EAGLEConfig(

vllm/engine/llm_engine.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -399,10 +399,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
399399
self.scheduler,
400400
self.seq_counter,
401401
get_tokenizer_for_seq,
402-
stop_checker=StopChecker(
403-
self.scheduler_config.max_model_len,
404-
get_tokenizer_for_seq,
405-
),
402+
stop_checker=StopChecker(self.scheduler_config.max_model_len,
403+
get_tokenizer_for_seq),
406404
))
407405

408406
self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}

0 commit comments

Comments
 (0)