Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: enable Spotless for Scala code #304

Merged
merged 1 commit into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions spark/.scalafmt.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
runner.dialect = scala212

# Version is required to make sure IntelliJ picks the right version
version = 3.7.3
preset = default

# Max column
maxColumn = 100

# This parameter simply says the .stripMargin method was not redefined by the user to assign
# special meaning to indentation preceding the | character. Hence, that indentation can be modified.
assumeStandardLibraryStripMargin = true
align.stripMargin = true

# Align settings
align.preset = none
align.closeParenSite = false
align.openParenCallSite = false
danglingParentheses.defnSite = false
danglingParentheses.callSite = false
danglingParentheses.ctrlSite = true
danglingParentheses.tupleSite = false
align.openParenCallSite = false
align.openParenDefnSite = false
align.openParenTupleSite = false

# Newlines
newlines.alwaysBeforeElseAfterCurlyIf = false
newlines.beforeCurlyLambdaParams = multiline # Newline before lambda params
newlines.afterCurlyLambdaParams = squash # No newline after lambda params
newlines.inInterpolation = "avoid"
newlines.avoidInResultType = true
optIn.annotationNewlines = true

# Scaladoc
docstrings.style = Asterisk # Javadoc style
docstrings.removeEmpty = true
docstrings.oneline = fold
docstrings.forceBlankLineBefore = true

# Indentation
indent.extendSite = 2 # This makes sure extend is not indented as the ctor parameters

# Rewrites
rewrite.rules = [AvoidInfix, Imports, RedundantBraces, SortModifiers]

# Imports
rewrite.imports.sort = scalastyle
rewrite.imports.groups = [
["io.substrait.spark\\..*"],
["org.apache.spark\\..*"],
[".*"],
["javax\\..*"],
["java\\..*"],
["scala\\..*"]
]
rewrite.imports.contiguousGroups = no
importSelectors = singleline # Imports in a single line, like IntelliJ

# Remove redundant braces in string interpolation.
rewrite.redundantBraces.stringInterpolation = true
rewrite.redundantBraces.defnBodies = false
rewrite.redundantBraces.generalExpressions = false
rewrite.redundantBraces.ifElseExpressions = false
rewrite.redundantBraces.methodBodies = false
rewrite.redundantBraces.includeUnitMethods = false
rewrite.redundantBraces.maxBreaks = 1

# Remove trailing commas
rewrite.trailingCommas.style = "never"
7 changes: 7 additions & 0 deletions spark/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,13 @@ dependencies {
testImplementation("org.apache.spark:spark-catalyst_2.12:${SPARK_VERSION}:tests")
}

spotless {
scala {
scalafmt().configFile(".scalafmt.conf")
toggleOffOn()
}
}

tasks {
test {
dependsOn(":core:shadowJar")
Expand Down
5 changes: 3 additions & 2 deletions spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
*/
package io.substrait.spark

import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.types._

import io.substrait.`type`.{NamedStruct, Type, TypeVisitor}
import io.substrait.function.TypeExpression
import io.substrait.utils.Util
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.types._

import scala.collection.JavaConverters
import scala.collection.JavaConverters.asScalaBufferConverter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package io.substrait.spark.expression

import io.substrait.spark.ToSubstraitType

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.analysis.{AnsiTypeCoercion, TypeCoercion}
Expand All @@ -29,7 +31,6 @@ import io.substrait.expression.{Expression => SExpression, ExpressionCreator, Fu
import io.substrait.expression.Expression.FailureBehavior
import io.substrait.extension.SimpleExtension
import io.substrait.function.{ParameterizedType, ToTypeString}
import io.substrait.spark.ToSubstraitType
import io.substrait.utils.Util

import java.{util => ju}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,12 @@ class IgnoreNullableAndParameters(val typeToMatch: ParameterizedType)
typeToMatch.isInstanceOf[Type.IntervalYear]

override def visit(`type`: Type.IntervalDay): Boolean =
typeToMatch.isInstanceOf[Type.IntervalDay] || typeToMatch.isInstanceOf[ParameterizedType.IntervalDay]
typeToMatch.isInstanceOf[Type.IntervalDay] || typeToMatch
.isInstanceOf[ParameterizedType.IntervalDay]

override def visit(`type`: Type.IntervalCompound): Boolean =
typeToMatch.isInstanceOf[Type.IntervalCompound] || typeToMatch.isInstanceOf[ParameterizedType.IntervalCompound]
typeToMatch.isInstanceOf[Type.IntervalCompound] || typeToMatch
.isInstanceOf[ParameterizedType.IntervalCompound]

override def visit(`type`: Type.UUID): Boolean = typeToMatch.isInstanceOf[Type.UUID]

Expand Down Expand Up @@ -109,11 +111,13 @@ class IgnoreNullableAndParameters(val typeToMatch: ParameterizedType)

@throws[RuntimeException]
override def visit(expr: ParameterizedType.IntervalDay): Boolean =
typeToMatch.isInstanceOf[Type.IntervalDay] || typeToMatch.isInstanceOf[ParameterizedType.IntervalDay]
typeToMatch.isInstanceOf[Type.IntervalDay] || typeToMatch
.isInstanceOf[ParameterizedType.IntervalDay]

@throws[RuntimeException]
override def visit(expr: ParameterizedType.IntervalCompound): Boolean =
typeToMatch.isInstanceOf[Type.IntervalCompound] || typeToMatch.isInstanceOf[ParameterizedType.IntervalCompound]
typeToMatch.isInstanceOf[Type.IntervalCompound] || typeToMatch
.isInstanceOf[ParameterizedType.IntervalCompound]

@throws[RuntimeException]
override def visit(expr: ParameterizedType.Struct): Boolean =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@ package io.substrait.spark.expression

import io.substrait.spark.{DefaultExpressionVisitor, HasOutputStack, SparkExtension, ToSubstraitType}
import io.substrait.spark.logical.ToLogicalPlan

import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Cast, Expression, In, Literal, MakeDecimal, NamedExpression, ScalarSubquery}
import org.apache.spark.sql.types.Decimal
import org.apache.spark.substrait.SparkTypeUtil
import org.apache.spark.unsafe.types.UTF8String

import io.substrait.`type`.{StringTypeVisitor, Type}
import io.substrait.{expression => exp}
import io.substrait.expression.{Expression => SExpression}
import io.substrait.util.DecimalUtil
import org.apache.spark.substrait.SparkTypeUtil

import scala.collection.JavaConverters.asScalaBufferConverter

Expand Down Expand Up @@ -132,31 +134,34 @@ class ToSparkExpression(
}

expr.declaration.name match {
case "make_decimal" if expr.declaration.uri == SparkExtension.uri => expr.outputType match {
// Need special case handing of this internal function.
// Because the precision and scale arguments are extracted from the output type,
// we can't use the generic scalar function conversion mechanism here.
case d: Type.Decimal => MakeDecimal(args.head, d.precision, d.scale)
case _ => throw new IllegalArgumentException("Output type of MakeDecimal must be a decimal type")
}
case _ => scalarFunctionConverter
.getSparkExpressionFromSubstraitFunc(expr.declaration().key(), expr.outputType())
.flatMap(sig => Option(sig.makeCall(args)))
.getOrElse({
val msg = String.format(
"Unable to convert scalar function %s(%s).",
expr.declaration.name,
expr.arguments.asScala
.map {
case ea: exp.EnumArg => ea.value.toString
case e: SExpression => e.getType.accept(new StringTypeVisitor)
case t: Type => t.accept(new StringTypeVisitor)
case a => throw new IllegalStateException("Unexpected value: " + a)
}
.mkString(", ")
)
throw new IllegalArgumentException(msg)
})
case "make_decimal" if expr.declaration.uri == SparkExtension.uri =>
expr.outputType match {
// Need special case handing of this internal function.
// Because the precision and scale arguments are extracted from the output type,
// we can't use the generic scalar function conversion mechanism here.
case d: Type.Decimal => MakeDecimal(args.head, d.precision, d.scale)
case _ =>
throw new IllegalArgumentException("Output type of MakeDecimal must be a decimal type")
}
case _ =>
scalarFunctionConverter
.getSparkExpressionFromSubstraitFunc(expr.declaration().key(), expr.outputType())
.flatMap(sig => Option(sig.makeCall(args)))
.getOrElse({
val msg = String.format(
"Unable to convert scalar function %s(%s).",
expr.declaration.name,
expr.arguments.asScala
.map {
case ea: exp.EnumArg => ea.value.toString
case e: SExpression => e.getType.accept(new StringTypeVisitor)
case t: Type => t.accept(new StringTypeVisitor)
case a => throw new IllegalStateException("Unexpected value: " + a)
}
.mkString(", ")
)
throw new IllegalArgumentException(msg)
})
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ package io.substrait.spark.expression
import io.substrait.spark.{HasOutputStack, ToSubstraitType}

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.substrait.SparkTypeUtil

import io.substrait.expression.{Expression => SExpression, ExpressionCreator, FieldReference, ImmutableExpression}
import io.substrait.expression.Expression.FailureBehavior
import io.substrait.utils.Util
import org.apache.spark.substrait.SparkTypeUtil

import scala.collection.JavaConverters.asJavaIterableConverter

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
*/
package io.substrait.spark.expression

import io.substrait.spark.ToSubstraitType

import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

import io.substrait.expression.{Expression => SExpression}
import io.substrait.expression.ExpressionCreator._
import io.substrait.spark.ToSubstraitType

class ToSubstraitLiteral {

Expand Down
26 changes: 16 additions & 10 deletions spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package io.substrait.spark.logical

import io.substrait.spark.{DefaultRelVisitor, SparkExtension, ToSubstraitType}
import io.substrait.spark.expression._

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedRelation}
import org.apache.spark.sql.catalyst.expressions._
Expand All @@ -29,6 +30,7 @@ import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, InMemoryFileIndex, LogicalRelation}
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
import org.apache.spark.sql.types.{DataTypes, IntegerType, StructField, StructType}

import io.substrait.`type`.{StringTypeVisitor, Type}
import io.substrait.{expression => exp}
import io.substrait.expression.{Expression => SExpression}
Expand Down Expand Up @@ -167,14 +169,14 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan]
if (limit >= 0) {
val limitExpr = toLiteral(limit)
if (offset > 0) {
GlobalLimit(limitExpr,
Offset(toLiteral(offset),
LocalLimit(toLiteral(offset + limit), child)))
GlobalLimit(
limitExpr,
Offset(toLiteral(offset), LocalLimit(toLiteral(offset + limit), child)))
} else {
GlobalLimit(limitExpr, LocalLimit(limitExpr, child))
}
} else {
Offset(toLiteral(offset), child)
Offset(toLiteral(offset), child)
}
}
override def visit(sort: relation.Sort): LogicalPlan = {
Expand Down Expand Up @@ -213,13 +215,16 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan]
withChild(child) {
val projections = expand.getFields.asScala
.map {
case sf: SwitchingField => sf.getDuplicates.asScala
.map(expr => expr.accept(expressionConverter))
.map(toNamedExpression)
case _: ConsistentField => throw new UnsupportedOperationException("ConsistentField not currently supported")
case sf: SwitchingField =>
sf.getDuplicates.asScala
.map(expr => expr.accept(expressionConverter))
.map(toNamedExpression)
case _: ConsistentField =>
throw new UnsupportedOperationException("ConsistentField not currently supported")
}

val output = projections.head.zip(names)
val output = projections.head
.zip(names)
.map { case (t, name) => StructField(name, t.dataType, t.nullable) }
.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())

Expand All @@ -240,7 +245,8 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan]
withOutput(children.flatMap(_.output)) {
set.getSetOp match {
case SetOp.UNION_ALL => Union(children, byName = false, allowMissingCol = false)
case op => throw new UnsupportedOperationException(s"Operation not currently supported: $op")
case op =>
throw new UnsupportedOperationException(s"Operation not currently supported: $op")
}
}
}
Expand Down
Loading
Loading