diff --git a/aws-lc-rs/src/cbb.rs b/aws-lc-rs/src/cbb.rs index 8ba195ced82..ec80a3b4c77 100644 --- a/aws-lc-rs/src/cbb.rs +++ b/aws-lc-rs/src/cbb.rs @@ -7,6 +7,14 @@ use std::mem::MaybeUninit; pub(crate) struct LcCBB(CBB); impl LcCBB { + pub(crate) fn new(initial_capacity: usize) -> LcCBB { + let mut cbb = MaybeUninit::::uninit(); + unsafe { + CBB_init(cbb.as_mut_ptr(), initial_capacity); + } + Self(unsafe { cbb.assume_init() }) + } + pub(crate) fn as_mut_ptr(&mut self) -> *mut CBB { &mut self.0 } @@ -19,11 +27,3 @@ impl Drop for LcCBB { } } } - -#[inline] -#[allow(non_snake_case)] -pub(crate) unsafe fn build_CBB(initial_capacity: usize) -> LcCBB { - let mut cbb = MaybeUninit::::uninit(); - CBB_init(cbb.as_mut_ptr(), initial_capacity); - LcCBB(cbb.assume_init()) -} diff --git a/aws-lc-rs/src/evp_pkey.rs b/aws-lc-rs/src/evp_pkey.rs index ee71169d6d7..8a4d7348c74 100644 --- a/aws-lc-rs/src/evp_pkey.rs +++ b/aws-lc-rs/src/evp_pkey.rs @@ -1,17 +1,17 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 OR ISC +use crate::cbb::LcCBB; +use crate::cbs; use crate::ec::PKCS8_DOCUMENT_MAX_LEN; use crate::error::{KeyRejected, Unspecified}; use crate::pkcs8::{Document, Version}; use crate::ptr::LcPtr; -use crate::{cbb, cbs}; use aws_lc::{ CBB_finish, EVP_PKEY_bits, EVP_PKEY_get1_EC_KEY, EVP_PKEY_get1_RSA, EVP_PKEY_id, EVP_marshal_private_key, EVP_marshal_private_key_v2, EVP_parse_private_key, EC_KEY, EVP_PKEY, RSA, }; -use std::mem::MaybeUninit; use std::os::raw::c_int; use std::ptr::null_mut; @@ -85,38 +85,30 @@ impl LcPtr { } pub(crate) fn marshall_private_key(&self, version: Version) -> Result { - unsafe { - let mut cbb = cbb::build_CBB(PKCS8_DOCUMENT_MAX_LEN); + let mut cbb = LcCBB::new(PKCS8_DOCUMENT_MAX_LEN); - match version { - Version::V1 => { - if 1 != EVP_marshal_private_key(cbb.as_mut_ptr(), **self) { - return Err(Unspecified); - } + match version { + Version::V1 => { + if 1 != unsafe { EVP_marshal_private_key(cbb.as_mut_ptr(), **self) } { + return Err(Unspecified); } - Version::V2 => { - if 1 != EVP_marshal_private_key_v2(cbb.as_mut_ptr(), **self) { - return Err(Unspecified); - } + } + Version::V2 => { + if 1 != unsafe { EVP_marshal_private_key_v2(cbb.as_mut_ptr(), **self) } { + return Err(Unspecified); } } + } - let mut pkcs8_bytes_ptr = null_mut::(); - let mut out_len = MaybeUninit::::uninit(); - if 1 != CBB_finish(cbb.as_mut_ptr(), &mut pkcs8_bytes_ptr, out_len.as_mut_ptr()) { - return Err(Unspecified); - } - let pkcs8_bytes_ptr = LcPtr::new(pkcs8_bytes_ptr)?; - let out_len = out_len.assume_init(); + let mut pkcs8_bytes_ptr = null_mut::(); + let mut out_len: usize = 0; + if 1 != unsafe { CBB_finish(cbb.as_mut_ptr(), &mut pkcs8_bytes_ptr, &mut out_len) } { + return Err(Unspecified); + } - let bytes_slice = pkcs8_bytes_ptr.as_slice(out_len); - let mut pkcs8_bytes = [0u8; PKCS8_DOCUMENT_MAX_LEN]; - pkcs8_bytes[0..out_len].copy_from_slice(bytes_slice); + let pkcs8_bytes_ptr = LcPtr::new(pkcs8_bytes_ptr)?; + let bytes = Vec::from(unsafe { pkcs8_bytes_ptr.as_slice(out_len) }).into_boxed_slice(); - Ok(Document { - bytes: pkcs8_bytes, - len: out_len, - }) - } + Ok(Document::new(bytes)) } } diff --git a/aws-lc-rs/src/pkcs8.rs b/aws-lc-rs/src/pkcs8.rs index 095cb6da19f..6ee0160fd1d 100644 --- a/aws-lc-rs/src/pkcs8.rs +++ b/aws-lc-rs/src/pkcs8.rs @@ -7,25 +7,29 @@ //! //! [RFC 5208]: https://tools.ietf.org/html/rfc5208. -use crate::ec; use zeroize::Zeroize; /// A generated PKCS#8 document. pub struct Document { - pub(crate) bytes: [u8; ec::PKCS8_DOCUMENT_MAX_LEN], - pub(crate) len: usize, + bytes: Box<[u8]>, +} + +impl Document { + pub(crate) fn new(bytes: Box<[u8]>) -> Self { + Self { bytes } + } } impl AsRef<[u8]> for Document { #[inline] fn as_ref(&self) -> &[u8] { - &self.bytes[..self.len] + &self.bytes } } impl Drop for Document { fn drop(&mut self) { - self.bytes.zeroize(); + self.bytes.as_mut().zeroize(); } } diff --git a/aws-lc-rs/src/ptr.rs b/aws-lc-rs/src/ptr.rs index 5103cc5f146..6f8dc6b850a 100644 --- a/aws-lc-rs/src/ptr.rs +++ b/aws-lc-rs/src/ptr.rs @@ -40,6 +40,11 @@ impl ManagedPointer

{ pub unsafe fn as_slice(&self, len: usize) -> &[P::T] { std::slice::from_raw_parts(self.pointer.as_const_ptr(), len) } + + #[allow(clippy::mut_from_ref)] + pub unsafe fn as_slice_mut(&self, len: usize) -> &mut [P::T] { + std::slice::from_raw_parts_mut(self.pointer.as_mut_ptr(), len) + } } impl Drop for ManagedPointer

{ @@ -160,6 +165,7 @@ pub(crate) trait Pointer { fn free(&mut self); fn as_const_ptr(&self) -> *const Self::T; + fn as_mut_ptr(&self) -> *mut Self::T; } pub(crate) trait IntoPointer

{ @@ -190,9 +196,15 @@ macro_rules! create_pointer { } } + #[inline] fn as_const_ptr(&self) -> *const Self::T { self.cast() } + + #[inline] + fn as_mut_ptr(&self) -> *mut Self::T { + *self + } } }; } diff --git a/aws-lc-rs/src/rsa.rs b/aws-lc-rs/src/rsa.rs index 0a073e0f28f..6422674cef8 100644 --- a/aws-lc-rs/src/rsa.rs +++ b/aws-lc-rs/src/rsa.rs @@ -11,11 +11,18 @@ // components. pub(crate) mod key; +mod oaep; pub(crate) mod signature; -pub use self::key::{KeyPair, PublicKey, PublicKeyComponents}; #[allow(clippy::module_name_repetitions)] pub use self::signature::RsaParameters; +pub use self::{ + key::{KeyPair, PublicKey, PublicKeyComponents}, + oaep::{ + EncryptionAlgorithm, EncryptionAlgorithmId, PrivateDecryptingKey, PublicEncryptingKey, + OAEP_SHA1_MGF1SHA1, OAEP_SHA256_MGF1SHA256, OAEP_SHA384_MGF1SHA384, OAEP_SHA512_MGF1SHA512, + }, +}; pub(crate) use self::signature::RsaVerificationAlgorithmId; diff --git a/aws-lc-rs/src/rsa/key.rs b/aws-lc-rs/src/rsa/key.rs index 085b36592b7..39cb515bfba 100644 --- a/aws-lc-rs/src/rsa/key.rs +++ b/aws-lc-rs/src/rsa/key.rs @@ -11,9 +11,11 @@ use core::{ #[cfg(feature = "fips")] use aws_lc::RSA_check_fips; use aws_lc::{ - EVP_DigestSignInit, EVP_PKEY_assign_RSA, EVP_PKEY_new, RSA_get0_e, RSA_get0_n, RSA_get0_p, - RSA_get0_q, RSA_new, RSA_parse_private_key, RSA_parse_public_key, RSA_public_key_to_bytes, - RSA_set0_key, RSA_size, EVP_PKEY, EVP_PKEY_CTX, RSA, + CBB_finish, EVP_DigestSignInit, EVP_PKEY_CTX_new_id, EVP_PKEY_CTX_set_rsa_keygen_bits, + EVP_PKEY_assign_RSA, EVP_PKEY_keygen, EVP_PKEY_keygen_init, EVP_PKEY_new, EVP_PKEY_size, + EVP_marshal_private_key, RSA_get0_e, RSA_get0_n, RSA_get0_p, RSA_get0_q, RSA_new, + RSA_parse_private_key, RSA_parse_public_key, RSA_public_key_to_bytes, RSA_set0_key, RSA_size, + EVP_PKEY, EVP_PKEY_CTX, EVP_PKEY_RSA, RSA, }; use mirai_annotations::verify_unreachable; @@ -31,14 +33,65 @@ use super::{ #[cfg(feature = "ring-io")] use crate::io; use crate::{ + cbb::LcCBB, cbs, digest, error::{KeyRejected, Unspecified}, + fips::indicator_check, hex, + pkcs8::Document, ptr::{ConstPointer, DetachableLcPtr, LcPtr}, rand, sealed::Sealed, }; +// Based on a meassurement of a PKCS#8 document containing an RSA-2048 key with a 5% additional buffer. +pub(super) const PKCS8_CAPACITY_BUFFER: usize = 1252; + +/// RSA key-size. +#[allow(clippy::module_name_repetitions)] +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum KeySize { + /// 2048-bit key + Rsa2048, + + /// 4096-bit key + Rsa4096, + + /// 8192-bit key + Rsa8192, +} + +impl KeySize { + /// Returns the size of the key in bytes [`KeySize`]. + #[inline] + pub fn len(self) -> usize { + match self { + KeySize::Rsa2048 => 256, + KeySize::Rsa4096 => 512, + KeySize::Rsa8192 => 1024, + } + } + + /// Returns the bits of this [`KeySize`]. + #[inline] + fn bit_len(self) -> i32 { + match self { + KeySize::Rsa2048 => 2048, + KeySize::Rsa4096 => 4096, + KeySize::Rsa8192 => 8192, + } + } + + pub(super) fn from_evp_pkey(evp_pkey: &LcPtr) -> Result { + Ok(match unsafe { EVP_PKEY_size(**evp_pkey) } { + 256 => KeySize::Rsa2048, + 512 => KeySize::Rsa4096, + 1024 => KeySize::Rsa8192, + _ => return Err(Unspecified), + }) + } +} + /// An RSA key pair, used for signing. #[allow(clippy::module_name_repetitions)] pub struct KeyPair { @@ -48,8 +101,8 @@ pub struct KeyPair { // other thread is concurrently calling a mutating function. Unless otherwise // documented, functions which take a |const| pointer are non-mutating and // functions which take a non-|const| pointer are mutating. - evp_pkey: LcPtr, - serialized_public_key: PublicKey, + pub(super) evp_pkey: LcPtr, + pub(super) serialized_public_key: PublicKey, } impl Sealed for KeyPair {} @@ -68,6 +121,16 @@ impl KeyPair { } } + /// Generate an RSA `KeyPair` of the specified key-strength. + /// + /// # Errors + /// * `Unspecified` + pub fn generate(size: KeySize) -> Result { + let private_key = generate_rsa_evp_pkey(size)?; + + Self::new(private_key).map_err(|_| Unspecified) + } + /// Parses an unencrypted PKCS#8-encoded RSA private key. /// /// Only two-prime (not multi-prime) keys are supported. The public modulus @@ -118,7 +181,7 @@ impl KeyPair { pub fn from_pkcs8(pkcs8: &[u8]) -> Result { unsafe { let evp_pkey = LcPtr::try_from(pkcs8)?; - Self::validate_rsa_pkey(&evp_pkey)?; + validate_rsa_pkey(&evp_pkey)?; Self::new(evp_pkey) } } @@ -130,58 +193,34 @@ impl KeyPair { pub fn from_der(input: &[u8]) -> Result { unsafe { let pkey = build_private_RSA_PKEY(input)?; - Self::validate_rsa_pkey(&pkey)?; + validate_rsa_pkey(&pkey)?; Self::new(pkey) } } - const MIN_RSA_PRIME_BITS: u32 = 1024; - const MAX_RSA_PRIME_BITS: u32 = 4096; - - /// ⚠️ Function assumes that `aws_lc::RSA_check_key` / `aws_lc::RSA_validate_key` has already been invoked beforehand. - /// `aws_lc::RSA_validate_key` is already invoked by `aws_lc::EVP_parse_private_key` / `aws_lc::RSA_parse_private_key`. - /// If the `EVP_PKEY` was constructed through another mechanism, then the key should be validated through the use of - /// one those verifier functions first. - unsafe fn validate_rsa_pkey(rsa: &LcPtr) -> Result<(), KeyRejected> { - let rsa = rsa.get_rsa()?.as_const(); - - let p = ConstPointer::new(RSA_get0_p(*rsa))?; - let q = ConstPointer::new(RSA_get0_q(*rsa))?; - let p_bits = p.num_bits(); - let q_bits = q.num_bits(); + /// Serializes this key-pair to a PKCS#8 (v1) document. + /// + /// # Errors + /// * `Unspecified`: any error encountered while serializing the key. + pub fn to_pkcs8v1(&self) -> Result { + let mut cbb = LcCBB::new(PKCS8_CAPACITY_BUFFER); - if p_bits != q_bits { - return Err(KeyRejected::inconsistent_components()); + if 1 != unsafe { EVP_marshal_private_key(cbb.as_mut_ptr(), *self.evp_pkey.as_const()) } { + return Err(Unspecified); } - if p_bits < Self::MIN_RSA_PRIME_BITS { - return Err(KeyRejected::too_small()); - } - if p_bits > Self::MAX_RSA_PRIME_BITS { - return Err(KeyRejected::too_large()); - } + let mut pkcs8_bytes_ptr = null_mut::(); + let mut out_len: usize = 0; - if p_bits % 512 != 0 { - return Err(KeyRejected::private_modulus_len_not_multiple_of_512_bits()); + if 1 != unsafe { CBB_finish(cbb.as_mut_ptr(), &mut pkcs8_bytes_ptr, &mut out_len) } { + return Err(Unspecified); } - let e = ConstPointer::new(RSA_get0_e(*rsa))?; - let min_exponent = DetachableLcPtr::try_from(65537)?; - match e.compare(&min_exponent.as_const()) { - Ordering::Less => Err(KeyRejected::too_small()), - Ordering::Equal | Ordering::Greater => Ok(()), - }?; - - // For the FIPS feature this will perform the necessary public-key validaiton steps and pairwise consistency tests. - // TODO: This also result in another call to `aws_lc::RSA_validate_key`, meaning duplicate effort is performed - // even after having already performing this operation during key parsing. Ideally the FIPS specific checks - // could be pulled out and invoked seperatly from the standard checks. - #[cfg(feature = "fips")] - if 1 != RSA_check_fips(*rsa as *mut RSA) { - return Err(KeyRejected::inconsistent_components()); - } + let pkcs8_bytes_ptr = LcPtr::new(pkcs8_bytes_ptr)?; - Ok(()) + let bytes = Vec::from(unsafe { pkcs8_bytes_ptr.as_slice(out_len) }).into_boxed_slice(); + + Ok(Document::new(bytes)) } /// Sign `msg`. `msg` is digested using the digest algorithm from @@ -497,3 +536,84 @@ unsafe fn serialize_RSA_pubkey(pubkey: &ConstPointer) -> Result, let pubkey_vec = Vec::from(pubkey_slice); Ok(pubkey_vec.into_boxed_slice()) } + +pub(super) fn generate_rsa_evp_pkey(size: KeySize) -> Result, Unspecified> { + let evp_pkey_ctx = LcPtr::new(unsafe { EVP_PKEY_CTX_new_id(EVP_PKEY_RSA, null_mut()) })?; + + if 1 != unsafe { EVP_PKEY_keygen_init(*evp_pkey_ctx) } { + return Err(Unspecified); + }; + + if 1 != unsafe { EVP_PKEY_CTX_set_rsa_keygen_bits(*evp_pkey_ctx, size.bit_len()) } { + return Err(Unspecified); + }; + + let mut pkey: *mut EVP_PKEY = null_mut(); + + if 1 != indicator_check!(unsafe { EVP_PKEY_keygen(*evp_pkey_ctx, &mut pkey) }) { + return Err(Unspecified); + }; + + Ok(LcPtr::new(pkey)?) +} + +/// ⚠️ Function assumes that `aws_lc::RSA_check_key` / `aws_lc::RSA_validate_key` has already been invoked beforehand. +/// `aws_lc::RSA_validate_key` is already invoked by `aws_lc::EVP_parse_private_key` / `aws_lc::RSA_parse_private_key`. +/// If the `EVP_PKEY` was constructed through another mechanism, then the key should be validated through the use of +/// one those verifier functions first. +pub(super) unsafe fn validate_rsa_pkey(rsa: &LcPtr) -> Result<(), KeyRejected> { + const MIN_RSA_PRIME_BITS: u32 = 1024; + const MAX_RSA_PRIME_BITS: u32 = 4096; + + let rsa = rsa.get_rsa()?.as_const(); + + let p = ConstPointer::new(RSA_get0_p(*rsa))?; + let q = ConstPointer::new(RSA_get0_q(*rsa))?; + let p_bits = p.num_bits(); + let q_bits = q.num_bits(); + + if p_bits != q_bits { + return Err(KeyRejected::inconsistent_components()); + } + + if p_bits < MIN_RSA_PRIME_BITS { + return Err(KeyRejected::too_small()); + } + if p_bits > MAX_RSA_PRIME_BITS { + return Err(KeyRejected::too_large()); + } + + if p_bits % 512 != 0 { + return Err(KeyRejected::private_modulus_len_not_multiple_of_512_bits()); + } + + let e = ConstPointer::new(RSA_get0_e(*rsa))?; + let min_exponent = DetachableLcPtr::try_from(65537)?; + match e.compare(&min_exponent.as_const()) { + Ordering::Less => Err(KeyRejected::too_small()), + Ordering::Equal | Ordering::Greater => Ok(()), + }?; + + // For the FIPS feature this will perform the necessary public-key validaiton steps and pairwise consistency tests. + // TODO: This also result in another call to `aws_lc::RSA_validate_key`, meaning duplicate effort is performed + // even after having already performing this operation during key parsing. Ideally the FIPS specific checks + // could be pulled out and invoked seperatly from the standard checks. + #[cfg(feature = "fips")] + if 1 != RSA_check_fips(*rsa as *mut RSA) { + return Err(KeyRejected::inconsistent_components()); + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::KeyPair; + + #[test] + fn generate_key() { + let keypair = KeyPair::generate(super::KeySize::Rsa2048).expect("generate successful"); + let document = keypair.to_pkcs8v1().expect("serialize keypair"); + let _ = KeyPair::from_pkcs8(document.as_ref()).expect("deserialize key"); + } +} diff --git a/aws-lc-rs/src/rsa/oaep.rs b/aws-lc-rs/src/rsa/oaep.rs new file mode 100644 index 00000000000..0d1b8feabb9 --- /dev/null +++ b/aws-lc-rs/src/rsa/oaep.rs @@ -0,0 +1,501 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 OR ISC + +use core::{fmt::Debug, ptr::null_mut}; +use std::ops::Deref; + +use aws_lc::{ + CBB_finish, EVP_PKEY_CTX_new, EVP_PKEY_CTX_set_rsa_mgf1_md, EVP_PKEY_CTX_set_rsa_oaep_md, + EVP_PKEY_CTX_set_rsa_padding, EVP_PKEY_decrypt, EVP_PKEY_decrypt_init, EVP_PKEY_encrypt, + EVP_PKEY_encrypt_init, EVP_PKEY_up_ref, EVP_marshal_private_key, EVP_marshal_public_key, + EVP_parse_public_key, EVP_sha1, EVP_sha256, EVP_sha384, EVP_sha512, EVP_MD, EVP_PKEY, + EVP_PKEY_CTX, RSA_PKCS1_OAEP_PADDING, +}; + +use crate::{ + buffer::Buffer, + cbb::{self, LcCBB}, + cbs, + encoding::AsDer, + error::{KeyRejected, Unspecified}, + fips::indicator_check, + pkcs8::Document, + ptr::LcPtr, +}; + +use super::key::{generate_rsa_evp_pkey, KeySize, PKCS8_CAPACITY_BUFFER}; + +/// RSA-OAEP with SHA1 Hash and SHA1 MGF1 +pub const OAEP_SHA1_MGF1SHA1: EncryptionAlgorithm = EncryptionAlgorithm { + id: EncryptionAlgorithmId::OaepSha1Mgf1sha1, + hash: EVP_sha1, + mgf: EVP_sha1, +}; + +/// RSA-OAEP with SHA256 Hash and SHA256 MGF1 +pub const OAEP_SHA256_MGF1SHA256: EncryptionAlgorithm = EncryptionAlgorithm { + id: EncryptionAlgorithmId::OaepSha256Mgf1sha256, + hash: EVP_sha256, + mgf: EVP_sha256, +}; + +/// RSA-OAEP with SHA384 Hash and SHA384 MGF1 +pub const OAEP_SHA384_MGF1SHA384: EncryptionAlgorithm = EncryptionAlgorithm { + id: EncryptionAlgorithmId::OaepSha384Mgf1sha384, + hash: EVP_sha384, + mgf: EVP_sha384, +}; + +/// RSA-OAEP with SHA512 Hash and SHA512 MGF1 +pub const OAEP_SHA512_MGF1SHA512: EncryptionAlgorithm = EncryptionAlgorithm { + id: EncryptionAlgorithmId::OaepSha512Mgf1sha512, + hash: EVP_sha512, + mgf: EVP_sha512, +}; + +/// RSA Encryption Algorithm Identifier +#[non_exhaustive] +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum EncryptionAlgorithmId { + /// RSA-OAEP with SHA1 Hash and SHA1 MGF1 + OaepSha1Mgf1sha1, + + /// RSA-OAEP with SHA256 Hash and SHA256 MGF1 + OaepSha256Mgf1sha256, + + /// RSA-OAEP with SHA384 Hash and SHA384 MGF1 + OaepSha384Mgf1sha384, + + /// RSA-OAEP with SHA512 Hash and SHA512 MGF1 + OaepSha512Mgf1sha512, +} + +type HashFn = unsafe extern "C" fn() -> *const EVP_MD; +type MgfFn = unsafe extern "C" fn() -> *const EVP_MD; + +/// An RSA Encryption Algorithm. +pub struct EncryptionAlgorithm { + id: EncryptionAlgorithmId, + hash: HashFn, + mgf: MgfFn, +} + +impl EncryptionAlgorithm { + /// Returns the algorithm's associated identifier. + #[must_use] + pub fn id(&self) -> EncryptionAlgorithmId { + self.id + } + + #[inline] + fn hash(&self) -> HashFn { + self.hash + } + + #[inline] + fn mgf(&self) -> MgfFn { + self.mgf + } +} + +impl Debug for EncryptionAlgorithm { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> std::fmt::Result { + Debug::fmt(&self.id, f) + } +} + +/// A PKCS#8 V1 (RFC 5208) document encoded using DER. +pub struct Pkcs8V1Der(Document); + +// We already had a `Document` type that is running double-duty for v1 and v2 documents +// in the ec module. Rather then use a `Buffer<...>` alias type with `AsDer` opt for using +// the existing Document type with `Deref` so we can still benefit from supporting different +// document serialization versions in the future with `AsDer`. +impl Deref for Pkcs8V1Der { + type Target = Document; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +/// An RSA Private Key used for decrypting ciphertext encrypted by [`PublicEncryptingKey`]. +pub struct PrivateDecryptingKey { + key: LcPtr, + size: KeySize, +} + +impl PrivateDecryptingKey { + fn new(key: LcPtr) -> Result { + let size = KeySize::from_evp_pkey(&key)?; + Ok(Self { key, size }) + } + + /// Generate a new RSA private key for use with asymmetrical encryption. + /// + /// # Errors + /// * `Unspeicifed` for any error that occurs during the generation of the RSA keypair. + pub fn generate(size: KeySize) -> Result { + Self::new(generate_rsa_evp_pkey(size)?) + } + + /// Construct a `PrivateDecryptingKey` from the pvoided PKCS#8 (v1) document. + /// + /// # Errors + /// * `Unspeicifed` for any error that occurs during deserialization of this key from PKCS#8. + pub fn from_pkcs8(pkcs8: &[u8]) -> Result { + unsafe { + let evp_pkey = LcPtr::try_from(pkcs8)?; + super::key::validate_rsa_pkey(&evp_pkey)?; + Self::new(evp_pkey).map_err(|_| KeyRejected::unexpected_error()) + } + } + + /// Returns the corresponding [`KeySize`]. + #[must_use] + pub fn key_size(&self) -> KeySize { + self.size + } + + /// Retrieves the `PublicEncryptingKey` corresponding with this `PrivateDecryptingKey`. + /// + /// # Errors + /// * `Unspeicifed` for any error that occurs computing the public key. + pub fn public_key(&self) -> Result { + if 1 != unsafe { EVP_PKEY_up_ref(*self.key) } { + return Err(Unspecified); + }; + PublicEncryptingKey::new(LcPtr::new(*self.key)?) + } + + /// Decrypts the contents in `ciphertext` and writes the corresponding plaintext to `output`. + /// + /// # Errors + /// * `Unspeicifed` for any error that occurs while decrypting `ciphertext`. + pub fn decrypt<'output>( + &self, + algorithm: &'static EncryptionAlgorithm, + ciphertext: &[u8], + output: &'output mut [u8], + ) -> Result<&'output mut [u8], Unspecified> { + let pkey_ctx = LcPtr::new(unsafe { EVP_PKEY_CTX_new(*self.key, null_mut()) })?; + + if 1 != unsafe { EVP_PKEY_decrypt_init(*pkey_ctx) } { + return Err(Unspecified); + } + + configure_oaep_crypto_operation(&pkey_ctx, algorithm.hash(), algorithm.mgf())?; + + let mut out_len = output.len(); + + if 1 != indicator_check!(unsafe { + EVP_PKEY_decrypt( + *pkey_ctx, + output.as_mut_ptr(), + &mut out_len, + ciphertext.as_ptr(), + ciphertext.len(), + ) + }) { + return Err(Unspecified); + }; + + Ok(&mut output[..out_len]) + } +} + +impl AsDer for PrivateDecryptingKey { + fn as_der(&self) -> Result { + let mut cbb = LcCBB::new(PKCS8_CAPACITY_BUFFER); + + if 1 != unsafe { EVP_marshal_private_key(cbb.as_mut_ptr(), *self.key.as_const()) } { + return Err(Unspecified); + } + + let mut pkcs8_bytes_ptr = null_mut::(); + let mut out_len = 0; + + if 1 != unsafe { CBB_finish(cbb.as_mut_ptr(), &mut pkcs8_bytes_ptr, &mut out_len) } { + return Err(Unspecified); + } + + let pkcs8_bytes_ptr = LcPtr::new(pkcs8_bytes_ptr)?; + + let bytes = Vec::from(unsafe { pkcs8_bytes_ptr.as_slice(out_len) }).into_boxed_slice(); + + Ok(Pkcs8V1Der(Document::new(bytes))) + } +} + +pub struct PublicKeyX509DerType { + _priv: (), +} + +pub type PublicKeyX509Der = Buffer<'static, PublicKeyX509DerType>; + +/// An RSA Public Key used for decrypting ciphertext encrypted by [`PublicEncryptingKey`]. +pub struct PublicEncryptingKey { + key: LcPtr, + size: KeySize, +} + +impl PublicEncryptingKey { + fn new(key: LcPtr) -> Result { + let size = KeySize::from_evp_pkey(&key)?; + Ok(Self { key, size }) + } + + /// Construct a `PublicEncryptingKey` from X.509 `SubjectPublicKeyInfo` DER encoded bytes. + /// + /// # Errors + /// * `Unspeicifed` for any error that occurs deserializing from bytes. + pub fn from_der(value: &[u8]) -> Result { + let mut der = unsafe { cbs::build_CBS(value) }; + let key = LcPtr::new(unsafe { EVP_parse_public_key(&mut der) })?; + Self::new(key) + } + + /// Returns the corresponding [`KeySize`]. + #[must_use] + pub fn key_size(&self) -> KeySize { + self.size + } + + /// Encrypts the contents in `plaintext` and writes the corresponding ciphertext to `output`. + /// + /// # Errors + /// * `Unspeicifed` for any error that occurs while decrypting `ciphertext`. + pub fn encrypt<'output>( + &self, + algorithm: &'static EncryptionAlgorithm, + plaintext: &[u8], + output: &'output mut [u8], + ) -> Result<&'output mut [u8], Unspecified> { + let pkey_ctx = LcPtr::new(unsafe { EVP_PKEY_CTX_new(*self.key, null_mut()) })?; + + if 1 != unsafe { EVP_PKEY_encrypt_init(*pkey_ctx) } { + return Err(Unspecified); + } + + configure_oaep_crypto_operation(&pkey_ctx, algorithm.hash(), algorithm.mgf())?; + + let mut out_len = output.len(); + + if 1 != indicator_check!(unsafe { + EVP_PKEY_encrypt( + *pkey_ctx, + output.as_mut_ptr(), + &mut out_len, + plaintext.as_ptr(), + plaintext.len(), + ) + }) { + return Err(Unspecified); + }; + + Ok(&mut output[..out_len]) + } +} + +fn configure_oaep_crypto_operation( + evp_pkey_ctx: &LcPtr, + hash: HashFn, + mgf: MgfFn, +) -> Result<(), Unspecified> { + if 1 != unsafe { EVP_PKEY_CTX_set_rsa_padding(**evp_pkey_ctx, RSA_PKCS1_OAEP_PADDING) } { + return Err(Unspecified); + }; + + if 1 != unsafe { EVP_PKEY_CTX_set_rsa_oaep_md(**evp_pkey_ctx, hash()) } { + return Err(Unspecified); + }; + + if 1 != unsafe { EVP_PKEY_CTX_set_rsa_mgf1_md(**evp_pkey_ctx, mgf()) } { + return Err(Unspecified); + }; + + Ok(()) +} + +impl AsDer for PublicEncryptingKey { + /// Serialize this `PublicEncryptingKey` to a X.509 `SubjectPublicKeyInfo` structure as DER encoded bytes. + /// + /// # Errors + /// * `Unspeicifed` for any error that occurs serializing to bytes. + fn as_der(&self) -> Result { + // TODO: Determine proper initial_capacity + let mut der = cbb::LcCBB::new(1024); + + if 1 != unsafe { EVP_marshal_public_key(der.as_mut_ptr(), *self.key) } { + return Err(Unspecified); + }; + + let mut out_data = null_mut::(); + let mut out_len: usize = 0; + + if 1 != unsafe { CBB_finish(der.as_mut_ptr(), &mut out_data, &mut out_len) } { + return Err(Unspecified); + }; + + let out_data = LcPtr::new(out_data)?; + + // TODO: Need a type to just hold the owned pointer from CBB rather then copying + Ok(Buffer::take_from_slice(unsafe { + out_data.as_slice_mut(out_len) + })) + } +} + +#[cfg(test)] +mod tests { + use crate::{ + encoding::AsDer, + rsa::{ + key::KeySize, + oaep::{ + OAEP_SHA1_MGF1SHA1, OAEP_SHA256_MGF1SHA256, OAEP_SHA384_MGF1SHA384, + OAEP_SHA512_MGF1SHA512, + }, + EncryptionAlgorithmId, + }, + }; + + use super::{PrivateDecryptingKey, PublicEncryptingKey}; + + #[test] + fn encryption_algorithm_id() { + assert_eq!( + OAEP_SHA1_MGF1SHA1.id(), + EncryptionAlgorithmId::OaepSha1Mgf1sha1 + ); + assert_eq!( + OAEP_SHA256_MGF1SHA256.id(), + EncryptionAlgorithmId::OaepSha256Mgf1sha256 + ); + assert_eq!( + OAEP_SHA384_MGF1SHA384.id(), + EncryptionAlgorithmId::OaepSha384Mgf1sha384 + ); + assert_eq!( + OAEP_SHA512_MGF1SHA512.id(), + EncryptionAlgorithmId::OaepSha512Mgf1sha512 + ); + } + + #[test] + fn encryption_algorithm_debug() { + assert_eq!("OaepSha1Mgf1sha1", format!("{OAEP_SHA1_MGF1SHA1:?}")); + } + + #[test] + fn generate() { + let private_key = PrivateDecryptingKey::generate(KeySize::Rsa2048).expect("generation"); + + let pkcs8v1 = private_key.as_der().expect("encoded"); + + let private_key = PrivateDecryptingKey::from_pkcs8(pkcs8v1.as_ref()).expect("decoded"); + + let public_key = private_key.public_key().expect("public key"); + + drop(private_key); + + let public_key_der = public_key.as_der().expect("encoded"); + + let _public_key = PublicEncryptingKey::from_der(public_key_der.as_ref()).expect("decoded"); + } + + macro_rules! round_trip_algorithm { + ($name:ident, $alg:expr, $keysize:expr) => { + #[test] + fn $name() { + const MESSAGE: &[u8] = b"Hello World!"; + + let private_key = PrivateDecryptingKey::generate($keysize).expect("generation"); + + assert_eq!(private_key.key_size(), $keysize); + + let public_key = private_key.public_key().expect("public key"); + + assert_eq!(public_key.key_size(), $keysize); + + let mut ciphertext = vec![0u8; private_key.key_size().len()]; + + let ciphertext = public_key + .encrypt($alg, MESSAGE, ciphertext.as_mut()) + .expect("encrypted"); + + let mut plaintext = vec![0u8; private_key.key_size().len()]; + + let plaintext = private_key + .decrypt($alg, ciphertext, &mut plaintext) + .expect("decryption"); + + assert_eq!(MESSAGE, plaintext); + } + }; + } + + round_trip_algorithm!( + rsa2048_oaep_sha1_mgf1sha1, + &OAEP_SHA1_MGF1SHA1, + KeySize::Rsa2048 + ); + round_trip_algorithm!( + rsa4096_oaep_sha1_mgf1sha1, + &OAEP_SHA1_MGF1SHA1, + KeySize::Rsa4096 + ); + round_trip_algorithm!( + rsa8192_oaep_sha1_mgf1sha1, + &OAEP_SHA1_MGF1SHA1, + KeySize::Rsa8192 + ); + + round_trip_algorithm!( + rsa2048_oaep_sha256_mgf1sha256, + &OAEP_SHA256_MGF1SHA256, + KeySize::Rsa2048 + ); + round_trip_algorithm!( + rsa4096_oaep_sha256_mgf1sha256, + &OAEP_SHA256_MGF1SHA256, + KeySize::Rsa4096 + ); + round_trip_algorithm!( + rsa8192_oaep_sha256_mgf1sha256, + &OAEP_SHA256_MGF1SHA256, + KeySize::Rsa8192 + ); + + round_trip_algorithm!( + rsa2048_oaep_sha384_mgf1sha384, + &OAEP_SHA384_MGF1SHA384, + KeySize::Rsa2048 + ); + round_trip_algorithm!( + rsa4096_oaep_sha384_mgf1sha384, + &OAEP_SHA384_MGF1SHA384, + KeySize::Rsa4096 + ); + round_trip_algorithm!( + rsa8192_oaep_sha384_mgf1sha384, + &OAEP_SHA384_MGF1SHA384, + KeySize::Rsa8192 + ); + + round_trip_algorithm!( + rsa2048_oaep_sha512_mgf1sha512, + &OAEP_SHA512_MGF1SHA512, + KeySize::Rsa2048 + ); + round_trip_algorithm!( + rsa4096_oaep_sha512_mgf1sha512, + &OAEP_SHA512_MGF1SHA512, + KeySize::Rsa4096 + ); + round_trip_algorithm!( + rsa8192_oaep_sha512_mgf1sha512, + &OAEP_SHA512_MGF1SHA512, + KeySize::Rsa8192 + ); +}