Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Major] Fuse bias+gemm and layernorm+quantization for more efficient ViT #254

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions awq/kernels/csrc/fused_layernorm/layernorm.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
void rms_norm_general(torch::Tensor &out, // [..., hidden_size]
torch::Tensor &input, // [..., hidden_size]
torch::Tensor &weight, // [hidden_size]
torch::Tensor &bias, // [hidden_size]
torch::Tensor &scaling, // [tokens] or [1]
float epsilon,
bool use_per_token_quant);
Expand Down
107 changes: 23 additions & 84 deletions awq/kernels/csrc/fused_layernorm/layernorm_kernels.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// Inspired by TRT-LLM.
// Modified by Shang Yang and Haotian Tang.
// Inspired by QServe https://github.com/mit-han-lab/qserve/tree/main.
// Modified by Yuming Lou.
// @article{lin2024awq,
// title={AWQ: Activation-aware Weight Quantization for On-Device LLM Compression and Acceleration},
// author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Chen, Wei-Ming and Wang, Wei-Chen and Xiao, Guangxuan and Dang, Xingyu and Gan, Chuang and Han, Song},
Expand All @@ -10,7 +10,6 @@
// }
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>

#include "dispatch_utils.h"
#include "utils.cuh"
#include "reduction_utils.cuh"
Expand Down Expand Up @@ -41,23 +40,24 @@ __inline__ __device__ Tf compute_layernorm(Tf val, float s_mean, float s_varianc
* First pass (loop) computes the mean.
* Second computes the variance via Var[x] = E[(x - E[x])²].
* Third pass computes and writes normed_output
*
* with USE_DIFF_OF_SQUARES set to true (may be faster but less accurate):
* For better speedup, we set USE_DIFF_OF_SQUARES to true (may be faster but less accurate):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets keep the original template for better flexibility.

* First pass (loop) computes the mean and variance via Var[x] = E[x²] - E[x]²
* Second pass computes and writes normed_output
*
* It turns out the accuracy dosen't drop.
*
* use_shmem controls if we cache input values into shared memory
*
* Optional: with dynamic scaling, the last pass doesn't write immediately but finds the
* amax per row. A final pass scales to int8 accordingly, and writes output to
* normed_output_quant.
*/
template <typename T, typename scale_type, bool USE_DIFF_OF_SQUARES = false>
template <typename T, typename scale_type>
__global__ void generalLayerNorm(const T* input, const T* gamma, const T* beta, T* normed_output, const float eps,
int tokens, int hidden_dim, const scale_type* scale_orig_quant_per_tensor, scale_type* scale_orig_quant_per_token,
int8_t* normed_output_quant, bool use_shmem)
{
constexpr auto num_elems_T = num_elems<T>::value;
constexpr auto num_elems_T = num_elems<T>::value;//1
using int8_packed_t = typename packed_as<int8_t, num_elems_T>::type;
using float_packed_t = typename packed_as<float, num_elems_T>::type;
using T_scalar = typename packed_as<T, 1>::type;
Expand All @@ -74,7 +74,6 @@ __global__ void generalLayerNorm(const T* input, const T* gamma, const T* beta,
float variance = 0.0f;
float local_sum = 0.0f;
float local_var_sum = 0.0f;

const int n_elems = hidden_dim / num_elems_T;
for (int i = tidx; i < n_elems; i += blockDim.x)
{
Expand All @@ -83,56 +82,25 @@ __global__ void generalLayerNorm(const T* input, const T* gamma, const T* beta,
{
shmem[i] = val;
}

const float_packed_t val_f = cuda_cast<float_packed_t>(val);
local_sum += cuda_sum<float>(val_f);
if (USE_DIFF_OF_SQUARES)
{
local_var_sum += cuda_sum<float>(val_f * val_f);
}
}

if (USE_DIFF_OF_SQUARES)
{
float packed[2] = {local_sum, local_var_sum};
blockReduceSumV2<float, 2>(packed);
mean = packed[0];
variance = packed[1];
}
else
{
mean = blockReduceSum(local_sum);
local_var_sum += cuda_sum<float>(val_f * val_f);
}
//Compute mean
float packed[2] = {local_sum, local_var_sum};
blockReduceSumV2<float, 2>(packed);
mean = packed[0];
variance = packed[1];

if (threadIdx.x == 0)
{
mean = mean / hidden_dim;
s_mean = mean;
if (USE_DIFF_OF_SQUARES)
{
variance = (variance / hidden_dim) - (mean * mean); // Var[x] = E[x²] - E[x]²
s_variance = rsqrtf(variance + eps);
}
variance = (variance / hidden_dim) - (mean * mean); // Var[x] = E[x²] - E[x]²
s_variance = rsqrtf(variance + eps);
}
__syncthreads();

if (!USE_DIFF_OF_SQUARES)
{
for (int i = tidx; i < n_elems; i += blockDim.x)
{
const T val = use_shmem ? shmem[i] : input[bidx * n_elems + i];
float_packed_t diff = cuda_cast<float_packed_t>(val) - s_mean;
local_var_sum += cuda_sum<float>(diff * diff);
}
variance = blockReduceSum(local_var_sum);

if (threadIdx.x == 0)
{
s_variance = rsqrtf(variance / hidden_dim + eps);
}
__syncthreads();
}

// Compute LN and Quantize
const bool with_per_token_scaling = scale_orig_quant_per_token != nullptr;
const bool with_per_tensor_scaling = scale_orig_quant_per_tensor != nullptr;
const float_packed_t scale_orig_quant
Expand Down Expand Up @@ -186,51 +154,21 @@ __global__ void generalLayerNorm(const T* input, const T* gamma, const T* beta,
}
}
}
}

// TODO(woosuk): Further optimize this kernel.
template <typename scalar_t, typename out_type, bool use_quant>
__global__ void
rms_norm_kernel(out_type *__restrict__ out, // [..., hidden_size]
const scalar_t *__restrict__ input, // [..., hidden_size]
const scalar_t *__restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens,
const int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;

for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x = (float)input[blockIdx.x * hidden_size + idx];
variance += x * x;
}
variance = blockReduceSum<float>(variance);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();

for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * hidden_size + idx];
if constexpr (use_quant) {
out[blockIdx.x * hidden_size + idx] = float_to_int8_rn(
((float)(x * s_variance)) * (float)(weight[idx]));
} else {
out[blockIdx.x * hidden_size + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
}
}
}
} // namespace vllm

void rms_norm_general(torch::Tensor &out, // [..., hidden_size]
torch::Tensor &input, // [..., hidden_size]
torch::Tensor &weight, // [hidden_size]
torch::Tensor &bias, // [hidden_size]
torch::Tensor &scaling, // [tokens] or [1]
float epsilon,
bool use_per_token_quant) {
bool use_per_token_quant = true) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
dim3 block(std::min(hidden_size/2, 1024));//Prevent thread idling when the embedding size is greater than 1024 and not an integer multiple of it.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we fix this in a more elegant way to improve utilzation?

block.x = 32 * ((block.x + 31) / 32);

const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Expand All @@ -240,7 +178,8 @@ void rms_norm_general(torch::Tensor &out, // [..., hidden_size]
// per-token
vllm::generalLayerNorm<T, at::Half><<<grid, block, 0, stream>>>(
reinterpret_cast<T*>(input.data_ptr<scalar_t>()),
reinterpret_cast<T*>(weight.data_ptr<scalar_t>()), nullptr,
reinterpret_cast<T*>(weight.data_ptr<scalar_t>()),
reinterpret_cast<T*>(bias.data_ptr<scalar_t>()),
nullptr, epsilon, num_tokens, hidden_size, nullptr, scaling.data_ptr<at::Half>(),
out.data_ptr<int8_t>(), false
);
Expand All @@ -258,4 +197,4 @@ void rms_norm_general(torch::Tensor &out, // [..., hidden_size]
);
}
});
}
}
8 changes: 4 additions & 4 deletions awq/kernels/csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#include "rope_new/fused_rope_with_pos.h"
#include "w8a8/w8a8_gemm_cuda.h"
#include "w8a8/quantization.h"
// #include "fused_layernorm/layernorm.h"
#include "fused_layernorm/layernorm.h"


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
Expand All @@ -29,7 +29,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("w8a8_gemm_forward_cuda", &w8a8_gemm_forward_cuda, "our w8a8 gemm kernel");
m.def("w8a8_gemm_fuse_bias_forward_cuda", &w8a8_gemm_fuse_bias_forward_cuda, "our w8a8 gemm fused bias kernel");
m.def("invoke_quant", &invoke_quant, "fp16->int8 quantization");
// m.def("rms_norm_general", &rms_norm_general, py::arg("out"), py::arg("input"),
// py::arg("weight"), py::arg("scaling"), py::arg("epsilon"), py::arg("use_per_token_quant") = false,
// "Apply Root Mean Square (RMS) Normalization to the input tensor (TRTLLM kernel).");
m.def("rms_norm_general", &rms_norm_general, py::arg("out"), py::arg("input"),
py::arg("weight"), py::arg("bias"),py::arg("scaling"), py::arg("epsilon"), py::arg("use_per_token_quant") = true,
"Apply Root Mean Square (RMS) Normalization to the input tensor (TRTLLM kernel).");
}
58 changes: 32 additions & 26 deletions awq/kernels/csrc/w8a8/w8a8_gemm_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N) * (CTA_K / WARP_K); \
constexpr int kSmemByteSize = \
(CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B)) * STAGES * \
sizeof(int8_t); \
sizeof(int8_t) + CTA_N * sizeof(float); \
if (kSmemByteSize >= 99 * 1024) \
{ \
printf("This kernel requires %d Bytes of shared memory, which exceeds " \
Expand All @@ -41,12 +41,12 @@
return ; \
} \
int num_blocks_m = (num_out_feats + CTA_M - 1) / CTA_M; \
int num_blocks_n = num_out_channels / CTA_N / 1; \
int num_blocks_n = num_out_channels / CTA_N / 1; \
const int log_tile = get_log_tile<8>((num_out_feats + CTA_M - 1) / CTA_M); \
const int tile_shift = 1 << log_tile; \
dim3 num_blocks(num_blocks_n *tile_shift, \
(num_blocks_m + tile_shift - 1) / tile_shift); \
dim3 threads_per_block(WARP_SIZE, NUM_WARPS); \
dim3 threads_per_block(WARP_SIZE, NUM_WARPS); \
auto kernel_func = \
dense_kernel0_fuse_bias<CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES>; \
cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, \
Expand Down Expand Up @@ -300,7 +300,7 @@ template <int CTA_M, int CTA_N, int CTA_K, int WARP_M, int WARP_N, int WARP_K,
int STAGES>
__global__ void dense_kernel0_fuse_bias(int8_t *__restrict__ A, int8_t *__restrict__ B,
half2 *__restrict__ wscales, half *__restrict__ ascales,
half *__restrict__ C, half2 *__restrict__ Bias,
half *__restrict__ C, half *__restrict__ Bias,
int M, int N, int K)
{
constexpr int NUM_WARPS_MN = CTA_M / WARP_M * CTA_N / WARP_N;
Expand All @@ -326,9 +326,10 @@ __global__ void dense_kernel0_fuse_bias(int8_t *__restrict__ A, int8_t *__restri
constexpr int kSmemSizeBPerStage = CTA_N * kSmemPadKB;
constexpr int kSmemSizeA = kSmemSizeAPerStage * STAGES;
constexpr int kSmemSizeB = kSmemSizeBPerStage * STAGES;
extern __shared__ int8_t mem_shared[];
extern __shared__ int8_t mem_shared[]; //extern: dynamic share, decided in kernel launch; shared: within block
int8_t *A_shared = mem_shared;
int8_t *B_shared = mem_shared + kSmemSizeA;
float *Bias_shared= reinterpret_cast<float*>(mem_shared + kSmemSizeA + kSmemSizeB);
int8_t A_shared_warp_[2][WARP_M * WARP_K /
WARP_SIZE];
int8_t B_shared_warp_[2][WARP_N * WARP_K /
Expand All @@ -344,22 +345,22 @@ __global__ void dense_kernel0_fuse_bias(int8_t *__restrict__ A, int8_t *__restri
int cta_offset_m = blockIdx_m * CTA_M;
int cta_offset_n = blockIdx_n * CTA_N;
int warp_mn = threadIdx.y % NUM_WARPS_MN;
int slice_id = threadIdx.y / NUM_WARPS_MN;
int slice_id = threadIdx.y / NUM_WARPS_MN; // Always zero if threadIdx.z==0!
int warp_offset_m = (warp_mn % (CTA_M / WARP_M)) * WARP_M;
int warp_offset_n = (warp_mn / (CTA_M / WARP_M)) * WARP_N;
int warp_offset_k = slice_id * WARP_K;

for (int i = 0; i < CTA_M * CTA_N / CTA_SIZE_MN; i++)
C_warp[i] = 0;

int gemm_iters = (K + CTA_K - 1) / CTA_K;
int k_0_0_ld = 0;
int k_0_0 = 0;
constexpr int prologue_stages = STAGES == 1 ? 1 : STAGES - 1;
int A_hoisted_row = threadIdx.y * A_warp_step_m + (threadIdx.x / A_threads_per_row);
int A_hoisted_col = (threadIdx.x % A_threads_per_row);
int A_hoisted_col_swizzled = A_hoisted_col ^ (A_hoisted_row / 2) & 3;

int B_hoisted_row = threadIdx.y * B_warp_step_n + (threadIdx.x / B_threads_per_row);
int B_hoisted_col = (threadIdx.x % B_threads_per_row);
int B_hoisted_col_swizzled = B_hoisted_col ^ (B_hoisted_row / 2) & 3;
Expand All @@ -374,6 +375,18 @@ __global__ void dense_kernel0_fuse_bias(int8_t *__restrict__ A, int8_t *__restri
int8_t *B_hoisted = B + cta_offset_n * K + B_hoisted_row * K +
B_hoisted_col * PACK_SIZE;
bool A_g2s_preds[A_total_global_iters];
//debug
// printf("A: %d ",A_total_global_iters);
// printf("B: %d ",B_total_global_iters);
// printf("prologue_stages: %d ",prologue_stages);
// __shared__ float2 Bias_shared[CTA_N];
#pragma unroll
for (int i = 0; i < CTA_N ; i++)
{
Bias_shared[i] = __half2float(Bias[cta_offset_n+i]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the bottleneck?

}


#pragma unroll
for (int i = 0; i < A_total_global_iters; i++)
{
Expand All @@ -396,6 +409,8 @@ __global__ void dense_kernel0_fuse_bias(int8_t *__restrict__ A, int8_t *__restri
__pipeline_wait_prior(STAGES - 2);
__syncthreads();

// global_to_share_bias<CTA_N,CTA_SIZE>(Bias,Bias_shared,cta_offset_n);

share_to_reg_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES>(
A_shared + warp_offset_k, A_shared_warp_[0], warp_offset_m, warp_offset_n, 0,
WARP_M / INTRIN_M);
Expand Down Expand Up @@ -547,14 +562,13 @@ __global__ void dense_kernel0_fuse_bias(int8_t *__restrict__ A, int8_t *__restri
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2)
{
int row_wb = row_wb_1 + (local_id % 4) / 2 * 8;
if (row_wb < M){
int col_wb = col_wb_1 + (local_id / 4) * 8 + (local_id % 2);
int col_wb = col_wb_1 + (local_id / 4) * 8 + (local_id % 2);
if (row_wb < M){
float2 wscale = __half22float2(*(wscales + col_wb / 2));
float ascale = __half2float(ascales[row_wb]);
float2 bias = __half22float2(Bias[col_wb]);
float2 psums = make_float2(__int2float_rn(C_warp_local[local_id]), __int2float_rn(C_warp_local[local_id + 1]));
psums.x = psums.x * wscale.x * ascale + bias.x;
psums.y = psums.y * wscale.y * ascale + bias.y;
psums.x = psums.x * wscale.x * ascale + Bias_shared[col_wb % CTA_N];
psums.y = psums.y * wscale.y * ascale + Bias_shared[col_wb % CTA_N + 1];
*reinterpret_cast<half2 *>(C + row_wb * N + col_wb) = __float22half2_rn(psums);
}
};
Expand All @@ -576,7 +590,7 @@ void w8a8_gemm_fuse_bias_forward_cuda(torch::Tensor _in_feats,
auto kernel = reinterpret_cast<int8_t *>(_kernel.data_ptr<int8_t>());
auto wscales = reinterpret_cast<half2 *>(_wscales.data_ptr());
auto ascales = reinterpret_cast<half *>(_ascales.data_ptr());
auto bias = reinterpret_cast<half2 *>(_bias.data_ptr());
auto bias = reinterpret_cast<half *>(_bias.data_ptr());
// auto options =
// torch::TensorOptions().dtype(torch::kFloat16).device(_in_feats.device());
// at::Tensor _out_feats =
Expand All @@ -592,10 +606,10 @@ void w8a8_gemm_fuse_bias_forward_cuda(torch::Tensor _in_feats,
constexpr int CTA_M = 128;
constexpr int CTA_N = 128;
constexpr int CTA_K = 64;
constexpr int WARP_M = 128;
constexpr int WARP_M = 64;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Issue] Why couldn't we have 128 here?

constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int STAGES = 3;
constexpr int STAGES = 6;
KERNEL_LAUNCH_CODE_FUSE_BIAS
}
else
Expand All @@ -604,7 +618,7 @@ void w8a8_gemm_fuse_bias_forward_cuda(torch::Tensor _in_feats,
constexpr int CTA_N = 64;
constexpr int CTA_K = 64;
constexpr int WARP_M = 32;
constexpr int WARP_N = 32;
constexpr int WARP_N = 16;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Issue] Why couldn't we have 16 here?

constexpr int WARP_K = 64;
constexpr int STAGES = 6;
KERNEL_LAUNCH_CODE_FUSE_BIAS
Expand All @@ -613,14 +627,6 @@ void w8a8_gemm_fuse_bias_forward_cuda(torch::Tensor _in_feats,
}










template <int CTA_M, int CTA_N, int CTA_K, int WARP_M, int WARP_N, int WARP_K,
int STAGES>
__global__ void dense_kernel0(int8_t *__restrict__ A, int8_t *__restrict__ B,
Expand Down
2 changes: 1 addition & 1 deletion awq/kernels/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
"csrc/rope_new/fused_rope_with_pos.cu",
"csrc/w8a8/w8a8_gemm_cuda.cu",
"csrc/w8a8/quantization.cu",
# "csrc/fused_layernorm/layernorm_kernels.cu"
"csrc/fused_layernorm/layernorm_kernels.cu"
],
extra_compile_args=extra_compile_args,
),
Expand Down
Loading