Skip to content

Commit

Permalink
feat: add support for empty list literals
Browse files Browse the repository at this point in the history
BREAKING CHANGE:
* ExpressionVisitor has a new method to visit Expression.EmptyListLiteral.
  Implementors which don’t extend AbstractExpressionVisitor will have to
  add their own implementation.
* Public class LiteralConstructorConverter now requires a TypeConverter to
  be passed to its constructor.
* ExpressionCopyOnWriteVisitor#visitLiteral() may now be called with
  Expression.EmptyListLiteral.
  • Loading branch information
patientstreetlight committed Feb 12, 2024
1 parent aad2739 commit f0c3bf8
Show file tree
Hide file tree
Showing 13 changed files with 162 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ public OUTPUT visit(Expression.ListLiteral expr) throws EXCEPTION {
return visitFallback(expr);
}

@Override
public OUTPUT visit(Expression.EmptyListLiteral expr) throws EXCEPTION {
return visitFallback(expr);
}

@Override
public OUTPUT visit(Expression.StructLiteral expr) throws EXCEPTION {
return visitFallback(expr);
Expand Down
19 changes: 19 additions & 0 deletions core/src/main/java/io/substrait/expression/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,25 @@ public <R, E extends Throwable> R accept(ExpressionVisitor<R, E> visitor) throws
}
}

@Value.Immutable
abstract class EmptyListLiteral implements Literal {
public abstract Type elementType();

@Override
public Type.ListType getType() {
return Type.withNullability(nullable()).list(elementType());
}

public static ImmutableExpression.EmptyListLiteral.Builder builder() {
return ImmutableExpression.EmptyListLiteral.builder();
}

@Override
public <R, E extends Throwable> R accept(ExpressionVisitor<R, E> visitor) throws E {
return visitor.visit(this);
}
}

@Value.Immutable
abstract static class StructLiteral implements Literal {
public abstract List<Literal> fields();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,13 @@ public static Expression.ListLiteral list(
return Expression.ListLiteral.builder().nullable(nullable).addAllValues(values).build();
}

public static Expression.EmptyListLiteral emptyList(boolean listNullable, Type elementType) {
return Expression.EmptyListLiteral.builder()
.elementType(elementType)
.nullable(listNullable)
.build();
}

public static Expression.StructLiteral struct(boolean nullable, Expression.Literal... values) {
return Expression.StructLiteral.builder().nullable(nullable).addFields(values).build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ public interface ExpressionVisitor<R, E extends Throwable> {

R visit(Expression.ListLiteral expr) throws E;

R visit(Expression.EmptyListLiteral expr) throws E;

R visit(Expression.StructLiteral expr) throws E;

R visit(Expression.Switch expr) throws E;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,19 @@ public Expression visit(io.substrait.expression.Expression.ListLiteral expr) {
});
}

@Override
public Expression visit(io.substrait.expression.Expression.EmptyListLiteral expr)
throws RuntimeException {
return lit(
builder -> {
var protoListType = expr.getType().accept(typeProtoConverter);
// For empty lists, the literal's own nullable field is ignored in favor
// of the nullability of the enclosed Type.List.
var ignoredLiteralNullable = false;
builder.setEmptyList(protoListType.getList()).setNullable(ignoredLiteralNullable);
});
}

@Override
public Expression visit(io.substrait.expression.Expression.StructLiteral expr) {
return lit(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,12 @@ public Expression.Literal from(io.substrait.proto.Expression.Literal literal) {
literal.getList().getValuesList().stream()
.map(this::from)
.collect(java.util.stream.Collectors.toList()));
case EMPTY_LIST -> {
// literal.getNullable() is intentionally ignored in favor of the nullability
// specified in the literal.getEmptyList() type.
var listType = protoTypeConverter.fromList(literal.getEmptyList());
yield ExpressionCreator.emptyList(listType.nullable(), listType.elementType());
}
default -> throw new IllegalStateException(
"Unexpected value: " + literal.getLiteralTypeCase());
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ public Optional<Expression> visit(Expression.ListLiteral expr) throws EXCEPTION
return visitLiteral(expr);
}

@Override
public Optional<Expression> visit(Expression.EmptyListLiteral expr) throws EXCEPTION {
return visitLiteral(expr);
}

@Override
public Optional<Expression> visit(Expression.StructLiteral expr) throws EXCEPTION {
return visitLiteral(expr);
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/java/io/substrait/type/TypeCreator.java
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ public Type.Struct struct(Stream<? extends Type> types) {
.build();
}

public Type list(Type type) {
public Type.ListType list(Type type) {
return Type.ListType.builder().nullable(nullable).elementType(type).build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public Type from(io.substrait.proto.Type type) {
type.getStruct().getTypesList().stream()
.map(this::from)
.collect(java.util.stream.Collectors.toList()));
case LIST -> n(type.getList().getNullability()).list(from(type.getList().getType()));
case LIST -> fromList(type.getList());
case MAP -> n(type.getMap().getNullability())
.map(from(type.getMap().getKey()), from(type.getMap().getValue()));
case USER_DEFINED -> {
Expand All @@ -61,6 +61,10 @@ public Type from(io.substrait.proto.Type type) {
};
}

public Type.ListType fromList(io.substrait.proto.Type.List list) {
return n(list.getNullability()).list(from(list.getType()));
}

public static boolean isNullable(io.substrait.proto.Type.Nullability nullability) {
return io.substrait.proto.Type.Nullability.NULLABILITY_NULLABLE == nullability;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ public static List<CallConverter> defaults(TypeConverter typeConverter) {
new FieldSelectionConverter(typeConverter),
CallConverters.CASE,
CallConverters.CAST.apply(typeConverter),
new LiteralConstructorConverter());
new LiteralConstructorConverter(typeConverter));
}

public interface SimpleCallConverter extends CallConverter {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import io.substrait.type.Type;
import io.substrait.util.DecimalUtil;
import java.math.BigDecimal;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.concurrent.TimeUnit;
Expand All @@ -37,6 +38,8 @@
import org.apache.calcite.sql.SqlIntervalQualifier;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlOperatorBinding;
import org.apache.calcite.sql.fun.SqlArrayValueConstructor;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.util.TimeString;
Expand Down Expand Up @@ -248,6 +251,34 @@ public RexNode visit(Expression.ListLiteral expr) throws RuntimeException {
return rexBuilder.makeCall(SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR, args);
}

/**
* A custom Calcite array value constructor for empty arrays. The default Calcite {@link
* SqlArrayValueConstructor} does not allow for an empty element list, in particular because its
* {@code inferReturnType()} method works by inspecting the types of its elements.
*
* <p>See the <a href="https://issues.apache.org/jira/browse/CALCITE-3504">CALCITE-3504</a> for
* more details.
*/
private static class EmptySqlArrayValueConstructor extends SqlArrayValueConstructor {
private final RelDataType type;

private EmptySqlArrayValueConstructor(RelDataType type) {
this.type = type;
}

@Override
public RelDataType inferReturnType(SqlOperatorBinding opBinding) {
return type;
}
}

@Override
public RexNode visit(Expression.EmptyListLiteral expr) throws RuntimeException {
var calciteType = typeConverter.toCalcite(typeFactory, expr.getType());
return rexBuilder.makeCall(
new EmptySqlArrayValueConstructor(calciteType), Collections.emptyList());
}

@Override
public RexNode visit(Expression.MapLiteral expr) throws RuntimeException {
var args =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import io.substrait.expression.Expression;
import io.substrait.expression.ExpressionCreator;
import io.substrait.isthmus.CallConverter;
import io.substrait.isthmus.TypeConverter;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -18,29 +19,53 @@ public class LiteralConstructorConverter implements CallConverter {
static final org.slf4j.Logger logger =
org.slf4j.LoggerFactory.getLogger(LiteralConstructorConverter.class);

private final TypeConverter typeConverter;

public LiteralConstructorConverter(TypeConverter typeConverter) {
this.typeConverter = typeConverter;
}

@Override
public Optional<Expression> convert(
RexCall call, Function<RexNode, Expression> topLevelConverter) {
SqlOperator operator = call.getOperator();
if (operator instanceof SqlArrayValueConstructor) {
return Optional.of(
ExpressionCreator.list(
false,
call.operands.stream()
.map(t -> ((Expression.Literal) topLevelConverter.apply(t)))
.collect(java.util.stream.Collectors.toList())));
return call.getOperands().isEmpty()
? toEmptyListLiteral(call)
: toNonEmptyListLiteral(call, topLevelConverter);
} else if (operator instanceof SqlMapValueConstructor) {
List<Expression.Literal> literals =
call.operands.stream()
.map(t -> ((Expression.Literal) topLevelConverter.apply(t)))
.collect(java.util.stream.Collectors.toList());
Map<Expression.Literal, Expression.Literal> items = new HashMap<>();
assert literals.size() % 2 == 0;
for (int i = 0; i < literals.size(); i += 2) {
items.put(literals.get(i), literals.get(i + 1));
}
return Optional.of(ExpressionCreator.map(false, items));
return toMapLiteral(call, topLevelConverter);
}
return Optional.empty();
}

private Optional<Expression> toMapLiteral(
RexCall call, Function<RexNode, Expression> topLevelConverter) {
List<Expression.Literal> literals =
call.operands.stream()
.map(t -> ((Expression.Literal) topLevelConverter.apply(t)))
.collect(java.util.stream.Collectors.toList());
Map<Expression.Literal, Expression.Literal> items = new HashMap<>();
assert literals.size() % 2 == 0;
for (int i = 0; i < literals.size(); i += 2) {
items.put(literals.get(i), literals.get(i + 1));
}
return Optional.of(ExpressionCreator.map(false, items));
}

private Optional<Expression> toNonEmptyListLiteral(
RexCall call, Function<RexNode, Expression> topLevelConverter) {
return Optional.of(
ExpressionCreator.list(
false,
call.operands.stream()
.map(t -> ((Expression.Literal) topLevelConverter.apply(t)))
.collect(java.util.stream.Collectors.toList())));
}

private Optional<Expression> toEmptyListLiteral(RexCall call) {
var calciteElementType = call.getType().getComponentType();
var substraitElementType = typeConverter.toSubstrait(calciteElementType);
return Optional.of(ExpressionCreator.emptyList(false, substraitElementType));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package io.substrait.isthmus;

import io.substrait.dsl.SubstraitBuilder;
import io.substrait.expression.ExpressionCreator;
import io.substrait.relation.Rel;
import io.substrait.type.TypeCreator;
import java.util.List;
import org.junit.jupiter.api.Test;

public class EmptyArrayLiteralTest extends PlanTestBase {
private static final TypeCreator N = TypeCreator.of(true);

private final SubstraitBuilder b = new SubstraitBuilder(extensions);

@Test
void emptyArrayLiteral() {
var colType = N.I8;
var emptyListLiteral = ExpressionCreator.emptyList(false, N.I8);
var rel =
b.project(
input -> List.of(emptyListLiteral),
Rel.Remap.offset(1, 1),
b.namedScan(List.of("t"), List.of("col"), List.of(colType)));
assertFullRoundTrip(rel);
}
}

0 comments on commit f0c3bf8

Please sign in to comment.