Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make size, decode and encode derives work with non-suffixed literals #115

Merged
merged 2 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 77 additions & 16 deletions mls-rs-codec-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@
// Copyright by contributors to this project.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)

use std::str::FromStr;

use darling::{
ast::{self, Fields},
FromDeriveInput, FromField, FromVariant,
};
use proc_macro2::TokenStream;
use proc_macro2::{Literal, TokenStream};
use quote::quote;
use syn::{parse_macro_input, parse_quote, DeriveInput, Expr, Generics, Ident, Index, Path};
use syn::{
parse_macro_input, parse_quote, Attribute, DeriveInput, Expr, Generics, Ident, Index, Lit, Path,
};

enum Operation {
Size,
Expand Down Expand Up @@ -84,8 +88,9 @@ struct MlsVariantReceiver {
}

#[derive(FromDeriveInput)]
#[darling(attributes(mls_codec))]
#[darling(attributes(mls_codec), forward_attrs(repr))]
struct MlsInputReceiver {
attrs: Vec<Attribute>,
ident: Ident,
generics: Generics,
data: ast::Data<MlsVariantReceiver, MlsFieldReceiver>,
Expand All @@ -95,28 +100,86 @@ impl MlsInputReceiver {
fn handle_input(&self, operation: Operation) -> TokenStream {
match self.data {
ast::Data::Struct(ref s) => struct_impl(s, operation),
ast::Data::Enum(ref e) => enum_impl(&self.ident, e, operation),
ast::Data::Enum(ref e) => enum_impl(&self.ident, &self.attrs, e, operation),
}
}
}

fn enum_impl(ident: &Ident, variants: &[MlsVariantReceiver], operation: Operation) -> TokenStream {
fn repr_ident(attrs: &[Attribute]) -> Option<Ident> {
let repr_path = attrs
.iter()
.filter(|attr| matches!(attr.style, syn::AttrStyle::Outer))
.find(|attr| attr.path().is_ident("repr"))
.map(|repr| repr.parse_args())
.transpose()
.ok()
.flatten();

let Some(Expr::Path(path)) = repr_path else {
return None;
};

path.path
.segments
.iter()
.find(|s| s.ident != "C")
.map(|path| path.ident.clone())
}

/// Provides the discriminant for a given variant. If the variant does not specify a suffix
/// and a `repr_ident` is provided, it will be appended to number.
fn discriminant_for_variant(
variant: &MlsVariantReceiver,
repr_ident: &Option<Ident>,
) -> TokenStream {
let discriminant = variant
.discriminant
.clone()
.expect("Enum discriminants must be explicitly defined");

let Expr::Lit(lit_expr) = &discriminant else {
return quote! {#discriminant};
};

let Lit::Int(lit_int) = &lit_expr.lit else {
return quote! {#discriminant};
};

if lit_int.suffix().is_empty() {
// This is dirty and there is probably a better way of doing this but I'm way too much of a noob at
// proc macros to pull it off...
// TODO: Add proper support for correctly ignoring transparent, packed and modifiers
let str = format!(
"{}{}",
lit_int.base10_digits(),
&repr_ident.clone().expect("Expected a repr(u*) to be provided or for the variant's discriminant to be defined with suffixed literals.")
);
Literal::from_str(&str)
.map(|l| quote! {#l})
.ok()
.unwrap_or_else(|| quote! {#discriminant})
} else {
quote! {#discriminant}
}
}

fn enum_impl(
ident: &Ident,
attrs: &[Attribute],
variants: &[MlsVariantReceiver],
operation: Operation,
) -> TokenStream {
let handle_error = operation.is_result().then_some(quote! { ? });
let path = operation.path();
let call = operation.call();
let extras = operation.extras();
let enum_name = &ident;

let repr_ident = repr_ident(attrs);
if matches!(operation, Operation::Decode) {
let cases = variants.iter().map(|variant| {
let variant_name = &variant.ident;

// TODO: Calculate discriminant and support integers that are assumed types like u16,
// u32 etc based on repr()
let discriminant = &variant
.discriminant
.clone()
.expect("Enum discriminants must be explicitly defined");
let discriminant = discriminant_for_variant(variant, &repr_ident);

// TODO: Support more than 1 field
match variant.fields.len() {
Expand All @@ -142,10 +205,7 @@ fn enum_impl(ident: &Ident, variants: &[MlsVariantReceiver], operation: Operatio
let cases = variants.iter().map(|variant| {
let variant_name = &variant.ident;

let discriminant = &variant
.discriminant
.clone()
.expect("Enum discriminants must be explicitly defined");
let discriminant = discriminant_for_variant(variant, &repr_ident);

let (parameter, field) = if variant.fields.is_empty() {
(None, None)
Expand Down Expand Up @@ -220,6 +280,7 @@ where
F: FnOnce(&MlsInputReceiver) -> TokenStream,
{
let input = parse_macro_input!(input as DeriveInput);

let input = MlsInputReceiver::from_derive_input(&input).unwrap();

let name = &input.ident;
Expand Down
8 changes: 8 additions & 0 deletions mls-rs-codec/tests/macro_usage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ enum TestEnum {
Case3(TestTupleStruct) = 42u16,
}

#[derive(Debug, Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
#[repr(u8)]
enum TestEnumWithoutSuffixedLiterals {
Case1 = 1,
Case2(TestFieldStruct) = 200,
Case3(TestTupleStruct) = 42,
}

#[derive(Debug, Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
struct TestGeneric<T: MlsSize + MlsEncode + MlsDecode>(T);

Expand Down
Loading