diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/SortFieldConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/SortFieldConverter.java new file mode 100644 index 000000000..9cfbf9db5 --- /dev/null +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/SortFieldConverter.java @@ -0,0 +1,32 @@ +package io.substrait.isthmus.expression; + +import io.substrait.expression.Expression; +import org.apache.calcite.rel.RelFieldCollation; +import org.apache.calcite.rex.RexFieldCollation; + +public class SortFieldConverter { + + /** Converts a {@link RexFieldCollation} to a {@link Expression.SortField}. */ + public static Expression.SortField toSortField( + RexFieldCollation rexFieldCollation, RexExpressionConverter rexExpressionConverter) { + var expr = rexFieldCollation.left.accept(rexExpressionConverter); + var rexDirection = rexFieldCollation.getDirection(); + Expression.SortDirection direction = + switch (rexDirection) { + case ASCENDING -> rexFieldCollation.getNullDirection() + == RelFieldCollation.NullDirection.LAST + ? Expression.SortDirection.ASC_NULLS_LAST + : Expression.SortDirection.ASC_NULLS_FIRST; + case DESCENDING -> rexFieldCollation.getNullDirection() + == RelFieldCollation.NullDirection.LAST + ? Expression.SortDirection.DESC_NULLS_LAST + : Expression.SortDirection.DESC_NULLS_FIRST; + default -> throw new IllegalArgumentException( + String.format( + "Unexpected RelFieldCollation.Direction:%s enum at the RexFieldCollation!", + rexDirection)); + }; + + return Expression.SortField.builder().expr(expr).direction(direction).build(); + } +} diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/WindowBoundConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/WindowBoundConverter.java new file mode 100644 index 000000000..3208905ca --- /dev/null +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/WindowBoundConverter.java @@ -0,0 +1,37 @@ +package io.substrait.isthmus.expression; + +import io.substrait.expression.WindowBound; +import java.math.BigDecimal; +import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexWindowBound; +import org.apache.calcite.sql.type.SqlTypeName; + +public class WindowBoundConverter { + + /** Converts a {@link RexWindowBound} to a {@link WindowBound}. */ + public static WindowBound toWindowBound(RexWindowBound rexWindowBound) { + if (rexWindowBound.isCurrentRow()) { + return WindowBound.CURRENT_ROW; + } + if (rexWindowBound.isUnbounded()) { + return WindowBound.UNBOUNDED; + } else { + if (rexWindowBound.getOffset() instanceof RexLiteral literal + && SqlTypeName.EXACT_TYPES.contains(literal.getTypeName())) { + BigDecimal offset = (BigDecimal) literal.getValue4(); + if (rexWindowBound.isPreceding()) { + return WindowBound.Preceding.of(offset.longValue()); + } + if (rexWindowBound.isFollowing()) { + return WindowBound.Following.of(offset.longValue()); + } + throw new IllegalStateException( + "window bound was none of CURRENT ROW, UNBOUNDED, PRECEDING or FOLLOWING"); + } + throw new IllegalArgumentException( + String.format( + "substrait only supports integer window offsets. Received: %s", + rexWindowBound.getOffset().getKind())); + } + } +} diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/WindowFunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/WindowFunctionConverter.java index 784db3876..97ebf202a 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/WindowFunctionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/WindowFunctionConverter.java @@ -1,28 +1,27 @@ package io.substrait.isthmus.expression; +import static io.substrait.isthmus.expression.SortFieldConverter.toSortField; +import static io.substrait.isthmus.expression.WindowBoundConverter.toWindowBound; + import com.google.common.collect.ImmutableList; import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; import io.substrait.expression.FunctionArg; import io.substrait.expression.WindowBound; import io.substrait.extension.SimpleExtension; +import io.substrait.isthmus.AggregateFunctions; import io.substrait.type.Type; -import java.math.BigDecimal; import java.util.Collections; import java.util.List; import java.util.Optional; import java.util.function.Function; import java.util.stream.Stream; -import org.apache.calcite.rel.RelFieldCollation; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; -import org.apache.calcite.rex.RexFieldCollation; -import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexOver; import org.apache.calcite.rex.RexWindow; -import org.apache.calcite.rex.RexWindowBound; -import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.SqlAggFunction; public class WindowFunctionConverter extends FunctionConverter< @@ -89,7 +88,10 @@ public Optional convert( Function topLevelConverter, RexExpressionConverter rexExpressionConverter) { var aggFunction = over.getAggOperator(); - FunctionFinder m = signatures.get(aggFunction); + + SqlAggFunction lookupFunction = + AggregateFunctions.toSubstraitAggVariant(aggFunction).orElse(aggFunction); + FunctionFinder m = signatures.get(lookupFunction); if (m == null) { return Optional.empty(); } @@ -101,55 +103,6 @@ public Optional convert( return m.attemptMatch(wrapped, topLevelConverter); } - private WindowBound toWindowBound(RexWindowBound rexWindowBound) { - if (rexWindowBound.isCurrentRow()) { - return WindowBound.CURRENT_ROW; - } - if (rexWindowBound.isUnbounded()) { - return WindowBound.UNBOUNDED; - } else { - if (rexWindowBound.getOffset() instanceof RexLiteral literal - && SqlTypeName.EXACT_TYPES.contains(literal.getTypeName())) { - BigDecimal offset = (BigDecimal) literal.getValue4(); - if (rexWindowBound.isPreceding()) { - return WindowBound.Preceding.of(offset.longValue()); - } - if (rexWindowBound.isFollowing()) { - return WindowBound.Following.of(offset.longValue()); - } - throw new IllegalStateException( - "window bound was none of CURRENT ROW, UNBOUNDED, PRECEDING or FOLLOWING"); - } - throw new IllegalArgumentException( - String.format( - "substrait only supports integer window offsets. Received: %s", - rexWindowBound.getOffset().getKind())); - } - } - - private Expression.SortField toSortField( - RexFieldCollation rexFieldCollation, RexExpressionConverter rexExpressionConverter) { - var expr = rexFieldCollation.left.accept(rexExpressionConverter); - var rexDirection = rexFieldCollation.getDirection(); - Expression.SortDirection direction = - switch (rexDirection) { - case ASCENDING -> rexFieldCollation.getNullDirection() - == RelFieldCollation.NullDirection.LAST - ? Expression.SortDirection.ASC_NULLS_LAST - : Expression.SortDirection.ASC_NULLS_FIRST; - case DESCENDING -> rexFieldCollation.getNullDirection() - == RelFieldCollation.NullDirection.LAST - ? Expression.SortDirection.DESC_NULLS_LAST - : Expression.SortDirection.DESC_NULLS_FIRST; - default -> throw new IllegalArgumentException( - String.format( - "Unexpected RelFieldCollation.Direction:%s enum at the RexFieldCollation!", - rexDirection)); - }; - - return Expression.SortField.builder().expr(expr).direction(direction).build(); - } - static class WrappedWindowCall implements FunctionConverter.GenericCall { private final RexOver over; private final RexExpressionConverter rexExpressionConverter; diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/WindowRelFunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/WindowRelFunctionConverter.java new file mode 100644 index 000000000..1c00362f0 --- /dev/null +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/WindowRelFunctionConverter.java @@ -0,0 +1,136 @@ +package io.substrait.isthmus.expression; + +import static io.substrait.isthmus.expression.WindowBoundConverter.toWindowBound; + +import com.google.common.collect.ImmutableList; +import io.substrait.expression.Expression; +import io.substrait.expression.ExpressionCreator; +import io.substrait.expression.FunctionArg; +import io.substrait.expression.WindowBound; +import io.substrait.extension.SimpleExtension; +import io.substrait.isthmus.AggregateFunctions; +import io.substrait.relation.ConsistentPartitionWindow; +import io.substrait.type.Type; +import java.util.List; +import java.util.Optional; +import java.util.function.Function; +import java.util.stream.Stream; +import org.apache.calcite.rel.core.Window; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexWindowBound; +import org.apache.calcite.sql.SqlAggFunction; + +public class WindowRelFunctionConverter + extends FunctionConverter< + SimpleExtension.WindowFunctionVariant, + ConsistentPartitionWindow.WindowRelFunctionInvocation, + WindowRelFunctionConverter.WrappedWindowRelCall> { + + @Override + protected ImmutableList getSigs() { + return FunctionMappings.WINDOW_SIGS; + } + + public WindowRelFunctionConverter( + List functions, RelDataTypeFactory typeFactory) { + super(functions, typeFactory); + } + + @Override + protected ConsistentPartitionWindow.WindowRelFunctionInvocation generateBinding( + WrappedWindowRelCall call, + SimpleExtension.WindowFunctionVariant function, + List arguments, + Type outputType) { + Window.RexWinAggCall over = call.getWinAggCall(); + + Expression.AggregationInvocation invocation = + over.distinct + ? Expression.AggregationInvocation.DISTINCT + : Expression.AggregationInvocation.ALL; + + // Calcite only supports ROW or RANGE mode + Expression.WindowBoundsType boundsType = + call.isRows() ? Expression.WindowBoundsType.ROWS : Expression.WindowBoundsType.RANGE; + WindowBound lowerBound = toWindowBound(call.getLowerBound()); + WindowBound upperBound = toWindowBound(call.getUpperBound()); + + return ExpressionCreator.windowRelFunction( + function, + outputType, + Expression.AggregationPhase.INITIAL_TO_RESULT, + invocation, + boundsType, + lowerBound, + upperBound, + arguments); + } + + public Optional convert( + Window.RexWinAggCall winAggCall, + RexWindowBound lowerBound, + RexWindowBound upperBound, + boolean isRows, + Function topLevelConverter) { + var aggFunction = (SqlAggFunction) winAggCall.getOperator(); + + SqlAggFunction lookupFunction = + AggregateFunctions.toSubstraitAggVariant(aggFunction).orElse(aggFunction); + FunctionFinder m = signatures.get(lookupFunction); + if (m == null) { + return Optional.empty(); + } + if (!m.allowedArgCount(winAggCall.getOperands().size())) { + return Optional.empty(); + } + + var wrapped = new WrappedWindowRelCall(winAggCall, lowerBound, upperBound, isRows); + return m.attemptMatch(wrapped, topLevelConverter); + } + + static class WrappedWindowRelCall implements GenericCall { + private final Window.RexWinAggCall winAggCall; + private final RexWindowBound lowerBound; + private final RexWindowBound upperBound; + private final boolean isRows; + + private WrappedWindowRelCall( + Window.RexWinAggCall winAggCall, + RexWindowBound lowerBound, + RexWindowBound upperBound, + boolean isRows) { + this.winAggCall = winAggCall; + this.lowerBound = lowerBound; + this.upperBound = upperBound; + this.isRows = isRows; + } + + @Override + public Stream getOperands() { + return winAggCall.getOperands().stream(); + } + + @Override + public RelDataType getType() { + return winAggCall.getType(); + } + + public Window.RexWinAggCall getWinAggCall() { + return winAggCall; + } + + public RexWindowBound getLowerBound() { + return lowerBound; + } + + public RexWindowBound getUpperBound() { + return upperBound; + } + + public boolean isRows() { + return isRows; + } + } +}