Skip to content

Commit d0af941

Browse files
Add support for building CUDA extension on Windows (#396)
* Enable FP6-LLM kernel build on Windows * fix benchmark script * update setup.py * update * fix indent * add -t=0 for linux --------- Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>
1 parent f5b6ec9 commit d0af941

File tree

7 files changed

+83
-48
lines changed

7 files changed

+83
-48
lines changed

benchmarks/benchmark_fp6_llm.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77

88

99
def benchmark(m: int, k: int, n: int):
10-
fp6_weight = torch.randint(256, size=(n, k // 4 * 3), dtype=torch.uint8, device="cuda")
10+
fp6_weight = torch.randint(256, size=(n, k * 3 // 4), dtype=torch.uint8, device="cuda")
1111
scales = torch.rand(n, dtype=torch.half, device="cuda") + 0.5
12-
fp6_linear = Fp6LlmLinear(fp6_weight.view(torch.int32), scales)
12+
fp6_linear = Fp6LlmLinear(fp6_weight, scales)
1313

1414
fp16_linear = nn.Linear(k, n, bias=True, dtype=torch.half, device="cuda")
15-
fp16_linear.weight.data = from_tc_float6_e3m2(fp6_weight.view(-1), n, k, dtype=torch.half) * scales[:, None]
15+
fp16_linear.weight.data = from_tc_float6_e3m2(fp6_weight, dtype=torch.half) * scales[:, None]
1616

1717
fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda")
1818
fp6_output = fp6_linear(fp16_act)

setup.py

+36-14
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def read_version(file_path="version.txt"):
3535
CUDAExtension,
3636
BuildExtension,
3737
CUDA_HOME,
38+
IS_WINDOWS
3839
)
3940

4041

@@ -52,20 +53,41 @@ def get_extensions():
5253
use_cuda = torch.cuda.is_available() and CUDA_HOME is not None
5354
extension = CUDAExtension if use_cuda else CppExtension
5455

55-
extra_link_args = []
56-
extra_compile_args = {
57-
"cxx": [
58-
"-O3" if not debug_mode else "-O0",
59-
"-fdiagnostics-color=always",
60-
],
61-
"nvcc": [
62-
"-O3" if not debug_mode else "-O0",
63-
]
64-
}
65-
if debug_mode:
66-
extra_compile_args["cxx"].append("-g")
67-
extra_compile_args["nvcc"].append("-g")
68-
extra_link_args.extend(["-O0", "-g"])
56+
if not IS_WINDOWS:
57+
extra_link_args = []
58+
extra_compile_args = {
59+
"cxx": [
60+
"-O3" if not debug_mode else "-O0",
61+
"-fdiagnostics-color=always",
62+
],
63+
"nvcc": [
64+
"-O3" if not debug_mode else "-O0",
65+
"-t=0",
66+
]
67+
}
68+
69+
if debug_mode:
70+
extra_compile_args["cxx"].append("-g")
71+
extra_compile_args["nvcc"].append("-g")
72+
extra_link_args.extend(["-O0", "-g"])
73+
74+
else:
75+
extra_link_args = []
76+
extra_compile_args = {
77+
"cxx": [
78+
"/O2" if not debug_mode else "/Od",
79+
"/permissive-"
80+
],
81+
"nvcc": [
82+
"-O3" if not debug_mode else "-O0",
83+
"-t=0",
84+
]
85+
}
86+
87+
if debug_mode:
88+
extra_compile_args["cxx"].append("/ZI")
89+
extra_compile_args["nvcc"].append("-g")
90+
extra_link_args.append("/DEBUG")
6991

7092
this_dir = os.path.dirname(os.path.curdir)
7193
extensions_dir = os.path.join(this_dir, "torchao", "csrc")

torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh

+5-4
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414
//
15-
// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/kernel_matmul.cuh
15+
// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/kernel_matmul.cuh
1616

1717
#include "configs.h"
1818
#include "utils_gmem.cuh"
@@ -133,11 +133,12 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales,
133133
uint32_t* __restrict__ write_SPTR_Frag1 = AFrag_2BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A1/4*4; // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16
134134
uint32_t* __restrict__ write_SPTR_Frag2 = AFrag_4BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A2/4*4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16
135135
// Trible-Buffer for B Tile
136-
half __restrict__ (*read_SPTR )[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;
136+
// MODIFICATION NOTE: to support MSVC, half __restrict__ (*read_SPTR ) is changed to below. similarly for read2_SPTR and write_SPTR.
137+
half (* __restrict__ read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;
137138
#ifdef PIPELINE_LEVEL_SMEM
138-
half __restrict__ (*read2_SPTR )[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;
139+
half (* __restrict__ read2_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;
139140
#endif
140-
half __restrict__ (*write_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;
141+
half (* __restrict__ write_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;
141142
//
142143
bool GlobalCopy = (tile_id_k+PIPELINE_LEVEL_GMEM-1) < NumIter;
143144
// Copying A tile from Global to Register, Bypassing L1, using double-buffer

torchao/csrc/cuda/fp6_llm/ptx_mma.cuh

+10-5
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414
//
15-
// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/ptx_mma.cuh
15+
// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/ptx_mma.cuh
1616

1717
/***************************************************************************
1818
* Copyright 2023 The FLash-LLM Authors. All rights reserved.
@@ -36,11 +36,14 @@
3636
#include <assert.h>
3737
#include "configs.h"
3838

39+
// MODIFICATION NOTE: to support MSVC
40+
// - uint32_t __restrict__ Reg[][4] is changed to uint32_t (* __restrict__ Reg)[4]
41+
// - half __restrict__ (*read_SPTR) is changed to half (* __restrict__ read_SPTR)
3942
#ifdef PIPELINE_LEVEL_SMEM
4043
template <typename TilingConfig>
41-
__device__ __forceinline__ void B_FromSharedToReg(uint32_t __restrict__ Reg[][4],
42-
half __restrict__ (*read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
43-
int slice_id) {
44+
__device__ __forceinline__ void B_FromSharedToReg(uint32_t (* __restrict__ Reg)[4],
45+
half (* __restrict__ read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
46+
int slice_id) {
4447
#ifdef DEBUG_MODE
4548
static_assert( (TilingConfig::WARP_COL_MMA_TENSORS==1) || (TilingConfig::WARP_COL_MMA_TENSORS%2==0) );
4649
#endif
@@ -112,8 +115,10 @@ __device__ __forceinline__ void B_FromSharedToReg(uint32_t __restrict__ Reg[
112115
}
113116
#endif
114117

118+
// MODIFICATION NOTE: to support MSVC, the function signature is changed from
119+
// MMA_FP16_M16N8K16(uint32_t __restrict__ c[], uint32_t __restrict__ *a, uint32_t __restrict__ *b).
115120
__device__ __forceinline__ void
116-
MMA_FP16_M16N8K16(uint32_t __restrict__ c[], uint32_t __restrict__ *a, uint32_t __restrict__ *b)
121+
MMA_FP16_M16N8K16(uint32_t * __restrict__ c, uint32_t * __restrict__ a, uint32_t * __restrict__ b)
117122
{
118123
asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
119124
"{ %0, %1, %2, %3},"

torchao/csrc/cuda/fp6_llm/utils_core.cuh

+5-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414
//
15-
// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_core.cuh
15+
// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_core.cuh
1616

1717
#ifndef UTILS_CORE_CUH
1818
#define UTILS_CORE_CUH
@@ -35,12 +35,13 @@ __device__ __forceinline__ void CopyFromSharedToRegister_AFrag(uint32_t Reg[], u
3535
}
3636
}
3737

38+
// MODIFICATION NOTE: to support MSVC, half __restrict__ (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below.
3839
template <typename TilingConfig>
3940
__device__ __forceinline__ void initialize_mma_slice(uint32_t (*a)[4],
4041
uint32_t (*b)[4],
4142
uint32_t* __restrict__ A1_SPTR_read,
4243
uint32_t* __restrict__ A2_SPTR_read,
43-
half __restrict__ (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
44+
half (* __restrict__ B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
4445
uint32_t* RPTR_Scales)
4546
{
4647
// Writing registers
@@ -53,13 +54,14 @@ __device__ __forceinline__ void initialize_mma_slice(uint32_t (
5354
B_FromSharedToReg<TilingConfig>(b, B_SPTR_read, 0); // Loading B from shared to registers
5455
}
5556

57+
// MODIFICATION NOTE: to support MSVC, half __restrict__ (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below.
5658
template <typename TilingConfig>
5759
__device__ __forceinline__ void core_mma_slice(float c[][REG_PER_THREAD_C_TENSOR_16_16],
5860
uint32_t (*a)[4],
5961
uint32_t (*b)[4],
6062
uint32_t* __restrict__ A1_SPTR_read,
6163
uint32_t* __restrict__ A2_SPTR_read,
62-
half __restrict__ (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
64+
half (* __restrict__ B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
6365
uint32_t* RPTR_Scales,
6466
int slice_id) // writing slice[slice_id] to registers, k=0 -> slice_id=1 for prefetching
6567
{

torchao/csrc/cuda/fp6_llm/utils_gmem.cuh

+7-6
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414
//
15-
// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_gmem.cuh
15+
// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_gmem.cuh
1616

1717
#ifndef UTILS_GMEM_CUH
1818
#define UTILS_GMEM_CUH
@@ -57,17 +57,18 @@ __device__ __forceinline__ void CopyFromGlobalToShared_Scales(half* SPTR_QuantSc
5757
for(int i=0; i<2; i++) SPTR_QuantScales[Offset_Shared+i] = GPTR_A_Scales[Offset_Global+i*8];
5858
}
5959

60+
// MODIFICATION NOTE: to support MSVC, half __restrict__ (*SharedPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below.
6061
/*
6162
* (1) Copying X rows * 64 columns of FP16 values, originally in row major
6263
* (2) Copying 64 rows * X columns of FP16 values, originally in column major
6364
* 16 Bytes per thread -> 512 Bytes per WARP = 4 line per WARP = 1 line per 8 Threads
6465
*/
6566
template<int MaxNumOfLinesToCopy, int BLOCK_WARPS>
66-
__device__ __forceinline__ void CopyFromGlobalToShared(half __restrict__ (*SharedPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
67-
const half* GlobalPTR,
68-
const int GlobalStride,
69-
const int NumOfLinesLeft, // To support arbitrary N dimensions.
70-
bool Pred = true) {
67+
__device__ __forceinline__ void CopyFromGlobalToShared(half (* __restrict__ SharedPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8],
68+
const half* GlobalPTR,
69+
const int GlobalStride,
70+
const int NumOfLinesLeft, // To support arbitrary N dimensions.
71+
bool Pred = true) {
7172
// static parameters: 1 Group (8 Threads) can copy 1 line (64 FP16) each time
7273
const int NumOfThreads = BLOCK_WARPS * WARP_SIZE;
7374
const int NumOfGroups = NumOfThreads / 8;

torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh

+17-13
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414
//
15-
// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_parallel_dequant.cuh
15+
// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_parallel_dequant.cuh
16+
// To support MSVC, all instances of u_int32_t are changed to uint32_t.
1617

1718
#ifndef UTILS_PARALLELDEQUANT_CUH
1819
#define UTILS_PARALLELDEQUANT_CUH
@@ -26,7 +27,7 @@
2627
* Outputs: R1, R2
2728
* Note: Simplified Exponent calculation is applied.
2829
*/
29-
__device__ __forceinline__ void FP6_FP16_Cast_4Way(u_int32_t *R1, u_int32_t *R2) {
30+
__device__ __forceinline__ void FP6_FP16_Cast_4Way(uint32_t *R1, uint32_t *R2) {
3031
*R2 = *R1 & 0x80808080;
3132
*R1 = *R1 >> 2;
3233
*R1 = *R1 & 0x1f1f1f1f;
@@ -41,7 +42,7 @@ __device__ __forceinline__ void FP6_FP16_Cast_4Way(u_int32_t *R1, u_int32_t *R2)
4142
* Outputs: R1, R2
4243
* Note: Simplified Exponent calculation is NOT applied.
4344
*/
44-
__device__ __forceinline__ void FP6_FP16_Cast_4Way_Naive(u_int32_t *R1, u_int32_t *R2) {
45+
__device__ __forceinline__ void FP6_FP16_Cast_4Way_Naive(uint32_t *R1, uint32_t *R2) {
4546
//*R2 = *R1 & 0x80808080;
4647
*R2 = *R1 & 0xc0c0c0c0;
4748
*R1 = *R1 >> 2;
@@ -63,7 +64,7 @@ __device__ __forceinline__ void FP6_FP16_Cast_4Way_Naive(u_int32_t *R1, u_int32_
6364
//*R2 = 0x3c003c00;
6465
}
6566

66-
__device__ __forceinline__ u_int32_t MultScale(u_int32_t PackedFP16Pair, half Scale) {
67+
__device__ __forceinline__ uint32_t MultScale(uint32_t PackedFP16Pair, half Scale) {
6768
half* FP16_1 = reinterpret_cast<half*>(&PackedFP16Pair);
6869
half* FP16_2 = FP16_1 + 1;
6970
uint32_t output;
@@ -73,16 +74,19 @@ __device__ __forceinline__ u_int32_t MultScale(u_int32_t PackedFP16Pair, half Sc
7374
return output;
7475
}
7576

76-
__device__ __forceinline__ void Dequant_32FP6_4Way(u_int32_t __restrict__ Reg[][4],
77-
u_int32_t __restrict__ *read_RPTR_Frag1,
78-
u_int32_t __restrict__ *read_RPTR_Frag2,
79-
u_int32_t *Scales) {
80-
u_int32_t *OutputRegs = reinterpret_cast<u_int32_t*> (Reg);
81-
u_int32_t *Frag1_PTR = read_RPTR_Frag1;
82-
u_int32_t *Frag2_PTR = read_RPTR_Frag2;
77+
// MODIFICATION NOTE: to support MSVC
78+
// - u_int32_t __restrict__ Reg[][4] is changed to below.
79+
// - u_int32_t __restrict__ *read_RPTR_Frag1 is changed to below. similarly for read_RPTR_Frag2
80+
__device__ __forceinline__ void Dequant_32FP6_4Way(uint32_t (* __restrict__ Reg)[4],
81+
uint32_t * __restrict__ read_RPTR_Frag1,
82+
uint32_t * __restrict__ read_RPTR_Frag2,
83+
uint32_t * Scales) {
84+
uint32_t *OutputRegs = reinterpret_cast<uint32_t*> (Reg);
85+
uint32_t *Frag1_PTR = read_RPTR_Frag1;
86+
uint32_t *Frag2_PTR = read_RPTR_Frag2;
8387
half *Scale_RPTR = reinterpret_cast<half*>(Scales);
84-
u_int32_t Packed_FP6 = 0;
85-
u_int32_t tmp = 0;
88+
uint32_t Packed_FP6 = 0;
89+
uint32_t tmp = 0;
8690
// Dequantizing 32 FP6, each Loop dequantizing 4 FP6
8791
#pragma unroll(8)
8892
for(int i=0; i<8; i++) {

0 commit comments

Comments
 (0)