From 2fd107ccd78bab24799892ce95f9a943f4a9d5bd Mon Sep 17 00:00:00 2001 From: qwshen Date: Thu, 22 Sep 2022 20:54:35 -0400 Subject: [PATCH] added aggregation pushdown --- README.md | 34 ++++++- pom.xml | 2 +- .../com/qwshen/flight/PushAggregation.java | 49 ++++++++++ .../com/qwshen/flight/QueryStatement.java | 13 ++- src/main/java/com/qwshen/flight/Table.java | 63 +++++++++---- .../flight/spark/read/FlightScanBuilder.java | 50 +++++++++-- .../qwshen/flight/spark/test/DremioTest.scala | 90 +++++++++++++++++-- 7 files changed, 267 insertions(+), 34 deletions(-) create mode 100644 src/main/java/com/qwshen/flight/PushAggregation.java diff --git a/README.md b/README.md index 06d15af..a300401 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ The following properties are optional: select id, "departure-time", "arrival-time" from flights where "flight-no" = 'ABC-21'; ``` -The connector supports optimized reads with filters, required columns and count pushing-down, and parallel reads when partitioning is enalbed. +The connector supports optimized reads with filters, required columns and aggregation pushing-down, and parallel reads when partitioning is enabled. ### 1. Load data ```scala @@ -113,7 +113,7 @@ spark.read Note: when lowerBound & upperBound with byColumn or predicates are used, they are eventually filters applied on the queries to fetch data which may impact the final result-set. Please make sure these partitioning options do not affect the final output, but rather only apply for partitioning the output. #### - Pushing filter & columns down -The filters and required-columns are pushed down when they are provided. This limits the data at the source which greatly decreases the amount of data being transferred and processed. +Filters and required-columns are pushed down when they are provided. This limits the data at the source which greatly decreases the amount of data being transferred and processed. ```scala spark.read .option("host", "192.168.0.26").option("port", 32010).option("tls.enabled", true).option("tls.verifyServer", false).option("user", "test").option("password", "Password@123") @@ -123,6 +123,36 @@ spark.read .select("order_id", "customer_id", "payment_method", "order_amount", "order_date") //required-columns are pushed down ``` +#### - Pushing aggregation down +Aggregations are pushed down when they are provided. Only the following aggregations are supported: +- max +- min +- count +- count distinct +- sum +- sum distinct + +For avg, it can be achieved by combining count & sum. For any other aggregations, they are calculated at Spark level. + +```scala +val df = spark.read + .option("host", "192.168.0.26").option("port", 32010).option("tls.enabled", true).option("tls.verifyServer", false).option("user", "test").option("password", "Password@123") + .options(options) //other options + .flight(""""e-commerce".orders""") + .filter("order_date > '2020-01-01' and order_amount > 100") //filter is pushed down + +df.agg(count(col("order_id")).as("num_orders"), sum(col("amount")).as("total_amount")).show() //aggregation pushed down + +df.groupBy(col("gender")) + .agg( + countDistinct(col("order_id")).as("num_orders"), + max(col("amount")).as("max_amount"), + min(col("amount")).as("min_amount"), + sum(col("amount")).as("total_amount") + ) //aggregation pushed down + .show() +``` + ### 2. Write data (tables being written must be iceberg tables in case of Dremio Flight) ```scala df.write.format("flight") diff --git a/pom.xml b/pom.xml index b9cc23f..2e78516 100644 --- a/pom.xml +++ b/pom.xml @@ -23,7 +23,7 @@ com.qwshen spark-flight-connector_${spark.version} - 1.0.0 + 1.0.1 jar diff --git a/src/main/java/com/qwshen/flight/PushAggregation.java b/src/main/java/com/qwshen/flight/PushAggregation.java new file mode 100644 index 0000000..e577c93 --- /dev/null +++ b/src/main/java/com/qwshen/flight/PushAggregation.java @@ -0,0 +1,49 @@ +package com.qwshen.flight; + +import java.io.Serializable; + +/** + * Describes the data structure for pushed-down aggregation + */ +public class PushAggregation implements Serializable { + //pushed-down Aggregate-Columns (expressions) + private String[] _columnExpressions = null; + //pushed-down GroupBy-Columns + private String[] _groupByColumns = null; + + /** + * Push down aggregation of columns + * select max(age), sum(distinct amount) from table where ... + * @param columnExpressions - the collection of aggregation expressions + */ + public PushAggregation(String[] columnExpressions) { + this._columnExpressions = columnExpressions; + } + + /** + * Push down aggregation with group by columns + * select max(age), sum(amount) from table where ... group by gender + * @param columnExpressions - the collection of aggregation expressions + * @param groupByColumns - the columns in group by + */ + public PushAggregation(String[] columnExpressions, String[] groupByColumns) { + this(columnExpressions); + this._groupByColumns = groupByColumns; + } + + /** + * Return the collection of aggregation expressions + * @return - the expressions + */ + public String[] getColumnExpressions() { + return this._columnExpressions; + } + + /** + * The columns for group-by + * @return - columns + */ + public String[] getGroupByColumns() { + return this._groupByColumns; + } +} diff --git a/src/main/java/com/qwshen/flight/QueryStatement.java b/src/main/java/com/qwshen/flight/QueryStatement.java index 7f9252d..d926c8f 100644 --- a/src/main/java/com/qwshen/flight/QueryStatement.java +++ b/src/main/java/com/qwshen/flight/QueryStatement.java @@ -8,15 +8,18 @@ public class QueryStatement implements Serializable { private final String _stmt; private final String _where; + private final String _groupBy; /** * Construct a ReadStatement * @param stmt - the select portion of a select-statement * @param where - the where portion of a select-statement + * @param groupBy - the groupBy portion of a select-statement */ - public QueryStatement(String stmt, String where) { + public QueryStatement(String stmt, String where, String groupBy) { this._stmt = stmt; this._where = where; + this._groupBy = groupBy; } /** @@ -29,6 +32,9 @@ public boolean different(QueryStatement rs) { if (!changed) { changed = (rs._where != null) ? !rs._where.equalsIgnoreCase(this._where) : this._where != null; } + if (!changed) { + changed = (rs._groupBy != null) ? !rs._groupBy.equalsIgnoreCase(this._groupBy) : this._groupBy != null; + } return changed; } @@ -37,6 +43,9 @@ public boolean different(QueryStatement rs) { * @return - the select-statement */ public String getStatement() { - return (this._where != null && this._where.length() > 0) ? String.format("%s where %s", this._stmt, this._where) : this._stmt; + return String.format("%s %s %s", this._stmt, + (this._where != null && this._where.length() > 0) ? String.format("where %s", this._where) : "", + (this._groupBy != null && this._groupBy.length() > 0) ? String.format("group by %s", this._groupBy) : "" + ); } } diff --git a/src/main/java/com/qwshen/flight/Table.java b/src/main/java/com/qwshen/flight/Table.java index 87bf1be..f62904a 100644 --- a/src/main/java/com/qwshen/flight/Table.java +++ b/src/main/java/com/qwshen/flight/Table.java @@ -1,13 +1,19 @@ package com.qwshen.flight; import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.parquet.Strings; import org.apache.spark.sql.sources.*; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.slf4j.LoggerFactory; +import scala.Array; +import scala.collection.mutable.HashTable; + import java.io.Serializable; import java.util.Arrays; +import java.util.Hashtable; +import java.util.function.BiFunction; import java.util.function.Function; /** @@ -41,7 +47,7 @@ private Table(String name, String columnQuote) { this._name = name; this._columnQuote = columnQuote; - this.prepareQueryStatement(false, null, null, null); + this.prepareQueryStatement(null, null, null, null); } /** @@ -122,23 +128,50 @@ public void initialize(Configuration config) { } //Prepare the query for submitting to remote flight service - private boolean prepareQueryStatement(Boolean forCount, StructField[] fields, String filter, PartitionBehavior partitionBehavior) { - String selectStmt = forCount ? String.format("select count(*) from %s", this._name) - : (fields == null || fields.length == 0) ? String.format("select * from %s", this._name) - : String.format("select %s from %s", String.join(",", Arrays.stream(fields).map(column -> String.format("%s%s%s", this._columnQuote, column.name(), this._columnQuote)).toArray(String[]::new)), this._name); - QueryStatement stmt = new QueryStatement(selectStmt, filter); + private boolean prepareQueryStatement(PushAggregation aggregation, StructField[] fields, String filter, PartitionBehavior partitionBehavior) { + //aggregation mode: 0 -> no aggregation; 1 -> aggregation without group-by; 2 -> aggregation with group-by + int aggMode = 0; + String select = "", groupBy = ""; + if (aggregation != null) { + String[] groupByFields = aggregation.getGroupByColumns(); + if (groupByFields != null && groupByFields.length > 0) { + aggMode = 2; + groupBy = String.join(",", groupByFields); + } else { + aggMode = 1; + } + select = String.format("select %s from %s", String.join(",", aggregation.getColumnExpressions()), this._name); + } else if (fields != null && fields.length > 0) { + select = String.format("select %s from %s", String.join(",", Arrays.stream(fields).map(column -> String.format("%s%s%s", this._columnQuote, column.name(), this._columnQuote)).toArray(String[]::new)), this._name); + } else { + select = String.format("select * from %s", this._name); + } + QueryStatement stmt = new QueryStatement(select, filter, groupBy); boolean changed = stmt.different(this._stmt); if (changed) { this._stmt = stmt; } - if (forCount) { + if (aggMode == 1) { this._partitionStmts.clear(); } else if (partitionBehavior != null && partitionBehavior.enabled()) { - String baseWhere = (filter != null && !filter.isEmpty()) ? String.format("(%s) and ", filter) : ""; - Function find = (name) -> (fields != null) ? Arrays.stream(fields).filter(field -> field.name().equalsIgnoreCase(name)).findFirst().orElse(null) : null; - String[] predicates = partitionBehavior.predicateDefined() ? partitionBehavior.getPredicates() : partitionBehavior.calculatePredicates(fields); - Arrays.stream(predicates).forEach(predicate -> this._partitionStmts.add(String.format("%s where %s(%s)", selectStmt, baseWhere, predicate))); + String where = (filter != null && !filter.isEmpty()) ? String.format("(%s) and ", filter) : ""; + BiFunction merge = (s1, s2) -> { + Hashtable s = new Hashtable(); + for (StructField sf : s1) { + s.put(sf.name(), sf); + } + for (StructField sf : s2) { + s.put(sf.name(), sf); + } + return s.values().toArray(new StructField[0]); + }; + String[] predicates = partitionBehavior.predicateDefined() ? partitionBehavior.getPredicates() + : partitionBehavior.calculatePredicates(this._sparkSchema == null ? fields : merge.apply(fields, this._sparkSchema.fields())); + for (String predicate : predicates) { + QueryStatement s = new QueryStatement(select, String.format("%s(%s)", where, predicate), groupBy); + this._partitionStmts.add(s.getStatement()); + } } return changed; } @@ -202,15 +235,15 @@ public String toWhereClause(Filter filter) { * Probe if the pushed filter, fields and aggregation would affect the existing schema & end-points * @param pushedFilter - the pushed filter * @param pushedFields - the pushed fields - * @param pushedCount - the pushed count aggregation + * @param pushedAggregation - the pushed aggregation * @param partitionBehavior - the partitioning behavior * @return - true if initialization is required */ - public Boolean probe(String pushedFilter, StructField[] pushedFields, boolean pushedCount, PartitionBehavior partitionBehavior) { - if ((pushedFilter == null || pushedFilter.isEmpty()) && (pushedFields == null || pushedFields.length == 0) && !pushedCount) { + public Boolean probe(String pushedFilter, StructField[] pushedFields, PushAggregation pushedAggregation, PartitionBehavior partitionBehavior) { + if ((pushedFilter == null || pushedFilter.isEmpty()) && (pushedFields == null || pushedFields.length == 0) && pushedAggregation == null) { return false; } - return this.prepareQueryStatement(pushedCount, pushedFields, pushedFilter, partitionBehavior); + return this.prepareQueryStatement(pushedAggregation, pushedFields, pushedFilter, partitionBehavior); } /** diff --git a/src/main/java/com/qwshen/flight/spark/read/FlightScanBuilder.java b/src/main/java/com/qwshen/flight/spark/read/FlightScanBuilder.java index 89ae24b..5bd33f1 100644 --- a/src/main/java/com/qwshen/flight/spark/read/FlightScanBuilder.java +++ b/src/main/java/com/qwshen/flight/spark/read/FlightScanBuilder.java @@ -2,13 +2,17 @@ import com.qwshen.flight.Configuration; import com.qwshen.flight.PartitionBehavior; +import com.qwshen.flight.PushAggregation; import com.qwshen.flight.Table; import org.apache.spark.sql.connector.expressions.aggregate.Aggregation; +import org.apache.spark.sql.connector.expressions.aggregate.*; import org.apache.spark.sql.connector.read.*; import org.apache.spark.sql.sources.*; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import java.util.Objects; import java.util.function.Function; @@ -24,8 +28,9 @@ public final class FlightScanBuilder implements ScanBuilder, SupportsPushDownFil private Filter[] _pdFilters = new Filter[0]; //the pushed-down columns private StructField[] _pdColumns = new StructField[0]; - //only supports pushed-down COUNT - private boolean _pdCount = false; + + //pushed-down aggregation + private PushAggregation _pdAggregation = null; /** * Construct a flight-scan builder @@ -40,14 +45,45 @@ public FlightScanBuilder(Configuration configuration, Table table, PartitionBeha } /** - * Only COUNT is supported for now + * Collection aggregations that will be pushed down * @param aggregation - the pushed aggregation - * @return - true if COUNT is pushed + * @return - pushed aggregation */ @Override public boolean pushAggregation(Aggregation aggregation) { - this._pdCount = Arrays.stream(aggregation.aggregateExpressions()).map(e -> e.toString().equalsIgnoreCase("count(*)")).findFirst().orElse(false); - return this._pdCount; + Function quote = (s) -> String.format("%s%s%s", this._table.getColumnQuote(), s, this._table.getColumnQuote()); + Function mks = (ss) -> String.join(",", Arrays.stream(ss).map(quote).toArray(String[]::new)); + + List pdAggregateColumns = new ArrayList<>(); + boolean push = true; + for (AggregateFunc agg : aggregation.aggregateExpressions()) { + if (agg instanceof CountStar) { + pdAggregateColumns.add(agg.toString().toLowerCase()); + } else if (agg instanceof Count) { + Count c = (Count)agg; + pdAggregateColumns.add(c.isDistinct() ? String.format("count(distinct(%s))", mks.apply(c.column().fieldNames())) : String.format("count(%s)", mks.apply(c.column().fieldNames()))); + } else if (agg instanceof Min) { + Min m = (Min)agg; + pdAggregateColumns.add(String.format("min(%s)", mks.apply(m.column().fieldNames()))); + } else if (agg instanceof Max) { + Max m = (Max)agg; + pdAggregateColumns.add(String.format("max(%s)", mks.apply(m.column().fieldNames()))); + } else if (agg instanceof Sum) { + Sum s = (Sum)agg; + pdAggregateColumns.add(s.isDistinct() ? String.format("sum(distinct(%s))", mks.apply(s.column().fieldNames())) : String.format("sum(%s)", mks.apply(s.column().fieldNames()))); + } else { + push = false; + break; + } + }; + if (push) { + String[] pdGroupByColumns = Arrays.stream(aggregation.groupByColumns()).flatMap(gbc -> Arrays.stream(gbc.fieldNames()).map(quote)).toArray(String[]::new); + pdAggregateColumns.addAll(0, Arrays.asList(pdGroupByColumns)); + this._pdAggregation = pdGroupByColumns.length > 0 ? new PushAggregation(pdAggregateColumns.toArray(new String[0]), pdGroupByColumns) : new PushAggregation(pdAggregateColumns.toArray(new String[0])); + } else { + this._pdAggregation = null; + } + return this._pdAggregation != null; } /** @@ -105,7 +141,7 @@ public void pruneColumns(StructType columns) { public Scan build() { //adjust flight-table upon pushed filters & columns String where = String.join(" and ", Arrays.stream(this._pdFilters).map(this._table::toWhereClause).toArray(String[]::new)); - if (this._table.probe(where, this._pdColumns, this._pdCount, this._partitionBehavior)) { + if (this._table.probe(where, this._pdColumns, this._pdAggregation, this._partitionBehavior)) { this._table.initialize(this._configuration); } return new FlightScan(this._configuration, this._table); diff --git a/src/test/scala/com/qwshen/flight/spark/test/DremioTest.scala b/src/test/scala/com/qwshen/flight/spark/test/DremioTest.scala index 0816906..c07ed58 100644 --- a/src/test/scala/com/qwshen/flight/spark/test/DremioTest.scala +++ b/src/test/scala/com/qwshen/flight/spark/test/DremioTest.scala @@ -1,13 +1,13 @@ package com.qwshen.flight.spark.test import org.apache.spark.sql.{DataFrame, SparkSession} -import org.apache.spark.sql.functions.{col, lit, struct, when, map, array} +import org.apache.spark.sql.functions.{array, col, count, countDistinct, lit, map, max, min, struct, sum, sum_distinct, when} import org.scalatest.{BeforeAndAfterEach, FunSuite} class DremioTest extends FunSuite with BeforeAndAfterEach { - private val dremioHost = "192.168.0.27" + private val dremioHost = "192.168.0.22" private val dremioPort = "32010" - private val dremioTlsEnabled = false; + private val dremioTlsEnabled = false private val user = "test" private val password = "Password@12345" @@ -170,6 +170,81 @@ class DremioTest extends FunSuite with BeforeAndAfterEach { df.show() } + test("Query a table with aggregation") { + val table = """"local-iceberg".iceberg_db.log_events_iceberg_table_events""" + val run: SparkSession => DataFrame = this.load(Map("table" -> table), None, Nil) + val df = this.execute(run) + + df.agg(max(col("float_amount")).as("max_float_amount"), sum(col("double_amount")).as("sum_double_amount")).show() + df.filter(col("float_amount") >= lit(2.34f)).agg(max(col("float_amount")).as("max_float_amount"), sum(col("double_amount")).as("sum_double_amount")).show() + + //df.limit(10).show() //not supported + //df.distinct().show() //not supported + + df.filter(col("float_amount") >= lit(2.34f)) + .groupBy(col("gender"), col("birthyear")) + .agg( + countDistinct(col("event_id")).as("distinct_count"), + count(col("event_id")).as("count"), + max(col("float_amount")).as("max_float_amount"), + min(col("float_amount")).as("min_float_amount"), + //avg(col("decimal_amount")).as("avg_decimal_amount"), //not supported + sum_distinct(col("double_amount")).as("distinct_sum_double_amount"), + sum(col("double_amount")).as("sum_double_amount") + ) + .show() + } + + test("Query a table with aggregation with partitioning by hashing") { + val table = """"local-iceberg".iceberg_db.log_events_iceberg_table_events""" + val run: SparkSession => DataFrame = this.load(Map("table" -> table, "partition.size" -> "3", "partition.byColumn" -> "event_id"), None, Nil) + val df = this.execute(run) + + df.agg(max(col("float_amount")).as("max_float_amount"), sum(col("double_amount")).as("sum_double_amount")).show() + df.filter(col("float_amount") >= lit(2.34f)).agg(max(col("float_amount")).as("max_float_amount"), sum(col("double_amount")).as("sum_double_amount")).show() + + //df.limit(10).show() //not supported + //df.distinct().show() //not supported + + df.filter(col("float_amount") >= lit(2.34f)) + .groupBy(col("gender"), col("birthyear")) + .agg( + countDistinct(col("event_id")).as("distinct_count"), + count(col("event_id")).as("count"), + max(col("float_amount")).as("max_float_amount"), + min(col("float_amount")).as("min_float_amount"), + //avg(col("decimal_amount")).as("avg_decimal_amount"), //not supported + sum_distinct(col("double_amount")).as("distinct_sum_double_amount"), + sum(col("double_amount")).as("sum_double_amount") + ) + .show() + } + + test("Query a table with aggregation with partitioning by lower/upper bounds") { + val table = """"local-iceberg".iceberg_db.log_events_iceberg_table_events""" + val run: SparkSession => DataFrame = this.load(Map("table" -> table, "partition.size" -> "3", "partition.byColumn" -> "start_date", "partition.lowerBound" -> "2012-10-01", "partition.upperBound" -> "2012-12-31"), None, Nil) + val df = this.execute(run) + + df.agg(max(col("float_amount")).as("max_float_amount"), sum(col("double_amount")).as("sum_double_amount")).show() + df.filter(col("float_amount") >= lit(2.34f)).agg(max(col("float_amount")).as("max_float_amount"), sum(col("double_amount")).as("sum_double_amount")).show() + + //df.limit(10).show() //not supported + //df.distinct().show() //not supported + + df.filter(col("float_amount") >= lit(2.34f)) + .groupBy(col("gender"), col("start_date")) + .agg( + countDistinct(col("event_id")).as("distinct_count"), + count(col("event_id")).as("count"), + max(col("float_amount")).as("max_float_amount"), + min(col("float_amount")).as("min_float_amount"), + //avg(col("decimal_amount")).as("avg_decimal_amount"), //not supported + sum_distinct(col("double_amount")).as("distinct_sum_double_amount"), + sum(col("double_amount")).as("sum_double_amount") + ) + .show() + } + test("Write a simple table") { val query = """ |select @@ -216,7 +291,7 @@ class DremioTest extends FunSuite with BeforeAndAfterEach { val df = this.execute(runLoad).cache //append - var dstTable = """"local-iceberg"."iceberg_db"."iceberg_events"""" + val dstTable = """"local-iceberg"."iceberg_db"."iceberg_events"""" val dfAppend = df.limit(10000) .withColumn("event_id", col("event_id") + 1000000000L) .withColumn("remark", lit("new")) @@ -230,7 +305,7 @@ class DremioTest extends FunSuite with BeforeAndAfterEach { val df = this.execute(runLoad).filter(col("user_id").isin("781622845", "1519813515", "1733137333", "3709565024")).cache //the target table - var dstTable = """"local-iceberg"."iceberg_db"."iceberg_events"""" + val dstTable = """"local-iceberg"."iceberg_db"."iceberg_events"""" //append for struct val dfStruct = df.filter(col("user_id") === lit("781622845") || col("user_id") === lit("3709565024")) @@ -312,10 +387,11 @@ class DremioTest extends FunSuite with BeforeAndAfterEach { .mode("append").save } - private var spark: SparkSession = _ //create spark-session - override def beforeEach(): Unit = spark = SparkSession.builder.master("local[*]").config("spark.executor.memory", "24g").config("spark.driver.memory", "24g").appName("test").getOrCreate + private var spark: SparkSession = _ //execute a job private def execute[T](run: SparkSession => T): T = run(spark) + + override def beforeEach(): Unit = spark = SparkSession.builder.master("local[*]").config("spark.executor.memory", "24g").config("spark.driver.memory", "24g").appName("test").getOrCreate override def afterEach(): Unit = spark.stop() }