From 2b39b845c666305322dcceafc4fff1b0e2c483e8 Mon Sep 17 00:00:00 2001 From: niebayes Date: Thu, 20 Feb 2025 19:59:53 +0800 Subject: [PATCH] fix: Substrait serializer clippy error: not calling truncate (#14723) * specify truncate true * add error handling * Apply suggestions from code review Co-authored-by: Matthijs Brobbel * remove substrait from error messages * Apply suggestions from code review Co-authored-by: Matthijs Brobbel * simplify serialize * fix ut * Update datafusion/substrait/tests/cases/serialize.rs Co-authored-by: Matthijs Brobbel * apply part of follow-up suggestions --------- Co-authored-by: Matthijs Brobbel --- datafusion/substrait/Cargo.toml | 1 + datafusion/substrait/src/serializer.rs | 45 +++++++++++++------ datafusion/substrait/tests/cases/serialize.rs | 20 +++++++++ 3 files changed, 52 insertions(+), 14 deletions(-) diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index f13d2b77a787..3e3ea7843ac9 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -41,6 +41,7 @@ pbjson-types = { workspace = true } prost = { workspace = true } substrait = { version = "0.53", features = ["serde"] } url = { workspace = true } +tokio = { workspace = true, features = ["fs"] } [dev-dependencies] datafusion = { workspace = true, features = ["nested_expressions"] } diff --git a/datafusion/substrait/src/serializer.rs b/datafusion/substrait/src/serializer.rs index 4278671777fd..4a9e5d55ce05 100644 --- a/datafusion/substrait/src/serializer.rs +++ b/datafusion/substrait/src/serializer.rs @@ -22,42 +22,59 @@ use datafusion::error::Result; use datafusion::prelude::*; use prost::Message; +use std::path::Path; use substrait::proto::Plan; +use tokio::{ + fs::OpenOptions, + io::{AsyncReadExt, AsyncWriteExt}, +}; -use std::fs::OpenOptions; -use std::io::{Read, Write}; +/// Plans a sql and serializes the generated logical plan to bytes. +/// The bytes are then written into a file at `path`. +/// +/// Returns an error if the file already exists. +pub async fn serialize( + sql: &str, + ctx: &SessionContext, + path: impl AsRef, +) -> Result<()> { + let protobuf_out = serialize_bytes(sql, ctx).await?; -#[allow(clippy::suspicious_open_options)] -pub async fn serialize(sql: &str, ctx: &SessionContext, path: &str) -> Result<()> { - let protobuf_out = serialize_bytes(sql, ctx).await; - let mut file = OpenOptions::new().create(true).write(true).open(path)?; - file.write_all(&protobuf_out?)?; + let mut file = OpenOptions::new() + .write(true) + .create_new(true) + .open(path) + .await?; + file.write_all(&protobuf_out).await?; Ok(()) } +/// Plans a sql and serializes the generated logical plan to bytes. pub async fn serialize_bytes(sql: &str, ctx: &SessionContext) -> Result> { let df = ctx.sql(sql).await?; let plan = df.into_optimized_plan()?; let proto = producer::to_substrait_plan(&plan, &ctx.state())?; let mut protobuf_out = Vec::::new(); - proto.encode(&mut protobuf_out).map_err(|e| { - DataFusionError::Substrait(format!("Failed to encode substrait plan: {e}")) - })?; + proto + .encode(&mut protobuf_out) + .map_err(|e| DataFusionError::Substrait(format!("Failed to encode plan: {e}")))?; Ok(protobuf_out) } -pub async fn deserialize(path: &str) -> Result> { +/// Reads the file at `path` and deserializes a plan from the bytes. +pub async fn deserialize(path: impl AsRef) -> Result> { let mut protobuf_in = Vec::::new(); - let mut file = OpenOptions::new().read(true).open(path)?; + let mut file = OpenOptions::new().read(true).open(path).await?; + file.read_to_end(&mut protobuf_in).await?; - file.read_to_end(&mut protobuf_in)?; deserialize_bytes(protobuf_in).await } +/// Deserializes a plan from the bytes. pub async fn deserialize_bytes(proto_bytes: Vec) -> Result> { Ok(Box::new(Message::decode(&*proto_bytes).map_err(|e| { - DataFusionError::Substrait(format!("Failed to decode substrait plan: {e}")) + DataFusionError::Substrait(format!("Failed to decode plan: {e}")) })?)) } diff --git a/datafusion/substrait/tests/cases/serialize.rs b/datafusion/substrait/tests/cases/serialize.rs index e28c63312788..02089b9fa92d 100644 --- a/datafusion/substrait/tests/cases/serialize.rs +++ b/datafusion/substrait/tests/cases/serialize.rs @@ -17,6 +17,7 @@ #[cfg(test)] mod tests { + use datafusion::common::assert_contains; use datafusion::datasource::provider_as_source; use datafusion::logical_expr::LogicalPlanBuilder; use datafusion_substrait::logical_plan::consumer::from_substrait_plan; @@ -31,6 +32,25 @@ mod tests { use substrait::proto::rel_common::{Emit, EmitKind}; use substrait::proto::{rel, RelCommon}; + #[tokio::test] + async fn serialize_to_file() -> Result<()> { + let ctx = create_context().await?; + let path = "tests/serialize_to_file.bin"; + let sql = "SELECT a, b FROM data"; + + // Test case 1: serializing to a non-existing file should succeed. + serializer::serialize(sql, &ctx, path).await?; + serializer::deserialize(path).await?; + + // Test case 2: serializing to an existing file should fail. + let got = serializer::serialize(sql, &ctx, path).await.unwrap_err(); + assert_contains!(got.to_string(), "File exists"); + + fs::remove_file(path)?; + + Ok(()) + } + #[tokio::test] async fn serialize_simple_select() -> Result<()> { let ctx = create_context().await?;