Skip to content

Commit

Permalink
Add non jumpdests (#36)
Browse files Browse the repository at this point in the history
* Add non-jumpdests verification

* Fix soundness issue

* Add test for non jumpdests

* Apply suggestions from code review

Co-authored-by: David <dvdplm@gmail.com>

* Address reviews

* Update Changelog

* Minor

* Fix changelog

---------

Co-authored-by: David <dvdplm@gmail.com>
  • Loading branch information
4l0n50 and dvdplm authored Feb 26, 2024
1 parent 6fb1df2 commit c9f9b98
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 8 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Unreleased

### Changed

- Add verification for invalid jumps. [#36](https://github.com/0xPolygonZero/zk_evm/pull/36)
- Refactor accessed lists as sorted linked lists ([#30](https://github.com/0xPolygonZero/zk_evm/pull/30))
- Change visibility of `compact` mod ([#57](https://github.com/0xPolygonZero/zk_evm/pull/57))

Expand Down
2 changes: 2 additions & 0 deletions evm_arithmetization/src/cpu/kernel/asm/core/exception.asm
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ global invalid_jump_jumpi_destination_common:
%shr_const(32)
%jumpi(fault_exception) // This keeps one copy of jump_dest on the stack, but that's fine.
// jump_dest is a valid address; check if it points to a `JUMP_DEST`.
DUP1
%verify_non_jumpdest
%mload_current(@SEGMENT_JUMPDEST_BITS)
// stack: is_valid_jumpdest
%jumpi(panic) // Trap should never have been entered.
Expand Down
33 changes: 31 additions & 2 deletions evm_arithmetization/src/cpu/kernel/asm/core/jumpdest_analysis.asm
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,16 @@ global write_table_if_jumpdest:
// pos 0102030405060708091011121314151617181920212223242526272829303132
PUSH 0x8080808080808080808080808080808080808080808080808080808080808080
AND
%jump_neq_const(0x8080808080808080808080808080808080808080808080808080808080808080, return_pop_opcode)
// If we received a proof it MUST be valid or we abort immediately. This
// is especially important for non-jumpdest proofs. Otherwise a malicious
// prover might mark a valid jumpdest as invalid by providing an invalid proof
// that makes verify_non_jumpdest return prematurely.
%assert_eq_const(0x8080808080808080808080808080808080808080808080808080808080808080)
POP
%add_const(32)

// check the remaining path
%jump(verify_path_and_write_jumpdest_table)
return_pop_opcode:
POP
return:
// stack: proof_prefix_addr, ctx, jumpdest, retdest
Expand Down Expand Up @@ -342,3 +345,29 @@ check_proof:
%jump(jumpdest_analysis)
%%after:
%endmacro

// Non-deterministically find the closest opcode to addr
// and call write_table_if_jumpdest so that `@SEGMENT_JUMPDEST_BITS`
// will contain a 0 if and only if addr is not a jumpdest
// stack: addr, retdest
// stack: (empty)
global verify_non_jumpdest:
// stack: addr, retdest
GET_CONTEXT
SWAP1
// stack: addr, ctx
PROVER_INPUT(jumpdest_table::non_jumpdest_proof)
// stack: proof, addr, ctx,
// Check that proof <= addr as otherwise it allows
// a malicious prover to leave `@SEGMENT_JUMPDEST_BITS` as 0
// at position addr while it shouldn't.
DUP2 DUP2
%assert_le
%write_table_if_jumpdest
JUMP

%macro verify_non_jumpdest
%stack (addr) -> (addr, %%after)
%jump(verify_non_jumpdest)
%%after:
%endmacro
4 changes: 2 additions & 2 deletions evm_arithmetization/src/cpu/kernel/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ pub(crate) fn simulate_cpu_and_get_user_jumps<F: Field>(

interpreter.run();

log::debug!("jdt = {:?}", interpreter.jumpdest_table);
log::trace!("jumpdest table = {:?}", interpreter.jumpdest_table);

interpreter
.generation_state
Expand Down Expand Up @@ -1176,7 +1176,7 @@ impl<'a, F: Field> Interpreter<'a, F> {
self.push(syscall_info)
}

fn get_jumpdest_bit(&self, offset: usize) -> U256 {
pub(crate) fn get_jumpdest_bit(&self, offset: usize) -> U256 {
if self.generation_state.memory.contexts[self.context()].segments
[Segment::JumpdestBits.unscale()]
.content
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ fn test_packed_verification() -> Result<()> {
interpreter.set_code(CONTEXT, code.clone());
interpreter.generation_state.jumpdest_table = Some(HashMap::from([(3, vec![1, 33])]));

interpreter.run()?;
assert!(interpreter.run().is_err());

assert!(interpreter.get_jumpdest_bits(CONTEXT).is_empty());

Expand All @@ -153,3 +153,62 @@ fn test_packed_verification() -> Result<()> {

Ok(())
}

#[test]
fn test_verify_non_jumpdest() -> Result<()> {
// By default the interpreter will skip jumpdest analysis asm and compute
// the jumpdest table bits natively. We avoid that starting 1 line after
// performing the missing first PROVER_INPUT "by hand"
let verify_non_jumpdest = KERNEL.global_labels["verify_non_jumpdest"];
const CONTEXT: usize = 3; // arbitrary

let add = get_opcode("ADD");
let push2 = get_push_opcode(2);
let jumpdest = get_opcode("JUMPDEST");

#[rustfmt::skip]
let mut code: Vec<u8> = vec![
add,
jumpdest,
push2,
jumpdest, // part of PUSH2
jumpdest, // part of PUSH2
jumpdest,
add,
jumpdest,
];
code.extend(
(0..32)
.rev()
.map(get_push_opcode)
.chain(std::iter::once(jumpdest)),
);
let code_len = code.len();

// If we add 1 to each opcode the jumpdest at position 32 is never a valid
// jumpdest
for i in 8..code_len - 1 {
code[i] += 1;
let mut interpreter: Interpreter<F> =
Interpreter::new_with_kernel(verify_non_jumpdest, vec![]);
interpreter.generation_state.registers.context = CONTEXT;

interpreter.set_code(CONTEXT, code.clone());
code[i] -= 1;

// We check that all non jumpdests are indeed non jumpdests
for (j, &opcode) in code
.iter()
.enumerate()
.filter(|&(j, _)| j != 1 && j != 5 && j != 7)
{
interpreter.generation_state.registers.program_counter = verify_non_jumpdest;
interpreter.push(0xDEADBEEFu32.into());
interpreter.push(j.into());
interpreter.run()?;
assert!(interpreter.stack().is_empty());
assert_eq!(interpreter.get_jumpdest_bit(j), U256::zero());
}
}
Ok(())
}
28 changes: 26 additions & 2 deletions evm_arithmetization/src/generation/prover_input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,12 +240,13 @@ impl<F: Field> GenerationState<F> {
}
}

/// Generate either the next used jump address or the proof for the last
/// jump address.
/// Generate either the next used jump address, the proof for the last
/// jump address, or a non-jumpdest proof.
fn run_jumpdest_table(&mut self, input_fn: &ProverInputFn) -> Result<U256, ProgramError> {
match input_fn.0[1].as_str() {
"next_address" => self.run_next_jumpdest_table_address(),
"next_proof" => self.run_next_jumpdest_table_proof(),
"non_jumpdest_proof" => self.run_next_non_jumpdest_proof(),
_ => Err(ProgramError::ProverInputError(InvalidInput)),
}
}
Expand Down Expand Up @@ -321,6 +322,20 @@ impl<F: Field> GenerationState<F> {
}
}

/// Returns a non-jumpdest proof for the address on the top of the stack. A
/// non-jumpdest proof is the closest address to the address on the top of
/// the stack, if the closses address is >= 32, or zero otherwise.
fn run_next_non_jumpdest_proof(&mut self) -> Result<U256, ProgramError> {
let code = self.get_current_code()?;
let address = u256_to_usize(stack_peek(self, 0)?)?;
let closest_opcode_addr = get_closest_opcode_address(&code, address);
Ok(if closest_opcode_addr < 32 {
U256::zero()
} else {
closest_opcode_addr.into()
})
}

/// Returns a pointer to an element in the list whose value is such that
/// `value <= addr < next_value` and `addr` is the top of the stack.
fn run_next_addresses_insert(&mut self) -> Result<U256, ProgramError> {
Expand Down Expand Up @@ -530,6 +545,15 @@ fn get_proofs_and_jumpdests(
proofs
}

/// Return the largest prev_addr in `code` such that `code[pred_addr]` is an
/// opcode (and not the argument of some PUSHXX) and pred_addr <= address
fn get_closest_opcode_address(code: &[u8], address: usize) -> usize {
let (prev_addr, _) = CodeIterator::until(code, address + 1)
.last()
.unwrap_or((0, 0));
prev_addr
}

/// An iterator over the EVM code contained in `code`, which skips the bytes
/// that are the arguments of a PUSHXX opcode.
struct CodeIterator<'a> {
Expand Down

0 comments on commit c9f9b98

Please sign in to comment.