Skip to content

Commit 7200dfb

Browse files
authored
create fft module (#292)
1 parent 14236f7 commit 7200dfb

File tree

3 files changed

+79
-77
lines changed

3 files changed

+79
-77
lines changed

cryptography/polynomial/src/domain.rs

+1-77
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use crate::fft::{fft_g1_inplace, fft_scalar_inplace, precompute_twiddle_factors};
12
use crate::poly_coeff::PolyCoeff;
23
use bls12_381::ff::{Field, PrimeField};
34
use bls12_381::{
@@ -223,83 +224,6 @@ impl Domain {
223224
}
224225
}
225226

226-
use std::ops::{Add, Mul, Neg, Sub};
227-
228-
trait FFTElement:
229-
Sized
230-
+ Copy
231-
+ Add<Output = Self>
232-
+ Sub<Output = Self>
233-
+ Mul<Scalar, Output = Self>
234-
+ Neg<Output = Self>
235-
{
236-
}
237-
238-
impl FFTElement for Scalar {}
239-
240-
impl FFTElement for G1Projective {}
241-
242-
fn fft_inplace<T: FFTElement>(twiddle_factors: &[Scalar], a: &mut [T]) {
243-
let n = a.len();
244-
let log_n = log2_pow2(n);
245-
assert_eq!(n, 1 << log_n);
246-
247-
for k in 0..n {
248-
let rk = bitreverse(k as u32, log_n) as usize;
249-
if k < rk {
250-
a.swap(rk, k);
251-
}
252-
}
253-
254-
let mut m = 1;
255-
for s in 0..log_n {
256-
let w_m = twiddle_factors[s as usize];
257-
for k in (0..n).step_by(2 * m) {
258-
let mut w = Scalar::ONE;
259-
for j in 0..m {
260-
let t = if w == Scalar::ONE {
261-
a[k + j + m]
262-
} else if w == -Scalar::ONE {
263-
-a[k + j + m]
264-
} else {
265-
a[k + j + m] * w
266-
};
267-
let u = a[k + j];
268-
a[k + j] = u + t;
269-
a[k + j + m] = u - t;
270-
w *= w_m;
271-
}
272-
}
273-
m *= 2;
274-
}
275-
}
276-
277-
fn fft_scalar_inplace(twiddle_factors: &[Scalar], a: &mut [Scalar]) {
278-
fft_inplace(twiddle_factors, a);
279-
}
280-
281-
fn fft_g1_inplace(twiddle_factors: &[Scalar], a: &mut [G1Projective]) {
282-
fft_inplace(twiddle_factors, a);
283-
}
284-
285-
fn bitreverse(mut n: u32, l: u32) -> u32 {
286-
let mut r = 0;
287-
for _ in 0..l {
288-
r = (r << 1) | (n & 1);
289-
n >>= 1;
290-
}
291-
r
292-
}
293-
fn log2_pow2(n: usize) -> u32 {
294-
n.trailing_zeros()
295-
}
296-
fn precompute_twiddle_factors<F: Field>(omega: &F, n: usize) -> Vec<F> {
297-
let log_n = log2_pow2(n);
298-
(0..log_n)
299-
.map(|s| omega.pow(&[(n / (1 << (s + 1))) as u64]))
300-
.collect()
301-
}
302-
303227
#[cfg(test)]
304228
mod tests {
305229
use crate::poly_coeff::poly_eval;

cryptography/polynomial/src/fft.rs

+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
use bls12_381::{ff::Field, G1Projective, Scalar};
2+
use std::ops::{Add, Mul, Neg, Sub};
3+
4+
trait FFTElement:
5+
Sized
6+
+ Copy
7+
+ Add<Output = Self>
8+
+ Sub<Output = Self>
9+
+ Mul<Scalar, Output = Self>
10+
+ Neg<Output = Self>
11+
{
12+
}
13+
14+
impl FFTElement for Scalar {}
15+
16+
impl FFTElement for G1Projective {}
17+
18+
fn fft_inplace<T: FFTElement>(twiddle_factors: &[Scalar], a: &mut [T]) {
19+
let n = a.len();
20+
let log_n = log2_pow2(n);
21+
assert_eq!(n, 1 << log_n);
22+
23+
for k in 0..n {
24+
let rk = bitreverse(k as u32, log_n) as usize;
25+
if k < rk {
26+
a.swap(rk, k);
27+
}
28+
}
29+
30+
let mut m = 1;
31+
for s in 0..log_n {
32+
let w_m = twiddle_factors[s as usize];
33+
for k in (0..n).step_by(2 * m) {
34+
let mut w = Scalar::ONE;
35+
for j in 0..m {
36+
let t = if w == Scalar::ONE {
37+
a[k + j + m]
38+
} else if w == -Scalar::ONE {
39+
-a[k + j + m]
40+
} else {
41+
a[k + j + m] * w
42+
};
43+
let u = a[k + j];
44+
a[k + j] = u + t;
45+
a[k + j + m] = u - t;
46+
w *= w_m;
47+
}
48+
}
49+
m *= 2;
50+
}
51+
}
52+
53+
pub(crate) fn fft_scalar_inplace(twiddle_factors: &[Scalar], a: &mut [Scalar]) {
54+
fft_inplace(twiddle_factors, a);
55+
}
56+
57+
pub(crate) fn fft_g1_inplace(twiddle_factors: &[Scalar], a: &mut [G1Projective]) {
58+
fft_inplace(twiddle_factors, a);
59+
}
60+
61+
fn bitreverse(mut n: u32, l: u32) -> u32 {
62+
let mut r = 0;
63+
for _ in 0..l {
64+
r = (r << 1) | (n & 1);
65+
n >>= 1;
66+
}
67+
r
68+
}
69+
fn log2_pow2(n: usize) -> u32 {
70+
n.trailing_zeros()
71+
}
72+
pub(crate) fn precompute_twiddle_factors<F: Field>(omega: &F, n: usize) -> Vec<F> {
73+
let log_n = log2_pow2(n);
74+
(0..log_n)
75+
.map(|s| omega.pow(&[(n / (1 << (s + 1))) as u64]))
76+
.collect()
77+
}

cryptography/polynomial/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
pub mod domain;
2+
mod fft;
23
pub mod poly_coeff;

0 commit comments

Comments
 (0)