Skip to content

Commit 4f93dfe

Browse files
[torch.compile] Fuse RMSNorm with quant (#9138)
Signed-off-by: luka <luka@neuralmagic.com> Co-authored-by: youkaichao <youkaichao@126.com>
1 parent e1b5a82 commit 4f93dfe

17 files changed

+1335
-368
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ set(VLLM_EXT_SRC
191191
"csrc/pos_encoding_kernels.cu"
192192
"csrc/activation_kernels.cu"
193193
"csrc/layernorm_kernels.cu"
194+
"csrc/layernorm_quant_kernels.cu"
194195
"csrc/quantization/gptq/q_gemm.cu"
195196
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
196197
"csrc/quantization/fp8/common.cu"

csrc/layernorm_kernels.cu

Lines changed: 4 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,13 @@
1-
#include <torch/all.h>
2-
#include <ATen/cuda/CUDAContext.h>
1+
#include "type_convert.cuh"
2+
#include "dispatch_utils.h"
3+
4+
#include <torch/cuda.h>
35
#include <c10/cuda/CUDAGuard.h>
46

5-
#include "dispatch_utils.h"
67
#ifndef USE_ROCM
7-
#include <cuda_bf16.h>
8-
#include <cuda_fp16.h>
9-
#include <cub/util_type.cuh>
108
#include <cub/cub.cuh>
119
#else
12-
#include <hip/hip_bf16.h>
13-
#include <hip/hip_fp16.h>
14-
#include <hipcub/util_type.hpp>
1510
#include <hipcub/hipcub.hpp>
16-
17-
using __nv_bfloat16 = __hip_bfloat16;
18-
using __nv_bfloat162 = __hip_bfloat162;
1911
#endif
2012

2113
namespace vllm {
@@ -51,155 +43,6 @@ __global__ void rms_norm_kernel(
5143
}
5244
}
5345

54-
/* Converter structs for the conversion from torch types to HIP/CUDA types,
55-
and the associated type conversions within HIP/CUDA. These helpers need
56-
to be implemented for now because the relevant type conversion
57-
operators/constructors are not consistently implemented by HIP/CUDA, so
58-
a generic conversion via type casts cannot be implemented.
59-
60-
Each struct should have the member static constexpr bool `exists`:
61-
If false, the optimized kernel is not used for the corresponding torch type.
62-
If true, the struct should be fully defined as shown in the examples below.
63-
*/
64-
template <typename torch_type>
65-
struct _typeConvert {
66-
static constexpr bool exists = false;
67-
};
68-
69-
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
70-
// CUDA < 12.0 runs into issues with packed type conversion
71-
template <>
72-
struct _typeConvert<c10::Half> {
73-
static constexpr bool exists = true;
74-
using hip_type = __half;
75-
using packed_hip_type = __half2;
76-
77-
__device__ static inline float convert(hip_type x) { return __half2float(x); }
78-
__device__ static inline float2 convert(packed_hip_type x) {
79-
return __half22float2(x);
80-
}
81-
__device__ static inline hip_type convert(float x) {
82-
return __float2half_rn(x);
83-
}
84-
__device__ static inline packed_hip_type convert(float2 x) {
85-
return __float22half2_rn(x);
86-
}
87-
};
88-
89-
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
90-
// CUDA_ARCH < 800 does not have BF16 support
91-
// TODO: Add in ROCm support once public headers handle bf16 maturely
92-
template <>
93-
struct _typeConvert<c10::BFloat16> {
94-
static constexpr bool exists = true;
95-
using hip_type = __nv_bfloat16;
96-
using packed_hip_type = __nv_bfloat162;
97-
98-
__device__ static inline float convert(hip_type x) {
99-
return __bfloat162float(x);
100-
}
101-
__device__ static inline float2 convert(packed_hip_type x) {
102-
return __bfloat1622float2(x);
103-
}
104-
__device__ static inline hip_type convert(float x) {
105-
return __float2bfloat16(x);
106-
}
107-
__device__ static inline packed_hip_type convert(float2 x) {
108-
return __float22bfloat162_rn(x);
109-
}
110-
};
111-
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
112-
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
113-
// 12000))
114-
115-
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
116-
for appropriate specializations of fused_add_rms_norm_kernel.
117-
Only functions that are necessary in that kernel are implemented.
118-
Alignment to 16 bytes is required to use 128-bit global memory ops.
119-
*/
120-
template <typename scalar_t, int width>
121-
struct alignas(16) _f16Vec {
122-
/* Not theoretically necessary that width is a power of 2 but should
123-
almost always be the case for optimization purposes */
124-
static_assert(width > 0 && (width & (width - 1)) == 0,
125-
"Width is not a positive power of 2!");
126-
using Converter = _typeConvert<scalar_t>;
127-
using T1 = typename Converter::hip_type;
128-
using T2 = typename Converter::packed_hip_type;
129-
T1 data[width];
130-
131-
__device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) {
132-
if constexpr (width % 2 == 0) {
133-
#pragma unroll
134-
for (int i = 0; i < width; i += 2) {
135-
T2 temp{data[i], data[i + 1]};
136-
temp += T2{other.data[i], other.data[i + 1]};
137-
data[i] = temp.x;
138-
data[i + 1] = temp.y;
139-
}
140-
} else {
141-
#pragma unroll
142-
for (int i = 0; i < width; ++i) data[i] += other.data[i];
143-
}
144-
return *this;
145-
}
146-
147-
__device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) {
148-
if constexpr (width % 2 == 0) {
149-
#pragma unroll
150-
for (int i = 0; i < width; i += 2) {
151-
T2 temp{data[i], data[i + 1]};
152-
temp *= T2{other.data[i], other.data[i + 1]};
153-
data[i] = temp.x;
154-
data[i + 1] = temp.y;
155-
}
156-
} else {
157-
#pragma unroll
158-
for (int i = 0; i < width; ++i) data[i] *= other.data[i];
159-
}
160-
return *this;
161-
}
162-
163-
__device__ _f16Vec& operator*=(const float scale) {
164-
if constexpr (width % 2 == 0) {
165-
#pragma unroll
166-
for (int i = 0; i < width; i += 2) {
167-
float2 temp_f = Converter::convert(T2{data[i], data[i + 1]});
168-
temp_f.x *= scale;
169-
temp_f.y *= scale;
170-
T2 temp = Converter::convert(temp_f);
171-
data[i] = temp.x;
172-
data[i + 1] = temp.y;
173-
}
174-
} else {
175-
#pragma unroll
176-
for (int i = 0; i < width; ++i) {
177-
float temp = Converter::convert(data[i]) * scale;
178-
data[i] = Converter::convert(temp);
179-
}
180-
}
181-
return *this;
182-
}
183-
184-
__device__ float sum_squares() const {
185-
float result = 0.0f;
186-
if constexpr (width % 2 == 0) {
187-
#pragma unroll
188-
for (int i = 0; i < width; i += 2) {
189-
float2 z = Converter::convert(T2{data[i], data[i + 1]});
190-
result += z.x * z.x + z.y * z.y;
191-
}
192-
} else {
193-
#pragma unroll
194-
for (int i = 0; i < width; ++i) {
195-
float x = Converter::convert(data[i]);
196-
result += x * x;
197-
}
198-
}
199-
return result;
200-
}
201-
};
202-
20346
/* Function specialization in the case of FP16/BF16 tensors.
20447
Additional optimizations we can make in this case are
20548
packed and vectorized operations, which help with the

0 commit comments

Comments
 (0)