Skip to content

Commit 9862d40

Browse files
committed
[ROCm]: Fix build from source failure with gcc14 and ROCm 6.3
Solves #13777 See #13777 Fixes build from source failures when building for ROCm. Signed-off-by: Arjun Kathuria <arjun.kathuria8@gmail.com>
1 parent 05e1f96 commit 9862d40

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

csrc/quantization/compressed_tensors/int8_quant_kernels.cu

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,13 @@ static inline __device__ int8_t float_to_int8_rn(float x) {
2626
float dst = std::nearbyint(x);
2727

2828
// saturate
29-
dst = std::clamp(dst, i8_min, i8_max);
29+
30+
// See https://github.com/pytorch/pytorch/issues/127666
31+
// See https://github.com/llvm/llvm-project/issues/95183
32+
// hip-clang std::clamp __glibcxx_assert_fail host function when building on
33+
// Arch/gcc14. The following replaces std::clamp usage with similar logic
34+
// dst = std::clamp(dst, i8_min, i8_max);
35+
dst = (dst < i8_min) ? i8_min : (dst > i8_max) ? i8_max : dst;
3036
return static_cast<int8_t>(dst);
3137
#else
3238
// CUDA path
@@ -79,7 +85,13 @@ static inline __device__ int8_t int32_to_int8(int32_t x) {
7985
static_cast<int32_t>(std::numeric_limits<int8_t>::max());
8086

8187
// saturate
82-
int32_t dst = std::clamp(x, i8_min, i8_max);
88+
89+
// See https://github.com/pytorch/pytorch/issues/127666
90+
// See https://github.com/llvm/llvm-project/issues/95183
91+
// hip-clang std::clamp __glibcxx_assert_fail host function when building on
92+
// Arch/gcc14. The following replaces std::clamp usage with similar logic
93+
// int32_t dst = std::clamp(x, i8_min, i8_max);
94+
int32_t dst = (x < i8_min) ? i8_min : (x > i8_max) ? i8_max : x;
8395
return static_cast<int8_t>(dst);
8496
#else
8597
// CUDA path

csrc/quantization/fused_kernels/quant_conversions.cuh

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,13 @@ static __device__ __forceinline__ int8_t float_to_int8_rn(float const x) {
2121
// round
2222
float dst = std::nearbyint(x);
2323
// saturate
24-
dst = std::clamp(dst, i8_min, i8_max);
24+
25+
// See https://github.com/pytorch/pytorch/issues/127666
26+
// See https://github.com/llvm/llvm-project/issues/95183
27+
// hip-clang std::clamp __glibcxx_assert_fail host function when building on
28+
// Arch/gcc14. The following replaces std::clamp usage with similar logic
29+
// dst = std::clamp(dst, i8_min, i8_max);
30+
dst = (dst < i8_min) ? i8_min : (dst > i8_max) ? i8_max : dst;
2531
return static_cast<int8_t>(dst);
2632
#else
2733
// CUDA path

0 commit comments

Comments
 (0)