From a3d070ca20eeba3280933b725cb5034e1dfec3c8 Mon Sep 17 00:00:00 2001 From: Andrew Coleman Date: Thu, 3 Oct 2024 14:12:48 +0100 Subject: [PATCH] chore(spark): enable spotless scala Add lint checking to the Scala code in the Spark module Signed-off-by: Andrew Coleman --- spark/.scalafmt.conf | 70 +++++++++++++++++++ spark/build.gradle.kts | 7 ++ .../io/substrait/spark/ToSubstraitType.scala | 5 +- .../spark/expression/FunctionConverter.scala | 3 +- .../IgnoreNullableAndParameters.scala | 12 ++-- .../spark/expression/ToSparkExpression.scala | 57 ++++++++------- .../expression/ToSubstraitExpression.scala | 2 +- .../spark/expression/ToSubstraitLiteral.scala | 3 +- .../spark/logical/ToLogicalPlan.scala | 26 ++++--- .../spark/logical/ToSubstraitRel.scala | 31 +++++--- .../scala/io/substrait/spark/TPCDSPlan.scala | 3 +- .../scala/io/substrait/spark/TPCHPlan.scala | 1 + 12 files changed, 164 insertions(+), 56 deletions(-) create mode 100644 spark/.scalafmt.conf diff --git a/spark/.scalafmt.conf b/spark/.scalafmt.conf new file mode 100644 index 000000000..fc6b14ef6 --- /dev/null +++ b/spark/.scalafmt.conf @@ -0,0 +1,70 @@ +runner.dialect = scala212 + +# Version is required to make sure IntelliJ picks the right version +version = 3.7.3 +preset = default + +# Max column +maxColumn = 100 + +# This parameter simply says the .stripMargin method was not redefined by the user to assign +# special meaning to indentation preceding the | character. Hence, that indentation can be modified. +assumeStandardLibraryStripMargin = true +align.stripMargin = true + +# Align settings +align.preset = none +align.closeParenSite = false +align.openParenCallSite = false +danglingParentheses.defnSite = false +danglingParentheses.callSite = false +danglingParentheses.ctrlSite = true +danglingParentheses.tupleSite = false +align.openParenCallSite = false +align.openParenDefnSite = false +align.openParenTupleSite = false + +# Newlines +newlines.alwaysBeforeElseAfterCurlyIf = false +newlines.beforeCurlyLambdaParams = multiline # Newline before lambda params +newlines.afterCurlyLambdaParams = squash # No newline after lambda params +newlines.inInterpolation = "avoid" +newlines.avoidInResultType = true +optIn.annotationNewlines = true + +# Scaladoc +docstrings.style = Asterisk # Javadoc style +docstrings.removeEmpty = true +docstrings.oneline = fold +docstrings.forceBlankLineBefore = true + +# Indentation +indent.extendSite = 2 # This makes sure extend is not indented as the ctor parameters + +# Rewrites +rewrite.rules = [AvoidInfix, Imports, RedundantBraces, SortModifiers] + +# Imports +rewrite.imports.sort = scalastyle +rewrite.imports.groups = [ + ["io.substrait.spark\\..*"], + ["org.apache.spark\\..*"], + [".*"], + ["javax\\..*"], + ["java\\..*"], + ["scala\\..*"] +] +rewrite.imports.contiguousGroups = no +importSelectors = singleline # Imports in a single line, like IntelliJ + +# Remove redundant braces in string interpolation. +rewrite.redundantBraces.stringInterpolation = true +rewrite.redundantBraces.defnBodies = false +rewrite.redundantBraces.generalExpressions = false +rewrite.redundantBraces.ifElseExpressions = false +rewrite.redundantBraces.methodBodies = false +rewrite.redundantBraces.includeUnitMethods = false +rewrite.redundantBraces.maxBreaks = 1 + +# Remove trailing commas +rewrite.trailingCommas.style = "never" diff --git a/spark/build.gradle.kts b/spark/build.gradle.kts index 54c320bc9..25cd28430 100644 --- a/spark/build.gradle.kts +++ b/spark/build.gradle.kts @@ -105,6 +105,13 @@ dependencies { testImplementation("org.apache.spark:spark-catalyst_2.12:${SPARK_VERSION}:tests") } +spotless { + scala { + scalafmt().configFile(".scalafmt.conf") + toggleOffOn() + } +} + tasks { test { dependsOn(":core:shadowJar") diff --git a/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala b/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala index 9522042ee..536facc68 100644 --- a/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala +++ b/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala @@ -16,11 +16,12 @@ */ package io.substrait.spark +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.types._ + import io.substrait.`type`.{NamedStruct, Type, TypeVisitor} import io.substrait.function.TypeExpression import io.substrait.utils.Util -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.types._ import scala.collection.JavaConverters import scala.collection.JavaConverters.asScalaBufferConverter diff --git a/spark/src/main/scala/io/substrait/spark/expression/FunctionConverter.scala b/spark/src/main/scala/io/substrait/spark/expression/FunctionConverter.scala index 7b0c2b473..e32e5e583 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/FunctionConverter.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/FunctionConverter.scala @@ -16,6 +16,8 @@ */ package io.substrait.spark.expression +import io.substrait.spark.ToSubstraitType + import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis.{AnsiTypeCoercion, TypeCoercion} @@ -29,7 +31,6 @@ import io.substrait.expression.{Expression => SExpression, ExpressionCreator, Fu import io.substrait.expression.Expression.FailureBehavior import io.substrait.extension.SimpleExtension import io.substrait.function.{ParameterizedType, ToTypeString} -import io.substrait.spark.ToSubstraitType import io.substrait.utils.Util import java.{util => ju} diff --git a/spark/src/main/scala/io/substrait/spark/expression/IgnoreNullableAndParameters.scala b/spark/src/main/scala/io/substrait/spark/expression/IgnoreNullableAndParameters.scala index 00076ce63..03c8252fb 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/IgnoreNullableAndParameters.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/IgnoreNullableAndParameters.scala @@ -56,10 +56,12 @@ class IgnoreNullableAndParameters(val typeToMatch: ParameterizedType) typeToMatch.isInstanceOf[Type.IntervalYear] override def visit(`type`: Type.IntervalDay): Boolean = - typeToMatch.isInstanceOf[Type.IntervalDay] || typeToMatch.isInstanceOf[ParameterizedType.IntervalDay] + typeToMatch.isInstanceOf[Type.IntervalDay] || typeToMatch + .isInstanceOf[ParameterizedType.IntervalDay] override def visit(`type`: Type.IntervalCompound): Boolean = - typeToMatch.isInstanceOf[Type.IntervalCompound] || typeToMatch.isInstanceOf[ParameterizedType.IntervalCompound] + typeToMatch.isInstanceOf[Type.IntervalCompound] || typeToMatch + .isInstanceOf[ParameterizedType.IntervalCompound] override def visit(`type`: Type.UUID): Boolean = typeToMatch.isInstanceOf[Type.UUID] @@ -109,11 +111,13 @@ class IgnoreNullableAndParameters(val typeToMatch: ParameterizedType) @throws[RuntimeException] override def visit(expr: ParameterizedType.IntervalDay): Boolean = - typeToMatch.isInstanceOf[Type.IntervalDay] || typeToMatch.isInstanceOf[ParameterizedType.IntervalDay] + typeToMatch.isInstanceOf[Type.IntervalDay] || typeToMatch + .isInstanceOf[ParameterizedType.IntervalDay] @throws[RuntimeException] override def visit(expr: ParameterizedType.IntervalCompound): Boolean = - typeToMatch.isInstanceOf[Type.IntervalCompound] || typeToMatch.isInstanceOf[ParameterizedType.IntervalCompound] + typeToMatch.isInstanceOf[Type.IntervalCompound] || typeToMatch + .isInstanceOf[ParameterizedType.IntervalCompound] @throws[RuntimeException] override def visit(expr: ParameterizedType.Struct): Boolean = diff --git a/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala b/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala index 2279d5496..54b472d3b 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala @@ -18,14 +18,16 @@ package io.substrait.spark.expression import io.substrait.spark.{DefaultExpressionVisitor, HasOutputStack, SparkExtension, ToSubstraitType} import io.substrait.spark.logical.ToLogicalPlan + import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Cast, Expression, In, Literal, MakeDecimal, NamedExpression, ScalarSubquery} import org.apache.spark.sql.types.Decimal +import org.apache.spark.substrait.SparkTypeUtil import org.apache.spark.unsafe.types.UTF8String + import io.substrait.`type`.{StringTypeVisitor, Type} import io.substrait.{expression => exp} import io.substrait.expression.{Expression => SExpression} import io.substrait.util.DecimalUtil -import org.apache.spark.substrait.SparkTypeUtil import scala.collection.JavaConverters.asScalaBufferConverter @@ -132,31 +134,34 @@ class ToSparkExpression( } expr.declaration.name match { - case "make_decimal" if expr.declaration.uri == SparkExtension.uri => expr.outputType match { - // Need special case handing of this internal function. - // Because the precision and scale arguments are extracted from the output type, - // we can't use the generic scalar function conversion mechanism here. - case d: Type.Decimal => MakeDecimal(args.head, d.precision, d.scale) - case _ => throw new IllegalArgumentException("Output type of MakeDecimal must be a decimal type") - } - case _ => scalarFunctionConverter - .getSparkExpressionFromSubstraitFunc(expr.declaration().key(), expr.outputType()) - .flatMap(sig => Option(sig.makeCall(args))) - .getOrElse({ - val msg = String.format( - "Unable to convert scalar function %s(%s).", - expr.declaration.name, - expr.arguments.asScala - .map { - case ea: exp.EnumArg => ea.value.toString - case e: SExpression => e.getType.accept(new StringTypeVisitor) - case t: Type => t.accept(new StringTypeVisitor) - case a => throw new IllegalStateException("Unexpected value: " + a) - } - .mkString(", ") - ) - throw new IllegalArgumentException(msg) - }) + case "make_decimal" if expr.declaration.uri == SparkExtension.uri => + expr.outputType match { + // Need special case handing of this internal function. + // Because the precision and scale arguments are extracted from the output type, + // we can't use the generic scalar function conversion mechanism here. + case d: Type.Decimal => MakeDecimal(args.head, d.precision, d.scale) + case _ => + throw new IllegalArgumentException("Output type of MakeDecimal must be a decimal type") + } + case _ => + scalarFunctionConverter + .getSparkExpressionFromSubstraitFunc(expr.declaration().key(), expr.outputType()) + .flatMap(sig => Option(sig.makeCall(args))) + .getOrElse({ + val msg = String.format( + "Unable to convert scalar function %s(%s).", + expr.declaration.name, + expr.arguments.asScala + .map { + case ea: exp.EnumArg => ea.value.toString + case e: SExpression => e.getType.accept(new StringTypeVisitor) + case t: Type => t.accept(new StringTypeVisitor) + case a => throw new IllegalStateException("Unexpected value: " + a) + } + .mkString(", ") + ) + throw new IllegalArgumentException(msg) + }) } } } diff --git a/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitExpression.scala b/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitExpression.scala index 4048829b3..609c1f0a8 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitExpression.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitExpression.scala @@ -19,11 +19,11 @@ package io.substrait.spark.expression import io.substrait.spark.{HasOutputStack, ToSubstraitType} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.substrait.SparkTypeUtil import io.substrait.expression.{Expression => SExpression, ExpressionCreator, FieldReference, ImmutableExpression} import io.substrait.expression.Expression.FailureBehavior import io.substrait.utils.Util -import org.apache.spark.substrait.SparkTypeUtil import scala.collection.JavaConverters.asJavaIterableConverter diff --git a/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitLiteral.scala b/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitLiteral.scala index 95633c15b..73362e982 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitLiteral.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitLiteral.scala @@ -16,13 +16,14 @@ */ package io.substrait.spark.expression +import io.substrait.spark.ToSubstraitType + import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import io.substrait.expression.{Expression => SExpression} import io.substrait.expression.ExpressionCreator._ -import io.substrait.spark.ToSubstraitType class ToSubstraitLiteral { diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala index 3740babee..cdc54b2ec 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala @@ -18,6 +18,7 @@ package io.substrait.spark.logical import io.substrait.spark.{DefaultRelVisitor, SparkExtension, ToSubstraitType} import io.substrait.spark.expression._ + import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedRelation} import org.apache.spark.sql.catalyst.expressions._ @@ -29,6 +30,7 @@ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, InMemoryFileIndex, LogicalRelation} import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.types.{DataTypes, IntegerType, StructField, StructType} + import io.substrait.`type`.{StringTypeVisitor, Type} import io.substrait.{expression => exp} import io.substrait.expression.{Expression => SExpression} @@ -167,14 +169,14 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan] if (limit >= 0) { val limitExpr = toLiteral(limit) if (offset > 0) { - GlobalLimit(limitExpr, - Offset(toLiteral(offset), - LocalLimit(toLiteral(offset + limit), child))) + GlobalLimit( + limitExpr, + Offset(toLiteral(offset), LocalLimit(toLiteral(offset + limit), child))) } else { GlobalLimit(limitExpr, LocalLimit(limitExpr, child)) } } else { - Offset(toLiteral(offset), child) + Offset(toLiteral(offset), child) } } override def visit(sort: relation.Sort): LogicalPlan = { @@ -213,13 +215,16 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan] withChild(child) { val projections = expand.getFields.asScala .map { - case sf: SwitchingField => sf.getDuplicates.asScala - .map(expr => expr.accept(expressionConverter)) - .map(toNamedExpression) - case _: ConsistentField => throw new UnsupportedOperationException("ConsistentField not currently supported") + case sf: SwitchingField => + sf.getDuplicates.asScala + .map(expr => expr.accept(expressionConverter)) + .map(toNamedExpression) + case _: ConsistentField => + throw new UnsupportedOperationException("ConsistentField not currently supported") } - val output = projections.head.zip(names) + val output = projections.head + .zip(names) .map { case (t, name) => StructField(name, t.dataType, t.nullable) } .map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) @@ -240,7 +245,8 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan] withOutput(children.flatMap(_.output)) { set.getSetOp match { case SetOp.UNION_ALL => Union(children, byName = false, allowMissingCol = false) - case op => throw new UnsupportedOperationException(s"Operation not currently supported: $op") + case op => + throw new UnsupportedOperationException(s"Operation not currently supported: $op") } } } diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala index 46d00f8cd..e6ce3a90c 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala @@ -18,6 +18,7 @@ package io.substrait.spark.logical import io.substrait.spark.{SparkExtension, ToSubstraitType} import io.substrait.spark.expression._ + import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.expressions._ @@ -27,10 +28,11 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation} import org.apache.spark.sql.types.StructType + import ToSubstraitType.toNamedStruct import io.substrait.{proto, relation} import io.substrait.debug.TreePrinter -import io.substrait.expression.{ExpressionCreator, Expression => SExpression} +import io.substrait.expression.{Expression => SExpression, ExpressionCreator} import io.substrait.extension.ExtensionCollector import io.substrait.hint.Hint import io.substrait.plan.{ImmutablePlan, ImmutableRoot, Plan} @@ -40,6 +42,7 @@ import io.substrait.relation.files.{FileFormat, ImmutableFileOrFiles} import io.substrait.relation.files.FileOrFiles.PathType import java.util.{Collections, Optional} + import scala.collection.JavaConverters.asJavaIterableConverter import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -71,8 +74,11 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { val substraitExps = expression.aggregateFunction.children.map(toExpression(output)) val invocation = SparkExtension.toAggregateFunction.apply(expression, substraitExps) - val filter = expression.filter map toExpression(output) - relation.Aggregate.Measure.builder.function(invocation).preMeasureFilter(Optional.ofNullable(filter.orNull)).build() + val filter = expression.filter.map(toExpression(output)) + relation.Aggregate.Measure.builder + .function(invocation) + .preMeasureFilter(Optional.ofNullable(filter.orNull)) + .build() } private def collectAggregates( @@ -172,7 +178,8 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { } private def fetch(child: LogicalPlan, offset: Long, limit: Long = -1): relation.Fetch = { - relation.Fetch.builder() + relation.Fetch + .builder() .input(visit(child)) .offset(offset) .count(limit) @@ -183,7 +190,8 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { p match { case OffsetAndLimit((offset, limit, child)) => fetch(child, offset, limit) case GlobalLimit(IntegerLiteral(globalLimit), LocalLimit(IntegerLiteral(localLimit), child)) - if globalLimit == localLimit => fetch(child, 0, localLimit) + if globalLimit == localLimit => + fetch(child, 0, localLimit) case _ => throw new UnsupportedOperationException(s"Unable to convert the limit expression: $p") } @@ -251,11 +259,14 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { } override def visitExpand(p: Expand): relation.Rel = { - val fields = p.projections.map(proj => { - relation.Expand.SwitchingField.builder.duplicates( - proj.map(toExpression(p.child.output)).asJava - ).build() - }) + val fields = p.projections.map( + proj => { + relation.Expand.SwitchingField.builder + .duplicates( + proj.map(toExpression(p.child.output)).asJava + ) + .build() + }) val names = p.output.map(_.name) diff --git a/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala b/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala index 826d7200c..c5bea8783 100644 --- a/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala +++ b/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala @@ -31,7 +31,7 @@ class TPCDSPlan extends TPCDSBase with SubstraitPlanTestBase { spark.conf.set("spark.sql.readSideCharPadding", "false") } - // "q9" failed in spark 3.3 + // spotless:off val successfulSQL: Set[String] = Set("q1", "q3", "q4", "q7", "q11", "q13", "q15", "q16", "q18", "q19", "q22", "q23a", "q23b", "q25", "q26", "q28", "q29", @@ -42,6 +42,7 @@ class TPCDSPlan extends TPCDSBase with SubstraitPlanTestBase { "q71", "q76", "q79", "q81", "q82", "q85", "q88", "q90", "q91", "q92", "q93", "q94", "q95", "q96", "q97", "q99") + // spotless:on tpcdsQueries.foreach { q => diff --git a/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala b/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala index 76c5b9a6a..c96cd5531 100644 --- a/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala +++ b/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala @@ -17,6 +17,7 @@ package io.substrait.spark import io.substrait.spark.logical.{ToLogicalPlan, ToSubstraitRel} + import org.apache.spark.sql.TPCHBase class TPCHPlan extends TPCHBase with SubstraitPlanTestBase {