1
+ import os
1
2
from dataclasses import dataclass
2
3
from importlib .util import find_spec
3
4
from typing import TYPE_CHECKING , Any , Dict , List , Optional , Tuple , Union
4
5
5
6
import torch
6
7
from torch import nn
8
+ from transformers_neuronx .config import GenerationConfig
7
9
8
10
from vllm .config import (DeviceConfig , ModelConfig , ParallelConfig ,
9
11
SchedulerConfig )
@@ -50,6 +52,9 @@ def from_broadcasted_tensor_dict(
50
52
51
53
class NeuronModelRunner (ModelRunnerBase [ModelInputForNeuron ]):
52
54
55
+ # NEURON has an upper limit on the top_k
56
+ _MAX_NEURON_SAMPLING_TOP_K = 256
57
+
53
58
def __init__ (
54
59
self ,
55
60
model_config : ModelConfig ,
@@ -76,6 +81,34 @@ def __init__(
76
81
# Lazy initialization.
77
82
self .model : nn .Module # initialize after load_model.
78
83
84
+ # Once NEURON_ON_DEVICE_SAMPLING_DISABLED is set to a non-zero value,
85
+ # turn off on-device sampling.
86
+ self ._on_device_sampling_disabled = int (
87
+ os .getenv ("NEURON_ON_DEVICE_SAMPLING_DISABLED" , "0" ))
88
+
89
+ # NEURON needs to update sampling parameters when request IDs change
90
+ # across batches. This variable stores the previous batch's request IDs
91
+ # to determine if an update is needed.
92
+ self ._previous_batch_request_ids : List [str ] = []
93
+
94
+ if not self ._on_device_sampling_disabled :
95
+ logger .warning (
96
+ "On-device sampling is turned on in Neuron by default, only "
97
+ "top_k, top_p, and temperature are current supported sampling "
98
+ "parameters. To turn off the on-device sampling, please set "
99
+ "the environment variable NEURON_ON_DEVICE_SAMPLING_DISABLED=1."
100
+ )
101
+ self .model_config .neuron_sampling_params = GenerationConfig (
102
+ max_length = self .scheduler_config .max_model_len ,
103
+ do_sample = True ,
104
+ per_batch_line = True ,
105
+ top_k = [self ._MAX_NEURON_SAMPLING_TOP_K ] \
106
+ * self .scheduler_config .max_num_seqs ,
107
+ top_p = [1.0 ] * self .scheduler_config .max_num_seqs ,
108
+ temperature = [1.0 ] * self .scheduler_config .max_num_seqs ,
109
+ dynamic = True ,
110
+ global_top_k = self ._MAX_NEURON_SAMPLING_TOP_K )
111
+
79
112
def load_model (self ) -> None :
80
113
if find_spec ("transformers_neuronx" ) is not None :
81
114
self .model = get_neuron_model (
@@ -215,7 +248,7 @@ def prepare_model_input(
215
248
else :
216
249
(input_tokens , input_positions ,
217
250
input_block_ids ) = self ._prepare_decode (seq_group_metadata_list )
218
- seq_lens = []
251
+ seq_lens = None
219
252
sampling_metadata = SamplingMetadata .prepare (
220
253
seq_group_metadata_list ,
221
254
seq_lens ,
@@ -227,12 +260,49 @@ def prepare_model_input(
227
260
self .pin_memory ,
228
261
generators = self .get_generators (finished_requests_ids ))
229
262
263
+ if not self ._on_device_sampling_disabled :
264
+ # Once the request IDs are changed in current iteration, we will
265
+ # update the on-device sampling parameters.
266
+ current_batch_request_ids = [
267
+ seq_group_meta_data .request_id
268
+ for seq_group_meta_data in seq_group_metadata_list
269
+ ]
270
+ if current_batch_request_ids != self ._previous_batch_request_ids :
271
+ self ._update_neuron_sampling_params (sampling_metadata )
272
+ self ._previous_batch_request_ids = current_batch_request_ids
273
+
230
274
return ModelInputForNeuron (input_tokens = input_tokens ,
231
275
input_positions = input_positions ,
232
276
input_block_ids = input_block_ids ,
233
277
sampling_metadata = sampling_metadata ,
234
278
multi_modal_kwargs = multi_modal_kwargs )
235
279
280
+ def _update_neuron_sampling_params (self ,
281
+ sampling_metadata : SamplingMetadata ):
282
+ # Update Neuron sampling parameters (GenerationConfig in Neuron)
283
+ current_sampling_params = self .model_config .neuron_sampling_params
284
+ assert current_sampling_params is not None , (
285
+ f"Failed to update sampling_params, "
286
+ f"current sampling params is { current_sampling_params } " )
287
+
288
+ top_k = current_sampling_params .top_k
289
+ top_p = current_sampling_params .top_p
290
+ temperature = current_sampling_params .temperature
291
+ for index , sequence_group_to_sample in enumerate (
292
+ sampling_metadata .seq_groups ):
293
+ top_k [index ] = self ._convert_to_neuron_top_k (
294
+ sequence_group_to_sample .sampling_params .top_k )
295
+ top_p [index ] = sequence_group_to_sample .sampling_params .top_p
296
+ temperature [index ] = \
297
+ sequence_group_to_sample .sampling_params .temperature
298
+
299
+ self .model .model .update_generation_config (current_sampling_params )
300
+
301
+ def _convert_to_neuron_top_k (self , top_k : int ) -> int :
302
+ if top_k < 0 or top_k > self ._MAX_NEURON_SAMPLING_TOP_K :
303
+ return self ._MAX_NEURON_SAMPLING_TOP_K
304
+ return top_k
305
+
236
306
@torch .inference_mode ()
237
307
def execute_model (
238
308
self ,
@@ -253,9 +323,13 @@ def execute_model(
253
323
device = self .device ),
254
324
)
255
325
256
- # Compute the logits.
257
- logits = self .model .compute_logits (hidden_states ,
258
- model_input .sampling_metadata )
326
+ # Compute the logits only if the on-device sampling is turned off as
327
+ # on-device sampling outputs the token ids.
328
+ if self ._on_device_sampling_disabled :
329
+ logits = self .model .compute_logits (hidden_states ,
330
+ model_input .sampling_metadata )
331
+ else :
332
+ logits = hidden_states
259
333
260
334
# Sample the next token.
261
335
output = self .model .sample (
0 commit comments