Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify interpreter and witness generation logic. #56

Merged
merged 16 commits into from
Feb 29, 2024
35 changes: 24 additions & 11 deletions evm_arithmetization/src/cpu/kernel/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ use mpt_trie::partial_trie::PartialTrie;
use plonky2::field::goldilocks_field::GoldilocksField;
use plonky2::field::types::Field;

use crate::byte_packing::byte_packing_stark::BytePackingOp;
use crate::cpu::columns::CpuColumnsView;
use crate::cpu::kernel::aggregator::KERNEL;
use crate::cpu::kernel::constants::context_metadata::ContextMetadata;
use crate::cpu::kernel::constants::global_metadata::GlobalMetadata;
Expand All @@ -20,6 +22,8 @@ use crate::generation::state::{
all_withdrawals_prover_inputs_reversed, GenerationState, GenerationStateCheckpoint,
};
use crate::generation::{state::State, GenerationInputs};
use crate::keccak_sponge::columns::KECCAK_WIDTH_BYTES;
use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeOp;
use crate::memory::segments::Segment;
use crate::util::h2u;
use crate::witness::errors::ProgramError;
Expand All @@ -31,8 +35,7 @@ use crate::witness::state::RegistersState;
use crate::witness::transition::{
decode, fill_op_flag, get_op_special_length, log_kernel_instruction, Transition,
};

type F = GoldilocksField;
use crate::{arithmetic, keccak, logic};

/// Halt interpreter execution whenever a jump to this offset is done.
const DEFAULT_HALT_OFFSET: usize = 0xdeadbeef;
Expand Down Expand Up @@ -746,10 +749,6 @@ impl<F: Field> State<F> for Interpreter<F> {
}
}

fn is_generation(&mut self) -> bool {
false
}

fn insert_preinitialized_segment(&mut self, segment: Segment, values: MemorySegmentState) {
self.generation_state
.memory
Expand Down Expand Up @@ -790,19 +789,33 @@ impl<F: Field> State<F> for Interpreter<F> {
self.clock += 1
}

fn get_clock(&mut self) -> usize {
fn get_clock(&self) -> usize {
self.clock
}

fn push_cpu(&mut self, val: CpuColumnsView<F>) {}

fn push_logic(&mut self, op: logic::Operation) {}

fn push_arithmetic(&mut self, op: arithmetic::Operation) {}

fn push_byte_packing(&mut self, op: BytePackingOp) {}

fn push_keccak(&mut self, input: [u64; keccak::keccak_stark::NUM_INPUTS], clock: usize) {}

fn push_keccak_bytes(&mut self, input: [u8; KECCAK_WIDTH_BYTES], clock: usize) {}

fn push_keccak_sponge(&mut self, op: KeccakSpongeOp) {}

fn rollback(&mut self, checkpoint: GenerationStateCheckpoint) {
self.generation_state.rollback(checkpoint)
}

fn get_context(&mut self) -> usize {
fn get_context(&self) -> usize {
self.context()
}

fn get_halt_context(&mut self) -> Option<usize> {
fn get_halt_context(&self) -> Option<usize> {
self.halt_context
}

Expand All @@ -816,8 +829,8 @@ impl<F: Field> State<F> for Interpreter<F> {
self.apply_memops();
}

fn get_stack(&mut self) -> Vec<U256> {
self.stack().clone()
fn get_stack(&self) -> Vec<U256> {
self.stack()
}

fn get_halt_offsets(&self) -> Vec<usize> {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#[cfg(test)]
mod bn {
use std::collections::HashMap;

use anyhow::Result;
use ethereum_types::U256;
Expand Down
6 changes: 3 additions & 3 deletions evm_arithmetization/src/generation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ fn simulate_cpu<F: Field>(state: &mut GenerationState<F>) -> anyhow::Result<()>
state.run_cpu()?;

let pc = state.registers.program_counter;
// Padding
// Setting the values of padding rows.
let mut row = CpuColumnsView::<F>::default();
row.clock = F::from_canonical_usize(state.traces.clock());
row.context = F::from_canonical_usize(state.registers.context);
Expand All @@ -325,8 +325,8 @@ fn simulate_cpu<F: Field>(state: &mut GenerationState<F>) -> anyhow::Result<()>
row.stack_len = F::from_canonical_usize(state.registers.stack_len);

loop {
// If our trace length is a power of 2, stop.
state.traces.push_cpu(true, row);
// Padding to a power of 2.
state.push_cpu(row);
row.clock += F::ONE;
if state.traces.clock().is_power_of_two() {
break;
Expand Down
16 changes: 4 additions & 12 deletions evm_arithmetization/src/generation/prover_input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -396,8 +396,6 @@ impl<F: Field> GenerationState<F> {
/// Simulate the user's code and store all the jump addresses with their
/// respective contexts.
fn generate_jumpdest_table(&mut self) -> Result<(), ProgramError> {
let _checkpoint = self.checkpoint();

// Simulate the user's code and (unnecessarily) part of the kernel code,
// skipping the validate table call
self.jumpdest_table = simulate_cpu_and_get_user_jumps("terminate_common", self);
Expand Down Expand Up @@ -632,34 +630,28 @@ impl<'a> AccList<'a> {
}
}

fn get_val(value: Option<U256>) -> U256 {
match value {
Some(v) => v,
None => 0.into(),
}
}
impl<'a> Iterator for AccList<'a> {
type Item = (usize, U256, U256);

fn next(&mut self) -> Option<Self::Item> {
if let Ok(new_pos) =
u256_to_usize(get_val(self.access_list_mem[self.pos + self.node_size - 1]))
u256_to_usize(self.access_list_mem[self.pos + self.node_size - 1].unwrap_or_default())
{
let old_pos = self.pos;
self.pos = new_pos - self.offset;
if self.node_size == 2 {
// addresses
Some((
old_pos,
get_val(self.access_list_mem[self.pos]),
self.access_list_mem[self.pos].unwrap_or_default(),
U256::zero(),
))
} else {
// storage_keys
Some((
old_pos,
get_val(self.access_list_mem[self.pos]),
get_val(self.access_list_mem[self.pos + 1]),
self.access_list_mem[self.pos].unwrap_or_default(),
self.access_list_mem[self.pos + 1].unwrap_or_default(),
))
}
} else {
Expand Down
99 changes: 75 additions & 24 deletions evm_arithmetization/src/generation/state.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
use std::collections::HashMap;
use std::mem::size_of;

use anyhow::bail;
use ethereum_types::{Address, BigEndianHash, H160, H256, U256};
use itertools::Itertools;
use keccak_hash::keccak;
use log::log_enabled;
use plonky2::field::types::Field;

use super::mpt::{load_all_mpts, TrieRootPtrs};
use super::TrieInputs;
use crate::byte_packing::byte_packing_stark::BytePackingOp;
use crate::cpu::kernel::aggregator::KERNEL;
use crate::cpu::kernel::constants::context_metadata::ContextMetadata;
use crate::cpu::membus::NUM_GP_CHANNELS;
use crate::cpu::stack::MAX_USER_STACK_SIZE;
use crate::generation::rlp::all_rlp_prover_inputs_reversed;
use crate::generation::CpuColumnsView;
use crate::generation::GenerationInputs;
use crate::keccak_sponge::columns::KECCAK_WIDTH_BYTES;
use crate::keccak_sponge::keccak_sponge_stark::KeccakSpongeOp;
use crate::memory::segments::Segment;
use crate::util::u256_to_usize;
use crate::witness::errors::ProgramError;
Expand All @@ -31,56 +36,103 @@ use crate::witness::transition::{
use crate::witness::util::{
fill_channel_with_value, mem_read_gp_with_log_and_fill, stack_peek, stack_pop_with_log_and_fill,
};
use crate::{arithmetic, keccak, logic};

/// A State is either an `Interpreter` (used for tests and jumpdest analysis) or
/// a `GenerationState`.
pub(crate) trait State<F: Field> {
/// Returns a `State`'s `Checkpoint`.
/// Returns a `State`'s latest `Checkpoint`.
fn checkpoint(&mut self) -> GenerationStateCheckpoint;

fn is_generation(&mut self) -> bool {
true
}
/// Increments the `gas_used` register by a value `n`.
fn incr_gas(&mut self, n: u64);

/// Increments the `program_counter` register by a value `n`.
fn incr_pc(&mut self, n: usize);

/// Returns a `State`'s registers.
/// Returns a `RegistersState`.
fn get_registers(&self) -> RegistersState;

/// Returns a `State`'s mutable registers.
/// Returns a mutable reference to the `State`'s registers.
fn get_mut_registers(&mut self) -> &mut RegistersState;

/// Returns the value stored at address `address` in a `State`.
/// Returns the value stored at address `address` in a `State`, or 0 if the
/// memory is unset at this position.
fn get_from_memory(&mut self, address: MemoryAddress) -> U256;

/// Returns a mutable `GenerationState` from a `State`.
/// Returns a mutable reference to a `State`'s `GenerationState`.
fn get_mut_generation_state(&mut self) -> &mut GenerationState<F>;

// /// Returns true if a `State` is a `GenerationState` and false otherwise.
// fn is_generation_state(&mut self) -> bool;

/// Increments the clock of an `Interpreter`'s clock.
fn incr_interpreter_clock(&mut self);

/// Returns the value of a `State`'s clock.
fn get_clock(&mut self) -> usize;
fn get_clock(&self) -> usize;

/// Rolls back a `State`.
fn rollback(&mut self, checkpoint: GenerationStateCheckpoint);

/// Returns a `State`'s stack.
fn get_stack(&mut self) -> Vec<U256>;
fn get_stack(&self) -> Vec<U256>;

/// Returns the current context.
fn get_context(&mut self) -> usize;
fn get_context(&self) -> usize;

fn get_halt_context(&mut self) -> Option<usize> {
/// Returns the context in which the jumpdest analysis should end.
fn get_halt_context(&self) -> Option<usize> {
None
}

fn push_cpu(&mut self, val: CpuColumnsView<F>) {
self.get_mut_generation_state().traces.cpu.push(val);
}

fn push_logic(&mut self, op: logic::Operation) {
self.get_mut_generation_state().traces.logic_ops.push(op);
}

fn push_arithmetic(&mut self, op: arithmetic::Operation) {
self.get_mut_generation_state()
.traces
.arithmetic_ops
.push(op);
}

fn push_memory(&mut self, op: MemoryOp) {
self.get_mut_generation_state().traces.memory_ops.push(op);
}

fn push_byte_packing(&mut self, op: BytePackingOp) {
self.get_mut_generation_state()
.traces
.byte_packing_ops
.push(op);
}

fn push_keccak(&mut self, input: [u64; keccak::keccak_stark::NUM_INPUTS], clock: usize) {
self.get_mut_generation_state()
.traces
.keccak_inputs
.push((input, clock));
}

fn push_keccak_bytes(&mut self, input: [u8; KECCAK_WIDTH_BYTES], clock: usize) {
let chunks = input
.chunks(size_of::<u64>())
.map(|chunk| u64::from_le_bytes(chunk.try_into().unwrap()))
.collect_vec()
.try_into()
.unwrap();
self.push_keccak(chunks, clock);
}

fn push_keccak_sponge(&mut self, op: KeccakSpongeOp) {
self.get_mut_generation_state()
.traces
.keccak_sponge_ops
.push(op);
}

/// Returns the content of a the `KernelGeneral` segment of a `State`.
fn mem_get_kernel_content(&self) -> Vec<Option<U256>>;

Expand All @@ -96,8 +148,8 @@ pub(crate) trait State<F: Field> {

fn is_preinitialized_segment(&self, segment: usize) -> bool;

/// Simulates a CPU. It only generates the traces if the `State` is a
/// `GenerationState`. Otherwise, it simply simulates all ooperations.
/// Simulates the CPU. It only generates the traces if the `State` is a
/// `GenerationState`.
fn run_cpu(&mut self) -> anyhow::Result<()>
where
Self: Transition<F>,
Expand All @@ -106,12 +158,12 @@ pub(crate) trait State<F: Field> {
let halt_offsets = self.get_halt_offsets();

loop {
// If we've reached the kernel's halt routine.
let registers = self.get_registers();
let pc = registers.program_counter;

let halt = registers.is_kernel && halt_offsets.contains(&pc);

// If we've reached the kernel's halt routine, halt.
if halt {
if let Some(halt_context) = self.get_halt_context() {
if registers.context == halt_context {
Expand Down Expand Up @@ -149,8 +201,7 @@ pub(crate) trait State<F: Field> {
let checkpoint = self.checkpoint();

let (row, _) = self.base_row();
let is_generation = self.is_generation();
generate_exception(exc_code, self, row, is_generation);
generate_exception(exc_code, self, row);

self.apply_ops(checkpoint);

Expand Down Expand Up @@ -432,7 +483,7 @@ impl<F: Field> State<F> for GenerationState<F> {
fn incr_interpreter_clock(&mut self) {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'd rather have the trait define the blanket no-op impl, so that we don't have to keep a reference to something that's interpreter specific here.

Copy link
Contributor Author

@LindaGuiga LindaGuiga Feb 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, in the end, I think it's better to increment the interpreter clock in push_cpu, since this is when the generation "clock" is increased. This would get rid of that method altogether, and would insure a better unification (we noticed that the clock of the interpreter was not increased when calling generate_exception)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yeah it's even better


/// Returns the value of a `State`'s clock.
fn get_clock(&mut self) -> usize {
fn get_clock(&self) -> usize {
self.traces.clock()
}

Expand All @@ -442,11 +493,11 @@ impl<F: Field> State<F> for GenerationState<F> {
}

/// Returns a `State`'s stack.
fn get_stack(&mut self) -> Vec<U256> {
fn get_stack(&self) -> Vec<U256> {
self.stack()
}

fn get_context(&mut self) -> usize {
fn get_context(&self) -> usize {
self.registers.context
}

Expand Down Expand Up @@ -536,7 +587,7 @@ impl<F: Field> Transition<F> for GenerationState<F> {
MemoryOpKind::Read,
self.registers.stack_top,
);
self.traces.push_memory(mem_op);
self.push_memory(mem_op);
}
self.registers.is_stack_top_read = false;

Expand Down
5 changes: 5 additions & 0 deletions evm_arithmetization/src/witness/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,11 @@ impl MemoryState {
Some(val)
}

/// Returns a memory value, or 0 if the memory is unset. If we have some
/// preinitialized segments (in interpreter mode), then the values might not
/// be stored in memory yet. If the value in memory is not set and the
/// address is part of the preinitialized segment, then we return the
/// preinitialized value instead.
pub(crate) fn get_with_init(&self, address: MemoryAddress) -> U256 {
match self.get(address) {
Some(val) => val,
Expand Down
Loading