diff --git a/integration_tests/juniper_tests/src/codegen/derive_union.rs b/integration_tests/juniper_tests/src/codegen/derive_union.rs index fcec2e74c..39a586ea2 100644 --- a/integration_tests/juniper_tests/src/codegen/derive_union.rs +++ b/integration_tests/juniper_tests/src/codegen/derive_union.rs @@ -28,6 +28,35 @@ pub enum Character { Two(Droid), } +#[derive(juniper::GraphQLUnion)] +#[graphql(Scalar = juniper::DefaultScalarValue)] +pub enum CharacterGeneric { + One(Human), + Two(Droid), + #[allow(dead_code)] + #[graphql(ignore)] + Hidden(T), +} + +#[derive(juniper::GraphQLUnion)] +#[graphql(on Droid = CharacterDyn::as_droid)] +pub enum CharacterDyn { + One(Human), + //#[graphql(ignore)] + #[graphql(with = CharacterDyn::as_droid)] + Two(Droid), +} + +impl CharacterDyn { + fn as_droid(&self, _: &()) -> Option<&Droid> { + match self { + Self::Two(droid) => Some(droid), + _ => None, + } + } +} + + // Context Test pub struct CustomContext { is_left: bool, diff --git a/integration_tests/juniper_tests/src/codegen/impl_union.rs b/integration_tests/juniper_tests/src/codegen/impl_union.rs index 5ed28a3f4..527230a3e 100644 --- a/integration_tests/juniper_tests/src/codegen/impl_union.rs +++ b/integration_tests/juniper_tests/src/codegen/impl_union.rs @@ -42,3 +42,63 @@ impl<'a> GraphQLUnion for &'a dyn Character { } } } + +/* +#[juniper::graphql_union] +impl GraphQLUnion for dyn Character { + fn resolve_human(&self) -> Option<&Human> { + self.as_human() + } + + fn resolve_droid(&self) -> Option<&Droid> { + self.as_droid() + } +} +*/ + +/* +#[derive(GraphQLUnion)] +#[graphql( + Human = Char::resolve_human, + Droid = Char::resolve_droid, +)] +#[graphql(with(Char::resolve_human) => Human)] +#[graphql(object = Droid, with = Char::resolve_droid)] +struct Char { + id: String, +} + +impl Char { + fn resolve_human(&self, _: &Context) -> Option<&Human> { + unimplemented!() + } + fn resolve_droid(&self, _: &Context) -> Option<&Droid> { + unimplemented!() + } +} + +#[graphq_union] +trait Charctr { + fn as_human(&self) -> Option<&Human> { None } + fn as_droid(&self, _: &Context) -> Option<&Droid> { None } +} + +#[graphq_union( + Human = Char::resolve_human, + Droid = Char::resolve_droid, +)] +#[graphql(object = Human, with = Charctr2::resolve_human)] +#[graphql(object = Droid, with = Charctr2::resolve_droid)] +trait Charctr2 { + fn id(&self) -> &str; +} + +impl dyn Charctr2 { + fn resolve_human(&self, _: &Context) -> Option<&Human> { + unimplemented!() + } + fn resolve_droid(&self, _: &Context) -> Option<&Droid> { + unimplemented!() + } +} +*/ \ No newline at end of file diff --git a/juniper_codegen/src/graphql_union/derive.rs b/juniper_codegen/src/graphql_union/derive.rs index 42c07ae9b..847220405 100644 --- a/juniper_codegen/src/graphql_union/derive.rs +++ b/juniper_codegen/src/graphql_union/derive.rs @@ -30,8 +30,6 @@ fn expand_enum(ast: syn::DeriveInput, mode: Mode) -> syn::Result syn::Result = match ast.data { + let mut variants: Vec<_> = match ast.data { Data::Enum(data) => data.variants, _ => unreachable!(), } .into_iter() .filter_map(|var| graphql_union_variant_from_enum_variant(var, &enum_ident)) .collect(); + if !meta.custom_resolvers.is_empty() { + let crate_path = mode.crate_path(); + // TODO: refactor into separate function + for (ty, rslvr) in meta.custom_resolvers { + let span = rslvr.span_joined(); + + let resolver_fn = rslvr.into_inner(); + let resolver_code = parse_quote! { + #resolver_fn(self, #crate_path::FromContext::from(context)) + }; + // Doing this may be quite an expensive, because resolving may contain some heavy + // computation, so we're preforming it twice. Unfortunately, we have no other options + // here, until the `juniper::GraphQLType` itself will allow to do it in some cleverer + // way. + let resolver_check = parse_quote! { + ({ #resolver_code } as ::std::option::Option<&#ty>).is_some() + }; + + if let Some(var) = variants.iter_mut().find(|v| v.ty == ty) { + var.resolver_code = resolver_code; + var.resolver_check = resolver_check; + var.span = span; + } else { + variants.push(UnionVariantDefinition { + ty, + resolver_code, + resolver_check, + enum_path: None, + span, + }) + } + } + } if variants.is_empty() { SCOPE.not_empty(enum_span); } @@ -97,7 +128,19 @@ fn graphql_union_variant_from_enum_variant( let var_span = var.span(); let var_ident = var.ident; - let path = quote! { #enum_ident::#var_ident }; + let enum_path = quote! { #enum_ident::#var_ident }; + + // TODO + if meta.custom_resolver.is_some() { + unimplemented!() + } + + let resolver_code = parse_quote! { + match self { #enum_ident::#var_ident(ref v) => Some(v), _ => None, } + }; + let resolver_check = parse_quote! { + matches!(self, #enum_path(_)) + }; let ty = match var.fields { Fields::Unnamed(fields) => { @@ -121,14 +164,18 @@ fn graphql_union_variant_from_enum_variant( Some(UnionVariantDefinition { ty, - path, + resolver_code, + resolver_check, + enum_path: Some(enum_path), span: var_span, }) } struct UnionVariantDefinition { pub ty: syn::Type, - pub path: TokenStream, + pub resolver_code: syn::Expr, + pub resolver_check: syn::Expr, + pub enum_path: Option, pub span: Span, } @@ -177,23 +224,16 @@ impl UnionDefinition { let match_names = self.variants.iter().map(|var| { let var_ty = &var.ty; - let var_path = &var.path; + let var_check = &var.resolver_check; quote! { - #var_path(_) => <#var_ty as #crate_path::GraphQLType<#scalar>>::name(&()) - .unwrap().to_string(), + if #var_check { + return <#var_ty as #crate_path::GraphQLType<#scalar>>::name(&()) + .unwrap().to_string(); + } } }); - let match_resolves: Vec<_> = self - .variants - .iter() - .map(|var| { - let var_path = &var.path; - quote! { - match self { #var_path(ref val) => Some(val), _ => None, } - } - }) - .collect(); + let match_resolves: Vec<_> = self.variants.iter().map(|var| &var.resolver_code).collect(); let resolve_into_type = self.variants.iter().zip(match_resolves.iter()).map(|(var, expr)| { let var_ty = &var.ty; @@ -291,12 +331,15 @@ impl UnionDefinition { fn concrete_type_name( &self, - _: &Self::Context, + context: &Self::Context, _: &Self::TypeInfo, ) -> String { - match self { - #( #match_names )* - } + #( #match_names )* + panic!( + "GraphQL union {} cannot be resolved into any of its variants in its \ + current state", + #name, + ); } fn resolve_into_type( @@ -306,9 +349,10 @@ impl UnionDefinition { _: Option<&[#crate_path::Selection<#scalar>]>, executor: &#crate_path::Executor, ) -> #crate_path::ExecutionResult<#scalar> { + let context = executor.context(); #( #resolve_into_type )* panic!( - "Concrete type {} is not handled by instance resolvers on GraphQL Union {}", + "Concrete type {} is not handled by instance resolvers on GraphQL union {}", type_name, #name, ); } @@ -327,26 +371,27 @@ impl UnionDefinition { _: Option<&'b [#crate_path::Selection<'b, #scalar>]>, executor: &'b #crate_path::Executor<'b, 'b, Self::Context, #scalar> ) -> #crate_path::BoxFuture<'b, #crate_path::ExecutionResult<#scalar>> { + let context = executor.context(); #( #resolve_into_type_async )* panic!( - "Concrete type {} is not handled by instance resolvers on GraphQL Union {}", + "Concrete type {} is not handled by instance resolvers on GraphQL union {}", type_name, #name, ); } } }; - let conversion_impls = self.variants.iter().map(|var| { + let conversion_impls = self.variants.iter().filter_map(|var| { let var_ty = &var.ty; - let var_path = &var.path; - quote! { + let var_path = var.enum_path.as_ref()?; + Some(quote! { #[automatically_derived] impl#impl_generics ::std::convert::From<#var_ty> for #ty#ty_generics { fn from(v: #var_ty) -> Self { #var_path(v) } } - } + }) }); let output_type_impl = quote! { diff --git a/juniper_codegen/src/graphql_union/mod.rs b/juniper_codegen/src/graphql_union/mod.rs index 78f2a64d2..c1b76fde0 100644 --- a/juniper_codegen/src/graphql_union/mod.rs +++ b/juniper_codegen/src/graphql_union/mod.rs @@ -1,6 +1,8 @@ pub mod attribute; pub mod derive; +use std::collections::HashMap; + use syn::{ parse::{Parse, ParseStream}, spanned::Spanned as _, @@ -47,13 +49,21 @@ struct UnionMeta { /// /// [1]: https://spec.graphql.org/June2018/#sec-Unions pub scalar: Option>, + + /// Explicitly specified custom resolver functions for [GraphQL union][1] variants. + /// + /// If absent, then macro will try to auto-infer all the possible variants from the type + /// declaration, if possible. That's why specifying a custom resolver function has sense, when + /// some custom [union][1] variant resolving logic is involved, or variants cannot be inferred. + /// + /// [1]: https://spec.graphql.org/June2018/#sec-Unions + pub custom_resolvers: HashMap>, } impl Parse for UnionMeta { fn parse(input: ParseStream) -> syn::Result { let mut output = Self::default(); - // TODO: check for duplicates? while !input.is_empty() { let ident: syn::Ident = input.parse()?; match ident.to_string().as_str() { @@ -97,6 +107,17 @@ impl Parse for UnionMeta { .replace(SpanContainer::new(ident.span(), Some(scl.span()), scl)) .none_or_else(|_| syn::Error::new(ident.span(), "duplicated attribute"))? } + "on" => { + let ty = input.parse::()?; + input.parse::()?; + let rslvr = input.parse::()?; + let rslvr_spanned = SpanContainer::new(ident.span(), Some(ty.span()), rslvr); + let rslvr_span = rslvr_spanned.span_joined(); + output + .custom_resolvers + .insert(ty, rslvr_spanned) + .none_or_else(|_| syn::Error::new(rslvr_span, "duplicated attribute"))? + } _ => { return Err(syn::Error::new(ident.span(), "unknown attribute")); } @@ -146,6 +167,19 @@ impl UnionMeta { } other.scalar }, + custom_resolvers: { + if !self.custom_resolvers.is_empty() { + for (ty, rslvr) in self.custom_resolvers { + other + .custom_resolvers + .insert(ty, rslvr) + .none_or_else(|dup| { + syn::Error::new(dup.span_joined(), "duplicated attribute") + })?; + } + } + other.custom_resolvers + }, }) } @@ -174,6 +208,15 @@ struct UnionVariantMeta { /// /// [1]: https://spec.graphql.org/June2018/#sec-Unions pub ignore: Option>, + + /// Explicitly specified custom resolver function for this [GraphQL union][1] variant. + /// + /// If absent, then macro will generate the code which just returns the variant inner value. + /// Usually, specifying a custom resolver function has sense, when some custom resolving logic + /// is involved. + /// + /// [1]: https://spec.graphql.org/June2018/#sec-Unions + pub custom_resolver: Option>, } impl Parse for UnionVariantMeta { @@ -187,6 +230,14 @@ impl Parse for UnionVariantMeta { .ignore .replace(SpanContainer::new(ident.span(), None, ident.clone())) .none_or_else(|_| syn::Error::new(ident.span(), "duplicated attribute"))?, + "with" => { + input.parse::()?; + let rslvr = input.parse::()?; + output + .custom_resolver + .replace(SpanContainer::new(ident.span(), Some(rslvr.span()), rslvr)) + .none_or_else(|_| syn::Error::new(ident.span(), "duplicated attribute"))? + } _ => { return Err(syn::Error::new(ident.span(), "unknown attribute")); } @@ -213,6 +264,14 @@ impl UnionVariantMeta { } other.ignore }, + custom_resolver: { + if let Some(v) = self.custom_resolver { + other.custom_resolver.replace(v).none_or_else(|dup| { + syn::Error::new(dup.span_ident(), "duplicated attribute") + })?; + } + other.custom_resolver + }, }) } diff --git a/juniper_codegen/src/util/span_container.rs b/juniper_codegen/src/util/span_container.rs index d449fbb71..f335da9fe 100644 --- a/juniper_codegen/src/util/span_container.rs +++ b/juniper_codegen/src/util/span_container.rs @@ -1,4 +1,7 @@ -use std::ops; +use std::{ + hash::{Hash, Hasher}, + ops, +}; use proc_macro2::{Span, TokenStream}; use quote::ToTokens; @@ -25,6 +28,19 @@ impl SpanContainer { self.ident } + pub fn span_joined(&self) -> Span { + if let Some(s) = self.expr { + // TODO: Use `Span::join` once stabilized and available on stable: + // https://github.com/rust-lang/rust/issues/54725 + // self.ident.join(s).unwrap() + + // At the moment, just return the second, more meaningful part. + s + } else { + self.ident + } + } + pub fn into_inner(self) -> T { self.val } @@ -69,3 +85,12 @@ impl PartialEq for SpanContainer { &self.val == other } } + +impl Hash for SpanContainer { + fn hash(&self, state: &mut H) + where + H: Hasher, + { + self.val.hash(state) + } +}