Skip to content

Commit 8306017

Browse files
committedJan 13, 2025
Improve dft performance on arm64
1 parent 6aea976 commit 8306017

File tree

3 files changed

+44
-11
lines changed

3 files changed

+44
-11
lines changed
 

Diff for: ‎src/dft/bitrev.hpp

+15
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,22 @@ constexpr inline static size_t bitrev_table_log2N = ilog2(arraysize(data::bitrev
4949
template <size_t Bits>
5050
CMT_GNU_CONSTEXPR inline u32 bitrev_using_table(u32 x)
5151
{
52+
#ifdef CMT_ARCH_NEON
53+
return __builtin_bitreverse32(x) >> (32 - Bits);
54+
#else
5255
if constexpr (Bits > bitrev_table_log2N)
5356
return bitreverse<Bits>(x);
5457

5558
return data::bitrev_table[x] >> (bitrev_table_log2N - Bits);
59+
#endif
5660
}
5761

5862
template <bool use_table>
5963
CMT_GNU_CONSTEXPR inline u32 bitrev_using_table(u32 x, size_t bits, cbool_t<use_table>)
6064
{
65+
#ifdef CMT_ARCH_NEON
66+
return __builtin_bitreverse32(x) >> (32 - bits);
67+
#else
6168
if constexpr (use_table)
6269
{
6370
return data::bitrev_table[x] >> (bitrev_table_log2N - bits);
@@ -66,10 +73,17 @@ CMT_GNU_CONSTEXPR inline u32 bitrev_using_table(u32 x, size_t bits, cbool_t<use_
6673
{
6774
return bitreverse<32>(x) >> (32 - bits);
6875
}
76+
#endif
6977
}
7078

7179
CMT_GNU_CONSTEXPR inline u32 dig4rev_using_table(u32 x, size_t bits)
7280
{
81+
#ifdef CMT_ARCH_NEON
82+
x = __builtin_bitreverse32(x);
83+
x = (((x & 0xaaaaaaaa) >> 1) | ((x & 0x55555555) << 1));
84+
x = x >> (32 - bits);
85+
return x;
86+
#else
7387
if (bits > bitrev_table_log2N)
7488
{
7589
if (bits <= 16)
@@ -82,6 +96,7 @@ CMT_GNU_CONSTEXPR inline u32 dig4rev_using_table(u32 x, size_t bits)
8296
x = (((x & 0xaaaaaaaa) >> 1) | ((x & 0x55555555) << 1));
8397
x = x >> (bitrev_table_log2N - bits);
8498
return x;
99+
#endif
85100
}
86101

87102
template <size_t log2n, size_t bitrev, typename T>

Diff for: ‎src/dft/fft-impl.hpp

+24-8
Original file line numberDiff line numberDiff line change
@@ -52,22 +52,30 @@ template <typename T>
5252
inline std::bitset<DFT_MAX_STAGES> fft_algorithm_selection;
5353

5454
template <>
55-
inline std::bitset<DFT_MAX_STAGES> fft_algorithm_selection<float>{ (1ull << 15) - 1 };
55+
inline std::bitset<DFT_MAX_STAGES> fft_algorithm_selection<float>{
56+
#ifdef CMT_ARCH_NEON
57+
0
58+
#else
59+
(1ull << 15) - 1
60+
#endif
61+
};
5662

5763
template <>
5864
inline std::bitset<DFT_MAX_STAGES> fft_algorithm_selection<double>{ 0 };
5965

6066
template <typename T>
61-
constexpr bool inline use_autosort(size_t log2n)
67+
inline bool use_autosort(size_t log2n)
6268
{
6369
return fft_algorithm_selection<T>[log2n];
6470
}
6571

72+
#ifndef CMT_ARCH_NEON
73+
#define KFR_AUTOSORT_FOR_2048
6674
#define KFR_AUTOSORT_FOR_128D
6775
#define KFR_AUTOSORT_FOR_256D
6876
#define KFR_AUTOSORT_FOR_512
6977
#define KFR_AUTOSORT_FOR_1024
70-
#define KFR_AUTOSORT_FOR_2048
78+
#endif
7179

7280
#ifdef CMT_ARCH_AVX
7381
template <>
@@ -855,7 +863,11 @@ template <typename T>
855863
struct fft_config
856864
{
857865
constexpr static inline const bool recursion = true;
858-
constexpr static inline const bool prefetch = true;
866+
#ifdef CMT_ARCH_NEON
867+
constexpr static inline const bool prefetch = false;
868+
#else
869+
constexpr static inline const bool prefetch = true;
870+
#endif
859871
constexpr static inline const size_t process_width =
860872
const_max(static_cast<size_t>(1), vector_capacity<T> / 16);
861873
};
@@ -1606,7 +1618,7 @@ struct fft_specialization<T, 10> : fft_final_stage_impl<T, false, 1024>
16061618
{
16071619
fft_final_stage_impl<T, false, 1024>::template do_execute<inverse>(out, in, nullptr);
16081620
if (this->need_reorder)
1609-
fft_reorder(out, 10, cfalse);
1621+
fft_reorder(out, csize_t<10>{}, cbool_t<always_br2>{});
16101622
}
16111623
};
16121624
#endif
@@ -1649,8 +1661,6 @@ struct fft_specialization<T, 11> : dft_stage<T>
16491661
radix8_autosort_pass_last(256, csize<width>, no, no, no, cbool<inverse>, out, out, tw);
16501662
}
16511663
};
1652-
1653-
#else
16541664
#endif
16551665

16561666
template <bool is_even, bool first, typename T, bool autosort>
@@ -1768,7 +1778,13 @@ KFR_INTRINSIC void init_fft(dft_plan<T>* self, size_t size, dft_order)
17681778
{
17691779
const size_t log2n = ilog2(size);
17701780
cswitch(
1771-
csizes_t<0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11>(), log2n,
1781+
csizes_t<0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10
1782+
#ifdef KFR_AUTOSORT_FOR_2048
1783+
,
1784+
11
1785+
#endif
1786+
>(),
1787+
log2n,
17721788
[&](auto log2n)
17731789
{
17741790
(void)log2n;

Diff for: ‎tests/dft_test.cpp

+5-3
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,13 @@ constexpr ctypes_t<float, double> dft_float_types{};
3333
constexpr ctypes_t<float> dft_float_types{};
3434
#endif
3535

36-
#if defined(CMT_ARCH_X86) && !defined(KFR_NO_PERF_TESTS)
36+
#if !defined(KFR_NO_PERF_TESTS)
3737

3838
static void full_barrier()
3939
{
40-
#ifdef CMT_COMPILER_GNU
40+
#if defined(CMT_ARCH_NEON)
41+
asm volatile("dmb ish" ::: "memory");
42+
#elif defined(CMT_COMPILER_GNU)
4143
asm volatile("mfence" ::: "memory");
4244
#else
4345
_ReadWriteBarrier();
@@ -235,7 +237,7 @@ TEST(fft_accuracy)
235237

236238
if (is_even(size))
237239
{
238-
index_t csize = dft_plan_real<float_type>::complex_size_for(size, dft_pack_format::CCs);
240+
index_t csize = dft_plan_real<float_type>::complex_size_for(size, dft_pack_format::CCs);
239241
univector<float_type> in = truncate(gen_random_range<float_type>(gen, -1.0, +1.0), size);
240242

241243
univector<complex<float_type>> out = truncate(dimensions<1>(scalar(qnan)), csize);

0 commit comments

Comments
 (0)