diff --git a/examples/offline_inference/neuron_eagle.py b/examples/offline_inference/neuron_eagle.py new file mode 100644 index 00000000000..4f63f1a2fb3 --- /dev/null +++ b/examples/offline_inference/neuron_eagle.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This example shows how to run offline inference with an EAGLE speculative +decoding model on neuron. To use EAGLE speculative decoding, you must use +a draft model that is specifically fine-tuned for EAGLE speculation. +Additionally, to use EAGLE with NxD Inference, the draft model must include +the LM head weights from the target model. These weights are shared between +the draft and target model. +""" + +from vllm import LLM, SamplingParams + +# Sample prompts. +prompts = [ + "What is annapurna labs?", +] + +# Create a sampling params object. +sampling_params = SamplingParams(top_k=1, max_tokens=500, ignore_eos=True) + +# Create an LLM. +llm = LLM( + model="/home/ubuntu/model_hf/Meta-Llama-3.1-70B-Instruct", + speculative_config={ + "model": "/home/ubuntu/model_hf/Llama-3.1-70B-Instruct-EAGLE-Draft", + "num_speculative_tokens": 5, + "max_model_len": 2048 + }, + max_num_seqs=4, + # The max_model_len and block_size arguments are required to be same as + # max sequence length when targeting neuron device. + # Currently, this is a known limitation in continuous batching support + # in neuronx-distributed-inference. + max_model_len=2048, + block_size=2048, + # The device can be automatically detected when AWS Neuron SDK is installed. + # The device argument can be either unspecified for automated detection, + # or explicitly assigned. + device="neuron", + tensor_parallel_size=32, + override_neuron_config={ + "enable_eagle_speculation": True, + "enable_fused_speculation": True + }, +) + +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. +outputs = llm.generate(prompts, sampling_params) +# Print the outputs. +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, \n\n\n\ Generated text: {generated_text!r}") diff --git a/examples/offline_inference/neuron_speculation.py b/examples/offline_inference/neuron_speculation.py new file mode 100644 index 00000000000..bef434bae5b --- /dev/null +++ b/examples/offline_inference/neuron_speculation.py @@ -0,0 +1,64 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This example shows how to run offline inference with a speculative +decoding model on neuron. +""" + +import os + +from vllm import LLM, SamplingParams + +# Sample prompts. +prompts = [ + "Hello, I am a language model and I can help", + "The president of the United States is", + "The capital of France is", +] + + +def config_buckets(): + """Configure context length and token gen buckets.""" + # creates XLA hlo graphs for all the context length buckets. + os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048" + # creates XLA hlo graphs for all the token gen buckets. + os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048" + + +def initialize_model(): + """Create an LLM with speculative decoding.""" + return LLM( + model="openlm-research/open_llama_7b", + speculative_config={ + "model": "openlm-research/open_llama_3b", + "num_speculative_tokens": 4, + "max_model_len": 2048 + }, + max_num_seqs=4, + max_model_len=2048, + block_size=2048, + use_v2_block_manager=True, + device="neuron", + tensor_parallel_size=32, + ) + + +def process_requests(model: LLM, sampling_params: SamplingParams): + """Generate texts from prompts and print them.""" + outputs = model.generate(prompts, sampling_params) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +def main(): + """Main function that sets up the model and processes prompts.""" + config_buckets() + model = initialize_model() + # Create a sampling params object. + sampling_params = SamplingParams(max_tokens=100, top_k=1) + process_requests(model, sampling_params) + + +if __name__ == '__main__': + main() diff --git a/requirements/neuron.txt b/requirements/neuron.txt index f8e3030834e..7df478eddde 100644 --- a/requirements/neuron.txt +++ b/requirements/neuron.txt @@ -5,4 +5,5 @@ packaging>=24.2 setuptools>=77.0.3,<80.0.0 torch-neuronx >= 2.5.0 -neuronx-cc +neuronx-cc>=2.0.0a0 +torchvision # Required for Llama3.2 multimodal image preprocessing diff --git a/tests/neuron/1_core/test_neuron_model_runner.py b/tests/neuron/1_core/test_neuron_model_runner.py new file mode 100644 index 00000000000..92417fb64f7 --- /dev/null +++ b/tests/neuron/1_core/test_neuron_model_runner.py @@ -0,0 +1,126 @@ +# SPDX-License-Identifier: Apache-2.0 +import os +from unittest.mock import MagicMock + +from vllm.config import VllmConfig +from vllm.engine.arg_utils import EngineArgs +from vllm.platforms import current_platform +from vllm.platforms.neuron import NeuronFramework +from vllm.sampling_params import SamplingParams +from vllm.sequence import SequenceData, SequenceGroupMetadata +from vllm.worker.neuron_model_runner import NeuronModelRunner + +os.environ[ + 'VLLM_NEURON_FRAMEWORK'] = NeuronFramework.TRANSFORMERS_NEURONX.value + + +def _create_neuron_model_runner(model: str, *args, + **kwargs) -> NeuronModelRunner: + engine_args = EngineArgs(model, *args, **kwargs) + engine_config = engine_args.create_engine_config() + vllm_config = VllmConfig( + model_config=engine_config.model_config, + parallel_config=engine_config.parallel_config, + scheduler_config=engine_config.scheduler_config, + device_config=engine_config.device_config, + ) + neuron_model_runner = NeuronModelRunner(vllm_config=vllm_config) + return neuron_model_runner + + +def test_update_neuron_sampling_params_not_full_batch(): + os.environ["NEURON_ON_DEVICE_SAMPLING_DISABLED"] = "0" + model_runner = _create_neuron_model_runner( + "facebook/opt-125m", + seed=0, + dtype="float16", + max_num_seqs=2, + ) + assert not model_runner._on_device_sampling_disabled + # Test sampling param updating only when TNx is framework + # NxDI handles sampling parameter updating inside model + if current_platform.use_transformers_neuronx(): + model_mock = MagicMock() + model_runner.model = model_mock + + seq_group_metadata_list = [ + SequenceGroupMetadata( + request_id="test_0", + is_prompt=True, + seq_data={0: SequenceData.from_seqs([1, 2, 3])}, + sampling_params=SamplingParams(temperature=0.5, + top_k=1, + top_p=0.5), + block_tables={0: [1]}, + ) + ] + + model_runner.prepare_model_input(seq_group_metadata_list) + + # Index neuron sampling parameters based on block_tables indices. + # The first block_id of the sequence 0 is 1, so its parameters are + # placed at index 1. So the sampling parameters will be: + # Index 0: default sampling parameters + # Index 1: sequecne 0's sampling parameters. + neuron_sampling_params = ( + model_runner.model_config.neuron_sampling_params) + assert neuron_sampling_params.temperature == [1.0, 0.5] + assert neuron_sampling_params.top_k == [ + model_runner._MAX_NEURON_SAMPLING_TOP_K, 1 + ] + assert neuron_sampling_params.top_p == [1.0, 0.5] + model_mock.model.update_generation_config.assert_called_once_with( + neuron_sampling_params) + + +def test_update_neuron_sampling_params_full_batch(): + os.environ["NEURON_ON_DEVICE_SAMPLING_DISABLED"] = "0" + model_runner = _create_neuron_model_runner( + "facebook/opt-125m", + seed=0, + dtype="float16", + max_num_seqs=2, + ) + assert not model_runner._on_device_sampling_disabled + + # Test sampling param updating only when TNx is framework + # NxDI handles sampling parameter updating inside model + if current_platform.use_transformers_neuronx(): + model_mock = MagicMock() + model_runner.model = model_mock + + seq_group_metadata_list = [ + SequenceGroupMetadata( + request_id="test_0", + is_prompt=True, + seq_data={0: SequenceData.from_seqs([1, 2, 3])}, + sampling_params=SamplingParams(temperature=0.5, + top_k=1, + top_p=0.5), + block_tables={0: [1]}, + ), + SequenceGroupMetadata( + request_id="test_0", + is_prompt=True, + seq_data={1: SequenceData.from_seqs([4, 5, 6])}, + sampling_params=SamplingParams(temperature=0.2, + top_k=2, + top_p=0.2), + block_tables={1: [0]}, + ) + ] + + model_runner.prepare_model_input(seq_group_metadata_list) + + # Index neuron sampling parameters based on block_tables indices. + # The first block_id of the sequence 0 is 1, so its parameters are + # placed at index 1. So the sampling parameters will be: + # Index 0: sequence 1's sampling parameters + # Index 1: sequecne 0's sampling parameters. + neuron_sampling_params = ( + model_runner.model_config.neuron_sampling_params) + assert neuron_sampling_params.temperature == [0.2, 0.5] + assert neuron_sampling_params.top_k == [2, 1] + assert neuron_sampling_params.top_p == [0.2, 0.5] + model_mock.model.update_generation_config.assert_called_once_with( + neuron_sampling_params) diff --git a/vllm/config.py b/vllm/config.py index 9738d2fd0e0..658fac669ea 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2269,6 +2269,9 @@ class SpeculativeConfig: """Scaling factor for entropy-based threshold, applied when using `TypicalAcceptanceSampler`.""" + speculative_token_tree: Optional[str] = None + """Specifies the tree structure for speculative token generation. + """ # required configuration params passed from engine target_model_config: ModelConfig = field(default=None, init=True) # type: ignore @@ -2443,10 +2446,11 @@ def __post_init__(self): "Chunked prefill and EAGLE are not compatible " "when using V0.") + from vllm.platforms import current_platform from vllm.transformers_utils.configs.eagle import ( EAGLEConfig) if isinstance(self.draft_model_config.hf_config, - EAGLEConfig): + EAGLEConfig) or current_platform.is_neuron(): pass else: eagle_config = EAGLEConfig( diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 4398852daac..90fa68142f4 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -399,10 +399,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: self.scheduler, self.seq_counter, get_tokenizer_for_seq, - stop_checker=StopChecker( - self.scheduler_config.max_model_len, - get_tokenizer_for_seq, - ), + stop_checker=StopChecker(self.scheduler_config.max_model_len, + get_tokenizer_for_seq), )) self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {} diff --git a/vllm/model_executor/model_loader/neuron.py b/vllm/model_executor/model_loader/neuron.py index a7b313f4e50..e4a48483764 100644 --- a/vllm/model_executor/model_loader/neuron.py +++ b/vllm/model_executor/model_loader/neuron.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 -"""Utilities for selecting and loading neuron models.""" +"""Utilities for selecting and loading Neuron models in transformers-neuronx +framework.""" +import ast import copy import importlib import os @@ -9,7 +11,8 @@ import torch.nn as nn from transformers import PretrainedConfig -from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig +from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig, + SpeculativeConfig) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import get_quantization_config from vllm.model_executor.layers.sampler import Sampler, SamplerOutput @@ -113,6 +116,67 @@ def load_weights(self, model_name_or_path: str, **kwargs): self.model.to_neuron() +class NeuronSpeculationCausalLM(nn.Module): + """A Neuron-optimized causal language model with speculative decoding.""" + + SPECULATION_TERMINATION_ID = -1 + + def __init__(self, speculation_model) -> None: + super().__init__() + self.model = speculation_model + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_block_ids: torch.Tensor, + ) -> torch.Tensor: + tokens, counts = self.model.speculative_iteration( + input_ids, positions, input_block_ids) + + # Mark the end of accepted speculative tokens for each sequence with the + # speculation termination id. + batch_size, steps = tokens.shape + mask = torch.arange(steps).expand(batch_size, -1) >= counts + tokens[mask] = self.SPECULATION_TERMINATION_ID + + return tokens + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[List[SamplerOutput]]: + batch_size, num_steps = logits.shape + seq_ids = [ + seq_id for sg in sampling_metadata.seq_groups + for seq_id in sg.seq_ids + ] + # Organize input tensors by step instead of by sequence. + accepted_token_ids_by_step = logits.transpose(0, 1) + accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() + + sampler_output_list = [] + for step_index in range(num_steps): + if all(token_id == self.SPECULATION_TERMINATION_ID + for token_id in accepted_token_ids_by_step[step_index]): + break + step_output_token_ids = [] + for sequence_index in range(batch_size): + token_id = accepted_token_ids_by_step[step_index][ + sequence_index] + step_output_token_ids.append( + CompletionSequenceGroupOutput(samples=[ + SequenceOutput(parent_seq_id=seq_ids[sequence_index], + output_token=token_id, + logprobs={token_id: Logprob(token_id)}) + ], + prompt_logprobs=None)) + sampler_output_list.append( + SamplerOutput(outputs=step_output_token_ids)) + return sampler_output_list + + def _get_model_architecture(config: PretrainedConfig) -> str: architectures = getattr(config, "architectures", []) for arch in architectures: @@ -138,6 +202,7 @@ def _get_buckets(env: str, default_value: List[int]) -> List[int]: def _get_default_neuron_config(model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig): + """Generate a neuron config based on vllm config args.""" from transformers_neuronx.config import ContinuousBatchingConfig from transformers_neuronx.constants import LAYOUT_BSH @@ -162,6 +227,27 @@ def _get_default_neuron_config(model_config: ModelConfig, return default_neuron_args +def _get_default_neuron_config_for_speculation( + model_config: ModelConfig, parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig): + """Generate a neuron config for speculative decoding based on + vllm config args.""" + from transformers_neuronx.config import ContinuousBatchingConfig + from transformers_neuronx.constants import LAYOUT_BSH + + continuous_batching_config = ContinuousBatchingConfig( + batch_size_for_shared_caches=scheduler_config.max_num_seqs) + + default_neuron_args = dict(collectives_layout=LAYOUT_BSH, + attention_layout=LAYOUT_BSH, + fuse_qkv=True, + on_device_embedding=True, + continuous_batching=continuous_batching_config, + on_device_generation=copy.deepcopy( + model_config.neuron_sampling_params)) + return default_neuron_args + + def _get_neuron_on_device_generation_config(model_config: ModelConfig): if not _is_neuron_on_device_sampling_disabled(model_config): return copy.deepcopy(model_config.neuron_sampling_params) @@ -213,7 +299,7 @@ def _get_neuron_config_after_override(default_neuron_config, def get_neuron_model(model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig) -> nn.Module: - + """Initializes a neuron-optimized model for inference.""" # Create a model instance. model = NeuronCausalLM( model_config.hf_config, @@ -230,7 +316,6 @@ def get_neuron_model(model_config: ModelConfig, n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS", [scheduler_config.max_model_len]) - # Load the weights from the cached or downloaded files. model.load_weights(model_config.model, tp_degree=parallel_config.tensor_parallel_size, amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], @@ -240,3 +325,151 @@ def get_neuron_model(model_config: ModelConfig, batch_size=scheduler_config.max_num_seqs) return model.eval() + + +def get_neuron_speculation_model(model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + speculation_config: SpeculativeConfig): + """Initializes a neuron-optimized speculation model for inference. + + This method is only applicable for speculation with a standalone draft model + """ + from transformers_neuronx.fused_speculation import FusedSpeculativeDecoder + + # For Eagle SD, we need to pass in additional parameters in neuron config. + is_eagle = getattr(speculation_config.draft_model_config.hf_config, + "is_eagle", False) + + # Create target model instance. + target_model = NeuronCausalLM(model_config.hf_config) + + default_neuron_config_args = _get_default_neuron_config_for_speculation( + model_config, parallel_config, scheduler_config) + if is_eagle: + default_neuron_config_args['is_eagle_target'] = True + + neuron_config = _get_neuron_config_after_override( + default_neuron_config_args, model_config.override_neuron_config) + + context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS", + [scheduler_config.max_model_len]) + n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS", + [scheduler_config.max_model_len]) + + target_model.load_weights( + model_config.model, + tp_degree=parallel_config.tensor_parallel_size, + amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], + neuron_config=neuron_config, + context_length_estimate=context_length_estimates, + n_positions=n_positions, + batch_size=scheduler_config.max_num_seqs) + + target_model.eval() + + # Create draft model instance. + draft_model = NeuronCausalLM( + speculation_config.draft_model_config.hf_config) + + default_draft_neuron_config_args = ( + _get_default_neuron_config_for_speculation( + speculation_config.draft_model_config, parallel_config, + scheduler_config)) + if is_eagle: + default_draft_neuron_config_args['is_eagle_draft'] = True + default_draft_neuron_config_args['has_pre_attention_norm'] = False + + draft_neuron_config = _get_neuron_config_after_override( + default_draft_neuron_config_args, + speculation_config.draft_model_config.override_neuron_config) + + draft_model.load_weights(speculation_config.draft_model_config.model, + tp_degree=speculation_config. + draft_parallel_config.tensor_parallel_size, + amp=TORCH_DTYPE_TO_NEURON_AMP[ + speculation_config.draft_model_config.dtype], + neuron_config=draft_neuron_config, + context_length_estimate=context_length_estimates, + n_positions=n_positions, + batch_size=scheduler_config.max_num_seqs) + + draft_model.eval() + + num_speculative_tokens = speculation_config.num_speculative_tokens + # Create speculation model instance. + speculation_model = FusedSpeculativeDecoder(draft_model.model, + target_model.model, + num_speculative_tokens) + speculation_model.to_neuron() + + return NeuronSpeculationCausalLM(speculation_model) + + +def get_neuron_eagle_speculation_model(model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + speculation_config: SpeculativeConfig): + """Initializes a neuron-optimized EAGLE speculation model for inference.""" + from transformers_neuronx.eagle_speculation import EagleSpeculativeDecoder + + # Create target model instance. + target_model = NeuronCausalLM(model_config.hf_config) + + default_neuron_config_args = _get_default_neuron_config_for_speculation( + model_config, parallel_config, scheduler_config) + default_neuron_config_args['is_eagle_target'] = True + neuron_config = _get_neuron_config_after_override( + default_neuron_config_args, model_config.override_neuron_config) + + context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS", + [scheduler_config.max_model_len]) + n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS", + [scheduler_config.max_model_len]) + + target_model.load_weights( + model_config.model, + tp_degree=parallel_config.tensor_parallel_size, + amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], + neuron_config=neuron_config, + context_length_estimate=context_length_estimates, + n_positions=n_positions, + batch_size=scheduler_config.max_num_seqs) + + target_model.eval() + + # Create draft model instance. + draft_model = NeuronCausalLM( + speculation_config.draft_model_config.hf_config) + + default_draft_neuron_config_args = ( + _get_default_neuron_config_for_speculation( + speculation_config.draft_model_config, parallel_config, + scheduler_config)) + default_draft_neuron_config_args['is_eagle_draft'] = True + default_draft_neuron_config_args['has_pre_attention_norm'] = False + draft_neuron_config = _get_neuron_config_after_override( + default_draft_neuron_config_args, + speculation_config.draft_model_config.override_neuron_config) + + draft_model.load_weights(speculation_config.draft_model_config.model, + tp_degree=speculation_config. + draft_parallel_config.tensor_parallel_size, + amp=TORCH_DTYPE_TO_NEURON_AMP[ + speculation_config.draft_model_config.dtype], + neuron_config=draft_neuron_config, + context_length_estimate=context_length_estimates, + n_positions=n_positions, + batch_size=scheduler_config.max_num_seqs) + + draft_model.eval() + + token_tree: Dict[int, List[int]] = ast.literal_eval( + speculation_config.speculative_token_tree) + + speculation_model = EagleSpeculativeDecoder(draft_model.model, + target_model.model, + token_tree=token_tree) + speculation_model.to_neuron() + + return NeuronSpeculationCausalLM(speculation_model) diff --git a/vllm/model_executor/model_loader/neuronx_distributed.py b/vllm/model_executor/model_loader/neuronx_distributed.py new file mode 100644 index 00000000000..f879c99ac2e --- /dev/null +++ b/vllm/model_executor/model_loader/neuronx_distributed.py @@ -0,0 +1,584 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Utilities for selecting and loading Neuron models in +neuronx-distributed-inference framework.""" +# Disabling yapf because yapf and isort have conflicts for the below imports +# yapf: disable +import copy +import hashlib +import importlib +import multiprocessing +import os +import shutil +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from neuronx_distributed_inference.models.config import ( + FusedSpecNeuronConfig, OnDeviceSamplingConfig) +from neuronx_distributed_inference.models.mllama.utils import ( + create_vision_mask) +from neuronx_distributed_inference.utils.hf_adapter import ( + load_pretrained_config) +from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig + +from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig, + SpeculativeConfig) +from vllm.logger import init_logger +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, + SequenceOutput) + +# yapf: enable +logger = init_logger(__name__) + +TORCH_DTYPE_TO_NEURON_AMP = { + "auto": "float32", + "half": "float16", + "float16": "float16", + "bfloat16": "bfloat16", + "float": "float32", + "float32": "float32", + torch.float16: "float16", + torch.bfloat16: "bfloat16", + torch.float32: "float32", +} + +# Models supported by Neuronx distributed for inference. +_NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str]] = { + "LlamaForCausalLM": + ("neuronx_distributed_inference.models.llama.modeling_llama", + "NeuronLlamaForCausalLM"), + "DbrxForCausalLM": + ("neuronx_distributed_inference.models.dbrx.modeling_dbrx", + "NeuronDbrxForCausalLM"), + "MixtralForCausalLM": + ("neuronx_distributed_inference.models.mixtral.modeling_mixtral", + "NeuronMixtralForCausalLM"), + "MllamaForConditionalGeneration": + ("neuronx_distributed_inference.models.mllama.modeling_mllama", + "NeuronMllamaForCausalLM"), +} + + +class NeuronCausalLM(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + ) -> None: + super().__init__() + self.config = config + self.logits_processor = LogitsProcessor(config.vocab_size, + logits_as_input=True) + self.sampler = Sampler() + + # Lazy initialized + self.model: nn.Module + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_block_ids: torch.Tensor, + sampling_params: torch.Tensor, + ) -> torch.Tensor: + output = self.model(input_ids, + attention_mask=None, + position_ids=positions, + seq_ids=input_block_ids, + sampling_params=sampling_params) + # on-device sampling + if self.config.neuron_config.on_device_sampling_config: + return output.hidden_states + else: + return output.logits[:, -1, :] + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(None, hidden_states, sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + # on-device sampling + if self.config.neuron_config.on_device_sampling_config: + batch_size = logits.shape + seq_ids = [ + seq_id for sg in sampling_metadata.seq_groups + for seq_id in sg.seq_ids + ] + assert len(seq_ids) == list(batch_size)[0], "batch size mismatch" + # Organize input tensors by step instead of by sequence. + accepted_token_ids_by_step = logits.flatten() + accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() + + step_output_token_ids = [] + for i, seq_id in enumerate(seq_ids): + token_id = accepted_token_ids_by_step[i] + step_output_token_ids.append( + CompletionSequenceGroupOutput(samples=[ + SequenceOutput(parent_seq_id=seq_id, + output_token=token_id, + logprobs={token_id: Logprob(token_id)}) + ], + prompt_logprobs=None)) + return SamplerOutput(outputs=step_output_token_ids) + else: + return self.sampler(logits, sampling_metadata) + + def load_weights(self, model_name_or_path: str, **kwargs): + arch = _get_model_architecture(self.config) + neuronx_module_path, neuronx_model_cls_name = ( + _NEURON_SUPPORTED_MODELS[arch]) + neuronx_module = importlib.import_module(neuronx_module_path) + neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name) + neuron_config = neuronx_model_cls.get_neuron_config_cls()( + **kwargs['neuron_config']) + self.config.neuron_config = neuron_config + config = neuronx_model_cls.get_config_cls()( + neuron_config, + load_config=load_pretrained_config(model_name_or_path)) + hashed_config = hashlib.md5( + config.to_json_string().encode('utf-8')).hexdigest() + if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None: + compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS") + elif os.path.exists(model_name_or_path): + compiled_model_path = os.path.join(model_name_or_path, + "neuron-compiled-artifacts", + hashed_config) + shutil.rmtree(compiled_model_path, ignore_errors=True) + else: + compiled_model_path = os.path.join("local-models", + model_name_or_path, + "neuron-compiled-artifacts", + hashed_config) + shutil.rmtree(compiled_model_path, ignore_errors=True) + try: + self.model = neuronx_model_cls(compiled_model_path) + override_neuron_config = kwargs["override_neuron_config"] + for k, v in override_neuron_config.items(): + setattr(self.model.config.neuron_config, k, v) + self.model.load(compiled_model_path) + return + except (FileNotFoundError, ValueError) as e: + logger.warning("Exception: %s", e) + logger.warning("Failed to load the model from %s, Recompiling...", + compiled_model_path) + if not os.path.exists(model_name_or_path): + hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path) + saved_path = os.path.join("local-models", model_name_or_path) + hf_model.save_pretrained(saved_path) + model_name_or_path = saved_path + self.model = neuronx_model_cls(model_name_or_path, config) + self.model.compile(compiled_model_path) + self.model.load(compiled_model_path) + + +class NeuronMllamaForCausalLM(nn.Module): + + def __init__(self, + config: PretrainedConfig, + on_device_sampling_disabled: bool = False) -> None: + super().__init__() + self.config = config + self.logits_processor = LogitsProcessor( + config.get_text_config().vocab_size, logits_as_input=True) + + self.on_device_sampling_disabled = on_device_sampling_disabled + if self.on_device_sampling_disabled: + # Use default sampler + self.sampler = Sampler() + + # Lazy initialized + self.model: nn.Module + + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, + seq_ids: torch.Tensor, pixel_values: torch.Tensor, + aspect_ratios: torch.Tensor, num_chunks: torch.Tensor, + has_image: torch.Tensor, sampling_params) -> torch.Tensor: + self.vision_mask = create_vision_mask(input_ids, self.vision_token_id) + output = self.model( + input_ids.to(torch.int32), + attention_mask=None, + position_ids=positions.to(torch.int32), + seq_ids=seq_ids.flatten().to(torch.int32), + pixel_values=pixel_values.to( + self.config.vision_config.torch_dtype), + aspect_ratios=aspect_ratios.to(torch.int32), + vision_mask=self.vision_mask.to(torch.int32), + sampling_params=sampling_params, + num_chunks=num_chunks.to(torch.int32), + has_image=has_image.to(torch.int32), + ) + if self.config.neuron_config.on_device_sampling_config: + return output.hidden_states + return output.logits[:, -1, :] + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(None, hidden_states, sampling_metadata) + return logits + + def sample(self, hidden_states, sampling_metadata): + if not self.on_device_sampling_disabled: + with torch.profiler.record_function("sample"): + hidden_states = hidden_states.flatten() + res = [] + sample_idx = 0 + for seq_group in sampling_metadata.seq_groups: + seq_ids = seq_group.seq_ids + samples = [] + for seq_id in seq_ids: + token_id = hidden_states[sample_idx].item() + samples.append( + SequenceOutput( + parent_seq_id=seq_id, + output_token=token_id, + logprobs={token_id: Logprob(token_id)})) + sample_idx += 1 + res.append( + CompletionSequenceGroupOutput(samples=samples, + prompt_logprobs=None)) + next_tokens = SamplerOutput(outputs=res) + else: + next_tokens = self.sampler(None, hidden_states, sampling_metadata) + return next_tokens + + def load_weights(self, model_name_or_path: str, **kwargs): + arch = _get_model_architecture(self.config) + neuronx_module_path, neuronx_model_cls_name = ( + _NEURON_SUPPORTED_MODELS[arch]) + neuronx_module = importlib.import_module(neuronx_module_path) + neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name) + neuron_config = neuronx_model_cls.get_neuron_config_cls()( + **kwargs['neuron_config']) + self.config.neuron_config = neuron_config + logger.info("neuron_config buckets: %s", + self.config.neuron_config.buckets) + config = neuronx_model_cls.get_config_cls()( + neuron_config, + load_config=load_pretrained_config(model_name_or_path)) + hashed_config = hashlib.md5( + config.to_json_string().encode('utf-8')).hexdigest() + if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None: + compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS") + elif os.path.exists(model_name_or_path): + compiled_model_path = os.path.join(model_name_or_path, + "neuron-compiled-artifacts", + hashed_config) + else: + compiled_model_path = os.path.join("local-models", + model_name_or_path, + "neuron-compiled-artifacts", + hashed_config) + try: + self.model = neuronx_model_cls(compiled_model_path) + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + self.vision_token_id = tokenizer( + "<|image|>", add_special_tokens=False).input_ids + self.model.load(compiled_model_path) + return + except (FileNotFoundError, ValueError): + logger.warning("Failed to load the model from %s, Recompiling...", + compiled_model_path) + if not os.path.exists(model_name_or_path): + hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path) + saved_path = os.path.join("local-models", model_name_or_path) + hf_model.save_pretrained(saved_path) + model_name_or_path = saved_path + self.model = neuronx_model_cls(model_name_or_path, config) + + logger.info("\nCompiling and saving model to %s", model_name_or_path) + + p = multiprocessing.Process(target=compile_model, + args=(self, compiled_model_path)) + p.start() + p.join() + + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + tokenizer.save_pretrained(compiled_model_path) + logger.info("Successfully compiled and saved the model in %s", + compiled_model_path) + + # Read "<|image|>" token_id from the tokenizer + self.vision_token_id = tokenizer("<|image|>", + add_special_tokens=False).input_ids + logger.info("\nLoading model from compiled checkpoint...") + self.model.load(compiled_model_path) + + +def compile_model(neuron_model, traced_model_path): + neuron_model.model.compile(traced_model_path) + + +class NeuronSpeculationCausalLM(nn.Module): + """A Neuron-optimized causal language model with speculative decoding.""" + + def __init__( + self, + config: PretrainedConfig, + ) -> None: + super().__init__() + self.config = config + self.logits_processor = LogitsProcessor(config.vocab_size, + logits_as_input=True) + # Lazy initialized + self.model: nn.Module + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_block_ids: torch.Tensor, + sampling_params: torch.Tensor, + ) -> torch.Tensor: + output = self.model(input_ids, + attention_mask=None, + position_ids=positions, + seq_ids=input_block_ids, + sampling_params=sampling_params) + # CTX encoding + if (positions[:, 0]).sum().item() == 0: + return output.fused_outputs[0][:, 0:1] + + # Fused Spec (Generation) + accepted_tokens_with_padding = output.fused_outputs[0] + next_pos_ids = output.fused_outputs[-1] + generated_token_counts = next_pos_ids - positions + + assert torch.any(generated_token_counts == 0).item() is False, \ + "NxDI model generated no output for one or more sequences." + + batch_size, steps = accepted_tokens_with_padding.shape + mask = torch.arange(steps).expand(batch_size, + -1) >= generated_token_counts + accepted_tokens_with_padding[mask] = -1 + + return accepted_tokens_with_padding + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[List[SamplerOutput]]: + batch_size, num_steps = logits.shape + seq_ids = [ + seq_id for sg in sampling_metadata.seq_groups + for seq_id in sg.seq_ids + ] + # Organize input tensors by step instead of by sequence. + accepted_token_ids_by_step = logits.transpose(0, 1) + accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() + + sampler_output_list = [] + for step_index in range(num_steps): + if all(token_id == -1 + for token_id in accepted_token_ids_by_step[step_index]): + break + step_output_token_ids = [] + for sequence_index in range(batch_size): + token_id = accepted_token_ids_by_step[step_index][ + sequence_index] + step_output_token_ids.append( + CompletionSequenceGroupOutput(samples=[ + SequenceOutput(parent_seq_id=seq_ids[sequence_index], + output_token=token_id, + logprobs={token_id: Logprob(token_id)}) + ], + prompt_logprobs=None)) + sampler_output_list.append( + SamplerOutput(outputs=step_output_token_ids)) + return sampler_output_list + + def load_weights(self, model_name_or_path: str, + draft_model_name_or_path: str, **kwargs): + arch = _get_model_architecture(self.config) + neuronx_module_path, neuronx_model_cls_name = ( + _NEURON_SUPPORTED_MODELS[arch]) + neuronx_module = importlib.import_module(neuronx_module_path) + neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name) + neuron_config = neuronx_model_cls.get_neuron_config_cls()( + **kwargs['neuron_config']) + config = neuronx_model_cls.get_config_cls()( + neuron_config, + load_config=load_pretrained_config(model_name_or_path)) + + draft_neuron_config = copy.deepcopy(config.neuron_config) + if not config.neuron_config.enable_eagle_speculation: + draft_neuron_config.speculation_length = 0 + draft_neuron_config.trace_tokengen_model = True + draft_neuron_config.enable_fused_speculation = False + if config.neuron_config.enable_eagle_speculation: + draft_neuron_config.is_eagle_draft = True + draft_neuron_config.sequence_parallel_enabled = False + draft_config = neuronx_model_cls.get_config_cls()( + draft_neuron_config, + load_config=load_pretrained_config(draft_model_name_or_path)) + fused_spec_config = (FusedSpecNeuronConfig( + neuronx_model_cls._model_cls, + draft_config=draft_config, + draft_model_path=draft_model_name_or_path)) + config.fused_spec_config = fused_spec_config + self.config.neuron_config = neuron_config + + hashed_config = hashlib.md5( + config.to_json_string().encode('utf-8')).hexdigest() + if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None: + compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS") + elif os.path.exists(model_name_or_path): + compiled_model_path = os.path.join(model_name_or_path, + "neuron-compiled-artifacts", + hashed_config) + shutil.rmtree(compiled_model_path, ignore_errors=True) + else: + compiled_model_path = os.path.join("local-models", + model_name_or_path, + "neuron-compiled-artifacts", + hashed_config) + shutil.rmtree(compiled_model_path, ignore_errors=True) + try: + self.model = neuronx_model_cls(compiled_model_path) + override_neuron_config = kwargs["override_neuron_config"] + for k, v in override_neuron_config.items(): + setattr(self.model.config.neuron_config, k, v) + self.model.load(compiled_model_path) + return + except (FileNotFoundError, ValueError) as e: + logger.warning("Exception: %s", e) + logger.warning("Failed to load the model from %s Recompiling...", + compiled_model_path) + if not os.path.exists(model_name_or_path): + hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path) + saved_path = os.path.join("local-models", model_name_or_path) + hf_model.save_pretrained(saved_path) + model_name_or_path = saved_path + if not os.path.exists(draft_model_name_or_path): + if draft_model_name_or_path != model_name_or_path: + hf_model = AutoModelForCausalLM.from_pretrained( + draft_model_name_or_path) + saved_path = os.path.join("local-models", + draft_model_name_or_path) + hf_model.save_pretrained(saved_path) + draft_model_name_or_path = saved_path + else: + draft_model_name_or_path = model_name_or_path + config.fused_spec_config.draft_model_path = draft_model_name_or_path + self.model = neuronx_model_cls(model_name_or_path, config) + self.model.compile(compiled_model_path) + self.model.load(compiled_model_path) + + +def _get_model_architecture(config: PretrainedConfig) -> str: + architectures = getattr(config, "architectures", []) + for arch in architectures: + if arch in _NEURON_SUPPORTED_MODELS: + return arch + raise ValueError( + f"Model architectures {architectures} are not supported on Neuron " + f"for now. Supported architectures: " + f"{list(_NEURON_SUPPORTED_MODELS.keys())}") + + +def _get_default_neuron_config(model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig): + """Generate a neuron config based on vllm config args.""" + on_device_sampling_config = OnDeviceSamplingConfig(dynamic=True, + deterministic=False) + batch_size = scheduler_config.max_num_seqs + + neuron_config = dict( + tp_degree=parallel_config.tensor_parallel_size, + ctx_batch_size=1, + batch_size=batch_size, + max_context_length=scheduler_config.max_model_len, + seq_len=scheduler_config.max_model_len, + enable_bucketing=True, + is_continuous_batching=(batch_size > 1), + quantized=False, + torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], + padding_side="right", + on_device_sampling_config=on_device_sampling_config, + sequence_parallel_enabled=True, + ) + return neuron_config + + +def _get_default_speculation_config(model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + speculation_config: SpeculativeConfig): + """Generate a neuron config for speculative decoding based on vllm config + args.""" + neuron_config = dict( + tp_degree=parallel_config.tensor_parallel_size, + batch_size=scheduler_config.max_num_seqs, + max_context_length=scheduler_config.max_model_len, + seq_len=scheduler_config.max_model_len, + speculation_length=speculation_config.num_speculative_tokens, + trace_tokengen_model=False, + enable_fused_speculation=True, + enable_bucketing=True, + quantized=False, + torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], + on_device_sampling_config=dict( + top_k=1, + do_sample=False, + )) + return neuron_config + + +def _get_neuron_config_after_override(default_neuron_config, + overridden_neuron_config): + """Update default neuron config values with override args""" + overridden_neuron_config = overridden_neuron_config or {} + default_neuron_config.update(overridden_neuron_config) + return default_neuron_config + + +def get_neuron_model(model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig) -> nn.Module: + """Initializes a neuron-optimized model for inference.""" + model_arch = _get_model_architecture(model_config.hf_config) + if model_arch == "MllamaForConditionalGeneration": + model = NeuronMllamaForCausalLM(model_config.hf_config) + else: + model = NeuronCausalLM(model_config.hf_config) + default_neuron_config_args = _get_default_neuron_config( + model_config, parallel_config, scheduler_config) + neuron_config = _get_neuron_config_after_override( + default_neuron_config_args, model_config.override_neuron_config) + + override_neuron_config = model_config.override_neuron_config + model.load_weights(model_config.model, + neuron_config=neuron_config, + override_neuron_config=override_neuron_config) + return model.eval() + + +def get_neuron_speculation_model(model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + speculation_config: SpeculativeConfig): + """Initializes a neuron-optimized speculation model for inference. + + This model handles speculation using both a draft model and an EAGLE draft. + """ + model = NeuronSpeculationCausalLM(model_config.hf_config) + default_neuron_config_args = _get_default_speculation_config( + model_config, parallel_config, scheduler_config, speculation_config) + neuron_config = _get_neuron_config_after_override( + default_neuron_config_args, model_config.override_neuron_config) + + override_neuron_config = model_config.override_neuron_config + model.load_weights(model_config.model, + speculation_config.draft_model_config.model, + neuron_config=neuron_config, + override_neuron_config=override_neuron_config) + return model.eval() diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index 0ed22104317..b1df4fd1339 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -176,17 +176,26 @@ def cpu_platform_plugin() -> Optional[str]: def neuron_platform_plugin() -> Optional[str]: - is_neuron = False + tnx_installed = False + nxd_installed = False logger.debug("Checking if Neuron platform is available.") try: import transformers_neuronx # noqa: F401 - is_neuron = True + tnx_installed = True logger.debug("Confirmed Neuron platform is available because" " transformers_neuronx is found.") - except ImportError as e: - logger.debug("Neuron platform is not available because: %s", str(e)) + except ImportError: pass + try: + import neuronx_distributed_inference # noqa: F401 + nxd_installed = True + logger.debug("Confirmed Neuron platform is available because" + " neuronx_distributed_inference is found.") + except ImportError: + pass + + is_neuron = tnx_installed or nxd_installed return "vllm.platforms.neuron.NeuronPlatform" if is_neuron else None diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py index e37a3a578cf..71f7c718cdf 100644 --- a/vllm/platforms/neuron.py +++ b/vllm/platforms/neuron.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 - +import enum +import os +from functools import lru_cache from typing import TYPE_CHECKING, Optional from vllm import envs @@ -15,6 +17,11 @@ logger = init_logger(__name__) +class NeuronFramework(enum.Enum): + TRANSFORMERS_NEURONX = "transformers-neuronx" + NEURONX_DISTRIBUTED_INFERENCE = "neuronx-distributed-inference" + + class NeuronPlatform(Platform): _enum = PlatformEnum.NEURON device_name: str = "neuron" @@ -43,8 +50,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: assert (vllm_config.lora_config is None), "LoRA is not supported for Neuron backend." - assert (not vllm_config.speculative_config - ), "Speculative decoding not yet supported for Neuron backend." cache_config = vllm_config.cache_config if cache_config: @@ -67,3 +72,71 @@ def get_device_communicator_cls(cls) -> str: @classmethod def use_all_gather(cls) -> bool: return True + + @classmethod + @lru_cache + def is_neuronx_distributed_inference(cls) -> bool: + try: + import neuronx_distributed_inference + except ImportError: + neuronx_distributed_inference = None + return neuronx_distributed_inference is not None + + @classmethod + @lru_cache + def is_transformers_neuronx(cls) -> bool: + try: + import transformers_neuronx + except ImportError: + transformers_neuronx = None + return transformers_neuronx is not None + + def get_neuron_framework_to_use(self): + """Return the specified framework if corresponding installations are + available. + + If no framework is specified, use neuronx-distributed-inference by + default. + If that's unavailable, check and switch to transformers-neuronx. + """ + if not self.is_neuron(): + raise AssertionError( + f"Neuron Framework unavailable for platform: {self}") + + tnx_installed = self.is_transformers_neuronx() + nxd_installed = self.is_neuronx_distributed_inference() + + specified_framework = os.environ.get("VLLM_NEURON_FRAMEWORK") + tnx_framework = NeuronFramework.TRANSFORMERS_NEURONX.value + nxd_framework = NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE.value + if specified_framework == tnx_framework and tnx_installed: + return self.TRANSFORMERS_NEURONX + + if ((specified_framework == nxd_framework and nxd_installed) + or (specified_framework is None and nxd_installed)): + return NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE + + if specified_framework is None and tnx_installed: + return NeuronFramework.TRANSFORMERS_NEURONX + + return None + + def use_neuronx_distributed(self): + """ + Return True if the framework determined in get_neuron_framework_to_use() + is NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE, False otherwise. This + is used to select the Neuron model framework and framework-specific + configuration to apply during model compilation. + """ + nxd_framework = NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE + return self.get_neuron_framework_to_use() == nxd_framework + + def use_transformers_neuronx(self): + """ + Return True if the framework determined in get_neuron_framework_to_use() + is NeuronFramework.TRANSFORMERS_NEURONX, False otherwise. This is used + to select the Neuron model framework and framework-specific + configuration to apply during model compilation. + """ + return self.get_neuron_framework_to_use( + ) == NeuronFramework.TRANSFORMERS_NEURONX diff --git a/vllm/worker/multi_step_neuron_model_runner.py b/vllm/worker/multi_step_neuron_model_runner.py new file mode 100644 index 00000000000..9618a4b49ff --- /dev/null +++ b/vllm/worker/multi_step_neuron_model_runner.py @@ -0,0 +1,81 @@ +# SPDX-License-Identifier: Apache-2.0 + +from importlib.util import find_spec +from typing import List, Optional + +import torch + +from vllm.config import VllmConfig +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.multimodal import MultiModalKwargs +from vllm.sequence import IntermediateTensors +from vllm.worker.neuron_model_runner import (ModelInputForNeuron, + NeuronModelRunner) + + +class MultiStepNeuronModelRunner(NeuronModelRunner): + """A model runner for multi step decoding using the transformers_neuronx + framework""" + + def __init__( + self, + vllm_config: VllmConfig, + ): + super().__init__(vllm_config) + self.speculation_config = self.speculative_config + from transformers_neuronx.config import GenerationConfig + self.speculation_config.draft_model_config.neuron_sampling_params = ( + GenerationConfig( + max_length=self.scheduler_config.max_model_len, + do_sample=True, + per_batch_line=True, + top_k=[self._MAX_NEURON_SAMPLING_TOP_K] \ + * self.scheduler_config.max_num_seqs, + top_p=[1.0] * self.scheduler_config.max_num_seqs, + temperature=[1.0] * self.scheduler_config.max_num_seqs, + dynamic=True, + global_top_k=self._MAX_NEURON_SAMPLING_TOP_K + )) + + def load_model(self) -> None: + if find_spec("transformers_neuronx") is not None: + from vllm.model_executor.model_loader.neuron import ( + get_neuron_eagle_speculation_model, + get_neuron_speculation_model) + if self.speculation_config.speculative_token_tree is not None: + self.model = get_neuron_eagle_speculation_model( + self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + speculation_config=self.speculation_config) + else: + self.model = get_neuron_speculation_model( + self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + speculation_config=self.speculation_config) + else: + raise NotImplementedError( + "Supports only Transformer-NeuronX based models.") + + @torch.inference_mode() + def execute_model( + self, + model_input: ModelInputForNeuron, + kv_caches: Optional[List[torch.Tensor]] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[List[SamplerOutput]]: + logits = self.model( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + input_block_ids=model_input.input_block_ids, + **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, + device=self.device), + ) + + output = self.model.sample( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + return output diff --git a/vllm/worker/multi_step_neuronx_distributed_model_runner.py b/vllm/worker/multi_step_neuronx_distributed_model_runner.py new file mode 100644 index 00000000000..b6a3492a493 --- /dev/null +++ b/vllm/worker/multi_step_neuronx_distributed_model_runner.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import List, Optional + +import torch + +from vllm.config import VllmConfig +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.multimodal import MultiModalKwargs +from vllm.sequence import IntermediateTensors +from vllm.worker.neuronx_distributed_model_runner import ( + NeuronxDistributedModelRunner) + + +class MultiStepNeuronxDistributedModelRunner(NeuronxDistributedModelRunner): + """A model runner for multi-step decoding using the + neuronx-distributed-inference framework""" + + def __init__( + self, + vllm_config: VllmConfig, + ): + super().__init__(vllm_config) + + def load_model(self) -> None: + from vllm.model_executor.model_loader.neuronx_distributed import ( + get_neuron_speculation_model) + self.model = get_neuron_speculation_model( + self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + speculation_config=self.speculative_config) + + @torch.inference_mode() + def execute_model( + self, + model_input, + kv_caches: Optional[List[torch.Tensor]] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[List[SamplerOutput]]: + sampling_params = torch.tensor([[ + seq_group.sampling_params.top_k, + seq_group.sampling_params.top_p, + seq_group.sampling_params.temperature, + ] for seq_group in model_input.sampling_metadata.seq_groups]) + + logits = self.model( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + input_block_ids=model_input.input_block_ids, + sampling_params=sampling_params, + **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, + device=self.device), + ) + + output = self.model.sample( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + return output diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index e046ebc449d..c80b69e78dc 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -2,20 +2,20 @@ import os from dataclasses import dataclass -from importlib.util import find_spec from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import torch from torch import nn -from transformers_neuronx.config import GenerationConfig -from vllm.config import VllmConfig -from vllm.forward_context import set_forward_context +from vllm.config import DeviceConfig, VllmConfig from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.neuron import get_neuron_model -from vllm.multimodal import BatchedTensorInputs, MultiModalKwargs +from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, + MultiModalKwargs) +from vllm.platforms import current_platform +from vllm.sampling_params import SamplingParams from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase @@ -34,12 +34,18 @@ class ModelInputForNeuron(ModelRunnerInputBase): input_tokens: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None input_block_ids: Optional[torch.Tensor] = None - sampling_metadata: Optional["SamplingMetadata"] = None - multi_modal_kwargs: Optional[BatchedTensorInputs] = None + sampling_metadata: SamplingMetadata = None + multi_modal_kwargs: BatchedTensorInputs = None def as_broadcastable_tensor_dict( self) -> Dict[str, Union[int, torch.Tensor]]: - raise NotImplementedError("ModelInputForNeuron cannot be broadcast.") + return { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + "input_block_ids": self.input_block_ids, + "sampling_metadata": self.sampling_metadata, + "multi_modal_kwargs": self.multi_modal_kwargs, + } @classmethod def from_broadcasted_tensor_dict( @@ -47,11 +53,17 @@ def from_broadcasted_tensor_dict( tensor_dict: Dict[str, Any], attn_backend: Optional["AttentionBackend"] = None, ) -> "ModelInputForNeuron": - assert attn_backend is None - return cls.from_broadcasted_tensor_dict(tensor_dict) + return ModelInputForNeuron( + input_tokens=tensor_dict["input_tokens"], + input_positions=tensor_dict["input_positions"], + input_block_ids=tensor_dict["input_block_ids"], + sampling_metadata=tensor_dict["sampling_metadata"], + multi_modal_kwargs=tensor_dict["multi_modal_kwargs"], + ) class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): + """A model runner for AWS Neuron hardware""" # NEURON has an upper limit on the top_k _MAX_NEURON_SAMPLING_TOP_K = 256 @@ -61,13 +73,20 @@ def __init__( vllm_config: VllmConfig, ): ModelRunnerBase.__init__(self, vllm_config) - model_config = self.model_config - if model_config is not None and model_config.get_sliding_window(): + + if (self.model_config is not None + and self.model_config.get_sliding_window()): logger.warning("Sliding window is not supported on Neuron. " "The model will run without sliding window.") + self.device_config = (self.device_config if self.device_config + is not None else DeviceConfig()) self.device = self.device_config.device self.pin_memory = is_pin_memory_available() + # Multi-modal data support + self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ + .create_input_mapper(self.model_config) + # Lazy initialization. self.model: nn.Module # initialize after load_model. @@ -82,32 +101,33 @@ def __init__( self._previous_batch_request_ids: List[str] = [] if not self._on_device_sampling_disabled: - logger.warning( - "On-device sampling is turned on in Neuron by default, only " - "top_k, top_p, and temperature are current supported sampling " - "parameters. To turn off the on-device sampling, please set " - "the environment variable NEURON_ON_DEVICE_SAMPLING_DISABLED=1." - ) - self.model_config.neuron_sampling_params = GenerationConfig( - max_length=self.scheduler_config.max_model_len, - do_sample=True, - per_batch_line=True, - top_k=[self._MAX_NEURON_SAMPLING_TOP_K] \ - * self.scheduler_config.max_num_seqs, - top_p=[1.0] * self.scheduler_config.max_num_seqs, - temperature=[1.0] * self.scheduler_config.max_num_seqs, - dynamic=True, - global_top_k=self._MAX_NEURON_SAMPLING_TOP_K) + self._init_neuron_sampling() - def load_model(self) -> None: - if find_spec("transformers_neuronx") is not None: - self.model = get_neuron_model( - self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config) + def _init_neuron_sampling(self) -> None: + if current_platform.use_transformers_neuronx(): + from transformers_neuronx.config import GenerationConfig else: - raise NotImplementedError( - "Supports only Transformer-NeuronX based models.") + from transformers import GenerationConfig + logger.warning( + "On-device sampling is turned on in Neuron by default, only " + "top_k, top_p, and temperature are current supported sampling " + "parameters. To turn off the on-device sampling, please set " + "the environment variable NEURON_ON_DEVICE_SAMPLING_DISABLED=1.") + self.model_config.neuron_sampling_params = GenerationConfig( + max_length=self.scheduler_config.max_model_len, + do_sample=True, + per_batch_line=True, + top_k=[self._MAX_NEURON_SAMPLING_TOP_K] \ + * self.scheduler_config.max_num_seqs, + top_p=[1.0] * self.scheduler_config.max_num_seqs, + temperature=[1.0] * self.scheduler_config.max_num_seqs, + dynamic=True, + global_top_k=self._MAX_NEURON_SAMPLING_TOP_K) + + def load_model(self) -> None: + self.model = get_neuron_model(self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config) def get_model(self) -> nn.Module: return self.model @@ -240,6 +260,16 @@ def prepare_model_input( (input_tokens, input_positions, input_block_ids) = self._prepare_decode(seq_group_metadata_list) seq_lens = None + + if not self._on_device_sampling_disabled: + for seq_group_metadata in seq_group_metadata_list: + sampling_params = seq_group_metadata.sampling_params + top_k, top_p, temperature = ( + self._convert_to_neuron_sampling_params(sampling_params)) + sampling_params.top_k = top_k + sampling_params.top_p = top_p + sampling_params.temperature = temperature + sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, @@ -251,7 +281,8 @@ def prepare_model_input( self.pin_memory, generators=self.get_generators(finished_requests_ids)) - if not self._on_device_sampling_disabled: + if current_platform.use_transformers_neuronx( + ) and not self._on_device_sampling_disabled: # Once the request IDs are changed in current iteration, we will # update the on-device sampling parameters. current_batch_request_ids = [ @@ -259,7 +290,7 @@ def prepare_model_input( for seq_group_meta_data in seq_group_metadata_list ] if current_batch_request_ids != self._previous_batch_request_ids: - self._update_neuron_sampling_params(sampling_metadata) + self._update_neuron_sampling_params(seq_group_metadata_list) self._previous_batch_request_ids = current_batch_request_ids return ModelInputForNeuron(input_tokens=input_tokens, @@ -268,31 +299,59 @@ def prepare_model_input( sampling_metadata=sampling_metadata, multi_modal_kwargs=multi_modal_kwargs) - def _update_neuron_sampling_params(self, - sampling_metadata: SamplingMetadata): + def _update_neuron_sampling_params( + self, seq_group_metadata_list: List[SequenceGroupMetadata]): # Update Neuron sampling parameters (GenerationConfig in Neuron) current_sampling_params = self.model_config.neuron_sampling_params assert current_sampling_params is not None, ( f"Failed to update sampling_params, " f"current sampling params is {current_sampling_params}") + is_update_needed = False + top_k = current_sampling_params.top_k top_p = current_sampling_params.top_p temperature = current_sampling_params.temperature - for index, sequence_group_to_sample in enumerate( - sampling_metadata.seq_groups): - top_k[index] = self._convert_to_neuron_top_k( - sequence_group_to_sample.sampling_params.top_k) - top_p[index] = sequence_group_to_sample.sampling_params.top_p - temperature[index] = \ - sequence_group_to_sample.sampling_params.temperature - self.model.model.update_generation_config(current_sampling_params) + # The index of a sequence's sampling parameters in neuron is equal to + # its index in `input_block_ids`. + for seq_group_metadata in seq_group_metadata_list: + seq_ids = list(seq_group_metadata.seq_data.keys()) + sampling_params = seq_group_metadata.sampling_params + + seq_group_top_k = sampling_params.top_k + seq_group_top_p = sampling_params.top_p + seq_group_temperature = sampling_params.temperature - def _convert_to_neuron_top_k(self, top_k: int) -> int: + for seq_id in seq_ids: + index = seq_group_metadata.block_tables[seq_id][0] + if (top_k[index] != seq_group_top_k + or top_p[index] != seq_group_top_p + or temperature[index] != seq_group_temperature): + is_update_needed = True + + top_k[index] = seq_group_top_k + top_p[index] = seq_group_top_p + temperature[index] = seq_group_temperature + + # update_generation_config is only available in transformers-neuronx + if is_update_needed and current_platform.use_transformers_neuronx(): + self.model.model.update_generation_config(current_sampling_params) + + def _convert_to_neuron_sampling_params( + self, sampling_params: SamplingParams) -> Tuple[int, float, float]: + # Returns the top_k, top_p and temperature parameters for neuron. + top_k = sampling_params.top_k + top_p = sampling_params.top_p + temperature = sampling_params.temperature + + if temperature == 0.0: + # Enable greedy sampling on zero temperature + return (1, 1.0, 1.0) if top_k < 0 or top_k > self._MAX_NEURON_SAMPLING_TOP_K: - return self._MAX_NEURON_SAMPLING_TOP_K - return top_k + top_k = self._MAX_NEURON_SAMPLING_TOP_K + + return (top_k, top_p, temperature) @torch.inference_mode() def execute_model( @@ -306,7 +365,26 @@ def execute_model( raise ValueError( "NeuronModelRunner does not support multi-step execution.") - with set_forward_context(None, self.vllm_config, 0): + # extract top_k, top_p and temperature from model_input for neuron + # forward call + sampling_params = (torch.tensor([[ + seq_group.sampling_params.top_k, seq_group.sampling_params.top_p, + seq_group.sampling_params.temperature + ] for seq_group in model_input.sampling_metadata.seq_groups])) + + if current_platform.use_neuronx_distributed(): + hidden_states = self.model( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + input_block_ids=model_input.input_block_ids, + sampling_params=sampling_params, + **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs + or {}, + device=self.device), + ) + elif current_platform.use_transformers_neuronx(): + # [TODO] validate on-device sampling + # The model signature may need change for on-device sampling hidden_states = self.model( input_ids=model_input.input_tokens, positions=model_input.input_positions, diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index df651e05a7b..aa8e39613ee 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -1,61 +1,81 @@ # SPDX-License-Identifier: Apache-2.0 """A Neuron worker class.""" +import os from typing import List, Optional, Tuple -import torch import torch.distributed from vllm.config import VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) +from vllm.logger import init_logger from vllm.model_executor import set_random_seed -from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.platforms import current_platform +from vllm.platforms.neuron import NeuronFramework from vllm.sequence import ExecuteModelRequest from vllm.worker.neuron_model_runner import NeuronModelRunner from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, LoRANotSupportedWorkerBase, WorkerBase, WorkerInput) +logger = init_logger(__name__) + class NeuronWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase): """A worker class that executes the model on a group of neuron cores. """ - def __init__( - self, - vllm_config: VllmConfig, - local_rank: int, - rank: int, - distributed_init_method: str, - is_driver_worker: bool = True, - ) -> None: + model_runner: NeuronModelRunner + + def __init__(self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + is_driver_worker: bool = False) -> None: WorkerBase.__init__(self, vllm_config=vllm_config) self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method + self.is_driver_worker = is_driver_worker + if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils import init_cached_hf_modules init_cached_hf_modules() - self.model_runner: NeuronModelRunner = NeuronModelRunner( - vllm_config=vllm_config) - self.is_driver_worker = is_driver_worker - - def execute_model( - self, - execute_model_req: Optional[ExecuteModelRequest] = None, - ) -> Optional[List[SamplerOutput]]: - assert execute_model_req is not None - assert (not execute_model_req.blocks_to_swap_in - and not execute_model_req.blocks_to_swap_out - and not execute_model_req.blocks_to_copy), ( - "Cache operations are not supported for Neuron backend.") - assert execute_model_req.num_lookahead_slots == 0, ( - "lookahead not supported for Neuron backend.") - output = LocalOrDistributedWorkerBase.execute_model( - self, execute_model_req) - return output + neuron_framework = current_platform.get_neuron_framework_to_use() + if neuron_framework == NeuronFramework.TRANSFORMERS_NEURONX: + self.model_runner = self.get_tnx_model_runner(vllm_config) + elif neuron_framework == NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE: + self.model_runner = self.get_neuronx_distributed_model_runner( + vllm_config) + else: + raise NotImplementedError( + "Specified framework" + + f" {os.environ.get('VLLM_NEURON_FRAMEWORK')}" + + " is either not installed or not supported." + + " Supported frameworks: " + + "[transformers-neuronx, neuronx-distributed-inference]") + + def get_tnx_model_runner(self, vllm_config): + from vllm.worker.multi_step_neuron_model_runner import ( + MultiStepNeuronModelRunner) + if self.speculative_config is not None: + return MultiStepNeuronModelRunner(vllm_config=vllm_config) + else: + return NeuronModelRunner(vllm_config=vllm_config) + + def get_neuronx_distributed_model_runner(self, vllm_config): + from vllm.worker.multi_step_neuronx_distributed_model_runner import ( + MultiStepNeuronxDistributedModelRunner) + from vllm.worker.neuronx_distributed_model_runner import ( + NeuronxDistributedModelRunner) + if self.speculative_config is not None: + return MultiStepNeuronxDistributedModelRunner( + vllm_config=vllm_config) + else: + return NeuronxDistributedModelRunner(vllm_config=vllm_config) def init_device(self) -> None: self.init_distributed_environment() @@ -121,17 +141,17 @@ def get_cache_block_size_bytes(self) -> int: def init_distributed_environment(self): """Neuron uses transformers-neuronx for tensor parallelism. - It has only one process to control multiple devices. - vLLM still needs the environment initialized when TP/PP > 1, - so we initialize a distributed environment with one process. + + vLLM still needs the environment initialized when TP/PP > 1 """ init_distributed_environment( world_size=1, - rank=0, - local_rank=0, + rank=self.rank, + local_rank=self.local_rank, distributed_init_method=self.distributed_init_method, backend="gloo", ) + ensure_model_parallel_initialized( 1, 1, diff --git a/vllm/worker/neuronx_distributed_model_runner.py b/vllm/worker/neuronx_distributed_model_runner.py new file mode 100644 index 00000000000..4e784e5e030 --- /dev/null +++ b/vllm/worker/neuronx_distributed_model_runner.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Optional + +import torch +from neuronx_distributed_inference.modules.generation.sampling import ( + prepare_sampling_params) + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.model_loader.neuronx_distributed import ( + _get_model_architecture, get_neuron_model) +from vllm.sequence import IntermediateTensors +from vllm.worker.neuron_model_runner import (ModelInputForNeuron, + NeuronModelRunner) + +logger = init_logger(__name__) + + +class NeuronxDistributedModelRunner(NeuronModelRunner): + + def __init__( + self, + vllm_config: VllmConfig, + ): + super().__init__(vllm_config) + + def load_model(self) -> None: + self.model = get_neuron_model(self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config) + + def get_nxd_sampling_params(self, sampling_metadata): + if self.model.config.neuron_config.on_device_sampling_config: + max_topk = (self.model.config.neuron_config. + on_device_sampling_config.global_topk) + else: + max_topk = self.model.config.vocab_size + + top_k = [1] * self.scheduler_config.max_num_seqs + top_p = [1.0] * self.scheduler_config.max_num_seqs + temperature = [1.0] * self.scheduler_config.max_num_seqs + + for index, sequenceGroupToSample in enumerate( + sampling_metadata.seq_groups): + top_k[index] = (sequenceGroupToSample.sampling_params.top_k + if sequenceGroupToSample.sampling_params.top_k > 0 + else max_topk) + top_p[index] = sequenceGroupToSample.sampling_params.top_p + temperature[index] = ( + sequenceGroupToSample.sampling_params.temperature) + + sampling_params = prepare_sampling_params( + batch_size=self.scheduler_config.max_num_seqs, + top_k=top_k, + top_p=top_p, + temperature=temperature) + return sampling_params + + def get_multi_modal_data_neuron(self, input_images): + raise NotImplementedError("need to restore multi-modal support") + + @torch.inference_mode() + def execute_model( + self, + model_input: ModelInputForNeuron, + kv_caches: Optional[List[torch.Tensor]] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[List[SamplerOutput]]: + if num_steps > 1: + raise ValueError( + "NeuronModelRunner does not support multi-step execution.") + + if _get_model_architecture( + self.model.config) != "MllamaForConditionalGeneration": + return super().execute_model(model_input, kv_caches, + intermediate_tensors, num_steps) + + sampling_params = self.get_nxd_sampling_params( + model_input.sampling_metadata) + + if model_input.multi_modal_kwargs.get('image') is not None: + pixel_values = [] + aspect_ratios = [] + num_chunks = [] + has_image = [] + for multi_modal_input in model_input.multi_modal_kwargs.get( + 'image'): + image_tensors = self.get_multi_modal_data_neuron( + multi_modal_input.squeeze(0)) + pixel_values.append(image_tensors[0]) + aspect_ratios.append(image_tensors[1]) + num_chunks.append(image_tensors[2]) + has_image.append(image_tensors[3]) + + pixel_values = torch.cat(pixel_values, dim=0) + aspect_ratios = torch.cat(aspect_ratios, dim=0) + num_chunks = torch.cat(num_chunks, dim=0) + has_image = torch.cat(has_image, dim=0) + + hidden_states = self.model( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + seq_ids=model_input.input_block_ids, + pixel_values=pixel_values, + aspect_ratios=aspect_ratios, + sampling_params=sampling_params, + num_chunks=num_chunks, + has_image=has_image, + ) + else: + empty_pixel_values = torch.zeros([1, 1, 4, 3, 560, 560], + dtype=torch.bfloat16) + empty_aspect_ratios = torch.ones([1, 1, 2], dtype=torch.int64) + num_chunks = torch.tensor([[1] + ]) # dummy num_chunks, will not be used + has_image = torch.tensor([0]) + hidden_states = self.model( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + seq_ids=model_input.input_block_ids, + pixel_values=empty_pixel_values, + aspect_ratios=empty_aspect_ratios, + sampling_params=sampling_params, + num_chunks=num_chunks, + has_image=has_image, + ) + + output = self.model.sample( + hidden_states=hidden_states, + sampling_metadata=model_input.sampling_metadata, + ) + + return [output]