From d5c329d1dadc6a222419264a06da7fe861639225 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 3 Oct 2024 14:13:11 -0700 Subject: [PATCH 01/88] adapt for dynamo --- vllm/sequence.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index 781bcedde2b..894473deb11 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1149,10 +1149,9 @@ def __eq__(self, other: object) -> bool: return self.embeddings == other.embeddings -class IntermediateTensors( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - array_like=True): # type: ignore[call-arg] +# cannot use msgspec.Struct here because Dynamo does not support it +@dataclass +class IntermediateTensors: """For all pipeline stages except the last, we need to return the hidden states and residuals to be sent to the next stage. This data structure contains the hidden states and residuals for a request. From 12e29fea6ac2a23fdeff9978a79fbc3c29773e8b Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 3 Oct 2024 14:15:10 -0700 Subject: [PATCH 02/88] fix tpu --- vllm/worker/tpu_model_runner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 2472ac25aee..038f0c31f95 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -765,7 +765,8 @@ def forward( slot_mapping = slot_mapping.flatten() attn_metadata.slot_mapping = slot_mapping - hidden_states = self.model( + # directly call `forward` to avoid interference from compilation + hidden_states = self.model.forward( token_ids, position_ids, kv_caches, From 504bd6c8a0d23e7593fb4b1a6f9b9e47b0c078f6 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 3 Oct 2024 14:17:51 -0700 Subject: [PATCH 03/88] add backend --- vllm/compilation/backends.py | 113 ++++++++++++++++++++++++++++++++++- 1 file changed, 112 insertions(+), 1 deletion(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index de0b1d8a757..98ab6667a68 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -1,8 +1,16 @@ import operator +from typing import Callable, Dict, Optional, Tuple, Union +from weakref import ReferenceType import torch import torch.fx as fx +from vllm.logger import init_logger + +from .wrapper import TorchCompileWrapperWithCustomDispatcher + +logger = init_logger(__name__) + def fix_functionalization(graph: fx.Graph): """ @@ -148,9 +156,112 @@ def fix_functionalization(graph: fx.Graph): # print(graph.python_code(root_module="self", verbose=True).src, file=f) -def vllm_backend(graph, example_inputs): +def wrap_inductor(graph, example_inputs, additional_inductor_config): from torch._inductor import config current_config = config.shallow_copy_dict() from torch._inductor.compile_fx import compile_fx + + if additional_inductor_config is not None: + current_config.update(additional_inductor_config) + if current_config['post_grad_custom_post_pass'] is not None: + logger.warning( + "post_grad_custom_post_pass is already set in the config. " + "Overwriting it with the fix_functionalization") current_config['post_grad_custom_post_pass'] = fix_functionalization return compile_fx(graph, example_inputs, config_patches=current_config) + + +def vllm_backend( + graph, + example_inputs, + model_ref: Optional[ + ReferenceType[TorchCompileWrapperWithCustomDispatcher]] = None, + additional_inductor_config: Optional[Dict] = None) -> Callable: + + # flags for all the seen shapes, whether we need to specialize + runtime_shapes_to_compile_flags: Dict[Tuple[int, ...], bool] = {} + + # if we need to specialize, the compiled graph for that shape + runtime_shapes_to_compiled_graph: Dict[Tuple[int, ...], Callable] = {} + + # this is the first compilation, we will compile a graph with + # dynamic shape, as the caller will mark first dimension as dynamic + logger.info("Compiling a graph for general shapes") + graph_for_symbolic_shape = wrap_inductor(graph, example_inputs, + additional_inductor_config) + + # TODO: Dynamo does not pass all dynamic shapes. + # Need to investigate why. It works now because all the dynamic + # shapes have the same value, and either of them can be used. + sym_shape_indices = [ + i for i, x in enumerate(example_inputs) if isinstance(x, torch.SymInt) + ] + + first_run = True + + # this is the function we return to Dynamo to run finally + def compiled_graph_wrapper(*args): + + runtime_shapes: Tuple[int, + ...] = tuple(args[i] for i in sym_shape_indices) + + nonlocal first_run + nonlocal runtime_shapes_to_compile_flags + nonlocal runtime_shapes_to_compiled_graph + + if first_run: + # the first compilation is for profiling, we directly run it + first_run = False + return graph_for_symbolic_shape(*args) + + if model_ref is None: + # no information about the model, we cannot specialize + return graph_for_symbolic_shape(*args) + + model: TorchCompileWrapperWithCustomDispatcher = model_ref() + assert model is not None, "model is garbage collected" + + if runtime_shapes not in runtime_shapes_to_compile_flags: + # we haven't seen this shape before + # query the model if we need to specialize for this shape + runtime_shapes_to_compile_flags[ + runtime_shapes] = model.need_to_specialize(runtime_shapes) + + if not runtime_shapes_to_compile_flags[runtime_shapes]: + # we don't need to specialize for this shape + return graph_for_symbolic_shape(*args) + + if runtime_shapes not in runtime_shapes_to_compiled_graph: + # we need to specialize for this shape, and we haven't compiled + # compile the graph for this shape + logger.info("Compiling a graph for shapes %s", runtime_shapes) + runtime_shapes_to_compiled_graph[runtime_shapes] = wrap_inductor( + graph, args, additional_inductor_config) + + return runtime_shapes_to_compiled_graph[runtime_shapes](*args) + + return compiled_graph_wrapper + + +def select_default_backend(level: int) -> Union[str, Callable]: + if level == 1: + backend = "eager" + return backend + assert level in [2, 3], f"Invalid level {level}" + + from vllm.compilation.backends import vllm_backend + from vllm.plugins import get_inductor_additional_configs + additional_configs = get_inductor_additional_configs() + + if level == 3: + if "max_autotune" in additional_configs and not additional_configs[ + "max_autotune"]: + logger.warning( + "max_autotune is disabled, but is overridden by level 3") + additional_configs['max_autotune'] = True + + from functools import partial + backend = partial(vllm_backend, + additional_inductor_config=additional_configs) + + return backend From 635361355bd893cb4ee153e6e6811ab6fc4e3118 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 3 Oct 2024 14:20:54 -0700 Subject: [PATCH 04/88] add use_custom_dispatcher --- vllm/compilation/decorators.py | 68 ++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 vllm/compilation/decorators.py diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py new file mode 100644 index 00000000000..113bbc1d310 --- /dev/null +++ b/vllm/compilation/decorators.py @@ -0,0 +1,68 @@ +from typing import List, Optional, Tuple, Union + +import torch + +import vllm.envs as envs +from vllm.attention import AttentionMetadata +from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher +from vllm.sequence import IntermediateTensors + + +def support_compile_llama_style(cls: type): + """ + A decorator to add support for compiling the forward method of a class. + If a module's **forward signature** is compatible with llama, this + decorator can be used to enable the compilation of the forward method. + """ + + # take care of method resolution order + # make sure super().__init__ is called on the base class + # other than TorchCompileWrapperWithCustomDispatcher + cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, ) + + old_init = cls.__init__ + + def __init__(self, *args, **kwargs): + old_init(self, *args, **kwargs) + self._use_torch_compile = envs.VLLM_TEST_TORCH_COMPILE_LEVEL > 0 + if self._use_torch_compile: + TorchCompileWrapperWithCustomDispatcher.__init__(self) + + cls.__init__ = __init__ + + def need_to_specialize(self, runtime_shapes: Tuple[int, ...]) -> bool: + if len(self.sizes_to_specialize) == 0: + return False + return runtime_shapes[0] in self.sizes_to_specialize + + cls.need_to_specialize = need_to_specialize + + def __call__( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if not self._use_torch_compile: + return self.forward(input_ids, positions, kv_caches, attn_metadata, + intermediate_tensors) + if len(self.compiled_codes) < 1: + torch._dynamo.mark_dynamic(input_ids, 0) + torch._dynamo.mark_dynamic(positions, 0) + if intermediate_tensors is not None: + for tensors in intermediate_tensors.tensors.values(): + torch._dynamo.mark_dynamic(tensors, 0) + return self.compiled_callable(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + if not self.use_custom_dispatcher: + return self.forward(input_ids, positions, kv_caches, attn_metadata, + intermediate_tensors) + with self.dispatch_to_code(0): + model_output = self.forward(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return model_output + + cls.__call__ = __call__ + return cls From 77ae8e76821e2c24654ded2053109dd01a0a7a4a Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 3 Oct 2024 14:22:26 -0700 Subject: [PATCH 05/88] update wrapper --- vllm/compilation/wrapper.py | 44 +++++++++++++++++++++++++++++++++++-- 1 file changed, 42 insertions(+), 2 deletions(-) diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index e923bd36ccc..930ecb1c608 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -1,9 +1,10 @@ import os import sys +import weakref from abc import abstractmethod from contextlib import contextmanager from types import CodeType -from typing import Callable, List +from typing import Any, Callable, List, Optional, Tuple import torch @@ -23,7 +24,30 @@ class TorchCompileWrapperWithCustomDispatcher: `torch.compile` over the forward method. """ - def __init__(self, compiled_callable: Callable): + def __init__(self, compiled_callable: Optional[Callable] = None): + + if compiled_callable is None: + # default compilation settings + # compiling the forward method + + # choose the compile backend + + # if the user has set the backend, use it + from vllm.plugins import get_torch_compile_backend + backend = get_torch_compile_backend() + if backend is None: + from vllm.compilation.backends import select_default_backend + backend = select_default_backend( + envs.VLLM_TEST_TORCH_COMPILE_LEVEL) + if not isinstance(backend, str): + from functools import partial + backend = partial(backend, model_ref=weakref.ref(self)) + + compiled_callable = torch.compile( + self.forward, + fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, + backend=backend) + self.compiled_callable = compiled_callable self.original_code_object = self.__class__.forward.__code__ self.compiled_codes: List[CodeType] = [] @@ -35,6 +59,8 @@ def __init__(self, compiled_callable: Callable): self.use_custom_dispatcher: bool = \ envs.VLLM_DYNAMO_USE_CUSTOM_DISPATCHER + self.sizes_to_specialize = [] + def __call__(self, *args, **kwargs): """Implement the dispatch logic here, beyond the torch.compile level. NOTE: this function can have additional arguments beyond the forward @@ -79,3 +105,17 @@ def dispatch_to_code(self, index: int): self.__class__.forward.__code__ = self.compiled_codes[index] yield self.__class__.forward.__code__ = self.original_code_object + + def set_sizes_to_specialize(self, sizes: List[Any]): + """Set the sizes to specialize for the compiled code.""" + self.sizes_to_specialize = sizes + + def need_to_specialize(self, runtime_shapes: Tuple[int, ...]) -> bool: + """Check if the current runtime shapes need to be specialized. + If not, we can use the graph for general shapes. + If yes, we will compile the graph for the current shapes. + The argument `runtime_shapes` is a tuple of integers, representing + the runtime shapes of the dimensions marked as dynamic during graph + capture. + """ + return False From 4d99a582556028dfa72b59b43770facf95936c45 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 3 Oct 2024 14:23:23 -0700 Subject: [PATCH 06/88] update envs --- vllm/envs.py | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 0f46ac4f61f..62c14012d82 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -35,7 +35,6 @@ VLLM_PP_LAYER_PARTITION: Optional[str] = None VLLM_CPU_KVCACHE_SPACE: int = 0 VLLM_CPU_OMP_THREADS_BIND: str = "" - VLLM_OPENVINO_DEVICE: str = "CPU" VLLM_OPENVINO_KVCACHE_SPACE: int = 0 VLLM_OPENVINO_CPU_KV_CACHE_PRECISION: Optional[str] = None VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False @@ -202,19 +201,20 @@ def get_default_config_root(): "VLLM_ALLOW_DEPRECATED_BEAM_SEARCH": lambda: os.environ.get("VLLM_ALLOW_DEPRECATED_BEAM_SEARCH", "0") == "1", - # Internal flag to enable Dynamo graph capture - "VLLM_TEST_DYNAMO_GRAPH_CAPTURE": - lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")), + # torch.compile optimization level + # 0: no optimization (don't use torch.compile) + # 1: capture the graph, run in eager mode (seconds of compilation time) + # 2: capture the graph, compile with inductor (minutes of compilation time) + # 3: capture the graph, compile with inductor max-autotune (dozens of minutes of compilation time) # noqa + "VLLM_TEST_TORCH_COMPILE_LEVEL": + lambda: int(os.environ.get("VLLM_TEST_TORCH_COMPILE_LEVEL", "0")), + + # Internal flag for Dynamo testing "VLLM_DYNAMO_USE_CUSTOM_DISPATCHER": lambda: (os.environ.get("VLLM_DYNAMO_USE_CUSTOM_DISPATCHER", "True").lower() in ("true", "1")), - # Internal flag to control whether we use custom op, - # or use the native pytorch implementation - "VLLM_TEST_COMPILE_NO_CUSTOM_OPS": - lambda: int(os.environ.get("VLLM_TEST_COMPILE_NO_CUSTOM_OPS", "0")), - # Internal flag to enable Dynamo fullgraph capture "VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE": lambda: bool( @@ -303,11 +303,6 @@ def get_default_config_root(): "VLLM_CPU_OMP_THREADS_BIND": lambda: os.getenv("VLLM_CPU_OMP_THREADS_BIND", "all"), - # OpenVINO device selection - # default is CPU - "VLLM_OPENVINO_DEVICE": - lambda: os.getenv("VLLM_OPENVINO_DEVICE", "CPU").upper(), - # OpenVINO key-value cache space # default is 4GB "VLLM_OPENVINO_KVCACHE_SPACE": From 2b79376c5dd6e9df5a0e4b2e1831102936161b91 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 3 Oct 2024 14:24:16 -0700 Subject: [PATCH 07/88] update custom op --- vllm/model_executor/custom_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 9102b5e19eb..25b632a4658 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -55,7 +55,7 @@ def dispatch_forward(self): # NOTE(woosuk): Here we assume that vLLM was built for only one # specific backend. Currently, we do not support dynamic dispatching. - if envs.VLLM_TEST_COMPILE_NO_CUSTOM_OPS: + if envs.VLLM_TEST_TORCH_COMPILE_LEVEL >= 2: return self.forward_native if is_hip(): From 7dfddcdecf8afd92f31fea505e2d162ebcb5b917 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 3 Oct 2024 14:25:04 -0700 Subject: [PATCH 08/88] support llama --- vllm/model_executor/models/llama.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 5ff31e3833e..d44e29b78bd 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -28,6 +28,7 @@ from transformers import LlamaConfig from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_compile_llama_style from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) @@ -344,6 +345,7 @@ def forward( return hidden_states +@support_compile_llama_style class LlamaForCausalLM(nn.Module, SupportsLoRA): packed_modules_mapping = { "qkv_proj": [ From abd1a6504df5fb31d6699e6f7afc360ab0cf3345 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 3 Oct 2024 14:26:06 -0700 Subject: [PATCH 09/88] update plugins --- vllm/plugins/__init__.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index 7939688ef0d..211fedbc6e2 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -1,5 +1,5 @@ import logging -from typing import Callable, Optional, Union +from typing import Callable, Dict, Optional, Union import vllm.envs as envs @@ -42,3 +42,15 @@ def set_torch_compile_backend(backend: Union[Callable, str]): def get_torch_compile_backend() -> Optional[Union[Callable, str]]: return _torch_compile_backend + + +_inductor_additional_configs: Dict = {} + + +def set_inductor_additional_configs(configs: Dict): + global _inductor_additional_configs + _inductor_additional_configs = configs + + +def get_inductor_additional_configs() -> Dict: + return _inductor_additional_configs From ce1907fb787d6657b31f0e22c2df9cb86873a82c Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 3 Oct 2024 14:28:13 -0700 Subject: [PATCH 10/88] update model runner --- vllm/worker/model_runner.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 51f65cbfcf8..e6e3d6627c2 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -14,10 +14,10 @@ import torch.distributed import torch.nn as nn -import vllm.envs as envs from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention.backends.abstract import AttentionState from vllm.attention.backends.utils import CommonAttentionState +from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) @@ -47,8 +47,7 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, async_tensor_h2d, - flatten_2d_lists, is_hip, is_pin_memory_available, - supports_dynamo) + flatten_2d_lists, is_hip, is_pin_memory_available) from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, @@ -1125,15 +1124,6 @@ def load_model(self) -> None: "provided. Defaulting to scaling factors of 1.0. " "This may lead to less accurate results!") - if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo(): - from vllm.compilation.backends import vllm_backend - from vllm.plugins import get_torch_compile_backend - backend = get_torch_compile_backend() or vllm_backend - self.model = torch.compile( - self.model, - fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, - backend=backend) - def save_sharded_state( self, path: str, @@ -1426,7 +1416,8 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: batch_size_capture_list = [ bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size ] - + if isinstance(self.model, TorchCompileWrapperWithCustomDispatcher): + self.model.set_sizes_to_specialize(batch_size_capture_list) with self.attn_state.graph_capture( max_batch_size), graph_capture() as graph_capture_context: # NOTE: Capturing the largest batch size first may help reduce the From e1ea867fe975e63ffdb0203ab98c4e43c3d79b7c Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 3 Oct 2024 14:39:55 -0700 Subject: [PATCH 11/88] add support --- .buildkite/test-pipeline.yaml | 6 +- tests/compile/test_full_graph.py | 13 ++- tests/compile/test_full_graph_multi_gpu.py | 22 ---- tests/compile/test_full_graph_smoke.py | 13 --- tests/compile/utils.py | 22 ++-- tests/utils.py | 111 ++++++++++++++------- 6 files changed, 98 insertions(+), 89 deletions(-) delete mode 100644 tests/compile/test_full_graph_multi_gpu.py delete mode 100644 tests/compile/test_full_graph_smoke.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index f678436dd05..7e6ee9dfa93 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -118,7 +118,9 @@ steps: - vllm/core/ - tests/distributed - tests/spec_decode/e2e/test_integration_dist_tp4 + - tests/compile commands: + - pytest -v -s compile/test_basic_correctness.py - pytest -v -s distributed/test_pynccl.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py @@ -224,7 +226,7 @@ steps: - vllm/ - tests/compile commands: - - pytest -v -s compile/test_full_graph_smoke.py + - pytest -v -s compile/test_basic_correctness.py - label: "PyTorch Fullgraph Test" # 18min source_file_dependencies: @@ -388,7 +390,7 @@ steps: - tests/distributed/ - vllm/compilation commands: - - pytest -v -s ./compile/test_full_graph_multi_gpu.py + - pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_wrapper.py - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed' - TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 5dd65ad7236..03e377023b9 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -1,13 +1,16 @@ import pytest -from vllm.compilation.backends import vllm_backend - +from ..utils import fork_new_process_for_each_test from .utils import TEST_MODELS, check_full_graph_support @pytest.mark.parametrize("model_info", TEST_MODELS) -@pytest.mark.parametrize("backend", ["eager", vllm_backend]) -def test_full_graph(model_info, backend): +@pytest.mark.parametrize("optimization_level", [1, 2]) +@fork_new_process_for_each_test +def test_full_graph(model_info, optimization_level): model = model_info[0] model_kwargs = model_info[1] - check_full_graph_support(model, model_kwargs, backend, tp_size=1) + check_full_graph_support(model, + model_kwargs, + optimization_level, + tp_size=1) diff --git a/tests/compile/test_full_graph_multi_gpu.py b/tests/compile/test_full_graph_multi_gpu.py deleted file mode 100644 index e9883d5254e..00000000000 --- a/tests/compile/test_full_graph_multi_gpu.py +++ /dev/null @@ -1,22 +0,0 @@ -import pytest - -from vllm.compilation.backends import vllm_backend -from vllm.utils import cuda_device_count_stateless - -from ..utils import fork_new_process_for_each_test -from .utils import TEST_MODELS_SMOKE, check_full_graph_support - - -@pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE) -@pytest.mark.parametrize("tp_size", [2]) -@pytest.mark.parametrize("backend", ["eager", vllm_backend]) -@fork_new_process_for_each_test -def test_full_graph_multi_gpu(model_info, tp_size, backend): - model = model_info[0] - model_kwargs = model_info[1] - - # Skip the test if there are not enough CUDA devices. - if cuda_device_count_stateless() < tp_size: - pytest.skip("Not enough CUDA devices for the test.") - - check_full_graph_support(model, model_kwargs, backend, tp_size=tp_size) diff --git a/tests/compile/test_full_graph_smoke.py b/tests/compile/test_full_graph_smoke.py deleted file mode 100644 index 0c5a95b4ead..00000000000 --- a/tests/compile/test_full_graph_smoke.py +++ /dev/null @@ -1,13 +0,0 @@ -import pytest - -from vllm.compilation.backends import vllm_backend - -from .utils import TEST_MODELS_SMOKE, check_full_graph_support - - -@pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE) -@pytest.mark.parametrize("backend", ["eager", vllm_backend]) -def test_full_graph(model_info, backend): - model = model_info[0] - model_kwargs = model_info[1] - check_full_graph_support(model, model_kwargs, backend, tp_size=1) diff --git a/tests/compile/utils.py b/tests/compile/utils.py index 2d06a0946d9..2326dcd94de 100644 --- a/tests/compile/utils.py +++ b/tests/compile/utils.py @@ -4,14 +4,12 @@ from tests.quantization.utils import is_quant_method_supported from vllm import LLM, SamplingParams -from vllm.plugins import set_torch_compile_backend from vllm.utils import is_hip TEST_MODELS_SMOKE = [ - ("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", { - "quantization": "compressed-tensors" - }), - ("meta-llama/Meta-Llama-3-8B", {}), + ("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", + ["--quantization", "compressed-tensors"]), + ("meta-llama/Meta-Llama-3-8B", []), ] TEST_MODELS = [ @@ -68,20 +66,20 @@ })) -def check_full_graph_support(model, model_kwargs, backend, tp_size=1): +def check_full_graph_support(model, + model_kwargs, + optimization_level, + tp_size=1): # make sure these models can be captured in full graph mode - if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ: - os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1" - os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1" + os.environ["VLLM_TEST_TORCH_COMPILE_LEVEL"] = str(optimization_level) + os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1" # Inductor doesn't support fp8/gptq_marlin_24 yet. quantization = model_kwargs.get("quantization") if (quantization == "fp8" or quantization == "gptq_marlin" - or quantization == "gptq_marlin_24") and backend != "eager": + or quantization == "gptq_marlin_24") and optimization_level > 1: return - set_torch_compile_backend(backend) - prompts = [ "Hello, my name is", "The president of the United States is", diff --git a/tests/utils.py b/tests/utils.py index 49bd4f236f6..42ae4ec5273 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -180,9 +180,24 @@ def compare_two_settings(model: str, env1: The first set of environment variables to pass to the API server. env2: The second set of environment variables to pass to the API server. """ + compare_all_settings(model, [arg1, arg2], [env1, env2], max_wait_seconds) + + +def compare_all_settings(model: str, + all_args: List[List[str]], + all_envs: List[Optional[Dict[str, str]]], + max_wait_seconds: Optional[float] = None) -> None: + """ + Launch API server with several different sets of arguments/environments + and compare the results of the API calls with the first set of arguments. + Args: + model: The model to test. + all_args: A list of argument lists to pass to the API server. + all_envs: A list of environment dictionaries to pass to the API server. + """ trust_remote_code = "--trust-remote-code" - if trust_remote_code in arg1 or trust_remote_code in arg2: + if any(trust_remote_code in args for args in all_args): tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) else: @@ -190,8 +205,9 @@ def compare_two_settings(model: str, prompt = "Hello, my name is" token_ids = tokenizer(prompt)["input_ids"] - results = [] - for args, env in ((arg1, env1), (arg2, env2)): + ref_results: List = [] + for i, (args, env) in enumerate(zip(all_args, all_envs)): + compare_results: List = [] with RemoteOpenAIServer(model, args, env_dict=env, @@ -202,10 +218,13 @@ def compare_two_settings(model: str, models = client.models.list() models = models.data served_model = models[0] - results.append({ - "test": "models_list", - "id": served_model.id, - "root": served_model.root, + (ref_results if i == 0 else compare_results).append({ + "test": + "models_list", + "id": + served_model.id, + "root": + served_model.root, }) # test with text prompt @@ -214,11 +233,15 @@ def compare_two_settings(model: str, max_tokens=5, temperature=0.0) - results.append({ - "test": "single_completion", - "text": completion.choices[0].text, - "finish_reason": completion.choices[0].finish_reason, - "usage": completion.usage, + (ref_results if i == 0 else compare_results).append({ + "test": + "single_completion", + "text": + completion.choices[0].text, + "finish_reason": + completion.choices[0].finish_reason, + "usage": + completion.usage, }) # test using token IDs @@ -229,11 +252,15 @@ def compare_two_settings(model: str, temperature=0.0, ) - results.append({ - "test": "token_ids", - "text": completion.choices[0].text, - "finish_reason": completion.choices[0].finish_reason, - "usage": completion.usage, + (ref_results if i == 0 else compare_results).append({ + "test": + "token_ids", + "text": + completion.choices[0].text, + "finish_reason": + completion.choices[0].finish_reason, + "usage": + completion.usage, }) # test seeded random sampling @@ -243,11 +270,15 @@ def compare_two_settings(model: str, seed=33, temperature=1.0) - results.append({ - "test": "seeded_sampling", - "text": completion.choices[0].text, - "finish_reason": completion.choices[0].finish_reason, - "usage": completion.usage, + (ref_results if i == 0 else compare_results).append({ + "test": + "seeded_sampling", + "text": + completion.choices[0].text, + "finish_reason": + completion.choices[0].finish_reason, + "usage": + completion.usage, }) # test seeded random sampling with multiple prompts @@ -257,7 +288,7 @@ def compare_two_settings(model: str, seed=33, temperature=1.0) - results.append({ + (ref_results if i == 0 else compare_results).append({ "test": "seeded_sampling", "text": [choice.text for choice in completion.choices], @@ -275,10 +306,13 @@ def compare_two_settings(model: str, temperature=0.0, ) - results.append({ - "test": "simple_list", - "text0": batch.choices[0].text, - "text1": batch.choices[1].text, + (ref_results if i == 0 else compare_results).append({ + "test": + "simple_list", + "text0": + batch.choices[0].text, + "text1": + batch.choices[1].text, }) # test streaming @@ -294,18 +328,25 @@ def compare_two_settings(model: str, assert len(chunk.choices) == 1 choice = chunk.choices[0] texts[choice.index] += choice.text - results.append({ + (ref_results if i == 0 else compare_results).append({ "test": "streaming", "texts": texts, }) - n = len(results) // 2 - arg1_results = results[:n] - arg2_results = results[n:] - for arg1_result, arg2_result in zip(arg1_results, arg2_results): - assert arg1_result == arg2_result, ( - f"Results for {model=} are not the same with {arg1=} and {arg2=}. " - f"{arg1_result=} != {arg2_result=}") + if i > 0: + # if any setting fails, raise an error early + ref_args = all_args[0] + ref_envs = all_envs[0] + compare_args = all_args[i] + compare_envs = all_envs[i] + for ref_result, compare_result in zip(ref_results, + compare_results): + assert ref_result == compare_result, ( + f"Results for {model=} are not the same.\n" + f"{ref_args=} {ref_envs=}\n" + f"{compare_args=} {compare_envs=}\n" + f"{ref_result=}\n" + f"{compare_result=}\n") def init_test_distributed_environment( From 511e07b2c7553a1e0d9c69f643906e1716f43c9a Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 3 Oct 2024 14:40:07 -0700 Subject: [PATCH 12/88] add files --- tests/compile/test_basic_correctness.py | 28 +++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 tests/compile/test_basic_correctness.py diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py new file mode 100644 index 00000000000..9cebe305b5c --- /dev/null +++ b/tests/compile/test_basic_correctness.py @@ -0,0 +1,28 @@ +from typing import Dict, List, Optional + +import pytest + +from vllm.utils import cuda_device_count_stateless + +from ..utils import compare_all_settings +from .utils import TEST_MODELS_SMOKE + + +@pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE) +@pytest.mark.parametrize("pp_size", [1, 2]) +@pytest.mark.parametrize("tp_size", [1]) +def test_compile_correctness(model_info, pp_size, tp_size): + # this test is run under multiple suits, with different GPUs. + # make sure we only run the test with correct CUDA devices. + # don't use "<", as it will duplicate the tests. + if cuda_device_count_stateless() != pp_size * tp_size: + pytest.skip("Not correct CUDA devices for the test.") + model = model_info[0] + model_args = model_info[1] + all_args = [["--enforce-eager"] + model_args + ["--max_model_len", "1024"] + + ["-pp", str(pp_size)] + ["-tp", str(tp_size)]] * 3 + all_envs: List[Optional[Dict[str, str]]] = [{ + "VLLM_TEST_TORCH_COMPILE_LEVEL": + str(i) + } for i in range(3)] + compare_all_settings(model, all_args, all_envs) From 3bb8950d52f91c5550feaaf54b75810c46737dfe Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 3 Oct 2024 17:26:53 -0700 Subject: [PATCH 13/88] fix not use_custom_dispatcher --- vllm/compilation/decorators.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 113bbc1d310..d5e7e97cfa8 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -48,6 +48,8 @@ def __call__( if not self._use_torch_compile: return self.forward(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors) + + # the first compilation needs to have dynamic shapes marked if len(self.compiled_codes) < 1: torch._dynamo.mark_dynamic(input_ids, 0) torch._dynamo.mark_dynamic(positions, 0) @@ -56,13 +58,21 @@ def __call__( torch._dynamo.mark_dynamic(tensors, 0) return self.compiled_callable(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors) + + # if we don't use custom dispatcher, we can directly call the + # compiled function and let torch.compile handle the dispatching, + # with the overhead of guard evaluation and recompilation. if not self.use_custom_dispatcher: - return self.forward(input_ids, positions, kv_caches, attn_metadata, - intermediate_tensors) + return self.compiled_callable(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + + # usually, capturing the model once is enough, and then we can + # dispatch to the compiled code directly, without going through + # the Dynamo guard mechanism. with self.dispatch_to_code(0): model_output = self.forward(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors) - return model_output + return model_output cls.__call__ = __call__ return cls From ed573fa96e54fd4062616faf48cd6ae3aaad77f8 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 4 Oct 2024 17:44:07 -0700 Subject: [PATCH 14/88] do not test inductor --- tests/compile/test_basic_correctness.py | 6 ++++-- tests/utils.py | 25 ++++++++++++++++++------- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index 9cebe305b5c..110cea0f499 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -10,7 +10,7 @@ @pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE) @pytest.mark.parametrize("pp_size", [1, 2]) -@pytest.mark.parametrize("tp_size", [1]) +@pytest.mark.parametrize("tp_size", [1, 2]) def test_compile_correctness(model_info, pp_size, tp_size): # this test is run under multiple suits, with different GPUs. # make sure we only run the test with correct CUDA devices. @@ -21,8 +21,10 @@ def test_compile_correctness(model_info, pp_size, tp_size): model_args = model_info[1] all_args = [["--enforce-eager"] + model_args + ["--max_model_len", "1024"] + ["-pp", str(pp_size)] + ["-tp", str(tp_size)]] * 3 + # don't test VLLM_TEST_TORCH_COMPILE_LEVEL == 2 case + # inductor will change the output, so we cannot compare them. all_envs: List[Optional[Dict[str, str]]] = [{ "VLLM_TEST_TORCH_COMPILE_LEVEL": str(i) - } for i in range(3)] + } for i in range(2)] compare_all_settings(model, all_args, all_envs) diff --git a/tests/utils.py b/tests/utils.py index 53bc0ce3672..84ba7642478 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -196,15 +196,26 @@ def compare_all_settings(model: str, all_envs: A list of environment dictionaries to pass to the API server. """ - trust_remote_code = "--trust-remote-code" - if any(trust_remote_code in args for args in all_args): - tokenizer = AutoTokenizer.from_pretrained(model, - trust_remote_code=True) - else: - tokenizer = AutoTokenizer.from_pretrained(model) + trust_remote_code = False + for args in all_args: + if "--trust-remote-code" in args: + trust_remote_code = True + break + + tokenizer_mode = "auto" + for args in all_args: + if "--tokenizer-mode" in args: + tokenizer_mode = args[args.index("--tokenizer-mode") + 1] + break + + tokenizer = get_tokenizer( + model, + trust_remote_code=trust_remote_code, + tokenizer_mode=tokenizer_mode, + ) prompt = "Hello, my name is" - token_ids = tokenizer(prompt)["input_ids"] + token_ids = tokenizer(prompt).input_ids ref_results: List = [] for i, (args, env) in enumerate(zip(all_args, all_envs)): compare_results: List = [] From 93ef0b5521aacdaa69f3b69d91a9fcc9d4b122e0 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 4 Oct 2024 21:09:15 -0700 Subject: [PATCH 15/88] add compile context --- vllm/compilation/backends.py | 22 +++++++++++----------- vllm/compilation/decorators.py | 9 +-------- vllm/compilation/wrapper.py | 18 +----------------- vllm/worker/model_runner.py | 15 +++++++++++---- 4 files changed, 24 insertions(+), 40 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 98ab6667a68..09f265d2b2a 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -1,5 +1,6 @@ +import copy import operator -from typing import Callable, Dict, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union from weakref import ReferenceType import torch @@ -7,6 +8,7 @@ from vllm.logger import init_logger +from .compile_context import get_compile_context from .wrapper import TorchCompileWrapperWithCustomDispatcher logger = init_logger(__name__) @@ -178,6 +180,8 @@ def vllm_backend( ReferenceType[TorchCompileWrapperWithCustomDispatcher]] = None, additional_inductor_config: Optional[Dict] = None) -> Callable: + sizes_to_specialize: List[int] = copy.deepcopy(get_compile_context()) + # flags for all the seen shapes, whether we need to specialize runtime_shapes_to_compile_flags: Dict[Tuple[int, ...], bool] = {} @@ -214,18 +218,14 @@ def compiled_graph_wrapper(*args): first_run = False return graph_for_symbolic_shape(*args) - if model_ref is None: - # no information about the model, we cannot specialize - return graph_for_symbolic_shape(*args) - - model: TorchCompileWrapperWithCustomDispatcher = model_ref() - assert model is not None, "model is garbage collected" - if runtime_shapes not in runtime_shapes_to_compile_flags: # we haven't seen this shape before - # query the model if we need to specialize for this shape - runtime_shapes_to_compile_flags[ - runtime_shapes] = model.need_to_specialize(runtime_shapes) + # query if we need to specialize for this shape + # we only specialize for the first dimension. + # TODO: investigate if any model needs to specialize + # beyond the first dimension + runtime_shapes_to_compile_flags[runtime_shapes] = runtime_shapes[ + 0] in sizes_to_specialize if not runtime_shapes_to_compile_flags[runtime_shapes]: # we don't need to specialize for this shape diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index d5e7e97cfa8..a4d0be74e88 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Union import torch @@ -30,13 +30,6 @@ def __init__(self, *args, **kwargs): cls.__init__ = __init__ - def need_to_specialize(self, runtime_shapes: Tuple[int, ...]) -> bool: - if len(self.sizes_to_specialize) == 0: - return False - return runtime_shapes[0] in self.sizes_to_specialize - - cls.need_to_specialize = need_to_specialize - def __call__( self, input_ids: torch.Tensor, diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 930ecb1c608..50449bad9ff 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -4,7 +4,7 @@ from abc import abstractmethod from contextlib import contextmanager from types import CodeType -from typing import Any, Callable, List, Optional, Tuple +from typing import Callable, List, Optional import torch @@ -59,8 +59,6 @@ def __init__(self, compiled_callable: Optional[Callable] = None): self.use_custom_dispatcher: bool = \ envs.VLLM_DYNAMO_USE_CUSTOM_DISPATCHER - self.sizes_to_specialize = [] - def __call__(self, *args, **kwargs): """Implement the dispatch logic here, beyond the torch.compile level. NOTE: this function can have additional arguments beyond the forward @@ -105,17 +103,3 @@ def dispatch_to_code(self, index: int): self.__class__.forward.__code__ = self.compiled_codes[index] yield self.__class__.forward.__code__ = self.original_code_object - - def set_sizes_to_specialize(self, sizes: List[Any]): - """Set the sizes to specialize for the compiled code.""" - self.sizes_to_specialize = sizes - - def need_to_specialize(self, runtime_shapes: Tuple[int, ...]) -> bool: - """Check if the current runtime shapes need to be specialized. - If not, we can use the graph for general shapes. - If yes, we will compile the graph for the current shapes. - The argument `runtime_shapes` is a tuple of integers, representing - the runtime shapes of the dimensions marked as dynamic during graph - capture. - """ - return False diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 0034200bbdb..5afdf79ca57 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -17,7 +17,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention.backends.abstract import AttentionState from vllm.attention.backends.utils import CommonAttentionState -from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher +from vllm.compilation.compile_context import set_compile_context from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) @@ -1277,7 +1277,15 @@ def profile_run(self) -> None: batch_size=batch_size, dtype=self.model_config.dtype, device=self.device) - self.execute_model(model_input, kv_caches, intermediate_tensors) + + graph_batch_size = self.max_batchsize_to_capture + batch_size_capture_list = [ + bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size + ] + if self.model_config.enforce_eager: + batch_size_capture_list = [] + with set_compile_context(batch_size_capture_list): + self.execute_model(model_input, kv_caches, intermediate_tensors) torch.cuda.synchronize() return @@ -1415,8 +1423,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: batch_size_capture_list = [ bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size ] - if isinstance(self.model, TorchCompileWrapperWithCustomDispatcher): - self.model.set_sizes_to_specialize(batch_size_capture_list) + with self.attn_state.graph_capture( max_batch_size), graph_capture() as graph_capture_context: # NOTE: Capturing the largest batch size first may help reduce the From 3cd40dbf011193b1e3202e452bd8525ad3af6eb9 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 4 Oct 2024 21:11:51 -0700 Subject: [PATCH 16/88] remove model reference --- vllm/compilation/backends.py | 4 ---- vllm/compilation/compile_context.py | 21 +++++++++++++++++++++ vllm/compilation/wrapper.py | 4 ---- 3 files changed, 21 insertions(+), 8 deletions(-) create mode 100644 vllm/compilation/compile_context.py diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 09f265d2b2a..d3fcbeaf569 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -1,7 +1,6 @@ import copy import operator from typing import Callable, Dict, List, Optional, Tuple, Union -from weakref import ReferenceType import torch import torch.fx as fx @@ -9,7 +8,6 @@ from vllm.logger import init_logger from .compile_context import get_compile_context -from .wrapper import TorchCompileWrapperWithCustomDispatcher logger = init_logger(__name__) @@ -176,8 +174,6 @@ def wrap_inductor(graph, example_inputs, additional_inductor_config): def vllm_backend( graph, example_inputs, - model_ref: Optional[ - ReferenceType[TorchCompileWrapperWithCustomDispatcher]] = None, additional_inductor_config: Optional[Dict] = None) -> Callable: sizes_to_specialize: List[int] = copy.deepcopy(get_compile_context()) diff --git a/vllm/compilation/compile_context.py b/vllm/compilation/compile_context.py new file mode 100644 index 00000000000..881c5d9386a --- /dev/null +++ b/vllm/compilation/compile_context.py @@ -0,0 +1,21 @@ +from contextlib import contextmanager +from typing import Any + +_compile_context: Any = None + +def get_compile_context() -> Any: + """Get the current compile context.""" + return _compile_context + +@contextmanager +def set_compile_context(context: Any): + """A context manager that stores the current compile context, + usually it is a list of sizes to specialize. + """ + global _compile_context + prev_context = _compile_context + _compile_context = context + try: + yield + finally: + _compile_context = prev_context diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 50449bad9ff..31bcfa7ac64 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -1,6 +1,5 @@ import os import sys -import weakref from abc import abstractmethod from contextlib import contextmanager from types import CodeType @@ -39,9 +38,6 @@ def __init__(self, compiled_callable: Optional[Callable] = None): from vllm.compilation.backends import select_default_backend backend = select_default_backend( envs.VLLM_TEST_TORCH_COMPILE_LEVEL) - if not isinstance(backend, str): - from functools import partial - backend = partial(backend, model_ref=weakref.ref(self)) compiled_callable = torch.compile( self.forward, From 4e2893086717a31470c9276c47891605f03ab9ed Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 4 Oct 2024 21:14:54 -0700 Subject: [PATCH 17/88] lint --- vllm/compilation/compile_context.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/compilation/compile_context.py b/vllm/compilation/compile_context.py index 881c5d9386a..29db3d4c637 100644 --- a/vllm/compilation/compile_context.py +++ b/vllm/compilation/compile_context.py @@ -3,10 +3,12 @@ _compile_context: Any = None + def get_compile_context() -> Any: """Get the current compile context.""" return _compile_context + @contextmanager def set_compile_context(context: Any): """A context manager that stores the current compile context, From 2ac7274c2d87352ae0751c89aeed102314105767 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 7 Oct 2024 16:52:00 -0700 Subject: [PATCH 18/88] change levels --- tests/compile/test_basic_correctness.py | 6 +++--- tests/compile/utils.py | 2 +- tests/tpu/test_compilation.py | 2 +- tests/tpu/test_custom_dispatcher.py | 2 +- vllm/compilation/backends.py | 8 ++++---- vllm/compilation/decorators.py | 2 +- vllm/compilation/wrapper.py | 5 ++--- vllm/envs.py | 10 ++-------- vllm/model_executor/custom_op.py | 2 +- vllm/platforms/tpu.py | 9 +++++++++ 10 files changed, 25 insertions(+), 23 deletions(-) diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index 110cea0f499..0bad1623d45 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -21,10 +21,10 @@ def test_compile_correctness(model_info, pp_size, tp_size): model_args = model_info[1] all_args = [["--enforce-eager"] + model_args + ["--max_model_len", "1024"] + ["-pp", str(pp_size)] + ["-tp", str(tp_size)]] * 3 - # don't test VLLM_TEST_TORCH_COMPILE_LEVEL == 2 case + # don't test VLLM_TORCH_COMPILE_LEVEL == 3 case # inductor will change the output, so we cannot compare them. all_envs: List[Optional[Dict[str, str]]] = [{ - "VLLM_TEST_TORCH_COMPILE_LEVEL": + "VLLM_TORCH_COMPILE_LEVEL": str(i) - } for i in range(2)] + } for i in range(3)] compare_all_settings(model, all_args, all_envs) diff --git a/tests/compile/utils.py b/tests/compile/utils.py index 2326dcd94de..eb5b2e741f9 100644 --- a/tests/compile/utils.py +++ b/tests/compile/utils.py @@ -71,7 +71,7 @@ def check_full_graph_support(model, optimization_level, tp_size=1): # make sure these models can be captured in full graph mode - os.environ["VLLM_TEST_TORCH_COMPILE_LEVEL"] = str(optimization_level) + os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(optimization_level) os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1" # Inductor doesn't support fp8/gptq_marlin_24 yet. diff --git a/tests/tpu/test_compilation.py b/tests/tpu/test_compilation.py index d8df86b2aaa..264b6c92f45 100644 --- a/tests/tpu/test_compilation.py +++ b/tests/tpu/test_compilation.py @@ -7,7 +7,7 @@ # disable custom dispatcher, let Dynamo takes over # all the control -os.environ['VLLM_DYNAMO_USE_CUSTOM_DISPATCHER'] = "0" +os.environ['VLLM_TORCH_COMPILE_LEVEL'] = "1" temp_dir = tempfile.mkdtemp() with depyf.prepare_debug(temp_dir): diff --git a/tests/tpu/test_custom_dispatcher.py b/tests/tpu/test_custom_dispatcher.py index 69ab67abdd1..2c73ea518f8 100644 --- a/tests/tpu/test_custom_dispatcher.py +++ b/tests/tpu/test_custom_dispatcher.py @@ -12,5 +12,5 @@ def test_custom_dispatcher(): compare_two_settings("google/gemma-2b", arg1=["--enforce-eager"], arg2=["--enforce-eager"], - env1={"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER": "0"}, + env1={"VLLM_TORCH_COMPILE_LEVEL": "1"}, env2={}) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index d3fcbeaf569..6e40e63f50a 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -240,20 +240,20 @@ def compiled_graph_wrapper(*args): def select_default_backend(level: int) -> Union[str, Callable]: - if level == 1: + if level in [1, 2]: backend = "eager" return backend - assert level in [2, 3], f"Invalid level {level}" + assert level in [3, 4], f"Invalid level {level}" from vllm.compilation.backends import vllm_backend from vllm.plugins import get_inductor_additional_configs additional_configs = get_inductor_additional_configs() - if level == 3: + if level == 4: if "max_autotune" in additional_configs and not additional_configs[ "max_autotune"]: logger.warning( - "max_autotune is disabled, but is overridden by level 3") + "max_autotune is disabled, but is overridden by level 4") additional_configs['max_autotune'] = True from functools import partial diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index a4d0be74e88..729af70345d 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -24,7 +24,7 @@ def support_compile_llama_style(cls: type): def __init__(self, *args, **kwargs): old_init(self, *args, **kwargs) - self._use_torch_compile = envs.VLLM_TEST_TORCH_COMPILE_LEVEL > 0 + self._use_torch_compile = envs.VLLM_TORCH_COMPILE_LEVEL > 0 if self._use_torch_compile: TorchCompileWrapperWithCustomDispatcher.__init__(self) diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 31bcfa7ac64..5c68888bcfa 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -36,8 +36,7 @@ def __init__(self, compiled_callable: Optional[Callable] = None): backend = get_torch_compile_backend() if backend is None: from vllm.compilation.backends import select_default_backend - backend = select_default_backend( - envs.VLLM_TEST_TORCH_COMPILE_LEVEL) + backend = select_default_backend(envs.VLLM_TORCH_COMPILE_LEVEL) compiled_callable = torch.compile( self.forward, @@ -53,7 +52,7 @@ def __init__(self, compiled_callable: Optional[Callable] = None): # subclasses can use this to switch between the custom dispatcher # and the default Dynamo guard mechanism. self.use_custom_dispatcher: bool = \ - envs.VLLM_DYNAMO_USE_CUSTOM_DISPATCHER + envs.VLLM_TORCH_COMPILE_LEVEL >= 2 def __call__(self, *args, **kwargs): """Implement the dispatch logic here, beyond the torch.compile level. diff --git a/vllm/envs.py b/vllm/envs.py index 62c14012d82..683627790b8 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -206,14 +206,8 @@ def get_default_config_root(): # 1: capture the graph, run in eager mode (seconds of compilation time) # 2: capture the graph, compile with inductor (minutes of compilation time) # 3: capture the graph, compile with inductor max-autotune (dozens of minutes of compilation time) # noqa - "VLLM_TEST_TORCH_COMPILE_LEVEL": - lambda: int(os.environ.get("VLLM_TEST_TORCH_COMPILE_LEVEL", "0")), - - # Internal flag for Dynamo testing - "VLLM_DYNAMO_USE_CUSTOM_DISPATCHER": - lambda: - (os.environ.get("VLLM_DYNAMO_USE_CUSTOM_DISPATCHER", "True").lower() in - ("true", "1")), + "VLLM_TORCH_COMPILE_LEVEL": + lambda: int(os.environ.get("VLLM_TORCH_COMPILE_LEVEL", "0")), # Internal flag to enable Dynamo fullgraph capture "VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE": diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 25b632a4658..09fe35ad617 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -55,7 +55,7 @@ def dispatch_forward(self): # NOTE(woosuk): Here we assume that vLLM was built for only one # specific backend. Currently, we do not support dynamic dispatching. - if envs.VLLM_TEST_TORCH_COMPILE_LEVEL >= 2: + if envs.VLLM_TORCH_COMPILE_LEVEL >= 2: return self.forward_native if is_hip(): diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index a35777f91ca..03fe49dbe05 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -1,7 +1,16 @@ +import os + import torch +from vllm.plugins import set_torch_compile_backend + from .interface import Platform, PlatformEnum +if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ: + os.environ["VLLM_TORCH_COMPILE_LEVEL"] = "2" + +set_torch_compile_backend("openxla") + class TpuPlatform(Platform): _enum = PlatformEnum.TPU From a3c947e88be9beaafc882dbbd6d49ceb88d9b452 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 7 Oct 2024 23:00:52 -0700 Subject: [PATCH 19/88] add levels --- vllm/compilation/levels.py | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 vllm/compilation/levels.py diff --git a/vllm/compilation/levels.py b/vllm/compilation/levels.py new file mode 100644 index 00000000000..e2c17590e62 --- /dev/null +++ b/vllm/compilation/levels.py @@ -0,0 +1,8 @@ +# constants for the levels of the compilation process + +class CompilationLevel: + NO_COMPILATION = 0 + DYNAMO_AS_IS = 1 + DYNAMO_ONCE = 2 + INDUCTOR = 3 + INDUCTOR_MAX_AUTOTUNE = 4 From 1a41c576466a18c7a1a4b79c0802236a54719aac Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 7 Oct 2024 23:01:40 -0700 Subject: [PATCH 20/88] use const --- tests/compile/test_full_graph.py | 6 +++++- vllm/compilation/levels.py | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 03e377023b9..f28f9145bb4 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -1,11 +1,15 @@ import pytest +from vllm.compilation.levels import CompilationLevel + from ..utils import fork_new_process_for_each_test from .utils import TEST_MODELS, check_full_graph_support @pytest.mark.parametrize("model_info", TEST_MODELS) -@pytest.mark.parametrize("optimization_level", [1, 2]) +@pytest.mark.parametrize( + "optimization_level", + [CompilationLevel.DYNAMO_ONCE, CompilationLevel.INDUCTOR]) @fork_new_process_for_each_test def test_full_graph(model_info, optimization_level): model = model_info[0] diff --git a/vllm/compilation/levels.py b/vllm/compilation/levels.py index e2c17590e62..162bf5ae649 100644 --- a/vllm/compilation/levels.py +++ b/vllm/compilation/levels.py @@ -1,5 +1,6 @@ # constants for the levels of the compilation process + class CompilationLevel: NO_COMPILATION = 0 DYNAMO_AS_IS = 1 From db61567933dca6698dfd171a18583bbe9fb0c7b1 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 7 Oct 2024 23:02:59 -0700 Subject: [PATCH 21/88] use const --- tests/compile/test_basic_correctness.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index 0bad1623d45..57cf2c72c60 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -2,6 +2,7 @@ import pytest +from vllm.compilation.levels import CompilationLevel from vllm.utils import cuda_device_count_stateless from ..utils import compare_all_settings @@ -26,5 +27,8 @@ def test_compile_correctness(model_info, pp_size, tp_size): all_envs: List[Optional[Dict[str, str]]] = [{ "VLLM_TORCH_COMPILE_LEVEL": str(i) - } for i in range(3)] + } for i in [ + CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_ONCE, + CompilationLevel.DYNAMO_AS_IS + ]] compare_all_settings(model, all_args, all_envs) From 275ede96818e11a6c405cd5764f0d17eefd29854 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 7 Oct 2024 23:03:21 -0700 Subject: [PATCH 22/88] use const --- tests/compile/test_basic_correctness.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index 57cf2c72c60..51e4b1da627 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -26,8 +26,8 @@ def test_compile_correctness(model_info, pp_size, tp_size): # inductor will change the output, so we cannot compare them. all_envs: List[Optional[Dict[str, str]]] = [{ "VLLM_TORCH_COMPILE_LEVEL": - str(i) - } for i in [ + str(level) + } for level in [ CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_ONCE, CompilationLevel.DYNAMO_AS_IS ]] From d1f084dbae0096a38b1cf6c05fde30c92a477180 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 7 Oct 2024 23:05:23 -0700 Subject: [PATCH 23/88] use const --- tests/compile/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/compile/utils.py b/tests/compile/utils.py index eb5b2e741f9..4473e158c51 100644 --- a/tests/compile/utils.py +++ b/tests/compile/utils.py @@ -4,6 +4,7 @@ from tests.quantization.utils import is_quant_method_supported from vllm import LLM, SamplingParams +from vllm.compilation.levels import CompilationLevel from vllm.utils import is_hip TEST_MODELS_SMOKE = [ @@ -77,7 +78,8 @@ def check_full_graph_support(model, # Inductor doesn't support fp8/gptq_marlin_24 yet. quantization = model_kwargs.get("quantization") if (quantization == "fp8" or quantization == "gptq_marlin" - or quantization == "gptq_marlin_24") and optimization_level > 1: + or quantization == "gptq_marlin_24" + ) and optimization_level >= CompilationLevel.INDUCTOR: return prompts = [ From 326c5b4e08e97675adfdc4e9b00663e9114c4b8f Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 7 Oct 2024 23:06:38 -0700 Subject: [PATCH 24/88] use const --- tests/tpu/test_compilation.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/tpu/test_compilation.py b/tests/tpu/test_compilation.py index 264b6c92f45..86d9af88e49 100644 --- a/tests/tpu/test_compilation.py +++ b/tests/tpu/test_compilation.py @@ -5,9 +5,11 @@ import depyf +from vllm.compilation.levels import CompilationLevel + # disable custom dispatcher, let Dynamo takes over # all the control -os.environ['VLLM_TORCH_COMPILE_LEVEL'] = "1" +os.environ['VLLM_TORCH_COMPILE_LEVEL'] = str(CompilationLevel.DYNAMO_AS_IS) temp_dir = tempfile.mkdtemp() with depyf.prepare_debug(temp_dir): From 9b7b0f3b7338dd8695540b9928cfb01f08911842 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 7 Oct 2024 23:07:23 -0700 Subject: [PATCH 25/88] use const --- tests/tpu/test_custom_dispatcher.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/tpu/test_custom_dispatcher.py b/tests/tpu/test_custom_dispatcher.py index 2c73ea518f8..53e09cc0bd1 100644 --- a/tests/tpu/test_custom_dispatcher.py +++ b/tests/tpu/test_custom_dispatcher.py @@ -1,5 +1,7 @@ import os +from vllm.compilation.levels import CompilationLevel + from ..utils import compare_two_settings # --enforce-eager on TPU causes graph compilation @@ -9,8 +11,9 @@ def test_custom_dispatcher(): - compare_two_settings("google/gemma-2b", - arg1=["--enforce-eager"], - arg2=["--enforce-eager"], - env1={"VLLM_TORCH_COMPILE_LEVEL": "1"}, - env2={}) + compare_two_settings( + "google/gemma-2b", + arg1=["--enforce-eager"], + arg2=["--enforce-eager"], + env1={"VLLM_TORCH_COMPILE_LEVEL": str(CompilationLevel.DYNAMO_ONCE)}, + env2={}) From 9cfa70c03599b7b1a5116582d030b91da023c061 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 7 Oct 2024 23:08:31 -0700 Subject: [PATCH 26/88] use const --- tests/tpu/test_custom_dispatcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tpu/test_custom_dispatcher.py b/tests/tpu/test_custom_dispatcher.py index 53e09cc0bd1..923d0f16808 100644 --- a/tests/tpu/test_custom_dispatcher.py +++ b/tests/tpu/test_custom_dispatcher.py @@ -16,4 +16,4 @@ def test_custom_dispatcher(): arg1=["--enforce-eager"], arg2=["--enforce-eager"], env1={"VLLM_TORCH_COMPILE_LEVEL": str(CompilationLevel.DYNAMO_ONCE)}, - env2={}) + env2={"VLLM_TORCH_COMPILE_LEVEL": str(CompilationLevel.DYNAMO_AS_IS)}) From e819be787f0605aa49c5159ddb0f12c8eff577f1 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 7 Oct 2024 23:09:19 -0700 Subject: [PATCH 27/88] use const --- tests/compile/test_basic_correctness.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index 51e4b1da627..f5d7e472c52 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -28,7 +28,8 @@ def test_compile_correctness(model_info, pp_size, tp_size): "VLLM_TORCH_COMPILE_LEVEL": str(level) } for level in [ - CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_ONCE, - CompilationLevel.DYNAMO_AS_IS + CompilationLevel.NO_COMPILATION, + CompilationLevel.DYNAMO_AS_IS, + CompilationLevel.DYNAMO_ONCE, ]] compare_all_settings(model, all_args, all_envs) From d9cb1622cacec75fad719083ce148ce649fa65a9 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 7 Oct 2024 23:12:45 -0700 Subject: [PATCH 28/88] use const --- vllm/compilation/backends.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 6e40e63f50a..892ff620990 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -8,6 +8,7 @@ from vllm.logger import init_logger from .compile_context import get_compile_context +from .levels import CompilationLevel logger = init_logger(__name__) @@ -240,20 +241,23 @@ def compiled_graph_wrapper(*args): def select_default_backend(level: int) -> Union[str, Callable]: - if level in [1, 2]: + if level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]: backend = "eager" return backend - assert level in [3, 4], f"Invalid level {level}" + assert level in [ + CompilationLevel.INDUCTOR, CompilationLevel.INDUCTOR_MAX_AUTOTUNE + ], f"Invalid level {level}" from vllm.compilation.backends import vllm_backend from vllm.plugins import get_inductor_additional_configs additional_configs = get_inductor_additional_configs() - if level == 4: + if level == CompilationLevel.INDUCTOR_MAX_AUTOTUNE: if "max_autotune" in additional_configs and not additional_configs[ "max_autotune"]: logger.warning( - "max_autotune is disabled, but is overridden by level 4") + "max_autotune is disabled, but is overridden by level %s", + CompilationLevel.INDUCTOR_MAX_AUTOTUNE) additional_configs['max_autotune'] = True from functools import partial From 825f38435d9a2f30f863847057312c1ca4e11e0b Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 7 Oct 2024 23:15:04 -0700 Subject: [PATCH 29/88] use const --- vllm/platforms/tpu.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 03fe49dbe05..dbb87f38e37 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -2,12 +2,13 @@ import torch +from vllm.compilation.levels import CompilationLevel from vllm.plugins import set_torch_compile_backend from .interface import Platform, PlatformEnum if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ: - os.environ["VLLM_TORCH_COMPILE_LEVEL"] = "2" + os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.DYNAMO_ONCE) set_torch_compile_backend("openxla") From c785fc893eb44afa12e09af00529735ddcc19cb0 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 7 Oct 2024 23:16:12 -0700 Subject: [PATCH 30/88] use const --- vllm/model_executor/custom_op.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 09fe35ad617..d0e90245ad0 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -1,6 +1,7 @@ import torch.nn as nn import vllm.envs as envs +from vllm.compilation.levels import CompilationLevel from vllm.platforms import current_platform from vllm.utils import is_cpu, is_hip, is_xpu @@ -55,7 +56,7 @@ def dispatch_forward(self): # NOTE(woosuk): Here we assume that vLLM was built for only one # specific backend. Currently, we do not support dynamic dispatching. - if envs.VLLM_TORCH_COMPILE_LEVEL >= 2: + if envs.VLLM_TORCH_COMPILE_LEVEL >= CompilationLevel.INDUCTOR: return self.forward_native if is_hip(): From 28e9f6f9c8ee84bdec8420902e6bf3b0d4302763 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 7 Oct 2024 23:17:37 -0700 Subject: [PATCH 31/88] restore --- vllm/envs.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/envs.py b/vllm/envs.py index fed1d8f9fc3..94e94031eef 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -35,6 +35,7 @@ VLLM_PP_LAYER_PARTITION: Optional[str] = None VLLM_CPU_KVCACHE_SPACE: int = 0 VLLM_CPU_OMP_THREADS_BIND: str = "" + VLLM_OPENVINO_DEVICE: str = "CPU" VLLM_OPENVINO_KVCACHE_SPACE: int = 0 VLLM_OPENVINO_CPU_KV_CACHE_PRECISION: Optional[str] = None VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False @@ -305,6 +306,11 @@ def get_default_config_root(): "VLLM_CPU_OMP_THREADS_BIND": lambda: os.getenv("VLLM_CPU_OMP_THREADS_BIND", "all"), + # OpenVINO device selection + # default is CPU + "VLLM_OPENVINO_DEVICE": + lambda: os.getenv("VLLM_OPENVINO_DEVICE", "CPU").upper(), + # OpenVINO key-value cache space # default is 4GB "VLLM_OPENVINO_KVCACHE_SPACE": From 718c5e4a76d6f7b896205720653853bb6c115798 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 7 Oct 2024 23:20:06 -0700 Subject: [PATCH 32/88] use const --- vllm/compilation/wrapper.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 5c68888bcfa..1594b64a61b 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -9,6 +9,8 @@ import vllm.envs as envs +from .levels import CompilationLevel + class TorchCompileWrapperWithCustomDispatcher: """ @@ -52,7 +54,7 @@ def __init__(self, compiled_callable: Optional[Callable] = None): # subclasses can use this to switch between the custom dispatcher # and the default Dynamo guard mechanism. self.use_custom_dispatcher: bool = \ - envs.VLLM_TORCH_COMPILE_LEVEL >= 2 + envs.VLLM_TORCH_COMPILE_LEVEL >= CompilationLevel.DYNAMO_ONCE def __call__(self, *args, **kwargs): """Implement the dispatch logic here, beyond the torch.compile level. From 03081cd26fa32d4515ea2b8198b5b772aaa841e3 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 7 Oct 2024 23:21:46 -0700 Subject: [PATCH 33/88] use const --- vllm/envs.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 94e94031eef..f217c7e8ac4 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -197,29 +197,10 @@ def get_default_config_root(): lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1")), - # Internal flag to enable Dynamo graph capture - "VLLM_TEST_DYNAMO_GRAPH_CAPTURE": - lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")), - "VLLM_DYNAMO_USE_CUSTOM_DISPATCHER": - lambda: - (os.environ.get("VLLM_DYNAMO_USE_CUSTOM_DISPATCHER", "True").lower() in - ("true", "1")), - - # Internal flag to control whether we use custom op, - # or use the native pytorch implementation - "VLLM_TEST_COMPILE_NO_CUSTOM_OPS": - lambda: int(os.environ.get("VLLM_TEST_COMPILE_NO_CUSTOM_OPS", "0")), - # Internal flag to enable Dynamo fullgraph capture "VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE": lambda: bool( os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"), - - # torch.compile optimization level - # 0: no optimization (don't use torch.compile) - # 1: capture the graph, run in eager mode (seconds of compilation time) - # 2: capture the graph, compile with inductor (minutes of compilation time) - # 3: capture the graph, compile with inductor max-autotune (dozens of minutes of compilation time) # noqa "VLLM_TORCH_COMPILE_LEVEL": lambda: int(os.environ.get("VLLM_TORCH_COMPILE_LEVEL", "0")), From fbac08d99a9d2ec8546cd8d93c83bb33a9a9f338 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 7 Oct 2024 23:24:14 -0700 Subject: [PATCH 34/88] error on inductor for tpu --- vllm/platforms/tpu.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index dbb87f38e37..8ba973b2826 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -2,6 +2,7 @@ import torch +import vllm.envs as envs from vllm.compilation.levels import CompilationLevel from vllm.plugins import set_torch_compile_backend @@ -10,6 +11,9 @@ if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ: os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.DYNAMO_ONCE) +assert envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR,\ + "TPU does not support Inductor." + set_torch_compile_backend("openxla") From 3c688ea57e4cbe23d5d2a5971724896ca02ad100 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 7 Oct 2024 23:36:42 -0700 Subject: [PATCH 35/88] fix llava --- vllm/compilation/decorators.py | 19 +++++++++++++------ vllm/model_executor/models/llama.py | 2 +- vllm/model_executor/models/llava.py | 8 +++++--- 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 729af70345d..e71114a90d2 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -4,6 +4,7 @@ import vllm.envs as envs from vllm.attention import AttentionMetadata +from vllm.compilation.levels import CompilationLevel from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.sequence import IntermediateTensors @@ -24,7 +25,8 @@ def support_compile_llama_style(cls: type): def __init__(self, *args, **kwargs): old_init(self, *args, **kwargs) - self._use_torch_compile = envs.VLLM_TORCH_COMPILE_LEVEL > 0 + self._use_torch_compile = \ + envs.VLLM_TORCH_COMPILE_LEVEL > CompilationLevel.NO_COMPILATION if self._use_torch_compile: TorchCompileWrapperWithCustomDispatcher.__init__(self) @@ -32,25 +34,30 @@ def __init__(self, *args, **kwargs): def __call__( self, - input_ids: torch.Tensor, + input_ids: Optional[torch.Tensor], positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if not self._use_torch_compile: return self.forward(input_ids, positions, kv_caches, attn_metadata, - intermediate_tensors) + intermediate_tensors, inputs_embeds) # the first compilation needs to have dynamic shapes marked if len(self.compiled_codes) < 1: - torch._dynamo.mark_dynamic(input_ids, 0) + if input_ids is not None: + torch._dynamo.mark_dynamic(input_ids, 0) torch._dynamo.mark_dynamic(positions, 0) + if inputs_embeds is not None: + torch._dynamo.mark_dynamic(inputs_embeds, 0) if intermediate_tensors is not None: for tensors in intermediate_tensors.tensors.values(): torch._dynamo.mark_dynamic(tensors, 0) return self.compiled_callable(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) # if we don't use custom dispatcher, we can directly call the # compiled function and let torch.compile handle the dispatching, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 0318d9d0d20..b7417b36282 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -267,6 +267,7 @@ def forward( return hidden_states, residual +@support_compile_llama_style class LlamaModel(nn.Module): def __init__( @@ -434,7 +435,6 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None: "factor attribute!") -@support_compile_llama_style class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index a62231b628c..dfd341a1146 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -366,6 +366,8 @@ def forward( input_ids = None inputs_embeds = None else: + # always pass the input via `inputs_embeds` + # to make sure the computation graph is consistent image_input = self._parse_and_validate_image_input(**kwargs) if image_input is not None: @@ -376,10 +378,10 @@ def forward( inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, vision_embeddings, self.config.image_token_index) - - input_ids = None else: - inputs_embeds = None + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) + input_ids = None hidden_states = self.language_model.model(input_ids, positions, From 32676f80535290f9a1e0d9499476e242eb6a8504 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 7 Oct 2024 23:43:54 -0700 Subject: [PATCH 36/88] restore tpu --- vllm/compilation/decorators.py | 3 +++ vllm/worker/tpu_model_runner.py | 3 +-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index e71114a90d2..04e5adcc441 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -16,6 +16,9 @@ def support_compile_llama_style(cls: type): decorator can be used to enable the compilation of the forward method. """ + if envs.VLLM_TORCH_COMPILE_LEVEL == CompilationLevel.NO_COMPILATION: + return cls + # take care of method resolution order # make sure super().__init__ is called on the base class # other than TorchCompileWrapperWithCustomDispatcher diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index c4e8afc591e..12e4215038d 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -762,8 +762,7 @@ def forward( slot_mapping = slot_mapping.flatten() attn_metadata.slot_mapping = slot_mapping - # directly call `forward` to avoid interference from compilation - hidden_states = self.model.forward( + hidden_states = self.model( token_ids, position_ids, kv_caches, From 3ed89da701f39554c2170306d018d6f8b7010d3f Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 7 Oct 2024 23:47:03 -0700 Subject: [PATCH 37/88] adjust for tpu --- vllm/utils.py | 18 ++++++++++++++++++ vllm/worker/tpu_model_runner.py | 6 +++++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/vllm/utils.py b/vllm/utils.py index bec2f951d69..5044b768c7c 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1370,3 +1370,21 @@ def dec(self, num=1): @property def value(self): return self._value + + +@contextlib.contextmanager +def temporary_env_var(key, value): + # Save the original value of the environment variable (if it exists) + original_value = os.environ.get(key) + # Set the environment variable to the new temporary value + os.environ[key] = value + try: + # Yield control back to the code block inside the "with" statement + yield + finally: + # Restore the original value after exiting the "with" block + if original_value is None: + del os.environ[key] # If it was originally not set, remove it + else: + os.environ[ + key] = original_value # Otherwise, restore the original value diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 12e4215038d..79874907494 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -20,6 +20,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, Logprob, SequenceGroupMetadata, SequenceOutput) +from vllm.utils import temporary_env_var from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, _add_attn_metadata_broadcastable_dict, @@ -140,7 +141,10 @@ def load_model(self) -> None: with patch( "vllm.model_executor.layers.vocab_parallel_embedding." "get_tensor_model_parallel_rank", - return_value=xm_tp_rank): + return_value=xm_tp_rank), \ + temporary_env_var("VLLM_TORCH_COMPILE_LEVEL", 0): + # patching VLLM_TORCH_COMPILE_LEVEL to 0 so that + # compilation for other platforms is disabled. model = get_model( model_config=self.model_config, load_config=self.load_config, From a3c3e210d97de4012e3edd28838e37afc4300916 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 8 Oct 2024 00:02:59 -0700 Subject: [PATCH 38/88] fix env var --- vllm/model_executor/models/gemma2.py | 2 ++ vllm/worker/tpu_model_runner.py | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index bd3c1114c92..f83aabf1cc5 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -21,6 +21,7 @@ from transformers import Gemma2Config from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_compile_llama_style from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.logger import init_logger @@ -238,6 +239,7 @@ def forward( return hidden_states, residual +@support_compile_llama_style class Gemma2Model(nn.Module): def __init__( diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 79874907494..056282dd808 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -11,6 +11,7 @@ import torch_xla.runtime as xr from vllm.attention import AttentionMetadata, get_attn_backend +from vllm.compilation.levels import CompilationLevel from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, ParallelConfig, SchedulerConfig) @@ -142,7 +143,7 @@ def load_model(self) -> None: "vllm.model_executor.layers.vocab_parallel_embedding." "get_tensor_model_parallel_rank", return_value=xm_tp_rank), \ - temporary_env_var("VLLM_TORCH_COMPILE_LEVEL", 0): + temporary_env_var("VLLM_TORCH_COMPILE_LEVEL", str(CompilationLevel.NO_COMPILATION)): # noqa # patching VLLM_TORCH_COMPILE_LEVEL to 0 so that # compilation for other platforms is disabled. model = get_model( From 30ff04f3b2c5a7a0169282aded913556e95a52c8 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 8 Oct 2024 00:53:45 -0700 Subject: [PATCH 39/88] fix calling --- vllm/compilation/decorators.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 04e5adcc441..d4932495168 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -44,7 +44,10 @@ def __call__( intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - if not self._use_torch_compile: + # torch.compile.is_compiling() means we are inside the compilation + # e.g. TPU has the compilation logic in model runner, so we don't + # need to compile the model inside. + if not self._use_torch_compile or torch.compile.is_compiling(): return self.forward(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors, inputs_embeds) @@ -58,23 +61,22 @@ def __call__( if intermediate_tensors is not None: for tensors in intermediate_tensors.tensors.values(): torch._dynamo.mark_dynamic(tensors, 0) - return self.compiled_callable(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) # if we don't use custom dispatcher, we can directly call the # compiled function and let torch.compile handle the dispatching, # with the overhead of guard evaluation and recompilation. - if not self.use_custom_dispatcher: + if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher: return self.compiled_callable(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) # usually, capturing the model once is enough, and then we can # dispatch to the compiled code directly, without going through # the Dynamo guard mechanism. with self.dispatch_to_code(0): model_output = self.forward(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return model_output cls.__call__ = __call__ From 13256c4c6d0d50937e307345863016e1bfb4dcea Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 8 Oct 2024 00:53:56 -0700 Subject: [PATCH 40/88] revert tpu --- vllm/worker/tpu_model_runner.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 056282dd808..12e4215038d 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -11,7 +11,6 @@ import torch_xla.runtime as xr from vllm.attention import AttentionMetadata, get_attn_backend -from vllm.compilation.levels import CompilationLevel from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, ParallelConfig, SchedulerConfig) @@ -21,7 +20,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, Logprob, SequenceGroupMetadata, SequenceOutput) -from vllm.utils import temporary_env_var from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, _add_attn_metadata_broadcastable_dict, @@ -142,10 +140,7 @@ def load_model(self) -> None: with patch( "vllm.model_executor.layers.vocab_parallel_embedding." "get_tensor_model_parallel_rank", - return_value=xm_tp_rank), \ - temporary_env_var("VLLM_TORCH_COMPILE_LEVEL", str(CompilationLevel.NO_COMPILATION)): # noqa - # patching VLLM_TORCH_COMPILE_LEVEL to 0 so that - # compilation for other platforms is disabled. + return_value=xm_tp_rank): model = get_model( model_config=self.model_config, load_config=self.load_config, From bf0e935009b54452ec522f97290b9f8666887fea Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 8 Oct 2024 00:54:25 -0700 Subject: [PATCH 41/88] revert utils --- vllm/utils.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index 5044b768c7c..bec2f951d69 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1370,21 +1370,3 @@ def dec(self, num=1): @property def value(self): return self._value - - -@contextlib.contextmanager -def temporary_env_var(key, value): - # Save the original value of the environment variable (if it exists) - original_value = os.environ.get(key) - # Set the environment variable to the new temporary value - os.environ[key] = value - try: - # Yield control back to the code block inside the "with" statement - yield - finally: - # Restore the original value after exiting the "with" block - if original_value is None: - del os.environ[key] # If it was originally not set, remove it - else: - os.environ[ - key] = original_value # Otherwise, restore the original value From 39571c5c3261e2cbc93bf77ce0a83ac3696a9f28 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 8 Oct 2024 01:01:52 -0700 Subject: [PATCH 42/88] fix typo --- vllm/compilation/decorators.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index d4932495168..0bb6955c9bc 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -44,10 +44,10 @@ def __call__( intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - # torch.compile.is_compiling() means we are inside the compilation + # torch.compiler.is_compiling() means we are inside the compilation # e.g. TPU has the compilation logic in model runner, so we don't # need to compile the model inside. - if not self._use_torch_compile or torch.compile.is_compiling(): + if not self._use_torch_compile or torch.compiler.is_compiling(): return self.forward(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors, inputs_embeds) From e3aea56888d63aff23d513c6c87e7ba8a3b893af Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 8 Oct 2024 14:17:51 -0700 Subject: [PATCH 43/88] add typing --- vllm/envs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/envs.py b/vllm/envs.py index f217c7e8ac4..2b6e8f20000 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -64,6 +64,7 @@ VLLM_USE_TRITON_AWQ: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_SKIP_P2P_CHECK: bool = False + VLLM_TORCH_COMPILE_LEVEL: int = 0 def get_default_cache_root(): From 61817954ec16468d178b0d5178c97005b8a26606 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 8 Oct 2024 14:25:25 -0700 Subject: [PATCH 44/88] move DYNAMO_AS_IS to model runner level --- vllm/compilation/decorators.py | 13 +++++++------ vllm/worker/model_runner.py | 15 ++++++++++++++- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 0bb6955c9bc..91be4f0ee1e 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -16,7 +16,11 @@ def support_compile_llama_style(cls: type): decorator can be used to enable the compilation of the forward method. """ - if envs.VLLM_TORCH_COMPILE_LEVEL == CompilationLevel.NO_COMPILATION: + # for CompilationLevel.DYNAMO_AS_IS , the upper level model runner + # will handle the compilation, so we don't need to do anything here. + if envs.VLLM_TORCH_COMPILE_LEVEL in [ + CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS + ]: return cls # take care of method resolution order @@ -28,10 +32,7 @@ def support_compile_llama_style(cls: type): def __init__(self, *args, **kwargs): old_init(self, *args, **kwargs) - self._use_torch_compile = \ - envs.VLLM_TORCH_COMPILE_LEVEL > CompilationLevel.NO_COMPILATION - if self._use_torch_compile: - TorchCompileWrapperWithCustomDispatcher.__init__(self) + TorchCompileWrapperWithCustomDispatcher.__init__(self) cls.__init__ = __init__ @@ -47,7 +48,7 @@ def __call__( # torch.compiler.is_compiling() means we are inside the compilation # e.g. TPU has the compilation logic in model runner, so we don't # need to compile the model inside. - if not self._use_torch_compile or torch.compiler.is_compiling(): + if torch.compiler.is_compiling(): return self.forward(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors, inputs_embeds) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 5afdf79ca57..e9f50b5a911 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -14,10 +14,12 @@ import torch.distributed import torch.nn as nn +import vllm.envs as envs from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention.backends.abstract import AttentionState from vllm.attention.backends.utils import CommonAttentionState from vllm.compilation.compile_context import set_compile_context +from vllm.compilation.levels import CompilationLevel from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) @@ -46,7 +48,8 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, async_tensor_h2d, - flatten_2d_lists, is_hip, is_pin_memory_available) + flatten_2d_lists, is_hip, is_pin_memory_available, + supports_dynamo) from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, @@ -1123,6 +1126,16 @@ def load_model(self) -> None: "provided. Defaulting to scaling factors of 1.0. " "This may lead to less accurate results!") + if envs.VLLM_TORCH_COMPILE_LEVEL == CompilationLevel.DYNAMO_AS_IS \ + and supports_dynamo(): + from vllm.compilation.backends import vllm_backend + from vllm.plugins import get_torch_compile_backend + backend = get_torch_compile_backend() or vllm_backend + self.model = torch.compile( + self.model, + fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, + backend=backend) + def save_sharded_state( self, path: str, From 1a80a7bc06b98fa254ef5fc18b1016eb989a1d34 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 8 Oct 2024 15:02:59 -0700 Subject: [PATCH 45/88] fix default context --- vllm/compilation/backends.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 892ff620990..4780358cea5 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -177,7 +177,9 @@ def vllm_backend( example_inputs, additional_inductor_config: Optional[Dict] = None) -> Callable: - sizes_to_specialize: List[int] = copy.deepcopy(get_compile_context()) + context = get_compile_context() + context = copy.deepcopy(context) if context is not None else [] + sizes_to_specialize: List[int] = context # flags for all the seen shapes, whether we need to specialize runtime_shapes_to_compile_flags: Dict[Tuple[int, ...], bool] = {} From 92d240b0aad55feb92691352dc1beee9f6d8d7bc Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 8 Oct 2024 15:15:30 -0700 Subject: [PATCH 46/88] use eager for DYNAMO_AS_IS by default --- vllm/worker/model_runner.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index e9f50b5a911..4d887c3c1a7 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1128,9 +1128,8 @@ def load_model(self) -> None: if envs.VLLM_TORCH_COMPILE_LEVEL == CompilationLevel.DYNAMO_AS_IS \ and supports_dynamo(): - from vllm.compilation.backends import vllm_backend from vllm.plugins import get_torch_compile_backend - backend = get_torch_compile_backend() or vllm_backend + backend = get_torch_compile_backend() or "eager" self.model = torch.compile( self.model, fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, From f4b0f501cc2a96f4677c4dfc1aa50decd00268f9 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 8 Oct 2024 15:27:32 -0700 Subject: [PATCH 47/88] update tests --- tests/compile/test_basic_correctness.py | 22 +++++++++++++++------- tests/compile/utils.py | 6 ------ 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index f5d7e472c52..4e662454fcf 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -6,20 +6,28 @@ from vllm.utils import cuda_device_count_stateless from ..utils import compare_all_settings -from .utils import TEST_MODELS_SMOKE -@pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE) -@pytest.mark.parametrize("pp_size", [1, 2]) -@pytest.mark.parametrize("tp_size", [1, 2]) -def test_compile_correctness(model_info, pp_size, tp_size): +# we cannot afford testing the full Catesian product +# of all models and all levels +@pytest.mark.parametrize( + "model, model_args, pp_size, tp_size, attn_backend, method", + [ + ("meta-llama/Meta-Llama-3-8B", [], 2, 2, "FLASH_ATTN", "generate"), + ("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", + ["--quantization", "compressed-tensors" + ], 1, 1, "FLASH_ATTN", "generate"), + ("google/gemma-2-2b-it", [], 1, 2, "FLASHINFER", "generate"), + # TODO: add multi-modality test for llava + ("llava-hf/llava-1.5-7b-hf", [], 2, 1, "FLASHINFER", "generate") + ]) +def test_compile_correctness(model, model_args, pp_size, tp_size, attn_backend, + method): # this test is run under multiple suits, with different GPUs. # make sure we only run the test with correct CUDA devices. # don't use "<", as it will duplicate the tests. if cuda_device_count_stateless() != pp_size * tp_size: pytest.skip("Not correct CUDA devices for the test.") - model = model_info[0] - model_args = model_info[1] all_args = [["--enforce-eager"] + model_args + ["--max_model_len", "1024"] + ["-pp", str(pp_size)] + ["-tp", str(tp_size)]] * 3 # don't test VLLM_TORCH_COMPILE_LEVEL == 3 case diff --git a/tests/compile/utils.py b/tests/compile/utils.py index 4473e158c51..5386eb0e379 100644 --- a/tests/compile/utils.py +++ b/tests/compile/utils.py @@ -7,12 +7,6 @@ from vllm.compilation.levels import CompilationLevel from vllm.utils import is_hip -TEST_MODELS_SMOKE = [ - ("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", - ["--quantization", "compressed-tensors"]), - ("meta-llama/Meta-Llama-3-8B", []), -] - TEST_MODELS = [ ("facebook/opt-125m", {}), ("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", { From 896431ab9409e55a754df93956d8328ef53c82b2 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 8 Oct 2024 15:29:13 -0700 Subject: [PATCH 48/88] update tests --- tests/compile/test_basic_correctness.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index 4e662454fcf..fe3f12eeaa1 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -28,6 +28,8 @@ def test_compile_correctness(model, model_args, pp_size, tp_size, attn_backend, # don't use "<", as it will duplicate the tests. if cuda_device_count_stateless() != pp_size * tp_size: pytest.skip("Not correct CUDA devices for the test.") + import os + os.environ["VLLM_ATTENTION_BACKEND"] = attn_backend all_args = [["--enforce-eager"] + model_args + ["--max_model_len", "1024"] + ["-pp", str(pp_size)] + ["-tp", str(tp_size)]] * 3 # don't test VLLM_TORCH_COMPILE_LEVEL == 3 case @@ -40,4 +42,4 @@ def test_compile_correctness(model, model_args, pp_size, tp_size, attn_backend, CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE, ]] - compare_all_settings(model, all_args, all_envs) + compare_all_settings(model, all_args, all_envs, method=method) From 388d5639ecd07349a0a02e0174c98ee72dad6fb5 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 8 Oct 2024 15:53:33 -0700 Subject: [PATCH 49/88] llava uses fullgraph=false --- tests/compile/test_basic_correctness.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index fe3f12eeaa1..b6ec7413978 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -11,18 +11,19 @@ # we cannot afford testing the full Catesian product # of all models and all levels @pytest.mark.parametrize( - "model, model_args, pp_size, tp_size, attn_backend, method", + "model, model_args, pp_size, tp_size, attn_backend, method, fullgraph", [ - ("meta-llama/Meta-Llama-3-8B", [], 2, 2, "FLASH_ATTN", "generate"), + ("meta-llama/Meta-Llama-3-8B", [], 2, 2, "FLASH_ATTN", "generate", + True), ("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", ["--quantization", "compressed-tensors" - ], 1, 1, "FLASH_ATTN", "generate"), - ("google/gemma-2-2b-it", [], 1, 2, "FLASHINFER", "generate"), + ], 1, 1, "FLASH_ATTN", "generate", True), + ("google/gemma-2-2b-it", [], 1, 2, "FLASHINFER", "generate", True), # TODO: add multi-modality test for llava - ("llava-hf/llava-1.5-7b-hf", [], 2, 1, "FLASHINFER", "generate") + ("llava-hf/llava-1.5-7b-hf", [], 2, 1, "FLASHINFER", "generate", False) ]) def test_compile_correctness(model, model_args, pp_size, tp_size, attn_backend, - method): + method, fullgraph): # this test is run under multiple suits, with different GPUs. # make sure we only run the test with correct CUDA devices. # don't use "<", as it will duplicate the tests. @@ -30,6 +31,8 @@ def test_compile_correctness(model, model_args, pp_size, tp_size, attn_backend, pytest.skip("Not correct CUDA devices for the test.") import os os.environ["VLLM_ATTENTION_BACKEND"] = attn_backend + if not fullgraph: + os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "0" all_args = [["--enforce-eager"] + model_args + ["--max_model_len", "1024"] + ["-pp", str(pp_size)] + ["-tp", str(tp_size)]] * 3 # don't test VLLM_TORCH_COMPILE_LEVEL == 3 case From ce7cd8e884380bb5d842f9b49945045ee58842a7 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 9 Oct 2024 22:35:54 -0700 Subject: [PATCH 50/88] disable tests first --- .buildkite/test-pipeline.yaml | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 00dccca0d37..4925f02033e 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -229,12 +229,14 @@ steps: commands: - pytest -v -s compile/test_basic_correctness.py -- label: "PyTorch Fullgraph Test" # 18min - source_file_dependencies: - - vllm/ - - tests/compile - commands: - - pytest -v -s compile/test_full_graph.py +# TODO: re-write in comparison tests, and fix symbolic shape +# for quantization ops. +# - label: "PyTorch Fullgraph Test" # 18min +# source_file_dependencies: +# - vllm/ +# - tests/compile +# commands: +# - pytest -v -s compile/test_full_graph.py - label: Kernels Test %N # 1h each mirror_hardwares: [amd] From 828e4257c1dc16b6c4ddb04f4455f4bb9c67bf21 Mon Sep 17 00:00:00 2001 From: luka Date: Fri, 4 Oct 2024 15:14:47 -0400 Subject: [PATCH 51/88] RMSNorm fusion working! --- csrc/torch_bindings.cpp | 6 +- vllm/compilation/backends.py | 12 +++- vllm/compilation/fusion.py | 127 +++++++++++++++++++++++++++++++++++ 3 files changed, 139 insertions(+), 6 deletions(-) create mode 100644 vllm/compilation/fusion.py diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index a0100b4a85e..307d854875a 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -94,7 +94,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Layernorm // Apply Root Mean Square (RMS) Normalization to the input tensor. ops.def( - "rms_norm(Tensor! out, Tensor input, Tensor weight, float epsilon) -> " + "rms_norm(Tensor! result, Tensor input, Tensor weight, float epsilon) -> " "()"); ops.impl("rms_norm", torch::kCUDA, &rms_norm); @@ -345,13 +345,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Compute int8 quantized tensor for given scaling factor. ops.def( - "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale," + "static_scaled_int8_quant(Tensor! result, Tensor input, Tensor scale," "Tensor? azp) -> ()"); ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant); // Compute int8 quantized tensor and scaling factor ops.def( - "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, " + "dynamic_scaled_int8_quant(Tensor! result, Tensor input, Tensor! scale, " "Tensor!? azp) -> ()"); ops.impl("dynamic_scaled_int8_quant", torch::kCUDA, &dynamic_scaled_int8_quant); diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index de0b1d8a757..d142773966c 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -47,8 +47,8 @@ def fix_functionalization(graph: fx.Graph): # Remove the getitem node for getitem_user in list(user.users): if (getitem_user.op == 'call_function' - and getitem_user.target - == torch.ops.aten.slice_scatter.default): + and getitem_user.target + == torch.ops.aten.slice_scatter.default): # Replace the uses of slice_scatter node # with mm_node getitem_user.replace_all_uses_with(mm_node) @@ -147,10 +147,16 @@ def fix_functionalization(graph: fx.Graph): # with open("after.py", "w") as f: # print(graph.python_code(root_module="self", verbose=True).src, file=f) +from vllm.compilation.fusion import get_fusion_pass +fusion_pass = get_fusion_pass() + +def post_grad_post_passes(graph: fx.Graph): + fusion_pass(graph) + fix_functionalization(graph) def vllm_backend(graph, example_inputs): from torch._inductor import config current_config = config.shallow_copy_dict() from torch._inductor.compile_fx import compile_fx - current_config['post_grad_custom_post_pass'] = fix_functionalization + current_config['post_grad_custom_post_pass'] = post_grad_post_passes return compile_fx(graph, example_inputs, config_patches=current_config) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py new file mode 100644 index 00000000000..d46aeae033a --- /dev/null +++ b/vllm/compilation/fusion.py @@ -0,0 +1,127 @@ +import logging + +import torch +from torch._inductor.pattern_matcher import PatternMatcherPass, register_replacement, fwd_only +from torch._higher_order_ops.auto_functionalize import auto_functionalized + +from vllm.logger import init_logger + +logger = init_logger(__name__) +logger.setLevel(logging.DEBUG) # TODO + + +# DYNAMIC +@torch.library.custom_op("vllm::fused_rms_norm_quant_dynamic", mutates_args=['result', 'scale', 'azp']) +def fused_rms_norm_quant_dynamic(result: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, + azp: torch.Tensor, epsilon: float) -> None: + print("vllm::fused_rms_norm_quant_dynamic") + result_rms = torch.empty_like(input) + torch.ops._C.rms_norm(result_rms, input, weight, epsilon) + torch.ops._C.dynamic_scaled_int8_quant(result, result_rms, scale, azp) + + +@torch.library.register_fake("vllm::fused_rms_norm_quant_dynamic") +def _(result: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, azp: torch.Tensor, + epsilon: float) -> None: + return + + +# TODO epsilon +def rms_pattern(result: torch.Tensor, result_rms: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at1 = auto_functionalized(torch.ops._C.rms_norm.default, result=result_rms, input=input, weight=weight, + epsilon=1e-6) + at2 = auto_functionalized(torch.ops._C.dynamic_scaled_int8_quant.default, result=result, input=at1[1], scale=scale, + azp=None) + + # result, scale + # TODO azp + return at2[1:2] + + +def rms_replacement(result: torch.Tensor, result_rms: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.fused_rms_norm_quant_dynamic.default, result=result, + input=input, weight=weight, + epsilon=1e-6, scale=scale, azp=None) + + # result, scale + # TODO azp + return at[1:2] + + +# STATIC +@torch.library.custom_op("vllm::fused_rms_norm_quant_static", mutates_args=['result']) +def fused_rms_norm_quant_static(result: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, + azp: torch.Tensor, epsilon: float) -> None: + print("vllm::fused_rms_norm_quant_static") + result_rms = torch.empty_like(input) + torch.ops._C.rms_norm(result_rms, input, weight, epsilon) + torch.ops._C.static_scaled_int8_quant(result, result_rms, scale, azp) + + +@torch.library.register_fake("vllm::fused_rms_norm_quant_static") +def _(result: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, azp: torch.Tensor, + epsilon: float) -> None: + return + + +def rms_pattern_static(result: torch.Tensor, result_rms: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at1 = auto_functionalized(torch.ops._C.rms_norm.default, result=result_rms, input=input, weight=weight, + epsilon=1e-5) + at2 = auto_functionalized(torch.ops._C.static_scaled_int8_quant.default, result=result, input=at1[1], scale=scale, + azp=None) + + # result + return at2[1] + + +def rms_replacement_static(result: torch.Tensor, result_rms: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized(torch.ops.vllm.fused_rms_norm_quant_static.default, result=result, input=input, + weight=weight, + epsilon=1e-5, scale=scale, azp=None) + + # result + return at[1] + + + +my_patterns = PatternMatcherPass() + + +def empty_bf16(*args, **kwargs): + return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda") + + +def empty_int8(*args, **kwargs): + return torch.empty(*args, **kwargs, dtype=torch.int8, device="cuda") + + +def get_patterns(): + my_patterns = PatternMatcherPass() + + inputs = [empty_int8(5, 4), empty_bf16(5, 4), empty_bf16(5, 4), empty_bf16(1, 5), torch.empty(1, 1, device="cuda")] + register_replacement(rms_pattern, rms_replacement, inputs, fwd_only, my_patterns) + register_replacement(rms_pattern_static, rms_replacement_static, inputs, fwd_only, my_patterns) + + return my_patterns + + +def get_fusion_pass(): + patterns = get_patterns() + + def fusion_pass(graph: torch.fx.Graph): + """ + Use the pattern matcher + """ + # logger.info("Graph before fusion pass:") + with open("before.py", "w") as f: + print(graph.python_code(root_module="self", verbose=True).src, file=f) + count = patterns.apply(graph) + logger.info(f"Replaced {count} patterns") + with open("after.py", "w") as f: + print(graph.python_code(root_module="self", verbose=True).src, file=f) + + return fusion_pass From 88d13790b2650e5e99ec13a23faae69d8cdecbe0 Mon Sep 17 00:00:00 2001 From: luka Date: Fri, 4 Oct 2024 16:52:15 -0400 Subject: [PATCH 52/88] fused with bug --- vllm/compilation/fusion.py | 41 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 39 insertions(+), 2 deletions(-) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index d46aeae033a..ab1f2993a77 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -87,6 +87,39 @@ def rms_replacement_static(result: torch.Tensor, result_rms: torch.Tensor, input return at[1] +@torch.library.custom_op("vllm::fused_rms_norm_residual_quant_static", mutates_args=['result', 'input', 'residual']) +def fused_rms_norm_residual_quant_static(result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, + weight: torch.Tensor, scale: torch.Tensor, azp: torch.Tensor, + epsilon: float) -> None: + # print("vllm::fused_rms_norm_residual_quant_static") + torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) + torch.ops._C.static_scaled_int8_quant(result, input, scale, azp) + + +@torch.library.register_fake("vllm::fused_rms_norm_residual_quant_static") +def _(result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, + azp: torch.Tensor, epsilon: float) -> None: + return + + +def rms_pattern_residual_static(result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at1 = auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input=input, residual=residual, weight=weight, + epsilon=1e-5) + at2 = auto_functionalized(torch.ops._C.static_scaled_int8_quant.default, result=result, input=at1[1], scale=scale, + azp=None) + + # result, residual + return at2[1], at1[2] + + +def rms_replacement_residual_static(result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, + weight: torch.Tensor, scale: torch.Tensor): + at = auto_functionalized(torch.ops.vllm.fused_rms_norm_residual_quant_static.default, result=result, input=input, + residual=residual, weight=weight, epsilon=1e-5, scale=scale, azp=None) + # result, residual + return at[1], at[3] + my_patterns = PatternMatcherPass() @@ -106,6 +139,10 @@ def get_patterns(): register_replacement(rms_pattern, rms_replacement, inputs, fwd_only, my_patterns) register_replacement(rms_pattern_static, rms_replacement_static, inputs, fwd_only, my_patterns) + # with residual + inputs = [empty_int8(5, 4), empty_bf16(5, 4), empty_bf16(5, 4), empty_bf16(1, 5), torch.empty(1, 1, device="cuda")] + register_replacement(rms_pattern_residual_static, rms_replacement_residual_static, inputs, fwd_only, my_patterns) + return my_patterns @@ -117,11 +154,11 @@ def fusion_pass(graph: torch.fx.Graph): Use the pattern matcher """ # logger.info("Graph before fusion pass:") - with open("before.py", "w") as f: + with open("before_fusion.py", "w") as f: print(graph.python_code(root_module="self", verbose=True).src, file=f) count = patterns.apply(graph) logger.info(f"Replaced {count} patterns") - with open("after.py", "w") as f: + with open("after_fusion.py", "w") as f: print(graph.python_code(root_module="self", verbose=True).src, file=f) return fusion_pass From deeaef3adb0f337e21a2b2a922833a68c23a41bf Mon Sep 17 00:00:00 2001 From: luka Date: Mon, 7 Oct 2024 18:48:56 -0400 Subject: [PATCH 53/88] Use pattern matcher to match, replace manually, giving correct output --- vllm/compilation/backends.py | 8 ++- vllm/compilation/fusion.py | 102 +++++++++++++++++++++++++++++------ vllm/envs.py | 10 ++++ 3 files changed, 103 insertions(+), 17 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index d142773966c..580742c046f 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -47,8 +47,8 @@ def fix_functionalization(graph: fx.Graph): # Remove the getitem node for getitem_user in list(user.users): if (getitem_user.op == 'call_function' - and getitem_user.target - == torch.ops.aten.slice_scatter.default): + and getitem_user.target + == torch.ops.aten.slice_scatter.default): # Replace the uses of slice_scatter node # with mm_node getitem_user.replace_all_uses_with(mm_node) @@ -147,13 +147,17 @@ def fix_functionalization(graph: fx.Graph): # with open("after.py", "w") as f: # print(graph.python_code(root_module="self", verbose=True).src, file=f) + from vllm.compilation.fusion import get_fusion_pass + fusion_pass = get_fusion_pass() + def post_grad_post_passes(graph: fx.Graph): fusion_pass(graph) fix_functionalization(graph) + def vllm_backend(graph, example_inputs): from torch._inductor import config current_config = config.shallow_copy_dict() diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index ab1f2993a77..79634a8d2cd 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -1,9 +1,11 @@ import logging +import operator import torch -from torch._inductor.pattern_matcher import PatternMatcherPass, register_replacement, fwd_only +from torch._inductor.pattern_matcher import PatternMatcherPass, register_replacement, fwd_only, Match from torch._higher_order_ops.auto_functionalize import auto_functionalized +from vllm import envs from vllm.logger import init_logger logger = init_logger(__name__) @@ -104,13 +106,13 @@ def _(result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, weight: def rms_pattern_residual_static(result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): - at1 = auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input=input, residual=residual, weight=weight, - epsilon=1e-5) - at2 = auto_functionalized(torch.ops._C.static_scaled_int8_quant.default, result=result, input=at1[1], scale=scale, + at = auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input=input, residual=residual, weight=weight, + epsilon=1e-5) + at1 = auto_functionalized(torch.ops._C.static_scaled_int8_quant.default, result=result, input=at[1], scale=scale, azp=None) # result, residual - return at2[1], at1[2] + return at1[1], at[2] def rms_replacement_residual_static(result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, @@ -133,32 +135,102 @@ def empty_int8(*args, **kwargs): def get_patterns(): - my_patterns = PatternMatcherPass() + my_patterns = PatternMatcherPass(pass_name="fusion_pass") inputs = [empty_int8(5, 4), empty_bf16(5, 4), empty_bf16(5, 4), empty_bf16(1, 5), torch.empty(1, 1, device="cuda")] register_replacement(rms_pattern, rms_replacement, inputs, fwd_only, my_patterns) register_replacement(rms_pattern_static, rms_replacement_static, inputs, fwd_only, my_patterns) + matches = [] + + def record_match_fn(match: Match): + matches.append(match) + return False + # with residual inputs = [empty_int8(5, 4), empty_bf16(5, 4), empty_bf16(5, 4), empty_bf16(1, 5), torch.empty(1, 1, device="cuda")] - register_replacement(rms_pattern_residual_static, rms_replacement_residual_static, inputs, fwd_only, my_patterns) - - return my_patterns + register_replacement(rms_pattern_residual_static, rms_replacement_residual_static, inputs, fwd_only, my_patterns, + extra_check=record_match_fn) + + return my_patterns, matches + + +def process_matches(matches, graph: torch.fx.Graph): + for match in matches: + nodes = list(graph.nodes) + # TODO this is an expensive check + if not all(node in nodes for node in match.nodes): + raise ValueError(f"Broken match: not all nodes in graph: {[node for node in match.nodes if node not in nodes]}") + last_node_in_match = max(match.nodes, key=lambda x: nodes.index(x)) + with graph.inserting_after(last_node_in_match): + kwargs = match.kwargs + kwargs["azp"] = None + kwargs["epsilon"] = 1e-5 + + fused_node = graph.call_function(auto_functionalized, + (torch.ops.vllm.fused_rms_norm_residual_quant_static.default,), + kwargs=kwargs) + + graph.inserting_after(fused_node) + result_node_new = graph.call_function(operator.getitem, (fused_node, 1)) + residual_node_new = graph.call_function(operator.getitem, (fused_node, 3)) + + # find the output and the residual + def find_auto_fn(op): + for node in match.nodes: + if node.op == "call_function" and node.target == auto_functionalized and node.args[0] == op: + return node + return None + + def find_getitem(node, idx): + for user in node.users: + if user.op == "call_function" and user.target == operator.getitem and user.args[1] == idx: + return user + return None + + rms_node = find_auto_fn(torch.ops._C.fused_add_rms_norm.default) + quant_node = find_auto_fn(torch.ops._C.static_scaled_int8_quant.default) + assert rms_node is not None + assert quant_node is not None + + assert len(rms_node.users) == 2 + assert len(quant_node.users) == 1 + + # meta["val"] is used by de-functionalization + rms_val = rms_node.meta["val"] + quant_val = quant_node.meta["val"] + fused_node.meta["val"] = (None, quant_val[1], rms_val[1], rms_val[2]) + + find_getitem(rms_node, 2).replace_all_uses_with(residual_node_new) + find_getitem(quant_node, 1).replace_all_uses_with(result_node_new) + + # Finally, remove matched nodes + graph.eliminate_dead_code() + assert all(node not in graph.nodes for node in match.nodes for match in matches) def get_fusion_pass(): - patterns = get_patterns() + patterns, matches = get_patterns() + + def dump_graph(graph: torch.fx.Graph, stage: str): + if stage in envs.VLLM_TORCH_COMPILE_FUSION_DUMP: + with open(f"{stage}.py", "w") as f: + print(graph.python_code(root_module="self", verbose=True).src, file=f) def fusion_pass(graph: torch.fx.Graph): """ Use the pattern matcher """ - # logger.info("Graph before fusion pass:") - with open("before_fusion.py", "w") as f: - print(graph.python_code(root_module="self", verbose=True).src, file=f) + matches.clear() + dump_graph(graph, "before_fusion") + count = patterns.apply(graph) logger.info(f"Replaced {count} patterns") - with open("after_fusion.py", "w") as f: - print(graph.python_code(root_module="self", verbose=True).src, file=f) + dump_graph(graph, "after_pattern_match") + + # Manually process multi-output matches (and run DCE) + process_matches(matches, graph) + logger.info(f"Post-processed {len(matches)} matches") + dump_graph(graph, "after_fusion") return fusion_pass diff --git a/vllm/envs.py b/vllm/envs.py index 97767bf5b5a..16ca1e653cd 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -12,6 +12,11 @@ VLLM_NCCL_SO_PATH: Optional[str] = None LD_LIBRARY_PATH: Optional[str] = None VLLM_USE_TRITON_FLASH_ATTN: bool = False + VLLM_TEST_DYNAMO_GRAPH_CAPTURE: int = 0 + VLLM_DYNAMO_USE_CUSTOM_DISPATCHER: bool = True + VLLM_TEST_COMPILE_NO_CUSTOM_OPS: int = 0 + VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE: bool = True + VLLM_TORCH_COMPILE_FUSION_DUMP: List[str] = [] LOCAL_RANK: int = 0 CUDA_VISIBLE_DEVICES: Optional[str] = None VLLM_ENGINE_ITERATION_TIMEOUT_S: int = 60 @@ -216,6 +221,11 @@ def get_default_config_root(): lambda: bool( os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"), + # Internal flag for dumping the model graph before and after fusion + "VLLM_TORCH_COMPILE_FUSION_DUMP": + lambda: list( + os.environ.get("VLLM_TORCH_COMPILE_FUSION_DUMP", "").split(",")), + # local rank of the process in the distributed setting, used to determine # the GPU device id "LOCAL_RANK": From 927d2dd47290083ff705ec1171b651eef9b9ebf4 Mon Sep 17 00:00:00 2001 From: luka Date: Tue, 8 Oct 2024 10:40:14 -0400 Subject: [PATCH 54/88] add quant_layernorm kernel (not modified yet) --- CMakeLists.txt | 1 + csrc/layernorm_quant_kernels.cu | 357 ++++++++++++++++++++++++++++++++ 2 files changed, 358 insertions(+) create mode 100644 csrc/layernorm_quant_kernels.cu diff --git a/CMakeLists.txt b/CMakeLists.txt index 4be524808a2..d884ac4c4ce 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -204,6 +204,7 @@ set(VLLM_EXT_SRC "csrc/pos_encoding_kernels.cu" "csrc/activation_kernels.cu" "csrc/layernorm_kernels.cu" + "csrc/layernorm_quant_kernels.cu" "csrc/quantization/gptq/q_gemm.cu" "csrc/quantization/compressed_tensors/int8_quant_kernels.cu" "csrc/quantization/fp8/common.cu" diff --git a/csrc/layernorm_quant_kernels.cu b/csrc/layernorm_quant_kernels.cu new file mode 100644 index 00000000000..7a7a25d2173 --- /dev/null +++ b/csrc/layernorm_quant_kernels.cu @@ -0,0 +1,357 @@ +#include +#include +#include + +#include "dispatch_utils.h" +#ifndef USE_ROCM + #include + #include + #include + #include +#else + #include + #include + #include + #include + +using __nv_bfloat16 = __hip_bfloat16; +using __nv_bfloat162 = __hip_bfloat162; +#endif + +namespace vllm { + +// TODO(woosuk): Further optimize this kernel. +template +__global__ void rms_norm_kernel( + scalar_t* __restrict__ out, // [..., hidden_size] + const scalar_t* __restrict__ input, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, const int num_tokens, const int hidden_size) { + __shared__ float s_variance; + float variance = 0.0f; + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + const float x = (float)input[blockIdx.x * hidden_size + idx]; + variance += x * x; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float x = (float)input[blockIdx.x * hidden_size + idx]; + out[blockIdx.x * hidden_size + idx] = + ((scalar_t)(x * s_variance)) * weight[idx]; + } +} + +/* Converter structs for the conversion from torch types to HIP/CUDA types, + and the associated type conversions within HIP/CUDA. These helpers need + to be implemented for now because the relevant type conversion + operators/constructors are not consistently implemented by HIP/CUDA, so + a generic conversion via type casts cannot be implemented. + + Each struct should have the member static constexpr bool `exists`: + If false, the optimized kernel is not used for the corresponding torch type. + If true, the struct should be fully defined as shown in the examples below. + */ +template +struct _typeConvert { + static constexpr bool exists = false; +}; + +#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) +// CUDA < 12.0 runs into issues with packed type conversion +template <> +struct _typeConvert { + static constexpr bool exists = true; + using hip_type = __half; + using packed_hip_type = __half2; + + __device__ static inline float convert(hip_type x) { return __half2float(x); } + __device__ static inline float2 convert(packed_hip_type x) { + return __half22float2(x); + } + __device__ static inline hip_type convert(float x) { + return __float2half_rn(x); + } + __device__ static inline packed_hip_type convert(float2 x) { + return __float22half2_rn(x); + } +}; + + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +// CUDA_ARCH < 800 does not have BF16 support +// TODO: Add in ROCm support once public headers handle bf16 maturely +template <> +struct _typeConvert { + static constexpr bool exists = true; + using hip_type = __nv_bfloat16; + using packed_hip_type = __nv_bfloat162; + + __device__ static inline float convert(hip_type x) { + return __bfloat162float(x); + } + __device__ static inline float2 convert(packed_hip_type x) { + return __bfloat1622float2(x); + } + __device__ static inline hip_type convert(float x) { + return __float2bfloat16(x); + } + __device__ static inline packed_hip_type convert(float2 x) { + return __float22bfloat162_rn(x); + } +}; + #endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= + // 12000)) + +/* Vector POD struct to generate vectorized and packed FP16/BF16 ops + for appropriate specializations of fused_add_rms_norm_kernel. + Only functions that are necessary in that kernel are implemented. + Alignment to 16 bytes is required to use 128-bit global memory ops. + */ +template +struct alignas(16) _f16Vec { + /* Not theoretically necessary that width is a power of 2 but should + almost always be the case for optimization purposes */ + static_assert(width > 0 && (width & (width - 1)) == 0, + "Width is not a positive power of 2!"); + using Converter = _typeConvert; + using T1 = typename Converter::hip_type; + using T2 = typename Converter::packed_hip_type; + T1 data[width]; + + __device__ _f16Vec& operator+=(const _f16Vec& other) { + if constexpr (width % 2 == 0) { +#pragma unroll + for (int i = 0; i < width; i += 2) { + T2 temp{data[i], data[i + 1]}; + temp += T2{other.data[i], other.data[i + 1]}; + data[i] = temp.x; + data[i + 1] = temp.y; + } + } else { +#pragma unroll + for (int i = 0; i < width; ++i) data[i] += other.data[i]; + } + return *this; + } + + __device__ _f16Vec& operator*=(const _f16Vec& other) { + if constexpr (width % 2 == 0) { +#pragma unroll + for (int i = 0; i < width; i += 2) { + T2 temp{data[i], data[i + 1]}; + temp *= T2{other.data[i], other.data[i + 1]}; + data[i] = temp.x; + data[i + 1] = temp.y; + } + } else { +#pragma unroll + for (int i = 0; i < width; ++i) data[i] *= other.data[i]; + } + return *this; + } + + __device__ _f16Vec& operator*=(const float scale) { + if constexpr (width % 2 == 0) { +#pragma unroll + for (int i = 0; i < width; i += 2) { + float2 temp_f = Converter::convert(T2{data[i], data[i + 1]}); + temp_f.x *= scale; + temp_f.y *= scale; + T2 temp = Converter::convert(temp_f); + data[i] = temp.x; + data[i + 1] = temp.y; + } + } else { +#pragma unroll + for (int i = 0; i < width; ++i) { + float temp = Converter::convert(data[i]) * scale; + data[i] = Converter::convert(temp); + } + } + return *this; + } + + __device__ float sum_squares() const { + float result = 0.0f; + if constexpr (width % 2 == 0) { +#pragma unroll + for (int i = 0; i < width; i += 2) { + float2 z = Converter::convert(T2{data[i], data[i + 1]}); + result += z.x * z.x + z.y * z.y; + } + } else { +#pragma unroll + for (int i = 0; i < width; ++i) { + float x = Converter::convert(data[i]); + result += x * x; + } + } + return result; + } +}; + +/* Function specialization in the case of FP16/BF16 tensors. + Additional optimizations we can make in this case are + packed and vectorized operations, which help with the + memory latency bottleneck. */ +template +__global__ std::enable_if_t<(width > 0) && _typeConvert::exists> +fused_add_rms_norm_kernel( + scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ residual, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, const int num_tokens, const int hidden_size) { + // Sanity checks on our vector struct and type-punned pointer arithmetic + static_assert(std::is_pod_v<_f16Vec>); + static_assert(sizeof(_f16Vec) == sizeof(scalar_t) * width); + + const int vec_hidden_size = hidden_size / width; + __shared__ float s_variance; + float variance = 0.0f; + /* These and the argument pointers are all declared `restrict` as they are + not aliased in practice. Argument pointers should not be dereferenced + in this kernel as that would be undefined behavior */ + auto* __restrict__ input_v = + reinterpret_cast<_f16Vec*>(input); + auto* __restrict__ residual_v = + reinterpret_cast<_f16Vec*>(residual); + auto* __restrict__ weight_v = + reinterpret_cast*>(weight); + + for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { + int id = blockIdx.x * vec_hidden_size + idx; + _f16Vec temp = input_v[id]; + temp += residual_v[id]; + variance += temp.sum_squares(); + residual_v[id] = temp; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { + int id = blockIdx.x * vec_hidden_size + idx; + _f16Vec temp = residual_v[id]; + temp *= s_variance; + temp *= weight_v[idx]; + input_v[id] = temp; + } +} + +/* Generic fused_add_rms_norm_kernel + The width field is not used here but necessary for other specializations. + */ +template +__global__ std::enable_if_t<(width == 0) || !_typeConvert::exists> +fused_add_rms_norm_kernel( + scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ residual, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, const int num_tokens, const int hidden_size) { + __shared__ float s_variance; + float variance = 0.0f; + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + scalar_t z = input[blockIdx.x * hidden_size + idx]; + z += residual[blockIdx.x * hidden_size + idx]; + float x = (float)z; + variance += x * x; + residual[blockIdx.x * hidden_size + idx] = z; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float x = (float)residual[blockIdx.x * hidden_size + idx]; + input[blockIdx.x * hidden_size + idx] = + ((scalar_t)(x * s_variance)) * weight[idx]; + } +} + +} // namespace vllm + +void rms_norm(torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + double epsilon) { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { + vllm::rms_norm_kernel<<>>( + out.data_ptr(), input.data_ptr(), + weight.data_ptr(), epsilon, num_tokens, hidden_size); + }); +} + +#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \ + vllm::fused_add_rms_norm_kernel \ + <<>>(input.data_ptr(), \ + residual.data_ptr(), \ + weight.data_ptr(), epsilon, \ + num_tokens, hidden_size); \ + }); + +void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] + torch::Tensor& residual, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + double epsilon) { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + dim3 grid(num_tokens); + /* This kernel is memory-latency bound in many scenarios. + When num_tokens is large, a smaller block size allows + for increased block occupancy on CUs and better latency + hiding on global mem ops. */ + const int max_block_size = (num_tokens < 256) ? 1024 : 256; + dim3 block(std::min(hidden_size, max_block_size)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + /*If the tensor types are FP16/BF16, try to use the optimized kernel + with packed + vectorized ops. + Max optimization is achieved with a width-8 vector of FP16/BF16s + since we can load at most 128 bits at once in a global memory op. + However, this requires each tensor's data to be aligned to 16 + bytes. + */ + auto inp_ptr = reinterpret_cast(input.data_ptr()); + auto res_ptr = reinterpret_cast(residual.data_ptr()); + auto wt_ptr = reinterpret_cast(weight.data_ptr()); + bool ptrs_are_aligned = + inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0; + if (ptrs_are_aligned && hidden_size % 8 == 0) { + LAUNCH_FUSED_ADD_RMS_NORM(8); + } else { + LAUNCH_FUSED_ADD_RMS_NORM(0); + } +} From fc3fde62b843386e70388b23878297986806e3d5 Mon Sep 17 00:00:00 2001 From: luka Date: Tue, 8 Oct 2024 11:01:31 -0400 Subject: [PATCH 55/88] fixes --- vllm/compilation/backends.py | 2 +- vllm/compilation/fusion.py | 13 ++++++------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 580742c046f..6a07290b94b 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -91,7 +91,7 @@ def fix_functionalization(graph: fx.Graph): kwargs = node.kwargs input = kwargs['input'] - out = kwargs['out'] + out = kwargs['result'] weight = kwargs['weight'] epsilon = kwargs['epsilon'] # Create a new call to torch.ops._C.rotary_embedding.default diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 79634a8d2cd..350c12803fe 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -36,9 +36,8 @@ def rms_pattern(result: torch.Tensor, result_rms: torch.Tensor, input: torch.Ten at2 = auto_functionalized(torch.ops._C.dynamic_scaled_int8_quant.default, result=result, input=at1[1], scale=scale, azp=None) - # result, scale - # TODO azp - return at2[1:2] + # result, scale (multi-output not currently working) + return at2[1:3] def rms_replacement(result: torch.Tensor, result_rms: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, @@ -47,9 +46,8 @@ def rms_replacement(result: torch.Tensor, result_rms: torch.Tensor, input: torch input=input, weight=weight, epsilon=1e-6, scale=scale, azp=None) - # result, scale - # TODO azp - return at[1:2] + # result, scale (multi-output not currently working) + return at[1:3] # STATIC @@ -206,7 +204,7 @@ def find_getitem(node, idx): # Finally, remove matched nodes graph.eliminate_dead_code() - assert all(node not in graph.nodes for node in match.nodes for match in matches) + assert all(node not in graph.nodes for node in (match for match in matches)) def get_fusion_pass(): @@ -214,6 +212,7 @@ def get_fusion_pass(): def dump_graph(graph: torch.fx.Graph, stage: str): if stage in envs.VLLM_TORCH_COMPILE_FUSION_DUMP: + logger.info("Printing graph to %s", f"{stage}.py") with open(f"{stage}.py", "w") as f: print(graph.python_code(root_module="self", verbose=True).src, file=f) From ef8e0f5ab332941ad3b3f9b680b923f23dafd798 Mon Sep 17 00:00:00 2001 From: luka Date: Tue, 8 Oct 2024 11:01:48 -0400 Subject: [PATCH 56/88] out -> result for fp8 quant ops --- csrc/torch_bindings.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 307d854875a..037775f025f 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -317,18 +317,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Compute FP8 quantized tensor for given scaling factor. ops.def( - "static_scaled_fp8_quant(Tensor! out, Tensor input, Tensor scale) -> ()"); + "static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> ()"); ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant); // Compute dynamic-per-tensor FP8 quantized tensor and scaling factor. ops.def( - "dynamic_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! scale) -> " + "dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) -> " "()"); ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant); // Compute dynamic-per-token FP8 quantized tensor and scaling factor. ops.def( - "dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, " + "dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, " "Tensor! scale, Tensor? scale_ub) -> " "()"); ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA, From c7d3d18c6eb8e029396fe765f4374c2b174ca215 Mon Sep 17 00:00:00 2001 From: luka Date: Tue, 8 Oct 2024 11:01:59 -0400 Subject: [PATCH 57/88] change int8 to fp8 --- vllm/compilation/fusion.py | 50 ++++++++++++++++---------------------- 1 file changed, 21 insertions(+), 29 deletions(-) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 350c12803fe..245b3016593 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -13,18 +13,17 @@ # DYNAMIC -@torch.library.custom_op("vllm::fused_rms_norm_quant_dynamic", mutates_args=['result', 'scale', 'azp']) +@torch.library.custom_op("vllm::fused_rms_norm_quant_dynamic", mutates_args=['result', 'scale']) def fused_rms_norm_quant_dynamic(result: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, - azp: torch.Tensor, epsilon: float) -> None: + epsilon: float) -> None: print("vllm::fused_rms_norm_quant_dynamic") result_rms = torch.empty_like(input) torch.ops._C.rms_norm(result_rms, input, weight, epsilon) - torch.ops._C.dynamic_scaled_int8_quant(result, result_rms, scale, azp) + torch.ops._C.dynamic_scaled_fp8_quant(result, result_rms, scale) @torch.library.register_fake("vllm::fused_rms_norm_quant_dynamic") -def _(result: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, azp: torch.Tensor, - epsilon: float) -> None: +def _(result: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, epsilon: float) -> None: return @@ -33,8 +32,7 @@ def rms_pattern(result: torch.Tensor, result_rms: torch.Tensor, input: torch.Ten scale: torch.Tensor): at1 = auto_functionalized(torch.ops._C.rms_norm.default, result=result_rms, input=input, weight=weight, epsilon=1e-6) - at2 = auto_functionalized(torch.ops._C.dynamic_scaled_int8_quant.default, result=result, input=at1[1], scale=scale, - azp=None) + at2 = auto_functionalized(torch.ops._C.dynamic_scaled_fp8_quant.default, result=result, input=at1[1], scale=scale) # result, scale (multi-output not currently working) return at2[1:3] @@ -44,7 +42,7 @@ def rms_replacement(result: torch.Tensor, result_rms: torch.Tensor, input: torch scale: torch.Tensor): at = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.fused_rms_norm_quant_dynamic.default, result=result, input=input, weight=weight, - epsilon=1e-6, scale=scale, azp=None) + epsilon=1e-6, scale=scale) # result, scale (multi-output not currently working) return at[1:3] @@ -53,16 +51,15 @@ def rms_replacement(result: torch.Tensor, result_rms: torch.Tensor, input: torch # STATIC @torch.library.custom_op("vllm::fused_rms_norm_quant_static", mutates_args=['result']) def fused_rms_norm_quant_static(result: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, - azp: torch.Tensor, epsilon: float) -> None: + epsilon: float) -> None: print("vllm::fused_rms_norm_quant_static") result_rms = torch.empty_like(input) torch.ops._C.rms_norm(result_rms, input, weight, epsilon) - torch.ops._C.static_scaled_int8_quant(result, result_rms, scale, azp) + torch.ops._C.static_scaled_fp8_quant(result, result_rms, scale) @torch.library.register_fake("vllm::fused_rms_norm_quant_static") -def _(result: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, azp: torch.Tensor, - epsilon: float) -> None: +def _(result: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, epsilon: float) -> None: return @@ -70,8 +67,7 @@ def rms_pattern_static(result: torch.Tensor, result_rms: torch.Tensor, input: to scale: torch.Tensor): at1 = auto_functionalized(torch.ops._C.rms_norm.default, result=result_rms, input=input, weight=weight, epsilon=1e-5) - at2 = auto_functionalized(torch.ops._C.static_scaled_int8_quant.default, result=result, input=at1[1], scale=scale, - azp=None) + at2 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default, result=result, input=at1[1], scale=scale) # result return at2[1] @@ -80,8 +76,7 @@ def rms_pattern_static(result: torch.Tensor, result_rms: torch.Tensor, input: to def rms_replacement_static(result: torch.Tensor, result_rms: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): at = auto_functionalized(torch.ops.vllm.fused_rms_norm_quant_static.default, result=result, input=input, - weight=weight, - epsilon=1e-5, scale=scale, azp=None) + weight=weight, epsilon=1e-5, scale=scale) # result return at[1] @@ -89,16 +84,15 @@ def rms_replacement_static(result: torch.Tensor, result_rms: torch.Tensor, input @torch.library.custom_op("vllm::fused_rms_norm_residual_quant_static", mutates_args=['result', 'input', 'residual']) def fused_rms_norm_residual_quant_static(result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, - weight: torch.Tensor, scale: torch.Tensor, azp: torch.Tensor, - epsilon: float) -> None: + weight: torch.Tensor, scale: torch.Tensor, epsilon: float) -> None: # print("vllm::fused_rms_norm_residual_quant_static") torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) - torch.ops._C.static_scaled_int8_quant(result, input, scale, azp) + torch.ops._C.static_scaled_fp8_quant(result, input, scale) @torch.library.register_fake("vllm::fused_rms_norm_residual_quant_static") def _(result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, - azp: torch.Tensor, epsilon: float) -> None: + epsilon: float) -> None: return @@ -106,8 +100,7 @@ def rms_pattern_residual_static(result: torch.Tensor, input: torch.Tensor, resid scale: torch.Tensor): at = auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input=input, residual=residual, weight=weight, epsilon=1e-5) - at1 = auto_functionalized(torch.ops._C.static_scaled_int8_quant.default, result=result, input=at[1], scale=scale, - azp=None) + at1 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default, result=result, input=at[1], scale=scale) # result, residual return at1[1], at[2] @@ -116,7 +109,7 @@ def rms_pattern_residual_static(result: torch.Tensor, input: torch.Tensor, resid def rms_replacement_residual_static(result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): at = auto_functionalized(torch.ops.vllm.fused_rms_norm_residual_quant_static.default, result=result, input=input, - residual=residual, weight=weight, epsilon=1e-5, scale=scale, azp=None) + residual=residual, weight=weight, epsilon=1e-5, scale=scale) # result, residual return at[1], at[3] @@ -128,14 +121,14 @@ def empty_bf16(*args, **kwargs): return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda") -def empty_int8(*args, **kwargs): - return torch.empty(*args, **kwargs, dtype=torch.int8, device="cuda") +def empty_fp8(*args, **kwargs): + return torch.empty(*args, **kwargs, dtype=torch.float8_e4m3fn, device="cuda") def get_patterns(): my_patterns = PatternMatcherPass(pass_name="fusion_pass") - inputs = [empty_int8(5, 4), empty_bf16(5, 4), empty_bf16(5, 4), empty_bf16(1, 5), torch.empty(1, 1, device="cuda")] + inputs = [empty_fp8(5, 4), empty_bf16(5, 4), empty_bf16(5, 4), empty_bf16(1, 5), torch.empty(1, 1, device="cuda")] register_replacement(rms_pattern, rms_replacement, inputs, fwd_only, my_patterns) register_replacement(rms_pattern_static, rms_replacement_static, inputs, fwd_only, my_patterns) @@ -146,7 +139,7 @@ def record_match_fn(match: Match): return False # with residual - inputs = [empty_int8(5, 4), empty_bf16(5, 4), empty_bf16(5, 4), empty_bf16(1, 5), torch.empty(1, 1, device="cuda")] + inputs = [empty_fp8(5, 4), empty_bf16(5, 4), empty_bf16(5, 4), empty_bf16(1, 5), torch.empty(1, 1, device="cuda")] register_replacement(rms_pattern_residual_static, rms_replacement_residual_static, inputs, fwd_only, my_patterns, extra_check=record_match_fn) @@ -162,7 +155,6 @@ def process_matches(matches, graph: torch.fx.Graph): last_node_in_match = max(match.nodes, key=lambda x: nodes.index(x)) with graph.inserting_after(last_node_in_match): kwargs = match.kwargs - kwargs["azp"] = None kwargs["epsilon"] = 1e-5 fused_node = graph.call_function(auto_functionalized, @@ -187,7 +179,7 @@ def find_getitem(node, idx): return None rms_node = find_auto_fn(torch.ops._C.fused_add_rms_norm.default) - quant_node = find_auto_fn(torch.ops._C.static_scaled_int8_quant.default) + quant_node = find_auto_fn(torch.ops._C.static_scaled_fp8_quant.default) assert rms_node is not None assert quant_node is not None From 9d0bf7f072d87deb85e345c5ad59ab648ca69969 Mon Sep 17 00:00:00 2001 From: luka Date: Wed, 9 Oct 2024 01:58:32 +0000 Subject: [PATCH 58/88] Added layernorm_quant kernels for static fp8 quant, including tests. Fusion seems to work --- csrc/layernorm_quant_kernels.cu | 82 +++++++++----- csrc/ops.h | 7 ++ csrc/quantization/fp8/common.cu | 178 +---------------------------- csrc/quantization/fp8/common.cuh | 185 +++++++++++++++++++++++++++++++ csrc/torch_bindings.cpp | 13 +++ tests/kernels/test_layernorm.py | 80 +++++++++++-- vllm/compilation/fusion.py | 14 +-- 7 files changed, 343 insertions(+), 216 deletions(-) create mode 100644 csrc/quantization/fp8/common.cuh diff --git a/csrc/layernorm_quant_kernels.cu b/csrc/layernorm_quant_kernels.cu index 7a7a25d2173..ea380e6211a 100644 --- a/csrc/layernorm_quant_kernels.cu +++ b/csrc/layernorm_quant_kernels.cu @@ -1,3 +1,5 @@ +#include "quantization/fp8/common.cuh" + #include #include #include @@ -22,10 +24,11 @@ namespace vllm { // TODO(woosuk): Further optimize this kernel. template -__global__ void rms_norm_kernel( - scalar_t* __restrict__ out, // [..., hidden_size] +__global__ void rms_norm_static_fp8_quant_kernel( + FP8_TYPE* __restrict__ out, // [..., hidden_size] const scalar_t* __restrict__ input, // [..., hidden_size] const scalar_t* __restrict__ weight, // [hidden_size] + const float* __restrict__ scale, // [1] const float epsilon, const int num_tokens, const int hidden_size) { __shared__ float s_variance; float variance = 0.0f; @@ -44,10 +47,14 @@ __global__ void rms_norm_kernel( } __syncthreads(); + // invert scale to avoid division + float const scale_inv = 1.0f / *scale; + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { float x = (float)input[blockIdx.x * hidden_size + idx]; + float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx]; out[blockIdx.x * hidden_size + idx] = - ((scalar_t)(x * s_variance)) * weight[idx]; + scaled_fp8_conversion(out_norm, scale_inv); } } @@ -206,10 +213,12 @@ struct alignas(16) _f16Vec { memory latency bottleneck. */ template __global__ std::enable_if_t<(width > 0) && _typeConvert::exists> -fused_add_rms_norm_kernel( +fused_add_rms_norm_static_fp8_quant_kernel( + FP8_TYPE* __restrict__ out, // [..., hidden_size] scalar_t* __restrict__ input, // [..., hidden_size] scalar_t* __restrict__ residual, // [..., hidden_size] const scalar_t* __restrict__ weight, // [hidden_size] + const float* __restrict__ scale, // [1] const float epsilon, const int num_tokens, const int hidden_size) { // Sanity checks on our vector struct and type-punned pointer arithmetic static_assert(std::is_pod_v<_f16Vec>); @@ -245,12 +254,19 @@ fused_add_rms_norm_kernel( } __syncthreads(); + // invert scale to avoid division + float const scale_inv = 1.0f / *scale; + for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { int id = blockIdx.x * vec_hidden_size + idx; _f16Vec temp = residual_v[id]; temp *= s_variance; temp *= weight_v[idx]; - input_v[id] = temp; +#pragma unroll + for (int i = 0; i < width; ++i) { + out[id * width + i] = + scaled_fp8_conversion(float(temp.data[i]), scale_inv); + } } } @@ -259,10 +275,12 @@ fused_add_rms_norm_kernel( */ template __global__ std::enable_if_t<(width == 0) || !_typeConvert::exists> -fused_add_rms_norm_kernel( +fused_add_rms_norm_static_fp8_quant_kernel( + FP8_TYPE* __restrict__ out, // [..., hidden_size] scalar_t* __restrict__ input, // [..., hidden_size] scalar_t* __restrict__ residual, // [..., hidden_size] const scalar_t* __restrict__ weight, // [hidden_size] + const float* __restrict__ scale, // [1] const float epsilon, const int num_tokens, const int hidden_size) { __shared__ float s_variance; float variance = 0.0f; @@ -284,19 +302,24 @@ fused_add_rms_norm_kernel( } __syncthreads(); + // invert scale to avoid division + float const scale_inv = 1.0f / *scale; + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { float x = (float)residual[blockIdx.x * hidden_size + idx]; - input[blockIdx.x * hidden_size + idx] = - ((scalar_t)(x * s_variance)) * weight[idx]; + float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx]; + out[blockIdx.x * hidden_size + idx] = + scaled_fp8_conversion(out_norm, scale_inv); } } } // namespace vllm -void rms_norm(torch::Tensor& out, // [..., hidden_size] - torch::Tensor& input, // [..., hidden_size] - torch::Tensor& weight, // [hidden_size] - double epsilon) { +void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + torch::Tensor& scale, // [1] + double epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; @@ -305,26 +328,31 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size] const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { - vllm::rms_norm_kernel<<>>( - out.data_ptr(), input.data_ptr(), - weight.data_ptr(), epsilon, num_tokens, hidden_size); + vllm::rms_norm_static_fp8_quant_kernel + <<>>( + out.data_ptr(), input.data_ptr(), + weight.data_ptr(), scale.data_ptr(), epsilon, + num_tokens, hidden_size); }); } -#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ - VLLM_DISPATCH_FLOATING_TYPES( \ - input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \ - vllm::fused_add_rms_norm_kernel \ - <<>>(input.data_ptr(), \ - residual.data_ptr(), \ - weight.data_ptr(), epsilon, \ - num_tokens, hidden_size); \ +#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \ + vllm::fused_add_rms_norm_static_fp8_quant_kernel \ + <<>>( \ + out.data_ptr(), input.data_ptr(), \ + residual.data_ptr(), weight.data_ptr(), \ + scale.data_ptr(), epsilon, num_tokens, hidden_size); \ }); -void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] - torch::Tensor& residual, // [..., hidden_size] - torch::Tensor& weight, // [hidden_size] - double epsilon) { +void fused_add_rms_norm_static_fp8_quant( + torch::Tensor& out, // [..., hidden_size], + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& residual, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + torch::Tensor& scale, // [1] + double epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; diff --git a/csrc/ops.h b/csrc/ops.h index fce545f95a7..21d2c37c203 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -32,6 +32,13 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, double epsilon); +void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input, + torch::Tensor& weight, torch::Tensor& scale, double epsilon); + +void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input, + torch::Tensor& residual, torch::Tensor& weight, + torch::Tensor& scale, double epsilon); + void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, torch::Tensor& key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox); diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index 7e23f922577..c9cadbd42f2 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -1,185 +1,13 @@ -#include -#include -#include - -#include +#include "common.cuh" #include "cuda_compat.h" #include "dispatch_utils.h" -#ifndef USE_ROCM - #include - #include -#else - #include - #include -#endif - -#ifndef USE_ROCM -using FP8_TYPE = c10::Float8_e4m3fn; -C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = - std::numeric_limits::max(); -#else - #include "amd/hip_float8.h" -using FP8_TYPE = c10::Float8_e4m3fnuz; -// Using the default max value from pytorch (240.0) will cause accuracy -// issue when running dynamic quantization. Here use 224.0f for rocm. -constexpr auto FP8_E4M3_MAX = 224.0f; -#endif +#include +#include namespace vllm { -__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { - float old; - old = (value >= 0) - ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) - : __uint_as_float( - atomicMin((unsigned int*)addr, __float_as_uint(value))); - - return old; -} - -template -__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val, - float const scale) { - float x = 0.0f; - if constexpr (is_scale_inverted) { - x = val * scale; - } else { - x = val / scale; - } - - float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); -#ifndef USE_ROCM - return static_cast(r); -#else - // Use hardware cvt instruction for fp8 on rocm - return c10::Float8_e4m3fnuz(hip_fp8(r).data, - c10::Float8_e4m3fnuz::from_bits()); -#endif -} - -// Compute the absolute maximum m of the input tensor and store -// m / float8_e4m3::max() in *scale. Each thread block performs a -// reduction tree and the memory in scale is atomically updated. -// So to get the right answer, *scale needs to be initialized to -// a value <= 0.0 and we need to wait for all thread blocks to -// finish before consuming *scale. -template -__global__ void segmented_max_reduction(float* __restrict__ scale, - const scalar_t* __restrict__ input, - int64_t num_elems) { - __shared__ float cache[1024]; - int64_t i = blockDim.x * blockIdx.x + threadIdx.x; - - // First store maximum for all values processes by - // the current thread in cache[threadIdx.x] - scalar_t tmp = 0.0; - while (i < num_elems) { - float x = static_cast(input[i]); - tmp = max(tmp, fabs(x)); - i += blockDim.x * gridDim.x; - } - cache[threadIdx.x] = tmp; - - __syncthreads(); - - // Now perform parallel reduction within the thread block - int ib = blockDim.x / 2; - while (ib != 0) { - if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) { - cache[threadIdx.x] = cache[threadIdx.x + ib]; - } - __syncthreads(); - ib /= 2; - } - // Finally, since cache[0] contains the maximum for this thread block, - // atomically write the max to the target location - if (threadIdx.x == 0) { - atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX); - } -} - -template -struct __align__(8) vec4_t { - scalar_t x; - scalar_t y; - scalar_t z; - scalar_t w; -}; - -typedef struct __align__(4) { - FP8_TYPE x; - FP8_TYPE y; - FP8_TYPE z; - FP8_TYPE w; -} -float8x4_t; - -template -__device__ float thread_max_vec(scalar_t const* __restrict__ input, - int64_t const num_elems, int const tid, - int const step) { - // Vectorized input/output to better utilize memory bandwidth. - vec4_t const* vectorized_in = - reinterpret_cast const*>(input); - - int64_t const num_vec_elems = num_elems >> 2; - float absmax_val = 0.0f; - -#pragma unroll 4 - for (int64_t i = tid; i < num_vec_elems; i += step) { - vec4_t in_vec = vectorized_in[i]; - absmax_val = max(absmax_val, fabs(in_vec.x)); - absmax_val = max(absmax_val, fabs(in_vec.y)); - absmax_val = max(absmax_val, fabs(in_vec.z)); - absmax_val = max(absmax_val, fabs(in_vec.w)); - } - - // Handle the remaining elements if num_elems is not divisible by 4 - for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) { - absmax_val = max(absmax_val, fabs(input[i])); - } - - return absmax_val; -} - -template -__device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out, - scalar_t const* __restrict__ input, - float const scale, - int64_t const num_elems, - int const tid, int const step) { - // Vectorized input/output to better utilize memory bandwidth. - vec4_t const* vectorized_in = - reinterpret_cast const*>(input); - float8x4_t* vectorized_out = reinterpret_cast(out); - - int64_t const num_vec_elems = num_elems >> 2; - -#pragma unroll 4 - for (int64_t i = tid; i < num_vec_elems; i += step) { - vec4_t in_vec = vectorized_in[i]; - float8x4_t out_vec; - - out_vec.x = scaled_fp8_conversion( - static_cast(in_vec.x), scale); - out_vec.y = scaled_fp8_conversion( - static_cast(in_vec.y), scale); - out_vec.z = scaled_fp8_conversion( - static_cast(in_vec.z), scale); - out_vec.w = scaled_fp8_conversion( - static_cast(in_vec.w), scale); - vectorized_out[i] = out_vec; - } - - // Handle the remaining elements if num_elems is not divisible by 4 - for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) { - out[i] = scaled_fp8_conversion( - static_cast(input[i]), scale); - } -} - template __global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out, const scalar_t* __restrict__ input, diff --git a/csrc/quantization/fp8/common.cuh b/csrc/quantization/fp8/common.cuh new file mode 100644 index 00000000000..00b1cc312de --- /dev/null +++ b/csrc/quantization/fp8/common.cuh @@ -0,0 +1,185 @@ +#pragma once + +// TODO(luka) remove unnecessary includes +#include "cuda_compat.h" +#include "dispatch_utils.h" + +#include +#include +#include +#include + +#ifndef USE_ROCM + #include + #include +#else + #include + #include +#endif + +#ifndef USE_ROCM +using FP8_TYPE = c10::Float8_e4m3fn; +C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = + std::numeric_limits::max(); +#else + #include "amd/hip_float8.h" +using FP8_TYPE = c10::Float8_e4m3fnuz; +// Using the default max value from pytorch (240.0) will cause accuracy +// issue when running dynamic quantization. Here use 224.0f for rocm. +constexpr auto FP8_E4M3_MAX = 224.0f; +#endif + +namespace vllm { + +__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { + float old; + old = (value >= 0) + ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) + : __uint_as_float( + atomicMin((unsigned int*)addr, __float_as_uint(value))); + + return old; +} + +template +__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val, + float const scale) { + float x = 0.0f; + if constexpr (is_scale_inverted) { + x = val * scale; + } else { + x = val / scale; + } + + float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); +#ifndef USE_ROCM + return static_cast(r); +#else + // Use hardware cvt instruction for fp8 on rocm + return c10::Float8_e4m3fnuz(hip_fp8(r).data, + c10::Float8_e4m3fnuz::from_bits()); +#endif +} + +// Compute the absolute maximum m of the input tensor and store +// m / float8_e4m3::max() in *scale. Each thread block performs a +// reduction tree and the memory in scale is atomically updated. +// So to get the right answer, *scale needs to be initialized to +// a value <= 0.0 and we need to wait for all thread blocks to +// finish before consuming *scale. +template +__global__ void segmented_max_reduction(float* __restrict__ scale, + const scalar_t* __restrict__ input, + int64_t num_elems) { + __shared__ float cache[1024]; + int64_t i = blockDim.x * blockIdx.x + threadIdx.x; + + // First store maximum for all values processes by + // the current thread in cache[threadIdx.x] + scalar_t tmp = 0.0; + while (i < num_elems) { + float x = static_cast(input[i]); + tmp = max(tmp, fabs(x)); + i += blockDim.x * gridDim.x; + } + cache[threadIdx.x] = tmp; + + __syncthreads(); + + // Now perform parallel reduction within the thread block + int ib = blockDim.x / 2; + while (ib != 0) { + if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) { + cache[threadIdx.x] = cache[threadIdx.x + ib]; + } + __syncthreads(); + ib /= 2; + } + // Finally, since cache[0] contains the maximum for this thread block, + // atomically write the max to the target location + if (threadIdx.x == 0) { + atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX); + } +} + +template +struct __align__(8) vec4_t { + scalar_t x; + scalar_t y; + scalar_t z; + scalar_t w; +}; + +typedef struct __align__(4) { + FP8_TYPE x; + FP8_TYPE y; + FP8_TYPE z; + FP8_TYPE w; +} +float8x4_t; + +template +__device__ float thread_max_vec(scalar_t const* __restrict__ input, + int64_t const num_elems, int const tid, + int const step) { + // Vectorized input/output to better utilize memory bandwidth. + vec4_t const* vectorized_in = + reinterpret_cast const*>(input); + + int64_t const num_vec_elems = num_elems >> 2; + float absmax_val = 0.0f; + +#pragma unroll 4 + for (int64_t i = tid; i < num_vec_elems; i += step) { + vec4_t in_vec = vectorized_in[i]; + absmax_val = max(absmax_val, fabs(in_vec.x)); + absmax_val = max(absmax_val, fabs(in_vec.y)); + absmax_val = max(absmax_val, fabs(in_vec.z)); + absmax_val = max(absmax_val, fabs(in_vec.w)); + } + + // Handle the remaining elements if num_elems is not divisible by 4 + for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) { + absmax_val = max(absmax_val, fabs(input[i])); + } + + return absmax_val; +} + +template +__device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out, + scalar_t const* __restrict__ input, + float const scale, + int64_t const num_elems, + int const tid, int const step) { + // Vectorized input/output to better utilize memory bandwidth. + vec4_t const* vectorized_in = + reinterpret_cast const*>(input); + float8x4_t* vectorized_out = reinterpret_cast(out); + + int64_t const num_vec_elems = num_elems >> 2; + +#pragma unroll 4 + for (int64_t i = tid; i < num_vec_elems; i += step) { + vec4_t in_vec = vectorized_in[i]; + float8x4_t out_vec; + + out_vec.x = scaled_fp8_conversion( + static_cast(in_vec.x), scale); + out_vec.y = scaled_fp8_conversion( + static_cast(in_vec.y), scale); + out_vec.z = scaled_fp8_conversion( + static_cast(in_vec.z), scale); + out_vec.w = scaled_fp8_conversion( + static_cast(in_vec.w), scale); + vectorized_out[i] = out_vec; + } + + // Handle the remaining elements if num_elems is not divisible by 4 + for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) { + out[i] = scaled_fp8_conversion( + static_cast(input[i]), scale); + } +} + +} // namespace vllm \ No newline at end of file diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 037775f025f..37dc0dd21fe 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -104,6 +104,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "float epsilon) -> ()"); ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm); + // Layernorm-quant + // Apply Root Mean Square (RMS) Normalization to the input tensor. + ops.def( + "rms_norm_static_fp8_quant(Tensor! result, Tensor input, Tensor weight, Tensor scale, float epsilon) -> " + "()"); + ops.impl("rms_norm_static_fp8_quant", torch::kCUDA, &rms_norm_static_fp8_quant); + + // In-place fused Add and RMS Normalization. + ops.def( + "fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor input, Tensor! residual, Tensor weight, " + "Tensor scale, float epsilon) -> ()"); + ops.impl("fused_add_rms_norm_static_fp8_quant", torch::kCUDA, &fused_add_rms_norm_static_fp8_quant); + // Rotary embedding // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. ops.def( diff --git a/tests/kernels/test_layernorm.py b/tests/kernels/test_layernorm.py index 382079d472e..a26db360b7d 100644 --- a/tests/kernels/test_layernorm.py +++ b/tests/kernels/test_layernorm.py @@ -1,13 +1,15 @@ import pytest import torch +from tests.kernels.quant_utils import FP8_DTYPE + from tests.kernels.utils import opcheck from vllm.model_executor.layers.layernorm import RMSNorm from vllm.utils import seed_everything DTYPES = [torch.half, torch.bfloat16, torch.float] NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing -HIDDEN_SIZES = [768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192, +HIDDEN_SIZES = [8, 768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192, 8199] # Arbitrary values for testing ADD_RESIDUAL = [False, True] SEEDS = [0] @@ -24,12 +26,12 @@ @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() def test_rms_norm( - num_tokens: int, - hidden_size: int, - add_residual: bool, - dtype: torch.dtype, - seed: int, - device: str, + num_tokens: int, + hidden_size: int, + add_residual: bool, + dtype: torch.dtype, + seed: int, + device: str, ) -> None: seed_everything(seed) torch.set_default_device(device) @@ -59,3 +61,67 @@ def test_rms_norm( else: opcheck(torch.ops._C.rms_norm, (out, x, layer.weight.data, layer.variance_epsilon)) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("add_residual", ADD_RESIDUAL) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("quant_scale", [1.0, 0.01, 10.0]) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_fused_rms_norm_quant( + num_tokens: int, + hidden_size: int, + add_residual: bool, + dtype: torch.dtype, + quant_scale: float, + seed: int, + device: str, +) -> None: + seed_everything(seed) + torch.set_default_device(device) + + weight = torch.empty(hidden_size, dtype=dtype).normal_(mean=1.0, std=0.1) + scale = 1 / (2 * hidden_size) + x = torch.randn(num_tokens, hidden_size, dtype=dtype) + x *= scale + if add_residual: + residual = torch.randn_like(x) * scale + residual_fused = residual.clone() + else: + residual = residual_fused = None + + out_norm = torch.empty_like(x) + out_quant = torch.empty_like(x, dtype=FP8_DTYPE) + out_quant_fused = torch.empty_like(out_quant) + + quant_scale_t = torch.tensor(quant_scale, dtype=torch.float32) + + if add_residual: + torch.ops._C.fused_add_rms_norm_static_fp8_quant( + out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6) + + # Unfused kernel is in-place so it goes second + # Also use a separate clone of x to avoid modifying the input + x_unfused = x.clone() + torch.ops._C.fused_add_rms_norm(x_unfused, residual, weight, 1e-6) + torch.ops._C.static_scaled_fp8_quant(out_quant, x_unfused, quant_scale_t) + + torch.cuda.synchronize() + torch.testing.assert_close(residual_fused, residual, atol=1e-2, rtol=1e-2) + + opcheck(torch.ops._C.fused_add_rms_norm_static_fp8_quant, + (out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6)) + else: + torch.ops._C.rms_norm_static_fp8_quant(out_quant_fused, x, weight, quant_scale_t, 1e-6) + + torch.ops._C.rms_norm(out_norm, x, weight, 1e-6) + torch.ops._C.static_scaled_fp8_quant(out_quant, out_norm, quant_scale_t) + + opcheck(torch.ops._C.rms_norm_static_fp8_quant, + (out_quant_fused, x, weight, quant_scale_t, 1e-6)) + + torch.testing.assert_close(out_quant_fused.to(dtype=torch.float32), + out_quant.to(dtype=torch.float32), + atol=1e-3, rtol=1e-3) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 245b3016593..bc7a56eb491 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -75,8 +75,8 @@ def rms_pattern_static(result: torch.Tensor, result_rms: torch.Tensor, input: to def rms_replacement_static(result: torch.Tensor, result_rms: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): - at = auto_functionalized(torch.ops.vllm.fused_rms_norm_quant_static.default, result=result, input=input, - weight=weight, epsilon=1e-5, scale=scale) + at = auto_functionalized(torch.ops._C.rms_norm_static_fp8_quant.default, result=result, input=input, + weight=weight, scale=scale, epsilon=1e-5) # result return at[1] @@ -108,10 +108,10 @@ def rms_pattern_residual_static(result: torch.Tensor, input: torch.Tensor, resid def rms_replacement_residual_static(result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): - at = auto_functionalized(torch.ops.vllm.fused_rms_norm_residual_quant_static.default, result=result, input=input, - residual=residual, weight=weight, epsilon=1e-5, scale=scale) + at = auto_functionalized(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, result=result, input=input, + residual=residual, weight=weight, scale=scale, epsilon=1e-5) # result, residual - return at[1], at[3] + return at[1], at[2] my_patterns = PatternMatcherPass() @@ -158,12 +158,12 @@ def process_matches(matches, graph: torch.fx.Graph): kwargs["epsilon"] = 1e-5 fused_node = graph.call_function(auto_functionalized, - (torch.ops.vllm.fused_rms_norm_residual_quant_static.default,), + (torch.ops._C.fused_add_rms_norm_static_fp8_quant.default,), kwargs=kwargs) graph.inserting_after(fused_node) result_node_new = graph.call_function(operator.getitem, (fused_node, 1)) - residual_node_new = graph.call_function(operator.getitem, (fused_node, 3)) + residual_node_new = graph.call_function(operator.getitem, (fused_node, 2)) # find the output and the residual def find_auto_fn(op): From d33c17990b01bc056c2aeb1e2b15e75b42d16788 Mon Sep 17 00:00:00 2001 From: luka Date: Wed, 9 Oct 2024 02:10:58 +0000 Subject: [PATCH 59/88] fix_functionalization for layernorm_quant kernels --- vllm/compilation/backends.py | 62 ++++++++++++++++++++++++++++++------ 1 file changed, 52 insertions(+), 10 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 6a07290b94b..77e91eb2d50 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -84,16 +84,41 @@ def fix_functionalization(graph: fx.Graph): user.replace_all_uses_with(replace_node) nodes_to_remove.append(user) nodes_to_remove.append(node) + elif node.args[0] == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: + # manual replace for fused_add_rms_norm_static_fp8_quant + # this is the most effective optimization for llama + # failing to do this will result in many unnecessary copies + + kwargs = node.kwargs + + result = kwargs['result'] + residual = kwargs['residual'] + + # Create a new call to torch.ops._C.rotary_embedding.default + with graph.inserting_before(node): + # just insert the call to the custom op + # NOTE: don't run dead code elimination, + # otherwise this op will be removed + graph.call_function( + torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, kwargs=kwargs) + + for user in list(node.users): + if user.op == 'call_function' and user.target == operator.getitem: # noqa + # Remove the getitem node + if user.args[1] == 1: + replace_node = result + elif user.args[1] == 2: + replace_node = residual + user.replace_all_uses_with(replace_node) + nodes_to_remove.append(user) + nodes_to_remove.append(node) elif node.args[0] == torch.ops._C.rms_norm.default: # manual replace for rms_norm kwargs = node.kwargs - input = kwargs['input'] - out = kwargs['result'] - weight = kwargs['weight'] - epsilon = kwargs['epsilon'] + replace_node = kwargs['result'] # Create a new call to torch.ops._C.rotary_embedding.default # cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa with graph.inserting_before(node): @@ -101,11 +126,28 @@ def fix_functionalization(graph: fx.Graph): # NOTE: don't run dead code elimination, # otherwise this op will be removed graph.call_function( - torch.ops._C.rms_norm.default, - args=(out, input, weight, epsilon), - ) + torch.ops._C.rms_norm.default, kwargs=kwargs) - replace_node = out + for user in list(node.users): + if user.op == 'call_function' and user.target == operator.getitem: # noqa + user.replace_all_uses_with(replace_node) + nodes_to_remove.append(user) + nodes_to_remove.append(node) + + elif node.args[0] == torch.ops._C.rms_norm_static_fp8_quant.default: + # manual replace for rms_norm + + kwargs = node.kwargs + + replace_node = kwargs['result'] + # Create a new call to torch.ops._C.rotary_embedding.default + # cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa + with graph.inserting_before(node): + # just insert the call to the custom op + # NOTE: don't run dead code elimination, + # otherwise this op will be removed + graph.call_function( + torch.ops._C.rms_norm_static_fp8_quant.default, kwargs=kwargs) for user in list(node.users): if user.op == 'call_function' and user.target == operator.getitem: # noqa @@ -144,8 +186,8 @@ def fix_functionalization(graph: fx.Graph): graph.erase_node(node) # debug code, if we want to see the graph after the transformation - # with open("after.py", "w") as f: - # print(graph.python_code(root_module="self", verbose=True).src, file=f) + with open("after.py", "w") as f: + print(graph.python_code(root_module="self", verbose=True).src, file=f) from vllm.compilation.fusion import get_fusion_pass From f7ac7efc310a90735f9e0713d6dc0cf92a4f1e1c Mon Sep 17 00:00:00 2001 From: luka Date: Wed, 9 Oct 2024 16:18:41 +0000 Subject: [PATCH 60/88] env var for disabling fusion --- vllm/compilation/fusion.py | 10 +++++++++- vllm/envs.py | 5 +++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index bc7a56eb491..dd6e9f2b83a 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -151,7 +151,8 @@ def process_matches(matches, graph: torch.fx.Graph): nodes = list(graph.nodes) # TODO this is an expensive check if not all(node in nodes for node in match.nodes): - raise ValueError(f"Broken match: not all nodes in graph: {[node for node in match.nodes if node not in nodes]}") + raise ValueError( + f"Broken match: not all nodes in graph: {[node for node in match.nodes if node not in nodes]}") last_node_in_match = max(match.nodes, key=lambda x: nodes.index(x)) with graph.inserting_after(last_node_in_match): kwargs = match.kwargs @@ -199,7 +200,14 @@ def find_getitem(node, idx): assert all(node not in graph.nodes for node in (match for match in matches)) +def noop_pass(graph: torch.fx.Graph): + pass + + def get_fusion_pass(): + if not envs.VLLM_TORCH_COMPILE_FUSION: + return noop_pass + patterns, matches = get_patterns() def dump_graph(graph: torch.fx.Graph, stage: str): diff --git a/vllm/envs.py b/vllm/envs.py index 16ca1e653cd..bafbdc73373 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -16,6 +16,7 @@ VLLM_DYNAMO_USE_CUSTOM_DISPATCHER: bool = True VLLM_TEST_COMPILE_NO_CUSTOM_OPS: int = 0 VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE: bool = True + VLLM_TORCH_COMPILE_FUSION: bool = True VLLM_TORCH_COMPILE_FUSION_DUMP: List[str] = [] LOCAL_RANK: int = 0 CUDA_VISIBLE_DEVICES: Optional[str] = None @@ -221,6 +222,10 @@ def get_default_config_root(): lambda: bool( os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"), + # Internal flag to enable fusion in torch.compile + "VLLM_TORCH_COMPILE_FUSION": lambda: bool( + os.environ.get("VLLM_TORCH_COMPILE_FUSION", "1") != "0"), + # Internal flag for dumping the model graph before and after fusion "VLLM_TORCH_COMPILE_FUSION_DUMP": lambda: list( From 36e8938c0a4f81ca846a71e0a28b838b46eb1574 Mon Sep 17 00:00:00 2001 From: luka Date: Wed, 9 Oct 2024 18:36:48 +0000 Subject: [PATCH 61/88] Fix for fusion assert --- vllm/compilation/fusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index dd6e9f2b83a..7374dfa26cc 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -197,7 +197,7 @@ def find_getitem(node, idx): # Finally, remove matched nodes graph.eliminate_dead_code() - assert all(node not in graph.nodes for node in (match for match in matches)) + assert all(node not in graph.nodes for match in matches for node in match.nodes) def noop_pass(graph: torch.fx.Graph): From 733d9f4fa28b5520b987df13dbe77b0b1d15f2a5 Mon Sep 17 00:00:00 2001 From: luka Date: Thu, 10 Oct 2024 15:55:13 +0000 Subject: [PATCH 62/88] Clean up fusion.py --- vllm/compilation/fusion.py | 68 -------------------------------------- 1 file changed, 68 deletions(-) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 7374dfa26cc..2cf593e7186 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -9,59 +9,6 @@ from vllm.logger import init_logger logger = init_logger(__name__) -logger.setLevel(logging.DEBUG) # TODO - - -# DYNAMIC -@torch.library.custom_op("vllm::fused_rms_norm_quant_dynamic", mutates_args=['result', 'scale']) -def fused_rms_norm_quant_dynamic(result: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, - epsilon: float) -> None: - print("vllm::fused_rms_norm_quant_dynamic") - result_rms = torch.empty_like(input) - torch.ops._C.rms_norm(result_rms, input, weight, epsilon) - torch.ops._C.dynamic_scaled_fp8_quant(result, result_rms, scale) - - -@torch.library.register_fake("vllm::fused_rms_norm_quant_dynamic") -def _(result: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, epsilon: float) -> None: - return - - -# TODO epsilon -def rms_pattern(result: torch.Tensor, result_rms: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor): - at1 = auto_functionalized(torch.ops._C.rms_norm.default, result=result_rms, input=input, weight=weight, - epsilon=1e-6) - at2 = auto_functionalized(torch.ops._C.dynamic_scaled_fp8_quant.default, result=result, input=at1[1], scale=scale) - - # result, scale (multi-output not currently working) - return at2[1:3] - - -def rms_replacement(result: torch.Tensor, result_rms: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor): - at = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.fused_rms_norm_quant_dynamic.default, result=result, - input=input, weight=weight, - epsilon=1e-6, scale=scale) - - # result, scale (multi-output not currently working) - return at[1:3] - - -# STATIC -@torch.library.custom_op("vllm::fused_rms_norm_quant_static", mutates_args=['result']) -def fused_rms_norm_quant_static(result: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, - epsilon: float) -> None: - print("vllm::fused_rms_norm_quant_static") - result_rms = torch.empty_like(input) - torch.ops._C.rms_norm(result_rms, input, weight, epsilon) - torch.ops._C.static_scaled_fp8_quant(result, result_rms, scale) - - -@torch.library.register_fake("vllm::fused_rms_norm_quant_static") -def _(result: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, epsilon: float) -> None: - return - def rms_pattern_static(result: torch.Tensor, result_rms: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): @@ -82,20 +29,6 @@ def rms_replacement_static(result: torch.Tensor, result_rms: torch.Tensor, input return at[1] -@torch.library.custom_op("vllm::fused_rms_norm_residual_quant_static", mutates_args=['result', 'input', 'residual']) -def fused_rms_norm_residual_quant_static(result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, - weight: torch.Tensor, scale: torch.Tensor, epsilon: float) -> None: - # print("vllm::fused_rms_norm_residual_quant_static") - torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) - torch.ops._C.static_scaled_fp8_quant(result, input, scale) - - -@torch.library.register_fake("vllm::fused_rms_norm_residual_quant_static") -def _(result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, - epsilon: float) -> None: - return - - def rms_pattern_residual_static(result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): at = auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input=input, residual=residual, weight=weight, @@ -129,7 +62,6 @@ def get_patterns(): my_patterns = PatternMatcherPass(pass_name="fusion_pass") inputs = [empty_fp8(5, 4), empty_bf16(5, 4), empty_bf16(5, 4), empty_bf16(1, 5), torch.empty(1, 1, device="cuda")] - register_replacement(rms_pattern, rms_replacement, inputs, fwd_only, my_patterns) register_replacement(rms_pattern_static, rms_replacement_static, inputs, fwd_only, my_patterns) matches = [] From d1f8ae8149050c86699f407869b0079a1ecb2629 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 10 Oct 2024 10:09:32 -0700 Subject: [PATCH 63/88] add supports_dynamo in the decorator --- vllm/compilation/decorators.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 91be4f0ee1e..b790e5550ad 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -7,6 +7,7 @@ from vllm.compilation.levels import CompilationLevel from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.sequence import IntermediateTensors +from vllm.utils import supports_dynamo def support_compile_llama_style(cls: type): @@ -20,7 +21,7 @@ def support_compile_llama_style(cls: type): # will handle the compilation, so we don't need to do anything here. if envs.VLLM_TORCH_COMPILE_LEVEL in [ CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS - ]: + ] or not supports_dynamo(): return cls # take care of method resolution order From 5073da79183ba75bf8973a77e851d9c7df502fe6 Mon Sep 17 00:00:00 2001 From: luka Date: Fri, 11 Oct 2024 01:47:44 +0000 Subject: [PATCH 64/88] fix example_inputs dtype --- vllm/compilation/fusion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 2cf593e7186..e375d200d62 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -61,7 +61,7 @@ def empty_fp8(*args, **kwargs): def get_patterns(): my_patterns = PatternMatcherPass(pass_name="fusion_pass") - inputs = [empty_fp8(5, 4), empty_bf16(5, 4), empty_bf16(5, 4), empty_bf16(1, 5), torch.empty(1, 1, device="cuda")] + inputs = [empty_fp8(5, 4), empty_bf16(5, 4), empty_bf16(5, 4), empty_bf16(1, 5), torch.empty(1, 1, device="cuda", dtype=torch.float32)] register_replacement(rms_pattern_static, rms_replacement_static, inputs, fwd_only, my_patterns) matches = [] @@ -71,7 +71,7 @@ def record_match_fn(match: Match): return False # with residual - inputs = [empty_fp8(5, 4), empty_bf16(5, 4), empty_bf16(5, 4), empty_bf16(1, 5), torch.empty(1, 1, device="cuda")] + inputs = [empty_fp8(5, 4), empty_bf16(5, 4), empty_bf16(5, 4), empty_bf16(1, 5), torch.empty(1, 1, device="cuda", dtype=torch.float32)] register_replacement(rms_pattern_residual_static, rms_replacement_residual_static, inputs, fwd_only, my_patterns, extra_check=record_match_fn) From 71379be39ce1cffff3482566930bff623c223c57 Mon Sep 17 00:00:00 2001 From: luka Date: Thu, 17 Oct 2024 21:14:37 +0000 Subject: [PATCH 65/88] extract common type conversion stuff to .cuh file --- csrc/layernorm_kernels.cu | 162 +------------------------------ csrc/layernorm_quant_kernels.cu | 160 +----------------------------- csrc/type_convert.cuh | 167 ++++++++++++++++++++++++++++++++ 3 files changed, 173 insertions(+), 316 deletions(-) create mode 100644 csrc/type_convert.cuh diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 7a7a25d2173..8b521815735 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -1,21 +1,16 @@ +#include "type_convert.cuh" +#include "quantization/fp8/common.cuh" + #include #include #include #include "dispatch_utils.h" + #ifndef USE_ROCM - #include - #include - #include #include #else - #include - #include - #include #include - -using __nv_bfloat16 = __hip_bfloat16; -using __nv_bfloat162 = __hip_bfloat162; #endif namespace vllm { @@ -51,155 +46,6 @@ __global__ void rms_norm_kernel( } } -/* Converter structs for the conversion from torch types to HIP/CUDA types, - and the associated type conversions within HIP/CUDA. These helpers need - to be implemented for now because the relevant type conversion - operators/constructors are not consistently implemented by HIP/CUDA, so - a generic conversion via type casts cannot be implemented. - - Each struct should have the member static constexpr bool `exists`: - If false, the optimized kernel is not used for the corresponding torch type. - If true, the struct should be fully defined as shown in the examples below. - */ -template -struct _typeConvert { - static constexpr bool exists = false; -}; - -#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) -// CUDA < 12.0 runs into issues with packed type conversion -template <> -struct _typeConvert { - static constexpr bool exists = true; - using hip_type = __half; - using packed_hip_type = __half2; - - __device__ static inline float convert(hip_type x) { return __half2float(x); } - __device__ static inline float2 convert(packed_hip_type x) { - return __half22float2(x); - } - __device__ static inline hip_type convert(float x) { - return __float2half_rn(x); - } - __device__ static inline packed_hip_type convert(float2 x) { - return __float22half2_rn(x); - } -}; - - #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -// CUDA_ARCH < 800 does not have BF16 support -// TODO: Add in ROCm support once public headers handle bf16 maturely -template <> -struct _typeConvert { - static constexpr bool exists = true; - using hip_type = __nv_bfloat16; - using packed_hip_type = __nv_bfloat162; - - __device__ static inline float convert(hip_type x) { - return __bfloat162float(x); - } - __device__ static inline float2 convert(packed_hip_type x) { - return __bfloat1622float2(x); - } - __device__ static inline hip_type convert(float x) { - return __float2bfloat16(x); - } - __device__ static inline packed_hip_type convert(float2 x) { - return __float22bfloat162_rn(x); - } -}; - #endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= - // 12000)) - -/* Vector POD struct to generate vectorized and packed FP16/BF16 ops - for appropriate specializations of fused_add_rms_norm_kernel. - Only functions that are necessary in that kernel are implemented. - Alignment to 16 bytes is required to use 128-bit global memory ops. - */ -template -struct alignas(16) _f16Vec { - /* Not theoretically necessary that width is a power of 2 but should - almost always be the case for optimization purposes */ - static_assert(width > 0 && (width & (width - 1)) == 0, - "Width is not a positive power of 2!"); - using Converter = _typeConvert; - using T1 = typename Converter::hip_type; - using T2 = typename Converter::packed_hip_type; - T1 data[width]; - - __device__ _f16Vec& operator+=(const _f16Vec& other) { - if constexpr (width % 2 == 0) { -#pragma unroll - for (int i = 0; i < width; i += 2) { - T2 temp{data[i], data[i + 1]}; - temp += T2{other.data[i], other.data[i + 1]}; - data[i] = temp.x; - data[i + 1] = temp.y; - } - } else { -#pragma unroll - for (int i = 0; i < width; ++i) data[i] += other.data[i]; - } - return *this; - } - - __device__ _f16Vec& operator*=(const _f16Vec& other) { - if constexpr (width % 2 == 0) { -#pragma unroll - for (int i = 0; i < width; i += 2) { - T2 temp{data[i], data[i + 1]}; - temp *= T2{other.data[i], other.data[i + 1]}; - data[i] = temp.x; - data[i + 1] = temp.y; - } - } else { -#pragma unroll - for (int i = 0; i < width; ++i) data[i] *= other.data[i]; - } - return *this; - } - - __device__ _f16Vec& operator*=(const float scale) { - if constexpr (width % 2 == 0) { -#pragma unroll - for (int i = 0; i < width; i += 2) { - float2 temp_f = Converter::convert(T2{data[i], data[i + 1]}); - temp_f.x *= scale; - temp_f.y *= scale; - T2 temp = Converter::convert(temp_f); - data[i] = temp.x; - data[i + 1] = temp.y; - } - } else { -#pragma unroll - for (int i = 0; i < width; ++i) { - float temp = Converter::convert(data[i]) * scale; - data[i] = Converter::convert(temp); - } - } - return *this; - } - - __device__ float sum_squares() const { - float result = 0.0f; - if constexpr (width % 2 == 0) { -#pragma unroll - for (int i = 0; i < width; i += 2) { - float2 z = Converter::convert(T2{data[i], data[i + 1]}); - result += z.x * z.x + z.y * z.y; - } - } else { -#pragma unroll - for (int i = 0; i < width; ++i) { - float x = Converter::convert(data[i]); - result += x * x; - } - } - return result; - } -}; - /* Function specialization in the case of FP16/BF16 tensors. Additional optimizations we can make in this case are packed and vectorized operations, which help with the diff --git a/csrc/layernorm_quant_kernels.cu b/csrc/layernorm_quant_kernels.cu index ea380e6211a..35c9232dcd9 100644 --- a/csrc/layernorm_quant_kernels.cu +++ b/csrc/layernorm_quant_kernels.cu @@ -1,3 +1,4 @@ +#include "type_convert.cuh" #include "quantization/fp8/common.cuh" #include @@ -5,19 +6,11 @@ #include #include "dispatch_utils.h" + #ifndef USE_ROCM - #include - #include - #include #include #else - #include - #include - #include #include - -using __nv_bfloat16 = __hip_bfloat16; -using __nv_bfloat162 = __hip_bfloat162; #endif namespace vllm { @@ -58,155 +51,6 @@ __global__ void rms_norm_static_fp8_quant_kernel( } } -/* Converter structs for the conversion from torch types to HIP/CUDA types, - and the associated type conversions within HIP/CUDA. These helpers need - to be implemented for now because the relevant type conversion - operators/constructors are not consistently implemented by HIP/CUDA, so - a generic conversion via type casts cannot be implemented. - - Each struct should have the member static constexpr bool `exists`: - If false, the optimized kernel is not used for the corresponding torch type. - If true, the struct should be fully defined as shown in the examples below. - */ -template -struct _typeConvert { - static constexpr bool exists = false; -}; - -#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) -// CUDA < 12.0 runs into issues with packed type conversion -template <> -struct _typeConvert { - static constexpr bool exists = true; - using hip_type = __half; - using packed_hip_type = __half2; - - __device__ static inline float convert(hip_type x) { return __half2float(x); } - __device__ static inline float2 convert(packed_hip_type x) { - return __half22float2(x); - } - __device__ static inline hip_type convert(float x) { - return __float2half_rn(x); - } - __device__ static inline packed_hip_type convert(float2 x) { - return __float22half2_rn(x); - } -}; - - #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -// CUDA_ARCH < 800 does not have BF16 support -// TODO: Add in ROCm support once public headers handle bf16 maturely -template <> -struct _typeConvert { - static constexpr bool exists = true; - using hip_type = __nv_bfloat16; - using packed_hip_type = __nv_bfloat162; - - __device__ static inline float convert(hip_type x) { - return __bfloat162float(x); - } - __device__ static inline float2 convert(packed_hip_type x) { - return __bfloat1622float2(x); - } - __device__ static inline hip_type convert(float x) { - return __float2bfloat16(x); - } - __device__ static inline packed_hip_type convert(float2 x) { - return __float22bfloat162_rn(x); - } -}; - #endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= - // 12000)) - -/* Vector POD struct to generate vectorized and packed FP16/BF16 ops - for appropriate specializations of fused_add_rms_norm_kernel. - Only functions that are necessary in that kernel are implemented. - Alignment to 16 bytes is required to use 128-bit global memory ops. - */ -template -struct alignas(16) _f16Vec { - /* Not theoretically necessary that width is a power of 2 but should - almost always be the case for optimization purposes */ - static_assert(width > 0 && (width & (width - 1)) == 0, - "Width is not a positive power of 2!"); - using Converter = _typeConvert; - using T1 = typename Converter::hip_type; - using T2 = typename Converter::packed_hip_type; - T1 data[width]; - - __device__ _f16Vec& operator+=(const _f16Vec& other) { - if constexpr (width % 2 == 0) { -#pragma unroll - for (int i = 0; i < width; i += 2) { - T2 temp{data[i], data[i + 1]}; - temp += T2{other.data[i], other.data[i + 1]}; - data[i] = temp.x; - data[i + 1] = temp.y; - } - } else { -#pragma unroll - for (int i = 0; i < width; ++i) data[i] += other.data[i]; - } - return *this; - } - - __device__ _f16Vec& operator*=(const _f16Vec& other) { - if constexpr (width % 2 == 0) { -#pragma unroll - for (int i = 0; i < width; i += 2) { - T2 temp{data[i], data[i + 1]}; - temp *= T2{other.data[i], other.data[i + 1]}; - data[i] = temp.x; - data[i + 1] = temp.y; - } - } else { -#pragma unroll - for (int i = 0; i < width; ++i) data[i] *= other.data[i]; - } - return *this; - } - - __device__ _f16Vec& operator*=(const float scale) { - if constexpr (width % 2 == 0) { -#pragma unroll - for (int i = 0; i < width; i += 2) { - float2 temp_f = Converter::convert(T2{data[i], data[i + 1]}); - temp_f.x *= scale; - temp_f.y *= scale; - T2 temp = Converter::convert(temp_f); - data[i] = temp.x; - data[i + 1] = temp.y; - } - } else { -#pragma unroll - for (int i = 0; i < width; ++i) { - float temp = Converter::convert(data[i]) * scale; - data[i] = Converter::convert(temp); - } - } - return *this; - } - - __device__ float sum_squares() const { - float result = 0.0f; - if constexpr (width % 2 == 0) { -#pragma unroll - for (int i = 0; i < width; i += 2) { - float2 z = Converter::convert(T2{data[i], data[i + 1]}); - result += z.x * z.x + z.y * z.y; - } - } else { -#pragma unroll - for (int i = 0; i < width; ++i) { - float x = Converter::convert(data[i]); - result += x * x; - } - } - return result; - } -}; - /* Function specialization in the case of FP16/BF16 tensors. Additional optimizations we can make in this case are packed and vectorized operations, which help with the diff --git a/csrc/type_convert.cuh b/csrc/type_convert.cuh new file mode 100644 index 00000000000..8840bb9ddca --- /dev/null +++ b/csrc/type_convert.cuh @@ -0,0 +1,167 @@ +#pragma once + +#include + +#ifndef USE_ROCM + #include + #include +// #include +#else + #include + #include + #include + +using __nv_bfloat16 = __hip_bfloat16; +using __nv_bfloat162 = __hip_bfloat162; +#endif + +namespace vllm { +/* Converter structs for the conversion from torch types to HIP/CUDA types, + and the associated type conversions within HIP/CUDA. These helpers need + to be implemented for now because the relevant type conversion + operators/constructors are not consistently implemented by HIP/CUDA, so + a generic conversion via type casts cannot be implemented. + + Each struct should have the member static constexpr bool `exists`: + If false, the optimized kernel is not used for the corresponding torch type. + If true, the struct should be fully defined as shown in the examples below. + */ +template +struct _typeConvert { + static constexpr bool exists = false; +}; + +#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) +// CUDA < 12.0 runs into issues with packed type conversion +template <> +struct _typeConvert { + static constexpr bool exists = true; + using hip_type = __half; + using packed_hip_type = __half2; + + __device__ static inline float convert(hip_type x) { return __half2float(x); } + __device__ static inline float2 convert(packed_hip_type x) { + return __half22float2(x); + } + __device__ static inline hip_type convert(float x) { + return __float2half_rn(x); + } + __device__ static inline packed_hip_type convert(float2 x) { + return __float22half2_rn(x); + } +}; + + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +// CUDA_ARCH < 800 does not have BF16 support +// TODO: Add in ROCm support once public headers handle bf16 maturely +template <> +struct _typeConvert { + static constexpr bool exists = true; + using hip_type = __nv_bfloat16; + using packed_hip_type = __nv_bfloat162; + + __device__ static inline float convert(hip_type x) { + return __bfloat162float(x); + } + __device__ static inline float2 convert(packed_hip_type x) { + return __bfloat1622float2(x); + } + __device__ static inline hip_type convert(float x) { + return __float2bfloat16(x); + } + __device__ static inline packed_hip_type convert(float2 x) { + return __float22bfloat162_rn(x); + } +}; + #endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= + // 12000)) + +/* Vector POD struct to generate vectorized and packed FP16/BF16 ops + for appropriate specializations of fused_add_rms_norm_kernel. + Only functions that are necessary in that kernel are implemented. + Alignment to 16 bytes is required to use 128-bit global memory ops. + */ +template +struct alignas(16) _f16Vec { + /* Not theoretically necessary that width is a power of 2 but should + almost always be the case for optimization purposes */ + static_assert(width > 0 && (width & (width - 1)) == 0, + "Width is not a positive power of 2!"); + using Converter = _typeConvert; + using T1 = typename Converter::hip_type; + using T2 = typename Converter::packed_hip_type; + T1 data[width]; + + __device__ _f16Vec& operator+=(const _f16Vec& other) { + if constexpr (width % 2 == 0) { +#pragma unroll + for (int i = 0; i < width; i += 2) { + T2 temp{data[i], data[i + 1]}; + temp += T2{other.data[i], other.data[i + 1]}; + data[i] = temp.x; + data[i + 1] = temp.y; + } + } else { +#pragma unroll + for (int i = 0; i < width; ++i) data[i] += other.data[i]; + } + return *this; + } + + __device__ _f16Vec& operator*=(const _f16Vec& other) { + if constexpr (width % 2 == 0) { +#pragma unroll + for (int i = 0; i < width; i += 2) { + T2 temp{data[i], data[i + 1]}; + temp *= T2{other.data[i], other.data[i + 1]}; + data[i] = temp.x; + data[i + 1] = temp.y; + } + } else { +#pragma unroll + for (int i = 0; i < width; ++i) data[i] *= other.data[i]; + } + return *this; + } + + __device__ _f16Vec& operator*=(const float scale) { + if constexpr (width % 2 == 0) { +#pragma unroll + for (int i = 0; i < width; i += 2) { + float2 temp_f = Converter::convert(T2{data[i], data[i + 1]}); + temp_f.x *= scale; + temp_f.y *= scale; + T2 temp = Converter::convert(temp_f); + data[i] = temp.x; + data[i + 1] = temp.y; + } + } else { +#pragma unroll + for (int i = 0; i < width; ++i) { + float temp = Converter::convert(data[i]) * scale; + data[i] = Converter::convert(temp); + } + } + return *this; + } + + __device__ float sum_squares() const { + float result = 0.0f; + if constexpr (width % 2 == 0) { +#pragma unroll + for (int i = 0; i < width; i += 2) { + float2 z = Converter::convert(T2{data[i], data[i + 1]}); + result += z.x * z.x + z.y * z.y; + } + } else { +#pragma unroll + for (int i = 0; i < width; ++i) { + float x = Converter::convert(data[i]); + result += x * x; + } + } + return result; + } +}; +} // namespace vllm \ No newline at end of file From f33d59b7e2d70179a1c3196702ea6455afc3b3aa Mon Sep 17 00:00:00 2001 From: luka Date: Thu, 17 Oct 2024 21:46:18 +0000 Subject: [PATCH 66/88] format --- csrc/ops.h | 13 ++-- csrc/quantization/fp8/common.cuh | 2 +- csrc/torch_bindings.cpp | 18 +++-- tests/kernels/test_layernorm.py | 51 +++++++------ vllm/compilation/backends.py | 22 +++--- vllm/compilation/fusion.py | 121 +++++++++++++++++++++++-------- vllm/envs.py | 9 +-- 7 files changed, 156 insertions(+), 80 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index 8522fa51f29..f5ba79969dd 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -33,11 +33,14 @@ void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, double epsilon); void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input, - torch::Tensor& weight, torch::Tensor& scale, double epsilon); - -void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input, - torch::Tensor& residual, torch::Tensor& weight, - torch::Tensor& scale, double epsilon); + torch::Tensor& weight, torch::Tensor& scale, + double epsilon); + +void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out, + torch::Tensor& input, + torch::Tensor& residual, + torch::Tensor& weight, + torch::Tensor& scale, double epsilon); void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, torch::Tensor& key, int64_t head_size, diff --git a/csrc/quantization/fp8/common.cuh b/csrc/quantization/fp8/common.cuh index 00b1cc312de..2c67e23ea23 100644 --- a/csrc/quantization/fp8/common.cuh +++ b/csrc/quantization/fp8/common.cuh @@ -182,4 +182,4 @@ __device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out, } } -} // namespace vllm \ No newline at end of file +} // namespace vllm \ No newline at end of file diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 2ce52c2a689..a8a6158a31d 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -107,15 +107,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Layernorm-quant // Apply Root Mean Square (RMS) Normalization to the input tensor. ops.def( - "rms_norm_static_fp8_quant(Tensor! result, Tensor input, Tensor weight, Tensor scale, float epsilon) -> " + "rms_norm_static_fp8_quant(Tensor! result, Tensor input, Tensor weight, " + "Tensor scale, float epsilon) -> " "()"); - ops.impl("rms_norm_static_fp8_quant", torch::kCUDA, &rms_norm_static_fp8_quant); + ops.impl("rms_norm_static_fp8_quant", torch::kCUDA, + &rms_norm_static_fp8_quant); // In-place fused Add and RMS Normalization. ops.def( - "fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor input, Tensor! residual, Tensor weight, " + "fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor input, " + "Tensor! residual, Tensor weight, " "Tensor scale, float epsilon) -> ()"); - ops.impl("fused_add_rms_norm_static_fp8_quant", torch::kCUDA, &fused_add_rms_norm_static_fp8_quant); + ops.impl("fused_add_rms_norm_static_fp8_quant", torch::kCUDA, + &fused_add_rms_norm_static_fp8_quant); // Rotary embedding // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. @@ -328,12 +332,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Compute FP8 quantized tensor for given scaling factor. ops.def( - "static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> ()"); + "static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> " + "()"); ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant); // Compute dynamic-per-tensor FP8 quantized tensor and scaling factor. ops.def( - "dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) -> " + "dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) " + "-> " "()"); ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant); diff --git a/tests/kernels/test_layernorm.py b/tests/kernels/test_layernorm.py index a26db360b7d..750ccfa29d0 100644 --- a/tests/kernels/test_layernorm.py +++ b/tests/kernels/test_layernorm.py @@ -2,7 +2,6 @@ import torch from tests.kernels.quant_utils import FP8_DTYPE - from tests.kernels.utils import opcheck from vllm.model_executor.layers.layernorm import RMSNorm from vllm.utils import seed_everything @@ -26,12 +25,12 @@ @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() def test_rms_norm( - num_tokens: int, - hidden_size: int, - add_residual: bool, - dtype: torch.dtype, - seed: int, - device: str, + num_tokens: int, + hidden_size: int, + add_residual: bool, + dtype: torch.dtype, + seed: int, + device: str, ) -> None: seed_everything(seed) torch.set_default_device(device) @@ -71,13 +70,13 @@ def test_rms_norm( @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) def test_fused_rms_norm_quant( - num_tokens: int, - hidden_size: int, - add_residual: bool, - dtype: torch.dtype, - quant_scale: float, - seed: int, - device: str, + num_tokens: int, + hidden_size: int, + add_residual: bool, + dtype: torch.dtype, + quant_scale: float, + seed: int, + device: str, ) -> None: seed_everything(seed) torch.set_default_device(device) @@ -106,22 +105,30 @@ def test_fused_rms_norm_quant( # Also use a separate clone of x to avoid modifying the input x_unfused = x.clone() torch.ops._C.fused_add_rms_norm(x_unfused, residual, weight, 1e-6) - torch.ops._C.static_scaled_fp8_quant(out_quant, x_unfused, quant_scale_t) + torch.ops._C.static_scaled_fp8_quant(out_quant, x_unfused, + quant_scale_t) torch.cuda.synchronize() - torch.testing.assert_close(residual_fused, residual, atol=1e-2, rtol=1e-2) - - opcheck(torch.ops._C.fused_add_rms_norm_static_fp8_quant, - (out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6)) + torch.testing.assert_close(residual_fused, + residual, + atol=1e-2, + rtol=1e-2) + + opcheck( + torch.ops._C.fused_add_rms_norm_static_fp8_quant, + (out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6)) else: - torch.ops._C.rms_norm_static_fp8_quant(out_quant_fused, x, weight, quant_scale_t, 1e-6) + torch.ops._C.rms_norm_static_fp8_quant(out_quant_fused, x, weight, + quant_scale_t, 1e-6) torch.ops._C.rms_norm(out_norm, x, weight, 1e-6) - torch.ops._C.static_scaled_fp8_quant(out_quant, out_norm, quant_scale_t) + torch.ops._C.static_scaled_fp8_quant(out_quant, out_norm, + quant_scale_t) opcheck(torch.ops._C.rms_norm_static_fp8_quant, (out_quant_fused, x, weight, quant_scale_t, 1e-6)) torch.testing.assert_close(out_quant_fused.to(dtype=torch.float32), out_quant.to(dtype=torch.float32), - atol=1e-3, rtol=1e-3) + atol=1e-3, + rtol=1e-3) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 6bd8bf59cc4..11737e1cedc 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -5,6 +5,7 @@ import torch import torch.fx as fx +from vllm.compilation.fusion import get_fusion_pass from vllm.logger import init_logger from .compile_context import get_compile_context @@ -93,7 +94,8 @@ def fix_functionalization(graph: fx.Graph): user.replace_all_uses_with(replace_node) nodes_to_remove.append(user) nodes_to_remove.append(node) - elif node.args[0] == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: + elif (node.args[0] == + torch.ops._C.fused_add_rms_norm_static_fp8_quant.default): # manual replace for fused_add_rms_norm_static_fp8_quant # this is the most effective optimization for llama # failing to do this will result in many unnecessary copies @@ -109,7 +111,9 @@ def fix_functionalization(graph: fx.Graph): # NOTE: don't run dead code elimination, # otherwise this op will be removed graph.call_function( - torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, kwargs=kwargs) + torch.ops._C.fused_add_rms_norm_static_fp8_quant. + default, + kwargs=kwargs) for user in list(node.users): if user.op == 'call_function' and user.target == operator.getitem: # noqa @@ -134,8 +138,8 @@ def fix_functionalization(graph: fx.Graph): # just insert the call to the custom op # NOTE: don't run dead code elimination, # otherwise this op will be removed - graph.call_function( - torch.ops._C.rms_norm.default, kwargs=kwargs) + graph.call_function(torch.ops._C.rms_norm.default, + kwargs=kwargs) for user in list(node.users): if user.op == 'call_function' and user.target == operator.getitem: # noqa @@ -143,7 +147,8 @@ def fix_functionalization(graph: fx.Graph): nodes_to_remove.append(user) nodes_to_remove.append(node) - elif node.args[0] == torch.ops._C.rms_norm_static_fp8_quant.default: + elif node.args[ + 0] == torch.ops._C.rms_norm_static_fp8_quant.default: # manual replace for rms_norm kwargs = node.kwargs @@ -156,7 +161,8 @@ def fix_functionalization(graph: fx.Graph): # NOTE: don't run dead code elimination, # otherwise this op will be removed graph.call_function( - torch.ops._C.rms_norm_static_fp8_quant.default, kwargs=kwargs) + torch.ops._C.rms_norm_static_fp8_quant.default, + kwargs=kwargs) for user in list(node.users): if user.op == 'call_function' and user.target == operator.getitem: # noqa @@ -199,8 +205,6 @@ def fix_functionalization(graph: fx.Graph): # print(graph.python_code(root_module="self", verbose=True).src, file=f) -from vllm.compilation.fusion import get_fusion_pass - fusion_pass = get_fusion_pass() @@ -219,7 +223,7 @@ def wrap_inductor(graph, example_inputs, additional_inductor_config): if current_config['post_grad_custom_post_pass'] is not None: logger.warning( "post_grad_custom_post_pass is already set in the config. " - "Overwriting it with the post_grad_post_passes") # TODO combine + "Overwriting it with the post_grad_post_passes") # TODO combine current_config['post_grad_custom_post_pass'] = post_grad_post_passes return compile_fx(graph, example_inputs, config_patches=current_config) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index e375d200d62..99cb289bf1c 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -1,48 +1,75 @@ -import logging import operator import torch -from torch._inductor.pattern_matcher import PatternMatcherPass, register_replacement, fwd_only, Match from torch._higher_order_ops.auto_functionalize import auto_functionalized +from torch._inductor.pattern_matcher import (Match, PatternMatcherPass, + fwd_only, register_replacement) from vllm import envs from vllm.logger import init_logger logger = init_logger(__name__) -def rms_pattern_static(result: torch.Tensor, result_rms: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, + +def rms_pattern_static(result: torch.Tensor, result_rms: torch.Tensor, + input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): - at1 = auto_functionalized(torch.ops._C.rms_norm.default, result=result_rms, input=input, weight=weight, + at1 = auto_functionalized(torch.ops._C.rms_norm.default, + result=result_rms, + input=input, + weight=weight, epsilon=1e-5) - at2 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default, result=result, input=at1[1], scale=scale) + at2 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default, + result=result, + input=at1[1], + scale=scale) # result return at2[1] -def rms_replacement_static(result: torch.Tensor, result_rms: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, +def rms_replacement_static(result: torch.Tensor, result_rms: torch.Tensor, + input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): - at = auto_functionalized(torch.ops._C.rms_norm_static_fp8_quant.default, result=result, input=input, - weight=weight, scale=scale, epsilon=1e-5) + at = auto_functionalized(torch.ops._C.rms_norm_static_fp8_quant.default, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=1e-5) # result return at[1] -def rms_pattern_residual_static(result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, +def rms_pattern_residual_static(result: torch.Tensor, input: torch.Tensor, + residual: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): - at = auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input=input, residual=residual, weight=weight, + at = auto_functionalized(torch.ops._C.fused_add_rms_norm.default, + input=input, + residual=residual, + weight=weight, epsilon=1e-5) - at1 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default, result=result, input=at[1], scale=scale) + at1 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default, + result=result, + input=at[1], + scale=scale) # result, residual return at1[1], at[2] -def rms_replacement_residual_static(result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, +def rms_replacement_residual_static(result: torch.Tensor, input: torch.Tensor, + residual: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): - at = auto_functionalized(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, result=result, input=input, - residual=residual, weight=weight, scale=scale, epsilon=1e-5) + at = auto_functionalized( + torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, + result=result, + input=input, + residual=residual, + weight=weight, + scale=scale, + epsilon=1e-5) # result, residual return at[1], at[2] @@ -55,14 +82,24 @@ def empty_bf16(*args, **kwargs): def empty_fp8(*args, **kwargs): - return torch.empty(*args, **kwargs, dtype=torch.float8_e4m3fn, device="cuda") + return torch.empty(*args, + **kwargs, + dtype=torch.float8_e4m3fn, + device="cuda") def get_patterns(): my_patterns = PatternMatcherPass(pass_name="fusion_pass") - inputs = [empty_fp8(5, 4), empty_bf16(5, 4), empty_bf16(5, 4), empty_bf16(1, 5), torch.empty(1, 1, device="cuda", dtype=torch.float32)] - register_replacement(rms_pattern_static, rms_replacement_static, inputs, fwd_only, my_patterns) + inputs = [ + empty_fp8(5, 4), + empty_bf16(5, 4), + empty_bf16(5, 4), + empty_bf16(1, 5), + torch.empty(1, 1, device="cuda", dtype=torch.float32) + ] + register_replacement(rms_pattern_static, rms_replacement_static, inputs, + fwd_only, my_patterns) matches = [] @@ -71,8 +108,18 @@ def record_match_fn(match: Match): return False # with residual - inputs = [empty_fp8(5, 4), empty_bf16(5, 4), empty_bf16(5, 4), empty_bf16(1, 5), torch.empty(1, 1, device="cuda", dtype=torch.float32)] - register_replacement(rms_pattern_residual_static, rms_replacement_residual_static, inputs, fwd_only, my_patterns, + inputs = [ + empty_fp8(5, 4), + empty_bf16(5, 4), + empty_bf16(5, 4), + empty_bf16(1, 5), + torch.empty(1, 1, device="cuda", dtype=torch.float32) + ] + register_replacement(rms_pattern_residual_static, + rms_replacement_residual_static, + inputs, + fwd_only, + my_patterns, extra_check=record_match_fn) return my_patterns, matches @@ -84,30 +131,38 @@ def process_matches(matches, graph: torch.fx.Graph): # TODO this is an expensive check if not all(node in nodes for node in match.nodes): raise ValueError( - f"Broken match: not all nodes in graph: {[node for node in match.nodes if node not in nodes]}") + f"Broken match: not all nodes in graph: " + f"{[node for node in match.nodes if node not in nodes]}") last_node_in_match = max(match.nodes, key=lambda x: nodes.index(x)) with graph.inserting_after(last_node_in_match): kwargs = match.kwargs kwargs["epsilon"] = 1e-5 - fused_node = graph.call_function(auto_functionalized, - (torch.ops._C.fused_add_rms_norm_static_fp8_quant.default,), - kwargs=kwargs) + fused_node = graph.call_function( + auto_functionalized, + (torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ), + kwargs=kwargs) graph.inserting_after(fused_node) - result_node_new = graph.call_function(operator.getitem, (fused_node, 1)) - residual_node_new = graph.call_function(operator.getitem, (fused_node, 2)) + result_node_new = graph.call_function(operator.getitem, + (fused_node, 1)) + residual_node_new = graph.call_function(operator.getitem, + (fused_node, 2)) + + def is_func(node, target): + return node.op == "call_function" and node.target == target # find the output and the residual - def find_auto_fn(op): + def find_auto_fn(op, + match=match): # need to bind match to silence lint for node in match.nodes: - if node.op == "call_function" and node.target == auto_functionalized and node.args[0] == op: + if is_func(node, auto_functionalized) and node.args[0] == op: return node return None def find_getitem(node, idx): for user in node.users: - if user.op == "call_function" and user.target == operator.getitem and user.args[1] == idx: + if is_func(node, operator.getitem) and user.args[1] == idx: return user return None @@ -129,7 +184,8 @@ def find_getitem(node, idx): # Finally, remove matched nodes graph.eliminate_dead_code() - assert all(node not in graph.nodes for match in matches for node in match.nodes) + assert all(node not in graph.nodes for match in matches + for node in match.nodes) def noop_pass(graph: torch.fx.Graph): @@ -146,7 +202,8 @@ def dump_graph(graph: torch.fx.Graph, stage: str): if stage in envs.VLLM_TORCH_COMPILE_FUSION_DUMP: logger.info("Printing graph to %s", f"{stage}.py") with open(f"{stage}.py", "w") as f: - print(graph.python_code(root_module="self", verbose=True).src, file=f) + print(graph.python_code(root_module="self", verbose=True).src, + file=f) def fusion_pass(graph: torch.fx.Graph): """ @@ -156,12 +213,12 @@ def fusion_pass(graph: torch.fx.Graph): dump_graph(graph, "before_fusion") count = patterns.apply(graph) - logger.info(f"Replaced {count} patterns") + logger.info("Replaced %s patterns", count) dump_graph(graph, "after_pattern_match") # Manually process multi-output matches (and run DCE) process_matches(matches, graph) - logger.info(f"Post-processed {len(matches)} matches") + logger.info("Post-processed %s matches", len(matches)) dump_graph(graph, "after_fusion") return fusion_pass diff --git a/vllm/envs.py b/vllm/envs.py index 44c34ac78b4..9d1682061e4 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -221,14 +221,13 @@ def get_default_config_root(): lambda: os.environ.get("VLLM_CUSTOM_OPS", "").replace(" ", "").split(","), # Internal flag to enable fusion in torch.compile - "VLLM_TORCH_COMPILE_FUSION": lambda: bool( - os.environ.get("VLLM_TORCH_COMPILE_FUSION", "1") != "0"), + "VLLM_TORCH_COMPILE_FUSION": + lambda: bool(os.environ.get("VLLM_TORCH_COMPILE_FUSION", "1") != "0"), # Internal flag for dumping the model graph before and after fusion "VLLM_TORCH_COMPILE_FUSION_DUMP": - lambda: list( - os.environ.get("VLLM_TORCH_COMPILE_FUSION_DUMP", "").split(",")), - + lambda: list( + os.environ.get("VLLM_TORCH_COMPILE_FUSION_DUMP", "").split(",")), # local rank of the process in the distributed setting, used to determine # the GPU device id From b053c0bbd3b502412868b096f0f2ed7e8cb8d92b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Mon, 21 Oct 2024 11:40:08 -0400 Subject: [PATCH 67/88] PR comments Remove expensive check --- vllm/compilation/fusion.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 99cb289bf1c..46b14011925 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -128,11 +128,6 @@ def record_match_fn(match: Match): def process_matches(matches, graph: torch.fx.Graph): for match in matches: nodes = list(graph.nodes) - # TODO this is an expensive check - if not all(node in nodes for node in match.nodes): - raise ValueError( - f"Broken match: not all nodes in graph: " - f"{[node for node in match.nodes if node not in nodes]}") last_node_in_match = max(match.nodes, key=lambda x: nodes.index(x)) with graph.inserting_after(last_node_in_match): kwargs = match.kwargs From 0de6baae80338af755930e0fbf9d713348e29fe3 Mon Sep 17 00:00:00 2001 From: luka Date: Tue, 22 Oct 2024 17:52:29 +0000 Subject: [PATCH 68/88] refactored fusion pass into a class --- vllm/compilation/backends.py | 4 +- vllm/compilation/fusion.py | 250 ++++++++++++++---------------- vllm/compilation/inductor_pass.py | 24 +++ vllm/envs.py | 12 +- 4 files changed, 149 insertions(+), 141 deletions(-) create mode 100644 vllm/compilation/inductor_pass.py diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 11737e1cedc..8f809f14e27 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -5,7 +5,7 @@ import torch import torch.fx as fx -from vllm.compilation.fusion import get_fusion_pass +from vllm.compilation.fusion import FusionPass from vllm.logger import init_logger from .compile_context import get_compile_context @@ -205,7 +205,7 @@ def fix_functionalization(graph: fx.Graph): # print(graph.python_code(root_module="self", verbose=True).src, file=f) -fusion_pass = get_fusion_pass() +fusion_pass = FusionPass() def post_grad_post_passes(graph: fx.Graph): diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 46b14011925..f728d5c1bb8 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -6,6 +6,7 @@ fwd_only, register_replacement) from vllm import envs +from vllm.compilation.inductor_pass import InductorPass from vllm.logger import init_logger logger = init_logger(__name__) @@ -74,146 +75,129 @@ def rms_replacement_residual_static(result: torch.Tensor, input: torch.Tensor, return at[1], at[2] -my_patterns = PatternMatcherPass() - - def empty_bf16(*args, **kwargs): return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda") def empty_fp8(*args, **kwargs): - return torch.empty(*args, - **kwargs, - dtype=torch.float8_e4m3fn, - device="cuda") - - -def get_patterns(): - my_patterns = PatternMatcherPass(pass_name="fusion_pass") - - inputs = [ - empty_fp8(5, 4), - empty_bf16(5, 4), - empty_bf16(5, 4), - empty_bf16(1, 5), - torch.empty(1, 1, device="cuda", dtype=torch.float32) - ] - register_replacement(rms_pattern_static, rms_replacement_static, inputs, - fwd_only, my_patterns) - - matches = [] - - def record_match_fn(match: Match): - matches.append(match) + fp8 = torch.float8_e4m3fn + return torch.empty(*args, **kwargs, dtype=fp8, device="cuda") + + +def empty_fp32(*args, **kwargs): + return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda") + + +class FusionPass(InductorPass): + + def __init__(self): + self.my_patterns = PatternMatcherPass(pass_name="fusion_pass") + self.matches = [] + + inputs = [ + empty_fp8(5, 4), + empty_bf16(5, 4), + empty_bf16(5, 4), + empty_bf16(1, 5), + empty_fp32(1, 1) + ] + register_replacement(rms_pattern_static, rms_replacement_static, + inputs, fwd_only, self.my_patterns) + + # with residual + inputs = [ + empty_fp8(5, 4), + empty_bf16(5, 4), + empty_bf16(5, 4), + empty_bf16(1, 5), + empty_fp32(1, 1) + ] + register_replacement(rms_pattern_residual_static, + rms_replacement_residual_static, + inputs, + fwd_only, + self.my_patterns, + extra_check=lambda m: self.record_match(m)) + + def record_match(self, match: Match) -> bool: + # TODO(luka): add better comment + self.matches.append(match) return False - # with residual - inputs = [ - empty_fp8(5, 4), - empty_bf16(5, 4), - empty_bf16(5, 4), - empty_bf16(1, 5), - torch.empty(1, 1, device="cuda", dtype=torch.float32) - ] - register_replacement(rms_pattern_residual_static, - rms_replacement_residual_static, - inputs, - fwd_only, - my_patterns, - extra_check=record_match_fn) - - return my_patterns, matches - - -def process_matches(matches, graph: torch.fx.Graph): - for match in matches: - nodes = list(graph.nodes) - last_node_in_match = max(match.nodes, key=lambda x: nodes.index(x)) - with graph.inserting_after(last_node_in_match): - kwargs = match.kwargs - kwargs["epsilon"] = 1e-5 - - fused_node = graph.call_function( - auto_functionalized, - (torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ), - kwargs=kwargs) - - graph.inserting_after(fused_node) - result_node_new = graph.call_function(operator.getitem, - (fused_node, 1)) - residual_node_new = graph.call_function(operator.getitem, - (fused_node, 2)) - - def is_func(node, target): - return node.op == "call_function" and node.target == target - - # find the output and the residual - def find_auto_fn(op, - match=match): # need to bind match to silence lint - for node in match.nodes: - if is_func(node, auto_functionalized) and node.args[0] == op: - return node - return None - - def find_getitem(node, idx): - for user in node.users: - if is_func(node, operator.getitem) and user.args[1] == idx: - return user - return None - - rms_node = find_auto_fn(torch.ops._C.fused_add_rms_norm.default) - quant_node = find_auto_fn(torch.ops._C.static_scaled_fp8_quant.default) - assert rms_node is not None - assert quant_node is not None - - assert len(rms_node.users) == 2 - assert len(quant_node.users) == 1 - - # meta["val"] is used by de-functionalization - rms_val = rms_node.meta["val"] - quant_val = quant_node.meta["val"] - fused_node.meta["val"] = (None, quant_val[1], rms_val[1], rms_val[2]) - - find_getitem(rms_node, 2).replace_all_uses_with(residual_node_new) - find_getitem(quant_node, 1).replace_all_uses_with(result_node_new) - - # Finally, remove matched nodes - graph.eliminate_dead_code() - assert all(node not in graph.nodes for match in matches - for node in match.nodes) - - -def noop_pass(graph: torch.fx.Graph): - pass - - -def get_fusion_pass(): - if not envs.VLLM_TORCH_COMPILE_FUSION: - return noop_pass - - patterns, matches = get_patterns() - - def dump_graph(graph: torch.fx.Graph, stage: str): - if stage in envs.VLLM_TORCH_COMPILE_FUSION_DUMP: - logger.info("Printing graph to %s", f"{stage}.py") - with open(f"{stage}.py", "w") as f: - print(graph.python_code(root_module="self", verbose=True).src, - file=f) - - def fusion_pass(graph: torch.fx.Graph): - """ - Use the pattern matcher - """ - matches.clear() - dump_graph(graph, "before_fusion") - - count = patterns.apply(graph) + def process_matches(self, graph: torch.fx.Graph): + # TODO(luka): add better comments (whole function) + for match in self.matches: + nodes = list(graph.nodes) + last_node_in_match = max(match.nodes, key=lambda x: nodes.index(x)) + with graph.inserting_after(last_node_in_match): + kwargs = match.kwargs + kwargs["epsilon"] = 1e-5 + + fused_node = graph.call_function( + auto_functionalized, + (torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, + ), + kwargs=kwargs) + + graph.inserting_after(fused_node) + result_node_new = graph.call_function(operator.getitem, + (fused_node, 1)) + residual_node_new = graph.call_function( + operator.getitem, (fused_node, 2)) + + def is_func(node, target): + return node.op == "call_function" and node.target == target + + # find the output and the residual + def find_auto_fn(match: Match, op): + for node in match.nodes: + if is_func(node, + auto_functionalized) and node.args[0] == op: + return node + return None + + def find_getitem(node, idx): + for user in node.users: + if is_func(node, operator.getitem) and user.args[1] == idx: + return user + return None + + rms_node = find_auto_fn(match, + torch.ops._C.fused_add_rms_norm.default) + quant_node = find_auto_fn( + match, torch.ops._C.static_scaled_fp8_quant.default) + assert rms_node is not None + assert quant_node is not None + + assert len(rms_node.users) == 2 + assert len(quant_node.users) == 1 + + # meta["val"] is used by de-functionalization + rms_val = rms_node.meta["val"] + quant_val = quant_node.meta["val"] + fused_node.meta["val"] = (None, quant_val[1], rms_val[1], + rms_val[2]) + + find_getitem(rms_node, 2).replace_all_uses_with(residual_node_new) + find_getitem(quant_node, 1).replace_all_uses_with(result_node_new) + + # Finally, remove matched nodes + graph.eliminate_dead_code() + assert all(node not in graph.nodes for match in self.matches + for node in match.nodes) + + def __call__(self, graph: torch.fx.Graph): + if not envs.VLLM_TORCH_COMPILE_FUSION: + return + + self.dump_graph(graph, "before_fusion") + + count = self.my_patterns.apply(graph) logger.info("Replaced %s patterns", count) - dump_graph(graph, "after_pattern_match") + self.dump_graph(graph, "after_pattern_match") # Manually process multi-output matches (and run DCE) - process_matches(matches, graph) - logger.info("Post-processed %s matches", len(matches)) - dump_graph(graph, "after_fusion") - - return fusion_pass + self.process_matches(graph) + logger.info("Post-processed %s matches", len(self.matches)) + self.dump_graph(graph, "after_fusion") + self.matches.clear() diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py new file mode 100644 index 00000000000..8c2f594d6c1 --- /dev/null +++ b/vllm/compilation/inductor_pass.py @@ -0,0 +1,24 @@ +from abc import ABC, abstractmethod + +import torch + +from vllm import envs +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class InductorPass(ABC): + + def dump_graph(self, graph: torch.fx.Graph, stage: str): + # TODO(luka): rename env var to VLLM_TORCH_COMPILE_DUMP + if stage in envs.VLLM_TORCH_COMPILE_DUMP: + filename = f"{stage}.py" # TODO(luka): add rank + logger.info("Printing graph to %s", filename) + with open(filename, "w") as f: + src = graph.python_code(root_module="self", verbose=True).src + print(src, file=f) + + @abstractmethod + def __call__(self, graph: torch.fx.Graph): + raise NotImplementedError diff --git a/vllm/envs.py b/vllm/envs.py index 9d1682061e4..14dd58320f8 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -66,7 +66,7 @@ VLLM_SKIP_P2P_CHECK: bool = False VLLM_TORCH_COMPILE_LEVEL: int = 0 VLLM_TORCH_COMPILE_FUSION: bool = True - VLLM_TORCH_COMPILE_FUSION_DUMP: List[str] = [] + VLLM_TORCH_COMPILE_DUMP: List[str] = [] VLLM_CUSTOM_OPS: List[str] = [] VLLM_DISABLED_KERNELS: List[str] = [] @@ -220,14 +220,14 @@ def get_default_config_root(): "VLLM_CUSTOM_OPS": lambda: os.environ.get("VLLM_CUSTOM_OPS", "").replace(" ", "").split(","), - # Internal flag to enable fusion in torch.compile + # Internal flag to enable fusion in torch.compile (default on) "VLLM_TORCH_COMPILE_FUSION": lambda: bool(os.environ.get("VLLM_TORCH_COMPILE_FUSION", "1") != "0"), - # Internal flag for dumping the model graph before and after fusion - "VLLM_TORCH_COMPILE_FUSION_DUMP": - lambda: list( - os.environ.get("VLLM_TORCH_COMPILE_FUSION_DUMP", "").split(",")), + # Internal flag for dumping the model graph at different stages of + # custom pass compilation + "VLLM_TORCH_COMPILE_DUMP": + lambda: list(os.environ.get("VLLM_TORCH_COMPILE_DUMP", "").split(",")), # local rank of the process in the distributed setting, used to determine # the GPU device id From f3e7d315b1ac5e12fcfd3278e72bbc9e8527a92b Mon Sep 17 00:00:00 2001 From: luka Date: Tue, 22 Oct 2024 18:11:51 +0000 Subject: [PATCH 69/88] PR comments: backends.py - fix_func pass fixed - support additional passes through config --- vllm/compilation/backends.py | 35 +++++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 8f809f14e27..89cb799a0d2 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -105,7 +105,8 @@ def fix_functionalization(graph: fx.Graph): result = kwargs['result'] residual = kwargs['residual'] - # Create a new call to torch.ops._C.rotary_embedding.default + # Create a new call to + # torch.ops._C.fused_add_rms_norm_static_fp8_quant.default with graph.inserting_before(node): # just insert the call to the custom op # NOTE: don't run dead code elimination, @@ -132,8 +133,7 @@ def fix_functionalization(graph: fx.Graph): kwargs = node.kwargs replace_node = kwargs['result'] - # Create a new call to torch.ops._C.rotary_embedding.default - # cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa + # Create a new call to torch.ops._C.rms_norm.default with graph.inserting_before(node): # just insert the call to the custom op # NOTE: don't run dead code elimination, @@ -148,14 +148,13 @@ def fix_functionalization(graph: fx.Graph): nodes_to_remove.append(node) elif node.args[ - 0] == torch.ops._C.rms_norm_static_fp8_quant.default: - # manual replace for rms_norm + 0] == torch.ops._C.rms_norm_static_fp8_quant.default: # noqa + # manual replace for rms_norm_static_fp8_quant kwargs = node.kwargs replace_node = kwargs['result'] - # Create a new call to torch.ops._C.rotary_embedding.default - # cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa + # Create a new call to torch.ops._C.rms_norm_static_fp8_quant.default # noqa with graph.inserting_before(node): # just insert the call to the custom op # NOTE: don't run dead code elimination, @@ -178,7 +177,7 @@ def fix_functionalization(graph: fx.Graph): input = kwargs['input'] out = kwargs['out'] - # Create a new call to torch.ops._C.rotary_embedding.default + # Create a new call to torch.ops._C.silu_and_mul.default # cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa with graph.inserting_before(node): # just insert the call to the custom op @@ -208,7 +207,7 @@ def fix_functionalization(graph: fx.Graph): fusion_pass = FusionPass() -def post_grad_post_passes(graph: fx.Graph): +def default_post_grad_post_passes(graph: fx.Graph): fusion_pass(graph) fix_functionalization(graph) @@ -220,11 +219,20 @@ def wrap_inductor(graph, example_inputs, additional_inductor_config): if additional_inductor_config is not None: current_config.update(additional_inductor_config) + + # If a custom post pass is given in config, + # run it after the default post passes. if current_config['post_grad_custom_post_pass'] is not None: - logger.warning( - "post_grad_custom_post_pass is already set in the config. " - "Overwriting it with the post_grad_post_passes") # TODO combine - current_config['post_grad_custom_post_pass'] = post_grad_post_passes + config_pass = current_config['post_grad_custom_post_pass'] + + def combined_pass(graph): + default_post_grad_post_passes(graph) + config_pass(graph) + + current_config['post_grad_custom_post_pass'] = combined_pass + + current_config[ + 'post_grad_custom_post_pass'] = default_post_grad_post_passes # noqa return compile_fx(graph, example_inputs, config_patches=current_config) @@ -232,7 +240,6 @@ def vllm_backend( graph, example_inputs, additional_inductor_config: Optional[Dict] = None) -> Callable: - context = get_compile_context() context = copy.deepcopy(context) if context is not None else [] sizes_to_specialize: List[int] = context From b2ab033bdbb0705a958ad0a942f94481624acf5a Mon Sep 17 00:00:00 2001 From: luka Date: Tue, 22 Oct 2024 19:42:04 +0000 Subject: [PATCH 70/88] Fix node bug in find_getitem --- vllm/compilation/fusion.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index f728d5c1bb8..c67aec34980 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -94,6 +94,8 @@ def __init__(self): self.my_patterns = PatternMatcherPass(pass_name="fusion_pass") self.matches = [] + # Fuse rms_norm + static_scaled_fp8_quant into + # rms_norm_static_fp8_quant inputs = [ empty_fp8(5, 4), empty_bf16(5, 4), @@ -104,7 +106,10 @@ def __init__(self): register_replacement(rms_pattern_static, rms_replacement_static, inputs, fwd_only, self.my_patterns) - # with residual + # Fuse fused_add_rms_norm + static_scaled_fp8_quant into + # fused_add_rms_norm_static_fp8_quant + # Because pattern has 2 outputs, we need to manually process the match + # (see process_matches) inputs = [ empty_fp8(5, 4), empty_bf16(5, 4), @@ -120,12 +125,19 @@ def __init__(self): extra_check=lambda m: self.record_match(m)) def record_match(self, match: Match) -> bool: - # TODO(luka): add better comment + # Hijack the extra_check to record the match and + # save it for post-processing. self.matches.append(match) + + # Return False to prevent automatic replacement. return False def process_matches(self, graph: torch.fx.Graph): - # TODO(luka): add better comments (whole function) + """ + Manually process multi-output matches and replace them with fused nodes. + This is necessary because the automatic replacement for multi-output + matches is broken: https://github.com/pytorch/pytorch/issues/137280 + """ for match in self.matches: nodes = list(graph.nodes) last_node_in_match = max(match.nodes, key=lambda x: nodes.index(x)) @@ -151,14 +163,13 @@ def is_func(node, target): # find the output and the residual def find_auto_fn(match: Match, op): for node in match.nodes: - if is_func(node, - auto_functionalized) and node.args[0] == op: + if is_func(node, auto_functionalized) and node.args[0] == op: # noqa return node return None def find_getitem(node, idx): for user in node.users: - if is_func(node, operator.getitem) and user.args[1] == idx: + if is_func(user, operator.getitem) and user.args[1] == idx: return user return None From 86b79dd6a0e4d47fa4f4fd32e7f6b45e0cc2ac7c Mon Sep 17 00:00:00 2001 From: luka Date: Tue, 22 Oct 2024 21:12:43 +0000 Subject: [PATCH 71/88] PR comments: - add TP rank to dump filename - improve search for last node in match - massively improve comments Signed-off-by: luka --- vllm/compilation/fusion.py | 104 ++++++++++++++++++++++-------- vllm/compilation/inductor_pass.py | 9 ++- 2 files changed, 84 insertions(+), 29 deletions(-) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index c67aec34980..ba8abf0d973 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -1,4 +1,5 @@ import operator +from typing import Iterable, Optional import torch from torch._higher_order_ops.auto_functionalize import auto_functionalized @@ -88,6 +89,44 @@ def empty_fp32(*args, **kwargs): return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda") +# Utilities for post-processing multi-output matches +def is_func(node: torch.fx.Node, target) -> bool: + return node.op == "call_function" and node.target == target + + +# Returns the first auto_functionalized node with the given op (if it exists) +def find_auto_fn_maybe(nodes: Iterable[torch.fx.Node], + op) -> Optional[torch.fx.Node]: + for node in nodes: + if is_func(node, auto_functionalized) and node.args[0] == op: # noqa + return node + return None + + +# Returns the first auto_functionalized node with the given op +def find_auto_fn(nodes: Iterable[torch.fx.Node], op) -> torch.fx.Node: + node = find_auto_fn_maybe(nodes, op) + assert node is not None, f"Could not find {op} in nodes {nodes}" + return node + + +# Returns the getitem node that extracts the idx-th element from node +# (if it exists) +def find_getitem_maybe(node: torch.fx.Node, + idx: int) -> Optional[torch.fx.Node]: + for user in node.users: + if is_func(user, operator.getitem) and user.args[1] == idx: + return user + return None + + +# Returns the getitem node that extracts the idx-th element from node +def find_getitem(node: torch.fx.Node, idx: int) -> torch.fx.Node: + ret = find_getitem_maybe(node, idx) + assert ret is not None, f"Could not find getitem {idx} in node {node}" + return ret + + class FusionPass(InductorPass): def __init__(self): @@ -139,11 +178,27 @@ def process_matches(self, graph: torch.fx.Graph): matches is broken: https://github.com/pytorch/pytorch/issues/137280 """ for match in self.matches: - nodes = list(graph.nodes) - last_node_in_match = max(match.nodes, key=lambda x: nodes.index(x)) + # To avoid use-before-definition errors, insert replacement nodes + # after the last node in the match. + # match.nodes is not guaranteed to be sorted. + # Find the last node in the match. + for last_node_in_match in reversed(graph.nodes): + if last_node_in_match in match.nodes: + break + else: + raise ValueError("No nodes in graph") + + # Insert a new auto_functionalized node for the fused operation, + # as well as getitem nodes to extract the result and residual. + # The auto_functionalized node returns a tuple of + # (None, result, residual) - None is the function return value. + # The resulting graph looks like this: + # at = auto_functionalized(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ...) # noqa + # result_node_new = at[1] + # residual_node_new = at[2] with graph.inserting_after(last_node_in_match): kwargs = match.kwargs - kwargs["epsilon"] = 1e-5 + kwargs["epsilon"] = 1e-5 # Currently hard-coded in RMSNorm fused_node = graph.call_function( auto_functionalized, @@ -157,38 +212,33 @@ def process_matches(self, graph: torch.fx.Graph): residual_node_new = graph.call_function( operator.getitem, (fused_node, 2)) - def is_func(node, target): - return node.op == "call_function" and node.target == target - - # find the output and the residual - def find_auto_fn(match: Match, op): - for node in match.nodes: - if is_func(node, auto_functionalized) and node.args[0] == op: # noqa - return node - return None + # Last part of replacement is rebinding the users of nodes in the + # match to use the new nodes. - def find_getitem(node, idx): - for user in node.users: - if is_func(user, operator.getitem) and user.args[1] == idx: - return user - return None - - rms_node = find_auto_fn(match, + # Find the nodes in the match that we need to rebind + rms_node = find_auto_fn(match.nodes, torch.ops._C.fused_add_rms_norm.default) quant_node = find_auto_fn( - match, torch.ops._C.static_scaled_fp8_quant.default) - assert rms_node is not None - assert quant_node is not None + match.nodes, torch.ops._C.static_scaled_fp8_quant.default) assert len(rms_node.users) == 2 assert len(quant_node.users) == 1 - # meta["val"] is used by de-functionalization - rms_val = rms_node.meta["val"] - quant_val = quant_node.meta["val"] - fused_node.meta["val"] = (None, quant_val[1], rms_val[1], - rms_val[2]) + # meta["val"] is used by de-functionalization and has to contain the + # value of the node (tuple of tensors) that would be returned by the + # functionalized node during tracing. + + rms_tup = rms_node.meta["val"] + quant_tup = quant_node.meta["val"] + + # The result of fused_node must be a tuple with the first element + # None (the function return value) and the remaining elements + # representing the mutated inputs. + fused_tup = (None, quant_tup[1], rms_tup[1], rms_tup[2]) + fused_node.meta["val"] = fused_tup + # Find the getitem nodes and replace their uses with the new nodes. + # The old nodes will be removed by DCE at the end of the pass. find_getitem(rms_node, 2).replace_all_uses_with(residual_node_new) find_getitem(quant_node, 1).replace_all_uses_with(result_node_new) diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 8c2f594d6c1..4d75841947e 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -3,6 +3,9 @@ import torch from vllm import envs +from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank +from vllm.distributed import ( + get_tensor_model_parallel_world_size as get_tp_world_size) from vllm.logger import init_logger logger = init_logger(__name__) @@ -11,9 +14,11 @@ class InductorPass(ABC): def dump_graph(self, graph: torch.fx.Graph, stage: str): - # TODO(luka): rename env var to VLLM_TORCH_COMPILE_DUMP if stage in envs.VLLM_TORCH_COMPILE_DUMP: - filename = f"{stage}.py" # TODO(luka): add rank + # Make sure filename includes rank in the distributed setting + rank = f"-{get_tp_rank()}" if get_tp_world_size() > 1 else "" + filename = f"{stage}{rank}.py" + logger.info("Printing graph to %s", filename) with open(filename, "w") as f: src = graph.python_code(root_module="self", verbose=True).src From e3d3f09b18a2d8cd04d5b4271717647713c5ecdb Mon Sep 17 00:00:00 2001 From: luka Date: Tue, 22 Oct 2024 22:07:24 +0000 Subject: [PATCH 72/88] PR comments: - remove unnecessary includes - add comment in layernorm quant Signed-off-by: luka --- csrc/layernorm_kernels.cu | 7 ++----- csrc/layernorm_quant_kernels.cu | 13 +++++++++---- csrc/quantization/fp8/common.cu | 9 ++++++--- csrc/quantization/fp8/common.cuh | 16 +--------------- csrc/type_convert.cuh | 2 -- 5 files changed, 18 insertions(+), 29 deletions(-) diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 8b521815735..fb6882f3e7c 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -1,12 +1,9 @@ #include "type_convert.cuh" -#include "quantization/fp8/common.cuh" +#include "dispatch_utils.h" -#include -#include +#include #include -#include "dispatch_utils.h" - #ifndef USE_ROCM #include #else diff --git a/csrc/layernorm_quant_kernels.cu b/csrc/layernorm_quant_kernels.cu index 35c9232dcd9..c18e2a4e4ab 100644 --- a/csrc/layernorm_quant_kernels.cu +++ b/csrc/layernorm_quant_kernels.cu @@ -1,12 +1,17 @@ +/* + * This file contains the CUDA kernels for the fused quantized layernorm. + * The kernels correspond to the kernels in layernorm_kernels.cu, except they + * also produce quantized output directly. + * Currently, only static fp8 quantization is supported. + */ + #include "type_convert.cuh" #include "quantization/fp8/common.cuh" +#include "dispatch_utils.h" -#include -#include +#include #include -#include "dispatch_utils.h" - #ifndef USE_ROCM #include #else diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index 5c5f8331a53..e4f6615ede1 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -1,11 +1,14 @@ #include "common.cuh" - -#include "cuda_compat.h" #include "dispatch_utils.h" -#include #include +#ifndef USE_ROCM + #include +#else + #include +#endif + namespace vllm { template diff --git a/csrc/quantization/fp8/common.cuh b/csrc/quantization/fp8/common.cuh index 2c67e23ea23..fd9c7f98a79 100644 --- a/csrc/quantization/fp8/common.cuh +++ b/csrc/quantization/fp8/common.cuh @@ -1,22 +1,8 @@ #pragma once -// TODO(luka) remove unnecessary includes -#include "cuda_compat.h" -#include "dispatch_utils.h" - -#include -#include -#include +#include #include -#ifndef USE_ROCM - #include - #include -#else - #include - #include -#endif - #ifndef USE_ROCM using FP8_TYPE = c10::Float8_e4m3fn; C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = diff --git a/csrc/type_convert.cuh b/csrc/type_convert.cuh index 8840bb9ddca..21b9d0ae515 100644 --- a/csrc/type_convert.cuh +++ b/csrc/type_convert.cuh @@ -5,11 +5,9 @@ #ifndef USE_ROCM #include #include -// #include #else #include #include - #include using __nv_bfloat16 = __hip_bfloat16; using __nv_bfloat162 = __hip_bfloat162; From a40aba7ff20ef14103f465ca35f8c5bfaa111be6 Mon Sep 17 00:00:00 2001 From: luka Date: Tue, 22 Oct 2024 22:13:28 +0000 Subject: [PATCH 73/88] yapf fix Signed-off-by: luka --- vllm/compilation/inductor_pass.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 4d75841947e..1ca768dfc88 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -3,9 +3,11 @@ import torch from vllm import envs +# yapf: disable from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank from vllm.distributed import ( get_tensor_model_parallel_world_size as get_tp_world_size) +# yapf: enable from vllm.logger import init_logger logger = init_logger(__name__) From 46420f023c58cd8d2eff3eda11e4ab79fe3dbbc6 Mon Sep 17 00:00:00 2001 From: luka Date: Fri, 25 Oct 2024 21:17:01 +0000 Subject: [PATCH 74/88] Unit test for rmsnorm-quant fusion Signed-off-by: luka --- tests/compile/backend.py | 33 ++++++++++++++ tests/compile/test_fusion.py | 88 ++++++++++++++++++++++++++++++++++++ 2 files changed, 121 insertions(+) create mode 100644 tests/compile/backend.py create mode 100644 tests/compile/test_fusion.py diff --git a/tests/compile/backend.py b/tests/compile/backend.py new file mode 100644 index 00000000000..c06c15bb179 --- /dev/null +++ b/tests/compile/backend.py @@ -0,0 +1,33 @@ +from copy import deepcopy +from typing import Callable + +import torch + + +class TestBackend(): + """ + This class provides a simple Inductor backend that can be used for testing. + It takes a list of custom passes and runs them after Inductor's passes. + It also saves the graph before and after the custom passes for inspection. + """ + + def __init__(self, *args: Callable[[torch.fx.Graph], None]): + self.custom_passes = args + from torch._inductor import config + self.current_config = config.shallow_copy_dict() + self.current_config['post_grad_custom_post_pass'] = self.post_pass + + def __call__(self, graph: torch.fx.GraphModule, example_inputs): + from torch._inductor.compile_fx import compile_fx + return compile_fx(graph, + example_inputs, + config_patches=self.current_config) + + def post_pass(self, graph: torch.fx.Graph): + self.graph_pre_pass = deepcopy(graph) + for pass_ in self.custom_passes: + pass_(graph) + + self.graph_post_pass = deepcopy(graph) + # assign by reference, will reflect the final state of the graph + self.final_graph = graph diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py new file mode 100644 index 00000000000..81cd66795d3 --- /dev/null +++ b/tests/compile/test_fusion.py @@ -0,0 +1,88 @@ +import pytest +import torch +from compressed_tensors.quantization import FP8_DTYPE + +from vllm._custom_ops import cutlass_scaled_mm, scaled_fp8_quant +from vllm.compilation.fusion import (FusionPass, find_auto_fn, + find_auto_fn_maybe) +from vllm.model_executor.layers.layernorm import RMSNorm + +from .backend import TestBackend + + +class TestModel(torch.nn.Module): + + def __init__(self, hidden_size: int, eps: float, *args, **kwargs): + super().__init__(*args, **kwargs) + self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] + self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(4)] + self.w = [ + torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() + for _ in range(2) + ] + + def forward(self, x): + resid = torch.relu(x) + y = self.norm[0](x) + yq, s0 = scaled_fp8_quant(y, self.scale[0]) + x2 = cutlass_scaled_mm(yq, + self.w[0], + s0, + self.scale[1], + out_dtype=x.dtype) + # make sure resid is used for replacement to work + y2, resid = self.norm[1](x2, resid) + yq2, s2 = scaled_fp8_quant(y2, self.scale[2]) + x3 = cutlass_scaled_mm(yq2, + self.w[1], + s2, + self.scale[3], + out_dtype=x.dtype) + y3, resid = self.norm[2](x3, resid) # use resid here + return y3 + + +# Init does pattern registration, which can only happen once +fusion_pass = FusionPass() + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("hidden_size", [64, 3392, 4096]) +@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049]) +@pytest.mark.parametrize("eps", [1e-5, 1e-6]) +def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps): + torch.set_default_device("cuda") + torch.set_default_dtype(torch.float16) + + if eps != 1e-5: + pytest.skip("Only test eps=1e-5 for now") + + backend = TestBackend(fusion_pass) + model = TestModel(hidden_size, eps) + + x = torch.rand(num_tokens, hidden_size) + result = model(x) + + model2 = torch.compile(model, backend=backend) + result2 = model2(x) + + # Check that it gives the same answer + torch.testing.assert_close(result, result2, atol=1e-3, rtol=1e-3) + + # Check substitution worked + pre_nodes = backend.graph_pre_pass.nodes + post_nodes = backend.graph_post_pass.nodes + + rms_quant = torch.ops._C.rms_norm_static_fp8_quant.default + add_rms_quant = torch.ops._C.fused_add_rms_norm_static_fp8_quant.default + fp8_quant = torch.ops._C.static_scaled_fp8_quant.default + + # In pre-nodes, fp8 quant should be present and fused kernels should not + assert find_auto_fn_maybe(pre_nodes, rms_quant) is None + assert find_auto_fn_maybe(pre_nodes, add_rms_quant) is None + find_auto_fn(pre_nodes, fp8_quant) + + # In post-nodes, fused kernels should be present and fp8 quant should not + find_auto_fn(post_nodes, rms_quant) + find_auto_fn(post_nodes, add_rms_quant) + assert find_auto_fn_maybe(post_nodes, fp8_quant) is None From a1c3d91aa119d3cd93da28875ffca797f2f3be41 Mon Sep 17 00:00:00 2001 From: luka Date: Fri, 25 Oct 2024 21:29:38 +0000 Subject: [PATCH 75/88] Skip on non-CUDA Signed-off-by: luka --- tests/compile/test_fusion.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 81cd66795d3..456100747c8 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -2,6 +2,7 @@ import torch from compressed_tensors.quantization import FP8_DTYPE +import vllm.envs as envs from vllm._custom_ops import cutlass_scaled_mm, scaled_fp8_quant from vllm.compilation.fusion import (FusionPass, find_auto_fn, find_auto_fn_maybe) @@ -50,6 +51,8 @@ def forward(self, x): @pytest.mark.parametrize("hidden_size", [64, 3392, 4096]) @pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", + reason="Only test on CUDA") def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps): torch.set_default_device("cuda") torch.set_default_dtype(torch.float16) From bc5d6ba80d46b591bd3f9d7860c541c9b69a657d Mon Sep 17 00:00:00 2001 From: luka Date: Tue, 29 Oct 2024 18:56:38 +0000 Subject: [PATCH 76/88] Fix FP8 HIP type in common.cuh Signed-off-by: luka --- csrc/quantization/fp8/common.cuh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/csrc/quantization/fp8/common.cuh b/csrc/quantization/fp8/common.cuh index fd9c7f98a79..d7c0297d533 100644 --- a/csrc/quantization/fp8/common.cuh +++ b/csrc/quantization/fp8/common.cuh @@ -1,13 +1,14 @@ #pragma once -#include #include #ifndef USE_ROCM + #include using FP8_TYPE = c10::Float8_e4m3fn; C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits::max(); #else + #include #include "amd/hip_float8.h" using FP8_TYPE = c10::Float8_e4m3fnuz; // Using the default max value from pytorch (240.0) will cause accuracy From 1966e6a3550d50764bd859d2f5d51e563395317e Mon Sep 17 00:00:00 2001 From: luka Date: Tue, 29 Oct 2024 19:16:30 +0000 Subject: [PATCH 77/88] Fix seed_everything Signed-off-by: luka --- tests/kernels/test_layernorm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/test_layernorm.py b/tests/kernels/test_layernorm.py index 02af9d64c7d..727769e0718 100644 --- a/tests/kernels/test_layernorm.py +++ b/tests/kernels/test_layernorm.py @@ -78,7 +78,7 @@ def test_fused_rms_norm_quant( seed: int, device: str, ) -> None: - seed_everything(seed) + current_platform.seed_everything(seed) torch.set_default_device(device) weight = torch.empty(hidden_size, dtype=dtype).normal_(mean=1.0, std=0.1) From 980b56dd5cec929bbe321cc04a84cbeeac8c967e Mon Sep 17 00:00:00 2001 From: luka Date: Wed, 30 Oct 2024 18:17:36 +0000 Subject: [PATCH 78/88] Add support for passes to VllmBackend, add fusion back in Signed-off-by: luka --- vllm/compilation/backends.py | 34 +++++++++++++++++++++++++++++----- vllm/compilation/config.py | 14 +------------- 2 files changed, 30 insertions(+), 18 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index d5b924c03de..49eae023c68 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -1,17 +1,18 @@ import copy import dataclasses import operator -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, + Union) import torch import torch.fx as fx from vllm.logger import init_logger -from vllm.utils import weak_ref_tensors +from vllm.utils import combine_fx_passes, weak_ref_tensors -from .fusion import FusionPass from .config import CompilationConfig from .counter import compilation_counter +from .fusion import FusionPass from .levels import CompilationLevel logger = init_logger(__name__) @@ -206,6 +207,7 @@ def fix_functionalization(graph: fx.Graph): # with open("after.py", "w") as f: # print(graph.python_code(root_module="self", verbose=True).src, file=f) + def wrap_inductor(graph, example_inputs, additional_inductor_config, @@ -296,6 +298,13 @@ class VllmBackend: The major work of this backend is to split the graph into piecewise graphs, and pass them to the piecewise backend. + + This backend also handles custom passes and adds them to Inductor config. + The order of the post-grad post-passes is: + 1. post_grad_passes (constructor parameter) + 2. config["post_grad_custom_post_pass"] + 3. fix_functionalization + This way, all passes operate on a functionalized graph. """ compilation_configs: CompilationConfig @@ -307,14 +316,27 @@ class VllmBackend: split_gm: fx.GraphModule piecewise_graphs: List[SplitItem] returned_callable: Callable + # Inductor passes to run on the graph pre-defunctionalization + post_grad_passes: Sequence[Callable] - def __init__(self, ): + def __init__(self, post_grad_passes: Sequence[Callable]): # every instance of VllmBackend has its own graph pool self.graph_pool = torch.cuda.graph_pool_handle() + self.post_grad_passes = post_grad_passes # `torch.compile` is JIT compiled, so we don't need to # do anything here + def add_passes_to_config(self): + config = self.compilation_configs.inductor_compile_config + custom_postgrad_pass = config["post_grad_custom_post_pass"] + passes = list(self.post_grad_passes) + if custom_postgrad_pass is not None: + passes = passes + [custom_postgrad_pass] + + passes = passes + [fix_functionalization] + config["post_grad_custom_post_pass"] = combine_fx_passes(passes) + def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: compilation_counter.num_graphs_seen += 1 @@ -328,6 +350,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: # we get the sizes to capture for cudagraph # from compilation context self.compilation_configs = CompilationConfig.select_and_init_config() + self.add_passes_to_config() self.split_gm, self.piecewise_graphs = split_graph( graph, self.compilation_configs.non_cudagraph_ops) @@ -528,4 +551,5 @@ def select_default_backend(level: int) -> Union[str, Callable]: return backend_str assert level == CompilationLevel.PIECEWISE - return VllmBackend() + passes = [FusionPass()] + return VllmBackend(passes) diff --git a/vllm/compilation/config.py b/vllm/compilation/config.py index 514f2b93ef6..07724a5f78f 100644 --- a/vllm/compilation/config.py +++ b/vllm/compilation/config.py @@ -81,7 +81,7 @@ def model_post_init(self, __context: Any) -> None: if not isinstance(v, str): assert callable(v), ( f"pass {k} should be a function or a qualified name") - self.inductor_passes[k] = v + self.inductor_compile_config[k] = v continue # resolve function from qualified name @@ -91,18 +91,6 @@ def model_post_init(self, __context: Any) -> None: func = __import__(module).__dict__[func_name] self.inductor_compile_config[k] = func - from vllm.compilation.backends import fix_functionalization - from vllm.utils import combine_fx_passes - if "post_grad_custom_post_pass" in self.inductor_compile_config: - self.inductor_compile_config[ - "post_grad_custom_post_pass"] = combine_fx_passes( - fix_functionalization, - self.inductor_compile_config["post_grad_custom_post_pass"], - ) - else: - self.inductor_compile_config[ - "post_grad_custom_post_pass"] = fix_functionalization - def init_during_runtime(self): """To complete the initialization of config, we need to know the compile context, which is only available From 8b2def5217b6b7033679b1bf82767fc27c61f0d9 Mon Sep 17 00:00:00 2001 From: luka Date: Thu, 31 Oct 2024 16:06:20 +0000 Subject: [PATCH 79/88] PR comments: - move extra env variables to config - fix pass adding logic Signed-off-by: luka --- vllm/compilation/backends.py | 22 ++++++++++++++-------- vllm/compilation/config.py | 11 +++++++++++ vllm/compilation/fusion.py | 13 ++++++------- vllm/compilation/inductor_pass.py | 21 ++++++++++++--------- vllm/envs.py | 12 +----------- 5 files changed, 44 insertions(+), 35 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 49eae023c68..4445e1e107d 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -319,7 +319,7 @@ class VllmBackend: # Inductor passes to run on the graph pre-defunctionalization post_grad_passes: Sequence[Callable] - def __init__(self, post_grad_passes: Sequence[Callable]): + def __init__(self, post_grad_passes: Sequence[Callable] = ()): # every instance of VllmBackend has its own graph pool self.graph_pool = torch.cuda.graph_pool_handle() self.post_grad_passes = post_grad_passes @@ -328,14 +328,21 @@ def __init__(self, post_grad_passes: Sequence[Callable]): # do anything here def add_passes_to_config(self): - config = self.compilation_configs.inductor_compile_config - custom_postgrad_pass = config["post_grad_custom_post_pass"] + config = self.compilation_configs passes = list(self.post_grad_passes) - if custom_postgrad_pass is not None: - passes = passes + [custom_postgrad_pass] + if config.enable_fusion: + passes = passes + [FusionPass(config)] + + inductor_config = config.inductor_compile_config + if "post_grad_custom_post_pass" in inductor_config: + passes = passes + [inductor_config["post_grad_custom_post_pass"]] + + # add the fix_functionalization pass last, so that all other + # passes operate on a functionalized graph passes = passes + [fix_functionalization] - config["post_grad_custom_post_pass"] = combine_fx_passes(passes) + combined_pass = combine_fx_passes(passes) + inductor_config["post_grad_custom_post_pass"] = combined_pass def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: @@ -551,5 +558,4 @@ def select_default_backend(level: int) -> Union[str, Callable]: return backend_str assert level == CompilationLevel.PIECEWISE - passes = [FusionPass()] - return VllmBackend(passes) + return VllmBackend() diff --git a/vllm/compilation/config.py b/vllm/compilation/config.py index 07724a5f78f..72377533140 100644 --- a/vllm/compilation/config.py +++ b/vllm/compilation/config.py @@ -1,4 +1,5 @@ import copy +from pathlib import Path from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field, PrivateAttr @@ -50,6 +51,12 @@ class CompilationConfig(BaseModel): name because the config uses json format. If we pass the config from Python, functions can also be passed directly via Python object constructor, e.g. `CompilationConfig(inductor_passes={"a": func})` + - Custom inductor passes: + - dump_graph_stages: list of stages for which we want to dump the graph. + Each pass defines its own stages (before, after, maybe in-between). + - dump_graph_dir: directory to dump the graph. Default is . + - enable_fusion: whether to enable the custom fusion pass. + TODO better pass enabling system. Why we have different sizes for cudagraph and inductor: - cudagraph: a cudagraph captured for a specific size can only be used @@ -72,6 +79,10 @@ class CompilationConfig(BaseModel): cudagraph_num_of_warmups: int = 0 cudagraph_capture_sizes: Optional[List[int]] = None + dump_graph_stages: List[str] = Field(default_factory=list) + dump_graph_dir: Path = Field(default=Path(".")) + enable_fusion: bool = True + # not configurable, computed after init compile_sizes: List[int] = PrivateAttr capture_sizes: List[int] = PrivateAttr diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index ba8abf0d973..87e961893d3 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -1,12 +1,12 @@ import operator -from typing import Iterable, Optional +from typing import Iterable, List, Optional import torch from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import (Match, PatternMatcherPass, fwd_only, register_replacement) -from vllm import envs +from vllm.compilation.config import CompilationConfig from vllm.compilation.inductor_pass import InductorPass from vllm.logger import init_logger @@ -129,9 +129,11 @@ def find_getitem(node: torch.fx.Node, idx: int) -> torch.fx.Node: class FusionPass(InductorPass): - def __init__(self): + def __init__(self, config: CompilationConfig): + super().__init__(config) + self.my_patterns = PatternMatcherPass(pass_name="fusion_pass") - self.matches = [] + self.matches: List[Match] = [] # Fuse rms_norm + static_scaled_fp8_quant into # rms_norm_static_fp8_quant @@ -248,9 +250,6 @@ def process_matches(self, graph: torch.fx.Graph): for node in match.nodes) def __call__(self, graph: torch.fx.Graph): - if not envs.VLLM_TORCH_COMPILE_FUSION: - return - self.dump_graph(graph, "before_fusion") count = self.my_patterns.apply(graph) diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 1ca768dfc88..97cc8b59c0d 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -2,7 +2,7 @@ import torch -from vllm import envs +from vllm.compilation.config import CompilationConfig # yapf: disable from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank from vllm.distributed import ( @@ -15,17 +15,20 @@ class InductorPass(ABC): + @abstractmethod + def __call__(self, graph: torch.fx.Graph): + raise NotImplementedError + + def __init__(self, config: CompilationConfig): + self.config = config + def dump_graph(self, graph: torch.fx.Graph, stage: str): - if stage in envs.VLLM_TORCH_COMPILE_DUMP: + if stage in self.config.dump_graph_stages: # Make sure filename includes rank in the distributed setting rank = f"-{get_tp_rank()}" if get_tp_world_size() > 1 else "" - filename = f"{stage}{rank}.py" + filepath = self.config.dump_graph_dir / f"{stage}{rank}.py" - logger.info("Printing graph to %s", filename) - with open(filename, "w") as f: + logger.info("Printing graph to %s", filepath) + with open(filepath, "w") as f: src = graph.python_code(root_module="self", verbose=True).src print(src, file=f) - - @abstractmethod - def __call__(self, graph: torch.fx.Graph): - raise NotImplementedError diff --git a/vllm/envs.py b/vllm/envs.py index d9dc6d2b92c..8c11553b2f2 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -67,8 +67,7 @@ VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_SKIP_P2P_CHECK: bool = False VLLM_TORCH_COMPILE_LEVEL: int = 0 - VLLM_TORCH_COMPILE_FUSION: bool = True - VLLM_TORCH_COMPILE_DUMP: List[str] = [] + VLLM_TORCH_COMPILE_CONFIG: Optional[str] = None VLLM_CUSTOM_OPS: List[str] = [] VLLM_DISABLED_KERNELS: List[str] = [] VLLM_USE_V1: bool = False @@ -228,15 +227,6 @@ def get_default_config_root(): "VLLM_CUSTOM_OPS": lambda: os.environ.get("VLLM_CUSTOM_OPS", "").replace(" ", "").split(","), - # Internal flag to enable fusion in torch.compile (default on) - "VLLM_TORCH_COMPILE_FUSION": - lambda: bool(os.environ.get("VLLM_TORCH_COMPILE_FUSION", "1") != "0"), - - # Internal flag for dumping the model graph at different stages of - # custom pass compilation - "VLLM_TORCH_COMPILE_DUMP": - lambda: list(os.environ.get("VLLM_TORCH_COMPILE_DUMP", "").split(",")), - # local rank of the process in the distributed setting, used to determine # the GPU device id "LOCAL_RANK": From 619d634f8f92cffd16273686ccc1c008fea4964f Mon Sep 17 00:00:00 2001 From: luka Date: Thu, 31 Oct 2024 16:09:35 +0000 Subject: [PATCH 80/88] Fix fusion pass init in test Signed-off-by: luka --- tests/compile/test_fusion.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 456100747c8..a5abbfae4ff 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -4,6 +4,7 @@ import vllm.envs as envs from vllm._custom_ops import cutlass_scaled_mm, scaled_fp8_quant +from vllm.compilation.config import CompilationConfig from vllm.compilation.fusion import (FusionPass, find_auto_fn, find_auto_fn_maybe) from vllm.model_executor.layers.layernorm import RMSNorm @@ -44,7 +45,7 @@ def forward(self, x): # Init does pattern registration, which can only happen once -fusion_pass = FusionPass() +fusion_pass = FusionPass(CompilationConfig(enable_fusion=True)) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) From f47e3585703d127617a7ff26442e9b76ddbccf72 Mon Sep 17 00:00:00 2001 From: luka Date: Thu, 31 Oct 2024 16:56:50 +0000 Subject: [PATCH 81/88] Fusion test: use apply_fp8_linear Signed-off-by: luka --- tests/compile/test_fusion.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index a5abbfae4ff..95ba49aa171 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -3,11 +3,12 @@ from compressed_tensors.quantization import FP8_DTYPE import vllm.envs as envs -from vllm._custom_ops import cutlass_scaled_mm, scaled_fp8_quant from vllm.compilation.config import CompilationConfig from vllm.compilation.fusion import (FusionPass, find_auto_fn, find_auto_fn_maybe) from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + apply_fp8_linear) from .backend import TestBackend @@ -26,20 +27,12 @@ def __init__(self, hidden_size: int, eps: float, *args, **kwargs): def forward(self, x): resid = torch.relu(x) y = self.norm[0](x) - yq, s0 = scaled_fp8_quant(y, self.scale[0]) - x2 = cutlass_scaled_mm(yq, - self.w[0], - s0, - self.scale[1], - out_dtype=x.dtype) + + x2 = apply_fp8_linear(y, self.w[0], self.scale[0], self.scale[1]) # make sure resid is used for replacement to work y2, resid = self.norm[1](x2, resid) - yq2, s2 = scaled_fp8_quant(y2, self.scale[2]) - x3 = cutlass_scaled_mm(yq2, - self.w[1], - s2, - self.scale[3], - out_dtype=x.dtype) + + x3 = apply_fp8_linear(y2, self.w[1], self.scale[2], self.scale[3]) y3, resid = self.norm[2](x3, resid) # use resid here return y3 From a252997911a197f429bf833235baa8723e51a11d Mon Sep 17 00:00:00 2001 From: luka Date: Thu, 31 Oct 2024 18:56:08 +0000 Subject: [PATCH 82/88] Add redundant reshapes removal pass. Signed-off-by: luka --- tests/compile/test_fusion.py | 11 ++++- vllm/compilation/backends.py | 3 ++ vllm/compilation/reshapes.py | 78 ++++++++++++++++++++++++++++++++++++ 3 files changed, 90 insertions(+), 2 deletions(-) create mode 100644 vllm/compilation/reshapes.py diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 95ba49aa171..869d1195cde 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -6,6 +6,7 @@ from vllm.compilation.config import CompilationConfig from vllm.compilation.fusion import (FusionPass, find_auto_fn, find_auto_fn_maybe) +from vllm.compilation.reshapes import RedundantReshapesPass from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( apply_fp8_linear) @@ -38,7 +39,9 @@ def forward(self, x): # Init does pattern registration, which can only happen once -fusion_pass = FusionPass(CompilationConfig(enable_fusion=True)) +config = CompilationConfig(enable_fusion=True) +reshape_pass = RedundantReshapesPass(config) +fusion_pass = FusionPass(config) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @@ -54,10 +57,14 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps): if eps != 1e-5: pytest.skip("Only test eps=1e-5 for now") - backend = TestBackend(fusion_pass) + # Reshape pass is needed for the fusion pass to work + backend = TestBackend(reshape_pass, fusion_pass) model = TestModel(hidden_size, eps) + # First dimension dynamic x = torch.rand(num_tokens, hidden_size) + torch._dynamo.mark_dynamic(x, 0) + result = model(x) model2 = torch.compile(model, backend=backend) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 4445e1e107d..e6568dc5063 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -14,6 +14,7 @@ from .counter import compilation_counter from .fusion import FusionPass from .levels import CompilationLevel +from .reshapes import RedundantReshapesPass logger = init_logger(__name__) @@ -331,6 +332,8 @@ def add_passes_to_config(self): config = self.compilation_configs passes = list(self.post_grad_passes) + passes = passes + [RedundantReshapesPass(config)] + if config.enable_fusion: passes = passes + [FusionPass(config)] diff --git a/vllm/compilation/reshapes.py b/vllm/compilation/reshapes.py new file mode 100644 index 00000000000..99174992fff --- /dev/null +++ b/vllm/compilation/reshapes.py @@ -0,0 +1,78 @@ +from typing import Union + +import torch.fx +from torch import SymInt + +from vllm.compilation.fusion import is_func +from vllm.compilation.inductor_pass import InductorPass +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class RedundantReshapesPass(InductorPass): + """ + This is an inductor pass that removes redundant reshape operations. + It is required for RMSNorm-quant fusion to work properly. + That's because apply_fp8_linear adds a reshape, which is redundant + in the 2D-case. + + Example graph: + + getitem_1: "f16[s0, 4096]" = ... + view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096]) + at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...) + out: "f8e4m3fn[s0, 4096]" = at[1] + + Can be replaced with: + getitem_1: "f16[s0, 4096]" = ... + at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...) + out: "f8e4m3fn[s0, 4096]" = at[1] + """ + + def __call__(self, graph: torch.fx.Graph): + self.dump_graph(graph, "before_reshapes") + count = 0 + # Remove no-op reshapes/views: + for node in graph.nodes: + if is_func(node, torch.ops.aten.reshape.default): + input, shape = node.args[:2] + input_shape = input.meta["val"].shape + + if all( + self.dims_equivalent(s, i_s) + for s, i_s in zip(shape, input_shape)): + node.replace_all_uses_with(input) + graph.erase_node(node) + count += 1 + + logger.info("Removed %s no-op reshapes", count) + + self.dump_graph(graph, "after_reshapes") + + def dims_equivalent(self, dim: Union[int, torch.fx.Node], + i_dim: Union[int, SymInt]) -> bool: + """ + This function checks if two dimensions are equivalent. + :param dim: The dimension arg to reshape + :param i_dim: The corresponding dimension in the input tensor + :return: Are the dimensions equivalent? + + There are three cases in which the dimensions are equivalent: + 1. The dimensions are equal (both integers) + 2. The reshape dimension is -1 (i.e. inferred) + 3. The dimensions both correspond to the same SymInt + + While case 2 does not guarantee the dimensions are equal, + they are equal if all other dimensions are equal. + + In case 3, the reshape dimension is a torch.fx.Node, + and its value is a SymInt. That value is equal to the + input dimension. + + """ + # Case 1 and 2 + if dim == i_dim or dim == -1: + return True + # Case 3 + return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim From 1b9717f248679d2aaa0f7ae27816ec3a46555604 Mon Sep 17 00:00:00 2001 From: luka Date: Thu, 31 Oct 2024 18:56:22 +0000 Subject: [PATCH 83/88] Fix graph dumping when TP not initialized Signed-off-by: luka --- vllm/compilation/inductor_pass.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 97cc8b59c0d..b388a45363b 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -7,6 +7,7 @@ from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank from vllm.distributed import ( get_tensor_model_parallel_world_size as get_tp_world_size) +from vllm.distributed import model_parallel_is_initialized as p_is_init # yapf: enable from vllm.logger import init_logger @@ -25,7 +26,8 @@ def __init__(self, config: CompilationConfig): def dump_graph(self, graph: torch.fx.Graph, stage: str): if stage in self.config.dump_graph_stages: # Make sure filename includes rank in the distributed setting - rank = f"-{get_tp_rank()}" if get_tp_world_size() > 1 else "" + parallel = p_is_init() and get_tp_world_size() > 1 + rank = f"-{get_tp_rank()}" if parallel else "" filepath = self.config.dump_graph_dir / f"{stage}{rank}.py" logger.info("Printing graph to %s", filepath) From daca8902360d088b85285bb46963924a9f6f934e Mon Sep 17 00:00:00 2001 From: luka Date: Thu, 31 Oct 2024 20:22:59 +0000 Subject: [PATCH 84/88] Reshape add edge-cases Signed-off-by: luka --- vllm/compilation/reshapes.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/vllm/compilation/reshapes.py b/vllm/compilation/reshapes.py index 99174992fff..0d284246d25 100644 --- a/vllm/compilation/reshapes.py +++ b/vllm/compilation/reshapes.py @@ -38,6 +38,13 @@ def __call__(self, graph: torch.fx.Graph): if is_func(node, torch.ops.aten.reshape.default): input, shape = node.args[:2] input_shape = input.meta["val"].shape + if len(shape) != len(input_shape): + # Reshape changing rank, skip + continue + + if shape.count(-1) > 1: + # Invalid reshape args, skip + continue if all( self.dims_equivalent(s, i_s) From 429db0a711cb0c392c23e0d217bbda740685c318 Mon Sep 17 00:00:00 2001 From: luka Date: Thu, 31 Oct 2024 20:53:39 +0000 Subject: [PATCH 85/88] Singleton pattern matcher for fusion pass Signed-off-by: luka --- vllm/compilation/fusion.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 87e961893d3..40dded2c7c2 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -128,12 +128,27 @@ def find_getitem(node: torch.fx.Node, idx: int) -> torch.fx.Node: class FusionPass(InductorPass): + # Patterns can only be registered once, so we store them as a class variable + _patterns: PatternMatcherPass = None def __init__(self, config: CompilationConfig): super().__init__(config) - self.my_patterns = PatternMatcherPass(pass_name="fusion_pass") self.matches: List[Match] = [] + self.register_patterns() + + @property + def patterns(self) -> PatternMatcherPass: + assert self.__class__._patterns is not None, \ + "Accessing patterns before they were registered" + return self.__class__._patterns + + def register_patterns(self): + # Only register patterns once + if self.__class__._patterns is not None: + return + + self.__class__._patterns = PatternMatcherPass(pass_name="fusion_pass") # Fuse rms_norm + static_scaled_fp8_quant into # rms_norm_static_fp8_quant @@ -145,7 +160,7 @@ def __init__(self, config: CompilationConfig): empty_fp32(1, 1) ] register_replacement(rms_pattern_static, rms_replacement_static, - inputs, fwd_only, self.my_patterns) + inputs, fwd_only, self.patterns) # Fuse fused_add_rms_norm + static_scaled_fp8_quant into # fused_add_rms_norm_static_fp8_quant @@ -162,7 +177,7 @@ def __init__(self, config: CompilationConfig): rms_replacement_residual_static, inputs, fwd_only, - self.my_patterns, + self.patterns, extra_check=lambda m: self.record_match(m)) def record_match(self, match: Match) -> bool: @@ -252,7 +267,7 @@ def process_matches(self, graph: torch.fx.Graph): def __call__(self, graph: torch.fx.Graph): self.dump_graph(graph, "before_fusion") - count = self.my_patterns.apply(graph) + count = self.patterns.apply(graph) logger.info("Replaced %s patterns", count) self.dump_graph(graph, "after_pattern_match") From e0b904e54a35942b8f6b6f1d4bea7cce610bcd94 Mon Sep 17 00:00:00 2001 From: luka Date: Fri, 8 Nov 2024 16:02:51 +0000 Subject: [PATCH 86/88] singleton fusion pass Signed-off-by: luka --- tests/compile/test_fusion.py | 2 +- vllm/compilation/backends.py | 2 +- vllm/compilation/fusion.py | 45 +++++++++++++++++++++++------------- 3 files changed, 31 insertions(+), 18 deletions(-) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 869d1195cde..e4d3defafb9 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -41,7 +41,7 @@ def forward(self, x): # Init does pattern registration, which can only happen once config = CompilationConfig(enable_fusion=True) reshape_pass = RedundantReshapesPass(config) -fusion_pass = FusionPass(config) +fusion_pass = FusionPass.instance(config) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index e6568dc5063..8259d1c1519 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -335,7 +335,7 @@ def add_passes_to_config(self): passes = passes + [RedundantReshapesPass(config)] if config.enable_fusion: - passes = passes + [FusionPass(config)] + passes = passes + [FusionPass.instance(config)] inductor_config = config.inductor_compile_config if "post_grad_custom_post_pass" in inductor_config: diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 40dded2c7c2..2a0cf0002c9 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -128,27 +128,40 @@ def find_getitem(node: torch.fx.Node, idx: int) -> torch.fx.Node: class FusionPass(InductorPass): - # Patterns can only be registered once, so we store them as a class variable - _patterns: PatternMatcherPass = None + """ + This pass fuses a pre-defined set of custom ops into fused ops. + It uses the torch pattern matcher to find the patterns and replace them. + It also manually processes multi-output matches, as those are broken in + the torch pattern matcher. - def __init__(self, config: CompilationConfig): - super().__init__(config) + Because patterns can only be registered once, the pass is a singleton. + This will be addressed in a future version of PyTorch: + https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980 + """ - self.matches: List[Match] = [] - self.register_patterns() + _instance: 'Optional[FusionPass]' = None - @property - def patterns(self) -> PatternMatcherPass: - assert self.__class__._patterns is not None, \ - "Accessing patterns before they were registered" - return self.__class__._patterns + @classmethod + def instance(cls, config: CompilationConfig): + """ + Get the singleton instance of the FusionPass. + If the instance exists, the config is updated but + initialization is not repeated. + """ + if cls._instance is None: + cls._instance = FusionPass(config) + else: + cls._instance.config = config + return cls._instance - def register_patterns(self): - # Only register patterns once - if self.__class__._patterns is not None: - return + def __init__(self, config: CompilationConfig): + assert self.__class__._instance is None, \ + "FusionPass singleton instance already exists" + super().__init__(config) - self.__class__._patterns = PatternMatcherPass(pass_name="fusion_pass") + self.matches: List[Match] = [] + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="fusion_pass") # Fuse rms_norm + static_scaled_fp8_quant into # rms_norm_static_fp8_quant From d9375df090e4af22c671d7702086107857d10103 Mon Sep 17 00:00:00 2001 From: luka Date: Fri, 8 Nov 2024 16:37:19 +0000 Subject: [PATCH 87/88] format Signed-off-by: luka --- tests/compile/backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/compile/backend.py b/tests/compile/backend.py index c06c15bb179..9d5c6827437 100644 --- a/tests/compile/backend.py +++ b/tests/compile/backend.py @@ -4,7 +4,7 @@ import torch -class TestBackend(): +class TestBackend: """ This class provides a simple Inductor backend that can be used for testing. It takes a list of custom passes and runs them after Inductor's passes. From d0a9e3725619a2e3a17f8743c51ec4d826159f58 Mon Sep 17 00:00:00 2001 From: luka Date: Thu, 7 Nov 2024 18:54:55 +0000 Subject: [PATCH 88/88] Add print Signed-off-by: luka --- vllm/compilation/inductor_pass.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index b388a45363b..b23351fa197 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -33,4 +33,6 @@ def dump_graph(self, graph: torch.fx.Graph, stage: str): logger.info("Printing graph to %s", filepath) with open(filepath, "w") as f: src = graph.python_code(root_module="self", verbose=True).src + # Add imports so it's not full of errors + print("import torch; from torch import device", file=f) print(src, file=f)