diff --git a/Cargo.lock b/Cargo.lock index f288caab7..4ef8e1ef3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4664,6 +4664,31 @@ version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" +[[package]] +name = "smt_trie" +version = "0.1.0" +dependencies = [ + "bytes", + "enum-as-inner", + "eth_trie", + "ethereum-types", + "hex", + "hex-literal", + "keccak-hash 0.10.0", + "log", + "num-traits", + "parking_lot", + "plonky2", + "pretty_env_logger", + "rand", + "rlp", + "rlp-derive", + "serde", + "serde_json", + "thiserror", + "uint", +] + [[package]] name = "socket2" version = "0.4.10" diff --git a/Cargo.toml b/Cargo.toml index 75f871ce8..5f254ba8a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,6 @@ [workspace] members = ["mpt_trie", + "smt_trie", "proof_gen", "trace_decoder", "evm_arithmetization", @@ -83,6 +84,7 @@ serde = "1.0.166" serde_json = "1.0.96" serde_path_to_error = "0.1.14" serde_with = "3.4.0" +smt_trie = { path = "smt_trie", version = "0.1.0" } sha2 = "0.10.6" static_assertions = "1.1.0" thiserror = "1.0.49" diff --git a/smt_trie/Cargo.toml b/smt_trie/Cargo.toml new file mode 100644 index 000000000..fb1335350 --- /dev/null +++ b/smt_trie/Cargo.toml @@ -0,0 +1,39 @@ +[package] +name = "smt_trie" +description = "Types and utility functions for building/working with Polygon Hermez Sparse Merkle Trees." +version = "0.1.0" +authors = ["William Borgeaud "] +readme = "README.md" +categories = ["cryptography"] +edition.workspace = true +license.workspace = true +repository.workspace = true +homepage.workspace = true +keywords.workspace = true + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +bytes = { workspace = true } +enum-as-inner = { workspace = true } +ethereum-types = { workspace = true } +hex = { workspace = true } +hex-literal = { workspace = true } +keccak-hash = { workspace = true } +log = { workspace = true } +num-traits = { workspace = true } +parking_lot = { workspace = true, features = ["serde"] } +plonky2 = { workspace = true } +rand = { workspace = true } +rlp = { workspace = true } +serde = { workspace = true, features = ["derive", "rc"] } +thiserror = { workspace = true } +uint = { workspace = true } + + +[dev-dependencies] +eth_trie = "0.4.0" +pretty_env_logger = "0.5.0" +rlp-derive = { workspace = true } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } diff --git a/smt_trie/README.md b/smt_trie/README.md new file mode 100644 index 000000000..54707f9fe --- /dev/null +++ b/smt_trie/README.md @@ -0,0 +1,2 @@ +Types and functions to work with the Hermez/Polygon zkEVM sparse Merkle tree (SMT) format. +See https://github.com/0xPolygonHermez/zkevm-commonjs for reference implementation. diff --git a/smt_trie/src/bits.rs b/smt_trie/src/bits.rs new file mode 100644 index 000000000..4d2d2ed91 --- /dev/null +++ b/smt_trie/src/bits.rs @@ -0,0 +1,103 @@ +use std::ops::Add; + +use ethereum_types::{BigEndianHash, H256, U256}; +use serde::{Deserialize, Serialize}; + +pub type Bit = bool; + +#[derive( + Copy, Clone, Deserialize, Default, Eq, Hash, Ord, PartialEq, PartialOrd, Serialize, Debug, +)] +pub struct Bits { + /// The number of bits in this sequence. + pub count: usize, + /// A packed encoding of these bits. Only the first (least significant) + /// `count` bits are used. The rest are unused and should be zero. + pub packed: U256, +} + +impl From for Bits { + fn from(packed: U256) -> Self { + Bits { count: 256, packed } + } +} + +impl From for Bits { + fn from(packed: H256) -> Self { + Bits { + count: 256, + packed: packed.into_uint(), + } + } +} + +impl Add for Bits { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + assert!(self.count + rhs.count <= 256, "Overflow"); + Self { + count: self.count + rhs.count, + packed: self.packed * (U256::one() << rhs.count) + rhs.packed, + } + } +} + +impl Bits { + pub fn empty() -> Self { + Bits { + count: 0, + packed: U256::zero(), + } + } + + pub fn is_empty(&self) -> bool { + self.count == 0 + } + + pub fn pop_next_bit(&mut self) -> Bit { + assert!(!self.is_empty(), "Cannot pop from empty bits"); + let b = !(self.packed & U256::one()).is_zero(); + self.packed >>= 1; + self.count -= 1; + b + } + + pub fn get_bit(&self, i: usize) -> Bit { + assert!(i < self.count, "Index out of bounds"); + !(self.packed & (U256::one() << (self.count - 1 - i))).is_zero() + } + + pub fn push_bit(&mut self, bit: Bit) { + self.packed = self.packed * 2 + U256::from(bit as u64); + self.count += 1; + } + + pub fn add_bit(&self, bit: Bit) -> Self { + let mut x = *self; + x.push_bit(bit); + x + } + + pub fn common_prefix(&self, k: &Bits) -> (Self, Option<(Bit, Bit)>) { + let mut a = *self; + let mut b = *k; + while a.count > b.count { + a.pop_next_bit(); + } + while a.count < b.count { + b.pop_next_bit(); + } + if a == b { + return (a, None); + } + let mut a_bit = a.pop_next_bit(); + let mut b_bit = b.pop_next_bit(); + while a != b { + a_bit = a.pop_next_bit(); + b_bit = b.pop_next_bit(); + } + assert_ne!(a_bit, b_bit, "Sanity check."); + (a, Some((a_bit, b_bit))) + } +} diff --git a/smt_trie/src/code.rs b/smt_trie/src/code.rs new file mode 100644 index 000000000..dd6b142b9 --- /dev/null +++ b/smt_trie/src/code.rs @@ -0,0 +1,85 @@ +/// Functions to hash contract bytecode using Poseidon. +/// See `hashContractBytecode()` in https://github.com/0xPolygonHermez/zkevm-commonjs/blob/main/src/smt-utils.js for reference implementation. +use ethereum_types::U256; +use plonky2::field::types::Field; +use plonky2::hash::poseidon::{self, Poseidon}; + +use crate::smt::{HashOut, F}; +use crate::utils::hashout2u; + +pub fn hash_contract_bytecode(mut code: Vec) -> HashOut { + poseidon_pad_byte_vec(&mut code); + + poseidon_hash_padded_byte_vec(code) +} + +pub fn poseidon_hash_padded_byte_vec(bytes: Vec) -> HashOut { + let mut capacity = [F::ZERO; poseidon::SPONGE_CAPACITY]; + let mut arr = [F::ZERO; poseidon::SPONGE_WIDTH]; + for blocks in bytes.chunks_exact(poseidon::SPONGE_RATE * 7) { + arr[..poseidon::SPONGE_RATE].copy_from_slice( + &blocks + .chunks_exact(7) + .map(|block| { + let mut bytes = [0u8; poseidon::SPONGE_RATE]; + bytes[..7].copy_from_slice(block); + F::from_canonical_u64(u64::from_le_bytes(bytes)) + }) + .collect::>(), + ); + arr[poseidon::SPONGE_RATE..poseidon::SPONGE_WIDTH].copy_from_slice(&capacity); + capacity = F::poseidon(arr)[0..poseidon::SPONGE_CAPACITY] + .try_into() + .unwrap(); + } + HashOut { elements: capacity } +} + +pub fn poseidon_pad_byte_vec(bytes: &mut Vec) { + bytes.push(0x01); + while bytes.len() % 56 != 0 { + bytes.push(0x00); + } + *bytes.last_mut().unwrap() |= 0x80; +} + +pub fn hash_bytecode_u256(code: Vec) -> U256 { + hashout2u(hash_contract_bytecode(code)) +} + +#[cfg(test)] +mod tests { + use hex_literal::hex; + + use super::*; + + #[test] + fn test_empty_code() { + assert_eq!( + hash_contract_bytecode(vec![]).elements, + [ + 10052403398432742521, + 15195891732843337299, + 2019258788108304834, + 4300613462594703212, + ] + .map(F::from_canonical_u64) + ); + } + + #[test] + fn test_some_code() { + let code = hex!("60806040526004361061003f5760003560e01c80632b68b9c6146100445780633fa4f2451461005b5780635cfb28e714610086578063718da7ee14610090575b600080fd5b34801561005057600080fd5b506100596100b9565b005b34801561006757600080fd5b506100706100f2565b60405161007d9190610195565b60405180910390f35b61008e6100f8565b005b34801561009c57600080fd5b506100b760048036038101906100b29190610159565b610101565b005b60008054906101000a900473ffffffffffffffffffffffffffffffffffffffff1673ffffffffffffffffffffffffffffffffffffffff16ff5b60015481565b34600181905550565b806000806101000a81548173ffffffffffffffffffffffffffffffffffffffff021916908373ffffffffffffffffffffffffffffffffffffffff16021790555050565b600081359050610153816101f1565b92915050565b60006020828403121561016f5761016e6101ec565b5b600061017d84828501610144565b91505092915050565b61018f816101e2565b82525050565b60006020820190506101aa6000830184610186565b92915050565b60006101bb826101c2565b9050919050565b600073ffffffffffffffffffffffffffffffffffffffff82169050919050565b6000819050919050565b600080fd5b6101fa816101b0565b811461020557600080fd5b5056fea26469706673582212207ae6e5d5feddef608b24cca98990c37cf78f8b377163a7c4951a429d90d6120464736f6c63430008070033"); + + assert_eq!( + hash_contract_bytecode(code.to_vec()).elements, + [ + 13311281292453978464, + 8384462470517067887, + 14733964407220681187, + 13541155386998871195 + ] + .map(F::from_canonical_u64) + ); + } +} diff --git a/smt_trie/src/db.rs b/smt_trie/src/db.rs new file mode 100644 index 000000000..f71fad29a --- /dev/null +++ b/smt_trie/src/db.rs @@ -0,0 +1,23 @@ +use std::collections::HashMap; + +use crate::smt::{Key, Node}; + +pub trait Db: Default { + fn get_node(&self, key: &Key) -> Option<&Node>; + fn set_node(&mut self, key: Key, value: Node); +} + +#[derive(Debug, Clone, Default)] +pub struct MemoryDb { + pub db: HashMap, +} + +impl Db for MemoryDb { + fn get_node(&self, key: &Key) -> Option<&Node> { + self.db.get(key) + } + + fn set_node(&mut self, key: Key, value: Node) { + self.db.insert(key, value); + } +} diff --git a/smt_trie/src/keys.rs b/smt_trie/src/keys.rs new file mode 100644 index 000000000..1f122adbb --- /dev/null +++ b/smt_trie/src/keys.rs @@ -0,0 +1,99 @@ +#![allow(clippy::needless_range_loop)] + +/// This module contains functions to generate keys for the SMT. +/// See https://github.com/0xPolygonHermez/zkevm-commonjs/blob/main/src/smt-utils.js for reference implementation. +use ethereum_types::{Address, U256}; +use plonky2::{field::types::Field, hash::poseidon::Poseidon}; + +use crate::smt::{Key, F}; + +const HASH_ZEROS: [u64; 4] = [ + 4330397376401421145, + 14124799381142128323, + 8742572140681234676, + 14345658006221440202, +]; + +const SMT_KEY_BALANCE: u64 = 0; +const SMT_KEY_NONCE: u64 = 1; +const SMT_KEY_CODE: u64 = 2; +const SMT_KEY_STORAGE: u64 = 3; +const SMT_KEY_LENGTH: u64 = 4; + +pub fn key_balance(addr: Address) -> Key { + let mut arr = [F::ZERO; 12]; + for i in 0..5 { + arr[i] = F::from_canonical_u32(u32::from_be_bytes( + addr.0[16 - 4 * i..16 - 4 * i + 4].try_into().unwrap(), + )); + } + + arr[6] = F::from_canonical_u64(SMT_KEY_BALANCE); + arr[8..12].copy_from_slice(&HASH_ZEROS.map(F::from_canonical_u64)); + + Key(F::poseidon(arr)[0..4].try_into().unwrap()) +} + +pub fn key_nonce(addr: Address) -> Key { + let mut arr = [F::ZERO; 12]; + for i in 0..5 { + arr[i] = F::from_canonical_u32(u32::from_be_bytes( + addr.0[16 - 4 * i..16 - 4 * i + 4].try_into().unwrap(), + )); + } + + arr[6] = F::from_canonical_u64(SMT_KEY_NONCE); + arr[8..12].copy_from_slice(&HASH_ZEROS.map(F::from_canonical_u64)); + + Key(F::poseidon(arr)[0..4].try_into().unwrap()) +} + +pub fn key_code(addr: Address) -> Key { + let mut arr = [F::ZERO; 12]; + for i in 0..5 { + arr[i] = F::from_canonical_u32(u32::from_be_bytes( + addr.0[16 - 4 * i..16 - 4 * i + 4].try_into().unwrap(), + )); + } + + arr[6] = F::from_canonical_u64(SMT_KEY_CODE); + arr[8..12].copy_from_slice(&HASH_ZEROS.map(F::from_canonical_u64)); + + Key(F::poseidon(arr)[0..4].try_into().unwrap()) +} + +pub fn key_storage(addr: Address, slot: U256) -> Key { + let mut arr = [F::ZERO; 12]; + for i in 0..5 { + arr[i] = F::from_canonical_u32(u32::from_be_bytes( + addr.0[16 - 4 * i..16 - 4 * i + 4].try_into().unwrap(), + )); + } + + arr[6] = F::from_canonical_u64(SMT_KEY_STORAGE); + let capacity: [F; 4] = { + let mut arr = [F::ZERO; 12]; + for i in 0..4 { + arr[2 * i] = F::from_canonical_u32(slot.0[i] as u32); + arr[2 * i + 1] = F::from_canonical_u32((slot.0[i] >> 32) as u32); + } + F::poseidon(arr)[0..4].try_into().unwrap() + }; + arr[8..12].copy_from_slice(&capacity); + + Key(F::poseidon(arr)[0..4].try_into().unwrap()) +} + +pub fn key_code_length(addr: Address) -> Key { + let mut arr = [F::ZERO; 12]; + for i in 0..5 { + arr[i] = F::from_canonical_u32(u32::from_be_bytes( + addr.0[16 - 4 * i..16 - 4 * i + 4].try_into().unwrap(), + )); + } + + arr[6] = F::from_canonical_u64(SMT_KEY_LENGTH); + arr[8..12].copy_from_slice(&HASH_ZEROS.map(F::from_canonical_u64)); + + Key(F::poseidon(arr)[0..4].try_into().unwrap()) +} diff --git a/smt_trie/src/lib.rs b/smt_trie/src/lib.rs new file mode 100644 index 000000000..11315f12c --- /dev/null +++ b/smt_trie/src/lib.rs @@ -0,0 +1,8 @@ +pub mod bits; +pub mod code; +pub mod db; +pub mod keys; +pub mod smt; +#[cfg(test)] +mod smt_test; +pub mod utils; diff --git a/smt_trie/src/smt.rs b/smt_trie/src/smt.rs new file mode 100644 index 000000000..356c8c3e9 --- /dev/null +++ b/smt_trie/src/smt.rs @@ -0,0 +1,535 @@ +#![allow(clippy::needless_range_loop)] + +use std::borrow::Borrow; +use std::collections::{HashMap, HashSet}; + +use ethereum_types::U256; +use plonky2::field::goldilocks_field::GoldilocksField; +use plonky2::field::types::{Field, PrimeField64}; +use plonky2::hash::poseidon::{Poseidon, PoseidonHash}; +use plonky2::plonk::config::Hasher; + +use crate::bits::Bits; +use crate::db::Db; +use crate::utils::{ + f2limbs, get_unique_sibling, hash0, hash_key_hash, hashout2u, key2u, limbs2f, u2h, u2k, +}; + +pub(crate) const HASH_TYPE: u8 = 0; +pub(crate) const INTERNAL_TYPE: u8 = 1; +pub(crate) const LEAF_TYPE: u8 = 2; + +pub type F = GoldilocksField; +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub struct Key(pub [F; 4]); +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub struct Node(pub [F; 12]); +pub type Hash = PoseidonHash; +pub type HashOut = >::Hash; + +impl Key { + pub fn split(&self) -> Bits { + let mut bits = Bits::empty(); + let mut arr: [_; 4] = std::array::from_fn(|i| self.0[i].to_canonical_u64()); + for _ in 0..64 { + for j in 0..4 { + bits.push_bit(arr[j] & 1 == 1); + arr[j] >>= 1; + } + } + bits + } + + pub fn join(bits: Bits, rem_key: Self) -> Self { + let mut n = [0; 4]; + let mut accs = [0; 4]; + for i in 0..bits.count { + if bits.get_bit(i) { + accs[i % 4] |= 1 << n[i % 4]; + } + n[i % 4] += 1; + } + let key = std::array::from_fn(|i| { + F::from_canonical_u64((rem_key.0[i].to_canonical_u64() << n[i]) | accs[i]) + }); + Key(key) + } + + fn remove_key_bits(&self, nbits: usize) -> Self { + let full_levels = nbits / 4; + let mut auxk = self.0.map(|x| x.to_canonical_u64()); + for i in 0..4 { + let mut n = full_levels; + if full_levels * 4 + i < nbits { + n += 1; + } + auxk[i] >>= n; + } + Key(auxk.map(F::from_canonical_u64)) + } +} + +impl Node { + pub fn is_one_siblings(&self) -> bool { + self.0[8].is_one() + } +} + +/// Sparse Merkle tree (SMT). +/// Represented as a map from keys to leaves and a map from keys to internal +/// nodes. Leaves hold either a value node, representing an account in the state +/// SMT or a value in the storage SMT, or a hash node, representing a hash of a +/// subtree. Internal nodes hold the hashes of their children. +/// The root is the hash of the root internal node. +/// Leaves are hashed using a prefix of 0, internal nodes using a prefix of 1. +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct Smt { + pub db: D, + pub kv_store: HashMap, + pub root: HashOut, +} + +impl Smt { + /// Returns `Poseidon(x, [0,0,0,0])` and save it in DB. + pub fn hash0(&mut self, x: [F; 8]) -> [F; 4] { + let h = hash0(x); + let a = std::array::from_fn(|i| if i < 8 { x[i] } else { F::ZERO }); + self.db.set_node(Key(h), Node(a)); + h + } + + /// Returns `Poseidon(key || h, [1,0,0,0])` and save it in DB. + pub fn hash_key_hash(&mut self, k: Key, h: [F; 4]) -> [F; 4] { + let a: [_; 8] = std::array::from_fn(|i| if i < 4 { k.0[i] } else { h[i - 4] }); + let a = std::array::from_fn(|i| match i { + j if j < 8 => a[i], + 8 => F::ONE, + _ => F::ZERO, + }); + let h = hash_key_hash(k, h); + self.db.set_node(Key(h), Node(a)); + h + } + + /// Returns the value associated with the key if it is in the SMT, otherwise + /// returns 0. + pub fn get(&self, key: Key) -> U256 { + let keys = key.split(); + let mut level = 0; + let mut acc_key = Bits::empty(); + let mut r = Key(self.root.elements); + + while !r.0.iter().all(F::is_zero) { + let sibling = self.db.get_node(&r).unwrap(); + if sibling.is_one_siblings() { + let found_val_a: [F; 8] = self + .db + .get_node(&Key(sibling.0[4..8].try_into().unwrap())) + .unwrap() + .0[0..8] + .try_into() + .unwrap(); + let found_rem_key = Key(sibling.0[0..4].try_into().unwrap()); + let found_val = limbs2f(found_val_a); + let found_key = Key::join(acc_key, found_rem_key); + return if found_key == key { + assert_eq!( + found_val, + self.kv_store.get(&key).copied().unwrap_or_default() + ); + found_val + } else { + assert!(self + .kv_store + .get(&key) + .copied() + .unwrap_or_default() + .is_zero()); + U256::zero() + }; + } else { + let b = keys.get_bit(level as usize); + r = Key(sibling.0[b as usize * 4..(b as usize + 1) * 4] + .try_into() + .unwrap()); + acc_key.push_bit(b); + level += 1; + } + } + unreachable!() + } + + /// Set the value associated with the key in the SMT. + /// If the value is 0 and the key is in the SMT, the key is removed from the + /// SMT. Reference implementation in https://github.com/0xPolygonHermez/zkevm-commonjs/blob/main/src/smt.js. + pub fn set(&mut self, key: Key, value: U256) { + if value.is_zero() { + self.kv_store.remove(&key); + } else { + self.kv_store.insert(key, value); + } + let mut r = Key(self.root.elements); + let mut new_root = self.root; + let keys = key.split(); + let mut level = 0isize; + let mut acc_key = Bits::empty(); + let mut found_key = None; + let mut found_rem_key = None; + let mut found_old_val_h = None; + let mut siblings = vec![]; + + while !r.0.iter().all(F::is_zero) { + let sibling = self.db.get_node(&r).unwrap(); + siblings.push(*sibling); + if sibling.is_one_siblings() { + found_old_val_h = Some(sibling.0[4..8].try_into().unwrap()); + let found_val_a: [F; 8] = + self.db.get_node(&Key(found_old_val_h.unwrap())).unwrap().0[0..8] + .try_into() + .unwrap(); + found_rem_key = Some(Key(sibling.0[0..4].try_into().unwrap())); + let _found_val = limbs2f(found_val_a); + found_key = Some(Key::join(acc_key, found_rem_key.unwrap())); + break; + } else { + let b = keys.get_bit(level as usize); + r = Key(sibling.0[b as usize * 4..(b as usize + 1) * 4] + .try_into() + .unwrap()); + acc_key.push_bit(b); + level += 1; + } + } + + level -= 1; + if !acc_key.is_empty() { + acc_key.pop_next_bit(); + } + + if value.is_zero() { + if let Some(found_key) = found_key { + if key == found_key { + if level >= 0 { + let i = (keys.get_bit(level as usize) as usize) * 4; + siblings[level as usize].0[i..i + 4].copy_from_slice(&[F::ZERO; 4]); + let mut u_key = get_unique_sibling(siblings[level as usize]); + + if u_key >= 0 { + let k = siblings[level as usize].0 + [u_key as usize * 4..u_key as usize * 4 + 4] + .try_into() + .unwrap(); + siblings[(level + 1) as usize] = *self.db.get_node(&Key(k)).unwrap(); + if siblings[(level + 1) as usize].is_one_siblings() { + let val_h = + siblings[(level + 1) as usize].0[4..8].try_into().unwrap(); + let val_a = self.db.get_node(&Key(val_h)).unwrap().0[0..8] + .try_into() + .unwrap(); + let r_key = + siblings[(level + 1) as usize].0[0..4].try_into().unwrap(); + + let _val = limbs2f(val_a); + + assert!(u_key == 0 || u_key == 1); + let ins_key = Key::join(acc_key.add_bit(u_key != 0), Key(r_key)); + while (u_key >= 0) && (level >= 0) { + level -= 1; + if level >= 0 { + u_key = get_unique_sibling(siblings[level as usize]); + } + } + + let old_key = ins_key.remove_key_bits((level + 1) as usize); + let old_leaf_hash = self.hash_key_hash(old_key, val_h); + + if level >= 0 { + let b = keys.get_bit(level as usize) as usize * 4; + siblings[level as usize].0[b..b + 4] + .copy_from_slice(&old_leaf_hash); + } else { + new_root = HashOut { + elements: old_leaf_hash, + }; + } + } + } else { + panic!() + } + } else { + new_root = HashOut { + elements: [F::ZERO; 4], + }; + } + } + } + } else if let Some(found_key) = found_key { + if key == found_key { + let new_val_h = self.hash0(f2limbs(value)); + let new_leaf_hash = self.hash_key_hash(found_rem_key.unwrap(), new_val_h); + if level >= 0 { + let i = (keys.get_bit(level as usize) as usize) * 4; + siblings[level as usize].0[i..i + 4].copy_from_slice(&new_leaf_hash); + } else { + new_root = HashOut { + elements: new_leaf_hash, + }; + } + } else { + let mut node = [F::ZERO; 8]; + let mut level2 = level + 1; + let found_keys = found_key.split(); + while keys.get_bit(level2 as usize) == found_keys.get_bit(level2 as usize) { + level2 += 1; + } + let old_key = found_key.remove_key_bits(level2 as usize + 1); + let old_leaf_hash = self.hash_key_hash(old_key, found_old_val_h.unwrap()); + + let new_key = key.remove_key_bits(level2 as usize + 1); + let new_val_h = self.hash0(f2limbs(value)); + let new_leaf_hash = self.hash_key_hash(new_key, new_val_h); + + let b = keys.get_bit(level2 as usize) as usize * 4; + let bb = found_keys.get_bit(level2 as usize) as usize * 4; + node[b..b + 4].copy_from_slice(&new_leaf_hash); + node[bb..bb + 4].copy_from_slice(&old_leaf_hash); + + let mut r2 = self.hash0(node); + level2 -= 1; + + while level2 != level { + node = [F::ZERO; 8]; + let b = keys.get_bit(level2 as usize) as usize * 4; + node[b..b + 4].copy_from_slice(&r2); + + r2 = self.hash0(node); + level2 -= 1; + } + + if level >= 0 { + let b = keys.get_bit(level as usize) as usize * 4; + siblings[level as usize].0[b..b + 4].copy_from_slice(&r2); + } else { + new_root = HashOut { elements: r2 }; + } + } + } else { + let new_key = key.remove_key_bits((level + 1) as usize); + let new_val_h = self.hash0(f2limbs(value)); + let new_leaf_hash = self.hash_key_hash(new_key, new_val_h); + + if level >= 0 { + let b = keys.get_bit(level as usize) as usize * 4; + siblings[level as usize].0[b..b + 4].copy_from_slice(&new_leaf_hash); + } else { + new_root = HashOut { + elements: new_leaf_hash, + }; + } + } + siblings.truncate((level + 1) as usize); + + while level >= 0 { + new_root = F::poseidon(siblings[level as usize].0)[0..4] + .try_into() + .unwrap(); + self.db + .set_node(Key(new_root.elements), siblings[level as usize]); + level -= 1; + if level >= 0 { + let b = keys.get_bit(level as usize) as usize * 4; + siblings[level as usize].0[b..b + 4].copy_from_slice(&new_root.elements); + } + } + self.root = new_root; + } + + /// Delete the key in the SMT. + pub fn delete(&mut self, key: Key) { + self.kv_store.remove(&key); + self.set(key, U256::zero()); + } + + /// Set the key to the hash in the SMT. + /// Needs to be called before any call to `set` to avoid issues. + pub fn set_hash(&mut self, key: Bits, hash: HashOut) { + let mut r = Key(self.root.elements); + let mut new_root = self.root; + let mut level = 0isize; + let mut siblings = vec![]; + + for _ in 0..key.count { + let sibling = self.db.get_node(&r).unwrap_or(&Node([F::ZERO; 12])); + siblings.push(*sibling); + if sibling.is_one_siblings() { + panic!("Hit a leaf node."); + } else { + let b = key.get_bit(level as usize); + r = Key(sibling.0[b as usize * 4..(b as usize + 1) * 4] + .try_into() + .unwrap()); + level += 1; + } + } + level -= 1; + assert_eq!( + r, + Key([F::ZERO; 4]), + "Tried to insert a hash node in a non-empty node." + ); + + if level >= 0 { + let b = key.get_bit(level as usize) as usize * 4; + siblings[level as usize].0[b..b + 4].copy_from_slice(&hash.elements); + } else { + new_root = hash; + } + siblings.truncate((level + 1) as usize); + + while level >= 0 { + new_root = F::poseidon(siblings[level as usize].0)[0..4] + .try_into() + .unwrap(); + self.db + .set_node(Key(new_root.elements), siblings[level as usize]); + level -= 1; + if level >= 0 { + let b = key.get_bit(level as usize) as usize * 4; + siblings[level as usize].0[b..b + 4].copy_from_slice(&new_root.elements); + } + } + self.root = new_root; + } + + /// Serialize and prune the SMT into a vector of U256. + /// Starts with a [0, 0] for convenience, that way `ptr=0` is a canonical + /// empty node. Therefore the root of the SMT is at `ptr=2`. + /// `keys` is a list of keys whose prefixes will not be hashed-out in the + /// serialization. + /// Serialization rules: + /// ```pseudocode + /// serialize( HashNode { h } ) = [HASH_TYPE, h] + /// serialize( InternalNode { left, right } ) = [INTERNAL_TYPE, serialize(left).ptr, serialize(right).ptr] + /// serialize( LeafNode { rem_key, value } ) = [LEAF_TYPE, rem_key, value] + /// ``` + pub fn serialize_and_prune, I: IntoIterator>( + &self, + keys: I, + ) -> Vec { + let mut v = vec![U256::zero(); 2]; // For empty hash node. + let key = Key(self.root.elements); + + let mut keys_to_include = HashSet::new(); + for key in keys.into_iter() { + let mut bits = key.borrow().split(); + loop { + keys_to_include.insert(bits); + if bits.is_empty() { + break; + } + bits.pop_next_bit(); + } + } + + serialize(self, key, &mut v, Bits::empty(), &keys_to_include); + if v.len() == 2 { + v.extend([U256::zero(); 2]); + } + v + } + + pub fn serialize(&self) -> Vec { + // Include all keys. + self.serialize_and_prune(self.kv_store.keys()) + } +} + +fn serialize( + smt: &Smt, + key: Key, + v: &mut Vec, + cur_bits: Bits, + keys_to_include: &HashSet, +) -> usize { + if key.0.iter().all(F::is_zero) { + return 0; // `ptr=0` is an empty node. + } + + if !keys_to_include.contains(&cur_bits) || smt.db.get_node(&key).is_none() { + let index = v.len(); + v.push(HASH_TYPE.into()); + v.push(key2u(key)); + index + } else if let Some(node) = smt.db.get_node(&key) { + if node.0.iter().all(F::is_zero) { + panic!("wtf?"); + } + + if node.is_one_siblings() { + let val_h = node.0[4..8].try_into().unwrap(); + let val_a = smt.db.get_node(&Key(val_h)).unwrap().0[0..8] + .try_into() + .unwrap(); + let rem_key = Key(node.0[0..4].try_into().unwrap()); + let val = limbs2f(val_a); + let index = v.len(); + v.push(LEAF_TYPE.into()); + v.push(key2u(rem_key)); + v.push(val); + index + } else { + let key_left = Key(node.0[0..4].try_into().unwrap()); + let key_right = Key(node.0[4..8].try_into().unwrap()); + let index = v.len(); + v.push(INTERNAL_TYPE.into()); + v.push(U256::zero()); + v.push(U256::zero()); + let i_left = + serialize(smt, key_left, v, cur_bits.add_bit(false), keys_to_include).into(); + v[index + 1] = i_left; + let i_right = + serialize(smt, key_right, v, cur_bits.add_bit(true), keys_to_include).into(); + v[index + 2] = i_right; + index + } + } else { + unreachable!() + } +} + +/// Hash a serialized state SMT, i.e., one where leaves hold accounts. +pub fn hash_serialize(v: &[U256]) -> HashOut { + _hash_serialize(v, 2) +} + +pub fn hash_serialize_u256(v: &[U256]) -> U256 { + hashout2u(hash_serialize(v)) +} + +fn _hash_serialize(v: &[U256], ptr: usize) -> HashOut { + assert!(v[ptr] <= u8::MAX.into()); + match v[ptr].as_u64() as u8 { + HASH_TYPE => u2h(v[ptr + 1]), + + INTERNAL_TYPE => { + let mut node = Node([F::ZERO; 12]); + for b in 0..2 { + let child_index = v[ptr + 1 + b]; + let child_hash = _hash_serialize(v, child_index.as_usize()); + node.0[b * 4..(b + 1) * 4].copy_from_slice(&child_hash.elements); + } + F::poseidon(node.0)[0..4].try_into().unwrap() + } + LEAF_TYPE => { + let rem_key = u2k(v[ptr + 1]); + let value = f2limbs(v[ptr + 2]); + let value_h = hash0(value); + let mut node = Node([F::ZERO; 12]); + node.0[8] = F::ONE; + node.0[0..4].copy_from_slice(&rem_key.0); + node.0[4..8].copy_from_slice(&value_h); + F::poseidon(node.0)[0..4].try_into().unwrap() + } + _ => panic!("Should not happen"), + } +} diff --git a/smt_trie/src/smt_test.rs b/smt_trie/src/smt_test.rs new file mode 100644 index 000000000..c086e17dc --- /dev/null +++ b/smt_trie/src/smt_test.rs @@ -0,0 +1,409 @@ +use ethereum_types::U256; +use plonky2::field::types::{Field, Sample}; +use plonky2::hash::hash_types::HashOut; +use rand::seq::SliceRandom; +use rand::{random, thread_rng, Rng}; + +use crate::bits::Bits; +use crate::db::Db; +use crate::smt::HASH_TYPE; +use crate::utils::hashout2u; +use crate::{ + db::MemoryDb, + smt::{hash_serialize, Key, Smt, F}, +}; + +#[test] +fn test_add_and_rem() { + let mut smt = Smt::::default(); + + let k = Key(F::rand_array()); + let v = U256(thread_rng().gen()); + smt.set(k, v); + assert_eq!(v, smt.get(k)); + + smt.set(k, U256::zero()); + assert_eq!(smt.root.elements, [F::ZERO; 4]); + + let ser = smt.serialize(); + assert_eq!(hash_serialize(&ser), smt.root); +} + +#[test] +fn test_add_and_rem_hermez() { + let mut smt = Smt::::default(); + + let k = Key([F::ONE, F::ZERO, F::ZERO, F::ZERO]); + let v = U256::from(2); + smt.set(k, v); + assert_eq!(v, smt.get(k)); + assert_eq!( + smt.root.elements, + [ + 16483217357039062949, + 6830539605347455377, + 6826288191577443203, + 8219762152026661456 + ] + .map(F::from_canonical_u64) + ); + + smt.set(k, U256::zero()); + assert_eq!(smt.root.elements, [F::ZERO; 4]); + + let ser = smt.serialize(); + assert_eq!(hash_serialize(&ser), smt.root); +} + +#[test] +fn test_update_element_1() { + let mut smt = Smt::::default(); + + let k = Key(F::rand_array()); + let v1 = U256(thread_rng().gen()); + let v2 = U256(thread_rng().gen()); + smt.set(k, v1); + let root = smt.root; + smt.set(k, v2); + smt.set(k, v1); + assert_eq!(smt.root, root); + + let ser = smt.serialize(); + assert_eq!(hash_serialize(&ser), smt.root); +} + +#[test] +fn test_add_shared_element_2() { + let mut smt = Smt::::default(); + + let k1 = Key(F::rand_array()); + let k2 = Key(F::rand_array()); + assert_ne!(k1, k2, "Unlucky"); + let v1 = U256(thread_rng().gen()); + let v2 = U256(thread_rng().gen()); + smt.set(k1, v1); + smt.set(k2, v2); + smt.set(k1, U256::zero()); + smt.set(k2, U256::zero()); + assert_eq!(smt.root.elements, [F::ZERO; 4]); + + let ser = smt.serialize(); + assert_eq!(hash_serialize(&ser), smt.root); +} + +#[test] +fn test_add_shared_element_3() { + let mut smt = Smt::::default(); + + let k1 = Key(F::rand_array()); + let k2 = Key(F::rand_array()); + let k3 = Key(F::rand_array()); + let v1 = U256(thread_rng().gen()); + let v2 = U256(thread_rng().gen()); + let v3 = U256(thread_rng().gen()); + smt.set(k1, v1); + smt.set(k2, v2); + smt.set(k3, v3); + smt.set(k1, U256::zero()); + smt.set(k2, U256::zero()); + smt.set(k3, U256::zero()); + assert_eq!(smt.root.elements, [F::ZERO; 4]); + + let ser = smt.serialize(); + assert_eq!(hash_serialize(&ser), smt.root); +} + +#[test] +fn test_add_remove_128() { + let mut smt = Smt::::default(); + + let kvs = (0..128) + .map(|_| { + let k = Key(F::rand_array()); + let v = U256(thread_rng().gen()); + smt.set(k, v); + (k, v) + }) + .collect::>(); + for &(k, v) in &kvs { + smt.set(k, v); + } + for &(k, _) in &kvs { + smt.set(k, U256::zero()); + } + assert_eq!(smt.root.elements, [F::ZERO; 4]); + + let ser = smt.serialize(); + assert_eq!(hash_serialize(&ser), smt.root); +} + +#[test] +fn test_should_read_random() { + let mut smt = Smt::::default(); + + let kvs = (0..128) + .map(|_| { + let k = Key(F::rand_array()); + let v = U256(thread_rng().gen()); + smt.set(k, v); + (k, v) + }) + .collect::>(); + for &(k, v) in &kvs { + smt.set(k, v); + } + for &(k, v) in &kvs { + assert_eq!(smt.get(k), v); + } + + let ser = smt.serialize(); + assert_eq!(hash_serialize(&ser), smt.root); +} + +#[test] +fn test_add_element_similar_key() { + let mut smt = Smt::::default(); + + let k1 = Key([F::ZERO; 4]); + let k2 = Key([F::from_canonical_u16(15), F::ZERO, F::ZERO, F::ZERO]); + let k3 = Key([F::from_canonical_u16(31), F::ZERO, F::ZERO, F::ZERO]); + let v1 = U256::from(2); + let v2 = U256::from(3); + smt.set(k1, v1); + smt.set(k2, v1); + smt.set(k3, v2); + + let expected_root = [ + 442750481621001142, + 12174547650106208885, + 10730437371575329832, + 4693848817100050981, + ] + .map(F::from_canonical_u64); + assert_eq!(smt.root.elements, expected_root); + + let ser = smt.serialize(); + assert_eq!(hash_serialize(&ser), smt.root); +} + +#[test] +fn test_leaf_one_level_depth() { + let mut smt = Smt::::default(); + + let k0 = Key([ + 15508201873038097485, + 13226964191399612151, + 16289586894263066011, + 5039894867879804772, + ] + .map(F::from_canonical_u64)); + let k1 = Key([ + 844617937539064431, + 8280782215217712600, + 776954566881514913, + 1946423943169448778, + ] + .map(F::from_canonical_u64)); + let k2 = Key([ + 15434611863279822111, + 11975487827769517766, + 15368078704174133449, + 1970673199824226969, + ] + .map(F::from_canonical_u64)); + let k3 = Key([ + 4947646911082557289, + 4015479196169929139, + 8997983193975654297, + 9607383237755583623, + ] + .map(F::from_canonical_u64)); + let k4 = Key([ + 15508201873038097485, + 13226964191399612151, + 16289586894263066011, + 5039894867879804772, + ] + .map(F::from_canonical_u64)); + + let v0 = U256::from_dec_str( + "8163644824788514136399898658176031121905718480550577527648513153802600646339", + ) + .unwrap(); + let v1 = U256::from_dec_str( + "115792089237316195423570985008687907853269984665640564039457584007913129639934", + ) + .unwrap(); + let v2 = U256::from_dec_str( + "115792089237316195423570985008687907853269984665640564039457584007913129639935", + ) + .unwrap(); + let v3 = U256::from_dec_str("7943875943875408").unwrap(); + let v4 = U256::from_dec_str( + "35179347944617143021579132182092200136526168785636368258055676929581544372820", + ) + .unwrap(); + + smt.set(k0, v0); + smt.set(k1, v1); + smt.set(k2, v2); + smt.set(k3, v3); + smt.set(k4, v4); + + let expected_root = [ + 13590506365193044307, + 13215874698458506886, + 4743455437729219665, + 1933616419393621600, + ] + .map(F::from_canonical_u64); + assert_eq!(smt.root.elements, expected_root); + + let ser = smt.serialize(); + assert_eq!(hash_serialize(&ser), smt.root); +} + +#[test] +fn test_no_write_0() { + let mut smt = Smt::::default(); + + let k1 = Key(F::rand_array()); + let k2 = Key(F::rand_array()); + let v = U256(thread_rng().gen()); + smt.set(k1, v); + let root = smt.root; + smt.set(k2, U256::zero()); + assert_eq!(smt.root, root); + + let ser = smt.serialize(); + assert_eq!(hash_serialize(&ser), smt.root); +} + +#[test] +fn test_set_hash_first_level() { + let mut smt = Smt::::default(); + + let kvs = (0..128) + .map(|_| { + let k = Key(F::rand_array()); + let v = U256(random()); + smt.set(k, v); + (k, v) + }) + .collect::>(); + for &(k, v) in &kvs { + smt.set(k, v); + } + + let first_level = smt.db.get_node(&Key(smt.root.elements)).unwrap(); + let mut hash_smt = Smt::::default(); + let zero = Bits { + count: 1, + packed: U256::zero(), + }; + let one = Bits { + count: 1, + packed: U256::one(), + }; + hash_smt.set_hash( + zero, + HashOut { + elements: first_level.0[0..4].try_into().unwrap(), + }, + ); + hash_smt.set_hash( + one, + HashOut { + elements: first_level.0[4..8].try_into().unwrap(), + }, + ); + + assert_eq!(smt.root, hash_smt.root); + + let ser = hash_smt.serialize(); + assert_eq!(hash_serialize(&ser), hash_smt.root); +} + +#[test] +fn test_set_hash_order() { + let mut smt = Smt::::default(); + + let level = 4; + + let mut khs = (1..1 << level) + .map(|i| { + let k = Bits { + count: level, + packed: i.into(), + }; + let hash = HashOut { + elements: F::rand_array(), + }; + (k, hash) + }) + .collect::>(); + for &(k, v) in &khs { + smt.set_hash(k, v); + } + let key = loop { + // Forgive my laziness + let key = Key(F::rand_array()); + let keys = key.split(); + if (0..level).all(|i| !keys.get_bit(i)) { + break key; + } + }; + let val = U256(random()); + smt.set(key, val); + + let mut second_smt = Smt::::default(); + khs.shuffle(&mut thread_rng()); + for (k, v) in khs { + second_smt.set_hash(k, v); + } + second_smt.set(key, val); + + assert_eq!(smt.root, second_smt.root); + + let ser = second_smt.serialize(); + assert_eq!(hash_serialize(&ser), second_smt.root); +} + +#[test] +fn test_serialize_and_prune() { + let mut smt = Smt::::default(); + + for _ in 0..128 { + let k = Key(F::rand_array()); + let v = U256(random()); + smt.set(k, v); + } + + let ser = smt.serialize(); + assert_eq!(hash_serialize(&ser), smt.root); + + let subset = { + let r: u128 = random(); + smt.kv_store + .keys() + .enumerate() + .filter_map(|(i, k)| if r & (1 << i) != 0 { Some(*k) } else { None }) + .collect::>() + }; + + let pruned_ser = smt.serialize_and_prune(subset); + assert_eq!(hash_serialize(&pruned_ser), smt.root); + assert!(pruned_ser.len() <= ser.len()); + + let trivial_ser = smt.serialize_and_prune::>(vec![]); + assert_eq!( + trivial_ser, + vec![ + U256::zero(), + U256::zero(), + HASH_TYPE.into(), + hashout2u(smt.root) + ] + ); + assert_eq!(hash_serialize(&trivial_ser), smt.root); +} diff --git a/smt_trie/src/utils.rs b/smt_trie/src/utils.rs new file mode 100644 index 000000000..267b6b8e9 --- /dev/null +++ b/smt_trie/src/utils.rs @@ -0,0 +1,89 @@ +use ethereum_types::U256; +use plonky2::field::types::{Field, PrimeField64}; +use plonky2::hash::poseidon::Poseidon; + +use crate::smt::{HashOut, Key, Node, F}; + +/// Returns `Poseidon(x, [0,0,0,0])`. +pub(crate) fn hash0(x: [F; 8]) -> [F; 4] { + F::poseidon(std::array::from_fn(|i| if i < 8 { x[i] } else { F::ZERO }))[0..4] + .try_into() + .unwrap() +} + +/// Returns `Poseidon(x, [1,0,0,0])`. +pub(crate) fn hash1(x: [F; 8]) -> [F; 4] { + F::poseidon(std::array::from_fn(|i| match i { + j if j < 8 => x[i], + 8 => F::ONE, + _ => F::ZERO, + }))[0..4] + .try_into() + .unwrap() +} + +/// Returns `Poseidon(key || h, [1,0,0,0])`. +pub(crate) fn hash_key_hash(k: Key, h: [F; 4]) -> [F; 4] { + hash1(std::array::from_fn( + |i| if i < 4 { k.0[i] } else { h[i - 4] }, + )) +} + +/// Split a U256 into 8 32-bit limbs in little-endian order. +pub(crate) fn f2limbs(x: U256) -> [F; 8] { + std::array::from_fn(|i| F::from_canonical_u32((x >> (32 * i)).low_u32())) +} + +/// Pack 8 32-bit limbs in little-endian order into a U256. +pub(crate) fn limbs2f(limbs: [F; 8]) -> U256 { + limbs + .into_iter() + .enumerate() + .fold(U256::zero(), |acc, (i, x)| { + acc + (U256::from(x.to_canonical_u64()) << (i * 32)) + }) +} + +/// Convert a `HashOut` to a `U256`. +pub fn hashout2u(h: HashOut) -> U256 { + key2u(Key(h.elements)) +} + +/// Convert a `Key` to a `U256`. +pub fn key2u(key: Key) -> U256 { + U256(key.0.map(|x| x.to_canonical_u64())) +} + +/// Convert a `U256` to a `Hashout`. +pub(crate) fn u2h(x: U256) -> HashOut { + HashOut { + elements: x.0.map(F::from_canonical_u64), + } +} + +/// Convert a `U256` to a `Key`. +pub(crate) fn u2k(x: U256) -> Key { + Key(x.0.map(F::from_canonical_u64)) +} + +/// Given a node, return the index of the unique non-zero sibling, or -1 if +/// there is no such sibling. +pub(crate) fn get_unique_sibling(node: Node) -> isize { + let mut nfound = 0; + let mut fnd = 0; + for i in (0..12).step_by(4) { + if !(node.0[i].is_zero() + && node.0[i + 1].is_zero() + && node.0[i + 2].is_zero() + && node.0[i + 3].is_zero()) + { + nfound += 1; + fnd = i as isize / 4; + } + } + if nfound == 1 { + fnd + } else { + -1 + } +}