Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(isthmus): support converting Substrait Plan.Root to Calcite RelRoot #339

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.substrait.isthmus;

import io.substrait.extension.SimpleExtension;
import io.substrait.plan.Plan;
import io.substrait.relation.AbstractRelVisitor;
import io.substrait.relation.NamedScan;
import io.substrait.relation.Rel;
Expand All @@ -12,8 +13,13 @@
import org.apache.calcite.jdbc.CalciteSchema;
import org.apache.calcite.jdbc.LookupCalciteSchema;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelRoot;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rel.type.RelDataTypeFieldImpl;
import org.apache.calcite.schema.Table;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.tools.Frameworks;
import org.apache.calcite.tools.RelBuilder;

Expand Down Expand Up @@ -99,6 +105,32 @@ public RelNode convert(Rel rel) {
return rel.accept(converter);
}

/**
* Converts a Substrait {@link Plan.Root} to a Calcite {@link RelRoot}
*
* <p>Generates a {@link RelDataType} row type with the final field names of the {@link Plan.Root}
* and creates a Calcite {@link RelRoot} with it.
*
* <p>TODO: revisit this code when support for WriteRel is added to substrait-java
*
* @param root {@link Plan.Root} to convert
* @return {@link RelRoot}
*/
public RelRoot convert(Plan.Root root) {
RelNode input = convert(root.getInput());
RelDataType inputRowType = input.getRowType();
RelDataTypeFactory.Builder builder = new RelDataTypeFactory.Builder(typeFactory);
for (RelDataTypeField field : inputRowType.getFieldList()) {
RelDataTypeFieldImpl renamedField =
new RelDataTypeFieldImpl(
root.getNames().get(field.getIndex()), field.getIndex(), field.getType());
builder.add(renamedField);
}

RelRoot calciteRoot = RelRoot.of(input, builder.build(), SqlKind.SELECT);
return calciteRoot;
}

private static class NamedStructGatherer extends AbstractRelVisitor<Void, RuntimeException> {
Map<List<String>, NamedStruct> tableMap;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package io.substrait.isthmus;

import static org.junit.jupiter.api.Assertions.assertEquals;

import io.substrait.plan.Plan;
import io.substrait.plan.ProtoPlanConverter;
import java.util.Map.Entry;
import org.apache.calcite.adapter.tpcds.TpcdsSchema;
import org.apache.calcite.rel.RelRoot;
import org.junit.jupiter.api.Test;

public class SubstraitToCalciteTest extends PlanTestBase {
@Test
void testConvertRoot() throws Exception {
SqlToSubstrait s = new SqlToSubstrait();
TpcdsSchema schema = new TpcdsSchema(1.0);

// single column
String newColName = "store_name";
String sql = "select s_store_name as " + newColName + " from tpcds.store";

io.substrait.proto.Plan protoPlan = s.execute(sql, "tpcds", schema);

ProtoPlanConverter protoPlanConverter = new ProtoPlanConverter(extensions);
Plan plan = protoPlanConverter.from(protoPlan);
Plan.Root root = plan.getRoots().get(0);

assertEquals(1, root.getNames().size());
assertEquals(newColName.toUpperCase(), root.getNames().get(0));

SubstraitToCalcite converter = new SubstraitToCalcite(extensions, typeFactory);
RelRoot relRoot = converter.convert(root);

assertEquals(root.getNames().size(), relRoot.fields.size());
for (Entry<Integer, String> field : relRoot.fields) {
assertEquals(root.getNames().get(field.getKey()), field.getValue());
}

// multiple columns
String storeIdColumnName = "s_store_id";
sql = "select " + storeIdColumnName + ", s_store_name as " + newColName + " from tpcds.store";
protoPlan = s.execute(sql, "tpcds", schema);

protoPlanConverter = new ProtoPlanConverter(extensions);
plan = protoPlanConverter.from(protoPlan);
root = plan.getRoots().get(0);

assertEquals(2, root.getNames().size());
assertEquals(storeIdColumnName.toUpperCase(), root.getNames().get(0));
assertEquals(newColName.toUpperCase(), root.getNames().get(1));

relRoot = converter.convert(root);

assertEquals(root.getNames().size(), relRoot.fields.size());
for (Entry<Integer, String> field : relRoot.fields) {
assertEquals(root.getNames().get(field.getKey()), field.getValue());
}
}
}
Loading