Skip to content

Commit

Permalink
feat(spark): add support for DelimiterSeparatedTextReadOptions
Browse files Browse the repository at this point in the history
The handling of CSV in the spark module works only in a very
limited way with many hard-coded assumptions.
This commit adds full support for delimited text support as
defined in the `FileOrFiles` proto message

Signed-off-by: Andrew Coleman <andrew_coleman@uk.ibm.com>
  • Loading branch information
andrew-coleman committed Jan 30, 2025
1 parent a0ef1dd commit fb4bc88
Show file tree
Hide file tree
Showing 15 changed files with 320 additions and 8 deletions.
2 changes: 1 addition & 1 deletion .editorconfig
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ trim_trailing_whitespace = true
[*.{yaml,yml}]
indent_size = 2

[{**/*.sql,**/OuterReferenceResolver.md,**gradlew.bat}]
[{**/*.sql,**/OuterReferenceResolver.md,**gradlew.bat,**/*.parquet,**/*.orc}]
charset = unset
end_of_line = unset
insert_final_newline = unset
Expand Down
12 changes: 11 additions & 1 deletion core/src/main/java/io/substrait/relation/ProtoRelConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,17 @@ protected FileOrFiles newFileOrFiles(ReadRel.LocalFiles.FileOrFiles file) {
} else if (file.hasDwrf()) {
builder.fileFormat(ImmutableFileFormat.DwrfReadOptions.builder().build());
} else if (file.hasText()) {
throw new RuntimeException("Delimiter separated text files not supported yet"); // TODO
var ffBuilder =
ImmutableFileFormat.DelimiterSeparatedTextReadOptions.builder()
.fieldDelimiter(file.getText().getFieldDelimiter())
.maxLineSize(file.getText().getMaxLineSize())
.quote(file.getText().getQuote())
.headerLinesToSkip(file.getText().getHeaderLinesToSkip())
.escape(file.getText().getEscape());
if (file.getText().hasValueTreatedAsNull()) {
ffBuilder.valueTreatedAsNull(file.getText().getValueTreatedAsNull());
}
builder.fileFormat(ffBuilder.build());
} else if (file.hasExtension()) {
builder.fileFormat(
ImmutableFileFormat.Extension.builder().extension(file.getExtension()).build());
Expand Down
16 changes: 16 additions & 0 deletions core/src/main/java/io/substrait/relation/files/FileFormat.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.substrait.relation.files;

import java.util.Optional;
import org.immutables.value.Value;

@Value.Enclosing
Expand All @@ -17,6 +18,21 @@ abstract static class OrcReadOptions implements FileFormat {}
@Value.Immutable
abstract static class DwrfReadOptions implements FileFormat {}

@Value.Immutable
abstract static class DelimiterSeparatedTextReadOptions implements FileFormat {
public abstract String getFieldDelimiter();

public abstract long getMaxLineSize();

public abstract String getQuote();

public abstract long getHeaderLinesToSkip();

public abstract String getEscape();

public abstract Optional<String> getValueTreatedAsNull();
}

@Value.Immutable
abstract static class Extension implements FileFormat {
public abstract com.google.protobuf.Any getExtension();
Expand Down
11 changes: 11 additions & 0 deletions core/src/main/java/io/substrait/relation/files/FileOrFiles.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,17 @@ default ReadRel.LocalFiles.FileOrFiles toProto() {
} else if (fileFormat instanceof FileFormat.DwrfReadOptions options) {
builder.setDwrf(
ReadRel.LocalFiles.FileOrFiles.DwrfReadOptions.newBuilder().build());
} else if (fileFormat
instanceof FileFormat.DelimiterSeparatedTextReadOptions options) {
var optionsBuilder =
ReadRel.LocalFiles.FileOrFiles.DelimiterSeparatedTextReadOptions.newBuilder()
.setFieldDelimiter(options.getFieldDelimiter())
.setMaxLineSize(options.getMaxLineSize())
.setQuote(options.getQuote())
.setHeaderLinesToSkip(options.getHeaderLinesToSkip())
.setEscape(options.getEscape());
options.getValueTreatedAsNull().ifPresent(optionsBuilder::setValueTreatedAsNull);
builder.setText(optionsBuilder.build());
} else if (fileFormat instanceof FileFormat.Extension options) {
builder.setExtension(options.getExtension());
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,14 @@ private ImmutableFileOrFiles.Builder setFileFormat(
case ARROW -> builder.fileFormat(ImmutableFileFormat.ArrowReadOptions.builder().build());
case ORC -> builder.fileFormat(ImmutableFileFormat.OrcReadOptions.builder().build());
case DWRF -> builder.fileFormat(ImmutableFileFormat.DwrfReadOptions.builder().build());
case TEXT -> builder; // TODO
case TEXT -> builder.fileFormat(
ImmutableFileFormat.DelimiterSeparatedTextReadOptions.builder()
.fieldDelimiter("|")
.maxLineSize(1000)
.quote("\"")
.headerLinesToSkip(1)
.escape("\\")
.build());
case EXTENSION -> builder.fileFormat(
ImmutableFileFormat.Extension.builder()
.extension(com.google.protobuf.Any.newBuilder().build())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ import org.apache.spark.sql.catalyst.util.toPrettySQL
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, InMemoryFileIndex, LogicalRelation}
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.types.{DataTypes, IntegerType, StructField, StructType}

import io.substrait.`type`.{StringTypeVisitor, Type}
Expand All @@ -40,6 +42,7 @@ import io.substrait.relation
import io.substrait.relation.Expand.{ConsistentField, SwitchingField}
import io.substrait.relation.LocalFiles
import io.substrait.relation.Set.SetOp
import io.substrait.relation.files.FileFormat
import org.apache.hadoop.fs.Path

import scala.collection.JavaConverters.asScalaBufferConverter
Expand Down Expand Up @@ -331,6 +334,34 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan]
override def visit(localFiles: LocalFiles): LogicalPlan = {
val schema = ToSubstraitType.toStructType(localFiles.getInitialSchema)
val output = schema.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())

// spark requires that all files have the same format
val formats = localFiles.getItems.asScala.map(i => i.getFileFormat.orElse(null)).distinct
if (formats.length != 1) {
throw new UnsupportedOperationException(s"All files must have the same format")
}
val (format, options) = formats.head match {
case csv: FileFormat.DelimiterSeparatedTextReadOptions =>
val opts = scala.collection.mutable.Map[String, String](
"delimiter" -> csv.getFieldDelimiter,
"quote" -> csv.getQuote,
"header" -> (csv.getHeaderLinesToSkip match {
case 0 => "false"
case 1 => "true"
case _ =>
throw new UnsupportedOperationException(
s"Cannot configure CSV reader to skip ${csv.getHeaderLinesToSkip} rows")
}),
"escape" -> csv.getEscape,
"maxColumns" -> csv.getMaxLineSize.toString
)
csv.getValueTreatedAsNull.ifPresent(nullValue => opts("nullValue") = nullValue)
(new CSVFileFormat, opts.toMap)
case _: FileFormat.ParquetReadOptions => (new ParquetFileFormat(), Map.empty[String, String])
case _: FileFormat.OrcReadOptions => (new OrcFileFormat(), Map.empty[String, String])
case format =>
throw new UnsupportedOperationException(s"File format not currently supported: $format")
}
new LogicalRelation(
relation = HadoopFsRelation(
location = new InMemoryFileIndex(
Expand All @@ -341,8 +372,8 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan]
partitionSchema = new StructType(),
dataSchema = schema,
bucketSpec = None,
fileFormat = new CSVFileFormat(),
options = Map()
fileFormat = format,
options = options
)(spark),
output = output,
catalogTable = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation}
import org.apache.spark.sql.types.{NullType, StructType}

Expand Down Expand Up @@ -378,8 +381,25 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
private def buildLocalFileScan(fsRelation: HadoopFsRelation): relation.AbstractReadRel = {
val namedStruct = toNamedStruct(fsRelation.schema)

val ff = new FileFormat.ParquetReadOptions {
override def toString: String = "csv" // TODO this is hardcoded at the moment
val format = fsRelation.fileFormat match {
case _: CSVFileFormat =>
new FileFormat.DelimiterSeparatedTextReadOptions {
// default values for options specified here:
// https://spark.apache.org/docs/latest/sql-data-sources-csv.html#data-source-option
override def getFieldDelimiter: String = fsRelation.options.getOrElse("delimiter", ",")
override def getMaxLineSize: Long =
fsRelation.options.getOrElse("maxColumns", "20480").toLong
override def getQuote: String = fsRelation.options.getOrElse("quote", "\"")
override def getHeaderLinesToSkip: Long =
if (fsRelation.options.getOrElse("header", false) == false) 0 else 1
override def getEscape: String = fsRelation.options.getOrElse("escape", "\\")
override def getValueTreatedAsNull: Optional[String] =
Optional.ofNullable(fsRelation.options.get("nullValue").orNull)
}
case _: ParquetFileFormat => new FileFormat.ParquetReadOptions {}
case _: OrcFileFormat => new FileFormat.OrcReadOptions {}
case format =>
throw new UnsupportedOperationException(s"File format not currently supported: $format")
}

relation.LocalFiles
Expand All @@ -391,7 +411,7 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
file => {
ImmutableFileOrFiles
.builder()
.fileFormat(ff)
.fileFormat(format)
.partitionIndex(0)
.start(0)
.length(fsRelation.sizeInBytes)
Expand Down
11 changes: 11 additions & 0 deletions spark/src/test/resources/csv/dataset-a.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
ID,VALUE
1,one
2,two
3,three
4,four
5,five
6,six
7,seven
8,eight
9,nine
10,ten
11 changes: 11 additions & 0 deletions spark/src/test/resources/csv/dataset-b.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
ID,VALUE
11,eleven
12,twelve
13,thirteen
14,fourteen
15,fifteen
16,sixteen
17,seventeen
18,eighteen
19,nineteen
20,twenty
11 changes: 11 additions & 0 deletions spark/src/test/resources/dataset-a.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
ID,VALUE
1,one
2,"two"
3,"three"
4,"fo,ur"
5,five
6,six
7,seven
8,eight
9,nine
10,ten
Binary file added spark/src/test/resources/dataset-a.orc
Binary file not shown.
Binary file added spark/src/test/resources/dataset-a.parquet
Binary file not shown.
10 changes: 10 additions & 0 deletions spark/src/test/resources/dataset-a.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
1|one
2|two
3|three
4|'fo|ur'
5|five
6|'six'
7|seven
8|eight
9|nine
10|ten
143 changes: 143 additions & 0 deletions spark/src/test/scala/io/substrait/spark/LocalFiles.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to you under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.substrait.spark

import io.substrait.spark.logical.{ToLogicalPlan, ToSubstraitRel}

import org.apache.spark.sql.{Dataset, DatasetUtil, Row}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

import io.substrait.plan.{PlanProtoConverter, ProtoPlanConverter}

import java.nio.file.Paths

class LocalFiles extends SharedSparkSession {
override def beforeAll(): Unit = {
super.beforeAll()
sparkContext.setLogLevel("WARN")

conf.setConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED, false)
// introduced in spark 3.4
spark.conf.set("spark.sql.readSideCharPadding", "false")
}

def assertRoundTrip(data: Dataset[Row]): Dataset[Row] = {
val toSubstrait = new ToSubstraitRel
val sparkPlan = data.queryExecution.optimizedPlan
val substraitPlan = toSubstrait.convert(sparkPlan)

// Serialize to proto buffer
val bytes = new PlanProtoConverter()
.toProto(substraitPlan)
.toByteArray

// Read it back
val protoPlan = io.substrait.proto.Plan
.parseFrom(bytes)
val substraitPlan2 = new ProtoPlanConverter().from(protoPlan)

val sparkPlan2 = new ToLogicalPlan(spark).convert(substraitPlan2)
val result = DatasetUtil.fromLogicalPlan(spark, sparkPlan2)

assertResult(data.columns)(result.columns)
assertResult(data.count)(result.count)
data.collect().zip(result.collect()).foreach {
case (before, after) => assertResult(before)(after)
}
result
}

test("CSV with header") {
val table = spark.read
.option("header", true)
.option("inferSchema", true)
.csv(Paths.get("src/test/resources/dataset-a.csv").toAbsolutePath.toString)

assertRoundTrip(table)
}

test("CSV null value") {
val table = spark.read
.option("header", true)
.option("inferSchema", true)
.option("nullValue", "seven")
.csv(Paths.get("src/test/resources/dataset-a.csv").toAbsolutePath.toString)

val result = assertRoundTrip(table)
val id = result.filter("isnull(VALUE)").head().get(0)

assertResult(id)(7)
}

test("Pipe delimited values") {
val schema = StructType(
StructField("ID", IntegerType, false) ::
StructField("VALUE", StringType, true) :: Nil)

val table: Dataset[Row] = spark.read
.schema(schema)
.option("delimiter", "|")
.option("quote", "'")
.csv(Paths.get("src/test/resources/dataset-a.txt").toAbsolutePath.toString)

assertRoundTrip(table)
}

test("Read csv folder") {
val table = spark.read
.option("header", true)
.option("inferSchema", true)
.csv(Paths.get("src/test/resources/csv/").toAbsolutePath.toString)

assertRoundTrip(table)
}

test("Read parquet file") {
val table = spark.read
.parquet(Paths.get("src/test/resources/dataset-a.parquet").toAbsolutePath.toString)

assertRoundTrip(table)
}

test("Read orc file") {
val table = spark.read
.orc(Paths.get("src/test/resources/dataset-a.orc").toAbsolutePath.toString)

assertRoundTrip(table)
}

test("Join tables from different formats") {
val csv = spark.read
.option("header", true)
.option("inferSchema", true)
.csv(Paths.get("src/test/resources/dataset-a.csv").toAbsolutePath.toString)

val orc = spark.read
.orc(Paths.get("src/test/resources/dataset-a.orc").toAbsolutePath.toString)
.withColumnRenamed("ID", "ID_B")
.withColumnRenamed("VALUE", "VALUE_B");

val both = csv
.join(orc, csv.col("ID").equalTo(orc.col("ID_B")))
.select("ID", "VALUE", "VALUE_B")

assertRoundTrip(both)
}
}
Loading

0 comments on commit fb4bc88

Please sign in to comment.