From be49a7cf4f8d1aeb016696d9f186947d9ee50db6 Mon Sep 17 00:00:00 2001 From: Jack Grigg Date: Tue, 26 Apr 2022 08:48:16 +0000 Subject: [PATCH] Add `Field::sum_of_products` method Closes zkcrypto/ff#79. --- CHANGELOG.md | 2 ++ src/lib.rs | 15 +++++++++++++++ tests/derive.rs | 41 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 58 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 91390e9..45dd09e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,8 @@ and this library adheres to Rust's notion of [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +### Added +- `ff::Field::sum_of_products` ## [0.11.0] - 2021-09-02 ### Added diff --git a/src/lib.rs b/src/lib.rs index 57978aa..044d174 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -124,6 +124,21 @@ pub trait Field: res } + + /// Returns `pairs.into_iter().fold(Self::zero(), |acc, (a_i, b_i)| acc + a_i * b_i)`. + /// + /// This computes the "dot product" or "inner product" `a ⋅ b` of two equal-length + /// sequences of elements `a` and `b`, such that `pairs = a.zip(b)`. + /// + /// The provided implementation of this trait method uses the direct calculation given + /// above. Implementations of `Field` should override this to use more efficient + /// methods that take advantage of their internal representation, such as interleaving + /// or sharing modular reductions. + fn sum_of_products<'a, I: IntoIterator + Clone>(pairs: I) -> Self { + pairs + .into_iter() + .fold(Self::zero(), |acc, (a_i, b_i)| acc + (*a_i * b_i)) + } } /// This represents an element of a prime field. diff --git a/tests/derive.rs b/tests/derive.rs index 1fd5b25..e9ba2c8 100644 --- a/tests/derive.rs +++ b/tests/derive.rs @@ -21,6 +21,47 @@ mod fermat { struct Fermat65537Field([u64; 1]); } +#[test] +fn sum_of_products() { + use ff::{Field, PrimeField}; + + let one = Bls381K12Scalar::one(); + + // [1, 2, 3, 4] + let values: Vec<_> = (0..4) + .scan(one, |acc, _| { + let ret = *acc; + *acc += &one; + Some(ret) + }) + .collect(); + + // We'll pair each value with itself. + let expected = Bls381K12Scalar::from_str_vartime("30").unwrap(); + + // Check that we can produce the necessary input from two iterators. + assert_eq!( + // Directly produces (&v, &v) + Bls381K12Scalar::sum_of_products(values.iter().zip(values.iter())), + expected, + ); + + // Check that we can produce the necessary input from an iterator of values. + assert_eq!( + // Maps &v to (&v, &v) + Bls381K12Scalar::sum_of_products(values.iter().map(|v| (v, v))), + expected, + ); + + // Check that we can produce the necessary input from an iterator of tuples. + let tuples: Vec<_> = values.into_iter().map(|v| (v, v)).collect(); + assert_eq!( + // Maps &(a, b) to (&a, &b) + Bls381K12Scalar::sum_of_products(tuples.iter().map(|(a, b)| (a, b))), + expected, + ); +} + #[test] fn batch_inversion() { use ff::{BatchInverter, Field};