Skip to content

Commit

Permalink
feat(isthmus): support for safe casting (#236)
Browse files Browse the repository at this point in the history
BREAKING CHANGE: method ExpressionCreator.cast(Type, Expression) has been removed

---------

Signed-off-by: mbwhite <whitemat@uk.ibm.com>
Co-authored-by: Victor Barua <victor.barua@datadoghq.com>
  • Loading branch information
mbwhite and vbarua authored Mar 15, 2024
1 parent 51406f7 commit 72785ad
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 11 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ gen
*.iml
out/**
*.iws
.vscode
Original file line number Diff line number Diff line change
Expand Up @@ -407,10 +407,6 @@ public static Expression.WindowFunctionInvocation windowFunction(
.build();
}

public static Expression cast(Type type, Expression expression) {
return cast(type, expression, Expression.FailureBehavior.UNSPECIFIED);
}

public static Expression cast(
Type type, Expression expression, Expression.FailureBehavior failureBehavior) {
return Expression.Cast.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,9 @@ public Expression from(io.substrait.proto.Expression expr) {
.build();
}
case CAST -> ExpressionCreator.cast(
protoTypeConverter.from(expr.getCast().getType()), from(expr.getCast().getInput()));
protoTypeConverter.from(expr.getCast().getType()),
from(expr.getCast().getInput()),
Expression.FailureBehavior.fromProto(expr.getCast().getFailureBehavior()));
case SUBQUERY -> {
switch (expr.getSubquery().getSubqueryTypeCase()) {
case SET_PREDICATE -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,22 @@ public class CallConverters {
public static Function<TypeConverter, SimpleCallConverter> CAST =
typeConverter ->
(call, visitor) -> {
if (call.getKind() != SqlKind.CAST) {
return null;
Expression.FailureBehavior failureBehavior;
switch (call.getKind()) {
case CAST:
failureBehavior = Expression.FailureBehavior.THROW_EXCEPTION;
break;
case SAFE_CAST:
failureBehavior = Expression.FailureBehavior.RETURN_NULL;
break;
default:
return null;
}

return ExpressionCreator.cast(
typeConverter.toSubstrait(call.getType()),
visitor.apply(call.getOperands().get(0)));
visitor.apply(call.getOperands().get(0)),
failureBehavior);
};

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import io.substrait.expression.AbstractExpressionVisitor;
import io.substrait.expression.EnumArg;
import io.substrait.expression.Expression;
import io.substrait.expression.Expression.FailureBehavior;
import io.substrait.expression.Expression.SingleOrList;
import io.substrait.expression.Expression.Switch;
import io.substrait.expression.FieldReference;
Expand Down Expand Up @@ -478,8 +479,9 @@ private String convert(FunctionArg a) {

@Override
public RexNode visit(Expression.Cast expr) throws RuntimeException {
var safeCast = expr.failureBehavior() == FailureBehavior.RETURN_NULL;
return rexBuilder.makeAbstractCast(
typeConverter.toCalcite(typeFactory, expr.getType()), expr.input().accept(this));
typeConverter.toCalcite(typeFactory, expr.getType()), expr.input().accept(this), safeCast);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ private static List<Expression> coerceArguments(List<Expression> arguments, Type
private static Expression coerceArgument(Expression argument, Type type) {
var typeMatches = isMatch(type, argument.getType());
if (!typeMatches) {
return ExpressionCreator.cast(type, argument);
return ExpressionCreator.cast(type, argument, Expression.FailureBehavior.THROW_EXCEPTION);
}
return argument;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@ public void coerceNumericOp() {
func -> {
// check that there is a cast for the incorrect argument type.
assertEquals(
ExpressionCreator.cast(TypeCreator.REQUIRED.I64, ExpressionCreator.i32(false, 20)),
ExpressionCreator.cast(
TypeCreator.REQUIRED.I64,
ExpressionCreator.i32(false, 20),
Expression.FailureBehavior.THROW_EXCEPTION),
func.arguments().get(0));
},
false); // TODO: implicit calcite cast
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import io.substrait.dsl.SubstraitBuilder;
import io.substrait.expression.Expression;
import io.substrait.expression.ExpressionCreator;
import io.substrait.expression.proto.ExpressionProtoConverter;
import io.substrait.extension.ExtensionCollector;
import io.substrait.isthmus.expression.ExpressionRexConverter;
Expand Down Expand Up @@ -105,6 +106,24 @@ public void switchExpression() {
expression);
}

@Test
public void castFailureCondition() {
Rel rel =
b.project(
input ->
List.of(
ExpressionCreator.cast(
R.I64,
b.fieldReference(input, 0),
Expression.FailureBehavior.THROW_EXCEPTION),
ExpressionCreator.cast(
R.I32, b.fieldReference(input, 0), Expression.FailureBehavior.RETURN_NULL)),
b.remap(1, 2),
b.namedScan(List.of("test"), List.of("col1"), List.of(R.STRING)));

assertFullRoundTrip(rel);
}

void assertExpressionEquality(Expression expected, Expression actual) {
// go the extra mile and convert both inputs to protobuf
// helps verify that the protobuf conversion is not broken
Expand Down

0 comments on commit 72785ad

Please sign in to comment.