From d899994897873269b977c3179afc3e92c71ac337 Mon Sep 17 00:00:00 2001 From: Bruno Volpato Date: Tue, 13 Feb 2024 16:38:28 -0500 Subject: [PATCH 1/2] fix: patch AggregateFunctionConverter to accept Calcite function instances --- .../substrait/isthmus/AggregateFunctions.java | 20 +++++++ .../AggregateFunctionConverter.java | 30 +++++++--- .../isthmus/expression/FunctionConverter.java | 4 ++ .../AggregateFunctionConverterTest.java | 59 +++++++++++++++++++ 4 files changed, 106 insertions(+), 7 deletions(-) create mode 100644 isthmus/src/test/java/io/substrait/isthmus/expression/AggregateFunctionConverterTest.java diff --git a/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java b/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java index 39c4f3cb7..16c6e6340 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java +++ b/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java @@ -1,5 +1,6 @@ package io.substrait.isthmus; +import java.util.Optional; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlKind; @@ -22,6 +23,25 @@ public class AggregateFunctions { public static SqlAggFunction SUM = new SubstraitSumAggFunction(); public static SqlAggFunction SUM0 = new SubstraitSumEmptyIsZeroAggFunction(); + /** + * Utility class to possibly convert the SqlAggFunction from the native Calcite implementation to + * the Substrait subclasses present here, in case they have definitions. + */ + public static Optional getSubstraitAggVariant(SqlAggFunction aggFunction) { + if (aggFunction instanceof SqlSumEmptyIsZeroAggFunction) { + return Optional.of(AggregateFunctions.SUM0); + } else if (aggFunction instanceof SqlMinMaxAggFunction fun) { + return Optional.of( + fun.getKind() == SqlKind.MIN ? AggregateFunctions.MIN : AggregateFunctions.MAX); + } else if (aggFunction instanceof SqlAvgAggFunction) { + return Optional.of(AggregateFunctions.AVG); + } else if (aggFunction instanceof SqlSumAggFunction) { + return Optional.of(AggregateFunctions.SUM); + } else { + return Optional.empty(); + } + } + /** Extension of {@link SqlMinMaxAggFunction} that ALWAYS infers a nullable return type */ private static class SubstraitSqlMinMaxAggFunction extends SqlMinMaxAggFunction { public SubstraitSqlMinMaxAggFunction(SqlKind kind) { diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java index 59068010a..38b4510a7 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java @@ -6,6 +6,7 @@ import io.substrait.expression.ExpressionCreator; import io.substrait.expression.FunctionArg; import io.substrait.extension.SimpleExtension; +import io.substrait.isthmus.AggregateFunctions; import io.substrait.isthmus.SubstraitRelVisitor; import io.substrait.isthmus.TypeConverter; import io.substrait.type.Type; @@ -80,13 +81,7 @@ public Optional convert( AggregateCall call, Function topLevelConverter) { - // replace COUNT() + distinct == true and approximate == true with APPROX_COUNT_DISTINCT - // before converting into substrait function - SqlAggFunction aggFunction = call.getAggregation(); - if (aggFunction == SqlStdOperatorTable.COUNT && call.isDistinct() && call.isApproximate()) { - aggFunction = SqlStdOperatorTable.APPROX_COUNT_DISTINCT; - } - FunctionFinder m = signatures.get(aggFunction); + var m = getFunctionFinder(call); if (m == null) { return Optional.empty(); } @@ -98,6 +93,27 @@ public Optional convert( return m.attemptMatch(wrapped, topLevelConverter); } + protected FunctionConverter< + SimpleExtension.AggregateFunctionVariant, + AggregateFunctionInvocation, + WrappedAggregateCall> + .FunctionFinder + getFunctionFinder(AggregateCall call) { + // replace COUNT() + distinct == true and approximate == true with APPROX_COUNT_DISTINCT + // before converting into substrait function + SqlAggFunction aggFunction = call.getAggregation(); + if (aggFunction == SqlStdOperatorTable.COUNT && call.isDistinct() && call.isApproximate()) { + aggFunction = SqlStdOperatorTable.APPROX_COUNT_DISTINCT; + } + + // Substrait has replaced those function classes with their own counterparts (which are + // subclasses of the Calcite ones), but some Calcite rules might still use the original + // functions during optimization (for example, at AggregateExpandDistinctAggregatesRule) + SqlAggFunction lookupFunction = + AggregateFunctions.getSubstraitAggVariant(aggFunction).orElse(aggFunction); + return signatures.get(lookupFunction); + } + static class WrappedAggregateCall implements GenericCall { private final AggregateCall call; private final RelNode input; diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java index fcf504128..e0be22ffc 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java @@ -354,6 +354,10 @@ public Optional attemptMatch(C call, Function topLevelCo } return Optional.empty(); } + + protected String getName() { + return name; + } } public interface GenericCall { diff --git a/isthmus/src/test/java/io/substrait/isthmus/expression/AggregateFunctionConverterTest.java b/isthmus/src/test/java/io/substrait/isthmus/expression/AggregateFunctionConverterTest.java new file mode 100644 index 000000000..b565e7bf4 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/expression/AggregateFunctionConverterTest.java @@ -0,0 +1,59 @@ +package io.substrait.isthmus.expression; + +import static org.junit.jupiter.api.Assertions.*; + +import io.substrait.dsl.SubstraitBuilder; +import io.substrait.extension.SimpleExtension; +import io.substrait.isthmus.RelCreator; +import io.substrait.isthmus.TypeConverter; +import java.io.IOException; +import java.util.List; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction; +import org.apache.calcite.sql.type.SqlTypeName; +import org.junit.jupiter.api.Test; + +public class AggregateFunctionConverterTest { + + protected static final SimpleExtension.ExtensionCollection EXTENSION_COLLECTION; + + static { + SimpleExtension.ExtensionCollection defaults; + try { + defaults = SimpleExtension.loadDefaults(); + } catch (IOException e) { + throw new RuntimeException("Failure while loading defaults.", e); + } + + EXTENSION_COLLECTION = defaults; + } + + final SubstraitBuilder b = new SubstraitBuilder(EXTENSION_COLLECTION); + + @Test + public void testFunctionFinderMatch() { + + RelCreator relCreator = new RelCreator(); + RelDataTypeFactory typeFactory = relCreator.typeFactory(); + + AggregateFunctionConverter converter = + new AggregateFunctionConverter( + EXTENSION_COLLECTION.aggregateFunctions(), + List.of(), + typeFactory, + TypeConverter.DEFAULT); + + var functionFinder = + converter.getFunctionFinder( + AggregateCall.create( + new SqlSumEmptyIsZeroAggFunction(), + true, + List.of(1), + 0, + typeFactory.createSqlType(SqlTypeName.VARCHAR), + null)); + assertNotNull(functionFinder); + assertEquals("sum0", functionFinder.getName()); + } +} From 09b6b2f433fdef122460b693176a3c439ae17b81 Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Wed, 21 Feb 2024 15:56:32 -0800 Subject: [PATCH 2/2] refactor: pr suggestions --- .../substrait/isthmus/AggregateFunctions.java | 22 +++++--- .../AggregateFunctionConverter.java | 14 ++--- .../isthmus/expression/FunctionConverter.java | 4 ++ .../isthmus/OptimizerIntegrationTest.java | 51 +++++++++++++++++++ .../io/substrait/isthmus/PlanTestBase.java | 2 +- .../AggregateFunctionConverterTest.java | 36 +++---------- 6 files changed, 82 insertions(+), 47 deletions(-) create mode 100644 isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java diff --git a/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java b/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java index 16c6e6340..6a8d84d88 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java +++ b/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java @@ -24,19 +24,29 @@ public class AggregateFunctions { public static SqlAggFunction SUM0 = new SubstraitSumEmptyIsZeroAggFunction(); /** - * Utility class to possibly convert the SqlAggFunction from the native Calcite implementation to - * the Substrait subclasses present here, in case they have definitions. + * Some Calcite rules, like {@link + * org.apache.calcite.rel.rules.AggregateExpandDistinctAggregatesRule}, introduce the default + * Calcite aggregate functions into plans. + * + *

When converting these Calcite plans to Substrait, we need to convert the default Calcite + * aggregate calls to the Substrait specific variants. + * + *

This function attempts to convert the given {@code aggFunction} to its Substrait equivalent + * + * @param aggFunction the {@link SqlAggFunction} to convert to a Substrait specific variant + * @return an optional containing the Substrait equivalent of the given {@code aggFunction} if + * conversion was needed, empty otherwise. */ - public static Optional getSubstraitAggVariant(SqlAggFunction aggFunction) { - if (aggFunction instanceof SqlSumEmptyIsZeroAggFunction) { - return Optional.of(AggregateFunctions.SUM0); - } else if (aggFunction instanceof SqlMinMaxAggFunction fun) { + public static Optional toSubstraitAggVariant(SqlAggFunction aggFunction) { + if (aggFunction instanceof SqlMinMaxAggFunction fun) { return Optional.of( fun.getKind() == SqlKind.MIN ? AggregateFunctions.MIN : AggregateFunctions.MAX); } else if (aggFunction instanceof SqlAvgAggFunction) { return Optional.of(AggregateFunctions.AVG); } else if (aggFunction instanceof SqlSumAggFunction) { return Optional.of(AggregateFunctions.SUM); + } else if (aggFunction instanceof SqlSumEmptyIsZeroAggFunction) { + return Optional.of(AggregateFunctions.SUM0); } else { return Optional.empty(); } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java index 38b4510a7..a62f7c0e7 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java @@ -93,12 +93,7 @@ public Optional convert( return m.attemptMatch(wrapped, topLevelConverter); } - protected FunctionConverter< - SimpleExtension.AggregateFunctionVariant, - AggregateFunctionInvocation, - WrappedAggregateCall> - .FunctionFinder - getFunctionFinder(AggregateCall call) { + protected FunctionFinder getFunctionFinder(AggregateCall call) { // replace COUNT() + distinct == true and approximate == true with APPROX_COUNT_DISTINCT // before converting into substrait function SqlAggFunction aggFunction = call.getAggregation(); @@ -106,11 +101,10 @@ public Optional convert( aggFunction = SqlStdOperatorTable.APPROX_COUNT_DISTINCT; } - // Substrait has replaced those function classes with their own counterparts (which are - // subclasses of the Calcite ones), but some Calcite rules might still use the original - // functions during optimization (for example, at AggregateExpandDistinctAggregatesRule) SqlAggFunction lookupFunction = - AggregateFunctions.getSubstraitAggVariant(aggFunction).orElse(aggFunction); + // Replace default Calcite aggregate calls with Substrait specific variants. + // See toSubstraitAggVariant for more details. + AggregateFunctions.toSubstraitAggVariant(aggFunction).orElse(aggFunction); return signatures.get(lookupFunction); } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java index e0be22ffc..10b3dd1df 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java @@ -358,6 +358,10 @@ public Optional attemptMatch(C call, Function topLevelCo protected String getName() { return name; } + + public SqlOperator getOperator() { + return operator; + } } public interface GenericCall { diff --git a/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java b/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java new file mode 100644 index 000000000..8545bf21a --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/OptimizerIntegrationTest.java @@ -0,0 +1,51 @@ +package io.substrait.isthmus; + +import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.io.IOException; +import java.util.List; +import org.apache.calcite.plan.hep.HepPlanner; +import org.apache.calcite.plan.hep.HepProgram; +import org.apache.calcite.plan.hep.HepProgramBuilder; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.RelRoot; +import org.apache.calcite.rel.rules.CoreRules; +import org.apache.calcite.sql.parser.SqlParseException; +import org.junit.jupiter.api.Test; + +public class OptimizerIntegrationTest extends PlanTestBase { + + @Test + void conversionHandlesBuiltInSum0CallAddedByRule() throws SqlParseException, IOException { + var query = + "select O_CUSTKEY, count(distinct O_ORDERKEY), count(*) from orders group by O_CUSTKEY"; + // verify that the query works generally + assertFullRoundTrip(query); + + SqlToSubstrait sqlConverter = new SqlToSubstrait(); + List relRoots = sqlConverter.sqlToRelNode(query, tpchSchemaCreateStatements()); + assertEquals(1, relRoots.size()); + RelRoot planRoot = relRoots.get(0); + RelNode originalPlan = planRoot.rel; + + // Create a program to apply the AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN rule. + // This will introduce a SqlSumEmptyIsZeroAggFunction to the plan. + // This function does not have a mapping to Substrait. + // SubstraitSumEmptyIsZeroAggFunction is the variant which has a mapping. + // See io.substrait.isthmus.AggregateFunctions for details + HepProgram program = + new HepProgramBuilder() + .addRuleInstance(CoreRules.AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN) + .build(); + HepPlanner planner = new HepPlanner(program); + planner.setRoot(originalPlan); + var newPlan = planner.findBestExp(); + + assertDoesNotThrow( + () -> + // Conversion of the new plan should succeed + SubstraitRelVisitor.convert(RelRoot.of(newPlan, planRoot.kind), EXTENSION_COLLECTION)); + } +} diff --git a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java index c0b725645..657c9eeea 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java +++ b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java @@ -32,7 +32,7 @@ import org.junit.jupiter.api.Assertions; public class PlanTestBase { - final SimpleExtension.ExtensionCollection extensions; + protected final SimpleExtension.ExtensionCollection extensions; { try { diff --git a/isthmus/src/test/java/io/substrait/isthmus/expression/AggregateFunctionConverterTest.java b/isthmus/src/test/java/io/substrait/isthmus/expression/AggregateFunctionConverterTest.java index b565e7bf4..f187f9ba6 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/expression/AggregateFunctionConverterTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/expression/AggregateFunctionConverterTest.java @@ -2,47 +2,22 @@ import static org.junit.jupiter.api.Assertions.*; -import io.substrait.dsl.SubstraitBuilder; -import io.substrait.extension.SimpleExtension; -import io.substrait.isthmus.RelCreator; +import io.substrait.isthmus.AggregateFunctions; +import io.substrait.isthmus.PlanTestBase; import io.substrait.isthmus.TypeConverter; -import java.io.IOException; import java.util.List; import org.apache.calcite.rel.core.AggregateCall; -import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction; import org.apache.calcite.sql.type.SqlTypeName; import org.junit.jupiter.api.Test; -public class AggregateFunctionConverterTest { - - protected static final SimpleExtension.ExtensionCollection EXTENSION_COLLECTION; - - static { - SimpleExtension.ExtensionCollection defaults; - try { - defaults = SimpleExtension.loadDefaults(); - } catch (IOException e) { - throw new RuntimeException("Failure while loading defaults.", e); - } - - EXTENSION_COLLECTION = defaults; - } - - final SubstraitBuilder b = new SubstraitBuilder(EXTENSION_COLLECTION); +public class AggregateFunctionConverterTest extends PlanTestBase { @Test - public void testFunctionFinderMatch() { - - RelCreator relCreator = new RelCreator(); - RelDataTypeFactory typeFactory = relCreator.typeFactory(); - + void testFunctionFinderMatch() { AggregateFunctionConverter converter = new AggregateFunctionConverter( - EXTENSION_COLLECTION.aggregateFunctions(), - List.of(), - typeFactory, - TypeConverter.DEFAULT); + extensions.aggregateFunctions(), List.of(), typeFactory, TypeConverter.DEFAULT); var functionFinder = converter.getFunctionFinder( @@ -55,5 +30,6 @@ public void testFunctionFinderMatch() { null)); assertNotNull(functionFinder); assertEquals("sum0", functionFinder.getName()); + assertEquals(AggregateFunctions.SUM0, functionFinder.getOperator()); } }