Skip to content

Commit 1197e02

Browse files
authored
[Build] Guard against older CUDA versions when building CUTLASS 3.x kernels (#5168)
1 parent 6575791 commit 1197e02

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
// clang-format will break include orders
2+
// clang-format off
3+
#include <cudaTypedefs.h>
4+
5+
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
6+
17
#include <torch/extension.h>
28

39
#include <ATen/cuda/CUDAContext.h>
@@ -6,8 +12,6 @@
612
#include <sstream>
713
#include <vector>
814

9-
// clang-format will break include orders
10-
// clang-format off
1115
#include "cutlass/cutlass.h"
1216

1317
#include "cute/tensor.hpp"
@@ -241,3 +245,5 @@ void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a,
241245
}
242246
}
243247
}
248+
249+
#endif

csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
#include <cudaTypedefs.h>
2+
13
#include <c10/cuda/CUDAGuard.h>
2-
#include <cuda_runtime.h>
34
#include <torch/extension.h>
45

56
void cutlass_scaled_mm_dq_sm75(torch::Tensor& c, torch::Tensor const& a,
@@ -17,10 +18,12 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor& c, torch::Tensor const& a,
1718
torch::Tensor const& a_scales,
1819
torch::Tensor const& b_scales);
1920

21+
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
2022
void cutlass_scaled_mm_dq_sm90(torch::Tensor& c, torch::Tensor const& a,
2123
torch::Tensor const& b,
2224
torch::Tensor const& a_scales,
2325
torch::Tensor const& b_scales);
26+
#endif
2427

2528
void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a,
2629
torch::Tensor const& b, torch::Tensor const& a_scales,
@@ -51,7 +54,13 @@ void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a,
5154

5255
if (version_num >= 90) {
5356
// Hopper
57+
58+
// Guard against compilation issues for sm90 kernels
59+
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
5460
cutlass_scaled_mm_dq_sm90(c, a, b, a_scales, b_scales);
61+
#else
62+
cutlass_scaled_mm_dq_sm80(c, a, b, a_scales, b_scales);
63+
#endif
5564
} else if (version_num == 89) {
5665
// Ada Lovelace
5766
cutlass_scaled_mm_dq_sm89(c, a, b, a_scales, b_scales);

0 commit comments

Comments
 (0)