Skip to content

Commit

Permalink
Poly: Hardcode barrett multiplier
Browse files Browse the repository at this point in the history
In most places in the code, we use magic constants with
`check-magic: ...` annotations explaining their origin,
rather than computing the magic constant in the code
and relying on constant evaluation by the compiler.

One place where this was not done is `barrett_reduce()`,
where the 'magic' Barrett multiplier was computed in C.

This commit aligns `barrett_reduce()` to other routines
relying on magic constants, by hardcoding the magic value
in C, and adding a `check-magic: ...` annotation to explain
its origin.

Signed-off-by: Hanno Becker <beckphan@amazon.co.uk>
  • Loading branch information
hanno-becker committed Feb 25, 2025
1 parent e319849 commit 4450d88
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions mlkem/poly.c
Original file line number Diff line number Diff line change
Expand Up @@ -79,21 +79,23 @@ __contract__(
ensures(return_value > -MLKEM_Q_HALF && return_value < MLKEM_Q_HALF)
)
{
/*
* To divide by MLKEM_Q using Barrett multiplication, the "magic number"
* multiplier is round_to_nearest(2**26/MLKEM_Q)
/* Barrett reduction approximates
* ```
* round(a/MLKEM_Q)
* = round(a*(2^N/MLKEM_Q))/2^N)
* ~= round(a*round(2^N/MLKEM_Q)/2^N)
* ```
* Here, we pick N=26.
*/
const int BPOWER = 26;
const int32_t barrett_multiplier = ((1 << BPOWER) + MLKEM_Q / 2) / MLKEM_Q;
const int32_t magic = 20159; /* check-magic: 20159 == round(2^26 / MLKEM_Q) */

/*
* Compute round_to_nearest(a/MLKEM_Q) using the multiplier
* above and shift by BPOWER places.
* PORTABILITY: Right-shift on a signed integer is, strictly-speaking,
* implementation-defined for negative left argument. Here,
* we assume it's sign-preserving "arithmetic" shift right. (C99 6.5.7 (5))
* PORTABILITY: Right-shift on a signed integer is
* implementation-defined for negative left argument.
* Here, we assume it's sign-preserving "arithmetic" shift right.
* See (C99 6.5.7 (5))
*/
const int32_t t = (barrett_multiplier * a + (1 << (BPOWER - 1))) >> BPOWER;
const int32_t t = (magic * a + (1 << 25)) >> 26;

/*
* t is in -10 .. +10, so we need 32-bit math to
Expand Down

0 comments on commit 4450d88

Please sign in to comment.