Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Polyeval{Instance, Witness} #315

Merged
merged 5 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ abomonation_derive = { version = "0.1.0", package = "abomonation_derive_ng" }
tracing = "0.1.37"
cfg-if = "1.0.0"
once_cell = "1.18.0"
itertools = "0.12.0"
itertools = "0.12.0" # zip_eq
rand = "0.8.5"
ref-cast = "1.0.20"
derive_more = "0.99.17"
ref-cast = "1.0.20" # allocation-less conversion in multilinear polys
derive_more = "0.99.17" # lightens impl macros for pasta
static_assertions = "1.1.0"

[target.'cfg(any(target_arch = "x86_64", target_arch = "aarch64"))'.dependencies]
Expand All @@ -55,8 +55,10 @@ grumpkin-msm = { git = "https://github.com/lurk-lab/grumpkin-msm", branch = "dev
# see https://github.com/rust-random/rand/pull/948
getrandom = { version = "0.2.0", default-features = false, features = ["js"] }

[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies]
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
proptest = "1.2.0"

[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies]
pprof = { version = "0.13" }
criterion = { version = "0.5", features = ["html_reports"] }

Expand Down
1 change: 0 additions & 1 deletion src/r1cs/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
//! This module defines R1CS related types and a folding scheme for Relaxed R1CS
mod sparse;
#[cfg(test)]
pub(crate) mod util;

use crate::{
Expand Down
24 changes: 24 additions & 0 deletions src/r1cs/util.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use ff::PrimeField;
use group::Group;
#[cfg(not(target_arch = "wasm32"))]
use proptest::prelude::*;

Expand All @@ -24,3 +25,26 @@ impl<F: PrimeField> Arbitrary for FWrap<F> {
strategy.boxed()
}
}

/// Wrapper struct around a Group element that implements additional traits
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct GWrap<G>(pub G);

impl<G: Group> Copy for GWrap<G> {}

#[cfg(not(target_arch = "wasm32"))]
/// Trait implementation for generating `GWrap<F>` instances with proptest
impl<G: Group> Arbitrary for GWrap<G> {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;

fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
use rand::rngs::StdRng;
use rand_core::SeedableRng;

let strategy = any::<[u8; 32]>()
.prop_map(|seed| Self(G::random(StdRng::from_seed(seed))))
.no_shrink();
strategy.boxed()
}
}
2 changes: 1 addition & 1 deletion src/spartan/batched.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ impl<E: Engine, EE: EvaluationEngineTrait<E>> BatchedRelaxedR1CSSNARKTrait<E>
};

let (batched_u, batched_w, sc_proof_batch, claims_batch_left) =
batch_eval_prove(u_vec, w_vec, &mut transcript)?;
batch_eval_prove(u_vec, &w_vec, &mut transcript)?;

let eval_arg = EE::prove(
ck,
Expand Down
7 changes: 4 additions & 3 deletions src/spartan/batched_ppsnark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ impl<E: Engine, EE: EvaluationEngineTrait<E>> BatchedRelaxedR1CSSNARKTrait<E>
|comm_Az_Bz_Cz, evals_Az_Bz_Cz_at_tau| {
let u = PolyEvalInstance::<E>::batch(
comm_Az_Bz_Cz.as_slice(),
&[], // ignored by the function
vec![], // ignored by the function
evals_Az_Bz_Cz_at_tau.as_slice(),
&c,
);
Expand Down Expand Up @@ -701,7 +701,8 @@ impl<E: Engine, EE: EvaluationEngineTrait<E>> BatchedRelaxedR1CSSNARKTrait<E>
let num_vars_u = w_vec.iter().map(|w| w.p.len().log_2()).collect::<Vec<_>>();
let u_batch =
PolyEvalInstance::<E>::batch_diff_size(&comms_vec, &evals_vec, &num_vars_u, rand_sc, c);
let w_batch = PolyEvalWitness::<E>::batch_diff_size(w_vec, c);
let w_batch =
PolyEvalWitness::<E>::batch_diff_size(&w_vec.iter().by_ref().collect::<Vec<_>>(), c);

let eval_arg = EE::prove(
ck,
Expand Down Expand Up @@ -819,7 +820,7 @@ impl<E: Engine, EE: EvaluationEngineTrait<E>> BatchedRelaxedR1CSSNARKTrait<E>
|comm_Az_Bz_Cz, evals_Az_Bz_Cz_at_tau| {
let u = PolyEvalInstance::<E>::batch(
comm_Az_Bz_Cz.as_slice(),
&tau_coords,
tau_coords.clone(),
evals_Az_Bz_Cz_at_tau.as_slice(),
&c,
);
Expand Down
224 changes: 162 additions & 62 deletions src/spartan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ use crate::{
use ff::Field;
use itertools::Itertools as _;
use polys::multilinear::SparsePolynomial;

use rayon::{iter::IntoParallelRefIterator, prelude::*};
use ref_cast::RefCast;

// Creates a vector of the first `n` powers of `s`.
fn powers<E: Engine>(s: &E::Scalar, n: usize) -> Vec<E::Scalar> {
Expand All @@ -35,7 +37,8 @@ fn powers<E: Engine>(s: &E::Scalar, n: usize) -> Vec<E::Scalar> {
}

/// A type that holds a witness to a polynomial evaluation instance
#[derive(Debug)]
#[repr(transparent)]
#[derive(Debug, RefCast)]
struct PolyEvalWitness<E: Engine> {
p: Vec<E::Scalar>, // polynomial
}
Expand All @@ -47,39 +50,43 @@ impl<E: Engine> PolyEvalWitness<E> {
///
/// We allow the input polynomials to have different sizes, and interpret smaller ones as
/// being padded with 0 to the maximum size of all polynomials.
fn batch_diff_size(W: Vec<Self>, s: E::Scalar) -> Self {
fn batch_diff_size(W: &[&Self], s: E::Scalar) -> Self {
let powers = powers::<E>(&s, W.len());

let size_max = W.iter().map(|w| w.p.len()).max().unwrap();
let p_vec = W.par_iter().map(|w| &w.p);
// Scale the input polynomials by the power of s
let p = W
.into_par_iter()
.zip_eq(powers.par_iter())
.map(|(mut w, s)| {
if *s != E::Scalar::ONE {
w.p.par_iter_mut().for_each(|e| *e *= s);
}
w.p
})
.reduce(
|| vec![E::Scalar::ZERO; size_max],
|left, right| {
// Sum into the largest polynomial
let (mut big, small) = if left.len() > right.len() {
(left, right)
let p = zip_with!((p_vec, powers.par_iter()), |v, weight| {
// compute the weighted sum for each vector
v.iter()
.map(|&x| {
if *weight != E::Scalar::ONE {
x * *weight
} else {
(right, left)
};
x
}
})
.collect::<Vec<_>>()
})
.reduce(
|| vec![E::Scalar::ZERO; size_max],
|left, right| {
// Sum into the largest polynomial
let (mut big, small) = if left.len() > right.len() {
(left, right)
} else {
(right, left)
};

#[allow(clippy::disallowed_methods)]
big
.par_iter_mut()
.zip(small.par_iter())
.for_each(|(b, s)| *b += s);
#[allow(clippy::disallowed_methods)]
big
.par_iter_mut()
.zip(small.par_iter())
.for_each(|(b, s)| *b += s);

big
},
);
big
},
);

Self { p }
}
Expand All @@ -95,22 +102,8 @@ impl<E: Engine> PolyEvalWitness<E> {
.iter()
.skip(1)
.for_each(|p| assert_eq!(p.len(), p_vec[0].len()));

let powers_of_s = powers::<E>(s, p_vec.len());

let p = zip_with!(par_iter, (p_vec, powers_of_s), |v, weight| {
// compute the weighted sum for each vector
v.iter().map(|&x| x * *weight).collect::<Vec<E::Scalar>>()
})
.reduce(
|| vec![E::Scalar::ZERO; p_vec[0].len()],
|acc, v| {
// perform vector addition to combine the weighted vectors
acc.into_iter().zip_eq(v).map(|(x, y)| x + y).collect()
},
);

Self { p }
let instances = p_vec.iter().map(|p| Self::ref_cast(p)).collect::<Vec<_>>();
Self::batch_diff_size(&instances, *s)
}
}

Expand Down Expand Up @@ -150,15 +143,14 @@ impl<E: Engine> PolyEvalInstance<E> {

// vᵢ = L₀(x_lo)⋅Pᵢ(x_hi)
lagrange_eval * eval
})
.collect::<Vec<_>>();
});

// C = ∑ᵢ γⁱ⋅Cᵢ
let comm_joint = zip_with!(iter, (c_vec, powers), |c, g_i| *c * *g_i)
.fold(Commitment::<E>::default(), |acc, item| acc + item);

// v = ∑ᵢ γⁱ⋅vᵢ
let eval_joint = zip_with!((evals_scaled.into_iter(), powers.iter()), |e, g_i| e * g_i).sum();
let eval_joint = zip_with!((evals_scaled, powers.iter()), |e, g_i| e * g_i).sum();

Self {
c: comm_joint,
Expand All @@ -167,22 +159,9 @@ impl<E: Engine> PolyEvalInstance<E> {
}
}

fn batch(c_vec: &[Commitment<E>], x: &[E::Scalar], e_vec: &[E::Scalar], s: &E::Scalar) -> Self {
let num_instances = c_vec.len();
assert_eq!(e_vec.len(), num_instances);

let powers_of_s = powers::<E>(s, num_instances);
// Weighted sum of evaluations
let e = zip_with!(par_iter, (e_vec, powers_of_s), |e, p| *e * p).sum();
// Weighted sum of commitments
let c = zip_with!(par_iter, (c_vec, powers_of_s), |c, p| *c * *p)
.reduce(Commitment::<E>::default, |acc, item| acc + item);

Self {
c,
x: x.to_vec(),
e,
}
fn batch(c_vec: &[Commitment<E>], x: Vec<E::Scalar>, e_vec: &[E::Scalar], s: &E::Scalar) -> Self {
let sizes = vec![x.len(); e_vec.len()];
Self::batch_diff_size(c_vec, e_vec, &sizes, x, *s)
}
}

Expand Down Expand Up @@ -225,3 +204,124 @@ fn compute_eval_table_sparse<E: Engine>(

(A_evals, B_evals, C_evals)
}

#[cfg(all(test, not(target_arch = "wasm32")))]
mod tests {
use super::*;
use crate::provider::PallasEngine;
use crate::r1cs::util::{FWrap, GWrap};
use pasta_curves::pallas::Point as PallasPoint;
use pasta_curves::Fq as Scalar;
use proptest::collection::vec;
use proptest::prelude::*;

impl<E: Engine> PolyEvalWitness<E> {
fn alt_batch(p_vec: &[&Vec<E::Scalar>], s: &E::Scalar) -> Self {
p_vec
.iter()
.skip(1)
.for_each(|p| assert_eq!(p.len(), p_vec[0].len()));

let powers_of_s = powers::<E>(s, p_vec.len());

let p = zip_with!(par_iter, (p_vec, powers_of_s), |v, weight| {
// compute the weighted sum for each vector
v.iter().map(|&x| x * *weight).collect::<Vec<E::Scalar>>()
})
.reduce(
|| vec![E::Scalar::ZERO; p_vec[0].len()],
|acc, v| {
// perform vector addition to combine the weighted vectors
acc.into_iter().zip_eq(v).map(|(x, y)| x + y).collect()
},
);

Self { p }
}
}

impl<E: Engine> PolyEvalInstance<E> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now with generic PolyEvalInstance we can derive various implementations of it, right? Do we use that already (or have plans to create and use) and if so where concretely (more exactly - do we have / need alternative implementation which is different comparing to the previous reference one implemented at Solidity side of things)?

Copy link
Contributor Author

@huitseeker huitseeker Feb 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PolyEvalInstance has been generic since the beginning, and the type itself has not changed:
https://github.com/microsoft/Nova/blob/37871d04ba04771bf19c51e5b336e45fdb86c7f7/src/spartan/mod.rs#L117-L121
I'm not sure I understand the question?

The idea here is that batch_diff_size is a more general function that performs a re-scaling in case the passed-in arguments have a different # of variables. When the # of variables is the same for every argument, batch and batch_diff_size produce exactly the same output, which allows implementing one in terms of the other.

Copy link
Contributor

@storojs72 storojs72 Feb 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, the question is actually - does new PolyEvalInstanse preserve compatibility with older one and where alt_batch is planning to be used (since as far as I see it is absent in upstream)

Copy link
Contributor Author

@huitseeker huitseeker Feb 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alt_batch isn't new, it's the moved, old implementation of batch. it is here in a tests module, gated by #[cfg(test)] and is not planned to be used anywhere. It is here specifically to ensure (through fuzzing unit tests) that PolyEval{Instance, Witness}::batch remains backward-compatible.

fn alt_batch(
c_vec: &[Commitment<E>],
x: Vec<E::Scalar>,
e_vec: &[E::Scalar],
s: &E::Scalar,
) -> Self {
let num_instances = c_vec.len();
assert_eq!(e_vec.len(), num_instances);

let powers_of_s = powers::<E>(s, num_instances);
// Weighted sum of evaluations
let e = zip_with!(par_iter, (e_vec, powers_of_s), |e, p| *e * p).sum();
// Weighted sum of commitments
let c = zip_with!(par_iter, (c_vec, powers_of_s), |c, p| *c * *p)
.reduce(Commitment::<E>::default, |acc, item| acc + item);

Self { c, x, e }
}
}

proptest! {
#[test]
fn test_pe_witness_batch_diff_size_batch(
s in any::<FWrap<Scalar>>(),
vecs in (50usize..100).prop_flat_map(|size| vec(
vec(any::<FWrap<Scalar>>().prop_map(|f| f.0), size..=size), // even-sized vec
1..5))
)
{
// when the vectors are the same size, batch_diff_size and batch agree
let res = PolyEvalWitness::<PallasEngine>::alt_batch(&vecs.iter().by_ref().collect::<Vec<_>>(), &s.0);
let witnesses = vecs.iter().map(PolyEvalWitness::ref_cast).collect::<Vec<_>>();
let res2 = PolyEvalWitness::<PallasEngine>::batch_diff_size(&witnesses, s.0);

prop_assert_eq!(res.p, res2.p);
}

#[test]
fn test_pe_witness_batch_diff_size_pad_batch(
s in any::<FWrap<Scalar>>(),
vecs in (50usize..100).prop_flat_map(|size| vec(
vec(any::<FWrap<Scalar>>().prop_map(|f| f.0), size-10..=size), // even-sized vec
1..10))
)
{
let size = vecs.iter().map(|v| v.len()).max().unwrap_or(0);
// when the vectors are not the same size, batch agrees with the padded version of the input
let padded_vecs = vecs.iter().cloned().map(|mut v| {v.resize(size, Scalar::ZERO); v}).collect::<Vec<_>>();
let res = PolyEvalWitness::<PallasEngine>::alt_batch(&padded_vecs.iter().by_ref().collect::<Vec<_>>(), &s.0);
let witnesses = vecs.iter().map(PolyEvalWitness::ref_cast).collect::<Vec<_>>();
let res2 = PolyEvalWitness::<PallasEngine>::batch_diff_size(&witnesses, s.0);

prop_assert_eq!(res.p, res2.p);
}

#[test]
fn test_pe_instance_batch_diff_size_batch(
s in any::<FWrap<Scalar>>(),
vecs_tuple in (50usize..100).prop_flat_map(|size|
(vec(any::<GWrap<PallasPoint>>().prop_map(|f| f.0), size..=size),
vec(any::<FWrap<Scalar>>().prop_map(|f| f.0), size..=size),
vec(any::<FWrap<Scalar>>().prop_map(|f| f.0), size..=size)
), // even-sized vecs
)
)
{
let (c_vec, e_vec, x_vec) = vecs_tuple;
let c_vecs = c_vec.into_iter().map(|c| Commitment::<PallasEngine>{ comm: c }).collect::<Vec<_>>();
// when poly evals are all for the max # of variables, batch_diff_size and batch agree
let res = PolyEvalInstance::<PallasEngine>::alt_batch(
&c_vecs,
x_vec.clone(),
&e_vec,
&s.0);

let sizes = vec![x_vec.len(); x_vec.len()];
let res2 = PolyEvalInstance::<PallasEngine>::batch_diff_size(&c_vecs, &e_vec, &sizes, x_vec.clone(), s.0);

prop_assert_eq!(res.c, res2.c);
prop_assert_eq!(res.x, res2.x);
prop_assert_eq!(res.e, res2.e);
}
}
}
Loading