From 2c51ed5f10476e2442dbf23b734d782c6e0b4496 Mon Sep 17 00:00:00 2001 From: Andrey Khmuro Date: Tue, 6 Feb 2024 22:13:11 +0300 Subject: [PATCH] feat: add TryFrom and TryFrom bounds to the FieldElement --- air/src/proof/context.rs | 2 +- math/src/field/extensions/cubic.rs | 32 ++++++++++++++++++++++++-- math/src/field/extensions/quadratic.rs | 32 ++++++++++++++++++++++++-- math/src/field/f128/mod.rs | 14 +++++++++++ math/src/field/f62/mod.rs | 14 +++++++++++ math/src/field/f64/mod.rs | 14 +++++++++++ math/src/field/traits.rs | 2 ++ 7 files changed, 105 insertions(+), 5 deletions(-) diff --git a/air/src/proof/context.rs b/air/src/proof/context.rs index d60f54b04..a6e4bf0af 100644 --- a/air/src/proof/context.rs +++ b/air/src/proof/context.rs @@ -224,7 +224,7 @@ fn bytes_to_element(bytes: &[u8]) -> B { let mut buf = bytes.to_vec(); buf.resize(B::ELEMENT_BYTES, 0); - let element = match B::try_from(&buf) { + let element = match B::try_from(buf.as_slice()) { Ok(element) => element, Err(_) => panic!("element deserialization failed"), }; diff --git a/math/src/field/extensions/cubic.rs b/math/src/field/extensions/cubic.rs index dfb9a584a..5dd18be72 100644 --- a/math/src/field/extensions/cubic.rs +++ b/math/src/field/extensions/cubic.rs @@ -11,8 +11,10 @@ use core::{ slice, }; use utils::{ - collections::Vec, string::ToString, AsBytes, ByteReader, ByteWriter, Deserializable, - DeserializationError, Randomizable, Serializable, SliceReader, + collections::Vec, + string::{String, ToString}, + AsBytes, ByteReader, ByteWriter, Deserializable, DeserializationError, Randomizable, + Serializable, SliceReader, }; #[cfg(feature = "serde")] @@ -319,6 +321,32 @@ impl> From for CubeExtension { } } +impl> TryFrom for CubeExtension { + type Error = String; + + fn try_from(value: u64) -> Result { + match B::try_from(value) { + Ok(elem) => Ok(Self::from(elem)), + Err(_) => Err(format!( + "invalid field element: value {value} is greater than or equal to the field modulus" + )), + } + } +} + +impl> TryFrom for CubeExtension { + type Error = String; + + fn try_from(value: u128) -> Result { + match B::try_from(value) { + Ok(elem) => Ok(Self::from(elem)), + Err(_) => Err(format!( + "invalid field element: value {value} is greater than or equal to the field modulus" + )), + } + } +} + impl<'a, B: ExtensibleField<3>> TryFrom<&'a [u8]> for CubeExtension { type Error = DeserializationError; diff --git a/math/src/field/extensions/quadratic.rs b/math/src/field/extensions/quadratic.rs index a8683b617..906829adc 100644 --- a/math/src/field/extensions/quadratic.rs +++ b/math/src/field/extensions/quadratic.rs @@ -11,8 +11,10 @@ use core::{ slice, }; use utils::{ - collections::Vec, string::ToString, AsBytes, ByteReader, ByteWriter, Deserializable, - DeserializationError, Randomizable, Serializable, SliceReader, + collections::Vec, + string::{String, ToString}, + AsBytes, ByteReader, ByteWriter, Deserializable, DeserializationError, Randomizable, + Serializable, SliceReader, }; #[cfg(feature = "serde")] @@ -313,6 +315,32 @@ impl> From for QuadExtension { } } +impl> TryFrom for QuadExtension { + type Error = String; + + fn try_from(value: u64) -> Result { + match B::try_from(value) { + Ok(elem) => Ok(Self::from(elem)), + Err(_) => Err(format!( + "invalid field element: value {value} is greater than or equal to the field modulus" + )), + } + } +} + +impl> TryFrom for QuadExtension { + type Error = String; + + fn try_from(value: u128) -> Result { + match B::try_from(value) { + Ok(elem) => Ok(Self::from(elem)), + Err(_) => Err(format!( + "invalid field element: value {value} is greater than or equal to the field modulus" + )), + } + } +} + impl<'a, B: ExtensibleField<2>> TryFrom<&'a [u8]> for QuadExtension { type Error = DeserializationError; diff --git a/math/src/field/f128/mod.rs b/math/src/field/f128/mod.rs index ffe63de56..41a56fe9f 100644 --- a/math/src/field/f128/mod.rs +++ b/math/src/field/f128/mod.rs @@ -355,6 +355,20 @@ impl From for BaseElement { } } +impl TryFrom for BaseElement { + type Error = String; + + fn try_from(value: u128) -> Result { + if value >= M { + Err(format!( + "invalid field element: value {value} is greater than or equal to the field modulus" + )) + } else { + Ok(Self::new(value)) + } + } +} + impl<'a> TryFrom<&'a [u8]> for BaseElement { type Error = String; diff --git a/math/src/field/f62/mod.rs b/math/src/field/f62/mod.rs index dec435b85..b134f12aa 100644 --- a/math/src/field/f62/mod.rs +++ b/math/src/field/f62/mod.rs @@ -455,6 +455,20 @@ impl TryFrom for BaseElement { } } +impl TryFrom for BaseElement { + type Error = String; + + fn try_from(value: u128) -> Result { + if value >= M as u128 { + Err(format!( + "invalid field element: value {value} is greater than or equal to the field modulus" + )) + } else { + Ok(Self::new(value as u64)) + } + } +} + impl TryFrom<[u8; 8]> for BaseElement { type Error = String; diff --git a/math/src/field/f64/mod.rs b/math/src/field/f64/mod.rs index 86d1a9a12..fd2b3dab4 100644 --- a/math/src/field/f64/mod.rs +++ b/math/src/field/f64/mod.rs @@ -550,6 +550,20 @@ impl TryFrom for BaseElement { } } +impl TryFrom for BaseElement { + type Error = String; + + fn try_from(value: u128) -> Result { + if value >= M as u128 { + Err(format!( + "invalid field element: value {value} is greater than or equal to the field modulus" + )) + } else { + Ok(Self::new(value as u64)) + } + } +} + impl TryFrom<[u8; 8]> for BaseElement { type Error = String; diff --git a/math/src/field/traits.rs b/math/src/field/traits.rs index 124e2980b..e6dd22094 100644 --- a/math/src/field/traits.rs +++ b/math/src/field/traits.rs @@ -49,6 +49,8 @@ pub trait FieldElement: + From + From + From + + TryFrom + + TryFrom + for<'a> TryFrom<&'a [u8]> + ExtensionOf<::BaseField> + AsBytes