Skip to content

Commit

Permalink
added aggregation pushdown
Browse files Browse the repository at this point in the history
  • Loading branch information
qwshen committed Sep 23, 2022
1 parent d4893ba commit 2fd107c
Show file tree
Hide file tree
Showing 7 changed files with 267 additions and 34 deletions.
34 changes: 32 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ The following properties are optional:
select id, "departure-time", "arrival-time" from flights where "flight-no" = 'ABC-21';
```

The connector supports optimized reads with filters, required columns and count pushing-down, and parallel reads when partitioning is enalbed.
The connector supports optimized reads with filters, required columns and aggregation pushing-down, and parallel reads when partitioning is enabled.

### 1. Load data
```scala
Expand Down Expand Up @@ -113,7 +113,7 @@ spark.read
Note: when lowerBound & upperBound with byColumn or predicates are used, they are eventually filters applied on the queries to fetch data which may impact the final result-set. Please make sure these partitioning options do not affect the final output, but rather only apply for partitioning the output.

#### - Pushing filter & columns down
The filters and required-columns are pushed down when they are provided. This limits the data at the source which greatly decreases the amount of data being transferred and processed.
Filters and required-columns are pushed down when they are provided. This limits the data at the source which greatly decreases the amount of data being transferred and processed.
```scala
spark.read
.option("host", "192.168.0.26").option("port", 32010).option("tls.enabled", true).option("tls.verifyServer", false).option("user", "test").option("password", "Password@123")
Expand All @@ -123,6 +123,36 @@ spark.read
.select("order_id", "customer_id", "payment_method", "order_amount", "order_date") //required-columns are pushed down
```

#### - Pushing aggregation down
Aggregations are pushed down when they are provided. Only the following aggregations are supported:
- max
- min
- count
- count distinct
- sum
- sum distinct

For avg, it can be achieved by combining count & sum. For any other aggregations, they are calculated at Spark level.

```scala
val df = spark.read
.option("host", "192.168.0.26").option("port", 32010).option("tls.enabled", true).option("tls.verifyServer", false).option("user", "test").option("password", "Password@123")
.options(options) //other options
.flight(""""e-commerce".orders""")
.filter("order_date > '2020-01-01' and order_amount > 100") //filter is pushed down

df.agg(count(col("order_id")).as("num_orders"), sum(col("amount")).as("total_amount")).show() //aggregation pushed down

df.groupBy(col("gender"))
.agg(
countDistinct(col("order_id")).as("num_orders"),
max(col("amount")).as("max_amount"),
min(col("amount")).as("min_amount"),
sum(col("amount")).as("total_amount")
) //aggregation pushed down
.show()
```

### 2. Write data (tables being written must be iceberg tables in case of Dremio Flight)
```scala
df.write.format("flight")
Expand Down
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

<groupId>com.qwshen</groupId>
<artifactId>spark-flight-connector_${spark.version}</artifactId>
<version>1.0.0</version>
<version>1.0.1</version>
<packaging>jar</packaging>

<build>
Expand Down
49 changes: 49 additions & 0 deletions src/main/java/com/qwshen/flight/PushAggregation.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package com.qwshen.flight;

import java.io.Serializable;

/**
* Describes the data structure for pushed-down aggregation
*/
public class PushAggregation implements Serializable {
//pushed-down Aggregate-Columns (expressions)
private String[] _columnExpressions = null;
//pushed-down GroupBy-Columns
private String[] _groupByColumns = null;

/**
* Push down aggregation of columns
* select max(age), sum(distinct amount) from table where ...
* @param columnExpressions - the collection of aggregation expressions
*/
public PushAggregation(String[] columnExpressions) {
this._columnExpressions = columnExpressions;
}

/**
* Push down aggregation with group by columns
* select max(age), sum(amount) from table where ... group by gender
* @param columnExpressions - the collection of aggregation expressions
* @param groupByColumns - the columns in group by
*/
public PushAggregation(String[] columnExpressions, String[] groupByColumns) {
this(columnExpressions);
this._groupByColumns = groupByColumns;
}

/**
* Return the collection of aggregation expressions
* @return - the expressions
*/
public String[] getColumnExpressions() {
return this._columnExpressions;
}

/**
* The columns for group-by
* @return - columns
*/
public String[] getGroupByColumns() {
return this._groupByColumns;
}
}
13 changes: 11 additions & 2 deletions src/main/java/com/qwshen/flight/QueryStatement.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,18 @@
public class QueryStatement implements Serializable {
private final String _stmt;
private final String _where;
private final String _groupBy;

/**
* Construct a ReadStatement
* @param stmt - the select portion of a select-statement
* @param where - the where portion of a select-statement
* @param groupBy - the groupBy portion of a select-statement
*/
public QueryStatement(String stmt, String where) {
public QueryStatement(String stmt, String where, String groupBy) {
this._stmt = stmt;
this._where = where;
this._groupBy = groupBy;
}

/**
Expand All @@ -29,6 +32,9 @@ public boolean different(QueryStatement rs) {
if (!changed) {
changed = (rs._where != null) ? !rs._where.equalsIgnoreCase(this._where) : this._where != null;
}
if (!changed) {
changed = (rs._groupBy != null) ? !rs._groupBy.equalsIgnoreCase(this._groupBy) : this._groupBy != null;
}
return changed;
}

Expand All @@ -37,6 +43,9 @@ public boolean different(QueryStatement rs) {
* @return - the select-statement
*/
public String getStatement() {
return (this._where != null && this._where.length() > 0) ? String.format("%s where %s", this._stmt, this._where) : this._stmt;
return String.format("%s %s %s", this._stmt,
(this._where != null && this._where.length() > 0) ? String.format("where %s", this._where) : "",
(this._groupBy != null && this._groupBy.length() > 0) ? String.format("group by %s", this._groupBy) : ""
);
}
}
63 changes: 48 additions & 15 deletions src/main/java/com/qwshen/flight/Table.java
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
package com.qwshen.flight;

import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.parquet.Strings;
import org.apache.spark.sql.sources.*;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.slf4j.LoggerFactory;
import scala.Array;
import scala.collection.mutable.HashTable;

import java.io.Serializable;
import java.util.Arrays;
import java.util.Hashtable;
import java.util.function.BiFunction;
import java.util.function.Function;

/**
Expand Down Expand Up @@ -41,7 +47,7 @@ private Table(String name, String columnQuote) {
this._name = name;
this._columnQuote = columnQuote;

this.prepareQueryStatement(false, null, null, null);
this.prepareQueryStatement(null, null, null, null);
}

/**
Expand Down Expand Up @@ -122,23 +128,50 @@ public void initialize(Configuration config) {
}

//Prepare the query for submitting to remote flight service
private boolean prepareQueryStatement(Boolean forCount, StructField[] fields, String filter, PartitionBehavior partitionBehavior) {
String selectStmt = forCount ? String.format("select count(*) from %s", this._name)
: (fields == null || fields.length == 0) ? String.format("select * from %s", this._name)
: String.format("select %s from %s", String.join(",", Arrays.stream(fields).map(column -> String.format("%s%s%s", this._columnQuote, column.name(), this._columnQuote)).toArray(String[]::new)), this._name);
QueryStatement stmt = new QueryStatement(selectStmt, filter);
private boolean prepareQueryStatement(PushAggregation aggregation, StructField[] fields, String filter, PartitionBehavior partitionBehavior) {
//aggregation mode: 0 -> no aggregation; 1 -> aggregation without group-by; 2 -> aggregation with group-by
int aggMode = 0;
String select = "", groupBy = "";
if (aggregation != null) {
String[] groupByFields = aggregation.getGroupByColumns();
if (groupByFields != null && groupByFields.length > 0) {
aggMode = 2;
groupBy = String.join(",", groupByFields);
} else {
aggMode = 1;
}
select = String.format("select %s from %s", String.join(",", aggregation.getColumnExpressions()), this._name);
} else if (fields != null && fields.length > 0) {
select = String.format("select %s from %s", String.join(",", Arrays.stream(fields).map(column -> String.format("%s%s%s", this._columnQuote, column.name(), this._columnQuote)).toArray(String[]::new)), this._name);
} else {
select = String.format("select * from %s", this._name);
}
QueryStatement stmt = new QueryStatement(select, filter, groupBy);
boolean changed = stmt.different(this._stmt);
if (changed) {
this._stmt = stmt;
}

if (forCount) {
if (aggMode == 1) {
this._partitionStmts.clear();
} else if (partitionBehavior != null && partitionBehavior.enabled()) {
String baseWhere = (filter != null && !filter.isEmpty()) ? String.format("(%s) and ", filter) : "";
Function<String, StructField> find = (name) -> (fields != null) ? Arrays.stream(fields).filter(field -> field.name().equalsIgnoreCase(name)).findFirst().orElse(null) : null;
String[] predicates = partitionBehavior.predicateDefined() ? partitionBehavior.getPredicates() : partitionBehavior.calculatePredicates(fields);
Arrays.stream(predicates).forEach(predicate -> this._partitionStmts.add(String.format("%s where %s(%s)", selectStmt, baseWhere, predicate)));
String where = (filter != null && !filter.isEmpty()) ? String.format("(%s) and ", filter) : "";
BiFunction<StructField[], StructField[], StructField[]> merge = (s1, s2) -> {
Hashtable<String, StructField> s = new Hashtable<String, StructField>();
for (StructField sf : s1) {
s.put(sf.name(), sf);
}
for (StructField sf : s2) {
s.put(sf.name(), sf);
}
return s.values().toArray(new StructField[0]);
};
String[] predicates = partitionBehavior.predicateDefined() ? partitionBehavior.getPredicates()
: partitionBehavior.calculatePredicates(this._sparkSchema == null ? fields : merge.apply(fields, this._sparkSchema.fields()));
for (String predicate : predicates) {
QueryStatement s = new QueryStatement(select, String.format("%s(%s)", where, predicate), groupBy);
this._partitionStmts.add(s.getStatement());
}
}
return changed;
}
Expand Down Expand Up @@ -202,15 +235,15 @@ public String toWhereClause(Filter filter) {
* Probe if the pushed filter, fields and aggregation would affect the existing schema & end-points
* @param pushedFilter - the pushed filter
* @param pushedFields - the pushed fields
* @param pushedCount - the pushed count aggregation
* @param pushedAggregation - the pushed aggregation
* @param partitionBehavior - the partitioning behavior
* @return - true if initialization is required
*/
public Boolean probe(String pushedFilter, StructField[] pushedFields, boolean pushedCount, PartitionBehavior partitionBehavior) {
if ((pushedFilter == null || pushedFilter.isEmpty()) && (pushedFields == null || pushedFields.length == 0) && !pushedCount) {
public Boolean probe(String pushedFilter, StructField[] pushedFields, PushAggregation pushedAggregation, PartitionBehavior partitionBehavior) {
if ((pushedFilter == null || pushedFilter.isEmpty()) && (pushedFields == null || pushedFields.length == 0) && pushedAggregation == null) {
return false;
}
return this.prepareQueryStatement(pushedCount, pushedFields, pushedFilter, partitionBehavior);
return this.prepareQueryStatement(pushedAggregation, pushedFields, pushedFilter, partitionBehavior);
}

/**
Expand Down
50 changes: 43 additions & 7 deletions src/main/java/com/qwshen/flight/spark/read/FlightScanBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@

import com.qwshen.flight.Configuration;
import com.qwshen.flight.PartitionBehavior;
import com.qwshen.flight.PushAggregation;
import com.qwshen.flight.Table;
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation;
import org.apache.spark.sql.connector.expressions.aggregate.*;
import org.apache.spark.sql.connector.read.*;
import org.apache.spark.sql.sources.*;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.function.Function;

Expand All @@ -24,8 +28,9 @@ public final class FlightScanBuilder implements ScanBuilder, SupportsPushDownFil
private Filter[] _pdFilters = new Filter[0];
//the pushed-down columns
private StructField[] _pdColumns = new StructField[0];
//only supports pushed-down COUNT
private boolean _pdCount = false;

//pushed-down aggregation
private PushAggregation _pdAggregation = null;

/**
* Construct a flight-scan builder
Expand All @@ -40,14 +45,45 @@ public FlightScanBuilder(Configuration configuration, Table table, PartitionBeha
}

/**
* Only COUNT is supported for now
* Collection aggregations that will be pushed down
* @param aggregation - the pushed aggregation
* @return - true if COUNT is pushed
* @return - pushed aggregation
*/
@Override
public boolean pushAggregation(Aggregation aggregation) {
this._pdCount = Arrays.stream(aggregation.aggregateExpressions()).map(e -> e.toString().equalsIgnoreCase("count(*)")).findFirst().orElse(false);
return this._pdCount;
Function<String, String> quote = (s) -> String.format("%s%s%s", this._table.getColumnQuote(), s, this._table.getColumnQuote());
Function<String[], String> mks = (ss) -> String.join(",", Arrays.stream(ss).map(quote).toArray(String[]::new));

List<String> pdAggregateColumns = new ArrayList<>();
boolean push = true;
for (AggregateFunc agg : aggregation.aggregateExpressions()) {
if (agg instanceof CountStar) {
pdAggregateColumns.add(agg.toString().toLowerCase());
} else if (agg instanceof Count) {
Count c = (Count)agg;
pdAggregateColumns.add(c.isDistinct() ? String.format("count(distinct(%s))", mks.apply(c.column().fieldNames())) : String.format("count(%s)", mks.apply(c.column().fieldNames())));
} else if (agg instanceof Min) {
Min m = (Min)agg;
pdAggregateColumns.add(String.format("min(%s)", mks.apply(m.column().fieldNames())));
} else if (agg instanceof Max) {
Max m = (Max)agg;
pdAggregateColumns.add(String.format("max(%s)", mks.apply(m.column().fieldNames())));
} else if (agg instanceof Sum) {
Sum s = (Sum)agg;
pdAggregateColumns.add(s.isDistinct() ? String.format("sum(distinct(%s))", mks.apply(s.column().fieldNames())) : String.format("sum(%s)", mks.apply(s.column().fieldNames())));
} else {
push = false;
break;
}
};
if (push) {
String[] pdGroupByColumns = Arrays.stream(aggregation.groupByColumns()).flatMap(gbc -> Arrays.stream(gbc.fieldNames()).map(quote)).toArray(String[]::new);
pdAggregateColumns.addAll(0, Arrays.asList(pdGroupByColumns));
this._pdAggregation = pdGroupByColumns.length > 0 ? new PushAggregation(pdAggregateColumns.toArray(new String[0]), pdGroupByColumns) : new PushAggregation(pdAggregateColumns.toArray(new String[0]));
} else {
this._pdAggregation = null;
}
return this._pdAggregation != null;
}

/**
Expand Down Expand Up @@ -105,7 +141,7 @@ public void pruneColumns(StructType columns) {
public Scan build() {
//adjust flight-table upon pushed filters & columns
String where = String.join(" and ", Arrays.stream(this._pdFilters).map(this._table::toWhereClause).toArray(String[]::new));
if (this._table.probe(where, this._pdColumns, this._pdCount, this._partitionBehavior)) {
if (this._table.probe(where, this._pdColumns, this._pdAggregation, this._partitionBehavior)) {
this._table.initialize(this._configuration);
}
return new FlightScan(this._configuration, this._table);
Expand Down
Loading

0 comments on commit 2fd107c

Please sign in to comment.