diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java index 729d11be..acd9567e 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java @@ -1,6 +1,7 @@ package io.substrait.isthmus; import io.substrait.extension.SimpleExtension; +import io.substrait.plan.Plan; import io.substrait.relation.AbstractRelVisitor; import io.substrait.relation.NamedScan; import io.substrait.relation.Rel; @@ -12,8 +13,13 @@ import org.apache.calcite.jdbc.CalciteSchema; import org.apache.calcite.jdbc.LookupCalciteSchema; import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.RelRoot; +import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.rel.type.RelDataTypeFieldImpl; import org.apache.calcite.schema.Table; +import org.apache.calcite.sql.SqlKind; import org.apache.calcite.tools.Frameworks; import org.apache.calcite.tools.RelBuilder; @@ -99,6 +105,32 @@ public RelNode convert(Rel rel) { return rel.accept(converter); } + /** + * Converts a Substrait {@link Plan.Root} to a Calcite {@link RelRoot} + * + *

Generates a {@link RelDataType} row type with the final field names of the {@link Plan.Root} + * and creates a Calcite {@link RelRoot} with it. + * + *

TODO: revisit this code when support for WriteRel is added to substrait-java + * + * @param root {@link Plan.Root} to convert + * @return {@link RelRoot} + */ + public RelRoot convert(Plan.Root root) { + RelNode input = convert(root.getInput()); + RelDataType inputRowType = input.getRowType(); + RelDataTypeFactory.Builder builder = new RelDataTypeFactory.Builder(typeFactory); + for (RelDataTypeField field : inputRowType.getFieldList()) { + RelDataTypeFieldImpl renamedField = + new RelDataTypeFieldImpl( + root.getNames().get(field.getIndex()), field.getIndex(), field.getType()); + builder.add(renamedField); + } + + RelRoot calciteRoot = RelRoot.of(input, builder.build(), SqlKind.SELECT); + return calciteRoot; + } + private static class NamedStructGatherer extends AbstractRelVisitor { Map, NamedStruct> tableMap; diff --git a/isthmus/src/test/java/io/substrait/isthmus/SubstraitToCalciteTest.java b/isthmus/src/test/java/io/substrait/isthmus/SubstraitToCalciteTest.java new file mode 100644 index 00000000..111af6ba --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/SubstraitToCalciteTest.java @@ -0,0 +1,59 @@ +package io.substrait.isthmus; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import io.substrait.plan.Plan; +import io.substrait.plan.ProtoPlanConverter; +import java.util.Map.Entry; +import org.apache.calcite.adapter.tpcds.TpcdsSchema; +import org.apache.calcite.rel.RelRoot; +import org.junit.jupiter.api.Test; + +public class SubstraitToCalciteTest extends PlanTestBase { + @Test + void testConvertRoot() throws Exception { + SqlToSubstrait s = new SqlToSubstrait(); + TpcdsSchema schema = new TpcdsSchema(1.0); + + // single column + String newColName = "store_name"; + String sql = "select s_store_name as " + newColName + " from tpcds.store"; + + io.substrait.proto.Plan protoPlan = s.execute(sql, "tpcds", schema); + + ProtoPlanConverter protoPlanConverter = new ProtoPlanConverter(extensions); + Plan plan = protoPlanConverter.from(protoPlan); + Plan.Root root = plan.getRoots().get(0); + + assertEquals(1, root.getNames().size()); + assertEquals(newColName.toUpperCase(), root.getNames().get(0)); + + SubstraitToCalcite converter = new SubstraitToCalcite(extensions, typeFactory); + RelRoot relRoot = converter.convert(root); + + assertEquals(root.getNames().size(), relRoot.fields.size()); + for (Entry field : relRoot.fields) { + assertEquals(root.getNames().get(field.getKey()), field.getValue()); + } + + // multiple columns + String storeIdColumnName = "s_store_id"; + sql = "select " + storeIdColumnName + ", s_store_name as " + newColName + " from tpcds.store"; + protoPlan = s.execute(sql, "tpcds", schema); + + protoPlanConverter = new ProtoPlanConverter(extensions); + plan = protoPlanConverter.from(protoPlan); + root = plan.getRoots().get(0); + + assertEquals(2, root.getNames().size()); + assertEquals(storeIdColumnName.toUpperCase(), root.getNames().get(0)); + assertEquals(newColName.toUpperCase(), root.getNames().get(1)); + + relRoot = converter.convert(root); + + assertEquals(root.getNames().size(), relRoot.fields.size()); + for (Entry field : relRoot.fields) { + assertEquals(root.getNames().get(field.getKey()), field.getValue()); + } + } +}