From be97857ecfc17e601d5b468fe41bf236832a9c4c Mon Sep 17 00:00:00 2001 From: Louym21 Date: Mon, 13 Jan 2025 03:20:31 -0500 Subject: [PATCH] [Minor] Fused some kernels --- awq/kernels/csrc/fused_layernorm/layernorm.h | 1 + .../csrc/fused_layernorm/layernorm_kernels.cu | 107 ++++-------------- awq/kernels/csrc/pybind.cpp | 8 +- awq/kernels/csrc/w8a8/w8a8_gemm_cuda.cu | 58 +++++----- awq/kernels/setup.py | 2 +- awq/quantize/w8a8_linear.py | 38 +++---- tinychat/modules/fused_siglipdecoder.py | 76 ++++++++++--- tinychat/nvila_demo.py | 18 --- 8 files changed, 133 insertions(+), 175 deletions(-) diff --git a/awq/kernels/csrc/fused_layernorm/layernorm.h b/awq/kernels/csrc/fused_layernorm/layernorm.h index 3639a15..9e5d740 100644 --- a/awq/kernels/csrc/fused_layernorm/layernorm.h +++ b/awq/kernels/csrc/fused_layernorm/layernorm.h @@ -14,6 +14,7 @@ void rms_norm_general(torch::Tensor &out, // [..., hidden_size] torch::Tensor &input, // [..., hidden_size] torch::Tensor &weight, // [hidden_size] + torch::Tensor &bias, // [hidden_size] torch::Tensor &scaling, // [tokens] or [1] float epsilon, bool use_per_token_quant); diff --git a/awq/kernels/csrc/fused_layernorm/layernorm_kernels.cu b/awq/kernels/csrc/fused_layernorm/layernorm_kernels.cu index f2ec103..fa33761 100644 --- a/awq/kernels/csrc/fused_layernorm/layernorm_kernels.cu +++ b/awq/kernels/csrc/fused_layernorm/layernorm_kernels.cu @@ -1,5 +1,5 @@ -// Inspired by TRT-LLM. -// Modified by Shang Yang and Haotian Tang. +// Inspired by QServe https://github.com/mit-han-lab/qserve/tree/main. +// Modified by Yuming Lou. // @article{lin2024awq, // title={AWQ: Activation-aware Weight Quantization for On-Device LLM Compression and Acceleration}, // author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Chen, Wei-Ming and Wang, Wei-Chen and Xiao, Guangxuan and Dang, Xingyu and Gan, Chuang and Han, Song}, @@ -10,7 +10,6 @@ // } #include #include - #include "dispatch_utils.h" #include "utils.cuh" #include "reduction_utils.cuh" @@ -41,10 +40,11 @@ __inline__ __device__ Tf compute_layernorm(Tf val, float s_mean, float s_varianc * First pass (loop) computes the mean. * Second computes the variance via Var[x] = E[(x - E[x])²]. * Third pass computes and writes normed_output - * - * with USE_DIFF_OF_SQUARES set to true (may be faster but less accurate): + * For better speedup, we set USE_DIFF_OF_SQUARES to true (may be faster but less accurate): * First pass (loop) computes the mean and variance via Var[x] = E[x²] - E[x]² * Second pass computes and writes normed_output + * + * It turns out the accuracy dosen't drop. * * use_shmem controls if we cache input values into shared memory * @@ -52,12 +52,12 @@ __inline__ __device__ Tf compute_layernorm(Tf val, float s_mean, float s_varianc * amax per row. A final pass scales to int8 accordingly, and writes output to * normed_output_quant. */ -template +template __global__ void generalLayerNorm(const T* input, const T* gamma, const T* beta, T* normed_output, const float eps, int tokens, int hidden_dim, const scale_type* scale_orig_quant_per_tensor, scale_type* scale_orig_quant_per_token, int8_t* normed_output_quant, bool use_shmem) { - constexpr auto num_elems_T = num_elems::value; + constexpr auto num_elems_T = num_elems::value;//1 using int8_packed_t = typename packed_as::type; using float_packed_t = typename packed_as::type; using T_scalar = typename packed_as::type; @@ -74,7 +74,6 @@ __global__ void generalLayerNorm(const T* input, const T* gamma, const T* beta, float variance = 0.0f; float local_sum = 0.0f; float local_var_sum = 0.0f; - const int n_elems = hidden_dim / num_elems_T; for (int i = tidx; i < n_elems; i += blockDim.x) { @@ -83,56 +82,25 @@ __global__ void generalLayerNorm(const T* input, const T* gamma, const T* beta, { shmem[i] = val; } - const float_packed_t val_f = cuda_cast(val); local_sum += cuda_sum(val_f); - if (USE_DIFF_OF_SQUARES) - { - local_var_sum += cuda_sum(val_f * val_f); - } - } - - if (USE_DIFF_OF_SQUARES) - { - float packed[2] = {local_sum, local_var_sum}; - blockReduceSumV2(packed); - mean = packed[0]; - variance = packed[1]; - } - else - { - mean = blockReduceSum(local_sum); + local_var_sum += cuda_sum(val_f * val_f); } + //Compute mean + float packed[2] = {local_sum, local_var_sum}; + blockReduceSumV2(packed); + mean = packed[0]; + variance = packed[1]; if (threadIdx.x == 0) { mean = mean / hidden_dim; s_mean = mean; - if (USE_DIFF_OF_SQUARES) - { - variance = (variance / hidden_dim) - (mean * mean); // Var[x] = E[x²] - E[x]² - s_variance = rsqrtf(variance + eps); - } + variance = (variance / hidden_dim) - (mean * mean); // Var[x] = E[x²] - E[x]² + s_variance = rsqrtf(variance + eps); } __syncthreads(); - - if (!USE_DIFF_OF_SQUARES) - { - for (int i = tidx; i < n_elems; i += blockDim.x) - { - const T val = use_shmem ? shmem[i] : input[bidx * n_elems + i]; - float_packed_t diff = cuda_cast(val) - s_mean; - local_var_sum += cuda_sum(diff * diff); - } - variance = blockReduceSum(local_var_sum); - - if (threadIdx.x == 0) - { - s_variance = rsqrtf(variance / hidden_dim + eps); - } - __syncthreads(); - } - + // Compute LN and Quantize const bool with_per_token_scaling = scale_orig_quant_per_token != nullptr; const bool with_per_tensor_scaling = scale_orig_quant_per_tensor != nullptr; const float_packed_t scale_orig_quant @@ -186,51 +154,21 @@ __global__ void generalLayerNorm(const T* input, const T* gamma, const T* beta, } } } -} -// TODO(woosuk): Further optimize this kernel. -template -__global__ void -rms_norm_kernel(out_type *__restrict__ out, // [..., hidden_size] - const scalar_t *__restrict__ input, // [..., hidden_size] - const scalar_t *__restrict__ weight, // [hidden_size] - const float epsilon, const int num_tokens, - const int hidden_size) { - __shared__ float s_variance; - float variance = 0.0f; - for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - const float x = (float)input[blockIdx.x * hidden_size + idx]; - variance += x * x; - } - variance = blockReduceSum(variance); - if (threadIdx.x == 0) { - s_variance = rsqrtf(variance / hidden_size + epsilon); - } - __syncthreads(); - - for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float x = (float)input[blockIdx.x * hidden_size + idx]; - if constexpr (use_quant) { - out[blockIdx.x * hidden_size + idx] = float_to_int8_rn( - ((float)(x * s_variance)) * (float)(weight[idx])); - } else { - out[blockIdx.x * hidden_size + idx] = - ((scalar_t)(x * s_variance)) * weight[idx]; - } - } -} +} // namespace vllm void rms_norm_general(torch::Tensor &out, // [..., hidden_size] torch::Tensor &input, // [..., hidden_size] torch::Tensor &weight, // [hidden_size] + torch::Tensor &bias, // [hidden_size] torch::Tensor &scaling, // [tokens] or [1] float epsilon, - bool use_per_token_quant) { + bool use_per_token_quant = true) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; dim3 grid(num_tokens); - dim3 block(std::min(hidden_size, 1024)); + dim3 block(std::min(hidden_size/2, 1024));//Prevent thread idling when the embedding size is greater than 1024 and not an integer multiple of it. block.x = 32 * ((block.x + 31) / 32); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -240,7 +178,8 @@ void rms_norm_general(torch::Tensor &out, // [..., hidden_size] // per-token vllm::generalLayerNorm<<>>( reinterpret_cast(input.data_ptr()), - reinterpret_cast(weight.data_ptr()), nullptr, + reinterpret_cast(weight.data_ptr()), + reinterpret_cast(bias.data_ptr()), nullptr, epsilon, num_tokens, hidden_size, nullptr, scaling.data_ptr(), out.data_ptr(), false ); @@ -258,4 +197,4 @@ void rms_norm_general(torch::Tensor &out, // [..., hidden_size] ); } }); -} +} \ No newline at end of file diff --git a/awq/kernels/csrc/pybind.cpp b/awq/kernels/csrc/pybind.cpp index ea0d8e6..20548b1 100644 --- a/awq/kernels/csrc/pybind.cpp +++ b/awq/kernels/csrc/pybind.cpp @@ -10,7 +10,7 @@ #include "rope_new/fused_rope_with_pos.h" #include "w8a8/w8a8_gemm_cuda.h" #include "w8a8/quantization.h" -// #include "fused_layernorm/layernorm.h" +#include "fused_layernorm/layernorm.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) @@ -29,7 +29,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) m.def("w8a8_gemm_forward_cuda", &w8a8_gemm_forward_cuda, "our w8a8 gemm kernel"); m.def("w8a8_gemm_fuse_bias_forward_cuda", &w8a8_gemm_fuse_bias_forward_cuda, "our w8a8 gemm fused bias kernel"); m.def("invoke_quant", &invoke_quant, "fp16->int8 quantization"); - // m.def("rms_norm_general", &rms_norm_general, py::arg("out"), py::arg("input"), - // py::arg("weight"), py::arg("scaling"), py::arg("epsilon"), py::arg("use_per_token_quant") = false, - // "Apply Root Mean Square (RMS) Normalization to the input tensor (TRTLLM kernel)."); + m.def("rms_norm_general", &rms_norm_general, py::arg("out"), py::arg("input"), + py::arg("weight"), py::arg("bias"),py::arg("scaling"), py::arg("epsilon"), py::arg("use_per_token_quant") = true, + "Apply Root Mean Square (RMS) Normalization to the input tensor (TRTLLM kernel)."); } diff --git a/awq/kernels/csrc/w8a8/w8a8_gemm_cuda.cu b/awq/kernels/csrc/w8a8/w8a8_gemm_cuda.cu index 31bf285..b62b762 100644 --- a/awq/kernels/csrc/w8a8/w8a8_gemm_cuda.cu +++ b/awq/kernels/csrc/w8a8/w8a8_gemm_cuda.cu @@ -32,7 +32,7 @@ constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N) * (CTA_K / WARP_K); \ constexpr int kSmemByteSize = \ (CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B)) * STAGES * \ - sizeof(int8_t); \ + sizeof(int8_t) + CTA_N * sizeof(float); \ if (kSmemByteSize >= 99 * 1024) \ { \ printf("This kernel requires %d Bytes of shared memory, which exceeds " \ @@ -41,12 +41,12 @@ return ; \ } \ int num_blocks_m = (num_out_feats + CTA_M - 1) / CTA_M; \ - int num_blocks_n = num_out_channels / CTA_N / 1; \ + int num_blocks_n = num_out_channels / CTA_N / 1; \ const int log_tile = get_log_tile<8>((num_out_feats + CTA_M - 1) / CTA_M); \ const int tile_shift = 1 << log_tile; \ dim3 num_blocks(num_blocks_n *tile_shift, \ (num_blocks_m + tile_shift - 1) / tile_shift); \ - dim3 threads_per_block(WARP_SIZE, NUM_WARPS); \ + dim3 threads_per_block(WARP_SIZE, NUM_WARPS); \ auto kernel_func = \ dense_kernel0_fuse_bias; \ cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, \ @@ -300,7 +300,7 @@ template __global__ void dense_kernel0_fuse_bias(int8_t *__restrict__ A, int8_t *__restrict__ B, half2 *__restrict__ wscales, half *__restrict__ ascales, - half *__restrict__ C, half2 *__restrict__ Bias, + half *__restrict__ C, half *__restrict__ Bias, int M, int N, int K) { constexpr int NUM_WARPS_MN = CTA_M / WARP_M * CTA_N / WARP_N; @@ -326,9 +326,10 @@ __global__ void dense_kernel0_fuse_bias(int8_t *__restrict__ A, int8_t *__restri constexpr int kSmemSizeBPerStage = CTA_N * kSmemPadKB; constexpr int kSmemSizeA = kSmemSizeAPerStage * STAGES; constexpr int kSmemSizeB = kSmemSizeBPerStage * STAGES; - extern __shared__ int8_t mem_shared[]; + extern __shared__ int8_t mem_shared[]; //extern: dynamic share, decided in kernel launch; shared: within block int8_t *A_shared = mem_shared; int8_t *B_shared = mem_shared + kSmemSizeA; + float *Bias_shared= reinterpret_cast(mem_shared + kSmemSizeA + kSmemSizeB); int8_t A_shared_warp_[2][WARP_M * WARP_K / WARP_SIZE]; int8_t B_shared_warp_[2][WARP_N * WARP_K / @@ -344,14 +345,14 @@ __global__ void dense_kernel0_fuse_bias(int8_t *__restrict__ A, int8_t *__restri int cta_offset_m = blockIdx_m * CTA_M; int cta_offset_n = blockIdx_n * CTA_N; int warp_mn = threadIdx.y % NUM_WARPS_MN; - int slice_id = threadIdx.y / NUM_WARPS_MN; + int slice_id = threadIdx.y / NUM_WARPS_MN; // Always zero if threadIdx.z==0! int warp_offset_m = (warp_mn % (CTA_M / WARP_M)) * WARP_M; int warp_offset_n = (warp_mn / (CTA_M / WARP_M)) * WARP_N; int warp_offset_k = slice_id * WARP_K; - + for (int i = 0; i < CTA_M * CTA_N / CTA_SIZE_MN; i++) C_warp[i] = 0; - + int gemm_iters = (K + CTA_K - 1) / CTA_K; int k_0_0_ld = 0; int k_0_0 = 0; @@ -359,7 +360,7 @@ __global__ void dense_kernel0_fuse_bias(int8_t *__restrict__ A, int8_t *__restri int A_hoisted_row = threadIdx.y * A_warp_step_m + (threadIdx.x / A_threads_per_row); int A_hoisted_col = (threadIdx.x % A_threads_per_row); int A_hoisted_col_swizzled = A_hoisted_col ^ (A_hoisted_row / 2) & 3; - + int B_hoisted_row = threadIdx.y * B_warp_step_n + (threadIdx.x / B_threads_per_row); int B_hoisted_col = (threadIdx.x % B_threads_per_row); int B_hoisted_col_swizzled = B_hoisted_col ^ (B_hoisted_row / 2) & 3; @@ -374,6 +375,18 @@ __global__ void dense_kernel0_fuse_bias(int8_t *__restrict__ A, int8_t *__restri int8_t *B_hoisted = B + cta_offset_n * K + B_hoisted_row * K + B_hoisted_col * PACK_SIZE; bool A_g2s_preds[A_total_global_iters]; + //debug + // printf("A: %d ",A_total_global_iters); + // printf("B: %d ",B_total_global_iters); + // printf("prologue_stages: %d ",prologue_stages); + // __shared__ float2 Bias_shared[CTA_N]; + #pragma unroll + for (int i = 0; i < CTA_N ; i++) + { + Bias_shared[i] = __half2float(Bias[cta_offset_n+i]); + } + + #pragma unroll for (int i = 0; i < A_total_global_iters; i++) { @@ -396,6 +409,8 @@ __global__ void dense_kernel0_fuse_bias(int8_t *__restrict__ A, int8_t *__restri __pipeline_wait_prior(STAGES - 2); __syncthreads(); +// global_to_share_bias(Bias,Bias_shared,cta_offset_n); + share_to_reg_one_stage_A( A_shared + warp_offset_k, A_shared_warp_[0], warp_offset_m, warp_offset_n, 0, WARP_M / INTRIN_M); @@ -547,14 +562,13 @@ __global__ void dense_kernel0_fuse_bias(int8_t *__restrict__ A, int8_t *__restri for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2) { int row_wb = row_wb_1 + (local_id % 4) / 2 * 8; - if (row_wb < M){ - int col_wb = col_wb_1 + (local_id / 4) * 8 + (local_id % 2); + int col_wb = col_wb_1 + (local_id / 4) * 8 + (local_id % 2); + if (row_wb < M){ float2 wscale = __half22float2(*(wscales + col_wb / 2)); float ascale = __half2float(ascales[row_wb]); - float2 bias = __half22float2(Bias[col_wb]); float2 psums = make_float2(__int2float_rn(C_warp_local[local_id]), __int2float_rn(C_warp_local[local_id + 1])); - psums.x = psums.x * wscale.x * ascale + bias.x; - psums.y = psums.y * wscale.y * ascale + bias.y; + psums.x = psums.x * wscale.x * ascale + Bias_shared[col_wb % CTA_N]; + psums.y = psums.y * wscale.y * ascale + Bias_shared[col_wb % CTA_N + 1]; *reinterpret_cast(C + row_wb * N + col_wb) = __float22half2_rn(psums); } }; @@ -576,7 +590,7 @@ void w8a8_gemm_fuse_bias_forward_cuda(torch::Tensor _in_feats, auto kernel = reinterpret_cast(_kernel.data_ptr()); auto wscales = reinterpret_cast(_wscales.data_ptr()); auto ascales = reinterpret_cast(_ascales.data_ptr()); - auto bias = reinterpret_cast(_bias.data_ptr()); + auto bias = reinterpret_cast(_bias.data_ptr()); // auto options = // torch::TensorOptions().dtype(torch::kFloat16).device(_in_feats.device()); // at::Tensor _out_feats = @@ -592,10 +606,10 @@ void w8a8_gemm_fuse_bias_forward_cuda(torch::Tensor _in_feats, constexpr int CTA_M = 128; constexpr int CTA_N = 128; constexpr int CTA_K = 64; - constexpr int WARP_M = 128; + constexpr int WARP_M = 64; constexpr int WARP_N = 32; constexpr int WARP_K = 64; - constexpr int STAGES = 3; + constexpr int STAGES = 6; KERNEL_LAUNCH_CODE_FUSE_BIAS } else @@ -604,7 +618,7 @@ void w8a8_gemm_fuse_bias_forward_cuda(torch::Tensor _in_feats, constexpr int CTA_N = 64; constexpr int CTA_K = 64; constexpr int WARP_M = 32; - constexpr int WARP_N = 32; + constexpr int WARP_N = 16; constexpr int WARP_K = 64; constexpr int STAGES = 6; KERNEL_LAUNCH_CODE_FUSE_BIAS @@ -613,14 +627,6 @@ void w8a8_gemm_fuse_bias_forward_cuda(torch::Tensor _in_feats, } - - - - - - - - template __global__ void dense_kernel0(int8_t *__restrict__ A, int8_t *__restrict__ B, diff --git a/awq/kernels/setup.py b/awq/kernels/setup.py index d654f14..ae9efd4 100644 --- a/awq/kernels/setup.py +++ b/awq/kernels/setup.py @@ -40,7 +40,7 @@ "csrc/rope_new/fused_rope_with_pos.cu", "csrc/w8a8/w8a8_gemm_cuda.cu", "csrc/w8a8/quantization.cu", - # "csrc/fused_layernorm/layernorm_kernels.cu" + "csrc/fused_layernorm/layernorm_kernels.cu" ], extra_compile_args=extra_compile_args, ), diff --git a/awq/quantize/w8a8_linear.py b/awq/quantize/w8a8_linear.py index 0336454..6d343e7 100644 --- a/awq/quantize/w8a8_linear.py +++ b/awq/quantize/w8a8_linear.py @@ -97,30 +97,26 @@ def apply_weights( if len(x.shape) > 2: assert 0, "Not implemented" x = x.view(-1, x_shape[-1]) - # If use awq_inference_engine.w8a8_gemm_forward_cuda - awq_inference_engine.w8a8_gemm_forward_cuda( - x, self.weight, self.dequant_scale, input_scale, output_buffer - ) - - # If use awq_inference_engine.w8a8_gemm_fuse_bias_forward_cuda - # awq_inference_engine.w8a8_gemm_fuse_bias_forward_cuda( - # x, self.weight, self.dequant_scale.half(), input_scale.half(), output_buffer, bias - # ) + if bias is not None: + # If use awq_inference_engine.w8a8_gemm_fuse_bias_forward_cuda + awq_inference_engine.w8a8_gemm_fuse_bias_forward_cuda( + x, self.weight, self.dequant_scale, input_scale, output_buffer, bias + ) + else: + # If use awq_inference_engine.w8a8_gemm_forward_cuda + awq_inference_engine.w8a8_gemm_forward_cuda( + x, self.weight, self.dequant_scale, input_scale, output_buffer + ) + + if len(x.shape) > 2: assert 0, "Not implemented 2" output_buffer = output_buffer.view(*x_shape[:-1], -1) def forward(self, input_, input_scale, output_buffer): # Matrix multiply. - - # If use awq_inference_engine.w8a8_gemm_fuse_bias_forward_cuda - # self.apply_weights(input_, input_scale, output_buffer, self.bias) - # If use awq_inference_engine.w8a8_gemm_forward_cuda - self.apply_weights(input_, input_scale, output_buffer) - output_bias = self.bias - if output_bias is not None: - output_buffer += output_bias + self.apply_weights(input_, input_scale, output_buffer, self.bias) @classmethod def from_linear( @@ -230,15 +226,7 @@ def forward(self, input): scales = input.abs().max(dim=-1, keepdim=True)[0] scales.clamp_(min=1e-5).div_(self.maxv) input.div_(scales).round_().mul_(scales) - # print(scales.abs().max(dim=-1, keepdim=True)[0].reshape(-1)) - # print(torch.sum(input==0)/input.numel()) output = torch.functional.F.linear(input, self.weight, self.bias) - # output=input.float()@(self.weight.float().T) - # # print(self.weight) - # if self.bias is not None: - # output=output+self.bias.float() - # # print(output[0,0]) - # # print(torch.sum(torch.isnan(output))/output.numel()) return output @classmethod diff --git a/tinychat/modules/fused_siglipdecoder.py b/tinychat/modules/fused_siglipdecoder.py index 7b4a55a..d002544 100644 --- a/tinychat/modules/fused_siglipdecoder.py +++ b/tinychat/modules/fused_siglipdecoder.py @@ -20,7 +20,7 @@ import awq_inference_engine - +@torch.no_grad() class QuantSiglipEncoder(nn.Module): def __init__(self, module: SiglipEncoder, bsz=64, seqlen=1024): super().__init__() @@ -80,7 +80,7 @@ def forward( attentions=None, ) - +@torch.no_grad() class QuantSiglipMLP(nn.Module): def __init__(self, siglipmlp, init_only=False): super().__init__() @@ -117,7 +117,7 @@ def forward(self, buffer: ActivationBuffer) -> torch.Tensor: buffer.in_out_fc2_act_buffer, ) - +@torch.no_grad() class QuantSiglipFlashAttention2(nn.Module): def __init__( self, @@ -175,15 +175,15 @@ def forward( ) # buffer.in_out_fc2_act_buffer=self.out_proj(buffer.in_out_fc2_act_buffer) - +@torch.no_grad() class QuantSiglipEncoderLayer(nn.Module): def __init__(self, module: SiglipEncoderLayer): super().__init__() self.embed_dim = module.embed_dim self.self_attn = QuantSiglipFlashAttention2(module.self_attn) - self.layer_norm1 = module.layer_norm1.cuda() + self.layer_norm1 = RMSNormGeneral(module.layer_norm1.weight.data, module.layer_norm1.bias.data, module.layer_norm1.eps, True).cuda() self.mlp = QuantSiglipMLP(module.mlp) - self.layer_norm2 = module.layer_norm2.cuda() + self.layer_norm2 = RMSNormGeneral(module.layer_norm2.weight.data, module.layer_norm2.bias.data, module.layer_norm2.eps, True).cuda() self.quant = self.invoke_quant_norm def invoke_quant_norm(self, buffer, normfn_output): @@ -205,28 +205,70 @@ def forward( # FP16 in FP16 out # Self Attention residual = hidden_states - normfn_output = self.layer_norm1(hidden_states) + self.layer_norm1( + hidden_states.reshape(-1, self.embed_dim), + buffer.quantized_hidden_states_buffer, + buffer.quantized_scale_buffer + ) # INT8 quantization - # normfn_output=torch.clip(normfn_output,min=-CLIP_RANGE,max=CLIP_RANGE) - self.quant(buffer, normfn_output.reshape(-1, 1152)) + # INT8 -> FP16 self.self_attn(buffer, bsz, seqlen) hidden_states = ( - residual.reshape(-1, residual.shape[-1]) + buffer.in_out_fc2_act_buffer + residual.reshape(-1, self.embed_dim) + buffer.in_out_fc2_act_buffer ) # Fully Connected residual = hidden_states - normfn_output = self.layer_norm2(hidden_states) - # FP16 -> INT8 - # normfn_output=torch.clip(normfn_output,min=-CLIP_RANGE,max=CLIP_RANGE) - normfn_output = self.quant( - buffer, - normfn_output, + + self.layer_norm2( + hidden_states.reshape(-1, self.embed_dim), + buffer.quantized_hidden_states_buffer, + buffer.quantized_scale_buffer ) # INT8 -> FP16 self.mlp(buffer) hidden_states = ( - residual.reshape(-1, residual.shape[-1]) + buffer.in_out_fc2_act_buffer + residual.reshape(-1, self.embed_dim) + buffer.in_out_fc2_act_buffer ) return hidden_states + +@torch.no_grad() +class RMSNormGeneral(nn.Module): + """Root mean square normalization (w/ per-token or per-tensor quant). + + Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight. + Refer to https://arxiv.org/abs/1910.07467 + """ + + def __init__( + self, + weight: torch.tensor, + bias: torch.tensor, + eps: float = 1e-6, + use_per_token_quant: bool = True, + ) -> None: + super().__init__() + self.weight = nn.Parameter(weight,requires_grad=False) + self.bias = nn.Parameter(bias,requires_grad=False) + self.variance_epsilon = eps + self.use_per_token_quant = use_per_token_quant + + def forward( + self, + x: torch.Tensor, + quantized_hidden_states_buffer: torch.Tensor, + quantized_scale_buffer: torch.Tensor, + quantized_sum_buffer: torch.Tensor = None, + ) -> torch.Tensor: + # quantized_sum_buffer is not used, only to keep the consistency of the interface + awq_inference_engine.rms_norm_general( + quantized_hidden_states_buffer, + x, + self.weight.data, + self.bias.data, + quantized_scale_buffer, + self.variance_epsilon, + self.use_per_token_quant, + ) + diff --git a/tinychat/nvila_demo.py b/tinychat/nvila_demo.py index 0944616..5538219 100644 --- a/tinychat/nvila_demo.py +++ b/tinychat/nvila_demo.py @@ -244,21 +244,3 @@ def main(args): ) args = parser.parse_args() main(args) -""" -python nvila_demo.py --model-path /home/yuming/workspace/qwen/models/nvila-video \ - --quant_path /home/yuming/workspace/awq4nvila/nvila-video-w4-g128.pt \ - --media ../figures/nvila_demo_video.mp4 \ - --act_scale_path /home/yuming/workspace/awq4nvila/nvila-video-VT-smooth-scale.pt \ - --all --chunk --vis-image - -python nvila_demo.py --model-path Efficient-Large-Model/nvila-internal-8b-v1 \ - --quant_path /home/yuming/workspace/awq4nvila/nvila-internal-8b-v1-w4-g128.pt \ - --media ../figures/vila-logo.jpg \ - --act_scale_path /home/yuming/workspace/awq4nvila/nvila-internal-8b-v1-VT-smooth-scale.pt \ - --all --chunk --vis-image - -python nvila_demo.py --model-path /home/yuming/workspace/qwen/models/nvila-lite-internal-8b-v1 \ - --quant_path /home/yuming/workspace/awq4nvila/nvila-lite-internal-8b-v1-w4-g128.pt \ - --act_scale_path /home/yuming/workspace/awq4nvila/nvila-lite-internal-8b-v1-VT-smooth-scale.pt --all \ - --media ../figures/vila-logo.jpg --chunk --vis-image -"""