Skip to content

Commit

Permalink
Add full support for generics
Browse files Browse the repository at this point in the history
  • Loading branch information
sosthene-nitrokey committed Feb 17, 2025
1 parent e5ee233 commit d8f4b69
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 21 deletions.
90 changes: 69 additions & 21 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,33 @@ extern crate proc_macro;
mod parse;

use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{parse_macro_input, Lifetime, LifetimeParam};
use proc_macro2::{Ident, Span};
use quote::{format_ident, quote, ToTokens};
use syn::{parse_macro_input, Lifetime, LifetimeParam, TypeParamBound};

use crate::parse::Input;

/// Wrapper around syn structs that don't implement `Copy` but we want to use at multiple places
#[derive(Clone, Copy)]
struct CopyWrapper<'a, T>(&'a T);

impl<'a, T: ToTokens> ToTokens for CopyWrapper<'a, T> {
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
self.0.to_tokens(tokens)
}

fn to_token_stream(&self) -> proc_macro2::TokenStream {
self.0.to_token_stream()
}

fn into_token_stream(self) -> proc_macro2::TokenStream
where
Self: Sized,
{
self.0.to_token_stream()
}
}

fn serialize_fields(fields: &[parse::Field], offset: usize) -> Vec<proc_macro2::TokenStream> {
fields
.iter()
Expand Down Expand Up @@ -85,11 +107,17 @@ pub fn derive_serialize(input: TokenStream) -> TokenStream {
let ident = input.ident;
let num_fields = count_serialized_fields(&input.fields);
let serialize_fields = serialize_fields(&input.fields, input.attrs.offset);
let lifetimes_imp = input.generics.lifetimes();
let lifetimes_ty = input.generics.lifetimes();
let (_, ty_generics, where_clause) = input.generics.split_for_impl();
let mut generics_cl = input.generics.clone();
generics_cl.type_params_mut().for_each(|t| {
t.bounds
.push_value(TypeParamBound::Verbatim(quote!(serde::Serialize)));
});
let (impl_generics, _, _) = generics_cl.split_for_impl();

TokenStream::from(quote! {
impl<#(#lifetimes_imp),*> serde::Serialize for #ident<#(#lifetimes_ty),*> {
#[automatically_derived]
impl #impl_generics serde::Serialize for #ident #ty_generics #where_clause {
fn serialize<S>(&self, serializer: S) -> core::result::Result<S::Ok, S::Error>
where
S: serde::Serializer
Expand Down Expand Up @@ -169,13 +197,6 @@ fn all_fields(fields: &[parse::Field]) -> Vec<proc_macro2::TokenStream> {
.collect()
}

fn de_lifetime<'a>(lifetimes: impl Iterator<Item = &'a LifetimeParam>) -> proc_macro2::TokenStream {
let lifetimes = lifetimes.map(|l| &l.lifetime);
quote! {
'de: #(#lifetimes)+*
}
}

#[proc_macro_derive(DeserializeIndexed, attributes(serde, serde_indexed))]
pub fn derive_deserialize(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as Input);
Expand All @@ -184,9 +205,34 @@ pub fn derive_deserialize(input: TokenStream) -> TokenStream {
let unwrap_expected_fields = unwrap_expected_fields(&input.fields);
let match_fields = match_fields(&input.fields, input.attrs.offset);
let all_fields = all_fields(&input.fields);
let de_lifetime = de_lifetime(input.generics.lifetimes());
let lifetimes: Vec<_> = input.generics.lifetimes().collect();
let lifetimes = &*lifetimes;

let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let ty_generics = CopyWrapper(&ty_generics);

let mut generics_cl = input.generics.clone();
generics_cl.params.insert(
0,
syn::GenericParam::Lifetime(LifetimeParam {
attrs: Vec::new(),
lifetime: Lifetime {
apostrophe: Span::call_site(),
ident: Ident::new("de", Span::call_site()),
},
colon_token: None,
bounds: input
.generics
.lifetimes()
.map(|l| l.lifetime.clone())
.collect(),
}),
);
generics_cl.type_params_mut().for_each(|t| {
t.bounds
.push_value(TypeParamBound::Verbatim(quote!(serde::Deserialize<'de>)));
});

let (impl_generics_with_de, _, _) = generics_cl.split_for_impl();
let impl_generics_with_de = CopyWrapper(&impl_generics_with_de);

let the_loop = if !input.fields.is_empty() {
// NB: In the previous "none_fields", we use the actual struct's
Expand All @@ -207,16 +253,17 @@ pub fn derive_deserialize(input: TokenStream) -> TokenStream {
quote! {}
};

TokenStream::from(quote! {
impl<#de_lifetime, #(#lifetimes),*> serde::Deserialize<'de> for #ident<#(#lifetimes),*> {
let res = quote! {
#[automatically_derived]
impl #impl_generics_with_de serde::Deserialize<'de> for #ident #ty_generics #where_clause {
fn deserialize<D>(deserializer: D) -> core::result::Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct IndexedVisitor<#(#lifetimes),*>(core::marker::PhantomData<#(&#lifetimes)* ()>);
struct IndexedVisitor #impl_generics (core::marker::PhantomData<#ident #ty_generics>);

impl<#de_lifetime, #(#lifetimes),*> serde::de::Visitor<'de> for IndexedVisitor<#(#lifetimes),*> {
type Value = #ident<#(#lifetimes),*>;
impl #impl_generics_with_de serde::de::Visitor<'de> for IndexedVisitor #ty_generics {
type Value = #ident #ty_generics;

fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
formatter.write_str(stringify!(#ident))
Expand All @@ -239,5 +286,6 @@ pub fn derive_deserialize(input: TokenStream) -> TokenStream {
deserializer.deserialize_map(IndexedVisitor(Default::default()))
}
}
})
};
TokenStream::from(res)
}
111 changes: 111 additions & 0 deletions tests/basics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,114 @@ mod cow {
};
}
}

mod generics {
use super::*;
use heapless::String;
use serde_byte_array::ByteArray;
use serde_bytes::Bytes;

const SERIALIZED_GENERIC_EXAMPLE: &'static [u8] = b"\xa1\x01\x43\x01\x02\x03";

#[derive(PartialEq, Debug, SerializeIndexed, DeserializeIndexed)]
#[serde_indexed(offset = 1)]
struct WithGeneric<T> {
data: T,
#[serde(skip_serializing_if = "Option::is_none")]
option: Option<u8>,
}

fn generics_example<'a>() -> WithGeneric<&'a Bytes> {
WithGeneric {
data: Bytes::new(&[1, 2, 3]),
option: None,
}
}

#[derive(PartialEq, Debug, SerializeIndexed, DeserializeIndexed)]
#[serde_indexed(offset = 1)]
struct WithConstGeneric<const N: usize> {
data: ByteArray<N>,
#[serde(skip_serializing_if = "Option::is_none")]
option: Option<u8>,
}

fn const_generics_example<'a>() -> WithConstGeneric<3> {
WithConstGeneric {
data: ByteArray::new([1, 2, 3]),
option: None,
}
}

#[test]
fn serialize() {
let data = generics_example();
let mut buf = [0u8; 64];
let size = cbor_serialize(&data, &mut buf).unwrap();

assert_eq!(&buf[..size], SERIALIZED_GENERIC_EXAMPLE);

let data = const_generics_example();
let mut buf = [0u8; 64];
let size = cbor_serialize(&data, &mut buf).unwrap();

assert_eq!(&buf[..size], SERIALIZED_GENERIC_EXAMPLE);
}

#[test]
fn deserialize() {
let example = generics_example();

let deserialized: WithGeneric<&'_ Bytes> =
cbor_deserialize_with_scratch(SERIALIZED_GENERIC_EXAMPLE, &mut []).unwrap();

assert_eq!(deserialized, example);

let example = const_generics_example();

let deserialized: WithConstGeneric<3> =
cbor_deserialize_with_scratch(SERIALIZED_GENERIC_EXAMPLE, &mut []).unwrap();

assert_eq!(deserialized, example);
}

#[derive(PartialEq, Debug, SerializeIndexed, DeserializeIndexed)]
#[serde_indexed(offset = 1)]
struct WithAllGenerics<'a, 'b, T, I, const N: usize, const Z: usize> {
data1: heapless::Vec<T, N>,
data2: heapless::Vec<I, Z>,
data3: &'a Bytes,
data4: &'b ByteArray<Z>,
}

fn all_generics_example<'a, 'b>() -> WithAllGenerics<'a, 'b, String<5>, u8, 10, 3> {
let data1 = heapless::Vec::from_slice(&["abc".into(), "acdef".into()]).unwrap();
let data2 = heapless::Vec::from_slice(&[1, 2]).unwrap();

const BYTES: ByteArray<3> = ByteArray::new(*b"123");
WithAllGenerics {
data1,
data2,
data3: Bytes::new(b"bytes"),
data4: &BYTES,
}
}

#[test]
fn all_generics() {
const SERIALIZED_ALL_GENERIC_EXAMPLE: &'static [u8] = b"\xa4\x01\x82\x63\x61\x62\x63\x65\x61\x63\x64\x65\x66\x02\x82\x01\x02\x03\x45\x62\x79\x74\x65\x73\x04\x43\x31\x32\x33";
let data = all_generics_example();
let mut buf = [0u8; 64];
let size = cbor_serialize(&data, &mut buf).unwrap();

println!("{buf:02x?}");
assert_eq!(&buf[..size], SERIALIZED_ALL_GENERIC_EXAMPLE);

let example = all_generics_example();

let deserialized: WithAllGenerics<'_, '_, String<5>, u8, 10, 3> =
cbor_deserialize_with_scratch(SERIALIZED_ALL_GENERIC_EXAMPLE, &mut []).unwrap();

assert_eq!(deserialized, example);
}
}

0 comments on commit d8f4b69

Please sign in to comment.