diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 000000000..7c028c03d --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,16 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## Unreleased + +### Changed + +- Refactor accessed lists as sorted linked lists ([#30](https://github.com/0xPolygonZero/zk_evm/pull/30)) +- Change visibility of `compact` mod ([#57](https://github.com/0xPolygonZero/zk_evm/pull/57)) + +## [0.1.0] - 2024-02-21 +* Initial release. diff --git a/Cargo.toml b/Cargo.toml index 52545235f..c3258a6b7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,19 +5,30 @@ resolver = "2" [workspace.dependencies] bytes = "1.5.0" enum-as-inner = "0.6.0" +env_logger = "0.10.0" ethereum-types = "0.14.1" hex = "0.4.3" hex-literal = "0.4.1" keccak-hash = "0.10.0" log = "0.4.20" num = "0.4.1" +rand = "0.8.5" rlp = "0.5.2" rlp-derive = "0.1.0" serde = "1.0.166" +serde_json = "1.0.96" +thiserror = "1.0.49" + +# plonky2-related dependencies +plonky2 = "0.2.0" +plonky2_maybe_rayon = "0.2.0" +plonky2_util = "0.2.0" +starky = "0.2.0" + [workspace.package] edition = "2021" license = "MIT OR Apache-2.0" repository = "https://github.com/0xPolygonZero/zk_evm" homepage = "https://github.com/0xPolygonZero/zk_evm" -keywords = ["cryptography", "SNARK", "PLONK", "FRI", "plonky2", "EVM", "ETHEREUM"] +keywords = ["cryptography", "STARK", "plonky2", "ethereum", "zk"] diff --git a/README.md b/README.md index ef89cd023..6bcce7f3e 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,38 @@ # zk_evm +A collection of libraries to prove Ethereum blocks with Polygon Zero Type 1 zkEVM, +powered by [starky and plonky2](https://github.com/0xPolygonZero/plonky2) proving systems. + + +## Directory structure + +This repository contains the following Rust crates: + +* [mpt_trie](./mpt_trie/README.md): A collection of types and functions to work with Ethereum Merkle Patricie Tries. + +* [trace_decoder](./trace_decoder/README.md): Flexible protocol designed to process Ethereum clients trace payloads into an IR format that can be +understood by the zkEVM prover. + +* [evm_arithmetization](./evm_arithmetization/README.md): Defines all the STARK constraints and recursive circuits to generate succinct proofs of EVM execution. +It uses starky and plonky2 as proving backend: https://github.com/0xPolygonZero/plonky2. + +* [proof_gen](./proof_gen/README.md): A convenience library for generating proofs from inputs already in Intermediate Representation (IR) format. + + +## Documentation + +Documentation is still incomplete and will be improved over time, a lot of useful material can +be found in the [docs](./docs/) section, including: + +* [sequence diagrams](./docs/usage_seq_diagrams.md) for the proof generation flow +* [zkEVM specifications](./docs/arithmetization/zkevm.pdf), detailing the underlying EVM proving statement + + +## Building + +The zkEVM stack currently requires the `nightly` toolchain, although we may transition to `stable` in the future. +Note that the prover uses the [Jemalloc](http://jemalloc.net/) memory allocator due to its superior performance. + ## License Licensed under either of @@ -11,4 +44,5 @@ at your option. ### Contribution -Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in the work by you, as defined in the Apache-2.0 license, shall be dual licensed as above, without any additional terms or conditions. +Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in the work by you, +as defined in the Apache-2.0 license, shall be dual licensed as above, without any additional terms or conditions. diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 000000000..d8e87e74e --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,17 @@ +# Polygon Technology Security Information + +## Link to vulnerability disclosure details (Bug Bounty). +- Websites and Applications: https://hackerone.com/polygon-technology +- Smart Contracts: https://immunefi.com/bounty/polygon + +## Languages that our team speaks and understands. +Preferred-Languages: en + +## Security-related job openings at Polygon. +https://polygon.technology/careers + +## Polygon security contact details. +security@polygon.technology + +## The URL for accessing the security.txt file. +Canonical: https://polygon.technology/security.txt diff --git a/docs/usage_seq_diagrams.md b/docs/usage_seq_diagrams.md index e0cf6bb43..71b308a0b 100644 --- a/docs/usage_seq_diagrams.md +++ b/docs/usage_seq_diagrams.md @@ -1,14 +1,15 @@ # Usage Diagrams -These are some hacked together diagrams showing how the protocol will (likely) be used. Also included what the old Edge proof generation process looked like as a reference. -## Proof Protocol +These are some diagrams showing how the protocol is implemented. + +## Proof Generation ```mermaid sequenceDiagram proof protocol client->>proof scheduler: protocol_payload - proof scheduler->>protocol decoder (lib): protolcol_payload - Note over proof scheduler,protocol decoder (lib): "txn_proof_gen_ir" are the payloads sent to Paladin for a txn - protocol decoder (lib)->>proof scheduler: [txn_proof_gen_ir] + proof scheduler->>trace decoder (lib): protocol_payload + Note over proof scheduler,trace decoder (lib): "txn_proof_gen_ir" are the payloads sent to Paladin for a txn + trace decoder (lib)->>proof scheduler: [txn_proof_gen_ir] proof scheduler->>paladin: [txn_proof_gen_ir] Note over proof scheduler,paladin: Paladin schedules jobs on multiple machines and returns a block proof loop txn_proof_gen_ir @@ -19,18 +20,3 @@ sequenceDiagram Note over proof scheduler,checkpoint contract: Note: Might send to an external service instead that compresses the proof proof scheduler->>checkpoint contract: block_proof ``` - -## Edge Proof Generation - -```mermaid -sequenceDiagram - edge->>zero provers (leader): block_trace - zero provers (leader)->>trace parsing lib: block_trace - Note over zero provers (leader),trace parsing lib: "txn_proof_gen_ir" are the payloads sent to each worker for a txn - trace parsing lib->>zero provers (leader): [txn_proof_gen_ir] - loop txn_proof_gen_ir - zero provers (leader)->>zero provers (worker): proof_gen_payload (txn, agg, block) - zero provers (worker)->>zero provers (leader): generated_proof (txn, agg, block) - end - zero provers (leader)->>checkpoint contract: block_proof -``` diff --git a/evm_arithmetization/Cargo.toml b/evm_arithmetization/Cargo.toml index 324f65b7a..d5a98d1c0 100644 --- a/evm_arithmetization/Cargo.toml +++ b/evm_arithmetization/Cargo.toml @@ -14,23 +14,23 @@ keywords.workspace = true [dependencies] anyhow = "1.0.40" bytes = { workspace = true } -env_logger = "0.10.0" -ethereum-types = "0.14.0" +env_logger = { workspace = true } +ethereum-types = { workspace = true } hex = { workspace = true, optional = true } hex-literal = { workspace = true } itertools = "0.11.0" keccak-hash = { workspace = true } log = { workspace = true } -plonky2_maybe_rayon = { git = "https://github.com/0xPolygonZero/plonky2.git", rev = "710225c9e0ac5822b2965ce74951cf000bbb8a2c" } +plonky2_maybe_rayon = { workspace = true } num = { workspace = true } num-bigint = "0.4.3" once_cell = "1.13.0" pest = "2.1.3" pest_derive = "2.1.0" -plonky2 = { git = "https://github.com/0xPolygonZero/plonky2.git", rev = "710225c9e0ac5822b2965ce74951cf000bbb8a2c" } -plonky2_util = { git = "https://github.com/0xPolygonZero/plonky2.git", rev = "710225c9e0ac5822b2965ce74951cf000bbb8a2c" } -starky = { git = "https://github.com/0xPolygonZero/plonky2.git", rev = "710225c9e0ac5822b2965ce74951cf000bbb8a2c" } -rand = "0.8.5" +plonky2 = { workspace = true } +plonky2_util = { workspace = true } +starky = { workspace = true } +rand = { workspace = true } rand_chacha = "0.3.1" rlp = { workspace = true } rlp-derive = { workspace = true } @@ -38,7 +38,7 @@ serde = { workspace = true, features = ["derive"] } static_assertions = "1.1.0" hashbrown = { version = "0.14.0" } tiny-keccak = "2.0.2" -serde_json = "1.0" +serde_json = { workspace = true } # Local dependencies mpt_trie = { version = "0.1.0", path = "../mpt_trie" } diff --git a/evm_arithmetization/src/cpu/kernel/asm/core/access_lists.asm b/evm_arithmetization/src/cpu/kernel/asm/core/access_lists.asm index 30afe27c4..76d183e5f 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/core/access_lists.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/core/access_lists.asm @@ -1,11 +1,55 @@ /// Access lists for addresses and storage keys. -/// The access list is stored in an array. The length of the array is stored in the global metadata. -/// For storage keys, the address and key are stored as two consecutive elements. -/// The array is stored in the SEGMENT_ACCESSED_ADDRESSES segment for addresses and in the SEGMENT_ACCESSED_STORAGE_KEYS segment for storage keys. +/// The access list is stored in a sorted linked list in SEGMENT_ACCESSED_ADDRESSES for addresses and +/// SEGMENT_ACCESSED_STORAGE_KEYS segment for storage keys. The length of +/// the segments is stored in the global metadata. /// Both arrays are stored in the kernel memory (context=0). -/// Searching and inserting is done by doing a linear search through the array. +/// Searching and inserting is done by guessing the predecessor in the list. /// If the address/storage key isn't found in the array, it is inserted at the end. -/// TODO: Look into using a more efficient data structure for the access lists. + +// Initialize the set of accessed addresses and storage keys with an empty list of the form (@U256_MAX)⮌ +// which is written as [@U256_MAX, @SEGMENT_ACCESSED_ADDRESSES] in SEGMENT_ACCESSED_ADDRESSES +// and as [@U256_MAX, _, _, @SEGMENT_ACCESSED_STORAGE_KEYS] in SEGMENT_ACCESSED_STORAGE_KEYS. +// Initialize SEGMENT_ACCESSED_ADDRESSES +global init_access_lists: + // stack: (empty) + // Store @U256_MAX at the beggining of the segment + PUSH @SEGMENT_ACCESSED_ADDRESSES // ctx == virt == 0 + DUP1 + PUSH @U256_MAX + MSTORE_GENERAL + // Store @SEGMENT_ACCESSED_ADDRESSES at address 1 + %increment + DUP1 + PUSH @SEGMENT_ACCESSED_ADDRESSES + MSTORE_GENERAL + + // Store the segment scaled length + %increment + %mstore_global_metadata(@GLOBAL_METADATA_ACCESSED_ADDRESSES_LEN) + // stack: (empty) + + // Initialize SEGMENT_ACCESSED_STORAGE_KEYS + // Store @U256_MAX at the beggining of the segment + PUSH @SEGMENT_ACCESSED_STORAGE_KEYS // ctx == virt == 0 + DUP1 + PUSH @U256_MAX + MSTORE_GENERAL + // Store @SEGMENT_ACCESSED_STORAGE_KEYS at address 3 + %add_const(3) + DUP1 + PUSH @SEGMENT_ACCESSED_STORAGE_KEYS + MSTORE_GENERAL + + // Store the segment scaled length + %increment + %mstore_global_metadata(@GLOBAL_METADATA_ACCESSED_STORAGE_KEYS_LEN) + JUMP + +%macro init_access_lists + PUSH %%after + %jump(init_access_lists) +%%after: +%endmacro %macro insert_accessed_addresses %stack (addr) -> (addr, %%after) @@ -19,76 +63,138 @@ POP %endmacro +// Multiply the ptr at the top of the stack by 2 +// and abort if 2*ptr - @SEGMENT_ACCESSED_ADDRESSES >= @GLOBAL_METADATA_ACCESSED_ADDRESSES_LEN +// In this way ptr must be pointing to the begining of a node. +%macro get_valid_addr_ptr + // stack: ptr + %mul_const(2) + PUSH @SEGMENT_ACCESSED_ADDRESSES + DUP2 + SUB + %assert_lt_const(@GLOBAL_METADATA_ACCESSED_ADDRESSES_LEN) + // stack: 2*ptr +%endmacro + + /// Inserts the address into the access list if it is not already present. /// Return 1 if the address was inserted, 0 if it was already present. global insert_accessed_addresses: // stack: addr, retdest + PROVER_INPUT(access_lists::address_insert) + // stack: pred_ptr/2, addr, retdest + %get_valid_addr_ptr + // stack: pred_ptr, addr, retdest + DUP1 + MLOAD_GENERAL + // stack: pred_addr, pred_ptr, addr, retdest + // If pred_add < addr OR pred_ptr == @SEGMENT_ACCESSED_ADDRESSES + DUP2 + %eq_const(@SEGMENT_ACCESSED_ADDRESSES) + // pred_ptr == start, pred_addr, pred_ptr, addr, retdest + DUP2 DUP5 GT + // addr > pred_addr, pred_ptr == start, pred_addr, pred_ptr, addr, retdest + OR + // (addr > pred_addr) || (pred_ptr == start), pred_addr, pred_ptr, addr, retdest + %jumpi(insert_new_address) + // Here, addr <= pred_addr. Assert that `addr == pred_addr`. + // stack: pred_addr, pred_ptr, addr, retdest + DUP3 + // stack: addr, pred_addr, pred_ptr, addr, retdest + %assert_eq + + // stack: pred_ptr, addr, retdest + // Check that this is not a deleted node + %increment + MLOAD_GENERAL + %jump_neq_const(@U256_MAX, address_found) + // We should have found the address. + PANIC +address_found: + // The address was already in the list + %stack (addr, retdest) -> (retdest, 0) // Return 0 to indicate that the address was already present. + JUMP + +insert_new_address: + // stack: pred_addr, pred_ptr, addr, retdest + POP + // get the value of the next address + %increment + // stack: next_ptr_ptr, addr, retdest %mload_global_metadata(@GLOBAL_METADATA_ACCESSED_ADDRESSES_LEN) - // stack: len, addr, retdest - PUSH @SEGMENT_ACCESSED_ADDRESSES ADD - PUSH @SEGMENT_ACCESSED_ADDRESSES -insert_accessed_addresses_loop: - // `i` and `len` are both scaled by SEGMENT_ACCESSED_ADDRESSES - %stack (i, len, addr, retdest) -> (i, len, i, len, addr, retdest) - EQ %jumpi(insert_address) - // stack: i, len, addr, retdest + DUP2 + MLOAD_GENERAL + // stack: next_ptr, new_ptr, next_ptr_ptr, addr, retdest + // Check that this is not a deleted node + DUP1 + %eq_const(@U256_MAX) + %assert_zero DUP1 MLOAD_GENERAL - // stack: loaded_addr, i, len, addr, retdest + // stack: next_val, next_ptr, new_ptr, next_ptr_ptr, addr, retdest + DUP5 + // Here, (addr > pred_addr) || (pred_ptr == @SEGMENT_ACCESSED_STORAGE_KEYS). + // We should have (addr < next_val), meaning the new value can be inserted between pred_ptr and next_ptr. + %assert_lt + // stack: next_ptr, new_ptr, next_ptr_ptr, addr, retdest + SWAP2 + DUP2 + // stack: new_ptr, next_ptr_ptr, new_ptr, next_ptr, addr, retdest + MSTORE_GENERAL + // stack: new_ptr, next_ptr, addr, retdest + DUP1 DUP4 - // stack: addr, loaded_addr, i, len, addr, retdest - EQ %jumpi(insert_accessed_addresses_found) - // stack: i, len, addr, retdest + MSTORE_GENERAL + // stack: new_ptr, next_ptr, addr, retdest %increment - %jump(insert_accessed_addresses_loop) - -insert_address: - %stack (i, len, addr, retdest) -> (i, addr, len, retdest) - DUP2 %journal_add_account_loaded // Add a journal entry for the loaded account. - %swap_mstore // Store new address at the end of the array. - // stack: len, retdest + DUP1 + // stack: new_next_ptr, new_next_ptr, next_ptr, addr, retdest + SWAP2 + MSTORE_GENERAL + // stack: new_next_ptr, addr, retdest %increment - %sub_const(@SEGMENT_ACCESSED_ADDRESSES) // unscale `len` - %mstore_global_metadata(@GLOBAL_METADATA_ACCESSED_ADDRESSES_LEN) // Store new length. - PUSH 1 // Return 1 to indicate that the address was inserted. - SWAP1 JUMP - -insert_accessed_addresses_found: - %stack (i, len, addr, retdest) -> (retdest, 0) // Return 0 to indicate that the address was already present. + %mstore_global_metadata(@GLOBAL_METADATA_ACCESSED_ADDRESSES_LEN) + // stack: addr, retdest + %journal_add_account_loaded + PUSH 1 + SWAP1 JUMP /// Remove the address from the access list. /// Panics if the address is not in the access list. +/// Otherwise it guesses the node before the address (pred) +/// such that (pred)->(next)->(next_next), where the (next) node +/// stores the address. It writes the link (pred)->(next_next) +/// and (next) is marked as deleted by writting U256_MAX in its +/// next node pointer. global remove_accessed_addresses: // stack: addr, retdest - %mload_global_metadata(@GLOBAL_METADATA_ACCESSED_ADDRESSES_LEN) - // stack: len, addr, retdest - PUSH @SEGMENT_ACCESSED_ADDRESSES ADD - PUSH @SEGMENT_ACCESSED_ADDRESSES -remove_accessed_addresses_loop: - // `i` and `len` are both scaled by SEGMENT_ACCESSED_ADDRESSES - %stack (i, len, addr, retdest) -> (i, len, i, len, addr, retdest) - EQ %jumpi(panic) - // stack: i, len, addr, retdest - DUP1 MLOAD_GENERAL - // stack: loaded_addr, i, len, addr, retdest + PROVER_INPUT(access_lists::address_remove) + // stack: pred_ptr/2, addr, retdest + %get_valid_addr_ptr + // stack: pred_ptr, addr, retdest + %increment + // stack: next_ptr_ptr, addr, retdest + DUP1 + MLOAD_GENERAL + // stack: next_ptr, next_ptr_ptr, addr, retdest + DUP1 + MLOAD_GENERAL + // stack: next_val, next_ptr, next_ptr_ptr, addr, retdest DUP4 - // stack: addr, loaded_addr, i, len, addr, retdest - EQ %jumpi(remove_accessed_addresses_found) - // stack: i, len, addr, retdest + %assert_eq + // stack: next_ptr, next_ptr_ptr, addr, retdest %increment - %jump(remove_accessed_addresses_loop) -remove_accessed_addresses_found: - %stack (i, len, addr, retdest) -> (len, 1, i, retdest) - SUB // len -= 1 - PUSH @SEGMENT_ACCESSED_ADDRESSES - DUP2 SUB // unscale `len` - %mstore_global_metadata(@GLOBAL_METADATA_ACCESSED_ADDRESSES_LEN) // Decrement the access list length. - // stack: len-1, i, retdest - MLOAD_GENERAL // Load the last address in the access list. - // stack: last_addr, i, retdest + // stack: next_next_ptr_ptr, next_ptr_ptr, addr, retdest + DUP1 + MLOAD_GENERAL + // stack: next_next_ptr, next_next_ptr_ptr, next_ptr_ptr, addr, retdest + SWAP1 + PUSH @U256_MAX + MSTORE_GENERAL + // stack: next_next_ptr, next_ptr_ptr, addr, retdest MSTORE_GENERAL - // Store the last address at the position of the removed address. + POP JUMP @@ -99,105 +205,172 @@ remove_accessed_addresses_found: // stack: cold_access, original_value %endmacro +// Multiply the ptr at the top of the stack by 4 +// and abort if 4*ptr - SEGMENT_ACCESSED_STORAGE_KEYS >= @GLOBAL_METADATA_ACCESSED_STORAGE_KEYS_LEN +// In this way ptr must be poiting to the begining of a node. +%macro get_valid_storage_ptr + // stack: ptr + %mul_const(4) + PUSH @SEGMENT_ACCESSED_STORAGE_KEYS + DUP2 + SUB + %assert_lt_const(@GLOBAL_METADATA_ACCESSED_STORAGE_KEYS_LEN) + // stack: 2*ptr +%endmacro + /// Inserts the storage key and value into the access list if it is not already present. /// `value` should be the current storage value at the slot `(addr, key)`. -/// Return `1, original_value` if the storage key was inserted, `0, original_value` if it was already present. +/// Return `1, value` if the storage key was inserted, `0, original_value` if it was already present. global insert_accessed_storage_keys: // stack: addr, key, value, retdest - %mload_global_metadata(@GLOBAL_METADATA_ACCESSED_STORAGE_KEYS_LEN) - // stack: len, addr, key, value, retdest - PUSH @SEGMENT_ACCESSED_STORAGE_KEYS ADD - PUSH @SEGMENT_ACCESSED_STORAGE_KEYS -insert_accessed_storage_keys_loop: - // `i` and `len` are both scaled by SEGMENT_ACCESSED_STORAGE_KEYS - %stack (i, len, addr, key, value, retdest) -> (i, len, i, len, addr, key, value, retdest) - EQ %jumpi(insert_storage_key) - // stack: i, len, addr, key, value, retdest - DUP1 %increment MLOAD_GENERAL - // stack: loaded_key, i, len, addr, key, value, retdest - DUP2 MLOAD_GENERAL - // stack: loaded_addr, loaded_key, i, len, addr, key, value, retdest - DUP5 EQ - // stack: loaded_addr==addr, loaded_key, i, len, addr, key, value, retdest - SWAP1 DUP6 EQ - // stack: loaded_key==key, loaded_addr==addr, i, len, addr, key, value, retdest - MUL // AND - %jumpi(insert_accessed_storage_keys_found) - // stack: i, len, addr, key, value, retdest + PROVER_INPUT(access_lists::storage_insert) + // stack: pred_ptr/4, addr, key, value, retdest + %get_valid_storage_ptr + // stack: pred_ptr, addr, key, value, retdest + DUP1 + MLOAD_GENERAL + DUP1 + // stack: pred_addr, pred_addr, pred_ptr, addr, key, value, retdest + DUP4 GT + DUP3 %eq_const(@SEGMENT_ACCESSED_STORAGE_KEYS) + ADD // OR + %jumpi(insert_storage_key) + // stack: pred_addr, pred_ptr, addr, key, value, retdest + // We know that addr <= pred_addr. It must hold that pred_addr == addr. + DUP3 + %assert_eq + // stack: pred_ptr, addr, key, value, retdest + DUP1 + %increment + MLOAD_GENERAL + // stack: pred_key, pred_ptr, addr, key, value, retdest + DUP1 DUP5 + GT + // stack: key > pred_key, pred_key, pred_ptr, addr, key, value, retdest + %jumpi(insert_storage_key) + // stack: pred_key, pred_ptr, addr, key, value, retdest + DUP4 + // We know that key <= pred_key. It must hold that pred_key == key. + %assert_eq + // stack: pred_ptr, addr, key, value, retdest + // Check that this is not a deleted node + DUP1 %add_const(3) - %jump(insert_accessed_storage_keys_loop) + MLOAD_GENERAL + %jump_neq_const(@U256_MAX, storage_key_found) + // The storage key is not in the list. + PANIC +storage_key_found: + // The address was already in the list + // stack: pred_ptr, addr, key, value, retdest + %add_const(2) + MLOAD_GENERAL + %stack (original_value, addr, key, value, retdest) -> (retdest, 0, original_value) // Return 0 to indicate that the address was already present. + JUMP insert_storage_key: - // stack: i, len, addr, key, value, retdest - DUP4 DUP4 %journal_add_storage_loaded // Add a journal entry for the loaded storage key. - // stack: i, len, addr, key, value, retdest - - %stack(dst, len, addr, key, value) -> (addr, dst, dst, key, dst, value, dst, @SEGMENT_ACCESSED_STORAGE_KEYS, value) - MSTORE_GENERAL // Store new address at the end of the array. - // stack: dst, key, dst, value, dst, segment, value, retdest - %increment SWAP1 - MSTORE_GENERAL // Store new key after that - // stack: dst, value, dst, segment, value, retdest - %add_const(2) SWAP1 - MSTORE_GENERAL // Store new value after that - // stack: dst, segment, value, retdest + // stack: pred_addr or pred_key, pred_ptr, addr, key, value, retdest + POP + // Insert a new storage key + // stack: pred_ptr, addr, key, value, retdest + // get the value of the next address %add_const(3) - SUB // unscale dst - %mstore_global_metadata(@GLOBAL_METADATA_ACCESSED_STORAGE_KEYS_LEN) // Store new length. - %stack (value, retdest) -> (retdest, 1, value) // Return 1 to indicate that the storage key was inserted. - JUMP - -insert_accessed_storage_keys_found: - // stack: i, len, addr, key, value, retdest - %add_const(2) + // stack: next_ptr_ptr, addr, key, value, retdest + %mload_global_metadata(@GLOBAL_METADATA_ACCESSED_STORAGE_KEYS_LEN) + DUP2 + MLOAD_GENERAL + // stack: next_ptr, new_ptr, next_ptr_ptr, addr, key, value, retdest + // Check that this is not a deleted node + DUP1 + %eq_const(@U256_MAX) + %assert_zero + DUP1 + MLOAD_GENERAL + // stack: next_val, next_ptr, new_ptr, next_ptr_ptr, addr, key, value, retdest + DUP5 + // Check that addr < next_val OR (next_val == addr AND key < next_key) + DUP2 DUP2 + LT + // stack: addr < next_val, addr, next_val, next_ptr, new_ptr, next_ptr_ptr, addr, key, value, retdest + SWAP2 + EQ + // stack: next_val == addr, addr < next_val, next_ptr, new_ptr, next_ptr_ptr, addr, key, value, retdest + DUP3 %increment MLOAD_GENERAL - %stack (original_value, len, addr, key, value, retdest) -> (retdest, 0, original_value) // Return 0 to indicate that the storage key was already present. + DUP8 + LT + // stack: next_key > key, next_val == addr, addr < next_val, next_ptr, new_ptr, next_ptr_ptr, addr, key, value, retdest + AND + OR + %assert_nonzero + // stack: next_ptr, new_ptr, next_ptr_ptr, addr, key, value, retdest + SWAP2 + DUP2 + MSTORE_GENERAL + // stack: new_ptr, next_ptr, addr, key, value, retdest + DUP1 + DUP4 + MSTORE_GENERAL // store addr + // stack: new_ptr, next_ptr, addr, key, value, retdest + %increment + DUP1 + DUP5 + MSTORE_GENERAL // store key + %increment + DUP1 + DUP6 + MSTORE_GENERAL // store value + // stack: new_ptr + 2, next_ptr, addr, key, value, retdest + %increment + DUP1 + // stack: new_next_ptr, new_next_ptr, next_ptr, addr, key, value, retdest + SWAP2 + MSTORE_GENERAL + // stack: new_next_ptr, addr, key, value, retdest + %increment + %mstore_global_metadata(@GLOBAL_METADATA_ACCESSED_STORAGE_KEYS_LEN) + // stack: addr, key, value, retdest + %stack (addr, key, value, retdest) -> (addr, key, retdest, 1, value) + %journal_add_storage_loaded JUMP /// Remove the storage key and its value from the access list. /// Panics if the key is not in the list. global remove_accessed_storage_keys: // stack: addr, key, retdest - %mload_global_metadata(@GLOBAL_METADATA_ACCESSED_STORAGE_KEYS_LEN) - // stack: len, addr, key, retdest - PUSH @SEGMENT_ACCESSED_STORAGE_KEYS ADD - PUSH @SEGMENT_ACCESSED_STORAGE_KEYS -remove_accessed_storage_keys_loop: - // `i` and `len` are both scaled by SEGMENT_ACCESSED_STORAGE_KEYS - %stack (i, len, addr, key, retdest) -> (i, len, i, len, addr, key, retdest) - EQ %jumpi(panic) - // stack: i, len, addr, key, retdest - DUP1 %increment MLOAD_GENERAL - // stack: loaded_key, i, len, addr, key, retdest - DUP2 MLOAD_GENERAL - // stack: loaded_addr, loaded_key, i, len, addr, key, retdest - DUP5 EQ - // stack: loaded_addr==addr, loaded_key, i, len, addr, key, retdest - SWAP1 DUP6 EQ - // stack: loaded_key==key, loaded_addr==addr, i, len, addr, key, retdest + PROVER_INPUT(access_lists::storage_remove) + // stack: pred_ptr/4, addr, key, retdest + %get_valid_storage_ptr + // stack: pred_ptr, addr, key, retdest + %add_const(3) + // stack: next_ptr_ptr, addr, key, retdest + DUP1 + MLOAD_GENERAL + // stack: next_ptr, next_ptr_ptr, addr, key, retdest + DUP1 + %increment + MLOAD_GENERAL + // stack: next_key, next_ptr, next_ptr_ptr, addr, key, retdest + DUP5 + EQ + DUP2 + MLOAD_GENERAL + // stack: next_addr, next_key == key, next_ptr, next_ptr_ptr, addr, key, retdest + DUP5 + EQ MUL // AND - %jumpi(remove_accessed_storage_keys_found) - // stack: i, len, addr, key, retdest + // stack: next_addr == addr AND next_key == key, next_ptr, next_ptr_ptr, addr, key, retdest + %assert_nonzero + // stack: next_ptr, next_ptr_ptr, addr, key, retdest %add_const(3) - %jump(remove_accessed_storage_keys_loop) - -remove_accessed_storage_keys_found: - %stack (i, len, addr, key, retdest) -> (len, 3, i, retdest) - SUB - PUSH @SEGMENT_ACCESSED_STORAGE_KEYS - DUP2 SUB // unscale - %mstore_global_metadata(@GLOBAL_METADATA_ACCESSED_STORAGE_KEYS_LEN) // Decrease the access list length. - // stack: len-3, i, retdest - DUP1 %add_const(2) MLOAD_GENERAL - // stack: last_value, len-3, i, retdest - DUP2 %add_const(1) MLOAD_GENERAL - // stack: last_key, last_value, len-3, i, retdest - DUP3 MLOAD_GENERAL - // stack: last_addr, last_key, last_value, len-3, i, retdest - DUP5 %swap_mstore // Move the last tuple to the position of the removed tuple. - // stack: last_key, last_value, len-3, i, retdest - DUP4 %add_const(1) %swap_mstore - // stack: last_value, len-3, i, retdest - DUP3 %add_const(2) %swap_mstore - // stack: len-3, i, retdest - %pop2 JUMP + // stack: next_next_ptr_ptr, next_ptr_ptr, addr, key, retdest + DUP1 + MLOAD_GENERAL + // stack: next_next_ptr, next_next_ptr_ptr, next_ptr_ptr, addr, key, retdest + SWAP1 + PUSH @U256_MAX + MSTORE_GENERAL + // stack: next_next_ptr, next_ptr_ptr, addr, key, retdest + MSTORE_GENERAL + %pop2 + JUMP \ No newline at end of file diff --git a/evm_arithmetization/src/cpu/kernel/asm/journal/account_loaded.asm b/evm_arithmetization/src/cpu/kernel/asm/journal/account_loaded.asm index 6c3c4ba04..d7da0a788 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/journal/account_loaded.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/journal/account_loaded.asm @@ -1,7 +1,9 @@ // struct AccountLoaded { address } %macro journal_add_account_loaded + // stack: address %journal_add_1(@JOURNAL_ENTRY_ACCOUNT_LOADED) + // stack: (empty) %endmacro global revert_account_loaded: diff --git a/evm_arithmetization/src/cpu/kernel/asm/journal/journal.asm b/evm_arithmetization/src/cpu/kernel/asm/journal/journal.asm index 9ba435087..39b6c9f1b 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/journal/journal.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/journal/journal.asm @@ -24,6 +24,7 @@ // stack: journal_size %increment %mstore_global_metadata(@GLOBAL_METADATA_JOURNAL_LEN) + // stack: (empty) %endmacro %macro journal_data_size @@ -65,6 +66,7 @@ %append_journal_data // stack: ptr %append_journal + // stack: (empty) %endmacro %macro journal_add_2(type) @@ -78,6 +80,7 @@ SWAP1 %append_journal_data // stack: ptr %append_journal + // stack: (empty) %endmacro %macro journal_add_3(type) @@ -93,6 +96,7 @@ SWAP1 %append_journal_data // stack: ptr %append_journal + // stack: (empty) %endmacro %macro journal_add_4(type) @@ -110,6 +114,7 @@ SWAP1 %append_journal_data // stack: ptr %append_journal + // stack: (empty) %endmacro %macro journal_load_1 diff --git a/evm_arithmetization/src/cpu/kernel/asm/main.asm b/evm_arithmetization/src/cpu/kernel/asm/main.asm index 1307f6d5d..780c4b6a6 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/main.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/main.asm @@ -12,6 +12,9 @@ global main: // Initialise the shift table %shift_table_init + // Initialize accessed addresses and storage keys lists + %init_access_lists + // Initialize the RLP DATA pointer to its initial position (ctx == virt == 0, segment = RLP) PUSH @SEGMENT_RLP_RAW %mstore_global_metadata(@GLOBAL_METADATA_RLP_DATA_SIZE) diff --git a/evm_arithmetization/src/cpu/kernel/asm/util/assertions.asm b/evm_arithmetization/src/cpu/kernel/asm/util/assertions.asm index 6c517407b..dc73721b3 100644 --- a/evm_arithmetization/src/cpu/kernel/asm/util/assertions.asm +++ b/evm_arithmetization/src/cpu/kernel/asm/util/assertions.asm @@ -34,10 +34,8 @@ global panic: %endmacro %macro assert_lt - // %assert_zero is cheaper than %assert_nonzero, so we will leverage the - // fact that (x < y) == !(x >= y). - GE - %assert_zero + LT + %assert_nonzero %endmacro %macro assert_lt(ret) diff --git a/evm_arithmetization/src/cpu/kernel/tests/account_code.rs b/evm_arithmetization/src/cpu/kernel/tests/account_code.rs index d3ac0c629..33c1024b0 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/account_code.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/account_code.rs @@ -198,6 +198,12 @@ fn test_extcodecopy() -> Result<()> { [Segment::ContextMetadata.unscale()] .set(GasLimit.unscale(), U256::from(1000000000000u64)); + // Pre-initialize the accessed addresses list. + let init_accessed_addresses = KERNEL.global_labels["init_access_lists"]; + interpreter.generation_state.registers.program_counter = init_accessed_addresses; + interpreter.push(0xdeadbeefu32.into()); + interpreter.run()?; + let extcodecopy = KERNEL.global_labels["sys_extcodecopy"]; // Put random data in main memory and the `KernelAccountCode` segment for @@ -327,6 +333,12 @@ fn sstore() -> Result<()> { let initial_stack = vec![]; let mut interpreter: Interpreter = Interpreter::new_with_kernel(0, initial_stack); + // Pre-initialize the accessed addresses list. + let init_accessed_addresses = KERNEL.global_labels["init_access_lists"]; + interpreter.generation_state.registers.program_counter = init_accessed_addresses; + interpreter.push(0xdeadbeefu32.into()); + interpreter.run()?; + // Prepare the interpreter by inserting the account in the state trie. prepare_interpreter_all_accounts(&mut interpreter, trie_inputs, addr, &code)?; @@ -417,6 +429,12 @@ fn sload() -> Result<()> { let initial_stack = vec![]; let mut interpreter: Interpreter = Interpreter::new_with_kernel(0, initial_stack); + // Pre-initialize the accessed addresses list. + let init_accessed_addresses = KERNEL.global_labels["init_access_lists"]; + interpreter.generation_state.registers.program_counter = init_accessed_addresses; + interpreter.push(0xdeadbeefu32.into()); + interpreter.run()?; + // Prepare the interpreter by inserting the account in the state trie. prepare_interpreter_all_accounts(&mut interpreter, trie_inputs, addr, &code)?; interpreter.run()?; diff --git a/evm_arithmetization/src/cpu/kernel/tests/core/access_lists.rs b/evm_arithmetization/src/cpu/kernel/tests/core/access_lists.rs index 4ee38e92c..d878f1732 100644 --- a/evm_arithmetization/src/cpu/kernel/tests/core/access_lists.rs +++ b/evm_arithmetization/src/cpu/kernel/tests/core/access_lists.rs @@ -1,7 +1,8 @@ use std::collections::HashSet; use anyhow::Result; -use ethereum_types::{Address, U256}; +use ethereum_types::{Address, H160, U256}; +use hashbrown::hash_map::rayon::IntoParIter; use plonky2::field::goldilocks_field::GoldilocksField as F; use rand::{thread_rng, Rng}; @@ -10,68 +11,175 @@ use crate::cpu::kernel::constants::global_metadata::GlobalMetadata::{ AccessedAddressesLen, AccessedStorageKeysLen, }; use crate::cpu::kernel::interpreter::Interpreter; -use crate::memory::segments::Segment::{AccessedAddresses, AccessedStorageKeys}; +use crate::memory::segments::Segment::{self, AccessedAddresses, AccessedStorageKeys}; +use crate::memory::segments::SEGMENT_SCALING_FACTOR; use crate::witness::memory::MemoryAddress; +#[test] +fn test_init_access_lists() -> Result<()> { + let init_label = KERNEL.global_labels["init_access_lists"]; + + // Check the initial state of the access list in the kernel. + let initial_stack = vec![0xdeadbeefu32.into()]; + let mut interpreter = Interpreter::::new_with_kernel(init_label, initial_stack); + interpreter.run()?; + + assert!(interpreter.stack().is_empty()); + + let acc_addr_list: Vec = (0..2) + .map(|i| { + interpreter.generation_state.memory.get(MemoryAddress::new( + 0, + Segment::AccessedAddresses, + i, + )) + }) + .collect(); + assert_eq!( + vec![U256::MAX, (Segment::AccessedAddresses as usize).into(),], + acc_addr_list + ); + + let acc_storage_keys: Vec = (0..4) + .map(|i| { + interpreter.generation_state.memory.get(MemoryAddress::new( + 0, + Segment::AccessedStorageKeys, + i, + )) + }) + .collect(); + + assert_eq!( + vec![ + U256::MAX, + U256::zero(), + U256::zero(), + (Segment::AccessedStorageKeys as usize).into() + ], + acc_storage_keys + ); + + Ok(()) +} + +#[test] +fn test_list_iterator() -> Result<()> { + let init_label = KERNEL.global_labels["init_access_lists"]; + + let initial_stack = vec![0xdeadbeefu32.into()]; + let mut interpreter = Interpreter::::new_with_kernel(init_label, initial_stack); + interpreter.run()?; + + // test the list iterator + let mut list = interpreter + .generation_state + .get_addresses_access_list() + .expect("Since we called init_access_lists there must be a list"); + + let Some((pos_0, next_val_0, _)) = list.next() else { + return Err(anyhow::Error::msg("Couldn't get value")); + }; + assert_eq!(pos_0, 0); + assert_eq!(next_val_0, U256::MAX); + let Some((pos_0, next_val_0, _)) = list.next() else { + return Err(anyhow::Error::msg("Couldn't get value")); + }; + assert_eq!(pos_0, 0); + Ok(()) +} + +#[test] +fn test_insert_address() -> Result<()> { + let init_label = KERNEL.global_labels["init_access_lists"]; + + // Test for address already in list. + let initial_stack = vec![0xdeadbeefu32.into()]; + let mut interpreter = Interpreter::::new_with_kernel(init_label, initial_stack); + interpreter.run()?; + + let insert_accessed_addresses = KERNEL.global_labels["insert_accessed_addresses"]; + + let retaddr = 0xdeadbeefu32.into(); + let mut rng = thread_rng(); + let mut address: H160 = rng.gen(); + + assert!(address != H160::zero(), "Cosmic luck or bad RNG?"); + + interpreter.push(retaddr); + interpreter.push(U256::from(address.0.as_slice())); + interpreter.generation_state.registers.program_counter = insert_accessed_addresses; + + interpreter.run()?; + assert_eq!(interpreter.stack(), &[U256::one()]); + assert_eq!( + interpreter + .generation_state + .memory + .get(MemoryAddress::new_bundle(U256::from(AccessedAddressesLen as usize)).unwrap()), + U256::from(Segment::AccessedAddresses as usize + 4) + ); + + Ok(()) +} + #[test] fn test_insert_accessed_addresses() -> Result<()> { + let init_access_lists = KERNEL.global_labels["init_access_lists"]; + + // Test for address already in list. + let initial_stack = vec![0xdeadbeefu32.into()]; + let mut interpreter = Interpreter::::new_with_kernel(init_access_lists, initial_stack); + interpreter.run()?; + let insert_accessed_addresses = KERNEL.global_labels["insert_accessed_addresses"]; let retaddr = 0xdeadbeefu32.into(); let mut rng = thread_rng(); - let n = rng.gen_range(1..10); - let addresses = (0..n) + let n = 10; + let mut addresses = (0..n) .map(|_| rng.gen::
()) .collect::>() .into_iter() .collect::>(); - let addr_in_list = addresses[rng.gen_range(0..n)]; let addr_not_in_list = rng.gen::
(); assert!( !addresses.contains(&addr_not_in_list), "Cosmic luck or bad RNG?" ); - // Test for address already in list. - let initial_stack = vec![retaddr, U256::from(addr_in_list.0.as_slice())]; - let mut interpreter: Interpreter = - Interpreter::new_with_kernel(insert_accessed_addresses, initial_stack); + let offset = Segment::AccessedAddresses as usize; for i in 0..n { let addr = U256::from(addresses[i].0.as_slice()); - interpreter - .generation_state - .memory - .set(MemoryAddress::new(0, AccessedAddresses, i), addr); + interpreter.push(0xdeadbeefu32.into()); + interpreter.push(addr); + interpreter.generation_state.registers.program_counter = insert_accessed_addresses; + interpreter.run()?; + assert_eq!(interpreter.pop().unwrap(), U256::one()); } - interpreter.generation_state.memory.set( - MemoryAddress::new_bundle(U256::from(AccessedAddressesLen as usize)).unwrap(), - U256::from(n), - ); - interpreter.run()?; - assert_eq!(interpreter.stack(), &[U256::zero()]); - assert_eq!( - interpreter - .generation_state - .memory - .get(MemoryAddress::new_bundle(U256::from(AccessedAddressesLen as usize)).unwrap()), - U256::from(n) - ); - // Test for address not in list. - let initial_stack = vec![retaddr, U256::from(addr_not_in_list.0.as_slice())]; - let mut interpreter: Interpreter = - Interpreter::new_with_kernel(insert_accessed_addresses, initial_stack); for i in 0..n { - let addr = U256::from(addresses[i].0.as_slice()); - interpreter - .generation_state - .memory - .set(MemoryAddress::new(0, AccessedAddresses, i), addr); + // Test for address already in list. + let addr_in_list = addresses[i]; + interpreter.push(retaddr); + interpreter.push(U256::from(addr_in_list.0.as_slice())); + interpreter.generation_state.registers.program_counter = insert_accessed_addresses; + interpreter.run()?; + assert_eq!(interpreter.pop().unwrap(), U256::zero()); + assert_eq!( + interpreter + .generation_state + .memory + .get(MemoryAddress::new_bundle(U256::from(AccessedAddressesLen as usize)).unwrap()), + U256::from(offset + 2 * (n + 1)) + ); } - interpreter.generation_state.memory.set( - MemoryAddress::new_bundle(U256::from(AccessedAddressesLen as usize)).unwrap(), - U256::from(n), - ); + + // Test for address not in list. + interpreter.push(retaddr); + interpreter.push(U256::from(addr_not_in_list.0.as_slice())); + interpreter.generation_state.registers.program_counter = insert_accessed_addresses; + interpreter.run()?; assert_eq!(interpreter.stack(), &[U256::one()]); assert_eq!( @@ -79,13 +187,14 @@ fn test_insert_accessed_addresses() -> Result<()> { .generation_state .memory .get(MemoryAddress::new_bundle(U256::from(AccessedAddressesLen as usize)).unwrap()), - U256::from(n + 1) + U256::from(offset + 2 * (n + 2)) ); assert_eq!( - interpreter - .generation_state - .memory - .get(MemoryAddress::new(0, AccessedAddresses, n)), + interpreter.generation_state.memory.get(MemoryAddress::new( + 0, + AccessedAddresses, + 2 * (n + 1) + )), U256::from(addr_not_in_list.0.as_slice()) ); @@ -94,12 +203,19 @@ fn test_insert_accessed_addresses() -> Result<()> { #[test] fn test_insert_accessed_storage_keys() -> Result<()> { + let init_access_lists = KERNEL.global_labels["init_access_lists"]; + + // Test for address already in list. + let initial_stack = vec![0xdeadbeefu32.into()]; + let mut interpreter = Interpreter::::new_with_kernel(init_access_lists, initial_stack); + interpreter.run()?; + let insert_accessed_storage_keys = KERNEL.global_labels["insert_accessed_storage_keys"]; let retaddr = 0xdeadbeefu32.into(); let mut rng = thread_rng(); - let n = rng.gen_range(1..10); - let storage_keys = (0..n) + let n = 10; + let mut storage_keys = (0..n) .map(|_| (rng.gen::
(), U256(rng.gen()), U256(rng.gen()))) .collect::>() .into_iter() @@ -111,72 +227,47 @@ fn test_insert_accessed_storage_keys() -> Result<()> { "Cosmic luck or bad RNG?" ); - // Test for storage key already in list. - let initial_stack = vec![ - retaddr, - storage_key_in_list.2, - storage_key_in_list.1, - U256::from(storage_key_in_list.0 .0.as_slice()), - ]; - let mut interpreter: Interpreter = - Interpreter::new_with_kernel(insert_accessed_storage_keys, initial_stack); + let offset = Segment::AccessedStorageKeys as usize; for i in 0..n { let addr = U256::from(storage_keys[i].0 .0.as_slice()); - interpreter - .generation_state - .memory - .set(MemoryAddress::new(0, AccessedStorageKeys, 3 * i), addr); - interpreter.generation_state.memory.set( - MemoryAddress::new(0, AccessedStorageKeys, 3 * i + 1), - storage_keys[i].1, - ); - interpreter.generation_state.memory.set( - MemoryAddress::new(0, AccessedStorageKeys, 3 * i + 2), - storage_keys[i].2, - ); + let key = storage_keys[i].1; + let value = storage_keys[i].2; + interpreter.push(retaddr); + interpreter.push(value); + interpreter.push(key); + interpreter.push(addr); + interpreter.generation_state.registers.program_counter = insert_accessed_storage_keys; + interpreter.run()?; + assert_eq!(interpreter.pop().unwrap(), U256::one()); + assert_eq!(interpreter.pop().unwrap(), value); } - interpreter.generation_state.memory.set( - MemoryAddress::new_bundle(U256::from(AccessedStorageKeysLen as usize)).unwrap(), - U256::from(3 * n), - ); - interpreter.run()?; - assert_eq!(interpreter.stack(), &[storage_key_in_list.2, U256::zero()]); - assert_eq!( - interpreter - .generation_state - .memory - .get(MemoryAddress::new_bundle(U256::from(AccessedStorageKeysLen as usize)).unwrap()), - U256::from(3 * n) - ); - // Test for storage key not in list. - let initial_stack = vec![ - retaddr, - storage_key_not_in_list.2, - storage_key_not_in_list.1, - U256::from(storage_key_not_in_list.0 .0.as_slice()), - ]; - let mut interpreter: Interpreter = - Interpreter::new_with_kernel(insert_accessed_storage_keys, initial_stack); - for i in 0..n { - let addr = U256::from(storage_keys[i].0 .0.as_slice()); - interpreter - .generation_state - .memory - .set(MemoryAddress::new(0, AccessedStorageKeys, 3 * i), addr); - interpreter.generation_state.memory.set( - MemoryAddress::new(0, AccessedStorageKeys, 3 * i + 1), - storage_keys[i].1, - ); - interpreter.generation_state.memory.set( - MemoryAddress::new(0, AccessedStorageKeys, 3 * i + 2), - storage_keys[i].2, + for i in 0..10 { + // Test for storage key already in list. + let (addr, key, value) = storage_keys[i]; + interpreter.push(retaddr); + interpreter.push(value); + interpreter.push(key); + interpreter.push(U256::from(addr.0.as_slice())); + interpreter.generation_state.registers.program_counter = insert_accessed_storage_keys; + interpreter.run()?; + assert_eq!(interpreter.pop().unwrap(), U256::zero()); + assert_eq!(interpreter.pop().unwrap(), value); + assert_eq!( + interpreter.generation_state.memory.get( + MemoryAddress::new_bundle(U256::from(AccessedStorageKeysLen as usize)).unwrap() + ), + U256::from(offset + 4 * (n + 1)) ); } - interpreter.generation_state.memory.set( - MemoryAddress::new_bundle(U256::from(AccessedStorageKeysLen as usize)).unwrap(), - U256::from(3 * n), - ); + + // Test for storage key not in list. + interpreter.push(retaddr); + interpreter.push(storage_key_not_in_list.2); + interpreter.push(storage_key_not_in_list.1); + interpreter.push(U256::from(storage_key_not_in_list.0 .0.as_slice())); + interpreter.generation_state.registers.program_counter = insert_accessed_storage_keys; + interpreter.run()?; assert_eq!( interpreter.stack(), @@ -187,20 +278,21 @@ fn test_insert_accessed_storage_keys() -> Result<()> { .generation_state .memory .get(MemoryAddress::new_bundle(U256::from(AccessedStorageKeysLen as usize)).unwrap()), - U256::from(3 * (n + 1)) + U256::from(offset + 4 * (n + 2)) ); assert_eq!( - interpreter - .generation_state - .memory - .get(MemoryAddress::new(0, AccessedStorageKeys, 3 * n,)), + interpreter.generation_state.memory.get(MemoryAddress::new( + 0, + AccessedStorageKeys, + 4 * (n + 1) + )), U256::from(storage_key_not_in_list.0 .0.as_slice()) ); assert_eq!( interpreter.generation_state.memory.get(MemoryAddress::new( 0, AccessedStorageKeys, - 3 * n + 1, + 4 * (n + 1) + 1 )), storage_key_not_in_list.1 ); @@ -208,7 +300,7 @@ fn test_insert_accessed_storage_keys() -> Result<()> { interpreter.generation_state.memory.get(MemoryAddress::new( 0, AccessedStorageKeys, - 3 * n + 2, + 4 * (n + 1) + 2 )), storage_key_not_in_list.2 ); diff --git a/evm_arithmetization/src/generation/prover_input.rs b/evm_arithmetization/src/generation/prover_input.rs index 9a2a38c80..766f880e4 100644 --- a/evm_arithmetization/src/generation/prover_input.rs +++ b/evm_arithmetization/src/generation/prover_input.rs @@ -55,6 +55,7 @@ impl GenerationState { "withdrawal" => self.run_withdrawal(), "num_bits" => self.run_num_bits(), "jumpdest_table" => self.run_jumpdest_table(input_fn), + "access_lists" => self.run_access_lists(input_fn), _ => Err(ProgramError::ProverInputError(InvalidFunction)), } } @@ -249,6 +250,18 @@ impl GenerationState { } } + /// Generates either the next used jump address or the proof for the last + /// jump address. + fn run_access_lists(&mut self, input_fn: &ProverInputFn) -> Result { + match input_fn.0[1].as_str() { + "address_insert" => self.run_next_addresses_insert(), + "storage_insert" => self.run_next_storage_insert(), + "address_remove" => self.run_next_addresses_remove(), + "storage_remove" => self.run_next_storage_remove(), + _ => Err(ProgramError::ProverInputError(InvalidInput)), + } + } + /// Returns the next used jump address. fn run_next_jumpdest_table_address(&mut self) -> Result { let context = u256_to_usize(stack_peek(self, 0)? >> CONTEXT_SCALING_FACTOR)?; @@ -307,6 +320,61 @@ impl GenerationState { )) } } + + /// Returns a pointer to an element in the list whose value is such that + /// `value <= addr < next_value` and `addr` is the top of the stack. + fn run_next_addresses_insert(&mut self) -> Result { + let addr = stack_peek(self, 0)?; + for (curr_ptr, next_addr, _) in self.get_addresses_access_list()? { + if next_addr > addr { + // In order to avoid pointers to the next ptr, we use the fact + // that valid pointers and Segment::AccessedAddresses are always even + return Ok(((Segment::AccessedAddresses as usize + curr_ptr) / 2usize).into()); + } + } + Ok((Segment::AccessedAddresses as usize).into()) + } + + /// Returns a pointer to an element in the list whose value is such that + /// `value < addr == next_value` and addr is the top of the stack. + /// If the element is not in the list returns loops forever + fn run_next_addresses_remove(&mut self) -> Result { + let addr = stack_peek(self, 0)?; + for (curr_ptr, next_addr, _) in self.get_addresses_access_list()? { + if next_addr == addr { + return Ok(((Segment::AccessedAddresses as usize + curr_ptr) / 2usize).into()); + } + } + Ok((Segment::AccessedAddresses as usize).into()) + } + + /// Returns a pointer to the predecessor of the top of the stack in the + /// accessed storage keys list. + fn run_next_storage_insert(&mut self) -> Result { + let addr = stack_peek(self, 0)?; + let key = stack_peek(self, 1)?; + for (curr_ptr, next_addr, next_key) in self.get_storage_keys_access_list()? { + if next_addr > addr || (next_addr == addr && next_key > key) { + // In order to avoid pointers to the key, value or next ptr, we use the fact + // that valid pointers and Segment::AccessedAddresses are always multiples of 4 + return Ok(((Segment::AccessedStorageKeys as usize + curr_ptr) / 4usize).into()); + } + } + Ok((Segment::AccessedAddresses as usize).into()) + } + + /// Returns a pointer to the predecessor of the top of the stack in the + /// accessed storage keys list. + fn run_next_storage_remove(&mut self) -> Result { + let addr = stack_peek(self, 0)?; + let key = stack_peek(self, 1)?; + for (curr_ptr, next_addr, next_key) in self.get_storage_keys_access_list()? { + if (next_addr == addr && next_key == key) || next_addr == U256::MAX { + return Ok(((Segment::AccessedStorageKeys as usize + curr_ptr) / 4usize).into()); + } + } + Ok((Segment::AccessedStorageKeys as usize).into()) + } } impl GenerationState { @@ -383,6 +451,46 @@ impl GenerationState { } } } + + pub(crate) fn get_addresses_access_list(&self) -> Result { + // `GlobalMetadata::AccessedAddressesLen` stores the value of the next available + // virtual address in the segment. In order to get the length we need + // to substract `Segment::AccessedAddresses` as usize. + let acc_addr_len = + u256_to_usize(self.get_global_metadata(GlobalMetadata::AccessedAddressesLen))? + - Segment::AccessedAddresses as usize; + AccList::from_mem_and_segment( + &self.memory.contexts[0].segments[Segment::AccessedAddresses.unscale()].content + [..acc_addr_len], + Segment::AccessedAddresses, + ) + } + + fn get_global_metadata(&self, data: GlobalMetadata) -> U256 { + self.memory.get(MemoryAddress::new( + 0, + Segment::GlobalMetadata, + data.unscale(), + )) + } + + pub(crate) fn get_storage_keys_access_list(&self) -> Result { + // GlobalMetadata::AccessedStorageKeysLen stores the value of the next available + // virtual address in the segment. In order to get the length we need + // to substract Segment::AccessedStorageKeys as usize + let acc_storage_len = u256_to_usize( + self.memory.get(MemoryAddress::new( + 0, + Segment::GlobalMetadata, + GlobalMetadata::AccessedStorageKeysLen.unscale(), + )) - Segment::AccessedStorageKeys as usize, + )?; + AccList::from_mem_and_segment( + &self.memory.contexts[0].segments[Segment::AccessedStorageKeys.unscale()].content + [..acc_storage_len], + Segment::AccessedStorageKeys, + ) + } } /// For all address in `jumpdest_table` smaller than `largest_address`, @@ -468,6 +576,63 @@ impl<'a> Iterator for CodeIterator<'a> { } } +// Iterates over a linked list implemented using a vector `access_list_mem`. +// In this representation, the values of nodes are stored in the range +// `access_list_mem[i..i + node_size - 1]`, and `access_list_mem[i + node_size - +// 1]` holds the address of the next node, where i = node_size * j. +pub(crate) struct AccList<'a> { + access_list_mem: &'a [U256], + node_size: usize, + offset: usize, + pos: usize, +} + +impl<'a> AccList<'a> { + fn from_mem_and_segment( + access_list_mem: &'a [U256], + segment: Segment, + ) -> Result { + if access_list_mem.is_empty() { + return Err(ProgramError::ProverInputError(InvalidInput)); + } + Ok(Self { + access_list_mem, + node_size: match segment { + Segment::AccessedAddresses => 2, + Segment::AccessedStorageKeys => 4, + _ => return Err(ProgramError::ProverInputError(InvalidInput)), + }, + offset: segment as usize, + pos: 0, + }) + } +} + +impl<'a> Iterator for AccList<'a> { + type Item = (usize, U256, U256); + + fn next(&mut self) -> Option { + let addr = self.access_list_mem[self.pos]; + if let Ok(new_pos) = u256_to_usize(self.access_list_mem[self.pos + self.node_size - 1]) { + let old_pos = self.pos; + self.pos = new_pos - self.offset; + if self.node_size == 2 { + // addresses + Some((old_pos, self.access_list_mem[self.pos], U256::zero())) + } else { + // storage_keys + Some(( + old_pos, + self.access_list_mem[self.pos], + self.access_list_mem[self.pos + 1], + )) + } + } else { + None + } + } +} + enum EvmField { Bls381Base, Bls381Scalar, diff --git a/evm_arithmetization/src/keccak/keccak_stark.rs b/evm_arithmetization/src/keccak/keccak_stark.rs index 969a0357f..e1a0484d6 100644 --- a/evm_arithmetization/src/keccak/keccak_stark.rs +++ b/evm_arithmetization/src/keccak/keccak_stark.rs @@ -275,14 +275,8 @@ impl, const D: usize> Stark for KeccakStark, const D: usize> Stark for KeccakStark>( let local_values = vars.get_local_values(); let next_values = vars.get_next_values(); + // Constrain the flags to be either 0 or 1. + for i in 0..NUM_ROUNDS { + let current_round_flag = local_values[reg_step(i)]; + yield_constr.constraint(current_round_flag * (current_round_flag - F::ONE)); + } + // Initially, the first step flag should be 1 while the others should be 0. yield_constr.constraint_first_row(local_values[reg_step(0)] - F::ONE); for i in 1..NUM_ROUNDS { @@ -25,17 +31,23 @@ pub(crate) fn eval_round_flags>( } // Flags should circularly increment, or be all zero for padding rows. + let current_any_flag = (0..NUM_ROUNDS) + .map(|i| local_values[reg_step(i)]) + .sum::

(); let next_any_flag = (0..NUM_ROUNDS).map(|i| next_values[reg_step(i)]).sum::

(); + // Padding row should only start after the last round row. + let last_round_flag = local_values[reg_step(NUM_ROUNDS - 1)]; + let padding_constraint = + (next_any_flag - F::ONE) * current_any_flag * (last_round_flag - F::ONE); for i in 0..NUM_ROUNDS { let current_round_flag = local_values[reg_step(i)]; let next_round_flag = next_values[reg_step((i + 1) % NUM_ROUNDS)]; - yield_constr.constraint_transition(next_any_flag * (next_round_flag - current_round_flag)); + yield_constr.constraint_transition( + next_any_flag * (next_round_flag - current_round_flag) + padding_constraint, + ); } // Padding rows should always be followed by padding rows. - let current_any_flag = (0..NUM_ROUNDS) - .map(|i| local_values[reg_step(i)]) - .sum::

(); yield_constr.constraint_transition(next_any_flag * (current_any_flag - F::ONE)); } @@ -48,6 +60,14 @@ pub(crate) fn eval_round_flags_recursively, const D let local_values = vars.get_local_values(); let next_values = vars.get_next_values(); + // Constrain the flags to be either 0 or 1. + for i in 0..NUM_ROUNDS { + let current_round_flag = local_values[reg_step(i)]; + let constraint = + builder.mul_sub_extension(current_round_flag, current_round_flag, current_round_flag); + yield_constr.constraint(builder, constraint); + } + // Initially, the first step flag should be 1 while the others should be 0. let step_0_minus_1 = builder.sub_extension(local_values[reg_step(0)], one); yield_constr.constraint_first_row(builder, step_0_minus_1); @@ -56,19 +76,25 @@ pub(crate) fn eval_round_flags_recursively, const D } // Flags should circularly increment, or be all zero for padding rows. + let current_any_flag = + builder.add_many_extension((0..NUM_ROUNDS).map(|i| local_values[reg_step(i)])); let next_any_flag = builder.add_many_extension((0..NUM_ROUNDS).map(|i| next_values[reg_step(i)])); + // Padding row should only start after the last round row. + let last_round_flag = local_values[reg_step(NUM_ROUNDS - 1)]; + let padding_constraint = { + let tmp = builder.mul_sub_extension(current_any_flag, next_any_flag, current_any_flag); + builder.mul_sub_extension(tmp, last_round_flag, tmp) + }; for i in 0..NUM_ROUNDS { let current_round_flag = local_values[reg_step(i)]; let next_round_flag = next_values[reg_step((i + 1) % NUM_ROUNDS)]; - let diff = builder.sub_extension(next_round_flag, current_round_flag); - let constraint = builder.mul_extension(next_any_flag, diff); + let flag_diff = builder.sub_extension(next_round_flag, current_round_flag); + let constraint = builder.mul_add_extension(next_any_flag, flag_diff, padding_constraint); yield_constr.constraint_transition(builder, constraint); } // Padding rows should always be followed by padding rows. - let current_any_flag = - builder.add_many_extension((0..NUM_ROUNDS).map(|i| local_values[reg_step(i)])); let constraint = builder.mul_sub_extension(next_any_flag, current_any_flag, next_any_flag); yield_constr.constraint_transition(builder, constraint); } diff --git a/mpt_trie/Cargo.toml b/mpt_trie/Cargo.toml index 8a4c2a41b..3951ef625 100644 --- a/mpt_trie/Cargo.toml +++ b/mpt_trie/Cargo.toml @@ -20,7 +20,7 @@ ethereum-types = { workspace = true } hex = { workspace = true } keccak-hash = { workspace = true } parking_lot = { version = "0.12.1", features = ["serde"] } -thiserror = "1.0.40" +thiserror = { workspace = true } log = { workspace = true } num = { workspace = true, optional = true } num-traits = "0.2.15" @@ -31,10 +31,9 @@ serde = { workspace = true, features = ["derive", "rc"] } [dev-dependencies] eth_trie = "0.4.0" pretty_env_logger = "0.5.0" -rand = "0.8.5" -rlp-derive = "0.1.0" -serde = { version = "1.0.160", features = ["derive"] } -serde_json = "1.0.96" +rand = { workspace = true } +rlp-derive = { workspace = true } +serde_json = { workspace = true } [features] default = ["trie_debug"] diff --git a/mpt_trie/src/debug_tools/common.rs b/mpt_trie/src/debug_tools/common.rs index b63bba989..cb38d5c47 100644 --- a/mpt_trie/src/debug_tools/common.rs +++ b/mpt_trie/src/debug_tools/common.rs @@ -1,65 +1,9 @@ -use std::fmt::{self, Display}; - +//! Common utilities for the debugging tools. use crate::{ - nibbles::{Nibble, Nibbles}, + nibbles::Nibbles, partial_trie::{Node, PartialTrie}, - utils::TrieNodeType, }; -#[derive(Clone, Debug, Eq, Hash, PartialEq)] -pub(super) enum PathSegment { - Empty, - Hash, - Branch(Nibble), - Extension(Nibbles), - Leaf(Nibbles), -} - -impl Display for PathSegment { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - PathSegment::Empty => write!(f, "Empty"), - PathSegment::Hash => write!(f, "Hash"), - PathSegment::Branch(nib) => write!(f, "Branch({})", nib), - PathSegment::Extension(nibs) => write!(f, "Extension({})", nibs), - PathSegment::Leaf(nibs) => write!(f, "Leaf({})", nibs), - } - } -} - -impl PathSegment { - pub(super) fn node_type(&self) -> TrieNodeType { - match self { - PathSegment::Empty => TrieNodeType::Empty, - PathSegment::Hash => TrieNodeType::Hash, - PathSegment::Branch(_) => TrieNodeType::Branch, - PathSegment::Extension(_) => TrieNodeType::Extension, - PathSegment::Leaf(_) => TrieNodeType::Leaf, - } - } - - pub(super) fn get_key_piece_from_seg_if_present(&self) -> Option { - match self { - PathSegment::Empty | PathSegment::Hash => None, - PathSegment::Branch(nib) => Some(Nibbles::from_nibble(*nib)), - PathSegment::Extension(nibs) | PathSegment::Leaf(nibs) => Some(*nibs), - } - } -} - -pub(super) fn get_segment_from_node_and_key_piece( - n: &Node, - k_piece: &Nibbles, -) -> PathSegment { - match TrieNodeType::from(n) { - TrieNodeType::Empty => PathSegment::Empty, - TrieNodeType::Hash => PathSegment::Hash, - TrieNodeType::Branch => PathSegment::Branch(k_piece.get_nibble(0)), - TrieNodeType::Extension => PathSegment::Extension(*k_piece), - TrieNodeType::Leaf => PathSegment::Leaf(*k_piece), - } -} - /// Get the key piece from the given node if applicable. /// /// Note that there is no specific [`Nibble`] associated with a branch like @@ -86,42 +30,3 @@ pub(super) fn get_key_piece_from_node(n: &Node) -> Nibbles { Node::Extension { nibbles, child: _ } | Node::Leaf { nibbles, value: _ } => *nibbles, } } - -#[derive(Clone, Debug, Default, Eq, Hash, PartialEq)] -pub struct NodePath(pub(super) Vec); - -impl NodePath { - pub(super) fn dup_and_append(&self, seg: PathSegment) -> Self { - let mut duped_vec = self.0.clone(); - duped_vec.push(seg); - - Self(duped_vec) - } - - pub(super) fn append(&mut self, seg: PathSegment) { - self.0.push(seg); - } - - fn write_elem(f: &mut fmt::Formatter<'_>, seg: &PathSegment) -> fmt::Result { - write!(f, "{}", seg) - } -} - -impl Display for NodePath { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let num_elems = self.0.len(); - - // For everything but the last elem. - for seg in self.0.iter().take(num_elems.saturating_sub(1)) { - Self::write_elem(f, seg)?; - write!(f, " --> ")?; - } - - // Avoid the extra `-->` for the last elem. - if let Some(seg) = self.0.last() { - Self::write_elem(f, seg)?; - } - - Ok(()) - } -} diff --git a/mpt_trie/src/debug_tools/diff.rs b/mpt_trie/src/debug_tools/diff.rs index b043d72e7..d6271cfb1 100644 --- a/mpt_trie/src/debug_tools/diff.rs +++ b/mpt_trie/src/debug_tools/diff.rs @@ -30,7 +30,8 @@ use std::{fmt::Display, ops::Deref}; use ethereum_types::H256; -use super::common::{get_key_piece_from_node, get_segment_from_node_and_key_piece, NodePath}; +use super::common::get_key_piece_from_node; +use crate::utils::{get_segment_from_node_and_key_piece, TriePath}; use crate::{ nibbles::Nibbles, partial_trie::{HashedPartialTrie, Node, PartialTrie}, @@ -38,7 +39,10 @@ use crate::{ }; #[derive(Debug, Eq, PartialEq)] +/// The difference between two Tries, represented as the highest +/// point of a structural divergence. pub struct TrieDiff { + /// The highest point of structural divergence. pub latest_diff_res: Option, // TODO: Later add a second pass for finding diffs from the bottom up (`earliest_diff_res`). } @@ -77,10 +81,15 @@ impl DiffDetectionState { /// A point (node) between the two tries where the children differ. #[derive(Clone, Debug, Eq, Hash, PartialEq)] pub struct DiffPoint { + /// The depth of the point in both tries. pub depth: usize, - pub path: NodePath, + /// The path of the point in both tries. + pub path: TriePath, + /// The node key in both tries. pub key: Nibbles, + /// The node info in the first trie. pub a_info: NodeInfo, + /// The node info in the second trie. pub b_info: NodeInfo, } @@ -89,7 +98,7 @@ impl DiffPoint { child_a: &HashedPartialTrie, child_b: &HashedPartialTrie, parent_k: Nibbles, - path: NodePath, + path: TriePath, ) -> Self { let a_key = parent_k.merge_nibbles(&get_key_piece_from_node(child_a)); let b_key = parent_k.merge_nibbles(&get_key_piece_from_node(child_b)); @@ -214,7 +223,7 @@ impl DepthNodeDiffState { parent_k: &Nibbles, child_a: &HashedPartialTrie, child_b: &HashedPartialTrie, - path: NodePath, + path: TriePath, ) { if field .as_ref() @@ -234,7 +243,7 @@ struct DepthDiffPerCallState<'a> { curr_depth: usize, // Horribly inefficient, but these are debug tools, so I think we get a pass. - curr_path: NodePath, + curr_path: TriePath, } impl<'a> DepthDiffPerCallState<'a> { @@ -251,7 +260,7 @@ impl<'a> DepthDiffPerCallState<'a> { b, curr_key, curr_depth, - curr_path: NodePath::default(), + curr_path: TriePath::default(), } } @@ -403,7 +412,7 @@ fn get_value_from_node(n: &Node) -> Option<&Vec> { #[cfg(test)] mod tests { - use super::{create_diff_between_tries, DiffPoint, NodeInfo, NodePath}; + use super::{create_diff_between_tries, DiffPoint, NodeInfo, TriePath}; use crate::{ nibbles::Nibbles, partial_trie::{HashedPartialTrie, PartialTrie}, @@ -439,7 +448,7 @@ mod tests { let expected = DiffPoint { depth: 0, - path: NodePath(vec![]), + path: TriePath(vec![]), key: Nibbles::default(), a_info: expected_a, b_info: expected_b, diff --git a/mpt_trie/src/debug_tools/query.rs b/mpt_trie/src/debug_tools/query.rs index e75c2ed76..dd7d3b654 100644 --- a/mpt_trie/src/debug_tools/query.rs +++ b/mpt_trie/src/debug_tools/query.rs @@ -5,13 +5,11 @@ use std::fmt::{self, Display}; use ethereum_types::H256; -use super::common::{ - get_key_piece_from_node_pulling_from_key_for_branches, get_segment_from_node_and_key_piece, - NodePath, PathSegment, -}; +use super::common::get_key_piece_from_node_pulling_from_key_for_branches; use crate::{ nibbles::Nibbles, partial_trie::{Node, PartialTrie, WrappedNode}, + utils::{get_segment_from_node_and_key_piece, TriePath, TrieSegment}, }; /// Params controlling how much information is reported in the query output. @@ -44,6 +42,7 @@ impl Default for DebugQueryParams { } #[derive(Debug, Default)] +/// A wrapper for `DebugQueryParams`. pub struct DebugQueryParamsBuilder { params: DebugQueryParams, } @@ -67,6 +66,7 @@ impl DebugQueryParamsBuilder { self } + /// Builds a new debug query for a given key. pub fn build>(self, k: K) -> DebugQuery { DebugQuery { k: k.into(), @@ -153,9 +153,13 @@ fn count_non_empty_branch_children_from_mask(mask: u16) -> usize { } #[derive(Clone, Debug)] +/// The result of a debug query contains information +/// of the path used for searching for a key in the trie. pub struct DebugQueryOutput { k: Nibbles, - node_path: NodePath, + + /// The nodes hit during the query. + pub node_path: TriePath, extra_node_info: Vec>, node_found: bool, params: DebugQueryParams, @@ -195,7 +199,7 @@ impl DebugQueryOutput { fn new(k: Nibbles, params: DebugQueryParams) -> Self { Self { k, - node_path: NodePath::default(), + node_path: TriePath::default(), extra_node_info: Vec::default(), node_found: false, params, @@ -215,7 +219,7 @@ impl DebugQueryOutput { // TODO: Make the output easier to read... fn fmt_node_based_on_debug_params( f: &mut fmt::Formatter<'_>, - seg: &PathSegment, + seg: &TrieSegment, extra_seg_info: &Option, params: &DebugQueryParams, ) -> fmt::Result { diff --git a/mpt_trie/src/debug_tools/stats.rs b/mpt_trie/src/debug_tools/stats.rs index 38429e6ed..6a23ab496 100644 --- a/mpt_trie/src/debug_tools/stats.rs +++ b/mpt_trie/src/debug_tools/stats.rs @@ -10,6 +10,8 @@ use num_traits::ToPrimitive; use crate::partial_trie::{Node, PartialTrie}; #[derive(Debug, Default)] +/// Statistics for a given trie, consisting of node count aggregated +/// by time, lowest depth and average depth of leaf and hash nodes. pub struct TrieStats { name: Option, counts: NodeCounts, @@ -31,6 +33,7 @@ impl Display for TrieStats { } impl TrieStats { + /// Compares with the statistics of another trie. pub fn compare(&self, other: &Self) -> TrieComparison { TrieComparison { node_comp: self.counts.compare(&other.counts), @@ -240,10 +243,13 @@ impl DepthStats { } } +/// Returns trie statistics consisting of node type counts as well as depth +/// statistics. pub fn get_trie_stats(trie: &T) -> TrieStats { get_trie_stats_common(trie, None) } +/// Returns trie statistics with a given name. pub fn get_trie_stats_with_name(trie: &T, name: String) -> TrieStats { get_trie_stats_common(trie, Some(name)) } diff --git a/mpt_trie/src/lib.rs b/mpt_trie/src/lib.rs index aa2df51e1..dc44cf80d 100644 --- a/mpt_trie/src/lib.rs +++ b/mpt_trie/src/lib.rs @@ -1,6 +1,7 @@ //! Utilities and types for working with Ethereum partial tries. //! -//! While there are other Ethereum trie libraries (such as [eth_trie](https://docs.rs/eth_trie/0.1.0/eth_trie), these libraries are not a good fit if: +//! While there are other Ethereum trie libraries (such as [eth_trie](https://docs.rs/eth_trie/0.1.0/eth_trie)), +//! these libraries are not a good fit if: //! - You only need a portion of an existing larger trie. //! - You need this partial trie to produce the same hash as the full trie. //! @@ -11,13 +12,17 @@ //! hash of the node it replaces. #![allow(incomplete_features)] +#![deny(rustdoc::broken_intra_doc_links)] +#![deny(missing_debug_implementations)] +#![deny(missing_docs)] pub mod nibbles; pub mod partial_trie; +pub mod special_query; mod trie_hashing; pub mod trie_ops; pub mod trie_subsets; -mod utils; +pub mod utils; #[cfg(feature = "trie_debug")] pub mod debug_tools; diff --git a/mpt_trie/src/nibbles.rs b/mpt_trie/src/nibbles.rs index c0b48e9bc..f5d63a043 100644 --- a/mpt_trie/src/nibbles.rs +++ b/mpt_trie/src/nibbles.rs @@ -1,3 +1,6 @@ +//! Define [`Nibbles`] and how to convert bytes, hex prefix encodings and +//! strings into nibbles. + use std::mem::size_of; use std::{ fmt::{self, Debug}, @@ -18,7 +21,9 @@ use uint::FromHexError; use crate::utils::{create_mask_of_1s, is_even}; // Use a whole byte for a Nibble just for convenience +/// A Nibble has 4 bits and is stored as `u8`. pub type Nibble = u8; +/// Used for the internal representation of a sequence of nibbles. pub type NibblesIntern = U512; /// Because there are two different ways to convert to `Nibbles`, we don't want @@ -46,25 +51,32 @@ pub trait ToNibbles { } #[derive(Debug, Error)] +/// Errors encountered when converting from `Bytes` to `Nibbles`. pub enum BytesToNibblesError { #[error("Tried constructing `Nibbles` from a zero byte slice")] + /// The size is zero. ZeroSizedKey, #[error("Tried constructing `Nibbles` from a byte slice with more than 33 bytes (len: {0})")] + /// The slice is too large. TooManyBytes(usize), } #[derive(Debug, Error)] +/// Errors encountered when converting to hex prefix encoding to nibbles. pub enum FromHexPrefixError { #[error("Tried to convert a hex prefix byte string into `Nibbles` with invalid flags at the start: {0:#04b}")] + /// The hex prefix encoding flag is invalid. InvalidFlags(Nibble), #[error("Tried to convert a hex prefix byte string into `Nibbles` that was longer than 64 bytes: (length: {0}, bytes: {1})")] + /// The hex prefix encoding is too large. TooLong(String, usize), } #[derive(Debug, Error)] #[error(transparent)] +/// An error encountered when converting a string to a sequence of nibbles. pub struct StrToNibblesError(#[from] FromHexError); /// The default conversion to nibbles will be to be precise down to the @@ -351,11 +363,11 @@ impl Nibbles { /// Appends `Nibbles` to the front. /// /// # Panics - /// Panics if appending the `Nibble` causes an overflow (total nibbles > + /// Panics if appending the `Nibbles` causes an overflow (total nibbles > /// 64). pub fn push_nibbles_front(&mut self, n: &Self) { let new_count = self.count + n.count; - assert!(new_count <= 64); + self.nibbles_append_safety_asserts(new_count); let shift_amt = 4 * self.count; @@ -363,6 +375,21 @@ impl Nibbles { self.packed |= n.packed << shift_amt; } + /// Appends `Nibbles` to the back. + /// + /// # Panics + /// Panics if appending the `Nibbles` causes an overflow (total nibbles > + /// 64). + pub fn push_nibbles_back(&mut self, n: &Self) { + let new_count = self.count + n.count; + self.nibbles_append_safety_asserts(new_count); + + let shift_amt = 4 * n.count; + + self.count = new_count; + self.packed = (self.packed << shift_amt) | n.packed; + } + /// Gets the nibbles at the range specified, where `0` is the next nibble. /// /// # Panics @@ -723,6 +750,7 @@ impl Nibbles { // TODO: Make not terrible at some point... Consider moving away from `U256` // internally? + /// Returns the nibbles bytes in big-endian format. pub fn bytes_be(&self) -> Vec { let mut byte_buf = [0; 64]; self.packed.to_big_endian(&mut byte_buf); @@ -752,7 +780,12 @@ impl Nibbles { assert!(n < 16); } + fn nibbles_append_safety_asserts(&self, new_count: usize) { + assert!(new_count <= 64); + } + // TODO: REMOVE BEFORE NEXT CRATE VERSION! THIS IS A TEMP HACK! + /// Converts to u256 returning an error if not possible. pub fn try_into_u256(&self) -> Result { match self.count <= 64 { false => Err(format!( @@ -773,6 +806,9 @@ mod tests { use super::{Nibble, Nibbles, ToNibbles}; use crate::nibbles::FromHexPrefixError; + const LONG_ZERO_NIBS_STR_LEN_63: &str = + "0x000000000000000000000000000000000000000000000000000000000000000"; + #[test] fn get_nibble_works() { let n = Nibbles::from(0x1234); @@ -884,6 +920,69 @@ mod tests { assert_eq!(res, expected_resulting_nibbles); } + #[test] + fn push_nibble_front_works() { + test_and_assert_nib_push_func(Nibbles::default(), 0x1, |n| n.push_nibble_front(0x1)); + test_and_assert_nib_push_func(0x1, 0x21, |n| n.push_nibble_front(0x2)); + test_and_assert_nib_push_func( + Nibbles::from_str(LONG_ZERO_NIBS_STR_LEN_63).unwrap(), + Nibbles::from_str("0x1000000000000000000000000000000000000000000000000000000000000000") + .unwrap(), + |n| n.push_nibble_front(0x1), + ); + } + + #[test] + fn push_nibble_back_works() { + test_and_assert_nib_push_func(Nibbles::default(), 0x1, |n| n.push_nibble_back(0x1)); + test_and_assert_nib_push_func(0x1, 0x12, |n| n.push_nibble_back(0x2)); + test_and_assert_nib_push_func( + Nibbles::from_str(LONG_ZERO_NIBS_STR_LEN_63).unwrap(), + Nibbles::from_str("0x0000000000000000000000000000000000000000000000000000000000000001") + .unwrap(), + |n| n.push_nibble_back(0x1), + ); + } + + #[test] + fn push_nibbles_front_works() { + test_and_assert_nib_push_func(Nibbles::default(), 0x1234, |n| { + n.push_nibbles_front(&0x1234.into()) + }); + test_and_assert_nib_push_func(0x1234, 0x5671234, |n| n.push_nibbles_front(&0x567.into())); + test_and_assert_nib_push_func( + Nibbles::from_str(LONG_ZERO_NIBS_STR_LEN_63).unwrap(), + Nibbles::from_str("0x1000000000000000000000000000000000000000000000000000000000000000") + .unwrap(), + |n| n.push_nibbles_front(&0x1.into()), + ); + } + + #[test] + fn push_nibbles_back_works() { + test_and_assert_nib_push_func(Nibbles::default(), 0x1234, |n| { + n.push_nibbles_back(&0x1234.into()) + }); + test_and_assert_nib_push_func(0x1234, 0x1234567, |n| n.push_nibbles_back(&0x567.into())); + test_and_assert_nib_push_func( + Nibbles::from_str(LONG_ZERO_NIBS_STR_LEN_63).unwrap(), + Nibbles::from_str("0x0000000000000000000000000000000000000000000000000000000000000001") + .unwrap(), + |n| n.push_nibbles_back(&0x1.into()), + ); + } + + fn test_and_assert_nib_push_func, E: Into>( + starting_nibs: S, + expected: E, + f: F, + ) { + let mut nibs = starting_nibs.into(); + (f)(&mut nibs); + + assert_eq!(nibs, expected.into()); + } + #[test] fn get_next_nibbles_works() { let n: Nibbles = 0x1234.into(); @@ -1165,7 +1264,7 @@ mod tests { fn nibbles_from_h256_works() { assert_eq!( format!("{:x}", Nibbles::from_h256_be(H256::from_low_u64_be(0))), - "0x0000000000000000000000000000000000000000000000000000000000000000" + "0x0000000000000000000000000000000000000000000000000000000000000000", ); assert_eq!( format!("{:x}", Nibbles::from_h256_be(H256::from_low_u64_be(2048))), diff --git a/mpt_trie/src/partial_trie.rs b/mpt_trie/src/partial_trie.rs index dca6df6a1..c1a04c4bf 100644 --- a/mpt_trie/src/partial_trie.rs +++ b/mpt_trie/src/partial_trie.rs @@ -54,6 +54,7 @@ pub trait PartialTrie: + TrieNodeIntern + Sized { + /// Creates a new partial trie from a node. fn new(n: Node) -> Self; /// Inserts a node into the trie. @@ -111,6 +112,7 @@ pub trait PartialTrie: /// Part of the trait that is not really part of the public interface but /// implementor of other node types still need to implement. pub trait TrieNodeIntern { + /// Returns the hash of the rlp encoding of self. fn hash_intern(&self) -> EncodedNode; } @@ -134,17 +136,26 @@ where Hash(H256), /// A branch node, which consists of 16 children and an optional value. Branch { + /// A slice containing the 16 children of this branch node. children: [WrappedNode; 16], + /// The payload of this node. value: Vec, }, /// An extension node, which consists of a list of nibbles and a single /// child. Extension { + /// The path of this extension. nibbles: Nibbles, + /// The child of this extension node. child: WrappedNode, }, /// A leaf node, which consists of a list of nibbles and a value. - Leaf { nibbles: Nibbles, value: Vec }, + Leaf { + /// The path of this leaf node. + nibbles: Nibbles, + /// The payload of this node + value: Vec, + }, } impl Eq for Node {} diff --git a/mpt_trie/src/special_query.rs b/mpt_trie/src/special_query.rs new file mode 100644 index 000000000..503331aa0 --- /dev/null +++ b/mpt_trie/src/special_query.rs @@ -0,0 +1,156 @@ +//! Specialized queries that users of the library may need that require +//! knowledge of the private internal trie state. + +use crate::{ + nibbles::Nibbles, + partial_trie::{Node, PartialTrie, WrappedNode}, + utils::TrieSegment, +}; + +/// An iterator for a trie query. Note that this iterator is lazy. +#[derive(Debug)] +pub struct TriePathIter { + /// The next node in the trie to query with the remaining key. + curr_node: WrappedNode, + + /// The remaining part of the key as we traverse down the trie. + curr_key: Nibbles, + + // Although wrapping `curr_node` in an option might be more "Rust like", the logic is a lot + // cleaner with a bool. + terminated: bool, +} + +impl Iterator for TriePathIter { + type Item = TrieSegment; + + fn next(&mut self) -> Option { + if self.terminated { + return None; + } + + match self.curr_node.as_ref() { + Node::Empty => { + self.terminated = true; + Some(TrieSegment::Empty) + } + Node::Hash(_) => { + self.terminated = true; + Some(TrieSegment::Hash) + } + Node::Branch { children, .. } => { + // Our query key has ended. Stop here. + if self.curr_key.is_empty() { + self.terminated = true; + return None; + } + + let nib = self.curr_key.pop_next_nibble_front(); + self.curr_node = children[nib as usize].clone(); + + Some(TrieSegment::Branch(nib)) + } + Node::Extension { nibbles, child } => { + match self + .curr_key + .nibbles_are_identical_up_to_smallest_count(nibbles) + { + false => { + // Only a partial match. Stop. + self.terminated = true; + None + } + true => { + pop_nibbles_clamped(&mut self.curr_key, nibbles.count); + let res = Some(TrieSegment::Extension(*nibbles)); + self.curr_node = child.clone(); + + res + } + } + } + Node::Leaf { nibbles, .. } => { + self.terminated = true; + + match self.curr_key == *nibbles { + false => None, + true => Some(TrieSegment::Leaf(*nibbles)), + } + } + } + } +} + +/// Attempts to pop `n` nibbles from the given [`Nibbles`] and "clamp" the +/// nibbles popped by not popping more nibbles than there are. +fn pop_nibbles_clamped(nibbles: &mut Nibbles, n: usize) -> Nibbles { + let n_nibs_to_pop = nibbles.count.min(n); + nibbles.pop_nibbles_front(n_nibs_to_pop) +} + +/// Returns all nodes in the trie that are traversed given a query (key). +/// +/// Note that if the key does not match the entire key of a node (eg. the +/// remaining key is `0x34` but the next key is a leaf with the key `0x3456`), +/// then the leaf will not appear in the query output. +pub fn path_for_query(trie: &Node, k: K) -> TriePathIter +where + K: Into, +{ + TriePathIter { + curr_node: trie.clone().into(), + curr_key: k.into(), + terminated: false, + } +} + +#[cfg(test)] +mod test { + use std::str::FromStr; + + use super::path_for_query; + use crate::{nibbles::Nibbles, testing_utils::handmade_trie_1, utils::TrieSegment}; + + #[test] + fn query_iter_works() { + let (trie, ks) = handmade_trie_1(); + + // ks --> vec![0x1234, 0x1324, 0x132400005_u64, 0x2001, 0x2002]; + let res = vec![ + vec![ + TrieSegment::Branch(1), + TrieSegment::Branch(2), + TrieSegment::Leaf(0x34.into()), + ], + vec![ + TrieSegment::Branch(1), + TrieSegment::Branch(3), + TrieSegment::Extension(0x24.into()), + ], + vec![ + TrieSegment::Branch(1), + TrieSegment::Branch(3), + TrieSegment::Extension(0x24.into()), + TrieSegment::Branch(0), + TrieSegment::Leaf(Nibbles::from_str("0x0005").unwrap()), + ], + vec![ + TrieSegment::Branch(2), + TrieSegment::Extension(Nibbles::from_str("0x00").unwrap()), + TrieSegment::Branch(0x1), + TrieSegment::Leaf(Nibbles::default()), + ], + vec![ + TrieSegment::Branch(2), + TrieSegment::Extension(Nibbles::from_str("0x00").unwrap()), + TrieSegment::Branch(0x2), + TrieSegment::Leaf(Nibbles::default()), + ], + ]; + + for (q, expected) in ks.into_iter().zip(res.into_iter()) { + let res: Vec<_> = path_for_query(&trie.node, q).collect(); + assert_eq!(res, expected) + } + } +} diff --git a/mpt_trie/src/trie_ops.rs b/mpt_trie/src/trie_ops.rs index 7883e475c..3fe632d3a 100644 --- a/mpt_trie/src/trie_ops.rs +++ b/mpt_trie/src/trie_ops.rs @@ -152,6 +152,8 @@ struct BranchStackEntry { } #[derive(Debug)] +/// An iterator that ranges over all the leafs and hash nodes +/// of the trie, in lexicographic order. pub struct PartialTrieIter { curr_key_after_last_branch: Nibbles, trie_stack: Vec>, diff --git a/mpt_trie/src/trie_subsets.rs b/mpt_trie/src/trie_subsets.rs index aa5f805e9..974b92afc 100644 --- a/mpt_trie/src/trie_subsets.rs +++ b/mpt_trie/src/trie_subsets.rs @@ -8,6 +8,7 @@ use std::sync::Arc; use ethereum_types::H256; +use log::trace; use thiserror::Error; use crate::{ @@ -17,12 +18,14 @@ use crate::{ utils::{bytes_to_h256, TrieNodeType}, }; +/// The output type of trie_subset operations. pub type SubsetTrieResult = Result; /// Errors that may occur when creating a subset [`PartialTrie`]. #[derive(Debug, Error)] pub enum SubsetTrieError { #[error("Tried to mark nodes in a tracked trie for a key that does not exist! (Key: {0}, trie: {1})")] + /// The key does not exist in the trie. UnexpectedKey(Nibbles, String), } @@ -204,6 +207,7 @@ impl TrackedNodeInfo { } } +// TODO: Make this interface also work with &[ ... ]... /// Create a [`PartialTrie`] subset from a base trie given an iterator of keys /// of nodes that may or may not exist in the trie. All nodes traversed by the /// keys will not be hashed out in the trie subset. If the key does not exist in @@ -219,6 +223,7 @@ where create_trie_subset_intern(&mut tracked_trie, keys_involved.into_iter()) } +// TODO: Make this interface also work with &[ ... ]... /// Create [`PartialTrie`] subsets from a given base `PartialTrie` given a /// iterator of keys per subset needed. See [`create_trie_subset`] for more /// info. @@ -257,42 +262,68 @@ where Ok(create_partial_trie_subset_from_tracked_trie(tracked_trie)) } +/// For a given key, mark every node that we encounter that is part of the key. +/// Note that this means non-existent keys passed into this function will mark +/// nodes to not be hashed that are part of the given key. For example: +/// - Relevant nodes in trie: [B(0x), B(0x1), L(0x123)] +/// - For the key `0x1`, the marked nodes would be [B(0x), B(0x1)]. +/// - For the key `0x12`, the marked nodes still would be [B(0x), B(0x1)]. +/// - For the key `0x123`, the marked nodes would be [B(0x), B(0x1), L(0x123)]. fn mark_nodes_that_are_needed( trie: &mut TrackedNode, curr_nibbles: &mut Nibbles, ) -> SubsetTrieResult<()> { - trie.info.touched = true; + trace!( + "Sub-trie marking at {:x}, (type: {})", + curr_nibbles, + TrieNodeType::from(trie.info.underlying_node.deref()) + ); match &mut trie.node { - TrackedNodeIntern::Empty => Ok(()), + TrackedNodeIntern::Empty => { + trie.info.touched = true; + } TrackedNodeIntern::Hash => match curr_nibbles.is_empty() { - false => Err(SubsetTrieError::UnexpectedKey( - *curr_nibbles, - format!("{:?}", trie), - )), - true => Ok(()), + false => { + return Err(SubsetTrieError::UnexpectedKey( + *curr_nibbles, + format!("{:?}", trie), + )) + } + true => { + trie.info.touched = true; + } }, // Note: If we end up supporting non-fixed sized keys, then we need to also check value. TrackedNodeIntern::Branch(children) => { + trie.info.touched = true; + // Check against branch value. if curr_nibbles.is_empty() { return Ok(()); } let nib = curr_nibbles.pop_next_nibble_front(); - mark_nodes_that_are_needed(&mut children[nib as usize], curr_nibbles) + return mark_nodes_that_are_needed(&mut children[nib as usize], curr_nibbles); } TrackedNodeIntern::Extension(child) => { let nibbles = trie.info.get_nibbles_expected(); let r = curr_nibbles.pop_nibbles_front(nibbles.count); - match r.nibbles_are_identical_up_to_smallest_count(nibbles) { - false => Ok(()), - true => mark_nodes_that_are_needed(child, curr_nibbles), + if r.nibbles_are_identical_up_to_smallest_count(nibbles) { + trie.info.touched = true; + return mark_nodes_that_are_needed(child, curr_nibbles); + } + } + TrackedNodeIntern::Leaf => { + let (k, _) = trie.info.get_leaf_nibbles_and_value_expected(); + if k == curr_nibbles { + trie.info.touched = true; } } - TrackedNodeIntern::Leaf => Ok(()), } + + Ok(()) } fn create_partial_trie_subset_from_tracked_trie( @@ -331,10 +362,14 @@ fn create_partial_trie_subset_from_tracked_trie( fn reset_tracked_trie_state(tracked_node: &mut TrackedNode) { match tracked_node.node { - TrackedNodeIntern::Branch(ref mut children) => { - children.iter_mut().for_each(|c| c.info.reset()) + TrackedNodeIntern::Branch(ref mut children) => children.iter_mut().for_each(|c| { + c.info.reset(); + reset_tracked_trie_state(c); + }), + TrackedNodeIntern::Extension(ref mut child) => { + child.info.reset(); + reset_tracked_trie_state(child); } - TrackedNodeIntern::Extension(ref mut child) => child.info.reset(), TrackedNodeIntern::Empty | TrackedNodeIntern::Hash | TrackedNodeIntern::Leaf => { tracked_node.info.reset() } @@ -343,17 +378,20 @@ fn reset_tracked_trie_state(tracked_node: &mut TrackedNode) { #[cfg(test)] mod tests { - use std::{collections::HashSet, iter::once}; + use std::{ + collections::{HashMap, HashSet}, + iter::once, + }; use ethereum_types::H256; use super::{create_trie_subset, create_trie_subsets}; use crate::{ nibbles::Nibbles, - partial_trie::{HashedPartialTrie, Node, PartialTrie}, + partial_trie::{Node, PartialTrie}, testing_utils::{ - create_trie_with_large_entry_nodes, generate_n_random_fixed_trie_value_entries, - handmade_trie_1, TrieType, + common_setup, create_trie_with_large_entry_nodes, + generate_n_random_fixed_trie_value_entries, handmade_trie_1, TrieType, }, trie_ops::ValOrHash, utils::TrieNodeType, @@ -384,36 +422,56 @@ mod tests { } } + fn get_all_nodes_in_trie(trie: &TrieType) -> Vec { + get_nodes_in_trie_intern(trie, false) + } + fn get_all_non_empty_and_hash_nodes_in_trie(trie: &TrieType) -> Vec { + get_nodes_in_trie_intern(trie, true) + } + + fn get_nodes_in_trie_intern( + trie: &TrieType, + return_on_empty_or_hash: bool, + ) -> Vec { let mut nodes = Vec::new(); - get_all_non_empty_and_non_hash_nodes_in_trie_intern(trie, Nibbles::default(), &mut nodes); + get_nodes_in_trie_intern_rec( + trie, + Nibbles::default(), + &mut nodes, + return_on_empty_or_hash, + ); nodes } - fn get_all_non_empty_and_non_hash_nodes_in_trie_intern( + fn get_nodes_in_trie_intern_rec( trie: &TrieType, mut curr_nibbles: Nibbles, nodes: &mut Vec, + return_on_empty_or_hash: bool, ) { match &trie.node { - Node::Empty | Node::Hash(_) => return, + Node::Empty | Node::Hash(_) => match return_on_empty_or_hash { + false => (), + true => return, + }, Node::Branch { children, .. } => { for (i, c) in children.iter().enumerate() { - get_all_non_empty_and_non_hash_nodes_in_trie_intern( + get_nodes_in_trie_intern_rec( c, curr_nibbles.merge_nibble(i as u8), nodes, + return_on_empty_or_hash, ) } } - Node::Extension { nibbles, child } => { - get_all_non_empty_and_non_hash_nodes_in_trie_intern( - child, - curr_nibbles.merge_nibbles(nibbles), - nodes, - ) - } + Node::Extension { nibbles, child } => get_nodes_in_trie_intern_rec( + child, + curr_nibbles.merge_nibbles(nibbles), + nodes, + return_on_empty_or_hash, + ), Node::Leaf { nibbles, .. } => curr_nibbles = curr_nibbles.merge_nibbles(nibbles), }; @@ -428,6 +486,8 @@ mod tests { #[test] fn empty_trie_does_not_return_err_on_query() { + common_setup(); + let trie = TrieType::default(); let nibbles: Nibbles = 0x1234.into(); let res = create_trie_subset(&trie, once(nibbles)); @@ -437,6 +497,8 @@ mod tests { #[test] fn non_existent_key_does_not_return_err() { + common_setup(); + let mut trie = TrieType::default(); trie.insert(0x1234, vec![0, 1, 2]); let res = create_trie_subset(&trie, once(0x5678)); @@ -446,7 +508,9 @@ mod tests { #[test] fn encountering_a_hash_node_returns_err() { - let trie = HashedPartialTrie::new(Node::Hash(H256::zero())); + common_setup(); + + let trie = TrieType::new(Node::Hash(H256::zero())); let res = create_trie_subset(&trie, once(0x1234)); assert!(res.is_err()) @@ -454,6 +518,8 @@ mod tests { #[test] fn single_node_trie_is_queryable() { + common_setup(); + let mut trie = TrieType::default(); trie.insert(0x1234, vec![0, 1, 2]); let trie_subset = create_trie_subset(&trie, once(0x1234)).unwrap(); @@ -463,6 +529,8 @@ mod tests { #[test] fn multi_node_trie_returns_proper_subset() { + common_setup(); + let trie = create_trie_with_large_entry_nodes(&[0x1234, 0x56, 0x12345_u64]); let trie_subset = create_trie_subset(&trie, vec![0x1234, 0x56]).unwrap(); @@ -475,6 +543,8 @@ mod tests { #[test] fn intermediate_nodes_are_included_in_subset() { + common_setup(); + let (trie, ks_nibbles) = handmade_trie_1(); let trie_subset_all = create_trie_subset(&trie, ks_nibbles.iter().cloned()).unwrap(); @@ -571,8 +641,55 @@ mod tests { ))); } + fn assert_nodes_are_leaf_nodes, I: IntoIterator>( + trie: &TrieType, + keys: I, + ) { + assert_keys_point_to_nodes_of_type( + trie, + keys.into_iter().map(|k| (k.into(), TrieNodeType::Leaf)), + ) + } + + fn assert_nodes_are_hash_nodes, I: IntoIterator>( + trie: &TrieType, + keys: I, + ) { + assert_keys_point_to_nodes_of_type( + trie, + keys.into_iter().map(|k| (k.into(), TrieNodeType::Hash)), + ) + } + + fn assert_keys_point_to_nodes_of_type( + trie: &TrieType, + keys: impl Iterator, + ) { + let nodes = get_all_nodes_in_trie(trie); + let keys_to_node_types: HashMap<_, _> = + HashMap::from_iter(nodes.into_iter().map(|n| (n.nibbles.reverse(), n.n_type))); + + for (k, expected_n_type) in keys { + let actual_n_type_opt = keys_to_node_types.get(&k); + + match actual_n_type_opt { + Some(actual_n_type) => { + if *actual_n_type != expected_n_type { + panic!("Expected trie node at {:x} to be a {} node but it wasn't! (found a {} node instead)", k, expected_n_type, actual_n_type) + } + } + None => panic!( + "Expected a {} node at {:x} but no node was found!", + expected_n_type, k + ), + } + } + } + #[test] fn all_leafs_of_keys_to_create_subset_are_included_in_subset_for_giant_trie() { + common_setup(); + let (_, trie_subsets, keys_of_subsets) = create_massive_trie_and_subsets(9009); for (sub_trie, ks_used) in trie_subsets.into_iter().zip(keys_of_subsets.into_iter()) { @@ -583,17 +700,38 @@ mod tests { #[test] fn hash_of_single_leaf_trie_partial_trie_matches_original_trie() { - let mut trie = TrieType::default(); - trie.insert(0x1234, vec![0]); + let trie = create_trie_with_large_entry_nodes(&[0x0]); let base_hash = trie.hash(); - let partial_trie = create_trie_subset(&trie, vec![0x1234]).unwrap(); + let partial_trie = create_trie_subset(&trie, [0x1234]).unwrap(); assert_eq!(base_hash, partial_trie.hash()); } + #[test] + fn sub_trie_that_includes_branch_but_not_children_hashes_out_children() { + common_setup(); + + let trie = create_trie_with_large_entry_nodes(&[0x1234, 0x12345, 0x12346, 0x1234f]); + let partial_trie = create_trie_subset(&trie, [0x1234f]).unwrap(); + + assert_nodes_are_hash_nodes(&partial_trie, [0x12345, 0x12346]); + } + + #[test] + fn sub_trie_for_non_existent_key_that_hits_branch_leaf_hashes_out_leaf() { + common_setup(); + + let trie = create_trie_with_large_entry_nodes(&[0x1234, 0x1234589, 0x12346]); + let partial_trie = create_trie_subset(&trie, [0x1234567]).unwrap(); + + // Note that `0x1234589` gets hashed at the branch slot at `0x12345`. + assert_nodes_are_hash_nodes(&partial_trie, [0x12345, 0x12346]); + } + #[test] fn hash_of_branch_partial_tries_matches_original_trie() { + common_setup(); let trie = create_trie_with_large_entry_nodes(&[0x1234, 0x56, 0x12345]); let base_hash: H256 = trie.hash(); @@ -611,6 +749,8 @@ mod tests { #[test] fn hash_of_giant_random_partial_tries_matches_original_trie() { + common_setup(); + let (base_trie, trie_subsets, _) = create_massive_trie_and_subsets(9010); let base_hash = base_trie.hash(); @@ -619,6 +759,37 @@ mod tests { .all(|p_tree| p_tree.hash() == base_hash)) } + #[test] + fn giant_random_partial_tries_hashes_leaves_correctly() { + common_setup(); + + let (base_trie, trie_subsets, leaf_keys_per_trie) = create_massive_trie_and_subsets(9011); + let all_keys: Vec = base_trie.keys().collect(); + + for (partial_trie, leaf_trie_keys) in + trie_subsets.into_iter().zip(leaf_keys_per_trie.into_iter()) + { + let leaf_keys_lookup: HashSet = leaf_trie_keys.iter().cloned().collect(); + let keys_of_hash_nodes = all_keys + .iter() + .filter(|k| !leaf_keys_lookup.contains(k)) + .cloned(); + + assert_nodes_are_leaf_nodes(&partial_trie, leaf_trie_keys); + + // We have no idea were the paths to the hashed out nodes will start in the + // trie, so the best we can do is to check that they don't exist (if we traverse + // over a `Hash` node, we return `None`.) + assert_all_keys_do_not_exist(&partial_trie, keys_of_hash_nodes); + } + } + + fn assert_all_keys_do_not_exist(trie: &TrieType, ks: impl Iterator) { + for k in ks { + assert!(trie.get(k).is_none()); + } + } + fn create_massive_trie_and_subsets(seed: u64) -> (TrieType, Vec, Vec>) { let trie_size = MASSIVE_TEST_NUM_SUB_TRIES * MASSIVE_TEST_NUM_SUB_TRIE_SIZE; diff --git a/mpt_trie/src/utils.rs b/mpt_trie/src/utils.rs index 9b87d8606..59400ee82 100644 --- a/mpt_trie/src/utils.rs +++ b/mpt_trie/src/utils.rs @@ -1,17 +1,36 @@ -use std::{fmt::Display, ops::BitAnd, sync::Arc}; +//! Various types and logic that don't fit well into any other module. + +use std::{ + borrow::Borrow, + fmt::{self, Display}, + ops::BitAnd, + sync::Arc, +}; use ethereum_types::{H256, U512}; use num_traits::PrimInt; -use crate::partial_trie::{Node, PartialTrie}; +use crate::{ + nibbles::{Nibble, Nibbles}, + partial_trie::{Node, PartialTrie}, +}; #[derive(Clone, Debug, Eq, Hash, PartialEq)] /// Simplified trie node type to make logging cleaner. -pub(crate) enum TrieNodeType { +pub enum TrieNodeType { + /// Empty node. Empty, + + /// Hash node. Hash, + + /// Branch node. Branch, + + /// Extension node. Extension, + + /// Leaf node. Leaf, } @@ -58,3 +77,190 @@ pub(crate) fn create_mask_of_1s(amt: usize) -> U512 { pub(crate) fn bytes_to_h256(b: &[u8; 32]) -> H256 { keccak_hash::H256::from_slice(b) } + +/// Minimal key information of "segments" (nodes) used to construct trie +/// "traces" of a trie query. Unlike [`TrieNodeType`], this type also contains +/// the key piece of the node if applicable (eg. [`Node::Empty`] & +/// [`Node::Hash`] do not have associated key pieces). +#[derive(Clone, Debug, Eq, Hash, PartialEq)] +pub enum TrieSegment { + /// Empty node. + Empty, + + /// Hash node. + Hash, + + /// Branch node along with the nibble of the child taken. + Branch(Nibble), + + /// Extension node along with the key piece of the node. + Extension(Nibbles), + + /// Leaf node along wth the key piece of the node. + Leaf(Nibbles), +} + +/// Trait for a type that can be converted into a trie key ([`Nibbles`]). +pub trait IntoTrieKey { + /// Reconstruct the key of the type. + fn into_key(self) -> Nibbles; +} + +impl, T: Iterator> IntoTrieKey for T { + fn into_key(self) -> Nibbles { + let mut key = Nibbles::default(); + + for seg in self { + match seg.borrow() { + TrieSegment::Empty | TrieSegment::Hash => (), + TrieSegment::Branch(nib) => key.push_nibble_back(*nib), + TrieSegment::Extension(nibs) | TrieSegment::Leaf(nibs) => { + key.push_nibbles_back(nibs) + } + } + } + + key + } +} + +/// A vector of path segments representing a path in the trie. +#[derive(Clone, Debug, Default, Eq, Hash, PartialEq)] +pub struct TriePath(pub Vec); + +impl Display for TriePath { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let num_elems = self.0.len(); + + // For everything but the last elem. + for seg in self.0.iter().take(num_elems.saturating_sub(1)) { + Self::write_elem(f, seg)?; + write!(f, " --> ")?; + } + + // Avoid the extra `-->` for the last elem. + if let Some(seg) = self.0.last() { + Self::write_elem(f, seg)?; + } + + Ok(()) + } +} + +impl IntoIterator for TriePath { + type Item = TrieSegment; + + type IntoIter = as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl From> for TriePath { + fn from(v: Vec) -> Self { + Self(v) + } +} + +impl FromIterator for TriePath { + fn from_iter>(iter: T) -> Self { + Self(Vec::from_iter(iter)) + } +} + +impl TriePath { + /// Get an iterator of the individual path segments in the [`TriePath`]. + pub fn iter(&self) -> impl Iterator { + self.0.iter() + } + + pub(crate) fn dup_and_append(&self, seg: TrieSegment) -> Self { + let mut duped_vec = self.0.clone(); + duped_vec.push(seg); + + Self(duped_vec) + } + + pub(crate) fn append(&mut self, seg: TrieSegment) { + self.0.push(seg); + } + + fn write_elem(f: &mut fmt::Formatter<'_>, seg: &TrieSegment) -> fmt::Result { + write!(f, "{}", seg) + } +} + +impl Display for TrieSegment { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TrieSegment::Empty => write!(f, "Empty"), + TrieSegment::Hash => write!(f, "Hash"), + TrieSegment::Branch(nib) => write!(f, "Branch({})", nib), + TrieSegment::Extension(nibs) => write!(f, "Extension({})", nibs), + TrieSegment::Leaf(nibs) => write!(f, "Leaf({})", nibs), + } + } +} + +impl TrieSegment { + /// Get the node type of the [`TrieSegment`]. + pub fn node_type(&self) -> TrieNodeType { + match self { + TrieSegment::Empty => TrieNodeType::Empty, + TrieSegment::Hash => TrieNodeType::Hash, + TrieSegment::Branch(_) => TrieNodeType::Branch, + TrieSegment::Extension(_) => TrieNodeType::Extension, + TrieSegment::Leaf(_) => TrieNodeType::Leaf, + } + } + + /// Extracts the key piece used by the segment (if applicable). + pub fn get_key_piece_from_seg_if_present(&self) -> Option { + match self { + TrieSegment::Empty | TrieSegment::Hash => None, + TrieSegment::Branch(nib) => Some(Nibbles::from_nibble(*nib)), + TrieSegment::Extension(nibs) | TrieSegment::Leaf(nibs) => Some(*nibs), + } + } +} + +/// Creates a [`TrieSegment`] given a node and a key we are querying. +/// +/// This function is intended to be used during a trie query as we are +/// traversing down a trie. Depending on the current node, we pop off nibbles +/// and use these to create `TrieSegment`s. +pub fn get_segment_from_node_and_key_piece( + n: &Node, + k_piece: &Nibbles, +) -> TrieSegment { + match TrieNodeType::from(n) { + TrieNodeType::Empty => TrieSegment::Empty, + TrieNodeType::Hash => TrieSegment::Hash, + TrieNodeType::Branch => TrieSegment::Branch(k_piece.get_nibble(0)), + TrieNodeType::Extension => TrieSegment::Extension(*k_piece), + TrieNodeType::Leaf => TrieSegment::Leaf(*k_piece), + } +} + +#[cfg(test)] +mod tests { + use std::str::FromStr; + + use super::{IntoTrieKey, TriePath, TrieSegment}; + use crate::nibbles::Nibbles; + + #[test] + fn path_from_query_works() { + let query_path: TriePath = vec![ + TrieSegment::Branch(1), + TrieSegment::Branch(2), + TrieSegment::Extension(0x34.into()), + TrieSegment::Leaf(0x567.into()), + ] + .into(); + + let reconstructed_key = query_path.iter().into_key(); + assert_eq!(reconstructed_key, Nibbles::from_str("0x1234567").unwrap()); + } +} diff --git a/proof_gen/Cargo.toml b/proof_gen/Cargo.toml index 019bcbd9f..4b980f45f 100644 --- a/proof_gen/Cargo.toml +++ b/proof_gen/Cargo.toml @@ -13,7 +13,7 @@ keywords.workspace = true ethereum-types = { workspace = true } log = { workspace = true } paste = "1.0.14" -plonky2 = { git = "https://github.com/0xPolygonZero/plonky2.git", rev = "710225c9e0ac5822b2965ce74951cf000bbb8a2c" } +plonky2 = { workspace = true } serde = { workspace = true } # Local dependencies diff --git a/proof_gen/README.md b/proof_gen/README.md index b16e55589..03bb6cfa1 100644 --- a/proof_gen/README.md +++ b/proof_gen/README.md @@ -1,9 +1,7 @@ -# Plonky block proof generator +# Proof generator Library for generating proofs from proof IR. -For the time being, the only library that produces proof IR is currently [plonky-edge-block-trace-parser](https://github.com/0xPolygonZero/plonky-edge-block-trace-parser). Down the road, the IR will be produced by decoding the proof gen protocol. - # General Usage (Extremely rough, will change) diff --git a/trace_decoder/Cargo.toml b/trace_decoder/Cargo.toml index c48892cc5..565aa666c 100644 --- a/trace_decoder/Cargo.toml +++ b/trace_decoder/Cargo.toml @@ -1,5 +1,6 @@ [package] name = "trace_decoder" +description = "Processes trace payloads into Intermediate Representation (IR) format." authors = ["Polygon Zero "] version = "0.1.0" edition.workspace = true @@ -23,7 +24,7 @@ rlp = { workspace = true } rlp-derive = { workspace = true } serde = { workspace = true } serde_with = "3.4.0" -thiserror = "1.0.49" +thiserror = { workspace = true } # Local dependencies mpt_trie = { version = "0.1.0", path = "../mpt_trie" } diff --git a/trace_decoder/README.md b/trace_decoder/README.md index e83296a55..948caa603 100644 --- a/trace_decoder/README.md +++ b/trace_decoder/README.md @@ -1,10 +1,10 @@ -# Proof protocol decoder +# Trace decoder A flexible protocol that clients (eg. full nodes) can use to easily generate block proofs for different chains. ## Specification -Temporary [high-level overview and comparison](docs/usage_seq_diagrams.md) to what the old Edge setup used to look like. The specification itself is in the repo [here](trace_decoder/src/trace_protocol.rs). +Temporary [high-level overview](docs/usage_seq_diagrams.md). The specification itself is in the repo [here](trace_decoder/src/trace_protocol.rs). Because processing the incoming proof protocol payload is not a resource bottleneck, the design is not worrying too much about performance. Instead, the core focus is flexibility in clients creating their own implementation, where the protocol supports multiple ways to provide different pieces of data. For example, there are multiple different formats available to provide the trie pre-images in, and the implementor can choose whichever is closest to its own internal data structures. diff --git a/trace_decoder/src/decoding.rs b/trace_decoder/src/decoding.rs index 4c3f0172a..482918b56 100644 --- a/trace_decoder/src/decoding.rs +++ b/trace_decoder/src/decoding.rs @@ -20,8 +20,7 @@ use crate::{ processed_block_trace::{NodesUsedByTxn, ProcessedBlockTrace, StateTrieWrites, TxnMetaState}, types::{ HashedAccountAddr, HashedNodeAddr, HashedStorageAddrNibbles, OtherBlockData, TrieRootHash, - TxnIdx, TxnProofGenIR, EMPTY_ACCOUNT_BYTES_RLPED, EMPTY_TRIE_HASH, - ZERO_STORAGE_SLOT_VAL_RLPED, + TxnIdx, TxnProofGenIR, EMPTY_ACCOUNT_BYTES_RLPED, ZERO_STORAGE_SLOT_VAL_RLPED, }, utils::{hash, update_val_if_some}, }; @@ -361,20 +360,35 @@ impl ProcessedBlockTrace { withdrawals: Vec<(Address, U256)>, dummies_already_added: bool, ) -> TraceParsingResult<()> { + let withdrawals_with_hashed_addrs_iter = withdrawals + .iter() + .map(|(addr, v)| (*addr, hash(addr.as_bytes()), *v)); + match dummies_already_added { // If we have no actual dummy proofs, then we create one and append it to the // end of the block. false => { - // Guaranteed to have a real txn. - let txn_idx_of_dummy_entry = - txn_ir.last().unwrap().txn_number_before.low_u64() as usize + 1; - - // Dummy state will be the state after the final txn. - let mut withdrawal_dummy = - create_dummy_gen_input(other_data, extra_data, final_trie_state); + // TODO: Decide if we want this allocation... + // To avoid double hashing the addrs, but I don't know if the extra `Vec` + // allocation is worth it. + let withdrawals_with_hashed_addrs: Vec<_> = + withdrawals_with_hashed_addrs_iter.collect(); + + // Dummy state will be the state after the final txn. Also need to include the + // account nodes that were accessed by the withdrawals. + let withdrawal_addrs = withdrawals_with_hashed_addrs + .iter() + .cloned() + .map(|(_, h_addr, _)| h_addr); + let mut withdrawal_dummy = create_dummy_gen_input_with_state_addrs_accessed( + other_data, + extra_data, + final_trie_state, + withdrawal_addrs, + )?; Self::update_trie_state_from_withdrawals( - &withdrawals, + withdrawals_with_hashed_addrs, &mut final_trie_state.state, )?; @@ -387,7 +401,7 @@ impl ProcessedBlockTrace { } true => { Self::update_trie_state_from_withdrawals( - &withdrawals, + withdrawals_with_hashed_addrs_iter, &mut final_trie_state.state, )?; @@ -404,22 +418,21 @@ impl ProcessedBlockTrace { /// Withdrawals update balances in the account trie, so we need to update /// our local trie state. fn update_trie_state_from_withdrawals<'a>( - withdrawals: impl IntoIterator + 'a, + withdrawals: impl IntoIterator + 'a, state: &mut HashedPartialTrie, ) -> TraceParsingResult<()> { - for (addr, amt) in withdrawals { - let h_addr = hash(addr.as_bytes()); + for (addr, h_addr, amt) in withdrawals { let h_addr_nibs = Nibbles::from_h256_be(h_addr); let acc_bytes = state .get(h_addr_nibs) .ok_or(TraceParsingError::MissingWithdrawalAccount( - *addr, h_addr, *amt, + addr, h_addr, amt, ))?; let mut acc_data = account_from_rlped_bytes(acc_bytes)?; - acc_data.balance += *amt; + acc_data.balance += amt; state.insert(h_addr_nibs, rlp::encode(&acc_data).to_vec()); } @@ -486,12 +499,37 @@ fn create_dummy_gen_input( extra_data: &ExtraBlockData, final_tries: &PartialTrieState, ) -> TxnProofGenIR { - let tries = create_dummy_proof_trie_inputs(final_tries); + let sub_tries = create_dummy_proof_trie_inputs( + final_tries, + create_fully_hashed_out_sub_partial_trie(&final_tries.state), + ); + create_dummy_gen_input_common(other_data, extra_data, sub_tries) +} + +fn create_dummy_gen_input_with_state_addrs_accessed( + other_data: &OtherBlockData, + extra_data: &ExtraBlockData, + final_tries: &PartialTrieState, + account_addrs_accessed: impl Iterator, +) -> TraceParsingResult { + let sub_tries = create_dummy_proof_trie_inputs( + final_tries, + create_minimal_state_partial_trie(&final_tries.state, account_addrs_accessed)?, + ); + Ok(create_dummy_gen_input_common( + other_data, extra_data, sub_tries, + )) +} +fn create_dummy_gen_input_common( + other_data: &OtherBlockData, + extra_data: &ExtraBlockData, + sub_tries: TrieInputs, +) -> GenerationInputs { let trie_roots_after = TrieRoots { - state_root: tries.state_trie.hash(), - transactions_root: tries.transactions_trie.hash(), - receipts_root: tries.receipts_trie.hash(), + state_root: sub_tries.state_trie.hash(), + transactions_root: sub_tries.transactions_trie.hash(), + receipts_root: sub_tries.receipts_trie.hash(), }; // Sanity checks @@ -506,7 +544,7 @@ fn create_dummy_gen_input( GenerationInputs { signed_txn: None, - tries, + tries: sub_tries, trie_roots_after, checkpoint_state_trie_root: extra_data.checkpoint_state_trie_root, block_metadata: other_data.b_data.b_meta.clone(), @@ -519,17 +557,11 @@ fn create_dummy_gen_input( } } -impl TxnMetaState { - fn txn_bytes(&self) -> Vec { - match self.txn_bytes.as_ref() { - Some(v) => v.clone(), - None => Vec::default(), - } - } -} - -fn create_dummy_proof_trie_inputs(final_trie_state: &PartialTrieState) -> TrieInputs { - let partial_sub_storage_tries: Vec<_> = final_trie_state +fn create_dummy_proof_trie_inputs( + final_tries_at_end_of_block: &PartialTrieState, + state_trie: HashedPartialTrie, +) -> TrieInputs { + let partial_sub_storage_tries: Vec<_> = final_tries_at_end_of_block .storage .iter() .map(|(hashed_acc_addr, s_trie)| { @@ -541,9 +573,13 @@ fn create_dummy_proof_trie_inputs(final_trie_state: &PartialTrieState) -> TrieIn .collect(); TrieInputs { - state_trie: create_fully_hashed_out_sub_partial_trie(&final_trie_state.state), - transactions_trie: create_fully_hashed_out_sub_partial_trie(&final_trie_state.txn), - receipts_trie: create_fully_hashed_out_sub_partial_trie(&final_trie_state.receipt), + state_trie, + transactions_trie: create_fully_hashed_out_sub_partial_trie( + &final_tries_at_end_of_block.txn, + ), + receipts_trie: create_fully_hashed_out_sub_partial_trie( + &final_tries_at_end_of_block.receipt, + ), storage_tries: partial_sub_storage_tries, } } @@ -605,3 +641,12 @@ fn account_from_rlped_bytes(bytes: &[u8]) -> TraceParsingResult { rlp::decode(bytes) .map_err(|err| TraceParsingError::AccountDecode(hex::encode(bytes), err.to_string())) } + +impl TxnMetaState { + fn txn_bytes(&self) -> Vec { + match self.txn_bytes.as_ref() { + Some(v) => v.clone(), + None => Vec::default(), + } + } +} diff --git a/trace_decoder/src/lib.rs b/trace_decoder/src/lib.rs index 314b433d6..a1858d93e 100644 --- a/trace_decoder/src/lib.rs +++ b/trace_decoder/src/lib.rs @@ -5,7 +5,7 @@ #![allow(unused)] #![allow(private_interfaces)] -mod compact; +pub mod compact; pub mod decoding; mod deserializers; pub mod processed_block_trace; diff --git a/trace_decoder/src/processed_block_trace.rs b/trace_decoder/src/processed_block_trace.rs index fa200df64..a37d17327 100644 --- a/trace_decoder/src/processed_block_trace.rs +++ b/trace_decoder/src/processed_block_trace.rs @@ -2,7 +2,7 @@ use std::collections::{HashMap, HashSet}; use std::fmt::Debug; use std::iter::once; -use ethereum_types::{Address, U256}; +use ethereum_types::{Address, H256, U256}; use evm_arithmetization::generation::mpt::{AccountRlp, LegacyReceiptRlp}; use mpt_trie::nibbles::Nibbles; use mpt_trie::partial_trie::{HashedPartialTrie, PartialTrie}; @@ -67,19 +67,6 @@ impl BlockTrace { print_value_and_hash_nodes_of_storage_trie(h_addr, s_trie); } - let resolve_code_hash_fn = |c_hash: &_| { - let resolve_code_hash_fn_ref = &p_meta.resolve_code_hash_fn; - let extra_code_hash_mappings_ref = &pre_image_data.extra_code_hash_mappings; - - match extra_code_hash_mappings_ref { - Some(m) => m - .get(c_hash) - .cloned() - .unwrap_or_else(|| (resolve_code_hash_fn_ref)(c_hash)), - None => (resolve_code_hash_fn_ref)(c_hash), - } - }; - let all_accounts_in_pre_image: Vec<_> = pre_image_data .tries .state @@ -94,10 +81,15 @@ impl BlockTrace { }) .collect(); + let mut code_hash_resolver = CodeHashResolving { + client_code_hash_resolve_f: &p_meta.resolve_code_hash_fn, + extra_code_hash_mappings: pre_image_data.extra_code_hash_mappings.unwrap_or_default(), + }; + let txn_info = self .txn_info .into_iter() - .map(|t| t.into_processed_txn_info(&all_accounts_in_pre_image, &resolve_code_hash_fn)) + .map(|t| t.into_processed_txn_info(&all_accounts_in_pre_image, &mut code_hash_resolver)) .collect::>(); ProcessedBlockTrace { @@ -206,11 +198,36 @@ pub(crate) struct ProcessedTxnInfo { pub(crate) meta: TxnMetaState, } +struct CodeHashResolving { + /// If we have not seen this code hash before, use the resolve function that + /// the client passes down to us. This will likely be an rpc call/cache + /// check. + client_code_hash_resolve_f: F, + + /// Code hash mappings that we have constructed from parsing the block + /// trace. If there are any txns that create contracts, then they will also + /// get added here as we process the deltas. + extra_code_hash_mappings: HashMap>, +} + +impl CodeHashResolving { + fn resolve(&mut self, c_hash: &CodeHash) -> Vec { + match self.extra_code_hash_mappings.get(c_hash) { + Some(code) => code.clone(), + None => (self.client_code_hash_resolve_f)(c_hash), + } + } + + fn insert_code(&mut self, c_hash: H256, code: Vec) { + self.extra_code_hash_mappings.insert(c_hash, code); + } +} + impl TxnInfo { fn into_processed_txn_info( self, all_accounts_in_pre_image: &[(HashedAccountAddr, AccountRlp)], - code_hash_resolve_f: &F, + code_hash_resolver: &mut CodeHashResolving, ) -> ProcessedTxnInfo { let mut nodes_used_by_txn = NodesUsedByTxn::default(); let mut contract_code_accessed = create_empty_code_access_map(); @@ -271,11 +288,13 @@ impl TxnInfo { ContractCodeUsage::Read(c_hash) => { contract_code_accessed .entry(c_hash) - .or_insert_with(|| code_hash_resolve_f(&c_hash)); + .or_insert_with(|| code_hash_resolver.resolve(&c_hash)); } ContractCodeUsage::Write(c_bytes) => { let c_hash = hash(&c_bytes); - contract_code_accessed.insert(c_hash, c_bytes.0); + + contract_code_accessed.insert(c_hash, c_bytes.0.clone()); + code_hash_resolver.insert_code(c_hash, c_bytes.0); } } }