From 4f562cc51167e72f58637ec21c51d7a2c0c581af Mon Sep 17 00:00:00 2001 From: Jaehun Kim Date: Fri, 19 Jul 2024 19:53:02 +0900 Subject: [PATCH] Removed all oop-realted implements, Implement hash_chain --- Cargo.toml | 4 +- src/channel/channel_states.rs | 26 +++++ src/channel/fs_prover_channel.rs | 13 --- src/channel/mod.rs | 76 ++++--------- src/lib.rs | 3 +- src/randomness/hash_chain.rs | 176 +++++++++++++++++++++++++++++++ src/randomness/mod.rs | 1 + 7 files changed, 230 insertions(+), 69 deletions(-) create mode 100644 src/randomness/hash_chain.rs create mode 100644 src/randomness/mod.rs diff --git a/Cargo.toml b/Cargo.toml index 7f5cdd9..1de26e0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,4 +7,6 @@ edition = "2021" anyhow = "1.0.86" ark-ff = "0.4.2" ark-poly = "0.4.2" -rand_chacha = "0.3.1" +sha3 = "0.10.8" +ethnum = "1.5.0" +lazy_static = "1.5.0" diff --git a/src/channel/channel_states.rs b/src/channel/channel_states.rs index 7a9b8df..50841c8 100644 --- a/src/channel/channel_states.rs +++ b/src/channel/channel_states.rs @@ -10,6 +10,32 @@ pub struct ChannelStates { pub data_count: usize, } +impl ChannelStates { + pub fn increment_byte_count(&mut self, n: usize) { + self.byte_count += n; + } + + pub fn increment_commitment_count(&mut self) { + self.commitment_count += 1; + } + + fn is_query_phase(&self) -> bool { + self.is_query_phase + } + + fn begin_query_phase(&mut self) { + self.is_query_phase = true; + } + + fn increment_hash_count(&mut self) { + self.hash_count += 1; + } + + fn increment_field_element_count(&mut self, n: usize) { + self.field_element_count += n; + } +} + impl fmt::Display for ChannelStates { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( diff --git a/src/channel/fs_prover_channel.rs b/src/channel/fs_prover_channel.rs index 0024d28..932c962 100644 --- a/src/channel/fs_prover_channel.rs +++ b/src/channel/fs_prover_channel.rs @@ -24,19 +24,6 @@ impl FSProverChannel { } } -impl AsMut for FSProverChannel { - fn as_mut(&mut self) -> &mut ChannelStates { - &mut self.states - } -} - -impl AsRef for FSProverChannel { - fn as_ref(&self) -> &ChannelStates { - &self.states - } -} - - impl Channel for FSProverChannel { type Field = F; diff --git a/src/channel/mod.rs b/src/channel/mod.rs index a1d4404..7a58e9c 100644 --- a/src/channel/mod.rs +++ b/src/channel/mod.rs @@ -8,11 +8,10 @@ pub mod tests; use ark_ff::Field; use channel_states::ChannelStates; -use std::convert::{AsMut, AsRef}; use crate::hashutil::TempHashContainer; #[allow(dead_code)] -trait Channel: AsRef + AsMut { +trait Channel { type Field: Field; fn recv_felem(&mut self, felem: Self::Field) -> Result; @@ -22,71 +21,40 @@ trait Channel: AsRef + AsMut { fn random_number(&mut self, bound: u64) -> u64; fn random_field(&mut self) -> Self::Field; +} - /// Only relevant for non-interactive channels. Changes the channel seed to a "safer" seed. - /// - /// This function guarantees that randomness fetched from the channel after calling this function - /// and before sending data from the prover to the channel, is "safe": A malicious - /// prover will have to perform 2^security_bits operations for each attempt to randomize the fetched - /// randomness. - /// - /// Increases the amount of work a malicious prover needs to perform, in order to fake a proof. - #[allow(unused_variables)] - fn apply_proof_of_work(&mut self, security_bits: usize) -> Result<(), anyhow::Error> { - Err(anyhow::Error::msg("Not a fs-channel")) - } +trait FSChannel: Channel { + fn apply_proof_of_work(&mut self, security_bits: usize) -> Result<(), anyhow::Error>; fn is_end_of_proof(&self) -> bool; - - fn is_query_phase(&self) -> bool { - AsRef::::as_ref(self).is_query_phase - } - - fn begin_query_phase(&mut self) { - AsMut::::as_mut(self).is_query_phase = true; - } - - fn increment_byte_count(&mut self, n: usize) { - AsMut::::as_mut(self).byte_count += n; - } - - fn increment_commitment_count(&mut self) { - AsMut::::as_mut(self).commitment_count += 1; - } - - fn increment_hash_count(&mut self) { - AsMut::::as_mut(self).hash_count += 1; - } - - fn increment_field_element_count(&mut self, n: usize) { - AsMut::::as_mut(self).field_element_count += n; - } } -trait VerifierChannel: Channel { +trait VerifierChannel: FSChannel { type HashT: TempHashContainer; - fn recv_commit_hash(&mut self) -> Result { - let bytes = self.recv_bytes(Self::HashT::size())?; - let mut hash = Self::HashT::init_empty(); - hash.update(&bytes); - self.increment_commitment_count(); - self.increment_hash_count(); - Ok(hash) - } + fn recv_commit_hash(&mut self) -> Result; + // { + // let bytes = self.recv_bytes(Self::HashT::size())?; + // let mut hash = Self::HashT::init_empty(); + // hash.update(&bytes); + // self.increment_commitment_count(); + // self.increment_hash_count(); + // Ok(hash) + // } } -trait ProverChannel: Channel { +trait ProverChannel: FSChannel { type HashT: TempHashContainer; fn send_felts(&mut self, felts: Vec) -> Result<(), anyhow::Error>; fn send_bytes(&mut self, bytes: Vec) -> Result<(), anyhow::Error>; - fn send_commit_hash(&mut self, hash: Self::HashT) -> Result<(), anyhow::Error> { - self.send_bytes(hash.hash())?; - self.increment_commitment_count(); - self.increment_hash_count(); - Ok(()) - } + fn send_commit_hash(&mut self, hash: Self::HashT) -> Result<(), anyhow::Error>; + // { + // self.send_bytes(hash.hash())?; + // self.increment_commitment_count(); + // self.increment_hash_count(); + // Ok(()) + // } } diff --git a/src/lib.rs b/src/lib.rs index 44be8b1..7d0ddb1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ mod felt252; -pub mod channel; +// pub mod channel; mod hashutil; mod merkle; +pub mod randomness; pub use felt252::Felt252; diff --git a/src/randomness/hash_chain.rs b/src/randomness/hash_chain.rs new file mode 100644 index 0000000..71084a1 --- /dev/null +++ b/src/randomness/hash_chain.rs @@ -0,0 +1,176 @@ +use std::ops::Add; + +use ethnum::U256; +use sha3::{Digest, Keccak256}; + +const KECCAK256_DIGEST_NUM_BYTES: usize = 32; + +pub struct HashChain { + digest: [u8; KECCAK256_DIGEST_NUM_BYTES], + spare_bytes: [u8; KECCAK256_DIGEST_NUM_BYTES * 2], + num_spare_bytes: usize, + counter: u64, +} + +impl Default for HashChain { + fn default() -> Self { + Self { + digest: [0u8; KECCAK256_DIGEST_NUM_BYTES], + spare_bytes: [0u8; KECCAK256_DIGEST_NUM_BYTES * 2], + num_spare_bytes: 0, + counter: 0, + } + } +} + +impl HashChain { + pub fn new(digest: &[u8; KECCAK256_DIGEST_NUM_BYTES]) -> Self { + Self { + digest: *digest, + spare_bytes: [0u8; KECCAK256_DIGEST_NUM_BYTES * 2], + num_spare_bytes: 0, + counter: 0, + } + } + + pub fn random_bytes(&mut self, random_bytes_out: &mut [u8]) { + let num_bytes = random_bytes_out.len(); + let num_full_blocks = num_bytes / KECCAK256_DIGEST_NUM_BYTES; + + for offset in + (0..num_full_blocks * KECCAK256_DIGEST_NUM_BYTES).step_by(KECCAK256_DIGEST_NUM_BYTES) + { + self.fill_random_bytes( + &mut random_bytes_out[offset..offset + KECCAK256_DIGEST_NUM_BYTES], + ); + } + + // If there are any bytes left, copy them from the spare bytes, otherwise get more random bytes + let num_tail_bytes = num_bytes % KECCAK256_DIGEST_NUM_BYTES; + if num_tail_bytes <= self.num_spare_bytes { + random_bytes_out[num_full_blocks * KECCAK256_DIGEST_NUM_BYTES..num_bytes] + .copy_from_slice(&self.spare_bytes[..num_tail_bytes]); + self.num_spare_bytes -= num_tail_bytes; + + // Shift the spare bytes to the left to remove the bytes we just copied + self.spare_bytes.copy_within(num_tail_bytes.., 0); + } else { + self.fill_random_bytes( + &mut random_bytes_out[num_full_blocks * KECCAK256_DIGEST_NUM_BYTES..num_bytes], + ); + } + } + + fn fill_random_bytes(&mut self, out: &mut [u8]) { + let num_bytes: usize = out.len(); + assert!( + num_bytes <= KECCAK256_DIGEST_NUM_BYTES, + "Asked to get more bytes than one digest size" + ); + + let prandom_bytes = self.next_hash(); + out.copy_from_slice(&prandom_bytes[..num_bytes]); + + assert!( + self.num_spare_bytes < KECCAK256_DIGEST_NUM_BYTES + num_bytes, + "Not enough room in spare bytes buffer. Have {} bytes and want to add {} bytes", + self.num_spare_bytes, + KECCAK256_DIGEST_NUM_BYTES - num_bytes + ); + + self.spare_bytes + [self.num_spare_bytes..self.num_spare_bytes + (KECCAK256_DIGEST_NUM_BYTES - num_bytes)] + .copy_from_slice(&prandom_bytes[num_bytes..]); + self.num_spare_bytes += KECCAK256_DIGEST_NUM_BYTES - num_bytes; + self.counter += 1; + } + + fn next_hash(&self) -> [u8; KECCAK256_DIGEST_NUM_BYTES] { + // TODO: below code is not efficient, but it works for now + let mut hasher = Keccak256::new(); + hasher.update(&self.digest); + hasher.update(&self.counter.to_le_bytes()); + let result = hasher.finalize(); + let mut hash_bytes = [0u8; KECCAK256_DIGEST_NUM_BYTES]; + hash_bytes.copy_from_slice(&result); + hash_bytes + } + + pub fn update_hash_chain(&mut self, raw_bytes: &[u8]) { + let seed_increment: u64 = 0; + self.mix_seed_with_bytes(raw_bytes, seed_increment); + } + + pub fn mix_seed_with_bytes(&mut self, raw_bytes: &[u8], seed_increment: u64) { + let mut mixed_bytes = vec![0u8; KECCAK256_DIGEST_NUM_BYTES + raw_bytes.len()]; + + // Deserialize the current digest into a u64 array + let big_int = U256::from_be_bytes(self.digest).add(U256::from(seed_increment)); + + // Serialize the incremented big_int back into the mixed_bytes + mixed_bytes[..KECCAK256_DIGEST_NUM_BYTES].copy_from_slice(&big_int.to_be_bytes()); + + // Copy the raw_bytes into the mixed_bytes + mixed_bytes[KECCAK256_DIGEST_NUM_BYTES..].copy_from_slice(raw_bytes); + + // Hash the mixed_bytes to update the digest + let mut hasher = Keccak256::new(); + hasher.update(&mixed_bytes); + let result = hasher.finalize(); + self.digest.copy_from_slice(&result); + + self.num_spare_bytes = 0; + self.counter = 0; + } + + pub fn get_hash_chain_state(&self) -> &[u8; KECCAK256_DIGEST_NUM_BYTES] { + &self.digest + } +} + +#[cfg(test)] +mod tests { + use super::*; + + const RANDOM_BYTES_1ST_KECCAK256: [u8; 8] = [0x07, 0x7C, 0xE2, 0x30, 0x83, 0x44, 0x67, 0xE7]; + const RANDOM_BYTES_1000TH_KECCAK256: [u8; 8] = [0xD1, 0x74, 0x78, 0xD2, 0x31, 0xC2, 0xAF, 0x63]; + const RANDOM_BYTES_1001ST_KECCAK256: [u8; 8] = [0xA0, 0xDA, 0xBD, 0x71, 0xEE, 0xAB, 0x82, 0xAC]; + + lazy_static::lazy_static! { + static ref RANDOM_BYTES_KECCAK256: std::collections::HashMap> = { + let mut m = std::collections::HashMap::new(); + m.insert(1, RANDOM_BYTES_1ST_KECCAK256.to_vec()); + m.insert(1000, RANDOM_BYTES_1000TH_KECCAK256.to_vec()); + m.insert(1001, RANDOM_BYTES_1001ST_KECCAK256.to_vec()); + m + }; + + static ref EXPECTED_RANDOM_BYTE_VECTORS: std::collections::HashMap>> = { + let mut m = std::collections::HashMap::new(); + m.insert(1,RANDOM_BYTES_KECCAK256.clone()); + m + }; + } + + #[test] + fn test_hash_chain_get_randoms() { + let mut bytes_1: [u8; 8] = [0u8; 8]; + let mut bytes_2 = [0u8; 8]; + + let mut hash_ch_1 = HashChain::new(&[0u8; 32]); + let mut hash_ch_2 = HashChain::new(&[0u8; 32]); + let stat1 = hash_ch_1.get_hash_chain_state().clone(); + hash_ch_1.random_bytes(&mut bytes_1); + hash_ch_2.random_bytes(&mut bytes_2); + + for _ in 0..1000 { + hash_ch_1.random_bytes(&mut bytes_1); + hash_ch_2.random_bytes(&mut bytes_2); + } + + assert_eq!(stat1, *hash_ch_1.get_hash_chain_state()); + assert_eq!(stat1, *hash_ch_2.get_hash_chain_state()); + assert_eq!(bytes_1, bytes_2); + } + +} diff --git a/src/randomness/mod.rs b/src/randomness/mod.rs new file mode 100644 index 0000000..330136e --- /dev/null +++ b/src/randomness/mod.rs @@ -0,0 +1 @@ +pub mod hash_chain;