diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 178af5b680..f68eb45629 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -30,6 +30,11 @@ snarkVM is a big project, so (non-)adherence to best practices related to perfor - if possible, reuse collections; an example would be a loop that needs a clean vector on each iteration: instead of creating and allocating it over and over, create it _before_ the loop and use `.clear()` on every iteration instead - try to keep the sizes of `enum` variants uniform; use `Box` on ones that are large +### Cross-platform consistency +- First and foremost, types which contain consensus- or cryptographic logic should have a consistent size across platforms. Their serialized output should not contain `usize`. For defense in depth, we serialize `usize` as `u64`. +- For clarity, use `u32` and `u64` as much or long as possible, especially in type definitions. +- Given that we only target 32- and 64-bit systems, casting `usize` as `u64` and casting `u32` as `usize` will always be safe and doesn't need a `try_from::`. In serialization code, for defense in depth it is still encouraged to use `try_from::`. + ### Misc. performance - avoid the `format!()` macro; if it is used only to convert a single value to a `String`, use `.to_string()` instead, which is also available to all the implementors of `Display` diff --git a/algorithms/examples/msm.rs b/algorithms/examples/msm.rs index c3abbff5ab..5eab47ebae 100644 --- a/algorithms/examples/msm.rs +++ b/algorithms/examples/msm.rs @@ -90,8 +90,8 @@ pub fn main() -> Result<()> { // Parse the variant. match args[1].as_str() { - "batched" => batched::msm(bases.as_slice(), scalars.as_slice()), - "standard" => standard::msm(bases.as_slice(), scalars.as_slice()), + "batched" => batched::msm(bases.as_slice(), scalars.as_slice())?, + "standard" => standard::msm(bases.as_slice(), scalars.as_slice())?, _ => panic!("Invalid variant: use 'batched' or 'standard'"), }; diff --git a/algorithms/src/crypto_hash/poseidon.rs b/algorithms/src/crypto_hash/poseidon.rs index 1d67be957a..5131110da9 100644 --- a/algorithms/src/crypto_hash/poseidon.rs +++ b/algorithms/src/crypto_hash/poseidon.rs @@ -322,7 +322,7 @@ impl PoseidonSponge { let capacity = F::size_in_bits() - 1; let mut dest_limbs = Vec::::new(); - let params = get_params(TargetField::size_in_bits(), F::size_in_bits(), ty); + let params = get_params(TargetField::size_in_bits_u32(), F::size_in_bits_u32(), ty); let adjustment_factor_lookup_table = { let mut table = Vec::::new(); @@ -342,9 +342,9 @@ impl PoseidonSponge { let first = &src_limbs[i]; let second = if i + 1 < src_len { Some(&src_limbs[i + 1]) } else { None }; - let first_max_bits_per_limb = params.bits_per_limb + crate::overhead!(first.1 + F::one()); + let first_max_bits_per_limb = params.bits_per_limb as usize + crate::overhead!(first.1 + F::one()); let second_max_bits_per_limb = if let Some(second) = second { - params.bits_per_limb + crate::overhead!(second.1 + F::one()) + params.bits_per_limb as usize + crate::overhead!(second.1 + F::one()) } else { 0 }; @@ -382,18 +382,19 @@ impl PoseidonSponge { elem: &::BigInteger, optimization_type: OptimizationType, ) -> SmallVec<[F; 10]> { - let params = get_params(TargetField::size_in_bits(), F::size_in_bits(), optimization_type); + let params = get_params(TargetField::size_in_bits_u32(), F::size_in_bits_u32(), optimization_type); // Push the lower limbs first let mut limbs: SmallVec<[F; 10]> = SmallVec::new(); let mut cur = *elem; for _ in 0..params.num_limbs { let cur_bits = cur.to_bits_be(); // `to_bits` is big endian - let cur_mod_r = - ::BigInteger::from_bits_be(&cur_bits[cur_bits.len() - params.bits_per_limb..]) - .unwrap(); // therefore, the lowest `bits_per_non_top_limb` bits is what we want. + let cur_mod_r = ::BigInteger::from_bits_be( + &cur_bits[cur_bits.len() - params.bits_per_limb as usize..], + ) + .unwrap(); // therefore, the lowest `bits_per_non_top_limb` bits is what we want. limbs.push(F::from_bigint(cur_mod_r).unwrap()); - cur.divn(params.bits_per_limb as u32); + cur.divn(params.bits_per_limb); } // then we reserve, so that the limbs are ``big limb first'' diff --git a/algorithms/src/errors.rs b/algorithms/src/errors.rs index 5efb2da697..d217e59069 100644 --- a/algorithms/src/errors.rs +++ b/algorithms/src/errors.rs @@ -20,30 +20,36 @@ pub enum SNARKError { #[error("{}", _0)] AnyhowError(#[from] anyhow::Error), + #[error("Batch size was different between public input and proof")] + BatchSizeMismatch, + + #[error("Circuit not found")] + CircuitNotFound, + #[error("{}", _0)] ConstraintFieldError(#[from] ConstraintFieldError), #[error("{}: {}", _0, _1)] Crate(&'static str, String), + #[error("Batch size was zero; must be at least 1")] + EmptyBatch, + #[error("Expected a circuit-specific SRS in SNARK")] ExpectedCircuitSpecificSRS, + #[error(transparent)] + IntError(#[from] std::num::TryFromIntError), + #[error("{}", _0)] Message(String), + #[error(transparent)] + ParseIntError(#[from] std::num::ParseIntError), + #[error("{}", _0)] SynthesisError(SynthesisError), - #[error("Batch size was zero; must be at least 1")] - EmptyBatch, - - #[error("Batch size was different between public input and proof")] - BatchSizeMismatch, - - #[error("Circuit not found")] - CircuitNotFound, - #[error("terminated")] Terminated, } diff --git a/algorithms/src/fft/domain.rs b/algorithms/src/fft/domain.rs index d2d4cbb49b..c73d920022 100644 --- a/algorithms/src/fft/domain.rs +++ b/algorithms/src/fft/domain.rs @@ -79,7 +79,7 @@ const MIN_PARALLEL_CHUNK_SIZE: usize = 1 << 7; #[derive(Copy, Clone, Hash, Eq, PartialEq, CanonicalSerialize, CanonicalDeserialize)] pub struct EvaluationDomain { /// The size of the domain. - pub size: u64, + pub size: usize, /// `log_2(self.size)`. pub log_size_of_group: u32, /// Size of the domain as a field element. @@ -114,7 +114,7 @@ impl EvaluationDomain { /// having `num_coeffs` coefficients. pub fn new(num_coeffs: usize) -> Option { // Compute the size of our evaluation domain - let size = num_coeffs.checked_next_power_of_two()? as u64; + let size = num_coeffs.checked_next_power_of_two()?; let log_size_of_group = size.trailing_zeros(); // libfqfft uses > https://github.com/scipr-lab/libfqfft/blob/e0183b2cef7d4c5deb21a6eaf3fe3b586d738fe0/libfqfft/evaluation_domain/domains/basic_radix2_domain.tcc#L33 @@ -124,12 +124,13 @@ impl EvaluationDomain { // Compute the generator for the multiplicative subgroup. // It should be the 2^(log_size_of_group) root of unity. - let group_gen = F::get_root_of_unity(size as usize)?; + let group_gen = F::get_root_of_unity(size)?; + let size_u64 = size as u64; // Check that it is indeed the 2^(log_size_of_group) root of unity. - debug_assert_eq!(group_gen.pow([size]), F::one()); + debug_assert_eq!(group_gen.pow([size_u64]), F::one()); - let size_as_field_element = F::from(size); + let size_as_field_element = F::from(size_u64); let size_inv = size_as_field_element.inverse()?; Some(EvaluationDomain { @@ -152,7 +153,7 @@ impl EvaluationDomain { /// Return the size of `self`. pub fn size(&self) -> usize { - self.size as usize + self.size } /// Compute an FFT. @@ -254,8 +255,8 @@ impl EvaluationDomain { /// `tau`. pub fn evaluate_all_lagrange_coefficients(&self, tau: F) -> Vec { // Evaluate all Lagrange polynomials - let size = self.size as usize; - let t_size = tau.pow([self.size]); + let size = self.size(); + let t_size = tau.pow([self.size as u64]); let one = F::one(); if t_size.is_one() { let mut u = vec![F::zero(); size]; @@ -297,7 +298,7 @@ impl EvaluationDomain { /// This evaluates the vanishing polynomial for this domain at tau. /// For multiplicative subgroups, this polynomial is `z(X) = X^self.size - 1`. pub fn evaluate_vanishing_polynomial(&self, tau: F) -> F { - tau.pow([self.size]) - F::one() + tau.pow([self.size as u64]) - F::one() } /// Return an iterator over the elements of the domain. @@ -373,7 +374,7 @@ impl EvaluationDomain { // SNP TODO: how to set threshold and check that the type is Fr if self.size >= 32 && std::mem::size_of::() == 32 { let result = snarkvm_algorithms_cuda::NTT( - self.size as usize, + self.size(), x_s, snarkvm_algorithms_cuda::NTTInputOutputOrder::NN, snarkvm_algorithms_cuda::NTTDirection::Forward, @@ -402,7 +403,7 @@ impl EvaluationDomain { // SNP TODO: how to set threshold if self.size >= 32 && std::mem::size_of::() == 32 { let result = snarkvm_algorithms_cuda::NTT( - self.size as usize, + self.size(), x_s, snarkvm_algorithms_cuda::NTTInputOutputOrder::NN, snarkvm_algorithms_cuda::NTTDirection::Inverse, @@ -423,7 +424,7 @@ impl EvaluationDomain { // SNP TODO: how to set threshold if self.size >= 32 && std::mem::size_of::() == 32 { let result = snarkvm_algorithms_cuda::NTT( - self.size as usize, + self.size(), x_s, snarkvm_algorithms_cuda::NTTInputOutputOrder::NN, snarkvm_algorithms_cuda::NTTDirection::Inverse, @@ -450,7 +451,7 @@ impl EvaluationDomain { // SNP TODO: how to set threshold if self.size >= 32 && std::mem::size_of::() == 32 { let result = snarkvm_algorithms_cuda::NTT( - self.size as usize, + self.size(), x_s, snarkvm_algorithms_cuda::NTTInputOutputOrder::NN, snarkvm_algorithms_cuda::NTTDirection::Forward, @@ -481,7 +482,7 @@ impl EvaluationDomain { // SNP TODO: how to set threshold if self.size >= 32 && std::mem::size_of::() == 32 { let result = snarkvm_algorithms_cuda::NTT( - self.size as usize, + self.size(), x_s, snarkvm_algorithms_cuda::NTTInputOutputOrder::NN, snarkvm_algorithms_cuda::NTTDirection::Inverse, @@ -515,7 +516,7 @@ impl EvaluationDomain { // SNP TODO: how to set threshold if self.size >= 32 && std::mem::size_of::() == 32 { let result = snarkvm_algorithms_cuda::NTT( - self.size as usize, + self.size(), x_s, snarkvm_algorithms_cuda::NTTInputOutputOrder::NN, snarkvm_algorithms_cuda::NTTDirection::Inverse, @@ -583,17 +584,17 @@ impl EvaluationDomain { // [1, g, g^2, ..., g^{(n/2) - 1}] #[cfg(feature = "serial")] pub fn roots_of_unity(&self, root: F) -> Vec { - compute_powers_serial((self.size as usize) / 2, root) + compute_powers_serial((self.size()) / 2, root) } /// Computes the first `self.size / 2` roots of unity. #[cfg(not(feature = "serial"))] pub fn roots_of_unity(&self, root: F) -> Vec { // TODO: check if this method can replace parallel compute powers. - let log_size = log2(self.size as usize); + let log_size = log2(self.size()); // early exit for short inputs if log_size <= LOG_ROOTS_OF_UNITY_PARALLEL_SIZE { - compute_powers_serial((self.size as usize) / 2, root) + compute_powers_serial((self.size()) / 2, root) } else { let mut temp = root; // w, w^2, w^4, w^8, ..., w^(2^(log_size - 1)) @@ -783,8 +784,8 @@ const MIN_GAP_SIZE_FOR_PARALLELISATION: usize = 1 << 10; const LOG_ROOTS_OF_UNITY_PARALLEL_SIZE: u32 = 7; #[inline] -pub(super) fn bitrev(a: u64, log_len: u32) -> u64 { - a.reverse_bits() >> (64 - log_len) +pub(super) fn bitrev(a: usize, log_len: usize) -> usize { + a.reverse_bits() >> (std::mem::size_of::() * 8 - log_len) } pub(crate) fn derange(xi: &mut [T]) { @@ -792,10 +793,10 @@ pub(crate) fn derange(xi: &mut [T]) { } fn derange_helper(xi: &mut [T], log_len: u32) { - for idx in 1..(xi.len() as u64 - 1) { - let ridx = bitrev(idx, log_len); + for idx in 1..(xi.len() - 1) { + let ridx = bitrev(idx, log_len as usize); if idx < ridx { - xi.swap(idx as usize, ridx as usize); + xi.swap(idx, ridx); } } } @@ -864,7 +865,7 @@ impl Iterator for Elements { type Item = F; fn next(&mut self) -> Option { - if self.cur_pow == self.domain.size { + if self.cur_pow == self.domain.size as u64 { None } else { let cur_elem = self.cur_elem; diff --git a/algorithms/src/fft/polynomial/sparse.rs b/algorithms/src/fft/polynomial/sparse.rs index 23b2b1f0df..a3e9dddde1 100644 --- a/algorithms/src/fft/polynomial/sparse.rs +++ b/algorithms/src/fft/polynomial/sparse.rs @@ -16,12 +16,11 @@ use crate::fft::{EvaluationDomain, Evaluations, Polynomial}; use snarkvm_fields::{Field, PrimeField}; -use snarkvm_utilities::serialize::*; use std::{collections::BTreeMap, fmt}; /// Stores a sparse polynomial in coefficient form. -#[derive(Clone, PartialEq, Eq, Hash, Default, CanonicalSerialize, CanonicalDeserialize)] +#[derive(Clone, PartialEq, Eq, Hash, Default)] #[must_use] pub struct SparsePolynomial { /// The coefficient a_i of `x^i` is stored as (i, a_i) in `self.coeffs`. diff --git a/algorithms/src/lib.rs b/algorithms/src/lib.rs index 440e1c6ead..b3ea5878e2 100644 --- a/algorithms/src/lib.rs +++ b/algorithms/src/lib.rs @@ -15,6 +15,7 @@ #![allow(clippy::module_inception)] #![allow(clippy::type_complexity)] #![cfg_attr(test, allow(clippy::assertions_on_result_states))] +#![warn(clippy::cast_possible_truncation)] #[cfg(feature = "wasm")] #[macro_use] diff --git a/algorithms/src/msm/fixed_base.rs b/algorithms/src/msm/fixed_base.rs index 5b167a6d96..81264f5d32 100644 --- a/algorithms/src/msm/fixed_base.rs +++ b/algorithms/src/msm/fixed_base.rs @@ -22,7 +22,7 @@ use rayon::prelude::*; pub struct FixedBase; impl FixedBase { - pub fn get_mul_window_size(num_scalars: usize) -> usize { + pub fn get_mul_window_size(num_scalars: usize) -> u32 { match num_scalars < 32 { true => 3, false => super::ln_without_floats(num_scalars), diff --git a/algorithms/src/msm/mod.rs b/algorithms/src/msm/mod.rs index 97fe46858e..b7a1a84b02 100644 --- a/algorithms/src/msm/mod.rs +++ b/algorithms/src/msm/mod.rs @@ -25,7 +25,7 @@ pub use variable_base::*; /// [`Explanation of usage`] /// /// [`Explanation of usage`]: https://github.com/scipr-lab/zexe/issues/79#issue-556220473 -fn ln_without_floats(a: usize) -> usize { +fn ln_without_floats(a: usize) -> u32 { // log2(a) * ln(2) - (crate::fft::domain::log2(a) * 69 / 100) as usize + crate::fft::domain::log2(a) * 69 / 100 } diff --git a/algorithms/src/msm/tests.rs b/algorithms/src/msm/tests.rs index 8a2cbe5c08..62137b4f12 100644 --- a/algorithms/src/msm/tests.rs +++ b/algorithms/src/msm/tests.rs @@ -45,7 +45,7 @@ fn variable_base_test_with_bls12() { let g = (0..SAMPLES).map(|_| G1Projective::rand(&mut rng).to_affine()).collect::>(); let naive = naive_variable_base_msm(g.as_slice(), v.as_slice()); - let fast = VariableBase::msm(g.as_slice(), v.as_slice()); + let fast = VariableBase::msm(g.as_slice(), v.as_slice()).unwrap(); assert_eq!(naive.to_affine(), fast.to_affine()); } @@ -60,7 +60,7 @@ fn variable_base_test_with_bls12_unequal_numbers() { let g = (0..SAMPLES).map(|_| G1Projective::rand(&mut rng).to_affine()).collect::>(); let naive = naive_variable_base_msm(g.as_slice(), v.as_slice()); - let fast = VariableBase::msm(g.as_slice(), v.as_slice()); + let fast = VariableBase::msm(g.as_slice(), v.as_slice()).unwrap(); assert_eq!(naive.to_affine(), fast.to_affine()); } diff --git a/algorithms/src/msm/variable_base/batched.rs b/algorithms/src/msm/variable_base/batched.rs index 9acf7ace67..4f46e5315d 100644 --- a/algorithms/src/msm/variable_base/batched.rs +++ b/algorithms/src/msm/variable_base/batched.rs @@ -171,10 +171,10 @@ fn batch_add_write( #[inline] pub(super) fn batch_add( - num_buckets: usize, + num_buckets: u32, bases: &[G], bucket_positions: &mut [BucketPosition], -) -> Vec { +) -> Result, anyhow::Error> { // assert_eq!(bases.len(), bucket_positions.len()); assert!(!bases.is_empty()); @@ -202,7 +202,7 @@ pub(super) fn batch_add( global_counter += 1; local_counter += 1; } - if current_bucket >= num_buckets as u32 { + if current_bucket >= num_buckets { local_counter = 1; } else if local_counter > 1 { // all ones is false if next len is not 1 @@ -216,13 +216,17 @@ pub(super) fn batch_add( bucket_positions[global_counter - (local_counter - 1) + 2 * i].scalar_index, bucket_positions[global_counter - (local_counter - 1) + 2 * i + 1].scalar_index, )); - bucket_positions[new_scalar_length + i] = - BucketPosition { bucket_index: current_bucket, scalar_index: (new_scalar_length + i) as u32 }; + bucket_positions[new_scalar_length + i] = BucketPosition { + bucket_index: current_bucket, + scalar_index: u32::try_from(new_scalar_length + i)?, + }; } if is_odd { instr.push((bucket_positions[global_counter].scalar_index, !0u32)); - bucket_positions[new_scalar_length + half] = - BucketPosition { bucket_index: current_bucket, scalar_index: (new_scalar_length + half) as u32 }; + bucket_positions[new_scalar_length + half] = BucketPosition { + bucket_index: current_bucket, + scalar_index: u32::try_from(new_scalar_length + half)?, + }; } // Reset the local_counter and update state new_scalar_length += half + (local_counter % 2); @@ -241,7 +245,7 @@ pub(super) fn batch_add( } else { instr.push((bucket_positions[global_counter].scalar_index, !0u32)); bucket_positions[new_scalar_length] = - BucketPosition { bucket_index: current_bucket, scalar_index: new_scalar_length as u32 }; + BucketPosition { bucket_index: current_bucket, scalar_index: u32::try_from(new_scalar_length)? }; new_scalar_length += 1; } global_counter += 1; @@ -267,7 +271,7 @@ pub(super) fn batch_add( global_counter += 1; local_counter += 1; } - if current_bucket >= num_buckets as u32 { + if current_bucket >= num_buckets { local_counter = 1; } else if local_counter > 1 { // all ones is false if next len is not 1 @@ -315,20 +319,20 @@ pub(super) fn batch_add( new_scalar_length = 0; } - let mut res = vec![Zero::zero(); num_buckets]; + let mut res = vec![Zero::zero(); num_buckets as usize]; for bucket_position in bucket_positions.iter().take(num_scalars) { res[bucket_position.bucket_index as usize] = new_bases[bucket_position.scalar_index as usize]; } - res + Ok(res) } #[inline] fn batched_window( bases: &[G], scalars: &[::BigInteger], - w_start: usize, - c: usize, -) -> (G::Projective, usize) { + w_start: u32, + c: u32, +) -> Result<(G::Projective, u32), anyhow::Error> { // We don't need the "zero" bucket, so we only have 2^c - 1 buckets let window_size = if (w_start % c) != 0 { w_start % c } else { c }; let num_buckets = (1 << window_size) - 1; @@ -340,16 +344,19 @@ fn batched_window( let mut scalar = scalar; // We right-shift by w_start, thus getting rid of the lower bits. - scalar.divn(w_start as u32); + scalar.divn(w_start); // We mod the remaining bits by the window size. - let scalar = (scalar.as_ref()[0] % (1 << c)) as i32; + let scalar = i32::try_from(scalar.as_ref()[0] % (1 << c))?; - BucketPosition { bucket_index: (scalar - 1) as u32, scalar_index: scalar_index as u32 } + Ok::<_, anyhow::Error>(BucketPosition { + bucket_index: (scalar - 1) as u32, + scalar_index: u32::try_from(scalar_index)?, + }) }) - .collect(); + .collect::>()?; - let buckets = batch_add(num_buckets, bases, &mut bucket_positions); + let buckets = batch_add(num_buckets, bases, &mut bucket_positions)?; let mut res = G::Projective::zero(); let mut running_sum = G::Projective::zero(); @@ -358,10 +365,13 @@ fn batched_window( res += &running_sum; } - (res, window_size) + Ok((res, window_size)) } -pub fn msm(bases: &[G], scalars: &[::BigInteger]) -> G::Projective { +pub fn msm( + bases: &[G], + scalars: &[::BigInteger], +) -> Result { if bases.len() < 15 { let num_bits = G::ScalarField::size_in_bits(); let bigint_size = ::BigInteger::NUM_LIMBS * 64; @@ -382,7 +392,7 @@ pub fn msm(bases: &[G], scalars: &[(bases: &[G], scalars: &[ crate::msm::ln_without_floats(scalars.len()) + 2, }; - let num_bits = ::size_in_bits(); + let num_bits = ::size_in_bits_u32(); // Each window is of size `c`. // We divide up the bits 0..num_bits into windows of size `c`, and // in parallel process each such window. - let window_sums: Vec<_> = - cfg_into_iter!(0..num_bits).step_by(c).map(|w_start| batched_window(bases, scalars, w_start, c)).collect(); + let window_sums: Vec<_> = cfg_into_iter!(0..num_bits) + .step_by(c as usize) + .map(|w_start| batched_window(bases, scalars, w_start, c)) + .collect::>()?; // We store the sum for the lowest window. let (lowest, window_sums) = window_sums.split_first().unwrap(); // We're traversing windows from high to low. - window_sums.iter().rev().fold(G::Projective::zero(), |mut total, (sum_i, window_size)| { + Ok(window_sums.iter().rev().fold(G::Projective::zero(), |mut total, (sum_i, window_size)| { total += sum_i; for _ in 0..*window_size { total.double_in_place(); } total - }) + lowest.0 + }) + lowest.0) } } diff --git a/algorithms/src/msm/variable_base/mod.rs b/algorithms/src/msm/variable_base/mod.rs index 11d4910bf3..6f68897625 100644 --- a/algorithms/src/msm/variable_base/mod.rs +++ b/algorithms/src/msm/variable_base/mod.rs @@ -26,7 +26,10 @@ use core::any::TypeId; pub struct VariableBase; impl VariableBase { - pub fn msm(bases: &[G], scalars: &[::BigInteger]) -> G::Projective { + pub fn msm( + bases: &[G], + scalars: &[::BigInteger], + ) -> Result { // For BLS12-377, we perform variable base MSM using a batched addition technique. if TypeId::of::() == TypeId::of::() { #[cfg(all(feature = "cuda", target_arch = "x86_64"))] @@ -35,8 +38,9 @@ impl VariableBase { let result = snarkvm_algorithms_cuda::msm::::BigInteger>( bases, scalars, ); + // Remove any cuda::Error if let Ok(result) = result { - return result; + return Ok(result); } } batched::msm(bases, scalars) @@ -97,10 +101,10 @@ mod tests { let naive_b = VariableBase::msm_naive_parallel(bases.as_slice(), scalars.as_slice()).to_affine(); assert_eq!(naive_a, naive_b, "MSM size: {msm_size}"); - let candidate = standard::msm(bases.as_slice(), scalars.as_slice()).to_affine(); + let candidate = standard::msm(bases.as_slice(), scalars.as_slice()).unwrap().to_affine(); assert_eq!(naive_a, candidate, "MSM size: {msm_size}"); - let candidate = batched::msm(bases.as_slice(), scalars.as_slice()).to_affine(); + let candidate = batched::msm(bases.as_slice(), scalars.as_slice()).unwrap().to_affine(); assert_eq!(naive_a, candidate, "MSM size: {msm_size}"); } } @@ -111,8 +115,8 @@ mod tests { let mut rng = TestRng::default(); for i in 2..17 { let (bases, scalars) = create_scalar_bases::(&mut rng, 1 << i); - let rust = standard::msm(bases.as_slice(), scalars.as_slice()); - let cuda = VariableBase::msm::(bases.as_slice(), scalars.as_slice()); + let rust = standard::msm(bases.as_slice(), scalars.as_slice()).unwrap(); + let cuda = VariableBase::msm::(bases.as_slice(), scalars.as_slice()).unwrap(); assert_eq!(rust.to_affine(), cuda.to_affine()); } } diff --git a/algorithms/src/msm/variable_base/standard.rs b/algorithms/src/msm/variable_base/standard.rs index ddba995b4f..565eb51e3d 100644 --- a/algorithms/src/msm/variable_base/standard.rs +++ b/algorithms/src/msm/variable_base/standard.rs @@ -22,12 +22,12 @@ use rayon::prelude::*; fn update_buckets( base: &G, mut scalar: ::BigInteger, - w_start: usize, - c: usize, + w_start: u32, + c: u32, buckets: &mut [G::Projective], -) { +) -> Result<(), anyhow::Error> { // We right-shift by w_start, thus getting rid of the lower bits. - scalar.divn(w_start as u32); + scalar.divn(w_start); // We mod the remaining bits by the window size. let scalar = scalar.as_ref()[0] % (1 << c); @@ -35,16 +35,17 @@ fn update_buckets( // If the scalar is non-zero, we update the corresponding bucket. // (Recall that `buckets` doesn't have a zero bucket.) if scalar != 0 { - buckets[(scalar - 1) as usize].add_assign_mixed(base); + buckets[usize::try_from(scalar - 1)?].add_assign_mixed(base); } + Ok(()) } fn standard_window( bases: &[G], scalars: &[::BigInteger], - w_start: usize, - c: usize, -) -> (G::Projective, usize) { + w_start: u32, + c: u32, +) -> Result<(G::Projective, u32), anyhow::Error> { let mut res = G::Projective::zero(); let fr_one = G::ScalarField::one().to_bigint(); @@ -62,7 +63,7 @@ fn standard_window( .iter() .zip(bases) .filter(|(&s, _)| s > fr_one) - .for_each(|(&scalar, base)| update_buckets(base, scalar, w_start, c, &mut buckets)); + .try_for_each(|(&scalar, base)| update_buckets(base, scalar, w_start, c, &mut buckets))?; // G::Projective::batch_normalization(&mut buckets); for running_sum in buckets.into_iter().rev().scan(G::Projective::zero(), |sum, b| { @@ -72,33 +73,38 @@ fn standard_window( res += running_sum; } - (res, window_size) + Ok((res, window_size)) } -pub fn msm(bases: &[G], scalars: &[::BigInteger]) -> G::Projective { +pub fn msm( + bases: &[G], + scalars: &[::BigInteger], +) -> Result { // Determine the bucket size `c` (chosen empirically). let c = match scalars.len() < 32 { true => 1, false => crate::msm::ln_without_floats(scalars.len()) + 2, }; - let num_bits = ::size_in_bits(); + let num_bits = ::size_in_bits_u32(); // Each window is of size `c`. // We divide up the bits 0..num_bits into windows of size `c`, and // in parallel process each such window. - let window_sums: Vec<_> = - cfg_into_iter!(0..num_bits).step_by(c).map(|w_start| standard_window(bases, scalars, w_start, c)).collect(); + let window_sums: Vec<_> = cfg_into_iter!(0..num_bits) + .step_by(c as usize) + .map(|w_start| standard_window(bases, scalars, w_start, c)) + .collect::>()?; // We store the sum for the lowest window. let (lowest, window_sums) = window_sums.split_first().unwrap(); // We're traversing windows from high to low. - window_sums.iter().rev().fold(G::Projective::zero(), |mut total, (sum_i, window_size)| { + Ok(window_sums.iter().rev().fold(G::Projective::zero(), |mut total, (sum_i, window_size)| { total += sum_i; for _ in 0..*window_size { total.double_in_place(); } total - }) + lowest.0 + }) + lowest.0) } diff --git a/algorithms/src/polycommit/error.rs b/algorithms/src/polycommit/error.rs index 6113ca9ffb..4c886733c1 100644 --- a/algorithms/src/polycommit/error.rs +++ b/algorithms/src/polycommit/error.rs @@ -92,6 +92,9 @@ pub enum PCError { label: String, }, + /// Could not convert from int. + IntError(std::num::TryFromIntError), + Terminated, } @@ -103,6 +106,12 @@ impl From for PCError { } } +impl From for PCError { + fn from(other: std::num::TryFromIntError) -> Self { + Self::IntError(other) + } +} + impl core::fmt::Display for PCError { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { match self { @@ -148,6 +157,7 @@ impl core::fmt::Display for PCError { (having degree {poly_degree:?}) is greater than the maximum \ supported degree ({supported_degree:?})" ), + Self::IntError(error) => write!(f, "{error}"), Self::Terminated => write!(f, "terminated"), } } diff --git a/algorithms/src/polycommit/kzg10/data_structures.rs b/algorithms/src/polycommit/kzg10/data_structures.rs index ac35c6383a..b0059db6b5 100644 --- a/algorithms/src/polycommit/kzg10/data_structures.rs +++ b/algorithms/src/polycommit/kzg10/data_structures.rs @@ -24,6 +24,7 @@ use snarkvm_utilities::{ error, io::{Read, Write}, serialize::{CanonicalDeserialize, CanonicalSerialize, Compress, SerializationError, Valid, Validate}, + try_write_as, FromBytes, ToBytes, ToMinimalBits, @@ -138,9 +139,9 @@ impl ToBytes for UniversalParams { self.h.write_le(&mut writer)?; // Serialize `supported_degree_bounds`. - (self.supported_degree_bounds.len() as u32).write_le(&mut writer)?; + try_write_as::(self.supported_degree_bounds.len(), &mut writer)?; for degree_bound in &self.supported_degree_bounds { - (*degree_bound as u32).write_le(&mut writer)?; + try_write_as::(*degree_bound, &mut writer)?; } // Serialize `prepared_h`. diff --git a/algorithms/src/polycommit/kzg10/mod.rs b/algorithms/src/polycommit/kzg10/mod.rs index 0d7b30f2a0..6e7e3a2b57 100644 --- a/algorithms/src/polycommit/kzg10/mod.rs +++ b/algorithms/src/polycommit/kzg10/mod.rs @@ -118,7 +118,7 @@ impl KZG10 { let (num_leading_zeros, plain_coeffs) = skip_leading_zeros_and_convert_to_bigints(polynomial); let msm_time = start_timer!(|| "MSM to compute commitment to plaintext poly"); - let commitment = VariableBase::msm(&powers.powers_of_beta_g[num_leading_zeros..], &plain_coeffs); + let commitment = VariableBase::msm(&powers.powers_of_beta_g[num_leading_zeros..], &plain_coeffs)?; end_timer!(msm_time); if terminator.load(Ordering::Relaxed) { @@ -151,7 +151,7 @@ impl KZG10 { let random_ints = convert_to_bigints(&randomness.blinding_polynomial.coeffs); let msm_time = start_timer!(|| "MSM to compute commitment to random poly"); let random_commitment = - VariableBase::msm(&powers.powers_of_beta_times_gamma_g, random_ints.as_slice()).to_affine(); + VariableBase::msm(&powers.powers_of_beta_times_gamma_g, random_ints.as_slice())?.to_affine(); end_timer!(msm_time); if terminator.load(Ordering::Relaxed) { @@ -186,7 +186,7 @@ impl KZG10 { let evaluations = evaluations.iter().map(|e| e.to_bigint()).collect::>(); let msm_time = start_timer!(|| "MSM to compute commitment to plaintext poly"); - let mut commitment = VariableBase::msm(&lagrange_basis.lagrange_basis_at_beta_g, &evaluations); + let mut commitment = VariableBase::msm(&lagrange_basis.lagrange_basis_at_beta_g, &evaluations)?; end_timer!(msm_time); if terminator.load(Ordering::Relaxed) { @@ -210,7 +210,7 @@ impl KZG10 { let random_ints = convert_to_bigints(&randomness.blinding_polynomial.coeffs); let msm_time = start_timer!(|| "MSM to compute commitment to random poly"); let random_commitment = - VariableBase::msm(&lagrange_basis.powers_of_beta_times_gamma_g, random_ints.as_slice()).to_affine(); + VariableBase::msm(&lagrange_basis.powers_of_beta_times_gamma_g, random_ints.as_slice())?.to_affine(); end_timer!(msm_time); if terminator.load(Ordering::Relaxed) { @@ -265,7 +265,7 @@ impl KZG10 { let (num_leading_zeros, witness_coeffs) = skip_leading_zeros_and_convert_to_bigints(witness_polynomial); let witness_comm_time = start_timer!(|| "Computing commitment to witness polynomial"); - let mut w = VariableBase::msm(&powers.powers_of_beta_g[num_leading_zeros..], &witness_coeffs); + let mut w = VariableBase::msm(&powers.powers_of_beta_g[num_leading_zeros..], &witness_coeffs)?; end_timer!(witness_comm_time); let random_v = if let Some(hiding_witness_polynomial) = hiding_witness_polynomial { @@ -276,7 +276,7 @@ impl KZG10 { let random_witness_coeffs = convert_to_bigints(&hiding_witness_polynomial.coeffs); let witness_comm_time = start_timer!(|| "Computing commitment to random witness polynomial"); - w += &VariableBase::msm(&powers.powers_of_beta_times_gamma_g, &random_witness_coeffs); + w += &VariableBase::msm(&powers.powers_of_beta_times_gamma_g, &random_witness_coeffs)?; end_timer!(witness_comm_time); Some(blinding_evaluation) } else { diff --git a/algorithms/src/polycommit/sonic_pc/data_structures.rs b/algorithms/src/polycommit/sonic_pc/data_structures.rs index 446dcdd84c..8230be4d78 100644 --- a/algorithms/src/polycommit/sonic_pc/data_structures.rs +++ b/algorithms/src/polycommit/sonic_pc/data_structures.rs @@ -16,7 +16,7 @@ use crate::{crypto_hash::sha256::sha256, fft::EvaluationDomain, polycommit::kzg1 use hashbrown::{HashMap, HashSet}; use snarkvm_curves::{PairingCurve, PairingEngine, ProjectiveCurve}; use snarkvm_fields::{ConstraintFieldError, Field, PrimeField, ToConstraintField}; -use snarkvm_utilities::{error, serialize::*, FromBytes, ToBytes}; +use snarkvm_utilities::{error, serialize::*, try_write_as, FromBytes, ToBytes}; use std::{ borrow::{Borrow, Cow}, @@ -81,7 +81,7 @@ pub struct CommitterKey { pub enforced_degree_bounds: Option>, /// The maximum degree supported by the `UniversalParams` from which `self` was derived - pub max_degree: usize, + pub max_degree: u32, } impl FromBytes for CommitterKey { @@ -215,7 +215,7 @@ impl FromBytes for CommitterKey { shifted_powers_of_beta_g, shifted_powers_of_beta_times_gamma_g, enforced_degree_bounds, - max_degree: max_degree as usize, + max_degree, }) } } @@ -223,22 +223,22 @@ impl FromBytes for CommitterKey { impl ToBytes for CommitterKey { fn write_le(&self, mut writer: W) -> io::Result<()> { // Serialize `powers`. - (self.powers_of_beta_g.len() as u32).write_le(&mut writer)?; + try_write_as::(self.powers_of_beta_g.len(), &mut writer)?; for power in &self.powers_of_beta_g { power.write_le(&mut writer)?; } // Serialize `powers`. - (self.lagrange_bases_at_beta_g.len() as u32).write_le(&mut writer)?; + try_write_as::(self.lagrange_bases_at_beta_g.len(), &mut writer)?; for (size, powers) in &self.lagrange_bases_at_beta_g { - (*size as u32).write_le(&mut writer)?; + try_write_as::(*size, &mut writer)?; for power in powers { power.write_le(&mut writer)?; } } // Serialize `powers_of_beta_times_gamma_g`. - (self.powers_of_beta_times_gamma_g.len() as u32).write_le(&mut writer)?; + try_write_as::(self.powers_of_beta_times_gamma_g.len(), &mut writer)?; for power_of_gamma_g in &self.powers_of_beta_times_gamma_g { power_of_gamma_g.write_le(&mut writer)?; } @@ -246,7 +246,7 @@ impl ToBytes for CommitterKey { // Serialize `shifted_powers_of_beta_g`. self.shifted_powers_of_beta_g.is_some().write_le(&mut writer)?; if let Some(shifted_powers_of_beta_g) = &self.shifted_powers_of_beta_g { - (shifted_powers_of_beta_g.len() as u32).write_le(&mut writer)?; + try_write_as::(shifted_powers_of_beta_g.len(), &mut writer)?; for shifted_power in shifted_powers_of_beta_g { shifted_power.write_le(&mut writer)?; } @@ -255,11 +255,11 @@ impl ToBytes for CommitterKey { // Serialize `shifted_powers_of_beta_times_gamma_g`. self.shifted_powers_of_beta_times_gamma_g.is_some().write_le(&mut writer)?; if let Some(shifted_powers_of_beta_times_gamma_g) = &self.shifted_powers_of_beta_times_gamma_g { - (shifted_powers_of_beta_times_gamma_g.len() as u32).write_le(&mut writer)?; - for (key, shifted_powers_of_beta_g) in shifted_powers_of_beta_times_gamma_g { - (*key as u32).write_le(&mut writer)?; - (shifted_powers_of_beta_g.len() as u32).write_le(&mut writer)?; - for shifted_power in shifted_powers_of_beta_g { + try_write_as::(shifted_powers_of_beta_times_gamma_g.len(), &mut writer)?; + for (key, shifted_powers_of_beta_times_gamma_g) in shifted_powers_of_beta_times_gamma_g { + try_write_as::(*key, &mut writer)?; + try_write_as::(shifted_powers_of_beta_times_gamma_g.len(), &mut writer)?; + for shifted_power in shifted_powers_of_beta_times_gamma_g { shifted_power.write_le(&mut writer)?; } } @@ -268,14 +268,14 @@ impl ToBytes for CommitterKey { // Serialize `enforced_degree_bounds`. self.enforced_degree_bounds.is_some().write_le(&mut writer)?; if let Some(enforced_degree_bounds) = &self.enforced_degree_bounds { - (enforced_degree_bounds.len() as u32).write_le(&mut writer)?; + try_write_as::(enforced_degree_bounds.len(), &mut writer)?; for enforced_degree_bound in enforced_degree_bounds { - (*enforced_degree_bound as u32).write_le(&mut writer)?; + try_write_as::(*enforced_degree_bound, &mut writer)?; } } // Serialize `max_degree`. - (self.max_degree as u32).write_le(&mut writer)?; + self.max_degree.write_le(&mut writer)?; // Construct the hash of the group elements. let mut hash_input = self.powers_of_beta_g.to_bytes_le().map_err(|_| error("Could not serialize powers"))?; @@ -341,7 +341,7 @@ pub struct CommitterUnionKey<'a, E: PairingEngine> { pub enforced_degree_bounds: Option>, /// The maximum degree supported by the `UniversalParams` from which `self` was derived - pub max_degree: usize, + pub max_degree: u32, } impl<'a, E: PairingEngine> CommitterUnionKey<'a, E> { @@ -387,7 +387,7 @@ impl<'a, E: PairingEngine> CommitterUnionKey<'a, E> { }) } - pub fn max_degree(&self) -> usize { + pub fn max_degree(&self) -> u32 { self.max_degree } diff --git a/algorithms/src/polycommit/sonic_pc/mod.rs b/algorithms/src/polycommit/sonic_pc/mod.rs index aac5a134d8..69d5d04a9e 100644 --- a/algorithms/src/polycommit/sonic_pc/mod.rs +++ b/algorithms/src/polycommit/sonic_pc/mod.rs @@ -156,7 +156,7 @@ impl> SonicKZG10 { shifted_powers_of_beta_g, shifted_powers_of_beta_times_gamma_g, enforced_degree_bounds, - max_degree, + max_degree: u32::try_from(max_degree)?, }; let g = pp.power_of_beta_g(0)?; @@ -244,7 +244,7 @@ impl> SonicKZG10 { kzg10::KZG10::::check_degrees_and_bounds( ck.supported_degree(), - ck.max_degree, + usize::try_from(ck.max_degree)?, ck.enforced_degree_bounds.as_deref(), p.clone(), )?; @@ -327,16 +327,12 @@ impl> SonicKZG10 { Randomness: 'a, Commitment: 'a, { + let max_degree = usize::try_from(ck.max_degree)?; Ok(Self::combine_polynomials(labeled_polynomials.into_iter().zip_eq(rands).map(|(p, r)| { let enforced_degree_bounds: Option<&[usize]> = ck.enforced_degree_bounds.as_deref(); - kzg10::KZG10::::check_degrees_and_bounds( - ck.supported_degree(), - ck.max_degree, - enforced_degree_bounds, - p, - ) - .unwrap(); + kzg10::KZG10::::check_degrees_and_bounds(ck.supported_degree(), max_degree, enforced_degree_bounds, p) + .unwrap(); let challenge = fs_rng.squeeze_short_nonnative_field_element::(); (challenge, p.polynomial().to_dense(), r) }))) @@ -464,7 +460,7 @@ impl> SonicKZG10 { p, Some(randomizer), fs_rng, - ); + )?; randomizer = fs_rng.squeeze_short_nonnative_field_element::(); } @@ -529,7 +525,7 @@ impl> SonicKZG10 { let lc_poly = LabeledPolynomial::new(lc_label.clone(), poly, degree_bound, hiding_bound); lc_polynomials.push(lc_poly); lc_randomness.push(randomness); - lc_commitments.push(Self::combine_commitments(coeffs_and_comms)); + lc_commitments.push(Self::combine_commitments(coeffs_and_comms)?); lc_info.push((lc_label, degree_bound)); } @@ -604,7 +600,7 @@ impl> SonicKZG10 { } } let lc_time = start_timer!(|| format!("Combining {num_polys} commitments for {lc_label}")); - lc_commitments.push(Self::combine_commitments(coeffs_and_comms)); + lc_commitments.push(Self::combine_commitments(coeffs_and_comms)?); end_timer!(lc_time); lc_info.push((lc_label, degree_bound)); } @@ -644,7 +640,7 @@ impl> SonicKZG10 { /// MSM for `commitments` and `coeffs` fn combine_commitments<'a>( coeffs_and_comms: impl IntoIterator)>, - ) -> E::G1Projective { + ) -> Result { let (scalars, bases): (Vec<_>, Vec<_>) = coeffs_and_comms.into_iter().map(|(f, c)| (f.into(), c.0)).unzip(); VariableBase::msm(&bases, &scalars) } @@ -668,7 +664,7 @@ impl> SonicKZG10 { proof: &kzg10::KZGProof, randomizer: Option, fs_rng: &mut S, - ) { + ) -> Result<(), PCError> { let acc_time = start_timer!(|| "Accumulating elements"); // Keeps track of running combination of values let mut combined_values = E::Fr::zero(); @@ -706,8 +702,9 @@ impl> SonicKZG10 { proof.w.to_projective() }; let coeffs = coeffs.into_iter().map(|c| c.into()).collect::>(); - *combined_adjusted_witness += VariableBase::msm(&bases, &coeffs); + *combined_adjusted_witness += VariableBase::msm(&bases, &coeffs)?; end_timer!(acc_time); + Ok(()) } #[allow(clippy::type_complexity)] diff --git a/algorithms/src/snark/marlin/ahp/ahp.rs b/algorithms/src/snark/marlin/ahp/ahp.rs index 9b7e42939a..39270bfb5c 100644 --- a/algorithms/src/snark/marlin/ahp/ahp.rs +++ b/algorithms/src/snark/marlin/ahp/ahp.rs @@ -29,7 +29,7 @@ use snarkvm_fields::{Field, PrimeField}; use snarkvm_r1cs::SynthesisError; use core::{borrow::Borrow, marker::PhantomData}; -use std::collections::BTreeMap; +use std::{collections::BTreeMap, num::TryFromIntError}; /// The algebraic holographic proof defined in [CHMMVW19](https://eprint.iacr.org/2019/1047). /// Currently, this AHP only supports inputs of size one @@ -46,7 +46,7 @@ struct VerifierChallenges { gamma: F, } -pub(crate) fn witness_label(circuit_id: CircuitId, poly: &str, i: usize) -> String { +pub(crate) fn witness_label(circuit_id: CircuitId, poly: &str, i: u32) -> String { format!("circuit_{circuit_id}_{poly}_{i:0>8}") } @@ -89,7 +89,7 @@ impl AHPForR1CS { /// of this protocol. /// The number of the variables must include the "one" variable. That is, it /// must be with respect to the number of formatted public inputs. - pub fn max_degree(num_constraints: usize, num_variables: usize, num_non_zero: usize) -> Result { + pub fn max_degree(num_constraints: usize, num_variables: usize, num_non_zero: usize) -> Result { let padded_matrix_dim = matrices::padded_matrix_dim(num_variables, num_constraints); let zk_bound = Self::zk_bound().unwrap_or(0); let constraint_domain_size = EvaluationDomain::::compute_size_of_domain(padded_matrix_dim) @@ -97,16 +97,18 @@ impl AHPForR1CS { let non_zero_domain_size = EvaluationDomain::::compute_size_of_domain(num_non_zero).ok_or(AHPError::PolynomialDegreeTooLarge)?; - Ok(*[ - 2 * constraint_domain_size + zk_bound - 2, - if MM::ZK { constraint_domain_size + 3 } else { 0 }, // mask_poly - constraint_domain_size, - constraint_domain_size, - non_zero_domain_size - 1, // non-zero polynomials - ] - .iter() - .max() - .unwrap()) + Ok(u32::try_from( + *[ + 2 * constraint_domain_size + zk_bound - 2, + if MM::ZK { constraint_domain_size + 3 } else { 0 }, // mask_poly + constraint_domain_size, + constraint_domain_size, + non_zero_domain_size - 1, // non-zero polynomials + ] + .iter() + .max() + .unwrap(), + )?) } /// Get all the strict degree bounds enforced in the AHP. @@ -242,12 +244,12 @@ impl AHPForR1CS { let z_b_i = (0..circuit_state.batch_size) .map(|i| { let z_b = witness_label(circuit_id, "z_b", i); - LinearCombination::new(z_b.clone(), [(F::one(), z_b)]) + Ok::<_, TryFromIntError>(LinearCombination::new(z_b.clone(), [(F::one(), z_b)])) }) - .collect::>(); - (circuit_id, z_b_i) + .collect::, _>>()?; + Ok((circuit_id, z_b_i)) }) - .collect::>(); + .collect::, TryFromIntError>>()?; let g_1 = LinearCombination::new("g_1", [(F::one(), "g_1")]); @@ -279,10 +281,11 @@ impl AHPForR1CS { let z_b_s_at_beta = z_b_s .iter() .map(|(circuit_id, z_b_i)| { - let z_b_i_s = z_b_i.iter().map(|z_b| evals.get_lc_eval(z_b, beta).unwrap()).collect::>(); - (*circuit_id, z_b_i_s) + let z_b_i_s = z_b_i.iter().map(|z_b| evals.get_lc_eval(z_b, beta)).try_collect()?; + Ok((*circuit_id, z_b_i_s)) }) - .collect::>(); + .collect::>, AHPError>>()?; + let batch_z_b_s_at_beta = z_b_s_at_beta .iter() .zip_eq(batch_combiners.values()) @@ -325,8 +328,8 @@ impl AHPForR1CS { for (id, c) in batch_combiners.iter() { let mut circuit_term = LinearCombination::empty(format!("lincheck_sumcheck term {id}")); for (j, instance_combiner) in c.instance_combiners.iter().enumerate() { - let z_a_j = witness_label(*id, "z_a", j); - let w_j = witness_label(*id, "w", j); + let z_a_j = witness_label(*id, "z_a", u32::try_from(j)?); + let w_j = witness_label(*id, "w", u32::try_from(j)?); circuit_term .add(r_alpha_at_beta_s[id] * instance_combiner * (eta_a + eta_c * z_b_s_at_beta[id][j]), z_a_j) .add(-t_at_beta_s[id] * v_X_at_beta[id] * instance_combiner, w_j); @@ -464,7 +467,7 @@ fn selector_evals( challenge: F, ) -> F { *cached_selector_evaluations - .entry((target_domain.size, largest_domain.size, challenge)) + .entry((target_domain.size as u64, largest_domain.size as u64, challenge)) .or_insert_with(|| largest_domain.evaluate_selector_polynomial(*target_domain, challenge)) } diff --git a/algorithms/src/snark/marlin/ahp/errors.rs b/algorithms/src/snark/marlin/ahp/errors.rs index 7fa9b5fc34..b9c332db9b 100644 --- a/algorithms/src/snark/marlin/ahp/errors.rs +++ b/algorithms/src/snark/marlin/ahp/errors.rs @@ -21,6 +21,8 @@ pub enum AHPError { ConstraintSystemError(snarkvm_r1cs::errors::SynthesisError), /// The instance generated during proving does not match that in the index. InstanceDoesNotMatchIndex, + /// Could not convert from int. + IntError(std::num::TryFromIntError), /// The number of public inputs is incorrect. InvalidPublicInputLength, /// During verification, a required evaluation is missing @@ -36,3 +38,9 @@ impl From for AHPError { AHPError::ConstraintSystemError(other) } } + +impl From for AHPError { + fn from(other: std::num::TryFromIntError) -> Self { + Self::IntError(other) + } +} diff --git a/algorithms/src/snark/marlin/ahp/indexer/circuit.rs b/algorithms/src/snark/marlin/ahp/indexer/circuit.rs index f14b540a29..3e65dd094b 100644 --- a/algorithms/src/snark/marlin/ahp/indexer/circuit.rs +++ b/algorithms/src/snark/marlin/ahp/indexer/circuit.rs @@ -113,7 +113,7 @@ impl Circuit { } /// The maximum degree required to represent polynomials of this index. - pub fn max_degree(&self) -> usize { + pub fn max_degree(&self) -> u32 { self.index_info.max_degree::() } diff --git a/algorithms/src/snark/marlin/ahp/indexer/circuit_info.rs b/algorithms/src/snark/marlin/ahp/indexer/circuit_info.rs index 9354b4d57d..dd0665f7be 100644 --- a/algorithms/src/snark/marlin/ahp/indexer/circuit_info.rs +++ b/algorithms/src/snark/marlin/ahp/indexer/circuit_info.rs @@ -43,7 +43,7 @@ pub struct CircuitInfo { impl CircuitInfo { /// The maximum degree of polynomial required to represent this index in the AHP. - pub fn max_degree(&self) -> usize { + pub fn max_degree(&self) -> u32 { let max_non_zero = self.num_non_zero_a.max(self.num_non_zero_b).max(self.num_non_zero_c); AHPForR1CS::::max_degree(self.num_constraints, self.num_variables, max_non_zero).unwrap() } diff --git a/algorithms/src/snark/marlin/ahp/prover/round_functions/first.rs b/algorithms/src/snark/marlin/ahp/prover/round_functions/first.rs index 1883e49e90..b178250ff7 100644 --- a/algorithms/src/snark/marlin/ahp/prover/round_functions/first.rs +++ b/algorithms/src/snark/marlin/ahp/prover/round_functions/first.rs @@ -49,9 +49,9 @@ impl AHPForR1CS { /// Output the degree bounds of oracles in the first round. pub fn first_round_polynomial_info<'a>( - circuits: impl Iterator, + batch_sizes: impl Iterator, ) -> BTreeMap { - let mut polynomials = circuits + let mut polynomials = batch_sizes .flat_map(|(&circuit_id, &batch_size)| { (0..batch_size).flat_map(move |i| { [ @@ -76,9 +76,9 @@ impl AHPForR1CS { ) -> Result, AHPError> { let round_time = start_timer!(|| "AHP::Prover::FirstRound"); let mut r_b_s = Vec::with_capacity(state.circuit_specific_states.len()); - let mut job_pool = snarkvm_utilities::ExecutionPool::with_capacity(3 * state.total_instances); + let mut job_pool = snarkvm_utilities::ExecutionPool::with_capacity((3 * state.total_instances).try_into()?); for (circuit, circuit_state) in state.circuit_specific_states.iter_mut() { - let batch_size = circuit_state.batch_size; + let batch_size = usize::try_from(circuit_state.batch_size)?; let z_a = circuit_state.z_a.take().unwrap(); let z_b = circuit_state.z_b.take().unwrap(); @@ -96,9 +96,9 @@ impl AHPForR1CS { for (j, (z_a, z_b, private_vars, x_poly)) in itertools::izip!(z_a, z_b, private_variables, x_polys).enumerate() { - let w_label = witness_label(circuit.id, "w", j); - let za_label = witness_label(circuit.id, "z_a", j); - let zb_label = witness_label(circuit.id, "z_b", j); + let w_label = witness_label(circuit.id, "w", u32::try_from(j)?); + let za_label = witness_label(circuit.id, "z_a", u32::try_from(j)?); + let zb_label = witness_label(circuit.id, "z_b", u32::try_from(j)?); job_pool.add_job(move || Self::calc_w(w_label, private_vars, x_poly, c_domain, i_domain, circuit)); job_pool.add_job(move || Self::calc_z_m(za_label, z_a, c_domain, circuit, None)); let r_b = F::rand(rng); @@ -121,11 +121,11 @@ impl AHPForR1CS { prover::SingleEntry { z_a, z_b, w_poly, z_a_poly, z_b_poly } }) .collect::>(); - assert_eq!(batches.len(), state.total_instances); + assert_eq!(batches.len(), usize::try_from(state.total_instances)?); let mut circuit_specific_batches = BTreeMap::new(); for ((circuit, state), r_b_s) in state.circuit_specific_states.iter_mut().zip(r_b_s) { - let batches = batches.drain(0..state.batch_size).collect_vec(); + let batches = batches.drain(0..state.batch_size.try_into()?).collect_vec(); circuit_specific_batches.insert(circuit.id, batches); state.mz_poly_randomizer = MM::ZK.then_some(r_b_s); end_timer!(round_time); diff --git a/algorithms/src/snark/marlin/ahp/prover/state.rs b/algorithms/src/snark/marlin/ahp/prover/state.rs index 8eaf4b2e8a..1e7e54df93 100644 --- a/algorithms/src/snark/marlin/ahp/prover/state.rs +++ b/algorithms/src/snark/marlin/ahp/prover/state.rs @@ -30,7 +30,7 @@ pub struct CircuitSpecificState { pub(super) non_zero_c_domain: EvaluationDomain, /// The number of instances being proved in this batch. - pub(in crate::snark) batch_size: usize, + pub(in crate::snark) batch_size: u32, /// The list of public inputs for each instance in the batch. /// The length of this list must be equal to the batch size. @@ -74,7 +74,7 @@ pub struct State<'a, F: PrimeField, MM: MarlinMode> { /// The largest constraint domain of all circuits in the batch. pub(in crate::snark) max_constraint_domain: EvaluationDomain, /// The total number of instances we're proving in the batch. - pub(in crate::snark) total_instances: usize, + pub(in crate::snark) total_instances: u32, } /// The public inputs for a single instance. @@ -112,13 +112,14 @@ impl<'a, F: PrimeField, MM: MarlinMode> State<'a, F, MM> { let first_padded_public_inputs = &variable_assignments[0].0; let input_domain = EvaluationDomain::new(first_padded_public_inputs.len()).unwrap(); - let batch_size = variable_assignments.len(); + let batch_size = variable_assignments.len().try_into()?; total_instances += batch_size; - let mut z_as = Vec::with_capacity(batch_size); - let mut z_bs = Vec::with_capacity(batch_size); - let mut x_polys = Vec::with_capacity(batch_size); - let mut padded_public_variables = Vec::with_capacity(batch_size); - let mut private_variables = Vec::with_capacity(batch_size); + let batch_size_usize = batch_size as usize; + let mut z_as = Vec::with_capacity(batch_size_usize); + let mut z_bs = Vec::with_capacity(batch_size_usize); + let mut x_polys = Vec::with_capacity(batch_size_usize); + let mut padded_public_variables = Vec::with_capacity(batch_size_usize); + let mut private_variables = Vec::with_capacity(batch_size_usize); for Assignments(padded_public_input, private_input, z_a, z_b) in variable_assignments { z_as.push(z_a); @@ -163,7 +164,7 @@ impl<'a, F: PrimeField, MM: MarlinMode> State<'a, F, MM> { } /// Get the batch size for a given circuit. - pub fn batch_size(&self, circuit: &Circuit) -> Option { + pub fn batch_size(&self, circuit: &Circuit) -> Option { self.circuit_specific_states.get(circuit).map(|s| s.batch_size) } diff --git a/algorithms/src/snark/marlin/ahp/verifier/messages.rs b/algorithms/src/snark/marlin/ahp/verifier/messages.rs index 2a9fe317a2..e9e44e67b3 100644 --- a/algorithms/src/snark/marlin/ahp/verifier/messages.rs +++ b/algorithms/src/snark/marlin/ahp/verifier/messages.rs @@ -16,7 +16,7 @@ use snarkvm_fields::PrimeField; use crate::snark::marlin::{witness_label, CircuitId, MarlinMode}; use itertools::Itertools; -use std::collections::BTreeMap; +use std::{collections::BTreeMap, num::TryFromIntError}; /// Randomizers used to combine circuit-specific and instance-specific elements in the AHP sumchecks #[derive(Clone, Debug)] @@ -80,7 +80,7 @@ pub struct QuerySet { } impl QuerySet { - pub fn new(state: &super::State) -> Self { + pub fn new(state: &super::State) -> Result { let beta = state.second_round_message.unwrap().beta; let gamma = state.gamma.unwrap(); // For the first linear combination @@ -91,8 +91,12 @@ impl QuerySet { // Note that z is the interpolation of x || w, so it equals x + v_X * w // We also use an optimization: instead of explicitly calculating z_c, we // use the "virtual oracle" z_a * z_b - Self { - batch_sizes: state.circuit_specific_states.iter().map(|(c, s)| (*c, s.batch_size)).collect(), + Ok(Self { + batch_sizes: state + .circuit_specific_states + .iter() + .map(|(c, s)| Ok::<_, TryFromIntError>((*c, usize::try_from(s.batch_size)?))) + .try_collect()?, g_1_query: ("beta".into(), beta), z_b_query: ("beta".into(), beta), lincheck_sumcheck_query: ("beta".into(), beta), @@ -101,16 +105,16 @@ impl QuerySet { g_b_query: ("gamma".into(), gamma), g_c_query: ("gamma".into(), gamma), matrix_sumcheck_query: ("gamma".into(), gamma), - } + }) } /// Returns a `BTreeSet` containing elements of the form /// `(polynomial_label, (query_label, query))`. - pub fn to_set(&self) -> crate::polycommit::sonic_pc::QuerySet { + pub fn to_set(&self) -> Result, TryFromIntError> { let mut query_set = crate::polycommit::sonic_pc::QuerySet::new(); for (&circuit_id, &batch_size) in self.batch_sizes.iter() { for j in 0..batch_size { - query_set.insert((witness_label(circuit_id, "z_b", j), self.z_b_query.clone())); + query_set.insert((witness_label(circuit_id, "z_b", u32::try_from(j)?), self.z_b_query.clone())); } query_set.insert((witness_label(circuit_id, "g_a", 0), self.g_a_query.clone())); query_set.insert((witness_label(circuit_id, "g_b", 0), self.g_b_query.clone())); @@ -119,6 +123,6 @@ impl QuerySet { query_set.insert(("g_1".into(), self.g_1_query.clone())); query_set.insert(("lincheck_sumcheck".into(), self.lincheck_sumcheck_query.clone())); query_set.insert(("matrix_sumcheck".into(), self.matrix_sumcheck_query.clone())); - query_set + Ok(query_set) } } diff --git a/algorithms/src/snark/marlin/ahp/verifier/state.rs b/algorithms/src/snark/marlin/ahp/verifier/state.rs index 60f05b6627..c452ed1651 100644 --- a/algorithms/src/snark/marlin/ahp/verifier/state.rs +++ b/algorithms/src/snark/marlin/ahp/verifier/state.rs @@ -35,7 +35,7 @@ pub struct CircuitSpecificState { pub(crate) non_zero_c_domain: EvaluationDomain, /// The number of instances being proved in this batch. - pub(in crate::snark::marlin) batch_size: usize, + pub(in crate::snark::marlin) batch_size: u32, } /// State of the AHP verifier. #[derive(Debug)] diff --git a/algorithms/src/snark/marlin/ahp/verifier/verifier.rs b/algorithms/src/snark/marlin/ahp/verifier/verifier.rs index c9d0aa1537..5175826928 100644 --- a/algorithms/src/snark/marlin/ahp/verifier/verifier.rs +++ b/algorithms/src/snark/marlin/ahp/verifier/verifier.rs @@ -30,12 +30,12 @@ use crate::{ }; use smallvec::SmallVec; use snarkvm_fields::PrimeField; -use std::collections::BTreeMap; +use std::{collections::BTreeMap, num::TryFromIntError}; impl AHPForR1CS { /// Output the first message and next round state. pub fn verifier_first_round>( - batch_sizes: &BTreeMap, + batch_sizes: &BTreeMap, circuit_infos: &BTreeMap>, max_constraint_domain: EvaluationDomain, largest_non_zero_domain: EvaluationDomain, @@ -49,14 +49,14 @@ impl AHPForR1CS { let mut num_circuit_combiners = vec![1; batch_sizes.len()]; num_circuit_combiners[0] = 0; // the first circuit_combiner is TargetField::one() and needs no random sampling - for ((batch_size, (circuit_id, circuit_info)), num_c_combiner) in + for ((&batch_size, (circuit_id, circuit_info)), num_c_combiner) in batch_sizes.values().zip(circuit_infos).zip(num_circuit_combiners) { let squeeze_time = start_timer!(|| format!("Squeezing challenges for {circuit_id}")); - let elems = fs_rng.squeeze_nonnative_field_elements(*batch_size - 1 + num_c_combiner); + let elems = fs_rng.squeeze_nonnative_field_elements(batch_size as usize - 1 + num_c_combiner); end_timer!(squeeze_time); - let (instance_combiners, circuit_combiner) = elems.split_at(*batch_size - 1); + let (instance_combiners, circuit_combiner) = elems.split_at(batch_size as usize - 1); assert_eq!(circuit_combiner.len(), num_c_combiner); let mut combiners = BatchCombiners { circuit_combiner: TargetField::one(), instance_combiners: vec![TargetField::one()] }; @@ -102,7 +102,7 @@ impl AHPForR1CS { non_zero_a_domain, non_zero_b_domain, non_zero_c_domain, - batch_size: *batch_size, + batch_size, }; circuit_specific_states.insert(*circuit_id, circuit_specific_state); } @@ -182,7 +182,9 @@ impl AHPForR1CS { } /// Output the query state and next round state. - pub fn verifier_query_set(state: State) -> (QuerySet, State) { - (QuerySet::new(&state), state) + pub fn verifier_query_set( + state: State, + ) -> Result<(QuerySet, State), TryFromIntError> { + Ok((QuerySet::new(&state)?, state)) } } diff --git a/algorithms/src/snark/marlin/data_structures/circuit_verifying_key.rs b/algorithms/src/snark/marlin/data_structures/circuit_verifying_key.rs index 90973534ad..62bcf48600 100644 --- a/algorithms/src/snark/marlin/data_structures/circuit_verifying_key.rs +++ b/algorithms/src/snark/marlin/data_structures/circuit_verifying_key.rs @@ -109,11 +109,6 @@ impl ToMinimalBits for CircuitVerifyingKey { @@ -80,14 +83,14 @@ impl Commitments { } fn deserialize_with_mode( - batch_sizes: &[usize], + batch_sizes: &[u32], mut reader: R, compress: Compress, validate: Validate, ) -> Result { - let mut w = Vec::with_capacity(batch_sizes.iter().sum()); - for batch_size in batch_sizes { - w.extend(deserialize_vec_without_len(&mut reader, compress, validate, *batch_size)?); + let mut w = Vec::with_capacity(batch_sizes.iter().sum::().try_into()?); + for &batch_size in batch_sizes { + w.extend(deserialize_vec_without_len(&mut reader, compress, validate, batch_size.try_into()?)?); } Ok(Commitments { witness_commitments: w, @@ -155,16 +158,17 @@ impl Evaluations { } fn deserialize_with_mode( - batch_sizes: &[usize], + batch_sizes: &[u32], mut reader: R, compress: Compress, validate: Validate, ) -> Result { + let mut z_b_evals = Vec::with_capacity(batch_sizes.len()); + for &batch_size in batch_sizes { + z_b_evals.push(deserialize_vec_without_len(&mut reader, compress, validate, batch_size.try_into()?)?); + } Ok(Evaluations { - z_b_evals: batch_sizes - .iter() - .map(|batch_size| deserialize_vec_without_len(&mut reader, compress, validate, *batch_size)) - .collect::>()?, + z_b_evals, g_1_eval: CanonicalDeserialize::deserialize_with_mode(&mut reader, compress, validate)?, g_a_evals: deserialize_vec_without_len(&mut reader, compress, validate, batch_sizes.len())?, g_b_evals: deserialize_vec_without_len(&mut reader, compress, validate, batch_sizes.len())?, @@ -176,8 +180,8 @@ impl Evaluations { impl Evaluations { pub(crate) fn from_map( map: &std::collections::BTreeMap, - batch_sizes: BTreeMap, - ) -> Self { + batch_sizes: BTreeMap, + ) -> Result { let mut z_b_evals_collect: BTreeMap> = BTreeMap::new(); let mut g_a_evals = Vec::with_capacity(batch_sizes.len()); let mut g_b_evals = Vec::with_capacity(batch_sizes.len()); @@ -193,7 +197,7 @@ impl Evaluations { if let Some(z_b_i) = z_b_evals_collect.get_mut(&circuit_id) { z_b_i.push(*value); } else { - let mut values = Vec::with_capacity(batch_sizes[&circuit_id]); + let mut values = Vec::with_capacity(batch_sizes[&circuit_id] as usize); values.push(*value); z_b_evals_collect.insert(circuit_id, values); } @@ -206,26 +210,26 @@ impl Evaluations { } } let z_b_evals = z_b_evals_collect.into_values().collect(); - Self { z_b_evals, g_1_eval: map["g_1"], g_a_evals, g_b_evals, g_c_evals } + Ok(Self { z_b_evals, g_1_eval: map["g_1"], g_a_evals, g_b_evals, g_c_evals }) } - pub(crate) fn get(&self, circuit_index: usize, label: &str) -> Option { + pub(crate) fn get(&self, circuit_index: usize, label: &str) -> Result, ParseIntError> { if label == "g_1" { - return Some(self.g_1_eval); + return Ok(Some(self.g_1_eval)); } if let Some(index) = label.find("z_b_") { let z_b_eval_circuit = &self.z_b_evals[circuit_index]; - let instance_index = label[index + 4..].parse::().unwrap(); - z_b_eval_circuit.get(instance_index).copied() + let instance_index = label[index + 4..].parse::()?; + Ok(z_b_eval_circuit.get(instance_index).copied()) } else if label.contains("g_a") { - self.g_a_evals.get(circuit_index).copied() + Ok(self.g_a_evals.get(circuit_index).copied()) } else if label.contains("g_b") { - self.g_b_evals.get(circuit_index).copied() + Ok(self.g_b_evals.get(circuit_index).copied()) } else if label.contains("g_c") { - self.g_c_evals.get(circuit_index).copied() + Ok(self.g_c_evals.get(circuit_index).copied()) } else { - None + Ok(None) } } @@ -253,7 +257,7 @@ impl Valid for Evaluations { #[derive(Clone, Debug, PartialEq, Eq)] pub struct Proof { /// The number of instances being proven in this proof. - batch_sizes: Vec, + batch_sizes: Vec, /// Commitments to prover polynomials. pub commitments: Commitments, @@ -271,35 +275,36 @@ pub struct Proof { impl Proof { /// Construct a new proof. pub fn new( - batch_sizes: BTreeMap, + batch_sizes_in: BTreeMap, commitments: Commitments, evaluations: Evaluations, msg: ahp::prover::ThirdMessage, pc_proof: sonic_pc::BatchLCProof, ) -> Result { let mut total_instances = 0; - let batch_sizes: Vec = batch_sizes.into_values().collect(); - for (z_b_evals, batch_size) in evaluations.z_b_evals.iter().zip(&batch_sizes) { + let mut batch_sizes = Vec::with_capacity(batch_sizes_in.len()); + for (z_b_evals, &batch_size) in evaluations.z_b_evals.iter().zip(batch_sizes_in.values()) { total_instances += batch_size; - if z_b_evals.len() != *batch_size { + batch_sizes.push(batch_size); + if z_b_evals.len() != batch_size as usize { return Err(SNARKError::BatchSizeMismatch); } } - if commitments.witness_commitments.len() != total_instances { + if commitments.witness_commitments.len() != total_instances as usize { return Err(SNARKError::BatchSizeMismatch); } Ok(Self { batch_sizes, commitments, evaluations, msg, pc_proof }) } - pub fn batch_sizes(&self) -> Result<&[usize], SNARKError> { + pub fn batch_sizes(&self) -> Result<&[u32], SNARKError> { let mut total_instances = 0; for (z_b_evals_i, &batch_size) in self.evaluations.z_b_evals.iter().zip(self.batch_sizes.iter()) { total_instances += batch_size; - if z_b_evals_i.len() != batch_size { + if u32::try_from(z_b_evals_i.len())? != batch_size { return Err(SNARKError::BatchSizeMismatch); } } - if self.commitments.witness_commitments.len() != total_instances { + if u32::try_from(self.commitments.witness_commitments.len())? != total_instances { return Err(SNARKError::BatchSizeMismatch); } Ok(&self.batch_sizes) @@ -308,8 +313,7 @@ impl Proof { impl CanonicalSerialize for Proof { fn serialize_with_mode(&self, mut writer: W, compress: Compress) -> Result<(), SerializationError> { - let batch_sizes: Vec = self.batch_sizes.iter().map(|x| u64::try_from(*x)).collect::>()?; - CanonicalSerialize::serialize_with_mode(&batch_sizes, &mut writer, compress)?; + CanonicalSerialize::serialize_with_mode(&self.batch_sizes, &mut writer, compress)?; Commitments::serialize_with_mode(&self.commitments, &mut writer, compress)?; Evaluations::serialize_with_mode(&self.evaluations, &mut writer, compress)?; CanonicalSerialize::serialize_with_mode(&self.msg, &mut writer, compress)?; @@ -344,8 +348,7 @@ impl CanonicalDeserialize for Proof { compress: Compress, validate: Validate, ) -> Result { - let batch_sizes: Vec = CanonicalDeserialize::deserialize_with_mode(&mut reader, compress, validate)?; - let batch_sizes: Vec = batch_sizes.into_iter().map(|x| x as usize).collect(); + let batch_sizes: Vec = CanonicalDeserialize::deserialize_with_mode(&mut reader, compress, validate)?; Ok(Proof { commitments: Commitments::deserialize_with_mode(&batch_sizes, &mut reader, compress, validate)?, evaluations: Evaluations::deserialize_with_mode(&batch_sizes, &mut reader, compress, validate)?, @@ -369,6 +372,7 @@ impl FromBytes for Proof { } #[cfg(test)] +#[allow(clippy::cast_possible_truncation)] mod test { #![allow(non_camel_case_types)] @@ -446,7 +450,7 @@ mod test { for j in 1..11 { let test_with_none = i * j % 2 == 0; let commitments = rand_commitments(i, j, test_with_none); - let batch_sizes = vec![i; j]; + let batch_sizes = vec![i as u32; j]; let combinations = modes(); for (compress, validate) in combinations { let size = Commitments::serialized_size(&commitments, compress); @@ -467,7 +471,7 @@ mod test { for i in 1..11 { for j in 1..11 { let evaluations: Evaluations = rand_evaluations(rng, i, j); - let batch_sizes = vec![i; j]; + let batch_sizes = vec![i as u32; j]; let combinations = modes(); for (compress, validate) in combinations { let size = Evaluations::serialized_size(&evaluations, compress); @@ -488,7 +492,7 @@ mod test { for i in 1..11 { for j in 1..11 { let test_with_none = i * j % 2 == 0; - let batch_sizes = vec![i; j]; + let batch_sizes = vec![i as u32; j]; let commitments = rand_commitments(i, j, test_with_none); let evaluations: Evaluations = rand_evaluations(rng, i, j); let msg = ahp::prover::ThirdMessage:: { sums: vec![rand_sums(rng); j] }; diff --git a/algorithms/src/snark/marlin/marlin.rs b/algorithms/src/snark/marlin/marlin.rs index 23906c3801..a8a0600552 100644 --- a/algorithms/src/snark/marlin/marlin.rs +++ b/algorithms/src/snark/marlin/marlin.rs @@ -49,7 +49,7 @@ use snarkvm_fields::{One, PrimeField, ToConstraintField, Zero}; use snarkvm_r1cs::ConstraintSynthesizer; use snarkvm_utilities::{to_bytes_le, ToBytes}; -use std::{borrow::Borrow, collections::BTreeMap, ops::Deref, sync::Arc}; +use std::{borrow::Borrow, collections::BTreeMap, num::TryFromIntError, ops::Deref, sync::Arc}; #[cfg(not(feature = "std"))] use snarkvm_utilities::println; @@ -102,10 +102,11 @@ impl, MM: MarlinMode> MarlinSNAR let indexed_circuit = AHPForR1CS::<_, MM>::index(*circuit)?; // TODO: Add check that c is in the correct mode. // Increase the universal SRS size to support the circuit size. - if universal_srs.max_degree() < indexed_circuit.max_degree() { - universal_srs.download_powers_for(0..indexed_circuit.max_degree()).map_err(|_| { - MarlinError::IndexTooLarge(universal_srs.max_degree(), indexed_circuit.max_degree()) - })?; + let max_degree = usize::try_from(indexed_circuit.max_degree())?; + if universal_srs.max_degree() < max_degree { + universal_srs + .download_powers_for(0..max_degree) + .map_err(|_| MarlinError::IndexTooLarge(universal_srs.max_degree(), max_degree))?; } let coefficient_support = AHPForR1CS::<_, MM>::get_degree_bounds(&indexed_circuit.index_info); @@ -113,7 +114,7 @@ impl, MM: MarlinMode> MarlinSNAR let supported_hiding_bound = 1; let (committer_key, verifier_key) = SonicKZG10::::trim( universal_srs, - indexed_circuit.max_degree(), + max_degree, [indexed_circuit.constraint_domain_size()], supported_hiding_bound, Some(coefficient_support.as_slice()), @@ -166,7 +167,7 @@ impl, MM: MarlinMode> MarlinSNAR fn init_sponge<'a>( fs_parameters: &FS::Parameters, - inputs_and_batch_sizes: &BTreeMap])>, + inputs_and_batch_sizes: &BTreeMap])>, circuit_commitments: impl Iterator]>, ) -> FS { let mut sponge = FS::new_with_parameters(fs_parameters); @@ -237,7 +238,7 @@ where type Proof = Proof; type ProvingKey = CircuitProvingKey; type ScalarField = E::Fr; - type UniversalSetupConfig = usize; + type UniversalSetupConfig = u32; type UniversalSetupParameters = UniversalSRS; type VerifierInput = [E::Fr]; type VerifyingKey = CircuitVerifyingKey; @@ -245,7 +246,7 @@ where fn universal_setup(max_degree: &Self::UniversalSetupConfig) -> Result { let setup_time = start_timer!(|| { format!("Marlin::UniversalSetup with max_degree {max_degree}",) }); - let srs = SonicKZG10::::load_srs(*max_degree).map_err(Into::into); + let srs = SonicKZG10::::load_srs((*max_degree).try_into()?).map_err(Into::into); end_timer!(setup_time); srs } @@ -397,12 +398,12 @@ where let padded_public_input = prover_state.padded_public_inputs(&pk.circuit).ok_or(SNARKError::CircuitNotFound)?; let circuit_id = pk.circuit.id; - batch_sizes.insert(circuit_id, batch_size); + circuit_ids.push(circuit_id); circuit_infos.insert(circuit_id, &pk.circuit_verifying_key.circuit_info); inputs_and_batch_sizes.insert(circuit_id, (batch_size, padded_public_input)); - total_instances += batch_size; public_inputs.insert(circuit_id, public_input); - circuit_ids.push(circuit_id); + total_instances += batch_size; + batch_sizes.insert(circuit_id, batch_size); } assert_eq!(prover_state.total_instances, total_instances); @@ -524,7 +525,7 @@ where assert!( polynomials.len() == keys_to_constraints.len() * 12 + // polys for row, col, rowcol, val - AHPForR1CS::::num_first_round_oracles(total_instances) + + AHPForR1CS::::num_first_round_oracles(total_instances.try_into()?) + AHPForR1CS::::num_second_round_oracles() + AHPForR1CS::::num_third_round_oracles(keys_to_constraints.len()) + AHPForR1CS::::num_fourth_round_oracles() @@ -586,7 +587,7 @@ where } // Compute the AHP verifier's query set. - let (query_set, verifier_state) = AHPForR1CS::<_, MM>::verifier_query_set(verifier_state); + let (query_set, verifier_state) = AHPForR1CS::<_, MM>::verifier_query_set(verifier_state)?; let lc_s = AHPForR1CS::<_, MM>::construct_linear_combinations( &public_inputs, &polynomials, @@ -598,7 +599,7 @@ where let eval_time = start_timer!(|| "Evaluating linear combinations over query set"); let mut evaluations = std::collections::BTreeMap::new(); - for (label, (_, point)) in query_set.to_set() { + for (label, (_, point)) in query_set.to_set()? { if !AHPForR1CS::::LC_WITH_ZERO_EVAL.contains(&label.as_str()) { let lc = lc_s.get(&label).ok_or_else(|| AHPError::MissingEval(label.to_string()))?; let evaluation = polynomials.get_lc_eval(lc, point)?; @@ -606,7 +607,7 @@ where } } - let evaluations = proof::Evaluations::from_map(&evaluations, batch_sizes.clone()); + let evaluations = proof::Evaluations::from_map(&evaluations, batch_sizes.clone())?; end_timer!(eval_time); Self::terminate(terminator)?; @@ -618,7 +619,7 @@ where lc_s.values(), polynomials, &labeled_commitments, - &query_set.to_set(), + &query_set.to_set()?, &commitment_randomnesses, &mut sponge, )?; @@ -664,7 +665,7 @@ where return Err(SNARKError::EmptyBatch); } - if public_inputs_i.len() != batch_sizes_vec[i] { + if u32::try_from(public_inputs_i.len())? != batch_sizes_vec[i] { return Err(SNARKError::BatchSizeMismatch); } } @@ -714,7 +715,7 @@ where circuit_infos.insert(circuit_id, &vk.orig_vk.circuit_info); circuit_ids.push(circuit_id); } - for (i, (vk, &batch_size)) in keys_to_inputs.keys().zip(batch_sizes.values()).enumerate() { + for (i, (vk, &batch_size)) in keys_to_inputs.keys().zip(batch_sizes_vec).enumerate() { inputs_and_batch_sizes.insert(vk.orig_vk.id, (batch_size, padded_public_vec[i].as_slice())); } @@ -739,33 +740,33 @@ where let first_round_info = AHPForR1CS::::first_round_polynomial_info(batch_sizes.iter()); let mut first_comms_consumed = 0; - let mut first_commitments = batch_sizes - .iter() - .flat_map(|(&circuit_id, &batch_size)| { - let first_comms = comms.witness_commitments[first_comms_consumed..][..batch_size] - .iter() - .enumerate() - .flat_map(|(j, w_comm)| { - [ - LabeledCommitment::new_with_info( - &first_round_info[&witness_label(circuit_id, "w", j)], - w_comm.w, - ), - LabeledCommitment::new_with_info( - &first_round_info[&witness_label(circuit_id, "z_a", j)], - w_comm.z_a, - ), - LabeledCommitment::new_with_info( - &first_round_info[&witness_label(circuit_id, "z_b", j)], - w_comm.z_b, - ), - ] - }) - .collect_vec(); - first_comms_consumed += batch_size; - first_comms - }) - .collect_vec(); + let total_instances = batch_sizes.values().sum::(); + let mut first_commitments = Vec::with_capacity(total_instances as usize); + for (&circuit_id, &batch_size) in batch_sizes.iter() { + let first_comms = comms.witness_commitments[first_comms_consumed as usize..][..batch_size as usize] + .iter() + .enumerate() + .map(|(j, w_comm)| { + Ok::<_, TryFromIntError>([ + LabeledCommitment::new_with_info( + &first_round_info[&witness_label(circuit_id, "w", u32::try_from(j)?)], + w_comm.w, + ), + LabeledCommitment::new_with_info( + &first_round_info[&witness_label(circuit_id, "z_a", u32::try_from(j)?)], + w_comm.z_a, + ), + LabeledCommitment::new_with_info( + &first_round_info[&witness_label(circuit_id, "z_b", u32::try_from(j)?)], + w_comm.z_b, + ), + ]) + }); + first_comms_consumed += batch_size; + for comms in first_comms { + first_commitments.extend(comms?); + } + } if MM::ZK { first_commitments.push(LabeledCommitment::new_with_info( @@ -860,7 +861,7 @@ where .collect(); let query_set_time = start_timer!(|| "Constructing query set"); - let (query_set, verifier_state) = AHPForR1CS::<_, MM>::verifier_query_set(verifier_state); + let (query_set, verifier_state) = AHPForR1CS::<_, MM>::verifier_query_set(verifier_state)?; end_timer!(query_set_time); sponge.absorb_nonnative_field_elements(proof.evaluations.to_field_elements()); @@ -869,7 +870,7 @@ where let mut current_circuit_id = "".to_string(); let mut circuit_index: i64 = -1; - for (label, (_point_name, q)) in query_set.to_set() { + for (label, (_point_name, q)) in query_set.to_set()? { if AHPForR1CS::::LC_WITH_ZERO_EVAL.contains(&label.as_ref()) { evaluations.insert((label, q), E::Fr::zero()); } else { @@ -882,7 +883,7 @@ where } let eval = proof .evaluations - .get(circuit_index as usize, &label) + .get(usize::try_from(circuit_index)?, &label)? .ok_or_else(|| AHPError::MissingEval(label.clone()))?; evaluations.insert((label, q), eval); } @@ -902,7 +903,7 @@ where &verifier_key, lc_s.values(), &commitments, - &query_set.to_set(), + &query_set.to_set()?, &evaluations, &proof.pc_proof, &mut sponge, diff --git a/algorithms/src/traits/algebraic_sponge.rs b/algorithms/src/traits/algebraic_sponge.rs index dbda007ef3..b7700bc28b 100644 --- a/algorithms/src/traits/algebraic_sponge.rs +++ b/algorithms/src/traits/algebraic_sponge.rs @@ -66,7 +66,7 @@ pub trait AlgebraicSponge: Clone + Debug { self.absorb_native_field_elements(&elements); } - /// Takes in field elements. + /// Takes out field elements. fn squeeze_native_field_elements(&mut self, num: usize) -> SmallVec<[F; 10]>; /// Takes out field elements. @@ -128,17 +128,17 @@ pub(crate) mod nonnative_params { #[derive(Clone, Debug)] pub struct NonNativeFieldParams { /// The number of limbs (`BaseField` elements) used to represent a `TargetField` element. Highest limb first. - pub num_limbs: usize, + pub num_limbs: u32, /// The number of bits of the limb - pub bits_per_limb: usize, + pub bits_per_limb: u32, } /// Obtain the parameters from a `ConstraintSystem`'s cache or generate a new one #[must_use] pub const fn get_params( - target_field_size: usize, - base_field_size: usize, + target_field_size: u32, + base_field_size: u32, optimization_type: OptimizationType, ) -> NonNativeFieldParams { let (num_of_limbs, limb_size) = find_parameters(base_field_size, target_field_size, optimization_type); @@ -156,21 +156,21 @@ pub(crate) mod nonnative_params { /// A function to search for parameters for nonnative field gadgets pub const fn find_parameters( - base_field_prime_length: usize, - target_field_prime_bit_length: usize, + base_field_prime_length: u32, + target_field_prime_bit_length: u32, optimization_type: OptimizationType, - ) -> (usize, usize) { + ) -> (u32, u32) { let mut found = false; - let mut min_cost = 0usize; - let mut min_cost_limb_size = 0usize; - let mut min_cost_num_of_limbs = 0usize; + let mut min_cost = 0u32; + let mut min_cost_limb_size = 0u32; + let mut min_cost_num_of_limbs = 0u32; let surfeit = 10; let mut max_limb_size = (base_field_prime_length - 1 - surfeit - 1) / 2 - 1; if max_limb_size > target_field_prime_bit_length { max_limb_size = target_field_prime_bit_length; } - let mut limb_size = 1; + let mut limb_size = 1u32; while limb_size <= max_limb_size { let num_of_limbs = (target_field_prime_bit_length + limb_size - 1) / limb_size; diff --git a/fields/src/traits/prime_field.rs b/fields/src/traits/prime_field.rs index 35f5614148..c662b4ba25 100644 --- a/fields/src/traits/prime_field.rs +++ b/fields/src/traits/prime_field.rs @@ -49,6 +49,11 @@ pub trait PrimeField: Self::Parameters::MODULUS_BITS as usize } + /// Returns the field size in bits as u32 + fn size_in_bits_u32() -> u32 { + Self::Parameters::MODULUS_BITS + } + /// Returns the capacity size for data bits. fn size_in_data_bits() -> usize { Self::Parameters::CAPACITY as usize diff --git a/r1cs/src/errors.rs b/r1cs/src/errors.rs index f1ab0a42b6..f3163cb60a 100644 --- a/r1cs/src/errors.rs +++ b/r1cs/src/errors.rs @@ -29,24 +29,27 @@ pub enum SynthesisError { /// During synthesis, we divided by zero. #[error("Division by zero during synthesis")] DivisionByZero, - /// During synthesis, we constructed an unsatisfiable constraint system. - #[error("Unsatisfiable constraint system")] - Unsatisfiable, - /// During synthesis, our polynomials ended up being too high of degree - #[error("Polynomial degree is too large")] - PolynomialDegreeTooLarge, - /// During proof generation, we encountered an identity in the CRS - #[error("Encountered an identity element in the CRS")] - UnexpectedIdentity, + /// During synthesis, we could not recover our desired int. + #[error(transparent)] + IntError(#[from] std::num::TryFromIntError), /// During proof generation, we encountered an I/O error with the CRS #[error("Encountered an I/O error")] IoError(std::io::Error), /// During verification, our verifying key was malformed. #[error("Malformed verifying key, public input count was {} but expected {}", _0, _1)] MalformedVerifyingKey(usize, usize), + /// During synthesis, our polynomials ended up being too high of degree + #[error("Polynomial degree is too large")] + PolynomialDegreeTooLarge, /// During CRS generation, we observed an unconstrained auxiliary variable #[error("Auxiliary variable was unconstrained")] UnconstrainedVariable, + /// During proof generation, we encountered an identity in the CRS + #[error("Encountered an identity element in the CRS")] + UnexpectedIdentity, + /// During synthesis, we constructed an unsatisfiable constraint system. + #[error("Unsatisfiable constraint system")] + Unsatisfiable, } impl From for SynthesisError { diff --git a/synthesizer/coinbase/src/helpers/coinbase_solution/bytes.rs b/synthesizer/coinbase/src/helpers/coinbase_solution/bytes.rs index bbc973432b..f0fe95acb2 100644 --- a/synthesizer/coinbase/src/helpers/coinbase_solution/bytes.rs +++ b/synthesizer/coinbase/src/helpers/coinbase_solution/bytes.rs @@ -13,6 +13,7 @@ // limitations under the License. use super::*; +use snarkvm_utilities::try_write_as; impl FromBytes for CoinbaseSolution { /// Reads the coinbase solution from the buffer. @@ -34,7 +35,7 @@ impl FromBytes for CoinbaseSolution { impl ToBytes for CoinbaseSolution { /// Writes the coinbase solution to the buffer. fn write_le(&self, mut writer: W) -> IoResult<()> { - (u32::try_from(self.partial_solutions.len()).map_err(|e| error(e.to_string()))?).write_le(&mut writer)?; + try_write_as::(self.partial_solutions.len(), &mut writer)?; for individual_puzzle_solution in &self.partial_solutions { individual_puzzle_solution.write_le(&mut writer)?; diff --git a/synthesizer/coinbase/src/lib.rs b/synthesizer/coinbase/src/lib.rs index f6a2325006..97d5f46a2e 100644 --- a/synthesizer/coinbase/src/lib.rs +++ b/synthesizer/coinbase/src/lib.rs @@ -335,7 +335,7 @@ impl CoinbasePuzzle { cfg_iter!(coinbase_solution.partial_solutions()).map(|solution| solution.commitment().0).collect(); let fs_challenges = challenge_points.into_iter().map(|f| f.to_bigint()).collect::>(); let accumulator_commitment = - KZGCommitment::(VariableBase::msm(&commitments, &fs_challenges).into()); + KZGCommitment::(VariableBase::msm(&commitments, &fs_challenges)?.into()); // Retrieve the coinbase verifying key. let coinbase_verifying_key = match self { diff --git a/synthesizer/src/block/transaction/deployment/bytes.rs b/synthesizer/src/block/transaction/deployment/bytes.rs index 88df477b4e..9a964e3339 100644 --- a/synthesizer/src/block/transaction/deployment/bytes.rs +++ b/synthesizer/src/block/transaction/deployment/bytes.rs @@ -13,6 +13,7 @@ // limitations under the License. use super::*; +use snarkvm_utilities::try_write_as; impl FromBytes for Deployment { /// Reads the deployment from a buffer. @@ -59,7 +60,7 @@ impl ToBytes for Deployment { // Write the program. self.program.write_le(&mut writer)?; // Write the number of entries in the bundle. - (u16::try_from(self.verifying_keys.len()).map_err(|e| error(e.to_string()))?).write_le(&mut writer)?; + try_write_as::(self.verifying_keys.len(), &mut writer)?; // Write each entry. for (function_name, (verifying_key, certificate)) in &self.verifying_keys { // Write the function name. diff --git a/synthesizer/src/block/transaction/execution/bytes.rs b/synthesizer/src/block/transaction/execution/bytes.rs index e2475d7ef1..5a66486266 100644 --- a/synthesizer/src/block/transaction/execution/bytes.rs +++ b/synthesizer/src/block/transaction/execution/bytes.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use snarkvm_utilities::try_write_as; + use super::*; impl FromBytes for Execution { @@ -54,7 +56,7 @@ impl ToBytes for Execution { // Write the version. 0u8.write_le(&mut writer)?; // Write the number of transitions. - (u8::try_from(self.transitions.len()).map_err(|e| error(e.to_string()))?).write_le(&mut writer)?; + try_write_as::(self.transitions.len(), &mut writer)?; // Write the transitions. for transition in self.transitions.values() { transition.write_le(&mut writer)?; diff --git a/synthesizer/src/block/transactions/confirmed/bytes.rs b/synthesizer/src/block/transactions/confirmed/bytes.rs index 4959a7ced6..6991ccc177 100644 --- a/synthesizer/src/block/transactions/confirmed/bytes.rs +++ b/synthesizer/src/block/transactions/confirmed/bytes.rs @@ -13,6 +13,7 @@ // limitations under the License. use super::*; +use snarkvm_utilities::try_write_as; impl FromBytes for ConfirmedTransaction { /// Reads the confirmed transaction from a buffer. @@ -82,7 +83,7 @@ impl ToBytes for ConfirmedTransaction { // Write the transaction. transaction.write_le(&mut writer)?; // Write the number of finalize operations. - NumFinalizeSize::try_from(finalize.len()).map_err(|e| error(e.to_string()))?.write_le(&mut writer)?; + try_write_as::(finalize.len(), &mut writer)?; // Write the finalize operations. finalize.iter().try_for_each(|finalize| finalize.write_le(&mut writer)) } @@ -94,7 +95,7 @@ impl ToBytes for ConfirmedTransaction { // Write the transaction. transaction.write_le(&mut writer)?; // Write the number of finalize operations. - NumFinalizeSize::try_from(finalize.len()).map_err(|e| error(e.to_string()))?.write_le(&mut writer)?; + try_write_as::(finalize.len(), &mut writer)?; // Write the finalize operations. finalize.iter().try_for_each(|finalize| finalize.write_le(&mut writer)) } diff --git a/synthesizer/src/block/transition/bytes.rs b/synthesizer/src/block/transition/bytes.rs index 385fb84d98..d1fed62d61 100644 --- a/synthesizer/src/block/transition/bytes.rs +++ b/synthesizer/src/block/transition/bytes.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use snarkvm_utilities::try_write_as; + use super::*; impl FromBytes for Transition { @@ -98,12 +100,12 @@ impl ToBytes for Transition { self.function_name.write_le(&mut writer)?; // Write the number of inputs. - (u8::try_from(self.inputs.len()).map_err(|e| error(e.to_string()))?).write_le(&mut writer)?; + try_write_as::(self.inputs.len(), &mut writer)?; // Write the inputs. self.inputs.write_le(&mut writer)?; // Write the number of outputs. - (u8::try_from(self.outputs.len()).map_err(|e| error(e.to_string()))?).write_le(&mut writer)?; + try_write_as::(self.outputs.len(), &mut writer)?; // Write the outputs. self.outputs.write_le(&mut writer)?; @@ -117,7 +119,7 @@ impl ToBytes for Transition { // Write the finalize variant. 1u8.write_le(&mut writer)?; // Write the number of inputs to finalize. - (u8::try_from(finalize.len()).map_err(|e| error(e.to_string()))?).write_le(&mut writer)?; + try_write_as::(finalize.len(), &mut writer)?; // Write the inputs to finalize. finalize.write_le(&mut writer)?; } diff --git a/synthesizer/src/vm/execute.rs b/synthesizer/src/vm/execute.rs index 824bbbc8c7..7ccc343835 100644 --- a/synthesizer/src/vm/execute.rs +++ b/synthesizer/src/vm/execute.rs @@ -185,13 +185,13 @@ mod tests { // Assert the size of the transaction. let transaction_size_in_bytes = transaction.to_bytes_le().unwrap().len(); - assert_eq!(1387, transaction_size_in_bytes, "Update me if serialization has changed"); + assert_eq!(1383, transaction_size_in_bytes, "Update me if serialization has changed"); // Assert the size of the execution. assert!(matches!(transaction, Transaction::Execute(_, _, _))); if let Transaction::Execute(_, execution, _) = &transaction { let execution_size_in_bytes = execution.to_bytes_le().unwrap().len(); - assert_eq!(1352, execution_size_in_bytes, "Update me if serialization has changed"); + assert_eq!(1348, execution_size_in_bytes, "Update me if serialization has changed"); } } @@ -224,13 +224,13 @@ mod tests { // Assert the size of the transaction. let transaction_size_in_bytes = transaction.to_bytes_le().unwrap().len(); - assert_eq!(2222, transaction_size_in_bytes, "Update me if serialization has changed"); + assert_eq!(2214, transaction_size_in_bytes, "Update me if serialization has changed"); // Assert the size of the execution. assert!(matches!(transaction, Transaction::Execute(_, _, _))); if let Transaction::Execute(_, execution, _) = &transaction { let execution_size_in_bytes = execution.to_bytes_le().unwrap().len(); - assert_eq!(2187, execution_size_in_bytes, "Update me if serialization has changed"); + assert_eq!(2179, execution_size_in_bytes, "Update me if serialization has changed"); } } @@ -258,13 +258,13 @@ mod tests { // Assert the size of the transaction. let transaction_size_in_bytes = transaction.to_bytes_le().unwrap().len(); - assert_eq!(2099, transaction_size_in_bytes, "Update me if serialization has changed"); + assert_eq!(2091, transaction_size_in_bytes, "Update me if serialization has changed"); // Assert the size of the execution. assert!(matches!(transaction, Transaction::Execute(_, _, _))); if let Transaction::Execute(_, execution, _) = &transaction { let execution_size_in_bytes = execution.to_bytes_le().unwrap().len(); - assert_eq!(2064, execution_size_in_bytes, "Update me if serialization has changed"); + assert_eq!(2056, execution_size_in_bytes, "Update me if serialization has changed"); } } @@ -291,13 +291,13 @@ mod tests { // Assert the size of the transaction. let transaction_size_in_bytes = transaction.to_bytes_le().unwrap().len(); - assert_eq!(2111, transaction_size_in_bytes, "Update me if serialization has changed"); + assert_eq!(2103, transaction_size_in_bytes, "Update me if serialization has changed"); // Assert the size of the execution. assert!(matches!(transaction, Transaction::Execute(_, _, _))); if let Transaction::Execute(_, execution, _) = &transaction { let execution_size_in_bytes = execution.to_bytes_le().unwrap().len(); - assert_eq!(2076, execution_size_in_bytes, "Update me if serialization has changed"); + assert_eq!(2068, execution_size_in_bytes, "Update me if serialization has changed"); } } } diff --git a/synthesizer/src/vm/execute_fee.rs b/synthesizer/src/vm/execute_fee.rs index 0993e1de42..ecf31aa8b1 100644 --- a/synthesizer/src/vm/execute_fee.rs +++ b/synthesizer/src/vm/execute_fee.rs @@ -130,6 +130,6 @@ mod tests { }; // Assert the size of the transition. let fee_size_in_bytes = fee.to_bytes_le().unwrap().len(); - assert_eq!(1935, fee_size_in_bytes, "Update me if serialization has changed"); + assert_eq!(1927, fee_size_in_bytes, "Update me if serialization has changed"); } } diff --git a/utilities/src/serialize/helpers.rs b/utilities/src/serialize/helpers.rs index 53600b66f8..de3d7de2c1 100644 --- a/utilities/src/serialize/helpers.rs +++ b/utilities/src/serialize/helpers.rs @@ -13,12 +13,15 @@ // limitations under the License. pub use crate::{ + error, io::{self, Read, Write}, + serialize::traits::*, FromBytes, + SerializationError, ToBytes, Vec, }; -use crate::{serialize::traits::*, SerializationError}; +use std::fmt::Display; /// Serialize a Vector's elements without serializing the Vector's length /// If you want to serialize the full Vector, use `CanonicalSerialize for Vec` @@ -49,3 +52,11 @@ pub fn deserialize_vec_without_len( ) -> Result, SerializationError> { (0..len).map(|_| CanonicalDeserialize::deserialize_with_mode(&mut reader, compress, validate)).collect() } + +/// Try to cast a usize and write it +pub fn try_write_as + ToBytes, W: Write>(input: usize, writer: &mut W) -> io::Result<()> +where + >::Error: Display, +{ + T::try_from(input).map_err(|e| error(e.to_string()))?.write_le(writer) +}