This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.2 by this push:
new b1f522c [SPARK-34952][SQL] DSv2 Aggregate push down APIs
b1f522c is described below
commit b1f522cf97e78329d0c48e3bcae72bd45f2e698e
Author: Huaxin Gao <[email protected]>
AuthorDate: Mon Jul 26 16:01:22 2021 +0800
[SPARK-34952][SQL] DSv2 Aggregate push down APIs
### What changes were proposed in this pull request?
Add interfaces and APIs to push down Aggregates to V2 Data Source
### Why are the changes needed?
improve performance
### Does this PR introduce _any_ user-facing change?
SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED was added. If this is set to
true, Aggregates are pushed down to Data Source.
### How was this patch tested?
New tests were added to test aggregates push down in
https://github.com/apache/spark/pull/32049. The original PR is split into two
PRs. This PR doesn't contain new tests.
Closes #33352 from huaxingao/aggPushDownInterface.
Authored-by: Huaxin Gao <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
(cherry picked from commit c561ee686551690bee689f37ae5bbd75119994d6)
Signed-off-by: Wenchen Fan <[email protected]>
---
.../AggregateFunc.java} | 13 +-
.../Aggregation.java} | 28 ++-
.../ScanBuilder.java => expressions/Count.java} | 31 ++-
.../CountStar.java} | 22 +-
.../ScanBuilder.java => expressions/Max.java} | 26 ++-
.../ScanBuilder.java => expressions/Min.java} | 28 ++-
.../spark/sql/connector/expressions/Sum.java | 57 +++++
.../spark/sql/connector/read/ScanBuilder.java | 3 +-
.../connector/read/SupportsPushDownAggregates.java | 56 +++++
.../spark/sql/execution/DataSourceScanExec.scala | 25 +-
.../execution/datasources/DataSourceStrategy.scala | 31 +++
.../execution/datasources/jdbc/JDBCOptions.scala | 4 +
.../sql/execution/datasources/jdbc/JDBCRDD.scala | 67 +++++-
.../execution/datasources/jdbc/JDBCRelation.scala | 18 ++
.../datasources/v2/DataSourceV2Strategy.scala | 5 +-
.../execution/datasources/v2/PushDownUtils.scala | 40 ++++
.../datasources/v2/V2ScanRelationPushDown.scala | 188 +++++++++++++--
.../execution/datasources/v2/jdbc/JDBCScan.scala | 23 +-
.../datasources/v2/jdbc/JDBCScanBuilder.scala | 76 +++++-
.../org/apache/spark/sql/jdbc/JDBCV2Suite.scala | 259 ++++++++++++++++++++-
20 files changed, 910 insertions(+), 90 deletions(-)
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/AggregateFunc.java
similarity index 73%
copy from
sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java
copy to
sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/AggregateFunc.java
index cb3eea7..eea8c31 100644
---
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/AggregateFunc.java
@@ -15,18 +15,17 @@
* limitations under the License.
*/
-package org.apache.spark.sql.connector.read;
+package org.apache.spark.sql.connector.expressions;
import org.apache.spark.annotation.Evolving;
+import java.io.Serializable;
+
/**
- * An interface for building the {@link Scan}. Implementations can mixin
SupportsPushDownXYZ
- * interfaces to do operator pushdown, and keep the operator pushdown result
in the returned
- * {@link Scan}.
+ * Base class of the Aggregate Functions.
*
- * @since 3.0.0
+ * @since 3.2.0
*/
@Evolving
-public interface ScanBuilder {
- Scan build();
+public interface AggregateFunc extends Expression, Serializable {
}
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Aggregation.java
similarity index 57%
copy from
sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java
copy to
sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Aggregation.java
index cb3eea7..fdf3031 100644
---
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Aggregation.java
@@ -15,18 +15,32 @@
* limitations under the License.
*/
-package org.apache.spark.sql.connector.read;
+package org.apache.spark.sql.connector.expressions;
import org.apache.spark.annotation.Evolving;
+import java.io.Serializable;
+
/**
- * An interface for building the {@link Scan}. Implementations can mixin
SupportsPushDownXYZ
- * interfaces to do operator pushdown, and keep the operator pushdown result
in the returned
- * {@link Scan}.
+ * Aggregation in SQL statement.
*
- * @since 3.0.0
+ * @since 3.2.0
*/
@Evolving
-public interface ScanBuilder {
- Scan build();
+public final class Aggregation implements Serializable {
+ private AggregateFunc[] aggregateExpressions;
+ private FieldReference[] groupByColumns;
+
+ public Aggregation(AggregateFunc[] aggregateExpressions, FieldReference[]
groupByColumns) {
+ this.aggregateExpressions = aggregateExpressions;
+ this.groupByColumns = groupByColumns;
+ }
+
+ public AggregateFunc[] aggregateExpressions() {
+ return aggregateExpressions;
+ }
+
+ public FieldReference[] groupByColumns() {
+ return groupByColumns;
+ }
}
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Count.java
similarity index 54%
copy from
sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java
copy to
sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Count.java
index cb3eea7..17562a1 100644
---
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Count.java
@@ -15,18 +15,35 @@
* limitations under the License.
*/
-package org.apache.spark.sql.connector.read;
+package org.apache.spark.sql.connector.expressions;
import org.apache.spark.annotation.Evolving;
/**
- * An interface for building the {@link Scan}. Implementations can mixin
SupportsPushDownXYZ
- * interfaces to do operator pushdown, and keep the operator pushdown result
in the returned
- * {@link Scan}.
+ * An aggregate function that returns the number of the specific row in a
group.
*
- * @since 3.0.0
+ * @since 3.2.0
*/
@Evolving
-public interface ScanBuilder {
- Scan build();
+public final class Count implements AggregateFunc {
+ private FieldReference column;
+ private boolean isDistinct;
+
+ public Count(FieldReference column, boolean isDistinct) {
+ this.column = column;
+ this.isDistinct = isDistinct;
+ }
+
+ public FieldReference column() {
+ return column;
+ }
+ public boolean isDinstinct() {
+ return isDistinct;
+ }
+
+ @Override
+ public String toString() { return "Count(" + column.describe() + "," +
isDistinct + ")"; }
+
+ @Override
+ public String describe() { return this.toString(); }
}
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/CountStar.java
similarity index 69%
copy from
sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java
copy to
sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/CountStar.java
index cb3eea7..777a99d 100644
---
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/CountStar.java
@@ -15,18 +15,26 @@
* limitations under the License.
*/
-package org.apache.spark.sql.connector.read;
+package org.apache.spark.sql.connector.expressions;
import org.apache.spark.annotation.Evolving;
/**
- * An interface for building the {@link Scan}. Implementations can mixin
SupportsPushDownXYZ
- * interfaces to do operator pushdown, and keep the operator pushdown result
in the returned
- * {@link Scan}.
+ * An aggregate function that returns the number of rows in a group.
*
- * @since 3.0.0
+ * @since 3.2.0
*/
@Evolving
-public interface ScanBuilder {
- Scan build();
+public final class CountStar implements AggregateFunc {
+
+ public CountStar() {
+ }
+
+ @Override
+ public String toString() {
+ return "CountStar()";
+ }
+
+ @Override
+ public String describe() { return this.toString(); }
}
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Max.java
similarity index 62%
copy from
sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java
copy to
sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Max.java
index cb3eea7..fe7689c 100644
---
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Max.java
@@ -15,18 +15,30 @@
* limitations under the License.
*/
-package org.apache.spark.sql.connector.read;
+package org.apache.spark.sql.connector.expressions;
import org.apache.spark.annotation.Evolving;
/**
- * An interface for building the {@link Scan}. Implementations can mixin
SupportsPushDownXYZ
- * interfaces to do operator pushdown, and keep the operator pushdown result
in the returned
- * {@link Scan}.
+ * An aggregate function that returns the maximum value in a group.
*
- * @since 3.0.0
+ * @since 3.2.0
*/
@Evolving
-public interface ScanBuilder {
- Scan build();
+public final class Max implements AggregateFunc {
+ private FieldReference column;
+
+ public Max(FieldReference column) {
+ this.column = column;
+ }
+
+ public FieldReference column() { return column; }
+
+ @Override
+ public String toString() {
+ return "Max(" + column.describe() + ")";
+ }
+
+ @Override
+ public String describe() { return this.toString(); }
}
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Min.java
similarity index 61%
copy from
sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java
copy to
sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Min.java
index cb3eea7..f528b0b 100644
---
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Min.java
@@ -15,18 +15,32 @@
* limitations under the License.
*/
-package org.apache.spark.sql.connector.read;
+package org.apache.spark.sql.connector.expressions;
import org.apache.spark.annotation.Evolving;
/**
- * An interface for building the {@link Scan}. Implementations can mixin
SupportsPushDownXYZ
- * interfaces to do operator pushdown, and keep the operator pushdown result
in the returned
- * {@link Scan}.
+ * An aggregate function that returns the minimum value in a group.
*
- * @since 3.0.0
+ * @since 3.2.0
*/
@Evolving
-public interface ScanBuilder {
- Scan build();
+public final class Min implements AggregateFunc {
+ private FieldReference column;
+
+ public Min(FieldReference column) {
+ this.column = column;
+ }
+
+ public FieldReference column() {
+ return column;
+ }
+
+ @Override
+ public String toString() {
+ return "Min(" + column.describe() + ")";
+ }
+
+ @Override
+ public String describe() { return this.toString(); }
}
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Sum.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Sum.java
new file mode 100644
index 0000000..4cb34be
--- /dev/null
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Sum.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.expressions;
+
+import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.types.DataType;
+
+/**
+ * An aggregate function that returns the summation of all the values in a
group.
+ *
+ * @since 3.2.0
+ */
+@Evolving
+public final class Sum implements AggregateFunc {
+ private FieldReference column;
+ private DataType dataType;
+ private boolean isDistinct;
+
+ public Sum(FieldReference column, DataType dataType, boolean isDistinct) {
+ this.column = column;
+ this.dataType = dataType;
+ this.isDistinct = isDistinct;
+ }
+
+ public FieldReference column() {
+ return column;
+ }
+ public DataType dataType() {
+ return dataType;
+ }
+ public boolean isDinstinct() {
+ return isDistinct;
+ }
+
+ @Override
+ public String toString() {
+ return "Sum(" + column.describe() + "," + dataType + "," + isDistinct
+ ")";
+ }
+
+ @Override
+ public String describe() { return this.toString(); }
+}
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java
index cb3eea7..b46f620 100644
---
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java
@@ -22,7 +22,8 @@ import org.apache.spark.annotation.Evolving;
/**
* An interface for building the {@link Scan}. Implementations can mixin
SupportsPushDownXYZ
* interfaces to do operator pushdown, and keep the operator pushdown result
in the returned
- * {@link Scan}.
+ * {@link Scan}. When pushing down operators, Spark pushes down filters first,
then pushes down
+ * aggregates or applies column pruning.
*
* @since 3.0.0
*/
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java
new file mode 100644
index 0000000..7efa333
--- /dev/null
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.read;
+
+import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.connector.expressions.Aggregation;
+
+/**
+ * A mix-in interface for {@link ScanBuilder}. Data sources can implement this
interface to
+ * push down aggregates. Spark assumes that the data source can't fully
complete the
+ * grouping work, and will group the data source output again. For queries like
+ * "SELECT min(value) AS m FROM t GROUP BY key", after pushing down the
aggregate
+ * to the data source, the data source can still output data with duplicated
keys, which is OK
+ * as Spark will do GROUP BY key again. The final query plan can be something
like this:
+ * {{{
+ * Aggregate [key#1], [min(min(value)#2) AS m#3]
+ * +- RelationV2[key#1, min(value)#2]
+ * }}}
+ *
+ * <p>
+ * Similarly, if there is no grouping expression, the data source can still
output more than one
+ * rows.
+ *
+ * <p>
+ * When pushing down operators, Spark pushes down filters to the data source
first, then push down
+ * aggregates or apply column pruning. Depends on data source implementation,
aggregates may or
+ * may not be able to be pushed down with filters. If pushed filters still
need to be evaluated
+ * after scanning, aggregates can't be pushed down.
+ *
+ * @since 3.2.0
+ */
+@Evolving
+public interface SupportsPushDownAggregates extends ScanBuilder {
+
+ /**
+ * Pushes down Aggregation to datasource. The order of the datasource scan
output columns should
+ * be: grouping columns, aggregate columns (in the same order as the
aggregate functions in
+ * the given Aggregation).
+ */
+ boolean pushAggregation(Aggregation aggregation);
+}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
index de991fc..603d53a 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
@@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning,
Partitioning, UnknownPartitioning}
import org.apache.spark.sql.catalyst.util.truncatedString
+import org.apache.spark.sql.connector.expressions.Aggregation
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat
=> ParquetSource}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
@@ -102,6 +103,7 @@ case class RowDataSourceScanExec(
requiredSchema: StructType,
filters: Set[Filter],
handledFilters: Set[Filter],
+ aggregation: Option[Aggregation],
rdd: RDD[InternalRow],
@transient relation: BaseRelation,
tableIdentifier: Option[TableIdentifier])
@@ -129,12 +131,29 @@ case class RowDataSourceScanExec(
override def inputRDD: RDD[InternalRow] = rdd
override val metadata: Map[String, String] = {
- val markedFilters = for (filter <- filters) yield {
- if (handledFilters.contains(filter)) s"*$filter" else s"$filter"
+
+ def seqToString(seq: Seq[Any]): String = seq.mkString("[", ", ", "]")
+
+ val (aggString, groupByString) = if (aggregation.nonEmpty) {
+ (seqToString(aggregation.get.aggregateExpressions),
+ seqToString(aggregation.get.groupByColumns))
+ } else {
+ ("[]", "[]")
+ }
+
+ val markedFilters = if (filters.nonEmpty) {
+ for (filter <- filters) yield {
+ if (handledFilters.contains(filter)) s"*$filter" else s"$filter"
+ }
+ } else {
+ handledFilters
}
+
Map(
"ReadSchema" -> requiredSchema.catalogString,
- "PushedFilters" -> markedFilters.mkString("[", ", ", "]"))
+ "PushedFilters" -> seqToString(markedFilters.toSeq),
+ "PushedAggregates" -> aggString,
+ "PushedGroupby" -> groupByString)
}
// Don't care about `rdd` and `tableIdentifier` when canonicalizing.
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 9e33723..2f334de 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -33,12 +33,14 @@ import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.planning.ScanOperation
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoDir,
InsertIntoStatement, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
import org.apache.spark.sql.connector.catalog.SupportsRead
import org.apache.spark.sql.connector.catalog.TableCapability._
+import org.apache.spark.sql.connector.expressions.{AggregateFunc, Count,
CountStar, FieldReference, Max, Min, Sum}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec,
SparkPlan}
import org.apache.spark.sql.execution.command._
@@ -332,6 +334,7 @@ object DataSourceStrategy
l.output.toStructType,
Set.empty,
Set.empty,
+ None,
toCatalystRDD(l, baseRelation.buildScan()),
baseRelation,
None) :: Nil
@@ -405,6 +408,7 @@ object DataSourceStrategy
requestedColumns.toStructType,
pushedFilters.toSet,
handledFilters,
+ None,
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
relation.relation,
relation.catalogTable.map(_.identifier))
@@ -427,6 +431,7 @@ object DataSourceStrategy
requestedColumns.toStructType,
pushedFilters.toSet,
handledFilters,
+ None,
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
relation.relation,
relation.catalogTable.map(_.identifier))
@@ -692,6 +697,32 @@ object DataSourceStrategy
(nonconvertiblePredicates ++ unhandledPredicates, pushedFilters,
handledFilters)
}
+ protected[sql] def translateAggregate(aggregates: AggregateExpression):
Option[AggregateFunc] = {
+ if (aggregates.filter.isEmpty) {
+ aggregates.aggregateFunction match {
+ case aggregate.Min(PushableColumnWithoutNestedColumn(name)) =>
+ Some(new Min(FieldReference(name).asInstanceOf[FieldReference]))
+ case aggregate.Max(PushableColumnWithoutNestedColumn(name)) =>
+ Some(new Max(FieldReference(name).asInstanceOf[FieldReference]))
+ case count: aggregate.Count if count.children.length == 1 =>
+ count.children.head match {
+ // SELECT COUNT(*) FROM table is translated to SELECT 1 FROM table
+ case Literal(_, _) => Some(new CountStar())
+ case PushableColumnWithoutNestedColumn(name) =>
+ Some(new Count(FieldReference(name).asInstanceOf[FieldReference],
+ aggregates.isDistinct))
+ case _ => None
+ }
+ case sum @ aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) =>
+ Some(new Sum(FieldReference(name).asInstanceOf[FieldReference],
+ sum.dataType, aggregates.isDistinct))
+ case _ => None
+ }
+ } else {
+ None
+ }
+ }
+
/**
* Convert RDD of Row into RDD of InternalRow with objects in catalyst types
*/
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
index 97d4f2d..8b2ae2b 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
@@ -188,6 +188,9 @@ class JDBCOptions(
// An option to allow/disallow pushing down predicate into JDBC data source
val pushDownPredicate = parameters.getOrElse(JDBC_PUSHDOWN_PREDICATE,
"true").toBoolean
+ // An option to allow/disallow pushing down aggregate into JDBC data source
+ val pushDownAggregate = parameters.getOrElse(JDBC_PUSHDOWN_AGGREGATE,
"false").toBoolean
+
// The local path of user's keytab file, which is assumed to be pre-uploaded
to all nodes either
// by --files option of spark-submit or manually
val keytab = {
@@ -259,6 +262,7 @@ object JDBCOptions {
val JDBC_TXN_ISOLATION_LEVEL = newOption("isolationLevel")
val JDBC_SESSION_INIT_STATEMENT = newOption("sessionInitStatement")
val JDBC_PUSHDOWN_PREDICATE = newOption("pushDownPredicate")
+ val JDBC_PUSHDOWN_AGGREGATE = newOption("pushDownAggregate")
val JDBC_KEYTAB = newOption("keytab")
val JDBC_PRINCIPAL = newOption("principal")
val JDBC_TABLE_COMMENT = newOption("tableComment")
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index 87ca78d..c22ca15 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -25,6 +25,7 @@ import org.apache.spark.{InterruptibleIterator, Partition,
SparkContext, TaskCon
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.connector.expressions.{AggregateFunc, Count,
CountStar, FieldReference, Max, Min, Sum}
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
@@ -133,6 +134,34 @@ object JDBCRDD extends Logging {
})
}
+ def compileAggregates(
+ aggregates: Seq[AggregateFunc],
+ dialect: JdbcDialect): Seq[String] = {
+ def quote(colName: String): String = dialect.quoteIdentifier(colName)
+
+ aggregates.map {
+ case min: Min =>
+ assert(min.column.fieldNames.length == 1)
+ s"MIN(${quote(min.column.fieldNames.head)})"
+ case max: Max =>
+ assert(max.column.fieldNames.length == 1)
+ s"MAX(${quote(max.column.fieldNames.head)})"
+ case count: Count =>
+ assert(count.column.fieldNames.length == 1)
+ val distinct = if (count.isDinstinct) "DISTINCT" else ""
+ val column = quote(count.column.fieldNames.head)
+ s"COUNT($distinct $column)"
+ case sum: Sum =>
+ assert(sum.column.fieldNames.length == 1)
+ val distinct = if (sum.isDinstinct) "DISTINCT" else ""
+ val column = quote(sum.column.fieldNames.head)
+ s"SUM($distinct $column)"
+ case _: CountStar =>
+ s"COUNT(1)"
+ case _ => ""
+ }
+ }
+
/**
* Build and return JDBCRDD from the given information.
*
@@ -143,6 +172,8 @@ object JDBCRDD extends Logging {
* @param parts - An array of JDBCPartitions specifying partition ids and
* per-partition WHERE clauses.
* @param options - JDBC options that contains url, table and other
information.
+ * @param requiredSchema - The schema of the columns to SELECT.
+ * @param aggregation - The pushed down aggregation
*
* @return An RDD representing "SELECT requiredColumns FROM fqTable".
*/
@@ -152,19 +183,27 @@ object JDBCRDD extends Logging {
requiredColumns: Array[String],
filters: Array[Filter],
parts: Array[Partition],
- options: JDBCOptions): RDD[InternalRow] = {
+ options: JDBCOptions,
+ outputSchema: Option[StructType] = None,
+ groupByColumns: Option[Array[FieldReference]] = None): RDD[InternalRow]
= {
val url = options.url
val dialect = JdbcDialects.get(url)
- val quotedColumns = requiredColumns.map(colName =>
dialect.quoteIdentifier(colName))
+ val quotedColumns = if (groupByColumns.isEmpty) {
+ requiredColumns.map(colName => dialect.quoteIdentifier(colName))
+ } else {
+ // these are already quoted in JDBCScanBuilder
+ requiredColumns
+ }
new JDBCRDD(
sc,
JdbcUtils.createConnectionFactory(options),
- pruneSchema(schema, requiredColumns),
+ outputSchema.getOrElse(pruneSchema(schema, requiredColumns)),
quotedColumns,
filters,
parts,
url,
- options)
+ options,
+ groupByColumns)
}
}
@@ -181,7 +220,8 @@ private[jdbc] class JDBCRDD(
filters: Array[Filter],
partitions: Array[Partition],
url: String,
- options: JDBCOptions)
+ options: JDBCOptions,
+ groupByColumns: Option[Array[FieldReference]])
extends RDD[InternalRow](sc, Nil) {
/**
@@ -222,6 +262,20 @@ private[jdbc] class JDBCRDD(
}
/**
+ * A GROUP BY clause representing pushed-down grouping columns.
+ */
+ private def getGroupByClause: String = {
+ if (groupByColumns.nonEmpty && groupByColumns.get.nonEmpty) {
+ assert(groupByColumns.get.forall(_.fieldNames.length == 1))
+ val dialect = JdbcDialects.get(url)
+ val quotedColumns = groupByColumns.get.map(c =>
dialect.quoteIdentifier(c.fieldNames.head))
+ s"GROUP BY ${quotedColumns.mkString(", ")}"
+ } else {
+ ""
+ }
+ }
+
+ /**
* Runs the SQL query against the JDBC driver.
*
*/
@@ -296,7 +350,8 @@ private[jdbc] class JDBCRDD(
val myWhereClause = getWhereClause(part)
- val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery}
$myWhereClause"
+ val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery}
$myWhereClause" +
+ s" $getGroupByClause"
stmt = conn.prepareStatement(sqlText,
ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
stmt.setFetchSize(options.fetchSize)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
index 4ec9a4f..5fb26d2 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.{DataFrame, Row, SaveMode,
SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils,
TimestampFormatter}
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId,
stringToDate, stringToTimestamp}
+import org.apache.spark.sql.connector.expressions.FieldReference
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.jdbc.JdbcDialects
@@ -288,6 +289,23 @@ private[sql] case class JDBCRelation(
jdbcOptions).asInstanceOf[RDD[Row]]
}
+ def buildScan(
+ requiredColumns: Array[String],
+ requireSchema: Option[StructType],
+ filters: Array[Filter],
+ groupByColumns: Option[Array[FieldReference]]): RDD[Row] = {
+ // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
+ JDBCRDD.scanTable(
+ sparkSession.sparkContext,
+ schema,
+ requiredColumns,
+ filters,
+ parts,
+ jdbcOptions,
+ requireSchema,
+ groupByColumns).asInstanceOf[RDD[Row]]
+ }
+
override def insert(data: DataFrame, overwrite: Boolean): Unit = {
data.write
.mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
index 1ab554f..4d77674 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
@@ -87,7 +87,7 @@ class DataSourceV2Strategy(session: SparkSession) extends
Strategy with Predicat
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case PhysicalOperation(project, filters,
- relation @ DataSourceV2ScanRelation(_, V1ScanWrapper(scan, translated,
pushed), output)) =>
+ DataSourceV2ScanRelation(_, V1ScanWrapper(scan, pushed, aggregate),
output)) =>
val v1Relation = scan.toV1TableScan[BaseRelation with
TableScan](session.sqlContext)
if (v1Relation.schema != scan.readSchema()) {
throw
QueryExecutionErrors.fallbackV1RelationReportsInconsistentSchemaError(
@@ -98,8 +98,9 @@ class DataSourceV2Strategy(session: SparkSession) extends
Strategy with Predicat
val dsScan = RowDataSourceScanExec(
output,
output.toStructType,
- translated.toSet,
+ Set.empty,
pushed.toSet,
+ aggregate,
unsafeRowRDD,
v1Relation,
tableIdentifier = None)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
index 1f57f17..ab5c5da 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
@@ -20,9 +20,13 @@ package org.apache.spark.sql.execution.datasources.v2
import scala.collection.mutable
import org.apache.spark.sql.catalyst.expressions.{AttributeReference,
AttributeSet, Expression, NamedExpression, PredicateHelper, SchemaPruning}
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
+import org.apache.spark.sql.connector.expressions.{Aggregation, FieldReference}
+import org.apache.spark.sql.connector.read.{Scan, ScanBuilder,
SupportsPushDownAggregates, SupportsPushDownFilters,
SupportsPushDownRequiredColumns}
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder,
SupportsPushDownFilters, SupportsPushDownRequiredColumns}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
+import
org.apache.spark.sql.execution.datasources.PushableColumnWithoutNestedColumn
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources
import org.apache.spark.sql.types.StructType
@@ -71,6 +75,42 @@ object PushDownUtils extends PredicateHelper {
}
/**
+ * Pushes down aggregates to the data source reader
+ *
+ * @return pushed aggregation.
+ */
+ def pushAggregates(
+ scanBuilder: ScanBuilder,
+ aggregates: Seq[AggregateExpression],
+ groupBy: Seq[Expression]): Option[Aggregation] = {
+
+ def columnAsString(e: Expression): Option[FieldReference] = e match {
+ case PushableColumnWithoutNestedColumn(name) =>
+ Some(FieldReference(name).asInstanceOf[FieldReference])
+ case _ => None
+ }
+
+ scanBuilder match {
+ case r: SupportsPushDownAggregates =>
+ val translatedAggregates =
aggregates.map(DataSourceStrategy.translateAggregate).flatten
+ val translatedGroupBys = groupBy.map(columnAsString).flatten
+
+ if (translatedAggregates.length != aggregates.length ||
+ translatedGroupBys.length != groupBy.length) {
+ return None
+ }
+
+ val agg = new Aggregation(translatedAggregates.toArray,
translatedGroupBys.toArray)
+ if (r.pushAggregation(agg)) {
+ Some(agg)
+ } else {
+ None
+ }
+ case _ => None
+ }
+ }
+
+ /**
* Applies column pruning to the data source, w.r.t. the references of the
given expressions.
*
* @return the `Scan` instance (since column pruning is the last step of
operator pushdown),
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
index d218056..445ff03 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
@@ -17,23 +17,36 @@
package org.apache.spark.sql.execution.datasources.v2
-import org.apache.spark.sql.catalyst.expressions.{And, Expression,
NamedExpression, ProjectionOverSchema, SubqueryExpression}
+import org.apache.spark.sql.catalyst.expressions.{And, Attribute,
AttributeReference, Expression, NamedExpression, PredicateHelper,
ProjectionOverSchema, SubqueryExpression}
+import org.apache.spark.sql.catalyst.expressions.aggregate
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.planning.ScanOperation
-import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan,
Project}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter,
LeafNode, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.connector.read.{Scan, V1Scan}
+import org.apache.spark.sql.connector.expressions.Aggregation
+import org.apache.spark.sql.connector.read.{Scan, ScanBuilder,
SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.sources
import org.apache.spark.sql.types.StructType
-object V2ScanRelationPushDown extends Rule[LogicalPlan] {
+object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
import DataSourceV2Implicits._
- override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
- case ScanOperation(project, filters, relation: DataSourceV2Relation) =>
- val scanBuilder =
relation.table.asReadable.newScanBuilder(relation.options)
+ def apply(plan: LogicalPlan): LogicalPlan = {
+
applyColumnPruning(pushdownAggregate(pushDownFilters(createScanBuilder(plan))))
+ }
+
+ private def createScanBuilder(plan: LogicalPlan) = plan.transform {
+ case r: DataSourceV2Relation =>
+ ScanBuilderHolder(r.output, r,
r.table.asReadable.newScanBuilder(r.options))
+ }
- val normalizedFilters = DataSourceStrategy.normalizeExprs(filters,
relation.output)
+ private def pushDownFilters(plan: LogicalPlan) = plan.transform {
+ // update the scan builder with filter push down and return a new plan
with filter pushed
+ case Filter(condition, sHolder: ScanBuilderHolder) =>
+ val filters = splitConjunctivePredicates(condition)
+ val normalizedFilters =
+ DataSourceStrategy.normalizeExprs(filters, sHolder.relation.output)
val (normalizedFiltersWithSubquery, normalizedFiltersWithoutSubquery) =
normalizedFilters.partition(SubqueryExpression.hasSubquery)
@@ -41,37 +54,142 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] {
// `postScanFilters` need to be evaluated after the scan.
// `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet
row group filter.
val (pushedFilters, postScanFiltersWithoutSubquery) =
PushDownUtils.pushFilters(
- scanBuilder, normalizedFiltersWithoutSubquery)
+ sHolder.builder, normalizedFiltersWithoutSubquery)
val postScanFilters = postScanFiltersWithoutSubquery ++
normalizedFiltersWithSubquery
+ logInfo(
+ s"""
+ |Pushing operators to ${sHolder.relation.name}
+ |Pushed Filters: ${pushedFilters.mkString(", ")}
+ |Post-Scan Filters: ${postScanFilters.mkString(",")}
+ """.stripMargin)
+
+ val filterCondition = postScanFilters.reduceLeftOption(And)
+ filterCondition.map(Filter(_, sHolder)).getOrElse(sHolder)
+ }
+
+ def pushdownAggregate(plan: LogicalPlan): LogicalPlan = plan.transform {
+ // update the scan builder with agg pushdown and return a new plan with
agg pushed
+ case aggNode @ Aggregate(groupingExpressions, resultExpressions, child) =>
+ child match {
+ case ScanOperation(project, filters, sHolder: ScanBuilderHolder)
+ if filters.isEmpty &&
project.forall(_.isInstanceOf[AttributeReference]) =>
+ sHolder.builder match {
+ case _: SupportsPushDownAggregates =>
+ val aggregates = resultExpressions.flatMap { expr =>
+ expr.collect {
+ case agg: AggregateExpression => agg
+ }
+ }
+ val pushedAggregates = PushDownUtils
+ .pushAggregates(sHolder.builder, aggregates,
groupingExpressions)
+ if (pushedAggregates.isEmpty) {
+ aggNode // return original plan node
+ } else {
+ // No need to do column pruning because only the aggregate
columns are used as
+ // DataSourceV2ScanRelation output columns. All the other
columns are not
+ // included in the output.
+ val scan = sHolder.builder.build()
+
+ // scalastyle:off
+ // use the group by columns and aggregate columns as the
output columns
+ // e.g. TABLE t (c1 INT, c2 INT, c3 INT)
+ // SELECT min(c1), max(c1) FROM t GROUP BY c2;
+ // Use c2, min(c1), max(c1) as output for
DataSourceV2ScanRelation
+ // We want to have the following logical plan:
+ // == Optimized Logical Plan ==
+ // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17,
max(max(c1)#22) AS max(c1)#18]
+ // +- RelationV2[c2#10, min(c1)#21, max(c1)#22]
+ // scalastyle:on
+ val newOutput = scan.readSchema().toAttributes
+ assert(newOutput.length == groupingExpressions.length +
aggregates.length)
+ val groupAttrs = groupingExpressions.zip(newOutput).map {
+ case (a: Attribute, b: Attribute) => b.withExprId(a.exprId)
+ case (_, b) => b
+ }
+ val output = groupAttrs ++ newOutput.drop(groupAttrs.length)
+
+ logInfo(
+ s"""
+ |Pushing operators to ${sHolder.relation.name}
+ |Pushed Aggregate Functions:
+ | ${pushedAggregates.get.aggregateExpressions.mkString(",
")}
+ |Pushed Group by:
+ | ${pushedAggregates.get.groupByColumns.mkString(", ")}
+ |Output: ${output.mkString(", ")}
+ """.stripMargin)
+
+ val wrappedScan = getWrappedScan(scan, sHolder,
pushedAggregates)
+
+ val scanRelation = DataSourceV2ScanRelation(sHolder.relation,
wrappedScan, output)
+
+ val plan = Aggregate(
+ output.take(groupingExpressions.length), resultExpressions,
scanRelation)
+
+ // scalastyle:off
+ // Change the optimized logical plan to reflect the pushed
down aggregate
+ // e.g. TABLE t (c1 INT, c2 INT, c3 INT)
+ // SELECT min(c1), max(c1) FROM t GROUP BY c2;
+ // The original logical plan is
+ // Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9) AS
max(c1)#18]
+ // +- RelationV2[c1#9, c2#10] ...
+ //
+ // After change the V2ScanRelation output to [c2#10,
min(c1)#21, max(c1)#22]
+ // we have the following
+ // !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9) AS
max(c1)#18]
+ // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
+ //
+ // We want to change it to
+ // == Optimized Logical Plan ==
+ // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17,
max(max(c1)#22) AS max(c1)#18]
+ // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
+ // scalastyle:on
+ var i = 0
+ val aggOutput = output.drop(groupAttrs.length)
+ plan.transformExpressions {
+ case agg: AggregateExpression =>
+ val aggFunction: aggregate.AggregateFunction =
+ agg.aggregateFunction match {
+ case max: aggregate.Max => max.copy(child =
aggOutput(i))
+ case min: aggregate.Min => min.copy(child =
aggOutput(i))
+ case sum: aggregate.Sum => sum.copy(child =
aggOutput(i))
+ case _: aggregate.Count => aggregate.Sum(aggOutput(i))
+ case other => other
+ }
+ i += 1
+ agg.copy(aggregateFunction = aggFunction)
+ }
+ }
+ case _ => aggNode
+ }
+ case _ => aggNode
+ }
+ }
+
+ def applyColumnPruning(plan: LogicalPlan): LogicalPlan = plan.transform {
+ case ScanOperation(project, filters, sHolder: ScanBuilderHolder) =>
+ // column pruning
val normalizedProjects = DataSourceStrategy
- .normalizeExprs(project, relation.output)
+ .normalizeExprs(project, sHolder.output)
.asInstanceOf[Seq[NamedExpression]]
val (scan, output) = PushDownUtils.pruneColumns(
- scanBuilder, relation, normalizedProjects, postScanFilters)
+ sHolder.builder, sHolder.relation, normalizedProjects, filters)
+
logInfo(
s"""
- |Pushing operators to ${relation.name}
- |Pushed Filters: ${pushedFilters.mkString(", ")}
- |Post-Scan Filters: ${postScanFilters.mkString(",")}
|Output: ${output.mkString(", ")}
""".stripMargin)
- val wrappedScan = scan match {
- case v1: V1Scan =>
- val translated =
filters.flatMap(DataSourceStrategy.translateFilter(_, true))
- V1ScanWrapper(v1, translated, pushedFilters)
- case _ => scan
- }
+ val wrappedScan = getWrappedScan(scan, sHolder,
Option.empty[Aggregation])
- val scanRelation = DataSourceV2ScanRelation(relation, wrappedScan,
output)
+ val scanRelation = DataSourceV2ScanRelation(sHolder.relation,
wrappedScan, output)
val projectionOverSchema = ProjectionOverSchema(output.toStructType)
val projectionFunc = (expr: Expression) => expr transformDown {
case projectionOverSchema(newExpr) => newExpr
}
- val filterCondition = postScanFilters.reduceLeftOption(And)
+ val filterCondition = filters.reduceLeftOption(And)
val newFilterCondition = filterCondition.map(projectionFunc)
val withFilter = newFilterCondition.map(Filter(_,
scanRelation)).getOrElse(scanRelation)
@@ -83,16 +201,36 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] {
} else {
withFilter
}
-
withProjection
}
+
+ private def getWrappedScan(
+ scan: Scan,
+ sHolder: ScanBuilderHolder,
+ aggregation: Option[Aggregation]): Scan = {
+ scan match {
+ case v1: V1Scan =>
+ val pushedFilters = sHolder.builder match {
+ case f: SupportsPushDownFilters =>
+ f.pushedFilters()
+ case _ => Array.empty[sources.Filter]
+ }
+ V1ScanWrapper(v1, pushedFilters, aggregation)
+ case _ => scan
+ }
+ }
}
+case class ScanBuilderHolder(
+ output: Seq[AttributeReference],
+ relation: DataSourceV2Relation,
+ builder: ScanBuilder) extends LeafNode
+
// A wrapper for v1 scan to carry the translated filters and the handled ones.
This is required by
// the physical v1 scan node.
case class V1ScanWrapper(
v1Scan: V1Scan,
- translatedFilters: Seq[sources.Filter],
- handledFilters: Seq[sources.Filter]) extends Scan {
+ handledFilters: Seq[sources.Filter],
+ pushedAggregate: Option[Aggregation]) extends Scan {
override def readSchema(): StructType = v1Scan.readSchema()
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala
index 860232b..d6ae7c8 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala
@@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2.jdbc
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.connector.expressions.FieldReference
import org.apache.spark.sql.connector.read.V1Scan
import org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation
import org.apache.spark.sql.sources.{BaseRelation, Filter, TableScan}
@@ -26,7 +27,9 @@ import org.apache.spark.sql.types.StructType
case class JDBCScan(
relation: JDBCRelation,
prunedSchema: StructType,
- pushedFilters: Array[Filter]) extends V1Scan {
+ pushedFilters: Array[Filter],
+ pushedAggregateColumn: Array[String] = Array(),
+ groupByColumns: Option[Array[FieldReference]]) extends V1Scan {
override def readSchema(): StructType = prunedSchema
@@ -36,14 +39,28 @@ case class JDBCScan(
override def schema: StructType = prunedSchema
override def needConversion: Boolean = relation.needConversion
override def buildScan(): RDD[Row] = {
- relation.buildScan(prunedSchema.map(_.name).toArray, pushedFilters)
+ if (groupByColumns.isEmpty) {
+ relation.buildScan(
+ prunedSchema.map(_.name).toArray, Some(prunedSchema),
pushedFilters, groupByColumns)
+ } else {
+ relation.buildScan(
+ pushedAggregateColumn, Some(prunedSchema), pushedFilters,
groupByColumns)
+ }
}
}.asInstanceOf[T]
}
override def description(): String = {
+ val (aggString, groupByString) = if (groupByColumns.nonEmpty) {
+ val groupByColumnsLength = groupByColumns.get.length
+ (seqToString(pushedAggregateColumn.drop(groupByColumnsLength)),
+ seqToString(pushedAggregateColumn.take(groupByColumnsLength)))
+ } else {
+ ("[]", "[]")
+ }
super.description() + ", prunedSchema: " + seqToString(prunedSchema) +
- ", PushedFilters: " + seqToString(pushedFilters)
+ ", PushedFilters: " + seqToString(pushedFilters) +
+ ", PushedAggregates: " + aggString + ", PushedGroupBy: " + groupByString
}
private def seqToString(seq: Seq[Any]): String = seq.mkString("[", ", ", "]")
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
index 270c5b6..7442eda 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
@@ -17,18 +17,20 @@
package org.apache.spark.sql.execution.datasources.v2.jdbc
import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.connector.read.{Scan, ScanBuilder,
SupportsPushDownFilters, SupportsPushDownRequiredColumns}
+import org.apache.spark.sql.connector.expressions.{Aggregation, Count,
CountStar, FieldReference, Max, Min, Sum}
+import org.apache.spark.sql.connector.read.{Scan, ScanBuilder,
SupportsPushDownAggregates, SupportsPushDownFilters,
SupportsPushDownRequiredColumns}
import org.apache.spark.sql.execution.datasources.PartitioningUtils
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD,
JDBCRelation}
import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.sources.Filter
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{LongType, StructField, StructType}
case class JDBCScanBuilder(
session: SparkSession,
schema: StructType,
jdbcOptions: JDBCOptions)
- extends ScanBuilder with SupportsPushDownFilters with
SupportsPushDownRequiredColumns {
+ extends ScanBuilder with SupportsPushDownFilters with
SupportsPushDownRequiredColumns
+ with SupportsPushDownAggregates{
private val isCaseSensitive = session.sessionState.conf.caseSensitiveAnalysis
@@ -49,6 +51,58 @@ case class JDBCScanBuilder(
override def pushedFilters(): Array[Filter] = pushedFilter
+ private var pushedAggregations = Option.empty[Aggregation]
+
+ private var pushedAggregateColumn: Array[String] = Array()
+
+ private def getStructFieldForCol(col: FieldReference): StructField =
+ schema.fields(schema.fieldNames.toList.indexOf(col.fieldNames.head))
+
+ override def pushAggregation(aggregation: Aggregation): Boolean = {
+ if (!jdbcOptions.pushDownAggregate) return false
+
+ val dialect = JdbcDialects.get(jdbcOptions.url)
+ val compiledAgg =
JDBCRDD.compileAggregates(aggregation.aggregateExpressions, dialect)
+
+ var outputSchema = new StructType()
+ aggregation.groupByColumns.foreach { col =>
+ val structField = getStructFieldForCol(col)
+ outputSchema = outputSchema.add(structField)
+ pushedAggregateColumn = pushedAggregateColumn :+
dialect.quoteIdentifier(structField.name)
+ }
+
+ // The column names here are already quoted and can be used to build sql
string directly.
+ // e.g. "DEPT","NAME",MAX("SALARY"),MIN("BONUS") =>
+ // SELECT "DEPT","NAME",MAX("SALARY"),MIN("BONUS") FROM "test"."employee"
+ // GROUP BY "DEPT", "NAME"
+ pushedAggregateColumn = pushedAggregateColumn ++ compiledAgg
+
+ aggregation.aggregateExpressions.foreach {
+ case max: Max =>
+ val structField = getStructFieldForCol(max.column)
+ outputSchema = outputSchema.add(structField.copy("max(" +
structField.name + ")"))
+ case min: Min =>
+ val structField = getStructFieldForCol(min.column)
+ outputSchema = outputSchema.add(structField.copy("min(" +
structField.name + ")"))
+ case count: Count =>
+ val distinct = if (count.isDinstinct) "DISTINCT " else ""
+ val structField = getStructFieldForCol(count.column)
+ outputSchema =
+ outputSchema.add(StructField(s"count($distinct" + structField.name +
")", LongType))
+ case _: CountStar =>
+ outputSchema = outputSchema.add(StructField("count(*)", LongType))
+ case sum: Sum =>
+ val distinct = if (sum.isDinstinct) "DISTINCT " else ""
+ val structField = getStructFieldForCol(sum.column)
+ outputSchema =
+ outputSchema.add(StructField(s"sum($distinct" + structField.name +
")", sum.dataType))
+ case _ => return false
+ }
+ this.pushedAggregations = Some(aggregation)
+ prunedSchema = outputSchema
+ true
+ }
+
override def pruneColumns(requiredSchema: StructType): Unit = {
// JDBC doesn't support nested column pruning.
// TODO (SPARK-32593): JDBC support nested column and nested column
pruning.
@@ -65,6 +119,20 @@ case class JDBCScanBuilder(
val resolver = session.sessionState.conf.resolver
val timeZoneId = session.sessionState.conf.sessionLocalTimeZone
val parts = JDBCRelation.columnPartition(schema, resolver, timeZoneId,
jdbcOptions)
- JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), prunedSchema,
pushedFilter)
+
+ // in prunedSchema, the schema is either pruned in pushAggregation (if
aggregates are
+ // pushed down), or pruned in pruneColumns (in regular column pruning).
These
+ // two are mutual exclusive.
+ // For aggregate push down case, we want to pass down the quoted column
lists such as
+ // "DEPT","NAME",MAX("SALARY"),MIN("BONUS"), instead of getting column
names from
+ // prunedSchema and quote them (will become "MAX(SALARY)", "MIN(BONUS)"
and can't
+ // be used in sql string.
+ val groupByColumns = if (pushedAggregations.nonEmpty) {
+ Some(pushedAggregations.get.groupByColumns)
+ } else {
+ Option.empty[Array[FieldReference]]
+ }
+ JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), prunedSchema,
pushedFilter,
+ pushedAggregateColumn, groupByColumns)
}
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index a3a3f47..c1f8f5f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -21,16 +21,16 @@ import java.sql.{Connection, DriverManager}
import java.util.Properties
import org.apache.spark.SparkConf
-import org.apache.spark.sql.{QueryTest, Row}
+import org.apache.spark.sql.{ExplainSuiteHelper, QueryTest, Row}
import
org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException
import org.apache.spark.sql.catalyst.plans.logical.Filter
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog
-import org.apache.spark.sql.functions.lit
+import org.apache.spark.sql.functions.{lit, sum, udf}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.util.Utils
-class JDBCV2Suite extends QueryTest with SharedSparkSession {
+class JDBCV2Suite extends QueryTest with SharedSparkSession with
ExplainSuiteHelper {
import testImplicits._
val tempDir = Utils.createTempDir()
@@ -41,6 +41,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession {
.set("spark.sql.catalog.h2", classOf[JDBCTableCatalog].getName)
.set("spark.sql.catalog.h2.url", url)
.set("spark.sql.catalog.h2.driver", "org.h2.Driver")
+ .set("spark.sql.catalog.h2.pushDownAggregate", "true")
private def withConnection[T](f: Connection => T): T = {
val conn = DriverManager.getConnection(url, new Properties())
@@ -64,6 +65,19 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession {
.executeUpdate()
conn.prepareStatement("INSERT INTO \"test\".\"people\" VALUES ('fred',
1)").executeUpdate()
conn.prepareStatement("INSERT INTO \"test\".\"people\" VALUES ('mary',
2)").executeUpdate()
+ conn.prepareStatement(
+ "CREATE TABLE \"test\".\"employee\" (dept INTEGER, name TEXT(32),
salary NUMERIC(20, 2)," +
+ " bonus DOUBLE)").executeUpdate()
+ conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (1,
'amy', 10000, 1000)")
+ .executeUpdate()
+ conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (2,
'alex', 12000, 1200)")
+ .executeUpdate()
+ conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (1,
'cathy', 9000, 1200)")
+ .executeUpdate()
+ conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (2,
'david', 10000, 1300)")
+ .executeUpdate()
+ conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (6,
'jen', 12000, 1200)")
+ .executeUpdate()
}
}
@@ -84,6 +98,14 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession {
case f: Filter => f
}
assert(filters.isEmpty)
+
+ df.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedFilters: [IsNotNull(ID), GreaterThan(ID,1)]"
+ checkKeywordsExistsInExplain(df, expected_plan_fragment)
+ }
+
checkAnswer(df, Row("mary", 2))
}
@@ -145,7 +167,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession
{
test("show tables") {
checkAnswer(sql("SHOW TABLES IN h2.test"),
- Seq(Row("test", "people", false), Row("test", "empty_table", false)))
+ Seq(Row("test", "people", false), Row("test", "empty_table", false),
+ Row("test", "employee", false)))
}
test("SQL API: create table as select") {
@@ -214,4 +237,232 @@ class JDBCV2Suite extends QueryTest with
SharedSparkSession {
checkAnswer(sql("SELECT name, id FROM h2.test.abc"), Row("bob", 4))
}
}
+
+ test("scan with aggregate push-down: MAX MIN with filter and group by") {
+ val df = sql("select MAX(SALARY), MIN(BONUS) FROM h2.test.employee where
dept > 0" +
+ " group by DEPT")
+ val filters = df.queryExecution.optimizedPlan.collect {
+ case f: Filter => f
+ }
+ assert(filters.isEmpty)
+ df.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregates: [Max(SALARY), Min(BONUS)], " +
+ "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " +
+ "PushedGroupby: [DEPT]"
+ checkKeywordsExistsInExplain(df, expected_plan_fragment)
+ }
+ checkAnswer(df, Seq(Row(10000, 1000), Row(12000, 1200), Row(12000, 1200)))
+ }
+
+ test("scan with aggregate push-down: MAX MIN with filter without group by") {
+ val df = sql("select MAX(ID), MIN(ID) FROM h2.test.people where id > 0")
+ val filters = df.queryExecution.optimizedPlan.collect {
+ case f: Filter => f
+ }
+ assert(filters.isEmpty)
+ df.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregates: [Max(ID), Min(ID)], " +
+ "PushedFilters: [IsNotNull(ID), GreaterThan(ID,0)], " +
+ "PushedGroupby: []"
+ checkKeywordsExistsInExplain(df, expected_plan_fragment)
+ }
+ checkAnswer(df, Seq(Row(2, 1)))
+ }
+
+ test("scan with aggregate push-down: aggregate + number") {
+ val df = sql("select MAX(SALARY) + 1 FROM h2.test.employee")
+ df.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregates: [Max(SALARY)]"
+ checkKeywordsExistsInExplain(df, expected_plan_fragment)
+ }
+ checkAnswer(df, Seq(Row(12001)))
+ }
+
+ test("scan with aggregate push-down: COUNT(*)") {
+ val df = sql("select COUNT(*) FROM h2.test.employee")
+ df.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregates: [CountStar()]"
+ checkKeywordsExistsInExplain(df, expected_plan_fragment)
+ }
+ checkAnswer(df, Seq(Row(5)))
+ }
+
+ test("scan with aggregate push-down: COUNT(col)") {
+ val df = sql("select COUNT(DEPT) FROM h2.test.employee")
+ df.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregates: [Count(DEPT,false)]"
+ checkKeywordsExistsInExplain(df, expected_plan_fragment)
+ }
+ checkAnswer(df, Seq(Row(5)))
+ }
+
+ test("scan with aggregate push-down: COUNT(DISTINCT col)") {
+ val df = sql("select COUNT(DISTINCT DEPT) FROM h2.test.employee")
+ df.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregates: [Count(DEPT,true)]"
+ checkKeywordsExistsInExplain(df, expected_plan_fragment)
+ }
+ checkAnswer(df, Seq(Row(3)))
+ }
+
+ test("scan with aggregate push-down: SUM without filer and group by") {
+ val df = sql("SELECT SUM(SALARY) FROM h2.test.employee")
+ df.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregates: [Sum(SALARY,DecimalType(30,2),false)]"
+ checkKeywordsExistsInExplain(df, expected_plan_fragment)
+ }
+ checkAnswer(df, Seq(Row(53000)))
+ }
+
+ test("scan with aggregate push-down: DISTINCT SUM without filer and group
by") {
+ val df = sql("SELECT SUM(DISTINCT SALARY) FROM h2.test.employee")
+ df.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregates: [Sum(SALARY,DecimalType(30,2),true)]"
+ checkKeywordsExistsInExplain(df, expected_plan_fragment)
+ }
+ checkAnswer(df, Seq(Row(31000)))
+ }
+
+ test("scan with aggregate push-down: SUM with group by") {
+ val df = sql("SELECT SUM(SALARY) FROM h2.test.employee GROUP BY DEPT")
+ df.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregates: [Sum(SALARY,DecimalType(30,2),false)], " +
+ "PushedFilters: [], " +
+ "PushedGroupby: [DEPT]"
+ checkKeywordsExistsInExplain(df, expected_plan_fragment)
+ }
+ checkAnswer(df, Seq(Row(19000), Row(22000), Row(12000)))
+ }
+
+ test("scan with aggregate push-down: DISTINCT SUM with group by") {
+ val df = sql("SELECT SUM(DISTINCT SALARY) FROM h2.test.employee GROUP BY
DEPT")
+ df.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregates: [Sum(SALARY,DecimalType(30,2),true)], " +
+ "PushedFilters: [], " +
+ "PushedGroupby: [DEPT]"
+ checkKeywordsExistsInExplain(df, expected_plan_fragment)
+ }
+ checkAnswer(df, Seq(Row(19000), Row(22000), Row(12000)))
+ }
+
+ test("scan with aggregate push-down: with multiple group by columns") {
+ val df = sql("select MAX(SALARY), MIN(BONUS) FROM h2.test.employee where
dept > 0" +
+ " group by DEPT, NAME")
+ val filters11 = df.queryExecution.optimizedPlan.collect {
+ case f: Filter => f
+ }
+ assert(filters11.isEmpty)
+ df.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregates: [Max(SALARY), Min(BONUS)], " +
+ "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " +
+ "PushedGroupby: [DEPT, NAME]"
+ checkKeywordsExistsInExplain(df, expected_plan_fragment)
+ }
+ checkAnswer(df, Seq(Row(9000, 1200), Row(12000, 1200), Row(10000, 1300),
+ Row(10000, 1000), Row(12000, 1200)))
+ }
+
+ test("scan with aggregate push-down: with having clause") {
+ val df = sql("select MAX(SALARY), MIN(BONUS) FROM h2.test.employee where
dept > 0" +
+ " group by DEPT having MIN(BONUS) > 1000")
+ val filters = df.queryExecution.optimizedPlan.collect {
+ case f: Filter => f // filter over aggregate not push down
+ }
+ assert(filters.nonEmpty)
+ df.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregates: [Max(SALARY), Min(BONUS)], " +
+ "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " +
+ "PushedGroupby: [DEPT]"
+ checkKeywordsExistsInExplain(df, expected_plan_fragment)
+ }
+ checkAnswer(df, Seq(Row(12000, 1200), Row(12000, 1200)))
+ }
+
+ test("scan with aggregate push-down: alias over aggregate") {
+ val df = sql("select * from h2.test.employee")
+ .groupBy($"DEPT")
+ .min("SALARY").as("total")
+ df.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregates: [Min(SALARY)], " +
+ "PushedFilters: [], " +
+ "PushedGroupby: [DEPT]"
+ checkKeywordsExistsInExplain(df, expected_plan_fragment)
+ }
+ checkAnswer(df, Seq(Row(1, 9000), Row(2, 10000), Row(6, 12000)))
+ }
+
+ test("scan with aggregate push-down: order by alias over aggregate") {
+ val df = spark.table("h2.test.employee")
+ val query = df.select($"DEPT", $"SALARY")
+ .filter($"DEPT" > 0)
+ .groupBy($"DEPT")
+ .agg(sum($"SALARY").as("total"))
+ .filter($"total" > 1000)
+ .orderBy($"total")
+ val filters = query.queryExecution.optimizedPlan.collect {
+ case f: Filter => f
+ }
+ assert(filters.nonEmpty) // filter over aggregate not pushed down
+ query.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregates: [Sum(SALARY,DecimalType(30,2),false)], " +
+ "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " +
+ "PushedGroupby: [DEPT]"
+ checkKeywordsExistsInExplain(query, expected_plan_fragment)
+ }
+ checkAnswer(query, Seq(Row(6, 12000), Row(1, 19000), Row(2, 22000)))
+ }
+
+ test("scan with aggregate push-down: udf over aggregate") {
+ val df = spark.table("h2.test.employee")
+ val decrease = udf { (x: Double, y: Double) => x - y }
+ val query = df.select(decrease(sum($"SALARY"), sum($"BONUS")).as("value"))
+ query.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregates: [Sum(SALARY,DecimalType(30,2),false),
Sum(BONUS,DoubleType,false)"
+ checkKeywordsExistsInExplain(query, expected_plan_fragment)
+ }
+ checkAnswer(query, Seq(Row(47100.0)))
+ }
+
+ test("scan with aggregate push-down: aggregate over alias") {
+ val cols = Seq("a", "b", "c", "d")
+ val df1 = sql("select * from h2.test.employee").toDF(cols: _*)
+ val df2 = df1.groupBy().sum("c")
+ df2.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregates: []" // aggregate over alias not push down
+ checkKeywordsExistsInExplain(df2, expected_plan_fragment)
+ }
+ checkAnswer(df2, Seq(Row(53000.00)))
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]