-
Notifications
You must be signed in to change notification settings - Fork 233
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Major] Fuse bias+gemm and layernorm+quantization for more efficient ViT #254
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
// Inspired by TRT-LLM. | ||
// Modified by Shang Yang and Haotian Tang. | ||
// Inspired by QServe https://github.com/mit-han-lab/qserve/tree/main. | ||
// Modified by Yuming Lou. | ||
// @article{lin2024awq, | ||
// title={AWQ: Activation-aware Weight Quantization for On-Device LLM Compression and Acceleration}, | ||
// author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Chen, Wei-Ming and Wang, Wei-Chen and Xiao, Guangxuan and Dang, Xingyu and Gan, Chuang and Han, Song}, | ||
|
@@ -10,7 +10,6 @@ | |
// } | ||
#include <ATen/cuda/CUDAContext.h> | ||
#include <torch/extension.h> | ||
|
||
#include "dispatch_utils.h" | ||
#include "utils.cuh" | ||
#include "reduction_utils.cuh" | ||
|
@@ -41,23 +40,24 @@ __inline__ __device__ Tf compute_layernorm(Tf val, float s_mean, float s_varianc | |
* First pass (loop) computes the mean. | ||
* Second computes the variance via Var[x] = E[(x - E[x])²]. | ||
* Third pass computes and writes normed_output | ||
* | ||
* with USE_DIFF_OF_SQUARES set to true (may be faster but less accurate): | ||
* For better speedup, we set USE_DIFF_OF_SQUARES to true (may be faster but less accurate): | ||
* First pass (loop) computes the mean and variance via Var[x] = E[x²] - E[x]² | ||
* Second pass computes and writes normed_output | ||
* | ||
* It turns out the accuracy dosen't drop. | ||
* | ||
* use_shmem controls if we cache input values into shared memory | ||
* | ||
* Optional: with dynamic scaling, the last pass doesn't write immediately but finds the | ||
* amax per row. A final pass scales to int8 accordingly, and writes output to | ||
* normed_output_quant. | ||
*/ | ||
template <typename T, typename scale_type, bool USE_DIFF_OF_SQUARES = false> | ||
template <typename T, typename scale_type> | ||
__global__ void generalLayerNorm(const T* input, const T* gamma, const T* beta, T* normed_output, const float eps, | ||
int tokens, int hidden_dim, const scale_type* scale_orig_quant_per_tensor, scale_type* scale_orig_quant_per_token, | ||
int8_t* normed_output_quant, bool use_shmem) | ||
{ | ||
constexpr auto num_elems_T = num_elems<T>::value; | ||
constexpr auto num_elems_T = num_elems<T>::value;//1 | ||
using int8_packed_t = typename packed_as<int8_t, num_elems_T>::type; | ||
using float_packed_t = typename packed_as<float, num_elems_T>::type; | ||
using T_scalar = typename packed_as<T, 1>::type; | ||
|
@@ -74,7 +74,6 @@ __global__ void generalLayerNorm(const T* input, const T* gamma, const T* beta, | |
float variance = 0.0f; | ||
float local_sum = 0.0f; | ||
float local_var_sum = 0.0f; | ||
|
||
const int n_elems = hidden_dim / num_elems_T; | ||
for (int i = tidx; i < n_elems; i += blockDim.x) | ||
{ | ||
|
@@ -83,56 +82,25 @@ __global__ void generalLayerNorm(const T* input, const T* gamma, const T* beta, | |
{ | ||
shmem[i] = val; | ||
} | ||
|
||
const float_packed_t val_f = cuda_cast<float_packed_t>(val); | ||
local_sum += cuda_sum<float>(val_f); | ||
if (USE_DIFF_OF_SQUARES) | ||
{ | ||
local_var_sum += cuda_sum<float>(val_f * val_f); | ||
} | ||
} | ||
|
||
if (USE_DIFF_OF_SQUARES) | ||
{ | ||
float packed[2] = {local_sum, local_var_sum}; | ||
blockReduceSumV2<float, 2>(packed); | ||
mean = packed[0]; | ||
variance = packed[1]; | ||
} | ||
else | ||
{ | ||
mean = blockReduceSum(local_sum); | ||
local_var_sum += cuda_sum<float>(val_f * val_f); | ||
} | ||
//Compute mean | ||
float packed[2] = {local_sum, local_var_sum}; | ||
blockReduceSumV2<float, 2>(packed); | ||
mean = packed[0]; | ||
variance = packed[1]; | ||
|
||
if (threadIdx.x == 0) | ||
{ | ||
mean = mean / hidden_dim; | ||
s_mean = mean; | ||
if (USE_DIFF_OF_SQUARES) | ||
{ | ||
variance = (variance / hidden_dim) - (mean * mean); // Var[x] = E[x²] - E[x]² | ||
s_variance = rsqrtf(variance + eps); | ||
} | ||
variance = (variance / hidden_dim) - (mean * mean); // Var[x] = E[x²] - E[x]² | ||
s_variance = rsqrtf(variance + eps); | ||
} | ||
__syncthreads(); | ||
|
||
if (!USE_DIFF_OF_SQUARES) | ||
{ | ||
for (int i = tidx; i < n_elems; i += blockDim.x) | ||
{ | ||
const T val = use_shmem ? shmem[i] : input[bidx * n_elems + i]; | ||
float_packed_t diff = cuda_cast<float_packed_t>(val) - s_mean; | ||
local_var_sum += cuda_sum<float>(diff * diff); | ||
} | ||
variance = blockReduceSum(local_var_sum); | ||
|
||
if (threadIdx.x == 0) | ||
{ | ||
s_variance = rsqrtf(variance / hidden_dim + eps); | ||
} | ||
__syncthreads(); | ||
} | ||
|
||
// Compute LN and Quantize | ||
const bool with_per_token_scaling = scale_orig_quant_per_token != nullptr; | ||
const bool with_per_tensor_scaling = scale_orig_quant_per_tensor != nullptr; | ||
const float_packed_t scale_orig_quant | ||
|
@@ -186,51 +154,21 @@ __global__ void generalLayerNorm(const T* input, const T* gamma, const T* beta, | |
} | ||
} | ||
} | ||
} | ||
|
||
// TODO(woosuk): Further optimize this kernel. | ||
template <typename scalar_t, typename out_type, bool use_quant> | ||
__global__ void | ||
rms_norm_kernel(out_type *__restrict__ out, // [..., hidden_size] | ||
const scalar_t *__restrict__ input, // [..., hidden_size] | ||
const scalar_t *__restrict__ weight, // [hidden_size] | ||
const float epsilon, const int num_tokens, | ||
const int hidden_size) { | ||
__shared__ float s_variance; | ||
float variance = 0.0f; | ||
|
||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { | ||
const float x = (float)input[blockIdx.x * hidden_size + idx]; | ||
variance += x * x; | ||
} | ||
variance = blockReduceSum<float>(variance); | ||
if (threadIdx.x == 0) { | ||
s_variance = rsqrtf(variance / hidden_size + epsilon); | ||
} | ||
__syncthreads(); | ||
|
||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { | ||
float x = (float)input[blockIdx.x * hidden_size + idx]; | ||
if constexpr (use_quant) { | ||
out[blockIdx.x * hidden_size + idx] = float_to_int8_rn( | ||
((float)(x * s_variance)) * (float)(weight[idx])); | ||
} else { | ||
out[blockIdx.x * hidden_size + idx] = | ||
((scalar_t)(x * s_variance)) * weight[idx]; | ||
} | ||
} | ||
} | ||
} // namespace vllm | ||
|
||
void rms_norm_general(torch::Tensor &out, // [..., hidden_size] | ||
torch::Tensor &input, // [..., hidden_size] | ||
torch::Tensor &weight, // [hidden_size] | ||
torch::Tensor &bias, // [hidden_size] | ||
torch::Tensor &scaling, // [tokens] or [1] | ||
float epsilon, | ||
bool use_per_token_quant) { | ||
bool use_per_token_quant = true) { | ||
int hidden_size = input.size(-1); | ||
int num_tokens = input.numel() / hidden_size; | ||
dim3 grid(num_tokens); | ||
dim3 block(std::min(hidden_size, 1024)); | ||
dim3 block(std::min(hidden_size/2, 1024));//Prevent thread idling when the embedding size is greater than 1024 and not an integer multiple of it. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we fix this in a more elegant way to improve utilzation? |
||
block.x = 32 * ((block.x + 31) / 32); | ||
|
||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
|
@@ -240,7 +178,8 @@ void rms_norm_general(torch::Tensor &out, // [..., hidden_size] | |
// per-token | ||
vllm::generalLayerNorm<T, at::Half><<<grid, block, 0, stream>>>( | ||
reinterpret_cast<T*>(input.data_ptr<scalar_t>()), | ||
reinterpret_cast<T*>(weight.data_ptr<scalar_t>()), nullptr, | ||
reinterpret_cast<T*>(weight.data_ptr<scalar_t>()), | ||
reinterpret_cast<T*>(bias.data_ptr<scalar_t>()), | ||
nullptr, epsilon, num_tokens, hidden_size, nullptr, scaling.data_ptr<at::Half>(), | ||
out.data_ptr<int8_t>(), false | ||
); | ||
|
@@ -258,4 +197,4 @@ void rms_norm_general(torch::Tensor &out, // [..., hidden_size] | |
); | ||
} | ||
}); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,7 +32,7 @@ | |
constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N) * (CTA_K / WARP_K); \ | ||
constexpr int kSmemByteSize = \ | ||
(CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B)) * STAGES * \ | ||
sizeof(int8_t); \ | ||
sizeof(int8_t) + CTA_N * sizeof(float); \ | ||
if (kSmemByteSize >= 99 * 1024) \ | ||
{ \ | ||
printf("This kernel requires %d Bytes of shared memory, which exceeds " \ | ||
|
@@ -41,12 +41,12 @@ | |
return ; \ | ||
} \ | ||
int num_blocks_m = (num_out_feats + CTA_M - 1) / CTA_M; \ | ||
int num_blocks_n = num_out_channels / CTA_N / 1; \ | ||
int num_blocks_n = num_out_channels / CTA_N / 1; \ | ||
const int log_tile = get_log_tile<8>((num_out_feats + CTA_M - 1) / CTA_M); \ | ||
const int tile_shift = 1 << log_tile; \ | ||
dim3 num_blocks(num_blocks_n *tile_shift, \ | ||
(num_blocks_m + tile_shift - 1) / tile_shift); \ | ||
dim3 threads_per_block(WARP_SIZE, NUM_WARPS); \ | ||
dim3 threads_per_block(WARP_SIZE, NUM_WARPS); \ | ||
auto kernel_func = \ | ||
dense_kernel0_fuse_bias<CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES>; \ | ||
cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, \ | ||
|
@@ -300,7 +300,7 @@ template <int CTA_M, int CTA_N, int CTA_K, int WARP_M, int WARP_N, int WARP_K, | |
int STAGES> | ||
__global__ void dense_kernel0_fuse_bias(int8_t *__restrict__ A, int8_t *__restrict__ B, | ||
half2 *__restrict__ wscales, half *__restrict__ ascales, | ||
half *__restrict__ C, half2 *__restrict__ Bias, | ||
half *__restrict__ C, half *__restrict__ Bias, | ||
int M, int N, int K) | ||
{ | ||
constexpr int NUM_WARPS_MN = CTA_M / WARP_M * CTA_N / WARP_N; | ||
|
@@ -326,9 +326,10 @@ __global__ void dense_kernel0_fuse_bias(int8_t *__restrict__ A, int8_t *__restri | |
constexpr int kSmemSizeBPerStage = CTA_N * kSmemPadKB; | ||
constexpr int kSmemSizeA = kSmemSizeAPerStage * STAGES; | ||
constexpr int kSmemSizeB = kSmemSizeBPerStage * STAGES; | ||
extern __shared__ int8_t mem_shared[]; | ||
extern __shared__ int8_t mem_shared[]; //extern: dynamic share, decided in kernel launch; shared: within block | ||
int8_t *A_shared = mem_shared; | ||
int8_t *B_shared = mem_shared + kSmemSizeA; | ||
float *Bias_shared= reinterpret_cast<float*>(mem_shared + kSmemSizeA + kSmemSizeB); | ||
int8_t A_shared_warp_[2][WARP_M * WARP_K / | ||
WARP_SIZE]; | ||
int8_t B_shared_warp_[2][WARP_N * WARP_K / | ||
|
@@ -344,22 +345,22 @@ __global__ void dense_kernel0_fuse_bias(int8_t *__restrict__ A, int8_t *__restri | |
int cta_offset_m = blockIdx_m * CTA_M; | ||
int cta_offset_n = blockIdx_n * CTA_N; | ||
int warp_mn = threadIdx.y % NUM_WARPS_MN; | ||
int slice_id = threadIdx.y / NUM_WARPS_MN; | ||
int slice_id = threadIdx.y / NUM_WARPS_MN; // Always zero if threadIdx.z==0! | ||
int warp_offset_m = (warp_mn % (CTA_M / WARP_M)) * WARP_M; | ||
int warp_offset_n = (warp_mn / (CTA_M / WARP_M)) * WARP_N; | ||
int warp_offset_k = slice_id * WARP_K; | ||
|
||
for (int i = 0; i < CTA_M * CTA_N / CTA_SIZE_MN; i++) | ||
C_warp[i] = 0; | ||
|
||
int gemm_iters = (K + CTA_K - 1) / CTA_K; | ||
int k_0_0_ld = 0; | ||
int k_0_0 = 0; | ||
constexpr int prologue_stages = STAGES == 1 ? 1 : STAGES - 1; | ||
int A_hoisted_row = threadIdx.y * A_warp_step_m + (threadIdx.x / A_threads_per_row); | ||
int A_hoisted_col = (threadIdx.x % A_threads_per_row); | ||
int A_hoisted_col_swizzled = A_hoisted_col ^ (A_hoisted_row / 2) & 3; | ||
|
||
int B_hoisted_row = threadIdx.y * B_warp_step_n + (threadIdx.x / B_threads_per_row); | ||
int B_hoisted_col = (threadIdx.x % B_threads_per_row); | ||
int B_hoisted_col_swizzled = B_hoisted_col ^ (B_hoisted_row / 2) & 3; | ||
|
@@ -374,6 +375,18 @@ __global__ void dense_kernel0_fuse_bias(int8_t *__restrict__ A, int8_t *__restri | |
int8_t *B_hoisted = B + cta_offset_n * K + B_hoisted_row * K + | ||
B_hoisted_col * PACK_SIZE; | ||
bool A_g2s_preds[A_total_global_iters]; | ||
//debug | ||
// printf("A: %d ",A_total_global_iters); | ||
// printf("B: %d ",B_total_global_iters); | ||
// printf("prologue_stages: %d ",prologue_stages); | ||
// __shared__ float2 Bias_shared[CTA_N]; | ||
#pragma unroll | ||
for (int i = 0; i < CTA_N ; i++) | ||
{ | ||
Bias_shared[i] = __half2float(Bias[cta_offset_n+i]); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this the bottleneck? |
||
} | ||
|
||
|
||
#pragma unroll | ||
for (int i = 0; i < A_total_global_iters; i++) | ||
{ | ||
|
@@ -396,6 +409,8 @@ __global__ void dense_kernel0_fuse_bias(int8_t *__restrict__ A, int8_t *__restri | |
__pipeline_wait_prior(STAGES - 2); | ||
__syncthreads(); | ||
|
||
// global_to_share_bias<CTA_N,CTA_SIZE>(Bias,Bias_shared,cta_offset_n); | ||
|
||
share_to_reg_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES>( | ||
A_shared + warp_offset_k, A_shared_warp_[0], warp_offset_m, warp_offset_n, 0, | ||
WARP_M / INTRIN_M); | ||
|
@@ -547,14 +562,13 @@ __global__ void dense_kernel0_fuse_bias(int8_t *__restrict__ A, int8_t *__restri | |
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2) | ||
{ | ||
int row_wb = row_wb_1 + (local_id % 4) / 2 * 8; | ||
if (row_wb < M){ | ||
int col_wb = col_wb_1 + (local_id / 4) * 8 + (local_id % 2); | ||
int col_wb = col_wb_1 + (local_id / 4) * 8 + (local_id % 2); | ||
if (row_wb < M){ | ||
float2 wscale = __half22float2(*(wscales + col_wb / 2)); | ||
float ascale = __half2float(ascales[row_wb]); | ||
float2 bias = __half22float2(Bias[col_wb]); | ||
float2 psums = make_float2(__int2float_rn(C_warp_local[local_id]), __int2float_rn(C_warp_local[local_id + 1])); | ||
psums.x = psums.x * wscale.x * ascale + bias.x; | ||
psums.y = psums.y * wscale.y * ascale + bias.y; | ||
psums.x = psums.x * wscale.x * ascale + Bias_shared[col_wb % CTA_N]; | ||
psums.y = psums.y * wscale.y * ascale + Bias_shared[col_wb % CTA_N + 1]; | ||
*reinterpret_cast<half2 *>(C + row_wb * N + col_wb) = __float22half2_rn(psums); | ||
} | ||
}; | ||
|
@@ -576,7 +590,7 @@ void w8a8_gemm_fuse_bias_forward_cuda(torch::Tensor _in_feats, | |
auto kernel = reinterpret_cast<int8_t *>(_kernel.data_ptr<int8_t>()); | ||
auto wscales = reinterpret_cast<half2 *>(_wscales.data_ptr()); | ||
auto ascales = reinterpret_cast<half *>(_ascales.data_ptr()); | ||
auto bias = reinterpret_cast<half2 *>(_bias.data_ptr()); | ||
auto bias = reinterpret_cast<half *>(_bias.data_ptr()); | ||
// auto options = | ||
// torch::TensorOptions().dtype(torch::kFloat16).device(_in_feats.device()); | ||
// at::Tensor _out_feats = | ||
|
@@ -592,10 +606,10 @@ void w8a8_gemm_fuse_bias_forward_cuda(torch::Tensor _in_feats, | |
constexpr int CTA_M = 128; | ||
constexpr int CTA_N = 128; | ||
constexpr int CTA_K = 64; | ||
constexpr int WARP_M = 128; | ||
constexpr int WARP_M = 64; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Issue] Why couldn't we have 128 here? |
||
constexpr int WARP_N = 32; | ||
constexpr int WARP_K = 64; | ||
constexpr int STAGES = 3; | ||
constexpr int STAGES = 6; | ||
KERNEL_LAUNCH_CODE_FUSE_BIAS | ||
} | ||
else | ||
|
@@ -604,7 +618,7 @@ void w8a8_gemm_fuse_bias_forward_cuda(torch::Tensor _in_feats, | |
constexpr int CTA_N = 64; | ||
constexpr int CTA_K = 64; | ||
constexpr int WARP_M = 32; | ||
constexpr int WARP_N = 32; | ||
constexpr int WARP_N = 16; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Issue] Why couldn't we have 16 here? |
||
constexpr int WARP_K = 64; | ||
constexpr int STAGES = 6; | ||
KERNEL_LAUNCH_CODE_FUSE_BIAS | ||
|
@@ -613,14 +627,6 @@ void w8a8_gemm_fuse_bias_forward_cuda(torch::Tensor _in_feats, | |
} | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
template <int CTA_M, int CTA_N, int CTA_K, int WARP_M, int WARP_N, int WARP_K, | ||
int STAGES> | ||
__global__ void dense_kernel0(int8_t *__restrict__ A, int8_t *__restrict__ B, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets keep the original template for better flexibility.