Skip to content

Commit

Permalink
feat: add support for Expression.EmptyMapLiteral
Browse files Browse the repository at this point in the history
  • Loading branch information
Blizzara committed Oct 24, 2024
1 parent ac0b7d1 commit 45d9387
Show file tree
Hide file tree
Showing 10 changed files with 71 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ public OUTPUT visit(Expression.MapLiteral expr) throws EXCEPTION {
return visitFallback(expr);
}

@Override
public OUTPUT visit(Expression.EmptyMapLiteral expr) throws EXCEPTION {
return visitFallback(expr);
}

@Override
public OUTPUT visit(Expression.ListLiteral expr) throws EXCEPTION {
return visitFallback(expr);
Expand Down
19 changes: 19 additions & 0 deletions core/src/main/java/io/substrait/expression/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,25 @@ public <R, E extends Throwable> R accept(ExpressionVisitor<R, E> visitor) throws
}
}

@Value.Immutable
abstract static class EmptyMapLiteral implements Literal {
public abstract Type keyType();

public abstract Type valueType();

public Type getType() {
return Type.withNullability(nullable()).map(keyType(), valueType());
}

public static ImmutableExpression.EmptyMapLiteral.Builder builder() {
return ImmutableExpression.EmptyMapLiteral.builder();
}

public <R, E extends Throwable> R accept(ExpressionVisitor<R, E> visitor) throws E {
return visitor.visit(this);
}
}

@Value.Immutable
abstract static class ListLiteral implements Literal {
public abstract List<Literal> values();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,15 @@ public static Expression.MapLiteral map(
return Expression.MapLiteral.builder().nullable(nullable).putAllValues(values).build();
}

public static Expression.EmptyMapLiteral emptyMap(
boolean nullable, Type keyType, Type valueType) {
return Expression.EmptyMapLiteral.builder()
.keyType(keyType)
.valueType(valueType)
.nullable(nullable)
.build();
}

public static Expression.ListLiteral list(boolean nullable, Expression.Literal... values) {
return Expression.ListLiteral.builder().nullable(nullable).addValues(values).build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ public interface ExpressionVisitor<R, E extends Throwable> {

R visit(Expression.MapLiteral expr) throws E;

R visit(Expression.EmptyMapLiteral expr) throws E;

R visit(Expression.ListLiteral expr) throws E;

R visit(Expression.EmptyListLiteral expr) throws E;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,21 @@ public Expression visit(io.substrait.expression.Expression.MapLiteral expr) {
});
}

@Override
public Expression visit(io.substrait.expression.Expression.EmptyMapLiteral expr) {
return lit(
bldr -> {
var protoMapType = expr.getType().accept(typeProtoConverter);
bldr.setEmptyMap(protoMapType.getMap())
// For empty maps, the Literal message's own nullable field should be ignored
// in favor of the nullability of the Type.Map in the literal's
// empty_map field. But for safety we set the literal's nullable field
// to match in case any readers either look in the wrong location
// or want to verify that they are consistent.
.setNullable(expr.nullable());
});
}

@Override
public Expression visit(io.substrait.expression.Expression.ListLiteral expr) {
return lit(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,12 @@ public Expression.Literal from(io.substrait.proto.Expression.Literal literal) {
literal.getNullable(),
literal.getMap().getKeyValuesList().stream()
.collect(Collectors.toMap(kv -> from(kv.getKey()), kv -> from(kv.getValue()))));
case EMPTY_MAP -> {
// literal.getNullable() is intentionally ignored in favor of the nullability
// specified in the literal.getEmptyMap() type.
var mapType = protoTypeConverter.fromMap(literal.getEmptyMap());
yield ExpressionCreator.emptyMap(mapType.nullable(), mapType.key(), mapType.value());
}
case UUID -> ExpressionCreator.uuid(literal.getNullable(), literal.getUuid());
case NULL -> ExpressionCreator.typedNull(protoTypeConverter.from(literal.getNull()));
case LIST -> ExpressionCreator.list(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,11 @@ public Optional<Expression> visit(Expression.MapLiteral expr) throws EXCEPTION {
return visitLiteral(expr);
}

@Override
public Optional<Expression> visit(Expression.EmptyMapLiteral expr) throws EXCEPTION {
return visitLiteral(expr);
}

@Override
public Optional<Expression> visit(Expression.ListLiteral expr) throws EXCEPTION {
return visitLiteral(expr);
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/java/io/substrait/type/TypeCreator.java
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ public Type.ListType list(Type type) {
return Type.ListType.builder().nullable(nullable).elementType(type).build();
}

public Type map(Type key, Type value) {
public Type.Map map(Type key, Type value) {
return Type.Map.builder().nullable(nullable).key(key).value(value).build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ public Type from(io.substrait.proto.Type type) {
.map(this::from)
.collect(java.util.stream.Collectors.toList()));
case LIST -> fromList(type.getList());
case MAP -> n(type.getMap().getNullability())
.map(from(type.getMap().getKey()), from(type.getMap().getValue()));
case MAP -> fromMap(type.getMap());
case USER_DEFINED -> {
var userDefined = type.getUserDefined();
var t = lookup.getType(userDefined.getTypeReference(), extensions);
Expand All @@ -74,6 +73,10 @@ public Type.ListType fromList(io.substrait.proto.Type.List list) {
return n(list.getNullability()).list(from(list.getType()));
}

public Type.Map fromMap(io.substrait.proto.Type.Map map) {
return n(map.getNullability()).map(from(map.getKey()), from(map.getValue()));
}

public static boolean isNullable(io.substrait.proto.Type.Nullability nullability) {
return io.substrait.proto.Type.Nullability.NULLABILITY_NULLABLE == nullability;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,8 @@ class ExpressionToString extends DefaultExpressionVisitor[String] {
override def visit(expr: Expression.UserDefinedLiteral): String = {
expr.toString
}

override def visit(expr: Expression.EmptyMapLiteral): String = {
expr.toString
}
}

0 comments on commit 45d9387

Please sign in to comment.