diff --git a/CMakeLists.txt b/CMakeLists.txt index 25c0865a90a..376565583d9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -191,6 +191,7 @@ set(VLLM_EXT_SRC "csrc/pos_encoding_kernels.cu" "csrc/activation_kernels.cu" "csrc/layernorm_kernels.cu" + "csrc/layernorm_quant_kernels.cu" "csrc/quantization/gptq/q_gemm.cu" "csrc/quantization/compressed_tensors/int8_quant_kernels.cu" "csrc/quantization/fp8/common.cu" diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 7a7a25d2173..fb6882f3e7c 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -1,21 +1,13 @@ -#include -#include +#include "type_convert.cuh" +#include "dispatch_utils.h" + +#include #include -#include "dispatch_utils.h" #ifndef USE_ROCM - #include - #include - #include #include #else - #include - #include - #include #include - -using __nv_bfloat16 = __hip_bfloat16; -using __nv_bfloat162 = __hip_bfloat162; #endif namespace vllm { @@ -51,155 +43,6 @@ __global__ void rms_norm_kernel( } } -/* Converter structs for the conversion from torch types to HIP/CUDA types, - and the associated type conversions within HIP/CUDA. These helpers need - to be implemented for now because the relevant type conversion - operators/constructors are not consistently implemented by HIP/CUDA, so - a generic conversion via type casts cannot be implemented. - - Each struct should have the member static constexpr bool `exists`: - If false, the optimized kernel is not used for the corresponding torch type. - If true, the struct should be fully defined as shown in the examples below. - */ -template -struct _typeConvert { - static constexpr bool exists = false; -}; - -#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) -// CUDA < 12.0 runs into issues with packed type conversion -template <> -struct _typeConvert { - static constexpr bool exists = true; - using hip_type = __half; - using packed_hip_type = __half2; - - __device__ static inline float convert(hip_type x) { return __half2float(x); } - __device__ static inline float2 convert(packed_hip_type x) { - return __half22float2(x); - } - __device__ static inline hip_type convert(float x) { - return __float2half_rn(x); - } - __device__ static inline packed_hip_type convert(float2 x) { - return __float22half2_rn(x); - } -}; - - #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -// CUDA_ARCH < 800 does not have BF16 support -// TODO: Add in ROCm support once public headers handle bf16 maturely -template <> -struct _typeConvert { - static constexpr bool exists = true; - using hip_type = __nv_bfloat16; - using packed_hip_type = __nv_bfloat162; - - __device__ static inline float convert(hip_type x) { - return __bfloat162float(x); - } - __device__ static inline float2 convert(packed_hip_type x) { - return __bfloat1622float2(x); - } - __device__ static inline hip_type convert(float x) { - return __float2bfloat16(x); - } - __device__ static inline packed_hip_type convert(float2 x) { - return __float22bfloat162_rn(x); - } -}; - #endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= - // 12000)) - -/* Vector POD struct to generate vectorized and packed FP16/BF16 ops - for appropriate specializations of fused_add_rms_norm_kernel. - Only functions that are necessary in that kernel are implemented. - Alignment to 16 bytes is required to use 128-bit global memory ops. - */ -template -struct alignas(16) _f16Vec { - /* Not theoretically necessary that width is a power of 2 but should - almost always be the case for optimization purposes */ - static_assert(width > 0 && (width & (width - 1)) == 0, - "Width is not a positive power of 2!"); - using Converter = _typeConvert; - using T1 = typename Converter::hip_type; - using T2 = typename Converter::packed_hip_type; - T1 data[width]; - - __device__ _f16Vec& operator+=(const _f16Vec& other) { - if constexpr (width % 2 == 0) { -#pragma unroll - for (int i = 0; i < width; i += 2) { - T2 temp{data[i], data[i + 1]}; - temp += T2{other.data[i], other.data[i + 1]}; - data[i] = temp.x; - data[i + 1] = temp.y; - } - } else { -#pragma unroll - for (int i = 0; i < width; ++i) data[i] += other.data[i]; - } - return *this; - } - - __device__ _f16Vec& operator*=(const _f16Vec& other) { - if constexpr (width % 2 == 0) { -#pragma unroll - for (int i = 0; i < width; i += 2) { - T2 temp{data[i], data[i + 1]}; - temp *= T2{other.data[i], other.data[i + 1]}; - data[i] = temp.x; - data[i + 1] = temp.y; - } - } else { -#pragma unroll - for (int i = 0; i < width; ++i) data[i] *= other.data[i]; - } - return *this; - } - - __device__ _f16Vec& operator*=(const float scale) { - if constexpr (width % 2 == 0) { -#pragma unroll - for (int i = 0; i < width; i += 2) { - float2 temp_f = Converter::convert(T2{data[i], data[i + 1]}); - temp_f.x *= scale; - temp_f.y *= scale; - T2 temp = Converter::convert(temp_f); - data[i] = temp.x; - data[i + 1] = temp.y; - } - } else { -#pragma unroll - for (int i = 0; i < width; ++i) { - float temp = Converter::convert(data[i]) * scale; - data[i] = Converter::convert(temp); - } - } - return *this; - } - - __device__ float sum_squares() const { - float result = 0.0f; - if constexpr (width % 2 == 0) { -#pragma unroll - for (int i = 0; i < width; i += 2) { - float2 z = Converter::convert(T2{data[i], data[i + 1]}); - result += z.x * z.x + z.y * z.y; - } - } else { -#pragma unroll - for (int i = 0; i < width; ++i) { - float x = Converter::convert(data[i]); - result += x * x; - } - } - return result; - } -}; - /* Function specialization in the case of FP16/BF16 tensors. Additional optimizations we can make in this case are packed and vectorized operations, which help with the diff --git a/csrc/layernorm_quant_kernels.cu b/csrc/layernorm_quant_kernels.cu new file mode 100644 index 00000000000..c18e2a4e4ab --- /dev/null +++ b/csrc/layernorm_quant_kernels.cu @@ -0,0 +1,234 @@ +/* + * This file contains the CUDA kernels for the fused quantized layernorm. + * The kernels correspond to the kernels in layernorm_kernels.cu, except they + * also produce quantized output directly. + * Currently, only static fp8 quantization is supported. + */ + +#include "type_convert.cuh" +#include "quantization/fp8/common.cuh" +#include "dispatch_utils.h" + +#include +#include + +#ifndef USE_ROCM + #include +#else + #include +#endif + +namespace vllm { + +// TODO(woosuk): Further optimize this kernel. +template +__global__ void rms_norm_static_fp8_quant_kernel( + FP8_TYPE* __restrict__ out, // [..., hidden_size] + const scalar_t* __restrict__ input, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float* __restrict__ scale, // [1] + const float epsilon, const int num_tokens, const int hidden_size) { + __shared__ float s_variance; + float variance = 0.0f; + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + const float x = (float)input[blockIdx.x * hidden_size + idx]; + variance += x * x; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + // invert scale to avoid division + float const scale_inv = 1.0f / *scale; + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float x = (float)input[blockIdx.x * hidden_size + idx]; + float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx]; + out[blockIdx.x * hidden_size + idx] = + scaled_fp8_conversion(out_norm, scale_inv); + } +} + +/* Function specialization in the case of FP16/BF16 tensors. + Additional optimizations we can make in this case are + packed and vectorized operations, which help with the + memory latency bottleneck. */ +template +__global__ std::enable_if_t<(width > 0) && _typeConvert::exists> +fused_add_rms_norm_static_fp8_quant_kernel( + FP8_TYPE* __restrict__ out, // [..., hidden_size] + scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ residual, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float* __restrict__ scale, // [1] + const float epsilon, const int num_tokens, const int hidden_size) { + // Sanity checks on our vector struct and type-punned pointer arithmetic + static_assert(std::is_pod_v<_f16Vec>); + static_assert(sizeof(_f16Vec) == sizeof(scalar_t) * width); + + const int vec_hidden_size = hidden_size / width; + __shared__ float s_variance; + float variance = 0.0f; + /* These and the argument pointers are all declared `restrict` as they are + not aliased in practice. Argument pointers should not be dereferenced + in this kernel as that would be undefined behavior */ + auto* __restrict__ input_v = + reinterpret_cast<_f16Vec*>(input); + auto* __restrict__ residual_v = + reinterpret_cast<_f16Vec*>(residual); + auto* __restrict__ weight_v = + reinterpret_cast*>(weight); + + for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { + int id = blockIdx.x * vec_hidden_size + idx; + _f16Vec temp = input_v[id]; + temp += residual_v[id]; + variance += temp.sum_squares(); + residual_v[id] = temp; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + // invert scale to avoid division + float const scale_inv = 1.0f / *scale; + + for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { + int id = blockIdx.x * vec_hidden_size + idx; + _f16Vec temp = residual_v[id]; + temp *= s_variance; + temp *= weight_v[idx]; +#pragma unroll + for (int i = 0; i < width; ++i) { + out[id * width + i] = + scaled_fp8_conversion(float(temp.data[i]), scale_inv); + } + } +} + +/* Generic fused_add_rms_norm_kernel + The width field is not used here but necessary for other specializations. + */ +template +__global__ std::enable_if_t<(width == 0) || !_typeConvert::exists> +fused_add_rms_norm_static_fp8_quant_kernel( + FP8_TYPE* __restrict__ out, // [..., hidden_size] + scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ residual, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float* __restrict__ scale, // [1] + const float epsilon, const int num_tokens, const int hidden_size) { + __shared__ float s_variance; + float variance = 0.0f; + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + scalar_t z = input[blockIdx.x * hidden_size + idx]; + z += residual[blockIdx.x * hidden_size + idx]; + float x = (float)z; + variance += x * x; + residual[blockIdx.x * hidden_size + idx] = z; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + // invert scale to avoid division + float const scale_inv = 1.0f / *scale; + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float x = (float)residual[blockIdx.x * hidden_size + idx]; + float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx]; + out[blockIdx.x * hidden_size + idx] = + scaled_fp8_conversion(out_norm, scale_inv); + } +} + +} // namespace vllm + +void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + torch::Tensor& scale, // [1] + double epsilon) { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { + vllm::rms_norm_static_fp8_quant_kernel + <<>>( + out.data_ptr(), input.data_ptr(), + weight.data_ptr(), scale.data_ptr(), epsilon, + num_tokens, hidden_size); + }); +} + +#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \ + vllm::fused_add_rms_norm_static_fp8_quant_kernel \ + <<>>( \ + out.data_ptr(), input.data_ptr(), \ + residual.data_ptr(), weight.data_ptr(), \ + scale.data_ptr(), epsilon, num_tokens, hidden_size); \ + }); + +void fused_add_rms_norm_static_fp8_quant( + torch::Tensor& out, // [..., hidden_size], + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& residual, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + torch::Tensor& scale, // [1] + double epsilon) { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + dim3 grid(num_tokens); + /* This kernel is memory-latency bound in many scenarios. + When num_tokens is large, a smaller block size allows + for increased block occupancy on CUs and better latency + hiding on global mem ops. */ + const int max_block_size = (num_tokens < 256) ? 1024 : 256; + dim3 block(std::min(hidden_size, max_block_size)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + /*If the tensor types are FP16/BF16, try to use the optimized kernel + with packed + vectorized ops. + Max optimization is achieved with a width-8 vector of FP16/BF16s + since we can load at most 128 bits at once in a global memory op. + However, this requires each tensor's data to be aligned to 16 + bytes. + */ + auto inp_ptr = reinterpret_cast(input.data_ptr()); + auto res_ptr = reinterpret_cast(residual.data_ptr()); + auto wt_ptr = reinterpret_cast(weight.data_ptr()); + bool ptrs_are_aligned = + inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0; + if (ptrs_are_aligned && hidden_size % 8 == 0) { + LAUNCH_FUSED_ADD_RMS_NORM(8); + } else { + LAUNCH_FUSED_ADD_RMS_NORM(0); + } +} diff --git a/csrc/ops.h b/csrc/ops.h index e0775ee1891..672e608e9c4 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -56,6 +56,16 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, double epsilon); +void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input, + torch::Tensor& weight, torch::Tensor& scale, + double epsilon); + +void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out, + torch::Tensor& input, + torch::Tensor& residual, + torch::Tensor& weight, + torch::Tensor& scale, double epsilon); + void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, torch::Tensor& key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox); diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index f2c609c1b68..e4f6615ede1 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -1,185 +1,16 @@ -#include -#include -#include - -#include - -#include "cuda_compat.h" +#include "common.cuh" #include "dispatch_utils.h" +#include + #ifndef USE_ROCM - #include #include #else - #include #include #endif -#ifndef USE_ROCM -using FP8_TYPE = c10::Float8_e4m3fn; -C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = - std::numeric_limits::max(); -#else - #include "amd/hip_float8.h" -using FP8_TYPE = c10::Float8_e4m3fnuz; -// Using the default max value from pytorch (240.0) will cause accuracy -// issue when running dynamic quantization. Here use 224.0f for rocm. -constexpr auto FP8_E4M3_MAX = 224.0f; -#endif - namespace vllm { -__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { - float old; - old = (value >= 0) - ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) - : __uint_as_float( - atomicMin((unsigned int*)addr, __float_as_uint(value))); - - return old; -} - -template -__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val, - float const scale) { - float x = 0.0f; - if constexpr (is_scale_inverted) { - x = val * scale; - } else { - x = val / scale; - } - - float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); -#ifndef USE_ROCM - return static_cast(r); -#else - // Use hardware cvt instruction for fp8 on rocm - return c10::Float8_e4m3fnuz(hip_fp8(r).data, - c10::Float8_e4m3fnuz::from_bits()); -#endif -} - -// Compute the absolute maximum m of the input tensor and store -// m / float8_e4m3::max() in *scale. Each thread block performs a -// reduction tree and the memory in scale is atomically updated. -// So to get the right answer, *scale needs to be initialized to -// a value <= 0.0 and we need to wait for all thread blocks to -// finish before consuming *scale. -template -__global__ void segmented_max_reduction(float* __restrict__ scale, - const scalar_t* __restrict__ input, - int64_t num_elems) { - __shared__ float cache[1024]; - int64_t i = blockDim.x * blockIdx.x + threadIdx.x; - - // First store maximum for all values processes by - // the current thread in cache[threadIdx.x] - scalar_t tmp = 0.0; - while (i < num_elems) { - float x = static_cast(input[i]); - tmp = max(tmp, fabs(x)); - i += blockDim.x * gridDim.x; - } - cache[threadIdx.x] = tmp; - - __syncthreads(); - - // Now perform parallel reduction within the thread block - int ib = blockDim.x / 2; - while (ib != 0) { - if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) { - cache[threadIdx.x] = cache[threadIdx.x + ib]; - } - __syncthreads(); - ib /= 2; - } - // Finally, since cache[0] contains the maximum for this thread block, - // atomically write the max to the target location - if (threadIdx.x == 0) { - atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX); - } -} - -template -struct __align__(8) vec4_t { - scalar_t x; - scalar_t y; - scalar_t z; - scalar_t w; -}; - -typedef struct __align__(4) { - FP8_TYPE x; - FP8_TYPE y; - FP8_TYPE z; - FP8_TYPE w; -} -float8x4_t; - -template -__device__ float thread_max_vec(scalar_t const* __restrict__ input, - int64_t const num_elems, int const tid, - int const step) { - // Vectorized input/output to better utilize memory bandwidth. - vec4_t const* vectorized_in = - reinterpret_cast const*>(input); - - int64_t const num_vec_elems = num_elems >> 2; - float absmax_val = 0.0f; - -#pragma unroll 4 - for (int64_t i = tid; i < num_vec_elems; i += step) { - vec4_t in_vec = vectorized_in[i]; - absmax_val = max(absmax_val, fabs(in_vec.x)); - absmax_val = max(absmax_val, fabs(in_vec.y)); - absmax_val = max(absmax_val, fabs(in_vec.z)); - absmax_val = max(absmax_val, fabs(in_vec.w)); - } - - // Handle the remaining elements if num_elems is not divisible by 4 - for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) { - absmax_val = max(absmax_val, fabs(input[i])); - } - - return absmax_val; -} - -template -__device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out, - scalar_t const* __restrict__ input, - float const scale, - int64_t const num_elems, - int const tid, int const step) { - // Vectorized input/output to better utilize memory bandwidth. - vec4_t const* vectorized_in = - reinterpret_cast const*>(input); - float8x4_t* vectorized_out = reinterpret_cast(out); - - int64_t const num_vec_elems = num_elems >> 2; - -#pragma unroll 4 - for (int64_t i = tid; i < num_vec_elems; i += step) { - vec4_t in_vec = vectorized_in[i]; - float8x4_t out_vec; - - out_vec.x = scaled_fp8_conversion( - static_cast(in_vec.x), scale); - out_vec.y = scaled_fp8_conversion( - static_cast(in_vec.y), scale); - out_vec.z = scaled_fp8_conversion( - static_cast(in_vec.z), scale); - out_vec.w = scaled_fp8_conversion( - static_cast(in_vec.w), scale); - vectorized_out[i] = out_vec; - } - - // Handle the remaining elements if num_elems is not divisible by 4 - for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) { - out[i] = scaled_fp8_conversion( - static_cast(input[i]), scale); - } -} - template __global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out, const scalar_t* __restrict__ input, diff --git a/csrc/quantization/fp8/common.cuh b/csrc/quantization/fp8/common.cuh new file mode 100644 index 00000000000..d7c0297d533 --- /dev/null +++ b/csrc/quantization/fp8/common.cuh @@ -0,0 +1,172 @@ +#pragma once + +#include + +#ifndef USE_ROCM + #include +using FP8_TYPE = c10::Float8_e4m3fn; +C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = + std::numeric_limits::max(); +#else + #include + #include "amd/hip_float8.h" +using FP8_TYPE = c10::Float8_e4m3fnuz; +// Using the default max value from pytorch (240.0) will cause accuracy +// issue when running dynamic quantization. Here use 224.0f for rocm. +constexpr auto FP8_E4M3_MAX = 224.0f; +#endif + +namespace vllm { + +__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { + float old; + old = (value >= 0) + ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) + : __uint_as_float( + atomicMin((unsigned int*)addr, __float_as_uint(value))); + + return old; +} + +template +__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val, + float const scale) { + float x = 0.0f; + if constexpr (is_scale_inverted) { + x = val * scale; + } else { + x = val / scale; + } + + float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); +#ifndef USE_ROCM + return static_cast(r); +#else + // Use hardware cvt instruction for fp8 on rocm + return c10::Float8_e4m3fnuz(hip_fp8(r).data, + c10::Float8_e4m3fnuz::from_bits()); +#endif +} + +// Compute the absolute maximum m of the input tensor and store +// m / float8_e4m3::max() in *scale. Each thread block performs a +// reduction tree and the memory in scale is atomically updated. +// So to get the right answer, *scale needs to be initialized to +// a value <= 0.0 and we need to wait for all thread blocks to +// finish before consuming *scale. +template +__global__ void segmented_max_reduction(float* __restrict__ scale, + const scalar_t* __restrict__ input, + int64_t num_elems) { + __shared__ float cache[1024]; + int64_t i = blockDim.x * blockIdx.x + threadIdx.x; + + // First store maximum for all values processes by + // the current thread in cache[threadIdx.x] + scalar_t tmp = 0.0; + while (i < num_elems) { + float x = static_cast(input[i]); + tmp = max(tmp, fabs(x)); + i += blockDim.x * gridDim.x; + } + cache[threadIdx.x] = tmp; + + __syncthreads(); + + // Now perform parallel reduction within the thread block + int ib = blockDim.x / 2; + while (ib != 0) { + if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) { + cache[threadIdx.x] = cache[threadIdx.x + ib]; + } + __syncthreads(); + ib /= 2; + } + // Finally, since cache[0] contains the maximum for this thread block, + // atomically write the max to the target location + if (threadIdx.x == 0) { + atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX); + } +} + +template +struct __align__(8) vec4_t { + scalar_t x; + scalar_t y; + scalar_t z; + scalar_t w; +}; + +typedef struct __align__(4) { + FP8_TYPE x; + FP8_TYPE y; + FP8_TYPE z; + FP8_TYPE w; +} +float8x4_t; + +template +__device__ float thread_max_vec(scalar_t const* __restrict__ input, + int64_t const num_elems, int const tid, + int const step) { + // Vectorized input/output to better utilize memory bandwidth. + vec4_t const* vectorized_in = + reinterpret_cast const*>(input); + + int64_t const num_vec_elems = num_elems >> 2; + float absmax_val = 0.0f; + +#pragma unroll 4 + for (int64_t i = tid; i < num_vec_elems; i += step) { + vec4_t in_vec = vectorized_in[i]; + absmax_val = max(absmax_val, fabs(in_vec.x)); + absmax_val = max(absmax_val, fabs(in_vec.y)); + absmax_val = max(absmax_val, fabs(in_vec.z)); + absmax_val = max(absmax_val, fabs(in_vec.w)); + } + + // Handle the remaining elements if num_elems is not divisible by 4 + for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) { + absmax_val = max(absmax_val, fabs(input[i])); + } + + return absmax_val; +} + +template +__device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out, + scalar_t const* __restrict__ input, + float const scale, + int64_t const num_elems, + int const tid, int const step) { + // Vectorized input/output to better utilize memory bandwidth. + vec4_t const* vectorized_in = + reinterpret_cast const*>(input); + float8x4_t* vectorized_out = reinterpret_cast(out); + + int64_t const num_vec_elems = num_elems >> 2; + +#pragma unroll 4 + for (int64_t i = tid; i < num_vec_elems; i += step) { + vec4_t in_vec = vectorized_in[i]; + float8x4_t out_vec; + + out_vec.x = scaled_fp8_conversion( + static_cast(in_vec.x), scale); + out_vec.y = scaled_fp8_conversion( + static_cast(in_vec.y), scale); + out_vec.z = scaled_fp8_conversion( + static_cast(in_vec.z), scale); + out_vec.w = scaled_fp8_conversion( + static_cast(in_vec.w), scale); + vectorized_out[i] = out_vec; + } + + // Handle the remaining elements if num_elems is not divisible by 4 + for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) { + out[i] = scaled_fp8_conversion( + static_cast(input[i]), scale); + } +} + +} // namespace vllm \ No newline at end of file diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 971a45d50ff..229fd554d3e 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -101,7 +101,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Layernorm // Apply Root Mean Square (RMS) Normalization to the input tensor. ops.def( - "rms_norm(Tensor! out, Tensor input, Tensor weight, float epsilon) -> " + "rms_norm(Tensor! result, Tensor input, Tensor weight, float epsilon) -> " "()"); ops.impl("rms_norm", torch::kCUDA, &rms_norm); @@ -111,6 +111,23 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "float epsilon) -> ()"); ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm); + // Layernorm-quant + // Apply Root Mean Square (RMS) Normalization to the input tensor. + ops.def( + "rms_norm_static_fp8_quant(Tensor! result, Tensor input, Tensor weight, " + "Tensor scale, float epsilon) -> " + "()"); + ops.impl("rms_norm_static_fp8_quant", torch::kCUDA, + &rms_norm_static_fp8_quant); + + // In-place fused Add and RMS Normalization. + ops.def( + "fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor input, " + "Tensor! residual, Tensor weight, " + "Tensor scale, float epsilon) -> ()"); + ops.impl("fused_add_rms_norm_static_fp8_quant", torch::kCUDA, + &fused_add_rms_norm_static_fp8_quant); + // Rotary embedding // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. ops.def( @@ -322,18 +339,20 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Compute FP8 quantized tensor for given scaling factor. ops.def( - "static_scaled_fp8_quant(Tensor! out, Tensor input, Tensor scale) -> ()"); + "static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> " + "()"); ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant); // Compute dynamic-per-tensor FP8 quantized tensor and scaling factor. ops.def( - "dynamic_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! scale) -> " + "dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) " + "-> " "()"); ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant); // Compute dynamic-per-token FP8 quantized tensor and scaling factor. ops.def( - "dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, " + "dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, " "Tensor! scale, Tensor? scale_ub) -> " "()"); ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA, @@ -341,13 +360,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Compute int8 quantized tensor for given scaling factor. ops.def( - "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale," + "static_scaled_int8_quant(Tensor! result, Tensor input, Tensor scale," "Tensor? azp) -> ()"); ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant); // Compute int8 quantized tensor and scaling factor ops.def( - "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, " + "dynamic_scaled_int8_quant(Tensor! result, Tensor input, Tensor! scale, " "Tensor!? azp) -> ()"); ops.impl("dynamic_scaled_int8_quant", torch::kCUDA, &dynamic_scaled_int8_quant); diff --git a/csrc/type_convert.cuh b/csrc/type_convert.cuh new file mode 100644 index 00000000000..21b9d0ae515 --- /dev/null +++ b/csrc/type_convert.cuh @@ -0,0 +1,165 @@ +#pragma once + +#include + +#ifndef USE_ROCM + #include + #include +#else + #include + #include + +using __nv_bfloat16 = __hip_bfloat16; +using __nv_bfloat162 = __hip_bfloat162; +#endif + +namespace vllm { +/* Converter structs for the conversion from torch types to HIP/CUDA types, + and the associated type conversions within HIP/CUDA. These helpers need + to be implemented for now because the relevant type conversion + operators/constructors are not consistently implemented by HIP/CUDA, so + a generic conversion via type casts cannot be implemented. + + Each struct should have the member static constexpr bool `exists`: + If false, the optimized kernel is not used for the corresponding torch type. + If true, the struct should be fully defined as shown in the examples below. + */ +template +struct _typeConvert { + static constexpr bool exists = false; +}; + +#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) +// CUDA < 12.0 runs into issues with packed type conversion +template <> +struct _typeConvert { + static constexpr bool exists = true; + using hip_type = __half; + using packed_hip_type = __half2; + + __device__ static inline float convert(hip_type x) { return __half2float(x); } + __device__ static inline float2 convert(packed_hip_type x) { + return __half22float2(x); + } + __device__ static inline hip_type convert(float x) { + return __float2half_rn(x); + } + __device__ static inline packed_hip_type convert(float2 x) { + return __float22half2_rn(x); + } +}; + + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +// CUDA_ARCH < 800 does not have BF16 support +// TODO: Add in ROCm support once public headers handle bf16 maturely +template <> +struct _typeConvert { + static constexpr bool exists = true; + using hip_type = __nv_bfloat16; + using packed_hip_type = __nv_bfloat162; + + __device__ static inline float convert(hip_type x) { + return __bfloat162float(x); + } + __device__ static inline float2 convert(packed_hip_type x) { + return __bfloat1622float2(x); + } + __device__ static inline hip_type convert(float x) { + return __float2bfloat16(x); + } + __device__ static inline packed_hip_type convert(float2 x) { + return __float22bfloat162_rn(x); + } +}; + #endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= + // 12000)) + +/* Vector POD struct to generate vectorized and packed FP16/BF16 ops + for appropriate specializations of fused_add_rms_norm_kernel. + Only functions that are necessary in that kernel are implemented. + Alignment to 16 bytes is required to use 128-bit global memory ops. + */ +template +struct alignas(16) _f16Vec { + /* Not theoretically necessary that width is a power of 2 but should + almost always be the case for optimization purposes */ + static_assert(width > 0 && (width & (width - 1)) == 0, + "Width is not a positive power of 2!"); + using Converter = _typeConvert; + using T1 = typename Converter::hip_type; + using T2 = typename Converter::packed_hip_type; + T1 data[width]; + + __device__ _f16Vec& operator+=(const _f16Vec& other) { + if constexpr (width % 2 == 0) { +#pragma unroll + for (int i = 0; i < width; i += 2) { + T2 temp{data[i], data[i + 1]}; + temp += T2{other.data[i], other.data[i + 1]}; + data[i] = temp.x; + data[i + 1] = temp.y; + } + } else { +#pragma unroll + for (int i = 0; i < width; ++i) data[i] += other.data[i]; + } + return *this; + } + + __device__ _f16Vec& operator*=(const _f16Vec& other) { + if constexpr (width % 2 == 0) { +#pragma unroll + for (int i = 0; i < width; i += 2) { + T2 temp{data[i], data[i + 1]}; + temp *= T2{other.data[i], other.data[i + 1]}; + data[i] = temp.x; + data[i + 1] = temp.y; + } + } else { +#pragma unroll + for (int i = 0; i < width; ++i) data[i] *= other.data[i]; + } + return *this; + } + + __device__ _f16Vec& operator*=(const float scale) { + if constexpr (width % 2 == 0) { +#pragma unroll + for (int i = 0; i < width; i += 2) { + float2 temp_f = Converter::convert(T2{data[i], data[i + 1]}); + temp_f.x *= scale; + temp_f.y *= scale; + T2 temp = Converter::convert(temp_f); + data[i] = temp.x; + data[i + 1] = temp.y; + } + } else { +#pragma unroll + for (int i = 0; i < width; ++i) { + float temp = Converter::convert(data[i]) * scale; + data[i] = Converter::convert(temp); + } + } + return *this; + } + + __device__ float sum_squares() const { + float result = 0.0f; + if constexpr (width % 2 == 0) { +#pragma unroll + for (int i = 0; i < width; i += 2) { + float2 z = Converter::convert(T2{data[i], data[i + 1]}); + result += z.x * z.x + z.y * z.y; + } + } else { +#pragma unroll + for (int i = 0; i < width; ++i) { + float x = Converter::convert(data[i]); + result += x * x; + } + } + return result; + } +}; +} // namespace vllm \ No newline at end of file diff --git a/tests/compile/backend.py b/tests/compile/backend.py new file mode 100644 index 00000000000..9d5c6827437 --- /dev/null +++ b/tests/compile/backend.py @@ -0,0 +1,33 @@ +from copy import deepcopy +from typing import Callable + +import torch + + +class TestBackend: + """ + This class provides a simple Inductor backend that can be used for testing. + It takes a list of custom passes and runs them after Inductor's passes. + It also saves the graph before and after the custom passes for inspection. + """ + + def __init__(self, *args: Callable[[torch.fx.Graph], None]): + self.custom_passes = args + from torch._inductor import config + self.current_config = config.shallow_copy_dict() + self.current_config['post_grad_custom_post_pass'] = self.post_pass + + def __call__(self, graph: torch.fx.GraphModule, example_inputs): + from torch._inductor.compile_fx import compile_fx + return compile_fx(graph, + example_inputs, + config_patches=self.current_config) + + def post_pass(self, graph: torch.fx.Graph): + self.graph_pre_pass = deepcopy(graph) + for pass_ in self.custom_passes: + pass_(graph) + + self.graph_post_pass = deepcopy(graph) + # assign by reference, will reflect the final state of the graph + self.final_graph = graph diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py new file mode 100644 index 00000000000..e4d3defafb9 --- /dev/null +++ b/tests/compile/test_fusion.py @@ -0,0 +1,92 @@ +import pytest +import torch +from compressed_tensors.quantization import FP8_DTYPE + +import vllm.envs as envs +from vllm.compilation.config import CompilationConfig +from vllm.compilation.fusion import (FusionPass, find_auto_fn, + find_auto_fn_maybe) +from vllm.compilation.reshapes import RedundantReshapesPass +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + apply_fp8_linear) + +from .backend import TestBackend + + +class TestModel(torch.nn.Module): + + def __init__(self, hidden_size: int, eps: float, *args, **kwargs): + super().__init__(*args, **kwargs) + self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] + self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(4)] + self.w = [ + torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() + for _ in range(2) + ] + + def forward(self, x): + resid = torch.relu(x) + y = self.norm[0](x) + + x2 = apply_fp8_linear(y, self.w[0], self.scale[0], self.scale[1]) + # make sure resid is used for replacement to work + y2, resid = self.norm[1](x2, resid) + + x3 = apply_fp8_linear(y2, self.w[1], self.scale[2], self.scale[3]) + y3, resid = self.norm[2](x3, resid) # use resid here + return y3 + + +# Init does pattern registration, which can only happen once +config = CompilationConfig(enable_fusion=True) +reshape_pass = RedundantReshapesPass(config) +fusion_pass = FusionPass.instance(config) + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("hidden_size", [64, 3392, 4096]) +@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049]) +@pytest.mark.parametrize("eps", [1e-5, 1e-6]) +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", + reason="Only test on CUDA") +def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps): + torch.set_default_device("cuda") + torch.set_default_dtype(torch.float16) + + if eps != 1e-5: + pytest.skip("Only test eps=1e-5 for now") + + # Reshape pass is needed for the fusion pass to work + backend = TestBackend(reshape_pass, fusion_pass) + model = TestModel(hidden_size, eps) + + # First dimension dynamic + x = torch.rand(num_tokens, hidden_size) + torch._dynamo.mark_dynamic(x, 0) + + result = model(x) + + model2 = torch.compile(model, backend=backend) + result2 = model2(x) + + # Check that it gives the same answer + torch.testing.assert_close(result, result2, atol=1e-3, rtol=1e-3) + + # Check substitution worked + pre_nodes = backend.graph_pre_pass.nodes + post_nodes = backend.graph_post_pass.nodes + + rms_quant = torch.ops._C.rms_norm_static_fp8_quant.default + add_rms_quant = torch.ops._C.fused_add_rms_norm_static_fp8_quant.default + fp8_quant = torch.ops._C.static_scaled_fp8_quant.default + + # In pre-nodes, fp8 quant should be present and fused kernels should not + assert find_auto_fn_maybe(pre_nodes, rms_quant) is None + assert find_auto_fn_maybe(pre_nodes, add_rms_quant) is None + find_auto_fn(pre_nodes, fp8_quant) + + # In post-nodes, fused kernels should be present and fp8 quant should not + find_auto_fn(post_nodes, rms_quant) + find_auto_fn(post_nodes, add_rms_quant) + assert find_auto_fn_maybe(post_nodes, fp8_quant) is None diff --git a/tests/kernels/test_layernorm.py b/tests/kernels/test_layernorm.py index 9dfa2cbe45e..727769e0718 100644 --- a/tests/kernels/test_layernorm.py +++ b/tests/kernels/test_layernorm.py @@ -1,13 +1,14 @@ import pytest import torch +from tests.kernels.quant_utils import FP8_DTYPE from tests.kernels.utils import opcheck from vllm.model_executor.layers.layernorm import RMSNorm from vllm.platforms import current_platform DTYPES = [torch.half, torch.bfloat16, torch.float] NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing -HIDDEN_SIZES = [768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192, +HIDDEN_SIZES = [8, 768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192, 8199] # Arbitrary values for testing ADD_RESIDUAL = [False, True] SEEDS = [0] @@ -59,3 +60,75 @@ def test_rms_norm( else: opcheck(torch.ops._C.rms_norm, (out, x, layer.weight.data, layer.variance_epsilon)) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("add_residual", ADD_RESIDUAL) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("quant_scale", [1.0, 0.01, 10.0]) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_fused_rms_norm_quant( + num_tokens: int, + hidden_size: int, + add_residual: bool, + dtype: torch.dtype, + quant_scale: float, + seed: int, + device: str, +) -> None: + current_platform.seed_everything(seed) + torch.set_default_device(device) + + weight = torch.empty(hidden_size, dtype=dtype).normal_(mean=1.0, std=0.1) + scale = 1 / (2 * hidden_size) + x = torch.randn(num_tokens, hidden_size, dtype=dtype) + x *= scale + if add_residual: + residual = torch.randn_like(x) * scale + residual_fused = residual.clone() + else: + residual = residual_fused = None + + out_norm = torch.empty_like(x) + out_quant = torch.empty_like(x, dtype=FP8_DTYPE) + out_quant_fused = torch.empty_like(out_quant) + + quant_scale_t = torch.tensor(quant_scale, dtype=torch.float32) + + if add_residual: + torch.ops._C.fused_add_rms_norm_static_fp8_quant( + out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6) + + # Unfused kernel is in-place so it goes second + # Also use a separate clone of x to avoid modifying the input + x_unfused = x.clone() + torch.ops._C.fused_add_rms_norm(x_unfused, residual, weight, 1e-6) + torch.ops._C.static_scaled_fp8_quant(out_quant, x_unfused, + quant_scale_t) + + torch.cuda.synchronize() + torch.testing.assert_close(residual_fused, + residual, + atol=1e-2, + rtol=1e-2) + + opcheck( + torch.ops._C.fused_add_rms_norm_static_fp8_quant, + (out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6)) + else: + torch.ops._C.rms_norm_static_fp8_quant(out_quant_fused, x, weight, + quant_scale_t, 1e-6) + + torch.ops._C.rms_norm(out_norm, x, weight, 1e-6) + torch.ops._C.static_scaled_fp8_quant(out_quant, out_norm, + quant_scale_t) + + opcheck(torch.ops._C.rms_norm_static_fp8_quant, + (out_quant_fused, x, weight, quant_scale_t, 1e-6)) + + torch.testing.assert_close(out_quant_fused.to(dtype=torch.float32), + out_quant.to(dtype=torch.float32), + atol=1e-3, + rtol=1e-3) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index abd1d16acca..f5fff344a1f 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -2,7 +2,8 @@ import dataclasses import operator from contextlib import ExitStack -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, + Union) from unittest.mock import patch import torch @@ -10,11 +11,13 @@ import vllm.envs as envs from vllm.logger import init_logger -from vllm.utils import weak_ref_tensors +from vllm.utils import combine_fx_passes, weak_ref_tensors from .config import CompilationConfig from .counter import compilation_counter +from .fusion import FusionPass from .levels import CompilationLevel +from .reshapes import RedundantReshapesPass logger = init_logger(__name__) @@ -99,28 +102,74 @@ def fix_functionalization(graph: fx.Graph): user.replace_all_uses_with(replace_node) nodes_to_remove.append(user) nodes_to_remove.append(node) + elif (node.args[0] == + torch.ops._C.fused_add_rms_norm_static_fp8_quant.default): + # manual replace for fused_add_rms_norm_static_fp8_quant + # this is the most effective optimization for llama + # failing to do this will result in many unnecessary copies + + kwargs = node.kwargs + + result = kwargs['result'] + residual = kwargs['residual'] + + # Create a new call to + # torch.ops._C.fused_add_rms_norm_static_fp8_quant.default + with graph.inserting_before(node): + # just insert the call to the custom op + # NOTE: don't run dead code elimination, + # otherwise this op will be removed + graph.call_function( + torch.ops._C.fused_add_rms_norm_static_fp8_quant. + default, + kwargs=kwargs) + + for user in list(node.users): + if user.op == 'call_function' and user.target == operator.getitem: # noqa + # Remove the getitem node + if user.args[1] == 1: + replace_node = result + elif user.args[1] == 2: + replace_node = residual + user.replace_all_uses_with(replace_node) + nodes_to_remove.append(user) + nodes_to_remove.append(node) elif node.args[0] == torch.ops._C.rms_norm.default: # manual replace for rms_norm kwargs = node.kwargs - input = kwargs['input'] - out = kwargs['out'] - weight = kwargs['weight'] - epsilon = kwargs['epsilon'] - # Create a new call to torch.ops._C.rotary_embedding.default - # cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa + replace_node = kwargs['result'] + # Create a new call to torch.ops._C.rms_norm.default with graph.inserting_before(node): # just insert the call to the custom op # NOTE: don't run dead code elimination, # otherwise this op will be removed - graph.call_function( - torch.ops._C.rms_norm.default, - args=(out, input, weight, epsilon), - ) + graph.call_function(torch.ops._C.rms_norm.default, + kwargs=kwargs) - replace_node = out + for user in list(node.users): + if user.op == 'call_function' and user.target == operator.getitem: # noqa + user.replace_all_uses_with(replace_node) + nodes_to_remove.append(user) + nodes_to_remove.append(node) + + elif node.args[ + 0] == torch.ops._C.rms_norm_static_fp8_quant.default: # noqa + # manual replace for rms_norm_static_fp8_quant + + kwargs = node.kwargs + + replace_node = kwargs['result'] + # Create a new call to torch.ops._C.rms_norm_static_fp8_quant.default # noqa + with graph.inserting_before(node): + # just insert the call to the custom op + # NOTE: don't run dead code elimination, + # otherwise this op will be removed + graph.call_function( + torch.ops._C.rms_norm_static_fp8_quant.default, + kwargs=kwargs) for user in list(node.users): if user.op == 'call_function' and user.target == operator.getitem: # noqa @@ -136,7 +185,7 @@ def fix_functionalization(graph: fx.Graph): input = kwargs['input'] out = kwargs['out'] - # Create a new call to torch.ops._C.rotary_embedding.default + # Create a new call to torch.ops._C.silu_and_mul.default # cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa with graph.inserting_before(node): # just insert the call to the custom op @@ -319,6 +368,13 @@ class VllmBackend: The major work of this backend is to split the graph into piecewise graphs, and pass them to the piecewise backend. + + This backend also handles custom passes and adds them to Inductor config. + The order of the post-grad post-passes is: + 1. post_grad_passes (constructor parameter) + 2. config["post_grad_custom_post_pass"] + 3. fix_functionalization + This way, all passes operate on a functionalized graph. """ compilation_configs: CompilationConfig @@ -330,8 +386,10 @@ class VllmBackend: split_gm: fx.GraphModule piecewise_graphs: List[SplitItem] returned_callable: Callable + # Inductor passes to run on the graph pre-defunctionalization + post_grad_passes: Sequence[Callable] - def __init__(self, ): + def __init__(self, post_grad_passes: Sequence[Callable] = ()): global global_graph_pool if global_graph_pool is None: global_graph_pool = torch.cuda.graph_pool_handle() @@ -340,10 +398,30 @@ def __init__(self, ): # streams, it might not be safe to share a global pool. # only investigate this when we use multiple streams self.graph_pool = global_graph_pool + self.post_grad_passes = post_grad_passes # `torch.compile` is JIT compiled, so we don't need to # do anything here + def add_passes_to_config(self): + config = self.compilation_configs + passes = list(self.post_grad_passes) + + passes = passes + [RedundantReshapesPass(config)] + + if config.enable_fusion: + passes = passes + [FusionPass.instance(config)] + + inductor_config = config.inductor_compile_config + if "post_grad_custom_post_pass" in inductor_config: + passes = passes + [inductor_config["post_grad_custom_post_pass"]] + + # add the fix_functionalization pass last, so that all other + # passes operate on a functionalized graph + passes = passes + [fix_functionalization] + combined_pass = combine_fx_passes(passes) + inductor_config["post_grad_custom_post_pass"] = combined_pass + def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: compilation_counter.num_graphs_seen += 1 @@ -357,6 +435,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: # we get the sizes to capture for cudagraph # from compilation context self.compilation_configs = CompilationConfig.select_and_init_config() + self.add_passes_to_config() self.split_gm, self.piecewise_graphs = split_graph( graph, self.compilation_configs.non_cudagraph_ops) diff --git a/vllm/compilation/config.py b/vllm/compilation/config.py index 514f2b93ef6..72377533140 100644 --- a/vllm/compilation/config.py +++ b/vllm/compilation/config.py @@ -1,4 +1,5 @@ import copy +from pathlib import Path from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field, PrivateAttr @@ -50,6 +51,12 @@ class CompilationConfig(BaseModel): name because the config uses json format. If we pass the config from Python, functions can also be passed directly via Python object constructor, e.g. `CompilationConfig(inductor_passes={"a": func})` + - Custom inductor passes: + - dump_graph_stages: list of stages for which we want to dump the graph. + Each pass defines its own stages (before, after, maybe in-between). + - dump_graph_dir: directory to dump the graph. Default is . + - enable_fusion: whether to enable the custom fusion pass. + TODO better pass enabling system. Why we have different sizes for cudagraph and inductor: - cudagraph: a cudagraph captured for a specific size can only be used @@ -72,6 +79,10 @@ class CompilationConfig(BaseModel): cudagraph_num_of_warmups: int = 0 cudagraph_capture_sizes: Optional[List[int]] = None + dump_graph_stages: List[str] = Field(default_factory=list) + dump_graph_dir: Path = Field(default=Path(".")) + enable_fusion: bool = True + # not configurable, computed after init compile_sizes: List[int] = PrivateAttr capture_sizes: List[int] = PrivateAttr @@ -81,7 +92,7 @@ def model_post_init(self, __context: Any) -> None: if not isinstance(v, str): assert callable(v), ( f"pass {k} should be a function or a qualified name") - self.inductor_passes[k] = v + self.inductor_compile_config[k] = v continue # resolve function from qualified name @@ -91,18 +102,6 @@ def model_post_init(self, __context: Any) -> None: func = __import__(module).__dict__[func_name] self.inductor_compile_config[k] = func - from vllm.compilation.backends import fix_functionalization - from vllm.utils import combine_fx_passes - if "post_grad_custom_post_pass" in self.inductor_compile_config: - self.inductor_compile_config[ - "post_grad_custom_post_pass"] = combine_fx_passes( - fix_functionalization, - self.inductor_compile_config["post_grad_custom_post_pass"], - ) - else: - self.inductor_compile_config[ - "post_grad_custom_post_pass"] = fix_functionalization - def init_during_runtime(self): """To complete the initialization of config, we need to know the compile context, which is only available diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py new file mode 100644 index 00000000000..2a0cf0002c9 --- /dev/null +++ b/vllm/compilation/fusion.py @@ -0,0 +1,291 @@ +import operator +from typing import Iterable, List, Optional + +import torch +from torch._higher_order_ops.auto_functionalize import auto_functionalized +from torch._inductor.pattern_matcher import (Match, PatternMatcherPass, + fwd_only, register_replacement) + +from vllm.compilation.config import CompilationConfig +from vllm.compilation.inductor_pass import InductorPass +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def rms_pattern_static(result: torch.Tensor, result_rms: torch.Tensor, + input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at1 = auto_functionalized(torch.ops._C.rms_norm.default, + result=result_rms, + input=input, + weight=weight, + epsilon=1e-5) + at2 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default, + result=result, + input=at1[1], + scale=scale) + + # result + return at2[1] + + +def rms_replacement_static(result: torch.Tensor, result_rms: torch.Tensor, + input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized(torch.ops._C.rms_norm_static_fp8_quant.default, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=1e-5) + + # result + return at[1] + + +def rms_pattern_residual_static(result: torch.Tensor, input: torch.Tensor, + residual: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized(torch.ops._C.fused_add_rms_norm.default, + input=input, + residual=residual, + weight=weight, + epsilon=1e-5) + at1 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default, + result=result, + input=at[1], + scale=scale) + + # result, residual + return at1[1], at[2] + + +def rms_replacement_residual_static(result: torch.Tensor, input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, scale: torch.Tensor): + at = auto_functionalized( + torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, + result=result, + input=input, + residual=residual, + weight=weight, + scale=scale, + epsilon=1e-5) + # result, residual + return at[1], at[2] + + +def empty_bf16(*args, **kwargs): + return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda") + + +def empty_fp8(*args, **kwargs): + fp8 = torch.float8_e4m3fn + return torch.empty(*args, **kwargs, dtype=fp8, device="cuda") + + +def empty_fp32(*args, **kwargs): + return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda") + + +# Utilities for post-processing multi-output matches +def is_func(node: torch.fx.Node, target) -> bool: + return node.op == "call_function" and node.target == target + + +# Returns the first auto_functionalized node with the given op (if it exists) +def find_auto_fn_maybe(nodes: Iterable[torch.fx.Node], + op) -> Optional[torch.fx.Node]: + for node in nodes: + if is_func(node, auto_functionalized) and node.args[0] == op: # noqa + return node + return None + + +# Returns the first auto_functionalized node with the given op +def find_auto_fn(nodes: Iterable[torch.fx.Node], op) -> torch.fx.Node: + node = find_auto_fn_maybe(nodes, op) + assert node is not None, f"Could not find {op} in nodes {nodes}" + return node + + +# Returns the getitem node that extracts the idx-th element from node +# (if it exists) +def find_getitem_maybe(node: torch.fx.Node, + idx: int) -> Optional[torch.fx.Node]: + for user in node.users: + if is_func(user, operator.getitem) and user.args[1] == idx: + return user + return None + + +# Returns the getitem node that extracts the idx-th element from node +def find_getitem(node: torch.fx.Node, idx: int) -> torch.fx.Node: + ret = find_getitem_maybe(node, idx) + assert ret is not None, f"Could not find getitem {idx} in node {node}" + return ret + + +class FusionPass(InductorPass): + """ + This pass fuses a pre-defined set of custom ops into fused ops. + It uses the torch pattern matcher to find the patterns and replace them. + It also manually processes multi-output matches, as those are broken in + the torch pattern matcher. + + Because patterns can only be registered once, the pass is a singleton. + This will be addressed in a future version of PyTorch: + https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980 + """ + + _instance: 'Optional[FusionPass]' = None + + @classmethod + def instance(cls, config: CompilationConfig): + """ + Get the singleton instance of the FusionPass. + If the instance exists, the config is updated but + initialization is not repeated. + """ + if cls._instance is None: + cls._instance = FusionPass(config) + else: + cls._instance.config = config + return cls._instance + + def __init__(self, config: CompilationConfig): + assert self.__class__._instance is None, \ + "FusionPass singleton instance already exists" + super().__init__(config) + + self.matches: List[Match] = [] + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="fusion_pass") + + # Fuse rms_norm + static_scaled_fp8_quant into + # rms_norm_static_fp8_quant + inputs = [ + empty_fp8(5, 4), + empty_bf16(5, 4), + empty_bf16(5, 4), + empty_bf16(1, 5), + empty_fp32(1, 1) + ] + register_replacement(rms_pattern_static, rms_replacement_static, + inputs, fwd_only, self.patterns) + + # Fuse fused_add_rms_norm + static_scaled_fp8_quant into + # fused_add_rms_norm_static_fp8_quant + # Because pattern has 2 outputs, we need to manually process the match + # (see process_matches) + inputs = [ + empty_fp8(5, 4), + empty_bf16(5, 4), + empty_bf16(5, 4), + empty_bf16(1, 5), + empty_fp32(1, 1) + ] + register_replacement(rms_pattern_residual_static, + rms_replacement_residual_static, + inputs, + fwd_only, + self.patterns, + extra_check=lambda m: self.record_match(m)) + + def record_match(self, match: Match) -> bool: + # Hijack the extra_check to record the match and + # save it for post-processing. + self.matches.append(match) + + # Return False to prevent automatic replacement. + return False + + def process_matches(self, graph: torch.fx.Graph): + """ + Manually process multi-output matches and replace them with fused nodes. + This is necessary because the automatic replacement for multi-output + matches is broken: https://github.com/pytorch/pytorch/issues/137280 + """ + for match in self.matches: + # To avoid use-before-definition errors, insert replacement nodes + # after the last node in the match. + # match.nodes is not guaranteed to be sorted. + # Find the last node in the match. + for last_node_in_match in reversed(graph.nodes): + if last_node_in_match in match.nodes: + break + else: + raise ValueError("No nodes in graph") + + # Insert a new auto_functionalized node for the fused operation, + # as well as getitem nodes to extract the result and residual. + # The auto_functionalized node returns a tuple of + # (None, result, residual) - None is the function return value. + # The resulting graph looks like this: + # at = auto_functionalized(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ...) # noqa + # result_node_new = at[1] + # residual_node_new = at[2] + with graph.inserting_after(last_node_in_match): + kwargs = match.kwargs + kwargs["epsilon"] = 1e-5 # Currently hard-coded in RMSNorm + + fused_node = graph.call_function( + auto_functionalized, + (torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, + ), + kwargs=kwargs) + + graph.inserting_after(fused_node) + result_node_new = graph.call_function(operator.getitem, + (fused_node, 1)) + residual_node_new = graph.call_function( + operator.getitem, (fused_node, 2)) + + # Last part of replacement is rebinding the users of nodes in the + # match to use the new nodes. + + # Find the nodes in the match that we need to rebind + rms_node = find_auto_fn(match.nodes, + torch.ops._C.fused_add_rms_norm.default) + quant_node = find_auto_fn( + match.nodes, torch.ops._C.static_scaled_fp8_quant.default) + + assert len(rms_node.users) == 2 + assert len(quant_node.users) == 1 + + # meta["val"] is used by de-functionalization and has to contain the + # value of the node (tuple of tensors) that would be returned by the + # functionalized node during tracing. + + rms_tup = rms_node.meta["val"] + quant_tup = quant_node.meta["val"] + + # The result of fused_node must be a tuple with the first element + # None (the function return value) and the remaining elements + # representing the mutated inputs. + fused_tup = (None, quant_tup[1], rms_tup[1], rms_tup[2]) + fused_node.meta["val"] = fused_tup + + # Find the getitem nodes and replace their uses with the new nodes. + # The old nodes will be removed by DCE at the end of the pass. + find_getitem(rms_node, 2).replace_all_uses_with(residual_node_new) + find_getitem(quant_node, 1).replace_all_uses_with(result_node_new) + + # Finally, remove matched nodes + graph.eliminate_dead_code() + assert all(node not in graph.nodes for match in self.matches + for node in match.nodes) + + def __call__(self, graph: torch.fx.Graph): + self.dump_graph(graph, "before_fusion") + + count = self.patterns.apply(graph) + logger.info("Replaced %s patterns", count) + self.dump_graph(graph, "after_pattern_match") + + # Manually process multi-output matches (and run DCE) + self.process_matches(graph) + logger.info("Post-processed %s matches", len(self.matches)) + self.dump_graph(graph, "after_fusion") + self.matches.clear() diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py new file mode 100644 index 00000000000..b23351fa197 --- /dev/null +++ b/vllm/compilation/inductor_pass.py @@ -0,0 +1,38 @@ +from abc import ABC, abstractmethod + +import torch + +from vllm.compilation.config import CompilationConfig +# yapf: disable +from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank +from vllm.distributed import ( + get_tensor_model_parallel_world_size as get_tp_world_size) +from vllm.distributed import model_parallel_is_initialized as p_is_init +# yapf: enable +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class InductorPass(ABC): + + @abstractmethod + def __call__(self, graph: torch.fx.Graph): + raise NotImplementedError + + def __init__(self, config: CompilationConfig): + self.config = config + + def dump_graph(self, graph: torch.fx.Graph, stage: str): + if stage in self.config.dump_graph_stages: + # Make sure filename includes rank in the distributed setting + parallel = p_is_init() and get_tp_world_size() > 1 + rank = f"-{get_tp_rank()}" if parallel else "" + filepath = self.config.dump_graph_dir / f"{stage}{rank}.py" + + logger.info("Printing graph to %s", filepath) + with open(filepath, "w") as f: + src = graph.python_code(root_module="self", verbose=True).src + # Add imports so it's not full of errors + print("import torch; from torch import device", file=f) + print(src, file=f) diff --git a/vllm/compilation/reshapes.py b/vllm/compilation/reshapes.py new file mode 100644 index 00000000000..0d284246d25 --- /dev/null +++ b/vllm/compilation/reshapes.py @@ -0,0 +1,85 @@ +from typing import Union + +import torch.fx +from torch import SymInt + +from vllm.compilation.fusion import is_func +from vllm.compilation.inductor_pass import InductorPass +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class RedundantReshapesPass(InductorPass): + """ + This is an inductor pass that removes redundant reshape operations. + It is required for RMSNorm-quant fusion to work properly. + That's because apply_fp8_linear adds a reshape, which is redundant + in the 2D-case. + + Example graph: + + getitem_1: "f16[s0, 4096]" = ... + view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096]) + at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...) + out: "f8e4m3fn[s0, 4096]" = at[1] + + Can be replaced with: + getitem_1: "f16[s0, 4096]" = ... + at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...) + out: "f8e4m3fn[s0, 4096]" = at[1] + """ + + def __call__(self, graph: torch.fx.Graph): + self.dump_graph(graph, "before_reshapes") + count = 0 + # Remove no-op reshapes/views: + for node in graph.nodes: + if is_func(node, torch.ops.aten.reshape.default): + input, shape = node.args[:2] + input_shape = input.meta["val"].shape + if len(shape) != len(input_shape): + # Reshape changing rank, skip + continue + + if shape.count(-1) > 1: + # Invalid reshape args, skip + continue + + if all( + self.dims_equivalent(s, i_s) + for s, i_s in zip(shape, input_shape)): + node.replace_all_uses_with(input) + graph.erase_node(node) + count += 1 + + logger.info("Removed %s no-op reshapes", count) + + self.dump_graph(graph, "after_reshapes") + + def dims_equivalent(self, dim: Union[int, torch.fx.Node], + i_dim: Union[int, SymInt]) -> bool: + """ + This function checks if two dimensions are equivalent. + :param dim: The dimension arg to reshape + :param i_dim: The corresponding dimension in the input tensor + :return: Are the dimensions equivalent? + + There are three cases in which the dimensions are equivalent: + 1. The dimensions are equal (both integers) + 2. The reshape dimension is -1 (i.e. inferred) + 3. The dimensions both correspond to the same SymInt + + While case 2 does not guarantee the dimensions are equal, + they are equal if all other dimensions are equal. + + In case 3, the reshape dimension is a torch.fx.Node, + and its value is a SymInt. That value is equal to the + input dimension. + + """ + # Case 1 and 2 + if dim == i_dim or dim == -1: + return True + # Case 3 + return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim diff --git a/vllm/envs.py b/vllm/envs.py index 9e596a699e4..154246c69f1 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -68,6 +68,7 @@ VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_SKIP_P2P_CHECK: bool = False VLLM_TORCH_COMPILE_LEVEL: int = 0 + VLLM_TORCH_COMPILE_CONFIG: Optional[str] = None VLLM_CUSTOM_OPS: List[str] = [] VLLM_DISABLED_KERNELS: List[str] = [] VLLM_USE_V1: bool = False @@ -226,6 +227,7 @@ def get_default_config_root(): # and disabled when running with Inductor (compile_level >= Inductor). "VLLM_CUSTOM_OPS": lambda: os.environ.get("VLLM_CUSTOM_OPS", "").replace(" ", "").split(","), + # local rank of the process in the distributed setting, used to determine # the GPU device id "LOCAL_RANK":