Skip to content

[ROCm]: Fix build from source failure with gcc14 and ROCm 6.3 #13779

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions csrc/quantization/compressed_tensors/int8_quant_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@ static inline __device__ int8_t float_to_int8_rn(float x) {
float dst = std::nearbyint(x);

// saturate
dst = std::clamp(dst, i8_min, i8_max);

// See https://github.com/pytorch/pytorch/issues/127666
// See https://github.com/llvm/llvm-project/issues/95183
// hip-clang std::clamp __glibcxx_assert_fail host function when building on
// Arch/gcc14. The following replaces std::clamp usage with similar logic
// dst = std::clamp(dst, i8_min, i8_max);
dst = (dst < i8_min) ? i8_min : (dst > i8_max) ? i8_max : dst;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: dst = std::min(i8_max, max(dst, i8_min))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: dst = std::min(i8_max, max(dst, i8_min))

Suggest not to change to use std::min and std::max, as this would leads more comparison in average, and is less efficient: I have commented/suggested the same in the PR ( pytorch/pytorch#127812 ) in PyTorch when we first addressed this type of issues:
image

This pure "<" with ">" comparison is also align with what the clamp's original's documentation.

return static_cast<int8_t>(dst);
#else
// CUDA path
Expand Down Expand Up @@ -79,7 +85,13 @@ static inline __device__ int8_t int32_to_int8(int32_t x) {
static_cast<int32_t>(std::numeric_limits<int8_t>::max());

// saturate
int32_t dst = std::clamp(x, i8_min, i8_max);

// See https://github.com/pytorch/pytorch/issues/127666
// See https://github.com/llvm/llvm-project/issues/95183
// hip-clang std::clamp __glibcxx_assert_fail host function when building on
// Arch/gcc14. The following replaces std::clamp usage with similar logic
// int32_t dst = std::clamp(x, i8_min, i8_max);
int32_t dst = (x < i8_min) ? i8_min : (x > i8_max) ? i8_max : x;
return static_cast<int8_t>(dst);
#else
// CUDA path
Expand Down
8 changes: 7 additions & 1 deletion csrc/quantization/fused_kernels/quant_conversions.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@ static __device__ __forceinline__ int8_t float_to_int8_rn(float const x) {
// round
float dst = std::nearbyint(x);
// saturate
dst = std::clamp(dst, i8_min, i8_max);

// See https://github.com/pytorch/pytorch/issues/127666
// See https://github.com/llvm/llvm-project/issues/95183
// hip-clang std::clamp __glibcxx_assert_fail host function when building on
// Arch/gcc14. The following replaces std::clamp usage with similar logic
// dst = std::clamp(dst, i8_min, i8_max);
dst = (dst < i8_min) ? i8_min : (dst > i8_max) ? i8_max : dst;
return static_cast<int8_t>(dst);
#else
// CUDA path
Expand Down