Skip to content

Commit cc90419

Browse files
[Hardware][Neuron] Add on-device sampling support for Neuron (#8746)
Co-authored-by: Ashraf Mahgoub <ashymahg@amazon.com>
1 parent 27302dd commit cc90419

File tree

2 files changed

+128
-13
lines changed

2 files changed

+128
-13
lines changed

vllm/model_executor/model_loader/neuron.py

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Utilities for selecting and loading neuron models."""
2+
import copy
23
import importlib
34
import os
45
from typing import Dict, List, Optional, Tuple
@@ -13,6 +14,8 @@
1314
from vllm.model_executor.layers.quantization import get_quantization_config
1415
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
1516
from vllm.model_executor.sampling_metadata import SamplingMetadata
17+
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
18+
SequenceOutput)
1619

1720
TORCH_DTYPE_TO_NEURON_AMP = {
1821
"auto": "f32",
@@ -37,15 +40,18 @@
3740

3841
class NeuronCasualLM(nn.Module):
3942

40-
def __init__(
41-
self,
42-
config: PretrainedConfig,
43-
) -> None:
43+
def __init__(self,
44+
config: PretrainedConfig,
45+
on_device_sampling_disabled: bool = False) -> None:
4446
super().__init__()
4547
self.config = config
4648
self.logits_processor = LogitsProcessor(config.vocab_size,
4749
logits_as_input=True)
48-
self.sampler = Sampler()
50+
51+
self.on_device_sampling_disabled = on_device_sampling_disabled
52+
if self.on_device_sampling_disabled:
53+
# Use default sampler
54+
self.sampler = Sampler()
4955

5056
# Lazy initialized
5157
self.model: nn.Module
@@ -71,8 +77,29 @@ def sample(
7177
logits: torch.Tensor,
7278
sampling_metadata: SamplingMetadata,
7379
) -> Optional[SamplerOutput]:
74-
next_tokens = self.sampler(logits, sampling_metadata)
75-
return next_tokens
80+
81+
if self.on_device_sampling_disabled:
82+
next_tokens = self.sampler(logits, sampling_metadata)
83+
return next_tokens
84+
85+
# On-device sampling outputs the token ids directly.
86+
sampled_token_ids = logits.flatten()
87+
next_tokens = []
88+
sample_idx = 0
89+
for seq_group in sampling_metadata.seq_groups:
90+
samples = []
91+
for seq_id in seq_group.seq_ids:
92+
token_id = sampled_token_ids[sample_idx].item()
93+
samples.append(
94+
SequenceOutput(parent_seq_id=seq_id,
95+
output_token=token_id,
96+
logprobs={token_id: Logprob(token_id)}))
97+
sample_idx += 1
98+
next_tokens.append(
99+
CompletionSequenceGroupOutput(samples=samples,
100+
prompt_logprobs=None))
101+
102+
return SamplerOutput(outputs=next_tokens)
76103

77104
def load_weights(self, model_name_or_path: str, **kwargs):
78105
arch = _get_model_architecture(self.config)
@@ -157,10 +184,22 @@ def _get_default_neuron_config(model_config: ModelConfig,
157184
quant=neuron_quantization_config_builder(model_config.quantization)
158185
if model_config.quantization else None,
159186
continuous_batching=continuous_batching_config,
160-
weight_tiling=bool(model_config.quantization))
187+
weight_tiling=bool(model_config.quantization),
188+
on_device_generation=_get_neuron_on_device_generation_config(
189+
model_config))
161190
return default_neuron_args
162191

163192

193+
def _get_neuron_on_device_generation_config(model_config: ModelConfig):
194+
if not _is_neuron_on_device_sampling_disabled(model_config):
195+
return copy.deepcopy(model_config.neuron_sampling_params)
196+
return None
197+
198+
199+
def _is_neuron_on_device_sampling_disabled(model_config: ModelConfig) -> bool:
200+
return not getattr(model_config, "neuron_sampling_params", None)
201+
202+
164203
def _get_neuron_config_after_override(default_neuron_config,
165204
overridden_neuron_config):
166205
from transformers_neuronx.config import NeuronConfig
@@ -174,7 +213,9 @@ def get_neuron_model(model_config: ModelConfig,
174213
scheduler_config: SchedulerConfig) -> nn.Module:
175214

176215
# Create a model instance.
177-
model = NeuronCasualLM(model_config.hf_config)
216+
model = NeuronCasualLM(
217+
model_config.hf_config,
218+
_is_neuron_on_device_sampling_disabled(model_config))
178219

179220
default_neuron_config_args = _get_default_neuron_config(
180221
model_config, parallel_config, scheduler_config)

vllm/worker/neuron_model_runner.py

Lines changed: 78 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import os
12
from dataclasses import dataclass
23
from importlib.util import find_spec
34
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
45

56
import torch
67
from torch import nn
8+
from transformers_neuronx.config import GenerationConfig
79

810
from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
911
SchedulerConfig)
@@ -50,6 +52,9 @@ def from_broadcasted_tensor_dict(
5052

5153
class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
5254

55+
# NEURON has an upper limit on the top_k
56+
_MAX_NEURON_SAMPLING_TOP_K = 256
57+
5358
def __init__(
5459
self,
5560
model_config: ModelConfig,
@@ -76,6 +81,34 @@ def __init__(
7681
# Lazy initialization.
7782
self.model: nn.Module # initialize after load_model.
7883

84+
# Once NEURON_ON_DEVICE_SAMPLING_DISABLED is set to a non-zero value,
85+
# turn off on-device sampling.
86+
self._on_device_sampling_disabled = int(
87+
os.getenv("NEURON_ON_DEVICE_SAMPLING_DISABLED", "0"))
88+
89+
# NEURON needs to update sampling parameters when request IDs change
90+
# across batches. This variable stores the previous batch's request IDs
91+
# to determine if an update is needed.
92+
self._previous_batch_request_ids: List[str] = []
93+
94+
if not self._on_device_sampling_disabled:
95+
logger.warning(
96+
"On-device sampling is turned on in Neuron by default, only "
97+
"top_k, top_p, and temperature are current supported sampling "
98+
"parameters. To turn off the on-device sampling, please set "
99+
"the environment variable NEURON_ON_DEVICE_SAMPLING_DISABLED=1."
100+
)
101+
self.model_config.neuron_sampling_params = GenerationConfig(
102+
max_length=self.scheduler_config.max_model_len,
103+
do_sample=True,
104+
per_batch_line=True,
105+
top_k=[self._MAX_NEURON_SAMPLING_TOP_K] \
106+
* self.scheduler_config.max_num_seqs,
107+
top_p=[1.0] * self.scheduler_config.max_num_seqs,
108+
temperature=[1.0] * self.scheduler_config.max_num_seqs,
109+
dynamic=True,
110+
global_top_k=self._MAX_NEURON_SAMPLING_TOP_K)
111+
79112
def load_model(self) -> None:
80113
if find_spec("transformers_neuronx") is not None:
81114
self.model = get_neuron_model(
@@ -215,7 +248,7 @@ def prepare_model_input(
215248
else:
216249
(input_tokens, input_positions,
217250
input_block_ids) = self._prepare_decode(seq_group_metadata_list)
218-
seq_lens = []
251+
seq_lens = None
219252
sampling_metadata = SamplingMetadata.prepare(
220253
seq_group_metadata_list,
221254
seq_lens,
@@ -227,12 +260,49 @@ def prepare_model_input(
227260
self.pin_memory,
228261
generators=self.get_generators(finished_requests_ids))
229262

263+
if not self._on_device_sampling_disabled:
264+
# Once the request IDs are changed in current iteration, we will
265+
# update the on-device sampling parameters.
266+
current_batch_request_ids = [
267+
seq_group_meta_data.request_id
268+
for seq_group_meta_data in seq_group_metadata_list
269+
]
270+
if current_batch_request_ids != self._previous_batch_request_ids:
271+
self._update_neuron_sampling_params(sampling_metadata)
272+
self._previous_batch_request_ids = current_batch_request_ids
273+
230274
return ModelInputForNeuron(input_tokens=input_tokens,
231275
input_positions=input_positions,
232276
input_block_ids=input_block_ids,
233277
sampling_metadata=sampling_metadata,
234278
multi_modal_kwargs=multi_modal_kwargs)
235279

280+
def _update_neuron_sampling_params(self,
281+
sampling_metadata: SamplingMetadata):
282+
# Update Neuron sampling parameters (GenerationConfig in Neuron)
283+
current_sampling_params = self.model_config.neuron_sampling_params
284+
assert current_sampling_params is not None, (
285+
f"Failed to update sampling_params, "
286+
f"current sampling params is {current_sampling_params}")
287+
288+
top_k = current_sampling_params.top_k
289+
top_p = current_sampling_params.top_p
290+
temperature = current_sampling_params.temperature
291+
for index, sequence_group_to_sample in enumerate(
292+
sampling_metadata.seq_groups):
293+
top_k[index] = self._convert_to_neuron_top_k(
294+
sequence_group_to_sample.sampling_params.top_k)
295+
top_p[index] = sequence_group_to_sample.sampling_params.top_p
296+
temperature[index] = \
297+
sequence_group_to_sample.sampling_params.temperature
298+
299+
self.model.model.update_generation_config(current_sampling_params)
300+
301+
def _convert_to_neuron_top_k(self, top_k: int) -> int:
302+
if top_k < 0 or top_k > self._MAX_NEURON_SAMPLING_TOP_K:
303+
return self._MAX_NEURON_SAMPLING_TOP_K
304+
return top_k
305+
236306
@torch.inference_mode()
237307
def execute_model(
238308
self,
@@ -253,9 +323,13 @@ def execute_model(
253323
device=self.device),
254324
)
255325

256-
# Compute the logits.
257-
logits = self.model.compute_logits(hidden_states,
258-
model_input.sampling_metadata)
326+
# Compute the logits only if the on-device sampling is turned off as
327+
# on-device sampling outputs the token ids.
328+
if self._on_device_sampling_disabled:
329+
logits = self.model.compute_logits(hidden_states,
330+
model_input.sampling_metadata)
331+
else:
332+
logits = hidden_states
259333

260334
# Sample the next token.
261335
output = self.model.sample(

0 commit comments

Comments
 (0)