|
1 |
| -#include <torch/all.h> |
2 |
| -#include <ATen/cuda/CUDAContext.h> |
| 1 | +#include "type_convert.cuh" |
| 2 | +#include "dispatch_utils.h" |
| 3 | + |
| 4 | +#include <torch/cuda.h> |
3 | 5 | #include <c10/cuda/CUDAGuard.h>
|
4 | 6 |
|
5 |
| -#include "dispatch_utils.h" |
6 | 7 | #ifndef USE_ROCM
|
7 |
| - #include <cuda_bf16.h> |
8 |
| - #include <cuda_fp16.h> |
9 |
| - #include <cub/util_type.cuh> |
10 | 8 | #include <cub/cub.cuh>
|
11 | 9 | #else
|
12 |
| - #include <hip/hip_bf16.h> |
13 |
| - #include <hip/hip_fp16.h> |
14 |
| - #include <hipcub/util_type.hpp> |
15 | 10 | #include <hipcub/hipcub.hpp>
|
16 |
| - |
17 |
| -using __nv_bfloat16 = __hip_bfloat16; |
18 |
| -using __nv_bfloat162 = __hip_bfloat162; |
19 | 11 | #endif
|
20 | 12 |
|
21 | 13 | namespace vllm {
|
@@ -51,155 +43,6 @@ __global__ void rms_norm_kernel(
|
51 | 43 | }
|
52 | 44 | }
|
53 | 45 |
|
54 |
| -/* Converter structs for the conversion from torch types to HIP/CUDA types, |
55 |
| - and the associated type conversions within HIP/CUDA. These helpers need |
56 |
| - to be implemented for now because the relevant type conversion |
57 |
| - operators/constructors are not consistently implemented by HIP/CUDA, so |
58 |
| - a generic conversion via type casts cannot be implemented. |
59 |
| -
|
60 |
| - Each struct should have the member static constexpr bool `exists`: |
61 |
| - If false, the optimized kernel is not used for the corresponding torch type. |
62 |
| - If true, the struct should be fully defined as shown in the examples below. |
63 |
| - */ |
64 |
| -template <typename torch_type> |
65 |
| -struct _typeConvert { |
66 |
| - static constexpr bool exists = false; |
67 |
| -}; |
68 |
| - |
69 |
| -#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) |
70 |
| -// CUDA < 12.0 runs into issues with packed type conversion |
71 |
| -template <> |
72 |
| -struct _typeConvert<c10::Half> { |
73 |
| - static constexpr bool exists = true; |
74 |
| - using hip_type = __half; |
75 |
| - using packed_hip_type = __half2; |
76 |
| - |
77 |
| - __device__ static inline float convert(hip_type x) { return __half2float(x); } |
78 |
| - __device__ static inline float2 convert(packed_hip_type x) { |
79 |
| - return __half22float2(x); |
80 |
| - } |
81 |
| - __device__ static inline hip_type convert(float x) { |
82 |
| - return __float2half_rn(x); |
83 |
| - } |
84 |
| - __device__ static inline packed_hip_type convert(float2 x) { |
85 |
| - return __float22half2_rn(x); |
86 |
| - } |
87 |
| -}; |
88 |
| - |
89 |
| - #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 |
90 |
| -// CUDA_ARCH < 800 does not have BF16 support |
91 |
| -// TODO: Add in ROCm support once public headers handle bf16 maturely |
92 |
| -template <> |
93 |
| -struct _typeConvert<c10::BFloat16> { |
94 |
| - static constexpr bool exists = true; |
95 |
| - using hip_type = __nv_bfloat16; |
96 |
| - using packed_hip_type = __nv_bfloat162; |
97 |
| - |
98 |
| - __device__ static inline float convert(hip_type x) { |
99 |
| - return __bfloat162float(x); |
100 |
| - } |
101 |
| - __device__ static inline float2 convert(packed_hip_type x) { |
102 |
| - return __bfloat1622float2(x); |
103 |
| - } |
104 |
| - __device__ static inline hip_type convert(float x) { |
105 |
| - return __float2bfloat16(x); |
106 |
| - } |
107 |
| - __device__ static inline packed_hip_type convert(float2 x) { |
108 |
| - return __float22bfloat162_rn(x); |
109 |
| - } |
110 |
| -}; |
111 |
| - #endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 |
112 |
| -#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= |
113 |
| - // 12000)) |
114 |
| - |
115 |
| -/* Vector POD struct to generate vectorized and packed FP16/BF16 ops |
116 |
| - for appropriate specializations of fused_add_rms_norm_kernel. |
117 |
| - Only functions that are necessary in that kernel are implemented. |
118 |
| - Alignment to 16 bytes is required to use 128-bit global memory ops. |
119 |
| - */ |
120 |
| -template <typename scalar_t, int width> |
121 |
| -struct alignas(16) _f16Vec { |
122 |
| - /* Not theoretically necessary that width is a power of 2 but should |
123 |
| - almost always be the case for optimization purposes */ |
124 |
| - static_assert(width > 0 && (width & (width - 1)) == 0, |
125 |
| - "Width is not a positive power of 2!"); |
126 |
| - using Converter = _typeConvert<scalar_t>; |
127 |
| - using T1 = typename Converter::hip_type; |
128 |
| - using T2 = typename Converter::packed_hip_type; |
129 |
| - T1 data[width]; |
130 |
| - |
131 |
| - __device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) { |
132 |
| - if constexpr (width % 2 == 0) { |
133 |
| -#pragma unroll |
134 |
| - for (int i = 0; i < width; i += 2) { |
135 |
| - T2 temp{data[i], data[i + 1]}; |
136 |
| - temp += T2{other.data[i], other.data[i + 1]}; |
137 |
| - data[i] = temp.x; |
138 |
| - data[i + 1] = temp.y; |
139 |
| - } |
140 |
| - } else { |
141 |
| -#pragma unroll |
142 |
| - for (int i = 0; i < width; ++i) data[i] += other.data[i]; |
143 |
| - } |
144 |
| - return *this; |
145 |
| - } |
146 |
| - |
147 |
| - __device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) { |
148 |
| - if constexpr (width % 2 == 0) { |
149 |
| -#pragma unroll |
150 |
| - for (int i = 0; i < width; i += 2) { |
151 |
| - T2 temp{data[i], data[i + 1]}; |
152 |
| - temp *= T2{other.data[i], other.data[i + 1]}; |
153 |
| - data[i] = temp.x; |
154 |
| - data[i + 1] = temp.y; |
155 |
| - } |
156 |
| - } else { |
157 |
| -#pragma unroll |
158 |
| - for (int i = 0; i < width; ++i) data[i] *= other.data[i]; |
159 |
| - } |
160 |
| - return *this; |
161 |
| - } |
162 |
| - |
163 |
| - __device__ _f16Vec& operator*=(const float scale) { |
164 |
| - if constexpr (width % 2 == 0) { |
165 |
| -#pragma unroll |
166 |
| - for (int i = 0; i < width; i += 2) { |
167 |
| - float2 temp_f = Converter::convert(T2{data[i], data[i + 1]}); |
168 |
| - temp_f.x *= scale; |
169 |
| - temp_f.y *= scale; |
170 |
| - T2 temp = Converter::convert(temp_f); |
171 |
| - data[i] = temp.x; |
172 |
| - data[i + 1] = temp.y; |
173 |
| - } |
174 |
| - } else { |
175 |
| -#pragma unroll |
176 |
| - for (int i = 0; i < width; ++i) { |
177 |
| - float temp = Converter::convert(data[i]) * scale; |
178 |
| - data[i] = Converter::convert(temp); |
179 |
| - } |
180 |
| - } |
181 |
| - return *this; |
182 |
| - } |
183 |
| - |
184 |
| - __device__ float sum_squares() const { |
185 |
| - float result = 0.0f; |
186 |
| - if constexpr (width % 2 == 0) { |
187 |
| -#pragma unroll |
188 |
| - for (int i = 0; i < width; i += 2) { |
189 |
| - float2 z = Converter::convert(T2{data[i], data[i + 1]}); |
190 |
| - result += z.x * z.x + z.y * z.y; |
191 |
| - } |
192 |
| - } else { |
193 |
| -#pragma unroll |
194 |
| - for (int i = 0; i < width; ++i) { |
195 |
| - float x = Converter::convert(data[i]); |
196 |
| - result += x * x; |
197 |
| - } |
198 |
| - } |
199 |
| - return result; |
200 |
| - } |
201 |
| -}; |
202 |
| - |
203 | 46 | /* Function specialization in the case of FP16/BF16 tensors.
|
204 | 47 | Additional optimizations we can make in this case are
|
205 | 48 | packed and vectorized operations, which help with the
|
|
0 commit comments