Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Document bounds and remove redundant reductions #208

Merged
merged 17 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,18 @@ jobs:
runs-on: ${{ matrix.target.runner }}
steps:
- uses: actions/checkout@v4
- name: native tests
- name: native build
uses: ./.github/actions/multi-functest
with:
compile_mode: native
func: false
nistkat: false
kat: falst
- name: native tests (+debug)
uses: ./.github/actions/multi-functest
with:
compile_mode: native
cflags: "-DMLKEM_DEBUG"
- name: cross tests (opt only)
if: ${{ matrix.target.runner == 'pqcp-arm64' && (success() || failure()) }}
uses: ./.github/actions/multi-functest
Expand Down
2 changes: 1 addition & 1 deletion mk/schemes.mk
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
SOURCES = $(wildcard mlkem/*.c)
SOURCES = $(wildcard mlkem/*.c) $(wildcard mlkem/debug/*.c)
ifeq ($(OPT),1)
SOURCES += $(wildcard mlkem/native/aarch64/*.[csS]) $(wildcard mlkem/native/x86_64/*.[csS])
CPPFLAGS += -DMLKEM_USE_NATIVE
Expand Down
39 changes: 39 additions & 0 deletions mlkem/debug/debug.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// SPDX-License-Identifier: Apache-2.0
#include "debug.h"

#if defined(MLKEM_DEBUG)

static char debug_buf[256];

void mlkem_debug_check_bounds(const char *file, int line,
const char *description, const int16_t *ptr,
unsigned len, int16_t lower_bound_inclusive,
int16_t upper_bound_inclusive) {
int err = 0;
unsigned i;
for (i = 0; i < len; i++) {
int16_t val = ptr[i];
if (val < lower_bound_inclusive || val > upper_bound_inclusive) {
snprintf(debug_buf, sizeof(debug_buf),
"%s, index %u, value %d out of bounds (%d,%d)", description, i,
(int)val, (int)lower_bound_inclusive,
(int)upper_bound_inclusive);
mlkem_debug_print_error(file, line, debug_buf);
err = 1;
}
}

if (err == 1)
exit(1);
}

void mlkem_debug_print_error(const char *file, int line, const char *msg) {
fprintf(stderr, "[ERROR:%s:%04d] %s\n", file, line, msg);
fflush(stderr);
}

#else /* MLKEM_DEBUG */

int empty_cu_debug;

#endif /* MLKEM_DEBUG */
105 changes: 105 additions & 0 deletions mlkem/debug/debug.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// SPDX-License-Identifier: Apache-2.0
#ifndef MLKEM_DEBUG_H
#define MLKEM_DEBUG_H

#if defined(MLKEM_DEBUG)
#include <stdio.h>
#include <stdlib.h>

/*************************************************
* Name: mlkem_debug_check_bounds
*
* Description: Check whether values in an array of int16_t
* are within specified bounds.
*
* Prints an error message to stderr and calls
* exit(1) if not.
*
* Arguments: - file: filename
* - line: line number
* - description: Textual description of check
* - ptr: Base of array to be checked
* - len: Number of int16_t in ptr
* - lower_bound_inclusive: Inclusive lower bound
* - upper_bound_inclusive: Inclusive upper bound
**************************************************/
void mlkem_debug_check_bounds(const char *file, int line,
const char *description, const int16_t *ptr,
unsigned len, int16_t lower_bound_inclusive,
int16_t upper_bound_inclusive);

/* Print error message to stderr alongside file and line information */
void mlkem_debug_print_error(const char *file, int line, const char *msg);

/* Check absolute bounds in array of int16_t's
* ptr: Base of array, expression of type int16_t*
* len: Number of int16_t in array
* abs_bound: Exclusive upper bound on absolute value to check
* msg: Message to print on failure */
#define BOUND(ptr, len, abs_bound, msg) \
do { \
mlkem_debug_check_bounds(__FILE__, __LINE__, (msg), (int16_t *)(ptr), \
(len), -((abs_bound)-1), ((abs_bound)-1)); \
} while (0)

/* Check absolute bounds on coefficients in polynomial or mulcache
* ptr: poly* or poly_mulcache* pointer to polynomial (cache) to check
* abs_bound: Exclusive upper bound on absolute value to check
* msg: Message to print on failure */
#define POLY_BOUND_MSG(ptr, abs_bound, msg) \
BOUND((ptr)->coeffs, (sizeof((ptr)->coeffs) / sizeof(int16_t)), (abs_bound), \
msg)

/* Check absolute bounds on coefficients in polynomial
* ptr: poly* of poly_mulcache* pointer to polynomial (cache) to check
* abs_bound: Exclusive upper bound on absolute value to check */
#define POLY_BOUND(ptr, abs_bound) \
POLY_BOUND_MSG((ptr), (abs_bound), "poly bound for " #ptr)

/* Check absolute bounds on coefficients in vector of polynomials
* ptr: polyvec* or polyvec_mulcache* pointer to vector of polynomials to check
* abs_bound: Exclusive upper bound on absolute value to check */
#define POLYVEC_BOUND(ptr, abs_bound) \
do { \
for (unsigned _debug_polyvec_bound_idx = 0; \
_debug_polyvec_bound_idx < KYBER_K; _debug_polyvec_bound_idx++) \
POLY_BOUND_MSG(&(ptr)->vec[_debug_polyvec_bound_idx], (abs_bound), \
"polyvec bound for " #ptr ".vec[i]"); \
} while (0)

// Following AWS-LC to define a C99-compliant static assert
#define MLKEM_CONCAT(left, right) left##right
#define MLKEM_STATIC_ASSERT_DEFINE(cond, msg) \
typedef struct { \
unsigned int MLKEM_CONCAT(static_assertion_, msg) : (cond) ? 1 : -1; \
} MLKEM_CONCAT(static_assertion_, msg) __attribute__((unused));

#define MLKEM_STATIC_ASSERT_ADD_LINE0(cond, suffix) \
MLKEM_STATIC_ASSERT_DEFINE(cond, MLKEM_CONCAT(at_line_, suffix))
#define MLKEM_STATIC_ASSERT_ADD_LINE1(cond, line, suffix) \
MLKEM_STATIC_ASSERT_ADD_LINE0(cond, MLKEM_CONCAT(line, suffix))
#define MLKEM_STATIC_ASSERT_ADD_LINE2(cond, suffix) \
MLKEM_STATIC_ASSERT_ADD_LINE1(cond, __LINE__, suffix)
#define MLKEM_STATIC_ASSERT_ADD_ERROR(cond, suffix) \
MLKEM_STATIC_ASSERT_ADD_LINE2(cond, MLKEM_CONCAT(_error_is_, suffix))
#define STATIC_ASSERT(cond, error) MLKEM_STATIC_ASSERT_ADD_ERROR(cond, error)

#else /* MLKEM_DEBUG */

#define BOUND(...) \
do { \
} while (0)
#define POLY_BOUND(...) \
do { \
} while (0)
#define POLYVEC_BOUND(...) \
do { \
} while (0)
#define POLY_BOUND_MSG(...) \
do { \
} while (0)
#define STATIC_ASSERT(...)

#endif /* MLKEM_DEBUG */

#endif /* MLKEM_DEBUG_H */
26 changes: 26 additions & 0 deletions mlkem/indcpa.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "symmetric.h"

#include "arith_native.h"
#include "debug/debug.h"

/*************************************************
* Name: pack_pk
Expand All @@ -29,6 +30,7 @@
**************************************************/
static void pack_pk(uint8_t r[KYBER_INDCPA_PUBLICKEYBYTES], polyvec *pk,
const uint8_t seed[KYBER_SYMBYTES]) {
POLYVEC_BOUND(pk, KYBER_Q);
polyvec_tobytes(r, pk);
memcpy(r + KYBER_POLYVECBYTES, seed, KYBER_SYMBYTES);
}
Expand All @@ -48,6 +50,11 @@ static void unpack_pk(polyvec *pk, uint8_t seed[KYBER_SYMBYTES],
const uint8_t packedpk[KYBER_INDCPA_PUBLICKEYBYTES]) {
polyvec_frombytes(pk, packedpk);
memcpy(seed, packedpk + KYBER_POLYVECBYTES, KYBER_SYMBYTES);

// TODO! pk must be subject to a "modulus check" at the top-level
// crypto_kem_enc_derand(). Once that's done, the reduction is no
// longer necessary here.
polyvec_reduce(pk);
}

/*************************************************
Expand All @@ -60,6 +67,7 @@ static void unpack_pk(polyvec *pk, uint8_t seed[KYBER_SYMBYTES],
*key)
**************************************************/
static void pack_sk(uint8_t r[KYBER_INDCPA_SECRETKEYBYTES], polyvec *sk) {
POLYVEC_BOUND(sk, KYBER_Q);
polyvec_tobytes(r, sk);
}

Expand All @@ -76,6 +84,7 @@ static void pack_sk(uint8_t r[KYBER_INDCPA_SECRETKEYBYTES], polyvec *sk) {
static void unpack_sk(polyvec *sk,
const uint8_t packedsk[KYBER_INDCPA_SECRETKEYBYTES]) {
polyvec_frombytes(sk, packedsk);
polyvec_reduce(sk);
}

/*************************************************
Expand Down Expand Up @@ -245,6 +254,9 @@ void gen_matrix(polyvec *a, const uint8_t seed[KYBER_SYMBYTES],
* - const uint8_t *coins: pointer to input randomness
* (of length KYBER_SYMBYTES bytes)
**************************************************/

STATIC_ASSERT(NTT_BOUND + KYBER_Q < INT16_MAX, indcpa_enc_bound_0)

void indcpa_keypair_derand(uint8_t pk[KYBER_INDCPA_PUBLICKEYBYTES],
uint8_t sk[KYBER_INDCPA_SECRETKEYBYTES],
const uint8_t coins[KYBER_SYMBYTES]) {
Expand Down Expand Up @@ -289,8 +301,10 @@ void indcpa_keypair_derand(uint8_t pk[KYBER_INDCPA_PUBLICKEYBYTES],
poly_tomont(&pkpv.vec[i]);
}

// Arithmetic cannot overflow, see static assertion at the top
polyvec_add(&pkpv, &pkpv, &e);
polyvec_reduce(&pkpv);
polyvec_reduce(&skpv);

pack_sk(sk, &skpv);
pack_pk(pk, &pkpv, publicseed);
Expand All @@ -311,6 +325,12 @@ void indcpa_keypair_derand(uint8_t pk[KYBER_INDCPA_PUBLICKEYBYTES],
* - const uint8_t *coins: pointer to input random coins used as
*seed (of length KYBER_SYMBYTES) to deterministically generate all randomness
**************************************************/

// Check that the arithmetic in indcpa_enc() does not overflow
STATIC_ASSERT(INVNTT_BOUND + KYBER_ETA1 < INT16_MAX, indcpa_enc_bound_0)
STATIC_ASSERT(INVNTT_BOUND + KYBER_ETA2 + KYBER_Q < INT16_MAX,
indcpa_enc_bound_1)

void indcpa_enc(uint8_t c[KYBER_INDCPA_BYTES],
const uint8_t m[KYBER_INDCPA_MSGBYTES],
const uint8_t pk[KYBER_INDCPA_PUBLICKEYBYTES],
Expand Down Expand Up @@ -355,6 +375,7 @@ void indcpa_enc(uint8_t c[KYBER_INDCPA_BYTES],
polyvec_invntt_tomont(&b);
poly_invntt_tomont(&v);

// Arithmetic cannot overflow, see static assertion at the top
polyvec_add(&b, &b, &ep);
poly_add(&v, &v, &epp);
poly_add(&v, &v, &k);
Expand All @@ -377,6 +398,10 @@ void indcpa_enc(uint8_t c[KYBER_INDCPA_BYTES],
* - const uint8_t *sk: pointer to input secret key
* (of length KYBER_INDCPA_SECRETKEYBYTES)
**************************************************/

// Check that the arithmetic in indcpa_dec() does not overflow
STATIC_ASSERT(INVNTT_BOUND + KYBER_Q < INT16_MAX, indcpa_dec_bound_0)

void indcpa_dec(uint8_t m[KYBER_INDCPA_MSGBYTES],
const uint8_t c[KYBER_INDCPA_BYTES],
const uint8_t sk[KYBER_INDCPA_SECRETKEYBYTES]) {
Expand All @@ -390,6 +415,7 @@ void indcpa_dec(uint8_t m[KYBER_INDCPA_MSGBYTES],
polyvec_basemul_acc_montgomery(&mp, &skpv, &b);
poly_invntt_tomont(&mp);

// Arithmetic cannot overflow, see static assertion at the top
poly_sub(&mp, &v, &mp);
poly_reduce(&mp);

Expand Down
5 changes: 2 additions & 3 deletions mlkem/native/aarch64/intt_123_45_67_twiddles.S
Original file line number Diff line number Diff line change
Expand Up @@ -477,9 +477,8 @@ roots_l34:
.short 0
.short 0
roots_l012:
// layer 0 root modified to include ninv
.short 266 // originally: 1600
.short 2618 // originally: 15749
.short 1600
.short 15749
.short 40
.short 394
.short 749
Expand Down
Loading