diff --git a/Cargo.lock b/Cargo.lock index 21abbb3015b5e..26764b0036d11 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -124,7 +124,7 @@ dependencies = [ [[package]] name = "arrow2" version = "0.8.1" -source = "git+https://github.com/datafuse-extras/arrow2?rev=f07cc2c#f07cc2ce06f3964ec755aa026cf62eef5e2c94b3" +source = "git+https://github.com/datafuse-extras/arrow2?rev=d14ae86#d14ae86c69cd76957adec3b14bb62d93732b43c9" dependencies = [ "ahash", "arrow-format", diff --git a/common/arrow/Cargo.toml b/common/arrow/Cargo.toml index 783b3ad9555cb..de1e5f10452d4 100644 --- a/common/arrow/Cargo.toml +++ b/common/arrow/Cargo.toml @@ -20,7 +20,7 @@ simd = ["arrow/simd"] # Workspace dependencies # Github dependencies -arrow = { package = "arrow2", git = "https://github.com/datafuse-extras/arrow2", default-features = false, rev = "f07cc2c"} +arrow = { package = "arrow2", git = "https://github.com/datafuse-extras/arrow2", default-features = false, rev = "d14ae86"} arrow-format = { version = "0.3.0", features = ["flight-data", "flight-service"] } parquet2 = { version = "0.8.1", default_features = false } # Crates.io dependencies diff --git a/query/src/api/rpc/flight_client_stream.rs b/query/src/api/rpc/flight_client_stream.rs index 228c7745444ba..569409b212ebc 100644 --- a/query/src/api/rpc/flight_client_stream.rs +++ b/query/src/api/rpc/flight_client_stream.rs @@ -53,10 +53,21 @@ impl FlightDataStream { } let arrow_schema = Arc::new(schema.to_arrow()); - Ok( - deserialize_batch(&flight_data, arrow_schema, true, &Default::default()) - .map(create_data_block)?, + let ipc_fields = common_arrow::arrow::io::ipc::write::default_ipc_fields( + &arrow_schema.fields, + ); + let ipc_schema = common_arrow::arrow::io::ipc::IpcSchema { + fields: ipc_fields, + is_little_endian: true, + }; + + Ok(deserialize_batch( + &flight_data, + arrow_schema, + &ipc_schema, + &Default::default(), ) + .map(create_data_block)?) } } }) @@ -66,7 +77,7 @@ impl FlightDataStream { #[inline] #[allow(dead_code)] pub fn from_receiver( - schema_ref: DataSchemaRef, + schema: DataSchemaRef, inner: Receiver>, ) -> impl Stream> { ReceiverStream::new(inner).map(move |flight_data| match flight_data { @@ -83,13 +94,18 @@ impl FlightDataStream { DataBlock::create(Arc::new(schema), columns) } - Ok(deserialize_batch( - &flight_data, - Arc::new(schema_ref.to_arrow()), - true, - &Default::default(), + let arrow_schema = Arc::new(schema.to_arrow()); + let ipc_fields = + common_arrow::arrow::io::ipc::write::default_ipc_fields(&arrow_schema.fields); + let ipc_schema = common_arrow::arrow::io::ipc::IpcSchema { + fields: ipc_fields, + is_little_endian: true, + }; + + Ok( + deserialize_batch(&flight_data, arrow_schema, &ipc_schema, &Default::default()) + .map(create_data_block)?, ) - .map(create_data_block)?) } }) } diff --git a/query/src/api/rpc/flight_dispatcher.rs b/query/src/api/rpc/flight_dispatcher.rs index 5f2f56c3a726c..9e3cd063e67b5 100644 --- a/query/src/api/rpc/flight_dispatcher.rs +++ b/query/src/api/rpc/flight_dispatcher.rs @@ -74,7 +74,10 @@ impl DatabendQueryFlightDispatcher { } #[tracing::instrument(level = "debug", skip_all)] - pub fn get_stream(&self, ticket: &StreamTicket) -> Result>> { + pub fn get_stream( + &self, + ticket: &StreamTicket, + ) -> Result<(mpsc::Receiver>, DataSchemaRef)> { let stage_name = format!("{}/{}", ticket.query_id, ticket.stage_id); if let Some(notify) = self.stages_notify.write().remove(&stage_name) { notify.notify_waiters(); @@ -82,7 +85,7 @@ impl DatabendQueryFlightDispatcher { let stream_name = format!("{}/{}", stage_name, ticket.stream); match self.streams.write().remove(&stream_name) { - Some(stream_info) => Ok(stream_info.rx), + Some(stream_info) => Ok((stream_info.rx, stream_info.schema)), None => Err(ErrorCode::NotFoundStream("Stream is not found")), } } diff --git a/query/src/api/rpc/flight_service.rs b/query/src/api/rpc/flight_service.rs index 19d1f9d71dcea..083500f684bc7 100644 --- a/query/src/api/rpc/flight_service.rs +++ b/query/src/api/rpc/flight_service.rs @@ -16,6 +16,8 @@ use std::convert::TryInto; use std::pin::Pin; use std::sync::Arc; +use common_arrow::arrow::io::flight::serialize_schema; +use common_arrow::arrow::io::ipc::write::default_ipc_fields; use common_arrow::arrow_format::flight::data::Action; use common_arrow::arrow_format::flight::data::ActionType; use common_arrow::arrow_format::flight::data::Criteria; @@ -106,10 +108,15 @@ impl FlightService for DatabendQueryFlightService { match ticket { FlightTicket::StreamTicket(steam_ticket) => { - let receiver = self.dispatcher.get_stream(&steam_ticket)?; + let (receiver, data_schema) = self.dispatcher.get_stream(&steam_ticket)?; + let arrow_schema = data_schema.to_arrow(); + let ipc_fields = default_ipc_fields(arrow_schema.fields()); + + serialize_schema(&arrow_schema, &ipc_fields); Ok(RawResponse::new( - Box::pin(FlightDataStream::create(receiver)) as FlightStream, + Box::pin(FlightDataStream::create(receiver, ipc_fields)) + as FlightStream, )) } } diff --git a/query/src/api/rpc/flight_service_stream.rs b/query/src/api/rpc/flight_service_stream.rs index 99dafa22537f3..433b74e206a3b 100644 --- a/query/src/api/rpc/flight_service_stream.rs +++ b/query/src/api/rpc/flight_service_stream.rs @@ -16,6 +16,7 @@ use std::convert::TryInto; use common_arrow::arrow::io::flight::serialize_batch; use common_arrow::arrow::io::ipc::write::WriteOptions; +use common_arrow::arrow::io::ipc::IpcField; use common_arrow::arrow_format::flight::data::FlightData; use common_base::tokio::macros::support::Pin; use common_base::tokio::macros::support::Poll; @@ -27,13 +28,18 @@ use tonic::Status; pub struct FlightDataStream { input: Receiver>, + ipc_fields: Vec, options: WriteOptions, } impl FlightDataStream { - pub fn create(input: Receiver>) -> FlightDataStream { + pub fn create( + input: Receiver>, + ipc_fields: Vec, + ) -> FlightDataStream { FlightDataStream { input, + ipc_fields, options: WriteOptions { compression: None }, } } @@ -49,7 +55,8 @@ impl Stream for FlightDataStream { Some(Ok(block)) => match block.try_into() { Err(error) => Some(Err(Status::from(error))), Ok(record_batch) => { - let (dicts, values) = serialize_batch(&record_batch, &self.options); + let (dicts, values) = + serialize_batch(&record_batch, &self.ipc_fields, &self.options); match dicts.is_empty() { true => Some(Ok(values)), diff --git a/query/tests/it/api/rpc/flight_dispatcher.rs b/query/tests/it/api/rpc/flight_dispatcher.rs index c5b5eb6dabdc9..fac4f457170cd 100644 --- a/query/tests/it/api/rpc/flight_dispatcher.rs +++ b/query/tests/it/api/rpc/flight_dispatcher.rs @@ -69,7 +69,7 @@ async fn test_run_shuffle_action_with_no_scatters() -> Result<()> { .await?; let stream = stream_ticket(&query_id, &stage_id, &stream_id); - let receiver = flight_dispatcher.get_stream(&stream)?; + let (receiver, _data_scheme) = flight_dispatcher.get_stream(&stream)?; let receiver_stream = ReceiverStream::new(receiver); let collect_data_blocks = receiver_stream.collect::>>(); @@ -114,7 +114,7 @@ async fn test_run_shuffle_action_with_scatter() -> Result<()> { .await?; let stream_1 = stream_ticket(&query_id, &stage_id, "stream_1"); - let receiver = flight_dispatcher.get_stream(&stream_1)?; + let (receiver, _data_scheme) = flight_dispatcher.get_stream(&stream_1)?; let receiver_stream = ReceiverStream::new(receiver); let collect_data_blocks = receiver_stream.collect::>>(); @@ -131,7 +131,7 @@ async fn test_run_shuffle_action_with_scatter() -> Result<()> { assert_blocks_eq(expect, &collect_data_blocks.await?); let stream_2 = stream_ticket(&query_id, &stage_id, "stream_2"); - let receiver = flight_dispatcher.get_stream(&stream_2)?; + let (receiver, _data_scheme) = flight_dispatcher.get_stream(&stream_2)?; let receiver_stream = ReceiverStream::new(receiver); let collect_data_blocks = receiver_stream.collect::>>();