Skip to content

Commit

Permalink
Merge pull request #550 from typelevel/topic/field-merging
Browse files Browse the repository at this point in the history
Fully implement field selection merging and collection rules
  • Loading branch information
milessabin authored Jan 29, 2024
2 parents 63a6f10 + f2d2da2 commit 5f52e39
Show file tree
Hide file tree
Showing 9 changed files with 2,241 additions and 115 deletions.
186 changes: 169 additions & 17 deletions modules/core/src/main/scala/compiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ class QueryCompiler(parser: QueryParser, schema: Schema, phases: List[Phase]) {
parser.parseText(text).flatMap { case (ops, frags) =>
for {
_ <- Result.fromProblems(validateVariablesAndFragments(ops, frags, reportUnused))
_ <- Result.fromProblems(validateFieldMergeability(ops, frags))
ops0 <- ops.traverse(op => compileOperation(op, untypedVars, frags, introspectionLevel, env).map(op0 => (op.name, op0)))
res <- (ops0, name) match {
case (List((_, op)), None) =>
Expand Down Expand Up @@ -250,7 +251,7 @@ class QueryCompiler(parser: QueryParser, schema: Schema, phases: List[Phase]) {
*/
def compileOperation(op: UntypedOperation, untypedVars: Option[Json], frags: List[UntypedFragment], introspectionLevel: IntrospectionLevel = Full, env: Env = Env.empty): Result[Operation] = {
val allPhases =
IntrospectionElaborator(introspectionLevel).toList ++ (VariablesSkipAndFragmentElaborator :: phases)
IntrospectionElaborator(introspectionLevel).toList ++ (VariablesSkipAndFragmentElaborator :: MergeFields :: phases)

for {
varDefs <- compileVarDefs(op.variables)
Expand Down Expand Up @@ -475,6 +476,146 @@ class QueryCompiler(parser: QueryParser, schema: Schema, phases: List[Phase]) {
}
}
}

/**
* Validates that field mergeability rules are satisfied for the supplied operations and fragments.
*
* Returns a list of problems encountered.
*/
def validateFieldMergeability(ops: List[UntypedOperation], frags: List[UntypedFragment]): List[Problem] = {
// Validates field mergeability for a single operation
def validateOp(op: UntypedOperation): List[Problem] = {
// Collects all top level selects of the supplied queries, ungrouping and inlining fragments where necessary
def collectSelects(queries: List[(NamedType, Query)]): List[(NamedType, UntypedSelect)] = {
queries.flatMap {
case (tpe, g: Group) => collectSelects(g.queries.map(q => (tpe, q)))
case (tpe, s: UntypedSelect) => List((tpe, s))
case (tpe, i: UntypedInlineFragment) =>
(for {
ntpe <- i.tpnme.map(nme => schema.definition(nme))
sels <- Some(collectSelects(List((ntpe.getOrElse(tpe), i.child))))
} yield sels).getOrElse(Nil) // Undefined types will be reported later
case (_, f: UntypedFragment) =>
(for {
frag <- frags.find(_.name == f.name)
ntpe <- schema.definition(frag.tpnme)
sels <- Some(collectSelects(List((ntpe, frag.child))))
} yield sels).getOrElse(Nil) // Undefined types will be reported later
case _ =>
Nil
}
}

// Checks that the supplied field types are compatible
def checkShapes(tpes: List[Type], resultName: String): Either[List[Problem], List[NamedType]] = {
if(tpes.sizeCompare(1) <= 0) Right(tpes.map(_.underlyingNamed))
else {
def stripNull(tpes: List[Type]): Either[List[Problem], List[NamedType]] = {
if (tpes.forall(!_.isNullable)) stripList(tpes)
else if(tpes.forall(_.isNullable)) stripList(tpes.map(_.nonNull))
else Left(List(Problem(s"Cannot merge fields named '$resultName' of both nullable and non-nullable types")))
}

def stripList(tpes: List[Type]): Either[List[Problem], List[NamedType]] = {
if (tpes.forall(!_.isList)) Right(tpes.map(_.underlyingNamed))
else if(tpes.forall(_.isList)) stripNull(tpes.collect { case ListType(elem) => elem })
else Left(List(Problem(s"Cannot merge fields named '$resultName' of both list and non-list types")))
}

stripNull(tpes) match {
case l@Left(_) => l
case r@Right(tpes) =>
if (tpes.forall(!_.isLeaf)) r
else {
val first = tpes.head
if (tpes.tail.forall(_ =:= first)) r
else {
val (leaf0, nonLeaf0) = tpes.partition(_.isLeaf)
val leafNames = leaf0.map(_.name).distinct
val nonLeafNames = nonLeaf0.map(_.name).distinct
val leafErrors =
if (leafNames.sizeCompare(1) <= 0) Nil
else List(Problem(s"Cannot merge fields named '$resultName' of distinct leaf types ${leafNames.mkString(", ")}"))
val nonLeafErrors =
if (nonLeafNames.isEmpty) Nil
else List(Problem(s"Cannot merge fields named '$resultName' of leaf types ${leafNames.mkString(", ")} and non-leaf types ${nonLeafNames.mkString(", ")}"))
Left(leafErrors ::: nonLeafErrors)
}
}
}
}
}

// Checks that the supplied queries are mergeable
def validateQueries(queries: List[(NamedType, Query)]): List[Problem] = {
val sels = collectSelects(queries)
val grouped = sels.groupBy(_._2.resultName)
val mergeProblems =
if (sels.sizeCompare(1) <= 0) Nil
else
grouped.toList.flatMap {
case (resultName, sels) =>
if (sels.sizeCompare(1) <= 0) Nil
else {
val allTypes = sels.map(_._1.dealias).distinct
allTypes.flatMap { tpe =>
val conflictSet = sels.collect { case (ntpe, sel) if ntpe <:< tpe || tpe <:< ntpe => sel }
if (conflictSet.sizeCompare(1) <= 0) Nil
else {
val first = conflictSet.head
val noNameConflicts = conflictSet.forall(_.name == first.name)
val noArgConflicts = conflictSet.forall(_.args == first.args)
if (noNameConflicts && noArgConflicts) Nil
else {
val nameProblems =
if (noNameConflicts) Nil
else List(Problem(s"Cannot merge fields with alias '$resultName' and names ${conflictSet.map(s => s"'${s.name}'").distinct.mkString(", ")}"))
val argProblems =
if (noArgConflicts) Nil
else List(Problem(s"Cannot merge fields named '$resultName' with different arguments"))
nameProblems ::: argProblems
}
}
}
}
}

if (mergeProblems.nonEmpty) mergeProblems.toList.distinct
else
grouped.flatMap {
case (resultName, sels) =>
val children = sels.flatMap {
case ((tpe, q)) =>
q.name match {
case "__typename" => Nil
case "__type" => List((Introspection.__TypeType, q.child))
case "__schema" => List((Introspection.__SchemaType, q.child))
case _ =>
(for {
ctpe <- tpe.field(q.name)
} yield (ctpe, q.child)).toList
} // Undefined fields and bogus subselection sets will be reported later
}
val (ctpes, cqs) = children.unzip
checkShapes(ctpes, resultName) match {
case Left(ps) => ps
case Right(Nil) => Nil
case Right(ctpes) if ctpes.head.isLeaf => Nil
case Right(ctpes) => validateQueries(ctpes.zip(cqs))
}
}.toList
}

op.rootTpe(schema) match {
case Result.Success(tpe) => validateQueries(List((tpe, op.query)))
case Result.Warning(ps, tpe) => ps.toList ++ validateQueries(List((tpe, op.query)))
case Result.Failure(ps) => ps.toList
case Result.InternalError(_) => Nil // This will be reported elsewhere
}
}

ops.flatMap(validateOp)
}
}

object QueryCompiler {
Expand Down Expand Up @@ -718,11 +859,11 @@ object QueryCompiler {
def validateSubselection(fieldName: String, child: Query): Elab[Unit] =
for {
c <- Elab.context
obj <- Elab.liftR(c.tpe.underlyingObject.toResultOrError(s"Expected object type, found ${c.tpe}"))
childCtx <- Elab.liftR(c.forField(fieldName, None))
tpe = childCtx.tpe
_ <- {
val isLeaf = tpe.isUnderlyingLeaf
def obj = c.tpe.underlyingNamed
if (isLeaf && child != Empty)
Elab.failure(s"Leaf field '$fieldName' of $obj must have an empty subselection set")
else if (!isLeaf && child == Empty)
Expand Down Expand Up @@ -801,14 +942,6 @@ object QueryCompiler {
object VariablesSkipAndFragmentElaborator extends Phase {
override def transform(query: Query): Elab[Query] =
query match {
case Group(children) =>
children.traverse(q => transform(q)).map { eqs =>
eqs.filterNot(_ == Empty) match {
case Nil => Empty
case eq :: Nil => eq
case eqs => Group(eqs)
}
}
case sel@UntypedSelect(fieldName, alias, args, dirs, child) =>
isSkipped(dirs).ifM(
Elab.pure(Empty),
Expand All @@ -819,7 +952,7 @@ object QueryCompiler {
childCtx <- Elab.liftR(c.forField(fieldName, alias))
vars <- Elab.vars
eArgs <- args.traverse(elaborateBinding(_, vars))
eDirs <- Elab.liftR(Directive.elaborateDirectives(s, dirs, vars))
eDirs <- Elab.liftR(Directive.elaborateDirectives(s, dirs.filterNot(dir => dir.name == "skip" || dir.name == "include"), vars))
_ <- Elab.push(childCtx, child)
ec <- transform(child)
_ <- Elab.pop
Expand All @@ -833,9 +966,9 @@ object QueryCompiler {
s <- Elab.schema
c <- Elab.context
f <- Elab.fragment(nme)
ctpe <- Elab.liftR(c.tpe.underlyingObject.toResultOrError(s"Expected object type, found ${c.tpe}"))
ctpe = c.tpe.underlyingNamed
subtpe <- Elab.liftR(s.definition(f.tpnme).toResult(s"Unknown type '${f.tpnme}' in type condition of fragment '$nme'"))
_ <- Elab.failure(s"Fragment '$nme' is not compatible with type '${c.tpe}'").whenA(!(subtpe <:< ctpe) && !(ctpe <:< subtpe))
_ <- Elab.failure(s"Fragment '$nme' is not compatible with type '${c.tpe}'").whenA(!fragmentApplies(subtpe, ctpe))
_ <- Elab.push(c.asType(subtpe), f.child)
ec <- transform(f.child)
_ <- Elab.pop
Expand All @@ -850,14 +983,14 @@ object QueryCompiler {
for {
s <- Elab.schema
c <- Elab.context
ctpe <- Elab.liftR(c.tpe.underlyingObject.toResultOrError(s"Expected object type, found ${c.tpe}"))
ctpe = c.tpe.underlyingNamed
subtpe <- tpnme0 match {
case None =>
Elab.pure(ctpe)
case Some(tpnme) =>
Elab.liftR(s.definition(tpnme).toResult(s"Unknown type '$tpnme' in type condition inline fragment"))
}
_ <- Elab.failure(s"Inline fragment with type condition '$subtpe' is not compatible with type '$ctpe'").whenA(!(subtpe <:< ctpe) && !(ctpe <:< subtpe))
_ <- Elab.failure(s"Inline fragment with type condition '$subtpe' is not compatible with type '$ctpe'").whenA(!fragmentApplies(subtpe, ctpe))
_ <- Elab.push(c.asType(subtpe), child)
ec <- transform(child)
_ <- Elab.pop
Expand All @@ -869,6 +1002,16 @@ object QueryCompiler {
case _ => super.transform(query)
}

/**
* Tests the supplied type condition is satisfied by the supplied context type
*/
def fragmentApplies(typeCond: NamedType, ctpe: NamedType): Boolean =
(typeCond.dealias, ctpe.dealias) match {
case (_: InterfaceType, u: UnionType) => u.members.forall(_ <:< typeCond)
case (_, u: UnionType) => u.members.exists(typeCond <:< _)
case _ => typeCond <:< ctpe || ctpe <:< typeCond
}

def elaborateBinding(b: Binding, vars: Vars): Elab[Binding] =
Elab.liftR(Value.elaborateValue(b.value, vars).map(ev => b.copy(value = ev)))

Expand Down Expand Up @@ -899,6 +1042,15 @@ object QueryCompiler {
}
}

/**
* A compiler phase which applies GraphQL field merge rules to an
* untyped query.
*/
object MergeFields extends Phase {
override def transform(query: Query): Elab[Query] =
Elab.pure(mergeUntypedQueries(List(query)))
}

/**
* A compiler phase which translates `Select` nodes to be directly
* interpretable.
Expand Down Expand Up @@ -937,7 +1089,7 @@ object QueryCompiler {
c <- Elab.context
s <- Elab.schema
childCtx <- Elab.liftR(c.forField(fieldName, resultName))
obj <- Elab.liftR(c.tpe.underlyingObject.toResultOrError(s"Expected object type, found ${c.tpe}"))
obj = c.tpe.underlyingNamed.dealias
field <- obj match {
case twf: TypeWithFields =>
Elab.liftR(twf.fieldInfo(fieldName).toResult(s"No field '$fieldName' for type ${obj.underlying}"))
Expand All @@ -958,7 +1110,7 @@ object QueryCompiler {
val e1 = Select(sel.name, sel.alias, e2)
val e0 =
if(attrs.isEmpty) e1
else Group((e1 :: attrs.map { case (nme, child) => Select(nme, child) }).flatMap(Query.ungroup))
else mergeQueries(e1 :: attrs.map { case (nme, child) => Select(nme, child) })

if (env.isEmpty) e0
else Environment(env, e0)
Expand Down
Loading

0 comments on commit 5f52e39

Please sign in to comment.