diff --git a/mlkem/poly.c b/mlkem/poly.c index 611f2c7d4..4ac241951 100644 --- a/mlkem/poly.c +++ b/mlkem/poly.c @@ -79,21 +79,23 @@ __contract__( ensures(return_value > -MLKEM_Q_HALF && return_value < MLKEM_Q_HALF) ) { - /* - * To divide by MLKEM_Q using Barrett multiplication, the "magic number" - * multiplier is round_to_nearest(2**26/MLKEM_Q) + /* Barrett reduction approximates + * ``` + * round(a/MLKEM_Q) + * = round(a*(2^N/MLKEM_Q))/2^N) + * ~= round(a*round(2^N/MLKEM_Q)/2^N) + * ``` + * Here, we pick N=26. */ - const int BPOWER = 26; - const int32_t barrett_multiplier = ((1 << BPOWER) + MLKEM_Q / 2) / MLKEM_Q; + const int32_t magic = 20159; /* check-magic: 20159 == round(2^26 / MLKEM_Q) */ /* - * Compute round_to_nearest(a/MLKEM_Q) using the multiplier - * above and shift by BPOWER places. - * PORTABILITY: Right-shift on a signed integer is, strictly-speaking, - * implementation-defined for negative left argument. Here, - * we assume it's sign-preserving "arithmetic" shift right. (C99 6.5.7 (5)) + * PORTABILITY: Right-shift on a signed integer is + * implementation-defined for negative left argument. + * Here, we assume it's sign-preserving "arithmetic" shift right. + * See (C99 6.5.7 (5)) */ - const int32_t t = (barrett_multiplier * a + (1 << (BPOWER - 1))) >> BPOWER; + const int32_t t = (magic * a + (1 << 25)) >> 26; /* * t is in -10 .. +10, so we need 32-bit math to