diff --git a/auth.go b/auth.go index 1da99813..b85bae17 100644 --- a/auth.go +++ b/auth.go @@ -180,7 +180,7 @@ func filterDefinition(sourceSchema *ast.Schema, visited map[string]bool, types m // Node interface is not defined in the merged schema continue } - if typ.Kind == ast.Interface { + if typ.IsAbstractType() { for _, pt := range sourceSchema.PossibleTypes[typ.Name] { types[pt.Name] = pt _ = filterDefinition(sourceSchema, visited, types, pt, AllowedFields{AllowAll: true}) @@ -194,12 +194,6 @@ func filterDefinition(sourceSchema *ast.Schema, visited map[string]bool, types m _ = filterDefinition(sourceSchema, visited, types, sourceSchema.Types[typeName], AllowedFields{AllowAll: true}) } - // unions - for _, t := range def.Types { - types[t] = sourceSchema.Types[t] - _ = filterDefinition(sourceSchema, visited, types, sourceSchema.Types[t], AllowedFields{AllowAll: true}) - } - return &resDef } @@ -212,8 +206,8 @@ func filterDefinition(sourceSchema *ast.Schema, visited map[string]bool, types m // Node interface is not defined in the merged schema continue } - // if the type is an interface we filter all the possible types - if typ.Kind == ast.Interface { + // if the type is abstract we filter all the possible types + if typ.IsAbstractType() { for _, pt := range sourceSchema.PossibleTypes[typ.Name] { newTypeDef := filterDefinition(sourceSchema, visited, types, pt, allowedSubFields) if typeDef, ok := types[pt.Name]; ok { diff --git a/auth_test.go b/auth_test.go index 6428b10a..2b9b367b 100644 --- a/auth_test.go +++ b/auth_test.go @@ -421,7 +421,7 @@ func TestFilterSchema(t *testing.T) { `), formatSchema(filteredSchema)) }) - t.Run(`union`, func(t *testing.T) { + t.Run(`union, allow all`, func(t *testing.T) { perms := OperationPermissions{ AllowedRootQueryFields: AllowedFields{AllowedSubfields: map[string]AllowedFields{ "somethingRandom": { @@ -454,6 +454,35 @@ func TestFilterSchema(t *testing.T) { `), formatSchema(filteredSchema)) }) + t.Run(`union`, func(t *testing.T) { + perms := OperationPermissions{ + AllowedRootQueryFields: AllowedFields{AllowedSubfields: map[string]AllowedFields{ + "somethingRandom": { + AllowedSubfields: map[string]AllowedFields{ + "id": {}, + }, + }, + }, + }, + } + filteredSchema := perms.FilterSchema(schema) + assert.Equal(t, loadAndFormatSchema(` + union MovieOrCinema = Movie | Cinema + + type Cinema { + id: ID! + } + + type Movie { + id: ID! + } + + type Query { + somethingRandom: MovieOrCinema! + } + `), formatSchema(filteredSchema)) + }) + t.Run(`interface`, func(t *testing.T) { perms := OperationPermissions{ AllowedRootQueryFields: AllowedFields{AllowedSubfields: map[string]AllowedFields{