Skip to content

Commit

Permalink
feat(isthmus): support full expressions in SortRel (#322)
Browse files Browse the repository at this point in the history
SortRels converted to Calcite are no longer limited to only containing field references
  • Loading branch information
yassram authored Feb 14, 2025
1 parent f04fc9b commit 80f8678
Show file tree
Hide file tree
Showing 2 changed files with 227 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -333,17 +333,28 @@ private AggregateCall fromMeasure(Aggregate.Measure measure) {
@Override
public RelNode visit(Sort sort) throws RuntimeException {
RelNode child = sort.getInput().accept(this);
List<RelFieldCollation> relFieldCollations =
List<RexNode> sortExpressions =
sort.getSortFields().stream()
.map(sortField -> toRelFieldCollation(sortField))
.map(this::directedRexNode)
.collect(java.util.stream.Collectors.toList());
if (relFieldCollations.isEmpty()) {
return relBuilder.push(child).sort(Collections.EMPTY_LIST).build();
}
RelNode node = relBuilder.push(child).sort(RelCollations.of(relFieldCollations)).build();
RelNode node = relBuilder.push(child).sort(sortExpressions).build();
return applyRemap(node, sort.getRemap());
}

private RexNode directedRexNode(Expression.SortField sortField) {
var expression = sortField.expr();
var rexNode = expression.accept(expressionRexConverter);
var sortDirection = sortField.direction();
return switch (sortDirection) {
case ASC_NULLS_FIRST -> relBuilder.nullsFirst(rexNode);
case ASC_NULLS_LAST -> relBuilder.nullsLast(rexNode);
case DESC_NULLS_FIRST -> relBuilder.nullsFirst(relBuilder.desc(rexNode));
case DESC_NULLS_LAST -> relBuilder.nullsLast(relBuilder.desc(rexNode));
case CLUSTERED -> throw new RuntimeException(
String.format("Unexpected Expression.SortDirection: Clustered!"));
};
}

@Override
public RelNode visit(Fetch fetch) throws RuntimeException {
RelNode child = fetch.getInput().accept(this);
Expand Down
210 changes: 210 additions & 0 deletions isthmus/src/test/java/io/substrait/isthmus/ComplexSortTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
package io.substrait.isthmus;

import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION;
import static org.junit.jupiter.api.Assertions.assertEquals;

import io.substrait.dsl.SubstraitBuilder;
import io.substrait.expression.Expression;
import io.substrait.relation.Rel;
import io.substrait.type.TypeCreator;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.List;
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.externalize.RelWriterImpl;
import org.apache.calcite.sql.SqlExplainLevel;
import org.apache.calcite.util.Pair;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.junit.jupiter.api.Test;

public class ComplexSortTest extends PlanTestBase {

final TypeCreator R = TypeCreator.of(false);
SubstraitBuilder b = new SubstraitBuilder(extensions);

final SubstraitToCalcite substraitToCalcite =
new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory);

/**
* A {@link RelWriterImpl} that annotates each {@link RelNode} with its {@link RelCollation} trait
* information. A {@link RelNode} is only annotated if its {@link RelCollation} is not empty.
*/
public static class CollationRelWriter extends RelWriterImpl {
public CollationRelWriter(StringWriter sw) {
super(new PrintWriter(sw), SqlExplainLevel.EXPPLAN_ATTRIBUTES, false);
}

@Override
protected void explain_(RelNode rel, List<Pair<String, @Nullable Object>> values) {
var collation = rel.getTraitSet().getCollation();
if (!collation.isDefault()) {
StringBuilder s = new StringBuilder();
spacer.spaces(s);
s.append("Collation: ").append(collation.toString());
pw.println(s);
}
super.explain_(rel, values);
}
}

@Test
void handleInputReferenceSort() {
// CREATE TABLE example (a VARCHAR)
// SELECT a FROM example ORDER BY a

Rel rel =
b.project(
input -> b.fieldReferences(input, 0),
b.remap(1),
b.sort(
input ->
List.of(
b.sortField(
b.fieldReference(input, 0), Expression.SortDirection.ASC_NULLS_LAST)),
b.namedScan(List.of("example"), List.of("a"), List.of(R.STRING))));

String expected =
"""
Collation: [0]
LogicalSort(sort0=[$0], dir0=[ASC])
LogicalTableScan(table=[[example]])
""";

RelNode relReturned = substraitToCalcite.convert(rel);
var sw = new StringWriter();
relReturned.explain(new CollationRelWriter(sw));
assertEquals(expected, sw.toString());
}

@Test
void handleCastExpressionSort() {
// CREATE TABLE example (a VARCHAR)
// SELECT a FROM example ORDER BY a::INT

Rel rel =
b.project(
input -> b.fieldReferences(input, 0),
b.remap(1),
b.sort(
input ->
List.of(
b.sortField(
b.cast(b.fieldReference(input, 0), R.I32),
Expression.SortDirection.ASC_NULLS_LAST)),
b.namedScan(List.of("example"), List.of("a"), List.of(R.STRING))));

String expected =
"""
LogicalProject(a0=[$0])
Collation: [1]
LogicalSort(sort0=[$1], dir0=[ASC])
LogicalProject(a=[$0], a0=[CAST($0):INTEGER NOT NULL])
LogicalTableScan(table=[[example]])
""";

RelNode relReturned = substraitToCalcite.convert(rel);
var sw = new StringWriter();
relReturned.explain(new CollationRelWriter(sw));
assertEquals(expected, sw.toString());
}

@Test
void handleCastProjectAndSortWithSortDirection() {
// CREATE TABLE example (a VARCHAR)
// SELECT a::INT FROM example ORDER BY a::INT DESC NULLS LAST

Rel rel =
b.project(
input -> List.of(b.cast(b.fieldReference(input, 0), R.I32)),
b.remap(1),
b.sort(
input ->
List.of(
b.sortField(
b.cast(b.fieldReference(input, 0), R.I32),
Expression.SortDirection.DESC_NULLS_LAST)),
b.namedScan(List.of("example"), List.of("a"), List.of(R.STRING))));

String expected =
"""
LogicalProject(a0=[CAST($0):INTEGER NOT NULL])
Collation: [1 DESC-nulls-last]
LogicalSort(sort0=[$1], dir0=[DESC-nulls-last])
LogicalProject(a=[$0], a0=[CAST($0):INTEGER NOT NULL])
LogicalTableScan(table=[[example]])
""";

RelNode relReturned = substraitToCalcite.convert(rel);
var sw = new StringWriter();
relReturned.explain(new CollationRelWriter(sw));
assertEquals(expected, sw.toString());
}

@Test
void handleCastSortToOriginalType() {
// CREATE TABLE example (a VARCHAR)
// SELECT a FROM example ORDER BY a::VARCHAR

Rel rel =
b.project(
input -> List.of(b.fieldReference(input, 0)),
b.remap(1),
b.sort(
input ->
List.of(
b.sortField(
b.cast(b.fieldReference(input, 0), R.STRING),
Expression.SortDirection.DESC_NULLS_LAST)),
b.namedScan(List.of("example"), List.of("a"), List.of(R.STRING))));

String expected =
"""
LogicalProject(a0=[$0])
Collation: [1 DESC-nulls-last]
LogicalSort(sort0=[$1], dir0=[DESC-nulls-last])
LogicalProject(a=[$0], a0=[$0])
LogicalTableScan(table=[[example]])
""";

RelNode relReturned = substraitToCalcite.convert(rel);
var sw = new StringWriter();
relReturned.explain(new CollationRelWriter(sw));
assertEquals(expected, sw.toString());
}

@Test
void handleComplex2ExpressionSort() {
// CREATE TABLE example (a VARCHAR, b INT)
// SELECT b, a FROM example ORDER BY a::INT DESC, -b + 42 ASC NULLS LAST

Rel rel =
b.project(
input -> List.of(b.fieldReference(input, 0), b.fieldReference(input, 1)),
b.remap(2, 3),
b.sort(
input ->
List.of(
b.sortField(
b.cast(b.fieldReference(input, 0), R.I32),
Expression.SortDirection.DESC_NULLS_FIRST),
b.sortField(
b.add(b.negate(b.fieldReference(input, 1)), b.i32(42)),
Expression.SortDirection.ASC_NULLS_LAST)),
b.namedScan(List.of("example"), List.of("a", "b"), List.of(R.STRING, R.I32))));

String expected =
"""
LogicalProject(a0=[$0], b0=[$1])
Collation: [2 DESC, 3]
LogicalSort(sort0=[$2], sort1=[$3], dir0=[DESC], dir1=[ASC])
LogicalProject(a=[$0], b=[$1], a0=[CAST($0):INTEGER NOT NULL], $f3=[+(-($1), 42)])
LogicalTableScan(table=[[example]])
""";

RelNode relReturned = substraitToCalcite.convert(rel);
var sw = new StringWriter();
relReturned.explain(new CollationRelWriter(sw));
assertEquals(expected, sw.toString());
}
}

0 comments on commit 80f8678

Please sign in to comment.