Skip to content

Commit

Permalink
add high performance moe kernel; fix a16w8 compile bug for sm<80 (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
laiwenzh authored Feb 11, 2025
1 parent 22807e4 commit 069c74e
Show file tree
Hide file tree
Showing 9 changed files with 3,465 additions and 98 deletions.
39 changes: 21 additions & 18 deletions csrc/core/kernel/cuda/gemm_lowp/gemm_a16w8_perc_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1320,33 +1320,32 @@ struct ComputeTile_A16W8_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
// dequant B
#pragma unroll
for (int i = 0; i < WARP_NITER / 2; ++i) {
cvt_8bx4_to_16bx4_bias128(BQ_frag[reg_buf_idx][2 * i],
BF_frag[reg_buf_idx][2 * i]);
typename HalfType<FType>::T2 B_zero_x2 =
num2num2(static_cast<typename HalfType<FType>::T1>(0.f));
typename HalfType<FType>::T2 B_zero_y2 =
num2num2(static_cast<typename HalfType<FType>::T1>(0.f));
if (has_zp) {
BF_frag[reg_buf_idx][2 * i][0] =
__hsub2(BF_frag[reg_buf_idx][2 * i][0], num2num2(B_zero[i].x));
BF_frag[reg_buf_idx][2 * i][1] =
__hsub2(BF_frag[reg_buf_idx][2 * i][1], num2num2(B_zero[i].x));
B_zero_x2 = num2num2(B_zero[i].x);
B_zero_y2 = num2num2(B_zero[i].y);
}

BF_frag[reg_buf_idx][2 * i][0] =
__hmul2(BF_frag[reg_buf_idx][2 * i][0], num2num2(B_scale[i].x));
BF_frag[reg_buf_idx][2 * i][1] =
__hmul2(BF_frag[reg_buf_idx][2 * i][1], num2num2(B_scale[i].x));
cvt_8bx4_to_16bx4_bias128(BQ_frag[reg_buf_idx][2 * i],
BF_frag[reg_buf_idx][2 * i]);

BF_frag[reg_buf_idx][2 * i][0] = dequantize_func(
BF_frag[reg_buf_idx][2 * i][0], num2num2(B_scale[i].x), B_zero_x2);
BF_frag[reg_buf_idx][2 * i][1] = dequantize_func(
BF_frag[reg_buf_idx][2 * i][1], num2num2(B_scale[i].x), B_zero_x2);

cvt_8bx4_to_16bx4_bias128(BQ_frag[reg_buf_idx][2 * i + 1],
BF_frag[reg_buf_idx][2 * i + 1]);
if (has_zp) {
BF_frag[reg_buf_idx][2 * i + 1][0] =
__hsub2(BF_frag[reg_buf_idx][2 * i + 1][0], num2num2(B_zero[i].y));
BF_frag[reg_buf_idx][2 * i + 1][1] =
__hsub2(BF_frag[reg_buf_idx][2 * i + 1][1], num2num2(B_zero[i].y));
}

BF_frag[reg_buf_idx][2 * i + 1][0] =
__hmul2(BF_frag[reg_buf_idx][2 * i + 1][0], num2num2(B_scale[i].y));
dequantize_func(BF_frag[reg_buf_idx][2 * i + 1][0],
num2num2(B_scale[i].y), B_zero_y2);
BF_frag[reg_buf_idx][2 * i + 1][1] =
__hmul2(BF_frag[reg_buf_idx][2 * i + 1][1], num2num2(B_scale[i].y));
dequantize_func(BF_frag[reg_buf_idx][2 * i + 1][1],
num2num2(B_scale[i].y), B_zero_y2);
}
}

Expand Down Expand Up @@ -1677,6 +1676,10 @@ void ampere_hgemm_A16W8_perc_f16_f16_MtilexNtilex32_mma16816_multistage_AN_BTN32
const uint32_t K, void* workspace, const int sm_version,
const SplitKParams fused_gemm_params, const float alpha,
cudaStream_t stream) {
if (sm_version < 0x0800) {
throw std::runtime_error(
"this kernel is not supported on devices below sm80");
}
int Mtile = fused_gemm_params.Mtile;
int grid_x = (M + Mtile - 1) / Mtile;
int Ntile = fused_gemm_params.Ntile;
Expand Down
8 changes: 8 additions & 0 deletions csrc/core/kernel/cuda/gemm_lowp/gemm_lowp_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1030,7 +1030,15 @@ __device__ __forceinline__ void cvt_8bx4_to_16bx4_bias128<__nv_bfloat162>(
}

static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return __bfloat162bfloat162(x);
#else
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDACC_VER_MINOR__ >= 3
__builtin_unreachable();
#else
return nv_bfloat162{};
#endif // __CUDACC_VER_MAJOR__ >= 11 && __CUDACC_VER_MINOR__ >= 3
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
}

static __device__ half2 inline num2num2(const half x) {
Expand Down
Loading

0 comments on commit 069c74e

Please sign in to comment.