Skip to content

Commit

Permalink
fix: code review - no conversions for precision, fix IntervalDay form…
Browse files Browse the repository at this point in the history
…ulas
  • Loading branch information
Blizzara committed Oct 25, 2024
1 parent 410475f commit fd01c2f
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 36 deletions.
15 changes: 3 additions & 12 deletions spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,25 +53,16 @@ private class ToSparkType
override def visit(expr: Type.Bool): DataType = BooleanType

override def visit(expr: Type.PrecisionTimestamp): DataType = {
if (expr.precision() != Util.MICROSECOND_PRECISION) {
throw new UnsupportedOperationException(
s"Unsupported precision for timestamp: ${expr.precision()}")
}
Util.assertMicroseconds(expr.precision())
TimestampNTZType
}
override def visit(expr: Type.PrecisionTimestampTZ): DataType = {
if (expr.precision() != Util.MICROSECOND_PRECISION) {
throw new UnsupportedOperationException(
s"Unsupported precision for timestamp: ${expr.precision()}")
}
Util.assertMicroseconds(expr.precision())
TimestampType
}

override def visit(expr: Type.IntervalDay): DataType = {
if (expr.precision() != Util.MICROSECOND_PRECISION) {
throw new UnsupportedOperationException(
s"Unsupported precision for intervalDay: ${expr.precision()}")
}
Util.assertMicroseconds(expr.precision())
DayTimeIntervalType.DEFAULT
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,26 +97,26 @@ class ToSparkExpression(
}

override def visit(expr: SExpression.PrecisionTimestampLiteral): Literal = {
Literal(
Util.toMicroseconds(expr.value(), expr.precision()),
ToSubstraitType.convert(expr.getType))

// Spark timestamps are stored as a microseconds Long
Util.assertMicroseconds(expr.precision())
Literal(expr.value(), ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.PrecisionTimestampTZLiteral): Literal = {
Literal(
Util.toMicroseconds(expr.value(), expr.precision()),
ToSubstraitType.convert(expr.getType))
// Spark timestamps are stored as a microseconds Long
Util.assertMicroseconds(expr.precision())
Literal(expr.value(), ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.IntervalDayLiteral): Literal = {
val micros =
(expr.days() * Util.SECONDS_PER_DAY + expr.seconds()) * Util.MICROSECOND_PRECISION +
Util.toMicroseconds(expr.subseconds(), expr.precision())
Util.assertMicroseconds(expr.precision())
// Spark uses a single microseconds Long as the "physical" type for DayTimeInterval
val micros = (expr.days() * Util.SECONDS_PER_DAY + expr.seconds()) * Util.MICROS_PER_SECOND + expr.subseconds()
Literal(micros, ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.IntervalYearLiteral): Literal = {
// Spark uses a single months Int as the "physical" type for YearMonthInterval
val months = expr.years() * 12 + expr.months()
Literal(months, ToSubstraitType.convert(expr.getType))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,12 @@ class ToSubstraitLiteral {
precisionTimestamp(false, _, Util.MICROSECOND_PRECISION) // Spark ts is in microseconds
val _timestampTz: Long => SExpression.Literal =
precisionTimestampTZ(false, _, Util.MICROSECOND_PRECISION) // Spark ts is in microseconds
val _intervalDay: Long => SExpression.Literal = (ms: Long) =>
intervalDay(false, 0, 0, ms, Util.MICROSECOND_PRECISION)
val _intervalDay: Long => SExpression.Literal = (ms: Long) => {
val days = (ms / Util.MICROS_PER_SECOND / Util.SECONDS_PER_DAY).toInt
val seconds = (ms / Util.MICROS_PER_SECOND % Util.SECONDS_PER_DAY).toInt
val micros = ms % Util.MICROS_PER_SECOND
intervalDay(false, days, seconds, micros, Util.MICROSECOND_PRECISION)
}
val _intervalYear: Int => SExpression.Literal = (m: Int) => intervalYear(false, m / 12, m % 12)
val _string: String => SExpression.Literal = string(false, _)
val _binary: Array[Byte] => SExpression.Literal = binary(false, _)
Expand Down
20 changes: 8 additions & 12 deletions spark/src/main/scala/io/substrait/utils/Util.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,15 @@ import scala.collection.mutable.ArrayBuffer

object Util {

val SECONDS_PER_DAY: Long = 24 * 60 * 60;
val MICROSECOND_PRECISION = 6; // for PrecisionTimestamp(TZ) types
val SECONDS_PER_DAY: Long = 24 * 60 * 60
val MICROS_PER_SECOND: Long = 1000 * 1000
val MICROSECOND_PRECISION = 6 // for PrecisionTimestamp(TZ) and IntervalDay types

def toMicroseconds(value: Long, precision: Int): Long = {
// Spark uses microseconds as a Long value for most time things
val factor = MICROSECOND_PRECISION - precision
// Doing this in a way that avoids floating point math
if (factor == 0) {
value
} else if (factor > 0) {
value * math.pow(10, factor).toLong
} else {
value / math.pow(10, -factor).toLong
def assertMicroseconds(precision: Int): Unit = {
// Spark uses microseconds as a Long value as the "physical" type for most time things
if (precision != MICROSECOND_PRECISION) {
throw new UnsupportedOperationException(
s"Unsupported precision: $precision. Only microsecond precision ($MICROSECOND_PRECISION) is supported")
}
}

Expand Down

0 comments on commit fd01c2f

Please sign in to comment.