Skip to content

Commit

Permalink
Masking cleanup (#53)
Browse files Browse the repository at this point in the history
* added masked pwm intt opcode
using pwm intt masked opcode in sequencer

* Enable masking for pwm_intt op

* Clean up

* Revert vf change

* Clean up

* Fix two share mult, clean up latency params

* Add input flops to shares, refresh shares before INTT, optimize twiddle reg, adjust delays

* Fix delay in pwm

* Remove w_delay, lint fixes

* MICROSOFT AUTOMATED PIPELINE: Stamp 'user/dev/kupadhyayula/masking_cleanup' with updated timestamp and hash after successful run

* Parameterize delays

* MICROSOFT AUTOMATED PIPELINE: Stamp 'user/dev/kupadhyayula/masking_cleanup' with updated timestamp and hash after successful run

* Use different random input for twiddle

* MICROSOFT AUTOMATED PIPELINE: Stamp 'user/dev/kupadhyayula/masking_cleanup' with updated timestamp and hash after successful run

---------

Co-authored-by: Nitsirks <michnorris@microsoft.com>
  • Loading branch information
upadhyayulakiran and Nitsirks authored Dec 12, 2024
1 parent 31291aa commit c23f8f0
Show file tree
Hide file tree
Showing 13 changed files with 299 additions and 224 deletions.
2 changes: 1 addition & 1 deletion .github/workflow_metadata/pr_hash
Original file line number Diff line number Diff line change
@@ -1 +1 @@
148a6d381422de56ae26bc8c4288130b67b86f624ee2adb675b36c18e09bc5319f1cc53b9c3268c98892d594e9a28b44
4f57be3471046889f34027ace2f1c510ede448243a09622503b3ede22983243c960758c271e18642e9a08a28b76f30d3
2 changes: 1 addition & 1 deletion .github/workflow_metadata/pr_timestamp
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1733339945
1733959211
33 changes: 26 additions & 7 deletions src/abr_libs/rtl/abr_masked_N_bit_mult_two_share.sv
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
// - Final output is obtained by combining the reshared and masked intermediate results.
// - It requires fresh randomness.
// - This design assumes that both x and y are secret, although y input from top level is usually public
// - It has one cycle latency and can accept a new input set at every clock.
// - It has two cycle latency and can accept a new input set at every clock.
//
//======================================================================

Expand All @@ -43,6 +43,7 @@

// Intermediate calculation logic for multiplication operations
logic [WIDTH-1:0] calculation [3:0];
logic [WIDTH-1:0] calculation_reg [1:0];
logic [WIDTH-1:0] calculation_rand [1:0];
logic [WIDTH-1:0] final_res [1:0];
logic [WIDTH-1:0] x0, x1, y0, y1;
Expand All @@ -53,12 +54,30 @@
calculation[1] = WIDTH'(x[1] * y[0]); // Multiplication of the second share x and first share y
calculation[2] = WIDTH'(x[0] * y[1]); // Multiplication of the first share x and second share y
calculation[3] = WIDTH'(x[1] * y[1]); // Multiplication of the second share x and second share y

calculation_rand[0] = calculation[2] + random;
calculation_rand[1] = calculation[1] - random;

final_res[0] = calculation[0] + calculation_rand[0];
final_res[1] = calculation[3] + calculation_rand[1];
end
always_ff @(posedge clk or negedge rst_n) begin
if (!rst_n) begin
for (int i = 0; i < 2; i++) begin
calculation_rand[i] <= 'h0;
calculation_reg[i] <= 'h0;
end
end
else if (zeroize) begin
for (int i = 0; i < 2; i++) begin
calculation_rand[i] <= 'h0;
calculation_reg[i] <= 'h0;
end
end
else begin
calculation_rand[0] <= calculation[2] + random;
calculation_rand[1] <= calculation[1] - random;
calculation_reg[0] <= calculation[0];
calculation_reg[1] <= calculation[3];
end
end
always_comb begin
final_res[0] = calculation_reg[0] + calculation_rand[0];
final_res[1] = calculation_reg[1] + calculation_rand[1];
end

// Final output assignment
Expand Down
2 changes: 1 addition & 1 deletion src/mldsa_top/rtl/mldsa_ctrl.sv
Original file line number Diff line number Diff line change
Expand Up @@ -1603,7 +1603,7 @@ mldsa_seq_sec mldsa_seq_sec_inst
INTT_raw_signal <= 'h0;
end
else begin
if (seq_en) begin
if (sec_seq_en) begin
unique case(sec_prog_cntr_nxt)
MLDSA_SIGN_VALID_S : begin //NTT(C)
NTT_raw_signal <= 'h1;
Expand Down
21 changes: 8 additions & 13 deletions src/ntt_top/rtl/ntt_butterfly2x2.sv
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,7 @@ module ntt_butterfly2x2
#(
parameter REG_SIZE = 23,
parameter MLDSA_Q = 23'd8380417,
parameter MLDSA_Q_DIV2_ODD = (MLDSA_Q + 1) / 2,
parameter BF_LATENCY = 10, //5 cycles per butterfly * 2 instances in serial = 10 clks
parameter PWM_LATENCY = 5, //latency of modular multiplier + modular addition to perform accumulation
parameter PWA_LATENCY = 1, //latency of modular addition
parameter PWS_LATENCY = 1, //latency of modular subtraction
parameter BF_STAGE1_LATENCY = BF_LATENCY/2
parameter MLDSA_Q_DIV2_ODD = (MLDSA_Q + 1) / 2
)
(
//Clock and reset
Expand Down Expand Up @@ -66,9 +61,9 @@ module ntt_butterfly2x2
logic [REG_SIZE-1:0] w01;
logic [REG_SIZE-1:0] w10;
logic [REG_SIZE-1:0] w11;
logic [BF_STAGE1_LATENCY-1:0][REG_SIZE-1:0] w10_reg, w11_reg; //Shift w10 by 5 cycles to match 1st stage BF latency
logic [UNMASKED_BF_STAGE1_LATENCY-1:0][REG_SIZE-1:0] w10_reg, w11_reg; //Shift w10 by 5 cycles to match 1st stage BF latency
logic pwo_mode;
logic [BF_LATENCY-1:0] ready_reg;
logic [UNMASKED_BF_LATENCY-1:0] ready_reg;

//Each butterfly unit takes u, v, w inputs and produces
//u, v outputs for the next stage to consume. Each butterfly
Expand All @@ -90,8 +85,8 @@ module ntt_butterfly2x2
w11_reg <= 'h0;
end
else begin
w10_reg <= {uvw_i.w10_i, w10_reg[BF_STAGE1_LATENCY-1:1]};
w11_reg <= {uvw_i.w11_i, w11_reg[BF_STAGE1_LATENCY-1:1]};
w10_reg <= {uvw_i.w10_i, w10_reg[UNMASKED_BF_STAGE1_LATENCY-1:1]};
w11_reg <= {uvw_i.w11_i, w11_reg[UNMASKED_BF_STAGE1_LATENCY-1:1]};
end
end

Expand Down Expand Up @@ -221,9 +216,9 @@ module ntt_butterfly2x2
ready_reg <= 'b0;
else begin
unique case(mode)
ct: ready_reg <= {enable, ready_reg[BF_LATENCY-1:1]};
gs: ready_reg <= {enable, ready_reg[BF_LATENCY-1:1]};
pwm: ready_reg <= accumulate ? {5'h0, enable, ready_reg[PWM_LATENCY-1:1]} : {6'h0, enable, ready_reg[PWM_LATENCY-2:1]};
ct: ready_reg <= {enable, ready_reg[UNMASKED_BF_LATENCY-1:1]};
gs: ready_reg <= {enable, ready_reg[UNMASKED_BF_LATENCY-1:1]};
pwm: ready_reg <= accumulate ? {5'h0, enable, ready_reg[UNMASKED_PWM_LATENCY-1:1]} : {6'h0, enable, ready_reg[UNMASKED_PWM_LATENCY-2:1]};
pwa: ready_reg <= {9'h0, enable};
pws: ready_reg <= {9'h0, enable};
default: ready_reg <= 'h0;
Expand Down
48 changes: 23 additions & 25 deletions src/ntt_top/rtl/ntt_ctrl.sv
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@ module ntt_ctrl
parameter MLDSA_Q_DIV2_ODD = (MLDSA_Q+1)/2,
parameter MLDSA_N = 256,
parameter MLDSA_LOGN = 8,
parameter MEM_ADDR_WIDTH = 15,
parameter BF_LATENCY = 10, //5 cycles per butterfly * 2 instances in serial = 10 clks
parameter NTT_BUF_LATENCY = 4
parameter MEM_ADDR_WIDTH = 15
)
(
input wire clk,
Expand Down Expand Up @@ -90,13 +88,9 @@ localparam INTT_READ_ADDR_STEP = 1;
localparam INTT_WRITE_ADDR_STEP = 16;
localparam PWO_READ_ADDR_STEP = 1;
localparam PWO_WRITE_ADDR_STEP = 1;
localparam PWM_LATENCY = 5;
localparam MASKED_BF_STAGE1_LATENCY = 264; //TODO check
localparam MASKED_PWM_LATENCY = 209; //For 1 masked pwm operation

localparam [MEM_ADDR_WIDTH-1:0] MEM_LAST_ADDR = 63;
localparam INTT_WRBUF_LATENCY = 13; //includes BF latency + mem latency for shuffled reads to begin
localparam MASKED_PWM_INTT_WRBUF_LATENCY = 481; //masked PWM+INTT latency + mem latency for shuffled reads to begin

//FSM states
ntt_read_state_t read_fsm_state_ps, read_fsm_state_ns;
ntt_write_state_t write_fsm_state_ps, write_fsm_state_ns;
Expand All @@ -115,11 +109,12 @@ logic [3:0] chunk_count;
logic [1:0] index_rand_offset, index_count, mem_rd_index_ofst;
logic [1:0] buf_rdptr_int;
logic [1:0] buf_rdptr_f;
logic [BF_LATENCY:0][1:0] buf_rdptr_reg;
logic [UNMASKED_BF_LATENCY:0][1:0] buf_rdptr_reg;
//logic [INTT_WRBUF_LATENCY-1:0][1:0] buf_wrptr_reg;
logic [MASKED_PWM_INTT_WRBUF_LATENCY-1:0][1:0] buf_wrptr_reg;
logic [MASKED_BF_STAGE1_LATENCY:0][3:0] chunk_count_reg;
// logic [MASKED_PWM_INTT_WRBUF_LATENCY:0] chunk_count_reg;
logic [MASKED_INTT_WRBUF_LATENCY-1:0][1:0] buf_wrptr_reg;
// logic [MASKED_BF_STAGE1_LATENCY:0][3:0] chunk_count_reg;
logic [MASKED_INTT_WRBUF_LATENCY-3:0][3:0] chunk_count_reg; //buf latency not rqd

logic latch_chunk_rand_offset, latch_index_rand_offset;
logic last_rd_addr, last_wr_addr;
logic mem_wr_en_fsm, mem_wr_en_reg;
Expand Down Expand Up @@ -366,9 +361,9 @@ always_comb begin
pw_mem_rd_addr_a_nxt = pw_base_addr_a + (4*chunk_count) + (PWO_READ_ADDR_STEP*mem_rd_index_ofst);
pw_mem_rd_addr_b_nxt = pw_base_addr_b + (4*chunk_count) + (PWO_READ_ADDR_STEP*mem_rd_index_ofst);
pw_mem_rd_addr_c_nxt = accumulate ? pw_base_addr_c + ((4*chunk_count)+(PWO_READ_ADDR_STEP*mem_rd_index_ofst)) : 'h0; //TODO check timing
pw_mem_wr_addr_c_nxt = accumulate ? pw_base_addr_c + (4*chunk_count_reg[PWM_LATENCY-2]) + (PWO_WRITE_ADDR_STEP*buf_rdptr_reg[PWM_LATENCY-2])
pw_mem_wr_addr_c_nxt = accumulate ? pw_base_addr_c + (4*chunk_count_reg[UNMASKED_PWM_LATENCY-2]) + (PWO_WRITE_ADDR_STEP*buf_rdptr_reg[UNMASKED_PWM_LATENCY-2])
: (pwa_mode | pws_mode) ? pw_base_addr_c + (4*chunk_count_reg[7]) + (PWO_WRITE_ADDR_STEP*buf_rdptr_reg[7])
: pw_base_addr_c + (4*chunk_count_reg[PWM_LATENCY-1]) + (PWO_WRITE_ADDR_STEP*buf_rdptr_reg[PWM_LATENCY-1]); //2
: pw_base_addr_c + (4*chunk_count_reg[UNMASKED_PWM_LATENCY-1]) + (PWO_WRITE_ADDR_STEP*buf_rdptr_reg[UNMASKED_PWM_LATENCY-1]); //2
end

//PWO addr
Expand Down Expand Up @@ -412,24 +407,24 @@ end


//------------------------------------------
//Twiddle addr logic - TODO: shuffling+masking (adjust latency)
//Twiddle addr logic
//------------------------------------------
always_comb begin
unique case(rounds_count)
'h0: begin
twiddle_end_addr = ct_mode ? 'd0 : 'd63;
twiddle_offset = 'h0;
twiddle_rand_offset = ct_mode ? 'h0 : pwm_intt_mode ? 7'((4*chunk_count_reg[MASKED_BF_STAGE1_LATENCY]) + buf_wrptr_reg[MASKED_PWM_INTT_WRBUF_LATENCY-1]) : 7'((4*chunk_count_reg[BF_LATENCY]) + buf_wrptr_reg[INTT_WRBUF_LATENCY-1]);
twiddle_rand_offset = ct_mode ? 'h0 : pwm_intt_mode ? 7'((4*chunk_count_reg[MASKED_INTT_WRBUF_LATENCY-MASKED_PWM_LATENCY-3]) + buf_wrptr_reg[MASKED_INTT_WRBUF_LATENCY-MASKED_PWM_LATENCY-1]) : 7'((4*chunk_count_reg[UNMASKED_BF_LATENCY]) + buf_wrptr_reg[INTT_WRBUF_LATENCY-1]); //pwm_intt mode only applies to round 0. Other rounds follow gs calc
end
'h1: begin
twiddle_end_addr = ct_mode ? 'd3 : 'd15;
twiddle_offset = ct_mode ? 'd1 : 'd64;
twiddle_rand_offset = ct_mode ? 7'(buf_rdptr_int) : pwm_intt_mode ? 7'((chunk_count_reg[MASKED_BF_STAGE1_LATENCY] % 4)*4 + buf_wrptr_reg[MASKED_PWM_INTT_WRBUF_LATENCY-1]) : 7'((chunk_count_reg[BF_LATENCY] % 4)*4 + buf_wrptr_reg[INTT_WRBUF_LATENCY-1]);
twiddle_rand_offset = ct_mode ? 7'(buf_rdptr_int) : 7'((chunk_count_reg[UNMASKED_BF_LATENCY] % 4)*4 + buf_wrptr_reg[INTT_WRBUF_LATENCY-1]);
end
'h2: begin
twiddle_end_addr = ct_mode ? 'd15 : 'd3;
twiddle_offset = ct_mode ? 'd5 : 'd80;
twiddle_rand_offset = ct_mode ? 7'((chunk_count % 'd4)*'d4 + buf_rdptr_int) : pwm_intt_mode ? 7'(buf_wrptr_reg[MASKED_PWM_INTT_WRBUF_LATENCY-1]) : 7'(buf_wrptr_reg[INTT_WRBUF_LATENCY-1]);
twiddle_rand_offset = ct_mode ? 7'((chunk_count % 'd4)*'d4 + buf_rdptr_int) : 7'(buf_wrptr_reg[INTT_WRBUF_LATENCY-1]);
end
'h3: begin
twiddle_end_addr = ct_mode ? 'd63 : 'd0;
Expand Down Expand Up @@ -578,16 +573,19 @@ always_ff @(posedge clk or negedge reset_n) begin
buf_wrptr_reg <= 'h0;
end
else if (ct_mode & (buf_rden_ntt | butterfly_ready)) begin
buf_rdptr_reg <= {buf_rdptr_int, buf_rdptr_reg[BF_LATENCY:1]};
buf_rdptr_reg <= {buf_rdptr_int, buf_rdptr_reg[UNMASKED_BF_LATENCY:1]};
end
else if ((gs_mode & (incr_mem_rd_addr | butterfly_ready))) begin
buf_wrptr_reg <= {{(MASKED_PWM_INTT_WRBUF_LATENCY-INTT_WRBUF_LATENCY){2'h0}}, mem_rd_index_ofst, buf_wrptr_reg[INTT_WRBUF_LATENCY-1:1]};
buf_wrptr_reg <= {{(MASKED_INTT_WRBUF_LATENCY-INTT_WRBUF_LATENCY){2'h0}}, mem_rd_index_ofst, buf_wrptr_reg[INTT_WRBUF_LATENCY-1:1]};
end
else if (pwo_mode & (incr_pw_rd_addr | butterfly_ready)) begin
buf_rdptr_reg <= {mem_rd_index_ofst, buf_rdptr_reg[BF_LATENCY:1]}; //TODO: create new reg with apt name for PWO
buf_rdptr_reg <= {mem_rd_index_ofst, buf_rdptr_reg[UNMASKED_BF_LATENCY:1]}; //TODO: create new reg with apt name for PWO
end
else if ((pwm_intt_mode)) begin
buf_wrptr_reg <= {mem_rd_index_ofst, buf_wrptr_reg[MASKED_INTT_WRBUF_LATENCY-1:1]};
end
else if ((pwm_intt_mode)) begin
buf_wrptr_reg <= {mem_rd_index_ofst, buf_wrptr_reg[MASKED_PWM_INTT_WRBUF_LATENCY-1:1]};
buf_wrptr_reg <= {mem_rd_index_ofst, buf_wrptr_reg[MASKED_INTT_WRBUF_LATENCY-1:1]};
end
else begin
buf_rdptr_reg <= 'h0;
Expand Down Expand Up @@ -627,11 +625,11 @@ always_ff @(posedge clk or negedge reset_n) begin
chunk_count_reg <= 'h0;
end
//chunk update can't use incr_mem_rd_addr in pwm_intt mode.
else if (pwm_intt_mode & incr_pw_rd_addr) begin
chunk_count_reg <= {chunk_count, chunk_count_reg[MASKED_BF_STAGE1_LATENCY:1]};
else if (pwm_intt_mode/* & incr_pw_rd_addr*/) begin
chunk_count_reg <= {chunk_count, chunk_count_reg[MASKED_INTT_WRBUF_LATENCY-3:1]};
end
else if (buf_rden_ntt | butterfly_ready | (gs_mode & incr_mem_rd_addr) | (pwo_mode & incr_pw_rd_addr)) begin //TODO: replace gs condition with an fsm generated flag perhaps?
chunk_count_reg <= {{(MASKED_BF_STAGE1_LATENCY+1-BF_LATENCY){4'h0}}, chunk_count, chunk_count_reg[BF_LATENCY:1]};
chunk_count_reg <= {{(MASKED_BF_STAGE1_LATENCY+1-UNMASKED_BF_LATENCY){4'h0}}, chunk_count, chunk_count_reg[UNMASKED_BF_LATENCY:1]};
end
end

Expand Down
17 changes: 17 additions & 0 deletions src/ntt_top/rtl/ntt_defines_pkg.sv
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,23 @@ parameter NTT_REG_SIZE = REG_SIZE-1;
parameter MASKED_WIDTH = 46;
// parameter MEM_DEPTH = 2**MLDSA_MEM_ADDR_WIDTH;

//----------------------
//Latency params for NTT
//----------------------
parameter INTT_WRBUF_LATENCY = 13;
parameter UNMASKED_BF_LATENCY = 10; //5 cycles per butterfly * 2 instances in serial = 10 clks
parameter UNMASKED_PWM_LATENCY = 5; //latency of modular multiplier + modular addition to perform accumulation
parameter UNMASKED_PWA_LATENCY = 1; //latency of modular addition
parameter UNMASKED_PWS_LATENCY = 1; //latency of modular subtraction
parameter UNMASKED_BF_STAGE1_LATENCY = UNMASKED_BF_LATENCY/2;

parameter MASKED_PWM_LATENCY = 211; //For 1 masked pwm operation
parameter MASKED_BF_STAGE1_LATENCY = 266; //For 1 masked butterfly operation
parameter MASKED_PWM_MASKED_INTT_LATENCY = MASKED_PWM_LATENCY + MASKED_BF_STAGE1_LATENCY; //PWM+stage1 INTT latency
parameter MASKED_INTT_LATENCY = MASKED_BF_STAGE1_LATENCY + UNMASKED_BF_STAGE1_LATENCY; //masked INTT latency
parameter MASKED_PWM_INTT_LATENCY = MASKED_PWM_LATENCY + MASKED_INTT_LATENCY + 1; //TODO: adjust for PWMA case. Adding 1 cyc as a placeholder for it
parameter MASKED_ADD_SUB_LATENCY = 53; //For 1 masked add/sub operation
parameter MASKED_INTT_WRBUF_LATENCY = MASKED_PWM_LATENCY + MASKED_INTT_LATENCY + 3; //masked PWM+INTT latency + mem latency for shuffled reads to begin (does not include PWMA case)

// typedef enum logic [2:0] {ct, gs, pwm, pwa, pws} mode_t;
//TODO: tb has issue with enums in top level ports. For now, using this workaround
Expand Down
Loading

0 comments on commit c23f8f0

Please sign in to comment.