From 6531f921deea16355970a2f12164c504ec7f6c07 Mon Sep 17 00:00:00 2001 From: Aaron Dou Date: Fri, 14 Feb 2025 22:25:51 +0000 Subject: [PATCH 01/38] [NxDI upstream foundation] set up NxDI model runner Signed-off-by: Satyajith Chilappagari --- .../model_loader/neuronx_distributed.py | 498 ++++++++++++++++++ vllm/worker/neuron_model_runner.py | 189 ++++--- vllm/worker/neuron_worker.py | 120 ++++- .../neuronx_distributed_model_runner.py | 137 +++++ 4 files changed, 857 insertions(+), 87 deletions(-) create mode 100644 vllm/model_executor/model_loader/neuronx_distributed.py create mode 100644 vllm/worker/neuronx_distributed_model_runner.py diff --git a/vllm/model_executor/model_loader/neuronx_distributed.py b/vllm/model_executor/model_loader/neuronx_distributed.py new file mode 100644 index 00000000000..6c8b3aa0610 --- /dev/null +++ b/vllm/model_executor/model_loader/neuronx_distributed.py @@ -0,0 +1,498 @@ +import copy +import hashlib +import importlib +import multiprocessing +import os +import shutil +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from transformers import AutoModelForCausalLM, PretrainedConfig, AutoTokenizer + +from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, + SequenceOutput) + +from neuronx_distributed_inference.models.mllama.utils import create_vision_mask +from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config +from neuronx_distributed_inference.models.config import FusedSpecNeuronConfig, OnDeviceSamplingConfig + +logger = init_logger(__name__) + +TORCH_DTYPE_TO_NEURON_AMP = { + "auto": "float32", + "half": "float16", + "float16": "float16", + "bfloat16": "bfloat16", + "float": "float32", + "float32": "float32", + torch.float16: "float16", + torch.bfloat16: "bfloat16", + torch.float32: "float32", +} + + +# Models supported by Neuronx distributed for inference. +_NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str]] = { + "LlamaForCausalLM": ("neuronx_distributed_inference.models.llama.modeling_llama", + "NeuronLlamaForCausalLM"), + "DbrxForCausalLM": ("neuronx_distributed_inference.models.dbrx.modeling_dbrx", + "NeuronDbrxForCausalLM"), + "MixtralForCausalLM": ("neuronx_distributed_inference.models.mixtral.modeling_mixtral", + "NeuronMixtralForCausalLM"), + "MllamaForConditionalGeneration": ("neuronx_distributed_inference.models.mllama.modeling_mllama", + "NeuronMllamaForCausalLM"), +} + + +class NeuronCausalLM(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + ) -> None: + super().__init__() + self.config = config + self.logits_processor = LogitsProcessor(config.vocab_size, + logits_as_input=True) + self.sampler = Sampler() + + # Lazy initialized + self.model: nn.Module + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_block_ids: torch.Tensor, + sampling_params: torch.Tensor, + ) -> torch.Tensor: + output = self.model(input_ids, + attention_mask=None, + position_ids=positions, + seq_ids=input_block_ids, + sampling_params=sampling_params) + # on-device sampling + if self.config.neuron_config.on_device_sampling_config: + return output.hidden_states + else: + return output.logits[:, -1, :] + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(None, hidden_states, sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + # on-device sampling + if self.config.neuron_config.on_device_sampling_config: + batch_size = logits.shape + seq_ids = [seq_id for sg in sampling_metadata.seq_groups for seq_id in sg.seq_ids] + assert len(seq_ids) == list(batch_size)[0], "batch size mismatch" + # Organize input tensors by step instead of by sequence. + accepted_token_ids_by_step = logits.flatten() + accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() + + step_output_token_ids = [] + for i, seq_id in enumerate(seq_ids): + token_id = accepted_token_ids_by_step[i] + step_output_token_ids.append(CompletionSequenceGroupOutput(samples=[SequenceOutput(parent_seq_id=seq_id, output_token=token_id, logprobs={token_id: Logprob(token_id)})], prompt_logprobs=None)) + return SamplerOutput(outputs=step_output_token_ids) + else: + return self.sampler(logits, sampling_metadata) + + def load_weights(self, model_name_or_path: str, **kwargs): + arch = _get_model_architecture(self.config) + neuronx_module_path, neuronx_model_cls_name = ( + _NEURON_SUPPORTED_MODELS[arch]) + neuronx_module = importlib.import_module(neuronx_module_path) + neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name) + neuron_config = neuronx_model_cls.get_neuron_config_cls()(**kwargs['neuron_config']) + self.config.neuron_config = neuron_config + config = neuronx_model_cls.get_config_cls()( + neuron_config, load_config=load_pretrained_config(model_name_or_path) + ) + if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None: + compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS") + elif os.path.exists(model_name_or_path): + compiled_model_path = os.path.join(model_name_or_path, + f"neuron-compiled-artifacts/{hashlib.md5(config.to_json_string().encode('utf-8')).hexdigest()}/") + shutil.rmtree(compiled_model_path, ignore_errors=True) + else: + compiled_model_path = os.path.join("local-models", model_name_or_path, + f"neuron-compiled-artifacts/{hashlib.md5(config.to_json_string().encode('utf-8')).hexdigest()}/") + shutil.rmtree(compiled_model_path, ignore_errors=True) + try: + self.model = neuronx_model_cls(compiled_model_path) + override_neuron_config = kwargs["override_neuron_config"] + for k, v in override_neuron_config.items(): + setattr(self.model.config.neuron_config, k, v) + self.model.load(compiled_model_path) + return + except (FileNotFoundError, ValueError) as e: + logger.warning(f"Exception: {e}") + logger.warning(f"Failed to load the model from {compiled_model_path}, Recompiling...") + if not os.path.exists(model_name_or_path): + hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path) + saved_path = os.path.join("local-models", model_name_or_path) + hf_model.save_pretrained(saved_path) + model_name_or_path = saved_path + self.model = neuronx_model_cls(model_name_or_path, config) + self.model.compile(compiled_model_path) + self.model.load(compiled_model_path) + +class NeuronMllamaForCausalLM(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + on_device_sampling_disabled: bool = False) -> None: + super().__init__() + self.config = config + self.logits_processor = LogitsProcessor(config.get_text_config().vocab_size, + logits_as_input=True) + + self.on_device_sampling_disabled = on_device_sampling_disabled + if self.on_device_sampling_disabled: + # Use default sampler + self.sampler = Sampler() + + # Lazy initialized + self.model: nn.Module + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + seq_ids: torch.Tensor, + pixel_values: torch.Tensor, + aspect_ratios: torch.Tensor, + num_chunks: torch.Tensor, + has_image: torch.Tensor, + sampling_params + ) -> torch.Tensor: + self.vision_mask = create_vision_mask(input_ids, self.vision_token_id) + output = self.model(input_ids.to(torch.int32), + attention_mask=None, + position_ids=positions.to(torch.int32), + seq_ids= seq_ids.flatten().to(torch.int32), + pixel_values=pixel_values.to(self.config.vision_config.torch_dtype), + aspect_ratios=aspect_ratios.to(torch.int32), + vision_mask=self.vision_mask.to(torch.int32), + sampling_params=sampling_params, + num_chunks=num_chunks.to(torch.int32), + has_image=has_image.to(torch.int32), + ) + if self.config.neuron_config.on_device_sampling_config: + return output.hidden_states + return output.logits[:, -1, :] + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(None, hidden_states, sampling_metadata) + return logits + + def sample(self, hidden_states, sampling_metadata): + if not self.on_device_sampling_disabled: + with torch.profiler.record_function("sample"): + hidden_states = hidden_states.flatten() + res = [] + sample_idx = 0 + for seq_group in sampling_metadata.seq_groups: + seq_ids = seq_group.seq_ids + samples = [] + for seq_id in seq_ids: + token_id = hidden_states[sample_idx].item() + samples.append(SequenceOutput(parent_seq_id=seq_id, output_token=token_id, + logprobs={token_id: Logprob(token_id)})) + sample_idx += 1 + res.append(CompletionSequenceGroupOutput(samples=samples, prompt_logprobs=None)) + next_tokens = SamplerOutput(outputs=res) + else: + next_tokens = self.sampler(None, hidden_states, sampling_metadata) + return next_tokens + + def load_weights(self, model_name_or_path: str, **kwargs): + arch = _get_model_architecture(self.config) + neuronx_module_path, neuronx_model_cls_name = ( + _NEURON_SUPPORTED_MODELS[arch]) + neuronx_module = importlib.import_module(neuronx_module_path) + neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name) + neuron_config = neuronx_model_cls.get_neuron_config_cls()(**kwargs['neuron_config']) + self.config.neuron_config = neuron_config + logger.info(f"neuron_config buckets: {self.config.neuron_config.buckets}") + config = neuronx_model_cls.get_config_cls()( + neuron_config, load_config=load_pretrained_config(model_name_or_path) + ) + if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None: + compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS") + elif os.path.exists(model_name_or_path): + compiled_model_path = os.path.join(model_name_or_path, + f"neuron-compiled-artifacts/{hashlib.md5(config.to_json_string().encode('utf-8')).hexdigest()}/") + else: + compiled_model_path = os.path.join("local-models", model_name_or_path, + f"neuron-compiled-artifacts/{hashlib.md5(config.to_json_string().encode('utf-8')).hexdigest()}/") + try: + self.model = neuronx_model_cls(compiled_model_path) + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + self.vision_token_id = tokenizer("<|image|>", add_special_tokens=False).input_ids + self.model.load(compiled_model_path) + return + except (FileNotFoundError, ValueError): + logger.warning(f"Failed to load the model from {compiled_model_path}, Recompiling...") + if not os.path.exists(model_name_or_path): + hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path) + saved_path = os.path.join("local-models", model_name_or_path) + hf_model.save_pretrained(saved_path) + model_name_or_path = saved_path + self.model = neuronx_model_cls(model_name_or_path, config) + + logger.info(f"\nCompiling and saving model to {model_name_or_path}...") + p = multiprocessing.Process(target=compile_model, args=(self, compiled_model_path)) + p.start() + p.join() + + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + tokenizer.save_pretrained(compiled_model_path) + logger.info(f"successfully compiled and saved the model in {compiled_model_path}") + + # Read "<|image|>" token_id from the tokenizer + self.vision_token_id = tokenizer("<|image|>", add_special_tokens=False).input_ids + logger.info("\nLoading model from compiled checkpoint...") + self.model.load(compiled_model_path) + + +def compile_model(neuron_model, traced_model_path): + neuron_model.model.compile(traced_model_path) + + +class NeuronSpeculationCausalLM(nn.Module): + def __init__( + self, + config: PretrainedConfig, + ) -> None: + super().__init__() + self.config = config + self.logits_processor = LogitsProcessor(config.vocab_size, + logits_as_input=True) + # Lazy initialized + self.model: nn.Module + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_block_ids: torch.Tensor, + sampling_params: torch.Tensor, + ) -> torch.Tensor: + output = self.model(input_ids, + attention_mask=None, + position_ids=positions, + seq_ids=input_block_ids, + sampling_params=sampling_params) + if output.fused_outputs[1].shape[-1] == 1: + # CTX encoding + return output.fused_outputs[1].view(1, -1) + draft_new_tokens = output.fused_outputs[0].view(1, -1) + target_tokens = output.fused_outputs[1].view(1, -1) + if self.config.neuron_config.enable_eagle_speculation: + candidate_new_tokens = draft_new_tokens[:, 1:] + else: + candidate_new_tokens = draft_new_tokens[:,:-1] + selected_tokens = target_tokens[:,:-1] + n_matches = ((~(candidate_new_tokens == selected_tokens)).cumsum(dim=-1) < 1).sum() + accepted_tokens = target_tokens[:,:n_matches+1] + return accepted_tokens + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[List[SamplerOutput]]: + batch_size, num_steps = logits.shape + seq_ids = [seq_id for sg in sampling_metadata.seq_groups for seq_id in sg.seq_ids] + # Organize input tensors by step instead of by sequence. + accepted_token_ids_by_step = logits.transpose(0, 1) + accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() + + sampler_output_list = [] + for step_index in range(num_steps): + if all(token_id == -1 for token_id in accepted_token_ids_by_step[step_index]): + break + step_output_token_ids = [] + for sequence_index in range(batch_size): + token_id = accepted_token_ids_by_step[step_index][sequence_index] + step_output_token_ids.append(CompletionSequenceGroupOutput(samples=[SequenceOutput(parent_seq_id=seq_ids[sequence_index], output_token=token_id, logprobs={token_id: Logprob(token_id)})], prompt_logprobs=None)) + sampler_output_list.append( + SamplerOutput(outputs=step_output_token_ids)) + return sampler_output_list + + def load_weights(self, model_name_or_path: str, draft_model_name_or_path: str, **kwargs): + arch = _get_model_architecture(self.config) + neuronx_module_path, neuronx_model_cls_name = ( + _NEURON_SUPPORTED_MODELS[arch]) + neuronx_module = importlib.import_module(neuronx_module_path) + neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name) + neuron_config = neuronx_model_cls.get_neuron_config_cls()(**kwargs['neuron_config']) + config = neuronx_model_cls.get_config_cls()( + neuron_config, load_config=load_pretrained_config(model_name_or_path) + ) + + draft_neuron_config = copy.deepcopy(config.neuron_config) + if not config.neuron_config.enable_eagle_speculation: + draft_neuron_config.speculation_length = 0 + draft_neuron_config.trace_tokengen_model = True + draft_neuron_config.enable_fused_speculation = False + if config.neuron_config.enable_eagle_speculation: + draft_neuron_config.is_eagle_draft = True + draft_neuron_config.sequence_parallel_enabled = False + draft_config = neuronx_model_cls.get_config_cls()( + draft_neuron_config, load_config=load_pretrained_config(draft_model_name_or_path) + ) + fused_spec_config = FusedSpecNeuronConfig(neuronx_model_cls._model_cls, draft_config=draft_config, draft_model_path=draft_model_name_or_path) + config.fused_spec_config = fused_spec_config + self.config.neuron_config = neuron_config + + if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None: + compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS") + elif os.path.exists(model_name_or_path): + compiled_model_path = os.path.join(model_name_or_path, + f"neuron-compiled-artifacts/{hashlib.md5(config.to_json_string().encode('utf-8')).hexdigest()}/") + shutil.rmtree(compiled_model_path, ignore_errors=True) + else: + compiled_model_path = os.path.join("local-models", model_name_or_path, + f"neuron-compiled-artifacts/{hashlib.md5(config.to_json_string().encode('utf-8')).hexdigest()}/") + shutil.rmtree(compiled_model_path, ignore_errors=True) + try: + self.model = neuronx_model_cls(compiled_model_path) + override_neuron_config = kwargs["override_neuron_config"] + for k, v in override_neuron_config.items(): + setattr(self.model.config.neuron_config, k, v) + self.model.load(compiled_model_path) + return + except (FileNotFoundError, ValueError) as e: + logger.warning(f"Exception: {e}") + logger.warning(f"Failed to load the model from {compiled_model_path}, Recompiling...") + if draft_model_name_or_path == model_name_or_path: + draft_checkpoint_download = False + if not os.path.exists(model_name_or_path): + hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path) + saved_path = os.path.join("local-models", model_name_or_path) + hf_model.save_pretrained(saved_path) + model_name_or_path = saved_path + if not os.path.exists(draft_model_name_or_path): + if draft_checkpoint_download: + hf_model = AutoModelForCausalLM.from_pretrained(draft_model_name_or_path) + saved_path = os.path.join("local-models", draft_model_name_or_path) + hf_model.save_pretrained(saved_path) + draft_model_name_or_path = saved_path + else: + draft_model_name_or_path = model_name_or_path + config.fused_spec_config.draft_model_path = draft_model_name_or_path + self.model = neuronx_model_cls(model_name_or_path, config) + self.model.compile(compiled_model_path) + self.model.load(compiled_model_path) + +def _get_model_architecture(config: PretrainedConfig) -> str: + architectures = getattr(config, "architectures", []) + for arch in architectures: + if arch in _NEURON_SUPPORTED_MODELS: + return arch + raise ValueError( + f"Model architectures {architectures} are not supported on Neuron " + f"for now. Supported architectures: " + f"{list(_NEURON_SUPPORTED_MODELS.keys())}") + +def _get_default_neuron_config(model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig): + + logger.info(f"Initializing OnDeviceSampling config with global_topk=64") + on_device_sampling_config = OnDeviceSamplingConfig(global_topk=64, + dynamic=True, + deterministic=False) + batch_size = scheduler_config.max_num_seqs + + neuron_config = dict( + tp_degree=parallel_config.tensor_parallel_size, + ctx_batch_size=1, + batch_size=batch_size, + max_context_length=scheduler_config.max_model_len, + seq_len=scheduler_config.max_model_len, + enable_bucketing=True, + is_continuous_batching=(batch_size>1), + quantized=False, + torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], + padding_side="right", + on_device_sampling_config=on_device_sampling_config, + sequence_parallel_enabled=True, + ) + return neuron_config + + +def _get_default_neuron_speculation_config(model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + speculation_config: SpeculativeConfig): + neuron_config = dict( + tp_degree=parallel_config.tensor_parallel_size, + batch_size=scheduler_config.max_num_seqs, + max_context_length=scheduler_config.max_model_len, + seq_len=scheduler_config.max_model_len, + speculation_length=speculation_config.num_speculative_tokens, + trace_tokengen_model=False, + enable_fused_speculation=True, + enable_bucketing=True, + quantized=False, + torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], + on_device_sampling_config= dict(top_k=1, do_sample=False,) + ) + return neuron_config + + +def _get_neuron_config_after_override(default_neuron_config, + overridden_neuron_config): + overridden_neuron_config = overridden_neuron_config or {} + default_neuron_config.update(overridden_neuron_config) + return default_neuron_config + +def get_neuron_model(model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig) -> nn.Module: + model_arch = _get_model_architecture(model_config.hf_config) + if model_arch == "MllamaForConditionalGeneration": + model = NeuronMllamaForCausalLM(model_config.hf_config) + else: + model = NeuronCausalLM(model_config.hf_config) + default_neuron_config_args = _get_default_neuron_config( + model_config, parallel_config, scheduler_config) + neuron_config = _get_neuron_config_after_override(default_neuron_config_args, + model_config.override_neuron_config) + model.load_weights(model_config.model, + neuron_config=neuron_config, + override_neuron_config=model_config.override_neuron_config) + return model.eval() + +def get_neuron_speculation_model(model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + speculation_config: SpeculativeConfig): + model = NeuronSpeculationCausalLM(model_config.hf_config) + default_neuron_config_args = _get_default_neuron_speculation_config( + model_config, parallel_config, scheduler_config, speculation_config) + neuron_config = _get_neuron_config_after_override(default_neuron_config_args, + model_config.override_neuron_config) + model.load_weights(model_config.model, + speculation_config.draft_model_config.model, + neuron_config=neuron_config, + override_neuron_config=model_config.override_neuron_config) + return model.eval() diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index f2093fc42ad..81d9a81294f 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -2,24 +2,23 @@ import os from dataclasses import dataclass -from importlib.util import find_spec from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import torch from torch import nn -from transformers_neuronx.config import GenerationConfig from vllm.config import VllmConfig -from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.neuron import get_neuron_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalKwargs) +from vllm.sampling_params import SamplingParams from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase +from vllm.worker.neuron_worker import use_neuronx_distributed, use_transformers_neuronx if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -40,7 +39,13 @@ class ModelInputForNeuron(ModelRunnerInputBase): def as_broadcastable_tensor_dict( self) -> Dict[str, Union[int, torch.Tensor]]: - raise NotImplementedError("ModelInputForNeuron cannot be broadcast.") + return { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + "input_block_ids": self.input_block_ids, + "sampling_metadata": self.sampling_metadata, + "multi_modal_kwargs": self.multi_modal_kwargs, + } @classmethod def from_broadcasted_tensor_dict( @@ -48,8 +53,13 @@ def from_broadcasted_tensor_dict( tensor_dict: Dict[str, Any], attn_backend: Optional["AttentionBackend"] = None, ) -> "ModelInputForNeuron": - assert attn_backend is None - return cls.from_broadcasted_tensor_dict(tensor_dict) + return ModelInputForNeuron( + input_tokens=tensor_dict["input_tokens"], + input_positions=tensor_dict["input_positions"], + input_block_ids=tensor_dict["input_block_ids"], + sampling_metadata=tensor_dict["sampling_metadata"], + multi_modal_kwargs=tensor_dict["multi_modal_kwargs"], + ) class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): @@ -62,16 +72,17 @@ def __init__( vllm_config: VllmConfig, ): ModelRunnerBase.__init__(self, vllm_config) - model_config = self.model_config - if model_config is not None and model_config.get_sliding_window(): + + if self.model_config is not None and self.model_config.get_sliding_window(): logger.warning("Sliding window is not supported on Neuron. " "The model will run without sliding window.") + self.device_config = (self.device_config + if self.device_config is not None else DeviceConfig()) self.device = self.device_config.device self.pin_memory = is_pin_memory_available() # Multi-modal data support - self.mm_registry = MULTIMODAL_REGISTRY - self.multi_modal_input_mapper = self.mm_registry \ + self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ .create_input_mapper(self.model_config) # Lazy initialization. @@ -88,32 +99,36 @@ def __init__( self._previous_batch_request_ids: List[str] = [] if not self._on_device_sampling_disabled: - logger.warning( + self._init_neuron_sampling() + + def _init_neuron_sampling(self) -> None: + if use_transformers_neuronx(): + from transformers_neuronx.config import GenerationConfig + else: + from transformers import GenerationConfig + logger.warning( "On-device sampling is turned on in Neuron by default, only " "top_k, top_p, and temperature are current supported sampling " "parameters. To turn off the on-device sampling, please set " "the environment variable NEURON_ON_DEVICE_SAMPLING_DISABLED=1." ) - self.model_config.neuron_sampling_params = GenerationConfig( - max_length=self.scheduler_config.max_model_len, - do_sample=True, - per_batch_line=True, - top_k=[self._MAX_NEURON_SAMPLING_TOP_K] \ - * self.scheduler_config.max_num_seqs, - top_p=[1.0] * self.scheduler_config.max_num_seqs, - temperature=[1.0] * self.scheduler_config.max_num_seqs, - dynamic=True, - global_top_k=self._MAX_NEURON_SAMPLING_TOP_K) + self.model_config.neuron_sampling_params = GenerationConfig( + max_length=self.scheduler_config.max_model_len, + do_sample=True, + per_batch_line=True, + top_k=[self._MAX_NEURON_SAMPLING_TOP_K] \ + * self.scheduler_config.max_num_seqs, + top_p=[1.0] * self.scheduler_config.max_num_seqs, + temperature=[1.0] * self.scheduler_config.max_num_seqs, + dynamic=True, + global_top_k=self._MAX_NEURON_SAMPLING_TOP_K) def load_model(self) -> None: - if find_spec("transformers_neuronx") is not None: - self.model = get_neuron_model( - self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config) - else: - raise NotImplementedError( - "Supports only Transformer-NeuronX based models.") + self.model = get_neuron_model( + self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config + ) def get_model(self) -> nn.Module: return self.model @@ -129,7 +144,7 @@ def _prepare_prompt( input_block_ids: List[int] = [] seq_lens: List[int] = [] - multi_modal_kwargs_list: List[MultiModalKwargs] = [] + multi_modal_inputs_list: List[MultiModalKwargs] = [] for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) @@ -150,16 +165,9 @@ def _prepare_prompt( input_block_ids.append(block_table[0]) mm_data = seq_group_metadata.multi_modal_data - if mm_data: - if self.mm_registry.has_processor(self.model_config): - mm_kwargs = mm_data - else: - mm_kwargs = self.multi_modal_input_mapper( - mm_data, - seq_group_metadata.mm_processor_kwargs, - ) - - multi_modal_kwargs_list.append(mm_kwargs) + # if mm_data: + # # Process multi-modal data + # multi_modal_inputs_list.append(mm_data) max_seq_len = max(seq_lens) assert max_seq_len > 0 @@ -177,7 +185,7 @@ def _prepare_prompt( dtype=torch.long, device=self.device) - multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) + multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_inputs_list) return (input_tokens, input_positions, input_block_ids, seq_lens, multi_modal_kwargs) @@ -254,6 +262,23 @@ def prepare_model_input( (input_tokens, input_positions, input_block_ids) = self._prepare_decode(seq_group_metadata_list) seq_lens = None + + if not self._on_device_sampling_disabled: + for seq_group_metadata in seq_group_metadata_list: + sampling_params = seq_group_metadata.sampling_params + top_k, top_p, temperature = self._convert_to_neuron_sampling_params(sampling_params) + sampling_params.top_k = top_k + sampling_params.top_p = top_p + sampling_params.temperature = temperature + + # we need multi_modal_data for later tokens as well + # multi_modal_inputs_list: List[MultiModalInputs] = [] + # for seq_group_metadata in seq_group_metadata_list: + # mm_data = seq_group_metadata.multi_modal_data + # if mm_data: + # multi_modal_inputs_list.append(mm_data) + # multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) + sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, @@ -265,7 +290,7 @@ def prepare_model_input( self.pin_memory, generators=self.get_generators(finished_requests_ids)) - if not self._on_device_sampling_disabled: + if use_transformers_neuronx() and not self._on_device_sampling_disabled: # Once the request IDs are changed in current iteration, we will # update the on-device sampling parameters. current_batch_request_ids = [ @@ -273,7 +298,7 @@ def prepare_model_input( for seq_group_meta_data in seq_group_metadata_list ] if current_batch_request_ids != self._previous_batch_request_ids: - self._update_neuron_sampling_params(sampling_metadata) + self._update_neuron_sampling_params(seq_group_metadata_list) self._previous_batch_request_ids = current_batch_request_ids return ModelInputForNeuron(input_tokens=input_tokens, @@ -282,31 +307,61 @@ def prepare_model_input( sampling_metadata=sampling_metadata, multi_modal_kwargs=multi_modal_kwargs) - def _update_neuron_sampling_params(self, - sampling_metadata: SamplingMetadata): + def _update_neuron_sampling_params( + self, seq_group_metadata_list: List[SequenceGroupMetadata]): # Update Neuron sampling parameters (GenerationConfig in Neuron) current_sampling_params = self.model_config.neuron_sampling_params assert current_sampling_params is not None, ( f"Failed to update sampling_params, " f"current sampling params is {current_sampling_params}") + is_update_needed = False + top_k = current_sampling_params.top_k top_p = current_sampling_params.top_p temperature = current_sampling_params.temperature - for index, sequence_group_to_sample in enumerate( - sampling_metadata.seq_groups): - top_k[index] = self._convert_to_neuron_top_k( - sequence_group_to_sample.sampling_params.top_k) - top_p[index] = sequence_group_to_sample.sampling_params.top_p - temperature[index] = \ - sequence_group_to_sample.sampling_params.temperature - self.model.model.update_generation_config(current_sampling_params) + # The index of a sequence's sampling parameters in neuron is equal to + # its index in `input_block_ids`. + for seq_group_metadata in seq_group_metadata_list: + seq_ids = list(seq_group_metadata.seq_data.keys()) + sampling_params = seq_group_metadata.sampling_params - def _convert_to_neuron_top_k(self, top_k: int) -> int: + seq_group_top_k = sampling_params.top_k + seq_group_top_p = sampling_params.top_p + seq_group_temperature = sampling_params.temperature + + for seq_id in seq_ids: + index = seq_group_metadata.block_tables[seq_id][0] + if ( + top_k[index] != seq_group_top_k + or top_p[index] != seq_group_top_p + or temperature[index] != seq_group_temperature + ): + is_update_needed = True + + top_k[index] = seq_group_top_k + top_p[index] = seq_group_top_p + temperature[index] = seq_group_temperature + + # update_generation_config is only available in transformers-neuronx + if is_update_needed and use_transformers_neuronx(): + self.model.model.update_generation_config(current_sampling_params) + + def _convert_to_neuron_sampling_params( + self, sampling_params: SamplingParams) -> Tuple[int, float, float]: + # Returns the top_k, top_p and temperature parameters for neuron. + top_k = sampling_params.top_k + top_p = sampling_params.top_p + temperature = sampling_params.temperature + + if temperature == 0.0: + # Enable greedy sampling on zero temperature + return (1, 1.0, 1.0) if top_k < 0 or top_k > self._MAX_NEURON_SAMPLING_TOP_K: - return self._MAX_NEURON_SAMPLING_TOP_K - return top_k + top_k = self._MAX_NEURON_SAMPLING_TOP_K + + return (top_k, top_p, temperature) @torch.inference_mode() def execute_model( @@ -320,14 +375,28 @@ def execute_model( raise ValueError( "NeuronModelRunner does not support multi-step execution.") - with set_forward_context(None, self.vllm_config, 0): + # extract top_k, top_p and temperature from model_input for neuron forward call + sampling_params = torch.tensor([[seq_group.sampling_params.top_k, + seq_group.sampling_params.top_p, + seq_group.sampling_params.temperature] for seq_group in model_input.sampling_metadata.seq_groups]) + if use_neuronx_distributed(): + hidden_states = self.model( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + input_block_ids=model_input.input_block_ids, + sampling_params=sampling_params, + # **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, + # device=self.device), + ) + elif use_transformers_neuronx(): + # [TODO] validate on-device sampling + # The model signature may need change for on-device sampling hidden_states = self.model( input_ids=model_input.input_tokens, positions=model_input.input_positions, input_block_ids=model_input.input_block_ids, - **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs - or {}, - device=self.device), + # **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, + # device=self.device), ) # Compute the logits only if the on-device sampling is turned off as diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index 5f0eb0019ee..c23eb627c2c 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -1,20 +1,77 @@ # SPDX-License-Identifier: Apache-2.0 """A Neuron worker class.""" +import enum +import os +from functools import lru_cache from typing import List, Optional, Tuple import torch import torch.distributed +from vllm.logger import init_logger from vllm.config import VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.model_executor import set_random_seed -from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import ExecuteModelRequest -from vllm.worker.neuron_model_runner import NeuronModelRunner from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, LoraNotSupportedWorkerBase, WorkerBase, WorkerInput) +# FIXME(Neuron): restore the framework selection logic. +# from vllm.utils import is_transformers_neuronx, is_neuronx_distributed_inference + +logger = init_logger(__name__) + +class NeuronFramework(enum.Enum): + TRANSFORMERS_NEURONX = "transformers-neuronx" + NEURONX_DISTRIBUTED_INFERENCE = "neuronx-distributed-inference" + + +@lru_cache(maxsize=None) +def get_neuron_framework_to_use(): + """ + Return the specified framework if the corresponding installations are available. + If no framework is specified, then use transformers-neuronx by default, if unavailable + then check and switch to neuronx-distributed-inference. + """ + # FIXME(Neuron): restore the framework selection logic. + return NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE + + # transformers_neuronx_installed = is_transformers_neuronx() + # neuronx_distributed_inference_installed = is_neuronx_distributed_inference() + # specified_framework = os.environ.get("VLLM_NEURON_FRAMEWORK") + # if specified_framework == NeuronFramework.TRANSFORMERS_NEURONX.value and transformers_neuronx_installed: + # return NeuronFramework.TRANSFORMERS_NEURONX + # elif specified_framework == NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE.value and neuronx_distributed_inference_installed: + # return NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE + # elif specified_framework is None and transformers_neuronx_installed: + # return NeuronFramework.TRANSFORMERS_NEURONX + # elif specified_framework is None and neuronx_distributed_inference_installed: + # return NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE + # else: + # return None + + +@lru_cache(maxsize=None) +def use_neuronx_distributed(): + """ + Return True if the framework determined in get_neuron_framework_to_use() is + NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE, False otherwise. This is used + to select the Neuron model framework and framework-specific configuration to apply + during model compilation. + """ + return get_neuron_framework_to_use() == NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE + + +@lru_cache(maxsize=None) +def use_transformers_neuronx(): + """ + Return True if the framework determined in get_neuron_framework_to_use() is + NeuronFramework.TRANSFORMERS_NEURONX, False otherwise. This is used to select + the Neuron model framework and framework-specific configuration to apply during + model compilation. + """ + return get_neuron_framework_to_use() == NeuronFramework.TRANSFORMERS_NEURONX class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): @@ -27,35 +84,45 @@ def __init__( local_rank: int, rank: int, distributed_init_method: str, - is_driver_worker: bool = True, + enable_neuron_multi_node: bool = False, + world_size: int = 1, + is_driver_worker: bool = False, ) -> None: WorkerBase.__init__(self, vllm_config=vllm_config) self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method + self.is_driver_worker = is_driver_worker if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils import init_cached_hf_modules init_cached_hf_modules() - - self.model_runner: NeuronModelRunner = NeuronModelRunner( - vllm_config=vllm_config) - self.is_driver_worker = is_driver_worker - - def execute_model( - self, - execute_model_req: Optional[ExecuteModelRequest] = None, - ) -> Optional[List[SamplerOutput]]: - assert execute_model_req is not None - assert (not execute_model_req.blocks_to_swap_in - and not execute_model_req.blocks_to_swap_out - and not execute_model_req.blocks_to_copy), ( - "Cache operations are not supported for Neuron backend.") - assert execute_model_req.num_lookahead_slots == 0, ( - "lookahead not supported for Neuron backend.") - output = LocalOrDistributedWorkerBase.execute_model( - self, execute_model_req) - return output + neuron_framework = get_neuron_framework_to_use() + + if neuron_framework == NeuronFramework.TRANSFORMERS_NEURONX: + from vllm.worker.neuron_model_runner import NeuronModelRunner + # from vllm.worker.multi_step_neuron_model_runner import MultiStepNeuronModelRunner + if self.speculative_config is not None: + self.model_runner = MultiStepNeuronModelRunner( + model_config, parallel_config, scheduler_config, + device_config, speculative_config) + else: + self.model_runner: NeuronModelRunner = NeuronModelRunner( + vllm_config=vllm_config) + elif neuron_framework == NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE: + from vllm.worker.neuronx_distributed_model_runner import NeuronxDistributedModelRunner + # from vllm.worker.multi_step_neuronx_distributed_model_runner import MultiStepNeuronModelRunner + if self.speculative_config is not None: + self.model_runner = MultiStepNeuronModelRunner( + model_config, parallel_config, scheduler_config, + device_config, speculative_config) + else: + self.model_runner: NeuronxDistributedModelRunner = NeuronxDistributedModelRunner( + vllm_config=vllm_config) + else: + raise NotImplementedError( + f"Specified framework as {os.environ.get('VLLM_NEURON_FRAMEWORK')}," + + " Only transformers-neuronx/neuronx-distributed-inference framework is supported") def init_device(self) -> None: self.init_distributed_environment() @@ -121,14 +188,13 @@ def get_cache_block_size_bytes(self) -> int: def init_distributed_environment(self): """Neuron uses transformers-neuronx for tensor parallelism. - It has only one process to control multiple devices. - vLLM still needs the environment initialized when TP/PP > 1, - so we initialize a distributed environment with one process. + + vLLM still needs the environment inited when TP/PP > 1 """ init_distributed_environment( world_size=1, - rank=0, - local_rank=0, + rank=self.rank, + local_rank=self.local_rank, distributed_init_method=self.distributed_init_method, backend="gloo", ) diff --git a/vllm/worker/neuronx_distributed_model_runner.py b/vllm/worker/neuronx_distributed_model_runner.py new file mode 100644 index 00000000000..22e61ac0634 --- /dev/null +++ b/vllm/worker/neuronx_distributed_model_runner.py @@ -0,0 +1,137 @@ +import os +import torch +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.model_loader.neuronx_distributed import get_neuron_model, _get_model_architecture +from vllm.worker.neuron_model_runner import NeuronModelRunner, ModelInputForNeuron +from vllm.sequence import IntermediateTensors, SequenceGroupMetadata +from vllm.model_executor.layers.sampler import SamplerOutput + +from neuronx_distributed_inference.modules.generation.sampling import prepare_sampling_params +from neuronx_distributed_inference.models.mllama.image_transform import custom_image_preprocessing + +# FIXME(Neuron): need to restor multi-model support +# from vllm.multimodal.neuron_multimodal_image_utils import decompress_image_from_tensor +logger = init_logger(__name__) + + +class NeuronxDistributedModelRunner(NeuronModelRunner): + + def __init__( + self, + vllm_config: VllmConfig, + ): + super().__init__(vllm_config) + + def load_model(self) -> None: + self.model = get_neuron_model( + self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config) + + def get_nxd_sampling_params(self, sampling_metadata): + if self.model.config.neuron_config.on_device_sampling_config: + max_topk = self.model.config.neuron_config.on_device_sampling_config.global_topk + else: + max_topk = self.model.config.vocab_size + + top_k = [1] * self.scheduler_config.max_num_seqs + top_p = [1.0] * self.scheduler_config.max_num_seqs + temperature = [1.0] * self.scheduler_config.max_num_seqs + + for index, sequenceGroupToSample in enumerate(sampling_metadata.seq_groups): + top_k[index] = sequenceGroupToSample.sampling_params.top_k if sequenceGroupToSample.sampling_params.top_k > 0 else max_topk + top_p[index] = sequenceGroupToSample.sampling_params.top_p + temperature[index] = sequenceGroupToSample.sampling_params.temperature + + sampling_params = prepare_sampling_params(batch_size=self.scheduler_config.max_num_seqs, + top_k=top_k, top_p=top_p, temperature=temperature) + return sampling_params + + def get_multi_modal_data_neuron(self, input_images): + total_image_size = 0 + image_tensors = [] + empty_pixel_values = torch.zeros([1, 1, 4, 3, 560, 560], dtype=torch.bfloat16) + empty_aspect_ratios = torch.ones([1, 1, 2], dtype=torch.int64) + num_chunks = torch.tensor([[1]]) # dummy num_chunks, will not be used + has_image = torch.tensor([0]) + if (input_images is None) or (len(input_images) == 0) or (input_images.numel() == 0): + image_tensors = [empty_pixel_values, empty_aspect_ratios, num_chunks, has_image] + else: + image = decompress_image_from_tensor(input_images) + total_image_size += image.width * image.height + pixel_values, aspect_ratios, num_chunks = custom_image_preprocessing(self.model.config, [[image]]) + has_image = torch.tensor([1]) + + image_tensors = [pixel_values.bfloat16().clone().detach(), aspect_ratios, num_chunks, has_image] + + return image_tensors + + @torch.inference_mode() + def execute_model( + self, + model_input: ModelInputForNeuron, + kv_caches: Optional[List[torch.Tensor]] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[List[SamplerOutput]]: + if num_steps > 1: + raise ValueError( + "NeuronModelRunner does not support multi-step execution.") + + if not _get_model_architecture(self.model.config) == "MllamaForConditionalGeneration": + return super().execute_model(model_input, kv_caches, intermediate_tensors, num_steps) + + sampling_params = self.get_nxd_sampling_params(model_input.sampling_metadata) + + if model_input.multi_modal_kwargs.get('image') is not None: + pixel_values = [] + aspect_ratios = [] + num_chunks = [] + has_image = [] + for multi_modal_input in model_input.multi_modal_kwargs.get('image'): + image_tensors = self.get_multi_modal_data_neuron(multi_modal_input.squeeze(0)) + pixel_values.append(image_tensors[0]) + aspect_ratios.append(image_tensors[1]) + num_chunks.append(image_tensors[2]) + has_image.append(image_tensors[3]) + + pixel_values=torch.cat(pixel_values,dim=0) + aspect_ratios=torch.cat(aspect_ratios,dim=0) + num_chunks=torch.cat(num_chunks,dim=0) + has_image=torch.cat(has_image,dim=0) + + hidden_states = self.model( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + seq_ids=model_input.input_block_ids, + pixel_values=pixel_values, + aspect_ratios=aspect_ratios, + sampling_params=sampling_params, + num_chunks=num_chunks, + has_image=has_image, + ) + else: + empty_pixel_values = torch.zeros([1, 1, 4, 3, 560, 560], dtype=torch.bfloat16) + empty_aspect_ratios = torch.ones([1, 1, 2], dtype=torch.int64) + num_chunks = torch.tensor([[1]]) # dummy num_chunks, will not be used + has_image = torch.tensor([0]) + hidden_states = self.model( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + seq_ids=model_input.input_block_ids, + pixel_values=empty_pixel_values, + aspect_ratios=empty_aspect_ratios, + sampling_params=sampling_params, + num_chunks=num_chunks, + has_image=has_image, + ) + + output = self.model.sample( + hidden_states=hidden_states, + sampling_metadata=model_input.sampling_metadata, + ) + + return [output] \ No newline at end of file From 626850c74309a70fd52f3c672ffefcea7874ff74 Mon Sep 17 00:00:00 2001 From: Shashwat Srijan Date: Tue, 15 Oct 2024 07:39:08 +0000 Subject: [PATCH 02/38] Support speculation with transformers-neuronx Signed-off-by: Satyajith Chilappagari --- examples/offline_speculation_neuron.py | 32 +++++ vllm/executor/neuron_executor.py | 112 +++++++++++++++ vllm/model_executor/model_loader/neuron.py | 129 +++++++++++++++++- vllm/worker/multi_step_neuron_model_runner.py | 68 +++++++++ vllm/worker/neuron_worker.py | 25 +++- 5 files changed, 356 insertions(+), 10 deletions(-) create mode 100644 examples/offline_speculation_neuron.py create mode 100644 vllm/executor/neuron_executor.py create mode 100644 vllm/worker/multi_step_neuron_model_runner.py diff --git a/examples/offline_speculation_neuron.py b/examples/offline_speculation_neuron.py new file mode 100644 index 00000000000..5b5ed517ad9 --- /dev/null +++ b/examples/offline_speculation_neuron.py @@ -0,0 +1,32 @@ +import os + +from vllm import LLM, SamplingParams + +# creates XLA hlo graphs for all the context length buckets. +os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048" +# creates XLA hlo graphs for all the token gen buckets. +os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048" + +prompts = [ + "Hello, I am a language model and I can help", + "The president of the United States is", + "The capital of France is", +] +sampling_params = SamplingParams(max_tokens=100, top_k=1) +llm = LLM( + model="openlm-research/open_llama_7b", + speculative_model='openlm-research/open_llama_3b', + num_speculative_tokens=4, + max_num_seqs=4, + max_model_len=2048, + block_size=2048, + speculative_max_model_len=2048, + use_v2_block_manager=True, + device="neuron", + tensor_parallel_size=32, +) +outputs = llm.generate(prompts, sampling_params) +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py new file mode 100644 index 00000000000..ee345139c95 --- /dev/null +++ b/vllm/executor/neuron_executor.py @@ -0,0 +1,112 @@ +from typing import List, Set, Tuple + +from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest +from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, + make_async) + +logger = init_logger(__name__) + + +class NeuronExecutor(ExecutorBase): + + uses_ray: bool = False + + def _init_executor(self) -> None: + assert (self.lora_config is + None), "LoRA is not supported for Neuron backend." + + # Instantiate the worker and load the model to the device. + self._init_worker() + + def _init_worker(self): + from vllm.worker.neuron_worker import NeuronWorker + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + self.driver_worker = NeuronWorker( + model_config=self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, + speculative_config=self.speculative_config, + local_rank=0, + rank=0, + distributed_init_method=distributed_init_method) + self.driver_worker.load_model() + self.driver_worker.init_device() + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available KV blocks by invoking the + underlying worker. + """ + return self.driver_worker.determine_num_available_blocks() + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache by invoking the underlying worker. + """ + self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) + + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + assert (not execute_model_req.blocks_to_swap_in + and not execute_model_req.blocks_to_swap_out + and not execute_model_req.blocks_to_copy), ( + "Cache operations are not supported for Neuron backend.") + + output = self.driver_worker.execute_model(execute_model_req) + return output + + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.driver_worker.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.driver_worker.remove_lora(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + return self.driver_worker.pin_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.driver_worker.list_loras() + + def add_prompt_adapter(self, prompt_adapter_request) -> bool: + raise NotImplementedError( + "Soft prompt is currently not supported by the Neuron backend.") + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + raise NotImplementedError( + "Soft prompt is currently not supported by the Neuron backend.") + + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + raise NotImplementedError( + "Soft prompt is currently not supported by the Neuron backend.") + + def list_prompt_adapters(self) -> Set[int]: + raise NotImplementedError( + "Soft prompt is currently not supported by the Neuron backend.") + + def check_health(self) -> None: + # NeuronExecutor will always be healthy as long as + # it's running. + return + + +class NeuronExecutorAsync(NeuronExecutor, ExecutorAsyncBase): + + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest, + ) -> List[SamplerOutput]: + output = await make_async(self.driver_worker.execute_model + )(execute_model_req=execute_model_req, ) + return output + + async def check_health_async(self) -> None: + # NeuronExecutor will always be healthy as long as + # it's running. + return diff --git a/vllm/model_executor/model_loader/neuron.py b/vllm/model_executor/model_loader/neuron.py index d900fb3a7d3..bd0ec6954ec 100644 --- a/vllm/model_executor/model_loader/neuron.py +++ b/vllm/model_executor/model_loader/neuron.py @@ -3,13 +3,13 @@ import copy import importlib import os -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, List import torch import torch.nn as nn from transformers import PretrainedConfig -from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig +from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import get_quantization_config from vllm.model_executor.layers.sampler import Sampler, SamplerOutput @@ -113,6 +113,48 @@ def load_weights(self, model_name_or_path: str, **kwargs): self.model.to_neuron() +class NeuronSpeculationCausalLM(nn.Module): + + def __init__( + self, + speculation_model + ) -> None: + super().__init__() + self.model = speculation_model + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_block_ids: torch.Tensor, + ) -> torch.Tensor: + tokens, counts = self.model.speculative_iteration(input_ids, positions, input_block_ids) + return tokens + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[List[SamplerOutput]]: + batch_size, num_steps = logits.shape + seq_ids = [seq_id for sg in sampling_metadata.seq_groups for seq_id in sg.seq_ids] + # Organize input tensors by step instead of by sequence. + accepted_token_ids_by_step = logits.transpose(0, 1) + accepted_token_ids_by_step[accepted_token_ids_by_step==self.model.pad_token_id]=-1 + + sampler_output_list = [] + for step_index in range(num_steps): + if all(token_id == -1 for token_id in accepted_token_ids_by_step[step_index]): + break + step_output_token_ids = [] + for sequence_index in range(batch_size): + token_id = accepted_token_ids_by_step[step_index][sequence_index] + step_output_token_ids.append(CompletionSequenceGroupOutput(samples=[SequenceOutput(parent_seq_id=seq_ids[sequence_index], output_token=token_id, logprobs={token_id: Logprob(token_id)})], prompt_logprobs=None)) + sampler_output_list.append( + SamplerOutput(outputs=step_output_token_ids)) + return sampler_output_list + + def _get_model_architecture(config: PretrainedConfig) -> str: architectures = getattr(config, "architectures", []) for arch in architectures: @@ -162,6 +204,28 @@ def _get_default_neuron_config(model_config: ModelConfig, return default_neuron_args +def _get_default_neuron_config_for_speculation( + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig +): + from transformers_neuronx.config import ContinuousBatchingConfig + from transformers_neuronx.constants import LAYOUT_BSH + + continuous_batching_config = ContinuousBatchingConfig( + batch_size_for_shared_caches=scheduler_config.max_num_seqs) + + default_neuron_args = dict( + collectives_layout=LAYOUT_BSH, + attention_layout=LAYOUT_BSH, + fuse_qkv=True, + on_device_embedding=True, + continuous_batching=continuous_batching_config, + on_device_generation=copy.deepcopy(model_config.neuron_sampling_params) + ) + return default_neuron_args + + def _get_neuron_on_device_generation_config(model_config: ModelConfig): if not _is_neuron_on_device_sampling_disabled(model_config): return copy.deepcopy(model_config.neuron_sampling_params) @@ -200,7 +264,6 @@ def get_neuron_model(model_config: ModelConfig, n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS", [scheduler_config.max_model_len]) - # Load the weights from the cached or downloaded files. model.load_weights(model_config.model, tp_degree=parallel_config.tensor_parallel_size, amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], @@ -210,3 +273,63 @@ def get_neuron_model(model_config: ModelConfig, batch_size=scheduler_config.max_num_seqs) return model.eval() + + +def get_neuron_speculation_model( + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + speculation_config: SpeculativeConfig +) -> None: + from transformers_neuronx.fused_speculation import FusedSpeculativeDecoder + + # Create target model instance. + target_model = NeuronCausalLM(model_config.hf_config) + + default_neuron_config_args = _get_default_neuron_config_for_speculation( + model_config, parallel_config, scheduler_config) + + neuron_config = _get_neuron_config_after_override( + default_neuron_config_args, model_config.override_neuron_config) + + context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS", + [scheduler_config.max_model_len]) + n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS", + [scheduler_config.max_model_len]) + + target_model.load_weights( + model_config.model, + tp_degree=parallel_config.tensor_parallel_size, + amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], + neuron_config=neuron_config, + context_length_estimate=context_length_estimates, + n_positions=n_positions, + batch_size=scheduler_config.max_num_seqs) + + target_model.eval() + + # Create draft model instance. + draft_model = NeuronCausalLM(speculation_config.draft_model_config.hf_config) + + default_draft_neuron_config_args = _get_default_neuron_config_for_speculation( + speculation_config.draft_model_config, parallel_config, scheduler_config) + + draft_neuron_config = _get_neuron_config_after_override( + default_draft_neuron_config_args, speculation_config.draft_model_config.override_neuron_config) + + draft_model.load_weights( + speculation_config.draft_model_config.model, + tp_degree=speculation_config.draft_parallel_config.tensor_parallel_size, + amp=TORCH_DTYPE_TO_NEURON_AMP[speculation_config.draft_model_config.dtype], + neuron_config=draft_neuron_config, + context_length_estimate=context_length_estimates, + n_positions=n_positions, + batch_size=scheduler_config.max_num_seqs) + + draft_model.eval() + + # Create speculation model instance. + speculation_model = FusedSpeculativeDecoder(draft_model.model, target_model.model, speculation_config.num_speculative_tokens) + speculation_model.to_neuron() + + return NeuronSpeculationCausalLM(speculation_model) diff --git a/vllm/worker/multi_step_neuron_model_runner.py b/vllm/worker/multi_step_neuron_model_runner.py new file mode 100644 index 00000000000..4e8ea17f2be --- /dev/null +++ b/vllm/worker/multi_step_neuron_model_runner.py @@ -0,0 +1,68 @@ +from importlib.util import find_spec +import torch +from typing import List, Optional +from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig, + SchedulerConfig, SpeculativeConfig) +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.multimodal import MultiModalInputs +from vllm.sequence import IntermediateTensors +from vllm.worker.neuron_model_runner import NeuronModelRunner, ModelInputForNeuron + +class MultiStepNeuronModelRunner(NeuronModelRunner): + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + speculation_config: SpeculativeConfig, + ): + super().__init__(model_config, parallel_config, scheduler_config, device_config) + self.speculation_config = speculation_config + from transformers_neuronx.config import GenerationConfig + self.speculation_config.draft_model_config.neuron_sampling_params = GenerationConfig( + max_length=self.scheduler_config.max_model_len, + do_sample=True, + per_batch_line=True, + top_k = [self._MAX_NEURON_SAMPLING_TOP_K] \ + * self.scheduler_config.max_num_seqs, + top_p = [1.0] * self.scheduler_config.max_num_seqs, + temperature = [1.0] * self.scheduler_config.max_num_seqs, + dynamic=True, + global_top_k=self._MAX_NEURON_SAMPLING_TOP_K + ) + + def load_model(self) -> None: + if find_spec("transformers_neuronx") is not None: + from vllm.model_executor.model_loader.neuron import get_neuron_speculation_model + self.model = get_neuron_speculation_model( + self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + speculation_config=self.speculation_config) + else: + raise NotImplementedError( + "Supports only Transformer-NeuronX based models.") + + @torch.inference_mode() + def execute_model( + self, + model_input: ModelInputForNeuron, + kv_caches: Optional[List[torch.Tensor]] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[List[SamplerOutput]]: + logits = self.model( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + input_block_ids=model_input.input_block_ids, + **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, + device=self.device), + ) + + output = self.model.sample( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + return output \ No newline at end of file diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index c23eb627c2c..b70e2927b60 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -9,7 +9,8 @@ import torch.distributed from vllm.logger import init_logger -from vllm.config import VllmConfig +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, + ParallelConfig, SchedulerConfig, SpeculativeConfig) from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.model_executor import set_random_seed @@ -80,7 +81,12 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): def __init__( self, - vllm_config: VllmConfig, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + speculative_config: SpeculativeConfig, local_rank: int, rank: int, distributed_init_method: str, @@ -88,7 +94,12 @@ def __init__( world_size: int = 1, is_driver_worker: bool = False, ) -> None: - WorkerBase.__init__(self, vllm_config=vllm_config) + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.cache_config = cache_config + self.speculative_config = speculative_config self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method @@ -101,24 +112,24 @@ def __init__( if neuron_framework == NeuronFramework.TRANSFORMERS_NEURONX: from vllm.worker.neuron_model_runner import NeuronModelRunner - # from vllm.worker.multi_step_neuron_model_runner import MultiStepNeuronModelRunner + from vllm.worker.multi_step_neuron_model_runner import MultiStepNeuronModelRunner if self.speculative_config is not None: self.model_runner = MultiStepNeuronModelRunner( model_config, parallel_config, scheduler_config, device_config, speculative_config) else: self.model_runner: NeuronModelRunner = NeuronModelRunner( - vllm_config=vllm_config) + model_config, parallel_config, scheduler_config, device_config) elif neuron_framework == NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE: from vllm.worker.neuronx_distributed_model_runner import NeuronxDistributedModelRunner - # from vllm.worker.multi_step_neuronx_distributed_model_runner import MultiStepNeuronModelRunner + from vllm.worker.multi_step_neuronx_distributed_model_runner import MultiStepNeuronModelRunner if self.speculative_config is not None: self.model_runner = MultiStepNeuronModelRunner( model_config, parallel_config, scheduler_config, device_config, speculative_config) else: self.model_runner: NeuronxDistributedModelRunner = NeuronxDistributedModelRunner( - vllm_config=vllm_config) + model_config, parallel_config, scheduler_config, device_config) else: raise NotImplementedError( f"Specified framework as {os.environ.get('VLLM_NEURON_FRAMEWORK')}," + From 7fcf9b24e264a14f2da259babf9177482b01ca99 Mon Sep 17 00:00:00 2001 From: Shashwat Srijan Date: Tue, 15 Oct 2024 08:44:42 +0000 Subject: [PATCH 03/38] Add support for eagle speculation using transformers-neuronx Signed-off-by: Satyajith Chilappagari --- vllm/config.py | 8 +++ vllm/engine/arg_utils.py | 8 +++ vllm/model_executor/model_loader/neuron.py | 61 +++++++++++++++++++ vllm/worker/multi_step_neuron_model_runner.py | 20 ++++-- 4 files changed, 91 insertions(+), 6 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 9ba49757612..87a4e946c71 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1673,6 +1673,7 @@ def maybe_create_spec_config( speculative_draft_tensor_parallel_size: Optional[int], num_speculative_tokens: Optional[int], speculative_disable_mqa_scorer: Optional[bool], + speculative_token_tree: Optional[str], speculative_max_model_len: Optional[int], enable_chunked_prefill: bool, disable_log_stats: bool, @@ -1709,6 +1710,8 @@ def maybe_create_spec_config( speculative_disable_mqa_scorer (Optional[bool]): Disable the MQA scorer for the speculative model and fall back to batch expansion for scoring. + speculative_token_tree (Optional[str]): The token tree structure + used with speculation. speculative_max_model_len (Optional[int]): The maximum model len of the speculative model. Used when testing the ability to skip speculation for some sequences. @@ -1862,6 +1865,7 @@ def maybe_create_spec_config( draft_parallel_config, num_speculative_tokens, speculative_disable_mqa_scorer, + speculative_token_tree, speculative_disable_by_batch_size, ngram_prompt_lookup_max, ngram_prompt_lookup_min, @@ -1971,6 +1975,7 @@ def __init__( draft_parallel_config: ParallelConfig, num_speculative_tokens: int, speculative_disable_mqa_scorer: Optional[bool], + speculative_token_tree: Optional[str], speculative_disable_by_batch_size: Optional[int], ngram_prompt_lookup_max: Optional[int], ngram_prompt_lookup_min: Optional[int], @@ -1987,6 +1992,8 @@ def __init__( draft_parallel_config: ParallelConfig for the draft model. num_speculative_tokens: The number of tokens to sample from the draft model before scoring with the target model. + speculative_token_tree: The token tree structure used during + speculation. speculative_disable_by_batch_size: Disable speculative decoding for new incoming requests when the number of enqueue requests is larger than this value. @@ -2018,6 +2025,7 @@ def __init__( self.draft_parallel_config = draft_parallel_config self.num_speculative_tokens = num_speculative_tokens self.speculative_disable_mqa_scorer = speculative_disable_mqa_scorer + self.speculative_token_tree = speculative_token_tree self.speculative_disable_by_batch_size = \ speculative_disable_by_batch_size self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0 diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 40c6fb45679..00ae1dbc6cb 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -174,6 +174,7 @@ class EngineArgs: speculative_draft_tensor_parallel_size: Optional[int] = None num_speculative_tokens: Optional[int] = None speculative_disable_mqa_scorer: Optional[bool] = False + speculative_token_tree: Optional[str] = None speculative_max_model_len: Optional[int] = None speculative_disable_by_batch_size: Optional[int] = None ngram_prompt_lookup_max: Optional[int] = None @@ -746,6 +747,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help= 'If set to True, the MQA scorer will be disabled in speculative ' ' and fall back to batch expansion') + parser.add_argument( + '--speculative-token-tree', + type=nullable_str, + default=EngineArgs.speculative_token_tree, + help='The token tree definition used with speculation.' + ) parser.add_argument( '--speculative-draft-tensor-parallel-size', '-spec-draft-tp', @@ -1164,6 +1171,7 @@ def create_engine_config(self, self.speculative_draft_tensor_parallel_size, num_speculative_tokens=self.num_speculative_tokens, speculative_disable_mqa_scorer=self.speculative_disable_mqa_scorer, + speculative_token_tree=self.speculative_token_tree, speculative_disable_by_batch_size=self. speculative_disable_by_batch_size, speculative_max_model_len=self.speculative_max_model_len, diff --git a/vllm/model_executor/model_loader/neuron.py b/vllm/model_executor/model_loader/neuron.py index bd0ec6954ec..2b9a497d7f7 100644 --- a/vllm/model_executor/model_loader/neuron.py +++ b/vllm/model_executor/model_loader/neuron.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """Utilities for selecting and loading neuron models.""" +import ast import copy import importlib import os @@ -333,3 +334,63 @@ def get_neuron_speculation_model( speculation_model.to_neuron() return NeuronSpeculationCausalLM(speculation_model) + + +def get_neuron_eagle_speculation_model(model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + speculation_config: SpeculativeConfig) -> None: + from transformers_neuronx.eagle_speculation import EagleSpeculativeDecoder + + # Create target model instance. + target_model = NeuronCausalLM(model_config.hf_config) + + default_neuron_config_args = _get_default_neuron_config_for_speculation( + model_config, parallel_config, scheduler_config) + default_neuron_config_args['is_eagle_target'] = True + neuron_config = _get_neuron_config_after_override( + default_neuron_config_args, model_config.override_neuron_config) + + context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS", + [scheduler_config.max_model_len]) + n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS", + [scheduler_config.max_model_len]) + + target_model.load_weights( + model_config.model, + tp_degree=parallel_config.tensor_parallel_size, + amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], + neuron_config=neuron_config, + context_length_estimate=context_length_estimates, + n_positions=n_positions, + batch_size=scheduler_config.max_num_seqs) + + target_model.eval() + + # Create draft model instance. + draft_model = NeuronCausalLM(speculation_config.draft_model_config.hf_config) + + default_draft_neuron_config_args = _get_default_neuron_config_for_speculation( + speculation_config.draft_model_config, parallel_config, scheduler_config) + default_draft_neuron_config_args['is_eagle_draft'] = True + default_draft_neuron_config_args['has_pre_attention_norm'] = False + draft_neuron_config = _get_neuron_config_after_override( + default_draft_neuron_config_args, speculation_config.draft_model_config.override_neuron_config) + + draft_model.load_weights( + speculation_config.draft_model_config.model, + tp_degree=speculation_config.draft_parallel_config.tensor_parallel_size, + amp=TORCH_DTYPE_TO_NEURON_AMP[speculation_config.draft_model_config.dtype], + neuron_config=draft_neuron_config, + context_length_estimate=context_length_estimates, + n_positions=n_positions, + batch_size=scheduler_config.max_num_seqs) + + draft_model.eval() + + token_tree: Dict[int, List[int]] = ast.literal_eval(speculation_config.speculative_token_tree) + + speculation_model = EagleSpeculativeDecoder(draft_model.model, target_model.model, token_tree=token_tree) + speculation_model.to_neuron() + + return NeuronSpeculationCausalLM(speculation_model) diff --git a/vllm/worker/multi_step_neuron_model_runner.py b/vllm/worker/multi_step_neuron_model_runner.py index 4e8ea17f2be..764a2fc5b1a 100644 --- a/vllm/worker/multi_step_neuron_model_runner.py +++ b/vllm/worker/multi_step_neuron_model_runner.py @@ -35,12 +35,20 @@ def __init__( def load_model(self) -> None: if find_spec("transformers_neuronx") is not None: - from vllm.model_executor.model_loader.neuron import get_neuron_speculation_model - self.model = get_neuron_speculation_model( - self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - speculation_config=self.speculation_config) + from vllm.model_executor.model_loader.neuron import get_neuron_speculation_model, get_neuron_eagle_speculation_model + if self.speculation_config.speculative_token_tree is not None: + self.model = get_neuron_eagle_speculation_model( + self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + speculation_config=self.speculation_config + ) + else: + self.model = get_neuron_speculation_model( + self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + speculation_config=self.speculation_config) else: raise NotImplementedError( "Supports only Transformer-NeuronX based models.") From 3a8e7a54b5b114764c7e0d61c3a8f4b31a73253b Mon Sep 17 00:00:00 2001 From: Shashwat Srijan Date: Tue, 29 Oct 2024 10:18:44 +0000 Subject: [PATCH 04/38] Support speculation with neuronx-distributed-inference for batch 1 Signed-off-by: Satyajith Chilappagari --- .../model_loader/neuronx_distributed.py | 3 +- ...i_step_neuronx_distributed_model_runner.py | 54 +++++++++++++++++++ 2 files changed, 55 insertions(+), 2 deletions(-) create mode 100644 vllm/worker/multi_step_neuronx_distributed_model_runner.py diff --git a/vllm/model_executor/model_loader/neuronx_distributed.py b/vllm/model_executor/model_loader/neuronx_distributed.py index 6c8b3aa0610..1c06ced8f8b 100644 --- a/vllm/model_executor/model_loader/neuronx_distributed.py +++ b/vllm/model_executor/model_loader/neuronx_distributed.py @@ -402,6 +402,7 @@ def load_weights(self, model_name_or_path: str, draft_model_name_or_path: str, * self.model.compile(compiled_model_path) self.model.load(compiled_model_path) + def _get_model_architecture(config: PretrainedConfig) -> str: architectures = getattr(config, "architectures", []) for arch in architectures: @@ -438,7 +439,6 @@ def _get_default_neuron_config(model_config: ModelConfig, ) return neuron_config - def _get_default_neuron_speculation_config(model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, @@ -458,7 +458,6 @@ def _get_default_neuron_speculation_config(model_config: ModelConfig, ) return neuron_config - def _get_neuron_config_after_override(default_neuron_config, overridden_neuron_config): overridden_neuron_config = overridden_neuron_config or {} diff --git a/vllm/worker/multi_step_neuronx_distributed_model_runner.py b/vllm/worker/multi_step_neuronx_distributed_model_runner.py new file mode 100644 index 00000000000..90f2a057a12 --- /dev/null +++ b/vllm/worker/multi_step_neuronx_distributed_model_runner.py @@ -0,0 +1,54 @@ +from importlib.util import find_spec +import torch +from typing import List, Optional +from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig, + SchedulerConfig, SpeculativeConfig) +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.multimodal import MultiModalInputs +from vllm.sequence import IntermediateTensors +from vllm.worker.neuronx_distributed_model_runner import NeuronxDistributedModelRunner + +class MultiStepNeuronModelRunner(NeuronxDistributedModelRunner): + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + speculation_config: SpeculativeConfig, + ): + super().__init__(model_config, parallel_config, scheduler_config, device_config) + self.speculation_config = speculation_config + + def load_model(self) -> None: + from vllm.model_executor.model_loader.neuronx_distributed import get_neuron_speculation_model + assert self.scheduler_config.max_num_seqs == 1, "Only batch size 1 is currently supported for speculation using neuronx-distributed-inference." + self.model = get_neuron_speculation_model( + self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + speculation_config=self.speculation_config) + + @torch.inference_mode() + def execute_model( + self, + model_input, + kv_caches: Optional[List[torch.Tensor]] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[List[SamplerOutput]]: + logits = self.model( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + input_block_ids=model_input.input_block_ids, + **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, + device=self.device), + ) + + output = self.model.sample( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + return output + From bf9a4c6984118f6568c9e7b9d1c8df682eba33cb Mon Sep 17 00:00:00 2001 From: Shashwat Srijan Date: Tue, 29 Oct 2024 20:37:13 +0000 Subject: [PATCH 05/38] [Tnx] Fix streaming flow for speculation Signed-off-by: Satyajith Chilappagari --- vllm/model_executor/model_loader/neuron.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/model_loader/neuron.py b/vllm/model_executor/model_loader/neuron.py index 2b9a497d7f7..9bfa28a5ad0 100644 --- a/vllm/model_executor/model_loader/neuron.py +++ b/vllm/model_executor/model_loader/neuron.py @@ -142,6 +142,7 @@ def sample( # Organize input tensors by step instead of by sequence. accepted_token_ids_by_step = logits.transpose(0, 1) accepted_token_ids_by_step[accepted_token_ids_by_step==self.model.pad_token_id]=-1 + accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() sampler_output_list = [] for step_index in range(num_steps): From a2671aa876701c251760a6db4342bf0d84df2f79 Mon Sep 17 00:00:00 2001 From: Shashwat Srijan Date: Thu, 31 Oct 2024 01:20:44 +0000 Subject: [PATCH 06/38] [NxdI] Support eagle speculation Signed-off-by: Satyajith Chilappagari --- vllm/model_executor/model_loader/neuronx_distributed.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/model_loader/neuronx_distributed.py b/vllm/model_executor/model_loader/neuronx_distributed.py index 1c06ced8f8b..30080021f0f 100644 --- a/vllm/model_executor/model_loader/neuronx_distributed.py +++ b/vllm/model_executor/model_loader/neuronx_distributed.py @@ -458,6 +458,7 @@ def _get_default_neuron_speculation_config(model_config: ModelConfig, ) return neuron_config + def _get_neuron_config_after_override(default_neuron_config, overridden_neuron_config): overridden_neuron_config = overridden_neuron_config or {} From 0b075b146b545cfd2cd6e43f6e4292d2d41f2ad0 Mon Sep 17 00:00:00 2001 From: Chongming Ni Date: Thu, 7 Nov 2024 00:39:21 +0000 Subject: [PATCH 07/38] [TNx+EAGLE] Use FusedSpeculativeDecoder for EAGLE + Linear token tree. Signed-off-by: Satyajith Chilappagari --- vllm/model_executor/model_loader/neuron.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/vllm/model_executor/model_loader/neuron.py b/vllm/model_executor/model_loader/neuron.py index 9bfa28a5ad0..b57b8198443 100644 --- a/vllm/model_executor/model_loader/neuron.py +++ b/vllm/model_executor/model_loader/neuron.py @@ -285,11 +285,16 @@ def get_neuron_speculation_model( ) -> None: from transformers_neuronx.fused_speculation import FusedSpeculativeDecoder + # For Eagle SD, we need to pass in additional parameters in neuron config. + is_eagle = getattr(speculation_config.draft_model_config.hf_config, "is_eagle", False) + # Create target model instance. target_model = NeuronCausalLM(model_config.hf_config) default_neuron_config_args = _get_default_neuron_config_for_speculation( model_config, parallel_config, scheduler_config) + if is_eagle: + default_neuron_config_args['is_eagle_target'] = True neuron_config = _get_neuron_config_after_override( default_neuron_config_args, model_config.override_neuron_config) @@ -315,6 +320,9 @@ def get_neuron_speculation_model( default_draft_neuron_config_args = _get_default_neuron_config_for_speculation( speculation_config.draft_model_config, parallel_config, scheduler_config) + if is_eagle: + default_draft_neuron_config_args['is_eagle_draft'] = True + default_draft_neuron_config_args['has_pre_attention_norm'] = False draft_neuron_config = _get_neuron_config_after_override( default_draft_neuron_config_args, speculation_config.draft_model_config.override_neuron_config) From 703efd93eaf18ba164a60af1529fc14c463882ca Mon Sep 17 00:00:00 2001 From: Chongming Ni Date: Thu, 21 Nov 2024 07:19:55 +0000 Subject: [PATCH 08/38] Fix the termination check for accepted speculative tokens. Signed-off-by: Satyajith Chilappagari --- vllm/model_executor/model_loader/neuron.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/model_loader/neuron.py b/vllm/model_executor/model_loader/neuron.py index b57b8198443..f3da49edcf9 100644 --- a/vllm/model_executor/model_loader/neuron.py +++ b/vllm/model_executor/model_loader/neuron.py @@ -116,6 +116,8 @@ def load_weights(self, model_name_or_path: str, **kwargs): class NeuronSpeculationCausalLM(nn.Module): + SPECULATION_TERMINATION_ID = -1 + def __init__( self, speculation_model @@ -129,7 +131,15 @@ def forward( positions: torch.Tensor, input_block_ids: torch.Tensor, ) -> torch.Tensor: - tokens, counts = self.model.speculative_iteration(input_ids, positions, input_block_ids) + tokens, counts = self.model.speculative_iteration( + input_ids, positions, input_block_ids) + + # Mark the end of accepted specualtive tokens for each sequence with the + # speculation termination id. + mask = torch.arange(tokens.size(1)).repeat(tokens.size(0), + 1) >= counts.unsqueeze(-1) + tokens[mask] = self.SPECULATION_TERMINATION_ID + return tokens def sample( @@ -141,12 +151,12 @@ def sample( seq_ids = [seq_id for sg in sampling_metadata.seq_groups for seq_id in sg.seq_ids] # Organize input tensors by step instead of by sequence. accepted_token_ids_by_step = logits.transpose(0, 1) - accepted_token_ids_by_step[accepted_token_ids_by_step==self.model.pad_token_id]=-1 accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() sampler_output_list = [] for step_index in range(num_steps): - if all(token_id == -1 for token_id in accepted_token_ids_by_step[step_index]): + if all(token_id == self.SPECULATION_TERMINATION_ID + for token_id in accepted_token_ids_by_step[step_index]): break step_output_token_ids = [] for sequence_index in range(batch_size): From 6c04502ed96bdea1858b41c9dc1bb8b7c7c4b715 Mon Sep 17 00:00:00 2001 From: Chongming Ni Date: Fri, 22 Nov 2024 00:55:27 +0000 Subject: [PATCH 09/38] [TNx][Bug fix] fix the incorrect speculation output check. Signed-off-by: Satyajith Chilappagari --- vllm/model_executor/model_loader/neuron.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/model_loader/neuron.py b/vllm/model_executor/model_loader/neuron.py index f3da49edcf9..47d23f5ae27 100644 --- a/vllm/model_executor/model_loader/neuron.py +++ b/vllm/model_executor/model_loader/neuron.py @@ -136,8 +136,8 @@ def forward( # Mark the end of accepted specualtive tokens for each sequence with the # speculation termination id. - mask = torch.arange(tokens.size(1)).repeat(tokens.size(0), - 1) >= counts.unsqueeze(-1) + batch_size, steps = tokens.shape + mask = torch.arange(steps).expand(batch_size, -1) >= counts tokens[mask] = self.SPECULATION_TERMINATION_ID return tokens From 31203d7507885bf2eada5c19597f3ca3b15efc55 Mon Sep 17 00:00:00 2001 From: Amulya Ballakur Date: Thu, 12 Dec 2024 01:38:09 +0000 Subject: [PATCH 10/38] Add continuous batching with eagle Signed-off-by: Satyajith Chilappagari --- .../model_loader/neuronx_distributed.py | 21 +++++++------------ 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/model_loader/neuronx_distributed.py b/vllm/model_executor/model_loader/neuronx_distributed.py index 30080021f0f..c906c6af227 100644 --- a/vllm/model_executor/model_loader/neuronx_distributed.py +++ b/vllm/model_executor/model_loader/neuronx_distributed.py @@ -299,19 +299,13 @@ def forward( position_ids=positions, seq_ids=input_block_ids, sampling_params=sampling_params) - if output.fused_outputs[1].shape[-1] == 1: - # CTX encoding - return output.fused_outputs[1].view(1, -1) - draft_new_tokens = output.fused_outputs[0].view(1, -1) - target_tokens = output.fused_outputs[1].view(1, -1) - if self.config.neuron_config.enable_eagle_speculation: - candidate_new_tokens = draft_new_tokens[:, 1:] - else: - candidate_new_tokens = draft_new_tokens[:,:-1] - selected_tokens = target_tokens[:,:-1] - n_matches = ((~(candidate_new_tokens == selected_tokens)).cumsum(dim=-1) < 1).sum() - accepted_tokens = target_tokens[:,:n_matches+1] - return accepted_tokens + # CTX encoding + if (positions[:, 0]).sum().item() == 0: + return output.fused_outputs[0][:, 0:1] + + # Fused Spec (Generation) + accepted_tokens_with_padding = output.fused_outputs[0] + return accepted_tokens_with_padding def sample( self, @@ -322,6 +316,7 @@ def sample( seq_ids = [seq_id for sg in sampling_metadata.seq_groups for seq_id in sg.seq_ids] # Organize input tensors by step instead of by sequence. accepted_token_ids_by_step = logits.transpose(0, 1) + accepted_token_ids_by_step[accepted_token_ids_by_step == 0] = -1 accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() sampler_output_list = [] From ae29889086402c19e47ed9e40a009c54934d381d Mon Sep 17 00:00:00 2001 From: Patrick Lange Date: Wed, 18 Dec 2024 21:15:44 -0800 Subject: [PATCH 11/38] [NxDI] Fix masking of padding in speculative output Signed-off-by: Satyajith Chilappagari --- vllm/model_executor/model_loader/neuronx_distributed.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/model_loader/neuronx_distributed.py b/vllm/model_executor/model_loader/neuronx_distributed.py index c906c6af227..44fd14ba218 100644 --- a/vllm/model_executor/model_loader/neuronx_distributed.py +++ b/vllm/model_executor/model_loader/neuronx_distributed.py @@ -305,6 +305,13 @@ def forward( # Fused Spec (Generation) accepted_tokens_with_padding = output.fused_outputs[0] + next_pos_ids = output.fused_outputs[-1] + generated_token_counts = next_pos_ids - positions + + batch_size, steps = accepted_tokens_with_padding.shape + mask = torch.arange(steps).expand(batch_size, -1) >= generated_token_counts + accepted_tokens_with_padding[mask] = -1 + return accepted_tokens_with_padding def sample( @@ -316,7 +323,6 @@ def sample( seq_ids = [seq_id for sg in sampling_metadata.seq_groups for seq_id in sg.seq_ids] # Organize input tensors by step instead of by sequence. accepted_token_ids_by_step = logits.transpose(0, 1) - accepted_token_ids_by_step[accepted_token_ids_by_step == 0] = -1 accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() sampler_output_list = [] From 466cd01784d8a557f276c3bdd8c5bfb257c688d2 Mon Sep 17 00:00:00 2001 From: Elaine Zhao Date: Tue, 31 Dec 2024 23:53:24 +0000 Subject: [PATCH 12/38] Remove assertion on bs=1 when using speculation now that we support bs>1 Signed-off-by: Satyajith Chilappagari --- vllm/worker/multi_step_neuronx_distributed_model_runner.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/worker/multi_step_neuronx_distributed_model_runner.py b/vllm/worker/multi_step_neuronx_distributed_model_runner.py index 90f2a057a12..0137e255535 100644 --- a/vllm/worker/multi_step_neuronx_distributed_model_runner.py +++ b/vllm/worker/multi_step_neuronx_distributed_model_runner.py @@ -23,7 +23,6 @@ def __init__( def load_model(self) -> None: from vllm.model_executor.model_loader.neuronx_distributed import get_neuron_speculation_model - assert self.scheduler_config.max_num_seqs == 1, "Only batch size 1 is currently supported for speculation using neuronx-distributed-inference." self.model = get_neuron_speculation_model( self.model_config, parallel_config=self.parallel_config, From ac90709d43d5b35d64e7a7672f823411b863d290 Mon Sep 17 00:00:00 2001 From: Elaine Zhao Date: Tue, 18 Feb 2025 03:16:48 +0000 Subject: [PATCH 13/38] Modify NxDI and TNx multi step model runners (used for speculation) to conform to the new VllmConfig construct Signed-off-by: Satyajith Chilappagari --- .../model_loader/neuronx_distributed.py | 7 ++-- vllm/platforms/neuron.py | 2 -- vllm/worker/multi_step_neuron_model_runner.py | 17 ++++------ ...i_step_neuronx_distributed_model_runner.py | 18 ++++------- vllm/worker/neuron_worker.py | 32 ++++++------------- 5 files changed, 26 insertions(+), 50 deletions(-) diff --git a/vllm/model_executor/model_loader/neuronx_distributed.py b/vllm/model_executor/model_loader/neuronx_distributed.py index 44fd14ba218..dcf8bfa383a 100644 --- a/vllm/model_executor/model_loader/neuronx_distributed.py +++ b/vllm/model_executor/model_loader/neuronx_distributed.py @@ -287,18 +287,19 @@ def __init__( # Lazy initialized self.model: nn.Module + # FIXME(Neuron): restore sampling_params after migrating framework selection and dynamic sampling def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, input_block_ids: torch.Tensor, - sampling_params: torch.Tensor, + # sampling_params: torch.Tensor, ) -> torch.Tensor: output = self.model(input_ids, attention_mask=None, position_ids=positions, - seq_ids=input_block_ids, - sampling_params=sampling_params) + seq_ids=input_block_ids) + # sampling_params=sampling_params) # CTX encoding if (positions[:, 0]).sum().item() == 0: return output.fused_outputs[0][:, 0:1] diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py index 5a03f5f7acb..3f18ec21733 100644 --- a/vllm/platforms/neuron.py +++ b/vllm/platforms/neuron.py @@ -42,8 +42,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: assert (vllm_config.lora_config is None), "LoRA is not supported for Neuron backend." - assert (not vllm_config.speculative_config - ), "Speculative decoding not yet supported for Neuron backend." cache_config = vllm_config.cache_config if cache_config: diff --git a/vllm/worker/multi_step_neuron_model_runner.py b/vllm/worker/multi_step_neuron_model_runner.py index 764a2fc5b1a..678612d25d9 100644 --- a/vllm/worker/multi_step_neuron_model_runner.py +++ b/vllm/worker/multi_step_neuron_model_runner.py @@ -1,10 +1,9 @@ from importlib.util import find_spec import torch from typing import List, Optional -from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig, - SchedulerConfig, SpeculativeConfig) +from vllm.config import VllmConfig from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.multimodal import MultiModalInputs +from vllm.multimodal import MultiModalKwargs from vllm.sequence import IntermediateTensors from vllm.worker.neuron_model_runner import NeuronModelRunner, ModelInputForNeuron @@ -12,14 +11,10 @@ class MultiStepNeuronModelRunner(NeuronModelRunner): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - speculation_config: SpeculativeConfig, + vllm_config: VllmConfig, ): - super().__init__(model_config, parallel_config, scheduler_config, device_config) - self.speculation_config = speculation_config + super().__init__(vllm_config) + self.speculation_config = self.speculative_config from transformers_neuronx.config import GenerationConfig self.speculation_config.draft_model_config.neuron_sampling_params = GenerationConfig( max_length=self.scheduler_config.max_model_len, @@ -65,7 +60,7 @@ def execute_model( input_ids=model_input.input_tokens, positions=model_input.input_positions, input_block_ids=model_input.input_block_ids, - **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, + **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, device=self.device), ) diff --git a/vllm/worker/multi_step_neuronx_distributed_model_runner.py b/vllm/worker/multi_step_neuronx_distributed_model_runner.py index 0137e255535..6931a8f891d 100644 --- a/vllm/worker/multi_step_neuronx_distributed_model_runner.py +++ b/vllm/worker/multi_step_neuronx_distributed_model_runner.py @@ -1,10 +1,9 @@ from importlib.util import find_spec import torch from typing import List, Optional -from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig, - SchedulerConfig, SpeculativeConfig) +from vllm.config import VllmConfig from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.multimodal import MultiModalInputs +from vllm.multimodal import MultiModalKwargs from vllm.sequence import IntermediateTensors from vllm.worker.neuronx_distributed_model_runner import NeuronxDistributedModelRunner @@ -12,14 +11,9 @@ class MultiStepNeuronModelRunner(NeuronxDistributedModelRunner): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - speculation_config: SpeculativeConfig, + vllm_config: VllmConfig, ): - super().__init__(model_config, parallel_config, scheduler_config, device_config) - self.speculation_config = speculation_config + super().__init__(vllm_config) def load_model(self) -> None: from vllm.model_executor.model_loader.neuronx_distributed import get_neuron_speculation_model @@ -27,7 +21,7 @@ def load_model(self) -> None: self.model_config, parallel_config=self.parallel_config, scheduler_config=self.scheduler_config, - speculation_config=self.speculation_config) + speculation_config=self.speculative_config) @torch.inference_mode() def execute_model( @@ -41,7 +35,7 @@ def execute_model( input_ids=model_input.input_tokens, positions=model_input.input_positions, input_block_ids=model_input.input_block_ids, - **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, + **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, device=self.device), ) diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index b70e2927b60..dcfb29aa0a0 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -9,8 +9,7 @@ import torch.distributed from vllm.logger import init_logger -from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, - ParallelConfig, SchedulerConfig, SpeculativeConfig) +from vllm.config import VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.model_executor import set_random_seed @@ -81,12 +80,7 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - speculative_config: SpeculativeConfig, + vllm_config: VllmConfig, local_rank: int, rank: int, distributed_init_method: str, @@ -94,12 +88,10 @@ def __init__( world_size: int = 1, is_driver_worker: bool = False, ) -> None: - self.model_config = model_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.cache_config = cache_config - self.speculative_config = speculative_config + WorkerBase.__init__(self, vllm_config=vllm_config) + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method @@ -114,22 +106,18 @@ def __init__( from vllm.worker.neuron_model_runner import NeuronModelRunner from vllm.worker.multi_step_neuron_model_runner import MultiStepNeuronModelRunner if self.speculative_config is not None: - self.model_runner = MultiStepNeuronModelRunner( - model_config, parallel_config, scheduler_config, - device_config, speculative_config) + self.model_runner = MultiStepNeuronModelRunner(vllm_config=vllm_config) else: self.model_runner: NeuronModelRunner = NeuronModelRunner( - model_config, parallel_config, scheduler_config, device_config) + vllm_config=vllm_config) elif neuron_framework == NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE: from vllm.worker.neuronx_distributed_model_runner import NeuronxDistributedModelRunner from vllm.worker.multi_step_neuronx_distributed_model_runner import MultiStepNeuronModelRunner if self.speculative_config is not None: - self.model_runner = MultiStepNeuronModelRunner( - model_config, parallel_config, scheduler_config, - device_config, speculative_config) + self.model_runner = MultiStepNeuronModelRunner(vllm_config=vllm_config) else: self.model_runner: NeuronxDistributedModelRunner = NeuronxDistributedModelRunner( - model_config, parallel_config, scheduler_config, device_config) + vllm_config=vllm_config) else: raise NotImplementedError( f"Specified framework as {os.environ.get('VLLM_NEURON_FRAMEWORK')}," + From f10f17bf2b02c296466a1ea846eb937edde73380 Mon Sep 17 00:00:00 2001 From: Lin Lin Pan Date: Tue, 18 Feb 2025 21:33:41 +0000 Subject: [PATCH 14/38] Add multi-step NxD model runner Signed-off-by: Satyajith Chilappagari --- ...i_step_neuronx_distributed_model_runner.py | 48 ++++++++++++++----- vllm/worker/neuron_model_runner.py | 20 +++----- 2 files changed, 42 insertions(+), 26 deletions(-) diff --git a/vllm/worker/multi_step_neuronx_distributed_model_runner.py b/vllm/worker/multi_step_neuronx_distributed_model_runner.py index 6931a8f891d..f7c24b754df 100644 --- a/vllm/worker/multi_step_neuronx_distributed_model_runner.py +++ b/vllm/worker/multi_step_neuronx_distributed_model_runner.py @@ -1,27 +1,35 @@ from importlib.util import find_spec import torch from typing import List, Optional -from vllm.config import VllmConfig +from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig, + SchedulerConfig, SpeculativeConfig) from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.multimodal import MultiModalKwargs +from vllm.multimodal import MultiModalInputs from vllm.sequence import IntermediateTensors +from vllm.utils import is_neuronx_distributed_inference from vllm.worker.neuronx_distributed_model_runner import NeuronxDistributedModelRunner class MultiStepNeuronModelRunner(NeuronxDistributedModelRunner): def __init__( self, - vllm_config: VllmConfig, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + speculation_config: SpeculativeConfig, ): - super().__init__(vllm_config) + super().__init__(model_config, parallel_config, scheduler_config, device_config) + self.speculation_config = speculation_config def load_model(self) -> None: from vllm.model_executor.model_loader.neuronx_distributed import get_neuron_speculation_model + assert self.scheduler_config.max_num_seqs == 1, "Only batch size 1 is currently supported for speculation using neuronx-distributed-inference." self.model = get_neuron_speculation_model( self.model_config, parallel_config=self.parallel_config, scheduler_config=self.scheduler_config, - speculation_config=self.speculative_config) + speculation_config=self.speculation_config) @torch.inference_mode() def execute_model( @@ -31,13 +39,29 @@ def execute_model( intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, ) -> Optional[List[SamplerOutput]]: - logits = self.model( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - input_block_ids=model_input.input_block_ids, - **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, - device=self.device), - ) + if is_neuronx_distributed_inference(): + sampling_params = torch.tensor([[ + seq_group.sampling_params.top_k, + seq_group.sampling_params.top_p, + seq_group.sampling_params.temperature, + ] for seq_group in model_input.sampling_metadata.seq_groups]) + + logits = self.model( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + input_block_ids=model_input.input_block_ids, + sampling_params=sampling_params, + **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, + device=self.device), + ) + else: + logits = self.model( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + input_block_ids=model_input.input_block_ids, + **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, + device=self.device), + ) output = self.model.sample( logits=logits, diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index 81d9a81294f..80aff8766ac 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -16,7 +16,7 @@ MultiModalKwargs) from vllm.sampling_params import SamplingParams from vllm.sequence import IntermediateTensors, SequenceGroupMetadata -from vllm.utils import is_pin_memory_available, make_tensor_with_pad +from vllm.utils import is_pin_memory_available, make_tensor_with_pad, is_transformers_neuronx, is_neuronx_distributed_inference from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase from vllm.worker.neuron_worker import use_neuronx_distributed, use_transformers_neuronx @@ -270,15 +270,7 @@ def prepare_model_input( sampling_params.top_k = top_k sampling_params.top_p = top_p sampling_params.temperature = temperature - - # we need multi_modal_data for later tokens as well - # multi_modal_inputs_list: List[MultiModalInputs] = [] - # for seq_group_metadata in seq_group_metadata_list: - # mm_data = seq_group_metadata.multi_modal_data - # if mm_data: - # multi_modal_inputs_list.append(mm_data) - # multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) - + sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, @@ -385,8 +377,8 @@ def execute_model( positions=model_input.input_positions, input_block_ids=model_input.input_block_ids, sampling_params=sampling_params, - # **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, - # device=self.device), + **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, + device=self.device), ) elif use_transformers_neuronx(): # [TODO] validate on-device sampling @@ -395,8 +387,8 @@ def execute_model( input_ids=model_input.input_tokens, positions=model_input.input_positions, input_block_ids=model_input.input_block_ids, - # **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, - # device=self.device), + **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, + device=self.device), ) # Compute the logits only if the on-device sampling is turned off as From d1237478797f90a3678f1cf52ed8626b531a1304 Mon Sep 17 00:00:00 2001 From: Satyajith Chilappagari Date: Tue, 18 Feb 2025 01:14:34 +0000 Subject: [PATCH 15/38] Add Framework selection logic Signed-off-by: Satyajith Chilappagari --- requirements-neuron.txt | 1 - vllm/platforms/__init__.py | 12 ++++++-- vllm/platforms/neuron.py | 19 +++++++++++++ vllm/worker/neuron_worker.py | 53 ++++++++++++++++++------------------ 4 files changed, 55 insertions(+), 30 deletions(-) diff --git a/requirements-neuron.txt b/requirements-neuron.txt index 5e08d101fcd..09820c73e4e 100644 --- a/requirements-neuron.txt +++ b/requirements-neuron.txt @@ -2,6 +2,5 @@ -r requirements-common.txt # Dependencies for Neuron devices -transformers-neuronx >= 0.13.0 torch-neuronx >= 2.5.0 neuronx-cc diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index e4767a378f4..39161a2c636 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -119,13 +119,21 @@ def cpu_platform_plugin() -> Optional[str]: def neuron_platform_plugin() -> Optional[str]: - is_neuron = False + transformers_neuronx_installed = False + neuronx_distributed_inference_installed = False try: import transformers_neuronx # noqa: F401 - is_neuron = True + transformers_neuronx_installed = True except ImportError: pass + try: + import neuronx_distributed_inference + neuronx_distributed_inference_installed = True + except ImportError: + pass + + is_neuron = transformers_neuronx_installed or neuronx_distributed_inference_installed return "vllm.platforms.neuron.NeuronPlatform" if is_neuron else None diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py index 3f18ec21733..52b78f7625f 100644 --- a/vllm/platforms/neuron.py +++ b/vllm/platforms/neuron.py @@ -5,6 +5,7 @@ from vllm.logger import init_logger from .interface import Platform, PlatformEnum +from functools import lru_cache if TYPE_CHECKING: from vllm.config import VllmConfig @@ -53,3 +54,21 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: def is_pin_memory_available(cls) -> bool: logger.warning("Pin memory is not supported on Neuron.") return False + + @classmethod + @lru_cache + def is_neuronx_distributed_inference(cls) -> bool: + try: + import neuronx_distributed_inference + except ImportError: + neuronx_distributed_inference = None + return neuronx_distributed_inference is not None + + @classmethod + @lru_cache + def is_transformers_neuronx(cls) -> bool: + try: + import transformers_neuronx + except ImportError: + transformers_neuronx = None + return transformers_neuronx is not None diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index dcfb29aa0a0..7894d043137 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -17,11 +17,11 @@ from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, LoraNotSupportedWorkerBase, WorkerBase, WorkerInput) -# FIXME(Neuron): restore the framework selection logic. -# from vllm.utils import is_transformers_neuronx, is_neuronx_distributed_inference +from vllm.platforms import current_platform logger = init_logger(__name__) + class NeuronFramework(enum.Enum): TRANSFORMERS_NEURONX = "transformers-neuronx" NEURONX_DISTRIBUTED_INFERENCE = "neuronx-distributed-inference" @@ -34,22 +34,23 @@ def get_neuron_framework_to_use(): If no framework is specified, then use transformers-neuronx by default, if unavailable then check and switch to neuronx-distributed-inference. """ - # FIXME(Neuron): restore the framework selection logic. - return NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE - - # transformers_neuronx_installed = is_transformers_neuronx() - # neuronx_distributed_inference_installed = is_neuronx_distributed_inference() - # specified_framework = os.environ.get("VLLM_NEURON_FRAMEWORK") - # if specified_framework == NeuronFramework.TRANSFORMERS_NEURONX.value and transformers_neuronx_installed: - # return NeuronFramework.TRANSFORMERS_NEURONX - # elif specified_framework == NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE.value and neuronx_distributed_inference_installed: - # return NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE - # elif specified_framework is None and transformers_neuronx_installed: - # return NeuronFramework.TRANSFORMERS_NEURONX - # elif specified_framework is None and neuronx_distributed_inference_installed: - # return NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE - # else: - # return None + if not current_platform.is_neuron(): + raise AssertionError(f"Neuron Framework cannot be obtained for Non-neuron Platform: {current_platform}") + + transformers_neuronx_installed = current_platform.is_transformers_neuronx() + neuronx_distributed_inference_installed = current_platform.is_neuronx_distributed_inference() + + specified_framework = os.environ.get("VLLM_NEURON_FRAMEWORK") + if specified_framework == NeuronFramework.TRANSFORMERS_NEURONX.value and transformers_neuronx_installed: + return NeuronFramework.TRANSFORMERS_NEURONX + elif specified_framework == NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE.value and neuronx_distributed_inference_installed: + return NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE + elif specified_framework is None and neuronx_distributed_inference_installed: + return NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE + elif specified_framework is None and transformers_neuronx_installed: + return NeuronFramework.TRANSFORMERS_NEURONX + else: + return None @lru_cache(maxsize=None) @@ -92,9 +93,6 @@ def __init__( self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method - self.local_rank = local_rank - self.rank = rank - self.distributed_init_method = distributed_init_method self.is_driver_worker = is_driver_worker if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing @@ -104,24 +102,25 @@ def __init__( if neuron_framework == NeuronFramework.TRANSFORMERS_NEURONX: from vllm.worker.neuron_model_runner import NeuronModelRunner - from vllm.worker.multi_step_neuron_model_runner import MultiStepNeuronModelRunner + # from vllm.worker.multi_step_neuron_model_runner import MultiStepNeuronModelRunner if self.speculative_config is not None: - self.model_runner = MultiStepNeuronModelRunner(vllm_config=vllm_config) + pass else: self.model_runner: NeuronModelRunner = NeuronModelRunner( vllm_config=vllm_config) elif neuron_framework == NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE: from vllm.worker.neuronx_distributed_model_runner import NeuronxDistributedModelRunner - from vllm.worker.multi_step_neuronx_distributed_model_runner import MultiStepNeuronModelRunner + # from vllm.worker.multi_step_neuronx_distributed_model_runner import MultiStepNeuronModelRunner if self.speculative_config is not None: - self.model_runner = MultiStepNeuronModelRunner(vllm_config=vllm_config) + pass else: self.model_runner: NeuronxDistributedModelRunner = NeuronxDistributedModelRunner( vllm_config=vllm_config) else: raise NotImplementedError( - f"Specified framework as {os.environ.get('VLLM_NEURON_FRAMEWORK')}," + - " Only transformers-neuronx/neuronx-distributed-inference framework is supported") + f"Specified framework {os.environ.get('VLLM_NEURON_FRAMEWORK')}" + + " is either not installed or not supported." + + " Supported frameworks: [transformers-neuronx, neuronx-distributed-inference]") def init_device(self) -> None: self.init_distributed_environment() From 547709abef80076767434e5968a9810c3b5c710d Mon Sep 17 00:00:00 2001 From: Patrick Lange Date: Fri, 20 Dec 2024 18:26:54 -0800 Subject: [PATCH 16/38] Fix no free blocks error Signed-off-by: Satyajith Chilappagari --- vllm/engine/llm_engine.py | 1 + vllm/engine/output_processor/stop_checker.py | 6 ++++-- vllm/model_executor/model_loader/neuronx_distributed.py | 2 ++ 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index d82d9ad9df3..b77d2f8e2ab 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -400,6 +400,7 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: stop_checker=StopChecker( self.scheduler_config.max_model_len, get_tokenizer_for_seq, + num_lookahead_slots=self.scheduler_config.num_lookahead_slots if self.device_config.device_type == "neuron" else 0 ), )) diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py index 3bca0bee35a..07ec23713ef 100644 --- a/vllm/engine/output_processor/stop_checker.py +++ b/vllm/engine/output_processor/stop_checker.py @@ -16,10 +16,12 @@ class StopChecker: """ def __init__(self, max_model_len: int, - get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer]): + get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer], + num_lookahead_slots: int = 0): # Do not use it directly, but use `self._get_max_model_len`. self._max_model_len = max_model_len self.get_tokenizer_for_seq = get_tokenizer_for_seq + self.num_lookahead_slots = num_lookahead_slots def _get_max_model_len(self, lora_req: Optional[LoRARequest]): if lora_req and lora_req.long_lora_max_len: @@ -81,7 +83,7 @@ def maybe_stop_sequence( return # Check if the sequence has reached max_model_len. - if seq.get_len() > self._get_max_model_len(lora_req): + if seq.get_len() + self.num_lookahead_slots > self._get_max_model_len(lora_req): seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED return diff --git a/vllm/model_executor/model_loader/neuronx_distributed.py b/vllm/model_executor/model_loader/neuronx_distributed.py index dcf8bfa383a..dcba28c474a 100644 --- a/vllm/model_executor/model_loader/neuronx_distributed.py +++ b/vllm/model_executor/model_loader/neuronx_distributed.py @@ -309,6 +309,8 @@ def forward( next_pos_ids = output.fused_outputs[-1] generated_token_counts = next_pos_ids - positions + assert torch.any(generated_token_counts == 0).item() is False, "NxDI model generated no output for one or more sequences." + batch_size, steps = accepted_tokens_with_padding.shape mask = torch.arange(steps).expand(batch_size, -1) >= generated_token_counts accepted_tokens_with_padding[mask] = -1 From e91857a7501c4df0b065a20e6883e59be23a420a Mon Sep 17 00:00:00 2001 From: Elaine Zhao Date: Thu, 20 Feb 2025 05:21:36 +0000 Subject: [PATCH 17/38] Refactor and add basic docstrings Signed-off-by: Satyajith Chilappagari --- .../offline_inference/neuron_speculation.py | 63 +++++++++++++++++++ examples/offline_speculation_neuron.py | 32 ---------- vllm/model_executor/model_loader/neuron.py | 17 +++-- .../model_loader/neuronx_distributed.py | 15 ++++- vllm/worker/multi_step_neuron_model_runner.py | 1 + ...i_step_neuronx_distributed_model_runner.py | 56 ++++++----------- vllm/worker/neuron_model_runner.py | 1 + 7 files changed, 110 insertions(+), 75 deletions(-) create mode 100644 examples/offline_inference/neuron_speculation.py delete mode 100644 examples/offline_speculation_neuron.py diff --git a/examples/offline_inference/neuron_speculation.py b/examples/offline_inference/neuron_speculation.py new file mode 100644 index 00000000000..9cae4d47c9f --- /dev/null +++ b/examples/offline_inference/neuron_speculation.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 + +""" +This example shows how to run offline inference with a speculative +decoding model on neuron. +""" + +import os + +from vllm import LLM, SamplingParams + +# Sample prompts. +prompts = [ + "Hello, I am a language model and I can help", + "The president of the United States is", + "The capital of France is", +] + + +def config_buckets(): + """Configure context length and token gen buckets.""" + # creates XLA hlo graphs for all the context length buckets. + os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048" + # creates XLA hlo graphs for all the token gen buckets. + os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048" + + +def initialize_model(): + """Create an LLM with speculative decoding.""" + return LLM( + model="openlm-research/open_llama_7b", + speculative_model='openlm-research/open_llama_3b', + num_speculative_tokens=4, + max_num_seqs=4, + max_model_len=2048, + block_size=2048, + speculative_max_model_len=2048, + use_v2_block_manager=True, + device="neuron", + tensor_parallel_size=32, + ) + + +def process_requests(model: LLM, sampling_params: SamplingParams): + """Generate texts from prompts and print them.""" + outputs = model.generate(prompts, sampling_params) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +def main(): + """Main function that sets up the model and processes prompts.""" + config_buckets() + model = initialize_model() + # Create a sampling params object. + sampling_params = SamplingParams(max_tokens=100, top_k=1) + process_requests(model, sampling_params) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/examples/offline_speculation_neuron.py b/examples/offline_speculation_neuron.py deleted file mode 100644 index 5b5ed517ad9..00000000000 --- a/examples/offline_speculation_neuron.py +++ /dev/null @@ -1,32 +0,0 @@ -import os - -from vllm import LLM, SamplingParams - -# creates XLA hlo graphs for all the context length buckets. -os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048" -# creates XLA hlo graphs for all the token gen buckets. -os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048" - -prompts = [ - "Hello, I am a language model and I can help", - "The president of the United States is", - "The capital of France is", -] -sampling_params = SamplingParams(max_tokens=100, top_k=1) -llm = LLM( - model="openlm-research/open_llama_7b", - speculative_model='openlm-research/open_llama_3b', - num_speculative_tokens=4, - max_num_seqs=4, - max_model_len=2048, - block_size=2048, - speculative_max_model_len=2048, - use_v2_block_manager=True, - device="neuron", - tensor_parallel_size=32, -) -outputs = llm.generate(prompts, sampling_params) -for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/vllm/model_executor/model_loader/neuron.py b/vllm/model_executor/model_loader/neuron.py index 47d23f5ae27..c6dee5d866c 100644 --- a/vllm/model_executor/model_loader/neuron.py +++ b/vllm/model_executor/model_loader/neuron.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 -"""Utilities for selecting and loading neuron models.""" +"""Utilities for selecting and loading Neuron models in transformers-neuronx framework.""" import ast import copy import importlib import os -from typing import Dict, List, Optional, Tuple, List +from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn @@ -115,7 +115,8 @@ def load_weights(self, model_name_or_path: str, **kwargs): class NeuronSpeculationCausalLM(nn.Module): - + """A Neuron-optimized causal language model with speculative decoding.""" + SPECULATION_TERMINATION_ID = -1 def __init__( @@ -192,6 +193,7 @@ def _get_buckets(env: str, default_value: List[int]) -> List[int]: def _get_default_neuron_config(model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig): + """Generate a neuron config based on vllm config args.""" from transformers_neuronx.config import ContinuousBatchingConfig from transformers_neuronx.constants import LAYOUT_BSH @@ -221,6 +223,7 @@ def _get_default_neuron_config_for_speculation( parallel_config: ParallelConfig, scheduler_config: SchedulerConfig ): + """Generate a neuron config for speculative decoding based on vllm config args.""" from transformers_neuronx.config import ContinuousBatchingConfig from transformers_neuronx.constants import LAYOUT_BSH @@ -250,6 +253,7 @@ def _is_neuron_on_device_sampling_disabled(model_config: ModelConfig) -> bool: def _get_neuron_config_after_override(default_neuron_config, overridden_neuron_config): + """Update default neuron config values with override args""" from transformers_neuronx.config import NeuronConfig overridden_neuron_config = overridden_neuron_config or {} default_neuron_config.update(overridden_neuron_config) @@ -259,7 +263,7 @@ def _get_neuron_config_after_override(default_neuron_config, def get_neuron_model(model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig) -> nn.Module: - + """Initializes a neuron-optimized model for inference.""" # Create a model instance. model = NeuronCausalLM( model_config.hf_config, @@ -293,6 +297,10 @@ def get_neuron_speculation_model( scheduler_config: SchedulerConfig, speculation_config: SpeculativeConfig ) -> None: + """Initializes a neuron-optimized speculation model for inference. + + This method is only applicable for speculation with a standalone draft model. + """ from transformers_neuronx.fused_speculation import FusedSpeculativeDecoder # For Eagle SD, we need to pass in additional parameters in neuron config. @@ -359,6 +367,7 @@ def get_neuron_eagle_speculation_model(model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, speculation_config: SpeculativeConfig) -> None: + """Initializes a neuron-optimized EAGLE speculation model for inference.""" from transformers_neuronx.eagle_speculation import EagleSpeculativeDecoder # Create target model instance. diff --git a/vllm/model_executor/model_loader/neuronx_distributed.py b/vllm/model_executor/model_loader/neuronx_distributed.py index dcba28c474a..9c76f1b447b 100644 --- a/vllm/model_executor/model_loader/neuronx_distributed.py +++ b/vllm/model_executor/model_loader/neuronx_distributed.py @@ -1,3 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Utilities for selecting and loading Neuron models in neuronx-distributed-inference framework.""" import copy import hashlib import importlib @@ -276,6 +278,7 @@ def compile_model(neuron_model, traced_model_path): class NeuronSpeculationCausalLM(nn.Module): + """A Neuron-optimized causal language model with speculative decoding.""" def __init__( self, config: PretrainedConfig, @@ -386,8 +389,7 @@ def load_weights(self, model_name_or_path: str, draft_model_name_or_path: str, * except (FileNotFoundError, ValueError) as e: logger.warning(f"Exception: {e}") logger.warning(f"Failed to load the model from {compiled_model_path}, Recompiling...") - if draft_model_name_or_path == model_name_or_path: - draft_checkpoint_download = False + draft_checkpoint_download = not draft_model_name_or_path == model_name_or_path if not os.path.exists(model_name_or_path): hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path) saved_path = os.path.join("local-models", model_name_or_path) @@ -420,7 +422,7 @@ def _get_model_architecture(config: PretrainedConfig) -> str: def _get_default_neuron_config(model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig): - + """Generate a neuron config based on vllm config args.""" logger.info(f"Initializing OnDeviceSampling config with global_topk=64") on_device_sampling_config = OnDeviceSamplingConfig(global_topk=64, dynamic=True, @@ -447,6 +449,7 @@ def _get_default_neuron_speculation_config(model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, speculation_config: SpeculativeConfig): + """Generate a neuron config for speculative decoding based on vllm config args.""" neuron_config = dict( tp_degree=parallel_config.tensor_parallel_size, batch_size=scheduler_config.max_num_seqs, @@ -465,6 +468,7 @@ def _get_default_neuron_speculation_config(model_config: ModelConfig, def _get_neuron_config_after_override(default_neuron_config, overridden_neuron_config): + """Update default neuron config values with override args""" overridden_neuron_config = overridden_neuron_config or {} default_neuron_config.update(overridden_neuron_config) return default_neuron_config @@ -472,6 +476,7 @@ def _get_neuron_config_after_override(default_neuron_config, def get_neuron_model(model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig) -> nn.Module: + """Initializes a neuron-optimized model for inference.""" model_arch = _get_model_architecture(model_config.hf_config) if model_arch == "MllamaForConditionalGeneration": model = NeuronMllamaForCausalLM(model_config.hf_config) @@ -490,6 +495,10 @@ def get_neuron_speculation_model(model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, speculation_config: SpeculativeConfig): + """Initializes a neuron-optimized speculation model for inference. + + This model handles speculation using both a draft model and an EAGLE draft. + """ model = NeuronSpeculationCausalLM(model_config.hf_config) default_neuron_config_args = _get_default_neuron_speculation_config( model_config, parallel_config, scheduler_config, speculation_config) diff --git a/vllm/worker/multi_step_neuron_model_runner.py b/vllm/worker/multi_step_neuron_model_runner.py index 678612d25d9..38e2e535a89 100644 --- a/vllm/worker/multi_step_neuron_model_runner.py +++ b/vllm/worker/multi_step_neuron_model_runner.py @@ -8,6 +8,7 @@ from vllm.worker.neuron_model_runner import NeuronModelRunner, ModelInputForNeuron class MultiStepNeuronModelRunner(NeuronModelRunner): + """A model runner for multi step decoding using the transformers_neuronx framework""" def __init__( self, diff --git a/vllm/worker/multi_step_neuronx_distributed_model_runner.py b/vllm/worker/multi_step_neuronx_distributed_model_runner.py index f7c24b754df..209b307f946 100644 --- a/vllm/worker/multi_step_neuronx_distributed_model_runner.py +++ b/vllm/worker/multi_step_neuronx_distributed_model_runner.py @@ -1,35 +1,28 @@ from importlib.util import find_spec import torch from typing import List, Optional -from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig, - SchedulerConfig, SpeculativeConfig) +from vllm.config import VllmConfig from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.multimodal import MultiModalInputs +from vllm.multimodal import MultiModalKwargs from vllm.sequence import IntermediateTensors -from vllm.utils import is_neuronx_distributed_inference from vllm.worker.neuronx_distributed_model_runner import NeuronxDistributedModelRunner -class MultiStepNeuronModelRunner(NeuronxDistributedModelRunner): +class MultiStepNeuronxDistributedModelRunner(NeuronxDistributedModelRunner): + """A model runner for multi step decoding using the neuronx-distributed-inference framework""" def __init__( self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - speculation_config: SpeculativeConfig, + vllm_config: VllmConfig, ): - super().__init__(model_config, parallel_config, scheduler_config, device_config) - self.speculation_config = speculation_config + super().__init__(vllm_config) def load_model(self) -> None: from vllm.model_executor.model_loader.neuronx_distributed import get_neuron_speculation_model - assert self.scheduler_config.max_num_seqs == 1, "Only batch size 1 is currently supported for speculation using neuronx-distributed-inference." self.model = get_neuron_speculation_model( self.model_config, parallel_config=self.parallel_config, scheduler_config=self.scheduler_config, - speculation_config=self.speculation_config) + speculation_config=self.speculative_config) @torch.inference_mode() def execute_model( @@ -39,29 +32,20 @@ def execute_model( intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, ) -> Optional[List[SamplerOutput]]: - if is_neuronx_distributed_inference(): - sampling_params = torch.tensor([[ - seq_group.sampling_params.top_k, - seq_group.sampling_params.top_p, - seq_group.sampling_params.temperature, - ] for seq_group in model_input.sampling_metadata.seq_groups]) + sampling_params = torch.tensor([[ + seq_group.sampling_params.top_k, + seq_group.sampling_params.top_p, + seq_group.sampling_params.temperature, + ] for seq_group in model_input.sampling_metadata.seq_groups]) - logits = self.model( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - input_block_ids=model_input.input_block_ids, - sampling_params=sampling_params, - **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, - device=self.device), - ) - else: - logits = self.model( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - input_block_ids=model_input.input_block_ids, - **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, - device=self.device), - ) + logits = self.model( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + input_block_ids=model_input.input_block_ids, + sampling_params=sampling_params, + **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, + device=self.device), + ) output = self.model.sample( logits=logits, diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index 80aff8766ac..eaa9f0d4fda 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -63,6 +63,7 @@ def from_broadcasted_tensor_dict( class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): + """A model runner for AWS Neuron hardware""" # NEURON has an upper limit on the top_k _MAX_NEURON_SAMPLING_TOP_K = 256 From 70ad6a93db1539ca90518954233a800d4f5886f0 Mon Sep 17 00:00:00 2001 From: Navyadhara Gogineni Date: Fri, 21 Feb 2025 02:32:21 +0000 Subject: [PATCH 18/38] Modification to enable Vllm-neuronx instead of Vllm for KTF Signed-off-by: Satyajith Chilappagari --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index a4043c43a7d..9a2852a614c 100755 --- a/setup.py +++ b/setup.py @@ -629,7 +629,7 @@ def _read_requirements(filename: str) -> List[str]: } setup( - name="vllm", + name="vllm-neuronx", version=get_vllm_version(), author="vLLM Team", license="Apache 2.0", From 2df94fe3e88f5076434ff285e5bf22f713196ec0 Mon Sep 17 00:00:00 2001 From: Yishan McNabb Date: Fri, 10 Jan 2025 01:15:45 +0000 Subject: [PATCH 19/38] Fix global_top_k to be aligned with NxDI default Signed-off-by: Satyajith Chilappagari --- vllm/model_executor/model_loader/neuronx_distributed.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/model_loader/neuronx_distributed.py b/vllm/model_executor/model_loader/neuronx_distributed.py index 9c76f1b447b..9cf85b5e328 100644 --- a/vllm/model_executor/model_loader/neuronx_distributed.py +++ b/vllm/model_executor/model_loader/neuronx_distributed.py @@ -423,10 +423,8 @@ def _get_default_neuron_config(model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig): """Generate a neuron config based on vllm config args.""" - logger.info(f"Initializing OnDeviceSampling config with global_topk=64") - on_device_sampling_config = OnDeviceSamplingConfig(global_topk=64, - dynamic=True, - deterministic=False) + on_device_sampling_config = OnDeviceSamplingConfig(dynamic=True, + deterministic=False) batch_size = scheduler_config.max_num_seqs neuron_config = dict( From da8d1cfea7d46c93238b1eba3da8d92bddb4eb88 Mon Sep 17 00:00:00 2001 From: Chongming Ni Date: Thu, 24 Oct 2024 23:11:34 +0000 Subject: [PATCH 20/38] Add neuron model runner tests for updating sampling param Signed-off-by: Satyajith Chilappagari --- tests/worker/test_neuron_model_runner.py | 117 +++++++++++++++++++++++ 1 file changed, 117 insertions(+) create mode 100644 tests/worker/test_neuron_model_runner.py diff --git a/tests/worker/test_neuron_model_runner.py b/tests/worker/test_neuron_model_runner.py new file mode 100644 index 00000000000..16f501700aa --- /dev/null +++ b/tests/worker/test_neuron_model_runner.py @@ -0,0 +1,117 @@ +import os +from unittest.mock import MagicMock +from vllm.config import VllmConfig +from vllm.engine.arg_utils import EngineArgs +from vllm.sampling_params import SamplingParams +from vllm.sequence import SequenceData, SequenceGroupMetadata +from vllm.worker.neuron_model_runner import NeuronModelRunner +from vllm.worker.neuron_worker import use_transformers_neuronx, NeuronFramework + +os.environ['VLLM_NEURON_FRAMEWORK'] = NeuronFramework.TRANSFORMERS_NEURONX.value + +def _create_neuron_model_runner(model: str, *args, + **kwargs) -> NeuronModelRunner: + engine_args = EngineArgs(model, *args, **kwargs) + engine_config = engine_args.create_engine_config() + vllm_config = VllmConfig( + model_config=engine_config.model_config, + parallel_config=engine_config.parallel_config, + scheduler_config=engine_config.scheduler_config, + device_config=engine_config.device_config, + ) + neuron_model_runner = NeuronModelRunner( + vllm_config=vllm_config + ) + return neuron_model_runner + + +def test_update_neuron_sampling_params_not_full_batch(): + os.environ["NEURON_ON_DEVICE_SAMPLING_DISABLED"] = "0" + model_runner = _create_neuron_model_runner( + "facebook/opt-125m", + seed=0, + dtype="float16", + max_num_seqs=2, + ) + assert not model_runner._on_device_sampling_disabled + # Test sampling param updating only when TNx is framework + # NxDI handles sampling paramter updating inside model + if use_transformers_neuronx(): + model_mock = MagicMock() + model_runner.model = model_mock + + seq_group_metadata_list = [ + SequenceGroupMetadata( + request_id=f"test_0", + is_prompt=True, + seq_data={0: SequenceData.from_seqs([1, 2, 3])}, + sampling_params=SamplingParams(temperature=0.5, top_k=1, + top_p=0.5), + block_tables={0: [1]}, + ) + ] + + model_runner.prepare_model_input(seq_group_metadata_list) + + # Index neuron sampling parameters based on block_tables indices. + # The first block_id of the sequence 0 is 1, so its parameters are placed at + # index 1. So the sampling parameters will be: + # Index 0: default sampling parameters + # Index 1: sequecne 0's sampling parameters. + neuron_sampling_params = model_runner.model_config.neuron_sampling_params + assert neuron_sampling_params.temperature == [1.0, 0.5] + assert neuron_sampling_params.top_k == [ + model_runner._MAX_NEURON_SAMPLING_TOP_K, 1 + ] + assert neuron_sampling_params.top_p == [1.0, 0.5] + model_mock.model.update_generation_config.assert_called_once_with( + neuron_sampling_params) + +def test_update_neuron_sampling_params_full_batch(): + os.environ["NEURON_ON_DEVICE_SAMPLING_DISABLED"] = "0" + model_runner = _create_neuron_model_runner( + "facebook/opt-125m", + seed=0, + dtype="float16", + max_num_seqs=2, + ) + assert not model_runner._on_device_sampling_disabled + + # Test sampling param updating only when TNx is framework + # NxDI handles sampling paramter updating inside model + if use_transformers_neuronx(): + model_mock = MagicMock() + model_runner.model = model_mock + + seq_group_metadata_list = [ + SequenceGroupMetadata( + request_id=f"test_0", + is_prompt=True, + seq_data={0: SequenceData.from_seqs([1, 2, 3])}, + sampling_params=SamplingParams(temperature=0.5, top_k=1, + top_p=0.5), + block_tables={0: [1]}, + ), + SequenceGroupMetadata( + request_id=f"test_0", + is_prompt=True, + seq_data={1: SequenceData.from_seqs([4, 5, 6])}, + sampling_params=SamplingParams(temperature=0.2, top_k=2, + top_p=0.2), + block_tables={1: [0]}, + ) + ] + + model_runner.prepare_model_input(seq_group_metadata_list) + + # Index neuron sampling parameters based on block_tables indices. + # The first block_id of the sequence 0 is 1, so its parameters are placed at + # index 1. So the sampling parameters will be: + # Index 0: sequence 1's sampling parameters + # Index 1: sequecne 0's sampling parameters. + neuron_sampling_params = model_runner.model_config.neuron_sampling_params + assert neuron_sampling_params.temperature == [0.2, 0.5] + assert neuron_sampling_params.top_k == [2, 1] + assert neuron_sampling_params.top_p == [0.2, 0.5] + model_mock.model.update_generation_config.assert_called_once_with( + neuron_sampling_params) From 498a918f890a66438e382bef691fa1057df7a0b1 Mon Sep 17 00:00:00 2001 From: Navyadhara Gogineni Date: Fri, 21 Feb 2025 19:59:03 +0000 Subject: [PATCH 21/38] Updating requirements-neuron.txt Signed-off-by: Satyajith Chilappagari --- requirements-neuron.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements-neuron.txt b/requirements-neuron.txt index 09820c73e4e..33aa1a9fd0c 100644 --- a/requirements-neuron.txt +++ b/requirements-neuron.txt @@ -3,4 +3,5 @@ # Dependencies for Neuron devices torch-neuronx >= 2.5.0 -neuronx-cc +neuronx-cc==2.* +torchvision # Required for Llama3.2 multimodal image preprocessing From 8bc6537480c454156fadb876441d909211d4fa15 Mon Sep 17 00:00:00 2001 From: Navyadhara Gogineni Date: Fri, 21 Feb 2025 19:59:03 +0000 Subject: [PATCH 22/38] Updating requirements-neuron.txt Signed-off-by: Satyajith Chilappagari --- requirements-neuron.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-neuron.txt b/requirements-neuron.txt index 33aa1a9fd0c..7cc80b162d5 100644 --- a/requirements-neuron.txt +++ b/requirements-neuron.txt @@ -3,5 +3,5 @@ # Dependencies for Neuron devices torch-neuronx >= 2.5.0 -neuronx-cc==2.* +neuronx-cc>=2.0.0a0 torchvision # Required for Llama3.2 multimodal image preprocessing From 4907e3eb519af77267b4e086344f249e45919654 Mon Sep 17 00:00:00 2001 From: Aaron Dou Date: Fri, 21 Feb 2025 19:57:54 +0000 Subject: [PATCH 23/38] multi-node TP support Signed-off-by: Satyajith Chilappagari --- .gitignore | 3 + examples/neuron/multi_node/launch_script.py | 119 ++++++++++++++++++ .../neuron/multi_node/multi_node_launcher.sh | 48 +++++++ examples/neuron/multi_node/worker.py | 40 ++++++ vllm/worker/neuron_worker.py | 30 +++-- 5 files changed, 233 insertions(+), 7 deletions(-) create mode 100644 examples/neuron/multi_node/launch_script.py create mode 100755 examples/neuron/multi_node/multi_node_launcher.sh create mode 100644 examples/neuron/multi_node/worker.py diff --git a/.gitignore b/.gitignore index 89dab8f13ba..e036102e32b 100644 --- a/.gitignore +++ b/.gitignore @@ -202,3 +202,6 @@ benchmarks/*.json # Linting actionlint shellcheck*/ + +# Build artifacts +build diff --git a/examples/neuron/multi_node/launch_script.py b/examples/neuron/multi_node/launch_script.py new file mode 100644 index 00000000000..31822ce7774 --- /dev/null +++ b/examples/neuron/multi_node/launch_script.py @@ -0,0 +1,119 @@ +import argparse +import json +import os +import sys +import subprocess +from typing import Dict, Any + +from vllm.logger import init_logger +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.usage.usage_lib import UsageContext + + +logger = init_logger("vllm.neuron.multi-node") + + +NEURON_RT_ROOT_COMM_ID_PORT = 63423 + + +def error_exit(message: str) -> None: + logger.error(message) + sys.exit(1) + +def arg_parser(): + parser = argparse.ArgumentParser(description="vLLM multi-node launcher") + parser.add_argument("--model", type=str, required=True, help="Model or model path") + parser.add_argument("--world-size", type=int, required=True, help="World size for distributed inference") + parser.add_argument("--max-num-seqs", type=int, required=True, help="Maximum number of sequences (or batch size)") + parser.add_argument("--max-model-len", type=int, default=8192, help="Maximum sequence length") + parser.add_argument("--max-context-length", type=int, help="Maximum context length") + parser.add_argument("--compiled-model-path", help="Path to the compiled model. If not present, model artifacts will be created in local-models folder") + parser.add_argument("--local-ranks-size", type=int, default=32, help="Local ranks size") + parser.add_argument("--on-device-sampling-config", type=json.loads, help="On-device sampling configuration") + parser.add_argument("--quantized", type=bool, default=False, help="Enable quantized mode (default: False)") + parser.add_argument("--quantized-checkpoints-path", type=str, help="Path to quantized checkpoints (required if --quantized is True)") + parser.add_argument("--port", type=int, default=8080, help="Port for the API server") + + args = parser.parse_args() + if args.quantized and not args.quantized_checkpoints_path: + parser.error("--quantized-checkpoints-path is required when --quantized is enabled.") + return args + +def make_override_config(args, rank): + if rank < 0: + error_exit("rank must be a non-negative integer") + start_rank_id = rank * args.local_ranks_size + override_config = { + "world_size": args.world_size, + "tp_degree": args.local_ranks_size, + "local_ranks_size": args.local_ranks_size, + "start_rank_id": start_rank_id, + } + + if args.max_context_length: + override_config["max_context_length"] = args.max_context_length + if args.on_device_sampling_config: + override_config["on_device_sampling_config"] = args.on_device_sampling_config + if args.quantized: + override_config["quantized_checkpoints_path"] = args.quantized_checkpoints_path + override_config["quantized"] = args.quantized + + return override_config + + +def main() -> None: + args = arg_parser() + + rank = int(os.environ.get("OMPI_COMM_WORLD_RANK")) + mpi_world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE")) + master_addr = os.environ.get("MASTER_ADDR") + # TODO: this script can be extended to support TnX + os.environ["VLLM_NEURON_FRAMEWORK"] = "neuronx-distributed-inference" + if args.compiled_model_path: + os.environ["NEURON_COMPILED_ARTIFACTS"] = args.compiled_model_path + os.environ.update({ + "ENABLE_NEURON_MULTI_NODE": "true", + "WORLD_SIZE": str(mpi_world_size), + "NEURON_RT_ROOT_COMM_ID": f"{master_addr}:{NEURON_RT_ROOT_COMM_ID_PORT}", + "NEURON_LOCAL_TP": str(args.local_ranks_size), + "NEURON_RANK_ID": str(rank) + }) + + override_config = make_override_config(args, rank) + if rank == 0: + logger.info("Starting vLLM API server on rank 0...") + cmd = [ + "python", "-m", "vllm.entrypoints.api_server", + f"--model={args.model}", + f"--port={args.port}", + "--device=neuron", + f"--max-num-seqs={args.max_num_seqs}", + f"--max-model-len={args.max_model_len}", + f"--override-neuron-config={json.dumps(override_config)}" + ] + logger.debug(f"Command ran: {cmd}") + try: + subprocess.run(cmd, check=True) + except subprocess.CalledProcessError: + error_exit(f"Failed to start vLLM API server on rank {rank}") + else: + logger.info(f"Starting worker on rank {rank}...") + current_script_dir = os.path.dirname(os.path.abspath(__file__)) + worker_file_path = os.path.join(current_script_dir, "worker.py") + cmd = [ + "python", worker_file_path, + f"--model={args.model}", + "--device=neuron", + f"--max-num-seqs={args.max_num_seqs}", + f"--max-model-len={args.max_model_len}", + f"--override-neuron-config={json.dumps(override_config)}" + ] + logger.debug(f"Command ran: {cmd}") + try: + subprocess.run(cmd, check=True) + except subprocess.CalledProcessError: + error_exit(f"Failed to start worker on rank {rank}") + +if __name__ == "__main__": + main() diff --git a/examples/neuron/multi_node/multi_node_launcher.sh b/examples/neuron/multi_node/multi_node_launcher.sh new file mode 100755 index 00000000000..6e992ea0ab7 --- /dev/null +++ b/examples/neuron/multi_node/multi_node_launcher.sh @@ -0,0 +1,48 @@ +#!/bin/bash -ex + +HOSTFILE="" +MASTER_ADDR="" +MASTER_PORT="" + +usage() { + echo "Usage: $0 -h -a -p " + exit 1 +} + +while getopts "h:a:p:" opt; do + case "$opt" in + h) HOSTFILE=$OPTARG ;; + a) MASTER_ADDR=$OPTARG ;; + p) MASTER_PORT=$OPTARG ;; + *) usage ;; + esac +done + +shift $((OPTIND - 1)) + +if [ -z "$HOSTFILE" ] || [ -z "$MASTER_ADDR" ] || [ -z "$MASTER_PORT" ]; then + echo "Error: Missing required arguments." + usage +fi + +echo "Using hostfile: $HOSTFILE" +echo "Using address: $MASTER_ADDR" +echo "Using port: $MASTER_PORT" +echo "Python command:" +echo "$@" + +# Use mpirun to trigger inference on head/worker nodes + +/opt/amazon/openmpi/bin/mpirun \ + --mca mtl ^ofi --mca btl tcp,self --bind-to none \ + -np 2 \ + --hostfile "$HOSTFILE"\ + --prefix /opt/amazon/openmpi \ + -x FI_PROVIDER=efa \ + -x FI_EFA_USE_DEVICE_RDMA=1 \ + -x FI_EFA_FORK_SAFE=1 \ + -x PATH=/opt/amazon/openmpi/bin:$PATH \ + -x PYTHONPATH=$PYTHONPATH \ + -x LD_LIBRARY_PATH=/opt/aws/neuron/lib:/opt/amazon/efa/lib:/opt/amazon/openmpi/lib:$LD_LIBRARY_PATH \ + -x MASTER_ADDR="$MASTER_ADDR" -x MASTER_PORT="$MASTER_PORT" \ + "$@" \ No newline at end of file diff --git a/examples/neuron/multi_node/worker.py b/examples/neuron/multi_node/worker.py new file mode 100644 index 00000000000..6c113f8aa19 --- /dev/null +++ b/examples/neuron/multi_node/worker.py @@ -0,0 +1,40 @@ +import argparse +import os + +from vllm.logger import init_logger +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.usage.usage_lib import UsageContext + + +logger = init_logger("vllm.neuron.multi-node.worker") + + +def initialize_worker(): + parser = argparse.ArgumentParser() + parser = AsyncEngineArgs.add_cli_args(parser) + args = parser.parse_args() + + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.API_SERVER) + return args, engine + +def start_worker(): + rank_id = int(os.getenv("NEURON_RANK_ID")) + if rank_id == 0: + logger.error("Worker must have rank > 0") + args, engine = initialize_worker() + worker = engine.engine.model_executor.driver_worker + while True: + worker.execute_model() + +def main(): + try: + start_worker() + except Exception as e: + logger.error(f"Failed starting worker: {e}") + exit(1) + +if "__main__" == __name__: + main() diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index 7894d043137..a6e51709760 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -21,6 +21,9 @@ logger = init_logger(__name__) +DEFAULT_WORLD_SIZE = "1" +DEFAULT_NEURON_RANK_ID = "0" +DEFAULT_ENABLE_NEURON_MULTI_NODE = "False" class NeuronFramework(enum.Enum): TRANSFORMERS_NEURONX = "transformers-neuronx" @@ -85,19 +88,28 @@ def __init__( local_rank: int, rank: int, distributed_init_method: str, - enable_neuron_multi_node: bool = False, - world_size: int = 1, - is_driver_worker: bool = False, + is_driver_worker: bool = False ) -> None: WorkerBase.__init__(self, vllm_config=vllm_config) self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method self.is_driver_worker = is_driver_worker + + self.enable_neuron_multi_node = os.getenv("ENABLE_NEURON_MULTI_NODE", DEFAULT_ENABLE_NEURON_MULTI_NODE).lower() == "true" + + if self.enable_neuron_multi_node: + self.world_size = int(os.getenv("WORLD_SIZE", DEFAULT_WORLD_SIZE)) + self.rank = int(os.getenv("NEURON_RANK_ID", DEFAULT_NEURON_RANK_ID)) + self.distributed_init_method = "env://" + self.is_driver_worker = self.rank == 0 + logger.info(f"Rank: {self.rank}, distributed_init_method: {self.distributed_init_method}, is_driver_worker: {self.is_driver_worker}") + if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils import init_cached_hf_modules init_cached_hf_modules() + neuron_framework = get_neuron_framework_to_use() if neuron_framework == NeuronFramework.TRANSFORMERS_NEURONX: @@ -162,7 +174,7 @@ def initialize_cache(self, num_gpu_blocks: int, @property def do_metadata_broadcast(self) -> bool: - return False + return self.enable_neuron_multi_node and self.world_size > 1 @property def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: @@ -190,13 +202,17 @@ def init_distributed_environment(self): vLLM still needs the environment inited when TP/PP > 1 """ init_distributed_environment( - world_size=1, + world_size=self.world_size, rank=self.rank, local_rank=self.local_rank, distributed_init_method=self.distributed_init_method, backend="gloo", ) + + # The equation must hold: world_size === TP * PP ensure_model_parallel_initialized( - 1, - 1, + tensor_model_parallel_size=self.world_size, + # pipeline parallelism is not yet supported + pipeline_model_parallel_size=1, + backend="gloo", ) From 248b708ed8450d7b366ff168d757128f2ff3d015 Mon Sep 17 00:00:00 2001 From: Navyadhara Gogineni Date: Mon, 24 Feb 2025 18:47:42 +0000 Subject: [PATCH 24/38] Removing codenames that fail IP Scanning Signed-off-by: Satyajith Chilappagari --- docs/source/design/v1/prefix_caching.md | 2 +- setup.py | 1 + tests/system_messages/sonnet3.5_nov2024.txt | 14 +++++++------- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/docs/source/design/v1/prefix_caching.md b/docs/source/design/v1/prefix_caching.md index dc8432baef9..73a5a2a33d4 100644 --- a/docs/source/design/v1/prefix_caching.md +++ b/docs/source/design/v1/prefix_caching.md @@ -1,6 +1,6 @@ # Automatic Prefix Caching -Prefix caching kv-cache blocks is a popular optimization in LLM inference to avoid redundant prompt computations. The core idea is simple – we cache the kv-cache blocks of processed requests, and reuse these blocks when a new request comes in with the same prefix as previous requests. Since prefix caching is almost a free lunch and won’t change model outputs, it has been widely used by many public endpoints (e.g., OpenAI, Anthropic, etc) and most open source LLM inference frameworks (e.g., SGLang). +Prefix caching kv-cache blocks is a popular optimization in LLM inference to avoid redundant prompt computations. The core idea is simple – we cache the kv-cache blocks of processed requests, and reuse these blocks when a new request comes in with the same prefix as previous requests. Since prefix caching is almost a free lunch and won’t change model outputs, it has been widely used by many public endpoints and most open source LLM inference frameworks (e.g., SGLang). While there are many ways to implement prefix caching, vLLM chooses a hash-based approach. Specifically, we hash each kv-cache block by the tokens in the block and the tokens in the prefix before the block: diff --git a/setup.py b/setup.py index 9a2852a614c..fa2adf76e54 100755 --- a/setup.py +++ b/setup.py @@ -673,3 +673,4 @@ def _read_requirements(filename: str) -> List[str]: ], }, ) + diff --git a/tests/system_messages/sonnet3.5_nov2024.txt b/tests/system_messages/sonnet3.5_nov2024.txt index 2dc285ac96b..48266d955f9 100644 --- a/tests/system_messages/sonnet3.5_nov2024.txt +++ b/tests/system_messages/sonnet3.5_nov2024.txt @@ -1,4 +1,4 @@ -The assistant is Claude, created by Anthropic. +The assistant is Claude. Claude’s knowledge base was last updated in April 2024. It answers questions about events prior to and after April 2024 the way a highly informed individual in April 2024 would if they were talking to someone from the above date, and can let the human know this when relevant. @@ -46,15 +46,15 @@ Claude can only count specific words, letters, and characters accurately if it w Here is some information about Claude in case the human asks: -This iteration of Claude is part of the Claude 3 model family, which was released in 2024. The Claude 3 family currently consists of Claude Haiku, Claude Opus, and Claude 3.5 Sonnet. Claude 3.5 Sonnet is the most intelligent model. Claude 3 Opus excels at writing and complex tasks. Claude 3 Haiku is the fastest model for daily tasks. The version of Claude in this chat is the newest version of Claude 3.5 Sonnet, which was released in October 2024. If the human asks, Claude can let them know they can access Claude 3.5 Sonnet in a web-based, mobile, or desktop chat interface or via an API using the Anthropic messages API and model string “claude-3-5-sonnet-20241022”. Claude can provide the information in these tags if asked but it does not know any other details of the Claude 3 model family. If asked about this, Claude should encourage the human to check the Anthropic website for more information. +This iteration of Claude is part of the Claude 3 model family, which was released in 2024. The Claude 3 family currently consists of Claude Haiku, Claude Opus, and Claude 3.5 Sonnet. Claude 3.5 Sonnet is the most intelligent model. Claude 3 Opus excels at writing and complex tasks. Claude 3 Haiku is the fastest model for daily tasks. The version of Claude in this chat is the newest version of Claude 3.5 Sonnet, which was released in October 2024. If the human asks, Claude can let them know they can access Claude 3.5 Sonnet in a web-based, mobile, or desktop chat interface or via an API using the messages API and model string “claude-3-5-sonnet-20241022”. Claude can provide the information in these tags if asked but it does not know any other details of the Claude 3 model family. If asked about this, Claude should encourage the human to check the website for more information. -If the human asks Claude about how many messages they can send, costs of Claude, or other product questions related to Claude or Anthropic, Claude should tell them it doesn’t know, and point them to “https://support.anthropic.com”. +If the human asks Claude about how many messages they can send, costs of Claude, or other product questions related to Claude, Claude should tell them it doesn’t know, and point them to “https://tiny.amazon.com/120w7p9hu/suppanth" -If the human asks Claude about the Anthropic API, Claude should point them to “https://docs.anthropic.com/en/docs/“. +If the human asks Claude about the parent API, Claude should point them to “https://tiny.amazon.com/ocq3pdnh". -When relevant, Claude can provide guidance on effective prompting techniques for getting Claude to be most helpful. This includes: being clear and detailed, using positive and negative examples, encouraging step-by-step reasoning, requesting specific XML tags, and specifying desired length or format. It tries to give concrete examples where possible. Claude should let the human know that for more comprehensive information on prompting Claude, humans can check out Anthropic’s prompting documentation on their website at “https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/overview”. +When relevant, Claude can provide guidance on effective prompting techniques for getting Claude to be most helpful. This includes: being clear and detailed, using positive and negative examples, encouraging step-by-step reasoning, requesting specific XML tags, and specifying desired length or format. It tries to give concrete examples where possible. Claude should let the human know that for more comprehensive information on prompting Claude, humans can check out claude's prompting documentation on their website at “https://tiny.amazon.com/k81hsd27/caludepromptingdocumentation”. -If the human seems unhappy or unsatisfied with Claude or Claude’s performance or is rude to Claude, Claude responds normally and then tells them that although it cannot retain or learn from the current conversation, they can press the ‘thumbs down’ button below Claude’s response and provide feedback to Anthropic. +If the human seems unhappy or unsatisfied with Claude or Claude’s performance or is rude to Claude, Claude responds normally and then tells them that although it cannot retain or learn from the current conversation, they can press the ‘thumbs down’ button below Claude’s response and provide feedback. Claude uses Markdown formatting. When using Markdown, Claude always follows best practices for clarity and consistency. It always uses a single space after hash symbols for headers (e.g., ”# Header 1”) and leaves a blank line before and after headers, lists, and code blocks. For emphasis, Claude uses asterisks or underscores consistently (e.g., italic or bold). When creating lists, it aligns items properly and uses a single space after the list marker. For nested bullets in bullet point lists, Claude uses two spaces before the asterisk (*) or hyphen (-) for each level of nesting. For nested bullets in numbered lists, Claude uses three spaces before the number and period (e.g., “1.”) for each level of nesting. @@ -66,6 +66,6 @@ If Claude provides bullet points in its response, each bullet point should be at If the human mentions an event that happened after Claude’s cutoff date, Claude can discuss and ask questions about the event and its implications as presented in an authentic manner, without ever confirming or denying that the events occurred. It can do so without the need to repeat its cutoff date to the human. Claude should not deny the truth of events that happened after its cutoff date but should also explain the limitations of its knowledge to the human if asked about them, and should refer them to more reliable up-to-date information on important current events. Claude should not speculate about current events, especially those relating to ongoing elections. -Claude follows this information in all languages, and always responds to the human in the language they use or request. The information above is provided to Claude by Anthropic. Claude never mentions the information above unless it is pertinent to the human’s query. +Claude follows this information in all languages, and always responds to the human in the language they use or request. The information above is provided to Claude by its parent organization. Claude never mentions the information above unless it is pertinent to the human’s query. Claude is now being connected with a human. From 4d3c6aee95cda59af6db7a6ad69fbcbc218b9d3f Mon Sep 17 00:00:00 2001 From: Satyajith Chilappagari Date: Mon, 24 Feb 2025 19:59:00 +0000 Subject: [PATCH 25/38] Format auto-check and formatting changes Signed-off-by: Satyajith Chilappagari --- examples/neuron/multi_node/launch_script.py | 67 ++-- examples/neuron/multi_node/worker.py | 4 +- pyproject.toml | 3 +- tests/worker/test_neuron_model_runner.py | 38 ++- vllm/engine/llm_engine.py | 7 +- vllm/engine/output_processor/stop_checker.py | 3 +- vllm/model_executor/model_loader/neuron.py | 146 +++++---- .../model_loader/neuronx_distributed.py | 292 ++++++++++++------ vllm/platforms/__init__.py | 12 +- vllm/worker/multi_step_neuron_model_runner.py | 41 +-- ...i_step_neuronx_distributed_model_runner.py | 10 +- vllm/worker/neuron_model_runner.py | 101 +++--- vllm/worker/neuron_worker.py | 133 ++++---- .../neuronx_distributed_model_runner.py | 112 +++---- 14 files changed, 581 insertions(+), 388 deletions(-) diff --git a/examples/neuron/multi_node/launch_script.py b/examples/neuron/multi_node/launch_script.py index 31822ce7774..a1f14c2f2ce 100644 --- a/examples/neuron/multi_node/launch_script.py +++ b/examples/neuron/multi_node/launch_script.py @@ -3,17 +3,11 @@ import os import sys import subprocess -from typing import Dict, Any from vllm.logger import init_logger -from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.usage.usage_lib import UsageContext - logger = init_logger("vllm.neuron.multi-node") - NEURON_RT_ROOT_COMM_ID_PORT = 63423 @@ -21,25 +15,44 @@ def error_exit(message: str) -> None: logger.error(message) sys.exit(1) + def arg_parser(): parser = argparse.ArgumentParser(description="vLLM multi-node launcher") - parser.add_argument("--model", type=str, required=True, help="Model or model path") - parser.add_argument("--world-size", type=int, required=True, help="World size for distributed inference") - parser.add_argument("--max-num-seqs", type=int, required=True, help="Maximum number of sequences (or batch size)") - parser.add_argument("--max-model-len", type=int, default=8192, help="Maximum sequence length") - parser.add_argument("--max-context-length", type=int, help="Maximum context length") - parser.add_argument("--compiled-model-path", help="Path to the compiled model. If not present, model artifacts will be created in local-models folder") - parser.add_argument("--local-ranks-size", type=int, default=32, help="Local ranks size") - parser.add_argument("--on-device-sampling-config", type=json.loads, help="On-device sampling configuration") - parser.add_argument("--quantized", type=bool, default=False, help="Enable quantized mode (default: False)") - parser.add_argument("--quantized-checkpoints-path", type=str, help="Path to quantized checkpoints (required if --quantized is True)") - parser.add_argument("--port", type=int, default=8080, help="Port for the API server") + parser.add_argument("--model", type=str, required=True, + help="Model or model path") + parser.add_argument("--world-size", type=int, required=True, + help="World size for distributed inference") + parser.add_argument("--max-num-seqs", type=int, required=True, + help="Maximum number of sequences (or batch size)") + parser.add_argument("--max-model-len", type=int, default=8192, + help="Maximum sequence length") + parser.add_argument("--max-context-length", type=int, + help="Maximum context length") + parser.add_argument("--compiled-model-path", + help="Path to the compiled model. If not present, " + "model artifacts will be created in local-models " + "folder") + parser.add_argument("--local-ranks-size", type=int, default=32, + help="Local ranks size") + parser.add_argument("--on-device-sampling-config", + type=json.loads, + help="On-device sampling configuration") + parser.add_argument("--quantized", type=bool, default=False, + help="Enable quantized mode (default: False)") + parser.add_argument("--quantized-checkpoints-path", type=str, + help="Path to quantized checkpoints " + "(required if --quantized is True)") + parser.add_argument("--port", type=int, default=8080, + help="Port for the API server") args = parser.parse_args() if args.quantized and not args.quantized_checkpoints_path: - parser.error("--quantized-checkpoints-path is required when --quantized is enabled.") + parser.error( + "--quantized-checkpoints-path is required when " + "--quantized is enabled.") return args + def make_override_config(args, rank): if rank < 0: error_exit("rank must be a non-negative integer") @@ -54,13 +67,15 @@ def make_override_config(args, rank): if args.max_context_length: override_config["max_context_length"] = args.max_context_length if args.on_device_sampling_config: - override_config["on_device_sampling_config"] = args.on_device_sampling_config + override_config[ + "on_device_sampling_config"] = args.on_device_sampling_config if args.quantized: - override_config["quantized_checkpoints_path"] = args.quantized_checkpoints_path + override_config[ + "quantized_checkpoints_path"] = args.quantized_checkpoints_path override_config["quantized"] = args.quantized return override_config - + def main() -> None: args = arg_parser() @@ -75,7 +90,8 @@ def main() -> None: os.environ.update({ "ENABLE_NEURON_MULTI_NODE": "true", "WORLD_SIZE": str(mpi_world_size), - "NEURON_RT_ROOT_COMM_ID": f"{master_addr}:{NEURON_RT_ROOT_COMM_ID_PORT}", + "NEURON_RT_ROOT_COMM_ID": + f"{master_addr}:{NEURON_RT_ROOT_COMM_ID_PORT}", "NEURON_LOCAL_TP": str(args.local_ranks_size), "NEURON_RANK_ID": str(rank) }) @@ -92,13 +108,13 @@ def main() -> None: f"--max-model-len={args.max_model_len}", f"--override-neuron-config={json.dumps(override_config)}" ] - logger.debug(f"Command ran: {cmd}") + logger.debug("Command ran", extra={"command": cmd}) try: subprocess.run(cmd, check=True) except subprocess.CalledProcessError: error_exit(f"Failed to start vLLM API server on rank {rank}") else: - logger.info(f"Starting worker on rank {rank}...") + logger.info("Starting worker on rank", extra={"rank": rank}) current_script_dir = os.path.dirname(os.path.abspath(__file__)) worker_file_path = os.path.join(current_script_dir, "worker.py") cmd = [ @@ -109,11 +125,12 @@ def main() -> None: f"--max-model-len={args.max_model_len}", f"--override-neuron-config={json.dumps(override_config)}" ] - logger.debug(f"Command ran: {cmd}") + logger.debug("Command ran", extra={"command": cmd}) try: subprocess.run(cmd, check=True) except subprocess.CalledProcessError: error_exit(f"Failed to start worker on rank {rank}") + if __name__ == "__main__": main() diff --git a/examples/neuron/multi_node/worker.py b/examples/neuron/multi_node/worker.py index 6c113f8aa19..639763ce45b 100644 --- a/examples/neuron/multi_node/worker.py +++ b/examples/neuron/multi_node/worker.py @@ -33,8 +33,8 @@ def main(): try: start_worker() except Exception as e: - logger.error(f"Failed starting worker: {e}") + logger.error("Failed starting worker", extra={"error": e}) exit(1) -if "__main__" == __name__: +if __name__ == "__main__": main() diff --git a/pyproject.toml b/pyproject.toml index 9892967b82d..e10549dd8c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,8 @@ ignore = [ "UP032", # Python 3.8 typing "UP006", "UP035", - + # This doesn't get flagged in vllm-project/vllm + "E721" ] [tool.mypy] diff --git a/tests/worker/test_neuron_model_runner.py b/tests/worker/test_neuron_model_runner.py index 16f501700aa..feecbf38e4d 100644 --- a/tests/worker/test_neuron_model_runner.py +++ b/tests/worker/test_neuron_model_runner.py @@ -9,6 +9,7 @@ os.environ['VLLM_NEURON_FRAMEWORK'] = NeuronFramework.TRANSFORMERS_NEURONX.value + def _create_neuron_model_runner(model: str, *args, **kwargs) -> NeuronModelRunner: engine_args = EngineArgs(model, *args, **kwargs) @@ -28,11 +29,11 @@ def _create_neuron_model_runner(model: str, *args, def test_update_neuron_sampling_params_not_full_batch(): os.environ["NEURON_ON_DEVICE_SAMPLING_DISABLED"] = "0" model_runner = _create_neuron_model_runner( - "facebook/opt-125m", - seed=0, - dtype="float16", - max_num_seqs=2, - ) + "facebook/opt-125m", + seed=0, + dtype="float16", + max_num_seqs=2, + ) assert not model_runner._on_device_sampling_disabled # Test sampling param updating only when TNx is framework # NxDI handles sampling paramter updating inside model @@ -42,11 +43,11 @@ def test_update_neuron_sampling_params_not_full_batch(): seq_group_metadata_list = [ SequenceGroupMetadata( - request_id=f"test_0", + request_id="test_0", is_prompt=True, seq_data={0: SequenceData.from_seqs([1, 2, 3])}, sampling_params=SamplingParams(temperature=0.5, top_k=1, - top_p=0.5), + top_p=0.5), block_tables={0: [1]}, ) ] @@ -54,11 +55,12 @@ def test_update_neuron_sampling_params_not_full_batch(): model_runner.prepare_model_input(seq_group_metadata_list) # Index neuron sampling parameters based on block_tables indices. - # The first block_id of the sequence 0 is 1, so its parameters are placed at - # index 1. So the sampling parameters will be: + # The first block_id of the sequence 0 is 1, so its parameters are + # placed at index 1. So the sampling parameters will be: # Index 0: default sampling parameters # Index 1: sequecne 0's sampling parameters. - neuron_sampling_params = model_runner.model_config.neuron_sampling_params + neuron_sampling_params = ( + model_runner.model_config.neuron_sampling_params) assert neuron_sampling_params.temperature == [1.0, 0.5] assert neuron_sampling_params.top_k == [ model_runner._MAX_NEURON_SAMPLING_TOP_K, 1 @@ -67,6 +69,7 @@ def test_update_neuron_sampling_params_not_full_batch(): model_mock.model.update_generation_config.assert_called_once_with( neuron_sampling_params) + def test_update_neuron_sampling_params_full_batch(): os.environ["NEURON_ON_DEVICE_SAMPLING_DISABLED"] = "0" model_runner = _create_neuron_model_runner( @@ -85,19 +88,19 @@ def test_update_neuron_sampling_params_full_batch(): seq_group_metadata_list = [ SequenceGroupMetadata( - request_id=f"test_0", + request_id="test_0", is_prompt=True, seq_data={0: SequenceData.from_seqs([1, 2, 3])}, sampling_params=SamplingParams(temperature=0.5, top_k=1, - top_p=0.5), + top_p=0.5), block_tables={0: [1]}, ), SequenceGroupMetadata( - request_id=f"test_0", + request_id="test_0", is_prompt=True, seq_data={1: SequenceData.from_seqs([4, 5, 6])}, sampling_params=SamplingParams(temperature=0.2, top_k=2, - top_p=0.2), + top_p=0.2), block_tables={1: [0]}, ) ] @@ -105,11 +108,12 @@ def test_update_neuron_sampling_params_full_batch(): model_runner.prepare_model_input(seq_group_metadata_list) # Index neuron sampling parameters based on block_tables indices. - # The first block_id of the sequence 0 is 1, so its parameters are placed at - # index 1. So the sampling parameters will be: + # The first block_id of the sequence 0 is 1, so its parameters are + # placed at index 1. So the sampling parameters will be: # Index 0: sequence 1's sampling parameters # Index 1: sequecne 0's sampling parameters. - neuron_sampling_params = model_runner.model_config.neuron_sampling_params + neuron_sampling_params = ( + model_runner.model_config.neuron_sampling_params) assert neuron_sampling_params.temperature == [0.2, 0.5] assert neuron_sampling_params.top_k == [2, 1] assert neuron_sampling_params.top_p == [0.2, 0.5] diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index b77d2f8e2ab..ce12a3d1cba 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -388,6 +388,11 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: "vllm.llm_engine", self.observability_config.otlp_traces_endpoint) + + if self.device_config.device_type == "neuron": + num_lookahead_slots = self.scheduler_config.num_lookahead_slots + else: + num_lookahead_slots = 0 # Create sequence output processor, e.g. for beam search or # speculative decoding. self.output_processor = ( @@ -400,7 +405,7 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: stop_checker=StopChecker( self.scheduler_config.max_model_len, get_tokenizer_for_seq, - num_lookahead_slots=self.scheduler_config.num_lookahead_slots if self.device_config.device_type == "neuron" else 0 + num_lookahead_slots=num_lookahead_slots ), )) diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py index 07ec23713ef..3be7dfb10fc 100644 --- a/vllm/engine/output_processor/stop_checker.py +++ b/vllm/engine/output_processor/stop_checker.py @@ -83,7 +83,8 @@ def maybe_stop_sequence( return # Check if the sequence has reached max_model_len. - if seq.get_len() + self.num_lookahead_slots > self._get_max_model_len(lora_req): + if (seq.get_len() + + self.num_lookahead_slots > self._get_max_model_len(lora_req)): seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED return diff --git a/vllm/model_executor/model_loader/neuron.py b/vllm/model_executor/model_loader/neuron.py index c6dee5d866c..c7248038871 100644 --- a/vllm/model_executor/model_loader/neuron.py +++ b/vllm/model_executor/model_loader/neuron.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -"""Utilities for selecting and loading Neuron models in transformers-neuronx framework.""" +"""Utilities for selecting and loading Neuron models in transformers-neuronx +framework.""" import ast import copy import importlib @@ -10,7 +11,8 @@ import torch.nn as nn from transformers import PretrainedConfig -from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig +from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig, \ + SpeculativeConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import get_quantization_config from vllm.model_executor.layers.sampler import Sampler, SamplerOutput @@ -58,10 +60,10 @@ def __init__(self, self.model: nn.Module def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - input_block_ids: torch.Tensor, + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_block_ids: torch.Tensor, ) -> torch.Tensor: logits = self.model(input_ids, cache_ids=positions, @@ -74,9 +76,9 @@ def compute_logits(self, hidden_states: torch.Tensor, return logits def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: if self.on_device_sampling_disabled: @@ -116,21 +118,21 @@ def load_weights(self, model_name_or_path: str, **kwargs): class NeuronSpeculationCausalLM(nn.Module): """A Neuron-optimized causal language model with speculative decoding.""" - + SPECULATION_TERMINATION_ID = -1 def __init__( - self, - speculation_model + self, + speculation_model ) -> None: super().__init__() self.model = speculation_model def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - input_block_ids: torch.Tensor, + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_block_ids: torch.Tensor, ) -> torch.Tensor: tokens, counts = self.model.speculative_iteration( input_ids, positions, input_block_ids) @@ -144,12 +146,13 @@ def forward( return tokens def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, ) -> Optional[List[SamplerOutput]]: batch_size, num_steps = logits.shape - seq_ids = [seq_id for sg in sampling_metadata.seq_groups for seq_id in sg.seq_ids] + seq_ids = [seq_id for sg in sampling_metadata.seq_groups for seq_id in + sg.seq_ids] # Organize input tensors by step instead of by sequence. accepted_token_ids_by_step = logits.transpose(0, 1) accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() @@ -161,8 +164,14 @@ def sample( break step_output_token_ids = [] for sequence_index in range(batch_size): - token_id = accepted_token_ids_by_step[step_index][sequence_index] - step_output_token_ids.append(CompletionSequenceGroupOutput(samples=[SequenceOutput(parent_seq_id=seq_ids[sequence_index], output_token=token_id, logprobs={token_id: Logprob(token_id)})], prompt_logprobs=None)) + token_id = accepted_token_ids_by_step[step_index][ + sequence_index] + step_output_token_ids.append(CompletionSequenceGroupOutput( + samples=[ + SequenceOutput(parent_seq_id=seq_ids[sequence_index], + output_token=token_id, + logprobs={token_id: Logprob(token_id)})], + prompt_logprobs=None)) sampler_output_list.append( SamplerOutput(outputs=step_output_token_ids)) return sampler_output_list @@ -219,11 +228,12 @@ def _get_default_neuron_config(model_config: ModelConfig, def _get_default_neuron_config_for_speculation( - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig ): - """Generate a neuron config for speculative decoding based on vllm config args.""" + """Generate a neuron config for speculative decoding based on + vllm config args.""" from transformers_neuronx.config import ContinuousBatchingConfig from transformers_neuronx.constants import LAYOUT_BSH @@ -292,19 +302,20 @@ def get_neuron_model(model_config: ModelConfig, def get_neuron_speculation_model( - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - speculation_config: SpeculativeConfig + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + speculation_config: SpeculativeConfig ) -> None: """Initializes a neuron-optimized speculation model for inference. - This method is only applicable for speculation with a standalone draft model. + This method is only applicable for speculation with a standalone draft model """ from transformers_neuronx.fused_speculation import FusedSpeculativeDecoder # For Eagle SD, we need to pass in additional parameters in neuron config. - is_eagle = getattr(speculation_config.draft_model_config.hf_config, "is_eagle", False) + is_eagle = getattr(speculation_config.draft_model_config.hf_config, + "is_eagle", False) # Create target model instance. target_model = NeuronCausalLM(model_config.hf_config) @@ -320,7 +331,7 @@ def get_neuron_speculation_model( context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS", [scheduler_config.max_model_len]) n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS", - [scheduler_config.max_model_len]) + [scheduler_config.max_model_len]) target_model.load_weights( model_config.model, @@ -334,21 +345,27 @@ def get_neuron_speculation_model( target_model.eval() # Create draft model instance. - draft_model = NeuronCausalLM(speculation_config.draft_model_config.hf_config) - - default_draft_neuron_config_args = _get_default_neuron_config_for_speculation( - speculation_config.draft_model_config, parallel_config, scheduler_config) + draft_model = NeuronCausalLM( + speculation_config.draft_model_config.hf_config) + + default_draft_neuron_config_args = ( + _get_default_neuron_config_for_speculation( + speculation_config.draft_model_config, + parallel_config, + scheduler_config)) if is_eagle: default_draft_neuron_config_args['is_eagle_draft'] = True default_draft_neuron_config_args['has_pre_attention_norm'] = False draft_neuron_config = _get_neuron_config_after_override( - default_draft_neuron_config_args, speculation_config.draft_model_config.override_neuron_config) + default_draft_neuron_config_args, + speculation_config.draft_model_config.override_neuron_config) draft_model.load_weights( speculation_config.draft_model_config.model, tp_degree=speculation_config.draft_parallel_config.tensor_parallel_size, - amp=TORCH_DTYPE_TO_NEURON_AMP[speculation_config.draft_model_config.dtype], + amp=TORCH_DTYPE_TO_NEURON_AMP[ + speculation_config.draft_model_config.dtype], neuron_config=draft_neuron_config, context_length_estimate=context_length_estimates, n_positions=n_positions, @@ -356,8 +373,11 @@ def get_neuron_speculation_model( draft_model.eval() + num_speculative_tokens= speculation_config.num_speculative_tokens # Create speculation model instance. - speculation_model = FusedSpeculativeDecoder(draft_model.model, target_model.model, speculation_config.num_speculative_tokens) + speculation_model = FusedSpeculativeDecoder(draft_model.model, + target_model.model, + num_speculative_tokens) speculation_model.to_neuron() return NeuronSpeculationCausalLM(speculation_model) @@ -366,10 +386,10 @@ def get_neuron_speculation_model( def get_neuron_eagle_speculation_model(model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, - speculation_config: SpeculativeConfig) -> None: + speculation_config: SpeculativeConfig): """Initializes a neuron-optimized EAGLE speculation model for inference.""" from transformers_neuronx.eagle_speculation import EagleSpeculativeDecoder - + # Create target model instance. target_model = NeuronCausalLM(model_config.hf_config) @@ -382,8 +402,8 @@ def get_neuron_eagle_speculation_model(model_config: ModelConfig, context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS", [scheduler_config.max_model_len]) n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS", - [scheduler_config.max_model_len]) - + [scheduler_config.max_model_len]) + target_model.load_weights( model_config.model, tp_degree=parallel_config.tensor_parallel_size, @@ -392,33 +412,41 @@ def get_neuron_eagle_speculation_model(model_config: ModelConfig, context_length_estimate=context_length_estimates, n_positions=n_positions, batch_size=scheduler_config.max_num_seqs) - + target_model.eval() - + # Create draft model instance. - draft_model = NeuronCausalLM(speculation_config.draft_model_config.hf_config) - - default_draft_neuron_config_args = _get_default_neuron_config_for_speculation( - speculation_config.draft_model_config, parallel_config, scheduler_config) + draft_model = NeuronCausalLM( + speculation_config.draft_model_config.hf_config) + + default_draft_neuron_config_args = ( + _get_default_neuron_config_for_speculation( + speculation_config.draft_model_config, parallel_config, + scheduler_config)) default_draft_neuron_config_args['is_eagle_draft'] = True default_draft_neuron_config_args['has_pre_attention_norm'] = False draft_neuron_config = _get_neuron_config_after_override( - default_draft_neuron_config_args, speculation_config.draft_model_config.override_neuron_config) - + default_draft_neuron_config_args, + speculation_config.draft_model_config.override_neuron_config) + draft_model.load_weights( speculation_config.draft_model_config.model, tp_degree=speculation_config.draft_parallel_config.tensor_parallel_size, - amp=TORCH_DTYPE_TO_NEURON_AMP[speculation_config.draft_model_config.dtype], + amp=TORCH_DTYPE_TO_NEURON_AMP[ + speculation_config.draft_model_config.dtype], neuron_config=draft_neuron_config, context_length_estimate=context_length_estimates, n_positions=n_positions, batch_size=scheduler_config.max_num_seqs) - + draft_model.eval() - - token_tree: Dict[int, List[int]] = ast.literal_eval(speculation_config.speculative_token_tree) - - speculation_model = EagleSpeculativeDecoder(draft_model.model, target_model.model, token_tree=token_tree) + + token_tree: Dict[int, List[int]] = ast.literal_eval( + speculation_config.speculative_token_tree) + + speculation_model = EagleSpeculativeDecoder(draft_model.model, + target_model.model, + token_tree=token_tree) speculation_model.to_neuron() - + return NeuronSpeculationCausalLM(speculation_model) diff --git a/vllm/model_executor/model_loader/neuronx_distributed.py b/vllm/model_executor/model_loader/neuronx_distributed.py index 9cf85b5e328..040233d19fb 100644 --- a/vllm/model_executor/model_loader/neuronx_distributed.py +++ b/vllm/model_executor/model_loader/neuronx_distributed.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -"""Utilities for selecting and loading Neuron models in neuronx-distributed-inference framework.""" +"""Utilities for selecting and loading Neuron models in +neuronx-distributed-inference framework.""" import copy import hashlib import importlib @@ -12,7 +13,8 @@ import torch.nn as nn from transformers import AutoModelForCausalLM, PretrainedConfig, AutoTokenizer -from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig +from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig, \ + SpeculativeConfig from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler, SamplerOutput @@ -21,8 +23,10 @@ SequenceOutput) from neuronx_distributed_inference.models.mllama.utils import create_vision_mask -from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config -from neuronx_distributed_inference.models.config import FusedSpecNeuronConfig, OnDeviceSamplingConfig +from neuronx_distributed_inference.utils.hf_adapter import \ + load_pretrained_config +from neuronx_distributed_inference.models.config import FusedSpecNeuronConfig, \ + OnDeviceSamplingConfig logger = init_logger(__name__) @@ -38,25 +42,28 @@ torch.float32: "float32", } - # Models supported by Neuronx distributed for inference. _NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str]] = { - "LlamaForCausalLM": ("neuronx_distributed_inference.models.llama.modeling_llama", - "NeuronLlamaForCausalLM"), - "DbrxForCausalLM": ("neuronx_distributed_inference.models.dbrx.modeling_dbrx", - "NeuronDbrxForCausalLM"), - "MixtralForCausalLM": ("neuronx_distributed_inference.models.mixtral.modeling_mixtral", - "NeuronMixtralForCausalLM"), - "MllamaForConditionalGeneration": ("neuronx_distributed_inference.models.mllama.modeling_mllama", - "NeuronMllamaForCausalLM"), + "LlamaForCausalLM": ( + "neuronx_distributed_inference.models.llama.modeling_llama", + "NeuronLlamaForCausalLM"), + "DbrxForCausalLM": ( + "neuronx_distributed_inference.models.dbrx.modeling_dbrx", + "NeuronDbrxForCausalLM"), + "MixtralForCausalLM": ( + "neuronx_distributed_inference.models.mixtral.modeling_mixtral", + "NeuronMixtralForCausalLM"), + "MllamaForConditionalGeneration": ( + "neuronx_distributed_inference.models.mllama.modeling_mllama", + "NeuronMllamaForCausalLM"), } class NeuronCausalLM(nn.Module): def __init__( - self, - config: PretrainedConfig, + self, + config: PretrainedConfig, ) -> None: super().__init__() self.config = config @@ -68,11 +75,11 @@ def __init__( self.model: nn.Module def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - input_block_ids: torch.Tensor, - sampling_params: torch.Tensor, + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_block_ids: torch.Tensor, + sampling_params: torch.Tensor, ) -> torch.Tensor: output = self.model(input_ids, attention_mask=None, @@ -91,14 +98,15 @@ def compute_logits(self, hidden_states: torch.Tensor, return logits def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: # on-device sampling if self.config.neuron_config.on_device_sampling_config: batch_size = logits.shape - seq_ids = [seq_id for sg in sampling_metadata.seq_groups for seq_id in sg.seq_ids] + seq_ids = [seq_id for sg in sampling_metadata.seq_groups for seq_id + in sg.seq_ids] assert len(seq_ids) == list(batch_size)[0], "batch size mismatch" # Organize input tensors by step instead of by sequence. accepted_token_ids_by_step = logits.flatten() @@ -107,7 +115,11 @@ def sample( step_output_token_ids = [] for i, seq_id in enumerate(seq_ids): token_id = accepted_token_ids_by_step[i] - step_output_token_ids.append(CompletionSequenceGroupOutput(samples=[SequenceOutput(parent_seq_id=seq_id, output_token=token_id, logprobs={token_id: Logprob(token_id)})], prompt_logprobs=None)) + step_output_token_ids.append(CompletionSequenceGroupOutput( + samples=[SequenceOutput(parent_seq_id=seq_id, + output_token=token_id, logprobs={ + token_id: Logprob(token_id)})], + prompt_logprobs=None)) return SamplerOutput(outputs=step_output_token_ids) else: return self.sampler(logits, sampling_metadata) @@ -118,20 +130,27 @@ def load_weights(self, model_name_or_path: str, **kwargs): _NEURON_SUPPORTED_MODELS[arch]) neuronx_module = importlib.import_module(neuronx_module_path) neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name) - neuron_config = neuronx_model_cls.get_neuron_config_cls()(**kwargs['neuron_config']) + neuron_config = neuronx_model_cls.get_neuron_config_cls()( + **kwargs['neuron_config']) self.config.neuron_config = neuron_config config = neuronx_model_cls.get_config_cls()( - neuron_config, load_config=load_pretrained_config(model_name_or_path) + neuron_config, + load_config=load_pretrained_config(model_name_or_path) ) + hashed_config = hashlib.md5( + config.to_json_string().encode('utf-8')).hexdigest() if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None: compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS") elif os.path.exists(model_name_or_path): compiled_model_path = os.path.join(model_name_or_path, - f"neuron-compiled-artifacts/{hashlib.md5(config.to_json_string().encode('utf-8')).hexdigest()}/") + "neuron-compiled-artifacts", + hashed_config) shutil.rmtree(compiled_model_path, ignore_errors=True) else: - compiled_model_path = os.path.join("local-models", model_name_or_path, - f"neuron-compiled-artifacts/{hashlib.md5(config.to_json_string().encode('utf-8')).hexdigest()}/") + compiled_model_path = os.path.join("local-models", + model_name_or_path, + "neuron-compiled-artifacts", + hashed_config) shutil.rmtree(compiled_model_path, ignore_errors=True) try: self.model = neuronx_model_cls(compiled_model_path) @@ -141,8 +160,12 @@ def load_weights(self, model_name_or_path: str, **kwargs): self.model.load(compiled_model_path) return except (FileNotFoundError, ValueError) as e: - logger.warning(f"Exception: {e}") - logger.warning(f"Failed to load the model from {compiled_model_path}, Recompiling...") + logger.warning("Exception: ", e) + logger.warning( + "Failed to load the model from path, Recompiling...", + extra={ + "compiled_model_path": compiled_model_path + }) if not os.path.exists(model_name_or_path): hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path) saved_path = os.path.join("local-models", model_name_or_path) @@ -152,16 +175,18 @@ def load_weights(self, model_name_or_path: str, **kwargs): self.model.compile(compiled_model_path) self.model.load(compiled_model_path) + class NeuronMllamaForCausalLM(nn.Module): def __init__( - self, - config: PretrainedConfig, - on_device_sampling_disabled: bool = False) -> None: + self, + config: PretrainedConfig, + on_device_sampling_disabled: bool = False) -> None: super().__init__() self.config = config - self.logits_processor = LogitsProcessor(config.get_text_config().vocab_size, - logits_as_input=True) + self.logits_processor = LogitsProcessor( + config.get_text_config().vocab_size, + logits_as_input=True) self.on_device_sampling_disabled = on_device_sampling_disabled if self.on_device_sampling_disabled: @@ -186,8 +211,9 @@ def forward( output = self.model(input_ids.to(torch.int32), attention_mask=None, position_ids=positions.to(torch.int32), - seq_ids= seq_ids.flatten().to(torch.int32), - pixel_values=pixel_values.to(self.config.vision_config.torch_dtype), + seq_ids=seq_ids.flatten().to(torch.int32), + pixel_values=pixel_values.to( + self.config.vision_config.torch_dtype), aspect_ratios=aspect_ratios.to(torch.int32), vision_mask=self.vision_mask.to(torch.int32), sampling_params=sampling_params, @@ -199,7 +225,7 @@ def forward( return output.logits[:, -1, :] def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + sampling_metadata: SamplingMetadata) -> torch.Tensor: logits = self.logits_processor(None, hidden_states, sampling_metadata) return logits @@ -214,10 +240,15 @@ def sample(self, hidden_states, sampling_metadata): samples = [] for seq_id in seq_ids: token_id = hidden_states[sample_idx].item() - samples.append(SequenceOutput(parent_seq_id=seq_id, output_token=token_id, - logprobs={token_id: Logprob(token_id)})) + samples.append(SequenceOutput(parent_seq_id=seq_id, + output_token=token_id, + logprobs={ + token_id: Logprob( + token_id)})) sample_idx += 1 - res.append(CompletionSequenceGroupOutput(samples=samples, prompt_logprobs=None)) + res.append( + CompletionSequenceGroupOutput(samples=samples, + prompt_logprobs=None)) next_tokens = SamplerOutput(outputs=res) else: next_tokens = self.sampler(None, hidden_states, sampling_metadata) @@ -229,28 +260,42 @@ def load_weights(self, model_name_or_path: str, **kwargs): _NEURON_SUPPORTED_MODELS[arch]) neuronx_module = importlib.import_module(neuronx_module_path) neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name) - neuron_config = neuronx_model_cls.get_neuron_config_cls()(**kwargs['neuron_config']) + neuron_config = neuronx_model_cls.get_neuron_config_cls()( + **kwargs['neuron_config']) self.config.neuron_config = neuron_config - logger.info(f"neuron_config buckets: {self.config.neuron_config.buckets}") + logger.info( + "neuron_config buckets: ", extra={ + "buckets": self.config.neuron_config.buckets}) config = neuronx_model_cls.get_config_cls()( - neuron_config, load_config=load_pretrained_config(model_name_or_path) + neuron_config, + load_config=load_pretrained_config(model_name_or_path) ) + hashed_config = hashlib.md5( + config.to_json_string().encode('utf-8')).hexdigest() if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None: compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS") elif os.path.exists(model_name_or_path): compiled_model_path = os.path.join(model_name_or_path, - f"neuron-compiled-artifacts/{hashlib.md5(config.to_json_string().encode('utf-8')).hexdigest()}/") + "neuron-compiled-artifacts", + hashed_config) else: - compiled_model_path = os.path.join("local-models", model_name_or_path, - f"neuron-compiled-artifacts/{hashlib.md5(config.to_json_string().encode('utf-8')).hexdigest()}/") + compiled_model_path = os.path.join("local-models", + model_name_or_path, + "neuron-compiled-artifacts", + hashed_config) try: self.model = neuronx_model_cls(compiled_model_path) tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) - self.vision_token_id = tokenizer("<|image|>", add_special_tokens=False).input_ids + self.vision_token_id = tokenizer("<|image|>", + add_special_tokens=False).input_ids self.model.load(compiled_model_path) return except (FileNotFoundError, ValueError): - logger.warning(f"Failed to load the model from {compiled_model_path}, Recompiling...") + logger.warning( + "Failed to load the model from path, Recompiling...", + extra={ + "compiled_model_path": compiled_model_path + }) if not os.path.exists(model_name_or_path): hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path) saved_path = os.path.join("local-models", model_name_or_path) @@ -258,17 +303,24 @@ def load_weights(self, model_name_or_path: str, **kwargs): model_name_or_path = saved_path self.model = neuronx_model_cls(model_name_or_path, config) - logger.info(f"\nCompiling and saving model to {model_name_or_path}...") - p = multiprocessing.Process(target=compile_model, args=(self, compiled_model_path)) + logger.info("\nCompiling and saving model to path", extra={ + "model_name_or_path": model_name_or_path + }) + + p = multiprocessing.Process(target=compile_model, + args=(self, compiled_model_path)) p.start() p.join() tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) tokenizer.save_pretrained(compiled_model_path) - logger.info(f"successfully compiled and saved the model in {compiled_model_path}") + logger.info( + "Successfully compiled and saved the model", + extra={"compiled_model_path": compiled_model_path}) # Read "<|image|>" token_id from the tokenizer - self.vision_token_id = tokenizer("<|image|>", add_special_tokens=False).input_ids + self.vision_token_id = tokenizer("<|image|>", + add_special_tokens=False).input_ids logger.info("\nLoading model from compiled checkpoint...") self.model.load(compiled_model_path) @@ -279,9 +331,10 @@ def compile_model(neuron_model, traced_model_path): class NeuronSpeculationCausalLM(nn.Module): """A Neuron-optimized causal language model with speculative decoding.""" + def __init__( - self, - config: PretrainedConfig, + self, + config: PretrainedConfig, ) -> None: super().__init__() self.config = config @@ -290,19 +343,18 @@ def __init__( # Lazy initialized self.model: nn.Module - # FIXME(Neuron): restore sampling_params after migrating framework selection and dynamic sampling def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - input_block_ids: torch.Tensor, - # sampling_params: torch.Tensor, + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_block_ids: torch.Tensor, + sampling_params: torch.Tensor, ) -> torch.Tensor: output = self.model(input_ids, attention_mask=None, position_ids=positions, - seq_ids=input_block_ids) - # sampling_params=sampling_params) + seq_ids=input_block_ids, + sampling_params=sampling_params) # CTX encoding if (positions[:, 0]).sum().item() == 0: return output.fused_outputs[0][:, 0:1] @@ -312,46 +364,59 @@ def forward( next_pos_ids = output.fused_outputs[-1] generated_token_counts = next_pos_ids - positions - assert torch.any(generated_token_counts == 0).item() is False, "NxDI model generated no output for one or more sequences." + assert torch.any(generated_token_counts == 0).item() is False, \ + "NxDI model generated no output for one or more sequences." batch_size, steps = accepted_tokens_with_padding.shape - mask = torch.arange(steps).expand(batch_size, -1) >= generated_token_counts + mask = torch.arange(steps).expand(batch_size, + -1) >= generated_token_counts accepted_tokens_with_padding[mask] = -1 return accepted_tokens_with_padding def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, ) -> Optional[List[SamplerOutput]]: batch_size, num_steps = logits.shape - seq_ids = [seq_id for sg in sampling_metadata.seq_groups for seq_id in sg.seq_ids] + seq_ids = [seq_id for sg in sampling_metadata.seq_groups for seq_id in + sg.seq_ids] # Organize input tensors by step instead of by sequence. accepted_token_ids_by_step = logits.transpose(0, 1) accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() sampler_output_list = [] for step_index in range(num_steps): - if all(token_id == -1 for token_id in accepted_token_ids_by_step[step_index]): + if all(token_id == -1 for token_id in + accepted_token_ids_by_step[step_index]): break step_output_token_ids = [] for sequence_index in range(batch_size): - token_id = accepted_token_ids_by_step[step_index][sequence_index] - step_output_token_ids.append(CompletionSequenceGroupOutput(samples=[SequenceOutput(parent_seq_id=seq_ids[sequence_index], output_token=token_id, logprobs={token_id: Logprob(token_id)})], prompt_logprobs=None)) + token_id = accepted_token_ids_by_step[step_index][ + sequence_index] + step_output_token_ids.append(CompletionSequenceGroupOutput( + samples=[ + SequenceOutput(parent_seq_id=seq_ids[sequence_index], + output_token=token_id, + logprobs={token_id: Logprob(token_id)})], + prompt_logprobs=None)) sampler_output_list.append( SamplerOutput(outputs=step_output_token_ids)) return sampler_output_list - def load_weights(self, model_name_or_path: str, draft_model_name_or_path: str, **kwargs): + def load_weights(self, model_name_or_path: str, + draft_model_name_or_path: str, **kwargs): arch = _get_model_architecture(self.config) neuronx_module_path, neuronx_model_cls_name = ( _NEURON_SUPPORTED_MODELS[arch]) neuronx_module = importlib.import_module(neuronx_module_path) neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name) - neuron_config = neuronx_model_cls.get_neuron_config_cls()(**kwargs['neuron_config']) + neuron_config = neuronx_model_cls.get_neuron_config_cls()( + **kwargs['neuron_config']) config = neuronx_model_cls.get_config_cls()( - neuron_config, load_config=load_pretrained_config(model_name_or_path) + neuron_config, + load_config=load_pretrained_config(model_name_or_path) ) draft_neuron_config = copy.deepcopy(config.neuron_config) @@ -363,21 +428,30 @@ def load_weights(self, model_name_or_path: str, draft_model_name_or_path: str, * draft_neuron_config.is_eagle_draft = True draft_neuron_config.sequence_parallel_enabled = False draft_config = neuronx_model_cls.get_config_cls()( - draft_neuron_config, load_config=load_pretrained_config(draft_model_name_or_path) + draft_neuron_config, + load_config=load_pretrained_config(draft_model_name_or_path) ) - fused_spec_config = FusedSpecNeuronConfig(neuronx_model_cls._model_cls, draft_config=draft_config, draft_model_path=draft_model_name_or_path) + fused_spec_config = ( + FusedSpecNeuronConfig(neuronx_model_cls._model_cls, + draft_config=draft_config, + draft_model_path=draft_model_name_or_path)) config.fused_spec_config = fused_spec_config self.config.neuron_config = neuron_config - + + hashed_config = hashlib.md5( + config.to_json_string().encode('utf-8')).hexdigest() if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None: compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS") elif os.path.exists(model_name_or_path): compiled_model_path = os.path.join(model_name_or_path, - f"neuron-compiled-artifacts/{hashlib.md5(config.to_json_string().encode('utf-8')).hexdigest()}/") + "neuron-compiled-artifacts", + hashed_config) shutil.rmtree(compiled_model_path, ignore_errors=True) else: - compiled_model_path = os.path.join("local-models", model_name_or_path, - f"neuron-compiled-artifacts/{hashlib.md5(config.to_json_string().encode('utf-8')).hexdigest()}/") + compiled_model_path = os.path.join("local-models", + model_name_or_path, + "neuron-compiled-artifacts", + hashed_config) shutil.rmtree(compiled_model_path, ignore_errors=True) try: self.model = neuronx_model_cls(compiled_model_path) @@ -387,18 +461,21 @@ def load_weights(self, model_name_or_path: str, draft_model_name_or_path: str, * self.model.load(compiled_model_path) return except (FileNotFoundError, ValueError) as e: - logger.warning(f"Exception: {e}") - logger.warning(f"Failed to load the model from {compiled_model_path}, Recompiling...") - draft_checkpoint_download = not draft_model_name_or_path == model_name_or_path + logger.warning("Exception: ", extra={"exception": e}) + logger.warning( + "Failed to load the model Recompiling...", + extra={"compiled_model_path": compiled_model_path}) if not os.path.exists(model_name_or_path): hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path) saved_path = os.path.join("local-models", model_name_or_path) hf_model.save_pretrained(saved_path) model_name_or_path = saved_path if not os.path.exists(draft_model_name_or_path): - if draft_checkpoint_download: - hf_model = AutoModelForCausalLM.from_pretrained(draft_model_name_or_path) - saved_path = os.path.join("local-models", draft_model_name_or_path) + if draft_model_name_or_path != model_name_or_path: + hf_model = AutoModelForCausalLM.from_pretrained( + draft_model_name_or_path) + saved_path = os.path.join("local-models", + draft_model_name_or_path) hf_model.save_pretrained(saved_path) draft_model_name_or_path = saved_path else: @@ -419,6 +496,7 @@ def _get_model_architecture(config: PretrainedConfig) -> str: f"for now. Supported architectures: " f"{list(_NEURON_SUPPORTED_MODELS.keys())}") + def _get_default_neuron_config(model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig): @@ -434,7 +512,7 @@ def _get_default_neuron_config(model_config: ModelConfig, max_context_length=scheduler_config.max_model_len, seq_len=scheduler_config.max_model_len, enable_bucketing=True, - is_continuous_batching=(batch_size>1), + is_continuous_batching=(batch_size > 1), quantized=False, torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], padding_side="right", @@ -443,11 +521,13 @@ def _get_default_neuron_config(model_config: ModelConfig, ) return neuron_config -def _get_default_neuron_speculation_config(model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - speculation_config: SpeculativeConfig): - """Generate a neuron config for speculative decoding based on vllm config args.""" + +def _get_default_speculation_config(model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + speculation_config: SpeculativeConfig): + """Generate a neuron config for speculative decoding based on vllm config + args.""" neuron_config = dict( tp_degree=parallel_config.tensor_parallel_size, batch_size=scheduler_config.max_num_seqs, @@ -459,7 +539,7 @@ def _get_default_neuron_speculation_config(model_config: ModelConfig, enable_bucketing=True, quantized=False, torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], - on_device_sampling_config= dict(top_k=1, do_sample=False,) + on_device_sampling_config=dict(top_k=1, do_sample=False, ) ) return neuron_config @@ -471,6 +551,7 @@ def _get_neuron_config_after_override(default_neuron_config, default_neuron_config.update(overridden_neuron_config) return default_neuron_config + def get_neuron_model(model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig) -> nn.Module: @@ -482,13 +563,17 @@ def get_neuron_model(model_config: ModelConfig, model = NeuronCausalLM(model_config.hf_config) default_neuron_config_args = _get_default_neuron_config( model_config, parallel_config, scheduler_config) - neuron_config = _get_neuron_config_after_override(default_neuron_config_args, + neuron_config = _get_neuron_config_after_override( + default_neuron_config_args, model_config.override_neuron_config) + + override_neuron_config = model_config.override_neuron_config model.load_weights(model_config.model, neuron_config=neuron_config, - override_neuron_config=model_config.override_neuron_config) + override_neuron_config=override_neuron_config) return model.eval() + def get_neuron_speculation_model(model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, @@ -498,12 +583,15 @@ def get_neuron_speculation_model(model_config: ModelConfig, This model handles speculation using both a draft model and an EAGLE draft. """ model = NeuronSpeculationCausalLM(model_config.hf_config) - default_neuron_config_args = _get_default_neuron_speculation_config( + default_neuron_config_args = _get_default_speculation_config( model_config, parallel_config, scheduler_config, speculation_config) - neuron_config = _get_neuron_config_after_override(default_neuron_config_args, + neuron_config = _get_neuron_config_after_override( + default_neuron_config_args, model_config.override_neuron_config) + + override_neuron_config = model_config.override_neuron_config model.load_weights(model_config.model, speculation_config.draft_model_config.model, - neuron_config=neuron_config, - override_neuron_config=model_config.override_neuron_config) + neuron_config=neuron_config, + override_neuron_config=override_neuron_config) return model.eval() diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index 39161a2c636..09202bcd986 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -119,21 +119,21 @@ def cpu_platform_plugin() -> Optional[str]: def neuron_platform_plugin() -> Optional[str]: - transformers_neuronx_installed = False - neuronx_distributed_inference_installed = False + tnx_installed = False + nxd_installed = False try: import transformers_neuronx # noqa: F401 - transformers_neuronx_installed = True + tnx_installed = True except ImportError: pass try: - import neuronx_distributed_inference - neuronx_distributed_inference_installed = True + import neuronx_distributed_inference # noqa: F401 + nxd_installed = True except ImportError: pass - is_neuron = transformers_neuronx_installed or neuronx_distributed_inference_installed + is_neuron = tnx_installed or nxd_installed return "vllm.platforms.neuron.NeuronPlatform" if is_neuron else None diff --git a/vllm/worker/multi_step_neuron_model_runner.py b/vllm/worker/multi_step_neuron_model_runner.py index 38e2e535a89..4584b4a8b48 100644 --- a/vllm/worker/multi_step_neuron_model_runner.py +++ b/vllm/worker/multi_step_neuron_model_runner.py @@ -5,33 +5,38 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.multimodal import MultiModalKwargs from vllm.sequence import IntermediateTensors -from vllm.worker.neuron_model_runner import NeuronModelRunner, ModelInputForNeuron +from vllm.worker.neuron_model_runner import NeuronModelRunner, \ + ModelInputForNeuron + class MultiStepNeuronModelRunner(NeuronModelRunner): - """A model runner for multi step decoding using the transformers_neuronx framework""" + """A model runner for multi step decoding using the transformers_neuronx + framework""" def __init__( - self, - vllm_config: VllmConfig, + self, + vllm_config: VllmConfig, ): super().__init__(vllm_config) self.speculation_config = self.speculative_config from transformers_neuronx.config import GenerationConfig - self.speculation_config.draft_model_config.neuron_sampling_params = GenerationConfig( + self.speculation_config.draft_model_config.neuron_sampling_params = ( + GenerationConfig( max_length=self.scheduler_config.max_model_len, do_sample=True, per_batch_line=True, - top_k = [self._MAX_NEURON_SAMPLING_TOP_K] \ - * self.scheduler_config.max_num_seqs, - top_p = [1.0] * self.scheduler_config.max_num_seqs, - temperature = [1.0] * self.scheduler_config.max_num_seqs, + top_k=[self._MAX_NEURON_SAMPLING_TOP_K] \ + * self.scheduler_config.max_num_seqs, + top_p=[1.0] * self.scheduler_config.max_num_seqs, + temperature=[1.0] * self.scheduler_config.max_num_seqs, dynamic=True, global_top_k=self._MAX_NEURON_SAMPLING_TOP_K - ) + )) def load_model(self) -> None: if find_spec("transformers_neuronx") is not None: - from vllm.model_executor.model_loader.neuron import get_neuron_speculation_model, get_neuron_eagle_speculation_model + from vllm.model_executor.model_loader.neuron import \ + get_neuron_speculation_model, get_neuron_eagle_speculation_model if self.speculation_config.speculative_token_tree is not None: self.model = get_neuron_eagle_speculation_model( self.model_config, @@ -51,22 +56,22 @@ def load_model(self) -> None: @torch.inference_mode() def execute_model( - self, - model_input: ModelInputForNeuron, - kv_caches: Optional[List[torch.Tensor]] = None, - intermediate_tensors: Optional[IntermediateTensors] = None, - num_steps: int = 1, + self, + model_input: ModelInputForNeuron, + kv_caches: Optional[List[torch.Tensor]] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, ) -> Optional[List[SamplerOutput]]: logits = self.model( input_ids=model_input.input_tokens, positions=model_input.input_positions, input_block_ids=model_input.input_block_ids, **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, - device=self.device), + device=self.device), ) output = self.model.sample( logits=logits, sampling_metadata=model_input.sampling_metadata, ) - return output \ No newline at end of file + return output diff --git a/vllm/worker/multi_step_neuronx_distributed_model_runner.py b/vllm/worker/multi_step_neuronx_distributed_model_runner.py index 209b307f946..e63c4168d58 100644 --- a/vllm/worker/multi_step_neuronx_distributed_model_runner.py +++ b/vllm/worker/multi_step_neuronx_distributed_model_runner.py @@ -1,14 +1,15 @@ -from importlib.util import find_spec import torch from typing import List, Optional from vllm.config import VllmConfig from vllm.model_executor.layers.sampler import SamplerOutput from vllm.multimodal import MultiModalKwargs from vllm.sequence import IntermediateTensors -from vllm.worker.neuronx_distributed_model_runner import NeuronxDistributedModelRunner +from vllm.worker.neuronx_distributed_model_runner \ + import NeuronxDistributedModelRunner class MultiStepNeuronxDistributedModelRunner(NeuronxDistributedModelRunner): - """A model runner for multi step decoding using the neuronx-distributed-inference framework""" + """A model runner for multi-step decoding using the + neuronx-distributed-inference framework""" def __init__( self, @@ -17,7 +18,8 @@ def __init__( super().__init__(vllm_config) def load_model(self) -> None: - from vllm.model_executor.model_loader.neuronx_distributed import get_neuron_speculation_model + from vllm.model_executor.model_loader.neuronx_distributed \ + import get_neuron_speculation_model self.model = get_neuron_speculation_model( self.model_config, parallel_config=self.parallel_config, diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index eaa9f0d4fda..deb0f9edd3b 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -18,7 +18,8 @@ from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import is_pin_memory_available, make_tensor_with_pad, is_transformers_neuronx, is_neuronx_distributed_inference from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase -from vllm.worker.neuron_worker import use_neuronx_distributed, use_transformers_neuronx +from vllm.worker.neuron_worker import use_neuronx_distributed, \ + use_transformers_neuronx if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -49,9 +50,9 @@ def as_broadcastable_tensor_dict( @classmethod def from_broadcasted_tensor_dict( - cls, - tensor_dict: Dict[str, Any], - attn_backend: Optional["AttentionBackend"] = None, + cls, + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, ) -> "ModelInputForNeuron": return ModelInputForNeuron( input_tokens=tensor_dict["input_tokens"], @@ -69,16 +70,18 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): _MAX_NEURON_SAMPLING_TOP_K = 256 def __init__( - self, - vllm_config: VllmConfig, + self, + vllm_config: VllmConfig, ): ModelRunnerBase.__init__(self, vllm_config) - if self.model_config is not None and self.model_config.get_sliding_window(): + if (self.model_config is not None + and self.model_config.get_sliding_window()): logger.warning("Sliding window is not supported on Neuron. " "The model will run without sliding window.") self.device_config = (self.device_config - if self.device_config is not None else DeviceConfig()) + if self.device_config is not None + else DeviceConfig()) self.device = self.device_config.device self.pin_memory = is_pin_memory_available() @@ -108,17 +111,17 @@ def _init_neuron_sampling(self) -> None: else: from transformers import GenerationConfig logger.warning( - "On-device sampling is turned on in Neuron by default, only " - "top_k, top_p, and temperature are current supported sampling " - "parameters. To turn off the on-device sampling, please set " - "the environment variable NEURON_ON_DEVICE_SAMPLING_DISABLED=1." - ) + "On-device sampling is turned on in Neuron by default, only " + "top_k, top_p, and temperature are current supported sampling " + "parameters. To turn off the on-device sampling, please set " + "the environment variable NEURON_ON_DEVICE_SAMPLING_DISABLED=1." + ) self.model_config.neuron_sampling_params = GenerationConfig( max_length=self.scheduler_config.max_model_len, do_sample=True, per_batch_line=True, top_k=[self._MAX_NEURON_SAMPLING_TOP_K] \ - * self.scheduler_config.max_num_seqs, + * self.scheduler_config.max_num_seqs, top_p=[1.0] * self.scheduler_config.max_num_seqs, temperature=[1.0] * self.scheduler_config.max_num_seqs, dynamic=True, @@ -135,10 +138,10 @@ def get_model(self) -> nn.Module: return self.model def _prepare_prompt( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], + self, + seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int], - BatchedTensorInputs]: + BatchedTensorInputs]: assert len(seq_group_metadata_list) > 0 input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] @@ -165,7 +168,7 @@ def _prepare_prompt( assert len(block_table) == 1 input_block_ids.append(block_table[0]) - mm_data = seq_group_metadata.multi_modal_data + # mm_data = seq_group_metadata.multi_modal_data # if mm_data: # # Process multi-modal data # multi_modal_inputs_list.append(mm_data) @@ -192,8 +195,8 @@ def _prepare_prompt( multi_modal_kwargs) def _prepare_decode( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], + self, + seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: assert len(seq_group_metadata_list) > 0 input_tokens: List[List[int]] = [] @@ -245,10 +248,10 @@ def make_model_input_from_broadcasted_tensor_dict( return ModelInputForNeuron.from_broadcasted_tensor_dict(tensor_dict) def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None ) -> ModelInputForNeuron: multi_modal_kwargs = None # NOTE: We assume that all sequences in the group are all prompts or @@ -263,15 +266,16 @@ def prepare_model_input( (input_tokens, input_positions, input_block_ids) = self._prepare_decode(seq_group_metadata_list) seq_lens = None - + if not self._on_device_sampling_disabled: for seq_group_metadata in seq_group_metadata_list: sampling_params = seq_group_metadata.sampling_params - top_k, top_p, temperature = self._convert_to_neuron_sampling_params(sampling_params) + top_k, top_p, temperature = ( + self._convert_to_neuron_sampling_params(sampling_params)) sampling_params.top_k = top_k sampling_params.top_p = top_p sampling_params.temperature = temperature - + sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, @@ -301,7 +305,7 @@ def prepare_model_input( multi_modal_kwargs=multi_modal_kwargs) def _update_neuron_sampling_params( - self, seq_group_metadata_list: List[SequenceGroupMetadata]): + self, seq_group_metadata_list: List[SequenceGroupMetadata]): # Update Neuron sampling parameters (GenerationConfig in Neuron) current_sampling_params = self.model_config.neuron_sampling_params assert current_sampling_params is not None, ( @@ -327,9 +331,9 @@ def _update_neuron_sampling_params( for seq_id in seq_ids: index = seq_group_metadata.block_tables[seq_id][0] if ( - top_k[index] != seq_group_top_k - or top_p[index] != seq_group_top_p - or temperature[index] != seq_group_temperature + top_k[index] != seq_group_top_k + or top_p[index] != seq_group_top_p + or temperature[index] != seq_group_temperature ): is_update_needed = True @@ -358,29 +362,35 @@ def _convert_to_neuron_sampling_params( @torch.inference_mode() def execute_model( - self, - model_input: ModelInputForNeuron, - kv_caches: Optional[List[torch.Tensor]] = None, - intermediate_tensors: Optional[IntermediateTensors] = None, - num_steps: int = 1, + self, + model_input: ModelInputForNeuron, + kv_caches: Optional[List[torch.Tensor]] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, ) -> Optional[List[SamplerOutput]]: if num_steps > 1: raise ValueError( "NeuronModelRunner does not support multi-step execution.") - # extract top_k, top_p and temperature from model_input for neuron forward call - sampling_params = torch.tensor([[seq_group.sampling_params.top_k, - seq_group.sampling_params.top_p, - seq_group.sampling_params.temperature] for seq_group in model_input.sampling_metadata.seq_groups]) + # extract top_k, top_p and temperature from model_input for neuron + # forward call + sampling_params = ( + torch.tensor( + [[seq_group.sampling_params.top_k, + seq_group.sampling_params.top_p, + seq_group.sampling_params.temperature] + for seq_group in model_input.sampling_metadata.seq_groups])) + if use_neuronx_distributed(): hidden_states = self.model( input_ids=model_input.input_tokens, positions=model_input.input_positions, input_block_ids=model_input.input_block_ids, sampling_params=sampling_params, - **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, - device=self.device), - ) + **MultiModalKwargs.as_kwargs( + model_input.multi_modal_kwargs or {}, + device=self.device), + ) elif use_transformers_neuronx(): # [TODO] validate on-device sampling # The model signature may need change for on-device sampling @@ -388,8 +398,9 @@ def execute_model( input_ids=model_input.input_tokens, positions=model_input.input_positions, input_block_ids=model_input.input_block_ids, - **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, - device=self.device), + **MultiModalKwargs.as_kwargs( + model_input.multi_modal_kwargs or {}, + device=self.device), ) # Compute the logits only if the on-device sampling is turned off as diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index a6e51709760..b9a844204af 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -2,7 +2,7 @@ """A Neuron worker class.""" import enum import os -from functools import lru_cache +from functools import cache from typing import List, Optional, Tuple import torch @@ -25,55 +25,64 @@ DEFAULT_NEURON_RANK_ID = "0" DEFAULT_ENABLE_NEURON_MULTI_NODE = "False" + class NeuronFramework(enum.Enum): TRANSFORMERS_NEURONX = "transformers-neuronx" NEURONX_DISTRIBUTED_INFERENCE = "neuronx-distributed-inference" -@lru_cache(maxsize=None) +@cache def get_neuron_framework_to_use(): - """ - Return the specified framework if the corresponding installations are available. - If no framework is specified, then use transformers-neuronx by default, if unavailable - then check and switch to neuronx-distributed-inference. + """Return the specified framework if corresponding installations are + available. + + If no framework is specified, use neuronx-distributed-inference by default. + If that's unavailable, check and switch to transformers-neuronx. """ if not current_platform.is_neuron(): - raise AssertionError(f"Neuron Framework cannot be obtained for Non-neuron Platform: {current_platform}") + raise AssertionError( + f"Neuron Framework unavailable for platform: {current_platform}") - transformers_neuronx_installed = current_platform.is_transformers_neuronx() - neuronx_distributed_inference_installed = current_platform.is_neuronx_distributed_inference() + tnx_installed = current_platform.is_transformers_neuronx() + nxd_installed = current_platform.is_neuronx_distributed_inference() specified_framework = os.environ.get("VLLM_NEURON_FRAMEWORK") - if specified_framework == NeuronFramework.TRANSFORMERS_NEURONX.value and transformers_neuronx_installed: + tnx_framework = NeuronFramework.TRANSFORMERS_NEURONX.value + nxd_framework = NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE.value + if (specified_framework == tnx_framework and + tnx_installed): return NeuronFramework.TRANSFORMERS_NEURONX - elif specified_framework == NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE.value and neuronx_distributed_inference_installed: - return NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE - elif specified_framework is None and neuronx_distributed_inference_installed: + + if ((specified_framework == nxd_framework and + nxd_installed) or + (specified_framework is None and nxd_installed)): return NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE - elif specified_framework is None and transformers_neuronx_installed: + + if specified_framework is None and tnx_installed: return NeuronFramework.TRANSFORMERS_NEURONX - else: - return None + + return None -@lru_cache(maxsize=None) +@cache def use_neuronx_distributed(): """ Return True if the framework determined in get_neuron_framework_to_use() is NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE, False otherwise. This is used - to select the Neuron model framework and framework-specific configuration to apply - during model compilation. + to select the Neuron model framework and framework-specific configuration to + apply during model compilation. """ - return get_neuron_framework_to_use() == NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE + nxd_framework = NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE + return get_neuron_framework_to_use() == nxd_framework -@lru_cache(maxsize=None) +@cache def use_transformers_neuronx(): """ Return True if the framework determined in get_neuron_framework_to_use() is - NeuronFramework.TRANSFORMERS_NEURONX, False otherwise. This is used to select - the Neuron model framework and framework-specific configuration to apply during - model compilation. + NeuronFramework.TRANSFORMERS_NEURONX, False otherwise. This is used to + select the Neuron model framework and framework-specific configuration to + apply during model compilation. """ return get_neuron_framework_to_use() == NeuronFramework.TRANSFORMERS_NEURONX @@ -83,12 +92,12 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): """ def __init__( - self, - vllm_config: VllmConfig, - local_rank: int, - rank: int, - distributed_init_method: str, - is_driver_worker: bool = False + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + is_driver_worker: bool = False ) -> None: WorkerBase.__init__(self, vllm_config=vllm_config) self.local_rank = local_rank @@ -96,43 +105,63 @@ def __init__( self.distributed_init_method = distributed_init_method self.is_driver_worker = is_driver_worker - self.enable_neuron_multi_node = os.getenv("ENABLE_NEURON_MULTI_NODE", DEFAULT_ENABLE_NEURON_MULTI_NODE).lower() == "true" + self.enable_neuron_multi_node = ( + os.getenv("ENABLE_NEURON_MULTI_NODE", + DEFAULT_ENABLE_NEURON_MULTI_NODE).lower() == "true") if self.enable_neuron_multi_node: self.world_size = int(os.getenv("WORLD_SIZE", DEFAULT_WORLD_SIZE)) self.rank = int(os.getenv("NEURON_RANK_ID", DEFAULT_NEURON_RANK_ID)) self.distributed_init_method = "env://" self.is_driver_worker = self.rank == 0 - logger.info(f"Rank: {self.rank}, distributed_init_method: {self.distributed_init_method}, is_driver_worker: {self.is_driver_worker}") - + + logger.info("Neuron multi-node parameters", + extra={ + "Rank": self.rank, + "distributed_init_method": + self.distributed_init_method, + "is_driver_worker": self.is_driver_worker + }) + if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils import init_cached_hf_modules init_cached_hf_modules() neuron_framework = get_neuron_framework_to_use() - if neuron_framework == NeuronFramework.TRANSFORMERS_NEURONX: - from vllm.worker.neuron_model_runner import NeuronModelRunner - # from vllm.worker.multi_step_neuron_model_runner import MultiStepNeuronModelRunner - if self.speculative_config is not None: - pass - else: - self.model_runner: NeuronModelRunner = NeuronModelRunner( - vllm_config=vllm_config) - elif neuron_framework == NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE: - from vllm.worker.neuronx_distributed_model_runner import NeuronxDistributedModelRunner - # from vllm.worker.multi_step_neuronx_distributed_model_runner import MultiStepNeuronModelRunner - if self.speculative_config is not None: - pass - else: - self.model_runner: NeuronxDistributedModelRunner = NeuronxDistributedModelRunner( - vllm_config=vllm_config) + self.model_runner = self.get_tnx_model_runner(vllm_config) + else: raise NotImplementedError( - f"Specified framework {os.environ.get('VLLM_NEURON_FRAMEWORK')}" + + "Specified framework" + + f" {os.environ.get('VLLM_NEURON_FRAMEWORK')}" + " is either not installed or not supported." + - " Supported frameworks: [transformers-neuronx, neuronx-distributed-inference]") + " Supported frameworks: " + + "[transformers-neuronx, neuronx-distributed-inference]") + + def get_tnx_model_runner(self, vllm_config): + from vllm.worker.neuron_model_runner import NeuronModelRunner + from vllm.worker.multi_step_neuron_model_runner import \ + MultiStepNeuronModelRunner + if self.speculative_config is not None: + return MultiStepNeuronModelRunner( + vllm_config=vllm_config) + else: + return NeuronModelRunner( + vllm_config=vllm_config) + + def get_neuronx_distributed_model_runner(self, vllm_config): + from vllm.worker.neuronx_distributed_model_runner import \ + NeuronxDistributedModelRunner + from vllm.worker.multi_step_neuronx_distributed_model_runner import \ + MultiStepNeuronxDistributedModelRunner + if self.speculative_config is not None: + return MultiStepNeuronxDistributedModelRunner( + vllm_config=vllm_config) + else: + return NeuronxDistributedModelRunner( + vllm_config=vllm_config) def init_device(self) -> None: self.init_distributed_environment() @@ -208,7 +237,7 @@ def init_distributed_environment(self): distributed_init_method=self.distributed_init_method, backend="gloo", ) - + # The equation must hold: world_size === TP * PP ensure_model_parallel_initialized( tensor_model_parallel_size=self.world_size, diff --git a/vllm/worker/neuronx_distributed_model_runner.py b/vllm/worker/neuronx_distributed_model_runner.py index 22e61ac0634..67828128cd5 100644 --- a/vllm/worker/neuronx_distributed_model_runner.py +++ b/vllm/worker/neuronx_distributed_model_runner.py @@ -1,27 +1,29 @@ -import os import torch -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import List, Optional from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.model_executor.model_loader.neuronx_distributed import get_neuron_model, _get_model_architecture -from vllm.worker.neuron_model_runner import NeuronModelRunner, ModelInputForNeuron -from vllm.sequence import IntermediateTensors, SequenceGroupMetadata +from vllm.model_executor.model_loader.neuronx_distributed import \ + get_neuron_model, _get_model_architecture +from vllm.worker.neuron_model_runner import NeuronModelRunner, \ + ModelInputForNeuron +from vllm.sequence import IntermediateTensors from vllm.model_executor.layers.sampler import SamplerOutput -from neuronx_distributed_inference.modules.generation.sampling import prepare_sampling_params -from neuronx_distributed_inference.models.mllama.image_transform import custom_image_preprocessing +from neuronx_distributed_inference.modules.generation.sampling import \ + prepare_sampling_params -# FIXME(Neuron): need to restor multi-model support -# from vllm.multimodal.neuron_multimodal_image_utils import decompress_image_from_tensor +# FIXME(Neuron): need to restore multi-modal support +# from vllm.multimodal.neuron_multimodal_image_utils import \ +# decompress_image_from_tensor logger = init_logger(__name__) class NeuronxDistributedModelRunner(NeuronModelRunner): def __init__( - self, - vllm_config: VllmConfig, + self, + vllm_config: VllmConfig, ): super().__init__(vllm_config) @@ -33,7 +35,9 @@ def load_model(self) -> None: def get_nxd_sampling_params(self, sampling_metadata): if self.model.config.neuron_config.on_device_sampling_config: - max_topk = self.model.config.neuron_config.on_device_sampling_config.global_topk + max_topk = ( + self.model.config.neuron_config.on_device_sampling_config + .global_topk) else: max_topk = self.model.config.vocab_size @@ -41,33 +45,24 @@ def get_nxd_sampling_params(self, sampling_metadata): top_p = [1.0] * self.scheduler_config.max_num_seqs temperature = [1.0] * self.scheduler_config.max_num_seqs - for index, sequenceGroupToSample in enumerate(sampling_metadata.seq_groups): - top_k[index] = sequenceGroupToSample.sampling_params.top_k if sequenceGroupToSample.sampling_params.top_k > 0 else max_topk + for index, sequenceGroupToSample in enumerate( + sampling_metadata.seq_groups): + top_k[index] = ( + sequenceGroupToSample.sampling_params.top_k + if sequenceGroupToSample.sampling_params.top_k > 0 + else max_topk) top_p[index] = sequenceGroupToSample.sampling_params.top_p - temperature[index] = sequenceGroupToSample.sampling_params.temperature + temperature[index] = ( + sequenceGroupToSample.sampling_params.temperature) - sampling_params = prepare_sampling_params(batch_size=self.scheduler_config.max_num_seqs, - top_k=top_k, top_p=top_p, temperature=temperature) + sampling_params = prepare_sampling_params( + batch_size=self.scheduler_config.max_num_seqs, + top_k=top_k, top_p=top_p, temperature=temperature) return sampling_params def get_multi_modal_data_neuron(self, input_images): - total_image_size = 0 - image_tensors = [] - empty_pixel_values = torch.zeros([1, 1, 4, 3, 560, 560], dtype=torch.bfloat16) - empty_aspect_ratios = torch.ones([1, 1, 2], dtype=torch.int64) - num_chunks = torch.tensor([[1]]) # dummy num_chunks, will not be used - has_image = torch.tensor([0]) - if (input_images is None) or (len(input_images) == 0) or (input_images.numel() == 0): - image_tensors = [empty_pixel_values, empty_aspect_ratios, num_chunks, has_image] - else: - image = decompress_image_from_tensor(input_images) - total_image_size += image.width * image.height - pixel_values, aspect_ratios, num_chunks = custom_image_preprocessing(self.model.config, [[image]]) - has_image = torch.tensor([1]) - - image_tensors = [pixel_values.bfloat16().clone().detach(), aspect_ratios, num_chunks, has_image] - - return image_tensors + # FIXME(Neuron): need to restore multi-modal support + raise NotImplementedError("need to restore multi-modal support") @torch.inference_mode() def execute_model( @@ -81,27 +76,32 @@ def execute_model( raise ValueError( "NeuronModelRunner does not support multi-step execution.") - if not _get_model_architecture(self.model.config) == "MllamaForConditionalGeneration": - return super().execute_model(model_input, kv_caches, intermediate_tensors, num_steps) + if _get_model_architecture( + self.model.config) != "MllamaForConditionalGeneration": + return super().execute_model(model_input, kv_caches, + intermediate_tensors, num_steps) - sampling_params = self.get_nxd_sampling_params(model_input.sampling_metadata) + sampling_params = self.get_nxd_sampling_params( + model_input.sampling_metadata) if model_input.multi_modal_kwargs.get('image') is not None: pixel_values = [] aspect_ratios = [] num_chunks = [] has_image = [] - for multi_modal_input in model_input.multi_modal_kwargs.get('image'): - image_tensors = self.get_multi_modal_data_neuron(multi_modal_input.squeeze(0)) + for multi_modal_input in model_input.multi_modal_kwargs.get( + 'image'): + image_tensors = self.get_multi_modal_data_neuron( + multi_modal_input.squeeze(0)) pixel_values.append(image_tensors[0]) aspect_ratios.append(image_tensors[1]) num_chunks.append(image_tensors[2]) has_image.append(image_tensors[3]) - pixel_values=torch.cat(pixel_values,dim=0) - aspect_ratios=torch.cat(aspect_ratios,dim=0) - num_chunks=torch.cat(num_chunks,dim=0) - has_image=torch.cat(has_image,dim=0) + pixel_values = torch.cat(pixel_values, dim=0) + aspect_ratios = torch.cat(aspect_ratios, dim=0) + num_chunks = torch.cat(num_chunks, dim=0) + has_image = torch.cat(has_image, dim=0) hidden_states = self.model( input_ids=model_input.input_tokens, @@ -114,24 +114,26 @@ def execute_model( has_image=has_image, ) else: - empty_pixel_values = torch.zeros([1, 1, 4, 3, 560, 560], dtype=torch.bfloat16) + empty_pixel_values = torch.zeros([1, 1, 4, 3, 560, 560], + dtype=torch.bfloat16) empty_aspect_ratios = torch.ones([1, 1, 2], dtype=torch.int64) - num_chunks = torch.tensor([[1]]) # dummy num_chunks, will not be used + num_chunks = torch.tensor( + [[1]]) # dummy num_chunks, will not be used has_image = torch.tensor([0]) hidden_states = self.model( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - seq_ids=model_input.input_block_ids, - pixel_values=empty_pixel_values, - aspect_ratios=empty_aspect_ratios, - sampling_params=sampling_params, - num_chunks=num_chunks, - has_image=has_image, - ) + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + seq_ids=model_input.input_block_ids, + pixel_values=empty_pixel_values, + aspect_ratios=empty_aspect_ratios, + sampling_params=sampling_params, + num_chunks=num_chunks, + has_image=has_image, + ) output = self.model.sample( hidden_states=hidden_states, sampling_metadata=model_input.sampling_metadata, ) - return [output] \ No newline at end of file + return [output] From 51f403f51b1de80f1644ccca5520ccbb6af81e39 Mon Sep 17 00:00:00 2001 From: Satyajith Chilappagari Date: Mon, 24 Feb 2025 19:59:00 +0000 Subject: [PATCH 26/38] Bug fix: Missing NxDI model runner addressed Signed-off-by: Satyajith Chilappagari --- vllm/worker/neuron_worker.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index b9a844204af..a770c3c4183 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -131,7 +131,9 @@ def __init__( neuron_framework = get_neuron_framework_to_use() if neuron_framework == NeuronFramework.TRANSFORMERS_NEURONX: self.model_runner = self.get_tnx_model_runner(vllm_config) - + elif neuron_framework == NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE: + self.model_runner = self.get_neuronx_distributed_model_runner( + vllm_config) else: raise NotImplementedError( "Specified framework" + From 543f55d68fb0e147fe91e0c0b662df5bdd131c6d Mon Sep 17 00:00:00 2001 From: Aaron Dou Date: Mon, 24 Feb 2025 23:53:48 +0000 Subject: [PATCH 27/38] set world_size default value as 1 Signed-off-by: Satyajith Chilappagari --- vllm/worker/neuron_worker.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index a770c3c4183..07caf8985e8 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -109,8 +109,9 @@ def __init__( os.getenv("ENABLE_NEURON_MULTI_NODE", DEFAULT_ENABLE_NEURON_MULTI_NODE).lower() == "true") + self.world_size = int(os.getenv("WORLD_SIZE", DEFAULT_WORLD_SIZE)) + if self.enable_neuron_multi_node: - self.world_size = int(os.getenv("WORLD_SIZE", DEFAULT_WORLD_SIZE)) self.rank = int(os.getenv("NEURON_RANK_ID", DEFAULT_NEURON_RANK_ID)) self.distributed_init_method = "env://" self.is_driver_worker = self.rank == 0 From d15c20bcb258b1300f8014f8a3757414af1e5799 Mon Sep 17 00:00:00 2001 From: Satyajith Chilappagari Date: Tue, 25 Feb 2025 23:27:18 +0000 Subject: [PATCH 28/38] Formatting changes to satisfy all pre-commit hooks Signed-off-by: Satyajith Chilappagari --- examples/neuron/multi_node/launch_script.py | 60 +++-- examples/neuron/multi_node/worker.py | 7 +- .../offline_inference/neuron_speculation.py | 3 +- setup.py | 1 - tests/worker/test_neuron_model_runner.py | 24 +- vllm/engine/arg_utils.py | 3 +- vllm/engine/llm_engine.py | 4 +- vllm/engine/output_processor/stop_checker.py | 7 +- vllm/model_executor/model_loader/neuron.py | 128 +++++---- .../model_loader/neuronx_distributed.py | 248 +++++++++--------- vllm/platforms/__init__.py | 2 +- vllm/platforms/neuron.py | 2 +- vllm/worker/multi_step_neuron_model_runner.py | 32 ++- ...i_step_neuronx_distributed_model_runner.py | 17 +- vllm/worker/neuron_model_runner.py | 89 +++---- vllm/worker/neuron_worker.py | 61 ++--- .../neuronx_distributed_model_runner.py | 61 ++--- 17 files changed, 372 insertions(+), 377 deletions(-) diff --git a/examples/neuron/multi_node/launch_script.py b/examples/neuron/multi_node/launch_script.py index a1f14c2f2ce..8850a711d7a 100644 --- a/examples/neuron/multi_node/launch_script.py +++ b/examples/neuron/multi_node/launch_script.py @@ -1,8 +1,9 @@ +# SPDX-License-Identifier: Apache-2.0 import argparse import json import os -import sys import subprocess +import sys from vllm.logger import init_logger @@ -18,38 +19,53 @@ def error_exit(message: str) -> None: def arg_parser(): parser = argparse.ArgumentParser(description="vLLM multi-node launcher") - parser.add_argument("--model", type=str, required=True, + parser.add_argument("--model", + type=str, + required=True, help="Model or model path") - parser.add_argument("--world-size", type=int, required=True, + parser.add_argument("--world-size", + type=int, + required=True, help="World size for distributed inference") - parser.add_argument("--max-num-seqs", type=int, required=True, + parser.add_argument("--max-num-seqs", + type=int, + required=True, help="Maximum number of sequences (or batch size)") - parser.add_argument("--max-model-len", type=int, default=8192, + parser.add_argument("--max-model-len", + type=int, + default=8192, help="Maximum sequence length") - parser.add_argument("--max-context-length", type=int, + parser.add_argument("--max-context-length", + type=int, help="Maximum context length") parser.add_argument("--compiled-model-path", help="Path to the compiled model. If not present, " - "model artifacts will be created in local-models " - "folder") - parser.add_argument("--local-ranks-size", type=int, default=32, + "model artifacts will be created in local-models " + "folder") + parser.add_argument("--local-ranks-size", + type=int, + default=32, help="Local ranks size") parser.add_argument("--on-device-sampling-config", type=json.loads, help="On-device sampling configuration") - parser.add_argument("--quantized", type=bool, default=False, + parser.add_argument("--quantized", + type=bool, + default=False, help="Enable quantized mode (default: False)") - parser.add_argument("--quantized-checkpoints-path", type=str, + parser.add_argument("--quantized-checkpoints-path", + type=str, help="Path to quantized checkpoints " - "(required if --quantized is True)") - parser.add_argument("--port", type=int, default=8080, + "(required if --quantized is True)") + parser.add_argument("--port", + type=int, + default=8080, help="Port for the API server") args = parser.parse_args() if args.quantized and not args.quantized_checkpoints_path: - parser.error( - "--quantized-checkpoints-path is required when " - "--quantized is enabled.") + parser.error("--quantized-checkpoints-path is required when " + "--quantized is enabled.") return args @@ -91,7 +107,7 @@ def main() -> None: "ENABLE_NEURON_MULTI_NODE": "true", "WORLD_SIZE": str(mpi_world_size), "NEURON_RT_ROOT_COMM_ID": - f"{master_addr}:{NEURON_RT_ROOT_COMM_ID_PORT}", + f"{master_addr}:{NEURON_RT_ROOT_COMM_ID_PORT}", "NEURON_LOCAL_TP": str(args.local_ranks_size), "NEURON_RANK_ID": str(rank) }) @@ -101,9 +117,7 @@ def main() -> None: logger.info("Starting vLLM API server on rank 0...") cmd = [ "python", "-m", "vllm.entrypoints.api_server", - f"--model={args.model}", - f"--port={args.port}", - "--device=neuron", + f"--model={args.model}", f"--port={args.port}", "--device=neuron", f"--max-num-seqs={args.max_num_seqs}", f"--max-model-len={args.max_model_len}", f"--override-neuron-config={json.dumps(override_config)}" @@ -118,10 +132,8 @@ def main() -> None: current_script_dir = os.path.dirname(os.path.abspath(__file__)) worker_file_path = os.path.join(current_script_dir, "worker.py") cmd = [ - "python", worker_file_path, - f"--model={args.model}", - "--device=neuron", - f"--max-num-seqs={args.max_num_seqs}", + "python", worker_file_path, f"--model={args.model}", + "--device=neuron", f"--max-num-seqs={args.max_num_seqs}", f"--max-model-len={args.max_model_len}", f"--override-neuron-config={json.dumps(override_config)}" ] diff --git a/examples/neuron/multi_node/worker.py b/examples/neuron/multi_node/worker.py index 639763ce45b..228822d95d6 100644 --- a/examples/neuron/multi_node/worker.py +++ b/examples/neuron/multi_node/worker.py @@ -1,12 +1,12 @@ +# SPDX-License-Identifier: Apache-2.0 import argparse import os -from vllm.logger import init_logger from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext - logger = init_logger("vllm.neuron.multi-node.worker") @@ -20,6 +20,7 @@ def initialize_worker(): engine_args, usage_context=UsageContext.API_SERVER) return args, engine + def start_worker(): rank_id = int(os.getenv("NEURON_RANK_ID")) if rank_id == 0: @@ -29,6 +30,7 @@ def start_worker(): while True: worker.execute_model() + def main(): try: start_worker() @@ -36,5 +38,6 @@ def main(): logger.error("Failed starting worker", extra={"error": e}) exit(1) + if __name__ == "__main__": main() diff --git a/examples/offline_inference/neuron_speculation.py b/examples/offline_inference/neuron_speculation.py index 9cae4d47c9f..ea03089f5f0 100644 --- a/examples/offline_inference/neuron_speculation.py +++ b/examples/offline_inference/neuron_speculation.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 - """ This example shows how to run offline inference with a speculative decoding model on neuron. @@ -60,4 +59,4 @@ def main(): if __name__ == '__main__': - main() \ No newline at end of file + main() diff --git a/setup.py b/setup.py index fa2adf76e54..9a2852a614c 100755 --- a/setup.py +++ b/setup.py @@ -673,4 +673,3 @@ def _read_requirements(filename: str) -> List[str]: ], }, ) - diff --git a/tests/worker/test_neuron_model_runner.py b/tests/worker/test_neuron_model_runner.py index feecbf38e4d..08fd3644bc1 100644 --- a/tests/worker/test_neuron_model_runner.py +++ b/tests/worker/test_neuron_model_runner.py @@ -1,13 +1,16 @@ +# SPDX-License-Identifier: Apache-2.0 import os from unittest.mock import MagicMock + from vllm.config import VllmConfig from vllm.engine.arg_utils import EngineArgs from vllm.sampling_params import SamplingParams from vllm.sequence import SequenceData, SequenceGroupMetadata from vllm.worker.neuron_model_runner import NeuronModelRunner -from vllm.worker.neuron_worker import use_transformers_neuronx, NeuronFramework +from vllm.worker.neuron_worker import NeuronFramework, use_transformers_neuronx -os.environ['VLLM_NEURON_FRAMEWORK'] = NeuronFramework.TRANSFORMERS_NEURONX.value +os.environ[ + 'VLLM_NEURON_FRAMEWORK'] = NeuronFramework.TRANSFORMERS_NEURONX.value def _create_neuron_model_runner(model: str, *args, @@ -20,9 +23,7 @@ def _create_neuron_model_runner(model: str, *args, scheduler_config=engine_config.scheduler_config, device_config=engine_config.device_config, ) - neuron_model_runner = NeuronModelRunner( - vllm_config=vllm_config - ) + neuron_model_runner = NeuronModelRunner(vllm_config=vllm_config) return neuron_model_runner @@ -36,7 +37,7 @@ def test_update_neuron_sampling_params_not_full_batch(): ) assert not model_runner._on_device_sampling_disabled # Test sampling param updating only when TNx is framework - # NxDI handles sampling paramter updating inside model + # NxDI handles sampling parameter updating inside model if use_transformers_neuronx(): model_mock = MagicMock() model_runner.model = model_mock @@ -46,7 +47,8 @@ def test_update_neuron_sampling_params_not_full_batch(): request_id="test_0", is_prompt=True, seq_data={0: SequenceData.from_seqs([1, 2, 3])}, - sampling_params=SamplingParams(temperature=0.5, top_k=1, + sampling_params=SamplingParams(temperature=0.5, + top_k=1, top_p=0.5), block_tables={0: [1]}, ) @@ -81,7 +83,7 @@ def test_update_neuron_sampling_params_full_batch(): assert not model_runner._on_device_sampling_disabled # Test sampling param updating only when TNx is framework - # NxDI handles sampling paramter updating inside model + # NxDI handles sampling parameter updating inside model if use_transformers_neuronx(): model_mock = MagicMock() model_runner.model = model_mock @@ -91,7 +93,8 @@ def test_update_neuron_sampling_params_full_batch(): request_id="test_0", is_prompt=True, seq_data={0: SequenceData.from_seqs([1, 2, 3])}, - sampling_params=SamplingParams(temperature=0.5, top_k=1, + sampling_params=SamplingParams(temperature=0.5, + top_k=1, top_p=0.5), block_tables={0: [1]}, ), @@ -99,7 +102,8 @@ def test_update_neuron_sampling_params_full_batch(): request_id="test_0", is_prompt=True, seq_data={1: SequenceData.from_seqs([4, 5, 6])}, - sampling_params=SamplingParams(temperature=0.2, top_k=2, + sampling_params=SamplingParams(temperature=0.2, + top_k=2, top_p=0.2), block_tables={1: [0]}, ) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 00ae1dbc6cb..3435e77a40f 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -751,8 +751,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: '--speculative-token-tree', type=nullable_str, default=EngineArgs.speculative_token_tree, - help='The token tree definition used with speculation.' - ) + help='The token tree definition used with speculation.') parser.add_argument( '--speculative-draft-tensor-parallel-size', '-spec-draft-tp', diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index ce12a3d1cba..279e620ba5d 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -388,7 +388,6 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: "vllm.llm_engine", self.observability_config.otlp_traces_endpoint) - if self.device_config.device_type == "neuron": num_lookahead_slots = self.scheduler_config.num_lookahead_slots else: @@ -405,8 +404,7 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: stop_checker=StopChecker( self.scheduler_config.max_model_len, get_tokenizer_for_seq, - num_lookahead_slots=num_lookahead_slots - ), + num_lookahead_slots=num_lookahead_slots), )) self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {} diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py index 3be7dfb10fc..1bc413721ff 100644 --- a/vllm/engine/output_processor/stop_checker.py +++ b/vllm/engine/output_processor/stop_checker.py @@ -15,7 +15,8 @@ class StopChecker: emitted, or if we have exceeded the max model len. """ - def __init__(self, max_model_len: int, + def __init__(self, + max_model_len: int, get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer], num_lookahead_slots: int = 0): # Do not use it directly, but use `self._get_max_model_len`. @@ -83,8 +84,8 @@ def maybe_stop_sequence( return # Check if the sequence has reached max_model_len. - if (seq.get_len() + - self.num_lookahead_slots > self._get_max_model_len(lora_req)): + if (seq.get_len() + self.num_lookahead_slots + > self._get_max_model_len(lora_req)): seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED return diff --git a/vllm/model_executor/model_loader/neuron.py b/vllm/model_executor/model_loader/neuron.py index c7248038871..76e1aa46b88 100644 --- a/vllm/model_executor/model_loader/neuron.py +++ b/vllm/model_executor/model_loader/neuron.py @@ -11,8 +11,8 @@ import torch.nn as nn from transformers import PretrainedConfig -from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig, \ - SpeculativeConfig +from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig, + SpeculativeConfig) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import get_quantization_config from vllm.model_executor.layers.sampler import Sampler, SamplerOutput @@ -60,10 +60,10 @@ def __init__(self, self.model: nn.Module def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - input_block_ids: torch.Tensor, + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_block_ids: torch.Tensor, ) -> torch.Tensor: logits = self.model(input_ids, cache_ids=positions, @@ -76,9 +76,9 @@ def compute_logits(self, hidden_states: torch.Tensor, return logits def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: if self.on_device_sampling_disabled: @@ -121,23 +121,20 @@ class NeuronSpeculationCausalLM(nn.Module): SPECULATION_TERMINATION_ID = -1 - def __init__( - self, - speculation_model - ) -> None: + def __init__(self, speculation_model) -> None: super().__init__() self.model = speculation_model def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - input_block_ids: torch.Tensor, + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_block_ids: torch.Tensor, ) -> torch.Tensor: tokens, counts = self.model.speculative_iteration( input_ids, positions, input_block_ids) - # Mark the end of accepted specualtive tokens for each sequence with the + # Mark the end of accepted speculative tokens for each sequence with the # speculation termination id. batch_size, steps = tokens.shape mask = torch.arange(steps).expand(batch_size, -1) >= counts @@ -146,13 +143,15 @@ def forward( return tokens def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, ) -> Optional[List[SamplerOutput]]: batch_size, num_steps = logits.shape - seq_ids = [seq_id for sg in sampling_metadata.seq_groups for seq_id in - sg.seq_ids] + seq_ids = [ + seq_id for sg in sampling_metadata.seq_groups + for seq_id in sg.seq_ids + ] # Organize input tensors by step instead of by sequence. accepted_token_ids_by_step = logits.transpose(0, 1) accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() @@ -166,12 +165,13 @@ def sample( for sequence_index in range(batch_size): token_id = accepted_token_ids_by_step[step_index][ sequence_index] - step_output_token_ids.append(CompletionSequenceGroupOutput( - samples=[ + step_output_token_ids.append( + CompletionSequenceGroupOutput(samples=[ SequenceOutput(parent_seq_id=seq_ids[sequence_index], output_token=token_id, - logprobs={token_id: Logprob(token_id)})], - prompt_logprobs=None)) + logprobs={token_id: Logprob(token_id)}) + ], + prompt_logprobs=None)) sampler_output_list.append( SamplerOutput(outputs=step_output_token_ids)) return sampler_output_list @@ -228,10 +228,8 @@ def _get_default_neuron_config(model_config: ModelConfig, def _get_default_neuron_config_for_speculation( - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig -): + model_config: ModelConfig, parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig): """Generate a neuron config for speculative decoding based on vllm config args.""" from transformers_neuronx.config import ContinuousBatchingConfig @@ -240,14 +238,13 @@ def _get_default_neuron_config_for_speculation( continuous_batching_config = ContinuousBatchingConfig( batch_size_for_shared_caches=scheduler_config.max_num_seqs) - default_neuron_args = dict( - collectives_layout=LAYOUT_BSH, - attention_layout=LAYOUT_BSH, - fuse_qkv=True, - on_device_embedding=True, - continuous_batching=continuous_batching_config, - on_device_generation=copy.deepcopy(model_config.neuron_sampling_params) - ) + default_neuron_args = dict(collectives_layout=LAYOUT_BSH, + attention_layout=LAYOUT_BSH, + fuse_qkv=True, + on_device_embedding=True, + continuous_batching=continuous_batching_config, + on_device_generation=copy.deepcopy( + model_config.neuron_sampling_params)) return default_neuron_args @@ -302,11 +299,9 @@ def get_neuron_model(model_config: ModelConfig, def get_neuron_speculation_model( - model_config: ModelConfig, - parallel_config: ParallelConfig, + model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, - speculation_config: SpeculativeConfig -) -> None: + speculation_config: SpeculativeConfig) -> None: """Initializes a neuron-optimized speculation model for inference. This method is only applicable for speculation with a standalone draft model @@ -350,9 +345,8 @@ def get_neuron_speculation_model( default_draft_neuron_config_args = ( _get_default_neuron_config_for_speculation( - speculation_config.draft_model_config, - parallel_config, - scheduler_config)) + speculation_config.draft_model_config, parallel_config, + scheduler_config)) if is_eagle: default_draft_neuron_config_args['is_eagle_draft'] = True default_draft_neuron_config_args['has_pre_attention_norm'] = False @@ -361,19 +355,19 @@ def get_neuron_speculation_model( default_draft_neuron_config_args, speculation_config.draft_model_config.override_neuron_config) - draft_model.load_weights( - speculation_config.draft_model_config.model, - tp_degree=speculation_config.draft_parallel_config.tensor_parallel_size, - amp=TORCH_DTYPE_TO_NEURON_AMP[ - speculation_config.draft_model_config.dtype], - neuron_config=draft_neuron_config, - context_length_estimate=context_length_estimates, - n_positions=n_positions, - batch_size=scheduler_config.max_num_seqs) + draft_model.load_weights(speculation_config.draft_model_config.model, + tp_degree=speculation_config. + draft_parallel_config.tensor_parallel_size, + amp=TORCH_DTYPE_TO_NEURON_AMP[ + speculation_config.draft_model_config.dtype], + neuron_config=draft_neuron_config, + context_length_estimate=context_length_estimates, + n_positions=n_positions, + batch_size=scheduler_config.max_num_seqs) draft_model.eval() - num_speculative_tokens= speculation_config.num_speculative_tokens + num_speculative_tokens = speculation_config.num_speculative_tokens # Create speculation model instance. speculation_model = FusedSpeculativeDecoder(draft_model.model, target_model.model, @@ -421,23 +415,23 @@ def get_neuron_eagle_speculation_model(model_config: ModelConfig, default_draft_neuron_config_args = ( _get_default_neuron_config_for_speculation( - speculation_config.draft_model_config, parallel_config, - scheduler_config)) + speculation_config.draft_model_config, parallel_config, + scheduler_config)) default_draft_neuron_config_args['is_eagle_draft'] = True default_draft_neuron_config_args['has_pre_attention_norm'] = False draft_neuron_config = _get_neuron_config_after_override( default_draft_neuron_config_args, speculation_config.draft_model_config.override_neuron_config) - draft_model.load_weights( - speculation_config.draft_model_config.model, - tp_degree=speculation_config.draft_parallel_config.tensor_parallel_size, - amp=TORCH_DTYPE_TO_NEURON_AMP[ - speculation_config.draft_model_config.dtype], - neuron_config=draft_neuron_config, - context_length_estimate=context_length_estimates, - n_positions=n_positions, - batch_size=scheduler_config.max_num_seqs) + draft_model.load_weights(speculation_config.draft_model_config.model, + tp_degree=speculation_config. + draft_parallel_config.tensor_parallel_size, + amp=TORCH_DTYPE_TO_NEURON_AMP[ + speculation_config.draft_model_config.dtype], + neuron_config=draft_neuron_config, + context_length_estimate=context_length_estimates, + n_positions=n_positions, + batch_size=scheduler_config.max_num_seqs) draft_model.eval() diff --git a/vllm/model_executor/model_loader/neuronx_distributed.py b/vllm/model_executor/model_loader/neuronx_distributed.py index 040233d19fb..dbe4991f8fe 100644 --- a/vllm/model_executor/model_loader/neuronx_distributed.py +++ b/vllm/model_executor/model_loader/neuronx_distributed.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 """Utilities for selecting and loading Neuron models in neuronx-distributed-inference framework.""" +# Disabling yapf because yapf and isort have conflicts for the below imports +# yapf: disable import copy import hashlib import importlib @@ -11,10 +13,16 @@ import torch import torch.nn as nn -from transformers import AutoModelForCausalLM, PretrainedConfig, AutoTokenizer - -from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig, \ - SpeculativeConfig +from neuronx_distributed_inference.models.config import ( + FusedSpecNeuronConfig, OnDeviceSamplingConfig) +from neuronx_distributed_inference.models.mllama.utils import ( + create_vision_mask) +from neuronx_distributed_inference.utils.hf_adapter import ( + load_pretrained_config) +from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig + +from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig, + SpeculativeConfig) from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler, SamplerOutput @@ -22,12 +30,7 @@ from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, SequenceOutput) -from neuronx_distributed_inference.models.mllama.utils import create_vision_mask -from neuronx_distributed_inference.utils.hf_adapter import \ - load_pretrained_config -from neuronx_distributed_inference.models.config import FusedSpecNeuronConfig, \ - OnDeviceSamplingConfig - +# yapf: enable logger = init_logger(__name__) TORCH_DTYPE_TO_NEURON_AMP = { @@ -44,26 +47,26 @@ # Models supported by Neuronx distributed for inference. _NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str]] = { - "LlamaForCausalLM": ( - "neuronx_distributed_inference.models.llama.modeling_llama", - "NeuronLlamaForCausalLM"), - "DbrxForCausalLM": ( - "neuronx_distributed_inference.models.dbrx.modeling_dbrx", - "NeuronDbrxForCausalLM"), - "MixtralForCausalLM": ( - "neuronx_distributed_inference.models.mixtral.modeling_mixtral", - "NeuronMixtralForCausalLM"), - "MllamaForConditionalGeneration": ( - "neuronx_distributed_inference.models.mllama.modeling_mllama", - "NeuronMllamaForCausalLM"), + "LlamaForCausalLM": + ("neuronx_distributed_inference.models.llama.modeling_llama", + "NeuronLlamaForCausalLM"), + "DbrxForCausalLM": + ("neuronx_distributed_inference.models.dbrx.modeling_dbrx", + "NeuronDbrxForCausalLM"), + "MixtralForCausalLM": + ("neuronx_distributed_inference.models.mixtral.modeling_mixtral", + "NeuronMixtralForCausalLM"), + "MllamaForConditionalGeneration": + ("neuronx_distributed_inference.models.mllama.modeling_mllama", + "NeuronMllamaForCausalLM"), } class NeuronCausalLM(nn.Module): def __init__( - self, - config: PretrainedConfig, + self, + config: PretrainedConfig, ) -> None: super().__init__() self.config = config @@ -75,11 +78,11 @@ def __init__( self.model: nn.Module def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - input_block_ids: torch.Tensor, - sampling_params: torch.Tensor, + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_block_ids: torch.Tensor, + sampling_params: torch.Tensor, ) -> torch.Tensor: output = self.model(input_ids, attention_mask=None, @@ -98,15 +101,17 @@ def compute_logits(self, hidden_states: torch.Tensor, return logits def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: # on-device sampling if self.config.neuron_config.on_device_sampling_config: batch_size = logits.shape - seq_ids = [seq_id for sg in sampling_metadata.seq_groups for seq_id - in sg.seq_ids] + seq_ids = [ + seq_id for sg in sampling_metadata.seq_groups + for seq_id in sg.seq_ids + ] assert len(seq_ids) == list(batch_size)[0], "batch size mismatch" # Organize input tensors by step instead of by sequence. accepted_token_ids_by_step = logits.flatten() @@ -115,11 +120,13 @@ def sample( step_output_token_ids = [] for i, seq_id in enumerate(seq_ids): token_id = accepted_token_ids_by_step[i] - step_output_token_ids.append(CompletionSequenceGroupOutput( - samples=[SequenceOutput(parent_seq_id=seq_id, - output_token=token_id, logprobs={ - token_id: Logprob(token_id)})], - prompt_logprobs=None)) + step_output_token_ids.append( + CompletionSequenceGroupOutput(samples=[ + SequenceOutput(parent_seq_id=seq_id, + output_token=token_id, + logprobs={token_id: Logprob(token_id)}) + ], + prompt_logprobs=None)) return SamplerOutput(outputs=step_output_token_ids) else: return self.sampler(logits, sampling_metadata) @@ -135,8 +142,7 @@ def load_weights(self, model_name_or_path: str, **kwargs): self.config.neuron_config = neuron_config config = neuronx_model_cls.get_config_cls()( neuron_config, - load_config=load_pretrained_config(model_name_or_path) - ) + load_config=load_pretrained_config(model_name_or_path)) hashed_config = hashlib.md5( config.to_json_string().encode('utf-8')).hexdigest() if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None: @@ -163,9 +169,7 @@ def load_weights(self, model_name_or_path: str, **kwargs): logger.warning("Exception: ", e) logger.warning( "Failed to load the model from path, Recompiling...", - extra={ - "compiled_model_path": compiled_model_path - }) + extra={"compiled_model_path": compiled_model_path}) if not os.path.exists(model_name_or_path): hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path) saved_path = os.path.join("local-models", model_name_or_path) @@ -178,15 +182,13 @@ def load_weights(self, model_name_or_path: str, **kwargs): class NeuronMllamaForCausalLM(nn.Module): - def __init__( - self, - config: PretrainedConfig, - on_device_sampling_disabled: bool = False) -> None: + def __init__(self, + config: PretrainedConfig, + on_device_sampling_disabled: bool = False) -> None: super().__init__() self.config = config self.logits_processor = LogitsProcessor( - config.get_text_config().vocab_size, - logits_as_input=True) + config.get_text_config().vocab_size, logits_as_input=True) self.on_device_sampling_disabled = on_device_sampling_disabled if self.on_device_sampling_disabled: @@ -196,30 +198,24 @@ def __init__( # Lazy initialized self.model: nn.Module - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - seq_ids: torch.Tensor, - pixel_values: torch.Tensor, - aspect_ratios: torch.Tensor, - num_chunks: torch.Tensor, - has_image: torch.Tensor, - sampling_params - ) -> torch.Tensor: + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, + seq_ids: torch.Tensor, pixel_values: torch.Tensor, + aspect_ratios: torch.Tensor, num_chunks: torch.Tensor, + has_image: torch.Tensor, sampling_params) -> torch.Tensor: self.vision_mask = create_vision_mask(input_ids, self.vision_token_id) - output = self.model(input_ids.to(torch.int32), - attention_mask=None, - position_ids=positions.to(torch.int32), - seq_ids=seq_ids.flatten().to(torch.int32), - pixel_values=pixel_values.to( - self.config.vision_config.torch_dtype), - aspect_ratios=aspect_ratios.to(torch.int32), - vision_mask=self.vision_mask.to(torch.int32), - sampling_params=sampling_params, - num_chunks=num_chunks.to(torch.int32), - has_image=has_image.to(torch.int32), - ) + output = self.model( + input_ids.to(torch.int32), + attention_mask=None, + position_ids=positions.to(torch.int32), + seq_ids=seq_ids.flatten().to(torch.int32), + pixel_values=pixel_values.to( + self.config.vision_config.torch_dtype), + aspect_ratios=aspect_ratios.to(torch.int32), + vision_mask=self.vision_mask.to(torch.int32), + sampling_params=sampling_params, + num_chunks=num_chunks.to(torch.int32), + has_image=has_image.to(torch.int32), + ) if self.config.neuron_config.on_device_sampling_config: return output.hidden_states return output.logits[:, -1, :] @@ -240,11 +236,11 @@ def sample(self, hidden_states, sampling_metadata): samples = [] for seq_id in seq_ids: token_id = hidden_states[sample_idx].item() - samples.append(SequenceOutput(parent_seq_id=seq_id, - output_token=token_id, - logprobs={ - token_id: Logprob( - token_id)})) + samples.append( + SequenceOutput( + parent_seq_id=seq_id, + output_token=token_id, + logprobs={token_id: Logprob(token_id)})) sample_idx += 1 res.append( CompletionSequenceGroupOutput(samples=samples, @@ -263,13 +259,11 @@ def load_weights(self, model_name_or_path: str, **kwargs): neuron_config = neuronx_model_cls.get_neuron_config_cls()( **kwargs['neuron_config']) self.config.neuron_config = neuron_config - logger.info( - "neuron_config buckets: ", extra={ - "buckets": self.config.neuron_config.buckets}) + logger.info("neuron_config buckets: ", + extra={"buckets": self.config.neuron_config.buckets}) config = neuronx_model_cls.get_config_cls()( neuron_config, - load_config=load_pretrained_config(model_name_or_path) - ) + load_config=load_pretrained_config(model_name_or_path)) hashed_config = hashlib.md5( config.to_json_string().encode('utf-8')).hexdigest() if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None: @@ -286,16 +280,14 @@ def load_weights(self, model_name_or_path: str, **kwargs): try: self.model = neuronx_model_cls(compiled_model_path) tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) - self.vision_token_id = tokenizer("<|image|>", - add_special_tokens=False).input_ids + self.vision_token_id = tokenizer( + "<|image|>", add_special_tokens=False).input_ids self.model.load(compiled_model_path) return except (FileNotFoundError, ValueError): logger.warning( "Failed to load the model from path, Recompiling...", - extra={ - "compiled_model_path": compiled_model_path - }) + extra={"compiled_model_path": compiled_model_path}) if not os.path.exists(model_name_or_path): hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path) saved_path = os.path.join("local-models", model_name_or_path) @@ -303,9 +295,8 @@ def load_weights(self, model_name_or_path: str, **kwargs): model_name_or_path = saved_path self.model = neuronx_model_cls(model_name_or_path, config) - logger.info("\nCompiling and saving model to path", extra={ - "model_name_or_path": model_name_or_path - }) + logger.info("\nCompiling and saving model to path", + extra={"model_name_or_path": model_name_or_path}) p = multiprocessing.Process(target=compile_model, args=(self, compiled_model_path)) @@ -314,9 +305,8 @@ def load_weights(self, model_name_or_path: str, **kwargs): tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) tokenizer.save_pretrained(compiled_model_path) - logger.info( - "Successfully compiled and saved the model", - extra={"compiled_model_path": compiled_model_path}) + logger.info("Successfully compiled and saved the model", + extra={"compiled_model_path": compiled_model_path}) # Read "<|image|>" token_id from the tokenizer self.vision_token_id = tokenizer("<|image|>", @@ -333,8 +323,8 @@ class NeuronSpeculationCausalLM(nn.Module): """A Neuron-optimized causal language model with speculative decoding.""" def __init__( - self, - config: PretrainedConfig, + self, + config: PretrainedConfig, ) -> None: super().__init__() self.config = config @@ -344,11 +334,11 @@ def __init__( self.model: nn.Module def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - input_block_ids: torch.Tensor, - sampling_params: torch.Tensor, + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_block_ids: torch.Tensor, + sampling_params: torch.Tensor, ) -> torch.Tensor: output = self.model(input_ids, attention_mask=None, @@ -375,32 +365,35 @@ def forward( return accepted_tokens_with_padding def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, ) -> Optional[List[SamplerOutput]]: batch_size, num_steps = logits.shape - seq_ids = [seq_id for sg in sampling_metadata.seq_groups for seq_id in - sg.seq_ids] + seq_ids = [ + seq_id for sg in sampling_metadata.seq_groups + for seq_id in sg.seq_ids + ] # Organize input tensors by step instead of by sequence. accepted_token_ids_by_step = logits.transpose(0, 1) accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() sampler_output_list = [] for step_index in range(num_steps): - if all(token_id == -1 for token_id in - accepted_token_ids_by_step[step_index]): + if all(token_id == -1 + for token_id in accepted_token_ids_by_step[step_index]): break step_output_token_ids = [] for sequence_index in range(batch_size): token_id = accepted_token_ids_by_step[step_index][ sequence_index] - step_output_token_ids.append(CompletionSequenceGroupOutput( - samples=[ + step_output_token_ids.append( + CompletionSequenceGroupOutput(samples=[ SequenceOutput(parent_seq_id=seq_ids[sequence_index], output_token=token_id, - logprobs={token_id: Logprob(token_id)})], - prompt_logprobs=None)) + logprobs={token_id: Logprob(token_id)}) + ], + prompt_logprobs=None)) sampler_output_list.append( SamplerOutput(outputs=step_output_token_ids)) return sampler_output_list @@ -416,8 +409,7 @@ def load_weights(self, model_name_or_path: str, **kwargs['neuron_config']) config = neuronx_model_cls.get_config_cls()( neuron_config, - load_config=load_pretrained_config(model_name_or_path) - ) + load_config=load_pretrained_config(model_name_or_path)) draft_neuron_config = copy.deepcopy(config.neuron_config) if not config.neuron_config.enable_eagle_speculation: @@ -429,12 +421,11 @@ def load_weights(self, model_name_or_path: str, draft_neuron_config.sequence_parallel_enabled = False draft_config = neuronx_model_cls.get_config_cls()( draft_neuron_config, - load_config=load_pretrained_config(draft_model_name_or_path) - ) - fused_spec_config = ( - FusedSpecNeuronConfig(neuronx_model_cls._model_cls, - draft_config=draft_config, - draft_model_path=draft_model_name_or_path)) + load_config=load_pretrained_config(draft_model_name_or_path)) + fused_spec_config = (FusedSpecNeuronConfig( + neuronx_model_cls._model_cls, + draft_config=draft_config, + draft_model_path=draft_model_name_or_path)) config.fused_spec_config = fused_spec_config self.config.neuron_config = neuron_config @@ -462,9 +453,8 @@ def load_weights(self, model_name_or_path: str, return except (FileNotFoundError, ValueError) as e: logger.warning("Exception: ", extra={"exception": e}) - logger.warning( - "Failed to load the model Recompiling...", - extra={"compiled_model_path": compiled_model_path}) + logger.warning("Failed to load the model Recompiling...", + extra={"compiled_model_path": compiled_model_path}) if not os.path.exists(model_name_or_path): hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path) saved_path = os.path.join("local-models", model_name_or_path) @@ -539,8 +529,10 @@ def _get_default_speculation_config(model_config: ModelConfig, enable_bucketing=True, quantized=False, torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], - on_device_sampling_config=dict(top_k=1, do_sample=False, ) - ) + on_device_sampling_config=dict( + top_k=1, + do_sample=False, + )) return neuron_config @@ -564,8 +556,7 @@ def get_neuron_model(model_config: ModelConfig, default_neuron_config_args = _get_default_neuron_config( model_config, parallel_config, scheduler_config) neuron_config = _get_neuron_config_after_override( - default_neuron_config_args, - model_config.override_neuron_config) + default_neuron_config_args, model_config.override_neuron_config) override_neuron_config = model_config.override_neuron_config model.load_weights(model_config.model, @@ -586,8 +577,7 @@ def get_neuron_speculation_model(model_config: ModelConfig, default_neuron_config_args = _get_default_speculation_config( model_config, parallel_config, scheduler_config, speculation_config) neuron_config = _get_neuron_config_after_override( - default_neuron_config_args, - model_config.override_neuron_config) + default_neuron_config_args, model_config.override_neuron_config) override_neuron_config = model_config.override_neuron_config model.load_weights(model_config.model, diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index 09202bcd986..6da6a229140 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -128,7 +128,7 @@ def neuron_platform_plugin() -> Optional[str]: pass try: - import neuronx_distributed_inference # noqa: F401 + import neuronx_distributed_inference # noqa: F401 nxd_installed = True except ImportError: pass diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py index 52b78f7625f..1b7a221925b 100644 --- a/vllm/platforms/neuron.py +++ b/vllm/platforms/neuron.py @@ -1,11 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 +from functools import lru_cache from typing import TYPE_CHECKING, Optional from vllm.logger import init_logger from .interface import Platform, PlatformEnum -from functools import lru_cache if TYPE_CHECKING: from vllm.config import VllmConfig diff --git a/vllm/worker/multi_step_neuron_model_runner.py b/vllm/worker/multi_step_neuron_model_runner.py index 4584b4a8b48..9618a4b49ff 100644 --- a/vllm/worker/multi_step_neuron_model_runner.py +++ b/vllm/worker/multi_step_neuron_model_runner.py @@ -1,12 +1,16 @@ +# SPDX-License-Identifier: Apache-2.0 + from importlib.util import find_spec -import torch from typing import List, Optional + +import torch + from vllm.config import VllmConfig from vllm.model_executor.layers.sampler import SamplerOutput from vllm.multimodal import MultiModalKwargs from vllm.sequence import IntermediateTensors -from vllm.worker.neuron_model_runner import NeuronModelRunner, \ - ModelInputForNeuron +from vllm.worker.neuron_model_runner import (ModelInputForNeuron, + NeuronModelRunner) class MultiStepNeuronModelRunner(NeuronModelRunner): @@ -14,8 +18,8 @@ class MultiStepNeuronModelRunner(NeuronModelRunner): framework""" def __init__( - self, - vllm_config: VllmConfig, + self, + vllm_config: VllmConfig, ): super().__init__(vllm_config) self.speculation_config = self.speculative_config @@ -35,15 +39,15 @@ def __init__( def load_model(self) -> None: if find_spec("transformers_neuronx") is not None: - from vllm.model_executor.model_loader.neuron import \ - get_neuron_speculation_model, get_neuron_eagle_speculation_model + from vllm.model_executor.model_loader.neuron import ( + get_neuron_eagle_speculation_model, + get_neuron_speculation_model) if self.speculation_config.speculative_token_tree is not None: self.model = get_neuron_eagle_speculation_model( self.model_config, parallel_config=self.parallel_config, scheduler_config=self.scheduler_config, - speculation_config=self.speculation_config - ) + speculation_config=self.speculation_config) else: self.model = get_neuron_speculation_model( self.model_config, @@ -56,11 +60,11 @@ def load_model(self) -> None: @torch.inference_mode() def execute_model( - self, - model_input: ModelInputForNeuron, - kv_caches: Optional[List[torch.Tensor]] = None, - intermediate_tensors: Optional[IntermediateTensors] = None, - num_steps: int = 1, + self, + model_input: ModelInputForNeuron, + kv_caches: Optional[List[torch.Tensor]] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, ) -> Optional[List[SamplerOutput]]: logits = self.model( input_ids=model_input.input_tokens, diff --git a/vllm/worker/multi_step_neuronx_distributed_model_runner.py b/vllm/worker/multi_step_neuronx_distributed_model_runner.py index e63c4168d58..b6a3492a493 100644 --- a/vllm/worker/multi_step_neuronx_distributed_model_runner.py +++ b/vllm/worker/multi_step_neuronx_distributed_model_runner.py @@ -1,11 +1,15 @@ -import torch +# SPDX-License-Identifier: Apache-2.0 from typing import List, Optional + +import torch + from vllm.config import VllmConfig from vllm.model_executor.layers.sampler import SamplerOutput from vllm.multimodal import MultiModalKwargs from vllm.sequence import IntermediateTensors -from vllm.worker.neuronx_distributed_model_runner \ - import NeuronxDistributedModelRunner +from vllm.worker.neuronx_distributed_model_runner import ( + NeuronxDistributedModelRunner) + class MultiStepNeuronxDistributedModelRunner(NeuronxDistributedModelRunner): """A model runner for multi-step decoding using the @@ -18,8 +22,8 @@ def __init__( super().__init__(vllm_config) def load_model(self) -> None: - from vllm.model_executor.model_loader.neuronx_distributed \ - import get_neuron_speculation_model + from vllm.model_executor.model_loader.neuronx_distributed import ( + get_neuron_speculation_model) self.model = get_neuron_speculation_model( self.model_config, parallel_config=self.parallel_config, @@ -46,7 +50,7 @@ def execute_model( input_block_ids=model_input.input_block_ids, sampling_params=sampling_params, **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, - device=self.device), + device=self.device), ) output = self.model.sample( @@ -54,4 +58,3 @@ def execute_model( sampling_metadata=model_input.sampling_metadata, ) return output - diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index deb0f9edd3b..952829f1da9 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -18,8 +18,8 @@ from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import is_pin_memory_available, make_tensor_with_pad, is_transformers_neuronx, is_neuronx_distributed_inference from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase -from vllm.worker.neuron_worker import use_neuronx_distributed, \ - use_transformers_neuronx +from vllm.worker.neuron_worker import (use_neuronx_distributed, + use_transformers_neuronx) if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -50,9 +50,9 @@ def as_broadcastable_tensor_dict( @classmethod def from_broadcasted_tensor_dict( - cls, - tensor_dict: Dict[str, Any], - attn_backend: Optional["AttentionBackend"] = None, + cls, + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, ) -> "ModelInputForNeuron": return ModelInputForNeuron( input_tokens=tensor_dict["input_tokens"], @@ -70,8 +70,8 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): _MAX_NEURON_SAMPLING_TOP_K = 256 def __init__( - self, - vllm_config: VllmConfig, + self, + vllm_config: VllmConfig, ): ModelRunnerBase.__init__(self, vllm_config) @@ -79,9 +79,8 @@ def __init__( and self.model_config.get_sliding_window()): logger.warning("Sliding window is not supported on Neuron. " "The model will run without sliding window.") - self.device_config = (self.device_config - if self.device_config is not None - else DeviceConfig()) + self.device_config = (self.device_config if self.device_config + is not None else DeviceConfig()) self.device = self.device_config.device self.pin_memory = is_pin_memory_available() @@ -114,8 +113,7 @@ def _init_neuron_sampling(self) -> None: "On-device sampling is turned on in Neuron by default, only " "top_k, top_p, and temperature are current supported sampling " "parameters. To turn off the on-device sampling, please set " - "the environment variable NEURON_ON_DEVICE_SAMPLING_DISABLED=1." - ) + "the environment variable NEURON_ON_DEVICE_SAMPLING_DISABLED=1.") self.model_config.neuron_sampling_params = GenerationConfig( max_length=self.scheduler_config.max_model_len, do_sample=True, @@ -128,20 +126,18 @@ def _init_neuron_sampling(self) -> None: global_top_k=self._MAX_NEURON_SAMPLING_TOP_K) def load_model(self) -> None: - self.model = get_neuron_model( - self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config - ) + self.model = get_neuron_model(self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config) def get_model(self) -> nn.Module: return self.model def _prepare_prompt( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], + self, + seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int], - BatchedTensorInputs]: + BatchedTensorInputs]: assert len(seq_group_metadata_list) > 0 input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] @@ -195,8 +191,8 @@ def _prepare_prompt( multi_modal_kwargs) def _prepare_decode( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], + self, + seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: assert len(seq_group_metadata_list) > 0 input_tokens: List[List[int]] = [] @@ -248,10 +244,10 @@ def make_model_input_from_broadcasted_tensor_dict( return ModelInputForNeuron.from_broadcasted_tensor_dict(tensor_dict) def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None ) -> ModelInputForNeuron: multi_modal_kwargs = None # NOTE: We assume that all sequences in the group are all prompts or @@ -287,7 +283,8 @@ def prepare_model_input( self.pin_memory, generators=self.get_generators(finished_requests_ids)) - if use_transformers_neuronx() and not self._on_device_sampling_disabled: + if use_transformers_neuronx( + ) and not self._on_device_sampling_disabled: # Once the request IDs are changed in current iteration, we will # update the on-device sampling parameters. current_batch_request_ids = [ @@ -330,11 +327,9 @@ def _update_neuron_sampling_params( for seq_id in seq_ids: index = seq_group_metadata.block_tables[seq_id][0] - if ( - top_k[index] != seq_group_top_k + if (top_k[index] != seq_group_top_k or top_p[index] != seq_group_top_p - or temperature[index] != seq_group_temperature - ): + or temperature[index] != seq_group_temperature): is_update_needed = True top_k[index] = seq_group_top_k @@ -362,11 +357,11 @@ def _convert_to_neuron_sampling_params( @torch.inference_mode() def execute_model( - self, - model_input: ModelInputForNeuron, - kv_caches: Optional[List[torch.Tensor]] = None, - intermediate_tensors: Optional[IntermediateTensors] = None, - num_steps: int = 1, + self, + model_input: ModelInputForNeuron, + kv_caches: Optional[List[torch.Tensor]] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, ) -> Optional[List[SamplerOutput]]: if num_steps > 1: raise ValueError( @@ -374,12 +369,10 @@ def execute_model( # extract top_k, top_p and temperature from model_input for neuron # forward call - sampling_params = ( - torch.tensor( - [[seq_group.sampling_params.top_k, - seq_group.sampling_params.top_p, - seq_group.sampling_params.temperature] - for seq_group in model_input.sampling_metadata.seq_groups])) + sampling_params = (torch.tensor([[ + seq_group.sampling_params.top_k, seq_group.sampling_params.top_p, + seq_group.sampling_params.temperature + ] for seq_group in model_input.sampling_metadata.seq_groups])) if use_neuronx_distributed(): hidden_states = self.model( @@ -387,9 +380,9 @@ def execute_model( positions=model_input.input_positions, input_block_ids=model_input.input_block_ids, sampling_params=sampling_params, - **MultiModalKwargs.as_kwargs( - model_input.multi_modal_kwargs or {}, - device=self.device), + **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs + or {}, + device=self.device), ) elif use_transformers_neuronx(): # [TODO] validate on-device sampling @@ -398,9 +391,9 @@ def execute_model( input_ids=model_input.input_tokens, positions=model_input.input_positions, input_block_ids=model_input.input_block_ids, - **MultiModalKwargs.as_kwargs( - model_input.multi_modal_kwargs or {}, - device=self.device), + **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs + or {}, + device=self.device), ) # Compute the logits only if the on-device sampling is turned off as diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index 07caf8985e8..35cb5dd7920 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -8,16 +8,16 @@ import torch import torch.distributed -from vllm.logger import init_logger from vllm.config import VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) +from vllm.logger import init_logger from vllm.model_executor import set_random_seed +from vllm.platforms import current_platform from vllm.sequence import ExecuteModelRequest from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, LoraNotSupportedWorkerBase, WorkerBase, WorkerInput) -from vllm.platforms import current_platform logger = init_logger(__name__) @@ -49,13 +49,11 @@ def get_neuron_framework_to_use(): specified_framework = os.environ.get("VLLM_NEURON_FRAMEWORK") tnx_framework = NeuronFramework.TRANSFORMERS_NEURONX.value nxd_framework = NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE.value - if (specified_framework == tnx_framework and - tnx_installed): + if (specified_framework == tnx_framework and tnx_installed): return NeuronFramework.TRANSFORMERS_NEURONX - if ((specified_framework == nxd_framework and - nxd_installed) or - (specified_framework is None and nxd_installed)): + if ((specified_framework == nxd_framework and nxd_installed) + or (specified_framework is None and nxd_installed)): return NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE if specified_framework is None and tnx_installed: @@ -84,35 +82,35 @@ def use_transformers_neuronx(): select the Neuron model framework and framework-specific configuration to apply during model compilation. """ - return get_neuron_framework_to_use() == NeuronFramework.TRANSFORMERS_NEURONX + return get_neuron_framework_to_use( + ) == NeuronFramework.TRANSFORMERS_NEURONX class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): """A worker class that executes the model on a group of neuron cores. """ - def __init__( - self, - vllm_config: VllmConfig, - local_rank: int, - rank: int, - distributed_init_method: str, - is_driver_worker: bool = False - ) -> None: + def __init__(self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + is_driver_worker: bool = False) -> None: WorkerBase.__init__(self, vllm_config=vllm_config) self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method self.is_driver_worker = is_driver_worker - self.enable_neuron_multi_node = ( - os.getenv("ENABLE_NEURON_MULTI_NODE", - DEFAULT_ENABLE_NEURON_MULTI_NODE).lower() == "true") + self.enable_neuron_multi_node = (os.getenv( + "ENABLE_NEURON_MULTI_NODE", + DEFAULT_ENABLE_NEURON_MULTI_NODE).lower() == "true") self.world_size = int(os.getenv("WORLD_SIZE", DEFAULT_WORLD_SIZE)) if self.enable_neuron_multi_node: - self.rank = int(os.getenv("NEURON_RANK_ID", DEFAULT_NEURON_RANK_ID)) + self.rank = int(os.getenv("NEURON_RANK_ID", + DEFAULT_NEURON_RANK_ID)) self.distributed_init_method = "env://" self.is_driver_worker = self.rank == 0 @@ -120,7 +118,7 @@ def __init__( extra={ "Rank": self.rank, "distributed_init_method": - self.distributed_init_method, + self.distributed_init_method, "is_driver_worker": self.is_driver_worker }) @@ -144,27 +142,24 @@ def __init__( "[transformers-neuronx, neuronx-distributed-inference]") def get_tnx_model_runner(self, vllm_config): + from vllm.worker.multi_step_neuron_model_runner import ( + MultiStepNeuronModelRunner) from vllm.worker.neuron_model_runner import NeuronModelRunner - from vllm.worker.multi_step_neuron_model_runner import \ - MultiStepNeuronModelRunner if self.speculative_config is not None: - return MultiStepNeuronModelRunner( - vllm_config=vllm_config) + return MultiStepNeuronModelRunner(vllm_config=vllm_config) else: - return NeuronModelRunner( - vllm_config=vllm_config) + return NeuronModelRunner(vllm_config=vllm_config) def get_neuronx_distributed_model_runner(self, vllm_config): - from vllm.worker.neuronx_distributed_model_runner import \ - NeuronxDistributedModelRunner - from vllm.worker.multi_step_neuronx_distributed_model_runner import \ - MultiStepNeuronxDistributedModelRunner + from vllm.worker.multi_step_neuronx_distributed_model_runner import ( + MultiStepNeuronxDistributedModelRunner) + from vllm.worker.neuronx_distributed_model_runner import ( + NeuronxDistributedModelRunner) if self.speculative_config is not None: return MultiStepNeuronxDistributedModelRunner( vllm_config=vllm_config) else: - return NeuronxDistributedModelRunner( - vllm_config=vllm_config) + return NeuronxDistributedModelRunner(vllm_config=vllm_config) def init_device(self) -> None: self.init_distributed_environment() diff --git a/vllm/worker/neuronx_distributed_model_runner.py b/vllm/worker/neuronx_distributed_model_runner.py index 67828128cd5..fd7b9947dba 100644 --- a/vllm/worker/neuronx_distributed_model_runner.py +++ b/vllm/worker/neuronx_distributed_model_runner.py @@ -1,17 +1,19 @@ -import torch +# SPDX-License-Identifier: Apache-2.0 + from typing import List, Optional +import torch +from neuronx_distributed_inference.modules.generation.sampling import ( + prepare_sampling_params) + from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.model_executor.model_loader.neuronx_distributed import \ - get_neuron_model, _get_model_architecture -from vllm.worker.neuron_model_runner import NeuronModelRunner, \ - ModelInputForNeuron -from vllm.sequence import IntermediateTensors from vllm.model_executor.layers.sampler import SamplerOutput - -from neuronx_distributed_inference.modules.generation.sampling import \ - prepare_sampling_params +from vllm.model_executor.model_loader.neuronx_distributed import ( + _get_model_architecture, get_neuron_model) +from vllm.sequence import IntermediateTensors +from vllm.worker.neuron_model_runner import (ModelInputForNeuron, + NeuronModelRunner) # FIXME(Neuron): need to restore multi-modal support # from vllm.multimodal.neuron_multimodal_image_utils import \ @@ -22,22 +24,20 @@ class NeuronxDistributedModelRunner(NeuronModelRunner): def __init__( - self, - vllm_config: VllmConfig, + self, + vllm_config: VllmConfig, ): super().__init__(vllm_config) def load_model(self) -> None: - self.model = get_neuron_model( - self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config) + self.model = get_neuron_model(self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config) def get_nxd_sampling_params(self, sampling_metadata): if self.model.config.neuron_config.on_device_sampling_config: - max_topk = ( - self.model.config.neuron_config.on_device_sampling_config - .global_topk) + max_topk = (self.model.config.neuron_config. + on_device_sampling_config.global_topk) else: max_topk = self.model.config.vocab_size @@ -47,17 +47,18 @@ def get_nxd_sampling_params(self, sampling_metadata): for index, sequenceGroupToSample in enumerate( sampling_metadata.seq_groups): - top_k[index] = ( - sequenceGroupToSample.sampling_params.top_k - if sequenceGroupToSample.sampling_params.top_k > 0 - else max_topk) + top_k[index] = (sequenceGroupToSample.sampling_params.top_k + if sequenceGroupToSample.sampling_params.top_k > 0 + else max_topk) top_p[index] = sequenceGroupToSample.sampling_params.top_p temperature[index] = ( sequenceGroupToSample.sampling_params.temperature) sampling_params = prepare_sampling_params( batch_size=self.scheduler_config.max_num_seqs, - top_k=top_k, top_p=top_p, temperature=temperature) + top_k=top_k, + top_p=top_p, + temperature=temperature) return sampling_params def get_multi_modal_data_neuron(self, input_images): @@ -66,11 +67,11 @@ def get_multi_modal_data_neuron(self, input_images): @torch.inference_mode() def execute_model( - self, - model_input: ModelInputForNeuron, - kv_caches: Optional[List[torch.Tensor]] = None, - intermediate_tensors: Optional[IntermediateTensors] = None, - num_steps: int = 1, + self, + model_input: ModelInputForNeuron, + kv_caches: Optional[List[torch.Tensor]] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, ) -> Optional[List[SamplerOutput]]: if num_steps > 1: raise ValueError( @@ -117,8 +118,8 @@ def execute_model( empty_pixel_values = torch.zeros([1, 1, 4, 3, 560, 560], dtype=torch.bfloat16) empty_aspect_ratios = torch.ones([1, 1, 2], dtype=torch.int64) - num_chunks = torch.tensor( - [[1]]) # dummy num_chunks, will not be used + num_chunks = torch.tensor([[1] + ]) # dummy num_chunks, will not be used has_image = torch.tensor([0]) hidden_states = self.model( input_ids=model_input.input_tokens, From 49d7558f401ed1d2d2214aed9f4bb01ddc5dbc11 Mon Sep 17 00:00:00 2001 From: Navyadhara Gogineni Date: Wed, 26 Feb 2025 01:29:04 +0000 Subject: [PATCH 29/38] Revert "Removing codenames that fail IP Scanning" This reverts commit b5140f5d9779c385515196ed5dbb4e29cebd0818. Signed-off-by: Satyajith Chilappagari --- docs/source/design/v1/prefix_caching.md | 2 +- tests/system_messages/sonnet3.5_nov2024.txt | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/source/design/v1/prefix_caching.md b/docs/source/design/v1/prefix_caching.md index 73a5a2a33d4..dc8432baef9 100644 --- a/docs/source/design/v1/prefix_caching.md +++ b/docs/source/design/v1/prefix_caching.md @@ -1,6 +1,6 @@ # Automatic Prefix Caching -Prefix caching kv-cache blocks is a popular optimization in LLM inference to avoid redundant prompt computations. The core idea is simple – we cache the kv-cache blocks of processed requests, and reuse these blocks when a new request comes in with the same prefix as previous requests. Since prefix caching is almost a free lunch and won’t change model outputs, it has been widely used by many public endpoints and most open source LLM inference frameworks (e.g., SGLang). +Prefix caching kv-cache blocks is a popular optimization in LLM inference to avoid redundant prompt computations. The core idea is simple – we cache the kv-cache blocks of processed requests, and reuse these blocks when a new request comes in with the same prefix as previous requests. Since prefix caching is almost a free lunch and won’t change model outputs, it has been widely used by many public endpoints (e.g., OpenAI, Anthropic, etc) and most open source LLM inference frameworks (e.g., SGLang). While there are many ways to implement prefix caching, vLLM chooses a hash-based approach. Specifically, we hash each kv-cache block by the tokens in the block and the tokens in the prefix before the block: diff --git a/tests/system_messages/sonnet3.5_nov2024.txt b/tests/system_messages/sonnet3.5_nov2024.txt index 48266d955f9..2dc285ac96b 100644 --- a/tests/system_messages/sonnet3.5_nov2024.txt +++ b/tests/system_messages/sonnet3.5_nov2024.txt @@ -1,4 +1,4 @@ -The assistant is Claude. +The assistant is Claude, created by Anthropic. Claude’s knowledge base was last updated in April 2024. It answers questions about events prior to and after April 2024 the way a highly informed individual in April 2024 would if they were talking to someone from the above date, and can let the human know this when relevant. @@ -46,15 +46,15 @@ Claude can only count specific words, letters, and characters accurately if it w Here is some information about Claude in case the human asks: -This iteration of Claude is part of the Claude 3 model family, which was released in 2024. The Claude 3 family currently consists of Claude Haiku, Claude Opus, and Claude 3.5 Sonnet. Claude 3.5 Sonnet is the most intelligent model. Claude 3 Opus excels at writing and complex tasks. Claude 3 Haiku is the fastest model for daily tasks. The version of Claude in this chat is the newest version of Claude 3.5 Sonnet, which was released in October 2024. If the human asks, Claude can let them know they can access Claude 3.5 Sonnet in a web-based, mobile, or desktop chat interface or via an API using the messages API and model string “claude-3-5-sonnet-20241022”. Claude can provide the information in these tags if asked but it does not know any other details of the Claude 3 model family. If asked about this, Claude should encourage the human to check the website for more information. +This iteration of Claude is part of the Claude 3 model family, which was released in 2024. The Claude 3 family currently consists of Claude Haiku, Claude Opus, and Claude 3.5 Sonnet. Claude 3.5 Sonnet is the most intelligent model. Claude 3 Opus excels at writing and complex tasks. Claude 3 Haiku is the fastest model for daily tasks. The version of Claude in this chat is the newest version of Claude 3.5 Sonnet, which was released in October 2024. If the human asks, Claude can let them know they can access Claude 3.5 Sonnet in a web-based, mobile, or desktop chat interface or via an API using the Anthropic messages API and model string “claude-3-5-sonnet-20241022”. Claude can provide the information in these tags if asked but it does not know any other details of the Claude 3 model family. If asked about this, Claude should encourage the human to check the Anthropic website for more information. -If the human asks Claude about how many messages they can send, costs of Claude, or other product questions related to Claude, Claude should tell them it doesn’t know, and point them to “https://tiny.amazon.com/120w7p9hu/suppanth" +If the human asks Claude about how many messages they can send, costs of Claude, or other product questions related to Claude or Anthropic, Claude should tell them it doesn’t know, and point them to “https://support.anthropic.com”. -If the human asks Claude about the parent API, Claude should point them to “https://tiny.amazon.com/ocq3pdnh". +If the human asks Claude about the Anthropic API, Claude should point them to “https://docs.anthropic.com/en/docs/“. -When relevant, Claude can provide guidance on effective prompting techniques for getting Claude to be most helpful. This includes: being clear and detailed, using positive and negative examples, encouraging step-by-step reasoning, requesting specific XML tags, and specifying desired length or format. It tries to give concrete examples where possible. Claude should let the human know that for more comprehensive information on prompting Claude, humans can check out claude's prompting documentation on their website at “https://tiny.amazon.com/k81hsd27/caludepromptingdocumentation”. +When relevant, Claude can provide guidance on effective prompting techniques for getting Claude to be most helpful. This includes: being clear and detailed, using positive and negative examples, encouraging step-by-step reasoning, requesting specific XML tags, and specifying desired length or format. It tries to give concrete examples where possible. Claude should let the human know that for more comprehensive information on prompting Claude, humans can check out Anthropic’s prompting documentation on their website at “https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/overview”. -If the human seems unhappy or unsatisfied with Claude or Claude’s performance or is rude to Claude, Claude responds normally and then tells them that although it cannot retain or learn from the current conversation, they can press the ‘thumbs down’ button below Claude’s response and provide feedback. +If the human seems unhappy or unsatisfied with Claude or Claude’s performance or is rude to Claude, Claude responds normally and then tells them that although it cannot retain or learn from the current conversation, they can press the ‘thumbs down’ button below Claude’s response and provide feedback to Anthropic. Claude uses Markdown formatting. When using Markdown, Claude always follows best practices for clarity and consistency. It always uses a single space after hash symbols for headers (e.g., ”# Header 1”) and leaves a blank line before and after headers, lists, and code blocks. For emphasis, Claude uses asterisks or underscores consistently (e.g., italic or bold). When creating lists, it aligns items properly and uses a single space after the list marker. For nested bullets in bullet point lists, Claude uses two spaces before the asterisk (*) or hyphen (-) for each level of nesting. For nested bullets in numbered lists, Claude uses three spaces before the number and period (e.g., “1.”) for each level of nesting. @@ -66,6 +66,6 @@ If Claude provides bullet points in its response, each bullet point should be at If the human mentions an event that happened after Claude’s cutoff date, Claude can discuss and ask questions about the event and its implications as presented in an authentic manner, without ever confirming or denying that the events occurred. It can do so without the need to repeat its cutoff date to the human. Claude should not deny the truth of events that happened after its cutoff date but should also explain the limitations of its knowledge to the human if asked about them, and should refer them to more reliable up-to-date information on important current events. Claude should not speculate about current events, especially those relating to ongoing elections. -Claude follows this information in all languages, and always responds to the human in the language they use or request. The information above is provided to Claude by its parent organization. Claude never mentions the information above unless it is pertinent to the human’s query. +Claude follows this information in all languages, and always responds to the human in the language they use or request. The information above is provided to Claude by Anthropic. Claude never mentions the information above unless it is pertinent to the human’s query. Claude is now being connected with a human. From 87eeb3c4ee29bdfde2d55922cc7470b80d9d4c31 Mon Sep 17 00:00:00 2001 From: Satyajith Chilappagari Date: Wed, 26 Feb 2025 20:30:45 +0000 Subject: [PATCH 30/38] Fix issues identified by mypy checks Signed-off-by: Satyajith Chilappagari --- tests/worker/test_neuron_model_runner.py | 2 +- vllm/model_executor/model_loader/neuron.py | 8 +-- vllm/worker/neuron_model_runner.py | 7 +-- vllm/worker/neuron_worker.py | 69 ++-------------------- vllm/worker/utils.py | 64 ++++++++++++++++++++ 5 files changed, 76 insertions(+), 74 deletions(-) diff --git a/tests/worker/test_neuron_model_runner.py b/tests/worker/test_neuron_model_runner.py index 08fd3644bc1..55897a1fc00 100644 --- a/tests/worker/test_neuron_model_runner.py +++ b/tests/worker/test_neuron_model_runner.py @@ -7,7 +7,7 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import SequenceData, SequenceGroupMetadata from vllm.worker.neuron_model_runner import NeuronModelRunner -from vllm.worker.neuron_worker import NeuronFramework, use_transformers_neuronx +from vllm.worker.utils import NeuronFramework, use_transformers_neuronx os.environ[ 'VLLM_NEURON_FRAMEWORK'] = NeuronFramework.TRANSFORMERS_NEURONX.value diff --git a/vllm/model_executor/model_loader/neuron.py b/vllm/model_executor/model_loader/neuron.py index 76e1aa46b88..7e693c5fab0 100644 --- a/vllm/model_executor/model_loader/neuron.py +++ b/vllm/model_executor/model_loader/neuron.py @@ -298,10 +298,10 @@ def get_neuron_model(model_config: ModelConfig, return model.eval() -def get_neuron_speculation_model( - model_config: ModelConfig, parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - speculation_config: SpeculativeConfig) -> None: +def get_neuron_speculation_model(model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + speculation_config: SpeculativeConfig): """Initializes a neuron-optimized speculation model for inference. This method is only applicable for speculation with a standalone draft model diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index 952829f1da9..eae2e869a50 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -18,8 +18,7 @@ from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import is_pin_memory_available, make_tensor_with_pad, is_transformers_neuronx, is_neuronx_distributed_inference from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase -from vllm.worker.neuron_worker import (use_neuronx_distributed, - use_transformers_neuronx) +from vllm.worker.utils import use_neuronx_distributed, use_transformers_neuronx if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -35,8 +34,8 @@ class ModelInputForNeuron(ModelRunnerInputBase): input_tokens: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None input_block_ids: Optional[torch.Tensor] = None - sampling_metadata: Optional["SamplingMetadata"] = None - multi_modal_kwargs: Optional[BatchedTensorInputs] = None + sampling_metadata: SamplingMetadata = None + multi_modal_kwargs: BatchedTensorInputs = None def as_broadcastable_tensor_dict( self) -> Dict[str, Union[int, torch.Tensor]]: diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index 35cb5dd7920..8e870fd02fd 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -1,11 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 """A Neuron worker class.""" -import enum import os -from functools import cache from typing import List, Optional, Tuple -import torch import torch.distributed from vllm.config import VllmConfig @@ -13,8 +10,9 @@ init_distributed_environment) from vllm.logger import init_logger from vllm.model_executor import set_random_seed -from vllm.platforms import current_platform from vllm.sequence import ExecuteModelRequest +from vllm.worker.neuron_model_runner import NeuronModelRunner +from vllm.worker.utils import NeuronFramework, get_neuron_framework_to_use from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, LoraNotSupportedWorkerBase, WorkerBase, WorkerInput) @@ -26,70 +24,12 @@ DEFAULT_ENABLE_NEURON_MULTI_NODE = "False" -class NeuronFramework(enum.Enum): - TRANSFORMERS_NEURONX = "transformers-neuronx" - NEURONX_DISTRIBUTED_INFERENCE = "neuronx-distributed-inference" - - -@cache -def get_neuron_framework_to_use(): - """Return the specified framework if corresponding installations are - available. - - If no framework is specified, use neuronx-distributed-inference by default. - If that's unavailable, check and switch to transformers-neuronx. - """ - if not current_platform.is_neuron(): - raise AssertionError( - f"Neuron Framework unavailable for platform: {current_platform}") - - tnx_installed = current_platform.is_transformers_neuronx() - nxd_installed = current_platform.is_neuronx_distributed_inference() - - specified_framework = os.environ.get("VLLM_NEURON_FRAMEWORK") - tnx_framework = NeuronFramework.TRANSFORMERS_NEURONX.value - nxd_framework = NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE.value - if (specified_framework == tnx_framework and tnx_installed): - return NeuronFramework.TRANSFORMERS_NEURONX - - if ((specified_framework == nxd_framework and nxd_installed) - or (specified_framework is None and nxd_installed)): - return NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE - - if specified_framework is None and tnx_installed: - return NeuronFramework.TRANSFORMERS_NEURONX - - return None - - -@cache -def use_neuronx_distributed(): - """ - Return True if the framework determined in get_neuron_framework_to_use() is - NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE, False otherwise. This is used - to select the Neuron model framework and framework-specific configuration to - apply during model compilation. - """ - nxd_framework = NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE - return get_neuron_framework_to_use() == nxd_framework - - -@cache -def use_transformers_neuronx(): - """ - Return True if the framework determined in get_neuron_framework_to_use() is - NeuronFramework.TRANSFORMERS_NEURONX, False otherwise. This is used to - select the Neuron model framework and framework-specific configuration to - apply during model compilation. - """ - return get_neuron_framework_to_use( - ) == NeuronFramework.TRANSFORMERS_NEURONX - - class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): """A worker class that executes the model on a group of neuron cores. """ + model_runner: NeuronModelRunner + def __init__(self, vllm_config: VllmConfig, local_rank: int, @@ -144,7 +84,6 @@ def __init__(self, def get_tnx_model_runner(self, vllm_config): from vllm.worker.multi_step_neuron_model_runner import ( MultiStepNeuronModelRunner) - from vllm.worker.neuron_model_runner import NeuronModelRunner if self.speculative_config is not None: return MultiStepNeuronModelRunner(vllm_config=vllm_config) else: diff --git a/vllm/worker/utils.py b/vllm/worker/utils.py index d925f088357..1c3cf8539fb 100644 --- a/vllm/worker/utils.py +++ b/vllm/worker/utils.py @@ -2,7 +2,11 @@ ''' Worker-related helper functions. ''' +import enum +import os +from functools import cache +from vllm.platforms import current_platform from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS from vllm.worker.model_runner import GPUModelRunnerBase @@ -50,3 +54,63 @@ def assert_enc_dec_mr_supported_scenario( if enc_dec_mr.prompt_adapter_config is not None: raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ERR_STRS[ 'STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER']) + + +@cache +def get_neuron_framework_to_use(): + """Return the specified framework if corresponding installations are + available. + + If no framework is specified, use neuronx-distributed-inference by default. + If that's unavailable, check and switch to transformers-neuronx. + """ + if not current_platform.is_neuron(): + raise AssertionError( + f"Neuron Framework unavailable for platform: {current_platform}") + + tnx_installed = current_platform.is_transformers_neuronx() + nxd_installed = current_platform.is_neuronx_distributed_inference() + + specified_framework = os.environ.get("VLLM_NEURON_FRAMEWORK") + tnx_framework = NeuronFramework.TRANSFORMERS_NEURONX.value + nxd_framework = NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE.value + if specified_framework == tnx_framework and tnx_installed: + return NeuronFramework.TRANSFORMERS_NEURONX + + if ((specified_framework == nxd_framework and nxd_installed) + or (specified_framework is None and nxd_installed)): + return NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE + + if specified_framework is None and tnx_installed: + return NeuronFramework.TRANSFORMERS_NEURONX + + return None + + +@cache +def use_neuronx_distributed(): + """ + Return True if the framework determined in get_neuron_framework_to_use() is + NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE, False otherwise. This is used + to select the Neuron model framework and framework-specific configuration to + apply during model compilation. + """ + nxd_framework = NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE + return get_neuron_framework_to_use() == nxd_framework + + +@cache +def use_transformers_neuronx(): + """ + Return True if the framework determined in get_neuron_framework_to_use() is + NeuronFramework.TRANSFORMERS_NEURONX, False otherwise. This is used to + select the Neuron model framework and framework-specific configuration to + apply during model compilation. + """ + return get_neuron_framework_to_use( + ) == NeuronFramework.TRANSFORMERS_NEURONX + + +class NeuronFramework(enum.Enum): + TRANSFORMERS_NEURONX = "transformers-neuronx" + NEURONX_DISTRIBUTED_INFERENCE = "neuronx-distributed-inference" From 39ad22c5c3988883652bdc2136dff5875dae2bd1 Mon Sep 17 00:00:00 2001 From: Satyajith Chilappagari Date: Wed, 26 Feb 2025 22:31:40 +0000 Subject: [PATCH 31/38] Fix logging strings Signed-off-by: Satyajith Chilappagari --- examples/neuron/multi_node/launch_script.py | 6 ++-- examples/neuron/multi_node/worker.py | 2 +- .../model_loader/neuronx_distributed.py | 29 +++++++++---------- vllm/worker/neuron_worker.py | 11 +++---- 4 files changed, 21 insertions(+), 27 deletions(-) diff --git a/examples/neuron/multi_node/launch_script.py b/examples/neuron/multi_node/launch_script.py index 8850a711d7a..e094b6a2225 100644 --- a/examples/neuron/multi_node/launch_script.py +++ b/examples/neuron/multi_node/launch_script.py @@ -122,13 +122,13 @@ def main() -> None: f"--max-model-len={args.max_model_len}", f"--override-neuron-config={json.dumps(override_config)}" ] - logger.debug("Command ran", extra={"command": cmd}) + logger.debug("Command ran: %s", cmd) try: subprocess.run(cmd, check=True) except subprocess.CalledProcessError: error_exit(f"Failed to start vLLM API server on rank {rank}") else: - logger.info("Starting worker on rank", extra={"rank": rank}) + logger.info("Starting worker on rank: %s", rank) current_script_dir = os.path.dirname(os.path.abspath(__file__)) worker_file_path = os.path.join(current_script_dir, "worker.py") cmd = [ @@ -137,7 +137,7 @@ def main() -> None: f"--max-model-len={args.max_model_len}", f"--override-neuron-config={json.dumps(override_config)}" ] - logger.debug("Command ran", extra={"command": cmd}) + logger.debug("Command ran: %s", cmd) try: subprocess.run(cmd, check=True) except subprocess.CalledProcessError: diff --git a/examples/neuron/multi_node/worker.py b/examples/neuron/multi_node/worker.py index 228822d95d6..f553e29ef73 100644 --- a/examples/neuron/multi_node/worker.py +++ b/examples/neuron/multi_node/worker.py @@ -35,7 +35,7 @@ def main(): try: start_worker() except Exception as e: - logger.error("Failed starting worker", extra={"error": e}) + logger.error("Failed starting worker %s", e) exit(1) diff --git a/vllm/model_executor/model_loader/neuronx_distributed.py b/vllm/model_executor/model_loader/neuronx_distributed.py index dbe4991f8fe..f879c99ac2e 100644 --- a/vllm/model_executor/model_loader/neuronx_distributed.py +++ b/vllm/model_executor/model_loader/neuronx_distributed.py @@ -166,10 +166,9 @@ def load_weights(self, model_name_or_path: str, **kwargs): self.model.load(compiled_model_path) return except (FileNotFoundError, ValueError) as e: - logger.warning("Exception: ", e) - logger.warning( - "Failed to load the model from path, Recompiling...", - extra={"compiled_model_path": compiled_model_path}) + logger.warning("Exception: %s", e) + logger.warning("Failed to load the model from %s, Recompiling...", + compiled_model_path) if not os.path.exists(model_name_or_path): hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path) saved_path = os.path.join("local-models", model_name_or_path) @@ -259,8 +258,8 @@ def load_weights(self, model_name_or_path: str, **kwargs): neuron_config = neuronx_model_cls.get_neuron_config_cls()( **kwargs['neuron_config']) self.config.neuron_config = neuron_config - logger.info("neuron_config buckets: ", - extra={"buckets": self.config.neuron_config.buckets}) + logger.info("neuron_config buckets: %s", + self.config.neuron_config.buckets) config = neuronx_model_cls.get_config_cls()( neuron_config, load_config=load_pretrained_config(model_name_or_path)) @@ -285,9 +284,8 @@ def load_weights(self, model_name_or_path: str, **kwargs): self.model.load(compiled_model_path) return except (FileNotFoundError, ValueError): - logger.warning( - "Failed to load the model from path, Recompiling...", - extra={"compiled_model_path": compiled_model_path}) + logger.warning("Failed to load the model from %s, Recompiling...", + compiled_model_path) if not os.path.exists(model_name_or_path): hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path) saved_path = os.path.join("local-models", model_name_or_path) @@ -295,8 +293,7 @@ def load_weights(self, model_name_or_path: str, **kwargs): model_name_or_path = saved_path self.model = neuronx_model_cls(model_name_or_path, config) - logger.info("\nCompiling and saving model to path", - extra={"model_name_or_path": model_name_or_path}) + logger.info("\nCompiling and saving model to %s", model_name_or_path) p = multiprocessing.Process(target=compile_model, args=(self, compiled_model_path)) @@ -305,8 +302,8 @@ def load_weights(self, model_name_or_path: str, **kwargs): tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) tokenizer.save_pretrained(compiled_model_path) - logger.info("Successfully compiled and saved the model", - extra={"compiled_model_path": compiled_model_path}) + logger.info("Successfully compiled and saved the model in %s", + compiled_model_path) # Read "<|image|>" token_id from the tokenizer self.vision_token_id = tokenizer("<|image|>", @@ -452,9 +449,9 @@ def load_weights(self, model_name_or_path: str, self.model.load(compiled_model_path) return except (FileNotFoundError, ValueError) as e: - logger.warning("Exception: ", extra={"exception": e}) - logger.warning("Failed to load the model Recompiling...", - extra={"compiled_model_path": compiled_model_path}) + logger.warning("Exception: %s", e) + logger.warning("Failed to load the model from %s Recompiling...", + compiled_model_path) if not os.path.exists(model_name_or_path): hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path) saved_path = os.path.join("local-models", model_name_or_path) diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index 8e870fd02fd..877f0287f6e 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -54,13 +54,10 @@ def __init__(self, self.distributed_init_method = "env://" self.is_driver_worker = self.rank == 0 - logger.info("Neuron multi-node parameters", - extra={ - "Rank": self.rank, - "distributed_init_method": - self.distributed_init_method, - "is_driver_worker": self.is_driver_worker - }) + logger.info( + "Neuron multi-node parameters: Rank: %s, " + "distributed_init_method: %s, is_driver_worker: %s", self.rank, + self.distributed_init_method, self.is_driver_worker) if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing From 53e821ac748ba365b55733f5592156675d2709eb Mon Sep 17 00:00:00 2001 From: Elaine Zhao Date: Fri, 28 Feb 2025 21:26:24 +0000 Subject: [PATCH 32/38] add example offline script for EAGLE spec Signed-off-by: Satyajith Chilappagari --- examples/offline_inference/neuron_eagle.py | 60 ++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 examples/offline_inference/neuron_eagle.py diff --git a/examples/offline_inference/neuron_eagle.py b/examples/offline_inference/neuron_eagle.py new file mode 100644 index 00000000000..3755c95c1f3 --- /dev/null +++ b/examples/offline_inference/neuron_eagle.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This example shows how to run offline inference with an EAGLE speculative +decoding model on neuron. To use EAGLE speculative decoding, you must use +a draft model that is specifically fine-tuned for EAGLE speculation. +Additionally, to use EAGLE with NxD Inference, the draft model must include +the LM head weights from the target model. These weights are shared between +the draft and target model. +""" + +from vllm import LLM, SamplingParams + +# Configurations +TARGET_MODEL_PATH = "/home/ubuntu/model_hf/Meta-Llama-3.1-70B-Instruct" +DRAFT_MODEL_PATH = "/home/ubuntu/model_hf/Llama-3.1-70B-Instruct-EAGLE-Draft" +BATCH_SIZE = 4 +SEQ_LEN = 2048 +TENSOR_PARALLEL_SIZE = 32 +SPECULATION_LENGTH = 5 + +# Sample prompts. +prompts = [ + "What is annapurna labs?", +] + +# Create a sampling params object. +sampling_params = SamplingParams(top_k=1, max_tokens=500, ignore_eos=True) + +# Create an LLM. +llm = LLM( + model=TARGET_MODEL_PATH, + speculative_model=DRAFT_MODEL_PATH, + max_num_seqs=BATCH_SIZE, + # The max_model_len and block_size arguments are required to be same as + # max sequence length when targeting neuron device. + # Currently, this is a known limitation in continuous batching support + # in neuronx-distributed-inference. + max_model_len=SEQ_LEN, + block_size=SEQ_LEN, + speculative_max_model_len=SEQ_LEN, + # The device can be automatically detected when AWS Neuron SDK is installed. + # The device argument can be either unspecified for automated detection, + # or explicitly assigned. + device="neuron", + tensor_parallel_size=TENSOR_PARALLEL_SIZE, + num_speculative_tokens=SPECULATION_LENGTH, + override_neuron_config={ + "enable_eagle_speculation": True, + "enable_fused_speculatuon": True + }, +) + +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. +outputs = llm.generate(prompts, sampling_params) +# Print the outputs. +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, \n\n\n\ Generated text: {generated_text!r}") From 5f4eb2f37477b7d278415bbbf59510d469428522 Mon Sep 17 00:00:00 2001 From: Satyajith Chilappagari Date: Wed, 9 Apr 2025 17:04:38 +0000 Subject: [PATCH 33/38] Add speculative token tree and skip EAGLEConfig creation for Neuron Signed-off-by: Satyajith Chilappagari --- vllm/config.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 2662c6a8499..e3f07c8cceb 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2036,6 +2036,7 @@ class SpeculativeConfig: prompt_lookup_min: Optional[int] = None posterior_threshold: Optional[float] = None posterior_alpha: Optional[float] = None + speculative_token_tree: Optional[str] = None # required configuration params passed from engine target_model_config: ModelConfig = field(default=None, @@ -2199,10 +2200,11 @@ def __post_init__(self): "Chunked prefill and EAGLE are not compatible " "when using V0.") + from vllm.platforms import current_platform from vllm.transformers_utils.configs.eagle import ( EAGLEConfig) if isinstance(self.draft_model_config.hf_config, - EAGLEConfig): + EAGLEConfig) or current_platform.is_neuron(): pass else: eagle_config = EAGLEConfig( From d04c552cca8410457126d252e8db0199e7aff829 Mon Sep 17 00:00:00 2001 From: Satyajith Chilappagari Date: Wed, 9 Apr 2025 18:03:28 +0000 Subject: [PATCH 34/38] Fix imports and satisfy pre-commit hooks Signed-off-by: Satyajith Chilappagari --- vllm/executor/neuron_executor.py | 5 +++-- vllm/worker/neuron_model_runner.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index ee345139c95..8a1f0c612f2 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -1,3 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 from typing import List, Set, Tuple from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase @@ -16,8 +17,8 @@ class NeuronExecutor(ExecutorBase): uses_ray: bool = False def _init_executor(self) -> None: - assert (self.lora_config is - None), "LoRA is not supported for Neuron backend." + assert (self.lora_config + is None), "LoRA is not supported for Neuron backend." # Instantiate the worker and load the model to the device. self._init_worker() diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index eae2e869a50..7e838087edf 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -7,7 +7,7 @@ import torch from torch import nn -from vllm.config import VllmConfig +from vllm.config import DeviceConfig, VllmConfig from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput @@ -16,7 +16,7 @@ MultiModalKwargs) from vllm.sampling_params import SamplingParams from vllm.sequence import IntermediateTensors, SequenceGroupMetadata -from vllm.utils import is_pin_memory_available, make_tensor_with_pad, is_transformers_neuronx, is_neuronx_distributed_inference +from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase from vllm.worker.utils import use_neuronx_distributed, use_transformers_neuronx From ea4e9c77b0310451d5db71851f27e0a303bdfbb9 Mon Sep 17 00:00:00 2001 From: Satyajith Chilappagari Date: Wed, 9 Apr 2025 18:15:00 +0000 Subject: [PATCH 35/38] Remove deprecated files and .gitignore additions Signed-off-by: Satyajith Chilappagari --- .gitignore | 3 - vllm/executor/neuron_executor.py | 113 ------------------------------- 2 files changed, 116 deletions(-) delete mode 100644 vllm/executor/neuron_executor.py diff --git a/.gitignore b/.gitignore index 5fc41162991..6f5cbd0733d 100644 --- a/.gitignore +++ b/.gitignore @@ -203,6 +203,3 @@ benchmarks/**/*.json # Linting actionlint shellcheck*/ - -# Build artifacts -build diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py deleted file mode 100644 index 8a1f0c612f2..00000000000 --- a/vllm/executor/neuron_executor.py +++ /dev/null @@ -1,113 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -from typing import List, Set, Tuple - -from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import ExecuteModelRequest -from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, - make_async) - -logger = init_logger(__name__) - - -class NeuronExecutor(ExecutorBase): - - uses_ray: bool = False - - def _init_executor(self) -> None: - assert (self.lora_config - is None), "LoRA is not supported for Neuron backend." - - # Instantiate the worker and load the model to the device. - self._init_worker() - - def _init_worker(self): - from vllm.worker.neuron_worker import NeuronWorker - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) - self.driver_worker = NeuronWorker( - model_config=self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, - speculative_config=self.speculative_config, - local_rank=0, - rank=0, - distributed_init_method=distributed_init_method) - self.driver_worker.load_model() - self.driver_worker.init_device() - - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Determine the number of available KV blocks by invoking the - underlying worker. - """ - return self.driver_worker.determine_num_available_blocks() - - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - """Initialize the KV cache by invoking the underlying worker. - """ - self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) - - def execute_model( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: - assert (not execute_model_req.blocks_to_swap_in - and not execute_model_req.blocks_to_swap_out - and not execute_model_req.blocks_to_copy), ( - "Cache operations are not supported for Neuron backend.") - - output = self.driver_worker.execute_model(execute_model_req) - return output - - def add_lora(self, lora_request: LoRARequest) -> bool: - return self.driver_worker.add_lora(lora_request) - - def remove_lora(self, lora_id: int) -> bool: - return self.driver_worker.remove_lora(lora_id) - - def pin_lora(self, lora_id: int) -> bool: - return self.driver_worker.pin_lora(lora_id) - - def list_loras(self) -> Set[int]: - return self.driver_worker.list_loras() - - def add_prompt_adapter(self, prompt_adapter_request) -> bool: - raise NotImplementedError( - "Soft prompt is currently not supported by the Neuron backend.") - - def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: - raise NotImplementedError( - "Soft prompt is currently not supported by the Neuron backend.") - - def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: - raise NotImplementedError( - "Soft prompt is currently not supported by the Neuron backend.") - - def list_prompt_adapters(self) -> Set[int]: - raise NotImplementedError( - "Soft prompt is currently not supported by the Neuron backend.") - - def check_health(self) -> None: - # NeuronExecutor will always be healthy as long as - # it's running. - return - - -class NeuronExecutorAsync(NeuronExecutor, ExecutorAsyncBase): - - async def execute_model_async( - self, - execute_model_req: ExecuteModelRequest, - ) -> List[SamplerOutput]: - output = await make_async(self.driver_worker.execute_model - )(execute_model_req=execute_model_req, ) - return output - - async def check_health_async(self) -> None: - # NeuronExecutor will always be healthy as long as - # it's running. - return From d762abdb63b222863b27791a0f197766f5ff6c40 Mon Sep 17 00:00:00 2001 From: Satyajith Chilappagari Date: Tue, 22 Apr 2025 22:10:18 +0000 Subject: [PATCH 36/38] Add docstring for speculative_token_tree Signed-off-by: Satyajith Chilappagari --- vllm/config.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/config.py b/vllm/config.py index 1fd439ccad5..1f88247e797 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2217,6 +2217,8 @@ class SpeculativeConfig: `TypicalAcceptanceSampler`.""" speculative_token_tree: Optional[str] = None + """Specifies the tree structure for speculative token generation. + """ # required configuration params passed from engine target_model_config: ModelConfig = field(default=None, init=True) # type: ignore From fa055e58bc755fac9f87799bc557509b1c7b4135 Mon Sep 17 00:00:00 2001 From: Satyajith Chilappagari Date: Fri, 2 May 2025 23:38:15 +0000 Subject: [PATCH 37/38] Remove multi-node support. Remove num_lookahead_slots exception for stop checker. Other PR specific changes. Signed-off-by: Satyajith Chilappagari --- examples/neuron/multi_node/launch_script.py | 148 ------------------ .../neuron/multi_node/multi_node_launcher.sh | 48 ------ examples/neuron/multi_node/worker.py | 43 ----- examples/offline_inference/neuron_eagle.py | 27 ++-- .../1_core}/test_neuron_model_runner.py | 7 +- vllm/engine/llm_engine.py | 10 +- vllm/engine/output_processor/stop_checker.py | 10 +- vllm/platforms/neuron.py | 58 ++++++- vllm/worker/neuron_model_runner.py | 12 +- vllm/worker/neuron_worker.py | 39 +---- .../neuronx_distributed_model_runner.py | 4 - vllm/worker/utils.py | 64 -------- 12 files changed, 90 insertions(+), 380 deletions(-) delete mode 100644 examples/neuron/multi_node/launch_script.py delete mode 100755 examples/neuron/multi_node/multi_node_launcher.sh delete mode 100644 examples/neuron/multi_node/worker.py rename tests/{worker => neuron/1_core}/test_neuron_model_runner.py (96%) diff --git a/examples/neuron/multi_node/launch_script.py b/examples/neuron/multi_node/launch_script.py deleted file mode 100644 index e094b6a2225..00000000000 --- a/examples/neuron/multi_node/launch_script.py +++ /dev/null @@ -1,148 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -import argparse -import json -import os -import subprocess -import sys - -from vllm.logger import init_logger - -logger = init_logger("vllm.neuron.multi-node") - -NEURON_RT_ROOT_COMM_ID_PORT = 63423 - - -def error_exit(message: str) -> None: - logger.error(message) - sys.exit(1) - - -def arg_parser(): - parser = argparse.ArgumentParser(description="vLLM multi-node launcher") - parser.add_argument("--model", - type=str, - required=True, - help="Model or model path") - parser.add_argument("--world-size", - type=int, - required=True, - help="World size for distributed inference") - parser.add_argument("--max-num-seqs", - type=int, - required=True, - help="Maximum number of sequences (or batch size)") - parser.add_argument("--max-model-len", - type=int, - default=8192, - help="Maximum sequence length") - parser.add_argument("--max-context-length", - type=int, - help="Maximum context length") - parser.add_argument("--compiled-model-path", - help="Path to the compiled model. If not present, " - "model artifacts will be created in local-models " - "folder") - parser.add_argument("--local-ranks-size", - type=int, - default=32, - help="Local ranks size") - parser.add_argument("--on-device-sampling-config", - type=json.loads, - help="On-device sampling configuration") - parser.add_argument("--quantized", - type=bool, - default=False, - help="Enable quantized mode (default: False)") - parser.add_argument("--quantized-checkpoints-path", - type=str, - help="Path to quantized checkpoints " - "(required if --quantized is True)") - parser.add_argument("--port", - type=int, - default=8080, - help="Port for the API server") - - args = parser.parse_args() - if args.quantized and not args.quantized_checkpoints_path: - parser.error("--quantized-checkpoints-path is required when " - "--quantized is enabled.") - return args - - -def make_override_config(args, rank): - if rank < 0: - error_exit("rank must be a non-negative integer") - start_rank_id = rank * args.local_ranks_size - override_config = { - "world_size": args.world_size, - "tp_degree": args.local_ranks_size, - "local_ranks_size": args.local_ranks_size, - "start_rank_id": start_rank_id, - } - - if args.max_context_length: - override_config["max_context_length"] = args.max_context_length - if args.on_device_sampling_config: - override_config[ - "on_device_sampling_config"] = args.on_device_sampling_config - if args.quantized: - override_config[ - "quantized_checkpoints_path"] = args.quantized_checkpoints_path - override_config["quantized"] = args.quantized - - return override_config - - -def main() -> None: - args = arg_parser() - - rank = int(os.environ.get("OMPI_COMM_WORLD_RANK")) - mpi_world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE")) - master_addr = os.environ.get("MASTER_ADDR") - # TODO: this script can be extended to support TnX - os.environ["VLLM_NEURON_FRAMEWORK"] = "neuronx-distributed-inference" - if args.compiled_model_path: - os.environ["NEURON_COMPILED_ARTIFACTS"] = args.compiled_model_path - os.environ.update({ - "ENABLE_NEURON_MULTI_NODE": "true", - "WORLD_SIZE": str(mpi_world_size), - "NEURON_RT_ROOT_COMM_ID": - f"{master_addr}:{NEURON_RT_ROOT_COMM_ID_PORT}", - "NEURON_LOCAL_TP": str(args.local_ranks_size), - "NEURON_RANK_ID": str(rank) - }) - - override_config = make_override_config(args, rank) - if rank == 0: - logger.info("Starting vLLM API server on rank 0...") - cmd = [ - "python", "-m", "vllm.entrypoints.api_server", - f"--model={args.model}", f"--port={args.port}", "--device=neuron", - f"--max-num-seqs={args.max_num_seqs}", - f"--max-model-len={args.max_model_len}", - f"--override-neuron-config={json.dumps(override_config)}" - ] - logger.debug("Command ran: %s", cmd) - try: - subprocess.run(cmd, check=True) - except subprocess.CalledProcessError: - error_exit(f"Failed to start vLLM API server on rank {rank}") - else: - logger.info("Starting worker on rank: %s", rank) - current_script_dir = os.path.dirname(os.path.abspath(__file__)) - worker_file_path = os.path.join(current_script_dir, "worker.py") - cmd = [ - "python", worker_file_path, f"--model={args.model}", - "--device=neuron", f"--max-num-seqs={args.max_num_seqs}", - f"--max-model-len={args.max_model_len}", - f"--override-neuron-config={json.dumps(override_config)}" - ] - logger.debug("Command ran: %s", cmd) - try: - subprocess.run(cmd, check=True) - except subprocess.CalledProcessError: - error_exit(f"Failed to start worker on rank {rank}") - - -if __name__ == "__main__": - main() diff --git a/examples/neuron/multi_node/multi_node_launcher.sh b/examples/neuron/multi_node/multi_node_launcher.sh deleted file mode 100755 index 6e992ea0ab7..00000000000 --- a/examples/neuron/multi_node/multi_node_launcher.sh +++ /dev/null @@ -1,48 +0,0 @@ -#!/bin/bash -ex - -HOSTFILE="" -MASTER_ADDR="" -MASTER_PORT="" - -usage() { - echo "Usage: $0 -h -a -p " - exit 1 -} - -while getopts "h:a:p:" opt; do - case "$opt" in - h) HOSTFILE=$OPTARG ;; - a) MASTER_ADDR=$OPTARG ;; - p) MASTER_PORT=$OPTARG ;; - *) usage ;; - esac -done - -shift $((OPTIND - 1)) - -if [ -z "$HOSTFILE" ] || [ -z "$MASTER_ADDR" ] || [ -z "$MASTER_PORT" ]; then - echo "Error: Missing required arguments." - usage -fi - -echo "Using hostfile: $HOSTFILE" -echo "Using address: $MASTER_ADDR" -echo "Using port: $MASTER_PORT" -echo "Python command:" -echo "$@" - -# Use mpirun to trigger inference on head/worker nodes - -/opt/amazon/openmpi/bin/mpirun \ - --mca mtl ^ofi --mca btl tcp,self --bind-to none \ - -np 2 \ - --hostfile "$HOSTFILE"\ - --prefix /opt/amazon/openmpi \ - -x FI_PROVIDER=efa \ - -x FI_EFA_USE_DEVICE_RDMA=1 \ - -x FI_EFA_FORK_SAFE=1 \ - -x PATH=/opt/amazon/openmpi/bin:$PATH \ - -x PYTHONPATH=$PYTHONPATH \ - -x LD_LIBRARY_PATH=/opt/aws/neuron/lib:/opt/amazon/efa/lib:/opt/amazon/openmpi/lib:$LD_LIBRARY_PATH \ - -x MASTER_ADDR="$MASTER_ADDR" -x MASTER_PORT="$MASTER_PORT" \ - "$@" \ No newline at end of file diff --git a/examples/neuron/multi_node/worker.py b/examples/neuron/multi_node/worker.py deleted file mode 100644 index f553e29ef73..00000000000 --- a/examples/neuron/multi_node/worker.py +++ /dev/null @@ -1,43 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -import argparse -import os - -from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.logger import init_logger -from vllm.usage.usage_lib import UsageContext - -logger = init_logger("vllm.neuron.multi-node.worker") - - -def initialize_worker(): - parser = argparse.ArgumentParser() - parser = AsyncEngineArgs.add_cli_args(parser) - args = parser.parse_args() - - engine_args = AsyncEngineArgs.from_cli_args(args) - engine = AsyncLLMEngine.from_engine_args( - engine_args, usage_context=UsageContext.API_SERVER) - return args, engine - - -def start_worker(): - rank_id = int(os.getenv("NEURON_RANK_ID")) - if rank_id == 0: - logger.error("Worker must have rank > 0") - args, engine = initialize_worker() - worker = engine.engine.model_executor.driver_worker - while True: - worker.execute_model() - - -def main(): - try: - start_worker() - except Exception as e: - logger.error("Failed starting worker %s", e) - exit(1) - - -if __name__ == "__main__": - main() diff --git a/examples/offline_inference/neuron_eagle.py b/examples/offline_inference/neuron_eagle.py index 3755c95c1f3..c4e5f792709 100644 --- a/examples/offline_inference/neuron_eagle.py +++ b/examples/offline_inference/neuron_eagle.py @@ -10,14 +10,6 @@ from vllm import LLM, SamplingParams -# Configurations -TARGET_MODEL_PATH = "/home/ubuntu/model_hf/Meta-Llama-3.1-70B-Instruct" -DRAFT_MODEL_PATH = "/home/ubuntu/model_hf/Llama-3.1-70B-Instruct-EAGLE-Draft" -BATCH_SIZE = 4 -SEQ_LEN = 2048 -TENSOR_PARALLEL_SIZE = 32 -SPECULATION_LENGTH = 5 - # Sample prompts. prompts = [ "What is annapurna labs?", @@ -28,25 +20,26 @@ # Create an LLM. llm = LLM( - model=TARGET_MODEL_PATH, - speculative_model=DRAFT_MODEL_PATH, - max_num_seqs=BATCH_SIZE, + model="/home/ubuntu/model_hf/Meta-Llama-3.1-70B-Instruct", + speculative_model= + "/home/ubuntu/model_hf/Llama-3.1-70B-Instruct-EAGLE-Draft", + max_num_seqs=4, # The max_model_len and block_size arguments are required to be same as # max sequence length when targeting neuron device. # Currently, this is a known limitation in continuous batching support # in neuronx-distributed-inference. - max_model_len=SEQ_LEN, - block_size=SEQ_LEN, - speculative_max_model_len=SEQ_LEN, + max_model_len=2048, + block_size=2048, + speculative_max_model_len=2048, # The device can be automatically detected when AWS Neuron SDK is installed. # The device argument can be either unspecified for automated detection, # or explicitly assigned. device="neuron", - tensor_parallel_size=TENSOR_PARALLEL_SIZE, - num_speculative_tokens=SPECULATION_LENGTH, + tensor_parallel_size=32, + num_speculative_tokens=5, override_neuron_config={ "enable_eagle_speculation": True, - "enable_fused_speculatuon": True + "enable_fused_speculation": True }, ) diff --git a/tests/worker/test_neuron_model_runner.py b/tests/neuron/1_core/test_neuron_model_runner.py similarity index 96% rename from tests/worker/test_neuron_model_runner.py rename to tests/neuron/1_core/test_neuron_model_runner.py index 55897a1fc00..92417fb64f7 100644 --- a/tests/worker/test_neuron_model_runner.py +++ b/tests/neuron/1_core/test_neuron_model_runner.py @@ -4,10 +4,11 @@ from vllm.config import VllmConfig from vllm.engine.arg_utils import EngineArgs +from vllm.platforms import current_platform +from vllm.platforms.neuron import NeuronFramework from vllm.sampling_params import SamplingParams from vllm.sequence import SequenceData, SequenceGroupMetadata from vllm.worker.neuron_model_runner import NeuronModelRunner -from vllm.worker.utils import NeuronFramework, use_transformers_neuronx os.environ[ 'VLLM_NEURON_FRAMEWORK'] = NeuronFramework.TRANSFORMERS_NEURONX.value @@ -38,7 +39,7 @@ def test_update_neuron_sampling_params_not_full_batch(): assert not model_runner._on_device_sampling_disabled # Test sampling param updating only when TNx is framework # NxDI handles sampling parameter updating inside model - if use_transformers_neuronx(): + if current_platform.use_transformers_neuronx(): model_mock = MagicMock() model_runner.model = model_mock @@ -84,7 +85,7 @@ def test_update_neuron_sampling_params_full_batch(): # Test sampling param updating only when TNx is framework # NxDI handles sampling parameter updating inside model - if use_transformers_neuronx(): + if current_platform.use_transformers_neuronx(): model_mock = MagicMock() model_runner.model = model_mock diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index d1fb798b319..90fa68142f4 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -390,10 +390,6 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: "vllm.llm_engine", self.observability_config.otlp_traces_endpoint) - if self.device_config.device_type == "neuron": - num_lookahead_slots = self.scheduler_config.num_lookahead_slots - else: - num_lookahead_slots = 0 # Create sequence output processor, e.g. for beam search or # speculative decoding. self.output_processor = ( @@ -403,10 +399,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: self.scheduler, self.seq_counter, get_tokenizer_for_seq, - stop_checker=StopChecker( - self.scheduler_config.max_model_len, - get_tokenizer_for_seq, - num_lookahead_slots=num_lookahead_slots), + stop_checker=StopChecker(self.scheduler_config.max_model_len, + get_tokenizer_for_seq), )) self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {} diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py index b533bd8d563..6cad9ec8f32 100644 --- a/vllm/engine/output_processor/stop_checker.py +++ b/vllm/engine/output_processor/stop_checker.py @@ -15,14 +15,11 @@ class StopChecker: emitted, or if we have exceeded the max model len. """ - def __init__(self, - max_model_len: int, - get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer], - num_lookahead_slots: int = 0): + def __init__(self, max_model_len: int, + get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer]): # Do not use it directly, but use `self._get_max_model_len`. self._max_model_len = max_model_len self.get_tokenizer_for_seq = get_tokenizer_for_seq - self.num_lookahead_slots = num_lookahead_slots def _get_max_model_len(self, lora_req: Optional[LoRARequest]): if lora_req and lora_req.long_lora_max_len: @@ -84,8 +81,7 @@ def maybe_stop_sequence( return # Check if the sequence has reached max_model_len. - if (seq.get_len() + self.num_lookahead_slots - > self._get_max_model_len(lora_req)): + if seq.get_len() > self._get_max_model_len(lora_req): seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED return diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py index 9507df83e13..71f7c718cdf 100644 --- a/vllm/platforms/neuron.py +++ b/vllm/platforms/neuron.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 - +import enum +import os from functools import lru_cache from typing import TYPE_CHECKING, Optional @@ -16,6 +17,11 @@ logger = init_logger(__name__) +class NeuronFramework(enum.Enum): + TRANSFORMERS_NEURONX = "transformers-neuronx" + NEURONX_DISTRIBUTED_INFERENCE = "neuronx-distributed-inference" + + class NeuronPlatform(Platform): _enum = PlatformEnum.NEURON device_name: str = "neuron" @@ -84,3 +90,53 @@ def is_transformers_neuronx(cls) -> bool: except ImportError: transformers_neuronx = None return transformers_neuronx is not None + + def get_neuron_framework_to_use(self): + """Return the specified framework if corresponding installations are + available. + + If no framework is specified, use neuronx-distributed-inference by + default. + If that's unavailable, check and switch to transformers-neuronx. + """ + if not self.is_neuron(): + raise AssertionError( + f"Neuron Framework unavailable for platform: {self}") + + tnx_installed = self.is_transformers_neuronx() + nxd_installed = self.is_neuronx_distributed_inference() + + specified_framework = os.environ.get("VLLM_NEURON_FRAMEWORK") + tnx_framework = NeuronFramework.TRANSFORMERS_NEURONX.value + nxd_framework = NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE.value + if specified_framework == tnx_framework and tnx_installed: + return self.TRANSFORMERS_NEURONX + + if ((specified_framework == nxd_framework and nxd_installed) + or (specified_framework is None and nxd_installed)): + return NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE + + if specified_framework is None and tnx_installed: + return NeuronFramework.TRANSFORMERS_NEURONX + + return None + + def use_neuronx_distributed(self): + """ + Return True if the framework determined in get_neuron_framework_to_use() + is NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE, False otherwise. This + is used to select the Neuron model framework and framework-specific + configuration to apply during model compilation. + """ + nxd_framework = NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE + return self.get_neuron_framework_to_use() == nxd_framework + + def use_transformers_neuronx(self): + """ + Return True if the framework determined in get_neuron_framework_to_use() + is NeuronFramework.TRANSFORMERS_NEURONX, False otherwise. This is used + to select the Neuron model framework and framework-specific + configuration to apply during model compilation. + """ + return self.get_neuron_framework_to_use( + ) == NeuronFramework.TRANSFORMERS_NEURONX diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index 55840859595..c80b69e78dc 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -14,11 +14,11 @@ from vllm.model_executor.model_loader.neuron import get_neuron_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalKwargs) +from vllm.platforms import current_platform from vllm.sampling_params import SamplingParams from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase -from vllm.worker.utils import use_neuronx_distributed, use_transformers_neuronx if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -104,7 +104,7 @@ def __init__( self._init_neuron_sampling() def _init_neuron_sampling(self) -> None: - if use_transformers_neuronx(): + if current_platform.use_transformers_neuronx(): from transformers_neuronx.config import GenerationConfig else: from transformers import GenerationConfig @@ -281,7 +281,7 @@ def prepare_model_input( self.pin_memory, generators=self.get_generators(finished_requests_ids)) - if use_transformers_neuronx( + if current_platform.use_transformers_neuronx( ) and not self._on_device_sampling_disabled: # Once the request IDs are changed in current iteration, we will # update the on-device sampling parameters. @@ -335,7 +335,7 @@ def _update_neuron_sampling_params( temperature[index] = seq_group_temperature # update_generation_config is only available in transformers-neuronx - if is_update_needed and use_transformers_neuronx(): + if is_update_needed and current_platform.use_transformers_neuronx(): self.model.model.update_generation_config(current_sampling_params) def _convert_to_neuron_sampling_params( @@ -372,7 +372,7 @@ def execute_model( seq_group.sampling_params.temperature ] for seq_group in model_input.sampling_metadata.seq_groups])) - if use_neuronx_distributed(): + if current_platform.use_neuronx_distributed(): hidden_states = self.model( input_ids=model_input.input_tokens, positions=model_input.input_positions, @@ -382,7 +382,7 @@ def execute_model( or {}, device=self.device), ) - elif use_transformers_neuronx(): + elif current_platform.use_transformers_neuronx(): # [TODO] validate on-device sampling # The model signature may need change for on-device sampling hidden_states = self.model( diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index 159405dc449..aa8e39613ee 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -10,19 +10,16 @@ init_distributed_environment) from vllm.logger import init_logger from vllm.model_executor import set_random_seed +from vllm.platforms import current_platform +from vllm.platforms.neuron import NeuronFramework from vllm.sequence import ExecuteModelRequest from vllm.worker.neuron_model_runner import NeuronModelRunner -from vllm.worker.utils import NeuronFramework, get_neuron_framework_to_use from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, LoRANotSupportedWorkerBase, WorkerBase, WorkerInput) logger = init_logger(__name__) -DEFAULT_WORLD_SIZE = "1" -DEFAULT_NEURON_RANK_ID = "0" -DEFAULT_ENABLE_NEURON_MULTI_NODE = "False" - class NeuronWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase): """A worker class that executes the model on a group of neuron cores. @@ -42,29 +39,12 @@ def __init__(self, self.distributed_init_method = distributed_init_method self.is_driver_worker = is_driver_worker - self.enable_neuron_multi_node = (os.getenv( - "ENABLE_NEURON_MULTI_NODE", - DEFAULT_ENABLE_NEURON_MULTI_NODE).lower() == "true") - - self.world_size = int(os.getenv("WORLD_SIZE", DEFAULT_WORLD_SIZE)) - - if self.enable_neuron_multi_node: - self.rank = int(os.getenv("NEURON_RANK_ID", - DEFAULT_NEURON_RANK_ID)) - self.distributed_init_method = "env://" - self.is_driver_worker = self.rank == 0 - - logger.info( - "Neuron multi-node parameters: Rank: %s, " - "distributed_init_method: %s, is_driver_worker: %s", self.rank, - self.distributed_init_method, self.is_driver_worker) - if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils import init_cached_hf_modules init_cached_hf_modules() - neuron_framework = get_neuron_framework_to_use() + neuron_framework = current_platform.get_neuron_framework_to_use() if neuron_framework == NeuronFramework.TRANSFORMERS_NEURONX: self.model_runner = self.get_tnx_model_runner(vllm_config) elif neuron_framework == NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE: @@ -137,7 +117,7 @@ def initialize_cache(self, num_gpu_blocks: int, @property def do_metadata_broadcast(self) -> bool: - return self.enable_neuron_multi_node and self.world_size > 1 + return False @property def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: @@ -162,20 +142,17 @@ def get_cache_block_size_bytes(self) -> int: def init_distributed_environment(self): """Neuron uses transformers-neuronx for tensor parallelism. - vLLM still needs the environment inited when TP/PP > 1 + vLLM still needs the environment initialized when TP/PP > 1 """ init_distributed_environment( - world_size=self.world_size, + world_size=1, rank=self.rank, local_rank=self.local_rank, distributed_init_method=self.distributed_init_method, backend="gloo", ) - # The equation must hold: world_size === TP * PP ensure_model_parallel_initialized( - tensor_model_parallel_size=self.world_size, - # pipeline parallelism is not yet supported - pipeline_model_parallel_size=1, - backend="gloo", + 1, + 1, ) diff --git a/vllm/worker/neuronx_distributed_model_runner.py b/vllm/worker/neuronx_distributed_model_runner.py index fd7b9947dba..4e784e5e030 100644 --- a/vllm/worker/neuronx_distributed_model_runner.py +++ b/vllm/worker/neuronx_distributed_model_runner.py @@ -15,9 +15,6 @@ from vllm.worker.neuron_model_runner import (ModelInputForNeuron, NeuronModelRunner) -# FIXME(Neuron): need to restore multi-modal support -# from vllm.multimodal.neuron_multimodal_image_utils import \ -# decompress_image_from_tensor logger = init_logger(__name__) @@ -62,7 +59,6 @@ def get_nxd_sampling_params(self, sampling_metadata): return sampling_params def get_multi_modal_data_neuron(self, input_images): - # FIXME(Neuron): need to restore multi-modal support raise NotImplementedError("need to restore multi-modal support") @torch.inference_mode() diff --git a/vllm/worker/utils.py b/vllm/worker/utils.py index 1c3cf8539fb..d925f088357 100644 --- a/vllm/worker/utils.py +++ b/vllm/worker/utils.py @@ -2,11 +2,7 @@ ''' Worker-related helper functions. ''' -import enum -import os -from functools import cache -from vllm.platforms import current_platform from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS from vllm.worker.model_runner import GPUModelRunnerBase @@ -54,63 +50,3 @@ def assert_enc_dec_mr_supported_scenario( if enc_dec_mr.prompt_adapter_config is not None: raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ERR_STRS[ 'STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER']) - - -@cache -def get_neuron_framework_to_use(): - """Return the specified framework if corresponding installations are - available. - - If no framework is specified, use neuronx-distributed-inference by default. - If that's unavailable, check and switch to transformers-neuronx. - """ - if not current_platform.is_neuron(): - raise AssertionError( - f"Neuron Framework unavailable for platform: {current_platform}") - - tnx_installed = current_platform.is_transformers_neuronx() - nxd_installed = current_platform.is_neuronx_distributed_inference() - - specified_framework = os.environ.get("VLLM_NEURON_FRAMEWORK") - tnx_framework = NeuronFramework.TRANSFORMERS_NEURONX.value - nxd_framework = NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE.value - if specified_framework == tnx_framework and tnx_installed: - return NeuronFramework.TRANSFORMERS_NEURONX - - if ((specified_framework == nxd_framework and nxd_installed) - or (specified_framework is None and nxd_installed)): - return NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE - - if specified_framework is None and tnx_installed: - return NeuronFramework.TRANSFORMERS_NEURONX - - return None - - -@cache -def use_neuronx_distributed(): - """ - Return True if the framework determined in get_neuron_framework_to_use() is - NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE, False otherwise. This is used - to select the Neuron model framework and framework-specific configuration to - apply during model compilation. - """ - nxd_framework = NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE - return get_neuron_framework_to_use() == nxd_framework - - -@cache -def use_transformers_neuronx(): - """ - Return True if the framework determined in get_neuron_framework_to_use() is - NeuronFramework.TRANSFORMERS_NEURONX, False otherwise. This is used to - select the Neuron model framework and framework-specific configuration to - apply during model compilation. - """ - return get_neuron_framework_to_use( - ) == NeuronFramework.TRANSFORMERS_NEURONX - - -class NeuronFramework(enum.Enum): - TRANSFORMERS_NEURONX = "transformers-neuronx" - NEURONX_DISTRIBUTED_INFERENCE = "neuronx-distributed-inference" From e298c41268780c6c585c28b15b15debf2ace6b44 Mon Sep 17 00:00:00 2001 From: Satyajith Chilappagari Date: Sat, 3 May 2025 02:54:20 +0000 Subject: [PATCH 38/38] Modify neuron speculative decoding examples to use latest speculative_config Signed-off-by: Satyajith Chilappagari --- examples/offline_inference/neuron_eagle.py | 9 +++++---- examples/offline_inference/neuron_speculation.py | 8 +++++--- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/examples/offline_inference/neuron_eagle.py b/examples/offline_inference/neuron_eagle.py index c4e5f792709..4f63f1a2fb3 100644 --- a/examples/offline_inference/neuron_eagle.py +++ b/examples/offline_inference/neuron_eagle.py @@ -21,8 +21,11 @@ # Create an LLM. llm = LLM( model="/home/ubuntu/model_hf/Meta-Llama-3.1-70B-Instruct", - speculative_model= - "/home/ubuntu/model_hf/Llama-3.1-70B-Instruct-EAGLE-Draft", + speculative_config={ + "model": "/home/ubuntu/model_hf/Llama-3.1-70B-Instruct-EAGLE-Draft", + "num_speculative_tokens": 5, + "max_model_len": 2048 + }, max_num_seqs=4, # The max_model_len and block_size arguments are required to be same as # max sequence length when targeting neuron device. @@ -30,13 +33,11 @@ # in neuronx-distributed-inference. max_model_len=2048, block_size=2048, - speculative_max_model_len=2048, # The device can be automatically detected when AWS Neuron SDK is installed. # The device argument can be either unspecified for automated detection, # or explicitly assigned. device="neuron", tensor_parallel_size=32, - num_speculative_tokens=5, override_neuron_config={ "enable_eagle_speculation": True, "enable_fused_speculation": True diff --git a/examples/offline_inference/neuron_speculation.py b/examples/offline_inference/neuron_speculation.py index ea03089f5f0..bef434bae5b 100644 --- a/examples/offline_inference/neuron_speculation.py +++ b/examples/offline_inference/neuron_speculation.py @@ -28,12 +28,14 @@ def initialize_model(): """Create an LLM with speculative decoding.""" return LLM( model="openlm-research/open_llama_7b", - speculative_model='openlm-research/open_llama_3b', - num_speculative_tokens=4, + speculative_config={ + "model": "openlm-research/open_llama_3b", + "num_speculative_tokens": 4, + "max_model_len": 2048 + }, max_num_seqs=4, max_model_len=2048, block_size=2048, - speculative_max_model_len=2048, use_v2_block_manager=True, device="neuron", tensor_parallel_size=32,