From c1e3763e9192303346abb64f5b6fc400b4af3f04 Mon Sep 17 00:00:00 2001 From: Fernando Otero Date: Tue, 30 Jul 2024 14:56:39 +0100 Subject: [PATCH] [libraries/pod]: improve `PodOption` type (#7076) * Improve PodOption type * Use associated constant * Add exhaustive matching --- libraries/pod/src/option.rs | 89 +++++++++++++++++++++++++++---------- 1 file changed, 66 insertions(+), 23 deletions(-) diff --git a/libraries/pod/src/option.rs b/libraries/pod/src/option.rs index 3b15bdd1bc4..60a4b584c79 100644 --- a/libraries/pod/src/option.rs +++ b/libraries/pod/src/option.rs @@ -8,16 +8,25 @@ use { bytemuck::{Pod, Zeroable}, - solana_program::{program_option::COption, pubkey::Pubkey}, + solana_program::{ + program_error::ProgramError, + program_option::COption, + pubkey::{Pubkey, PUBKEY_BYTES}, + }, }; /// Trait for types that can be `None`. /// /// This trait is used to indicate that a type can be `None` according to a /// specific value. -pub trait Nullable: Default + Pod { +pub trait Nullable: PartialEq + Pod + Sized { + /// Value that represents `None` for the type. + const NONE: Self; + /// Indicates whether the value is `None` or not. - fn is_none(&self) -> bool; + fn is_none(&self) -> bool { + self == &Self::NONE + } /// Indicates whether the value is `Some`` value of type `T`` or not. fn is_some(&self) -> bool { @@ -66,8 +75,16 @@ impl PodOption { } } +/// ## Safety +/// +/// `PodOption` is a transparent wrapper around a `Pod` type `T` with identical +/// data representation. unsafe impl Pod for PodOption {} +/// ## Safety +/// +/// `PodOption` is a transparent wrapper around a `Pod` type `T` with identical +/// data representation. unsafe impl Zeroable for PodOption {} impl From for PodOption { @@ -76,32 +93,33 @@ impl From for PodOption { } } -impl From> for PodOption { - fn from(from: Option) -> Self { - match from { - Some(value) => PodOption(value), - None => PodOption(T::default()), +impl TryFrom> for PodOption { + type Error = ProgramError; + + fn try_from(value: Option) -> Result { + match value { + Some(value) if value.is_none() => Err(ProgramError::InvalidArgument), + Some(value) => Ok(PodOption(value)), + None => Ok(PodOption(T::NONE)), } } } -impl From> for PodOption { - fn from(from: COption) -> Self { - match from { - COption::Some(value) => PodOption(value), - COption::None => PodOption(T::default()), +impl TryFrom> for PodOption { + type Error = ProgramError; + + fn try_from(value: COption) -> Result { + match value { + COption::Some(value) if value.is_none() => Err(ProgramError::InvalidArgument), + COption::Some(value) => Ok(PodOption(value)), + COption::None => Ok(PodOption(T::NONE)), } } } /// Implementation of `Nullable` for `Pubkey`. -/// -/// The implementation assumes that the default value of `Pubkey` represents -/// the `None` value. impl Nullable for Pubkey { - fn is_none(&self) -> bool { - self == &Pubkey::default() - } + const NONE: Self = Pubkey::new_from_array([0u8; PUBKEY_BYTES]); } #[cfg(test)] @@ -126,13 +144,38 @@ mod tests { assert_eq!(values[1], PodOption::from(Pubkey::default())); let option_pubkey = Some(sysvar::ID); - let pod_option_pubkey: PodOption = option_pubkey.into(); + let pod_option_pubkey: PodOption = option_pubkey.try_into().unwrap(); assert_eq!(pod_option_pubkey, PodOption::from(sysvar::ID)); - assert_eq!(pod_option_pubkey, PodOption::from(option_pubkey)); + assert_eq!( + pod_option_pubkey, + PodOption::try_from(option_pubkey).unwrap() + ); let coption_pubkey = COption::Some(sysvar::ID); - let pod_option_pubkey: PodOption = coption_pubkey.into(); + let pod_option_pubkey: PodOption = coption_pubkey.try_into().unwrap(); assert_eq!(pod_option_pubkey, PodOption::from(sysvar::ID)); - assert_eq!(pod_option_pubkey, PodOption::from(coption_pubkey)); + assert_eq!( + pod_option_pubkey, + PodOption::try_from(coption_pubkey).unwrap() + ); + } + + #[test] + fn test_try_from_option() { + let some_pubkey = Some(sysvar::ID); + assert_eq!( + PodOption::try_from(some_pubkey).unwrap(), + PodOption(sysvar::ID) + ); + + let none_pubkey = None; + assert_eq!( + PodOption::try_from(none_pubkey).unwrap(), + PodOption::from(Pubkey::NONE) + ); + + let invalid_option = Some(Pubkey::NONE); + let err = PodOption::try_from(invalid_option).unwrap_err(); + assert_eq!(err, ProgramError::InvalidArgument); } }