Skip to content

Commit 46d2b0a

Browse files
committed
[ec] Rewrite scalar_mult_base in C
For performance. This implies the need to get generator points from C as well. The pre-computed tables are stored in static memory, and computed lazily.
1 parent 9f767a9 commit 46d2b0a

7 files changed

+241
-58
lines changed

ec/mirage_crypto_ec.ml

+28-58
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ module type Dsa = sig
6262
module K_gen (H : Mirage_crypto.Hash.S) : sig
6363
val generate : key:priv -> Cstruct.t -> Cstruct.t
6464
end
65+
val force_precomputation : unit -> unit
6566
end
6667

6768
module type Dh_dsa = sig
@@ -108,6 +109,8 @@ module type Foreign = sig
108109

109110
val double_c : out_point -> point -> unit
110111
val add_c : out_point -> point -> point -> unit
112+
val scalar_mult_base_c : out_point -> string -> unit
113+
val force_precomputation_c : unit -> unit
111114
end
112115

113116
module type Field_element = sig
@@ -125,6 +128,7 @@ module type Field_element = sig
125128
val to_octets : field_element -> string
126129
val double_point : point -> point
127130
val add_point : point -> point -> point
131+
val scalar_mult_base_point : scalar -> point
128132
end
129133

130134
module Make_field_element (P : Parameters) (F : Foreign) : Field_element = struct
@@ -213,6 +217,11 @@ module Make_field_element (P : Parameters) (F : Foreign) : Field_element = struc
213217
let tmp = out_point () in
214218
F.add_c tmp a b;
215219
out_p_to_p tmp
220+
221+
let scalar_mult_base_point (Scalar d) =
222+
let tmp = out_point () in
223+
F.scalar_mult_base_c tmp d;
224+
out_p_to_p tmp
216225
end
217226

218227
module type Point = sig
@@ -226,6 +235,8 @@ module type Point = sig
226235
val x_of_finite_point : point -> string
227236
val params_g : point
228237
val select : bool -> then_:point -> else_:point -> point
238+
val scalar_mult_base : scalar -> point
239+
val force_precomputation : unit -> unit
229240
end
230241

231242
module Make_point (P : Parameters) (F : Foreign) : Point = struct
@@ -406,6 +417,9 @@ module Make_point (P : Parameters) (F : Foreign) : Point = struct
406417
of_octets buf
407418
| 0x00 | 0x04 -> Error `Invalid_length
408419
| _ -> Error `Invalid_format
420+
421+
let scalar_mult_base = Fe.scalar_mult_base_point
422+
let force_precomputation = F.force_precomputation_c
409423
end
410424

411425
module type Scalar = sig
@@ -414,8 +428,8 @@ module type Scalar = sig
414428
val of_octets : string -> (scalar, error) result
415429
val to_octets : scalar -> string
416430
val scalar_mult : scalar -> point -> point
417-
418431
val scalar_mult_base : scalar -> point
432+
val force_precomputation : unit -> unit
419433
end
420434

421435
module Make_scalar (Param : Parameters) (P : Point) : Scalar = struct
@@ -435,62 +449,6 @@ module Make_scalar (Param : Parameters) (P : Point) : Scalar = struct
435449

436450
let to_octets (Scalar buf) = rev_string buf
437451

438-
(* Use a sliding window optimization method for scalar multiplication
439-
Hard-coded window size = 4
440-
Implementation inspired from Go's crypto library
441-
https://github.com/golang/go/blob/a5cd894318677359f6d07ee74f9004d28b4d164c/src/crypto/internal/nistec/p256.go#L317
442-
*)
443-
module Precomputed = struct
444-
(* Pre-compute multiples of the generator point *)
445-
let pre_compute_multiples () =
446-
let len = Param.fe_length * 2 in
447-
let one_table _ = Array.init 15 (fun _ -> P.at_infinity ()) in
448-
let table = Array.init len one_table in
449-
let base = ref P.params_g in
450-
for i = 0 to len - 1 do
451-
table.(i).(0) <- !base;
452-
for j = 1 to 14 do
453-
table.(i).(j) <- P.add !base table.(i).(j - 1)
454-
done;
455-
base := P.double !base;
456-
base := P.double !base;
457-
base := P.double !base;
458-
base := P.double !base
459-
done;
460-
table
461-
462-
(* Select the n-th element of the table
463-
without leaking information about [n] *)
464-
let table_select table n =
465-
let p = ref (P.at_infinity ()) in
466-
for i = 1 to 15 do
467-
let cond = not (Eqaf.bool_of_int (n - i)) in
468-
p := P.select cond ~then_:table.(i - 1) ~else_:!p
469-
done;
470-
!p
471-
472-
(* Returns [kG] by decomposing [k] in binary form, and adding
473-
[2^0G * k_0 + 2^1G * k_1 + ...] in constant time using
474-
pre-computed values of 2^iG *)
475-
let scalar_mult_base (Scalar k) tables =
476-
let p = ref (P.at_infinity ()) in
477-
let index = ref 0 in (* Index increases since k is big-endian *)
478-
for i = 0 to String.length k - 1 do
479-
let byte = String.get_uint8 k i in
480-
let winValue = byte land 0b1111 in
481-
p := P.add !p (table_select tables.(!index) winValue);
482-
incr index;
483-
let winValue = byte lsr 4 in
484-
p := P.add !p (table_select tables.(!index) winValue);
485-
incr index
486-
done;
487-
!p
488-
489-
let scalar_mult_base =
490-
let tables = pre_compute_multiples () in
491-
fun d -> scalar_mult_base d tables
492-
end
493-
494452
(* Branchless Montgomery ladder method *)
495453
let scalar_mult (Scalar s) p =
496454
let r0 = ref (P.at_infinity ()) in
@@ -506,7 +464,9 @@ module Make_scalar (Param : Parameters) (P : Point) : Scalar = struct
506464
!r0
507465

508466
(* Specialization of [scalar_mult d p] when [p] is the generator *)
509-
let scalar_mult_base = Precomputed.scalar_mult_base
467+
let scalar_mult_base = P.scalar_mult_base
468+
469+
let force_precomputation = P.force_precomputation
510470
end
511471

512472
module Make_dh (Param : Parameters) (P : Point) (S : Scalar) : Dh = struct
@@ -818,6 +778,8 @@ module Make_dsa (Param : Parameters) (F : Fn) (P : Point) (S : Scalar) (H : Mira
818778

819779
let verify ~key (r, s) digest =
820780
verify_octets ~key (Cstruct.to_string r, Cstruct.to_string s) (Cstruct.to_string digest)
781+
782+
let force_precomputation = S.force_precomputation
821783
end
822784

823785
module P224 : Dh_dsa = struct
@@ -849,6 +811,8 @@ module P224 : Dh_dsa = struct
849811
external select_c : out_field_element -> bool -> field_element -> field_element -> unit = "mc_p224_select" [@@noalloc]
850812
external double_c : out_point -> point -> unit = "mc_p224_point_double" [@@noalloc]
851813
external add_c : out_point -> point -> point -> unit = "mc_p224_point_add" [@@noalloc]
814+
external scalar_mult_base_c : out_point -> string -> unit = "mc_p224_scalar_mult_base" [@@noalloc]
815+
external force_precomputation_c : unit -> unit = "mc_p224_force_precomputation" [@@noalloc]
852816
end
853817

854818
module Foreign_n = struct
@@ -898,6 +862,8 @@ module P256 : Dh_dsa = struct
898862
external select_c : out_field_element -> bool -> field_element -> field_element -> unit = "mc_p256_select" [@@noalloc]
899863
external double_c : out_point -> point -> unit = "mc_p256_point_double" [@@noalloc]
900864
external add_c : out_point -> point -> point -> unit = "mc_p256_point_add" [@@noalloc]
865+
external scalar_mult_base_c : out_point -> string -> unit = "mc_p256_scalar_mult_base" [@@noalloc]
866+
external force_precomputation_c : unit -> unit = "mc_p256_force_precomputation" [@@noalloc]
901867
end
902868

903869
module Foreign_n = struct
@@ -948,6 +914,8 @@ module P384 : Dh_dsa = struct
948914
external select_c : out_field_element -> bool -> field_element -> field_element -> unit = "mc_p384_select" [@@noalloc]
949915
external double_c : out_point -> point -> unit = "mc_p384_point_double" [@@noalloc]
950916
external add_c : out_point -> point -> point -> unit = "mc_p384_point_add" [@@noalloc]
917+
external scalar_mult_base_c : out_point -> string -> unit = "mc_p384_scalar_mult_base" [@@noalloc]
918+
external force_precomputation_c : unit -> unit = "mc_p384_force_precomputation" [@@noalloc]
951919
end
952920

953921
module Foreign_n = struct
@@ -999,6 +967,8 @@ module P521 : Dh_dsa = struct
999967
external select_c : out_field_element -> bool -> field_element -> field_element -> unit = "mc_p521_select" [@@noalloc]
1000968
external double_c : out_point -> point -> unit = "mc_p521_point_double" [@@noalloc]
1001969
external add_c : out_point -> point -> point -> unit = "mc_p521_point_add" [@@noalloc]
970+
external scalar_mult_base_c : out_point -> string -> unit = "mc_p521_scalar_mult_base" [@@noalloc]
971+
external force_precomputation_c : unit -> unit = "mc_p521_force_precomputation" [@@noalloc]
1002972
end
1003973

1004974
module Foreign_n = struct

ec/mirage_crypto_ec.mli

+2
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ module type Dsa = sig
129129
(** [generate ~key digest] deterministically takes the given private key
130130
and message digest to a [k] suitable for seeding the signing process. *)
131131
end
132+
133+
val force_precomputation : unit -> unit
132134
end
133135

134136
(** Elliptic curve with Diffie-Hellman and DSA. *)

ec/native/p224_stubs.c

+23
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@
1515
#define LEN_PRIME 224
1616
#define CURVE_DESCRIPTION fiat_p224
1717

18+
#define FE_LENGTH 28
19+
20+
// Generator point, see https://neuromancer.sk/std/nist/P-224
21+
static uint8_t gb_x[FE_LENGTH] = {0xb7, 0xe, 0xc, 0xbd, 0x6b, 0xb4, 0xbf, 0x7f, 0x32, 0x13, 0x90, 0xb9, 0x4a, 0x3, 0xc1, 0xd3, 0x56, 0xc2, 0x11, 0x22, 0x34, 0x32, 0x80, 0xd6, 0x11, 0x5c, 0x1d, 0x21};
22+
static uint8_t gb_y[FE_LENGTH] = {0xbd, 0x37, 0x63, 0x88, 0xb5, 0xf7, 0x23, 0xfb, 0x4c, 0x22, 0xdf, 0xe6, 0xcd, 0x43, 0x75, 0xa0, 0x5a, 0x7, 0x47, 0x64, 0x44, 0xd5, 0x81, 0x99, 0x85, 0x0, 0x7e, 0x34};
23+
1824
#include "inversion_template.h"
1925
#include "point_operations.h"
2026

@@ -139,3 +145,20 @@ CAMLprim value mc_p224_select(value out, value bit, value t, value f)
139145
);
140146
CAMLreturn(Val_unit);
141147
}
148+
149+
CAMLprim value mc_p224_scalar_mult_base(value out, value s)
150+
{
151+
CAMLparam2(out, s);
152+
scalar_mult_base(
153+
(WORD *) Bytes_val(Field(out, 0)),
154+
(WORD *) Bytes_val(Field(out, 1)),
155+
(WORD *) Bytes_val(Field(out, 2)),
156+
(uint8_t *) String_val(s),
157+
caml_string_length(s)
158+
);
159+
CAMLreturn(Val_unit);
160+
}
161+
162+
CAMLprim void mc_p224_force_precomputation(void) {
163+
compute_generator_table();
164+
}

ec/native/p256_stubs.c

+24
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@
1515
#define LEN_PRIME 256
1616
#define CURVE_DESCRIPTION fiat_p256
1717

18+
#define FE_LENGTH 32
19+
20+
// Generator point, see https://neuromancer.sk/std/nist/P-256
21+
static uint8_t gb_x[FE_LENGTH] = {0x6b, 0x17, 0xd1, 0xf2, 0xe1, 0x2c, 0x42, 0x47, 0xf8, 0xbc, 0xe6, 0xe5, 0x63, 0xa4, 0x40, 0xf2, 0x77, 0x3, 0x7d, 0x81, 0x2d, 0xeb, 0x33, 0xa0, 0xf4, 0xa1, 0x39, 0x45, 0xd8, 0x98, 0xc2, 0x96};
22+
static uint8_t gb_y[FE_LENGTH] = {0x4f, 0xe3, 0x42, 0xe2, 0xfe, 0x1a, 0x7f, 0x9b, 0x8e, 0xe7, 0xeb, 0x4a, 0x7c, 0xf, 0x9e, 0x16, 0x2b, 0xce, 0x33, 0x57, 0x6b, 0x31, 0x5e, 0xce, 0xcb, 0xb6, 0x40, 0x68, 0x37, 0xbf, 0x51, 0xf5};
23+
1824
#include "inversion_template.h"
1925
#include "point_operations.h"
2026

@@ -139,3 +145,21 @@ CAMLprim value mc_p256_select(value out, value bit, value t, value f)
139145
);
140146
CAMLreturn(Val_unit);
141147
}
148+
149+
150+
CAMLprim value mc_p256_scalar_mult_base(value out, value s)
151+
{
152+
CAMLparam2(out, s);
153+
scalar_mult_base(
154+
(WORD *) Bytes_val(Field(out, 0)),
155+
(WORD *) Bytes_val(Field(out, 1)),
156+
(WORD *) Bytes_val(Field(out, 2)),
157+
(uint8_t *) String_val(s),
158+
caml_string_length(s)
159+
);
160+
CAMLreturn(Val_unit);
161+
}
162+
163+
CAMLprim void mc_p256_force_precomputation(void) {
164+
compute_generator_table();
165+
}

ec/native/p384_stubs.c

+23
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@
1515
#define LEN_PRIME 384
1616
#define CURVE_DESCRIPTION fiat_p384
1717

18+
#define FE_LENGTH 48
19+
20+
// Generator point, see https://neuromancer.sk/std/nist/P-384
21+
static uint8_t gb_x[FE_LENGTH] = {0xaa, 0x87, 0xca, 0x22, 0xbe, 0x8b, 0x5, 0x37, 0x8e, 0xb1, 0xc7, 0x1e, 0xf3, 0x20, 0xad, 0x74, 0x6e, 0x1d, 0x3b, 0x62, 0x8b, 0xa7, 0x9b, 0x98, 0x59, 0xf7, 0x41, 0xe0, 0x82, 0x54, 0x2a, 0x38, 0x55, 0x2, 0xf2, 0x5d, 0xbf, 0x55, 0x29, 0x6c, 0x3a, 0x54, 0x5e, 0x38, 0x72, 0x76, 0xa, 0xb7};
22+
static uint8_t gb_y[FE_LENGTH] = {0x36, 0x17, 0xde, 0x4a, 0x96, 0x26, 0x2c, 0x6f, 0x5d, 0x9e, 0x98, 0xbf, 0x92, 0x92, 0xdc, 0x29, 0xf8, 0xf4, 0x1d, 0xbd, 0x28, 0x9a, 0x14, 0x7c, 0xe9, 0xda, 0x31, 0x13, 0xb5, 0xf0, 0xb8, 0xc0, 0xa, 0x60, 0xb1, 0xce, 0x1d, 0x7e, 0x81, 0x9d, 0x7a, 0x43, 0x1d, 0x7c, 0x90, 0xea, 0xe, 0x5f};
23+
1824
#include "inversion_template.h"
1925
#include "point_operations.h"
2026

@@ -139,3 +145,20 @@ CAMLprim value mc_p384_select(value out, value bit, value t, value f)
139145
);
140146
CAMLreturn(Val_unit);
141147
}
148+
149+
CAMLprim value mc_p384_scalar_mult_base(value out, value s)
150+
{
151+
CAMLparam2(out, s);
152+
scalar_mult_base(
153+
(WORD *) Bytes_val(Field(out, 0)),
154+
(WORD *) Bytes_val(Field(out, 1)),
155+
(WORD *) Bytes_val(Field(out, 2)),
156+
(uint8_t *) String_val(s),
157+
caml_string_length(s)
158+
);
159+
CAMLreturn(Val_unit);
160+
}
161+
162+
CAMLprim void mc_p384_force_precomputation(void) {
163+
compute_generator_table();
164+
}

ec/native/p521_stubs.c

+23
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@
1515
#define LEN_PRIME 521
1616
#define CURVE_DESCRIPTION fiat_p521
1717

18+
#define FE_LENGTH 66
19+
20+
// Generator point, see https://neuromancer.sk/std/nist/P-521
21+
static uint8_t gb_x[FE_LENGTH] = {0x0, 0xc6, 0x85, 0x8e, 0x6, 0xb7, 0x4, 0x4, 0xe9, 0xcd, 0x9e, 0x3e, 0xcb, 0x66, 0x23, 0x95, 0xb4, 0x42, 0x9c, 0x64, 0x81, 0x39, 0x5, 0x3f, 0xb5, 0x21, 0xf8, 0x28, 0xaf, 0x60, 0x6b, 0x4d, 0x3d, 0xba, 0xa1, 0x4b, 0x5e, 0x77, 0xef, 0xe7, 0x59, 0x28, 0xfe, 0x1d, 0xc1, 0x27, 0xa2, 0xff, 0xa8, 0xde, 0x33, 0x48, 0xb3, 0xc1, 0x85, 0x6a, 0x42, 0x9b, 0xf9, 0x7e, 0x7e, 0x31, 0xc2, 0xe5, 0xbd, 0x66};
22+
static uint8_t gb_y[FE_LENGTH] = {0x1, 0x18, 0x39, 0x29, 0x6a, 0x78, 0x9a, 0x3b, 0xc0, 0x4, 0x5c, 0x8a, 0x5f, 0xb4, 0x2c, 0x7d, 0x1b, 0xd9, 0x98, 0xf5, 0x44, 0x49, 0x57, 0x9b, 0x44, 0x68, 0x17, 0xaf, 0xbd, 0x17, 0x27, 0x3e, 0x66, 0x2c, 0x97, 0xee, 0x72, 0x99, 0x5e, 0xf4, 0x26, 0x40, 0xc5, 0x50, 0xb9, 0x1, 0x3f, 0xad, 0x7, 0x61, 0x35, 0x3c, 0x70, 0x86, 0xa2, 0x72, 0xc2, 0x40, 0x88, 0xbe, 0x94, 0x76, 0x9f, 0xd1, 0x66, 0x50};
23+
1824
#include "inversion_template.h"
1925
#include "point_operations.h"
2026

@@ -139,3 +145,20 @@ CAMLprim value mc_p521_select(value out, value bit, value t, value f)
139145
);
140146
CAMLreturn(Val_unit);
141147
}
148+
149+
CAMLprim value mc_p521_scalar_mult_base(value out, value s)
150+
{
151+
CAMLparam2(out, s);
152+
scalar_mult_base(
153+
(WORD *) Bytes_val(Field(out, 0)),
154+
(WORD *) Bytes_val(Field(out, 1)),
155+
(WORD *) Bytes_val(Field(out, 2)),
156+
(uint8_t *) String_val(s),
157+
caml_string_length(s)
158+
);
159+
CAMLreturn(Val_unit);
160+
}
161+
162+
CAMLprim void mc_p521_force_precomputation(void) {
163+
compute_generator_table();
164+
}

0 commit comments

Comments
 (0)