diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeFromSource.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeFromSource.kt index d0c071da84..081321f46d 100644 --- a/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeFromSource.kt +++ b/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeFromSource.kt @@ -18,7 +18,7 @@ import org.partiql.ast.AstNode import org.partiql.ast.Expr import org.partiql.ast.From import org.partiql.ast.Statement -import org.partiql.ast.builder.ast +import org.partiql.ast.fromJoin import org.partiql.ast.helpers.toBinder import org.partiql.ast.util.AstRewriter @@ -31,6 +31,7 @@ internal object NormalizeFromSource : AstPass { private object Visitor : AstRewriter() { + // Each SFW starts the ctx count again. override fun visitExprSFW(node: Expr.SFW, ctx: Int): AstNode = super.visitExprSFW(node, 0) override fun visitStatementDMLBatchLegacy(node: Statement.DML.BatchLegacy, ctx: Int): AstNode = @@ -38,20 +39,25 @@ internal object NormalizeFromSource : AstPass { override fun visitFrom(node: From, ctx: Int) = super.visitFrom(node, ctx) as From - override fun visitFromJoin(node: From.Join, ctx: Int) = ast { + override fun visitFromJoin(node: From.Join, ctx: Int): From { val lhs = visitFrom(node.lhs, ctx) val rhs = visitFrom(node.rhs, ctx + 1) val condition = node.condition?.let { visitExpr(it, ctx) as Expr } - if (lhs !== node.lhs || rhs !== node.rhs || condition !== node.condition) { + return if (lhs !== node.lhs || rhs !== node.rhs || condition !== node.condition) { fromJoin(lhs, rhs, node.type, condition) } else { node } } - override fun visitFromValue(node: From.Value, ctx: Int) = when (node.asAlias) { - null -> node.copy(asAlias = node.expr.toBinder(ctx)) - else -> node + override fun visitFromValue(node: From.Value, ctx: Int): From { + val expr = visitExpr(node.expr, ctx) as Expr + val asAlias = node.asAlias ?: expr.toBinder(ctx) + return if (expr !== node.expr || asAlias !== node.asAlias) { + node.copy(expr = expr, asAlias = asAlias) + } else { + node + } } } } diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeSelect.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeSelect.kt index f131acf511..3e5c6526ac 100644 --- a/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeSelect.kt +++ b/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeSelect.kt @@ -80,7 +80,6 @@ import org.partiql.value.stringValue * } FROM A AS x * ``` * - * TODO: GROUP BY * TODO: LET * * Requires [NormalizeFromSource]. diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeSelectList.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeSelectList.kt deleted file mode 100644 index e0fe892cdc..0000000000 --- a/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeSelectList.kt +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Copyright 2022 Amazon.com, Inc. or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at: - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific - * language governing permissions and limitations under the License. - */ - -package org.partiql.ast.normalize - -import org.partiql.ast.Expr -import org.partiql.ast.Select -import org.partiql.ast.Statement -import org.partiql.ast.builder.ast -import org.partiql.ast.helpers.toBinder -import org.partiql.ast.util.AstRewriter - -/** - * Adds an `as` alias to every select-list item. - * - * - [org.partiql.ast.helpers.toBinder] - * - https://partiql.org/assets/PartiQL-Specification.pdf#page=28 - * - https://web.cecs.pdx.edu/~len/sql1999.pdf#page=287 - */ -internal object NormalizeSelectList : AstPass { - - override fun apply(statement: Statement) = Visitor.visitStatement(statement, 0) as Statement - - private object Visitor : AstRewriter() { - - override fun visitSelectProject(node: Select.Project, ctx: Int) = ast { - if (node.items.isEmpty()) { - return@ast node - } - var diff = false - val transformed = ArrayList(node.items.size) - node.items.forEachIndexed { i, n -> - val item = visitSelectProjectItem(n, i) as Select.Project.Item - if (item !== n) diff = true - transformed.add(item) - } - // We don't want to create a new list unless we have to, as to not trigger further rewrites up the tree. - if (diff) selectProject(transformed) else node - } - - override fun visitSelectProjectItemAll(node: Select.Project.Item.All, ctx: Int) = node.copy() - - override fun visitSelectProjectItemExpression(node: Select.Project.Item.Expression, ctx: Int) = ast { - val expr = visitExpr(node.expr, 0) as Expr - val alias = when (node.asAlias) { - null -> expr.toBinder(ctx) - else -> node.asAlias - } - if (expr != node.expr || alias != node.asAlias) { - selectProjectItemExpression(expr, alias) - } else { - node - } - } - } -} diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeSelectStar.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeSelectStar.kt deleted file mode 100644 index 97f4272d4c..0000000000 --- a/partiql-ast/src/main/kotlin/org/partiql/ast/normalize/NormalizeSelectStar.kt +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Copyright 2022 Amazon.com, Inc. or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at: - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific - * language governing permissions and limitations under the License. - */ - -package org.partiql.ast.normalize - -import org.partiql.ast.Expr -import org.partiql.ast.From -import org.partiql.ast.Identifier -import org.partiql.ast.Select -import org.partiql.ast.Statement -import org.partiql.ast.builder.AstBuilder -import org.partiql.ast.builder.ast -import org.partiql.ast.helpers.toBinder -import org.partiql.ast.util.AstRewriter - -/** - * Rewrites - * - `SELECT * FROM A AS x, B AS y AT i` -> `SELECT x.* AS _1, y.* as _2, i AS i FROM A AS x, B AS y AT i` - * - TODO GROUP BY - * - * Requires [NormalizeFromSource] - */ -internal object NormalizeSelectStar : AstPass { - - override fun apply(statement: Statement): Statement = Visitor.visitStatement(statement, Unit) as Statement - - private object Visitor : AstRewriter() { - - override fun visitExprSFW(node: Expr.SFW, ctx: Unit) = ast { - val sfw = super.visitExprSFW(node, ctx) as Expr.SFW - if (sfw.select !is Select.Star) { - return@ast sfw - } - val sel = selectProject { - sfw.from.aliases().forEachIndexed { i, binding -> - val asAlias = binding.first - val atAlias = binding.second - val byAlias = binding.third - items += asAlias.star(i) - if (atAlias != null) items += atAlias.simple() - if (byAlias != null) items += byAlias.simple() - } - setq = (sfw.select as Select.Star).setq - } - sfw.copy(select = sel) - } - - // Helpers - - private fun From.aliases(): List> = when (this) { - is From.Join -> lhs.aliases() + rhs.aliases() - is From.Value -> { - val asAlias = asAlias?.symbol ?: error("AST not normalized, missing asAlias on FROM source.") - val atAlias = atAlias?.symbol - val byAlias = byAlias?.symbol - listOf(Triple(asAlias, atAlias, byAlias)) - } - } - - // t -> t.* AS _i - private fun String.star(i: Int) = ast { - val expr = exprPath { - root = exprVar(id(this@star), Expr.Var.Scope.DEFAULT) - steps += exprPathStepUnpivot() - } - val alias = expr.toBinder(i) - selectProjectItemExpression(expr, alias) - } - - // t -> t AS t - private fun String.simple() = ast { - val expr = exprVar(id(this@simple), Expr.Var.Scope.DEFAULT) - val alias = id(this@simple) - selectProjectItemExpression(expr, alias) - } - - private fun AstBuilder.id(symbol: String) = identifierSymbol(symbol, Identifier.CaseSensitivity.INSENSITIVE) - } -} diff --git a/partiql-lang/src/test/kotlin/org/partiql/lang/planner/transforms/PartiQLSchemaInferencerTests.kt b/partiql-lang/src/test/kotlin/org/partiql/lang/planner/transforms/PartiQLSchemaInferencerTests.kt index 516c6d44e5..c4a7816d69 100644 --- a/partiql-lang/src/test/kotlin/org/partiql/lang/planner/transforms/PartiQLSchemaInferencerTests.kt +++ b/partiql-lang/src/test/kotlin/org/partiql/lang/planner/transforms/PartiQLSchemaInferencerTests.kt @@ -14,7 +14,6 @@ import org.junit.jupiter.params.provider.ArgumentsSource import org.junit.jupiter.params.provider.MethodSource import org.partiql.annotations.ExperimentalPartiQLSchemaInferencer import org.partiql.errors.Problem -import org.partiql.errors.ProblemHandler import org.partiql.errors.UNKNOWN_PROBLEM_LOCATION import org.partiql.lang.errors.ProblemCollector import org.partiql.lang.planner.transforms.PartiQLSchemaInferencerTests.TestCase.ErrorTestCase @@ -131,6 +130,11 @@ class PartiQLSchemaInferencerTests { @Execution(ExecutionMode.CONCURRENT) fun testCaseWhens(tc: TestCase) = runTest(tc) + @ParameterizedTest + @MethodSource("subqueryCases") + @Execution(ExecutionMode.CONCURRENT) + fun testSubqueries(tc: TestCase) = runTest(tc) + companion object { private val root = this::class.java.getResource("/catalogs/default")!!.toURI().toPath().pathString @@ -156,6 +160,10 @@ class PartiQLSchemaInferencerTests { field("connector_name", ionString("local")), field("root", ionString("$root/pql")), ), + "subqueries" to ionStructOf( + field("connector_name", ionString("local")), + field("root", ionString("$root/subqueries")), + ), ) const val CATALOG_AWS = "aws" @@ -2463,6 +2471,69 @@ class PartiQLSchemaInferencerTests { ) ), ) + + @JvmStatic + fun subqueryCases() = listOf( + SuccessTestCase( + name = "Subquery IN collection", + catalog = "subqueries", + key = PartiQLTest.Key("subquery", "subquery-00"), + expected = BagType( + StructType( + fields = mapOf( + "x" to INT4, + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true), + TupleConstraint.Ordered + ) + ) + ) + ), + SuccessTestCase( + name = "Subquery scalar coercion", + catalog = "subqueries", + key = PartiQLTest.Key("subquery", "subquery-01"), + expected = BagType( + StructType( + fields = mapOf( + "x" to INT4, + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true), + TupleConstraint.Ordered + ) + ) + ) + ), + SuccessTestCase( + name = "Subquery simple JOIN", + catalog = "subqueries", + key = PartiQLTest.Key("subquery", "subquery-02"), + expected = BagType( + StructType( + fields = mapOf( + "x" to INT4, + "y" to INT4, + "z" to INT4, + "a" to INT4, + "b" to INT4, + "c" to INT4, + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true), + TupleConstraint.Ordered + ) + ) + ) + ), + ) } sealed class TestCase { @@ -2474,7 +2545,7 @@ class PartiQLSchemaInferencerTests { val catalog: String? = null, val catalogPath: List = emptyList(), val expected: StaticType, - val warnings: ProblemHandler? = null + val warnings: ProblemHandler? = null, ) : TestCase() { override fun toString(): String = "$name : $query" } diff --git a/partiql-plan/src/main/resources/partiql_plan_0_1.ion b/partiql-plan/src/main/resources/partiql_plan_0_1.ion index ce031d4ccb..b43e5bd7c4 100644 --- a/partiql-plan/src/main/resources/partiql_plan_0_1.ion +++ b/partiql-plan/src/main/resources/partiql_plan_0_1.ion @@ -151,11 +151,9 @@ rex::{ rel: rel, }, - coll_to_scalar::{ - subquery: { - select: select, - type: static_type // reify `select` type - } + subquery::{ + select: select, + coercion: [ SCALAR, ROW ], }, select::{ diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/transforms/RexConverter.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/transforms/RexConverter.kt index 44f4c0d25f..39d49b63a4 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/transforms/RexConverter.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/transforms/RexConverter.kt @@ -19,7 +19,6 @@ package org.partiql.planner.transforms import org.partiql.ast.AstNode import org.partiql.ast.DatetimeField import org.partiql.ast.Expr -import org.partiql.ast.Select import org.partiql.ast.Type import org.partiql.ast.visitor.AstBaseVisitor import org.partiql.plan.Identifier @@ -29,10 +28,6 @@ import org.partiql.plan.fnUnresolved import org.partiql.plan.identifierSymbol import org.partiql.plan.rex import org.partiql.plan.rexOpCall -import org.partiql.plan.rexOpCase -import org.partiql.plan.rexOpCaseBranch -import org.partiql.plan.rexOpCollToScalar -import org.partiql.plan.rexOpCollToScalarSubquery import org.partiql.plan.rexOpCollection import org.partiql.plan.rexOpLit import org.partiql.plan.rexOpPath @@ -42,6 +37,8 @@ import org.partiql.plan.rexOpPathStepUnpivot import org.partiql.plan.rexOpPathStepWildcard import org.partiql.plan.rexOpStruct import org.partiql.plan.rexOpStructField +import org.partiql.plan.rexOpSubquery +import org.partiql.plan.rexOpTupleUnion import org.partiql.plan.rexOpVarUnresolved import org.partiql.planner.Env import org.partiql.planner.typer.toNonNullStaticType @@ -50,7 +47,6 @@ import org.partiql.types.StaticType import org.partiql.types.TimeType import org.partiql.value.PartiQLValueExperimental import org.partiql.value.boolValue -import org.partiql.value.datetime.Timestamp import org.partiql.value.int32Value import org.partiql.value.int64Value import org.partiql.value.nullValue @@ -65,6 +61,7 @@ internal object RexConverter { @OptIn(PartiQLValueExperimental::class) @Suppress("PARAMETER_NAME_CHANGED_ON_OVERRIDE") private object ToRex : AstBaseVisitor() { + override fun defaultReturn(node: AstNode, context: Env): Rex = throw IllegalArgumentException("unsupported rex $node") @@ -77,6 +74,31 @@ internal object RexConverter { return rex(type, op) } + /** + * !! IMPORTANT !! + * + * This is the top-level visit for handling subquery coercion. The default behavior is to coerce to a scalar. + * In some situations, ie comparison to complex types we may make assertions on the desired type. + * + * It is recommended that every method (except for the exceptional cases) recurse the tree from visitExprCoerce. + * + * - RHS of comparison when LHS is an array or collection expression; and visa-versa + * - It is the collection expression of a FROM clause or JOIN + * - It is the RHS of an IN predicate + * - It is an argument of an OUTER set operator. + * + * @param node + * @param ctx + * @return + */ + private fun visitExprCoerce(node: Expr, ctx: Env, coercion: Rex.Op.Subquery.Coercion = Rex.Op.Subquery.Coercion.SCALAR): Rex { + val rex = super.visitExpr(node, ctx) + return when (rex.op is Rex.Op.Select) { + true -> rex(StaticType.ANY, rexOpSubquery(rex.op as Rex.Op.Select, coercion)) + else -> rex + } + } + override fun visitExprVar(node: Expr.Var, context: Env): Rex { val type = (StaticType.ANY) val identifier = AstToPlan.convert(node.identifier) @@ -91,7 +113,7 @@ internal object RexConverter { override fun visitExprUnary(node: Expr.Unary, context: Env): Rex { val type = (StaticType.ANY) // Args - val arg = node.expr.accept(ToRex, context) + val arg = visitExprCoerce(node, context) val args = listOf(arg) // Fn val id = identifierSymbol(node.op.name.lowercase(), Identifier.CaseSensitivity.SENSITIVE) @@ -104,8 +126,8 @@ internal object RexConverter { override fun visitExprBinary(node: Expr.Binary, context: Env): Rex { val type = (StaticType.ANY) // Args - val lhs = node.lhs.accept(ToRex, context) - val rhs = node.rhs.accept(ToRex, context) + val lhs = visitExprCoerce(node.lhs, context) + val rhs = visitExprCoerce(node.rhs, context) val args = listOf(lhs, rhs) // Fn val id = identifierSymbol(node.op.name.lowercase(), Identifier.CaseSensitivity.SENSITIVE) @@ -118,11 +140,11 @@ internal object RexConverter { override fun visitExprPath(node: Expr.Path, context: Env): Rex { val type = (StaticType.ANY) // Args - val root = visitExpr(node.root, context) + val root = visitExprCoerce(node.root, context) val steps = node.steps.map { when (it) { is Expr.Path.Step.Index -> { - val key = visitExpr(it.key, context) + val key = visitExprCoerce(it.key, context) rexOpPathStepIndex(key) } is Expr.Path.Step.Symbol -> { @@ -147,22 +169,25 @@ internal object RexConverter { } val fn = fnUnresolved(id, false) // Args - val args = node.args.map { visitExpr(it, context) } + val args = node.args.map { visitExprCoerce(it, context) } // Rex val op = rexOpCall(fn, args) return rex(type, op) } - private fun visitExprCallTupleUnion(node: Expr.Call, context: Env) = plan { + private fun visitExprCallTupleUnion(node: Expr.Call, context: Env): Rex { val type = (StaticType.STRUCT) - val args = node.args.map { visitExpr(it, context) }.toMutableList() + val args = node.args.map { visitExprCoerce(it, context) }.toMutableList() val op = rexOpTupleUnion(args) - rex(type, op) + return rex(type, op) } override fun visitExprCase(node: Expr.Case, context: Env) = plan { val type = (StaticType.ANY) - val rex = node.expr?.let { visitExpr(it, context) } + val rex = when (node.expr) { + null -> null + else -> visitExprCoerce(node.expr!!, context) // match `rex + } // Converts AST CASE (x) WHEN y THEN z --> Plan CASE WHEN x = y THEN z val id = identifierSymbol(Expr.Binary.Op.EQ.name.lowercase(), Identifier.CaseSensitivity.SENSITIVE) @@ -176,14 +201,14 @@ internal object RexConverter { } val branches = node.branches.map { - val branchCondition = visitExpr(it.condition, context) - val branchRex = visitExpr(it.expr, context) + val branchCondition = visitExprCoerce(it.condition, context) + val branchRex = visitExprCoerce(it.expr, context) createBranch(branchCondition, branchRex) }.toMutableList() val defaultRex = when (val default = node.default) { null -> rex(type = StaticType.NULL, op = rexOpLit(value = nullValue())) - else -> visitExpr(default, context) + else -> visitExprCoerce(default, context) } val op = rexOpCase(branches = branches, default = defaultRex) rex(type, op) @@ -197,7 +222,7 @@ internal object RexConverter { Expr.Collection.Type.LIST -> StaticType.LIST Expr.Collection.Type.SEXP -> StaticType.SEXP } - val values = node.values.map { visitExpr(it, context) } + val values = node.values.map { visitExprCoerce(it, context) } val op = rexOpCollection(values) return rex(type, op) } @@ -205,8 +230,8 @@ internal object RexConverter { override fun visitExprStruct(node: Expr.Struct, context: Env): Rex { val type = (StaticType.STRUCT) val fields = node.fields.map { - val k = visitExpr(it.name, context) - val v = visitExpr(it.value, context) + val k = visitExprCoerce(it.name, context) + val v = visitExprCoerce(it.value, context) rexOpStructField(k, v) } val op = rexOpStruct(fields) @@ -221,9 +246,9 @@ internal object RexConverter { override fun visitExprLike(node: Expr.Like, ctx: Env): Rex { val type = StaticType.BOOL // Args - val arg0 = visitExpr(node.value, ctx) - val arg1 = visitExpr(node.pattern, ctx) - val arg2 = node.escape?.let { visitExpr(it, ctx) } + val arg0 = visitExprCoerce(node.value, ctx) + val arg1 = visitExprCoerce(node.pattern, ctx) + val arg2 = node.escape?.let { visitExprCoerce(it, ctx) } // Call Variants var call = when (arg2) { null -> call("like", arg0, arg1) @@ -242,9 +267,9 @@ internal object RexConverter { override fun visitExprBetween(node: Expr.Between, ctx: Env): Rex { val type = StaticType.BOOL // Args - val arg0 = visitExpr(node.value, ctx) - val arg1 = visitExpr(node.from, ctx) - val arg2 = visitExpr(node.to, ctx) + val arg0 = visitExprCoerce(node.value, ctx) + val arg1 = visitExprCoerce(node.from, ctx) + val arg2 = visitExprCoerce(node.to, ctx) // Call var call = call("between", arg0, arg1, arg2) // NOT? @@ -260,8 +285,8 @@ internal object RexConverter { override fun visitExprInCollection(node: Expr.InCollection, ctx: Env): Rex { val type = StaticType.BOOL // Args - val arg0 = visitExpr(node.lhs, ctx) - val arg1 = visitExpr(node.rhs, ctx) + val arg0 = visitExprCoerce(node.lhs, ctx) + val arg1 = visitExpr(node.rhs, ctx) // !! don't insert scalar subquery coercions // Call var call = call("in_collection", arg0, arg1) // NOT? @@ -277,7 +302,7 @@ internal object RexConverter { override fun visitExprIsType(node: Expr.IsType, ctx: Env): Rex { val type = StaticType.BOOL // arg - val arg0 = visitExpr(node.value, ctx) + val arg0 = visitExprCoerce(node.value, ctx) var call = when (val targetType = node.type) { is Type.NullType -> call("is_null", arg0) @@ -334,7 +359,6 @@ internal object RexConverter { // ELSE NULL END override fun visitExprCoalesce(node: Expr.Coalesce, ctx: Env): Rex = plan { val type = StaticType.ANY - val createBranch: (Rex) -> Rex.Op.Case.Branch = { expr: Rex -> val updatedCondition = rex(type, negate(call("is_null", expr))) rexOpCaseBranch(updatedCondition, expr) @@ -373,9 +397,9 @@ internal object RexConverter { override fun visitExprSubstring(node: Expr.Substring, ctx: Env): Rex { val type = StaticType.ANY // Args - val arg0 = visitExpr(node.value, ctx) - val arg1 = node.start?.let { visitExpr(it, ctx) } ?: rex(StaticType.INT, rexOpLit(int64Value(1))) - val arg2 = node.length?.let { visitExpr(it, ctx) } + val arg0 = visitExprCoerce(node.value, ctx) + val arg1 = node.start?.let { visitExprCoerce(it, ctx) } ?: rex(StaticType.INT, rexOpLit(int64Value(1))) + val arg2 = node.length?.let { visitExprCoerce(it, ctx) } // Call Variants val call = when (arg2) { null -> call("substring", arg0, arg1) @@ -390,8 +414,8 @@ internal object RexConverter { override fun visitExprPosition(node: Expr.Position, ctx: Env): Rex { val type = StaticType.ANY // Args - val arg0 = visitExpr(node.lhs, ctx) - val arg1 = visitExpr(node.rhs, ctx) + val arg0 = visitExprCoerce(node.lhs, ctx) + val arg1 = visitExprCoerce(node.rhs, ctx) // Call val call = call("position", arg0, arg1) return rex(type, call) @@ -403,8 +427,8 @@ internal object RexConverter { override fun visitExprTrim(node: Expr.Trim, ctx: Env): Rex { val type = StaticType.TEXT // Args - val arg0 = visitExpr(node.value, ctx) - val arg1 = node.chars?.let { visitExpr(it, ctx) } + val arg0 = visitExprCoerce(node.value, ctx) + val arg1 = node.chars?.let { visitExprCoerce(it, ctx) } // Call Variants val call = when (node.spec) { Expr.Trim.Spec.LEADING -> when (arg1) { @@ -435,7 +459,7 @@ internal object RexConverter { // TODO: Ignoring type parameter now override fun visitExprCast(node: Expr.Cast, ctx: Env): Rex { val type = node.asType - val arg0 = visitExpr(node.value, ctx) + val arg0 = visitExprCoerce(node.value, ctx) return when (type) { is Type.NullType -> rex(StaticType.NULL, call("cast_null", arg0)) is Type.Missing -> rex(StaticType.MISSING, call("cast_missing", arg0)) @@ -486,8 +510,8 @@ internal object RexConverter { override fun visitExprDateAdd(node: Expr.DateAdd, ctx: Env): Rex { val type = StaticType.TIMESTAMP // Args - val arg0 = visitExpr(node.lhs, ctx) - val arg1 = visitExpr(node.rhs, ctx) + val arg0 = visitExprCoerce(node.lhs, ctx) + val arg1 = visitExprCoerce(node.rhs, ctx) // Call Variants val call = when (node.field) { DatetimeField.TIMEZONE_HOUR -> error("Invalid call DATE_ADD(TIMEZONE_HOUR, ...)") @@ -500,8 +524,8 @@ internal object RexConverter { override fun visitExprDateDiff(node: Expr.DateDiff, ctx: Env): Rex { val type = StaticType.TIMESTAMP // Args - val arg0 = visitExpr(node.lhs, ctx) - val arg1 = visitExpr(node.rhs, ctx) + val arg0 = visitExprCoerce(node.lhs, ctx) + val arg1 = visitExprCoerce(node.rhs, ctx) // Call Variants val call = when (node.field) { DatetimeField.TIMEZONE_HOUR -> error("Invalid call DATE_DIFF(TIMEZONE_HOUR, ...)") @@ -518,31 +542,7 @@ internal object RexConverter { return rex(type, call) } - /** - * This indicates we've hit a SQL `SELECT` subquery in the context of an expression tree. - * There, coerce to scalar via COLL_TO_SCALAR: https://partiql.org/dql/subqueries.html#scalar-subquery - * - * The default behavior is to coerce, but we remove the scalar coercion in the special cases, - * - RHS of comparison when LHS is an array; and visa-versa - * - It is the collection expression of a FROM clause - * - It is the RHS of an IN predicate - */ - override fun visitExprSFW(node: Expr.SFW, context: Env): Rex { - val query = RelConverter.apply(node, context) - return when (val select = query.op) { - is Rex.Op.Select -> { - if (node.select is Select.Value) { - // SELECT VALUE does not implicitly coerce to a scalar - return query - } - // Insert the coercion - val type = select.constructor.type - val subquery = rexOpCollToScalarSubquery(select, query.type) - rex(type, rexOpCollToScalar(subquery)) - } - else -> query - } - } + override fun visitExprSFW(node: Expr.SFW, context: Env): Rex = RelConverter.apply(node, context) // Helpers @@ -564,7 +564,6 @@ internal object RexConverter { /** * Create a [Rex.Op.Call] node which has a hidden unresolved Function. - * A hidden function, will have a unicode 0xFDD0 as prefix in the key for [org.partiql.planner.FunctionMap]. * The purpose of having such hidden function is to prevent usage of generated function name in query text. */ private fun call(name: String, vararg args: Rex): Rex.Op.Call { @@ -583,7 +582,5 @@ internal object RexConverter { } private fun Int?.toRex() = rex(StaticType.INT4, rexOpLit(int32Value(this))) - - private fun Boolean?.toRex() = rex(StaticType.BOOL, rexOpLit(boolValue(this))) } } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/typer/PlanTyper.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/typer/PlanTyper.kt index e66d7ad190..be04079cc4 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/typer/PlanTyper.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/typer/PlanTyper.kt @@ -635,7 +635,9 @@ internal class PlanTyper( private fun foldCaseBranch(condition: Rex, result: Rex): Rex.Op.Case.Branch { val call = condition.op as? Rex.Op.Call ?: return rexOpCaseBranch(condition, result) val fn = call.fn as? Fn.Resolved ?: return rexOpCaseBranch(condition, result) - if (fn.signature.name.equals("is_struct", ignoreCase = true).not()) { return rexOpCaseBranch(condition, result) } + if (fn.signature.name.equals("is_struct", ignoreCase = true).not()) { + return rexOpCaseBranch(condition, result) + } val ref = call.args.getOrNull(0) ?: error("IS STRUCT requires an argument.") val simplifiedCondition = when { ref.type.allTypes.all { it is StructType } -> rex(StaticType.BOOL, rexOpLit(boolValue(true))) @@ -715,12 +717,50 @@ internal class PlanTyper( TODO("Type RexOpPivot") } - override fun visitRexOpCollToScalar(node: Rex.Op.CollToScalar, ctx: StaticType?): Rex { - TODO("Type RexOpCollToScalar") + override fun visitRexOpSubquery(node: Rex.Op.Subquery, ctx: StaticType?): Rex { + val select = visitRexOpSelect(node.select, ctx).op as Rex.Op.Select + val subquery = node.copy(select = select) + return when (node.coercion) { + Rex.Op.Subquery.Coercion.SCALAR -> visitRexOpSubqueryScalar(subquery, select.constructor.type) + Rex.Op.Subquery.Coercion.ROW -> visitRexOpSubqueryRow(subquery, select.constructor.type) + } + } + + /** + * Calculate output type of a row-value subquery. + */ + private fun visitRexOpSubqueryRow(subquery: Rex.Op.Subquery, cons: StaticType): Rex { + if (cons !is StructType) { + return rexErr("Subquery with non-SQL SELECT cannot be coerced to a row-value expression. Found constructor type: $cons") + } + // Do a simple cardinality check for the moment. + // TODO we can only check cardinality if we know we are in a a comparison operator. + // val n = coercion.columns.size + // val m = cons.fields.size + // if (n != m) { + // return rexErr("Cannot coercion subquery with $m attributes to a row-value-expression with $n attributes") + // } + // If we made it this far, then we can coerce this subquery to the desired complex value + val type = StaticType.LIST + val op = subquery + return rex(type, op) } - override fun visitRexOpCollToScalarSubquery(node: Rex.Op.CollToScalar.Subquery, ctx: StaticType?): Rex { - TODO("Type RexOpCollToScalarSubquery") + /** + * Calculate output type of a scalar subquery. + */ + private fun visitRexOpSubqueryScalar(subquery: Rex.Op.Subquery, cons: StaticType): Rex { + if (cons !is StructType) { + return rexErr("Subquery with non-SQL SELECT cannot be coerced to a scalar. Found constructor type: $cons") + } + val n = cons.fields.size + if (n != 1) { + return rexErr("SELECT constructor with $n attributes cannot be coerced to a scalar. Found constructor type: $cons") + } + // If we made it this far, then we can coerce this subquery to a scalar + val type = cons.fields.first().value + val op = subquery + return rex(type, op) } override fun visitRexOpSelect(node: Rex.Op.Select, ctx: StaticType?): Rex { @@ -819,8 +859,12 @@ internal class PlanTyper( ) possibleOutputTypes.add(StaticType.MISSING) } - is NullType -> { return StaticType.NULL } - else -> { return StaticType.MISSING } + is NullType -> { + return StaticType.NULL + } + else -> { + return StaticType.MISSING + } } } uniqueAttrs = when { @@ -879,8 +923,13 @@ internal class PlanTyper( return buildArgumentPermutations(flattenedArgs, accumulator = emptyList()) } - private fun buildArgumentPermutations(args: List>, accumulator: List): Sequence> { - if (args.isEmpty()) { return sequenceOf(accumulator) } + private fun buildArgumentPermutations( + args: List>, + accumulator: List, + ): Sequence> { + if (args.isEmpty()) { + return sequenceOf(accumulator) + } val first = args.first() val rest = when (args.size) { 1 -> emptyList() diff --git a/partiql-planner/src/testFixtures/resources/catalogs/default/subqueries/S.ion b/partiql-planner/src/testFixtures/resources/catalogs/default/subqueries/S.ion new file mode 100644 index 0000000000..6491b56119 --- /dev/null +++ b/partiql-planner/src/testFixtures/resources/catalogs/default/subqueries/S.ion @@ -0,0 +1,21 @@ +{ + type: "list", + items: { + type: "struct", + constraints: [ closed, unique ], + fields: [ + { + name: "a", + type: "int32" + }, + { + name: "b", + type: "int32" + }, + { + name: "c", + type: "int32" + }, + ] + } +} diff --git a/partiql-planner/src/testFixtures/resources/catalogs/default/subqueries/T.ion b/partiql-planner/src/testFixtures/resources/catalogs/default/subqueries/T.ion new file mode 100644 index 0000000000..9782f53166 --- /dev/null +++ b/partiql-planner/src/testFixtures/resources/catalogs/default/subqueries/T.ion @@ -0,0 +1,21 @@ +{ + type: "list", + items: { + type: "struct", + constraints: [ closed, unique ], + fields: [ + { + name: "x", + type: "int32" + }, + { + name: "y", + type: "int32" + }, + { + name: "z", + type: "int32" + }, + ] + } +} diff --git a/partiql-planner/src/testFixtures/resources/inputs/subquery/non_correlated.sql b/partiql-planner/src/testFixtures/resources/inputs/subquery/non_correlated.sql new file mode 100644 index 0000000000..3b5e05a20f --- /dev/null +++ b/partiql-planner/src/testFixtures/resources/inputs/subquery/non_correlated.sql @@ -0,0 +1,15 @@ +--#[subquery-00] +SELECT x +FROM T +WHERE x IN (SELECT a FROM S); + +--#[subquery-01] +SELECT x +FROM T +WHERE x > (SELECT MAX(a) FROM S); + +--#[subquery-02] +SELECT t.*, s.* +FROM T AS t + JOIN (SELECT * FROM S) AS s + ON t.x = s.a;