From 22e23cffc837503afc544731ef126acea98c9880 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sava=20Vrane=C5=A1evi=C4=87?= <20240220+svranesevic@users.noreply.github.com> Date: Tue, 11 Jun 2024 13:58:26 +0200 Subject: [PATCH] Fix propagation of CSV options through protos (#245) * Fix sink output schema being passed in to `FileSinkExec` instead of sink input schema * Expose double_quote csv option, and ensure all csv_options are propagated through logical/physical plans --------- Co-authored-by: svranesevic --- datafusion/common/src/config.rs | 1 + .../common/src/file_options/csv_writer.rs | 7 +- datafusion/proto/proto/datafusion.proto | 7 ++ datafusion/proto/src/generated/pbjson.rs | 70 +++++++++++++++++++ datafusion/proto/src/generated/prost.rs | 12 ++++ datafusion/proto/src/logical_plan/mod.rs | 28 +++++++- .../proto/src/physical_plan/from_proto.rs | 1 + datafusion/proto/src/physical_plan/mod.rs | 12 ++-- .../proto/src/physical_plan/to_proto.rs | 1 + 9 files changed, 131 insertions(+), 8 deletions(-) diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 968d8215ca4d..6f917ee09d7f 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -1520,6 +1520,7 @@ config_namespace! { pub delimiter: u8, default = b',' pub quote: u8, default = b'"' pub escape: Option, default = None + pub double_quote: bool, default = true pub compression: CompressionTypeVariant, default = CompressionTypeVariant::UNCOMPRESSED pub schema_infer_max_rec: usize, default = 100 pub date_format: Option, default = None diff --git a/datafusion/common/src/file_options/csv_writer.rs b/datafusion/common/src/file_options/csv_writer.rs index 5f1a62682f8d..d10d2a737d49 100644 --- a/datafusion/common/src/file_options/csv_writer.rs +++ b/datafusion/common/src/file_options/csv_writer.rs @@ -51,8 +51,13 @@ impl TryFrom<&CsvOptions> for CsvWriterOptions { fn try_from(value: &CsvOptions) -> Result { let mut builder = WriterBuilder::default() .with_header(value.has_header) - .with_delimiter(value.delimiter); + .with_quote(value.quote) + .with_delimiter(value.delimiter) + .with_double_quote(value.double_quote); + if let Some(v) = &value.escape { + builder = builder.with_escape(*v) + } if let Some(v) = &value.date_format { builder = builder.with_date_format(v.into()) } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index e9d170f30851..dc520032bf82 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1255,6 +1255,12 @@ message CsvWriterOptions { string time_format = 7; // Optional value to represent null string null_value = 8; + // Optional quote. Defaults to `b'"'` + string quote = 9; + // Optional escape. Defaults to `'\\'` + string escape = 10; + // Optional flag whether to double quote instead of escaping. Defaults to `true` + bool double_quote = 11; } // Options controlling CSV format @@ -1271,6 +1277,7 @@ message CsvOptions { string timestamp_tz_format = 10; // Optional timestamp with timezone format string time_format = 11; // Optional time format string null_value = 12; // Optional representation of null value + bool double_quote = 13; // Indicates whether to use double quotes instead of escaping } // Options controlling CSV format diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 3b9bfb9750a1..92f919e1edab 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -5498,6 +5498,9 @@ impl serde::Serialize for CsvOptions { if !self.null_value.is_empty() { len += 1; } + if self.double_quote { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.CsvOptions", len)?; if self.has_header { struct_ser.serialize_field("hasHeader", &self.has_header)?; @@ -5541,6 +5544,9 @@ impl serde::Serialize for CsvOptions { if !self.null_value.is_empty() { struct_ser.serialize_field("nullValue", &self.null_value)?; } + if self.double_quote { + struct_ser.serialize_field("doubleQuote", &self.double_quote)?; + } struct_ser.end() } } @@ -5571,6 +5577,8 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { "timeFormat", "null_value", "nullValue", + "double_quote", + "doubleQuote", ]; #[allow(clippy::enum_variant_names)] @@ -5587,6 +5595,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { TimestampTzFormat, TimeFormat, NullValue, + DoubleQuote, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -5620,6 +5629,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { "timestampTzFormat" | "timestamp_tz_format" => Ok(GeneratedField::TimestampTzFormat), "timeFormat" | "time_format" => Ok(GeneratedField::TimeFormat), "nullValue" | "null_value" => Ok(GeneratedField::NullValue), + "doubleQuote" | "double_quote" => Ok(GeneratedField::DoubleQuote), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -5651,6 +5661,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { let mut timestamp_tz_format__ = None; let mut time_format__ = None; let mut null_value__ = None; + let mut double_quote__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::HasHeader => { @@ -5733,6 +5744,12 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { } null_value__ = Some(map_.next_value()?); } + GeneratedField::DoubleQuote => { + if double_quote__.is_some() { + return Err(serde::de::Error::duplicate_field("doubleQuote")); + } + double_quote__ = Some(map_.next_value()?); + } } } Ok(CsvOptions { @@ -5748,6 +5765,7 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { timestamp_tz_format: timestamp_tz_format__.unwrap_or_default(), time_format: time_format__.unwrap_or_default(), null_value: null_value__.unwrap_or_default(), + double_quote: double_quote__.unwrap_or_default(), }) } } @@ -6204,6 +6222,15 @@ impl serde::Serialize for CsvWriterOptions { if !self.null_value.is_empty() { len += 1; } + if !self.quote.is_empty() { + len += 1; + } + if !self.escape.is_empty() { + len += 1; + } + if self.double_quote { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.CsvWriterOptions", len)?; if self.compression != 0 { let v = CompressionTypeVariant::try_from(self.compression) @@ -6231,6 +6258,15 @@ impl serde::Serialize for CsvWriterOptions { if !self.null_value.is_empty() { struct_ser.serialize_field("nullValue", &self.null_value)?; } + if !self.quote.is_empty() { + struct_ser.serialize_field("quote", &self.quote)?; + } + if !self.escape.is_empty() { + struct_ser.serialize_field("escape", &self.escape)?; + } + if self.double_quote { + struct_ser.serialize_field("doubleQuote", &self.double_quote)?; + } struct_ser.end() } } @@ -6255,6 +6291,10 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions { "timeFormat", "null_value", "nullValue", + "quote", + "escape", + "double_quote", + "doubleQuote", ]; #[allow(clippy::enum_variant_names)] @@ -6267,6 +6307,9 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions { TimestampFormat, TimeFormat, NullValue, + Quote, + Escape, + DoubleQuote, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -6296,6 +6339,9 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions { "timestampFormat" | "timestamp_format" => Ok(GeneratedField::TimestampFormat), "timeFormat" | "time_format" => Ok(GeneratedField::TimeFormat), "nullValue" | "null_value" => Ok(GeneratedField::NullValue), + "quote" => Ok(GeneratedField::Quote), + "escape" => Ok(GeneratedField::Escape), + "doubleQuote" | "double_quote" => Ok(GeneratedField::DoubleQuote), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -6323,6 +6369,9 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions { let mut timestamp_format__ = None; let mut time_format__ = None; let mut null_value__ = None; + let mut quote__ = None; + let mut escape__ = None; + let mut double_quote__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Compression => { @@ -6373,6 +6422,24 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions { } null_value__ = Some(map_.next_value()?); } + GeneratedField::Quote => { + if quote__.is_some() { + return Err(serde::de::Error::duplicate_field("quote")); + } + quote__ = Some(map_.next_value()?); + } + GeneratedField::Escape => { + if escape__.is_some() { + return Err(serde::de::Error::duplicate_field("escape")); + } + escape__ = Some(map_.next_value()?); + } + GeneratedField::DoubleQuote => { + if double_quote__.is_some() { + return Err(serde::de::Error::duplicate_field("doubleQuote")); + } + double_quote__ = Some(map_.next_value()?); + } } } Ok(CsvWriterOptions { @@ -6384,6 +6451,9 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions { timestamp_format: timestamp_format__.unwrap_or_default(), time_format: time_format__.unwrap_or_default(), null_value: null_value__.unwrap_or_default(), + quote: quote__.unwrap_or_default(), + escape: escape__.unwrap_or_default(), + double_quote: double_quote__.unwrap_or_default(), }) } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 4375c3226c84..95ca468b34ad 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1699,6 +1699,15 @@ pub struct CsvWriterOptions { /// Optional value to represent null #[prost(string, tag = "8")] pub null_value: ::prost::alloc::string::String, + /// Optional quote. Defaults to `b'"'` + #[prost(string, tag = "9")] + pub quote: ::prost::alloc::string::String, + /// Optional escape. Defaults to `'\\'` + #[prost(string, tag = "10")] + pub escape: ::prost::alloc::string::String, + /// Optional flag whether to double quote instead of escaping. Defaults to `true` + #[prost(bool, tag = "11")] + pub double_quote: bool, } /// Options controlling CSV format #[allow(clippy::derive_partial_eq_without_eq)] @@ -1740,6 +1749,9 @@ pub struct CsvOptions { /// Optional representation of null value #[prost(string, tag = "12")] pub null_value: ::prost::alloc::string::String, + /// Indicates whether to use double quotes instead of escaping + #[prost(bool, tag = "13")] + pub double_quote: bool, } /// Options controlling CSV format #[allow(clippy::derive_partial_eq_without_eq)] diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 9b3b677e3c0a..2ebde11bb4df 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -1626,6 +1626,9 @@ pub(crate) fn csv_writer_options_to_proto( timestamp_format: csv_options.timestamp_format().unwrap_or("").to_owned(), time_format: csv_options.time_format().unwrap_or("").to_owned(), null_value: csv_options.null().to_owned(), + quote: (csv_options.quote() as char).to_string(), + escape: (csv_options.escape() as char).to_string(), + double_quote: csv_options.double_quote(), } } @@ -1644,11 +1647,34 @@ pub(crate) fn csv_writer_options_from_proto( return Err(proto_error("Error parsing CSV Delimiter")); } } + if !writer_options.quote.is_empty() { + if let Some(quote) = writer_options.quote.chars().next() { + if quote.is_ascii() { + builder = builder.with_quote(quote as u8); + } else { + return Err(proto_error("CSV quote is not ASCII")); + } + } else { + return Err(proto_error("Error parsing CSV quote")); + } + } + if !writer_options.escape.is_empty() { + if let Some(escape) = writer_options.escape.chars().next() { + if escape.is_ascii() { + builder = builder.with_escape(escape as u8); + } else { + return Err(proto_error("CSV escape is not ASCII")); + } + } else { + return Err(proto_error("Error parsing CSV escape")); + } + } Ok(builder .with_header(writer_options.has_header) .with_date_format(writer_options.date_format.clone()) .with_datetime_format(writer_options.datetime_format.clone()) .with_timestamp_format(writer_options.timestamp_format.clone()) .with_time_format(writer_options.time_format.clone()) - .with_null(writer_options.null_value.clone())) + .with_null(writer_options.null_value.clone()) + .with_double_quote(writer_options.double_quote)) } diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index f2c5b4b080b2..46e6c34d8bde 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -872,6 +872,7 @@ impl TryFrom<&protobuf::CsvOptions> for CsvOptions { delimiter: proto_opts.delimiter[0], quote: proto_opts.quote[0], escape: proto_opts.escape.first().copied(), + double_quote: proto_opts.double_quote, compression: proto_opts.compression().into(), schema_infer_max_rec: proto_opts.schema_infer_max_rec as usize, date_format: (!proto_opts.date_format.is_empty()) diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 9c89f1744166..c3699c20cd22 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -1020,7 +1020,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { .as_ref() .ok_or_else(|| proto_error("Missing required field in protobuf"))? .try_into()?; - let sink_schema = convert_required!(sink.sink_schema)?; + let sink_schema = input.schema(); let sort_order = sink .sort_order .as_ref() @@ -1037,7 +1037,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { Ok(Arc::new(FileSinkExec::new( input, Arc::new(data_sink), - Arc::new(sink_schema), + sink_schema, sort_order, ))) } @@ -1050,7 +1050,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { .as_ref() .ok_or_else(|| proto_error("Missing required field in protobuf"))? .try_into()?; - let sink_schema = convert_required!(sink.sink_schema)?; + let sink_schema = input.schema(); let sort_order = sink .sort_order .as_ref() @@ -1067,7 +1067,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { Ok(Arc::new(FileSinkExec::new( input, Arc::new(data_sink), - Arc::new(sink_schema), + sink_schema, sort_order, ))) } @@ -1080,7 +1080,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { .as_ref() .ok_or_else(|| proto_error("Missing required field in protobuf"))? .try_into()?; - let sink_schema = convert_required!(sink.sink_schema)?; + let sink_schema = input.schema(); let sort_order = sink .sort_order .as_ref() @@ -1097,7 +1097,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { Ok(Arc::new(FileSinkExec::new( input, Arc::new(data_sink), - Arc::new(sink_schema), + sink_schema, sort_order, ))) } diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 5728fe45d9bb..24c2b66b3716 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -1070,6 +1070,7 @@ impl TryFrom<&CsvOptions> for protobuf::CsvOptions { timestamp_tz_format: opts.timestamp_tz_format.clone().unwrap_or_default(), time_format: opts.time_format.clone().unwrap_or_default(), null_value: opts.null_value.clone().unwrap_or_default(), + double_quote: opts.double_quote, }) } }