Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Masking cleanup #53

Merged
merged 19 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
21617bc
added masked pwm intt opcode
Nitsirks Nov 21, 2024
0c9e28c
Merge branch 'user/dev/michnorris/masking_opcodes' into user/dev/kupa…
upadhyayulakiran Nov 21, 2024
f88fb75
Enable masking for pwm_intt op
upadhyayulakiran Dec 2, 2024
3ac04f5
Clean up
upadhyayulakiran Dec 2, 2024
deaa51a
Merge branch 'main' into user/dev/kupadhyayula/masking_fixes
upadhyayulakiran Dec 3, 2024
f75356d
Revert vf change
upadhyayulakiran Dec 3, 2024
f890968
Clean up
upadhyayulakiran Dec 3, 2024
6677cae
Fix two share mult, clean up latency params
upadhyayulakiran Dec 5, 2024
e300ba3
Merge branch 'main' into user/dev/kupadhyayula/masking_cleanup
upadhyayulakiran Dec 5, 2024
abdca46
Add input flops to shares, refresh shares before INTT, optimize twidd…
upadhyayulakiran Dec 10, 2024
42811ae
Fix delay in pwm
upadhyayulakiran Dec 10, 2024
69809d6
Remove w_delay, lint fixes
upadhyayulakiran Dec 10, 2024
de9841f
MICROSOFT AUTOMATED PIPELINE: Stamp 'user/dev/kupadhyayula/masking_cl…
upadhyayulakiran Dec 10, 2024
1c1bf80
Parameterize delays
upadhyayulakiran Dec 11, 2024
fdcf6a9
Merge branch 'user/dev/kupadhyayula/masking_cleanup' of ssh://github.…
upadhyayulakiran Dec 11, 2024
7511a1d
MICROSOFT AUTOMATED PIPELINE: Stamp 'user/dev/kupadhyayula/masking_cl…
upadhyayulakiran Dec 11, 2024
d9fa77d
Use different random input for twiddle
upadhyayulakiran Dec 11, 2024
7be318f
Merge branch 'user/dev/kupadhyayula/masking_cleanup' of ssh://github.…
upadhyayulakiran Dec 11, 2024
e0019e8
MICROSOFT AUTOMATED PIPELINE: Stamp 'user/dev/kupadhyayula/masking_cl…
upadhyayulakiran Dec 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflow_metadata/pr_hash
Original file line number Diff line number Diff line change
@@ -1 +1 @@
148a6d381422de56ae26bc8c4288130b67b86f624ee2adb675b36c18e09bc5319f1cc53b9c3268c98892d594e9a28b44
5d113844520ff7d46a3ba862712f0caafe67cdb2163fc7d9aa16812197145e666d1ee25ebf40061744711b6e364eebcf
2 changes: 1 addition & 1 deletion .github/workflow_metadata/pr_timestamp
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1733339945
1733877428
33 changes: 26 additions & 7 deletions src/abr_libs/rtl/abr_masked_N_bit_mult_two_share.sv
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
// - Final output is obtained by combining the reshared and masked intermediate results.
// - It requires fresh randomness.
// - This design assumes that both x and y are secret, although y input from top level is usually public
// - It has one cycle latency and can accept a new input set at every clock.
// - It has two cycle latency and can accept a new input set at every clock.
//
//======================================================================

Expand All @@ -43,6 +43,7 @@

// Intermediate calculation logic for multiplication operations
logic [WIDTH-1:0] calculation [3:0];
logic [WIDTH-1:0] calculation_reg [1:0];
logic [WIDTH-1:0] calculation_rand [1:0];
logic [WIDTH-1:0] final_res [1:0];
logic [WIDTH-1:0] x0, x1, y0, y1;
Expand All @@ -53,12 +54,30 @@
calculation[1] = WIDTH'(x[1] * y[0]); // Multiplication of the second share x and first share y
calculation[2] = WIDTH'(x[0] * y[1]); // Multiplication of the first share x and second share y
calculation[3] = WIDTH'(x[1] * y[1]); // Multiplication of the second share x and second share y

calculation_rand[0] = calculation[2] + random;
calculation_rand[1] = calculation[1] - random;

final_res[0] = calculation[0] + calculation_rand[0];
final_res[1] = calculation[3] + calculation_rand[1];
end
always_ff @(posedge clk or negedge rst_n) begin
if (!rst_n) begin
for (int i = 0; i < 2; i++) begin
calculation_rand[i] <= 'h0;
calculation_reg[i] <= 'h0;
end
end
else if (zeroize) begin
for (int i = 0; i < 2; i++) begin
calculation_rand[i] <= 'h0;
calculation_reg[i] <= 'h0;
end
end
else begin
calculation_rand[0] <= calculation[2] + random;
calculation_rand[1] <= calculation[1] - random;
calculation_reg[0] <= calculation[0];
calculation_reg[1] <= calculation[3];
end
end
always_comb begin
final_res[0] = calculation_reg[0] + calculation_rand[0];
final_res[1] = calculation_reg[1] + calculation_rand[1];
end

// Final output assignment
Expand Down
2 changes: 1 addition & 1 deletion src/mldsa_top/rtl/mldsa_ctrl.sv
Original file line number Diff line number Diff line change
Expand Up @@ -1603,7 +1603,7 @@ mldsa_seq_sec mldsa_seq_sec_inst
INTT_raw_signal <= 'h0;
end
else begin
if (seq_en) begin
if (sec_seq_en) begin
unique case(sec_prog_cntr_nxt)
MLDSA_SIGN_VALID_S : begin //NTT(C)
NTT_raw_signal <= 'h1;
Expand Down
21 changes: 8 additions & 13 deletions src/ntt_top/rtl/ntt_butterfly2x2.sv
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,7 @@ module ntt_butterfly2x2
#(
parameter REG_SIZE = 23,
parameter MLDSA_Q = 23'd8380417,
parameter MLDSA_Q_DIV2_ODD = (MLDSA_Q + 1) / 2,
parameter BF_LATENCY = 10, //5 cycles per butterfly * 2 instances in serial = 10 clks
parameter PWM_LATENCY = 5, //latency of modular multiplier + modular addition to perform accumulation
parameter PWA_LATENCY = 1, //latency of modular addition
parameter PWS_LATENCY = 1, //latency of modular subtraction
parameter BF_STAGE1_LATENCY = BF_LATENCY/2
parameter MLDSA_Q_DIV2_ODD = (MLDSA_Q + 1) / 2
)
(
//Clock and reset
Expand Down Expand Up @@ -66,9 +61,9 @@ module ntt_butterfly2x2
logic [REG_SIZE-1:0] w01;
logic [REG_SIZE-1:0] w10;
logic [REG_SIZE-1:0] w11;
logic [BF_STAGE1_LATENCY-1:0][REG_SIZE-1:0] w10_reg, w11_reg; //Shift w10 by 5 cycles to match 1st stage BF latency
logic [UNMASKED_BF_STAGE1_LATENCY-1:0][REG_SIZE-1:0] w10_reg, w11_reg; //Shift w10 by 5 cycles to match 1st stage BF latency
logic pwo_mode;
logic [BF_LATENCY-1:0] ready_reg;
logic [UNMASKED_BF_LATENCY-1:0] ready_reg;

//Each butterfly unit takes u, v, w inputs and produces
//u, v outputs for the next stage to consume. Each butterfly
Expand All @@ -90,8 +85,8 @@ module ntt_butterfly2x2
w11_reg <= 'h0;
end
else begin
w10_reg <= {uvw_i.w10_i, w10_reg[BF_STAGE1_LATENCY-1:1]};
w11_reg <= {uvw_i.w11_i, w11_reg[BF_STAGE1_LATENCY-1:1]};
w10_reg <= {uvw_i.w10_i, w10_reg[UNMASKED_BF_STAGE1_LATENCY-1:1]};
w11_reg <= {uvw_i.w11_i, w11_reg[UNMASKED_BF_STAGE1_LATENCY-1:1]};
end
end

Expand Down Expand Up @@ -221,9 +216,9 @@ module ntt_butterfly2x2
ready_reg <= 'b0;
else begin
unique case(mode)
ct: ready_reg <= {enable, ready_reg[BF_LATENCY-1:1]};
gs: ready_reg <= {enable, ready_reg[BF_LATENCY-1:1]};
pwm: ready_reg <= accumulate ? {5'h0, enable, ready_reg[PWM_LATENCY-1:1]} : {6'h0, enable, ready_reg[PWM_LATENCY-2:1]};
ct: ready_reg <= {enable, ready_reg[UNMASKED_BF_LATENCY-1:1]};
gs: ready_reg <= {enable, ready_reg[UNMASKED_BF_LATENCY-1:1]};
pwm: ready_reg <= accumulate ? {5'h0, enable, ready_reg[UNMASKED_PWM_LATENCY-1:1]} : {6'h0, enable, ready_reg[UNMASKED_PWM_LATENCY-2:1]};
pwa: ready_reg <= {9'h0, enable};
pws: ready_reg <= {9'h0, enable};
default: ready_reg <= 'h0;
Expand Down
48 changes: 23 additions & 25 deletions src/ntt_top/rtl/ntt_ctrl.sv
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@ module ntt_ctrl
parameter MLDSA_Q_DIV2_ODD = (MLDSA_Q+1)/2,
parameter MLDSA_N = 256,
parameter MLDSA_LOGN = 8,
parameter MEM_ADDR_WIDTH = 15,
parameter BF_LATENCY = 10, //5 cycles per butterfly * 2 instances in serial = 10 clks
parameter NTT_BUF_LATENCY = 4
parameter MEM_ADDR_WIDTH = 15
)
(
input wire clk,
Expand Down Expand Up @@ -90,13 +88,9 @@ localparam INTT_READ_ADDR_STEP = 1;
localparam INTT_WRITE_ADDR_STEP = 16;
localparam PWO_READ_ADDR_STEP = 1;
localparam PWO_WRITE_ADDR_STEP = 1;
localparam PWM_LATENCY = 5;
localparam MASKED_BF_STAGE1_LATENCY = 264; //TODO check
localparam MASKED_PWM_LATENCY = 209; //For 1 masked pwm operation

localparam [MEM_ADDR_WIDTH-1:0] MEM_LAST_ADDR = 63;
localparam INTT_WRBUF_LATENCY = 13; //includes BF latency + mem latency for shuffled reads to begin
localparam MASKED_PWM_INTT_WRBUF_LATENCY = 481; //masked PWM+INTT latency + mem latency for shuffled reads to begin

//FSM states
ntt_read_state_t read_fsm_state_ps, read_fsm_state_ns;
ntt_write_state_t write_fsm_state_ps, write_fsm_state_ns;
Expand All @@ -115,11 +109,12 @@ logic [3:0] chunk_count;
logic [1:0] index_rand_offset, index_count, mem_rd_index_ofst;
logic [1:0] buf_rdptr_int;
logic [1:0] buf_rdptr_f;
logic [BF_LATENCY:0][1:0] buf_rdptr_reg;
logic [UNMASKED_BF_LATENCY:0][1:0] buf_rdptr_reg;
//logic [INTT_WRBUF_LATENCY-1:0][1:0] buf_wrptr_reg;
logic [MASKED_PWM_INTT_WRBUF_LATENCY-1:0][1:0] buf_wrptr_reg;
logic [MASKED_BF_STAGE1_LATENCY:0][3:0] chunk_count_reg;
// logic [MASKED_PWM_INTT_WRBUF_LATENCY:0] chunk_count_reg;
logic [MASKED_INTT_WRBUF_LATENCY-1:0][1:0] buf_wrptr_reg;
// logic [MASKED_BF_STAGE1_LATENCY:0][3:0] chunk_count_reg;
logic [MASKED_INTT_WRBUF_LATENCY-3:0][3:0] chunk_count_reg; //buf latency not rqd

logic latch_chunk_rand_offset, latch_index_rand_offset;
logic last_rd_addr, last_wr_addr;
logic mem_wr_en_fsm, mem_wr_en_reg;
Expand Down Expand Up @@ -366,9 +361,9 @@ always_comb begin
pw_mem_rd_addr_a_nxt = pw_base_addr_a + (4*chunk_count) + (PWO_READ_ADDR_STEP*mem_rd_index_ofst);
pw_mem_rd_addr_b_nxt = pw_base_addr_b + (4*chunk_count) + (PWO_READ_ADDR_STEP*mem_rd_index_ofst);
pw_mem_rd_addr_c_nxt = accumulate ? pw_base_addr_c + ((4*chunk_count)+(PWO_READ_ADDR_STEP*mem_rd_index_ofst)) : 'h0; //TODO check timing
pw_mem_wr_addr_c_nxt = accumulate ? pw_base_addr_c + (4*chunk_count_reg[PWM_LATENCY-2]) + (PWO_WRITE_ADDR_STEP*buf_rdptr_reg[PWM_LATENCY-2])
pw_mem_wr_addr_c_nxt = accumulate ? pw_base_addr_c + (4*chunk_count_reg[UNMASKED_PWM_LATENCY-2]) + (PWO_WRITE_ADDR_STEP*buf_rdptr_reg[UNMASKED_PWM_LATENCY-2])
: (pwa_mode | pws_mode) ? pw_base_addr_c + (4*chunk_count_reg[7]) + (PWO_WRITE_ADDR_STEP*buf_rdptr_reg[7])
: pw_base_addr_c + (4*chunk_count_reg[PWM_LATENCY-1]) + (PWO_WRITE_ADDR_STEP*buf_rdptr_reg[PWM_LATENCY-1]); //2
: pw_base_addr_c + (4*chunk_count_reg[UNMASKED_PWM_LATENCY-1]) + (PWO_WRITE_ADDR_STEP*buf_rdptr_reg[UNMASKED_PWM_LATENCY-1]); //2
end

//PWO addr
Expand Down Expand Up @@ -412,24 +407,24 @@ end


//------------------------------------------
//Twiddle addr logic - TODO: shuffling+masking (adjust latency)
//Twiddle addr logic
//------------------------------------------
always_comb begin
unique case(rounds_count)
'h0: begin
twiddle_end_addr = ct_mode ? 'd0 : 'd63;
twiddle_offset = 'h0;
twiddle_rand_offset = ct_mode ? 'h0 : pwm_intt_mode ? 7'((4*chunk_count_reg[MASKED_BF_STAGE1_LATENCY]) + buf_wrptr_reg[MASKED_PWM_INTT_WRBUF_LATENCY-1]) : 7'((4*chunk_count_reg[BF_LATENCY]) + buf_wrptr_reg[INTT_WRBUF_LATENCY-1]);
twiddle_rand_offset = ct_mode ? 'h0 : pwm_intt_mode ? 7'((4*chunk_count_reg[MASKED_INTT_WRBUF_LATENCY-MASKED_PWM_LATENCY-3]) + buf_wrptr_reg[MASKED_INTT_WRBUF_LATENCY-MASKED_PWM_LATENCY-1]) : 7'((4*chunk_count_reg[UNMASKED_BF_LATENCY]) + buf_wrptr_reg[INTT_WRBUF_LATENCY-1]); //pwm_intt mode only applies to round 0. Other rounds follow gs calc
end
'h1: begin
twiddle_end_addr = ct_mode ? 'd3 : 'd15;
twiddle_offset = ct_mode ? 'd1 : 'd64;
twiddle_rand_offset = ct_mode ? 7'(buf_rdptr_int) : pwm_intt_mode ? 7'((chunk_count_reg[MASKED_BF_STAGE1_LATENCY] % 4)*4 + buf_wrptr_reg[MASKED_PWM_INTT_WRBUF_LATENCY-1]) : 7'((chunk_count_reg[BF_LATENCY] % 4)*4 + buf_wrptr_reg[INTT_WRBUF_LATENCY-1]);
twiddle_rand_offset = ct_mode ? 7'(buf_rdptr_int) : 7'((chunk_count_reg[UNMASKED_BF_LATENCY] % 4)*4 + buf_wrptr_reg[INTT_WRBUF_LATENCY-1]);
end
'h2: begin
twiddle_end_addr = ct_mode ? 'd15 : 'd3;
twiddle_offset = ct_mode ? 'd5 : 'd80;
twiddle_rand_offset = ct_mode ? 7'((chunk_count % 'd4)*'d4 + buf_rdptr_int) : pwm_intt_mode ? 7'(buf_wrptr_reg[MASKED_PWM_INTT_WRBUF_LATENCY-1]) : 7'(buf_wrptr_reg[INTT_WRBUF_LATENCY-1]);
twiddle_rand_offset = ct_mode ? 7'((chunk_count % 'd4)*'d4 + buf_rdptr_int) : 7'(buf_wrptr_reg[INTT_WRBUF_LATENCY-1]);
end
'h3: begin
twiddle_end_addr = ct_mode ? 'd63 : 'd0;
Expand Down Expand Up @@ -578,16 +573,19 @@ always_ff @(posedge clk or negedge reset_n) begin
buf_wrptr_reg <= 'h0;
end
else if (ct_mode & (buf_rden_ntt | butterfly_ready)) begin
buf_rdptr_reg <= {buf_rdptr_int, buf_rdptr_reg[BF_LATENCY:1]};
buf_rdptr_reg <= {buf_rdptr_int, buf_rdptr_reg[UNMASKED_BF_LATENCY:1]};
end
else if ((gs_mode & (incr_mem_rd_addr | butterfly_ready))) begin
buf_wrptr_reg <= {{(MASKED_PWM_INTT_WRBUF_LATENCY-INTT_WRBUF_LATENCY){2'h0}}, mem_rd_index_ofst, buf_wrptr_reg[INTT_WRBUF_LATENCY-1:1]};
buf_wrptr_reg <= {{(MASKED_INTT_WRBUF_LATENCY-INTT_WRBUF_LATENCY){2'h0}}, mem_rd_index_ofst, buf_wrptr_reg[INTT_WRBUF_LATENCY-1:1]};
end
else if (pwo_mode & (incr_pw_rd_addr | butterfly_ready)) begin
buf_rdptr_reg <= {mem_rd_index_ofst, buf_rdptr_reg[BF_LATENCY:1]}; //TODO: create new reg with apt name for PWO
buf_rdptr_reg <= {mem_rd_index_ofst, buf_rdptr_reg[UNMASKED_BF_LATENCY:1]}; //TODO: create new reg with apt name for PWO
end
else if ((pwm_intt_mode)) begin
buf_wrptr_reg <= {mem_rd_index_ofst, buf_wrptr_reg[MASKED_INTT_WRBUF_LATENCY-1:1]};
end
else if ((pwm_intt_mode)) begin
buf_wrptr_reg <= {mem_rd_index_ofst, buf_wrptr_reg[MASKED_PWM_INTT_WRBUF_LATENCY-1:1]};
buf_wrptr_reg <= {mem_rd_index_ofst, buf_wrptr_reg[MASKED_INTT_WRBUF_LATENCY-1:1]};
end
else begin
buf_rdptr_reg <= 'h0;
Expand Down Expand Up @@ -627,11 +625,11 @@ always_ff @(posedge clk or negedge reset_n) begin
chunk_count_reg <= 'h0;
end
//chunk update can't use incr_mem_rd_addr in pwm_intt mode.
else if (pwm_intt_mode & incr_pw_rd_addr) begin
chunk_count_reg <= {chunk_count, chunk_count_reg[MASKED_BF_STAGE1_LATENCY:1]};
else if (pwm_intt_mode/* & incr_pw_rd_addr*/) begin
chunk_count_reg <= {chunk_count, chunk_count_reg[MASKED_INTT_WRBUF_LATENCY-3:1]};
end
else if (buf_rden_ntt | butterfly_ready | (gs_mode & incr_mem_rd_addr) | (pwo_mode & incr_pw_rd_addr)) begin //TODO: replace gs condition with an fsm generated flag perhaps?
chunk_count_reg <= {{(MASKED_BF_STAGE1_LATENCY+1-BF_LATENCY){4'h0}}, chunk_count, chunk_count_reg[BF_LATENCY:1]};
chunk_count_reg <= {{(MASKED_BF_STAGE1_LATENCY+1-UNMASKED_BF_LATENCY){4'h0}}, chunk_count, chunk_count_reg[UNMASKED_BF_LATENCY:1]};
end
end

Expand Down
17 changes: 17 additions & 0 deletions src/ntt_top/rtl/ntt_defines_pkg.sv
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,23 @@ parameter NTT_REG_SIZE = REG_SIZE-1;
parameter MASKED_WIDTH = 46;
// parameter MEM_DEPTH = 2**MLDSA_MEM_ADDR_WIDTH;

//----------------------
//Latency params for NTT
//----------------------
parameter INTT_WRBUF_LATENCY = 13;
parameter UNMASKED_BF_LATENCY = 10; //5 cycles per butterfly * 2 instances in serial = 10 clks
parameter UNMASKED_PWM_LATENCY = 5; //latency of modular multiplier + modular addition to perform accumulation
parameter UNMASKED_PWA_LATENCY = 1; //latency of modular addition
parameter UNMASKED_PWS_LATENCY = 1; //latency of modular subtraction
parameter UNMASKED_BF_STAGE1_LATENCY = UNMASKED_BF_LATENCY/2;

parameter MASKED_PWM_LATENCY = 211; //For 1 masked pwm operation
parameter MASKED_BF_STAGE1_LATENCY = 266; //For 1 masked butterfly operation
parameter MASKED_PWM_MASKED_INTT_LATENCY = MASKED_PWM_LATENCY + MASKED_BF_STAGE1_LATENCY; //PWM+stage1 INTT latency
parameter MASKED_INTT_LATENCY = MASKED_BF_STAGE1_LATENCY + UNMASKED_BF_STAGE1_LATENCY; //masked INTT latency
parameter MASKED_PWM_INTT_LATENCY = MASKED_PWM_LATENCY + MASKED_INTT_LATENCY + 1; //TODO: adjust for PWMA case. Adding 1 cyc as a placeholder for it
parameter MASKED_ADD_SUB_LATENCY = 53; //For 1 masked add/sub operation
parameter MASKED_INTT_WRBUF_LATENCY = MASKED_PWM_LATENCY + MASKED_INTT_LATENCY + 3; //masked PWM+INTT latency + mem latency for shuffled reads to begin (does not include PWMA case)

// typedef enum logic [2:0] {ct, gs, pwm, pwa, pws} mode_t;
//TODO: tb has issue with enums in top level ports. For now, using this workaround
Expand Down
Loading