diff --git a/datafusion/sqllogictest/README.md b/datafusion/sqllogictest/README.md index 257937a65e4b..a18455476dab 100644 --- a/datafusion/sqllogictest/README.md +++ b/datafusion/sqllogictest/README.md @@ -243,6 +243,14 @@ export RUST_MIN_STACK=30485760; PG_COMPAT=true INCLUDE_SQLITE=true cargo test --features=postgres --test sqllogictests ``` +To update the sqllite expected answers use the `datafusion/sqllogictest/regenerate_sqlite_files.sh` script. + +Note this must be run with an empty postgres instance. For example + +```shell +PG_URI=postgresql://postgres@localhost:5432/postgres bash datafusion/sqllogictest/regenerate_sqlite_files.sh +``` + ## Updating tests: Completion Mode In test script completion mode, `sqllogictests` reads a prototype script and runs the statements and queries against the diff --git a/datafusion/sqllogictest/regenerate/sqllogictests.rs b/datafusion/sqllogictest/regenerate/sqllogictests.rs index a2706558814f..edad16bc84b1 100644 --- a/datafusion/sqllogictest/regenerate/sqllogictests.rs +++ b/datafusion/sqllogictest/regenerate/sqllogictests.rs @@ -16,10 +16,10 @@ // under the License. use clap::Parser; -use datafusion_common::instant::Instant; -use datafusion_common::utils::get_available_parallelism; -use datafusion_common::{exec_datafusion_err, exec_err, DataFusionError, Result}; -use datafusion_common_runtime::SpawnedTask; +use datafusion::common::instant::Instant; +use datafusion::common::utils::get_available_parallelism; +use datafusion::common::{exec_datafusion_err, exec_err, DataFusionError, Result}; +use datafusion::common::runtime::SpawnedTask; use datafusion_sqllogictest::{DataFusion, TestContext}; use futures::stream::StreamExt; use indicatif::{ @@ -378,7 +378,7 @@ async fn run_test_file_with_postgres( _mp: MultiProgress, _mp_style: ProgressStyle, ) -> Result<()> { - use datafusion_common::plan_err; + use datafusion::common::plan_err; plan_err!("Can not run with postgres as postgres feature is not enabled") } @@ -512,7 +512,7 @@ async fn run_complete_file_with_postgres( _mp: MultiProgress, _mp_style: ProgressStyle, ) -> Result<()> { - use datafusion_common::plan_err; + use datafusion::common::plan_err; plan_err!("Can not run with postgres as postgres feature is not enabled") } diff --git a/datafusion/sqllogictest/regenerate/src/engines/datafusion_engine/runner.rs b/datafusion/sqllogictest/regenerate/src/engines/datafusion_engine/runner.rs new file mode 100644 index 000000000000..e696058484a9 --- /dev/null +++ b/datafusion/sqllogictest/regenerate/src/engines/datafusion_engine/runner.rs @@ -0,0 +1,131 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; +use std::{path::PathBuf, time::Duration}; + +use super::{error::Result, normalize, DFSqlLogicTestError}; +use arrow::record_batch::RecordBatch; +use async_trait::async_trait; +use datafusion::physical_plan::common::collect; +use datafusion::physical_plan::execute_stream; +use datafusion::prelude::SessionContext; +use indicatif::ProgressBar; +use log::Level::{Debug, Info}; +use log::{debug, log_enabled, warn}; +use sqllogictest::DBOutput; +use tokio::time::Instant; + +use crate::engines::output::{DFColumnType, DFOutput}; + +pub struct DataFusion { + ctx: SessionContext, + relative_path: PathBuf, + pb: ProgressBar, +} + +impl DataFusion { + pub fn new(ctx: SessionContext, relative_path: PathBuf, pb: ProgressBar) -> Self { + Self { + ctx, + relative_path, + pb, + } + } + + fn update_slow_count(&self) { + let msg = self.pb.message(); + let split: Vec<&str> = msg.split(" ").collect(); + let mut current_count = 0; + + if split.len() > 2 { + // third match will be current slow count + current_count = split[2].parse::().unwrap(); + } + + current_count += 1; + + self.pb + .set_message(format!("{} - {} took > 500 ms", split[0], current_count)); + } +} + +#[async_trait] +impl sqllogictest::AsyncDB for DataFusion { + type Error = DFSqlLogicTestError; + type ColumnType = DFColumnType; + + async fn run(&mut self, sql: &str) -> Result { + if log_enabled!(Debug) { + debug!( + "[{}] Running query: \"{}\"", + self.relative_path.display(), + sql + ); + } + + let start = Instant::now(); + let result = run_query(&self.ctx, sql).await; + let duration = start.elapsed(); + + if duration.gt(&Duration::from_millis(500)) { + self.update_slow_count(); + } + + self.pb.inc(1); + + if log_enabled!(Info) && duration.gt(&Duration::from_secs(2)) { + warn!( + "[{}] Running query took more than 2 sec ({duration:?}): \"{sql}\"", + self.relative_path.display() + ); + } + + result + } + + /// Engine name of current database. + fn engine_name(&self) -> &str { + "DataFusion" + } + + /// [`DataFusion`] calls this function to perform sleep. + /// + /// The default implementation is `std::thread::sleep`, which is universal to any async runtime + /// but would block the current thread. If you are running in tokio runtime, you should override + /// this by `tokio::time::sleep`. + async fn sleep(dur: Duration) { + tokio::time::sleep(dur).await; + } +} + +async fn run_query(ctx: &SessionContext, sql: impl Into) -> Result { + let df = ctx.sql(sql.into().as_str()).await?; + let task_ctx = Arc::new(df.task_ctx()); + let plan = df.create_physical_plan().await?; + + let stream = execute_stream(plan, task_ctx)?; + let types = normalize::convert_schema_to_types(stream.schema().fields()); + let results: Vec = collect(stream).await?; + let rows = normalize::convert_batches(results)?; + + if rows.is_empty() && types.is_empty() { + Ok(DBOutput::StatementComplete(0)) + } else { + Ok(DBOutput::Rows { types, rows }) + } +} diff --git a/datafusion/sqllogictest/regenerate/src/engines/postgres_engine/mod.rs b/datafusion/sqllogictest/regenerate/src/engines/postgres_engine/mod.rs new file mode 100644 index 000000000000..050d19449c31 --- /dev/null +++ b/datafusion/sqllogictest/regenerate/src/engines/postgres_engine/mod.rs @@ -0,0 +1,379 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use async_trait::async_trait; +use bytes::Bytes; +use datafusion::common::runtime::SpawnedTask; +use futures::{SinkExt, StreamExt}; +use log::{debug, info}; +use sqllogictest::DBOutput; +/// Postgres engine implementation for sqllogictest. +use std::path::{Path, PathBuf}; +use std::str::FromStr; +use std::time::Duration; + +use super::conversion::*; +use crate::engines::output::{DFColumnType, DFOutput}; +use chrono::{NaiveDate, NaiveDateTime, NaiveTime}; +use indicatif::ProgressBar; +use postgres_types::Type; +use rust_decimal::Decimal; +use tokio::time::Instant; +use tokio_postgres::{Column, Row}; +use types::PgRegtype; + +mod types; + +// default connect string, can be overridden by the `PG_URL` environment variable +const PG_URI: &str = "postgresql://postgres@127.0.0.1/test"; + +/// DataFusion sql-logicaltest error +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("Postgres error: {0}")] + Postgres(#[from] tokio_postgres::error::Error), + #[error("Error handling copy command: {0}")] + Copy(String), +} + +pub type Result = std::result::Result; + +pub struct Postgres { + // None means the connection has been shutdown + client: Option, + _spawned_task: Option>, + /// Relative test file path + relative_path: PathBuf, + pb: ProgressBar, +} + +impl Postgres { + /// Creates a runner for executing queries against an existing postgres connection. + /// `relative_path` is used for display output and to create a postgres schema. + /// + /// The database connection details can be overridden by the + /// `PG_URI` environment variable. + /// + /// This defaults to + /// + /// ```text + /// PG_URI="postgresql://postgres@127.0.0.1/test" + /// ``` + /// + /// See https://docs.rs/tokio-postgres/latest/tokio_postgres/config/struct.Config.html#url for format + pub async fn connect(relative_path: PathBuf, pb: ProgressBar) -> Result { + let uri = + std::env::var("PG_URI").map_or(PG_URI.to_string(), std::convert::identity); + + info!("Using postgres connection string: {uri}"); + + let config = tokio_postgres::Config::from_str(&uri)?; + + // hint to user what the connection string was + let res = config.connect(tokio_postgres::NoTls).await; + if res.is_err() { + eprintln!("Error connecting to postgres using PG_URI={uri}"); + }; + + let (client, connection) = res?; + + let spawned_task = SpawnedTask::spawn(async move { + if let Err(e) = connection.await { + log::error!("Postgres connection error: {:?}", e); + } + }); + + let schema = schema_name(&relative_path); + + // create a new clean schema for running the test + debug!("Creating new empty schema '{schema}'"); + client + .execute(&format!("DROP SCHEMA IF EXISTS {schema} CASCADE"), &[]) + .await?; + + client + .execute(&format!("CREATE SCHEMA {schema}"), &[]) + .await?; + + client + .execute(&format!("SET search_path TO {schema}"), &[]) + .await?; + + Ok(Self { + client: Some(client), + _spawned_task: Some(spawned_task), + relative_path, + pb, + }) + } + + fn get_client(&mut self) -> &mut tokio_postgres::Client { + self.client.as_mut().expect("client is shutdown") + } + + /// Special COPY command support. "COPY 'filename'" requires the + /// server to read the file which may not be possible (maybe it is + /// remote or running in some other docker container). + /// + /// Thus, we rewrite sql statements like + /// + /// ```sql + /// COPY ... FROM 'filename' ... + /// ``` + /// + /// Into + /// + /// ```sql + /// COPY ... FROM STDIN ... + /// ``` + /// + /// And read the file locally. + async fn run_copy_command(&mut self, sql: &str) -> Result { + let canonical_sql = sql.trim_start().to_ascii_lowercase(); + + debug!("Handling COPY command: {sql}"); + + // Hacky way to find the 'filename' in the statement + let mut tokens = canonical_sql.split_whitespace().peekable(); + let mut filename = None; + + // COPY FROM '/opt/data/csv/aggregate_test_100.csv' ... + // + // into + // + // COPY FROM STDIN ... + + let mut new_sql = vec![]; + while let Some(tok) = tokens.next() { + new_sql.push(tok); + // rewrite FROM to FROM STDIN + if tok == "from" { + filename = tokens.next(); + new_sql.push("STDIN"); + } + } + + let filename = filename.map(no_quotes).ok_or_else(|| { + Error::Copy(format!("Can not find filename in COPY: {sql}")) + })?; + + let new_sql = new_sql.join(" "); + debug!("Copying data from file {filename} using sql: {new_sql}"); + + // start the COPY command and get location to write data to + let tx = self.get_client().transaction().await?; + let sink = tx.copy_in(&new_sql).await?; + let mut sink = Box::pin(sink); + + // read the input file as a string ans feed it to the copy command + let data = std::fs::read_to_string(filename) + .map_err(|e| Error::Copy(format!("Error reading {filename}: {e}")))?; + + let mut data_stream = futures::stream::iter(vec![Ok(Bytes::from(data))]).boxed(); + + sink.send_all(&mut data_stream).await?; + sink.close().await?; + tx.commit().await?; + Ok(DBOutput::StatementComplete(0)) + } + + fn update_slow_count(&self) { + let msg = self.pb.message(); + let split: Vec<&str> = msg.split(" ").collect(); + let mut current_count = 0; + + if split.len() > 2 { + // second match will be current slow count + current_count += split[2].parse::().unwrap(); + } + + current_count += 1; + + self.pb + .set_message(format!("{} - {} took > 500 ms", split[0], current_count)); + } +} + +/// remove single quotes from the start and end of the string +/// +/// 'filename' --> filename +fn no_quotes(t: &str) -> &str { + t.trim_start_matches('\'').trim_end_matches('\'') +} + +/// Given a file name like pg_compat_foo.slt +/// return a schema name +fn schema_name(relative_path: &Path) -> String { + relative_path + .to_string_lossy() + .chars() + .filter(|ch| ch.is_ascii_alphanumeric()) + .collect::() + .trim_start_matches("pg_") + .to_string() +} + +#[async_trait] +impl sqllogictest::AsyncDB for Postgres { + type Error = Error; + type ColumnType = DFColumnType; + + async fn run( + &mut self, + sql: &str, + ) -> Result, Self::Error> { + debug!( + "[{}] Running query: \"{}\"", + self.relative_path.display(), + sql + ); + + let lower_sql = sql.trim_start().to_ascii_lowercase(); + + let is_query_sql = { + lower_sql.starts_with("select") + || lower_sql.starts_with("values") + || lower_sql.starts_with("show") + || lower_sql.starts_with("with") + || lower_sql.starts_with("describe") + || ((lower_sql.starts_with("insert") + || lower_sql.starts_with("update") + || lower_sql.starts_with("delete")) + && lower_sql.contains("returning")) + }; + + if lower_sql.starts_with("copy") { + self.pb.inc(1); + return self.run_copy_command(sql).await; + } + + if !is_query_sql { + self.get_client().execute(sql, &[]).await?; + self.pb.inc(1); + return Ok(DBOutput::StatementComplete(0)); + } + let start = Instant::now(); + let rows = self.get_client().query(sql, &[]).await?; + let duration = start.elapsed(); + + if duration.gt(&Duration::from_millis(500)) { + self.update_slow_count(); + } + + self.pb.inc(1); + + let types: Vec = if rows.is_empty() { + self.get_client() + .prepare(sql) + .await? + .columns() + .iter() + .map(|c| c.type_().clone()) + .collect() + } else { + rows[0] + .columns() + .iter() + .map(|c| c.type_().clone()) + .collect() + }; + + if rows.is_empty() && types.is_empty() { + Ok(DBOutput::StatementComplete(0)) + } else { + Ok(DBOutput::Rows { + types: convert_types(types), + rows: convert_rows(rows), + }) + } + } + + fn engine_name(&self) -> &str { + "postgres" + } + +} + +fn convert_rows(rows: Vec) -> Vec> { + rows.iter() + .map(|row| { + row.columns() + .iter() + .enumerate() + .map(|(idx, column)| cell_to_string(row, column, idx)) + .collect::>() + }) + .collect::>() +} + +macro_rules! make_string { + ($row:ident, $idx:ident, $t:ty) => {{ + let value: Option<$t> = $row.get($idx); + match value { + Some(value) => value.to_string(), + None => NULL_STR.to_string(), + } + }}; + ($row:ident, $idx:ident, $t:ty, $convert:ident) => {{ + let value: Option<$t> = $row.get($idx); + match value { + Some(value) => $convert(value).to_string(), + None => NULL_STR.to_string(), + } + }}; +} + +fn cell_to_string(row: &Row, column: &Column, idx: usize) -> String { + match column.type_().clone() { + Type::CHAR => make_string!(row, idx, i8), + Type::INT2 => make_string!(row, idx, i16), + Type::INT4 => make_string!(row, idx, i32), + Type::INT8 => make_string!(row, idx, i64), + Type::NUMERIC => make_string!(row, idx, Decimal, decimal_to_str), + Type::DATE => make_string!(row, idx, NaiveDate), + Type::TIME => make_string!(row, idx, NaiveTime), + Type::TIMESTAMP => { + let value: Option = row.get(idx); + value + .map(|d| format!("{d:?}")) + .unwrap_or_else(|| "NULL".to_string()) + } + Type::BOOL => make_string!(row, idx, bool, bool_to_str), + Type::BPCHAR | Type::VARCHAR | Type::TEXT => { + make_string!(row, idx, &str, varchar_to_str) + } + Type::FLOAT4 => make_string!(row, idx, f32, f32_to_str), + Type::FLOAT8 => make_string!(row, idx, f64, f64_to_str), + Type::REGTYPE => make_string!(row, idx, PgRegtype), + _ => unimplemented!("Unsupported type: {}", column.type_().name()), + } +} + +fn convert_types(types: Vec) -> Vec { + types + .into_iter() + .map(|t| match t { + Type::BOOL => DFColumnType::Boolean, + Type::INT2 | Type::INT4 | Type::INT8 => DFColumnType::Integer, + Type::BPCHAR | Type::VARCHAR | Type::TEXT => DFColumnType::Text, + Type::FLOAT4 | Type::FLOAT8 | Type::NUMERIC => DFColumnType::Float, + Type::DATE | Type::TIME => DFColumnType::DateTime, + Type::TIMESTAMP => DFColumnType::Timestamp, + _ => DFColumnType::Another, + }) + .collect() +} diff --git a/datafusion/sqllogictest/regenerate_sqlite_files.sh b/datafusion/sqllogictest/regenerate_sqlite_files.sh index 0c1b26b1a9d3..981891c05b2f 100755 --- a/datafusion/sqllogictest/regenerate_sqlite_files.sh +++ b/datafusion/sqllogictest/regenerate_sqlite_files.sh @@ -168,8 +168,10 @@ sd -f i '^sqllogictest.*' 'sqllogictest = { git = "https://github.com/Omega359/s echo "Replacing the datafusion/sqllogictest/bin/sqllogictests.rs file with a custom version required for running completion" -# replace the sqllogictest.rs with a customized version. -cp datafusion/sqllogictest/regenerate/sqllogictests.rs datafusion/sqllogictest/bin/sqllogictests.rs +# replace the sqllogictest.rs with a customized versions. +cp datafusion/sqllogictest/regenerate/sqllogictests.rs datafusion/sqllogictest/bin/sqllogictests.rs +cp datafusion/sqllogictest/regenerate/src/engines/postgres_engine/mod.rs datafusion/sqllogictest/src/engines/postgres_engine/mod.rs +cp datafusion/sqllogictest/regenerate/src/engines/datafusion_engine/runner.rs datafusion/sqllogictest/src/engines/datafusion_engine/runner.rs echo "Running the sqllogictests with sqlite completion. This will take approximately an hour to run" @@ -200,6 +202,5 @@ echo "Cleaning up source code changes and temporary files and directories" cd "$DF_HOME" || exit; find ./datafusion-testing/data/sqlite/ -type f -name "*.bak" -exec rm {} \; find ./datafusion/sqllogictest/test_files/pg_compat/ -type f -name "*.bak" -exec rm {} \; -git checkout datafusion/sqllogictest/Cargo.toml -git checkout datafusion/sqllogictest/bin/sqllogictests.rs +git checkout datafusion/sqllogictest rm -rf /tmp/sqlitetesting