From c73c8e7d92c539601ac5dc8800cfcb4961fa79b6 Mon Sep 17 00:00:00 2001 From: Caio Sym Date: Fri, 15 Mar 2024 10:17:07 +0000 Subject: [PATCH] Make size, decode and encode derives work with non-suffixed literals --- mls-rs-codec-derive/src/lib.rs | 93 +++++++++++++++++++++++++------ mls-rs-codec/tests/macro_usage.rs | 8 +++ 2 files changed, 85 insertions(+), 16 deletions(-) diff --git a/mls-rs-codec-derive/src/lib.rs b/mls-rs-codec-derive/src/lib.rs index 3c7603f9..593bd712 100644 --- a/mls-rs-codec-derive/src/lib.rs +++ b/mls-rs-codec-derive/src/lib.rs @@ -2,13 +2,17 @@ // Copyright by contributors to this project. // SPDX-License-Identifier: (Apache-2.0 OR MIT) +use std::str::FromStr; + use darling::{ ast::{self, Fields}, FromDeriveInput, FromField, FromVariant, }; -use proc_macro2::TokenStream; +use proc_macro2::{Literal, TokenStream}; use quote::quote; -use syn::{parse_macro_input, parse_quote, DeriveInput, Expr, Generics, Ident, Index, Path}; +use syn::{ + parse_macro_input, parse_quote, Attribute, DeriveInput, Expr, Generics, Ident, Index, Lit, Path, +}; enum Operation { Size, @@ -84,8 +88,9 @@ struct MlsVariantReceiver { } #[derive(FromDeriveInput)] -#[darling(attributes(mls_codec))] +#[darling(attributes(mls_codec), forward_attrs(repr))] struct MlsInputReceiver { + attrs: Vec, ident: Ident, generics: Generics, data: ast::Data, @@ -95,28 +100,86 @@ impl MlsInputReceiver { fn handle_input(&self, operation: Operation) -> TokenStream { match self.data { ast::Data::Struct(ref s) => struct_impl(s, operation), - ast::Data::Enum(ref e) => enum_impl(&self.ident, e, operation), + ast::Data::Enum(ref e) => enum_impl(&self.ident, &self.attrs, e, operation), } } } -fn enum_impl(ident: &Ident, variants: &[MlsVariantReceiver], operation: Operation) -> TokenStream { +fn repr_ident(attrs: &[Attribute]) -> Option { + let repr_path = attrs + .iter() + .filter(|attr| matches!(attr.style, syn::AttrStyle::Outer)) + .find(|attr| attr.path().is_ident("repr")) + .map(|repr| repr.parse_args()) + .transpose() + .ok() + .flatten(); + + let Some(Expr::Path(path)) = repr_path else { + return None; + }; + + path.path + .segments + .iter() + .find(|s| s.ident != "C") + .map(|path| path.ident.clone()) +} + +/// Provides the discriminant for a given variant. If the variant does not specify a suffix +/// and a `repr_ident` is provided, it will be appended to number. +fn discriminant_for_variant( + variant: &MlsVariantReceiver, + repr_ident: &Option, +) -> TokenStream { + let discriminant = variant + .discriminant + .clone() + .expect("Enum discriminants must be explicitly defined"); + + let Expr::Lit(lit_expr) = &discriminant else { + return quote! {#discriminant}; + }; + + let Lit::Int(lit_int) = &lit_expr.lit else { + return quote! {#discriminant}; + }; + + if lit_int.suffix().is_empty() { + // This is dirty and there is probably a better way of doing this but I'm way too much of a noob at + // proc macros to pull it off... + // TODO: Add proper support for correctly ignoring transparent, packed and modifiers + let str = format!( + "{}{}", + lit_int.base10_digits(), + &repr_ident.clone().expect("Expected a repr(u*) to be provided or for the variant's discriminant to be defined with suffixed literals.") + ); + Literal::from_str(&str) + .map(|l| quote! {#l}) + .ok() + .unwrap_or_else(|| quote! {#discriminant}) + } else { + quote! {#discriminant} + } +} + +fn enum_impl( + ident: &Ident, + attrs: &[Attribute], + variants: &[MlsVariantReceiver], + operation: Operation, +) -> TokenStream { let handle_error = operation.is_result().then_some(quote! { ? }); let path = operation.path(); let call = operation.call(); let extras = operation.extras(); let enum_name = &ident; - + let repr_ident = repr_ident(attrs); if matches!(operation, Operation::Decode) { let cases = variants.iter().map(|variant| { let variant_name = &variant.ident; - // TODO: Calculate discriminant and support integers that are assumed types like u16, - // u32 etc based on repr() - let discriminant = &variant - .discriminant - .clone() - .expect("Enum discriminants must be explicitly defined"); + let discriminant = discriminant_for_variant(variant, &repr_ident); // TODO: Support more than 1 field match variant.fields.len() { @@ -142,10 +205,7 @@ fn enum_impl(ident: &Ident, variants: &[MlsVariantReceiver], operation: Operatio let cases = variants.iter().map(|variant| { let variant_name = &variant.ident; - let discriminant = &variant - .discriminant - .clone() - .expect("Enum discriminants must be explicitly defined"); + let discriminant = discriminant_for_variant(variant, &repr_ident); let (parameter, field) = if variant.fields.is_empty() { (None, None) @@ -220,6 +280,7 @@ where F: FnOnce(&MlsInputReceiver) -> TokenStream, { let input = parse_macro_input!(input as DeriveInput); + let input = MlsInputReceiver::from_derive_input(&input).unwrap(); let name = &input.ident; diff --git a/mls-rs-codec/tests/macro_usage.rs b/mls-rs-codec/tests/macro_usage.rs index 72e17db6..40b9feee 100644 --- a/mls-rs-codec/tests/macro_usage.rs +++ b/mls-rs-codec/tests/macro_usage.rs @@ -40,6 +40,14 @@ enum TestEnum { Case3(TestTupleStruct) = 42u16, } +#[derive(Debug, Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)] +#[repr(u8)] +enum TestEnumWithoutSuffixedLiterals { + Case1 = 1, + Case2(TestFieldStruct) = 200, + Case3(TestTupleStruct) = 42, +} + #[derive(Debug, Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)] struct TestGeneric(T);