Skip to content

Commit 14236f7

Browse files
authored
add a trait to dedup code (#291)
1 parent 0fc446a commit 14236f7

File tree

1 file changed

+22
-50
lines changed

1 file changed

+22
-50
lines changed

cryptography/polynomial/src/domain.rs

+22-50
Original file line numberDiff line numberDiff line change
@@ -223,60 +223,27 @@ impl Domain {
223223
}
224224
}
225225

226-
/// Computes a FFT of the field elements(scalars).
227-
///
228-
/// Note: This is essentially multiple inner products.
229-
///
230-
/// TODO: This method is still duplicated below
231-
fn fft_scalar_inplace(twiddle_factors: &[Scalar], a: &mut [Scalar]) {
232-
let n = a.len();
233-
let log_n = log2_pow2(n);
234-
assert_eq!(n, 1 << log_n);
235-
236-
// Bit-reversal permutation
237-
for k in 0..n {
238-
let rk = bitreverse(k as u32, log_n) as usize;
239-
if k < rk {
240-
a.swap(rk, k);
241-
}
242-
}
243-
244-
let mut m = 1;
245-
for s in 0..log_n {
246-
let w_m = twiddle_factors[s as usize];
247-
for k in (0..n).step_by(2 * m) {
248-
let mut w = Scalar::ONE;
249-
250-
for j in 0..m {
251-
let t = if w == Scalar::ONE {
252-
a[k + j + m]
253-
} else if w == -Scalar::ONE {
254-
-a[k + j + m]
255-
} else {
256-
a[k + j + m] * w
257-
};
226+
use std::ops::{Add, Mul, Neg, Sub};
227+
228+
trait FFTElement:
229+
Sized
230+
+ Copy
231+
+ Add<Output = Self>
232+
+ Sub<Output = Self>
233+
+ Mul<Scalar, Output = Self>
234+
+ Neg<Output = Self>
235+
{
236+
}
258237

259-
let u = a[k + j];
238+
impl FFTElement for Scalar {}
260239

261-
a[k + j] = u + t;
262-
a[k + j + m] = u - t;
263-
264-
w *= w_m;
265-
}
266-
}
267-
m *= 2;
268-
}
269-
}
240+
impl FFTElement for G1Projective {}
270241

271-
/// Computes a FFT of the group elements(points).
272-
///
273-
/// Note: This is essentially multiple multi-scalar multiplications.
274-
fn fft_g1_inplace(twiddle_factors: &[Scalar], a: &mut [G1Projective]) {
242+
fn fft_inplace<T: FFTElement>(twiddle_factors: &[Scalar], a: &mut [T]) {
275243
let n = a.len();
276244
let log_n = log2_pow2(n);
277245
assert_eq!(n, 1 << log_n);
278246

279-
// Bit-reversal permutation
280247
for k in 0..n {
281248
let rk = bitreverse(k as u32, log_n) as usize;
282249
if k < rk {
@@ -294,12 +261,9 @@ fn fft_g1_inplace(twiddle_factors: &[Scalar], a: &mut [G1Projective]) {
294261
a[k + j + m]
295262
} else if w == -Scalar::ONE {
296263
-a[k + j + m]
297-
} else if a[k + j + m].is_identity().into() {
298-
G1Projective::identity()
299264
} else {
300265
a[k + j + m] * w
301266
};
302-
303267
let u = a[k + j];
304268
a[k + j] = u + t;
305269
a[k + j + m] = u - t;
@@ -310,6 +274,14 @@ fn fft_g1_inplace(twiddle_factors: &[Scalar], a: &mut [G1Projective]) {
310274
}
311275
}
312276

277+
fn fft_scalar_inplace(twiddle_factors: &[Scalar], a: &mut [Scalar]) {
278+
fft_inplace(twiddle_factors, a);
279+
}
280+
281+
fn fft_g1_inplace(twiddle_factors: &[Scalar], a: &mut [G1Projective]) {
282+
fft_inplace(twiddle_factors, a);
283+
}
284+
313285
fn bitreverse(mut n: u32, l: u32) -> u32 {
314286
let mut r = 0;
315287
for _ in 0..l {

0 commit comments

Comments
 (0)