-
Notifications
You must be signed in to change notification settings - Fork 77
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix(isthmus): allow for conversion of plans containing Calcite aggregate functions #230
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
import io.substrait.expression.ExpressionCreator; | ||
import io.substrait.expression.FunctionArg; | ||
import io.substrait.extension.SimpleExtension; | ||
import io.substrait.isthmus.AggregateFunctions; | ||
import io.substrait.isthmus.SubstraitRelVisitor; | ||
import io.substrait.isthmus.TypeConverter; | ||
import io.substrait.type.Type; | ||
|
@@ -80,13 +81,7 @@ public Optional<AggregateFunctionInvocation> convert( | |
AggregateCall call, | ||
Function<RexNode, Expression> topLevelConverter) { | ||
|
||
// replace COUNT() + distinct == true and approximate == true with APPROX_COUNT_DISTINCT | ||
// before converting into substrait function | ||
SqlAggFunction aggFunction = call.getAggregation(); | ||
if (aggFunction == SqlStdOperatorTable.COUNT && call.isDistinct() && call.isApproximate()) { | ||
aggFunction = SqlStdOperatorTable.APPROX_COUNT_DISTINCT; | ||
} | ||
FunctionFinder m = signatures.get(aggFunction); | ||
var m = getFunctionFinder(call); | ||
if (m == null) { | ||
return Optional.empty(); | ||
} | ||
|
@@ -98,6 +93,21 @@ public Optional<AggregateFunctionInvocation> convert( | |
return m.attemptMatch(wrapped, topLevelConverter); | ||
} | ||
|
||
protected FunctionFinder getFunctionFinder(AggregateCall call) { | ||
// replace COUNT() + distinct == true and approximate == true with APPROX_COUNT_DISTINCT | ||
// before converting into substrait function | ||
SqlAggFunction aggFunction = call.getAggregation(); | ||
if (aggFunction == SqlStdOperatorTable.COUNT && call.isDistinct() && call.isApproximate()) { | ||
aggFunction = SqlStdOperatorTable.APPROX_COUNT_DISTINCT; | ||
} | ||
|
||
SqlAggFunction lookupFunction = | ||
// Replace default Calcite aggregate calls with Substrait specific variants. | ||
// See toSubstraitAggVariant for more details. | ||
AggregateFunctions.toSubstraitAggVariant(aggFunction).orElse(aggFunction); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added more detailed docs to the function definition |
||
return signatures.get(lookupFunction); | ||
} | ||
|
||
static class WrappedAggregateCall implements GenericCall { | ||
private final AggregateCall call; | ||
private final RelNode input; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
package io.substrait.isthmus; | ||
|
||
import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION; | ||
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; | ||
import static org.junit.jupiter.api.Assertions.assertEquals; | ||
|
||
import java.io.IOException; | ||
import java.util.List; | ||
import org.apache.calcite.plan.hep.HepPlanner; | ||
import org.apache.calcite.plan.hep.HepProgram; | ||
import org.apache.calcite.plan.hep.HepProgramBuilder; | ||
import org.apache.calcite.rel.RelNode; | ||
import org.apache.calcite.rel.RelRoot; | ||
import org.apache.calcite.rel.rules.CoreRules; | ||
import org.apache.calcite.sql.parser.SqlParseException; | ||
import org.junit.jupiter.api.Test; | ||
|
||
public class OptimizerIntegrationTest extends PlanTestBase { | ||
|
||
@Test | ||
void conversionHandlesBuiltInSum0CallAddedByRule() throws SqlParseException, IOException { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test captures how the default Calcite aggregate calls can be introduced by rules and verifies the conversion executes without failing even when this occur (thanks to your changes ✨ ) |
||
var query = | ||
"select O_CUSTKEY, count(distinct O_ORDERKEY), count(*) from orders group by O_CUSTKEY"; | ||
// verify that the query works generally | ||
assertFullRoundTrip(query); | ||
|
||
SqlToSubstrait sqlConverter = new SqlToSubstrait(); | ||
List<RelRoot> relRoots = sqlConverter.sqlToRelNode(query, tpchSchemaCreateStatements()); | ||
assertEquals(1, relRoots.size()); | ||
RelRoot planRoot = relRoots.get(0); | ||
RelNode originalPlan = planRoot.rel; | ||
|
||
// Create a program to apply the AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN rule. | ||
// This will introduce a SqlSumEmptyIsZeroAggFunction to the plan. | ||
// This function does not have a mapping to Substrait. | ||
// SubstraitSumEmptyIsZeroAggFunction is the variant which has a mapping. | ||
// See io.substrait.isthmus.AggregateFunctions for details | ||
HepProgram program = | ||
new HepProgramBuilder() | ||
.addRuleInstance(CoreRules.AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN) | ||
.build(); | ||
HepPlanner planner = new HepPlanner(program); | ||
planner.setRoot(originalPlan); | ||
var newPlan = planner.findBestExp(); | ||
|
||
assertDoesNotThrow( | ||
() -> | ||
// Conversion of the new plan should succeed | ||
SubstraitRelVisitor.convert(RelRoot.of(newPlan, planRoot.kind), EXTENSION_COLLECTION)); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
package io.substrait.isthmus.expression; | ||
|
||
import static org.junit.jupiter.api.Assertions.*; | ||
|
||
import io.substrait.isthmus.AggregateFunctions; | ||
import io.substrait.isthmus.PlanTestBase; | ||
import io.substrait.isthmus.TypeConverter; | ||
import java.util.List; | ||
import org.apache.calcite.rel.core.AggregateCall; | ||
import org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction; | ||
import org.apache.calcite.sql.type.SqlTypeName; | ||
import org.junit.jupiter.api.Test; | ||
|
||
public class AggregateFunctionConverterTest extends PlanTestBase { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Re-using code from |
||
|
||
@Test | ||
void testFunctionFinderMatch() { | ||
AggregateFunctionConverter converter = | ||
new AggregateFunctionConverter( | ||
extensions.aggregateFunctions(), List.of(), typeFactory, TypeConverter.DEFAULT); | ||
|
||
var functionFinder = | ||
converter.getFunctionFinder( | ||
AggregateCall.create( | ||
new SqlSumEmptyIsZeroAggFunction(), | ||
true, | ||
List.of(1), | ||
0, | ||
typeFactory.createSqlType(SqlTypeName.VARCHAR), | ||
null)); | ||
assertNotNull(functionFinder); | ||
assertEquals("sum0", functionFinder.getName()); | ||
assertEquals(AggregateFunctions.SUM0, functionFinder.getOperator()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Comparing the name is insufficient as both |
||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re-ordered the checks to match the function definitions above.