diff --git a/.editorconfig b/.editorconfig index ce1567087..998c28f6c 100644 --- a/.editorconfig +++ b/.editorconfig @@ -10,7 +10,7 @@ trim_trailing_whitespace = true [*.{yaml,yml}] indent_size = 2 -[{**/*.sql,**/OuterReferenceResolver.md,**gradlew.bat}] +[{**/*.sql,**/OuterReferenceResolver.md,**gradlew.bat,**/*.parquet,**/*.orc}] charset = unset end_of_line = unset insert_final_newline = unset diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java index c9bc60705..130e3aa4d 100644 --- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java @@ -311,7 +311,17 @@ protected FileOrFiles newFileOrFiles(ReadRel.LocalFiles.FileOrFiles file) { } else if (file.hasDwrf()) { builder.fileFormat(ImmutableFileFormat.DwrfReadOptions.builder().build()); } else if (file.hasText()) { - throw new RuntimeException("Delimiter separated text files not supported yet"); // TODO + var ffBuilder = + ImmutableFileFormat.DelimiterSeparatedTextReadOptions.builder() + .fieldDelimiter(file.getText().getFieldDelimiter()) + .maxLineSize(file.getText().getMaxLineSize()) + .quote(file.getText().getQuote()) + .headerLinesToSkip(file.getText().getHeaderLinesToSkip()) + .escape(file.getText().getEscape()); + if (file.getText().hasValueTreatedAsNull()) { + ffBuilder.valueTreatedAsNull(file.getText().getValueTreatedAsNull()); + } + builder.fileFormat(ffBuilder.build()); } else if (file.hasExtension()) { builder.fileFormat( ImmutableFileFormat.Extension.builder().extension(file.getExtension()).build()); diff --git a/core/src/main/java/io/substrait/relation/files/FileFormat.java b/core/src/main/java/io/substrait/relation/files/FileFormat.java index 0f3524da6..b8632867a 100644 --- a/core/src/main/java/io/substrait/relation/files/FileFormat.java +++ b/core/src/main/java/io/substrait/relation/files/FileFormat.java @@ -1,5 +1,6 @@ package io.substrait.relation.files; +import java.util.Optional; import org.immutables.value.Value; @Value.Enclosing @@ -17,6 +18,21 @@ abstract static class OrcReadOptions implements FileFormat {} @Value.Immutable abstract static class DwrfReadOptions implements FileFormat {} + @Value.Immutable + abstract static class DelimiterSeparatedTextReadOptions implements FileFormat { + public abstract String getFieldDelimiter(); + + public abstract long getMaxLineSize(); + + public abstract String getQuote(); + + public abstract long getHeaderLinesToSkip(); + + public abstract String getEscape(); + + public abstract Optional getValueTreatedAsNull(); + } + @Value.Immutable abstract static class Extension implements FileFormat { public abstract com.google.protobuf.Any getExtension(); diff --git a/core/src/main/java/io/substrait/relation/files/FileOrFiles.java b/core/src/main/java/io/substrait/relation/files/FileOrFiles.java index 87dab1353..2dbb54ac5 100644 --- a/core/src/main/java/io/substrait/relation/files/FileOrFiles.java +++ b/core/src/main/java/io/substrait/relation/files/FileOrFiles.java @@ -43,6 +43,17 @@ default ReadRel.LocalFiles.FileOrFiles toProto() { } else if (fileFormat instanceof FileFormat.DwrfReadOptions options) { builder.setDwrf( ReadRel.LocalFiles.FileOrFiles.DwrfReadOptions.newBuilder().build()); + } else if (fileFormat + instanceof FileFormat.DelimiterSeparatedTextReadOptions options) { + var optionsBuilder = + ReadRel.LocalFiles.FileOrFiles.DelimiterSeparatedTextReadOptions.newBuilder() + .setFieldDelimiter(options.getFieldDelimiter()) + .setMaxLineSize(options.getMaxLineSize()) + .setQuote(options.getQuote()) + .setHeaderLinesToSkip(options.getHeaderLinesToSkip()) + .setEscape(options.getEscape()); + options.getValueTreatedAsNull().ifPresent(optionsBuilder::setValueTreatedAsNull); + builder.setText(optionsBuilder.build()); } else if (fileFormat instanceof FileFormat.Extension options) { builder.setExtension(options.getExtension()); } else { diff --git a/core/src/test/java/io/substrait/type/proto/LocalFilesRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/LocalFilesRoundtripTest.java index 7166eef6c..be40d4905 100644 --- a/core/src/test/java/io/substrait/type/proto/LocalFilesRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/LocalFilesRoundtripTest.java @@ -75,7 +75,14 @@ private ImmutableFileOrFiles.Builder setFileFormat( case ARROW -> builder.fileFormat(ImmutableFileFormat.ArrowReadOptions.builder().build()); case ORC -> builder.fileFormat(ImmutableFileFormat.OrcReadOptions.builder().build()); case DWRF -> builder.fileFormat(ImmutableFileFormat.DwrfReadOptions.builder().build()); - case TEXT -> builder; // TODO + case TEXT -> builder.fileFormat( + ImmutableFileFormat.DelimiterSeparatedTextReadOptions.builder() + .fieldDelimiter("|") + .maxLineSize(1000) + .quote("\"") + .headerLinesToSkip(1) + .escape("\\") + .build()); case EXTENSION -> builder.fileFormat( ImmutableFileFormat.Extension.builder() .extension(com.google.protobuf.Any.newBuilder().build()) diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala index 68b15345a..200cdeb8d 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala @@ -30,6 +30,8 @@ import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, InMemoryFileIndex, LogicalRelation} import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat +import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.types.{DataTypes, IntegerType, StructField, StructType} import io.substrait.`type`.{StringTypeVisitor, Type} @@ -40,6 +42,7 @@ import io.substrait.relation import io.substrait.relation.Expand.{ConsistentField, SwitchingField} import io.substrait.relation.LocalFiles import io.substrait.relation.Set.SetOp +import io.substrait.relation.files.FileFormat import org.apache.hadoop.fs.Path import scala.collection.JavaConverters.asScalaBufferConverter @@ -331,6 +334,34 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan] override def visit(localFiles: LocalFiles): LogicalPlan = { val schema = ToSubstraitType.toStructType(localFiles.getInitialSchema) val output = schema.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) + + // spark requires that all files have the same format + val formats = localFiles.getItems.asScala.map(i => i.getFileFormat.orElse(null)).distinct + if (formats.length != 1) { + throw new UnsupportedOperationException(s"All files must have the same format") + } + val (format, options) = formats.head match { + case csv: FileFormat.DelimiterSeparatedTextReadOptions => + val opts = scala.collection.mutable.Map[String, String]( + "delimiter" -> csv.getFieldDelimiter, + "quote" -> csv.getQuote, + "header" -> (csv.getHeaderLinesToSkip match { + case 0 => "false" + case 1 => "true" + case _ => + throw new UnsupportedOperationException( + s"Cannot configure CSV reader to skip ${csv.getHeaderLinesToSkip} rows") + }), + "escape" -> csv.getEscape, + "maxColumns" -> csv.getMaxLineSize.toString + ) + csv.getValueTreatedAsNull.ifPresent(nullValue => opts("nullValue") = nullValue) + (new CSVFileFormat, opts.toMap) + case _: FileFormat.ParquetReadOptions => (new ParquetFileFormat(), Map.empty[String, String]) + case _: FileFormat.OrcReadOptions => (new OrcFileFormat(), Map.empty[String, String]) + case format => + throw new UnsupportedOperationException(s"File format not currently supported: $format") + } new LogicalRelation( relation = HadoopFsRelation( location = new InMemoryFileIndex( @@ -341,8 +372,8 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan] partitionSchema = new StructType(), dataSchema = schema, bucketSpec = None, - fileFormat = new CSVFileFormat(), - options = Map() + fileFormat = format, + options = options )(spark), output = output, catalogTable = None, diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala index 2827a9c30..5770bac33 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala @@ -26,6 +26,9 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat +import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation} import org.apache.spark.sql.types.{NullType, StructType} @@ -378,8 +381,25 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { private def buildLocalFileScan(fsRelation: HadoopFsRelation): relation.AbstractReadRel = { val namedStruct = toNamedStruct(fsRelation.schema) - val ff = new FileFormat.ParquetReadOptions { - override def toString: String = "csv" // TODO this is hardcoded at the moment + val format = fsRelation.fileFormat match { + case _: CSVFileFormat => + new FileFormat.DelimiterSeparatedTextReadOptions { + // default values for options specified here: + // https://spark.apache.org/docs/latest/sql-data-sources-csv.html#data-source-option + override def getFieldDelimiter: String = fsRelation.options.getOrElse("delimiter", ",") + override def getMaxLineSize: Long = + fsRelation.options.getOrElse("maxColumns", "20480").toLong + override def getQuote: String = fsRelation.options.getOrElse("quote", "\"") + override def getHeaderLinesToSkip: Long = + if (fsRelation.options.getOrElse("header", false) == false) 0 else 1 + override def getEscape: String = fsRelation.options.getOrElse("escape", "\\") + override def getValueTreatedAsNull: Optional[String] = + Optional.ofNullable(fsRelation.options.get("nullValue").orNull) + } + case _: ParquetFileFormat => new FileFormat.ParquetReadOptions {} + case _: OrcFileFormat => new FileFormat.OrcReadOptions {} + case format => + throw new UnsupportedOperationException(s"File format not currently supported: $format") } relation.LocalFiles @@ -391,7 +411,7 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { file => { ImmutableFileOrFiles .builder() - .fileFormat(ff) + .fileFormat(format) .partitionIndex(0) .start(0) .length(fsRelation.sizeInBytes) diff --git a/spark/src/test/resources/csv/dataset-a.csv b/spark/src/test/resources/csv/dataset-a.csv new file mode 100644 index 000000000..79ddefed3 --- /dev/null +++ b/spark/src/test/resources/csv/dataset-a.csv @@ -0,0 +1,11 @@ +ID,VALUE +1,one +2,two +3,three +4,four +5,five +6,six +7,seven +8,eight +9,nine +10,ten diff --git a/spark/src/test/resources/csv/dataset-b.csv b/spark/src/test/resources/csv/dataset-b.csv new file mode 100644 index 000000000..7b94f0896 --- /dev/null +++ b/spark/src/test/resources/csv/dataset-b.csv @@ -0,0 +1,11 @@ +ID,VALUE +11,eleven +12,twelve +13,thirteen +14,fourteen +15,fifteen +16,sixteen +17,seventeen +18,eighteen +19,nineteen +20,twenty diff --git a/spark/src/test/resources/dataset-a.csv b/spark/src/test/resources/dataset-a.csv new file mode 100644 index 000000000..e868182ad --- /dev/null +++ b/spark/src/test/resources/dataset-a.csv @@ -0,0 +1,11 @@ +ID,VALUE +1,one +2,"two" +3,"three" +4,"fo,ur" +5,five +6,six +7,seven +8,eight +9,nine +10,ten diff --git a/spark/src/test/resources/dataset-a.orc b/spark/src/test/resources/dataset-a.orc new file mode 100644 index 000000000..847557119 Binary files /dev/null and b/spark/src/test/resources/dataset-a.orc differ diff --git a/spark/src/test/resources/dataset-a.parquet b/spark/src/test/resources/dataset-a.parquet new file mode 100644 index 000000000..fec9fcf57 Binary files /dev/null and b/spark/src/test/resources/dataset-a.parquet differ diff --git a/spark/src/test/resources/dataset-a.txt b/spark/src/test/resources/dataset-a.txt new file mode 100644 index 000000000..d6dd3cf7b --- /dev/null +++ b/spark/src/test/resources/dataset-a.txt @@ -0,0 +1,10 @@ +1|one +2|two +3|three +4|'fo|ur' +5|five +6|'six' +7|seven +8|eight +9|nine +10|ten diff --git a/spark/src/test/scala/io/substrait/spark/LocalFiles.scala b/spark/src/test/scala/io/substrait/spark/LocalFiles.scala new file mode 100644 index 000000000..b0f8cf418 --- /dev/null +++ b/spark/src/test/scala/io/substrait/spark/LocalFiles.scala @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.substrait.spark + +import io.substrait.spark.logical.{ToLogicalPlan, ToSubstraitRel} + +import org.apache.spark.sql.{Dataset, DatasetUtil, Row} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} + +import io.substrait.plan.{PlanProtoConverter, ProtoPlanConverter} + +import java.nio.file.Paths + +class LocalFiles extends SharedSparkSession { + override def beforeAll(): Unit = { + super.beforeAll() + sparkContext.setLogLevel("WARN") + + conf.setConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED, false) + // introduced in spark 3.4 + spark.conf.set("spark.sql.readSideCharPadding", "false") + } + + def assertRoundTrip(data: Dataset[Row]): Dataset[Row] = { + val toSubstrait = new ToSubstraitRel + val sparkPlan = data.queryExecution.optimizedPlan + val substraitPlan = toSubstrait.convert(sparkPlan) + + // Serialize to proto buffer + val bytes = new PlanProtoConverter() + .toProto(substraitPlan) + .toByteArray + + // Read it back + val protoPlan = io.substrait.proto.Plan + .parseFrom(bytes) + val substraitPlan2 = new ProtoPlanConverter().from(protoPlan) + + val sparkPlan2 = new ToLogicalPlan(spark).convert(substraitPlan2) + val result = DatasetUtil.fromLogicalPlan(spark, sparkPlan2) + + assertResult(data.columns)(result.columns) + assertResult(data.count)(result.count) + data.collect().zip(result.collect()).foreach { + case (before, after) => assertResult(before)(after) + } + result + } + + test("CSV with header") { + val table = spark.read + .option("header", true) + .option("inferSchema", true) + .csv(Paths.get("src/test/resources/dataset-a.csv").toAbsolutePath.toString) + + assertRoundTrip(table) + } + + test("CSV null value") { + val table = spark.read + .option("header", true) + .option("inferSchema", true) + .option("nullValue", "seven") + .csv(Paths.get("src/test/resources/dataset-a.csv").toAbsolutePath.toString) + + val result = assertRoundTrip(table) + val id = result.filter("isnull(VALUE)").head().get(0) + + assertResult(id)(7) + } + + test("Pipe delimited values") { + val schema = StructType( + StructField("ID", IntegerType, false) :: + StructField("VALUE", StringType, true) :: Nil) + + val table: Dataset[Row] = spark.read + .schema(schema) + .option("delimiter", "|") + .option("quote", "'") + .csv(Paths.get("src/test/resources/dataset-a.txt").toAbsolutePath.toString) + + assertRoundTrip(table) + } + + test("Read csv folder") { + val table = spark.read + .option("header", true) + .option("inferSchema", true) + .csv(Paths.get("src/test/resources/csv/").toAbsolutePath.toString) + + assertRoundTrip(table) + } + + test("Read parquet file") { + val table = spark.read + .parquet(Paths.get("src/test/resources/dataset-a.parquet").toAbsolutePath.toString) + + assertRoundTrip(table) + } + + test("Read orc file") { + val table = spark.read + .orc(Paths.get("src/test/resources/dataset-a.orc").toAbsolutePath.toString) + + assertRoundTrip(table) + } + + test("Join tables from different formats") { + val csv = spark.read + .option("header", true) + .option("inferSchema", true) + .csv(Paths.get("src/test/resources/dataset-a.csv").toAbsolutePath.toString) + + val orc = spark.read + .orc(Paths.get("src/test/resources/dataset-a.orc").toAbsolutePath.toString) + .withColumnRenamed("ID", "ID_B") + .withColumnRenamed("VALUE", "VALUE_B"); + + val both = csv + .join(orc, csv.col("ID").equalTo(orc.col("ID_B"))) + .select("ID", "VALUE", "VALUE_B") + + assertRoundTrip(both) + } +} diff --git a/spark/src/test/spark-3.2/org/apache/spark/sql/DatasetUtil.scala b/spark/src/test/spark-3.2/org/apache/spark/sql/DatasetUtil.scala new file mode 100644 index 000000000..c56185b12 --- /dev/null +++ b/spark/src/test/spark-3.2/org/apache/spark/sql/DatasetUtil.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +object DatasetUtil { + def fromLogicalPlan(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame = { + sparkSession.withActive { + val qe = sparkSession.sessionState.executePlan(logicalPlan) + qe.assertAnalyzed() + new Dataset[Row](qe, RowEncoder(qe.analyzed.schema)) + } + } +}