Skip to content

Commit 59b0d07

Browse files
authored
faster avx512 exp implementation (ggml-org#7551)
* faster avx512 exp implementation * x->r * improve accuracy, handle special cases * remove `e`
1 parent d5c0582 commit 59b0d07

File tree

1 file changed

+19
-24
lines changed

1 file changed

+19
-24
lines changed

ggml.c

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2315,32 +2315,27 @@ inline static __m512 ggml_v_expf(__m512 x) {
23152315
const __m512 r = _mm512_set1_ps(0x1.8p23f);
23162316
const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
23172317
const __m512 n = _mm512_sub_ps(z, r);
2318-
const __m512 b = _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
2319-
_mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
2320-
const __m512i e = _mm512_slli_epi32(_mm512_castps_si512(z), 23);
2321-
const __m512 k = _mm512_castsi512_ps(_mm512_add_epi32(e, _mm512_castps_si512(_mm512_set1_ps(1))));
2322-
const __mmask16 c = _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(126), _CMP_GT_OQ);
2323-
const __m512 u = _mm512_mul_ps(b, b);
2324-
const __m512 j = _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
2325-
_mm512_set1_ps(0x1.573e2ep-5f)), u,
2326-
_mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
2327-
_mm512_set1_ps(0x1.fffdb6p-2f))),
2328-
u, _mm512_mul_ps(_mm512_set1_ps(0x1.ffffecp-1f), b));
2329-
if (_mm512_kortestz(c, c))
2330-
return _mm512_fmadd_ps(j, k, k);
2331-
const __m512i g = _mm512_and_si512(
2332-
_mm512_movm_epi32(_mm512_cmp_ps_mask(n, _mm512_setzero_ps(), _CMP_LE_OQ)),
2333-
_mm512_set1_epi32(0x82000000u));
2334-
const __m512 s1 =
2335-
_mm512_castsi512_ps(_mm512_add_epi32(g, _mm512_set1_epi32(0x7f000000u)));
2336-
const __m512 s2 = _mm512_castsi512_ps(_mm512_sub_epi32(e, g));
2318+
const __m512 b =
2319+
_mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
2320+
_mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
23372321
const __mmask16 d =
23382322
_mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
2339-
return _mm512_mask_blend_ps(
2340-
d, _mm512_mask_blend_ps(
2341-
c, _mm512_fmadd_ps(k, j, k),
2342-
_mm512_mul_ps(_mm512_fmadd_ps(s2, j, s2), s1)),
2343-
_mm512_mul_ps(s1, s1));
2323+
const __m512 u = _mm512_mul_ps(b, b);
2324+
const __m512 j = _mm512_fmadd_ps(
2325+
_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
2326+
_mm512_set1_ps(0x1.573e2ep-5f)),
2327+
u,
2328+
_mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
2329+
_mm512_set1_ps(0x1.fffdb6p-2f))),
2330+
u,
2331+
_mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F)));
2332+
const __m512 res = _mm512_scalef_ps(j, n);
2333+
if (_mm512_kortestz(d, d))
2334+
return res;
2335+
const __m512 zero = _mm512_setzero_ps();
2336+
const __m512 alt = _mm512_mask_blend_ps(
2337+
_mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero);
2338+
return _mm512_mask_blend_ps(d, res, alt);
23442339
}
23452340

23462341
// computes silu x/(1+exp(-x)) in single precision vector

0 commit comments

Comments
 (0)