Skip to content

Commit

Permalink
Make size, decode and encode derives work with non-suffixed literals
Browse files Browse the repository at this point in the history
  • Loading branch information
CaioSym committed Mar 15, 2024
1 parent 2d92710 commit 15c8156
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 16 deletions.
90 changes: 74 additions & 16 deletions mls-rs-codec-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@
// Copyright by contributors to this project.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)

use std::str::FromStr;

use darling::{
ast::{self, Fields},
FromDeriveInput, FromField, FromVariant,
};
use proc_macro2::TokenStream;
use proc_macro2::{Literal, TokenStream};
use quote::quote;
use syn::{parse_macro_input, parse_quote, DeriveInput, Expr, Generics, Ident, Index, Path};
use syn::{
parse_macro_input, parse_quote, Attribute, DeriveInput, Expr, Generics, Ident, Index, Lit, Path,
};

enum Operation {
Size,
Expand Down Expand Up @@ -84,8 +88,9 @@ struct MlsVariantReceiver {
}

#[derive(FromDeriveInput)]
#[darling(attributes(mls_codec))]
#[darling(attributes(mls_codec), forward_attrs(repr))]
struct MlsInputReceiver {
attrs: Vec<Attribute>,
ident: Ident,
generics: Generics,
data: ast::Data<MlsVariantReceiver, MlsFieldReceiver>,
Expand All @@ -95,28 +100,83 @@ impl MlsInputReceiver {
fn handle_input(&self, operation: Operation) -> TokenStream {
match self.data {
ast::Data::Struct(ref s) => struct_impl(s, operation),
ast::Data::Enum(ref e) => enum_impl(&self.ident, e, operation),
ast::Data::Enum(ref e) => enum_impl(&self.ident, &self.attrs, e, operation),
}
}
}

fn enum_impl(ident: &Ident, variants: &[MlsVariantReceiver], operation: Operation) -> TokenStream {
fn repr_type(attrs: &Vec<Attribute>) -> Option<Expr> {
attrs
.iter()
.filter(|attr| matches!(attr.style, syn::AttrStyle::Outer))
.filter(|attr| attr.path().is_ident("repr"))
.next()
.map(|repr| repr.parse_args())
.transpose()
.ok()
.flatten()
}

/// Provides the discriminant for a given variant. If the variant does not specify a suffix
/// and a `repr_type` is provided, it will be appended to number.
fn discriminant_for_variant(variant: &MlsVariantReceiver, repr_type: &Option<Expr>) -> TokenStream {
let discriminant = variant
.discriminant
.clone()
.expect("Enum discriminants must be explicitly defined");

let Expr::Lit(lit_expr) = &discriminant else {
return quote! {#discriminant};
};

let Lit::Int(lit_int) = &lit_expr.lit else {
return quote! {#discriminant};
};

if lit_int.suffix().is_empty() {
let Some(Expr::Path(path)) = repr_type else {
return quote! {#discriminant};
};

let repr_ident = path.path.segments.iter()
.filter(|s| s.ident != "C")
.next()
.expect("Expected a repr(u*) to be provided or for the variant's discriminant to be defined with suffixed literals.");

// This is dirty and there is probably a better way of doing this but I'm way too much of a noob at
// proc macros to pull it off...
// TODO: Add proper support for correctly ignoring transparent, packed and modifiers
let str = format!(
"{}{}",
lit_int.base10_digits(),
repr_ident.ident.to_string()
);
return Literal::from_str(&str)
.map(|l| quote! {#l})
.ok()
.unwrap_or_else(|| quote! {#discriminant});
} else {
quote! {#discriminant}
}
}

fn enum_impl(
ident: &Ident,
attrs: &Vec<Attribute>,
variants: &[MlsVariantReceiver],
operation: Operation,
) -> TokenStream {
let handle_error = operation.is_result().then_some(quote! { ? });
let path = operation.path();
let call = operation.call();
let extras = operation.extras();
let enum_name = &ident;

let repr_type = repr_type(attrs);
if matches!(operation, Operation::Decode) {
let cases = variants.iter().map(|variant| {
let variant_name = &variant.ident;

// TODO: Calculate discriminant and support integers that are assumed types like u16,
// u32 etc based on repr()
let discriminant = &variant
.discriminant
.clone()
.expect("Enum discriminants must be explicitly defined");
let discriminant = discriminant_for_variant(variant, &repr_type);

// TODO: Support more than 1 field
match variant.fields.len() {
Expand All @@ -142,10 +202,7 @@ fn enum_impl(ident: &Ident, variants: &[MlsVariantReceiver], operation: Operatio
let cases = variants.iter().map(|variant| {
let variant_name = &variant.ident;

let discriminant = &variant
.discriminant
.clone()
.expect("Enum discriminants must be explicitly defined");
let discriminant = discriminant_for_variant(variant, &repr_type);

let (parameter, field) = if variant.fields.is_empty() {
(None, None)
Expand Down Expand Up @@ -220,6 +277,7 @@ where
F: FnOnce(&MlsInputReceiver) -> TokenStream,
{
let input = parse_macro_input!(input as DeriveInput);

let input = MlsInputReceiver::from_derive_input(&input).unwrap();

let name = &input.ident;
Expand Down
8 changes: 8 additions & 0 deletions mls-rs-codec/tests/macro_usage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ enum TestEnum {
Case3(TestTupleStruct) = 42u16,
}

#[derive(Debug, Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
#[repr(u8)]
enum TestEnumWithoutSuffixedLiterals {
Case1 = 1,
Case2(TestFieldStruct) = 200,
Case3(TestTupleStruct) = 42,
}

#[derive(Debug, Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
struct TestGeneric<T: MlsSize + MlsEncode + MlsDecode>(T);

Expand Down

0 comments on commit 15c8156

Please sign in to comment.