From 4450d882140e42dcda571fb22b6a215906e351ac Mon Sep 17 00:00:00 2001 From: Hanno Becker Date: Tue, 25 Feb 2025 07:09:08 +0000 Subject: [PATCH] Poly: Hardcode barrett multiplier In most places in the code, we use magic constants with `check-magic: ...` annotations explaining their origin, rather than computing the magic constant in the code and relying on constant evaluation by the compiler. One place where this was not done is `barrett_reduce()`, where the 'magic' Barrett multiplier was computed in C. This commit aligns `barrett_reduce()` to other routines relying on magic constants, by hardcoding the magic value in C, and adding a `check-magic: ...` annotation to explain its origin. Signed-off-by: Hanno Becker --- mlkem/poly.c | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/mlkem/poly.c b/mlkem/poly.c index 611f2c7d4..4ac241951 100644 --- a/mlkem/poly.c +++ b/mlkem/poly.c @@ -79,21 +79,23 @@ __contract__( ensures(return_value > -MLKEM_Q_HALF && return_value < MLKEM_Q_HALF) ) { - /* - * To divide by MLKEM_Q using Barrett multiplication, the "magic number" - * multiplier is round_to_nearest(2**26/MLKEM_Q) + /* Barrett reduction approximates + * ``` + * round(a/MLKEM_Q) + * = round(a*(2^N/MLKEM_Q))/2^N) + * ~= round(a*round(2^N/MLKEM_Q)/2^N) + * ``` + * Here, we pick N=26. */ - const int BPOWER = 26; - const int32_t barrett_multiplier = ((1 << BPOWER) + MLKEM_Q / 2) / MLKEM_Q; + const int32_t magic = 20159; /* check-magic: 20159 == round(2^26 / MLKEM_Q) */ /* - * Compute round_to_nearest(a/MLKEM_Q) using the multiplier - * above and shift by BPOWER places. - * PORTABILITY: Right-shift on a signed integer is, strictly-speaking, - * implementation-defined for negative left argument. Here, - * we assume it's sign-preserving "arithmetic" shift right. (C99 6.5.7 (5)) + * PORTABILITY: Right-shift on a signed integer is + * implementation-defined for negative left argument. + * Here, we assume it's sign-preserving "arithmetic" shift right. + * See (C99 6.5.7 (5)) */ - const int32_t t = (barrett_multiplier * a + (1 << (BPOWER - 1))) >> BPOWER; + const int32_t t = (magic * a + (1 << 25)) >> 26; /* * t is in -10 .. +10, so we need 32-bit math to