Skip to content

Commit 581279d

Browse files
authored
[OPT] Zero-extend unsigned 16-bit integers when bitcasting (#5714)
The folding rule `BitCastScalarOrVector` was incorrectly handling bitcasting to unsigned integers smaller than 32-bits. It was simply copying the entire 32-bit word containing the integer. This conflicts with the requirement in section 2.2.1 of the SPIR-V spec which states that unsigned numeric types with a bit width less than 32-bits must have the high-order bits set to 0. This change include a refactor of the bit extension code to be able to test it better, and to use it in multiple files. Fixes microsoft/DirectXShaderCompiler#6319.
1 parent 80a1aed commit 581279d

8 files changed

+108
-74
lines changed

source/opt/const_folding_rules.cpp

+4-57
Original file line numberDiff line numberDiff line change
@@ -21,59 +21,6 @@ namespace opt {
2121
namespace {
2222
constexpr uint32_t kExtractCompositeIdInIdx = 0;
2323

24-
// Returns the value obtained by extracting the |number_of_bits| least
25-
// significant bits from |value|, and sign-extending it to 64-bits.
26-
uint64_t SignExtendValue(uint64_t value, uint32_t number_of_bits) {
27-
if (number_of_bits == 64) return value;
28-
29-
uint64_t mask_for_sign_bit = 1ull << (number_of_bits - 1);
30-
uint64_t mask_for_significant_bits = (mask_for_sign_bit << 1) - 1ull;
31-
if (value & mask_for_sign_bit) {
32-
// Set upper bits to 1
33-
value |= ~mask_for_significant_bits;
34-
} else {
35-
// Clear the upper bits
36-
value &= mask_for_significant_bits;
37-
}
38-
return value;
39-
}
40-
41-
// Returns the value obtained by extracting the |number_of_bits| least
42-
// significant bits from |value|, and zero-extending it to 64-bits.
43-
uint64_t ZeroExtendValue(uint64_t value, uint32_t number_of_bits) {
44-
if (number_of_bits == 64) return value;
45-
46-
uint64_t mask_for_first_bit_to_clear = 1ull << (number_of_bits);
47-
uint64_t mask_for_bits_to_keep = mask_for_first_bit_to_clear - 1;
48-
value &= mask_for_bits_to_keep;
49-
return value;
50-
}
51-
52-
// Returns a constant whose value is `value` and type is `type`. This constant
53-
// will be generated by `const_mgr`. The type must be a scalar integer type.
54-
const analysis::Constant* GenerateIntegerConstant(
55-
const analysis::Integer* integer_type, uint64_t result,
56-
analysis::ConstantManager* const_mgr) {
57-
assert(integer_type != nullptr);
58-
59-
std::vector<uint32_t> words;
60-
if (integer_type->width() == 64) {
61-
// In the 64-bit case, two words are needed to represent the value.
62-
words = {static_cast<uint32_t>(result),
63-
static_cast<uint32_t>(result >> 32)};
64-
} else {
65-
// In all other cases, only a single word is needed.
66-
assert(integer_type->width() <= 32);
67-
if (integer_type->IsSigned()) {
68-
result = SignExtendValue(result, integer_type->width());
69-
} else {
70-
result = ZeroExtendValue(result, integer_type->width());
71-
}
72-
words = {static_cast<uint32_t>(result)};
73-
}
74-
return const_mgr->GetConstant(integer_type, words);
75-
}
76-
7724
// Returns a constants with the value NaN of the given type. Only works for
7825
// 32-bit and 64-bit float point types. Returns |nullptr| if an error occurs.
7926
const analysis::Constant* GetNan(const analysis::Type* type,
@@ -1730,7 +1677,7 @@ BinaryScalarFoldingRule FoldBinaryIntegerOperation(uint64_t (*op)(uint64_t,
17301677
uint64_t result = op(ia, ib);
17311678

17321679
const analysis::Constant* result_constant =
1733-
GenerateIntegerConstant(integer_type, result, const_mgr);
1680+
const_mgr->GenerateIntegerConstant(integer_type, result);
17341681
return result_constant;
17351682
};
17361683
}
@@ -1745,7 +1692,7 @@ const analysis::Constant* FoldScalarSConvert(
17451692
const analysis::Integer* integer_type = result_type->AsInteger();
17461693
assert(integer_type && "The result type of an SConvert");
17471694
int64_t value = a->GetSignExtendedValue();
1748-
return GenerateIntegerConstant(integer_type, value, const_mgr);
1695+
return const_mgr->GenerateIntegerConstant(integer_type, value);
17491696
}
17501697

17511698
// A scalar folding rule that folds OpUConvert.
@@ -1762,8 +1709,8 @@ const analysis::Constant* FoldScalarUConvert(
17621709
// If the operand was an unsigned value with less than 32-bit, it would have
17631710
// been sign extended earlier, and we need to clear those bits.
17641711
auto* operand_type = a->type()->AsInteger();
1765-
value = ZeroExtendValue(value, operand_type->width());
1766-
return GenerateIntegerConstant(integer_type, value, const_mgr);
1712+
value = utils::ClearHighBits(value, 64 - operand_type->width());
1713+
return const_mgr->GenerateIntegerConstant(integer_type, value);
17671714
}
17681715
} // namespace
17691716

source/opt/constants.cpp

+22
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,28 @@ uint32_t ConstantManager::GetNullConstId(const Type* type) {
525525
return GetDefiningInstruction(c)->result_id();
526526
}
527527

528+
const Constant* ConstantManager::GenerateIntegerConstant(
529+
const analysis::Integer* integer_type, uint64_t result) {
530+
assert(integer_type != nullptr);
531+
532+
std::vector<uint32_t> words;
533+
if (integer_type->width() == 64) {
534+
// In the 64-bit case, two words are needed to represent the value.
535+
words = {static_cast<uint32_t>(result),
536+
static_cast<uint32_t>(result >> 32)};
537+
} else {
538+
// In all other cases, only a single word is needed.
539+
assert(integer_type->width() <= 32);
540+
if (integer_type->IsSigned()) {
541+
result = utils::SignExtendValue(result, integer_type->width());
542+
} else {
543+
result = utils::ZeroExtendValue(result, integer_type->width());
544+
}
545+
words = {static_cast<uint32_t>(result)};
546+
}
547+
return GetConstant(integer_type, words);
548+
}
549+
528550
std::vector<const analysis::Constant*> Constant::GetVectorComponents(
529551
analysis::ConstantManager* const_mgr) const {
530552
std::vector<const analysis::Constant*> components;

source/opt/constants.h

+5
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,11 @@ class ConstantManager {
671671
// Returns the id of a OpConstantNull with type of |type|.
672672
uint32_t GetNullConstId(const Type* type);
673673

674+
// Returns a constant whose value is `value` and type is `type`. This constant
675+
// will be generated by `const_mgr`. The type must be a scalar integer type.
676+
const Constant* GenerateIntegerConstant(const analysis::Integer* integer_type,
677+
uint64_t result);
678+
674679
private:
675680
// Creates a Constant instance with the given type and a vector of constant
676681
// defining words. Returns a unique pointer to the created Constant instance

source/opt/fold_spec_constant_op_and_composite_pass.cpp

+1-12
Original file line numberDiff line numberDiff line change
@@ -247,18 +247,7 @@ utils::SmallVector<uint32_t, 2> EncodeIntegerAsWords(const analysis::Type& type,
247247

248248
// Truncate first_word if the |type| has width less than uint32.
249249
if (bit_width < bits_per_word) {
250-
const uint32_t num_high_bits_to_mask = bits_per_word - bit_width;
251-
const bool is_negative_after_truncation =
252-
result_type_signed &&
253-
utils::IsBitAtPositionSet(first_word, bit_width - 1);
254-
255-
if (is_negative_after_truncation) {
256-
// Truncate and sign-extend |first_word|. No padding words will be
257-
// added and |pad_value| can be left as-is.
258-
first_word = utils::SetHighBits(first_word, num_high_bits_to_mask);
259-
} else {
260-
first_word = utils::ClearHighBits(first_word, num_high_bits_to_mask);
261-
}
250+
first_word = utils::SignExtendValue(first_word, bit_width);
262251
}
263252

264253
utils::SmallVector<uint32_t, 2> words = {first_word};

source/opt/folding_rules.cpp

+8-2
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,14 @@ std::vector<uint32_t> GetWordsFromNumericScalarOrVectorConstant(
180180
const analysis::Constant* ConvertWordsToNumericScalarOrVectorConstant(
181181
analysis::ConstantManager* const_mgr, const std::vector<uint32_t>& words,
182182
const analysis::Type* type) {
183-
if (type->AsInteger() || type->AsFloat())
184-
return const_mgr->GetConstant(type, words);
183+
const spvtools::opt::analysis::Integer* int_type = type->AsInteger();
184+
185+
if (int_type && int_type->width() <= 32) {
186+
assert(words.size() == 1);
187+
return const_mgr->GenerateIntegerConstant(int_type, words[0]);
188+
}
189+
190+
if (int_type || type->AsFloat()) return const_mgr->GetConstant(type, words);
185191
if (const auto* vec_type = type->AsVector())
186192
return const_mgr->GetNumericVectorConstantWithWords(vec_type, words);
187193
return nullptr;

source/util/bitutils.h

+25
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,31 @@ T ClearHighBits(T word, size_t num_bits_to_set) {
181181
false);
182182
}
183183

184+
// Returns the value obtained by extracting the |number_of_bits| least
185+
// significant bits from |value|, and sign-extending it to 64-bits.
186+
template <typename T>
187+
T SignExtendValue(T value, uint32_t number_of_bits) {
188+
const uint32_t bit_width = sizeof(value) * 8;
189+
if (number_of_bits == bit_width) return value;
190+
191+
bool is_negative = utils::IsBitAtPositionSet(value, number_of_bits - 1);
192+
if (is_negative) {
193+
value = utils::SetHighBits(value, bit_width - number_of_bits);
194+
} else {
195+
value = utils::ClearHighBits(value, bit_width - number_of_bits);
196+
}
197+
return value;
198+
}
199+
200+
// Returns the value obtained by extracting the |number_of_bits| least
201+
// significant bits from |value|, and zero-extending it to 64-bits.
202+
template <typename T>
203+
T ZeroExtendValue(T value, uint32_t number_of_bits) {
204+
const uint32_t bit_width = sizeof(value) * 8;
205+
if (number_of_bits == bit_width) return value;
206+
return utils::ClearHighBits(value, bit_width - number_of_bits);
207+
}
208+
184209
} // namespace utils
185210
} // namespace spvtools
186211

test/opt/fold_test.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -924,7 +924,7 @@ INSTANTIATE_TEST_SUITE_P(TestCase, IntegerInstructionFoldingTest,
924924
"%2 = OpBitcast %ushort %short_0xBC00\n" +
925925
"OpReturn\n" +
926926
"OpFunctionEnd",
927-
2, 0xFFFFBC00),
927+
2, 0xBC00),
928928
// Test case 53: Bit-cast half 1 to ushort
929929
InstructionFoldingCase<uint32_t>(
930930
Header() + "%main = OpFunction %void None %void_func\n" +
@@ -940,7 +940,7 @@ INSTANTIATE_TEST_SUITE_P(TestCase, IntegerInstructionFoldingTest,
940940
"%2 = OpBitcast %short %ushort_0xBC00\n" +
941941
"OpReturn\n" +
942942
"OpFunctionEnd",
943-
2, 0xBC00),
943+
2, 0xFFFFBC00),
944944
// Test case 55: Bit-cast short 0xBC00 to short
945945
InstructionFoldingCase<uint32_t>(
946946
Header() + "%main = OpFunction %void None %void_func\n" +
@@ -996,7 +996,7 @@ INSTANTIATE_TEST_SUITE_P(TestCase, IntegerInstructionFoldingTest,
996996
"%2 = OpBitcast %ubyte %byte_n1\n" +
997997
"OpReturn\n" +
998998
"OpFunctionEnd",
999-
2, 0xFFFFFFFF),
999+
2, 0xFF),
10001000
// Test case 62: Negate 2.
10011001
InstructionFoldingCase<uint32_t>(
10021002
Header() + "%main = OpFunction %void None %void_func\n" +

test/util/bitutils_test.cpp

+40
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,46 @@ TEST(BitUtilsTest, IsBitSetAtPositionAll) {
188188
EXPECT_TRUE(IsBitAtPositionSet(max_u64, i));
189189
}
190190
}
191+
192+
struct ExtendedValueTestCase {
193+
uint32_t input;
194+
uint32_t bit_width;
195+
uint32_t expected_result;
196+
};
197+
198+
using SignExtendedValueTest = ::testing::TestWithParam<ExtendedValueTestCase>;
199+
200+
TEST_P(SignExtendedValueTest, SignExtendValue) {
201+
const auto& tc = GetParam();
202+
auto result = SignExtendValue(tc.input, tc.bit_width);
203+
EXPECT_EQ(result, tc.expected_result);
204+
}
205+
INSTANTIATE_TEST_SUITE_P(
206+
SignExtendValue, SignExtendedValueTest,
207+
::testing::Values(ExtendedValueTestCase{1, 1, 0xFFFFFFFF},
208+
ExtendedValueTestCase{1, 2, 0x1},
209+
ExtendedValueTestCase{2, 1, 0x0},
210+
ExtendedValueTestCase{0x8, 4, 0xFFFFFFF8},
211+
ExtendedValueTestCase{0x8765, 16, 0xFFFF8765},
212+
ExtendedValueTestCase{0x7765, 16, 0x7765},
213+
ExtendedValueTestCase{0xDEADBEEF, 32, 0xDEADBEEF}));
214+
215+
using ZeroExtendedValueTest = ::testing::TestWithParam<ExtendedValueTestCase>;
216+
217+
TEST_P(ZeroExtendedValueTest, ZeroExtendValue) {
218+
const auto& tc = GetParam();
219+
auto result = ZeroExtendValue(tc.input, tc.bit_width);
220+
EXPECT_EQ(result, tc.expected_result);
221+
}
222+
223+
INSTANTIATE_TEST_SUITE_P(
224+
ZeroExtendValue, ZeroExtendedValueTest,
225+
::testing::Values(ExtendedValueTestCase{1, 1, 0x1},
226+
ExtendedValueTestCase{1, 2, 0x1},
227+
ExtendedValueTestCase{2, 1, 0x0},
228+
ExtendedValueTestCase{0x8, 4, 0x8},
229+
ExtendedValueTestCase{0xFF8765, 16, 0x8765},
230+
ExtendedValueTestCase{0xDEADBEEF, 32, 0xDEADBEEF}));
191231
} // namespace
192232
} // namespace utils
193233
} // namespace spvtools

0 commit comments

Comments
 (0)