Skip to content

Commit 8b081e2

Browse files
AVX2 vectorization for very large bitsets (#4422)
Co-authored-by: Stephan T. Lavavej <stl@nuwen.net>
1 parent a6c2a72 commit 8b081e2

File tree

3 files changed

+133
-24
lines changed

3 files changed

+133
-24
lines changed

benchmarks/src/bitset_to_string.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,11 @@ namespace {
4343

4444
BENCHMARK(BM_bitset_to_string<15, char>);
4545
BENCHMARK(BM_bitset_to_string<64, char>);
46+
BENCHMARK(BM_bitset_to_string<512, char>);
4647
BENCHMARK(BM_bitset_to_string_large_single<char>);
4748
BENCHMARK(BM_bitset_to_string<7, wchar_t>);
4849
BENCHMARK(BM_bitset_to_string<64, wchar_t>);
50+
BENCHMARK(BM_bitset_to_string<512, wchar_t>);
4951
BENCHMARK(BM_bitset_to_string_large_single<wchar_t>);
5052

5153
BENCHMARK_MAIN();

stl/src/vector_algorithms.cpp

+92
Original file line numberDiff line numberDiff line change
@@ -2169,6 +2169,17 @@ __declspec(noalias) size_t
21692169

21702170
#ifndef _M_ARM64EC
21712171
namespace {
2172+
__m256i __forceinline _Bitset_to_string_1_step_avx(const uint32_t _Val, const __m256i _Px0, const __m256i _Px1) {
2173+
const __m128i _Vx0 = _mm_cvtsi32_si128(_Val);
2174+
const __m128i _Vx1 = _mm_shuffle_epi8(_Vx0, _mm_set_epi32(0x00000000, 0x01010101, 0x02020202, 0x03030303));
2175+
const __m256i _Vx2 = _mm256_castsi128_si256(_Vx1);
2176+
const __m256i _Vx3 = _mm256_permutevar8x32_epi32(_Vx2, _mm256_set_epi32(3, 3, 2, 2, 1, 1, 0, 0));
2177+
const __m256i _Msk = _mm256_and_si256(_Vx3, _mm256_set1_epi64x(0x0102040810204080));
2178+
const __m256i _Ex0 = _mm256_cmpeq_epi8(_Msk, _mm256_setzero_si256());
2179+
const __m256i _Ex1 = _mm256_blendv_epi8(_Px1, _Px0, _Ex0);
2180+
return _Ex1;
2181+
}
2182+
21722183
__m128i __forceinline _Bitset_to_string_1_step(const uint16_t _Val, const __m128i _Px0, const __m128i _Px1) {
21732184
const __m128i _Vx0 = _mm_cvtsi32_si128(_Val);
21742185
const __m128i _Vx1 = _mm_unpacklo_epi8(_Vx0, _Vx0);
@@ -2180,6 +2191,18 @@ namespace {
21802191
return _Ex1;
21812192
}
21822193

2194+
__m256i __forceinline _Bitset_to_string_2_step_avx(const uint16_t _Val, const __m256i _Px0, const __m256i _Px1) {
2195+
const __m128i _Vx0 = _mm_cvtsi32_si128(_Val);
2196+
const __m128i _Vx1 = _mm_shuffle_epi8(_Vx0, _mm_set_epi32(0x00000000, 0x00000000, 0x01010101, 0x01010101));
2197+
const __m256i _Vx2 = _mm256_castsi128_si256(_Vx1);
2198+
const __m256i _Vx3 = _mm256_permute4x64_epi64(_Vx2, _MM_SHUFFLE(1, 1, 0, 0));
2199+
const __m256i _Msk = _mm256_and_si256(
2200+
_Vx3, _mm256_set_epi64x(0x0001000200040008, 0x0010002000400080, 0x0001000200040008, 0x0010002000400080));
2201+
const __m256i _Ex0 = _mm256_cmpeq_epi16(_Msk, _mm256_setzero_si256());
2202+
const __m256i _Ex1 = _mm256_blendv_epi8(_Px1, _Px0, _Ex0);
2203+
return _Ex1;
2204+
}
2205+
21832206
__m128i __forceinline _Bitset_to_string_2_step(const uint8_t _Val, const __m128i _Px0, const __m128i _Px1) {
21842207
const __m128i _Vx = _mm_set1_epi16(_Val);
21852208
const __m128i _Msk = _mm_and_si128(_Vx, _mm_set_epi64x(0x0001000200040008, 0x0010002000400080));
@@ -2195,6 +2218,38 @@ extern "C" {
21952218
__declspec(noalias) void __stdcall __std_bitset_to_string_1(
21962219
char* const _Dest, const void* _Src, size_t _Size_bits, const char _Elem0, const char _Elem1) noexcept {
21972220
#ifndef _M_ARM64EC
2221+
if (_Use_avx2() && _Size_bits >= 256) {
2222+
const __m256i _Px0 = _mm256_broadcastb_epi8(_mm_cvtsi32_si128(_Elem0));
2223+
const __m256i _Px1 = _mm256_broadcastb_epi8(_mm_cvtsi32_si128(_Elem1));
2224+
if (_Size_bits >= 32) {
2225+
char* _Pos = _Dest + _Size_bits;
2226+
_Size_bits &= 0x1F;
2227+
char* const _Stop_at = _Dest + _Size_bits;
2228+
do {
2229+
uint32_t _Val;
2230+
memcpy(&_Val, _Src, 4);
2231+
const __m256i _Elems = _Bitset_to_string_1_step_avx(_Val, _Px0, _Px1);
2232+
_Pos -= 32;
2233+
_mm256_storeu_si256(reinterpret_cast<__m256i*>(_Pos), _Elems);
2234+
_Advance_bytes(_Src, 4);
2235+
} while (_Pos != _Stop_at);
2236+
}
2237+
2238+
if (_Size_bits > 0) {
2239+
__assume(_Size_bits < 32);
2240+
uint32_t _Val = 0;
2241+
memcpy(&_Val, _Src, (_Size_bits + 7) / 8);
2242+
const __m256i _Elems = _Bitset_to_string_1_step_avx(_Val, _Px0, _Px1);
2243+
char _Tmp[32];
2244+
_mm256_storeu_si256(reinterpret_cast<__m256i*>(_Tmp), _Elems);
2245+
const char* const _Tmpd = _Tmp + (32 - _Size_bits);
2246+
memcpy(_Dest, _Tmpd, _Size_bits);
2247+
}
2248+
2249+
_mm256_zeroupper(); // TRANSITION, DevCom-10331414
2250+
return;
2251+
}
2252+
21982253
if (_Use_sse2()) {
21992254
const __m128i _Px0 = _mm_set1_epi8(_Elem0 ^ _Elem1);
22002255
const __m128i _Px1 = _mm_set1_epi8(_Elem1);
@@ -2241,6 +2296,43 @@ __declspec(noalias) void __stdcall __std_bitset_to_string_1(
22412296
__declspec(noalias) void __stdcall __std_bitset_to_string_2(
22422297
wchar_t* const _Dest, const void* _Src, size_t _Size_bits, const wchar_t _Elem0, const wchar_t _Elem1) noexcept {
22432298
#ifndef _M_ARM64EC
2299+
if (_Use_avx2() && _Size_bits >= 256) {
2300+
const __m256i _Px0 = _mm256_broadcastw_epi16(_mm_cvtsi32_si128(_Elem0));
2301+
const __m256i _Px1 = _mm256_broadcastw_epi16(_mm_cvtsi32_si128(_Elem1));
2302+
2303+
if (_Size_bits >= 16) {
2304+
wchar_t* _Pos = _Dest + _Size_bits;
2305+
_Size_bits &= 0xF;
2306+
wchar_t* const _Stop_at = _Dest + _Size_bits;
2307+
do {
2308+
uint16_t _Val;
2309+
memcpy(&_Val, _Src, 2);
2310+
const __m256i _Elems = _Bitset_to_string_2_step_avx(_Val, _Px0, _Px1);
2311+
_Pos -= 16;
2312+
_mm256_storeu_si256(reinterpret_cast<__m256i*>(_Pos), _Elems);
2313+
_Advance_bytes(_Src, 2);
2314+
} while (_Pos != _Stop_at);
2315+
}
2316+
2317+
if (_Size_bits > 0) {
2318+
__assume(_Size_bits < 16);
2319+
uint16_t _Val;
2320+
if (_Size_bits > 8) {
2321+
memcpy(&_Val, _Src, 2);
2322+
} else {
2323+
_Val = *reinterpret_cast<const uint8_t*>(_Src);
2324+
}
2325+
const __m256i _Elems = _Bitset_to_string_2_step_avx(_Val, _Px0, _Px1);
2326+
wchar_t _Tmp[16];
2327+
_mm256_storeu_si256(reinterpret_cast<__m256i*>(_Tmp), _Elems);
2328+
const wchar_t* const _Tmpd = _Tmp + (16 - _Size_bits);
2329+
memcpy(_Dest, _Tmpd, _Size_bits * 2);
2330+
}
2331+
2332+
_mm256_zeroupper(); // TRANSITION, DevCom-10331414
2333+
return;
2334+
}
2335+
22442336
if (_Use_sse2()) {
22452337
const __m128i _Px0 = _mm_set1_epi16(_Elem0 ^ _Elem1);
22462338
const __m128i _Px1 = _mm_set1_epi16(_Elem1);

tests/std/tests/VSO_0000000_vector_algorithms/test.cpp

+39-24
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <random>
1616
#include <string>
1717
#include <type_traits>
18+
#include <utility>
1819
#include <vector>
1920

2021
#if _HAS_CXX20
@@ -474,6 +475,43 @@ void test_one_container() {
474475
test_two_containers<Container, list<int>>();
475476
}
476477

478+
template <size_t N>
479+
bool test_randomized_bitset(mt19937_64& gen) {
480+
string str;
481+
wstring wstr;
482+
str.reserve(N);
483+
wstr.reserve(N);
484+
485+
while (str.size() != N) {
486+
uint64_t random_value = gen();
487+
488+
for (int bits = 0; bits < 64 && str.size() != N; ++bits) {
489+
const auto character = '0' + (random_value & 1);
490+
str.push_back(static_cast<char>(character));
491+
wstr.push_back(static_cast<wchar_t>(character));
492+
random_value >>= 1;
493+
}
494+
}
495+
496+
const bitset<N> b(str);
497+
498+
assert(b.to_string() == str);
499+
assert(b.template to_string<wchar_t>() == wstr);
500+
501+
return true;
502+
}
503+
504+
template <size_t Base, size_t... Vals>
505+
void test_randomized_bitset_base(index_sequence<Vals...>, mt19937_64& gen) {
506+
bool ignored[] = {test_randomized_bitset<Base + Vals>(gen)...};
507+
(void) ignored;
508+
}
509+
510+
template <size_t Base, size_t Count>
511+
void test_randomized_bitset_base_count(mt19937_64& gen) {
512+
test_randomized_bitset_base<Base>(make_index_sequence<Count>{}, gen);
513+
}
514+
477515
void test_bitset(mt19937_64& gen) {
478516
assert(bitset<0>(0x0ULL).to_string() == "");
479517
assert(bitset<0>(0xFEDCBA9876543210ULL).to_string() == "");
@@ -515,30 +553,7 @@ void test_bitset(mt19937_64& gen) {
515553
assert(bitset<75>(0xFEDCBA9876543210ULL).to_string<char32_t>()
516554
== U"000000000001111111011011100101110101001100001110110010101000011001000010000"); // not vectorized
517555

518-
{
519-
constexpr size_t N = 2048;
520-
521-
string str;
522-
wstring wstr;
523-
str.reserve(N);
524-
wstr.reserve(N);
525-
526-
while (str.size() != N) {
527-
uint64_t random_value = gen();
528-
529-
for (int bits = 0; bits < 64; ++bits) {
530-
const auto character = '0' + (random_value & 1);
531-
str.push_back(static_cast<char>(character));
532-
wstr.push_back(static_cast<wchar_t>(character));
533-
random_value >>= 1;
534-
}
535-
}
536-
537-
const bitset<N> b(str);
538-
539-
assert(b.to_string() == str);
540-
assert(b.to_string<wchar_t>() == wstr);
541-
}
556+
test_randomized_bitset_base_count<512 - 5, 32 + 10>(gen);
542557
}
543558

544559
void test_various_containers() {

0 commit comments

Comments
 (0)