diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 00000000..24ec1c56 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,20 @@ +# Documentation available at editorconfig.org + +root=true + +[*] +ident_style = space +ident_size = 4 +end_of_line = lf +charset = utf-8 +trim_trailing_whitespace = true +insert_final_newline = true + +[*.rs] +max_line_length = 100 + +[*.md] +trim_trailing_whitespace = false + +[*.yml] +ident_size = 2 diff --git a/CHANGELOG.md b/CHANGELOG.md index 356099cd..b8a6559d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,9 @@ +## 0.8.0 (TBD) + +* Implemented the `PartialMmr` data structure (#195). +* Updated Winterfell dependency to v0.7 (#200) +* Implemented RPX hash function (#201). + ## 0.7.1 (2023-10-10) * Fixed RPO Falcon signature build on Windows. @@ -12,7 +18,6 @@ * Implemented benchmarking for `TieredSmt` (#182). * Added more leaf traversal methods for `MerkleStore` (#185). * Added SVE acceleration for RPO hash function (#189). -* Implemented the `PartialMmr` datastructure (#195). ## 0.6.0 (2023-06-25) diff --git a/Cargo.toml b/Cargo.toml index ec93d063..daccafca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,12 +1,12 @@ [package] name = "miden-crypto" -version = "0.7.1" +version = "0.8.0" description = "Miden Cryptographic primitives" authors = ["miden contributors"] readme = "README.md" license = "MIT" repository = "https://github.com/0xPolygonMiden/crypto" -documentation = "https://docs.rs/miden-crypto/0.7.1" +documentation = "https://docs.rs/miden-crypto/0.8.0" categories = ["cryptography", "no-std"] keywords = ["miden", "crypto", "hash", "merkle"] edition = "2021" @@ -42,16 +42,16 @@ sve = ["std"] blake3 = { version = "1.5", default-features = false } clap = { version = "4.4", features = ["derive"], optional = true } libc = { version = "0.2", default-features = false, optional = true } -rand_utils = { version = "0.6", package = "winter-rand-utils", optional = true } +rand_utils = { version = "0.7", package = "winter-rand-utils", optional = true } serde = { version = "1.0", features = [ "derive" ], default-features = false, optional = true } -winter_crypto = { version = "0.6", package = "winter-crypto", default-features = false } -winter_math = { version = "0.6", package = "winter-math", default-features = false } -winter_utils = { version = "0.6", package = "winter-utils", default-features = false } +winter_crypto = { version = "0.7", package = "winter-crypto", default-features = false } +winter_math = { version = "0.7", package = "winter-math", default-features = false } +winter_utils = { version = "0.7", package = "winter-utils", default-features = false } [dev-dependencies] criterion = { version = "0.5", features = ["html_reports"] } proptest = "1.3" -rand_utils = { version = "0.6", package = "winter-rand-utils" } +rand_utils = { version = "0.7", package = "winter-rand-utils" } [build-dependencies] cc = { version = "1.0", features = ["parallel"], optional = true } diff --git a/README.md b/README.md index 7ec16b25..f4de47ed 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,7 @@ This crate contains cryptographic primitives used in Polygon Miden. * [BLAKE3](https://github.com/BLAKE3-team/BLAKE3) hash function with 256-bit, 192-bit, or 160-bit output. The 192-bit and 160-bit outputs are obtained by truncating the 256-bit output of the standard BLAKE3. * [RPO](https://eprint.iacr.org/2022/1577) hash function with 256-bit output. This hash function is an algebraic hash function suitable for recursive STARKs. +* [RPX](https://eprint.iacr.org/2023/1045) hash function with 256-bit output. Similar to RPO, this hash function is suitable for recursive STARKs but it is about 2x faster as compared to RPO. For performance benchmarks of these hash functions and their comparison to other popular hash functions please see [here](./benches/). @@ -16,6 +17,7 @@ For performance benchmarks of these hash functions and their comparison to other * `MerkleTree`: a regular fully-balanced binary Merkle tree. The depth of this tree can be at most 64. * `Mmr`: a Merkle mountain range structure designed to function as an append-only log. * `PartialMerkleTree`: a partial view of a Merkle tree where some sub-trees may not be known. This is similar to a collection of Merkle paths all resolving to the same root. The length of the paths can be at most 64. +* `PartialMmr`: a partial view of a Merkle mountain range structure. * `SimpleSmt`: a Sparse Merkle Tree (with no compaction), mapping 64-bit keys to 4-element values. * `TieredSmt`: a Sparse Merkle tree (with compaction), mapping 4-element keys to 4-element values. diff --git a/benches/README.md b/benches/README.md index a1dddd0a..1ba848ab 100644 --- a/benches/README.md +++ b/benches/README.md @@ -6,6 +6,7 @@ In the Miden VM, we make use of different hash functions. Some of these are "tra * **Poseidon** as specified [here](https://eprint.iacr.org/2019/458.pdf) and implemented [here](https://github.com/mir-protocol/plonky2/blob/806b88d7d6e69a30dc0b4775f7ba275c45e8b63b/plonky2/src/hash/poseidon_goldilocks.rs) (but in pure Rust, without vectorized instructions). * **Rescue Prime (RP)** as specified [here](https://eprint.iacr.org/2020/1143) and implemented [here](https://github.com/novifinancial/winterfell/blob/46dce1adf0/crypto/src/hash/rescue/rp64_256/mod.rs). * **Rescue Prime Optimized (RPO)** as specified [here](https://eprint.iacr.org/2022/1577) and implemented in this crate. +* **Rescue Prime Extended (RPX)** a variant of the [xHash](https://eprint.iacr.org/2023/1045) hash function as implemented in this crate. ## Comparison and Instructions @@ -15,28 +16,28 @@ The second scenario is that of sequential hashing where we take a sequence of le #### Scenario 1: 2-to-1 hashing `h(a,b)` -| Function | BLAKE3 | SHA3 | Poseidon | Rp64_256 | RPO_256 | -| ------------------- | ------ | --------| --------- | --------- | ------- | -| Apple M1 Pro | 80 ns | 245 ns | 1.5 us | 9.1 us | 5.4 us | -| Apple M2 | 76 ns | 233 ns | 1.3 us | 7.9 us | 5.0 us | -| Amazon Graviton 3 | 108 ns | | | | 5.3 us | -| AMD Ryzen 9 5950X | 64 ns | 273 ns | 1.2 us | 9.1 us | 5.5 us | -| Intel Core i5-8279U | 80 ns | | | | 8.7 us | -| Intel Xeon 8375C | 67 ns | | | | 8.2 us | +| Function | BLAKE3 | SHA3 | Poseidon | Rp64_256 | RPO_256 | RPX_256 | +| ------------------- | ------ | ------- | --------- | --------- | ------- | ------- | +| Apple M1 Pro | 76 ns | 245 ns | 1.5 µs | 9.1 µs | 5.2 µs | 2.7 µs | +| Apple M2 Max | 71 ns | 233 ns | 1.3 µs | 7.9 µs | 4.6 µs | 2.4 µs | +| Amazon Graviton 3 | 108 ns | | | | 5.3 µs | 3.1 µs | +| AMD Ryzen 9 5950X | 64 ns | 273 ns | 1.2 µs | 9.1 µs | 5.5 µs | | +| Intel Core i5-8279U | 68 ns | 536 ns | 2.0 µs | 13.6 µs | 8.5 µs | 4.4 µs | +| Intel Xeon 8375C | 67 ns | | | | 8.2 µs | | #### Scenario 2: Sequential hashing of 100 elements `h([a_0,...,a_99])` -| Function | BLAKE3 | SHA3 | Poseidon | Rp64_256 | RPO_256 | -| ------------------- | -------| ------- | --------- | --------- | ------- | -| Apple M1 Pro | 1.0 us | 1.5 us | 19.4 us | 118 us | 70 us | -| Apple M2 | 1.0 us | 1.5 us | 17.4 us | 103 us | 65 us | -| Amazon Graviton 3 | 1.4 us | | | | 69 us | -| AMD Ryzen 9 5950X | 0.8 us | 1.7 us | 15.7 us | 120 us | 72 us | -| Intel Core i5-8279U | 1.0 us | | | | 116 us | -| Intel Xeon 8375C | 0.8 ns | | | | 110 us | +| Function | BLAKE3 | SHA3 | Poseidon | Rp64_256 | RPO_256 | RPX_256 | +| ------------------- | -------| ------- | --------- | --------- | ------- | ------- | +| Apple M1 Pro | 1.0 µs | 1.5 µs | 19.4 µs | 118 µs | 69 µs | 35 µs | +| Apple M2 Max | 0.9 µs | 1.5 µs | 17.4 µs | 103 µs | 60 µs | 31 µs | +| Amazon Graviton 3 | 1.4 µs | | | | 69 µs | 41 µs | +| AMD Ryzen 9 5950X | 0.8 µs | 1.7 µs | 15.7 µs | 120 µs | 72 µs | | +| Intel Core i5-8279U | 0.9 µs | | | | 107 µs | 56 µs | +| Intel Xeon 8375C | 0.8 µs | | | | 110 µs | | Notes: -- On Graviton 3, RPO256 is run with SVE acceleration enabled. +- On Graviton 3, RPO256 and RPX256 are run with SVE acceleration enabled. ### Instructions Before you can run the benchmarks, you'll need to make sure you have Rust [installed](https://www.rust-lang.org/tools/install). After that, to run the benchmarks for RPO and BLAKE3, clone the current repository, and from the root directory of the repo run the following: diff --git a/benches/hash.rs b/benches/hash.rs index 271c1f52..ea5e1e07 100644 --- a/benches/hash.rs +++ b/benches/hash.rs @@ -3,6 +3,7 @@ use miden_crypto::{ hash::{ blake::Blake3_256, rpo::{Rpo256, RpoDigest}, + rpx::{Rpx256, RpxDigest}, }, Felt, }; @@ -57,6 +58,54 @@ fn rpo256_sequential(c: &mut Criterion) { }); } +fn rpx256_2to1(c: &mut Criterion) { + let v: [RpxDigest; 2] = [Rpx256::hash(&[1_u8]), Rpx256::hash(&[2_u8])]; + c.bench_function("RPX256 2-to-1 hashing (cached)", |bench| { + bench.iter(|| Rpx256::merge(black_box(&v))) + }); + + c.bench_function("RPX256 2-to-1 hashing (random)", |bench| { + bench.iter_batched( + || { + [ + Rpx256::hash(&rand_value::().to_le_bytes()), + Rpx256::hash(&rand_value::().to_le_bytes()), + ] + }, + |state| Rpx256::merge(&state), + BatchSize::SmallInput, + ) + }); +} + +fn rpx256_sequential(c: &mut Criterion) { + let v: [Felt; 100] = (0..100) + .into_iter() + .map(Felt::new) + .collect::>() + .try_into() + .expect("should not fail"); + c.bench_function("RPX256 sequential hashing (cached)", |bench| { + bench.iter(|| Rpx256::hash_elements(black_box(&v))) + }); + + c.bench_function("RPX256 sequential hashing (random)", |bench| { + bench.iter_batched( + || { + let v: [Felt; 100] = (0..100) + .into_iter() + .map(|_| Felt::new(rand_value())) + .collect::>() + .try_into() + .expect("should not fail"); + v + }, + |state| Rpx256::hash_elements(&state), + BatchSize::SmallInput, + ) + }); +} + fn blake3_2to1(c: &mut Criterion) { let v: [::Digest; 2] = [Blake3_256::hash(&[1_u8]), Blake3_256::hash(&[2_u8])]; @@ -106,5 +155,13 @@ fn blake3_sequential(c: &mut Criterion) { }); } -criterion_group!(hash_group, rpo256_2to1, rpo256_sequential, blake3_2to1, blake3_sequential); +criterion_group!( + hash_group, + rpx256_2to1, + rpx256_sequential, + rpo256_2to1, + rpo256_sequential, + blake3_2to1, + blake3_sequential +); criterion_main!(hash_group); diff --git a/src/dsa/rpo_falcon512/falcon_c/falcon.c b/src/dsa/rpo_falcon512/falcon_c/falcon.c index cd7bed56..11dc5de7 100644 --- a/src/dsa/rpo_falcon512/falcon_c/falcon.c +++ b/src/dsa/rpo_falcon512/falcon_c/falcon.c @@ -112,6 +112,7 @@ int PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo( return 0; } +/* see falcon.h */ int PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_rpo( uint8_t *pk, uint8_t *sk @@ -126,6 +127,63 @@ int PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_rpo( return PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo(pk, sk, seed); } +/* see falcon.h */ +int PQCLEAN_FALCON512_CLEAN_crypto_pk_from_sk_rpo( + const uint8_t *sk, + uint8_t *pk) +{ + uint8_t b[FALCON_KEYGEN_TEMP_9]; + int8_t f[512], g[512]; + uint16_t h[512]; + size_t u, v; + + /* + * Decode the private key. + */ + if (sk[0] != 0x50 + 9) + { + return -1; + } + u = 1; + v = PQCLEAN_FALCON512_CLEAN_trim_i8_decode( + f, 9, PQCLEAN_FALCON512_CLEAN_max_fg_bits[9], + sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u); + if (v == 0) + { + return -1; + } + u += v; + v = PQCLEAN_FALCON512_CLEAN_trim_i8_decode( + g, 9, PQCLEAN_FALCON512_CLEAN_max_fg_bits[9], + sk + u, PQCLEAN_FALCON512_CLEAN_CRYPTO_SECRETKEYBYTES - u); + if (v == 0) + { + return -1; + } + + /* + * Compute public key h = g.f^(-1) mod X^N+1 mod q. + */ + if (!PQCLEAN_FALCON512_CLEAN_compute_public(h, f, g, 9, (uint8_t *)b)) + { + return -1; + } + + /* + * Encode public key. + */ + pk[0] = 0x00 + 9; + v = PQCLEAN_FALCON512_CLEAN_modq_encode( + pk + 1, PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES - 1, + h, 9); + if (v != PQCLEAN_FALCON512_CLEAN_CRYPTO_PUBLICKEYBYTES - 1) + { + return -1; + } + + return 0; +} + /* * Compute the signature. nonce[] receives the nonce and must have length * NONCELEN bytes. sigbuf[] receives the signature value (without nonce diff --git a/src/dsa/rpo_falcon512/falcon_c/falcon.h b/src/dsa/rpo_falcon512/falcon_c/falcon.h index bdcc3ec7..559d7657 100644 --- a/src/dsa/rpo_falcon512/falcon_c/falcon.h +++ b/src/dsa/rpo_falcon512/falcon_c/falcon.h @@ -31,6 +31,13 @@ int PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_rpo( int PQCLEAN_FALCON512_CLEAN_crypto_sign_keypair_from_seed_rpo( uint8_t *pk, uint8_t *sk, unsigned char *seed); +/* + * Generate the public key from the secret key (sk). Public key goes into pk[]. + * + * Return value: 0 on success, -1 on error. + */ +int PQCLEAN_FALCON512_CLEAN_crypto_pk_from_sk_rpo(const uint8_t *sk, uint8_t *pk); + /* * Compute a signature on a provided message (m, mlen), with a given * private key (sk). Signature is written in sig[], with length written diff --git a/src/dsa/rpo_falcon512/ffi.rs b/src/dsa/rpo_falcon512/ffi.rs index 4508ff2c..6f33b299 100644 --- a/src/dsa/rpo_falcon512/ffi.rs +++ b/src/dsa/rpo_falcon512/ffi.rs @@ -25,6 +25,11 @@ extern "C" { seed: *const u8, ) -> c_int; + /// Generates the public key from the private key. Public key goes into pk[]. + /// + /// Return value: 0 on success, -1 on error. + pub fn PQCLEAN_FALCON512_CLEAN_crypto_pk_from_sk_rpo(sk: *const u8, pk: *mut u8) -> c_int; + /// Compute a signature on a provided message (m, mlen), with a given private key (sk). /// Signature is written in sig[], with length written into *siglen. Signature length is /// variable; maximum signature length (in bytes) is 666. @@ -103,11 +108,10 @@ mod tests { #[test] fn falcon_ffi() { unsafe { - //let mut rng = rand::thread_rng(); - // --- generate a key pair from a seed ---------------------------- let mut pk = [0u8; PK_LEN]; + let mut pk_gen = [0u8; PK_LEN]; let mut sk = [0u8; SK_LEN]; let seed: [u8; NONCE_LEN] = rand_array(); @@ -120,6 +124,11 @@ mod tests { ) ); + // --- Generate public key from private key and check correctness - + + PQCLEAN_FALCON512_CLEAN_crypto_pk_from_sk_rpo(sk.as_ptr(), pk_gen.as_mut_ptr()); + assert_eq!(pk, pk_gen); + // --- sign a message and make sure it verifies ------------------- let mlen: usize = rand_value::() as usize; diff --git a/src/dsa/rpo_falcon512/keys.rs b/src/dsa/rpo_falcon512/keys.rs index 91702450..168f1eea 100644 --- a/src/dsa/rpo_falcon512/keys.rs +++ b/src/dsa/rpo_falcon512/keys.rs @@ -42,6 +42,27 @@ impl From for Word { } } +// SECRET KEY +// ================================================================================================ + +/// Derives the (expanded) public key associated to a given secret key. +/// +/// # Errors +/// Returns an error if decoding sk fails or if sk is not a valid secret key. +#[cfg(feature = "std")] +pub fn sk_to_pk_bytes(sk: SecretKeyBytes) -> Result { + let mut pk = [0u8; PK_LEN]; + + let res = + unsafe { ffi::PQCLEAN_FALCON512_CLEAN_crypto_pk_from_sk_rpo(sk.as_ptr(), pk.as_mut_ptr()) }; + + if res == 0 { + Ok(pk) + } else { + Err(FalconError::KeyGenerationFailed) + } +} + // KEY PAIR // ================================================================================================ @@ -147,7 +168,12 @@ impl KeyPair { }; if res == 0 { - Ok(Signature { sig, pk: self.public_key }) + Ok(Signature { + sig, + pk: self.public_key, + pk_polynomial: Default::default(), + sig_polynomial: Default::default(), + }) } else { Err(FalconError::SigGenerationFailed) } @@ -177,9 +203,21 @@ impl Deserializable for KeyPair { #[cfg(all(test, feature = "std"))] mod tests { - use super::{super::Felt, KeyPair, NonceBytes, Word}; + use super::{super::Felt, sk_to_pk_bytes, KeyPair, NonceBytes, Word}; use rand_utils::{rand_array, rand_vector}; + #[test] + fn test_pk_from_sk() { + // generate random keys + let keys = KeyPair::new().unwrap(); + let pk_expected = keys.expanded_public_key(); + let sk = keys.secret_key; + + let pk = sk_to_pk_bytes(sk).unwrap(); + + assert_eq!(pk, pk_expected); + } + #[test] fn test_falcon_verification() { // generate random keys diff --git a/src/dsa/rpo_falcon512/signature.rs b/src/dsa/rpo_falcon512/signature.rs index afcde98e..df98915b 100644 --- a/src/dsa/rpo_falcon512/signature.rs +++ b/src/dsa/rpo_falcon512/signature.rs @@ -4,6 +4,7 @@ use super::{ SIG_L2_BOUND, ZERO, }; use crate::utils::string::ToString; +use core::cell::OnceCell; // FALCON SIGNATURE // ================================================================================================ @@ -43,6 +44,10 @@ use crate::utils::string::ToString; pub struct Signature { pub(super) pk: PublicKeyBytes, pub(super) sig: SignatureBytes, + + // Cached polynomial decoding for public key and signatures + pub(super) pk_polynomial: OnceCell, + pub(super) sig_polynomial: OnceCell, } impl Signature { @@ -51,10 +56,11 @@ impl Signature { /// Returns the public key polynomial h. pub fn pub_key_poly(&self) -> Polynomial { - // TODO: memoize - // we assume that the signature was constructed with a valid public key, and thus - // expect() is OK here. - Polynomial::from_pub_key(&self.pk).expect("invalid public key") + *self.pk_polynomial.get_or_init(|| { + // we assume that the signature was constructed with a valid public key, and thus + // expect() is OK here. + Polynomial::from_pub_key(&self.pk).expect("invalid public key") + }) } /// Returns the nonce component of the signature represented as field elements. @@ -70,10 +76,11 @@ impl Signature { // Returns the polynomial representation of the signature in Z_p[x]/(phi). pub fn sig_poly(&self) -> Polynomial { - // TODO: memoize - // we assume that the signature was constructed with a valid signature, and thus - // expect() is OK here. - Polynomial::from_signature(&self.sig).expect("invalid signature") + *self.sig_polynomial.get_or_init(|| { + // we assume that the signature was constructed with a valid signature, and thus + // expect() is OK here. + Polynomial::from_signature(&self.sig).expect("invalid signature") + }) } // HASH-TO-POINT @@ -123,12 +130,14 @@ impl Deserializable for Signature { let sig: SignatureBytes = source.read_array()?; // make sure public key and signature can be decoded correctly - Polynomial::from_pub_key(&pk) - .map_err(|err| DeserializationError::InvalidValue(err.to_string()))?; - Polynomial::from_signature(&sig[41..]) - .map_err(|err| DeserializationError::InvalidValue(err.to_string()))?; - - Ok(Self { pk, sig }) + let pk_polynomial = Polynomial::from_pub_key(&pk) + .map_err(|err| DeserializationError::InvalidValue(err.to_string()))? + .into(); + let sig_polynomial = Polynomial::from_signature(&sig[41..]) + .map_err(|err| DeserializationError::InvalidValue(err.to_string()))? + .into(); + + Ok(Self { pk, sig, pk_polynomial, sig_polynomial }) } } diff --git a/src/hash/mod.rs b/src/hash/mod.rs index 8c87562c..ea068339 100644 --- a/src/hash/mod.rs +++ b/src/hash/mod.rs @@ -1,9 +1,17 @@ //! Cryptographic hash functions used by the Miden VM and the Miden rollup. -use super::{Felt, FieldElement, StarkField, ONE, ZERO}; +use super::{CubeExtension, Felt, FieldElement, StarkField, ONE, ZERO}; pub mod blake; -pub mod rpo; + +mod rescue; +pub mod rpo { + pub use super::rescue::{Rpo256, RpoDigest}; +} + +pub mod rpx { + pub use super::rescue::{Rpx256, RpxDigest}; +} // RE-EXPORTS // ================================================================================================ diff --git a/src/hash/rpo/mds_freq.rs b/src/hash/rescue/mds/freq.rs similarity index 96% rename from src/hash/rpo/mds_freq.rs rename to src/hash/rescue/mds/freq.rs index 6d1f1fdf..17ef67f9 100644 --- a/src/hash/rpo/mds_freq.rs +++ b/src/hash/rescue/mds/freq.rs @@ -11,7 +11,8 @@ /// divisions by 2 and repeated modular reductions. This is because of our explicit choice of /// an MDS matrix that has small powers of 2 entries in frequency domain. /// The following implementation has benefited greatly from the discussions and insights of -/// Hamish Ivey-Law and Jacqueline Nabaglo of Polygon Zero. +/// Hamish Ivey-Law and Jacqueline Nabaglo of Polygon Zero and is base on Nabaglo's Plonky2 +/// implementation. // Rescue MDS matrix in frequency domain. // More precisely, this is the output of the three 4-point (real) FFTs of the first column of @@ -26,7 +27,7 @@ const MDS_FREQ_BLOCK_THREE: [i64; 3] = [-8, 1, 1]; // We use split 3 x 4 FFT transform in order to transform our vectors into the frequency domain. #[inline(always)] -pub(crate) const fn mds_multiply_freq(state: [u64; 12]) -> [u64; 12] { +pub const fn mds_multiply_freq(state: [u64; 12]) -> [u64; 12] { let [s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] = state; let (u0, u1, u2) = fft4_real([s0, s3, s6, s9]); @@ -156,7 +157,7 @@ const fn block3(x: [i64; 3], y: [i64; 3]) -> [i64; 3] { #[cfg(test)] mod tests { - use super::super::{Felt, Rpo256, MDS, ZERO}; + use super::super::{apply_mds, Felt, MDS, ZERO}; use proptest::prelude::*; const STATE_WIDTH: usize = 12; @@ -185,7 +186,7 @@ mod tests { v2 = v1; apply_mds_naive(&mut v1); - Rpo256::apply_mds(&mut v2); + apply_mds(&mut v2); prop_assert_eq!(v1, v2); } diff --git a/src/hash/rescue/mds/mod.rs b/src/hash/rescue/mds/mod.rs new file mode 100644 index 00000000..11b1972a --- /dev/null +++ b/src/hash/rescue/mds/mod.rs @@ -0,0 +1,214 @@ +use super::{Felt, STATE_WIDTH, ZERO}; + +mod freq; +pub use freq::mds_multiply_freq; + +// MDS MULTIPLICATION +// ================================================================================================ + +#[inline(always)] +pub fn apply_mds(state: &mut [Felt; STATE_WIDTH]) { + let mut result = [ZERO; STATE_WIDTH]; + + // Using the linearity of the operations we can split the state into a low||high decomposition + // and operate on each with no overflow and then combine/reduce the result to a field element. + // The no overflow is guaranteed by the fact that the MDS matrix is a small powers of two in + // frequency domain. + let mut state_l = [0u64; STATE_WIDTH]; + let mut state_h = [0u64; STATE_WIDTH]; + + for r in 0..STATE_WIDTH { + let s = state[r].inner(); + state_h[r] = s >> 32; + state_l[r] = (s as u32) as u64; + } + + let state_h = mds_multiply_freq(state_h); + let state_l = mds_multiply_freq(state_l); + + for r in 0..STATE_WIDTH { + let s = state_l[r] as u128 + ((state_h[r] as u128) << 32); + let s_hi = (s >> 64) as u64; + let s_lo = s as u64; + let z = (s_hi << 32) - s_hi; + let (res, over) = s_lo.overflowing_add(z); + + result[r] = Felt::from_mont(res.wrapping_add(0u32.wrapping_sub(over as u32) as u64)); + } + *state = result; +} + +// MDS MATRIX +// ================================================================================================ + +/// RPO MDS matrix +pub const MDS: [[Felt; STATE_WIDTH]; STATE_WIDTH] = [ + [ + Felt::new(7), + Felt::new(23), + Felt::new(8), + Felt::new(26), + Felt::new(13), + Felt::new(10), + Felt::new(9), + Felt::new(7), + Felt::new(6), + Felt::new(22), + Felt::new(21), + Felt::new(8), + ], + [ + Felt::new(8), + Felt::new(7), + Felt::new(23), + Felt::new(8), + Felt::new(26), + Felt::new(13), + Felt::new(10), + Felt::new(9), + Felt::new(7), + Felt::new(6), + Felt::new(22), + Felt::new(21), + ], + [ + Felt::new(21), + Felt::new(8), + Felt::new(7), + Felt::new(23), + Felt::new(8), + Felt::new(26), + Felt::new(13), + Felt::new(10), + Felt::new(9), + Felt::new(7), + Felt::new(6), + Felt::new(22), + ], + [ + Felt::new(22), + Felt::new(21), + Felt::new(8), + Felt::new(7), + Felt::new(23), + Felt::new(8), + Felt::new(26), + Felt::new(13), + Felt::new(10), + Felt::new(9), + Felt::new(7), + Felt::new(6), + ], + [ + Felt::new(6), + Felt::new(22), + Felt::new(21), + Felt::new(8), + Felt::new(7), + Felt::new(23), + Felt::new(8), + Felt::new(26), + Felt::new(13), + Felt::new(10), + Felt::new(9), + Felt::new(7), + ], + [ + Felt::new(7), + Felt::new(6), + Felt::new(22), + Felt::new(21), + Felt::new(8), + Felt::new(7), + Felt::new(23), + Felt::new(8), + Felt::new(26), + Felt::new(13), + Felt::new(10), + Felt::new(9), + ], + [ + Felt::new(9), + Felt::new(7), + Felt::new(6), + Felt::new(22), + Felt::new(21), + Felt::new(8), + Felt::new(7), + Felt::new(23), + Felt::new(8), + Felt::new(26), + Felt::new(13), + Felt::new(10), + ], + [ + Felt::new(10), + Felt::new(9), + Felt::new(7), + Felt::new(6), + Felt::new(22), + Felt::new(21), + Felt::new(8), + Felt::new(7), + Felt::new(23), + Felt::new(8), + Felt::new(26), + Felt::new(13), + ], + [ + Felt::new(13), + Felt::new(10), + Felt::new(9), + Felt::new(7), + Felt::new(6), + Felt::new(22), + Felt::new(21), + Felt::new(8), + Felt::new(7), + Felt::new(23), + Felt::new(8), + Felt::new(26), + ], + [ + Felt::new(26), + Felt::new(13), + Felt::new(10), + Felt::new(9), + Felt::new(7), + Felt::new(6), + Felt::new(22), + Felt::new(21), + Felt::new(8), + Felt::new(7), + Felt::new(23), + Felt::new(8), + ], + [ + Felt::new(8), + Felt::new(26), + Felt::new(13), + Felt::new(10), + Felt::new(9), + Felt::new(7), + Felt::new(6), + Felt::new(22), + Felt::new(21), + Felt::new(8), + Felt::new(7), + Felt::new(23), + ], + [ + Felt::new(23), + Felt::new(8), + Felt::new(26), + Felt::new(13), + Felt::new(10), + Felt::new(9), + Felt::new(7), + Felt::new(6), + Felt::new(22), + Felt::new(21), + Felt::new(8), + Felt::new(7), + ], +]; diff --git a/src/hash/rescue/mod.rs b/src/hash/rescue/mod.rs new file mode 100644 index 00000000..2fa942ac --- /dev/null +++ b/src/hash/rescue/mod.rs @@ -0,0 +1,401 @@ +use super::{ + CubeExtension, Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField, ONE, ZERO, +}; +use core::ops::Range; + +mod mds; +use mds::{apply_mds, MDS}; + +mod rpo; +pub use rpo::{Rpo256, RpoDigest}; + +mod rpx; +pub use rpx::{Rpx256, RpxDigest}; + +#[cfg(test)] +mod tests; + +// CONSTANTS +// ================================================================================================ + +/// The number of rounds is set to 7. For the RPO hash functions all rounds are uniform. For the +/// RPX hash function, there are 3 different types of rounds. +const NUM_ROUNDS: usize = 7; + +/// Sponge state is set to 12 field elements or 96 bytes; 8 elements are reserved for rate and +/// the remaining 4 elements are reserved for capacity. +const STATE_WIDTH: usize = 12; + +/// The rate portion of the state is located in elements 4 through 11. +const RATE_RANGE: Range = 4..12; +const RATE_WIDTH: usize = RATE_RANGE.end - RATE_RANGE.start; + +const INPUT1_RANGE: Range = 4..8; +const INPUT2_RANGE: Range = 8..12; + +/// The capacity portion of the state is located in elements 0, 1, 2, and 3. +const CAPACITY_RANGE: Range = 0..4; + +/// The output of the hash function is a digest which consists of 4 field elements or 32 bytes. +/// +/// The digest is returned from state elements 4, 5, 6, and 7 (the first four elements of the +/// rate portion). +const DIGEST_RANGE: Range = 4..8; +const DIGEST_SIZE: usize = DIGEST_RANGE.end - DIGEST_RANGE.start; + +/// The number of bytes needed to encoded a digest +const DIGEST_BYTES: usize = 32; + +/// The number of byte chunks defining a field element when hashing a sequence of bytes +const BINARY_CHUNK_SIZE: usize = 7; + +/// S-Box and Inverse S-Box powers; +/// +/// The constants are defined for tests only because the exponentiations in the code are unrolled +/// for efficiency reasons. +#[cfg(test)] +const ALPHA: u64 = 7; +#[cfg(test)] +const INV_ALPHA: u64 = 10540996611094048183; + +// SBOX FUNCTION +// ================================================================================================ + +#[inline(always)] +fn apply_sbox(state: &mut [Felt; STATE_WIDTH]) { + state[0] = state[0].exp7(); + state[1] = state[1].exp7(); + state[2] = state[2].exp7(); + state[3] = state[3].exp7(); + state[4] = state[4].exp7(); + state[5] = state[5].exp7(); + state[6] = state[6].exp7(); + state[7] = state[7].exp7(); + state[8] = state[8].exp7(); + state[9] = state[9].exp7(); + state[10] = state[10].exp7(); + state[11] = state[11].exp7(); +} + +// INVERSE SBOX FUNCTION +// ================================================================================================ + +#[inline(always)] +fn apply_inv_sbox(state: &mut [Felt; STATE_WIDTH]) { + // compute base^10540996611094048183 using 72 multiplications per array element + // 10540996611094048183 = b1001001001001001001001001001000110110110110110110110110110110111 + + // compute base^10 + let mut t1 = *state; + t1.iter_mut().for_each(|t| *t = t.square()); + + // compute base^100 + let mut t2 = t1; + t2.iter_mut().for_each(|t| *t = t.square()); + + // compute base^100100 + let t3 = exp_acc::(t2, t2); + + // compute base^100100100100 + let t4 = exp_acc::(t3, t3); + + // compute base^100100100100100100100100 + let t5 = exp_acc::(t4, t4); + + // compute base^100100100100100100100100100100 + let t6 = exp_acc::(t5, t3); + + // compute base^1001001001001001001001001001000100100100100100100100100100100 + let t7 = exp_acc::(t6, t6); + + // compute base^1001001001001001001001001001000110110110110110110110110110110111 + for (i, s) in state.iter_mut().enumerate() { + let a = (t7[i].square() * t6[i]).square().square(); + let b = t1[i] * t2[i] * *s; + *s = a * b; + } + + #[inline(always)] + fn exp_acc( + base: [B; N], + tail: [B; N], + ) -> [B; N] { + let mut result = base; + for _ in 0..M { + result.iter_mut().for_each(|r| *r = r.square()); + } + result.iter_mut().zip(tail).for_each(|(r, t)| *r *= t); + result + } +} + +// OPTIMIZATIONS +// ================================================================================================ + +#[cfg(all(target_feature = "sve", feature = "sve"))] +#[link(name = "rpo_sve", kind = "static")] +extern "C" { + fn add_constants_and_apply_sbox( + state: *mut std::ffi::c_ulong, + constants: *const std::ffi::c_ulong, + ) -> bool; + fn add_constants_and_apply_inv_sbox( + state: *mut std::ffi::c_ulong, + constants: *const std::ffi::c_ulong, + ) -> bool; +} + +#[inline(always)] +#[cfg(all(target_feature = "sve", feature = "sve"))] +fn optimized_add_constants_and_apply_sbox( + state: &mut [Felt; STATE_WIDTH], + ark: &[Felt; STATE_WIDTH], +) -> bool { + unsafe { + add_constants_and_apply_sbox(state.as_mut_ptr() as *mut u64, ark.as_ptr() as *const u64) + } +} + +#[inline(always)] +#[cfg(not(all(target_feature = "sve", feature = "sve")))] +fn optimized_add_constants_and_apply_sbox( + _state: &mut [Felt; STATE_WIDTH], + _ark: &[Felt; STATE_WIDTH], +) -> bool { + false +} + +#[inline(always)] +#[cfg(all(target_feature = "sve", feature = "sve"))] +fn optimized_add_constants_and_apply_inv_sbox( + state: &mut [Felt; STATE_WIDTH], + ark: &[Felt; STATE_WIDTH], +) -> bool { + unsafe { + add_constants_and_apply_inv_sbox(state.as_mut_ptr() as *mut u64, ark.as_ptr() as *const u64) + } +} + +#[inline(always)] +#[cfg(not(all(target_feature = "sve", feature = "sve")))] +fn optimized_add_constants_and_apply_inv_sbox( + _state: &mut [Felt; STATE_WIDTH], + _ark: &[Felt; STATE_WIDTH], +) -> bool { + false +} + +#[inline(always)] +fn add_constants(state: &mut [Felt; STATE_WIDTH], ark: &[Felt; STATE_WIDTH]) { + state.iter_mut().zip(ark).for_each(|(s, &k)| *s += k); +} + +// ROUND CONSTANTS +// ================================================================================================ + +/// Rescue round constants; +/// computed as in [specifications](https://github.com/ASDiscreteMathematics/rpo) +/// +/// The constants are broken up into two arrays ARK1 and ARK2; ARK1 contains the constants for the +/// first half of RPO round, and ARK2 contains constants for the second half of RPO round. +const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = [ + [ + Felt::new(5789762306288267392), + Felt::new(6522564764413701783), + Felt::new(17809893479458208203), + Felt::new(107145243989736508), + Felt::new(6388978042437517382), + Felt::new(15844067734406016715), + Felt::new(9975000513555218239), + Felt::new(3344984123768313364), + Felt::new(9959189626657347191), + Felt::new(12960773468763563665), + Felt::new(9602914297752488475), + Felt::new(16657542370200465908), + ], + [ + Felt::new(12987190162843096997), + Felt::new(653957632802705281), + Felt::new(4441654670647621225), + Felt::new(4038207883745915761), + Felt::new(5613464648874830118), + Felt::new(13222989726778338773), + Felt::new(3037761201230264149), + Felt::new(16683759727265180203), + Felt::new(8337364536491240715), + Felt::new(3227397518293416448), + Felt::new(8110510111539674682), + Felt::new(2872078294163232137), + ], + [ + Felt::new(18072785500942327487), + Felt::new(6200974112677013481), + Felt::new(17682092219085884187), + Felt::new(10599526828986756440), + Felt::new(975003873302957338), + Felt::new(8264241093196931281), + Felt::new(10065763900435475170), + Felt::new(2181131744534710197), + Felt::new(6317303992309418647), + Felt::new(1401440938888741532), + Felt::new(8884468225181997494), + Felt::new(13066900325715521532), + ], + [ + Felt::new(5674685213610121970), + Felt::new(5759084860419474071), + Felt::new(13943282657648897737), + Felt::new(1352748651966375394), + Felt::new(17110913224029905221), + Felt::new(1003883795902368422), + Felt::new(4141870621881018291), + Felt::new(8121410972417424656), + Felt::new(14300518605864919529), + Felt::new(13712227150607670181), + Felt::new(17021852944633065291), + Felt::new(6252096473787587650), + ], + [ + Felt::new(4887609836208846458), + Felt::new(3027115137917284492), + Felt::new(9595098600469470675), + Felt::new(10528569829048484079), + Felt::new(7864689113198939815), + Felt::new(17533723827845969040), + Felt::new(5781638039037710951), + Felt::new(17024078752430719006), + Felt::new(109659393484013511), + Felt::new(7158933660534805869), + Felt::new(2955076958026921730), + Felt::new(7433723648458773977), + ], + [ + Felt::new(16308865189192447297), + Felt::new(11977192855656444890), + Felt::new(12532242556065780287), + Felt::new(14594890931430968898), + Felt::new(7291784239689209784), + Felt::new(5514718540551361949), + Felt::new(10025733853830934803), + Felt::new(7293794580341021693), + Felt::new(6728552937464861756), + Felt::new(6332385040983343262), + Felt::new(13277683694236792804), + Felt::new(2600778905124452676), + ], + [ + Felt::new(7123075680859040534), + Felt::new(1034205548717903090), + Felt::new(7717824418247931797), + Felt::new(3019070937878604058), + Felt::new(11403792746066867460), + Felt::new(10280580802233112374), + Felt::new(337153209462421218), + Felt::new(13333398568519923717), + Felt::new(3596153696935337464), + Felt::new(8104208463525993784), + Felt::new(14345062289456085693), + Felt::new(17036731477169661256), + ], +]; + +const ARK2: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = [ + [ + Felt::new(6077062762357204287), + Felt::new(15277620170502011191), + Felt::new(5358738125714196705), + Felt::new(14233283787297595718), + Felt::new(13792579614346651365), + Felt::new(11614812331536767105), + Felt::new(14871063686742261166), + Felt::new(10148237148793043499), + Felt::new(4457428952329675767), + Felt::new(15590786458219172475), + Felt::new(10063319113072092615), + Felt::new(14200078843431360086), + ], + [ + Felt::new(6202948458916099932), + Felt::new(17690140365333231091), + Felt::new(3595001575307484651), + Felt::new(373995945117666487), + Felt::new(1235734395091296013), + Felt::new(14172757457833931602), + Felt::new(707573103686350224), + Felt::new(15453217512188187135), + Felt::new(219777875004506018), + Felt::new(17876696346199469008), + Felt::new(17731621626449383378), + Felt::new(2897136237748376248), + ], + [ + Felt::new(8023374565629191455), + Felt::new(15013690343205953430), + Felt::new(4485500052507912973), + Felt::new(12489737547229155153), + Felt::new(9500452585969030576), + Felt::new(2054001340201038870), + Felt::new(12420704059284934186), + Felt::new(355990932618543755), + Felt::new(9071225051243523860), + Felt::new(12766199826003448536), + Felt::new(9045979173463556963), + Felt::new(12934431667190679898), + ], + [ + Felt::new(18389244934624494276), + Felt::new(16731736864863925227), + Felt::new(4440209734760478192), + Felt::new(17208448209698888938), + Felt::new(8739495587021565984), + Felt::new(17000774922218161967), + Felt::new(13533282547195532087), + Felt::new(525402848358706231), + Felt::new(16987541523062161972), + Felt::new(5466806524462797102), + Felt::new(14512769585918244983), + Felt::new(10973956031244051118), + ], + [ + Felt::new(6982293561042362913), + Felt::new(14065426295947720331), + Felt::new(16451845770444974180), + Felt::new(7139138592091306727), + Felt::new(9012006439959783127), + Felt::new(14619614108529063361), + Felt::new(1394813199588124371), + Felt::new(4635111139507788575), + Felt::new(16217473952264203365), + Felt::new(10782018226466330683), + Felt::new(6844229992533662050), + Felt::new(7446486531695178711), + ], + [ + Felt::new(3736792340494631448), + Felt::new(577852220195055341), + Felt::new(6689998335515779805), + Felt::new(13886063479078013492), + Felt::new(14358505101923202168), + Felt::new(7744142531772274164), + Felt::new(16135070735728404443), + Felt::new(12290902521256031137), + Felt::new(12059913662657709804), + Felt::new(16456018495793751911), + Felt::new(4571485474751953524), + Felt::new(17200392109565783176), + ], + [ + Felt::new(17130398059294018733), + Felt::new(519782857322261988), + Felt::new(9625384390925085478), + Felt::new(1664893052631119222), + Felt::new(7629576092524553570), + Felt::new(3485239601103661425), + Felt::new(9755891797164033838), + Felt::new(15218148195153269027), + Felt::new(16460604813734957368), + Felt::new(9643968136937729763), + Felt::new(3611348709641382851), + Felt::new(18256379591337759196), + ], +]; diff --git a/src/hash/rpo/digest.rs b/src/hash/rescue/rpo/digest.rs similarity index 71% rename from src/hash/rpo/digest.rs rename to src/hash/rescue/rpo/digest.rs index 2a269d6d..a4cfa174 100644 --- a/src/hash/rpo/digest.rs +++ b/src/hash/rescue/rpo/digest.rs @@ -1,4 +1,4 @@ -use super::{Digest, Felt, StarkField, DIGEST_SIZE, ZERO}; +use super::{Digest, Felt, StarkField, DIGEST_BYTES, DIGEST_SIZE, ZERO}; use crate::utils::{ bytes_to_hex_string, hex_to_bytes, string::String, ByteReader, ByteWriter, Deserializable, DeserializationError, HexParseError, Serializable, @@ -6,9 +6,6 @@ use crate::utils::{ use core::{cmp::Ordering, fmt::Display, ops::Deref}; use winter_utils::Randomizable; -/// The number of bytes needed to encoded a digest -pub const DIGEST_BYTES: usize = 32; - // DIGEST TRAIT IMPLEMENTATIONS // ================================================================================================ @@ -172,9 +169,21 @@ impl From<&RpoDigest> for String { } } -// CONVERSIONS: TO DIGEST +// CONVERSIONS: TO RPO DIGEST // ================================================================================================ +#[derive(Copy, Clone, Debug)] +pub enum RpoDigestError { + /// The provided u64 integer does not fit in the field's moduli. + InvalidInteger, +} + +impl From<&[Felt; DIGEST_SIZE]> for RpoDigest { + fn from(value: &[Felt; DIGEST_SIZE]) -> Self { + Self(*value) + } +} + impl From<[Felt; DIGEST_SIZE]> for RpoDigest { fn from(value: [Felt; DIGEST_SIZE]) -> Self { Self(value) @@ -200,6 +209,46 @@ impl TryFrom<[u8; DIGEST_BYTES]> for RpoDigest { } } +impl TryFrom<&[u8; DIGEST_BYTES]> for RpoDigest { + type Error = HexParseError; + + fn try_from(value: &[u8; DIGEST_BYTES]) -> Result { + (*value).try_into() + } +} + +impl TryFrom<&[u8]> for RpoDigest { + type Error = HexParseError; + + fn try_from(value: &[u8]) -> Result { + (*value).try_into() + } +} + +impl TryFrom<[u64; DIGEST_SIZE]> for RpoDigest { + type Error = RpoDigestError; + + fn try_from(value: [u64; DIGEST_SIZE]) -> Result { + if value[0] >= Felt::MODULUS + || value[1] >= Felt::MODULUS + || value[2] >= Felt::MODULUS + || value[3] >= Felt::MODULUS + { + return Err(RpoDigestError::InvalidInteger); + } + + Ok(Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()])) + } +} + +impl TryFrom<&[u64; DIGEST_SIZE]> for RpoDigest { + type Error = RpoDigestError; + + fn try_from(value: &[u64; DIGEST_SIZE]) -> Result { + (*value).try_into() + } +} + impl TryFrom<&str> for RpoDigest { type Error = HexParseError; @@ -253,13 +302,24 @@ impl Deserializable for RpoDigest { } } +// ITERATORS +// ================================================================================================ +impl IntoIterator for RpoDigest { + type Item = Felt; + type IntoIter = <[Felt; 4] as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + // TESTS // ================================================================================================ #[cfg(test)] mod tests { - use super::{Deserializable, Felt, RpoDigest, Serializable, DIGEST_BYTES}; - use crate::utils::SliceReader; + use super::{Deserializable, Felt, RpoDigest, Serializable, DIGEST_BYTES, DIGEST_SIZE}; + use crate::utils::{string::String, SliceReader}; use rand_utils::rand_value; #[test] @@ -281,7 +341,6 @@ mod tests { assert_eq!(d1, d2); } - #[cfg(feature = "std")] #[test] fn digest_encoding() { let digest = RpoDigest([ @@ -296,4 +355,54 @@ mod tests { assert_eq!(digest, round_trip); } + + #[test] + fn test_conversions() { + let digest = RpoDigest([ + Felt::new(rand_value()), + Felt::new(rand_value()), + Felt::new(rand_value()), + Felt::new(rand_value()), + ]); + + let v: [Felt; DIGEST_SIZE] = digest.into(); + let v2: RpoDigest = v.into(); + assert_eq!(digest, v2); + + let v: [Felt; DIGEST_SIZE] = (&digest).into(); + let v2: RpoDigest = v.into(); + assert_eq!(digest, v2); + + let v: [u64; DIGEST_SIZE] = digest.into(); + let v2: RpoDigest = v.try_into().unwrap(); + assert_eq!(digest, v2); + + let v: [u64; DIGEST_SIZE] = (&digest).into(); + let v2: RpoDigest = v.try_into().unwrap(); + assert_eq!(digest, v2); + + let v: [u8; DIGEST_BYTES] = digest.into(); + let v2: RpoDigest = v.try_into().unwrap(); + assert_eq!(digest, v2); + + let v: [u8; DIGEST_BYTES] = (&digest).into(); + let v2: RpoDigest = v.try_into().unwrap(); + assert_eq!(digest, v2); + + let v: String = digest.into(); + let v2: RpoDigest = v.try_into().unwrap(); + assert_eq!(digest, v2); + + let v: String = (&digest).into(); + let v2: RpoDigest = v.try_into().unwrap(); + assert_eq!(digest, v2); + + let v: [u8; DIGEST_BYTES] = digest.into(); + let v2: RpoDigest = (&v).try_into().unwrap(); + assert_eq!(digest, v2); + + let v: [u8; DIGEST_BYTES] = (&digest).into(); + let v2: RpoDigest = (&v).try_into().unwrap(); + assert_eq!(digest, v2); + } } diff --git a/src/hash/rescue/rpo/mod.rs b/src/hash/rescue/rpo/mod.rs new file mode 100644 index 00000000..c28f87d1 --- /dev/null +++ b/src/hash/rescue/rpo/mod.rs @@ -0,0 +1,324 @@ +use super::{ + add_constants, apply_inv_sbox, apply_mds, apply_sbox, + optimized_add_constants_and_apply_inv_sbox, optimized_add_constants_and_apply_sbox, Digest, + ElementHasher, Felt, FieldElement, Hasher, StarkField, ARK1, ARK2, BINARY_CHUNK_SIZE, + CAPACITY_RANGE, DIGEST_BYTES, DIGEST_RANGE, DIGEST_SIZE, INPUT1_RANGE, INPUT2_RANGE, MDS, + NUM_ROUNDS, ONE, RATE_RANGE, RATE_WIDTH, STATE_WIDTH, ZERO, +}; +use core::{convert::TryInto, ops::Range}; + +mod digest; +pub use digest::RpoDigest; + +#[cfg(test)] +mod tests; + +// HASHER IMPLEMENTATION +// ================================================================================================ + +/// Implementation of the Rescue Prime Optimized hash function with 256-bit output. +/// +/// The hash function is implemented according to the Rescue Prime Optimized +/// [specifications](https://eprint.iacr.org/2022/1577) +/// +/// The parameters used to instantiate the function are: +/// * Field: 64-bit prime field with modulus 2^64 - 2^32 + 1. +/// * State width: 12 field elements. +/// * Capacity size: 4 field elements. +/// * Number of founds: 7. +/// * S-Box degree: 7. +/// +/// The above parameters target 128-bit security level. The digest consists of four field elements +/// and it can be serialized into 32 bytes (256 bits). +/// +/// ## Hash output consistency +/// Functions [hash_elements()](Rpo256::hash_elements), [merge()](Rpo256::merge), and +/// [merge_with_int()](Rpo256::merge_with_int) are internally consistent. That is, computing +/// a hash for the same set of elements using these functions will always produce the same +/// result. For example, merging two digests using [merge()](Rpo256::merge) will produce the +/// same result as hashing 8 elements which make up these digests using +/// [hash_elements()](Rpo256::hash_elements) function. +/// +/// However, [hash()](Rpo256::hash) function is not consistent with functions mentioned above. +/// For example, if we take two field elements, serialize them to bytes and hash them using +/// [hash()](Rpo256::hash), the result will differ from the result obtained by hashing these +/// elements directly using [hash_elements()](Rpo256::hash_elements) function. The reason for +/// this difference is that [hash()](Rpo256::hash) function needs to be able to handle +/// arbitrary binary strings, which may or may not encode valid field elements - and thus, +/// deserialization procedure used by this function is different from the procedure used to +/// deserialize valid field elements. +/// +/// Thus, if the underlying data consists of valid field elements, it might make more sense +/// to deserialize them into field elements and then hash them using +/// [hash_elements()](Rpo256::hash_elements) function rather then hashing the serialized bytes +/// using [hash()](Rpo256::hash) function. +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub struct Rpo256(); + +impl Hasher for Rpo256 { + /// Rpo256 collision resistance is the same as the security level, that is 128-bits. + /// + /// #### Collision resistance + /// + /// However, our setup of the capacity registers might drop it to 126. + /// + /// Related issue: [#69](https://github.com/0xPolygonMiden/crypto/issues/69) + const COLLISION_RESISTANCE: u32 = 128; + + type Digest = RpoDigest; + + fn hash(bytes: &[u8]) -> Self::Digest { + // initialize the state with zeroes + let mut state = [ZERO; STATE_WIDTH]; + + // set the capacity (first element) to a flag on whether or not the input length is evenly + // divided by the rate. this will prevent collisions between padded and non-padded inputs, + // and will rule out the need to perform an extra permutation in case of evenly divided + // inputs. + let is_rate_multiple = bytes.len() % RATE_WIDTH == 0; + if !is_rate_multiple { + state[CAPACITY_RANGE.start] = ONE; + } + + // initialize a buffer to receive the little-endian elements. + let mut buf = [0_u8; 8]; + + // iterate the chunks of bytes, creating a field element from each chunk and copying it + // into the state. + // + // every time the rate range is filled, a permutation is performed. if the final value of + // `i` is not zero, then the chunks count wasn't enough to fill the state range, and an + // additional permutation must be performed. + let i = bytes.chunks(BINARY_CHUNK_SIZE).fold(0, |i, chunk| { + // the last element of the iteration may or may not be a full chunk. if it's not, then + // we need to pad the remainder bytes of the chunk with zeroes, separated by a `1`. + // this will avoid collisions. + if chunk.len() == BINARY_CHUNK_SIZE { + buf[..BINARY_CHUNK_SIZE].copy_from_slice(chunk); + } else { + buf.fill(0); + buf[..chunk.len()].copy_from_slice(chunk); + buf[chunk.len()] = 1; + } + + // set the current rate element to the input. since we take at most 7 bytes, we are + // guaranteed that the inputs data will fit into a single field element. + state[RATE_RANGE.start + i] = Felt::new(u64::from_le_bytes(buf)); + + // proceed filling the range. if it's full, then we apply a permutation and reset the + // counter to the beginning of the range. + if i == RATE_WIDTH - 1 { + Self::apply_permutation(&mut state); + 0 + } else { + i + 1 + } + }); + + // if we absorbed some elements but didn't apply a permutation to them (would happen when + // the number of elements is not a multiple of RATE_WIDTH), apply the RPO permutation. we + // don't need to apply any extra padding because the first capacity element containts a + // flag indicating whether the input is evenly divisible by the rate. + if i != 0 { + state[RATE_RANGE.start + i..RATE_RANGE.end].fill(ZERO); + state[RATE_RANGE.start + i] = ONE; + Self::apply_permutation(&mut state); + } + + // return the first 4 elements of the rate as hash result. + RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap()) + } + + fn merge(values: &[Self::Digest; 2]) -> Self::Digest { + // initialize the state by copying the digest elements into the rate portion of the state + // (8 total elements), and set the capacity elements to 0. + let mut state = [ZERO; STATE_WIDTH]; + let it = Self::Digest::digests_as_elements(values.iter()); + for (i, v) in it.enumerate() { + state[RATE_RANGE.start + i] = *v; + } + + // apply the RPO permutation and return the first four elements of the state + Self::apply_permutation(&mut state); + RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap()) + } + + fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest { + // initialize the state as follows: + // - seed is copied into the first 4 elements of the rate portion of the state. + // - if the value fits into a single field element, copy it into the fifth rate element + // and set the sixth rate element to 1. + // - if the value doesn't fit into a single field element, split it into two field + // elements, copy them into rate elements 5 and 6, and set the seventh rate element + // to 1. + // - set the first capacity element to 1 + let mut state = [ZERO; STATE_WIDTH]; + state[INPUT1_RANGE].copy_from_slice(seed.as_elements()); + state[INPUT2_RANGE.start] = Felt::new(value); + if value < Felt::MODULUS { + state[INPUT2_RANGE.start + 1] = ONE; + } else { + state[INPUT2_RANGE.start + 1] = Felt::new(value / Felt::MODULUS); + state[INPUT2_RANGE.start + 2] = ONE; + } + + // common padding for both cases + state[CAPACITY_RANGE.start] = ONE; + + // apply the RPO permutation and return the first four elements of the state + Self::apply_permutation(&mut state); + RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap()) + } +} + +impl ElementHasher for Rpo256 { + type BaseField = Felt; + + fn hash_elements>(elements: &[E]) -> Self::Digest { + // convert the elements into a list of base field elements + let elements = E::slice_as_base_elements(elements); + + // initialize state to all zeros, except for the first element of the capacity part, which + // is set to 1 if the number of elements is not a multiple of RATE_WIDTH. + let mut state = [ZERO; STATE_WIDTH]; + if elements.len() % RATE_WIDTH != 0 { + state[CAPACITY_RANGE.start] = ONE; + } + + // absorb elements into the state one by one until the rate portion of the state is filled + // up; then apply the Rescue permutation and start absorbing again; repeat until all + // elements have been absorbed + let mut i = 0; + for &element in elements.iter() { + state[RATE_RANGE.start + i] = element; + i += 1; + if i % RATE_WIDTH == 0 { + Self::apply_permutation(&mut state); + i = 0; + } + } + + // if we absorbed some elements but didn't apply a permutation to them (would happen when + // the number of elements is not a multiple of RATE_WIDTH), apply the RPO permutation after + // padding by appending a 1 followed by as many 0 as necessary to make the input length a + // multiple of the RATE_WIDTH. + if i > 0 { + state[RATE_RANGE.start + i] = ONE; + i += 1; + while i != RATE_WIDTH { + state[RATE_RANGE.start + i] = ZERO; + i += 1; + } + Self::apply_permutation(&mut state); + } + + // return the first 4 elements of the state as hash result + RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap()) + } +} + +// HASH FUNCTION IMPLEMENTATION +// ================================================================================================ + +impl Rpo256 { + // CONSTANTS + // -------------------------------------------------------------------------------------------- + + /// The number of rounds is set to 7 to target 128-bit security level. + pub const NUM_ROUNDS: usize = NUM_ROUNDS; + + /// Sponge state is set to 12 field elements or 768 bytes; 8 elements are reserved for rate and + /// the remaining 4 elements are reserved for capacity. + pub const STATE_WIDTH: usize = STATE_WIDTH; + + /// The rate portion of the state is located in elements 4 through 11 (inclusive). + pub const RATE_RANGE: Range = RATE_RANGE; + + /// The capacity portion of the state is located in elements 0, 1, 2, and 3. + pub const CAPACITY_RANGE: Range = CAPACITY_RANGE; + + /// The output of the hash function can be read from state elements 4, 5, 6, and 7. + pub const DIGEST_RANGE: Range = DIGEST_RANGE; + + /// MDS matrix used for computing the linear layer in a RPO round. + pub const MDS: [[Felt; STATE_WIDTH]; STATE_WIDTH] = MDS; + + /// Round constants added to the hasher state in the first half of the RPO round. + pub const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK1; + + /// Round constants added to the hasher state in the second half of the RPO round. + pub const ARK2: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK2; + + // TRAIT PASS-THROUGH FUNCTIONS + // -------------------------------------------------------------------------------------------- + + /// Returns a hash of the provided sequence of bytes. + #[inline(always)] + pub fn hash(bytes: &[u8]) -> RpoDigest { + ::hash(bytes) + } + + /// Returns a hash of two digests. This method is intended for use in construction of + /// Merkle trees and verification of Merkle paths. + #[inline(always)] + pub fn merge(values: &[RpoDigest; 2]) -> RpoDigest { + ::merge(values) + } + + /// Returns a hash of the provided field elements. + #[inline(always)] + pub fn hash_elements>(elements: &[E]) -> RpoDigest { + ::hash_elements(elements) + } + + // DOMAIN IDENTIFIER + // -------------------------------------------------------------------------------------------- + + /// Returns a hash of two digests and a domain identifier. + pub fn merge_in_domain(values: &[RpoDigest; 2], domain: Felt) -> RpoDigest { + // initialize the state by copying the digest elements into the rate portion of the state + // (8 total elements), and set the capacity elements to 0. + let mut state = [ZERO; STATE_WIDTH]; + let it = RpoDigest::digests_as_elements(values.iter()); + for (i, v) in it.enumerate() { + state[RATE_RANGE.start + i] = *v; + } + + // set the second capacity element to the domain value. The first capacity element is used + // for padding purposes. + state[CAPACITY_RANGE.start + 1] = domain; + + // apply the RPO permutation and return the first four elements of the state + Self::apply_permutation(&mut state); + RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap()) + } + + // RESCUE PERMUTATION + // -------------------------------------------------------------------------------------------- + + /// Applies RPO permutation to the provided state. + #[inline(always)] + pub fn apply_permutation(state: &mut [Felt; STATE_WIDTH]) { + for i in 0..NUM_ROUNDS { + Self::apply_round(state, i); + } + } + + /// RPO round function. + #[inline(always)] + pub fn apply_round(state: &mut [Felt; STATE_WIDTH], round: usize) { + // apply first half of RPO round + apply_mds(state); + if !optimized_add_constants_and_apply_sbox(state, &ARK1[round]) { + add_constants(state, &ARK1[round]); + apply_sbox(state); + } + + // apply second half of RPO round + apply_mds(state); + if !optimized_add_constants_and_apply_inv_sbox(state, &ARK2[round]) { + add_constants(state, &ARK2[round]); + apply_inv_sbox(state); + } + } +} diff --git a/src/hash/rpo/tests.rs b/src/hash/rescue/rpo/tests.rs similarity index 96% rename from src/hash/rpo/tests.rs rename to src/hash/rescue/rpo/tests.rs index 3ca5a33c..3dbdf4d9 100644 --- a/src/hash/rpo/tests.rs +++ b/src/hash/rescue/rpo/tests.rs @@ -1,6 +1,6 @@ use super::{ - Felt, FieldElement, Hasher, Rpo256, RpoDigest, StarkField, ALPHA, INV_ALPHA, ONE, STATE_WIDTH, - ZERO, + super::{apply_inv_sbox, apply_sbox, ALPHA, INV_ALPHA}, + Felt, FieldElement, Hasher, Rpo256, RpoDigest, StarkField, ONE, STATE_WIDTH, ZERO, }; use crate::{ utils::collections::{BTreeSet, Vec}, @@ -10,13 +10,6 @@ use core::convert::TryInto; use proptest::prelude::*; use rand_utils::rand_value; -#[test] -fn test_alphas() { - let e: Felt = Felt::new(rand_value()); - let e_exp = e.exp(ALPHA); - assert_eq!(e, e_exp.exp(INV_ALPHA)); -} - #[test] fn test_sbox() { let state = [Felt::new(rand_value()); STATE_WIDTH]; @@ -25,7 +18,7 @@ fn test_sbox() { expected.iter_mut().for_each(|v| *v = v.exp(ALPHA)); let mut actual = state; - Rpo256::apply_sbox(&mut actual); + apply_sbox(&mut actual); assert_eq!(expected, actual); } @@ -38,7 +31,7 @@ fn test_inv_sbox() { expected.iter_mut().for_each(|v| *v = v.exp(INV_ALPHA)); let mut actual = state; - Rpo256::apply_inv_sbox(&mut actual); + apply_inv_sbox(&mut actual); assert_eq!(expected, actual); } diff --git a/src/hash/rescue/rpx/digest.rs b/src/hash/rescue/rpx/digest.rs new file mode 100644 index 00000000..a9a236a7 --- /dev/null +++ b/src/hash/rescue/rpx/digest.rs @@ -0,0 +1,398 @@ +use super::{Digest, Felt, StarkField, DIGEST_BYTES, DIGEST_SIZE, ZERO}; +use crate::utils::{ + bytes_to_hex_string, hex_to_bytes, string::String, ByteReader, ByteWriter, Deserializable, + DeserializationError, HexParseError, Serializable, +}; +use core::{cmp::Ordering, fmt::Display, ops::Deref}; +use winter_utils::Randomizable; + +// DIGEST TRAIT IMPLEMENTATIONS +// ================================================================================================ + +#[derive(Debug, Default, Copy, Clone, Eq, PartialEq)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] +#[cfg_attr(feature = "serde", serde(into = "String", try_from = "&str"))] +pub struct RpxDigest([Felt; DIGEST_SIZE]); + +impl RpxDigest { + pub const fn new(value: [Felt; DIGEST_SIZE]) -> Self { + Self(value) + } + + pub fn as_elements(&self) -> &[Felt] { + self.as_ref() + } + + pub fn as_bytes(&self) -> [u8; DIGEST_BYTES] { + ::as_bytes(self) + } + + pub fn digests_as_elements<'a, I>(digests: I) -> impl Iterator + where + I: Iterator, + { + digests.flat_map(|d| d.0.iter()) + } +} + +impl Digest for RpxDigest { + fn as_bytes(&self) -> [u8; DIGEST_BYTES] { + let mut result = [0; DIGEST_BYTES]; + + result[..8].copy_from_slice(&self.0[0].as_int().to_le_bytes()); + result[8..16].copy_from_slice(&self.0[1].as_int().to_le_bytes()); + result[16..24].copy_from_slice(&self.0[2].as_int().to_le_bytes()); + result[24..].copy_from_slice(&self.0[3].as_int().to_le_bytes()); + + result + } +} + +impl Deref for RpxDigest { + type Target = [Felt; DIGEST_SIZE]; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Ord for RpxDigest { + fn cmp(&self, other: &Self) -> Ordering { + // compare the inner u64 of both elements. + // + // it will iterate the elements and will return the first computation different than + // `Equal`. Otherwise, the ordering is equal. + // + // the endianness is irrelevant here because since, this being a cryptographically secure + // hash computation, the digest shouldn't have any ordered property of its input. + // + // finally, we use `Felt::inner` instead of `Felt::as_int` so we avoid performing a + // montgomery reduction for every limb. that is safe because every inner element of the + // digest is guaranteed to be in its canonical form (that is, `x in [0,p)`). + self.0.iter().map(Felt::inner).zip(other.0.iter().map(Felt::inner)).fold( + Ordering::Equal, + |ord, (a, b)| match ord { + Ordering::Equal => a.cmp(&b), + _ => ord, + }, + ) + } +} + +impl PartialOrd for RpxDigest { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Display for RpxDigest { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + let encoded: String = self.into(); + write!(f, "{}", encoded)?; + Ok(()) + } +} + +impl Randomizable for RpxDigest { + const VALUE_SIZE: usize = DIGEST_BYTES; + + fn from_random_bytes(bytes: &[u8]) -> Option { + let bytes_array: Option<[u8; 32]> = bytes.try_into().ok(); + if let Some(bytes_array) = bytes_array { + Self::try_from(bytes_array).ok() + } else { + None + } + } +} + +// CONVERSIONS: FROM RPX DIGEST +// ================================================================================================ + +impl From<&RpxDigest> for [Felt; DIGEST_SIZE] { + fn from(value: &RpxDigest) -> Self { + value.0 + } +} + +impl From for [Felt; DIGEST_SIZE] { + fn from(value: RpxDigest) -> Self { + value.0 + } +} + +impl From<&RpxDigest> for [u64; DIGEST_SIZE] { + fn from(value: &RpxDigest) -> Self { + [ + value.0[0].as_int(), + value.0[1].as_int(), + value.0[2].as_int(), + value.0[3].as_int(), + ] + } +} + +impl From for [u64; DIGEST_SIZE] { + fn from(value: RpxDigest) -> Self { + [ + value.0[0].as_int(), + value.0[1].as_int(), + value.0[2].as_int(), + value.0[3].as_int(), + ] + } +} + +impl From<&RpxDigest> for [u8; DIGEST_BYTES] { + fn from(value: &RpxDigest) -> Self { + value.as_bytes() + } +} + +impl From for [u8; DIGEST_BYTES] { + fn from(value: RpxDigest) -> Self { + value.as_bytes() + } +} + +impl From for String { + /// The returned string starts with `0x`. + fn from(value: RpxDigest) -> Self { + bytes_to_hex_string(value.as_bytes()) + } +} + +impl From<&RpxDigest> for String { + /// The returned string starts with `0x`. + fn from(value: &RpxDigest) -> Self { + (*value).into() + } +} + +// CONVERSIONS: TO RPX DIGEST +// ================================================================================================ + +#[derive(Copy, Clone, Debug)] +pub enum RpxDigestError { + /// The provided u64 integer does not fit in the field's moduli. + InvalidInteger, +} + +impl From<&[Felt; DIGEST_SIZE]> for RpxDigest { + fn from(value: &[Felt; DIGEST_SIZE]) -> Self { + Self(*value) + } +} + +impl From<[Felt; DIGEST_SIZE]> for RpxDigest { + fn from(value: [Felt; DIGEST_SIZE]) -> Self { + Self(value) + } +} + +impl TryFrom<[u8; DIGEST_BYTES]> for RpxDigest { + type Error = HexParseError; + + fn try_from(value: [u8; DIGEST_BYTES]) -> Result { + // Note: the input length is known, the conversion from slice to array must succeed so the + // `unwrap`s below are safe + let a = u64::from_le_bytes(value[0..8].try_into().unwrap()); + let b = u64::from_le_bytes(value[8..16].try_into().unwrap()); + let c = u64::from_le_bytes(value[16..24].try_into().unwrap()); + let d = u64::from_le_bytes(value[24..32].try_into().unwrap()); + + if [a, b, c, d].iter().any(|v| *v >= Felt::MODULUS) { + return Err(HexParseError::OutOfRange); + } + + Ok(RpxDigest([Felt::new(a), Felt::new(b), Felt::new(c), Felt::new(d)])) + } +} + +impl TryFrom<&[u8; DIGEST_BYTES]> for RpxDigest { + type Error = HexParseError; + + fn try_from(value: &[u8; DIGEST_BYTES]) -> Result { + (*value).try_into() + } +} + +impl TryFrom<&[u8]> for RpxDigest { + type Error = HexParseError; + + fn try_from(value: &[u8]) -> Result { + (*value).try_into() + } +} + +impl TryFrom<[u64; DIGEST_SIZE]> for RpxDigest { + type Error = RpxDigestError; + + fn try_from(value: [u64; DIGEST_SIZE]) -> Result { + if value[0] >= Felt::MODULUS + || value[1] >= Felt::MODULUS + || value[2] >= Felt::MODULUS + || value[3] >= Felt::MODULUS + { + return Err(RpxDigestError::InvalidInteger); + } + + Ok(Self([value[0].into(), value[1].into(), value[2].into(), value[3].into()])) + } +} + +impl TryFrom<&[u64; DIGEST_SIZE]> for RpxDigest { + type Error = RpxDigestError; + + fn try_from(value: &[u64; DIGEST_SIZE]) -> Result { + (*value).try_into() + } +} + +impl TryFrom<&str> for RpxDigest { + type Error = HexParseError; + + /// Expects the string to start with `0x`. + fn try_from(value: &str) -> Result { + hex_to_bytes(value).and_then(|v| v.try_into()) + } +} + +impl TryFrom for RpxDigest { + type Error = HexParseError; + + /// Expects the string to start with `0x`. + fn try_from(value: String) -> Result { + value.as_str().try_into() + } +} + +impl TryFrom<&String> for RpxDigest { + type Error = HexParseError; + + /// Expects the string to start with `0x`. + fn try_from(value: &String) -> Result { + value.as_str().try_into() + } +} + +// SERIALIZATION / DESERIALIZATION +// ================================================================================================ + +impl Serializable for RpxDigest { + fn write_into(&self, target: &mut W) { + target.write_bytes(&self.as_bytes()); + } +} + +impl Deserializable for RpxDigest { + fn read_from(source: &mut R) -> Result { + let mut inner: [Felt; DIGEST_SIZE] = [ZERO; DIGEST_SIZE]; + for inner in inner.iter_mut() { + let e = source.read_u64()?; + if e >= Felt::MODULUS { + return Err(DeserializationError::InvalidValue(String::from( + "Value not in the appropriate range", + ))); + } + *inner = Felt::new(e); + } + + Ok(Self(inner)) + } +} + +// TESTS +// ================================================================================================ + +#[cfg(test)] +mod tests { + use super::{Deserializable, Felt, RpxDigest, Serializable, DIGEST_BYTES, DIGEST_SIZE}; + use crate::utils::{string::String, SliceReader}; + use rand_utils::rand_value; + + #[test] + fn digest_serialization() { + let e1 = Felt::new(rand_value()); + let e2 = Felt::new(rand_value()); + let e3 = Felt::new(rand_value()); + let e4 = Felt::new(rand_value()); + + let d1 = RpxDigest([e1, e2, e3, e4]); + + let mut bytes = vec![]; + d1.write_into(&mut bytes); + assert_eq!(DIGEST_BYTES, bytes.len()); + + let mut reader = SliceReader::new(&bytes); + let d2 = RpxDigest::read_from(&mut reader).unwrap(); + + assert_eq!(d1, d2); + } + + #[cfg(feature = "std")] + #[test] + fn digest_encoding() { + let digest = RpxDigest([ + Felt::new(rand_value()), + Felt::new(rand_value()), + Felt::new(rand_value()), + Felt::new(rand_value()), + ]); + + let string: String = digest.into(); + let round_trip: RpxDigest = string.try_into().expect("decoding failed"); + + assert_eq!(digest, round_trip); + } + + #[test] + fn test_conversions() { + let digest = RpxDigest([ + Felt::new(rand_value()), + Felt::new(rand_value()), + Felt::new(rand_value()), + Felt::new(rand_value()), + ]); + + let v: [Felt; DIGEST_SIZE] = digest.into(); + let v2: RpxDigest = v.into(); + assert_eq!(digest, v2); + + let v: [Felt; DIGEST_SIZE] = (&digest).into(); + let v2: RpxDigest = v.into(); + assert_eq!(digest, v2); + + let v: [u64; DIGEST_SIZE] = digest.into(); + let v2: RpxDigest = v.try_into().unwrap(); + assert_eq!(digest, v2); + + let v: [u64; DIGEST_SIZE] = (&digest).into(); + let v2: RpxDigest = v.try_into().unwrap(); + assert_eq!(digest, v2); + + let v: [u8; DIGEST_BYTES] = digest.into(); + let v2: RpxDigest = v.try_into().unwrap(); + assert_eq!(digest, v2); + + let v: [u8; DIGEST_BYTES] = (&digest).into(); + let v2: RpxDigest = v.try_into().unwrap(); + assert_eq!(digest, v2); + + let v: String = digest.into(); + let v2: RpxDigest = v.try_into().unwrap(); + assert_eq!(digest, v2); + + let v: String = (&digest).into(); + let v2: RpxDigest = v.try_into().unwrap(); + assert_eq!(digest, v2); + + let v: [u8; DIGEST_BYTES] = digest.into(); + let v2: RpxDigest = (&v).try_into().unwrap(); + assert_eq!(digest, v2); + + let v: [u8; DIGEST_BYTES] = (&digest).into(); + let v2: RpxDigest = (&v).try_into().unwrap(); + assert_eq!(digest, v2); + } +} diff --git a/src/hash/rescue/rpx/mod.rs b/src/hash/rescue/rpx/mod.rs new file mode 100644 index 00000000..541310e0 --- /dev/null +++ b/src/hash/rescue/rpx/mod.rs @@ -0,0 +1,379 @@ +use super::{ + add_constants, apply_inv_sbox, apply_mds, apply_sbox, + optimized_add_constants_and_apply_inv_sbox, optimized_add_constants_and_apply_sbox, + CubeExtension, Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField, ARK1, ARK2, + BINARY_CHUNK_SIZE, CAPACITY_RANGE, DIGEST_BYTES, DIGEST_RANGE, DIGEST_SIZE, INPUT1_RANGE, + INPUT2_RANGE, MDS, NUM_ROUNDS, ONE, RATE_RANGE, RATE_WIDTH, STATE_WIDTH, ZERO, +}; +use core::{convert::TryInto, ops::Range}; + +mod digest; +pub use digest::RpxDigest; + +#[cfg(all(target_feature = "sve", feature = "sve"))] +#[link(name = "rpo_sve", kind = "static")] +extern "C" { + fn add_constants_and_apply_sbox( + state: *mut std::ffi::c_ulong, + constants: *const std::ffi::c_ulong, + ) -> bool; + fn add_constants_and_apply_inv_sbox( + state: *mut std::ffi::c_ulong, + constants: *const std::ffi::c_ulong, + ) -> bool; +} + +pub type CubicExtElement = CubeExtension; + +// HASHER IMPLEMENTATION +// ================================================================================================ + +/// Implementation of the Rescue Prime eXtension hash function with 256-bit output. +/// +/// The hash function is based on the XHash12 construction in [specifications](https://eprint.iacr.org/2023/1045) +/// +/// The parameters used to instantiate the function are: +/// * Field: 64-bit prime field with modulus 2^64 - 2^32 + 1. +/// * State width: 12 field elements. +/// * Capacity size: 4 field elements. +/// * S-Box degree: 7. +/// * Rounds: There are 3 different types of rounds: +/// - (FB): `apply_mds` → `add_constants` → `apply_sbox` → `apply_mds` → `add_constants` → `apply_inv_sbox`. +/// - (E): `add_constants` → `ext_sbox` (which is raising to power 7 in the degree 3 extension field). +/// - (M): `apply_mds` → `add_constants`. +/// * Permutation: (FB) (E) (FB) (E) (FB) (E) (M). +/// +/// The above parameters target 128-bit security level. The digest consists of four field elements +/// and it can be serialized into 32 bytes (256 bits). +/// +/// ## Hash output consistency +/// Functions [hash_elements()](Rpx256::hash_elements), [merge()](Rpx256::merge), and +/// [merge_with_int()](Rpx256::merge_with_int) are internally consistent. That is, computing +/// a hash for the same set of elements using these functions will always produce the same +/// result. For example, merging two digests using [merge()](Rpx256::merge) will produce the +/// same result as hashing 8 elements which make up these digests using +/// [hash_elements()](Rpx256::hash_elements) function. +/// +/// However, [hash()](Rpx256::hash) function is not consistent with functions mentioned above. +/// For example, if we take two field elements, serialize them to bytes and hash them using +/// [hash()](Rpx256::hash), the result will differ from the result obtained by hashing these +/// elements directly using [hash_elements()](Rpx256::hash_elements) function. The reason for +/// this difference is that [hash()](Rpx256::hash) function needs to be able to handle +/// arbitrary binary strings, which may or may not encode valid field elements - and thus, +/// deserialization procedure used by this function is different from the procedure used to +/// deserialize valid field elements. +/// +/// Thus, if the underlying data consists of valid field elements, it might make more sense +/// to deserialize them into field elements and then hash them using +/// [hash_elements()](Rpx256::hash_elements) function rather then hashing the serialized bytes +/// using [hash()](Rpx256::hash) function. +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub struct Rpx256(); + +impl Hasher for Rpx256 { + /// Rpx256 collision resistance is the same as the security level, that is 128-bits. + /// + /// #### Collision resistance + /// + /// However, our setup of the capacity registers might drop it to 126. + /// + /// Related issue: [#69](https://github.com/0xPolygonMiden/crypto/issues/69) + const COLLISION_RESISTANCE: u32 = 128; + + type Digest = RpxDigest; + + fn hash(bytes: &[u8]) -> Self::Digest { + // initialize the state with zeroes + let mut state = [ZERO; STATE_WIDTH]; + + // set the capacity (first element) to a flag on whether or not the input length is evenly + // divided by the rate. this will prevent collisions between padded and non-padded inputs, + // and will rule out the need to perform an extra permutation in case of evenly divided + // inputs. + let is_rate_multiple = bytes.len() % RATE_WIDTH == 0; + if !is_rate_multiple { + state[CAPACITY_RANGE.start] = ONE; + } + + // initialize a buffer to receive the little-endian elements. + let mut buf = [0_u8; 8]; + + // iterate the chunks of bytes, creating a field element from each chunk and copying it + // into the state. + // + // every time the rate range is filled, a permutation is performed. if the final value of + // `i` is not zero, then the chunks count wasn't enough to fill the state range, and an + // additional permutation must be performed. + let i = bytes.chunks(BINARY_CHUNK_SIZE).fold(0, |i, chunk| { + // the last element of the iteration may or may not be a full chunk. if it's not, then + // we need to pad the remainder bytes of the chunk with zeroes, separated by a `1`. + // this will avoid collisions. + if chunk.len() == BINARY_CHUNK_SIZE { + buf[..BINARY_CHUNK_SIZE].copy_from_slice(chunk); + } else { + buf.fill(0); + buf[..chunk.len()].copy_from_slice(chunk); + buf[chunk.len()] = 1; + } + + // set the current rate element to the input. since we take at most 7 bytes, we are + // guaranteed that the inputs data will fit into a single field element. + state[RATE_RANGE.start + i] = Felt::new(u64::from_le_bytes(buf)); + + // proceed filling the range. if it's full, then we apply a permutation and reset the + // counter to the beginning of the range. + if i == RATE_WIDTH - 1 { + Self::apply_permutation(&mut state); + 0 + } else { + i + 1 + } + }); + + // if we absorbed some elements but didn't apply a permutation to them (would happen when + // the number of elements is not a multiple of RATE_WIDTH), apply the RPX permutation. we + // don't need to apply any extra padding because the first capacity element containts a + // flag indicating whether the input is evenly divisible by the rate. + if i != 0 { + state[RATE_RANGE.start + i..RATE_RANGE.end].fill(ZERO); + state[RATE_RANGE.start + i] = ONE; + Self::apply_permutation(&mut state); + } + + // return the first 4 elements of the rate as hash result. + RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap()) + } + + fn merge(values: &[Self::Digest; 2]) -> Self::Digest { + // initialize the state by copying the digest elements into the rate portion of the state + // (8 total elements), and set the capacity elements to 0. + let mut state = [ZERO; STATE_WIDTH]; + let it = Self::Digest::digests_as_elements(values.iter()); + for (i, v) in it.enumerate() { + state[RATE_RANGE.start + i] = *v; + } + + // apply the RPX permutation and return the first four elements of the state + Self::apply_permutation(&mut state); + RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap()) + } + + fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest { + // initialize the state as follows: + // - seed is copied into the first 4 elements of the rate portion of the state. + // - if the value fits into a single field element, copy it into the fifth rate element + // and set the sixth rate element to 1. + // - if the value doesn't fit into a single field element, split it into two field + // elements, copy them into rate elements 5 and 6, and set the seventh rate element + // to 1. + // - set the first capacity element to 1 + let mut state = [ZERO; STATE_WIDTH]; + state[INPUT1_RANGE].copy_from_slice(seed.as_elements()); + state[INPUT2_RANGE.start] = Felt::new(value); + if value < Felt::MODULUS { + state[INPUT2_RANGE.start + 1] = ONE; + } else { + state[INPUT2_RANGE.start + 1] = Felt::new(value / Felt::MODULUS); + state[INPUT2_RANGE.start + 2] = ONE; + } + + // common padding for both cases + state[CAPACITY_RANGE.start] = ONE; + + // apply the RPX permutation and return the first four elements of the state + Self::apply_permutation(&mut state); + RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap()) + } +} + +impl ElementHasher for Rpx256 { + type BaseField = Felt; + + fn hash_elements>(elements: &[E]) -> Self::Digest { + // convert the elements into a list of base field elements + let elements = E::slice_as_base_elements(elements); + + // initialize state to all zeros, except for the first element of the capacity part, which + // is set to 1 if the number of elements is not a multiple of RATE_WIDTH. + let mut state = [ZERO; STATE_WIDTH]; + if elements.len() % RATE_WIDTH != 0 { + state[CAPACITY_RANGE.start] = ONE; + } + + // absorb elements into the state one by one until the rate portion of the state is filled + // up; then apply the Rescue permutation and start absorbing again; repeat until all + // elements have been absorbed + let mut i = 0; + for &element in elements.iter() { + state[RATE_RANGE.start + i] = element; + i += 1; + if i % RATE_WIDTH == 0 { + Self::apply_permutation(&mut state); + i = 0; + } + } + + // if we absorbed some elements but didn't apply a permutation to them (would happen when + // the number of elements is not a multiple of RATE_WIDTH), apply the RPX permutation after + // padding by appending a 1 followed by as many 0 as necessary to make the input length a + // multiple of the RATE_WIDTH. + if i > 0 { + state[RATE_RANGE.start + i] = ONE; + i += 1; + while i != RATE_WIDTH { + state[RATE_RANGE.start + i] = ZERO; + i += 1; + } + Self::apply_permutation(&mut state); + } + + // return the first 4 elements of the state as hash result + RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap()) + } +} + +// HASH FUNCTION IMPLEMENTATION +// ================================================================================================ + +impl Rpx256 { + // CONSTANTS + // -------------------------------------------------------------------------------------------- + + /// Sponge state is set to 12 field elements or 768 bytes; 8 elements are reserved for rate and + /// the remaining 4 elements are reserved for capacity. + pub const STATE_WIDTH: usize = STATE_WIDTH; + + /// The rate portion of the state is located in elements 4 through 11 (inclusive). + pub const RATE_RANGE: Range = RATE_RANGE; + + /// The capacity portion of the state is located in elements 0, 1, 2, and 3. + pub const CAPACITY_RANGE: Range = CAPACITY_RANGE; + + /// The output of the hash function can be read from state elements 4, 5, 6, and 7. + pub const DIGEST_RANGE: Range = DIGEST_RANGE; + + /// MDS matrix used for computing the linear layer in the (FB) and (E) rounds. + pub const MDS: [[Felt; STATE_WIDTH]; STATE_WIDTH] = MDS; + + /// Round constants added to the hasher state in the first half of the round. + pub const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK1; + + /// Round constants added to the hasher state in the second half of the round. + pub const ARK2: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK2; + + // TRAIT PASS-THROUGH FUNCTIONS + // -------------------------------------------------------------------------------------------- + + /// Returns a hash of the provided sequence of bytes. + #[inline(always)] + pub fn hash(bytes: &[u8]) -> RpxDigest { + ::hash(bytes) + } + + /// Returns a hash of two digests. This method is intended for use in construction of + /// Merkle trees and verification of Merkle paths. + #[inline(always)] + pub fn merge(values: &[RpxDigest; 2]) -> RpxDigest { + ::merge(values) + } + + /// Returns a hash of the provided field elements. + #[inline(always)] + pub fn hash_elements>(elements: &[E]) -> RpxDigest { + ::hash_elements(elements) + } + + // DOMAIN IDENTIFIER + // -------------------------------------------------------------------------------------------- + + /// Returns a hash of two digests and a domain identifier. + pub fn merge_in_domain(values: &[RpxDigest; 2], domain: Felt) -> RpxDigest { + // initialize the state by copying the digest elements into the rate portion of the state + // (8 total elements), and set the capacity elements to 0. + let mut state = [ZERO; STATE_WIDTH]; + let it = RpxDigest::digests_as_elements(values.iter()); + for (i, v) in it.enumerate() { + state[RATE_RANGE.start + i] = *v; + } + + // set the second capacity element to the domain value. The first capacity element is used + // for padding purposes. + state[CAPACITY_RANGE.start + 1] = domain; + + // apply the RPX permutation and return the first four elements of the state + Self::apply_permutation(&mut state); + RpxDigest::new(state[DIGEST_RANGE].try_into().unwrap()) + } + + // RPX PERMUTATION + // -------------------------------------------------------------------------------------------- + + /// Applies RPX permutation to the provided state. + #[inline(always)] + pub fn apply_permutation(state: &mut [Felt; STATE_WIDTH]) { + Self::apply_fb_round(state, 0); + Self::apply_ext_round(state, 1); + Self::apply_fb_round(state, 2); + Self::apply_ext_round(state, 3); + Self::apply_fb_round(state, 4); + Self::apply_ext_round(state, 5); + Self::apply_final_round(state, 6); + } + + // RPX PERMUTATION ROUND FUNCTIONS + // -------------------------------------------------------------------------------------------- + + /// (FB) round function. + #[inline(always)] + pub fn apply_fb_round(state: &mut [Felt; STATE_WIDTH], round: usize) { + apply_mds(state); + if !optimized_add_constants_and_apply_sbox(state, &ARK1[round]) { + add_constants(state, &ARK1[round]); + apply_sbox(state); + } + + apply_mds(state); + if !optimized_add_constants_and_apply_inv_sbox(state, &ARK2[round]) { + add_constants(state, &ARK2[round]); + apply_inv_sbox(state); + } + } + + /// (E) round function. + #[inline(always)] + pub fn apply_ext_round(state: &mut [Felt; STATE_WIDTH], round: usize) { + // add constants + add_constants(state, &ARK1[round]); + + // decompose the state into 4 elements in the cubic extension field and apply the power 7 + // map to each of the elements + let [s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] = *state; + let ext0 = Self::exp7(CubicExtElement::new(s0, s1, s2)); + let ext1 = Self::exp7(CubicExtElement::new(s3, s4, s5)); + let ext2 = Self::exp7(CubicExtElement::new(s6, s7, s8)); + let ext3 = Self::exp7(CubicExtElement::new(s9, s10, s11)); + + // decompose the state back into 12 base field elements + let arr_ext = [ext0, ext1, ext2, ext3]; + *state = CubicExtElement::slice_as_base_elements(&arr_ext) + .try_into() + .expect("shouldn't fail"); + } + + /// (M) round function. + #[inline(always)] + pub fn apply_final_round(state: &mut [Felt; STATE_WIDTH], round: usize) { + apply_mds(state); + add_constants(state, &ARK1[round]); + } + + /// Computes an exponentiation to the power 7 in cubic extension field + #[inline(always)] + pub fn exp7(x: CubeExtension) -> CubeExtension { + let x2 = x.square(); + let x4 = x2.square(); + + let x3 = x2 * x; + x3 * x4 + } +} diff --git a/src/hash/rescue/tests.rs b/src/hash/rescue/tests.rs new file mode 100644 index 00000000..f0669e96 --- /dev/null +++ b/src/hash/rescue/tests.rs @@ -0,0 +1,9 @@ +use super::{Felt, FieldElement, ALPHA, INV_ALPHA}; +use rand_utils::rand_value; + +#[test] +fn test_alphas() { + let e: Felt = Felt::new(rand_value()); + let e_exp = e.exp(ALPHA); + assert_eq!(e, e_exp.exp(INV_ALPHA)); +} diff --git a/src/hash/rpo/mod.rs b/src/hash/rpo/mod.rs deleted file mode 100644 index fafce892..00000000 --- a/src/hash/rpo/mod.rs +++ /dev/null @@ -1,905 +0,0 @@ -use super::{Digest, ElementHasher, Felt, FieldElement, Hasher, StarkField, ONE, ZERO}; -use core::{convert::TryInto, ops::Range}; - -mod digest; -pub use digest::RpoDigest; - -mod mds_freq; -use mds_freq::mds_multiply_freq; - -#[cfg(test)] -mod tests; - -#[cfg(all(target_feature = "sve", feature = "sve"))] -#[link(name = "rpo_sve", kind = "static")] -extern "C" { - fn add_constants_and_apply_sbox( - state: *mut std::ffi::c_ulong, - constants: *const std::ffi::c_ulong, - ) -> bool; - fn add_constants_and_apply_inv_sbox( - state: *mut std::ffi::c_ulong, - constants: *const std::ffi::c_ulong, - ) -> bool; -} - -// CONSTANTS -// ================================================================================================ - -/// Sponge state is set to 12 field elements or 96 bytes; 8 elements are reserved for rate and -/// the remaining 4 elements are reserved for capacity. -const STATE_WIDTH: usize = 12; - -/// The rate portion of the state is located in elements 4 through 11. -const RATE_RANGE: Range = 4..12; -const RATE_WIDTH: usize = RATE_RANGE.end - RATE_RANGE.start; - -const INPUT1_RANGE: Range = 4..8; -const INPUT2_RANGE: Range = 8..12; - -/// The capacity portion of the state is located in elements 0, 1, 2, and 3. -const CAPACITY_RANGE: Range = 0..4; - -/// The output of the hash function is a digest which consists of 4 field elements or 32 bytes. -/// -/// The digest is returned from state elements 4, 5, 6, and 7 (the first four elements of the -/// rate portion). -const DIGEST_RANGE: Range = 4..8; -const DIGEST_SIZE: usize = DIGEST_RANGE.end - DIGEST_RANGE.start; - -/// The number of rounds is set to 7 to target 128-bit security level -const NUM_ROUNDS: usize = 7; - -/// The number of byte chunks defining a field element when hashing a sequence of bytes -const BINARY_CHUNK_SIZE: usize = 7; - -/// S-Box and Inverse S-Box powers; -/// -/// The constants are defined for tests only because the exponentiations in the code are unrolled -/// for efficiency reasons. -#[cfg(test)] -const ALPHA: u64 = 7; -#[cfg(test)] -const INV_ALPHA: u64 = 10540996611094048183; - -// HASHER IMPLEMENTATION -// ================================================================================================ - -/// Implementation of the Rescue Prime Optimized hash function with 256-bit output. -/// -/// The hash function is implemented according to the Rescue Prime Optimized -/// [specifications](https://eprint.iacr.org/2022/1577) -/// -/// The parameters used to instantiate the function are: -/// * Field: 64-bit prime field with modulus 2^64 - 2^32 + 1. -/// * State width: 12 field elements. -/// * Capacity size: 4 field elements. -/// * Number of founds: 7. -/// * S-Box degree: 7. -/// -/// The above parameters target 128-bit security level. The digest consists of four field elements -/// and it can be serialized into 32 bytes (256 bits). -/// -/// ## Hash output consistency -/// Functions [hash_elements()](Rpo256::hash_elements), [merge()](Rpo256::merge), and -/// [merge_with_int()](Rpo256::merge_with_int) are internally consistent. That is, computing -/// a hash for the same set of elements using these functions will always produce the same -/// result. For example, merging two digests using [merge()](Rpo256::merge) will produce the -/// same result as hashing 8 elements which make up these digests using -/// [hash_elements()](Rpo256::hash_elements) function. -/// -/// However, [hash()](Rpo256::hash) function is not consistent with functions mentioned above. -/// For example, if we take two field elements, serialize them to bytes and hash them using -/// [hash()](Rpo256::hash), the result will differ from the result obtained by hashing these -/// elements directly using [hash_elements()](Rpo256::hash_elements) function. The reason for -/// this difference is that [hash()](Rpo256::hash) function needs to be able to handle -/// arbitrary binary strings, which may or may not encode valid field elements - and thus, -/// deserialization procedure used by this function is different from the procedure used to -/// deserialize valid field elements. -/// -/// Thus, if the underlying data consists of valid field elements, it might make more sense -/// to deserialize them into field elements and then hash them using -/// [hash_elements()](Rpo256::hash_elements) function rather then hashing the serialized bytes -/// using [hash()](Rpo256::hash) function. -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -pub struct Rpo256(); - -impl Hasher for Rpo256 { - /// Rpo256 collision resistance is the same as the security level, that is 128-bits. - /// - /// #### Collision resistance - /// - /// However, our setup of the capacity registers might drop it to 126. - /// - /// Related issue: [#69](https://github.com/0xPolygonMiden/crypto/issues/69) - const COLLISION_RESISTANCE: u32 = 128; - - type Digest = RpoDigest; - - fn hash(bytes: &[u8]) -> Self::Digest { - // initialize the state with zeroes - let mut state = [ZERO; STATE_WIDTH]; - - // set the capacity (first element) to a flag on whether or not the input length is evenly - // divided by the rate. this will prevent collisions between padded and non-padded inputs, - // and will rule out the need to perform an extra permutation in case of evenly divided - // inputs. - let is_rate_multiple = bytes.len() % RATE_WIDTH == 0; - if !is_rate_multiple { - state[CAPACITY_RANGE.start] = ONE; - } - - // initialize a buffer to receive the little-endian elements. - let mut buf = [0_u8; 8]; - - // iterate the chunks of bytes, creating a field element from each chunk and copying it - // into the state. - // - // every time the rate range is filled, a permutation is performed. if the final value of - // `i` is not zero, then the chunks count wasn't enough to fill the state range, and an - // additional permutation must be performed. - let i = bytes.chunks(BINARY_CHUNK_SIZE).fold(0, |i, chunk| { - // the last element of the iteration may or may not be a full chunk. if it's not, then - // we need to pad the remainder bytes of the chunk with zeroes, separated by a `1`. - // this will avoid collisions. - if chunk.len() == BINARY_CHUNK_SIZE { - buf[..BINARY_CHUNK_SIZE].copy_from_slice(chunk); - } else { - buf.fill(0); - buf[..chunk.len()].copy_from_slice(chunk); - buf[chunk.len()] = 1; - } - - // set the current rate element to the input. since we take at most 7 bytes, we are - // guaranteed that the inputs data will fit into a single field element. - state[RATE_RANGE.start + i] = Felt::new(u64::from_le_bytes(buf)); - - // proceed filling the range. if it's full, then we apply a permutation and reset the - // counter to the beginning of the range. - if i == RATE_WIDTH - 1 { - Self::apply_permutation(&mut state); - 0 - } else { - i + 1 - } - }); - - // if we absorbed some elements but didn't apply a permutation to them (would happen when - // the number of elements is not a multiple of RATE_WIDTH), apply the RPO permutation. we - // don't need to apply any extra padding because the first capacity element containts a - // flag indicating whether the input is evenly divisible by the rate. - if i != 0 { - state[RATE_RANGE.start + i..RATE_RANGE.end].fill(ZERO); - state[RATE_RANGE.start + i] = ONE; - Self::apply_permutation(&mut state); - } - - // return the first 4 elements of the rate as hash result. - RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap()) - } - - fn merge(values: &[Self::Digest; 2]) -> Self::Digest { - // initialize the state by copying the digest elements into the rate portion of the state - // (8 total elements), and set the capacity elements to 0. - let mut state = [ZERO; STATE_WIDTH]; - let it = Self::Digest::digests_as_elements(values.iter()); - for (i, v) in it.enumerate() { - state[RATE_RANGE.start + i] = *v; - } - - // apply the RPO permutation and return the first four elements of the state - Self::apply_permutation(&mut state); - RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap()) - } - - fn merge_with_int(seed: Self::Digest, value: u64) -> Self::Digest { - // initialize the state as follows: - // - seed is copied into the first 4 elements of the rate portion of the state. - // - if the value fits into a single field element, copy it into the fifth rate element - // and set the sixth rate element to 1. - // - if the value doesn't fit into a single field element, split it into two field - // elements, copy them into rate elements 5 and 6, and set the seventh rate element - // to 1. - // - set the first capacity element to 1 - let mut state = [ZERO; STATE_WIDTH]; - state[INPUT1_RANGE].copy_from_slice(seed.as_elements()); - state[INPUT2_RANGE.start] = Felt::new(value); - if value < Felt::MODULUS { - state[INPUT2_RANGE.start + 1] = ONE; - } else { - state[INPUT2_RANGE.start + 1] = Felt::new(value / Felt::MODULUS); - state[INPUT2_RANGE.start + 2] = ONE; - } - - // common padding for both cases - state[CAPACITY_RANGE.start] = ONE; - - // apply the RPO permutation and return the first four elements of the state - Self::apply_permutation(&mut state); - RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap()) - } -} - -impl ElementHasher for Rpo256 { - type BaseField = Felt; - - fn hash_elements>(elements: &[E]) -> Self::Digest { - // convert the elements into a list of base field elements - let elements = E::slice_as_base_elements(elements); - - // initialize state to all zeros, except for the first element of the capacity part, which - // is set to 1 if the number of elements is not a multiple of RATE_WIDTH. - let mut state = [ZERO; STATE_WIDTH]; - if elements.len() % RATE_WIDTH != 0 { - state[CAPACITY_RANGE.start] = ONE; - } - - // absorb elements into the state one by one until the rate portion of the state is filled - // up; then apply the Rescue permutation and start absorbing again; repeat until all - // elements have been absorbed - let mut i = 0; - for &element in elements.iter() { - state[RATE_RANGE.start + i] = element; - i += 1; - if i % RATE_WIDTH == 0 { - Self::apply_permutation(&mut state); - i = 0; - } - } - - // if we absorbed some elements but didn't apply a permutation to them (would happen when - // the number of elements is not a multiple of RATE_WIDTH), apply the RPO permutation after - // padding by appending a 1 followed by as many 0 as necessary to make the input length a - // multiple of the RATE_WIDTH. - if i > 0 { - state[RATE_RANGE.start + i] = ONE; - i += 1; - while i != RATE_WIDTH { - state[RATE_RANGE.start + i] = ZERO; - i += 1; - } - Self::apply_permutation(&mut state); - } - - // return the first 4 elements of the state as hash result - RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap()) - } -} - -// HASH FUNCTION IMPLEMENTATION -// ================================================================================================ - -impl Rpo256 { - // CONSTANTS - // -------------------------------------------------------------------------------------------- - - /// The number of rounds is set to 7 to target 128-bit security level. - pub const NUM_ROUNDS: usize = NUM_ROUNDS; - - /// Sponge state is set to 12 field elements or 768 bytes; 8 elements are reserved for rate and - /// the remaining 4 elements are reserved for capacity. - pub const STATE_WIDTH: usize = STATE_WIDTH; - - /// The rate portion of the state is located in elements 4 through 11 (inclusive). - pub const RATE_RANGE: Range = RATE_RANGE; - - /// The capacity portion of the state is located in elements 0, 1, 2, and 3. - pub const CAPACITY_RANGE: Range = CAPACITY_RANGE; - - /// The output of the hash function can be read from state elements 4, 5, 6, and 7. - pub const DIGEST_RANGE: Range = DIGEST_RANGE; - - /// MDS matrix used for computing the linear layer in a RPO round. - pub const MDS: [[Felt; STATE_WIDTH]; STATE_WIDTH] = MDS; - - /// Round constants added to the hasher state in the first half of the RPO round. - pub const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK1; - - /// Round constants added to the hasher state in the second half of the RPO round. - pub const ARK2: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = ARK2; - - // TRAIT PASS-THROUGH FUNCTIONS - // -------------------------------------------------------------------------------------------- - - /// Returns a hash of the provided sequence of bytes. - #[inline(always)] - pub fn hash(bytes: &[u8]) -> RpoDigest { - ::hash(bytes) - } - - /// Returns a hash of two digests. This method is intended for use in construction of - /// Merkle trees and verification of Merkle paths. - #[inline(always)] - pub fn merge(values: &[RpoDigest; 2]) -> RpoDigest { - ::merge(values) - } - - /// Returns a hash of the provided field elements. - #[inline(always)] - pub fn hash_elements>(elements: &[E]) -> RpoDigest { - ::hash_elements(elements) - } - - // DOMAIN IDENTIFIER - // -------------------------------------------------------------------------------------------- - - /// Returns a hash of two digests and a domain identifier. - pub fn merge_in_domain(values: &[RpoDigest; 2], domain: Felt) -> RpoDigest { - // initialize the state by copying the digest elements into the rate portion of the state - // (8 total elements), and set the capacity elements to 0. - let mut state = [ZERO; STATE_WIDTH]; - let it = RpoDigest::digests_as_elements(values.iter()); - for (i, v) in it.enumerate() { - state[RATE_RANGE.start + i] = *v; - } - - // set the second capacity element to the domain value. The first capacity element is used - // for padding purposes. - state[CAPACITY_RANGE.start + 1] = domain; - - // apply the RPO permutation and return the first four elements of the state - Self::apply_permutation(&mut state); - RpoDigest::new(state[DIGEST_RANGE].try_into().unwrap()) - } - - // RESCUE PERMUTATION - // -------------------------------------------------------------------------------------------- - - /// Applies RPO permutation to the provided state. - #[inline(always)] - pub fn apply_permutation(state: &mut [Felt; STATE_WIDTH]) { - for i in 0..NUM_ROUNDS { - Self::apply_round(state, i); - } - } - - /// RPO round function. - #[inline(always)] - pub fn apply_round(state: &mut [Felt; STATE_WIDTH], round: usize) { - // apply first half of RPO round - Self::apply_mds(state); - if !Self::optimized_add_constants_and_apply_sbox(state, &ARK1[round]) { - Self::add_constants(state, &ARK1[round]); - Self::apply_sbox(state); - } - - // apply second half of RPO round - Self::apply_mds(state); - if !Self::optimized_add_constants_and_apply_inv_sbox(state, &ARK2[round]) { - Self::add_constants(state, &ARK2[round]); - Self::apply_inv_sbox(state); - } - } - - // HELPER FUNCTIONS - // -------------------------------------------------------------------------------------------- - - #[inline(always)] - #[cfg(all(target_feature = "sve", feature = "sve"))] - fn optimized_add_constants_and_apply_sbox( - state: &mut [Felt; STATE_WIDTH], - ark: &[Felt; STATE_WIDTH], - ) -> bool { - unsafe { - add_constants_and_apply_sbox(state.as_mut_ptr() as *mut u64, ark.as_ptr() as *const u64) - } - } - - #[inline(always)] - #[cfg(not(all(target_feature = "sve", feature = "sve")))] - fn optimized_add_constants_and_apply_sbox( - _state: &mut [Felt; STATE_WIDTH], - _ark: &[Felt; STATE_WIDTH], - ) -> bool { - false - } - - #[inline(always)] - #[cfg(all(target_feature = "sve", feature = "sve"))] - fn optimized_add_constants_and_apply_inv_sbox( - state: &mut [Felt; STATE_WIDTH], - ark: &[Felt; STATE_WIDTH], - ) -> bool { - unsafe { - add_constants_and_apply_inv_sbox( - state.as_mut_ptr() as *mut u64, - ark.as_ptr() as *const u64, - ) - } - } - - #[inline(always)] - #[cfg(not(all(target_feature = "sve", feature = "sve")))] - fn optimized_add_constants_and_apply_inv_sbox( - _state: &mut [Felt; STATE_WIDTH], - _ark: &[Felt; STATE_WIDTH], - ) -> bool { - false - } - - #[inline(always)] - fn apply_mds(state: &mut [Felt; STATE_WIDTH]) { - let mut result = [ZERO; STATE_WIDTH]; - - // Using the linearity of the operations we can split the state into a low||high decomposition - // and operate on each with no overflow and then combine/reduce the result to a field element. - // The no overflow is guaranteed by the fact that the MDS matrix is a small powers of two in - // frequency domain. - let mut state_l = [0u64; STATE_WIDTH]; - let mut state_h = [0u64; STATE_WIDTH]; - - for r in 0..STATE_WIDTH { - let s = state[r].inner(); - state_h[r] = s >> 32; - state_l[r] = (s as u32) as u64; - } - - let state_h = mds_multiply_freq(state_h); - let state_l = mds_multiply_freq(state_l); - - for r in 0..STATE_WIDTH { - let s = state_l[r] as u128 + ((state_h[r] as u128) << 32); - let s_hi = (s >> 64) as u64; - let s_lo = s as u64; - let z = (s_hi << 32) - s_hi; - let (res, over) = s_lo.overflowing_add(z); - - result[r] = Felt::from_mont(res.wrapping_add(0u32.wrapping_sub(over as u32) as u64)); - } - *state = result; - } - - #[inline(always)] - fn add_constants(state: &mut [Felt; STATE_WIDTH], ark: &[Felt; STATE_WIDTH]) { - state.iter_mut().zip(ark).for_each(|(s, &k)| *s += k); - } - - #[inline(always)] - fn apply_sbox(state: &mut [Felt; STATE_WIDTH]) { - state[0] = state[0].exp7(); - state[1] = state[1].exp7(); - state[2] = state[2].exp7(); - state[3] = state[3].exp7(); - state[4] = state[4].exp7(); - state[5] = state[5].exp7(); - state[6] = state[6].exp7(); - state[7] = state[7].exp7(); - state[8] = state[8].exp7(); - state[9] = state[9].exp7(); - state[10] = state[10].exp7(); - state[11] = state[11].exp7(); - } - - #[inline(always)] - fn apply_inv_sbox(state: &mut [Felt; STATE_WIDTH]) { - // compute base^10540996611094048183 using 72 multiplications per array element - // 10540996611094048183 = b1001001001001001001001001001000110110110110110110110110110110111 - - // compute base^10 - let mut t1 = *state; - t1.iter_mut().for_each(|t| *t = t.square()); - - // compute base^100 - let mut t2 = t1; - t2.iter_mut().for_each(|t| *t = t.square()); - - // compute base^100100 - let t3 = Self::exp_acc::(t2, t2); - - // compute base^100100100100 - let t4 = Self::exp_acc::(t3, t3); - - // compute base^100100100100100100100100 - let t5 = Self::exp_acc::(t4, t4); - - // compute base^100100100100100100100100100100 - let t6 = Self::exp_acc::(t5, t3); - - // compute base^1001001001001001001001001001000100100100100100100100100100100 - let t7 = Self::exp_acc::(t6, t6); - - // compute base^1001001001001001001001001001000110110110110110110110110110110111 - for (i, s) in state.iter_mut().enumerate() { - let a = (t7[i].square() * t6[i]).square().square(); - let b = t1[i] * t2[i] * *s; - *s = a * b; - } - } - - #[inline(always)] - fn exp_acc( - base: [B; N], - tail: [B; N], - ) -> [B; N] { - let mut result = base; - for _ in 0..M { - result.iter_mut().for_each(|r| *r = r.square()); - } - result.iter_mut().zip(tail).for_each(|(r, t)| *r *= t); - result - } -} - -// MDS -// ================================================================================================ -/// RPO MDS matrix -const MDS: [[Felt; STATE_WIDTH]; STATE_WIDTH] = [ - [ - Felt::new(7), - Felt::new(23), - Felt::new(8), - Felt::new(26), - Felt::new(13), - Felt::new(10), - Felt::new(9), - Felt::new(7), - Felt::new(6), - Felt::new(22), - Felt::new(21), - Felt::new(8), - ], - [ - Felt::new(8), - Felt::new(7), - Felt::new(23), - Felt::new(8), - Felt::new(26), - Felt::new(13), - Felt::new(10), - Felt::new(9), - Felt::new(7), - Felt::new(6), - Felt::new(22), - Felt::new(21), - ], - [ - Felt::new(21), - Felt::new(8), - Felt::new(7), - Felt::new(23), - Felt::new(8), - Felt::new(26), - Felt::new(13), - Felt::new(10), - Felt::new(9), - Felt::new(7), - Felt::new(6), - Felt::new(22), - ], - [ - Felt::new(22), - Felt::new(21), - Felt::new(8), - Felt::new(7), - Felt::new(23), - Felt::new(8), - Felt::new(26), - Felt::new(13), - Felt::new(10), - Felt::new(9), - Felt::new(7), - Felt::new(6), - ], - [ - Felt::new(6), - Felt::new(22), - Felt::new(21), - Felt::new(8), - Felt::new(7), - Felt::new(23), - Felt::new(8), - Felt::new(26), - Felt::new(13), - Felt::new(10), - Felt::new(9), - Felt::new(7), - ], - [ - Felt::new(7), - Felt::new(6), - Felt::new(22), - Felt::new(21), - Felt::new(8), - Felt::new(7), - Felt::new(23), - Felt::new(8), - Felt::new(26), - Felt::new(13), - Felt::new(10), - Felt::new(9), - ], - [ - Felt::new(9), - Felt::new(7), - Felt::new(6), - Felt::new(22), - Felt::new(21), - Felt::new(8), - Felt::new(7), - Felt::new(23), - Felt::new(8), - Felt::new(26), - Felt::new(13), - Felt::new(10), - ], - [ - Felt::new(10), - Felt::new(9), - Felt::new(7), - Felt::new(6), - Felt::new(22), - Felt::new(21), - Felt::new(8), - Felt::new(7), - Felt::new(23), - Felt::new(8), - Felt::new(26), - Felt::new(13), - ], - [ - Felt::new(13), - Felt::new(10), - Felt::new(9), - Felt::new(7), - Felt::new(6), - Felt::new(22), - Felt::new(21), - Felt::new(8), - Felt::new(7), - Felt::new(23), - Felt::new(8), - Felt::new(26), - ], - [ - Felt::new(26), - Felt::new(13), - Felt::new(10), - Felt::new(9), - Felt::new(7), - Felt::new(6), - Felt::new(22), - Felt::new(21), - Felt::new(8), - Felt::new(7), - Felt::new(23), - Felt::new(8), - ], - [ - Felt::new(8), - Felt::new(26), - Felt::new(13), - Felt::new(10), - Felt::new(9), - Felt::new(7), - Felt::new(6), - Felt::new(22), - Felt::new(21), - Felt::new(8), - Felt::new(7), - Felt::new(23), - ], - [ - Felt::new(23), - Felt::new(8), - Felt::new(26), - Felt::new(13), - Felt::new(10), - Felt::new(9), - Felt::new(7), - Felt::new(6), - Felt::new(22), - Felt::new(21), - Felt::new(8), - Felt::new(7), - ], -]; - -// ROUND CONSTANTS -// ================================================================================================ - -/// Rescue round constants; -/// computed as in [specifications](https://github.com/ASDiscreteMathematics/rpo) -/// -/// The constants are broken up into two arrays ARK1 and ARK2; ARK1 contains the constants for the -/// first half of RPO round, and ARK2 contains constants for the second half of RPO round. -const ARK1: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = [ - [ - Felt::new(5789762306288267392), - Felt::new(6522564764413701783), - Felt::new(17809893479458208203), - Felt::new(107145243989736508), - Felt::new(6388978042437517382), - Felt::new(15844067734406016715), - Felt::new(9975000513555218239), - Felt::new(3344984123768313364), - Felt::new(9959189626657347191), - Felt::new(12960773468763563665), - Felt::new(9602914297752488475), - Felt::new(16657542370200465908), - ], - [ - Felt::new(12987190162843096997), - Felt::new(653957632802705281), - Felt::new(4441654670647621225), - Felt::new(4038207883745915761), - Felt::new(5613464648874830118), - Felt::new(13222989726778338773), - Felt::new(3037761201230264149), - Felt::new(16683759727265180203), - Felt::new(8337364536491240715), - Felt::new(3227397518293416448), - Felt::new(8110510111539674682), - Felt::new(2872078294163232137), - ], - [ - Felt::new(18072785500942327487), - Felt::new(6200974112677013481), - Felt::new(17682092219085884187), - Felt::new(10599526828986756440), - Felt::new(975003873302957338), - Felt::new(8264241093196931281), - Felt::new(10065763900435475170), - Felt::new(2181131744534710197), - Felt::new(6317303992309418647), - Felt::new(1401440938888741532), - Felt::new(8884468225181997494), - Felt::new(13066900325715521532), - ], - [ - Felt::new(5674685213610121970), - Felt::new(5759084860419474071), - Felt::new(13943282657648897737), - Felt::new(1352748651966375394), - Felt::new(17110913224029905221), - Felt::new(1003883795902368422), - Felt::new(4141870621881018291), - Felt::new(8121410972417424656), - Felt::new(14300518605864919529), - Felt::new(13712227150607670181), - Felt::new(17021852944633065291), - Felt::new(6252096473787587650), - ], - [ - Felt::new(4887609836208846458), - Felt::new(3027115137917284492), - Felt::new(9595098600469470675), - Felt::new(10528569829048484079), - Felt::new(7864689113198939815), - Felt::new(17533723827845969040), - Felt::new(5781638039037710951), - Felt::new(17024078752430719006), - Felt::new(109659393484013511), - Felt::new(7158933660534805869), - Felt::new(2955076958026921730), - Felt::new(7433723648458773977), - ], - [ - Felt::new(16308865189192447297), - Felt::new(11977192855656444890), - Felt::new(12532242556065780287), - Felt::new(14594890931430968898), - Felt::new(7291784239689209784), - Felt::new(5514718540551361949), - Felt::new(10025733853830934803), - Felt::new(7293794580341021693), - Felt::new(6728552937464861756), - Felt::new(6332385040983343262), - Felt::new(13277683694236792804), - Felt::new(2600778905124452676), - ], - [ - Felt::new(7123075680859040534), - Felt::new(1034205548717903090), - Felt::new(7717824418247931797), - Felt::new(3019070937878604058), - Felt::new(11403792746066867460), - Felt::new(10280580802233112374), - Felt::new(337153209462421218), - Felt::new(13333398568519923717), - Felt::new(3596153696935337464), - Felt::new(8104208463525993784), - Felt::new(14345062289456085693), - Felt::new(17036731477169661256), - ], -]; - -const ARK2: [[Felt; STATE_WIDTH]; NUM_ROUNDS] = [ - [ - Felt::new(6077062762357204287), - Felt::new(15277620170502011191), - Felt::new(5358738125714196705), - Felt::new(14233283787297595718), - Felt::new(13792579614346651365), - Felt::new(11614812331536767105), - Felt::new(14871063686742261166), - Felt::new(10148237148793043499), - Felt::new(4457428952329675767), - Felt::new(15590786458219172475), - Felt::new(10063319113072092615), - Felt::new(14200078843431360086), - ], - [ - Felt::new(6202948458916099932), - Felt::new(17690140365333231091), - Felt::new(3595001575307484651), - Felt::new(373995945117666487), - Felt::new(1235734395091296013), - Felt::new(14172757457833931602), - Felt::new(707573103686350224), - Felt::new(15453217512188187135), - Felt::new(219777875004506018), - Felt::new(17876696346199469008), - Felt::new(17731621626449383378), - Felt::new(2897136237748376248), - ], - [ - Felt::new(8023374565629191455), - Felt::new(15013690343205953430), - Felt::new(4485500052507912973), - Felt::new(12489737547229155153), - Felt::new(9500452585969030576), - Felt::new(2054001340201038870), - Felt::new(12420704059284934186), - Felt::new(355990932618543755), - Felt::new(9071225051243523860), - Felt::new(12766199826003448536), - Felt::new(9045979173463556963), - Felt::new(12934431667190679898), - ], - [ - Felt::new(18389244934624494276), - Felt::new(16731736864863925227), - Felt::new(4440209734760478192), - Felt::new(17208448209698888938), - Felt::new(8739495587021565984), - Felt::new(17000774922218161967), - Felt::new(13533282547195532087), - Felt::new(525402848358706231), - Felt::new(16987541523062161972), - Felt::new(5466806524462797102), - Felt::new(14512769585918244983), - Felt::new(10973956031244051118), - ], - [ - Felt::new(6982293561042362913), - Felt::new(14065426295947720331), - Felt::new(16451845770444974180), - Felt::new(7139138592091306727), - Felt::new(9012006439959783127), - Felt::new(14619614108529063361), - Felt::new(1394813199588124371), - Felt::new(4635111139507788575), - Felt::new(16217473952264203365), - Felt::new(10782018226466330683), - Felt::new(6844229992533662050), - Felt::new(7446486531695178711), - ], - [ - Felt::new(3736792340494631448), - Felt::new(577852220195055341), - Felt::new(6689998335515779805), - Felt::new(13886063479078013492), - Felt::new(14358505101923202168), - Felt::new(7744142531772274164), - Felt::new(16135070735728404443), - Felt::new(12290902521256031137), - Felt::new(12059913662657709804), - Felt::new(16456018495793751911), - Felt::new(4571485474751953524), - Felt::new(17200392109565783176), - ], - [ - Felt::new(17130398059294018733), - Felt::new(519782857322261988), - Felt::new(9625384390925085478), - Felt::new(1664893052631119222), - Felt::new(7629576092524553570), - Felt::new(3485239601103661425), - Felt::new(9755891797164033838), - Felt::new(15218148195153269027), - Felt::new(16460604813734957368), - Felt::new(9643968136937729763), - Felt::new(3611348709641382851), - Felt::new(18256379591337759196), - ], -]; diff --git a/src/lib.rs b/src/lib.rs index 0eca3e3e..26fb3436 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,7 +13,10 @@ pub mod utils; // RE-EXPORTS // ================================================================================================ -pub use winter_math::{fields::f64::BaseElement as Felt, FieldElement, StarkField}; +pub use winter_math::{ + fields::{f64::BaseElement as Felt, CubeExtension, QuadExtension}, + FieldElement, StarkField, +}; // TYPE ALIASES // ================================================================================================ diff --git a/src/main.rs b/src/main.rs index e9f8299a..31264d87 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,9 +1,8 @@ use clap::Parser; use miden_crypto::{ - hash::rpo::RpoDigest, - merkle::MerkleError, + hash::rpo::{Rpo256, RpoDigest}, + merkle::{MerkleError, TieredSmt}, Felt, Word, ONE, - {hash::rpo::Rpo256, merkle::TieredSmt}, }; use rand_utils::rand_value; use std::time::Instant; diff --git a/src/merkle/empty_roots.rs b/src/merkle/empty_roots.rs index 17cd7819..724d8ec0 100644 --- a/src/merkle/empty_roots.rs +++ b/src/merkle/empty_roots.rs @@ -10,12 +10,19 @@ pub struct EmptySubtreeRoots; impl EmptySubtreeRoots { /// Returns a static slice with roots of empty subtrees of a Merkle tree starting at the /// specified depth. - pub const fn empty_hashes(depth: u8) -> &'static [RpoDigest] { - let ptr = &EMPTY_SUBTREES[255 - depth as usize] as *const RpoDigest; + pub const fn empty_hashes(tree_depth: u8) -> &'static [RpoDigest] { + let ptr = &EMPTY_SUBTREES[255 - tree_depth as usize] as *const RpoDigest; // Safety: this is a static/constant array, so it will never be outlived. If we attempt to // use regular slices, this wouldn't be a `const` function, meaning we won't be able to use // the returned value for static/constant definitions. - unsafe { slice::from_raw_parts(ptr, depth as usize + 1) } + unsafe { slice::from_raw_parts(ptr, tree_depth as usize + 1) } + } + + /// Returns the node's digest for a sub-tree with all its leaves set to the empty word. + pub const fn entry(tree_depth: u8, node_depth: u8) -> &'static RpoDigest { + assert!(node_depth <= tree_depth); + let pos = 255 - tree_depth + node_depth; + &EMPTY_SUBTREES[pos as usize] } } @@ -1583,3 +1590,16 @@ fn all_depths_opens_to_zero() { .for_each(|(x, computed)| assert_eq!(x, computed)); } } + +#[test] +fn test_entry() { + // check the leaf is always the empty work + for depth in 0..255 { + assert_eq!(EmptySubtreeRoots::entry(depth, depth), &RpoDigest::new(EMPTY_WORD)); + } + + // check the root matches the first element of empty_hashes + for depth in 0..255 { + assert_eq!(EmptySubtreeRoots::entry(depth, 0), &EmptySubtreeRoots::empty_hashes(depth)[0]); + } +} diff --git a/src/merkle/error.rs b/src/merkle/error.rs index 5012b75c..b513212f 100644 --- a/src/merkle/error.rs +++ b/src/merkle/error.rs @@ -13,8 +13,9 @@ pub enum MerkleError { DuplicateValuesForKey(RpoDigest), InvalidIndex { depth: u8, value: u64 }, InvalidDepth { expected: u8, provided: u8 }, + InvalidSubtreeDepth { subtree_depth: u8, tree_depth: u8 }, InvalidPath(MerklePath), - InvalidNumEntries(usize, usize), + InvalidNumEntries(usize), NodeNotInSet(NodeIndex), NodeNotInStore(RpoDigest, NodeIndex), NumLeavesNotPowerOfTwo(usize), @@ -30,18 +31,21 @@ impl fmt::Display for MerkleError { DepthTooBig(depth) => write!(f, "the provided depth {depth} is too big"), DuplicateValuesForIndex(key) => write!(f, "multiple values provided for key {key}"), DuplicateValuesForKey(key) => write!(f, "multiple values provided for key {key}"), - InvalidIndex{ depth, value} => write!( - f, - "the index value {value} is not valid for the depth {depth}" - ), - InvalidDepth { expected, provided } => write!( - f, - "the provided depth {provided} is not valid for {expected}" - ), + InvalidIndex { depth, value } => { + write!(f, "the index value {value} is not valid for the depth {depth}") + } + InvalidDepth { expected, provided } => { + write!(f, "the provided depth {provided} is not valid for {expected}") + } + InvalidSubtreeDepth { subtree_depth, tree_depth } => { + write!(f, "tried inserting a subtree of depth {subtree_depth} into a tree of depth {tree_depth}") + } InvalidPath(_path) => write!(f, "the provided path is not valid"), - InvalidNumEntries(max, provided) => write!(f, "the provided number of entries is {provided}, but the maximum for the given depth is {max}"), + InvalidNumEntries(max) => write!(f, "number of entries exceeded the maximum: {max}"), NodeNotInSet(index) => write!(f, "the node with index ({index}) is not in the set"), - NodeNotInStore(hash, index) => write!(f, "the node {hash:?} with index ({index}) is not in the store"), + NodeNotInStore(hash, index) => { + write!(f, "the node {hash:?} with index ({index}) is not in the store") + } NumLeavesNotPowerOfTwo(leaves) => { write!(f, "the leaves count {leaves} is not a power of 2") } diff --git a/src/merkle/index.rs b/src/merkle/index.rs index 25c9282d..c533fa0d 100644 --- a/src/merkle/index.rs +++ b/src/merkle/index.rs @@ -187,13 +187,20 @@ mod tests { #[test] fn test_node_index_value_too_high() { assert_eq!(NodeIndex::new(0, 0).unwrap(), NodeIndex { depth: 0, value: 0 }); - match NodeIndex::new(0, 1) { - Err(MerkleError::InvalidIndex { depth, value }) => { - assert_eq!(depth, 0); - assert_eq!(value, 1); - } - _ => unreachable!(), - } + let err = NodeIndex::new(0, 1).unwrap_err(); + assert_eq!(err, MerkleError::InvalidIndex { depth: 0, value: 1 }); + + assert_eq!(NodeIndex::new(1, 1).unwrap(), NodeIndex { depth: 1, value: 1 }); + let err = NodeIndex::new(1, 2).unwrap_err(); + assert_eq!(err, MerkleError::InvalidIndex { depth: 1, value: 2 }); + + assert_eq!(NodeIndex::new(2, 3).unwrap(), NodeIndex { depth: 2, value: 3 }); + let err = NodeIndex::new(2, 4).unwrap_err(); + assert_eq!(err, MerkleError::InvalidIndex { depth: 2, value: 4 }); + + assert_eq!(NodeIndex::new(3, 7).unwrap(), NodeIndex { depth: 3, value: 7 }); + let err = NodeIndex::new(3, 8).unwrap_err(); + assert_eq!(err, MerkleError::InvalidIndex { depth: 3, value: 8 }); } #[test] diff --git a/src/merkle/mmr/full.rs b/src/merkle/mmr/full.rs index 6b397dce..b47437d7 100644 --- a/src/merkle/mmr/full.rs +++ b/src/merkle/mmr/full.rs @@ -9,11 +9,12 @@ //! least number of leaves. The structure preserves the invariant that each tree has different //! depths, i.e. as part of adding adding a new element to the forest the trees with same depth are //! merged, creating a new tree with depth d+1, this process is continued until the property is -//! restabilished. +//! reestablished. use super::{ - super::{InnerNodeInfo, MerklePath, RpoDigest, Vec}, + super::{InnerNodeInfo, MerklePath, Vec}, bit::TrueBitPositionIterator, leaf_to_corresponding_tree, nodes_in_forest, MmrDelta, MmrError, MmrPeaks, MmrProof, Rpo256, + RpoDigest, }; // MMR @@ -76,13 +77,13 @@ impl Mmr { /// Note: The leaf position is the 0-indexed number corresponding to the order the leaves were /// added, this corresponds to the MMR size _prior_ to adding the element. So the 1st element /// has position 0, the second position 1, and so on. - pub fn open(&self, pos: usize) -> Result { + pub fn open(&self, pos: usize, target_forest: usize) -> Result { // find the target tree responsible for the MMR position let tree_bit = - leaf_to_corresponding_tree(pos, self.forest).ok_or(MmrError::InvalidPosition(pos))?; + leaf_to_corresponding_tree(pos, target_forest).ok_or(MmrError::InvalidPosition(pos))?; // isolate the trees before the target - let forest_before = self.forest & high_bitmask(tree_bit + 1); + let forest_before = target_forest & high_bitmask(tree_bit + 1); let index_offset = nodes_in_forest(forest_before); // update the value position from global to the target tree @@ -92,7 +93,7 @@ impl Mmr { let (_, path) = self.collect_merkle_path_and_value(tree_bit, relative_pos, index_offset); Ok(MmrProof { - forest: self.forest, + forest: target_forest, position: pos, merkle_path: MerklePath::new(path), }) @@ -143,9 +144,13 @@ impl Mmr { self.forest += 1; } - /// Returns an accumulator representing the current state of the MMR. - pub fn accumulator(&self) -> MmrPeaks { - let peaks: Vec = TrueBitPositionIterator::new(self.forest) + /// Returns an peaks of the MMR for the version specified by `forest`. + pub fn peaks(&self, forest: usize) -> Result { + if forest > self.forest { + return Err(MmrError::InvalidPeaks); + } + + let peaks: Vec = TrueBitPositionIterator::new(forest) .rev() .map(|bit| nodes_in_forest(1 << bit)) .scan(0, |offset, el| { @@ -156,39 +161,41 @@ impl Mmr { .collect(); // Safety: the invariant is maintained by the [Mmr] - MmrPeaks::new(self.forest, peaks).unwrap() + let peaks = MmrPeaks::new(forest, peaks).unwrap(); + + Ok(peaks) } /// Compute the required update to `original_forest`. /// /// The result is a packed sequence of the authentication elements required to update the trees /// that have been merged together, followed by the new peaks of the [Mmr]. - pub fn get_delta(&self, original_forest: usize) -> Result { - if original_forest > self.forest { + pub fn get_delta(&self, from_forest: usize, to_forest: usize) -> Result { + if to_forest > self.forest || from_forest > to_forest { return Err(MmrError::InvalidPeaks); } - if original_forest == self.forest { - return Ok(MmrDelta { forest: self.forest, data: Vec::new() }); + if from_forest == to_forest { + return Ok(MmrDelta { forest: to_forest, data: Vec::new() }); } let mut result = Vec::new(); - // Find the largest tree in this [Mmr] which is new to `original_forest`. - let candidate_trees = self.forest ^ original_forest; + // Find the largest tree in this [Mmr] which is new to `from_forest`. + let candidate_trees = to_forest ^ from_forest; let mut new_high = 1 << candidate_trees.ilog2(); // Collect authentication nodes used for tree merges // ---------------------------------------------------------------------------------------- - // Find the trees from `original_forest` that have been merged into `new_high`. - let mut merges = original_forest & (new_high - 1); + // Find the trees from `from_forest` that have been merged into `new_high`. + let mut merges = from_forest & (new_high - 1); - // Find the peaks that are common to `original_forest` and this [Mmr] - let common_trees = original_forest ^ merges; + // Find the peaks that are common to `from_forest` and this [Mmr] + let common_trees = from_forest ^ merges; if merges != 0 { - // Skip the smallest trees unknown to `original_forest`. + // Skip the smallest trees unknown to `from_forest`. let mut target = 1 << merges.trailing_zeros(); // Collect siblings required to computed the merged tree's peak @@ -213,15 +220,15 @@ impl Mmr { } } else { // The new high tree may not be the result of any merges, if it is smaller than all the - // trees of `original_forest`. + // trees of `from_forest`. new_high = 0; } // Collect the new [Mmr] peaks // ---------------------------------------------------------------------------------------- - let mut new_peaks = self.forest ^ common_trees ^ new_high; - let old_peaks = self.forest ^ new_peaks; + let mut new_peaks = to_forest ^ common_trees ^ new_high; + let old_peaks = to_forest ^ new_peaks; let mut offset = nodes_in_forest(old_peaks); while new_peaks != 0 { let target = 1 << new_peaks.ilog2(); @@ -230,7 +237,7 @@ impl Mmr { new_peaks ^= target; } - Ok(MmrDelta { forest: self.forest, data: result }) + Ok(MmrDelta { forest: to_forest, data: result }) } /// An iterator over inner nodes in the MMR. The order of iteration is unspecified. diff --git a/src/merkle/mmr/mod.rs b/src/merkle/mmr/mod.rs index 5e72d2d5..a28a8434 100644 --- a/src/merkle/mmr/mod.rs +++ b/src/merkle/mmr/mod.rs @@ -10,7 +10,7 @@ mod proof; #[cfg(test)] mod tests; -use super::{Felt, Rpo256, Word}; +use super::{Felt, Rpo256, RpoDigest, Word}; // REEXPORTS // ================================================================================================ diff --git a/src/merkle/mmr/partial.rs b/src/merkle/mmr/partial.rs index d90cdf37..be3e75d9 100644 --- a/src/merkle/mmr/partial.rs +++ b/src/merkle/mmr/partial.rs @@ -1,5 +1,5 @@ +use super::{MmrDelta, MmrProof, Rpo256, RpoDigest}; use crate::{ - hash::rpo::{Rpo256, RpoDigest}, merkle::{ mmr::{leaf_to_corresponding_tree, nodes_in_forest}, InOrderIndex, MerklePath, MmrError, MmrPeaks, @@ -7,8 +7,6 @@ use crate::{ utils::collections::{BTreeMap, Vec}, }; -use super::{MmrDelta, MmrProof}; - /// Partially materialized [Mmr], used to efficiently store and update the authentication paths for /// a subset of the elements in a full [Mmr]. /// diff --git a/src/merkle/mmr/peaks.rs b/src/merkle/mmr/peaks.rs index a3613fe8..e0ec3f24 100644 --- a/src/merkle/mmr/peaks.rs +++ b/src/merkle/mmr/peaks.rs @@ -54,13 +54,18 @@ impl MmrPeaks { &self.peaks } + /// Returns the current num_leaves and peaks of the [Mmr]. + pub fn into_parts(self) -> (usize, Vec) { + (self.num_leaves, self.peaks) + } + /// Hashes the peaks. /// /// The procedure will: /// - Flatten and pad the peaks to a vector of Felts. /// - Hash the vector of Felts. - pub fn hash_peaks(&self) -> Word { - Rpo256::hash_elements(&self.flatten_and_pad_peaks()).into() + pub fn hash_peaks(&self) -> RpoDigest { + Rpo256::hash_elements(&self.flatten_and_pad_peaks()) } pub fn verify(&self, value: RpoDigest, opening: MmrProof) -> bool { diff --git a/src/merkle/mmr/tests.rs b/src/merkle/mmr/tests.rs index d829c965..a92f4c64 100644 --- a/src/merkle/mmr/tests.rs +++ b/src/merkle/mmr/tests.rs @@ -1,11 +1,10 @@ use super::{ - super::{InnerNodeInfo, Vec}, + super::{InnerNodeInfo, Rpo256, RpoDigest, Vec}, bit::TrueBitPositionIterator, full::high_bitmask, - leaf_to_corresponding_tree, nodes_in_forest, Mmr, MmrPeaks, PartialMmr, Rpo256, + leaf_to_corresponding_tree, nodes_in_forest, Mmr, MmrPeaks, PartialMmr, }; use crate::{ - hash::rpo::RpoDigest, merkle::{int_to_node, InOrderIndex, MerklePath, MerkleTree, MmrProof, NodeIndex}, Felt, Word, }; @@ -137,7 +136,7 @@ fn test_mmr_simple() { assert_eq!(mmr.nodes.len(), 1); assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]); - let acc = mmr.accumulator(); + let acc = mmr.peaks(mmr.forest()).unwrap(); assert_eq!(acc.num_leaves(), 1); assert_eq!(acc.peaks(), &[postorder[0]]); @@ -146,7 +145,7 @@ fn test_mmr_simple() { assert_eq!(mmr.nodes.len(), 3); assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]); - let acc = mmr.accumulator(); + let acc = mmr.peaks(mmr.forest()).unwrap(); assert_eq!(acc.num_leaves(), 2); assert_eq!(acc.peaks(), &[postorder[2]]); @@ -155,7 +154,7 @@ fn test_mmr_simple() { assert_eq!(mmr.nodes.len(), 4); assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]); - let acc = mmr.accumulator(); + let acc = mmr.peaks(mmr.forest()).unwrap(); assert_eq!(acc.num_leaves(), 3); assert_eq!(acc.peaks(), &[postorder[2], postorder[3]]); @@ -164,7 +163,7 @@ fn test_mmr_simple() { assert_eq!(mmr.nodes.len(), 7); assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]); - let acc = mmr.accumulator(); + let acc = mmr.peaks(mmr.forest()).unwrap(); assert_eq!(acc.num_leaves(), 4); assert_eq!(acc.peaks(), &[postorder[6]]); @@ -173,7 +172,7 @@ fn test_mmr_simple() { assert_eq!(mmr.nodes.len(), 8); assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]); - let acc = mmr.accumulator(); + let acc = mmr.peaks(mmr.forest()).unwrap(); assert_eq!(acc.num_leaves(), 5); assert_eq!(acc.peaks(), &[postorder[6], postorder[7]]); @@ -182,7 +181,7 @@ fn test_mmr_simple() { assert_eq!(mmr.nodes.len(), 10); assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]); - let acc = mmr.accumulator(); + let acc = mmr.peaks(mmr.forest()).unwrap(); assert_eq!(acc.num_leaves(), 6); assert_eq!(acc.peaks(), &[postorder[6], postorder[9]]); @@ -191,7 +190,7 @@ fn test_mmr_simple() { assert_eq!(mmr.nodes.len(), 11); assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]); - let acc = mmr.accumulator(); + let acc = mmr.peaks(mmr.forest()).unwrap(); assert_eq!(acc.num_leaves(), 7); assert_eq!(acc.peaks(), &[postorder[6], postorder[9], postorder[10]]); } @@ -203,96 +202,139 @@ fn test_mmr_open() { let h23 = merge(LEAVES[2], LEAVES[3]); // node at pos 7 is the root - assert!(mmr.open(7).is_err(), "Element 7 is not in the tree, result should be None"); + assert!( + mmr.open(7, mmr.forest()).is_err(), + "Element 7 is not in the tree, result should be None" + ); // node at pos 6 is the root let empty: MerklePath = MerklePath::new(vec![]); let opening = mmr - .open(6) + .open(6, mmr.forest()) .expect("Element 6 is contained in the tree, expected an opening result."); assert_eq!(opening.merkle_path, empty); assert_eq!(opening.forest, mmr.forest); assert_eq!(opening.position, 6); assert!( - mmr.accumulator().verify(LEAVES[6], opening), + mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[6], opening), "MmrProof should be valid for the current accumulator." ); // nodes 4,5 are depth 1 let root_to_path = MerklePath::new(vec![LEAVES[4]]); let opening = mmr - .open(5) + .open(5, mmr.forest()) .expect("Element 5 is contained in the tree, expected an opening result."); assert_eq!(opening.merkle_path, root_to_path); assert_eq!(opening.forest, mmr.forest); assert_eq!(opening.position, 5); assert!( - mmr.accumulator().verify(LEAVES[5], opening), + mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[5], opening), "MmrProof should be valid for the current accumulator." ); let root_to_path = MerklePath::new(vec![LEAVES[5]]); let opening = mmr - .open(4) + .open(4, mmr.forest()) .expect("Element 4 is contained in the tree, expected an opening result."); assert_eq!(opening.merkle_path, root_to_path); assert_eq!(opening.forest, mmr.forest); assert_eq!(opening.position, 4); assert!( - mmr.accumulator().verify(LEAVES[4], opening), + mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[4], opening), "MmrProof should be valid for the current accumulator." ); // nodes 0,1,2,3 are detph 2 let root_to_path = MerklePath::new(vec![LEAVES[2], h01]); let opening = mmr - .open(3) + .open(3, mmr.forest()) .expect("Element 3 is contained in the tree, expected an opening result."); assert_eq!(opening.merkle_path, root_to_path); assert_eq!(opening.forest, mmr.forest); assert_eq!(opening.position, 3); assert!( - mmr.accumulator().verify(LEAVES[3], opening), + mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[3], opening), "MmrProof should be valid for the current accumulator." ); let root_to_path = MerklePath::new(vec![LEAVES[3], h01]); let opening = mmr - .open(2) + .open(2, mmr.forest()) .expect("Element 2 is contained in the tree, expected an opening result."); assert_eq!(opening.merkle_path, root_to_path); assert_eq!(opening.forest, mmr.forest); assert_eq!(opening.position, 2); assert!( - mmr.accumulator().verify(LEAVES[2], opening), + mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[2], opening), "MmrProof should be valid for the current accumulator." ); let root_to_path = MerklePath::new(vec![LEAVES[0], h23]); let opening = mmr - .open(1) + .open(1, mmr.forest()) .expect("Element 1 is contained in the tree, expected an opening result."); assert_eq!(opening.merkle_path, root_to_path); assert_eq!(opening.forest, mmr.forest); assert_eq!(opening.position, 1); assert!( - mmr.accumulator().verify(LEAVES[1], opening), + mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[1], opening), "MmrProof should be valid for the current accumulator." ); let root_to_path = MerklePath::new(vec![LEAVES[1], h23]); let opening = mmr - .open(0) + .open(0, mmr.forest()) .expect("Element 0 is contained in the tree, expected an opening result."); assert_eq!(opening.merkle_path, root_to_path); assert_eq!(opening.forest, mmr.forest); assert_eq!(opening.position, 0); assert!( - mmr.accumulator().verify(LEAVES[0], opening), + mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[0], opening), "MmrProof should be valid for the current accumulator." ); } +#[test] +fn test_mmr_open_older_version() { + let mmr: Mmr = LEAVES.into(); + + fn is_even(v: &usize) -> bool { + v & 1 == 0 + } + + // merkle path of a node is empty if there are no elements to pair with it + for pos in (0..mmr.forest()).filter(is_even) { + let forest = pos + 1; + let proof = mmr.open(pos, forest).unwrap(); + assert_eq!(proof.forest, forest); + assert_eq!(proof.merkle_path.nodes(), []); + assert_eq!(proof.position, pos); + } + + // openings match that of a merkle tree + let mtree: MerkleTree = LEAVES[..4].try_into().unwrap(); + for forest in 4..=LEAVES.len() { + for pos in 0..4 { + let idx = NodeIndex::new(2, pos).unwrap(); + let path = mtree.get_path(idx).unwrap(); + let proof = mmr.open(pos as usize, forest).unwrap(); + assert_eq!(path, proof.merkle_path); + } + } + let mtree: MerkleTree = LEAVES[4..6].try_into().unwrap(); + for forest in 6..=LEAVES.len() { + for pos in 0..2 { + let idx = NodeIndex::new(1, pos).unwrap(); + let path = mtree.get_path(idx).unwrap(); + // account for the bigger tree with 4 elements + let mmr_pos = (pos + 4) as usize; + let proof = mmr.open(mmr_pos, forest).unwrap(); + assert_eq!(path, proof.merkle_path); + } + } +} + /// Tests the openings of a simple Mmr with a single tree of depth 8. #[test] fn test_mmr_open_eight() { @@ -313,49 +355,49 @@ fn test_mmr_open_eight() { let root = mtree.root(); let position = 0; - let proof = mmr.open(position).unwrap(); + let proof = mmr.open(position, mmr.forest()).unwrap(); let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap(); assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root); let position = 1; - let proof = mmr.open(position).unwrap(); + let proof = mmr.open(position, mmr.forest()).unwrap(); let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap(); assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root); let position = 2; - let proof = mmr.open(position).unwrap(); + let proof = mmr.open(position, mmr.forest()).unwrap(); let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap(); assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root); let position = 3; - let proof = mmr.open(position).unwrap(); + let proof = mmr.open(position, mmr.forest()).unwrap(); let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap(); assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root); let position = 4; - let proof = mmr.open(position).unwrap(); + let proof = mmr.open(position, mmr.forest()).unwrap(); let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap(); assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root); let position = 5; - let proof = mmr.open(position).unwrap(); + let proof = mmr.open(position, mmr.forest()).unwrap(); let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap(); assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root); let position = 6; - let proof = mmr.open(position).unwrap(); + let proof = mmr.open(position, mmr.forest()).unwrap(); let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap(); assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root); let position = 7; - let proof = mmr.open(position).unwrap(); + let proof = mmr.open(position, mmr.forest()).unwrap(); let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap(); assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root); @@ -371,47 +413,47 @@ fn test_mmr_open_seven() { let mmr: Mmr = LEAVES.into(); let position = 0; - let proof = mmr.open(position).unwrap(); + let proof = mmr.open(position, mmr.forest()).unwrap(); let merkle_path: MerklePath = mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap(); assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof.merkle_path.compute_root(0, LEAVES[0]).unwrap(), mtree1.root()); let position = 1; - let proof = mmr.open(position).unwrap(); + let proof = mmr.open(position, mmr.forest()).unwrap(); let merkle_path: MerklePath = mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap(); assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof.merkle_path.compute_root(1, LEAVES[1]).unwrap(), mtree1.root()); let position = 2; - let proof = mmr.open(position).unwrap(); + let proof = mmr.open(position, mmr.forest()).unwrap(); let merkle_path: MerklePath = mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap(); assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof.merkle_path.compute_root(2, LEAVES[2]).unwrap(), mtree1.root()); let position = 3; - let proof = mmr.open(position).unwrap(); + let proof = mmr.open(position, mmr.forest()).unwrap(); let merkle_path: MerklePath = mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap(); assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof.merkle_path.compute_root(3, LEAVES[3]).unwrap(), mtree1.root()); let position = 4; - let proof = mmr.open(position).unwrap(); + let proof = mmr.open(position, mmr.forest()).unwrap(); let merkle_path: MerklePath = mtree2.get_path(NodeIndex::new(1, 0u64).unwrap()).unwrap(); assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof.merkle_path.compute_root(0, LEAVES[4]).unwrap(), mtree2.root()); let position = 5; - let proof = mmr.open(position).unwrap(); + let proof = mmr.open(position, mmr.forest()).unwrap(); let merkle_path: MerklePath = mtree2.get_path(NodeIndex::new(1, 1u64).unwrap()).unwrap(); assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof.merkle_path.compute_root(1, LEAVES[5]).unwrap(), mtree2.root()); let position = 6; - let proof = mmr.open(position).unwrap(); + let proof = mmr.open(position, mmr.forest()).unwrap(); let merkle_path: MerklePath = [].as_ref().into(); assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof.merkle_path.compute_root(0, LEAVES[6]).unwrap(), LEAVES[6]); @@ -435,7 +477,7 @@ fn test_mmr_invariants() { let mut mmr = Mmr::new(); for v in 1..=1028 { mmr.add(int_to_node(v)); - let accumulator = mmr.accumulator(); + let accumulator = mmr.peaks(mmr.forest()).unwrap(); assert_eq!(v as usize, mmr.forest(), "MMR leaf count must increase by one on every add"); assert_eq!( v as usize, @@ -516,10 +558,50 @@ fn test_mmr_inner_nodes() { assert_eq!(postorder, nodes); } +#[test] +fn test_mmr_peaks() { + let mmr: Mmr = LEAVES.into(); + + let forest = 0b0001; + let acc = mmr.peaks(forest).unwrap(); + assert_eq!(acc.num_leaves(), forest); + assert_eq!(acc.peaks(), &[mmr.nodes[0]]); + + let forest = 0b0010; + let acc = mmr.peaks(forest).unwrap(); + assert_eq!(acc.num_leaves(), forest); + assert_eq!(acc.peaks(), &[mmr.nodes[2]]); + + let forest = 0b0011; + let acc = mmr.peaks(forest).unwrap(); + assert_eq!(acc.num_leaves(), forest); + assert_eq!(acc.peaks(), &[mmr.nodes[2], mmr.nodes[3]]); + + let forest = 0b0100; + let acc = mmr.peaks(forest).unwrap(); + assert_eq!(acc.num_leaves(), forest); + assert_eq!(acc.peaks(), &[mmr.nodes[6]]); + + let forest = 0b0101; + let acc = mmr.peaks(forest).unwrap(); + assert_eq!(acc.num_leaves(), forest); + assert_eq!(acc.peaks(), &[mmr.nodes[6], mmr.nodes[7]]); + + let forest = 0b0110; + let acc = mmr.peaks(forest).unwrap(); + assert_eq!(acc.num_leaves(), forest); + assert_eq!(acc.peaks(), &[mmr.nodes[6], mmr.nodes[9]]); + + let forest = 0b0111; + let acc = mmr.peaks(forest).unwrap(); + assert_eq!(acc.num_leaves(), forest); + assert_eq!(acc.peaks(), &[mmr.nodes[6], mmr.nodes[9], mmr.nodes[10]]); +} + #[test] fn test_mmr_hash_peaks() { let mmr: Mmr = LEAVES.into(); - let peaks = mmr.accumulator(); + let peaks = mmr.peaks(mmr.forest()).unwrap(); let first_peak = Rpo256::merge(&[ Rpo256::merge(&[LEAVES[0], LEAVES[1]]), @@ -531,10 +613,7 @@ fn test_mmr_hash_peaks() { // minimum length is 16 let mut expected_peaks = [first_peak, second_peak, third_peak].to_vec(); expected_peaks.resize(16, RpoDigest::default()); - assert_eq!( - peaks.hash_peaks(), - *Rpo256::hash_elements(&digests_to_elements(&expected_peaks)) - ); + assert_eq!(peaks.hash_peaks(), Rpo256::hash_elements(&digests_to_elements(&expected_peaks))); } #[test] @@ -552,7 +631,7 @@ fn test_mmr_peaks_hash_less_than_16() { expected_peaks.resize(16, RpoDigest::default()); assert_eq!( accumulator.hash_peaks(), - *Rpo256::hash_elements(&digests_to_elements(&expected_peaks)) + Rpo256::hash_elements(&digests_to_elements(&expected_peaks)) ); } } @@ -569,47 +648,47 @@ fn test_mmr_peaks_hash_odd() { expected_peaks.resize(18, RpoDigest::default()); assert_eq!( accumulator.hash_peaks(), - *Rpo256::hash_elements(&digests_to_elements(&expected_peaks)) + Rpo256::hash_elements(&digests_to_elements(&expected_peaks)) ); } #[test] -fn test_mmr_updates() { +fn test_mmr_delta() { let mmr: Mmr = LEAVES.into(); - let acc = mmr.accumulator(); + let acc = mmr.peaks(mmr.forest()).unwrap(); // original_forest can't have more elements assert!( - mmr.get_delta(LEAVES.len() + 1).is_err(), + mmr.get_delta(LEAVES.len() + 1, mmr.forest()).is_err(), "Can not provide updates for a newer Mmr" ); // if the number of elements is the same there is no change assert!( - mmr.get_delta(LEAVES.len()).unwrap().data.is_empty(), + mmr.get_delta(LEAVES.len(), mmr.forest()).unwrap().data.is_empty(), "There are no updates for the same Mmr version" ); // missing the last element added, which is itself a tree peak - assert_eq!(mmr.get_delta(6).unwrap().data, vec![acc.peaks()[2]], "one peak"); + assert_eq!(mmr.get_delta(6, mmr.forest()).unwrap().data, vec![acc.peaks()[2]], "one peak"); // missing the sibling to complete the tree of depth 2, and the last element assert_eq!( - mmr.get_delta(5).unwrap().data, + mmr.get_delta(5, mmr.forest()).unwrap().data, vec![LEAVES[5], acc.peaks()[2]], "one sibling, one peak" ); // missing the whole last two trees, only send the peaks assert_eq!( - mmr.get_delta(4).unwrap().data, + mmr.get_delta(4, mmr.forest()).unwrap().data, vec![acc.peaks()[1], acc.peaks()[2]], "two peaks" ); // missing the sibling to complete the first tree, and the two last trees assert_eq!( - mmr.get_delta(3).unwrap().data, + mmr.get_delta(3, mmr.forest()).unwrap().data, vec![LEAVES[3], acc.peaks()[1], acc.peaks()[2]], "one sibling, two peaks" ); @@ -617,24 +696,66 @@ fn test_mmr_updates() { // missing half of the first tree, only send the computed element (not the leaves), and the new // peaks assert_eq!( - mmr.get_delta(2).unwrap().data, + mmr.get_delta(2, mmr.forest()).unwrap().data, vec![mmr.nodes[5], acc.peaks()[1], acc.peaks()[2]], "one sibling, two peaks" ); assert_eq!( - mmr.get_delta(1).unwrap().data, + mmr.get_delta(1, mmr.forest()).unwrap().data, vec![LEAVES[1], mmr.nodes[5], acc.peaks()[1], acc.peaks()[2]], "one sibling, two peaks" ); - assert_eq!(&mmr.get_delta(0).unwrap().data, acc.peaks(), "all peaks"); + assert_eq!(&mmr.get_delta(0, mmr.forest()).unwrap().data, acc.peaks(), "all peaks"); +} + +#[test] +fn test_mmr_delta_old_forest() { + let mmr: Mmr = LEAVES.into(); + + // from_forest must be smaller-or-equal to to_forest + for version in 1..=mmr.forest() { + assert!(mmr.get_delta(version + 1, version).is_err()); + } + + // when from_forest and to_forest are equal, there are no updates + for version in 1..=mmr.forest() { + let delta = mmr.get_delta(version, version).unwrap(); + assert!(delta.data.is_empty()); + assert_eq!(delta.forest, version); + } + + // test update which merges the odd peak to the right + for count in 0..(mmr.forest() / 2) { + // *2 because every iteration tests a pair + // +1 because the Mmr is 1-indexed + let from_forest = (count * 2) + 1; + let to_forest = (count * 2) + 2; + let delta = mmr.get_delta(from_forest, to_forest).unwrap(); + + // *2 because every iteration tests a pair + // +1 because sibling is the odd element + let sibling = (count * 2) + 1; + assert_eq!(delta.data, [LEAVES[sibling]]); + assert_eq!(delta.forest, to_forest); + } + + let version = 4; + let delta = mmr.get_delta(1, version).unwrap(); + assert_eq!(delta.data, [mmr.nodes[1], mmr.nodes[5]]); + assert_eq!(delta.forest, version); + + let version = 5; + let delta = mmr.get_delta(1, version).unwrap(); + assert_eq!(delta.data, [mmr.nodes[1], mmr.nodes[5], mmr.nodes[7]]); + assert_eq!(delta.forest, version); } #[test] fn test_partial_mmr_simple() { let mmr: Mmr = LEAVES.into(); - let acc = mmr.accumulator(); + let acc = mmr.peaks(mmr.forest()).unwrap(); let mut partial: PartialMmr = acc.clone().into(); // check initial state of the partial mmr @@ -645,7 +766,7 @@ fn test_partial_mmr_simple() { assert_eq!(partial.nodes.len(), 0); // check state after adding tracking one element - let proof1 = mmr.open(0).unwrap(); + let proof1 = mmr.open(0, mmr.forest()).unwrap(); let el1 = mmr.get(proof1.position).unwrap(); partial.add(proof1.position, el1, &proof1.merkle_path).unwrap(); @@ -657,7 +778,7 @@ fn test_partial_mmr_simple() { let idx = idx.parent(); assert_eq!(partial.nodes[&idx.sibling()], proof1.merkle_path[1]); - let proof2 = mmr.open(1).unwrap(); + let proof2 = mmr.open(1, mmr.forest()).unwrap(); let el2 = mmr.get(proof2.position).unwrap(); partial.add(proof2.position, el2, &proof2.merkle_path).unwrap(); @@ -675,21 +796,21 @@ fn test_partial_mmr_update_single() { let mut full = Mmr::new(); let zero = int_to_node(0); full.add(zero); - let mut partial: PartialMmr = full.accumulator().into(); + let mut partial: PartialMmr = full.peaks(full.forest()).unwrap().into(); - let proof = full.open(0).unwrap(); + let proof = full.open(0, full.forest()).unwrap(); partial.add(proof.position, zero, &proof.merkle_path).unwrap(); for i in 1..100 { let node = int_to_node(i); full.add(node); - let delta = full.get_delta(partial.forest()).unwrap(); + let delta = full.get_delta(partial.forest(), full.forest()).unwrap(); partial.apply(delta).unwrap(); assert_eq!(partial.forest(), full.forest()); - assert_eq!(partial.peaks(), full.accumulator().peaks()); + assert_eq!(partial.peaks(), full.peaks(full.forest()).unwrap().peaks()); - let proof1 = full.open(i as usize).unwrap(); + let proof1 = full.open(i as usize, full.forest()).unwrap(); partial.add(proof1.position, node, &proof1.merkle_path).unwrap(); let proof2 = partial.open(proof1.position).unwrap().unwrap(); assert_eq!(proof1.merkle_path, proof2.merkle_path); @@ -699,7 +820,7 @@ fn test_partial_mmr_update_single() { #[test] fn test_mmr_add_invalid_odd_leaf() { let mmr: Mmr = LEAVES.into(); - let acc = mmr.accumulator(); + let acc = mmr.peaks(mmr.forest()).unwrap(); let mut partial: PartialMmr = acc.clone().into(); let empty = MerklePath::new(Vec::new()); diff --git a/src/merkle/mod.rs b/src/merkle/mod.rs index 4e55cb7b..9c6c3c77 100644 --- a/src/merkle/mod.rs +++ b/src/merkle/mod.rs @@ -31,7 +31,7 @@ mod tiered_smt; pub use tiered_smt::{TieredSmt, TieredSmtProof, TieredSmtProofError}; mod mmr; -pub use mmr::{InOrderIndex, Mmr, MmrError, MmrPeaks, MmrProof, PartialMmr}; +pub use mmr::{InOrderIndex, Mmr, MmrDelta, MmrError, MmrPeaks, MmrProof, PartialMmr}; mod store; pub use store::{DefaultMerkleStore, MerkleStore, RecordingMerkleStore, StoreNode}; diff --git a/src/merkle/node.rs b/src/merkle/node.rs index 4305e7f7..bf18d386 100644 --- a/src/merkle/node.rs +++ b/src/merkle/node.rs @@ -1,4 +1,4 @@ -use crate::hash::rpo::RpoDigest; +use super::RpoDigest; /// Representation of a node with two children used for iterating over containers. #[derive(Debug, Clone, PartialEq, Eq)] diff --git a/src/merkle/partial_mt/mod.rs b/src/merkle/partial_mt/mod.rs index c29b6e8b..35af8cd7 100644 --- a/src/merkle/partial_mt/mod.rs +++ b/src/merkle/partial_mt/mod.rs @@ -109,9 +109,9 @@ impl PartialMerkleTree { // check if the number of leaves can be accommodated by the tree's depth; we use a min // depth of 63 because we consider passing in a vector of size 2^64 infeasible. - let max = (1_u64 << 63) as usize; + let max = 2usize.pow(63); if layers.len() > max { - return Err(MerkleError::InvalidNumEntries(max, layers.len())); + return Err(MerkleError::InvalidNumEntries(max)); } // Get maximum depth diff --git a/src/merkle/path.rs b/src/merkle/path.rs index 745931c9..16a092f7 100644 --- a/src/merkle/path.rs +++ b/src/merkle/path.rs @@ -1,5 +1,6 @@ use super::{vec, InnerNodeInfo, MerkleError, NodeIndex, Rpo256, RpoDigest, Vec}; use core::ops::{Deref, DerefMut}; +use winter_utils::{ByteReader, Deserializable, DeserializationError, Serializable}; // MERKLE PATH // ================================================================================================ @@ -17,6 +18,7 @@ impl MerklePath { /// Creates a new Merkle path from a list of nodes. pub fn new(nodes: Vec) -> Self { + assert!(nodes.len() <= u8::MAX.into(), "MerklePath may have at most 256 items"); Self { nodes } } @@ -189,6 +191,54 @@ pub struct RootPath { pub path: MerklePath, } +// SERILIZATION +// ================================================================================================ +impl Serializable for MerklePath { + fn write_into(&self, target: &mut W) { + assert!(self.nodes.len() <= u8::MAX.into(), "Length enforced in the construtor"); + target.write_u8(self.nodes.len() as u8); + self.nodes.write_into(target); + } +} + +impl Deserializable for MerklePath { + fn read_from(source: &mut R) -> Result { + let count = source.read_u8()?.into(); + let nodes = RpoDigest::read_batch_from(source, count)?; + Ok(Self { nodes }) + } +} + +impl Serializable for ValuePath { + fn write_into(&self, target: &mut W) { + self.value.write_into(target); + self.path.write_into(target); + } +} + +impl Deserializable for ValuePath { + fn read_from(source: &mut R) -> Result { + let value = RpoDigest::read_from(source)?; + let path = MerklePath::read_from(source)?; + Ok(Self { value, path }) + } +} + +impl Serializable for RootPath { + fn write_into(&self, target: &mut W) { + self.root.write_into(target); + self.path.write_into(target); + } +} + +impl Deserializable for RootPath { + fn read_from(source: &mut R) -> Result { + let root = RpoDigest::read_from(source)?; + let path = MerklePath::read_from(source)?; + Ok(Self { root, path }) + } +} + // TESTS // ================================================================================================ diff --git a/src/merkle/simple_smt/mod.rs b/src/merkle/simple_smt/mod.rs index a9511223..4ce94a41 100644 --- a/src/merkle/simple_smt/mod.rs +++ b/src/merkle/simple_smt/mod.rs @@ -19,7 +19,6 @@ pub struct SimpleSmt { root: RpoDigest, leaves: BTreeMap, branches: BTreeMap, - empty_hashes: Vec, } impl SimpleSmt { @@ -52,13 +51,11 @@ impl SimpleSmt { return Err(MerkleError::DepthTooBig(depth as u64)); } - let empty_hashes = EmptySubtreeRoots::empty_hashes(depth).to_vec(); - let root = empty_hashes[0]; + let root = *EmptySubtreeRoots::entry(depth, 0); Ok(Self { root, depth, - empty_hashes, leaves: BTreeMap::new(), branches: BTreeMap::new(), }) @@ -74,39 +71,54 @@ impl SimpleSmt { /// - If the depth is 0 or is greater than 64. /// - The number of entries exceeds the maximum tree capacity, that is 2^{depth}. /// - The provided entries contain multiple values for the same key. - pub fn with_leaves(depth: u8, entries: R) -> Result - where - R: IntoIterator, - I: Iterator + ExactSizeIterator, - { + pub fn with_leaves( + depth: u8, + entries: impl IntoIterator, + ) -> Result { // create an empty tree let mut tree = Self::new(depth)?; - // check if the number of leaves can be accommodated by the tree's depth; we use a min - // depth of 63 because we consider passing in a vector of size 2^64 infeasible. - let entries = entries.into_iter(); - let max = 1 << tree.depth.min(63); - if entries.len() > max { - return Err(MerkleError::InvalidNumEntries(max, entries.len())); - } + // compute the max number of entries. We use an upper bound of depth 63 because we consider + // passing in a vector of size 2^64 infeasible. + let max_num_entries = 2_usize.pow(tree.depth.min(63).into()); - // append leaves to the tree returning an error if a duplicate entry for the same key - // is found - let mut empty_entries = BTreeSet::new(); - for (key, value) in entries { - let old_value = tree.update_leaf(key, value)?; - if old_value != Self::EMPTY_VALUE || empty_entries.contains(&key) { - return Err(MerkleError::DuplicateValuesForIndex(key)); + // This being a sparse data structure, the EMPTY_WORD is not assigned to the `BTreeMap`, so + // entries with the empty value need additional tracking. + let mut key_set_to_zero = BTreeSet::new(); + + for (idx, (key, value)) in entries.into_iter().enumerate() { + if idx >= max_num_entries { + return Err(MerkleError::InvalidNumEntries(max_num_entries)); } - // if we've processed an empty entry, add the key to the set of empty entry keys, and - // if this key was already in the set, return an error - if value == Self::EMPTY_VALUE && !empty_entries.insert(key) { + + let old_value = tree.update_leaf(key, value)?; + + if old_value != Self::EMPTY_VALUE || key_set_to_zero.contains(&key) { return Err(MerkleError::DuplicateValuesForIndex(key)); } + + if value == Self::EMPTY_VALUE { + key_set_to_zero.insert(key); + }; } Ok(tree) } + /// Wrapper around [`SimpleSmt::with_leaves`] which inserts leaves at contiguous indices + /// starting at index 0. + pub fn with_contiguous_leaves( + depth: u8, + entries: impl IntoIterator, + ) -> Result { + Self::with_leaves( + depth, + entries + .into_iter() + .enumerate() + .map(|(idx, word)| (idx.try_into().expect("tree max depth is 2^8"), word)), + ) + } + // PUBLIC ACCESSORS // -------------------------------------------------------------------------------------------- @@ -133,10 +145,12 @@ impl SimpleSmt { } else if index.depth() == self.depth() { // the lookup in empty_hashes could fail only if empty_hashes were not built correctly // by the constructor as we check the depth of the lookup above. - Ok(RpoDigest::from( - self.get_leaf_node(index.value()) - .unwrap_or_else(|| *self.empty_hashes[index.depth() as usize]), - )) + let leaf_pos = index.value(); + let leaf = match self.get_leaf_node(leaf_pos) { + Some(word) => word.into(), + None => *EmptySubtreeRoots::entry(self.depth, index.depth()), + }; + Ok(leaf) } else { Ok(self.get_branch_node(&index).parent()) } @@ -214,6 +228,9 @@ impl SimpleSmt { /// # Errors /// Returns an error if the index is greater than the maximum tree capacity, that is 2^{depth}. pub fn update_leaf(&mut self, index: u64, value: Word) -> Result { + // validate the index before modifying the structure + let idx = NodeIndex::new(self.depth(), index)?; + let old_value = self.insert_leaf_node(index, value).unwrap_or(Self::EMPTY_VALUE); // if the old value and new value are the same, there is nothing to update @@ -221,8 +238,82 @@ impl SimpleSmt { return Ok(value); } - let mut index = NodeIndex::new(self.depth(), index)?; - let mut value = RpoDigest::from(value); + self.recompute_nodes_from_index_to_root(idx, RpoDigest::from(value)); + + Ok(old_value) + } + + /// Inserts a subtree at the specified index. The depth at which the subtree is inserted is + /// computed as `self.depth() - subtree.depth()`. + /// + /// Returns the new root. + pub fn set_subtree( + &mut self, + subtree_insertion_index: u64, + subtree: SimpleSmt, + ) -> Result { + if subtree.depth() > self.depth() { + return Err(MerkleError::InvalidSubtreeDepth { + subtree_depth: subtree.depth(), + tree_depth: self.depth(), + }); + } + + // Verify that `subtree_insertion_index` is valid. + let subtree_root_insertion_depth = self.depth() - subtree.depth(); + let subtree_root_index = + NodeIndex::new(subtree_root_insertion_depth, subtree_insertion_index)?; + + // add leaves + // -------------- + + // The subtree's leaf indices live in their own context - i.e. a subtree of depth `d`. If we + // insert the subtree at `subtree_insertion_index = 0`, then the subtree leaf indices are + // valid as they are. However, consider what happens when we insert at + // `subtree_insertion_index = 1`. The first leaf of our subtree now will have index `2^d`; + // you can see it as there's a full subtree sitting on its left. In general, for + // `subtree_insertion_index = i`, there are `i` subtrees sitting before the subtree we want + // to insert, so we need to adjust all its leaves by `i * 2^d`. + let leaf_index_shift: u64 = subtree_insertion_index * 2_u64.pow(subtree.depth().into()); + for (subtree_leaf_idx, leaf_value) in subtree.leaves() { + let new_leaf_idx = leaf_index_shift + subtree_leaf_idx; + debug_assert!(new_leaf_idx < 2_u64.pow(self.depth().into())); + + self.insert_leaf_node(new_leaf_idx, *leaf_value); + } + + // add subtree's branch nodes (which includes the root) + // -------------- + for (branch_idx, branch_node) in subtree.branches { + let new_branch_idx = { + let new_depth = subtree_root_insertion_depth + branch_idx.depth(); + let new_value = subtree_insertion_index * 2_u64.pow(branch_idx.depth().into()) + + branch_idx.value(); + + NodeIndex::new(new_depth, new_value).expect("index guaranteed to be valid") + }; + + self.branches.insert(new_branch_idx, branch_node); + } + + // recompute nodes starting from subtree root + // -------------- + self.recompute_nodes_from_index_to_root(subtree_root_index, subtree.root); + + Ok(self.root) + } + + // HELPER METHODS + // -------------------------------------------------------------------------------------------- + + /// Recomputes the branch nodes (including the root) from `index` all the way to the root. + /// `node_hash_at_index` is the hash of the node stored at index. + fn recompute_nodes_from_index_to_root( + &mut self, + mut index: NodeIndex, + node_hash_at_index: RpoDigest, + ) { + let mut value = node_hash_at_index; for _ in 0..index.depth() { let is_right = index.is_value_odd(); index.move_up(); @@ -232,12 +323,8 @@ impl SimpleSmt { value = Rpo256::merge(&[left, right]); } self.root = value; - Ok(old_value) } - // HELPER METHODS - // -------------------------------------------------------------------------------------------- - fn get_leaf_node(&self, key: u64) -> Option { self.leaves.get(&key).copied() } @@ -248,8 +335,8 @@ impl SimpleSmt { fn get_branch_node(&self, index: &NodeIndex) -> BranchNode { self.branches.get(index).cloned().unwrap_or_else(|| { - let node = self.empty_hashes[index.depth() as usize + 1]; - BranchNode { left: node, right: node } + let node = EmptySubtreeRoots::entry(self.depth, index.depth() + 1); + BranchNode { left: *node, right: *node } }) } diff --git a/src/merkle/simple_smt/tests.rs b/src/merkle/simple_smt/tests.rs index f2c55d19..3d270be7 100644 --- a/src/merkle/simple_smt/tests.rs +++ b/src/merkle/simple_smt/tests.rs @@ -3,7 +3,7 @@ use super::{ NodeIndex, Rpo256, Vec, }; use crate::{ - merkle::{digests_to_words, int_to_leaf, int_to_node}, + merkle::{digests_to_words, int_to_leaf, int_to_node, EmptySubtreeRoots}, Word, }; @@ -71,6 +71,21 @@ fn build_sparse_tree() { assert_eq!(old_value, EMPTY_WORD); } +/// Tests that [`SimpleSmt::with_contiguous_leaves`] works as expected +#[test] +fn build_contiguous_tree() { + let tree_with_leaves = SimpleSmt::with_leaves( + 2, + [0, 1, 2, 3].into_iter().zip(digests_to_words(&VALUES4).into_iter()), + ) + .unwrap(); + + let tree_with_contiguous_leaves = + SimpleSmt::with_contiguous_leaves(2, digests_to_words(&VALUES4).into_iter()).unwrap(); + + assert_eq!(tree_with_leaves, tree_with_contiguous_leaves); +} + #[test] fn test_depth2_tree() { let tree = @@ -214,22 +229,31 @@ fn small_tree_opening_is_consistent() { } #[test] -fn fail_on_duplicates() { - let entries = [(1_u64, int_to_leaf(1)), (5, int_to_leaf(2)), (1_u64, int_to_leaf(3))]; - let smt = SimpleSmt::with_leaves(64, entries); - assert!(smt.is_err()); - - let entries = [(1_u64, int_to_leaf(0)), (5, int_to_leaf(2)), (1_u64, int_to_leaf(0))]; - let smt = SimpleSmt::with_leaves(64, entries); - assert!(smt.is_err()); +fn test_simplesmt_fail_on_duplicates() { + let values = [ + // same key, same value + (int_to_leaf(1), int_to_leaf(1)), + // same key, different values + (int_to_leaf(1), int_to_leaf(2)), + // same key, set to zero + (EMPTY_WORD, int_to_leaf(1)), + // same key, re-set to zero + (int_to_leaf(1), EMPTY_WORD), + // same key, set to zero twice + (EMPTY_WORD, EMPTY_WORD), + ]; - let entries = [(1_u64, int_to_leaf(0)), (5, int_to_leaf(2)), (1_u64, int_to_leaf(1))]; - let smt = SimpleSmt::with_leaves(64, entries); - assert!(smt.is_err()); + for (first, second) in values.iter() { + // consecutive + let entries = [(1, *first), (1, *second)]; + let smt = SimpleSmt::with_leaves(64, entries); + assert_eq!(smt.unwrap_err(), MerkleError::DuplicateValuesForIndex(1)); - let entries = [(1_u64, int_to_leaf(1)), (5, int_to_leaf(2)), (1_u64, int_to_leaf(0))]; - let smt = SimpleSmt::with_leaves(64, entries); - assert!(smt.is_err()); + // not consecutive + let entries = [(1, *first), (5, int_to_leaf(5)), (1, *second)]; + let smt = SimpleSmt::with_leaves(64, entries); + assert_eq!(smt.unwrap_err(), MerkleError::DuplicateValuesForIndex(1)); + } } #[test] @@ -239,6 +263,227 @@ fn with_no_duplicates_empty_node() { assert!(smt.is_ok()); } +#[test] +fn test_simplesmt_update_nonexisting_leaf_with_zero() { + // TESTING WITH EMPTY WORD + // -------------------------------------------------------------------------------------------- + + // Depth 1 has 2 leaf. Position is 0-indexed, position 2 doesn't exist. + let mut smt = SimpleSmt::new(1).unwrap(); + let result = smt.update_leaf(2, EMPTY_WORD); + assert!(!smt.leaves.contains_key(&2)); + assert!(result.is_err()); + + // Depth 2 has 4 leaves. Position is 0-indexed, position 4 doesn't exist. + let mut smt = SimpleSmt::new(2).unwrap(); + let result = smt.update_leaf(4, EMPTY_WORD); + assert!(!smt.leaves.contains_key(&4)); + assert!(result.is_err()); + + // Depth 3 has 8 leaves. Position is 0-indexed, position 8 doesn't exist. + let mut smt = SimpleSmt::new(3).unwrap(); + let result = smt.update_leaf(8, EMPTY_WORD); + assert!(!smt.leaves.contains_key(&8)); + assert!(result.is_err()); + + // TESTING WITH A VALUE + // -------------------------------------------------------------------------------------------- + let value = int_to_node(1); + + // Depth 1 has 2 leaves. Position is 0-indexed, position 1 doesn't exist. + let mut smt = SimpleSmt::new(1).unwrap(); + let result = smt.update_leaf(2, *value); + assert!(!smt.leaves.contains_key(&2)); + assert!(result.is_err()); + + // Depth 2 has 4 leaves. Position is 0-indexed, position 2 doesn't exist. + let mut smt = SimpleSmt::new(2).unwrap(); + let result = smt.update_leaf(4, *value); + assert!(!smt.leaves.contains_key(&4)); + assert!(result.is_err()); + + // Depth 3 has 8 leaves. Position is 0-indexed, position 4 doesn't exist. + let mut smt = SimpleSmt::new(3).unwrap(); + let result = smt.update_leaf(8, *value); + assert!(!smt.leaves.contains_key(&8)); + assert!(result.is_err()); +} + +#[test] +fn test_simplesmt_with_leaves_nonexisting_leaf() { + // TESTING WITH EMPTY WORD + // -------------------------------------------------------------------------------------------- + + // Depth 1 has 2 leaf. Position is 0-indexed, position 2 doesn't exist. + let leaves = [(2, EMPTY_WORD)]; + let result = SimpleSmt::with_leaves(1, leaves); + assert!(result.is_err()); + + // Depth 2 has 4 leaves. Position is 0-indexed, position 4 doesn't exist. + let leaves = [(4, EMPTY_WORD)]; + let result = SimpleSmt::with_leaves(2, leaves); + assert!(result.is_err()); + + // Depth 3 has 8 leaves. Position is 0-indexed, position 8 doesn't exist. + let leaves = [(8, EMPTY_WORD)]; + let result = SimpleSmt::with_leaves(3, leaves); + assert!(result.is_err()); + + // TESTING WITH A VALUE + // -------------------------------------------------------------------------------------------- + let value = int_to_node(1); + + // Depth 1 has 2 leaves. Position is 0-indexed, position 2 doesn't exist. + let leaves = [(2, *value)]; + let result = SimpleSmt::with_leaves(1, leaves); + assert!(result.is_err()); + + // Depth 2 has 4 leaves. Position is 0-indexed, position 4 doesn't exist. + let leaves = [(4, *value)]; + let result = SimpleSmt::with_leaves(2, leaves); + assert!(result.is_err()); + + // Depth 3 has 8 leaves. Position is 0-indexed, position 8 doesn't exist. + let leaves = [(8, *value)]; + let result = SimpleSmt::with_leaves(3, leaves); + assert!(result.is_err()); +} + +#[test] +fn test_simplesmt_set_subtree() { + // Final Tree: + // + // ____k____ + // / \ + // _i_ _j_ + // / \ / \ + // e f g h + // / \ / \ / \ / \ + // a b 0 0 c 0 0 d + + let z = EMPTY_WORD; + + let a = Word::from(Rpo256::merge(&[z.into(); 2])); + let b = Word::from(Rpo256::merge(&[a.into(); 2])); + let c = Word::from(Rpo256::merge(&[b.into(); 2])); + let d = Word::from(Rpo256::merge(&[c.into(); 2])); + + let e = Rpo256::merge(&[a.into(), b.into()]); + let f = Rpo256::merge(&[z.into(), z.into()]); + let g = Rpo256::merge(&[c.into(), z.into()]); + let h = Rpo256::merge(&[z.into(), d.into()]); + + let i = Rpo256::merge(&[e, f]); + let j = Rpo256::merge(&[g, h]); + + let k = Rpo256::merge(&[i, j]); + + // subtree: + // g + // / \ + // c 0 + let subtree = { + let depth = 1; + let entries = vec![(0, c)]; + SimpleSmt::with_leaves(depth, entries).unwrap() + }; + + // insert subtree + let tree = { + let depth = 3; + let entries = vec![(0, a), (1, b), (7, d)]; + let mut tree = SimpleSmt::with_leaves(depth, entries).unwrap(); + + tree.set_subtree(2, subtree).unwrap(); + + tree + }; + + assert_eq!(tree.root(), k); + assert_eq!(tree.get_leaf(4).unwrap(), c); + assert_eq!(tree.get_branch_node(&NodeIndex::new_unchecked(2, 2)).parent(), g); +} + +/// Ensures that an invalid input node index into `set_subtree()` incurs no mutation of the tree +#[test] +fn test_simplesmt_set_subtree_unchanged_for_wrong_index() { + // Final Tree: + // + // ____k____ + // / \ + // _i_ _j_ + // / \ / \ + // e f g h + // / \ / \ / \ / \ + // a b 0 0 c 0 0 d + + let z = EMPTY_WORD; + + let a = Word::from(Rpo256::merge(&[z.into(); 2])); + let b = Word::from(Rpo256::merge(&[a.into(); 2])); + let c = Word::from(Rpo256::merge(&[b.into(); 2])); + let d = Word::from(Rpo256::merge(&[c.into(); 2])); + + // subtree: + // g + // / \ + // c 0 + let subtree = { + let depth = 1; + let entries = vec![(0, c)]; + SimpleSmt::with_leaves(depth, entries).unwrap() + }; + + let mut tree = { + let depth = 3; + let entries = vec![(0, a), (1, b), (7, d)]; + SimpleSmt::with_leaves(depth, entries).unwrap() + }; + let tree_root_before_insertion = tree.root(); + + // insert subtree + assert!(tree.set_subtree(500, subtree).is_err()); + + assert_eq!(tree.root(), tree_root_before_insertion); +} + +/// We insert an empty subtree that has the same depth as the original tree +#[test] +fn test_simplesmt_set_subtree_entire_tree() { + // Initial Tree: + // + // ____k____ + // / \ + // _i_ _j_ + // / \ / \ + // e f g h + // / \ / \ / \ / \ + // a b 0 0 c 0 0 d + + let z = EMPTY_WORD; + + let a = Word::from(Rpo256::merge(&[z.into(); 2])); + let b = Word::from(Rpo256::merge(&[a.into(); 2])); + let c = Word::from(Rpo256::merge(&[b.into(); 2])); + let d = Word::from(Rpo256::merge(&[c.into(); 2])); + + let depth = 3; + + // subtree: E3 + let subtree = { SimpleSmt::with_leaves(depth, Vec::new()).unwrap() }; + assert_eq!(subtree.root(), *EmptySubtreeRoots::entry(depth, 0)); + + // insert subtree + let mut tree = { + let entries = vec![(0, a), (1, b), (4, c), (7, d)]; + SimpleSmt::with_leaves(depth, entries).unwrap() + }; + + tree.set_subtree(0, subtree).unwrap(); + + assert_eq!(tree.root(), *EmptySubtreeRoots::entry(depth, 0)); +} + // HELPER FUNCTIONS // -------------------------------------------------------------------------------------------- diff --git a/src/merkle/store/tests.rs b/src/merkle/store/tests.rs index dc32ffd6..70bab7a9 100644 --- a/src/merkle/store/tests.rs +++ b/src/merkle/store/tests.rs @@ -1,9 +1,8 @@ use super::{ DefaultMerkleStore as MerkleStore, EmptySubtreeRoots, MerkleError, MerklePath, NodeIndex, - PartialMerkleTree, RecordingMerkleStore, RpoDigest, + PartialMerkleTree, RecordingMerkleStore, Rpo256, RpoDigest, }; use crate::{ - hash::rpo::Rpo256, merkle::{digests_to_words, int_to_leaf, int_to_node, MerkleTree, SimpleSmt}, Felt, Word, ONE, WORD_SIZE, ZERO, }; diff --git a/src/merkle/tiered_smt/proof.rs b/src/merkle/tiered_smt/proof.rs index 28ac2880..1de17d52 100644 --- a/src/merkle/tiered_smt/proof.rs +++ b/src/merkle/tiered_smt/proof.rs @@ -85,18 +85,26 @@ impl TieredSmtProof { /// Note: this method cannot be used to assert non-membership. That is, if false is returned, /// it does not mean that the provided key-value pair is not in the tree. pub fn verify_membership(&self, key: &RpoDigest, value: &Word, root: &RpoDigest) -> bool { - if self.is_value_empty() { - if value != &EMPTY_VALUE { - return false; - } - // if the proof is for an empty value, we can verify it against any key which has a - // common prefix with the key storied in entries, but the prefix must be greater than - // the path length - let common_prefix_tier = get_common_prefix_tier_depth(key, &self.entries[0].0); - if common_prefix_tier < self.path.depth() { - return false; - } - } else if !self.entries.contains(&(*key, *value)) { + // Handles the following scenarios: + // - the value is set + // - empty leaf, there is an explicit entry for the key with the empty value + // - shared 64-bit prefix, the target key is not included in the entries list, the value is implicitly the empty word + let v = match self.entries.iter().find(|(k, _)| k == key) { + Some((_, v)) => v, + None => &EMPTY_VALUE, + }; + + // The value must match for the proof to be valid + if v != value { + return false; + } + + // If the proof is for an empty value, we can verify it against any key which has a common + // prefix with the key storied in entries, but the prefix must be greater than the path + // length + if self.is_value_empty() + && get_common_prefix_tier_depth(key, &self.entries[0].0) < self.path.depth() + { return false; } diff --git a/src/merkle/tiered_smt/tests.rs b/src/merkle/tiered_smt/tests.rs index 560db47c..788ba288 100644 --- a/src/merkle/tiered_smt/tests.rs +++ b/src/merkle/tiered_smt/tests.rs @@ -715,6 +715,38 @@ fn tsmt_bottom_tier_two() { // GET PROOF TESTS // ================================================================================================ +/// Tests the membership and non-membership proof for a single at depth 64 +#[test] +fn tsmt_get_proof_single_element_64() { + let mut smt = TieredSmt::default(); + + let raw_a = 0b_00000000_00000001_00000000_00000001_00000000_00000001_00000000_00000001_u64; + let key_a = [ONE, ONE, ONE, raw_a.into()].into(); + let value_a = [ONE, ONE, ONE, ONE]; + smt.insert(key_a, value_a); + + // push element `a` to depth 64, by inserting another value that shares the 48-bit prefix + let raw_b = 0b_00000000_00000001_00000000_00000001_00000000_00000001_00000000_00000000_u64; + let key_b = [ONE, ONE, ONE, raw_b.into()].into(); + smt.insert(key_b, [ONE, ONE, ONE, ONE]); + + // verify the proof for element `a` + let proof = smt.prove(key_a); + assert!(proof.verify_membership(&key_a, &value_a, &smt.root())); + + // check that a value that is not inserted in the tree produces a valid membership proof for the + // empty word + let key = [ZERO, ZERO, ZERO, ZERO].into(); + let proof = smt.prove(key); + assert!(proof.verify_membership(&key, &EMPTY_WORD, &smt.root())); + + // check that a key that shared the 64-bit prefix with `a`, but is not inserted, also has a + // valid membership proof for the empty word + let key = [ONE, ONE, ZERO, raw_a.into()].into(); + let proof = smt.prove(key); + assert!(proof.verify_membership(&key, &EMPTY_WORD, &smt.root())); +} + #[test] fn tsmt_get_proof() { let mut smt = TieredSmt::default();