diff --git a/CHANGELOG.md b/CHANGELOG.md index 329aa5cbd..6959c794c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,9 +8,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Changed - -* Refactor accessed lists as sorted linked lists ([#30](https://github.com/0xPolygonZero/zk_evm/pull/30)) -* Unify interpreter and prover witness generation ([#56](https://github.com/0xPolygonZero/zk_evm/pull/56)) +- Add verification for invalid jumps. [#36](https://github.com/0xPolygonZero/zk_evm/pull/36) +- Refactor accessed lists as sorted linked lists ([#30](https://github.com/0xPolygonZero/zk_evm/pull/30)) +- Change visibility of `compact` mod ([#57](https://github.com/0xPolygonZero/zk_evm/pull/57)) +- Unify interpreter and prover witness generation ([#56](https://github.com/0xPolygonZero/zk_evm/pull/56)) ## [0.1.0] - 2024-02-21 -* Initial release. \ No newline at end of file +* Initial release. diff --git a/evm_arithmetization/src/cpu/kernel/asm/core/exception.asm b/evm_arithmetization/src/cpu/kernel/asm/core/exception.asm index 6ce2d676d..a35754c57 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/core/exception.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/core/exception.asm @@ -132,6 +132,8 @@ global invalid_jump_jumpi_destination_common: %shr_const(32) %jumpi(fault_exception) // This keeps one copy of jump_dest on the stack, but that's fine. // jump_dest is a valid address; check if it points to a `JUMP_DEST`. + DUP1 + %verify_non_jumpdest %mload_current(@SEGMENT_JUMPDEST_BITS) // stack: is_valid_jumpdest %jumpi(panic) // Trap should never have been entered. diff --git a/evm_arithmetization/src/cpu/kernel/asm/core/jumpdest_analysis.asm b/evm_arithmetization/src/cpu/kernel/asm/core/jumpdest_analysis.asm index 934d1f629..49b59fe63 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/core/jumpdest_analysis.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/core/jumpdest_analysis.asm @@ -277,13 +277,16 @@ global write_table_if_jumpdest: // pos 0102030405060708091011121314151617181920212223242526272829303132 PUSH 0x8080808080808080808080808080808080808080808080808080808080808080 AND - %jump_neq_const(0x8080808080808080808080808080808080808080808080808080808080808080, return_pop_opcode) + // If we received a proof it MUST be valid or we abort immediately. This + // is especially important for non-jumpdest proofs. Otherwise a malicious + // prover might mark a valid jumpdest as invalid by providing an invalid proof + // that makes verify_non_jumpdest return prematurely. + %assert_eq_const(0x8080808080808080808080808080808080808080808080808080808080808080) POP %add_const(32) // check the remaining path %jump(verify_path_and_write_jumpdest_table) -return_pop_opcode: POP return: // stack: proof_prefix_addr, ctx, jumpdest, retdest @@ -342,3 +345,29 @@ check_proof: %jump(jumpdest_analysis) %%after: %endmacro + +// Non-deterministically find the closest opcode to addr +// and call write_table_if_jumpdest so that `@SEGMENT_JUMPDEST_BITS` +// will contain a 0 if and only if addr is not a jumpdest +// stack: addr, retdest +// stack: (empty) +global verify_non_jumpdest: + // stack: addr, retdest + GET_CONTEXT + SWAP1 + // stack: addr, ctx + PROVER_INPUT(jumpdest_table::non_jumpdest_proof) + // stack: proof, addr, ctx, + // Check that proof <= addr as otherwise it allows + // a malicious prover to leave `@SEGMENT_JUMPDEST_BITS` as 0 + // at position addr while it shouldn't. + DUP2 DUP2 + %assert_le + %write_table_if_jumpdest + JUMP + +%macro verify_non_jumpdest + %stack (addr) -> (addr, %%after) + %jump(verify_non_jumpdest) +%%after: +%endmacro diff --git a/evm_arithmetization/src/cpu/kernel/interpreter.rs b/evm_arithmetization/src/cpu/kernel/interpreter.rs index 6159a79e0..e6e444e1d 100644 --- a/evm_arithmetization/src/cpu/kernel/interpreter.rs +++ b/evm_arithmetization/src/cpu/kernel/interpreter.rs @@ -125,7 +125,7 @@ pub(crate) fn simulate_cpu_and_get_user_jumps( interpreter.run(); - log::debug!("jdt = {:?}", interpreter.jumpdest_table); + log::trace!("jumpdest table = {:?}", interpreter.jumpdest_table); interpreter .generation_state @@ -655,7 +655,7 @@ impl Interpreter { KERNEL.offset_label(self.generation_state.registers.program_counter) } - fn get_jumpdest_bit(&self, offset: usize) -> U256 { + pub(crate) fn get_jumpdest_bit(&self, offset: usize) -> U256 { if self.generation_state.memory.contexts[self.context()].segments [Segment::JumpdestBits.unscale()] .content diff --git a/evm_arithmetization/src/cpu/kernel/tests/core/jumpdest_analysis.rs b/evm_arithmetization/src/cpu/kernel/tests/core/jumpdest_analysis.rs index 96a532834..feddd421a 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/core/jumpdest_analysis.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/core/jumpdest_analysis.rs @@ -144,7 +144,7 @@ fn test_packed_verification() -> Result<()> { interpreter.set_code(CONTEXT, code.clone()); interpreter.generation_state.jumpdest_table = Some(HashMap::from([(3, vec![1, 33])])); - interpreter.run()?; + assert!(interpreter.run().is_err()); assert!(interpreter.get_jumpdest_bits(CONTEXT).is_empty()); @@ -153,3 +153,62 @@ fn test_packed_verification() -> Result<()> { Ok(()) } + +#[test] +fn test_verify_non_jumpdest() -> Result<()> { + // By default the interpreter will skip jumpdest analysis asm and compute + // the jumpdest table bits natively. We avoid that starting 1 line after + // performing the missing first PROVER_INPUT "by hand" + let verify_non_jumpdest = KERNEL.global_labels["verify_non_jumpdest"]; + const CONTEXT: usize = 3; // arbitrary + + let add = get_opcode("ADD"); + let push2 = get_push_opcode(2); + let jumpdest = get_opcode("JUMPDEST"); + + #[rustfmt::skip] + let mut code: Vec = vec![ + add, + jumpdest, + push2, + jumpdest, // part of PUSH2 + jumpdest, // part of PUSH2 + jumpdest, + add, + jumpdest, + ]; + code.extend( + (0..32) + .rev() + .map(get_push_opcode) + .chain(std::iter::once(jumpdest)), + ); + let code_len = code.len(); + + // If we add 1 to each opcode the jumpdest at position 32 is never a valid + // jumpdest + for i in 8..code_len - 1 { + code[i] += 1; + let mut interpreter: Interpreter = + Interpreter::new_with_kernel(verify_non_jumpdest, vec![]); + interpreter.generation_state.registers.context = CONTEXT; + + interpreter.set_code(CONTEXT, code.clone()); + code[i] -= 1; + + // We check that all non jumpdests are indeed non jumpdests + for (j, &opcode) in code + .iter() + .enumerate() + .filter(|&(j, _)| j != 1 && j != 5 && j != 7) + { + interpreter.generation_state.registers.program_counter = verify_non_jumpdest; + interpreter.push(0xDEADBEEFu32.into()); + interpreter.push(j.into()); + interpreter.run()?; + assert!(interpreter.stack().is_empty()); + assert_eq!(interpreter.get_jumpdest_bit(j), U256::zero()); + } + } + Ok(()) +} diff --git a/evm_arithmetization/src/generation/prover_input.rs b/evm_arithmetization/src/generation/prover_input.rs index 869c72496..f5d635e6b 100644 --- a/evm_arithmetization/src/generation/prover_input.rs +++ b/evm_arithmetization/src/generation/prover_input.rs @@ -241,12 +241,13 @@ impl GenerationState { } } - /// Generate either the next used jump address or the proof for the last - /// jump address. + /// Generate either the next used jump address, the proof for the last + /// jump address, or a non-jumpdest proof. fn run_jumpdest_table(&mut self, input_fn: &ProverInputFn) -> Result { match input_fn.0[1].as_str() { "next_address" => self.run_next_jumpdest_table_address(), "next_proof" => self.run_next_jumpdest_table_proof(), + "non_jumpdest_proof" => self.run_next_non_jumpdest_proof(), _ => Err(ProgramError::ProverInputError(InvalidInput)), } } @@ -322,6 +323,20 @@ impl GenerationState { } } + /// Returns a non-jumpdest proof for the address on the top of the stack. A + /// non-jumpdest proof is the closest address to the address on the top of + /// the stack, if the closses address is >= 32, or zero otherwise. + fn run_next_non_jumpdest_proof(&mut self) -> Result { + let code = self.get_current_code()?; + let address = u256_to_usize(stack_peek(self, 0)?)?; + let closest_opcode_addr = get_closest_opcode_address(&code, address); + Ok(if closest_opcode_addr < 32 { + U256::zero() + } else { + closest_opcode_addr.into() + }) + } + /// Returns a pointer to an element in the list whose value is such that /// `value <= addr < next_value` and `addr` is the top of the stack. fn run_next_addresses_insert(&mut self) -> Result { @@ -540,6 +555,15 @@ fn get_proofs_and_jumpdests( proofs } +/// Return the largest prev_addr in `code` such that `code[pred_addr]` is an +/// opcode (and not the argument of some PUSHXX) and pred_addr <= address +fn get_closest_opcode_address(code: &[u8], address: usize) -> usize { + let (prev_addr, _) = CodeIterator::until(code, address + 1) + .last() + .unwrap_or((0, 0)); + prev_addr +} + /// An iterator over the EVM code contained in `code`, which skips the bytes /// that are the arguments of a PUSHXX opcode. struct CodeIterator<'a> { diff --git a/trace_decoder/src/lib.rs b/trace_decoder/src/lib.rs index 314b433d6..a1858d93e 100644 --- a/trace_decoder/src/lib.rs +++ b/trace_decoder/src/lib.rs @@ -5,7 +5,7 @@ #![allow(unused)] #![allow(private_interfaces)] -mod compact; +pub mod compact; pub mod decoding; mod deserializers; pub mod processed_block_trace;