diff --git a/Cargo.toml b/Cargo.toml index 0277a2534..767178329 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,12 +19,13 @@ serde = "1.0.166" serde_json = "1.0.96" serde-big-array = "0.5.1" thiserror = "1.0.49" +hashbrown = "0.14.0" # plonky2-related dependencies -plonky2 = "0.2.2" +plonky2 = { git = "https://github.com/0xPolygonZero/plonky2.git", rev = "dc77c77f2b06500e16ad4d7f1c2b057903602eed" } plonky2_maybe_rayon = "0.2.0" -plonky2_util = "0.2.0" -starky = "0.4.0" +plonky2_util = { git = "https://github.com/0xPolygonZero/plonky2.git", rev = "dc77c77f2b06500e16ad4d7f1c2b057903602eed" } +starky = { git = "https://github.com/0xPolygonZero/plonky2.git", rev = "dc77c77f2b06500e16ad4d7f1c2b057903602eed" } [workspace.package] diff --git a/evm_arithmetization/Cargo.toml b/evm_arithmetization/Cargo.toml index 7835399a4..0a3894be1 100644 --- a/evm_arithmetization/Cargo.toml +++ b/evm_arithmetization/Cargo.toml @@ -36,7 +36,7 @@ rlp = { workspace = true } rlp-derive = { workspace = true } serde = { workspace = true, features = ["derive"] } static_assertions = "1.1.0" -hashbrown = { version = "0.14.0" } +hashbrown = { workspace = true } tiny-keccak = "2.0.2" serde_json = { workspace = true } serde-big-array = { workspace = true } diff --git a/evm_arithmetization/src/cpu/kernel/interpreter.rs b/evm_arithmetization/src/cpu/kernel/interpreter.rs index b36c5390a..886093b29 100644 --- a/evm_arithmetization/src/cpu/kernel/interpreter.rs +++ b/evm_arithmetization/src/cpu/kernel/interpreter.rs @@ -422,7 +422,7 @@ impl Interpreter { } pub(crate) fn run(&mut self) -> Result<(RegistersState, Option), anyhow::Error> { - let (final_registers, final_mem) = self.run_cpu(self.max_cpu_len_log, self.is_dummy)?; + let (final_registers, final_mem) = self.run_cpu(self.max_cpu_len_log)?; #[cfg(debug_assertions)] { diff --git a/evm_arithmetization/src/cpu/kernel/tests/bignum/mod.rs b/evm_arithmetization/src/cpu/kernel/tests/bignum/mod.rs index 0db8b13da..efb37bcd3 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/bignum/mod.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/bignum/mod.rs @@ -9,7 +9,6 @@ use itertools::Itertools; use num::{BigUint, One, Zero}; use num_bigint::RandBigInt; use plonky2::field::goldilocks_field::GoldilocksField as F; -use plonky2_util::ceil_div_usize; use rand::Rng; use crate::cpu::kernel::aggregator::KERNEL; @@ -90,7 +89,7 @@ fn max_bignum(bit_size: usize) -> BigUint { } fn bignum_len(a: &BigUint) -> usize { - ceil_div_usize(a.bits() as usize, BIGNUM_LIMB_BITS) + (a.bits() as usize).div_ceil(BIGNUM_LIMB_BITS) } fn run_test(fn_label: &str, memory: Vec, stack: Vec) -> Result<(Vec, Vec)> { diff --git a/evm_arithmetization/src/cpu/kernel/tests/init_exc_stop.rs b/evm_arithmetization/src/cpu/kernel/tests/init_exc_stop.rs index a657a3e8e..e37407034 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/init_exc_stop.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/init_exc_stop.rs @@ -13,8 +13,12 @@ use crate::generation::state::State; use crate::generation::TrieInputs; use crate::generation::NUM_EXTRA_CYCLES_AFTER; use crate::generation::NUM_EXTRA_CYCLES_BEFORE; +use crate::memory::segments::Segment; use crate::proof::BlockMetadata; +use crate::proof::RegistersData; +use crate::proof::RegistersIdx; use crate::proof::TrieRoots; +use crate::witness::memory::MemoryAddress; use crate::witness::state::RegistersState; use crate::{proof::BlockHashes, GenerationInputs, Node}; @@ -89,19 +93,49 @@ fn test_init_exc_stop() { "Incorrect registers for dummy run." ); - let main_offset = KERNEL.global_labels["main"]; - let mut interpreter: Interpreter = - Interpreter::new_dummy_with_generation_inputs(initial_offset, vec![], &inputs); + let exc_stop_offset = KERNEL.global_labels["exc_stop"]; + + let pc_u256 = U256::from(interpreter.get_registers().program_counter); + let exit_info = pc_u256 + (U256::one() << 32); + interpreter.push(exit_info).unwrap(); + interpreter.get_mut_registers().program_counter = exc_stop_offset; interpreter.halt_offsets = vec![KERNEL.global_labels["halt_final"]]; interpreter.set_is_kernel(true); interpreter.clock = 0; + + // Set the program counter and `is_kernel` at the end of the execution. The + // `registers_before` and `registers_after` are stored contiguously in the + // `RegistersState` segment. We need to update `registers_after` here, hence the + // offset by `RegistersData::SIZE`. + let regs_to_set = [ + ( + MemoryAddress { + context: 0, + segment: Segment::RegistersStates.unscale(), + virt: RegistersData::SIZE + RegistersIdx::ProgramCounter as usize, + }, + pc_u256, + ), + ( + MemoryAddress { + context: 0, + segment: Segment::RegistersStates.unscale(), + virt: RegistersData::SIZE + RegistersIdx::IsKernel as usize, + }, + U256::one(), + ), + ]; + interpreter.set_memory_multi_addresses(®s_to_set); + interpreter.run().expect("Running dummy exc_stop failed."); - // The "-1" comes from the fact that we stop 1 cycle before the max, to allow - // for one padding row, which is needed for CPU STARK. + // The "-2" comes from the fact that: + // - we stop 1 cycle before the max, to allow for one padding row, which is + // needed for CPU STARK. + // - we need one additional cycle to enter `exc_stop`. assert_eq!( interpreter.get_clock(), - NUM_EXTRA_CYCLES_BEFORE + NUM_EXTRA_CYCLES_AFTER - 1, + NUM_EXTRA_CYCLES_AFTER - 2, "NUM_EXTRA_CYCLES_AFTER is set incorrectly." ); } diff --git a/evm_arithmetization/src/cpu/kernel/utils.rs b/evm_arithmetization/src/cpu/kernel/utils.rs index adda086e8..082086d17 100644 --- a/evm_arithmetization/src/cpu/kernel/utils.rs +++ b/evm_arithmetization/src/cpu/kernel/utils.rs @@ -1,7 +1,6 @@ use core::fmt::Debug; use ethereum_types::U256; -use plonky2_util::ceil_div_usize; /// Enumerate the length `W` windows of `vec`, and run `maybe_replace` on each /// one. @@ -28,7 +27,7 @@ where } pub(crate) fn u256_to_trimmed_be_bytes(u256: &U256) -> Vec { - let num_bytes = ceil_div_usize(u256.bits(), 8); + let num_bytes = u256.bits().div_ceil(8); // `byte` is little-endian, so we manually reverse it. (0..num_bytes).rev().map(|i| u256.byte(i)).collect() } diff --git a/evm_arithmetization/src/fixed_recursive_verifier.rs b/evm_arithmetization/src/fixed_recursive_verifier.rs index 7daf76bc0..59035b909 100644 --- a/evm_arithmetization/src/fixed_recursive_verifier.rs +++ b/evm_arithmetization/src/fixed_recursive_verifier.rs @@ -15,7 +15,7 @@ use plonky2::fri::oracle::PolynomialBatch; use plonky2::fri::FriParams; use plonky2::gates::constant::ConstantGate; use plonky2::gates::noop::NoopGate; -use plonky2::hash::hash_types::{MerkleCapTarget, RichField}; +use plonky2::hash::hash_types::{MerkleCapTarget, RichField, NUM_HASH_OUT_ELTS}; use plonky2::iop::challenger::RecursiveChallenger; use plonky2::iop::target::{BoolTarget, Target}; use plonky2::iop::witness::{PartialWitness, WitnessWrite}; @@ -49,7 +49,10 @@ use crate::proof::{ FinalPublicValues, MemCapTarget, PublicValues, PublicValuesTarget, RegistersDataTarget, TrieRoots, TrieRootsTarget, TARGET_HASH_SIZE, }; -use crate::prover::{check_abort_signal, generate_all_data_segments, prove, GenerationSegmentData}; +use crate::prover::{ + check_abort_signal, generate_all_data_segments, prove, GenerationSegmentData, + SegmentDataIterator, +}; use crate::recursive_verifier::{ add_common_recursion_gates, add_virtual_public_values, get_memory_extra_looking_sum_circuit, recursive_stark_circuit, set_public_value_targets, PlonkWrapperCircuit, PublicInputs, @@ -64,12 +67,14 @@ use crate::witness::state::RegistersState; /// this size. const THRESHOLD_DEGREE_BITS: usize = 13; +#[derive(Clone)] pub struct ProverOutputData where F: RichField + Extendable, C: GenericConfig, C::Hasher: AlgebraicHasher, { + pub is_dummy: bool, pub proof_with_pis: ProofWithPublicInputs, pub public_values: PublicValues, } @@ -187,7 +192,7 @@ where { pub circuit: CircuitData, lhs: AggregationChildTarget, - rhs: AggregationChildTarget, + rhs: AggregationChildWithDummyTarget, public_values: PublicValuesTarget, cyclic_vk: VerifierCircuitTarget, } @@ -220,7 +225,7 @@ where let cyclic_vk = buffer.read_target_verifier_circuit()?; let public_values = PublicValuesTarget::from_buffer(buffer)?; let lhs = AggregationChildTarget::from_buffer(buffer)?; - let rhs = AggregationChildTarget::from_buffer(buffer)?; + let rhs = AggregationChildWithDummyTarget::from_buffer(buffer)?; Ok(Self { circuit, lhs, @@ -272,6 +277,52 @@ impl AggregationChildTarget { } } +#[derive(Eq, PartialEq, Debug)] +struct AggregationChildWithDummyTarget { + is_agg: BoolTarget, + is_dummy: BoolTarget, + agg_proof: ProofWithPublicInputsTarget, + real_proof: ProofWithPublicInputsTarget, +} + +impl AggregationChildWithDummyTarget { + fn to_buffer(&self, buffer: &mut Vec) -> IoResult<()> { + buffer.write_target_bool(self.is_agg)?; + buffer.write_target_bool(self.is_dummy)?; + buffer.write_target_proof_with_public_inputs(&self.agg_proof)?; + buffer.write_target_proof_with_public_inputs(&self.real_proof)?; + Ok(()) + } + + fn from_buffer(buffer: &mut Buffer) -> IoResult { + let is_agg = buffer.read_target_bool()?; + let is_dummy = buffer.read_target_bool()?; + let agg_proof = buffer.read_target_proof_with_public_inputs()?; + let real_proof = buffer.read_target_proof_with_public_inputs()?; + Ok(Self { + is_agg, + is_dummy, + agg_proof, + real_proof, + }) + } + + // `len_mem_cap` is the length of the Merkle + // caps for `MemBefore` and `MemAfter`. + fn public_values>( + &self, + builder: &mut CircuitBuilder, + len_mem_cap: usize, + ) -> PublicValuesTarget { + let agg_pv = + PublicValuesTarget::from_public_inputs(&self.agg_proof.public_inputs, len_mem_cap); + let segment_pv = + PublicValuesTarget::from_public_inputs(&self.real_proof.public_inputs, len_mem_cap); + + PublicValuesTarget::select(builder, self.is_agg, agg_pv, segment_pv) + } +} + /// Data for the transaction aggregation circuit, which is used to compress two /// proofs into one. Each inner proof can be either a segment aggregation proof /// or another transaction aggregation proof. @@ -802,14 +853,20 @@ where let public_values = add_virtual_public_values(&mut builder, cap_before_len); let cyclic_vk = builder.add_verifier_data_public_inputs(); + // The right hand side child might be dummy. let lhs_segment = Self::add_segment_agg_child(&mut builder, root); - let rhs_segment = Self::add_segment_agg_child(&mut builder, root); + let rhs_segment = + Self::add_segment_agg_child_with_dummy(&mut builder, root, lhs_segment.proof.clone()); let lhs_pv = lhs_segment.public_values(&mut builder, cap_before_len); let rhs_pv = rhs_segment.public_values(&mut builder, cap_before_len); - // All the block metadata is the same for both segments. It is also the case for - // extra_block_data. + let is_dummy = rhs_segment.is_dummy; + let one = builder.one(); + let is_not_dummy = builder.sub(one, is_dummy.target); + let is_not_dummy = BoolTarget::new_unsafe(is_not_dummy); + + // Always connect the lhs to the aggregation public values. TrieRootsTarget::connect( &mut builder, public_values.trie_roots_before, @@ -818,76 +875,111 @@ where TrieRootsTarget::connect( &mut builder, public_values.trie_roots_after, - rhs_pv.trie_roots_after, + lhs_pv.trie_roots_after, ); - TrieRootsTarget::connect( + BlockMetadataTarget::connect( + &mut builder, + public_values.block_metadata, + lhs_pv.block_metadata, + ); + BlockHashesTarget::connect( + &mut builder, + public_values.block_hashes, + lhs_pv.block_hashes, + ); + ExtraBlockDataTarget::connect( &mut builder, + public_values.extra_block_data, + lhs_pv.extra_block_data, + ); + RegistersDataTarget::connect( + &mut builder, + public_values.registers_before.clone(), + lhs_pv.registers_before.clone(), + ); + MemCapTarget::connect( + &mut builder, + public_values.mem_before.clone(), + lhs_pv.mem_before.clone(), + ); + + // If the rhs is a real proof, all the block metadata must be the same for both + // segments. It is also the case for the extra block data. + TrieRootsTarget::conditional_assert_eq( + &mut builder, + is_not_dummy, public_values.trie_roots_before, rhs_pv.trie_roots_before, ); - TrieRootsTarget::connect( + TrieRootsTarget::conditional_assert_eq( &mut builder, + is_not_dummy, public_values.trie_roots_after, - lhs_pv.trie_roots_after, + rhs_pv.trie_roots_after, ); - BlockMetadataTarget::connect( + BlockMetadataTarget::conditional_assert_eq( &mut builder, + is_not_dummy, public_values.block_metadata, rhs_pv.block_metadata, ); - BlockMetadataTarget::connect( - &mut builder, - public_values.block_metadata, - lhs_pv.block_metadata, - ); - BlockHashesTarget::connect( + BlockHashesTarget::conditional_assert_eq( &mut builder, + is_not_dummy, public_values.block_hashes, rhs_pv.block_hashes, ); - BlockHashesTarget::connect( - &mut builder, - public_values.block_hashes, - lhs_pv.block_hashes, - ); - ExtraBlockDataTarget::connect( + ExtraBlockDataTarget::conditional_assert_eq( &mut builder, + is_not_dummy, public_values.extra_block_data, rhs_pv.extra_block_data, ); - ExtraBlockDataTarget::connect( - &mut builder, - public_values.extra_block_data, - lhs_pv.extra_block_data, - ); - // Connect registers and merkle caps between segments. - RegistersDataTarget::connect( + // If the rhs is a real proof: Connect registers and merkle caps between + // segments. + RegistersDataTarget::conditional_assert_eq( &mut builder, + is_not_dummy, public_values.registers_after.clone(), rhs_pv.registers_after.clone(), ); - RegistersDataTarget::connect( + RegistersDataTarget::conditional_assert_eq( &mut builder, - public_values.registers_before.clone(), - lhs_pv.registers_before.clone(), + is_not_dummy, + lhs_pv.registers_after.clone(), + rhs_pv.registers_before.clone(), ); - RegistersDataTarget::connect( + MemCapTarget::conditional_assert_eq( &mut builder, - lhs_pv.registers_after, - rhs_pv.registers_before.clone(), + is_not_dummy, + public_values.mem_after.clone(), + rhs_pv.mem_after.clone(), ); - MemCapTarget::connect( + MemCapTarget::conditional_assert_eq( &mut builder, - public_values.mem_before.clone(), - lhs_pv.mem_before.clone(), + is_not_dummy, + lhs_pv.mem_after.clone(), + rhs_pv.mem_before.clone(), ); - MemCapTarget::connect( + + // If the rhs is a dummy, then the lhs must be a segment. + let constr = builder.mul(is_dummy.target, lhs_segment.is_agg.target); + builder.assert_zero(constr); + + // If the rhs is a dummy, then the aggregation PVs are equal to the lhs PVs. + MemCapTarget::conditional_assert_eq( &mut builder, + is_dummy, public_values.mem_after.clone(), - rhs_pv.mem_after, + lhs_pv.mem_after, + ); + RegistersDataTarget::conditional_assert_eq( + &mut builder, + is_dummy, + public_values.registers_after.clone(), + lhs_pv.registers_after, ); - MemCapTarget::connect(&mut builder, lhs_pv.mem_after, rhs_pv.mem_before); // Pad to match the root circuit's degree. while log2_ceil(builder.num_gates()) < root.circuit.common.degree_bits() { @@ -1186,6 +1278,36 @@ where } } + fn add_segment_agg_child_with_dummy( + builder: &mut CircuitBuilder, + root: &RootCircuitData, + dummy_proof: ProofWithPublicInputsTarget, + ) -> AggregationChildWithDummyTarget { + let common = &root.circuit.common; + let root_vk = builder.constant_verifier_data(&root.circuit.verifier_only); + let is_agg = builder.add_virtual_bool_target_safe(); + let agg_proof = builder.add_virtual_proof_with_pis(common); + let is_dummy = builder.add_virtual_bool_target_safe(); + let real_proof = builder.add_virtual_proof_with_pis(common); + + let segment_proof = builder.select_proof_with_pis(is_dummy, &dummy_proof, &real_proof); + builder + .conditionally_verify_cyclic_proof::( + is_agg, + &agg_proof, + &segment_proof, + &root_vk, + common, + ) + .expect("Failed to build cyclic recursion circuit"); + AggregationChildWithDummyTarget { + is_agg, + is_dummy, + agg_proof, + real_proof, + } + } + fn add_txn_agg_child( builder: &mut CircuitBuilder, segment_agg: &SegmentAggregationCircuitData, @@ -1424,6 +1546,7 @@ where let root_proof = self.root.circuit.prove(root_inputs)?; Ok(ProverOutputData { + is_dummy: false, proof_with_pis: root_proof, public_values: all_proof.public_values, }) @@ -1440,21 +1563,34 @@ where timing: &mut TimingTree, abort_signal: Option>, ) -> anyhow::Result>> { - let mut all_data_segments = - generate_all_data_segments::(Some(max_cpu_len_log), &generation_inputs)?; - let mut proofs = Vec::with_capacity(all_data_segments.len()); - for mut data in all_data_segments { + let mut it_segment_data = SegmentDataIterator { + inputs: &generation_inputs, + partial_next_data: None, + max_cpu_len_log: Some(max_cpu_len_log), + }; + + let mut proofs = vec![]; + + for mut next_data in it_segment_data { let proof = self.prove_segment( all_stark, config, generation_inputs.clone(), - &mut data, + &mut next_data.1, timing, abort_signal.clone(), )?; proofs.push(proof); } + // Since aggregations require at least two segment proofs, add a dummy proof if + // there is only one proof. + if proofs.len() == 1 { + let mut first_proof = proofs[0].clone(); + first_proof.is_dummy = true; + proofs.push(first_proof); + } + Ok(proofs) } @@ -1567,34 +1703,29 @@ where /// /// - `lhs_is_agg`: a boolean indicating whether the left child proof is an /// aggregation proof or a regular segment proof. - /// - `lhs_proof`: the left child proof. - /// - `lhs_public_values`: the public values associated to the right child - /// proof. + /// - `lhs_proof`: the left child prover output data. /// - `rhs_is_agg`: a boolean indicating whether the right child proof is an /// aggregation proof or a regular transaction proof. - /// - `rhs_proof`: the right child proof. - /// - `rhs_public_values`: the public values associated to the right child - /// proof. + /// - `rhs_proof`: the right child prover output data. /// /// # Outputs /// - /// This method outputs a tuple of [`ProofWithPublicInputs`] and - /// its [`PublicValues`]. Only the proof with public inputs is necessary - /// for a verifier to assert correctness of the computation, - /// but the public values are output for the prover convenience, as these - /// are necessary during proof aggregation. + /// This method outputs a [`ProverOutputData`]. Only the proof with + /// public inputs is necessary for a verifier to assert correctness of + /// the computation, but the public values and `is_dummy` are output for the + /// prover convenience, as these are necessary during proof aggregation. pub fn prove_segment_aggregation( &self, lhs_is_agg: bool, - lhs_proof: &ProofWithPublicInputs, - lhs_public_values: PublicValues, - + lhs_prover_output: &ProverOutputData, rhs_is_agg: bool, - rhs_proof: &ProofWithPublicInputs, - rhs_public_values: PublicValues, - ) -> anyhow::Result<(ProofWithPublicInputs, PublicValues)> { + rhs_prover_output: &ProverOutputData, + ) -> anyhow::Result> { let mut agg_inputs = PartialWitness::new(); + let lhs_proof = &lhs_prover_output.proof_with_pis; + let rhs_proof = &rhs_prover_output.proof_with_pis; + let rhs_is_dummy = rhs_prover_output.is_dummy; Self::set_dummy_if_necessary( &self.segment_aggregation.lhs, lhs_is_agg, @@ -1603,12 +1734,25 @@ where lhs_proof, ); - Self::set_dummy_if_necessary( + let len_mem_cap = self + .segment_aggregation + .public_values + .mem_before + .mem_cap + .0 + .len(); + + // If rhs is dummy, the rhs proof is also set to be the lhs. + let real_rhs_proof = if rhs_is_dummy { lhs_proof } else { rhs_proof }; + + Self::set_dummy_if_necessary_with_dummy( &self.segment_aggregation.rhs, rhs_is_agg, + rhs_is_dummy, &self.segment_aggregation.circuit, &mut agg_inputs, - rhs_proof, + real_rhs_proof, + len_mem_cap, ); agg_inputs.set_verifier_data_target( @@ -1617,24 +1761,34 @@ where ); // Aggregates both `PublicValues` from the provided proofs into a single one. + + let lhs_public_values = &lhs_prover_output.public_values; + let rhs_public_values = &rhs_prover_output.public_values; + + let real_public_values = if rhs_is_dummy { + lhs_public_values.clone() + } else { + rhs_public_values.clone() + }; + let agg_public_values = PublicValues { - trie_roots_before: lhs_public_values.trie_roots_before, - trie_roots_after: rhs_public_values.trie_roots_after, + trie_roots_before: lhs_public_values.trie_roots_before.clone(), + trie_roots_after: real_public_values.trie_roots_after, extra_block_data: ExtraBlockData { checkpoint_state_trie_root: lhs_public_values .extra_block_data .checkpoint_state_trie_root, txn_number_before: lhs_public_values.extra_block_data.txn_number_before, - txn_number_after: rhs_public_values.extra_block_data.txn_number_after, + txn_number_after: real_public_values.extra_block_data.txn_number_after, gas_used_before: lhs_public_values.extra_block_data.gas_used_before, - gas_used_after: rhs_public_values.extra_block_data.gas_used_after, + gas_used_after: real_public_values.extra_block_data.gas_used_after, }, - block_metadata: rhs_public_values.block_metadata, - block_hashes: rhs_public_values.block_hashes, - registers_before: lhs_public_values.registers_before, - registers_after: rhs_public_values.registers_after, - mem_before: lhs_public_values.mem_before, - mem_after: rhs_public_values.mem_after, + block_metadata: real_public_values.block_metadata, + block_hashes: real_public_values.block_hashes, + registers_before: lhs_public_values.registers_before.clone(), + registers_after: real_public_values.registers_after, + mem_before: lhs_public_values.mem_before.clone(), + mem_after: real_public_values.mem_after, }; set_public_value_targets( @@ -1647,7 +1801,12 @@ where })?; let aggregation_proof = self.segment_aggregation.circuit.prove(agg_inputs)?; - Ok((aggregation_proof, agg_public_values)) + let agg_output = ProverOutputData { + is_dummy: false, + proof_with_pis: aggregation_proof, + public_values: agg_public_values, + }; + Ok(agg_output) } pub fn verify_segment_aggregation( @@ -1814,6 +1973,34 @@ where agg_inputs.set_proof_with_pis_target(&agg_child.proof, proof); } + /// If the proof is not an aggregation, we set the cyclic vk to a dummy + /// value, so that it corresponds to the aggregation cyclic vk. If the proof + /// is dummy, we set `is_dummy` to `true`. Note that only the rhs can be + /// dummy. + fn set_dummy_if_necessary_with_dummy( + agg_child: &AggregationChildWithDummyTarget, + is_agg: bool, + is_dummy: bool, + circuit: &CircuitData, + agg_inputs: &mut PartialWitness, + proof: &ProofWithPublicInputs, + len_mem_cap: usize, + ) { + agg_inputs.set_bool_target(agg_child.is_agg, is_agg); + agg_inputs.set_bool_target(agg_child.is_dummy, is_dummy); + if is_agg { + agg_inputs.set_proof_with_pis_target(&agg_child.agg_proof, proof); + } else { + Self::set_dummy_proof_with_cyclic_vk_pis( + circuit, + agg_inputs, + &agg_child.agg_proof, + proof, + ); + } + agg_inputs.set_proof_with_pis_target(&agg_child.real_proof, proof); + } + /// Create a final block proof, once all transactions of a given block have /// been combined into a single aggregation proof. /// diff --git a/evm_arithmetization/src/generation/mod.rs b/evm_arithmetization/src/generation/mod.rs index 387072220..e79099e6a 100644 --- a/evm_arithmetization/src/generation/mod.rs +++ b/evm_arithmetization/src/generation/mod.rs @@ -474,7 +474,7 @@ fn simulate_cpu( max_cpu_len_log: Option, is_dummy: bool, ) -> anyhow::Result<(RegistersState, Option)> { - let (final_registers, mem_after) = state.run_cpu(max_cpu_len_log, is_dummy)?; + let (final_registers, mem_after) = state.run_cpu(max_cpu_len_log)?; let pc = state.registers.program_counter; // Setting the values of padding rows. diff --git a/evm_arithmetization/src/generation/state.rs b/evm_arithmetization/src/generation/state.rs index cabbb60f6..36108bf0b 100644 --- a/evm_arithmetization/src/generation/state.rs +++ b/evm_arithmetization/src/generation/state.rs @@ -77,13 +77,9 @@ pub(crate) trait State { fn get_context(&self) -> usize; /// Checks whether we have reached the maximal cpu length. - fn at_end_segment(&self, opt_max_cpu_len: Option, is_dummy: bool) -> bool { - if let Some(max_cpu_len_log) = opt_max_cpu_len { - if is_dummy { - self.get_clock() == max_cpu_len_log - NUM_EXTRA_CYCLES_AFTER - } else { - self.get_clock() == (1 << max_cpu_len_log) - NUM_EXTRA_CYCLES_AFTER - } + fn at_end_segment(&self, opt_cycle_limit: Option) -> bool { + if let Some(cycle_limit) = opt_cycle_limit { + self.get_clock() == cycle_limit } else { false } @@ -92,10 +88,8 @@ pub(crate) trait State { /// Checks whether we have reached the `halt` label in kernel mode. fn at_halt(&self) -> bool { let halt = KERNEL.global_labels["halt"]; - let halt_final = KERNEL.global_labels["halt_final"]; let registers = self.get_registers(); - registers.is_kernel - && (registers.program_counter == halt || registers.program_counter == halt_final) + registers.is_kernel && (registers.program_counter == halt) } /// Returns the context in which the jumpdest analysis should end. @@ -176,13 +170,15 @@ pub(crate) trait State { fn run_cpu( &mut self, max_cpu_len_log: Option, - is_dummy: bool, ) -> anyhow::Result<(RegistersState, Option)> where Self: Transition, { let halt_offsets = self.get_halt_offsets(); + let cycle_limit = + max_cpu_len_log.map(|max_len_log| (1 << max_len_log) - NUM_EXTRA_CYCLES_AFTER); + let mut final_registers = RegistersState::default(); let final_mem = self.get_active_memory(); let mut running = true; @@ -192,7 +188,7 @@ pub(crate) trait State { let pc = registers.program_counter; let halt_final = registers.is_kernel && halt_offsets.contains(&pc); - if running && (self.at_halt() || self.at_end_segment(max_cpu_len_log, is_dummy)) { + if running && (self.at_halt() || self.at_end_segment(cycle_limit)) { running = false; final_registers = registers; diff --git a/evm_arithmetization/src/keccak_sponge/keccak_sponge_stark.rs b/evm_arithmetization/src/keccak_sponge/keccak_sponge_stark.rs index db1efae99..73418c40f 100644 --- a/evm_arithmetization/src/keccak_sponge/keccak_sponge_stark.rs +++ b/evm_arithmetization/src/keccak_sponge/keccak_sponge_stark.rs @@ -13,7 +13,6 @@ use plonky2::iop::ext_target::ExtensionTarget; use plonky2::timed; use plonky2::util::timing::TimingTree; use plonky2::util::transpose; -use plonky2_util::ceil_div_usize; use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use starky::evaluation_frame::StarkEvaluationFrame; use starky::lookup::{Column, Filter, Lookup}; @@ -137,7 +136,7 @@ pub(crate) fn ctl_looking_memory(i: usize) -> Vec> { /// Returns the number of `KeccakSponge` tables looking into the `LogicStark`. pub(crate) const fn num_logic_ctls() -> usize { const U8S_PER_CTL: usize = 32; - ceil_div_usize(KECCAK_RATE_BYTES, U8S_PER_CTL) + KECCAK_RATE_BYTES.div_ceil(U8S_PER_CTL) } /// Creates the vector of `Columns` required to perform the `i`th logic CTL. diff --git a/evm_arithmetization/src/logic.rs b/evm_arithmetization/src/logic.rs index c5f952465..be389450c 100644 --- a/evm_arithmetization/src/logic.rs +++ b/evm_arithmetization/src/logic.rs @@ -10,7 +10,6 @@ use plonky2::hash::hash_types::RichField; use plonky2::iop::ext_target::ExtensionTarget; use plonky2::timed; use plonky2::util::timing::TimingTree; -use plonky2_util::ceil_div_usize; use starky::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer}; use starky::evaluation_frame::StarkEvaluationFrame; use starky::lookup::{Column, Filter}; @@ -28,7 +27,7 @@ const VAL_BITS: usize = 256; pub(crate) const PACKED_LIMB_BITS: usize = 32; /// Number of field elements needed to store each input/output at the specified /// packing. -const PACKED_LEN: usize = ceil_div_usize(VAL_BITS, PACKED_LIMB_BITS); +const PACKED_LEN: usize = VAL_BITS.div_ceil(PACKED_LIMB_BITS); /// `LogicStark` columns. pub(crate) mod columns { diff --git a/evm_arithmetization/src/proof.rs b/evm_arithmetization/src/proof.rs index 2aa73f072..f946b63cf 100644 --- a/evm_arithmetization/src/proof.rs +++ b/evm_arithmetization/src/proof.rs @@ -378,7 +378,19 @@ pub struct RegistersData { /// Gas used so far. pub gas_used: U256, } + +pub(crate) enum RegistersIdx { + ProgramCounter = 0, + IsKernel = 1, + StackLen = 2, + StackTop = 3, + Context = 4, + GasUsed = 5, +} + impl RegistersData { + pub(crate) const SIZE: usize = 6; + pub fn from_public_inputs(pis: &[F]) -> Self { assert!(pis.len() == RegistersDataTarget::SIZE); @@ -859,6 +871,28 @@ impl TrieRootsTarget { builder.connect(tr0.receipts_root[i], tr1.receipts_root[i]); } } + + /// If `condition`, asserts that `tr0 == tr1`. + pub(crate) fn conditional_assert_eq, const D: usize>( + builder: &mut CircuitBuilder, + condition: BoolTarget, + tr0: Self, + tr1: Self, + ) { + for i in 0..8 { + builder.conditional_assert_eq(condition.target, tr0.state_root[i], tr1.state_root[i]); + builder.conditional_assert_eq( + condition.target, + tr0.transactions_root[i], + tr1.transactions_root[i], + ); + builder.conditional_assert_eq( + condition.target, + tr0.receipts_root[i], + tr1.receipts_root[i], + ); + } + } } /// Circuit version of `BlockMetadata`. @@ -979,6 +1013,45 @@ impl BlockMetadataTarget { builder.connect(bm0.block_bloom[i], bm1.block_bloom[i]) } } + + /// If `condition`, asserts that `bm0 == bm1`. + pub(crate) fn conditional_assert_eq, const D: usize>( + builder: &mut CircuitBuilder, + condition: BoolTarget, + bm0: Self, + bm1: Self, + ) { + for i in 0..5 { + builder.conditional_assert_eq( + condition.target, + bm0.block_beneficiary[i], + bm1.block_beneficiary[i], + ); + } + builder.conditional_assert_eq(condition.target, bm0.block_timestamp, bm1.block_timestamp); + builder.conditional_assert_eq(condition.target, bm0.block_number, bm1.block_number); + builder.conditional_assert_eq(condition.target, bm0.block_difficulty, bm1.block_difficulty); + for i in 0..8 { + builder.conditional_assert_eq( + condition.target, + bm0.block_random[i], + bm1.block_random[i], + ); + } + builder.conditional_assert_eq(condition.target, bm0.block_gaslimit, bm1.block_gaslimit); + builder.conditional_assert_eq(condition.target, bm0.block_chain_id, bm1.block_chain_id); + for i in 0..2 { + builder.conditional_assert_eq( + condition.target, + bm0.block_base_fee[i], + bm1.block_base_fee[i], + ) + } + builder.conditional_assert_eq(condition.target, bm0.block_gas_used, bm1.block_gas_used); + for i in 0..64 { + builder.conditional_assert_eq(condition.target, bm0.block_bloom[i], bm1.block_bloom[i]) + } + } } /// Circuit version of `BlockHashes`. @@ -1044,6 +1117,21 @@ impl BlockHashesTarget { builder.connect(bm0.cur_hash[i], bm1.cur_hash[i]); } } + + /// If `condition`, asserts that `bm0 == bm1`. + pub(crate) fn conditional_assert_eq, const D: usize>( + builder: &mut CircuitBuilder, + condition: BoolTarget, + bm0: Self, + bm1: Self, + ) { + for i in 0..2048 { + builder.conditional_assert_eq(condition.target, bm0.prev_hashes[i], bm1.prev_hashes[i]); + } + for i in 0..8 { + builder.conditional_assert_eq(condition.target, bm0.cur_hash[i], bm1.cur_hash[i]); + } + } } /// Circuit version of `ExtraBlockData`. @@ -1070,7 +1158,7 @@ pub struct ExtraBlockDataTarget { impl ExtraBlockDataTarget { /// Number of `Target`s required for the extra block data. - const SIZE: usize = 12; + pub const SIZE: usize = 12; /// Extracts the extra block data `Target`s from the public input `Target`s. /// The provided `pis` should start with the extra vblock data. @@ -1135,6 +1223,30 @@ impl ExtraBlockDataTarget { builder.connect(ed0.gas_used_before, ed1.gas_used_before); builder.connect(ed0.gas_used_after, ed1.gas_used_after); } + + /// If `condition`, asserts that `ed0 == ed1`. + pub(crate) fn conditional_assert_eq, const D: usize>( + builder: &mut CircuitBuilder, + condition: BoolTarget, + ed0: Self, + ed1: Self, + ) { + for i in 0..8 { + builder.conditional_assert_eq( + condition.target, + ed0.checkpoint_state_trie_root[i], + ed1.checkpoint_state_trie_root[i], + ); + } + builder.conditional_assert_eq( + condition.target, + ed0.txn_number_before, + ed1.txn_number_before, + ); + builder.conditional_assert_eq(condition.target, ed0.txn_number_after, ed1.txn_number_after); + builder.conditional_assert_eq(condition.target, ed0.gas_used_before, ed1.gas_used_before); + builder.conditional_assert_eq(condition.target, ed0.gas_used_after, ed1.gas_used_after); + } } /// Circuit version of `RegistersData`. @@ -1158,7 +1270,7 @@ pub struct RegistersDataTarget { impl RegistersDataTarget { /// Number of `Target`s required for the extra block data. - const SIZE: usize = 13; + pub const SIZE: usize = 13; /// Extracts the extra block data `Target`s from the public input `Target`s. /// The provided `pis` should start with the extra vblock data. @@ -1216,6 +1328,23 @@ impl RegistersDataTarget { builder.connect(rd0.context, rd1.context); builder.connect(rd0.gas_used, rd1.gas_used); } + + /// If `condition`, asserts that `rd0 == rd1`. + pub(crate) fn conditional_assert_eq, const D: usize>( + builder: &mut CircuitBuilder, + condition: BoolTarget, + rd0: Self, + rd1: Self, + ) { + builder.conditional_assert_eq(condition.target, rd0.program_counter, rd1.program_counter); + builder.conditional_assert_eq(condition.target, rd0.is_kernel, rd1.is_kernel); + builder.conditional_assert_eq(condition.target, rd0.stack_len, rd1.stack_len); + for i in 0..8 { + builder.conditional_assert_eq(condition.target, rd0.stack_top[i], rd1.stack_top[i]); + } + builder.conditional_assert_eq(condition.target, rd0.context, rd1.context); + builder.conditional_assert_eq(condition.target, rd0.gas_used, rd1.gas_used); + } } #[derive(Debug, Clone, PartialEq, Eq)] @@ -1283,4 +1412,22 @@ impl MemCapTarget { } } } + + /// If `condition`, asserts that `mc0 == mc1`. + pub(crate) fn conditional_assert_eq, const D: usize>( + builder: &mut CircuitBuilder, + condition: BoolTarget, + mc0: Self, + mc1: Self, + ) { + for i in 0..mc0.mem_cap.0.len() { + for j in 0..NUM_HASH_OUT_ELTS { + builder.conditional_assert_eq( + condition.target, + mc0.mem_cap.0[i].elements[j], + mc1.mem_cap.0[i].elements[j], + ); + } + } + } } diff --git a/evm_arithmetization/src/prover.rs b/evm_arithmetization/src/prover.rs index 0858a1109..a660d88bf 100644 --- a/evm_arithmetization/src/prover.rs +++ b/evm_arithmetization/src/prover.rs @@ -5,6 +5,7 @@ use anyhow::{anyhow, Result}; use itertools::Itertools; use once_cell::sync::Lazy; use plonky2::field::extension::Extendable; +use plonky2::field::goldilocks_field::GoldilocksField; use plonky2::field::polynomial::PolynomialValues; use plonky2::field::types::Field; use plonky2::fri::oracle::PolynomialBatch; @@ -25,7 +26,7 @@ use starky::stark::Stark; use crate::all_stark::{AllStark, Table, NUM_TABLES}; use crate::cpu::kernel::aggregator::KERNEL; use crate::cpu::kernel::interpreter::{set_registers_and_run, ExtraSegmentData, Interpreter}; -use crate::generation::state::GenerationState; +use crate::generation::state::{GenerationState, State}; use crate::generation::{generate_traces, GenerationInputs}; use crate::get_challenges::observe_public_values; use crate::memory::segments::Segment; @@ -54,6 +55,11 @@ pub struct GenerationSegmentData { } impl GenerationSegmentData { + /// Indicates if this segment is a dummy one. + pub fn is_dummy(&self) -> bool { + self.is_dummy + } + /// Retrieves the index of this segment. pub fn segment_index(&self) -> usize { self.segment_index @@ -517,6 +523,87 @@ fn build_segment_data( } } +pub struct SegmentDataIterator<'a> { + pub partial_next_data: Option, + pub inputs: &'a GenerationInputs, + pub max_cpu_len_log: Option, +} + +type F = GoldilocksField; +impl<'a> Iterator for SegmentDataIterator<'a> { + type Item = (GenerationInputs, GenerationSegmentData); + + fn next(&mut self) -> Option { + let cur_and_next_data = generate_next_segment::( + self.max_cpu_len_log, + self.inputs, + self.partial_next_data.clone(), + ); + + if cur_and_next_data.is_some() { + let (data, next_data) = cur_and_next_data.expect("Data cannot be `None`"); + self.partial_next_data = next_data; + Some((self.inputs.clone(), data)) + } else { + None + } + } +} + +/// Returns the data for the current segment, as well as the data -- except +/// registers_after -- for the next segment. +pub(crate) fn generate_next_segment( + max_cpu_len_log: Option, + inputs: &GenerationInputs, + partial_segment_data: Option, +) -> Option<(GenerationSegmentData, Option)> { + let mut interpreter = Interpreter::::new_with_generation_inputs( + KERNEL.global_labels["init"], + vec![], + inputs, + max_cpu_len_log, + ); + + // Get the (partial) current segment data, if it is provided. Otherwise, + // initialize it. + let mut segment_data = if let Some(partial) = partial_segment_data { + if partial.registers_after.program_counter == KERNEL.global_labels["halt"] { + return None; + } + interpreter + .get_mut_generation_state() + .set_segment_data(&partial); + interpreter.generation_state.memory = partial.memory.clone(); + partial + } else { + build_segment_data(0, None, None, None, &interpreter) + }; + + let segment_index = segment_data.segment_index; + + // Run the interpreter to get `registers_after` and the partial data for the + // next segment. + if let Ok((updated_registers, mem_after)) = + set_registers_and_run(segment_data.registers_after, &mut interpreter) + { + // Set `registers_after` correctly and push the data. + let before_registers = segment_data.registers_after; + + let partial_segment_data = Some(build_segment_data( + segment_index + 1, + Some(updated_registers), + Some(updated_registers), + mem_after, + &interpreter, + )); + + segment_data.registers_after = updated_registers; + Some((segment_data, partial_segment_data)) + } else { + None + } +} + /// Returns a vector containing the data required to generate all the segments /// of a transaction. pub fn generate_all_data_segments( @@ -555,39 +642,6 @@ pub fn generate_all_data_segments( ); } - // We need at least two segments to prove a segment aggregation. - if all_seg_data.len() == 1 { - let mut interpreter = Interpreter::::new_dummy_with_generation_inputs( - KERNEL.global_labels["init"], - vec![], - inputs, - ); - - let dummy_seg = GenerationSegmentData { - is_dummy: true, - registers_before: RegistersState::new(), - registers_after: RegistersState::new(), - max_cpu_len_log: interpreter.get_max_cpu_len_log(), - ..all_seg_data[0].clone() - }; - let (updated_registers, mem_after) = - set_registers_and_run(dummy_seg.registers_after, &mut interpreter)?; - let mut mem_after = mem_after - .expect("The interpreter was running, so it should have returned a MemoryState"); - // During the interpreter initialization, we set the trie data and initialize - // `RlpRaw`. But we do not want to pass this information to the first actual - // segment in `MemBefore` since the values are not actually accessed in the - // dummy generation. - mem_after.contexts[0].segments[Segment::RlpRaw.unscale()].content = vec![]; - mem_after.contexts[0].segments[Segment::TrieData.unscale()].content = vec![]; - all_seg_data[0].memory = mem_after; - - all_seg_data.insert(0, dummy_seg); - - // We need to update the index of the non-dummy segment, now at position 1. - all_seg_data[1].segment_index += 1; - } - Ok(all_seg_data) } diff --git a/evm_arithmetization/tests/empty_txn_list.rs b/evm_arithmetization/tests/empty_txn_list.rs index 579501411..b4a46c2e4 100644 --- a/evm_arithmetization/tests/empty_txn_list.rs +++ b/evm_arithmetization/tests/empty_txn_list.rs @@ -79,17 +79,7 @@ fn test_empty_txn_list() -> anyhow::Result<()> { let all_circuits = AllRecursiveCircuits::::new( &all_stark, // Minimal ranges to prove an empty list - &[ - 16..17, - 8..10, - 7..11, - 4..15, - 8..11, - 4..13, - 11..18, - 8..18, - 7..17, - ], + &[16..17, 8..9, 8..10, 5..8, 8..9, 4..6, 16..17, 16..17, 7..17], &config, ); @@ -188,42 +178,43 @@ fn test_empty_txn_list() -> anyhow::Result<()> { ); // We can duplicate the proofs here because the state hasn't mutated. - let (segmented_agg_proof, segmented_agg_public_values) = all_circuits - .prove_segment_aggregation( - false, - &segment_proofs_data[0].proof_with_pis, - segment_proofs_data[0].public_values.clone(), - false, - &segment_proofs_data[1].proof_with_pis, - segment_proofs_data[1].public_values.clone(), - )?; - all_circuits.verify_segment_aggregation(&segmented_agg_proof)?; - - let (segmented_agg_proof, segmented_agg_public_values) = all_circuits - .prove_segment_aggregation( - true, - &segmented_agg_proof, - segmented_agg_public_values, - false, - &segment_proofs_data[2].proof_with_pis, - segment_proofs_data[2].public_values.clone(), - )?; - all_circuits.verify_segment_aggregation(&segmented_agg_proof)?; + let aggregation_output_data = all_circuits.prove_segment_aggregation( + false, + &segment_proofs_data[0], + false, + &segment_proofs_data[1], + )?; + all_circuits.verify_segment_aggregation(&aggregation_output_data.proof_with_pis)?; + + let aggregation_output_data = all_circuits.prove_segment_aggregation( + true, + &aggregation_output_data, + false, + &segment_proofs_data[2], + )?; + all_circuits.verify_segment_aggregation(&aggregation_output_data.proof_with_pis)?; // Test retrieved public values from the proof public inputs. let retrieved_public_values = PublicValues::from_public_inputs( - &segmented_agg_proof.public_inputs, - segmented_agg_public_values.mem_before.mem_cap.len(), + &aggregation_output_data.proof_with_pis.public_inputs, + aggregation_output_data + .public_values + .mem_before + .mem_cap + .len(), + ); + assert_eq!( + retrieved_public_values, + aggregation_output_data.public_values ); - assert_eq!(retrieved_public_values, segmented_agg_public_values); let (txn_proof, txn_public_values) = all_circuits.prove_transaction_aggregation( false, - &segmented_agg_proof, - segmented_agg_public_values.clone(), + &aggregation_output_data.proof_with_pis, + aggregation_output_data.public_values.clone(), false, - &segmented_agg_proof, - segmented_agg_public_values, + &aggregation_output_data.proof_with_pis, + aggregation_output_data.public_values, )?; all_circuits.verify_txn_aggregation(&txn_proof)?; diff --git a/evm_arithmetization/tests/log_opcode.rs b/evm_arithmetization/tests/log_opcode.rs index 5562ab803..0008f462e 100644 --- a/evm_arithmetization/tests/log_opcode.rs +++ b/evm_arithmetization/tests/log_opcode.rs @@ -450,13 +450,13 @@ fn test_log_with_aggreg() -> anyhow::Result<()> { &all_stark, &[ 16..17, - 8..15, - 7..17, - 4..15, + 11..15, + 12..17, 8..11, - 4..13, - 16..20, - 8..18, + 8..9, + 6..12, + 17..20, + 16..17, 7..17, ], &config, @@ -469,12 +469,22 @@ fn test_log_with_aggreg() -> anyhow::Result<()> { &all_stark, &config, inputs_first, - max_cpu_len_log, + // We want only one segment. + 20, &mut timing, None, )?; - assert_eq!(segment_proofs_data_first.len(), 2); + assert_eq!(segment_proofs_data_first.len(), 2); // second one is a dummy segment + + let segment_agg_prover_output_data_first = all_circuits.prove_segment_aggregation( + false, + &segment_proofs_data_first[0], + false, + &segment_proofs_data_first[1], + )?; + all_circuits + .verify_segment_aggregation(&segment_agg_prover_output_data_first.proof_with_pis)?; // The gas used and transaction number are fed to the next transaction, so the // two proofs can be correctly aggregated. @@ -614,35 +624,22 @@ fn test_log_with_aggreg() -> anyhow::Result<()> { all_circuits.verify_root(proof.clone())?; } - let (segment_agg_proof_first, updated_agg_public_values_first) = all_circuits - .prove_segment_aggregation( - false, - &segment_proofs_data_first[0].proof_with_pis, - segment_proofs_data_first[0].public_values.clone(), - false, - &segment_proofs_data_first[1].proof_with_pis, - segment_proofs_data_first[1].public_values.clone(), - )?; - all_circuits.verify_segment_aggregation(&segment_agg_proof_first)?; - - let (segment_agg_proof_second, updated_agg_public_values_second) = all_circuits - .prove_segment_aggregation( - false, - &segment_proofs_data_second[0].proof_with_pis, - segment_proofs_data_second[0].public_values.clone(), - false, - &segment_proofs_data_second[1].proof_with_pis, - segment_proofs_data_second[1].public_values.clone(), - )?; - all_circuits.verify_segment_aggregation(&segment_agg_proof_second)?; + let segment_agg_prover_output_data_second = all_circuits.prove_segment_aggregation( + false, + &segment_proofs_data_second[0], + false, + &segment_proofs_data_second[1], + )?; + all_circuits + .verify_segment_aggregation(&segment_agg_prover_output_data_second.proof_with_pis)?; let (txn_proof, txn_pv) = all_circuits.prove_transaction_aggregation( false, - &segment_agg_proof_first, - updated_agg_public_values_first, + &segment_agg_prover_output_data_first.proof_with_pis, + segment_agg_prover_output_data_first.public_values, false, - &segment_agg_proof_second, - updated_agg_public_values_second, + &segment_agg_prover_output_data_second.proof_with_pis, + segment_agg_prover_output_data_second.public_values, )?; let (first_block_proof, _block_public_values) = @@ -714,23 +711,21 @@ fn test_log_with_aggreg() -> anyhow::Result<()> { all_circuits.verify_root(proof.clone())?; } - let (segment_agg_proof, updated_agg_public_values) = all_circuits.prove_segment_aggregation( + let segment_agg_prover_output_data = all_circuits.prove_segment_aggregation( false, - &segment_proofs_data[0].proof_with_pis, - segment_proofs_data[0].public_values.clone(), + &segment_proofs_data[0], false, - &segment_proofs_data[1].proof_with_pis, - segment_proofs_data[1].public_values.clone(), + &segment_proofs_data[1], )?; - all_circuits.verify_segment_aggregation(&segment_agg_proof)?; + all_circuits.verify_segment_aggregation(&segment_agg_prover_output_data.proof_with_pis)?; let (second_txn_proof, second_txn_pvs) = all_circuits.prove_transaction_aggregation( false, - &segment_agg_proof, - updated_agg_public_values.clone(), + &segment_agg_prover_output_data.proof_with_pis, + segment_agg_prover_output_data.public_values.clone(), false, - &segment_agg_proof, - updated_agg_public_values, + &segment_agg_prover_output_data.proof_with_pis, + segment_agg_prover_output_data.public_values, )?; let (second_block_proof, _block_public_values) = all_circuits.prove_block( None, // We don't specify a previous proof, considering block 1 as the new checkpoint. diff --git a/proof_gen/Cargo.toml b/proof_gen/Cargo.toml index b0f6b680e..044dd3e72 100644 --- a/proof_gen/Cargo.toml +++ b/proof_gen/Cargo.toml @@ -15,6 +15,7 @@ log = { workspace = true } paste = "1.0.14" plonky2 = { workspace = true } serde = { workspace = true } +hashbrown = { workspace = true } # Local dependencies evm_arithmetization = { version = "0.1.3", path = "../evm_arithmetization" } diff --git a/proof_gen/src/proof_gen.rs b/proof_gen/src/proof_gen.rs index 39075ab02..cd6d48a52 100644 --- a/proof_gen/src/proof_gen.rs +++ b/proof_gen/src/proof_gen.rs @@ -3,10 +3,13 @@ use std::sync::{atomic::AtomicBool, Arc}; -use evm_arithmetization::{prover::GenerationSegmentData, AllStark, GenerationInputs, StarkConfig}; +use evm_arithmetization::{ + fixed_recursive_verifier::ProverOutputData, prover::GenerationSegmentData, AllStark, + GenerationInputs, StarkConfig, +}; +use hashbrown::HashMap; use plonky2::{ gates::noop::NoopGate, - iop::witness::PartialWitness, plonk::{circuit_builder::CircuitBuilder, circuit_data::CircuitConfig}, util::timing::TimingTree, }; @@ -17,7 +20,7 @@ use crate::{ SegmentAggregatableProof, TxnAggregatableProof, }, prover_state::ProverState, - types::{Config, Field, PlonkyProofIntern, EXTENSION_DEGREE}, + types::{Field, PlonkyProofIntern, EXTENSION_DEGREE}, }; /// A type alias for `Result`. @@ -70,23 +73,46 @@ pub fn generate_segment_proof( /// Generates an aggregation proof from two child proofs. /// /// Note that the child proofs may be either transaction or aggregation proofs. +/// +/// If a transaction only contains a single segment, this function must still be +/// called to generate a `GeneratedSegmentAggProof`. In that case, you can set +/// `has_dummy` to `true`, and provide an arbitrary proof for the right child. pub fn generate_segment_agg_proof( p_state: &ProverState, lhs_child: &SegmentAggregatableProof, rhs_child: &SegmentAggregatableProof, + has_dummy: bool, ) -> ProofGenResult { - let (intern, p_vals) = p_state + if has_dummy { + assert!( + !lhs_child.is_agg(), + "Cannot have a dummy segment with an aggregation." + ); + } + + let lhs_prover_output_data = ProverOutputData { + is_dummy: false, + proof_with_pis: lhs_child.intern().clone(), + public_values: lhs_child.public_values(), + }; + let rhs_prover_output_data = ProverOutputData { + is_dummy: has_dummy, + proof_with_pis: rhs_child.intern().clone(), + public_values: rhs_child.public_values(), + }; + let agg_output_data = p_state .state .prove_segment_aggregation( lhs_child.is_agg(), - lhs_child.intern(), - lhs_child.public_values(), + &lhs_prover_output_data, rhs_child.is_agg(), - rhs_child.intern(), - rhs_child.public_values(), + &rhs_prover_output_data, ) .map_err(|err| err.to_string())?; + let p_vals = agg_output_data.public_values; + let intern = agg_output_data.proof_with_pis; + Ok(GeneratedSegmentAggProof { p_vals, intern }) } @@ -154,13 +180,6 @@ pub fn dummy_proof() -> ProofGenResult { builder.add_gate(NoopGate, vec![]); let circuit_data = builder.build::<_>(); - let inputs = PartialWitness::new(); - - plonky2::plonk::prover::prove::( - &circuit_data.prover_only, - &circuit_data.common, - inputs, - &mut TimingTree::default(), - ) - .map_err(|e| ProofGenError(e.to_string())) + plonky2::recursion::dummy_circuit::dummy_proof(&circuit_data, HashMap::default()) + .map_err(|e| ProofGenError(e.to_string())) } diff --git a/proof_gen/src/proof_types.rs b/proof_gen/src/proof_types.rs index 94f5ff3be..fea8f845f 100644 --- a/proof_gen/src/proof_types.rs +++ b/proof_gen/src/proof_types.rs @@ -58,7 +58,7 @@ pub struct GeneratedBlockProof { #[derive(Clone, Debug, Deserialize, Serialize)] pub enum SegmentAggregatableProof { /// The underlying proof is a segment proof. - Txn(GeneratedSegmentProof), + Seg(GeneratedSegmentProof), /// The underlying proof is an aggregation proof. Agg(GeneratedSegmentAggProof), } @@ -68,6 +68,9 @@ pub enum SegmentAggregatableProof { /// away whether or not the proof was a txn or agg proof. #[derive(Clone, Debug, Deserialize, Serialize)] pub enum TxnAggregatableProof { + /// The underlying proof is a segment proof. It first needs to be aggregated + /// with another segment proof, or a dummy one. + Segment(GeneratedSegmentProof), /// The underlying proof is a transaction proof. Txn(GeneratedSegmentAggProof), /// The underlying proof is an aggregation proof. @@ -77,21 +80,21 @@ pub enum TxnAggregatableProof { impl SegmentAggregatableProof { pub(crate) fn public_values(&self) -> PublicValues { match self { - SegmentAggregatableProof::Txn(info) => info.p_vals.clone(), + SegmentAggregatableProof::Seg(info) => info.p_vals.clone(), SegmentAggregatableProof::Agg(info) => info.p_vals.clone(), } } pub(crate) const fn is_agg(&self) -> bool { match self { - SegmentAggregatableProof::Txn(_) => false, + SegmentAggregatableProof::Seg(_) => false, SegmentAggregatableProof::Agg(_) => true, } } pub(crate) const fn intern(&self) -> &PlonkyProofIntern { match self { - SegmentAggregatableProof::Txn(info) => &info.intern, + SegmentAggregatableProof::Seg(info) => &info.intern, SegmentAggregatableProof::Agg(info) => &info.intern, } } @@ -100,6 +103,7 @@ impl SegmentAggregatableProof { impl TxnAggregatableProof { pub(crate) fn public_values(&self) -> PublicValues { match self { + TxnAggregatableProof::Segment(info) => info.p_vals.clone(), TxnAggregatableProof::Txn(info) => info.p_vals.clone(), TxnAggregatableProof::Agg(info) => info.p_vals.clone(), } @@ -107,6 +111,7 @@ impl TxnAggregatableProof { pub(crate) fn is_agg(&self) -> bool { match self { + TxnAggregatableProof::Segment(_) => false, TxnAggregatableProof::Txn(_) => false, TxnAggregatableProof::Agg(_) => true, } @@ -114,6 +119,7 @@ impl TxnAggregatableProof { pub(crate) fn intern(&self) -> &PlonkyProofIntern { match self { + TxnAggregatableProof::Segment(info) => &info.intern, TxnAggregatableProof::Txn(info) => &info.intern, TxnAggregatableProof::Agg(info) => &info.intern, } @@ -122,7 +128,7 @@ impl TxnAggregatableProof { impl From for SegmentAggregatableProof { fn from(v: GeneratedSegmentProof) -> Self { - Self::Txn(v) + Self::Seg(v) } } @@ -148,9 +154,7 @@ impl From for TxnAggregatableProof { fn from(v: SegmentAggregatableProof) -> Self { match v { SegmentAggregatableProof::Agg(agg) => TxnAggregatableProof::Txn(agg), - SegmentAggregatableProof::Txn(_) => { - panic!("Should be an aggregation by now. Missing segment?") - } + SegmentAggregatableProof::Seg(seg) => TxnAggregatableProof::Segment(seg), } } }