Skip to content

Commit

Permalink
Sampling: Change fallback logic to native implementation
Browse files Browse the repository at this point in the history
From a perspective of the native backend, the rejection sampling
routine `rej_uniform` is special in that it is not expected to
be replaced by native code in its entirety, but only in special cases.
Outside of those special cases, the default C implementation is used.

Previously, this fallback logic was implemented as follows:
First, the native backend would be called. Upon success, the function
would return immediately. Otherwise, it would fall back to the default
implementation. Success/Failure would be communicated through a special
return value -1.

There are two problems with this logic:
- It appears very difficult to reason about in CBMC: Specifically, when
  we call the native backend, we shift the input buffer by the amount of
  data that has already been successfully sampled, and CBMC struggles
  reasoning about that.
- We call the native backend with a potentially unaligned buffer, which
  seems unnatural. This is not an issue for the existing backends,
  because their bounds checks ensure that the native implementation only
  takes effect at the beginning of rejection sampling, where the output
  buffer is aligned; yet, alignment is not guaranteed in general.

This commit simplifies the fallback logic by invoking the native backend
only upon the first call to `rej_uniform`, when no coefficients have yet
been sampled. This solves both issues above.

Signed-off-by: Hanno Becker <beckphan@amazon.co.uk>
  • Loading branch information
hanno-becker committed Jan 23, 2025
1 parent 47c3fa6 commit 1b2db0c
Showing 1 changed file with 15 additions and 22 deletions.
37 changes: 15 additions & 22 deletions mlkem/sampling.c
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ __contract__(
requires(offset <= target && target <= 4096 && buflen <= 4096 && buflen % 3 == 0)
requires(memory_no_alias(r, sizeof(int16_t) * target))
requires(memory_no_alias(buf, buflen))
requires(offset > 0 ==> array_bound(r, 0, offset, 0, MLKEM_Q))
requires(array_bound(r, 0, offset, 0, MLKEM_Q))
assigns(memory_slice(r, sizeof(int16_t) * target))
ensures(offset <= return_value && return_value <= target)
ensures(return_value > 0 ==> array_bound(r, 0, return_value, 0, MLKEM_Q))
ensures(array_bound(r, 0, return_value, 0, MLKEM_Q))
)
{
unsigned int ctr, pos;
Expand Down Expand Up @@ -66,7 +66,6 @@ __contract__(
return ctr;
}

#if !defined(MLKEM_USE_NATIVE_REJ_UNIFORM)
/*************************************************
* Name: rej_uniform
*
Expand All @@ -87,7 +86,7 @@ __contract__(
* Must be a multiple of 3.
*
* Note: Strictly speaking, only a few values of buflen near UINT_MAX need
* excluding. The limit of 4096 is somewhat arbitary but sufficient for all
* excluding. The limit of 128 is somewhat arbitary but sufficient for all
* uses of this function. Similarly, the actual limit for target is UINT_MAX/2.
*
* Returns the new offset of sampled 16-bit integers, at most target,
Expand All @@ -110,33 +109,27 @@ __contract__(
requires(offset <= target && target <= 4096 && buflen <= 4096 && buflen % 3 == 0)
requires(memory_no_alias(r, sizeof(int16_t) * target))
requires(memory_no_alias(buf, buflen))
requires(offset > 0 ==> array_bound(r, 0, offset, 0, MLKEM_Q))
requires(array_bound(r, 0, offset, 0, MLKEM_Q))
assigns(memory_slice(r, sizeof(int16_t) * target))
ensures(offset <= return_value && return_value <= target)
ensures(return_value > 0 ==> array_bound(r, 0, return_value, 0, MLKEM_Q))
ensures(array_bound(r, 0, return_value, 0, MLKEM_Q))
)
{
return rej_uniform_scalar(r, target, offset, buf, buflen);
}
#else /* MLKEM_USE_NATIVE_REJ_UNIFORM */
static unsigned int rej_uniform(int16_t *r, unsigned int target,
unsigned int offset, const uint8_t *buf,
unsigned int buflen)
{
int ret;

/* Sample from large buffer with full lane as much as possible. */
ret = rej_uniform_native(r + offset, target - offset, buf, buflen);
if (ret != -1)
#if defined(MLKEM_USE_NATIVE_REJ_UNIFORM)
if (offset == 0)
{
unsigned res = offset + (unsigned)ret;
debug_assert_bound(r, res, 0, MLKEM_Q);
return res;
int ret = rej_uniform_native(r, target, buf, buflen);
if (ret != -1)
{
unsigned res = (unsigned)ret;
debug_assert_bound(r, res, 0, MLKEM_Q);
return res;
}
}
#endif /* MLKEM_USE_NATIVE_REJ_UNIFORM */

return rej_uniform_scalar(r, target, offset, buf, buflen);
}
#endif /* MLKEM_USE_NATIVE_REJ_UNIFORM */

#ifndef MLKEM_GEN_MATRIX_NBLOCKS
#define MLKEM_GEN_MATRIX_NBLOCKS \
Expand Down

0 comments on commit 1b2db0c

Please sign in to comment.