Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(isthmus): allow for conversion of plans containing Calcite aggregate functions #230

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.substrait.isthmus;

import java.util.Optional;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
Expand All @@ -22,6 +23,35 @@ public class AggregateFunctions {
public static SqlAggFunction SUM = new SubstraitSumAggFunction();
public static SqlAggFunction SUM0 = new SubstraitSumEmptyIsZeroAggFunction();

/**
* Some Calcite rules, like {@link
* org.apache.calcite.rel.rules.AggregateExpandDistinctAggregatesRule}, introduce the default
* Calcite aggregate functions into plans.
*
* <p>When converting these Calcite plans to Substrait, we need to convert the default Calcite
* aggregate calls to the Substrait specific variants.
*
* <p>This function attempts to convert the given {@code aggFunction} to its Substrait equivalent
*
* @param aggFunction the {@link SqlAggFunction} to convert to a Substrait specific variant
* @return an optional containing the Substrait equivalent of the given {@code aggFunction} if
* conversion was needed, empty otherwise.
*/
public static Optional<SqlAggFunction> toSubstraitAggVariant(SqlAggFunction aggFunction) {
if (aggFunction instanceof SqlMinMaxAggFunction fun) {
return Optional.of(
fun.getKind() == SqlKind.MIN ? AggregateFunctions.MIN : AggregateFunctions.MAX);
} else if (aggFunction instanceof SqlAvgAggFunction) {
return Optional.of(AggregateFunctions.AVG);
} else if (aggFunction instanceof SqlSumAggFunction) {
return Optional.of(AggregateFunctions.SUM);
} else if (aggFunction instanceof SqlSumEmptyIsZeroAggFunction) {
return Optional.of(AggregateFunctions.SUM0);
} else {
return Optional.empty();
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-ordered the checks to match the function definitions above.

}

/** Extension of {@link SqlMinMaxAggFunction} that ALWAYS infers a nullable return type */
private static class SubstraitSqlMinMaxAggFunction extends SqlMinMaxAggFunction {
public SubstraitSqlMinMaxAggFunction(SqlKind kind) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import io.substrait.expression.ExpressionCreator;
import io.substrait.expression.FunctionArg;
import io.substrait.extension.SimpleExtension;
import io.substrait.isthmus.AggregateFunctions;
import io.substrait.isthmus.SubstraitRelVisitor;
import io.substrait.isthmus.TypeConverter;
import io.substrait.type.Type;
Expand Down Expand Up @@ -80,13 +81,7 @@ public Optional<AggregateFunctionInvocation> convert(
AggregateCall call,
Function<RexNode, Expression> topLevelConverter) {

// replace COUNT() + distinct == true and approximate == true with APPROX_COUNT_DISTINCT
// before converting into substrait function
SqlAggFunction aggFunction = call.getAggregation();
if (aggFunction == SqlStdOperatorTable.COUNT && call.isDistinct() && call.isApproximate()) {
aggFunction = SqlStdOperatorTable.APPROX_COUNT_DISTINCT;
}
FunctionFinder m = signatures.get(aggFunction);
var m = getFunctionFinder(call);
if (m == null) {
return Optional.empty();
}
Expand All @@ -98,6 +93,21 @@ public Optional<AggregateFunctionInvocation> convert(
return m.attemptMatch(wrapped, topLevelConverter);
}

protected FunctionFinder getFunctionFinder(AggregateCall call) {
// replace COUNT() + distinct == true and approximate == true with APPROX_COUNT_DISTINCT
// before converting into substrait function
SqlAggFunction aggFunction = call.getAggregation();
if (aggFunction == SqlStdOperatorTable.COUNT && call.isDistinct() && call.isApproximate()) {
aggFunction = SqlStdOperatorTable.APPROX_COUNT_DISTINCT;
}

SqlAggFunction lookupFunction =
// Replace default Calcite aggregate calls with Substrait specific variants.
// See toSubstraitAggVariant for more details.
AggregateFunctions.toSubstraitAggVariant(aggFunction).orElse(aggFunction);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added more detailed docs to the function definition

return signatures.get(lookupFunction);
}

static class WrappedAggregateCall implements GenericCall {
private final AggregateCall call;
private final RelNode input;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,14 @@ public Optional<T> attemptMatch(C call, Function<RexNode, Expression> topLevelCo
}
return Optional.empty();
}

protected String getName() {
return name;
}

public SqlOperator getOperator() {
return operator;
}
}

public interface GenericCall {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package io.substrait.isthmus;

import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;

import java.io.IOException;
import java.util.List;
import org.apache.calcite.plan.hep.HepPlanner;
import org.apache.calcite.plan.hep.HepProgram;
import org.apache.calcite.plan.hep.HepProgramBuilder;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelRoot;
import org.apache.calcite.rel.rules.CoreRules;
import org.apache.calcite.sql.parser.SqlParseException;
import org.junit.jupiter.api.Test;

public class OptimizerIntegrationTest extends PlanTestBase {

@Test
void conversionHandlesBuiltInSum0CallAddedByRule() throws SqlParseException, IOException {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test captures how the default Calcite aggregate calls can be introduced by rules and verifies the conversion executes without failing even when this occur (thanks to your changes ✨ )

var query =
"select O_CUSTKEY, count(distinct O_ORDERKEY), count(*) from orders group by O_CUSTKEY";
// verify that the query works generally
assertFullRoundTrip(query);

SqlToSubstrait sqlConverter = new SqlToSubstrait();
List<RelRoot> relRoots = sqlConverter.sqlToRelNode(query, tpchSchemaCreateStatements());
assertEquals(1, relRoots.size());
RelRoot planRoot = relRoots.get(0);
RelNode originalPlan = planRoot.rel;

// Create a program to apply the AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN rule.
// This will introduce a SqlSumEmptyIsZeroAggFunction to the plan.
// This function does not have a mapping to Substrait.
// SubstraitSumEmptyIsZeroAggFunction is the variant which has a mapping.
// See io.substrait.isthmus.AggregateFunctions for details
HepProgram program =
new HepProgramBuilder()
.addRuleInstance(CoreRules.AGGREGATE_EXPAND_DISTINCT_AGGREGATES_TO_JOIN)
.build();
HepPlanner planner = new HepPlanner(program);
planner.setRoot(originalPlan);
var newPlan = planner.findBestExp();

assertDoesNotThrow(
() ->
// Conversion of the new plan should succeed
SubstraitRelVisitor.convert(RelRoot.of(newPlan, planRoot.kind), EXTENSION_COLLECTION));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import org.junit.jupiter.api.Assertions;

public class PlanTestBase {
final SimpleExtension.ExtensionCollection extensions;
protected final SimpleExtension.ExtensionCollection extensions;

{
try {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package io.substrait.isthmus.expression;

import static org.junit.jupiter.api.Assertions.*;

import io.substrait.isthmus.AggregateFunctions;
import io.substrait.isthmus.PlanTestBase;
import io.substrait.isthmus.TypeConverter;
import java.util.List;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction;
import org.apache.calcite.sql.type.SqlTypeName;
import org.junit.jupiter.api.Test;

public class AggregateFunctionConverterTest extends PlanTestBase {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-using code from PlanTestBase.


@Test
void testFunctionFinderMatch() {
AggregateFunctionConverter converter =
new AggregateFunctionConverter(
extensions.aggregateFunctions(), List.of(), typeFactory, TypeConverter.DEFAULT);

var functionFinder =
converter.getFunctionFinder(
AggregateCall.create(
new SqlSumEmptyIsZeroAggFunction(),
true,
List.of(1),
0,
typeFactory.createSqlType(SqlTypeName.VARCHAR),
null));
assertNotNull(functionFinder);
assertEquals("sum0", functionFinder.getName());
assertEquals(AggregateFunctions.SUM0, functionFinder.getOperator());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comparing the name is insufficient as both SqlSumEmptyIsZeroAggFunction and SubstraitSumEmptyIsZeroAggFunction have the same name.

}
}
Loading