From 22358505db6a991a29bd476bd90ba3dd893e1cb5 Mon Sep 17 00:00:00 2001 From: Lucian Jones Date: Wed, 27 Apr 2022 09:57:40 +1200 Subject: [PATCH 1/3] Add test to reproduce issue --- execution_test.go | 53 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/execution_test.go b/execution_test.go index 709af147..1ecb47f2 100644 --- a/execution_test.go +++ b/execution_test.go @@ -984,6 +984,59 @@ func TestFederatedQueryFragmentSpreads(t *testing.T) { f.checkSuccess(t) }) + t.Run("with multiple top level fragment spreads (gadget implementation)", func(t *testing.T) { + f := &queryExecutionFixture{ + services: []testService{serviceA, serviceB}, + query: ` + query Foo { + snapshot(id: "GADGET1") { + id + name + ... GadgetFragment + ... GizmoFragment + } + } + + fragment GadgetFragment on GadgetImplementation { + gadgets { + id + name + agents { + name + ... on Agent { + country + } + } + } + } + + fragment GizmoFragment on GizmoImplementation { + gizmos { + id + name + } + }`, + expected: ` + { + "snapshot": { + "id": "100", + "name": "foo", + "gadgets": [ + { + "id": "GADGET1", + "name": "Gadget #1", + "agents": [ + {"name": "James Bond", "country": "UK"} + ] + } + ] + } + }`, + } + + f.checkSuccess(t) + }) + } func TestQueryExecutionMultipleServices(t *testing.T) { From 518ba05273238d8e6cdce5517ef26e8167497f90 Mon Sep 17 00:00:00 2001 From: Lucian Jones Date: Wed, 27 Apr 2022 10:02:21 +1200 Subject: [PATCH 2/3] Fix overwriting ObjectDefintion with wrong value --- execution.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/execution.go b/execution.go index 38b08de2..ae9163cc 100644 --- a/execution.go +++ b/execution.go @@ -690,7 +690,7 @@ func (s *ExecutableSchema) evaluateSkipAndIncludeRec(vars map[string]interface{} Name: selection.Name, Directives: removeSkipAndInclude(selection.Directives), Position: selection.Position, - ObjectDefinition: selection.Definition.Definition, + ObjectDefinition: selection.ObjectDefinition, Definition: &ast.FragmentDefinition{ Name: selection.Definition.Name, VariableDefinition: selection.Definition.VariableDefinition, From 9e0455f46af2d06e19ee126786b710f300664944 Mon Sep 17 00:00:00 2001 From: Lucian Jones Date: Wed, 27 Apr 2022 10:04:30 +1200 Subject: [PATCH 3/3] Eliminate unnecessary fragment spreads --- query_execution.go | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/query_execution.go b/query_execution.go index cca15a54..fdcdb2ca 100644 --- a/query_execution.go +++ b/query_execution.go @@ -843,7 +843,7 @@ func unionAndTrimSelectionSetRec(objectTypename string, schema *ast.Schema, sele fragment := selection if fragment.ObjectDefinition.IsAbstractType() && fragmentImplementsAbstractType(schema, fragment.ObjectDefinition.Name, fragment.TypeCondition) && - objectTypenameMatchesDifferentFragment(objectTypename, fragment) { + objectTypenameMatchesDifferentFragment(objectTypename, fragment.TypeCondition) { continue } @@ -853,7 +853,18 @@ func unionAndTrimSelectionSetRec(objectTypename string, schema *ast.Schema, sele filteredSelectionSet = append(filteredSelectionSet, selection) } case *ast.FragmentSpread: - filteredSelectionSet = append(filteredSelectionSet, selection) + fragment := selection + if fragment.ObjectDefinition.IsAbstractType() && + fragmentImplementsAbstractType(schema, fragment.ObjectDefinition.Name, fragment.Definition.TypeCondition) && + objectTypenameMatchesDifferentFragment(objectTypename, fragment.Definition.TypeCondition) { + continue + } + + filteredSelections := unionAndTrimSelectionSetRec(objectTypename, schema, fragment.Definition.SelectionSet, seenFields) + if len(filteredSelections) > 0 { + fragment.Definition.SelectionSet = filteredSelections + filteredSelectionSet = append(filteredSelectionSet, selection) + } } } @@ -869,6 +880,6 @@ func extractAndCastTypenameField(result map[string]interface{}) string { return typeNameInterface.(string) } -func objectTypenameMatchesDifferentFragment(typename string, fragment *ast.InlineFragment) bool { - return fragment.TypeCondition != typename +func objectTypenameMatchesDifferentFragment(typename, fragmentTypeCondition string) bool { + return fragmentTypeCondition != typename }