Skip to content

[torch.compile] Fuse RMSNorm with quant #9138

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 104 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from 84 commits
Commits
Show all changes
104 commits
Select commit Hold shift + click to select a range
d5c329d
adapt for dynamo
youkaichao Oct 3, 2024
12e29fe
fix tpu
youkaichao Oct 3, 2024
504bd6c
add backend
youkaichao Oct 3, 2024
6353613
add use_custom_dispatcher
youkaichao Oct 3, 2024
77ae8e7
update wrapper
youkaichao Oct 3, 2024
4d99a58
update envs
youkaichao Oct 3, 2024
2b79376
update custom op
youkaichao Oct 3, 2024
7dfddcd
support llama
youkaichao Oct 3, 2024
abd1a65
update plugins
youkaichao Oct 3, 2024
ce1907f
update model runner
youkaichao Oct 3, 2024
e1ea867
add support
youkaichao Oct 3, 2024
511e07b
add files
youkaichao Oct 3, 2024
3bb8950
fix not use_custom_dispatcher
youkaichao Oct 4, 2024
c4d7189
Merge branch 'main' into compile_integration
youkaichao Oct 5, 2024
ed573fa
do not test inductor
youkaichao Oct 5, 2024
93ef0b5
add compile context
youkaichao Oct 5, 2024
3cd40db
remove model reference
youkaichao Oct 5, 2024
4e28930
lint
youkaichao Oct 5, 2024
2ac7274
change levels
youkaichao Oct 7, 2024
34fe820
Merge branch 'main' into compile_integration
youkaichao Oct 8, 2024
a3c947e
add levels
youkaichao Oct 8, 2024
1a41c57
use const
youkaichao Oct 8, 2024
db61567
use const
youkaichao Oct 8, 2024
275ede9
use const
youkaichao Oct 8, 2024
d1f084d
use const
youkaichao Oct 8, 2024
326c5b4
use const
youkaichao Oct 8, 2024
9b7b0f3
use const
youkaichao Oct 8, 2024
9cfa70c
use const
youkaichao Oct 8, 2024
e819be7
use const
youkaichao Oct 8, 2024
d9cb162
use const
youkaichao Oct 8, 2024
825f384
use const
youkaichao Oct 8, 2024
c785fc8
use const
youkaichao Oct 8, 2024
28e9f6f
restore
youkaichao Oct 8, 2024
718c5e4
use const
youkaichao Oct 8, 2024
03081cd
use const
youkaichao Oct 8, 2024
fbac08d
error on inductor for tpu
youkaichao Oct 8, 2024
3c688ea
fix llava
youkaichao Oct 8, 2024
32676f8
restore tpu
youkaichao Oct 8, 2024
5ae34df
Merge branch 'main' into compile_integration
youkaichao Oct 8, 2024
3ed89da
adjust for tpu
youkaichao Oct 8, 2024
a3c3e21
fix env var
youkaichao Oct 8, 2024
30ff04f
fix calling
youkaichao Oct 8, 2024
13256c4
revert tpu
youkaichao Oct 8, 2024
bf0e935
revert utils
youkaichao Oct 8, 2024
39571c5
fix typo
youkaichao Oct 8, 2024
e3aea56
add typing
youkaichao Oct 8, 2024
6181795
move DYNAMO_AS_IS to model runner level
youkaichao Oct 8, 2024
1a80a7b
fix default context
youkaichao Oct 8, 2024
92d240b
use eager for DYNAMO_AS_IS by default
youkaichao Oct 8, 2024
f4b0f50
update tests
youkaichao Oct 8, 2024
896431a
update tests
youkaichao Oct 8, 2024
388d563
llava uses fullgraph=false
youkaichao Oct 8, 2024
3642b77
Merge branch 'main' into compile_integration
youkaichao Oct 9, 2024
3e3ea58
Merge branch 'main' into compile_integration
youkaichao Oct 10, 2024
ce7cd8e
disable tests first
youkaichao Oct 10, 2024
828e425
RMSNorm fusion working!
ProExpertProg Oct 4, 2024
88d1379
fused with bug
ProExpertProg Oct 4, 2024
deeaef3
Use pattern matcher to match, replace manually, giving correct output
ProExpertProg Oct 7, 2024
927d2dd
add quant_layernorm kernel (not modified yet)
ProExpertProg Oct 8, 2024
fc3fde6
fixes
ProExpertProg Oct 8, 2024
ef8e0f5
out -> result for fp8 quant ops
ProExpertProg Oct 8, 2024
c7d3d18
change int8 to fp8
ProExpertProg Oct 8, 2024
9d0bf7f
Added layernorm_quant kernels for static fp8 quant, including tests. …
ProExpertProg Oct 9, 2024
d33c179
fix_functionalization for layernorm_quant kernels
ProExpertProg Oct 9, 2024
f7ac7ef
env var for disabling fusion
ProExpertProg Oct 9, 2024
36e8938
Fix for fusion assert
ProExpertProg Oct 9, 2024
733d9f4
Clean up fusion.py
ProExpertProg Oct 10, 2024
ab41d84
Merge branch 'main' into compile_integration
youkaichao Oct 10, 2024
d1f8ae8
add supports_dynamo in the decorator
youkaichao Oct 10, 2024
8dbff55
Merge branch 'compile_integration' into luka/rms-norm-fusion-kaichao
ProExpertProg Oct 10, 2024
7477c2f
Merge remote-tracking branch 'upstream/main' into luka/rms-norm-fusion
ProExpertProg Oct 11, 2024
5073da7
fix example_inputs dtype
ProExpertProg Oct 11, 2024
71379be
extract common type conversion stuff to .cuh file
ProExpertProg Oct 17, 2024
da75630
Merge remote-tracking branch 'refs/remotes/upstream/main' into luka/r…
ProExpertProg Oct 17, 2024
f33d59b
format
ProExpertProg Oct 17, 2024
b053c0b
PR comments
ProExpertProg Oct 21, 2024
0de6baa
refactored fusion pass into a class
ProExpertProg Oct 22, 2024
f3e7d31
PR comments: backends.py
ProExpertProg Oct 22, 2024
95f8985
Merge branch 'main' into luka/rms-norm-fusion
ProExpertProg Oct 22, 2024
b2ab033
Fix node bug in find_getitem
ProExpertProg Oct 22, 2024
86b79dd
PR comments:
ProExpertProg Oct 22, 2024
e3d3f09
PR comments:
ProExpertProg Oct 22, 2024
a40aba7
yapf fix
ProExpertProg Oct 22, 2024
70fb2fe
Merge remote-tracking branch 'upstream/main' into luka/rms-norm-fusion
ProExpertProg Oct 23, 2024
46420f0
Unit test for rmsnorm-quant fusion
ProExpertProg Oct 25, 2024
a1c3d91
Skip on non-CUDA
ProExpertProg Oct 25, 2024
7986e07
Merge remote-tracking branch 'upstream/main' into luka/rms-norm-fusion
ProExpertProg Oct 28, 2024
bc5d6ba
Fix FP8 HIP type in common.cuh
ProExpertProg Oct 29, 2024
c245f54
Merge remote-tracking branch 'upstream/main' into luka/rms-norm-fusion
ProExpertProg Oct 29, 2024
1966e6a
Fix seed_everything
ProExpertProg Oct 29, 2024
0dff724
Merge remote-tracking branch 'upstream/main' into luka/rms-norm-fusion
ProExpertProg Oct 30, 2024
980b56d
Add support for passes to VllmBackend, add fusion back in
ProExpertProg Oct 30, 2024
863d657
Merge remote-tracking branch 'upstream/main' into luka/rms-norm-fusion
ProExpertProg Oct 31, 2024
8b2def5
PR comments:
ProExpertProg Oct 31, 2024
619d634
Fix fusion pass init in test
ProExpertProg Oct 31, 2024
f47e358
Fusion test: use apply_fp8_linear
ProExpertProg Oct 31, 2024
a252997
Add redundant reshapes removal pass.
ProExpertProg Oct 31, 2024
1b9717f
Fix graph dumping when TP not initialized
ProExpertProg Oct 31, 2024
daca890
Reshape add edge-cases
ProExpertProg Oct 31, 2024
429db0a
Singleton pattern matcher for fusion pass
ProExpertProg Oct 31, 2024
e0b904e
singleton fusion pass
ProExpertProg Nov 8, 2024
d73933b
Merge remote-tracking branch 'upstream/main' into luka/rms-norm-fusion
ProExpertProg Nov 8, 2024
d9375df
format
ProExpertProg Nov 8, 2024
d0a9e37
Add print
ProExpertProg Nov 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ set(VLLM_EXT_SRC
"csrc/pos_encoding_kernels.cu"
"csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/layernorm_quant_kernels.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/common.cu"
Expand Down
165 changes: 4 additions & 161 deletions csrc/layernorm_kernels.cu
Original file line number Diff line number Diff line change
@@ -1,21 +1,13 @@
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include "type_convert.cuh"
#include "dispatch_utils.h"

#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>

#include "dispatch_utils.h"
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cub/util_type.cuh>
#include <cub/cub.cuh>
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp>

using __nv_bfloat16 = __hip_bfloat16;
using __nv_bfloat162 = __hip_bfloat162;
#endif

namespace vllm {
Expand Down Expand Up @@ -51,155 +43,6 @@ __global__ void rms_norm_kernel(
}
}

/* Converter structs for the conversion from torch types to HIP/CUDA types,
and the associated type conversions within HIP/CUDA. These helpers need
to be implemented for now because the relevant type conversion
operators/constructors are not consistently implemented by HIP/CUDA, so
a generic conversion via type casts cannot be implemented.

Each struct should have the member static constexpr bool `exists`:
If false, the optimized kernel is not used for the corresponding torch type.
If true, the struct should be fully defined as shown in the examples below.
*/
template <typename torch_type>
struct _typeConvert {
static constexpr bool exists = false;
};

#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
// CUDA < 12.0 runs into issues with packed type conversion
template <>
struct _typeConvert<c10::Half> {
static constexpr bool exists = true;
using hip_type = __half;
using packed_hip_type = __half2;

__device__ static inline float convert(hip_type x) { return __half2float(x); }
__device__ static inline float2 convert(packed_hip_type x) {
return __half22float2(x);
}
__device__ static inline hip_type convert(float x) {
return __float2half_rn(x);
}
__device__ static inline packed_hip_type convert(float2 x) {
return __float22half2_rn(x);
}
};

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// CUDA_ARCH < 800 does not have BF16 support
// TODO: Add in ROCm support once public headers handle bf16 maturely
template <>
struct _typeConvert<c10::BFloat16> {
static constexpr bool exists = true;
using hip_type = __nv_bfloat16;
using packed_hip_type = __nv_bfloat162;

__device__ static inline float convert(hip_type x) {
return __bfloat162float(x);
}
__device__ static inline float2 convert(packed_hip_type x) {
return __bfloat1622float2(x);
}
__device__ static inline hip_type convert(float x) {
return __float2bfloat16(x);
}
__device__ static inline packed_hip_type convert(float2 x) {
return __float22bfloat162_rn(x);
}
};
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
// 12000))

/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
for appropriate specializations of fused_add_rms_norm_kernel.
Only functions that are necessary in that kernel are implemented.
Alignment to 16 bytes is required to use 128-bit global memory ops.
*/
template <typename scalar_t, int width>
struct alignas(16) _f16Vec {
/* Not theoretically necessary that width is a power of 2 but should
almost always be the case for optimization purposes */
static_assert(width > 0 && (width & (width - 1)) == 0,
"Width is not a positive power of 2!");
using Converter = _typeConvert<scalar_t>;
using T1 = typename Converter::hip_type;
using T2 = typename Converter::packed_hip_type;
T1 data[width];

__device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) {
if constexpr (width % 2 == 0) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
T2 temp{data[i], data[i + 1]};
temp += T2{other.data[i], other.data[i + 1]};
data[i] = temp.x;
data[i + 1] = temp.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i) data[i] += other.data[i];
}
return *this;
}

__device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) {
if constexpr (width % 2 == 0) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
T2 temp{data[i], data[i + 1]};
temp *= T2{other.data[i], other.data[i + 1]};
data[i] = temp.x;
data[i + 1] = temp.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i) data[i] *= other.data[i];
}
return *this;
}

__device__ _f16Vec& operator*=(const float scale) {
if constexpr (width % 2 == 0) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
float2 temp_f = Converter::convert(T2{data[i], data[i + 1]});
temp_f.x *= scale;
temp_f.y *= scale;
T2 temp = Converter::convert(temp_f);
data[i] = temp.x;
data[i + 1] = temp.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i) {
float temp = Converter::convert(data[i]) * scale;
data[i] = Converter::convert(temp);
}
}
return *this;
}

__device__ float sum_squares() const {
float result = 0.0f;
if constexpr (width % 2 == 0) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
float2 z = Converter::convert(T2{data[i], data[i + 1]});
result += z.x * z.x + z.y * z.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i) {
float x = Converter::convert(data[i]);
result += x * x;
}
}
return result;
}
};

/* Function specialization in the case of FP16/BF16 tensors.
Additional optimizations we can make in this case are
packed and vectorized operations, which help with the
Expand Down
Loading