Skip to content

Commit

Permalink
Add mul(forward_and_scalar)
Browse files Browse the repository at this point in the history
It currently isn't possible to use `#[derive(Mul)]` to implement multiplication with itself AND with a scalar at the same time.

Example:
```rust
fn without_forward() {
    #[derive(Clone, Copy, Mul)]
    pub struct Vec2 {
        pub x: f32,
        pub y: f32,
    }

    let a = Vec2 { x: 1.0, y: 2.0 };
    let c = a * Vec2 { x: 1.0, y: 1.0 }; // ❌ Doesn't work
    let d = a * 2.0; // ✔️ Works
}

fn with_forward() {
    #[derive(Clone, Copy, Mul)]
    #[mul(forward)]
    pub struct Vec2 {
        pub x: f32,
        pub y: f32,
    }

    let a = Vec2 { x: 1.0, y: 2.0 };
    let c = a * Vec2 { x: 1.0, y: 1.0 }; // ✔️ Works
    let d = a * 2.0; // ❌ Doesn't work
}
```

With this commit, you can have both:
```rust
fn with_forward_and_scalar() {
    #[derive(Clone, Copy, Mul)]
    #[mul(forward_and_scalar)]
    pub struct Vec2 {
        pub x: f32,
        pub y: f32,
    }

    let a = Vec2 { x: 1.0, y: 2.0 };
    let c = a * Vec2 { x: 1.0, y: 1.0 }; // ✔️ Works
    let d = a * 2.0; // ✔️ Works
}
```
  • Loading branch information
Cannedfood committed Mar 9, 2025
1 parent 1b0e166 commit 53f4749
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 8 deletions.
21 changes: 17 additions & 4 deletions impl/src/mul_assign_like.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,28 @@ pub fn expand(input: &DeriveInput, trait_name: &'static str) -> Result<TokenStre
.to_string()
+ "_assign";

let mut state = State::with_attr_params(
let state = State::with_attr_params(
input,
trait_name,
method_name,
AttrParams::struct_(vec!["forward"]),
AttrParams::struct_(vec!["forward", "forward_and_scalar"]),
)?;

if state.default_info.forward {
return Ok(add_assign_like::expand(input, trait_name));
}

if state.default_info.forward_and_scalar {
return Ok(quote! {
{add_assign_like::expand(input, trait_name)}
{expand_mul_assign_like(state)}
});
}

Ok(expand_mul_assign_like(state))
}

fn expand_mul_assign_like(mut state: State<'_>) -> TokenStream {
let scalar_ident = format_ident!("__RhsT");
state.add_trait_path_type_param(quote! { #scalar_ident });
let multi_field_data = state.enabled_fields_data();
Expand Down Expand Up @@ -51,7 +64,7 @@ pub fn expand(input: &DeriveInput, trait_name: &'static str) -> Result<TokenStre
);
let (impl_generics, _, where_clause) = generics.split_for_impl();

Ok(quote! {
quote! {
#[automatically_derived]
impl #impl_generics #trait_path<#scalar_ident> for #input_type #ty_generics #where_clause {
#[inline]
Expand All @@ -60,5 +73,5 @@ pub fn expand(input: &DeriveInput, trait_name: &'static str) -> Result<TokenStre
#( #exprs; )*
}
}
})
}
}
17 changes: 13 additions & 4 deletions impl/src/mul_like.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,25 @@ use std::iter;
use syn::{DeriveInput, Result};

pub fn expand(input: &DeriveInput, trait_name: &'static str) -> Result<TokenStream> {
let mut state = State::with_attr_params(
let state = State::with_attr_params(
input,
trait_name,
trait_name.to_lowercase(),
AttrParams::struct_(vec!["forward"]),
AttrParams::struct_(vec!["forward", "forward_and_scalar"]),
)?;
if state.default_info.forward {
return Ok(add_like::expand(input, trait_name));
}
if state.default_info.forward_and_scalar {
return Ok(quote! {
{add_like::expand(input, trait_name)}
{expand_mul_like(state)}
});
}
Ok(expand_mul_like(state))
}

fn expand_mul_like(mut state: State<'_>) -> TokenStream {
let scalar_ident = format_ident!("__RhsT");
state.add_trait_path_type_param(quote! { #scalar_ident });
let multi_field_data = state.enabled_fields_data();
Expand Down Expand Up @@ -47,7 +56,7 @@ pub fn expand(input: &DeriveInput, trait_name: &'static str) -> Result<TokenStre
);
let body = multi_field_data.initializer(&initializers);
let (impl_generics, _, where_clause) = generics.split_for_impl();
Ok(quote! {
quote! {
#[automatically_derived]
impl #impl_generics #trait_path_with_params for #input_type #ty_generics #where_clause {
type Output = #input_type #ty_generics;
Expand All @@ -58,5 +67,5 @@ pub fn expand(input: &DeriveInput, trait_name: &'static str) -> Result<TokenStre
#body
}
}
})
}
}
20 changes: 20 additions & 0 deletions impl/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ impl<'input> State<'input> {
let defaults = struct_meta_info.into_full(FullMetaInfo {
enabled: default_enabled,
forward: false,
forward_and_scalar: false,
// Default to owned true, except when first attribute has one of owned,
// ref or ref_mut
// - not a single attribute means default true
Expand Down Expand Up @@ -873,6 +874,13 @@ fn get_meta_info(
None,
)?;

if info.forward.is_some() && info.forward_and_scalar.is_some() {
return Err(Error::new(
list.span(),
"Attributes `forward` and `forward_and_scalar` are mutually exclusive",
));
}

Ok(info)
}

Expand Down Expand Up @@ -1016,7 +1024,13 @@ fn parse_punctuated_nested_meta(
match (wrapper_name, attr_name.as_str()) {
(None, "ignore") => info.enabled = Some(false),
(None, "forward") => info.forward = Some(true),
(None, "forward_and_scalar") => {
info.forward_and_scalar = Some(true)
}
(Some("not"), "forward") => info.forward = Some(false),
(Some("not"), "forward_and_scalar") => {
info.forward_and_scalar = Some(false)
}
(None, "owned") => info.owned = Some(true),
(None, "ref") => info.ref_ = Some(true),
(None, "ref_mut") => info.ref_mut = Some(true),
Expand Down Expand Up @@ -1204,6 +1218,7 @@ pub(crate) mod polyfill {
pub struct FullMetaInfo {
pub enabled: bool,
pub forward: bool,
pub forward_and_scalar: bool,
pub owned: bool,
pub ref_: bool,
pub ref_mut: bool,
Expand All @@ -1214,6 +1229,7 @@ pub struct FullMetaInfo {
pub struct MetaInfo {
pub enabled: Option<bool>,
pub forward: Option<bool>,
pub forward_and_scalar: Option<bool>,
pub owned: Option<bool>,
pub ref_: Option<bool>,
pub ref_mut: Option<bool>,
Expand All @@ -1228,6 +1244,9 @@ impl MetaInfo {
FullMetaInfo {
enabled: self.enabled.unwrap_or(defaults.enabled),
forward: self.forward.unwrap_or(defaults.forward),
forward_and_scalar: self
.forward_and_scalar
.unwrap_or(defaults.forward_and_scalar),
owned: self.owned.unwrap_or(defaults.owned),
ref_: self.ref_.unwrap_or(defaults.ref_),
ref_mut: self.ref_mut.unwrap_or(defaults.ref_mut),
Expand Down Expand Up @@ -1727,6 +1746,7 @@ pub(crate) mod attr {
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
match input.parse::<syn::Path>()? {
p if p.is_ident("forward") => Ok(Self),
p if p.is_ident("forward_and_scalar") => Ok(Self),
p => Err(syn::Error::new(p.span(), "only `forward` allowed here")),
}
}
Expand Down

0 comments on commit 53f4749

Please sign in to comment.