This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new a66c74fd1d6f [SPARK-52187][SQL] Introduce Join pushdown for DSv2
a66c74fd1d6f is described below
commit a66c74fd1d6f33d5c70a18e3bd2f47850717afdb
Author: Petar Vasiljevic <[email protected]>
AuthorDate: Tue Jul 15 21:40:54 2025 +0800
[SPARK-52187][SQL] Introduce Join pushdown for DSv2
### What changes were proposed in this pull request?
With this PR I am introducing the Join pushdown interface for DSv2
connectors and it's implementation for JDBC connectors.
The interface itself, `SupportsPushDownJoin` has the following API:
```
public interface SupportsPushDownJoin extends ScanBuilder {
/**
* Returns true if the other side of the join is compatible with the
* current {code SupportsPushDownJoin} for a join push down, meaning both
sides can be
* processed together within the same underlying data source.
*
* <p>For example, JDBC connectors are compatible if they use the same
* host, port, username, and password.</p>
*/
boolean isOtherSideCompatibleForJoin(SupportsPushDownJoin other);
/**
* Pushes down the join of the current {code SupportsPushDownJoin} and
the other side of join
* {code SupportsPushDownJoin}.
*
* param other {code SupportsPushDownJoin} that this {code
SupportsPushDownJoin}
* gets joined with.
* param joinType the type of join.
* param leftSideRequiredColumnsWithAliases required output of the
* left side {code
SupportsPushDownJoin}
* param rightSideRequiredColumnsWithAliases required output of the
* right side {code
SupportsPushDownJoin}
* param condition join condition. Columns are named after the specified
aliases in
* {code leftSideRequiredColumnWithAliases} and {code
rightSideRequiredColumnWithAliases}
* return True if join has been successfully pushed down.
*/
boolean pushDownJoin(
SupportsPushDownJoin other,
JoinType joinType,
ColumnWithAlias[] leftSideRequiredColumnsWithAliases,
ColumnWithAlias[] rightSideRequiredColumnsWithAliases,
Predicate condition
);
/**
* A helper class used when there are duplicated names coming from 2
sides of the join
* operator.
* <br>
* Holds information of original output name and the alias of the new
output.
*/
record ColumnWithAlias(String colName, String alias) {}
}
```
With this implementation, only Inner joins are supported. Left and Right
joins can be added as well in the future. Cross joins won't be supported since
they can increase the amount of data that is being read (although we can have
some other join on top that would actually decrease it in the end).
Also, none of the dialects currently supports the join push down. It is
only available for H2 dialect. The join push down capability is guarded by
SQLConf `spark.sql.optimizer.datasourceV2JoinPushdown`, JDBC option
`pushDownJoin` and JDBC dialect method `supportsJoin`.
### Why are the changes needed?
DSv2 connectors can't push down the join operator.
### Does this PR introduce _any_ user-facing change?
This PR itself no since the behaviour is not implemented for any of the
connectors (besides H2 which is testing JDBC dialect).
### How was this patch tested?
New tests and some local testing with TPCDS queries.
### Was this patch authored or co-authored using generative AI tooling?
Closes #50921 from PetarVasiljevic-DB/support_join_for_dsv2.
Authored-by: Petar Vasiljevic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../apache/spark/sql/connector/join/JoinType.java} | 21 +-
.../sql/connector/read/SupportsPushDownJoin.java | 73 ++++
.../org/apache/spark/sql/internal/SQLConf.scala | 10 +
.../spark/sql/execution/DataSourceScanExec.scala | 9 +-
.../execution/datasources/DataSourceStrategy.scala | 15 +-
.../execution/datasources/jdbc/JDBCOptions.scala | 5 +
.../datasources/v2/PushedDownOperators.scala | 3 +-
.../datasources/v2/V2ScanRelationPushDown.scala | 199 +++++++++-
.../datasources/v2/jdbc/JDBCScanBuilder.scala | 156 +++++++-
.../org/apache/spark/sql/jdbc/H2Dialect.scala | 2 +
.../org/apache/spark/sql/jdbc/JdbcDialects.scala | 5 +
.../spark/sql/jdbc/JdbcSQLQueryBuilder.scala | 48 ++-
.../spark/sql/jdbc/JDBCV2JoinPushdownSuite.scala | 413 +++++++++++++++++++++
.../org/apache/spark/sql/jdbc/JDBCV2Suite.scala | 1 +
14 files changed, 928 insertions(+), 32 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/join/JoinType.java
similarity index 58%
copy from
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala
copy to
sql/catalyst/src/main/java/org/apache/spark/sql/connector/join/JoinType.java
index 49044c6e24db..23ef609201eb 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/join/JoinType.java
@@ -15,21 +15,16 @@
* limitations under the License.
*/
-package org.apache.spark.sql.execution.datasources.v2
+package org.apache.spark.sql.connector.join;
-import org.apache.spark.sql.connector.expressions.SortOrder
-import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
-import org.apache.spark.sql.connector.expressions.filter.Predicate
+import org.apache.spark.annotation.Evolving;
/**
- * Pushed down operators
+ * Enum representing the join type in public API.
+ *
+ * @since 4.1.0
*/
-case class PushedDownOperators(
- aggregation: Option[Aggregation],
- sample: Option[TableSampleInfo],
- limit: Option[Int],
- offset: Option[Int],
- sortValues: Seq[SortOrder],
- pushedPredicates: Seq[Predicate]) {
- assert((limit.isEmpty && sortValues.isEmpty) || limit.isDefined)
+@Evolving
+public enum JoinType {
+ INNER_JOIN,
}
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownJoin.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownJoin.java
new file mode 100644
index 000000000000..01939f91926a
--- /dev/null
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownJoin.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.read;
+
+import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.connector.expressions.filter.Predicate;
+import org.apache.spark.sql.connector.join.JoinType;
+
+/**
+ * A mix-in interface for {@link ScanBuilder}. Data sources can implement this
interface to
+ * push down join operators.
+ *
+ * @since 4.1.0
+ */
+@Evolving
+public interface SupportsPushDownJoin extends ScanBuilder {
+ /**
+ * Returns true if the other side of the join is compatible with the
+ * current {@code SupportsPushDownJoin} for a join push down, meaning both
sides can be
+ * processed together within the same underlying data source.
+ * <br>
+ * <br>
+ * For example, JDBC connectors are compatible if they use the same
+ * host, port, username, and password.
+ */
+ boolean isOtherSideCompatibleForJoin(SupportsPushDownJoin other);
+
+ /**
+ * Pushes down the join of the current {@code SupportsPushDownJoin} and the
other side of join
+ * {@code SupportsPushDownJoin}.
+ *
+ * @param other {@code SupportsPushDownJoin} that this {@code
SupportsPushDownJoin}
+ * gets joined with.
+ * @param joinType the type of join.
+ * @param leftSideRequiredColumnsWithAliases required output of the
+ * left side {@code
SupportsPushDownJoin}
+ * @param rightSideRequiredColumnsWithAliases required output of the
+ * right side {@code
SupportsPushDownJoin}
+ * @param condition join condition. Columns are named after the specified
aliases in
+ * {@code leftSideRequiredColumnWithAliases} and {@code
rightSideRequiredColumnWithAliases}
+ * @return True if join has been successfully pushed down.
+ */
+ boolean pushDownJoin(
+ SupportsPushDownJoin other,
+ JoinType joinType,
+ ColumnWithAlias[] leftSideRequiredColumnsWithAliases,
+ ColumnWithAlias[] rightSideRequiredColumnsWithAliases,
+ Predicate condition
+ );
+
+ /**
+ * A helper class used when there are duplicated names coming from 2 sides
of the join
+ * operator.
+ * <br>
+ * Holds information of original output name and the alias of the new
output.
+ */
+ record ColumnWithAlias(String colName, String alias) {}
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 69a90f87cb0e..b17c4147b951 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -1774,6 +1774,14 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val DATA_SOURCE_V2_JOIN_PUSHDOWN =
+ buildConf("spark.sql.optimizer.datasourceV2JoinPushdown")
+ .internal()
+ .doc("When this config is set to true, join is tried to be pushed down" +
+ "for DSv2 data sources in V2ScanRelationPushdown optimization rule.")
+ .booleanConf
+ .createWithDefault(false)
+
// This is used to set the default data source
val DEFAULT_DATA_SOURCE_NAME = buildConf("spark.sql.sources.default")
.doc("The default data source to use in input/output.")
@@ -6278,6 +6286,8 @@ class SQLConf extends Serializable with Logging with
SqlApiConf {
def nameResolutionLogLevel: Level = getConf(NAME_RESOLUTION_LOG_LEVEL)
+ def dataSourceV2JoinPushdown: Boolean = getConf(DATA_SOURCE_V2_JOIN_PUSHDOWN)
+
def dynamicPartitionPruningEnabled: Boolean =
getConf(DYNAMIC_PARTITION_PRUNING_ENABLED)
def dynamicPartitionPruningUseStats: Boolean =
getConf(DYNAMIC_PARTITION_PRUNING_USE_STATS)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
index 31ab367c2d00..14c62a7992ad 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
@@ -189,6 +189,12 @@ case class RowDataSourceScanExec(
seqToString(markedFilters.toSeq)
}
+ val pushedJoins = if (pushedDownOperators.joinedRelations.length > 1) {
+ Map("PushedJoins" -> seqToString(pushedDownOperators.joinedRelations))
+ } else {
+ Map()
+ }
+
Map("ReadSchema" -> requiredSchema.catalogString,
"PushedFilters" -> pushedFilters) ++
pushedDownOperators.aggregation.fold(Map[String, String]()) { v =>
@@ -200,7 +206,8 @@ case class RowDataSourceScanExec(
offsetInfo ++
pushedDownOperators.sample.map(v => "PushedSample" ->
s"SAMPLE (${(v.upperBound - v.lowerBound) * 100}) ${v.withReplacement}
SEED(${v.seed})"
- )
+ ) ++
+ pushedJoins
}
// Don't care about `rdd` and `tableIdentifier`, and `stream` when
canonicalizing.
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 3b55a294b21b..882012c968e2 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -37,6 +37,7 @@ import
org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
+import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, InsertIntoDir,
InsertIntoStatement, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
@@ -47,6 +48,7 @@ import org.apache.spark.sql.connector.catalog.{SupportsRead,
V1Table}
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.connector.expressions.{Expression => V2Expression,
NullOrdering, SortDirection, SortOrder => V2SortOrder, SortValue}
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc,
Aggregation}
+import org.apache.spark.sql.connector.join.{JoinType => V2JoinType}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan}
@@ -399,7 +401,7 @@ object DataSourceStrategy
l.output.toStructType,
Set.empty,
Set.empty,
- PushedDownOperators(None, None, None, None, Seq.empty, Seq.empty),
+ PushedDownOperators(None, None, None, None, Seq.empty, Seq.empty,
Seq.empty),
toCatalystRDD(l, baseRelation.buildScan()),
baseRelation,
l.stream,
@@ -474,7 +476,7 @@ object DataSourceStrategy
requestedColumns.toStructType,
pushedFilters.toSet,
handledFilters,
- PushedDownOperators(None, None, None, None, Seq.empty, Seq.empty),
+ PushedDownOperators(None, None, None, None, Seq.empty, Seq.empty,
Seq.empty),
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
relation.relation,
relation.stream,
@@ -498,7 +500,7 @@ object DataSourceStrategy
requestedColumns.toStructType,
pushedFilters.toSet,
handledFilters,
- PushedDownOperators(None, None, None, None, Seq.empty, Seq.empty),
+ PushedDownOperators(None, None, None, None, Seq.empty, Seq.empty,
Seq.empty),
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
relation.relation,
relation.stream,
@@ -508,6 +510,13 @@ object DataSourceStrategy
}
}
+ def translateJoinType(joinType: JoinType): Option[V2JoinType] = {
+ joinType match {
+ case Inner => Some(V2JoinType.INNER_JOIN)
+ case _ => None
+ }
+ }
+
/**
* Convert RDD of Row into RDD of InternalRow with objects in catalyst types
*/
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
index f0c638b7d07c..280a586fdb04 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
@@ -215,6 +215,10 @@ class JDBCOptions(
// This only applies to Data Source V2 JDBC
val pushDownTableSample = parameters.getOrElse(JDBC_PUSHDOWN_TABLESAMPLE,
"true").toBoolean
+ // An option to allow/disallow pushing down JOIN into JDBC data source
+ // This only applies to Data Source V2 JDBC
+ val pushDownJoin = parameters.getOrElse(JDBC_PUSHDOWN_JOIN, "true").toBoolean
+
// The local path of user's keytab file, which is assumed to be pre-uploaded
to all nodes either
// by --files option of spark-submit or manually
val keytab = {
@@ -321,6 +325,7 @@ object JDBCOptions {
val JDBC_PUSHDOWN_LIMIT = newOption("pushDownLimit")
val JDBC_PUSHDOWN_OFFSET = newOption("pushDownOffset")
val JDBC_PUSHDOWN_TABLESAMPLE = newOption("pushDownTableSample")
+ val JDBC_PUSHDOWN_JOIN = newOption("pushDownJoin")
val JDBC_KEYTAB = newOption("keytab")
val JDBC_PRINCIPAL = newOption("principal")
val JDBC_TABLE_COMMENT = newOption("tableComment")
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala
index 49044c6e24db..668bfac0c452 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala
@@ -30,6 +30,7 @@ case class PushedDownOperators(
limit: Option[Int],
offset: Option[Int],
sortValues: Seq[SortOrder],
- pushedPredicates: Seq[Predicate]) {
+ pushedPredicates: Seq[Predicate],
+ joinedRelations: Seq[String]) {
assert((limit.isEmpty && sortValues.isEmpty) || limit.isDefined)
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
index 5f7e86cab524..f7f1c4f522c2 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
@@ -25,13 +25,13 @@ import
org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, And, Attribu
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.optimizer.CollapseProject
import org.apache.spark.sql.catalyst.planning.{PhysicalOperation,
ScanOperation}
-import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter,
LeafNode, Limit, LimitAndOffset, LocalLimit, LogicalPlan, Offset,
OffsetAndLimit, Project, Sample, Sort}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Join,
LeafNode, Limit, LimitAndOffset, LocalLimit, LogicalPlan, Offset,
OffsetAndLimit, Project, Sample, Sort}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.connector.expressions.{SortOrder => V2SortOrder}
import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg,
Count, CountStar, Max, Min, Sum}
import org.apache.spark.sql.connector.expressions.filter.Predicate
-import org.apache.spark.sql.connector.read.{Scan, ScanBuilder,
SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan}
+import org.apache.spark.sql.connector.read.{Scan, ScanBuilder,
SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownJoin,
V1Scan}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.sources
import org.apache.spark.sql.types.{DataType, DecimalType, IntegerType,
StructType}
@@ -46,9 +46,11 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with
PredicateHelper {
createScanBuilder,
pushDownSample,
pushDownFilters,
+ pushDownJoin,
pushDownAggregates,
pushDownLimitAndOffset,
buildScanWithPushedAggregate,
+ buildScanWithPushedJoin,
pruneColumns)
pushdownRules.foldLeft(plan) { (newPlan, pushDownRule) =>
@@ -98,6 +100,146 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan]
with PredicateHelper {
filterCondition.map(Filter(_, sHolder)).getOrElse(sHolder)
}
+ def pushDownJoin(plan: LogicalPlan): LogicalPlan = plan.transformUp {
+ // Join can be attempted to be pushed down only if left and right side of
join are
+ // compatible (same data source, for example). Also, another requirement
is that if
+ // there are projections between Join and ScanBuilderHolder, these
projections need to be
+ // AttributeReferences. We could probably support Alias as well, but this
should be on
+ // TODO list.
+ // Alias can exist between Join and sHolder node because the query below
is not valid:
+ // SELECT * FROM
+ // (SELECT * FROM tbl t1 JOIN tbl2 t2) p
+ // JOIN
+ // (SELECT * FROM tbl t3 JOIN tbl3 t4) q
+ // ON p.t1.col = q.t3.col (this is not possible)
+ // It's because there are duplicated columns in both sides of top level
join and it's not
+ // possible to fully qualified the column names in condition. Therefore,
query should be
+ // rewritten so that each of the outputs of child joins are aliased, so
there would be a
+ // projection with aliases between top level join and scanBuilderHolder
(that has pushed
+ // child joins).
+ case node @ Join(
+ PhysicalOperation(
+ leftProjections,
+ Nil,
+ leftHolder @ ScanBuilderHolder(_, _, lBuilder: SupportsPushDownJoin)
+ ),
+ PhysicalOperation(
+ rightProjections,
+ Nil,
+ rightHolder @ ScanBuilderHolder(_, _, rBuilder: SupportsPushDownJoin)
+ ),
+ joinType,
+ condition,
+ _) if conf.dataSourceV2JoinPushdown &&
+ // We do not support pushing down anything besides AttributeReference.
+ leftProjections.forall(_.isInstanceOf[AttributeReference]) &&
+ rightProjections.forall(_.isInstanceOf[AttributeReference]) &&
+ // Cross joins are not supported because they increase the amount of
data.
+ condition.isDefined &&
+ lBuilder.isOtherSideCompatibleForJoin(rBuilder) =>
+ val leftSideRequiredColumnNames =
getRequiredColumnNames(leftProjections, leftHolder)
+ val rightSideRequiredColumnNames =
getRequiredColumnNames(rightProjections, rightHolder)
+
+ // Alias the duplicated columns from left side of the join. We are
creating the
+ // Map[String, Int] to tell how many times each column name has occured
within one side.
+ val leftSideNameCounts: Map[String, Int] =
+
leftSideRequiredColumnNames.groupBy(identity).view.mapValues(_.size).toMap
+ val rightSideNameCounts: Map[String, Int] =
+
rightSideRequiredColumnNames.groupBy(identity).view.mapValues(_.size).toMap
+ // It's more performant to call contains on Set than on Seq
+ val rightSideColumnNamesSet = rightSideRequiredColumnNames.toSet
+
+ val leftSideRequiredColumnsWithAliases = leftSideRequiredColumnNames.map
{ name =>
+ val aliasName =
+ if (leftSideNameCounts(name) > 1 ||
rightSideColumnNamesSet.contains(name)) {
+ generateJoinOutputAlias(name)
+ } else {
+ null
+ }
+
+ new SupportsPushDownJoin.ColumnWithAlias(name, aliasName)
+ }
+
+ // Aliasing of duplicated columns in right side is done only if there
are duplicates in
+ // right side only. There won't be a conflict with left side columns
because they are
+ // already aliased.
+ val rightSideRequiredColumnsWithAliases =
rightSideRequiredColumnNames.map { name =>
+ val aliasName =
+ if (rightSideNameCounts(name) > 1) {
+ generateJoinOutputAlias(name)
+ } else {
+ null
+ }
+
+ new SupportsPushDownJoin.ColumnWithAlias(name, aliasName)
+ }
+
+ // Create the AttributeMap that holds (Attribute -> Attribute with up to
date name) mapping.
+ val pushedJoinOutputMap = AttributeMap[Expression](
+ node.output
+ .zip(leftSideRequiredColumnsWithAliases ++
rightSideRequiredColumnsWithAliases)
+ .collect {
+ case (attr, columnWithAlias) if columnWithAlias.alias() != null =>
+ (attr, attr.withName(columnWithAlias.alias()))
+ }
+ .toMap
+ )
+
+ // Reuse the previously calculated map to update the condition with
attributes
+ // with up-to-date names
+ val normalizedCondition = condition.map { e =>
+ DataSourceStrategy.normalizeExprs(
+ Seq(e),
+ (leftHolder.output ++ rightHolder.output).map { a =>
+ pushedJoinOutputMap.getOrElse(a,
a).asInstanceOf[AttributeReference]
+ }
+ ).head
+ }
+
+ val translatedCondition =
+ normalizedCondition.flatMap(DataSourceV2Strategy.translateFilterV2(_))
+ val translatedJoinType = DataSourceStrategy.translateJoinType(joinType)
+
+ if (translatedJoinType.isDefined &&
+ translatedCondition.isDefined &&
+ lBuilder.pushDownJoin(
+ rBuilder,
+ translatedJoinType.get,
+ leftSideRequiredColumnsWithAliases,
+ rightSideRequiredColumnsWithAliases,
+ translatedCondition.get)
+ ) {
+ leftHolder.joinedRelations = leftHolder.joinedRelations ++
rightHolder.joinedRelations
+ leftHolder.pushedPredicates = leftHolder.pushedPredicates ++
+ rightHolder.pushedPredicates :+ translatedCondition.get
+
+ leftHolder.output = node.output.asInstanceOf[Seq[AttributeReference]]
+ leftHolder.pushedJoinOutputMap = pushedJoinOutputMap
+
+ leftHolder
+ } else {
+ node
+ }
+ }
+
+ def generateJoinOutputAlias(name: String): String =
+ s"${name}_${java.util.UUID.randomUUID().toString.replace("-", "_")}"
+
+ // projections' names are maybe not up to date if the joins have been
previously pushed down.
+ // For this reason, we need to use pushedJoinOutputMap to get up to date
names.
+ def getRequiredColumnNames(
+ projections: Seq[NamedExpression],
+ sHolder: ScanBuilderHolder): Array[String] = {
+ val normalizedProjections = DataSourceStrategy.normalizeExprs(
+ projections,
+ sHolder.output.map { a =>
+ sHolder.pushedJoinOutputMap.getOrElse(a,
a).asInstanceOf[AttributeReference]
+ }
+ ).asInstanceOf[Seq[AttributeReference]]
+
+ normalizedProjections.map(_.name).toArray
+ }
+
def pushDownAggregates(plan: LogicalPlan): LogicalPlan = plan.transform {
// update the scan builder with agg pushdown and return a new plan with
agg pushed
case agg: Aggregate => rewriteAggregate(agg)
@@ -113,10 +255,9 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan]
with PredicateHelper {
val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int]
val aggregates = collectAggregates(actualResultExprs,
aggExprToOutputOrdinal)
- val normalizedAggExprs = DataSourceStrategy.normalizeExprs(
- aggregates,
holder.relation.output).asInstanceOf[Seq[AggregateExpression]]
- val normalizedGroupingExpr = DataSourceStrategy.normalizeExprs(
- actualGroupExprs, holder.relation.output)
+ val normalizedAggExprs =
+ normalizeExpressions(aggregates,
holder).asInstanceOf[Seq[AggregateExpression]]
+ val normalizedGroupingExpr = normalizeExpressions(actualGroupExprs,
holder)
val translatedAggOpt = DataSourceStrategy.translateAggregation(
normalizedAggExprs, normalizedGroupingExpr)
if (translatedAggOpt.isEmpty) {
@@ -356,6 +497,25 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan]
with PredicateHelper {
Project(projectList, scanRelation)
}
+ def buildScanWithPushedJoin(plan: LogicalPlan): LogicalPlan = plan.transform
{
+ case holder: ScanBuilderHolder if holder.joinedRelations.length > 1 =>
+ val scan = holder.builder.build()
+ val realOutput = toAttributes(scan.readSchema())
+ assert(realOutput.length == holder.output.length,
+ "The data source returns unexpected number of columns")
+ val wrappedScan = getWrappedScan(scan, holder)
+ val scanRelation = DataSourceV2ScanRelation(holder.relation,
wrappedScan, realOutput)
+
+ // When join is pushed down, the real output is going to be, for example,
+ // SALARY_01234#0, NAME_ab123#1, DEPT_cd123#2.
+ // We should revert these names back to original names. For example,
+ // SALARY#0, NAME#1, DEPT#1. This is done by adding projection with
appropriate aliases.
+ val projectList = realOutput.zip(holder.output).map { case (a1, a2) =>
+ Alias(a1, a2.name)(a2.exprId)
+ }
+ Project(projectList, scanRelation)
+ }
+
def pruneColumns(plan: LogicalPlan): LogicalPlan = plan.transform {
case ScanOperation(project, filtersStayUp, filtersPushDown, sHolder:
ScanBuilderHolder) =>
// column pruning
@@ -441,8 +601,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan]
with PredicateHelper {
} else {
aliasReplacedOrder.asInstanceOf[Seq[SortOrder]]
}
- val normalizedOrders = DataSourceStrategy.normalizeExprs(
- newOrder, sHolder.relation.output).asInstanceOf[Seq[SortOrder]]
+ val normalizedOrders = normalizeExpressions(newOrder,
sHolder).asInstanceOf[Seq[SortOrder]]
val orders = DataSourceStrategy.translateSortOrders(normalizedOrders)
if (orders.length == order.length) {
val (isPushed, isPartiallyPushed) =
@@ -540,6 +699,23 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan]
with PredicateHelper {
}
}
+ private def normalizeExpressions(
+ expressions: Seq[Expression],
+ sHolder: ScanBuilderHolder): Seq[Expression] = {
+ val output = if (sHolder.joinedRelations.length == 1) {
+ // Join is not pushed down
+ sHolder.relation.output
+ } else {
+ // sHolder.output's names can be out of date if the joins has previously
been pushed down.
+ // For this reason, we need to use pushedJoinOutputMap to get up to date
names.
+ sHolder.output.map { a =>
+ sHolder.pushedJoinOutputMap.getOrElse(a,
a).asInstanceOf[AttributeReference]
+ }
+ }
+
+ DataSourceStrategy.normalizeExprs(expressions, output)
+ }
+
private def getWrappedScan(scan: Scan, sHolder: ScanBuilderHolder): Scan = {
scan match {
case v1: V1Scan =>
@@ -549,7 +725,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan]
with PredicateHelper {
case _ => Array.empty[sources.Filter]
}
val pushedDownOperators = PushedDownOperators(sHolder.pushedAggregate,
sHolder.pushedSample,
- sHolder.pushedLimit, sHolder.pushedOffset, sHolder.sortOrders,
sHolder.pushedPredicates)
+ sHolder.pushedLimit, sHolder.pushedOffset, sHolder.sortOrders,
sHolder.pushedPredicates,
+ sHolder.joinedRelations.map(_.name))
V1ScanWrapper(v1, pushedFilters.toImmutableArraySeq,
pushedDownOperators)
case _ => scan
}
@@ -573,6 +750,10 @@ case class ScanBuilderHolder(
var pushedAggregate: Option[Aggregation] = None
var pushedAggOutputMap: AttributeMap[Expression] =
AttributeMap.empty[Expression]
+
+ var joinedRelations: Seq[DataSourceV2RelationBase] = Seq(relation)
+
+ var pushedJoinOutputMap: AttributeMap[Expression] =
AttributeMap.empty[Expression]
}
// A wrapper for v1 scan to carry the translated filters and the handled ones,
along with
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
index 230f30fb1d06..68df0e6a9726 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
@@ -23,17 +23,18 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.expressions.{FieldReference, SortOrder}
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.connector.expressions.filter.Predicate
-import org.apache.spark.sql.connector.read.{ScanBuilder,
SupportsPushDownAggregates, SupportsPushDownLimit, SupportsPushDownOffset,
SupportsPushDownRequiredColumns, SupportsPushDownTableSample,
SupportsPushDownTopN, SupportsPushDownV2Filters}
+import org.apache.spark.sql.connector.join.JoinType
+import org.apache.spark.sql.connector.read.{ScanBuilder,
SupportsPushDownAggregates, SupportsPushDownJoin, SupportsPushDownLimit,
SupportsPushDownOffset, SupportsPushDownRequiredColumns,
SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters}
import org.apache.spark.sql.execution.datasources.PartitioningUtils
-import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD,
JDBCRelation}
+import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions,
JDBCPartition, JDBCRDD, JDBCRelation}
import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
-import org.apache.spark.sql.jdbc.JdbcDialects
+import org.apache.spark.sql.jdbc.{JdbcDialects, JdbcSQLQueryBuilder}
import org.apache.spark.sql.types.StructType
case class JDBCScanBuilder(
session: SparkSession,
schema: StructType,
- jdbcOptions: JDBCOptions)
+ var jdbcOptions: JDBCOptions)
extends ScanBuilder
with SupportsPushDownV2Filters
with SupportsPushDownRequiredColumns
@@ -42,6 +43,7 @@ case class JDBCScanBuilder(
with SupportsPushDownOffset
with SupportsPushDownTableSample
with SupportsPushDownTopN
+ with SupportsPushDownJoin
with Logging {
private val dialect = JdbcDialects.get(jdbcOptions.url)
@@ -121,6 +123,143 @@ case class JDBCScanBuilder(
}
}
+ // TODO: currently we check that all the options are same (besides dbtable
and query options).
+ // That is too strict, so in the future we should relax this check by
asserting only specific
+ // options are some (e.g. host, port, username, password, database...).
+ // Also, we need to check if join is done on 2 tables from 2 different
databases within same
+ // host. These shouldn't be allowed.
+ override def isOtherSideCompatibleForJoin(other: SupportsPushDownJoin):
Boolean = {
+ if (!jdbcOptions.pushDownJoin ||
+ !dialect.supportsJoin ||
+ !other.isInstanceOf[JDBCScanBuilder]) {
+ return false
+ }
+
+ val filteredJDBCOptions = jdbcOptions.parameters -
+ JDBCOptions.JDBC_TABLE_NAME -
+ JDBCOptions.JDBC_QUERY_STRING
+
+ val otherSideFilteredJDBCOptions =
other.asInstanceOf[JDBCScanBuilder].jdbcOptions.parameters -
+ JDBCOptions.JDBC_TABLE_NAME -
+ JDBCOptions.JDBC_QUERY_STRING
+
+ filteredJDBCOptions == otherSideFilteredJDBCOptions
+ };
+
+ /**
+ * Helper method to calculate StructType based on the
SupportsPushDownJoin.ColumnWithAlias and
+ * the given schema.
+ *
+ * If ColumnWithAlias object has defined alias, new field with new name
being equal to alias
+ * should be returned. Otherwise, original field is returned.
+ */
+ private def calculateJoinOutputSchema(
+ columnsWithAliases: Array[SupportsPushDownJoin.ColumnWithAlias],
+ schema: StructType): StructType = {
+ var newSchema = StructType(Seq())
+ columnsWithAliases.foreach { columnWithAlias =>
+ val colName = columnWithAlias.colName()
+ val alias = columnWithAlias.alias()
+ val field = schema(colName)
+
+ if (alias == null) {
+ newSchema = newSchema.add(field)
+ } else {
+ newSchema = newSchema.add(alias, field.dataType, field.nullable,
field.metadata)
+ }
+ }
+
+ newSchema
+ }
+
+ override def pushDownJoin(
+ other: SupportsPushDownJoin,
+ joinType: JoinType,
+ leftSideRequiredColumnsWithAliases:
Array[SupportsPushDownJoin.ColumnWithAlias],
+ rightSideRequiredColumnsWithAliases:
Array[SupportsPushDownJoin.ColumnWithAlias],
+ condition: Predicate ): Boolean = {
+ if (!jdbcOptions.pushDownJoin || !dialect.supportsJoin) {
+ return false
+ }
+
+ val joinTypeStringOption = joinType match {
+ case JoinType.INNER_JOIN => Some("INNER JOIN")
+ case _ => None
+ }
+ if (!joinTypeStringOption.isDefined) {
+ return false
+ }
+
+ val compiledCondition = dialect.compileExpression(condition)
+ if (!compiledCondition.isDefined) {
+ return false
+ }
+
+ val otherJdbcScanBuilder = other.asInstanceOf[JDBCScanBuilder]
+
+ // requiredSchema will become the finalSchema of this JDBCScanBuilder
+ var requiredSchema = StructType(Seq())
+ requiredSchema =
calculateJoinOutputSchema(leftSideRequiredColumnsWithAliases, finalSchema)
+ requiredSchema = requiredSchema.merge(
+ calculateJoinOutputSchema(
+ rightSideRequiredColumnsWithAliases,
+ otherJdbcScanBuilder.finalSchema
+ )
+ )
+
+ val joinOutputColumns = requiredSchema.fields.map(f =>
dialect.quoteIdentifier(f.name))
+ val conditionString = compiledCondition.get
+
+ // Get left side and right side of join sql query builders and recursively
build them when
+ // crafting join sql query.
+ val leftSideJdbcSQLBuilder =
getJoinPushdownJdbcSQLBuilder(leftSideRequiredColumnsWithAliases)
+ val otherSideJdbcSQLBuilder = otherJdbcScanBuilder
+ .getJoinPushdownJdbcSQLBuilder(rightSideRequiredColumnsWithAliases)
+
+ val joinQuery = dialect
+ .getJdbcSQLQueryBuilder(jdbcOptions)
+ .withJoin(
+ leftSideJdbcSQLBuilder,
+ otherSideJdbcSQLBuilder,
+ JoinPushdownAliasGenerator.getSubqueryQualifier,
+ JoinPushdownAliasGenerator.getSubqueryQualifier,
+ joinOutputColumns,
+ joinTypeStringOption.get,
+ conditionString
+ )
+ .build()
+
+ val newJdbcOptionsMap = jdbcOptions.parameters.originalMap +
+ (JDBCOptions.JDBC_QUERY_STRING -> joinQuery) -
JDBCOptions.JDBC_TABLE_NAME
+
+ jdbcOptions = new JDBCOptions(newJdbcOptionsMap)
+ finalSchema = requiredSchema
+
+ // We need to reset the pushedPredicate because it has already been
consumed in previously
+ // crafted SQL query.
+ pushedPredicate = Array.empty[Predicate]
+ // Table sample is pushed down already as well, so we need to reset it to
None to not push it
+ // down again when join pushdown is triggered again on this
JDBCScanBuilder.
+ tableSample = None
+
+ true
+ }
+
+ def getJoinPushdownJdbcSQLBuilder(
+ columnsWithAliases: Array[SupportsPushDownJoin.ColumnWithAlias]):
JdbcSQLQueryBuilder = {
+ val quotedColumns = columnsWithAliases.map(col =>
dialect.quoteIdentifier(col.colName()))
+ val quotedAliases = columnsWithAliases
+ .map(col => Option(col.alias()).map(dialect.quoteIdentifier))
+
+ // Only filters can be pushed down before join pushdown, so we need to
craft SQL query
+ // that contains filters as well.
+ // Joins on top of samples are not supported so we don't need to provide
tableSample here.
+ dialect
+ .getJdbcSQLQueryBuilder(jdbcOptions)
+ .withPredicates(pushedPredicate, JDBCPartition(whereClause = null, idx =
1))
+ .withAliasedColumns(quotedColumns, quotedAliases)
+ }
+
override def pushTableSample(
lowerBound: Double,
upperBound: Double,
@@ -194,4 +333,13 @@ case class JDBCScanBuilder(
JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), finalSchema,
pushedPredicate,
pushedAggregateList, pushedGroupBys, tableSample, pushedLimit,
sortOrders, pushedOffset)
}
+
+}
+
+object JoinPushdownAliasGenerator {
+ private val subQueryId = new java.util.concurrent.atomic.AtomicLong()
+
+ def getSubqueryQualifier: String = {
+ "join_subquery_" + subQueryId.getAndIncrement()
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
index b5ee88aebd7d..fab859bee81c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
@@ -307,4 +307,6 @@ private[sql] case class H2Dialect() extends JdbcDialect
with NoLegacyJDBCError {
override def supportsLimit: Boolean = true
override def supportsOffset: Boolean = true
+
+ override def supportsJoin: Boolean = true
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
index da0df734bbec..8afec9cead07 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
@@ -853,6 +853,11 @@ abstract class JdbcDialect extends Serializable with
Logging {
def supportsHint: Boolean = false
+ /**
+ * Returns true if dialect supports JOIN operator.
+ */
+ def supportsJoin: Boolean = false
+
/**
* Return the DB-specific quoted and fully qualified table name
*/
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcSQLQueryBuilder.scala
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcSQLQueryBuilder.scala
index 95be14f816a7..97ec093a0e29 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcSQLQueryBuilder.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcSQLQueryBuilder.scala
@@ -91,6 +91,20 @@ class JdbcSQLQueryBuilder(dialect: JdbcDialect, options:
JDBCOptions) {
this
}
+ def withAliasedColumns(
+ columns: Array[String],
+ aliases: Array[Option[String]]): JdbcSQLQueryBuilder = {
+ if (columns.nonEmpty) {
+ assert(columns.length == aliases.length,
+ "Number of columns does not match the number of provided aliases")
+
+ columnList = columns.zip(aliases).map {
+ case (column, alias) => if (alias.isDefined) s"$column AS
${alias.get}" else column
+ }.mkString(",")
+ }
+ this
+ }
+
/**
* Constructs the WHERE clause that following dialect's SQL syntax.
*/
@@ -164,6 +178,38 @@ class JdbcSQLQueryBuilder(dialect: JdbcDialect, options:
JDBCOptions) {
this
}
+ /**
+ * Represents JOIN subquery in case Join has been pushed down. This value
should be used
+ * instead of options.tableOrQuery if join has been pushed down.
+ */
+ private var joinQuery: Option[String] = None
+
+ def withJoin(
+ left: JdbcSQLQueryBuilder,
+ right: JdbcSQLQueryBuilder,
+ leftSideQualifier: String,
+ rightSideQualifier: String,
+ columns: Array[String],
+ joinType: String,
+ joinCondition: String): JdbcSQLQueryBuilder = {
+ columnList = columns.mkString(",")
+ joinQuery = Some(
+ s"""(
+ |SELECT ${columns.mkString(",")} FROM
+ |(${left.build()}) $leftSideQualifier
+ |$joinType
+ |(${right.build()}) $rightSideQualifier
+ |ON $joinCondition
+ |)""".stripMargin
+ )
+
+ this
+ }
+
+ // If join has been pushed down, reuse join query as a subquery. Otherwise,
fallback to
+ // what is provided in options.
+ private def tableOrQuery = joinQuery.getOrElse(options.tableOrQuery)
+
/**
* Build the final SQL query that following dialect's SQL syntax.
*/
@@ -174,7 +220,7 @@ class JdbcSQLQueryBuilder(dialect: JdbcDialect, options:
JDBCOptions) {
val offsetClause = dialect.getOffsetClause(offset)
options.prepareQuery +
- s"SELECT $hintClause$columnList FROM ${options.tableOrQuery}
$tableSampleClause" +
+ s"SELECT $hintClause$columnList FROM $tableOrQuery $tableSampleClause" +
s" $whereClause $groupByClause $orderByClause $limitClause $offsetClause"
}
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2JoinPushdownSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2JoinPushdownSuite.scala
new file mode 100644
index 000000000000..b77e905fea5d
--- /dev/null
+++
b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2JoinPushdownSuite.scala
@@ -0,0 +1,413 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.jdbc
+
+import java.sql.{Connection, DriverManager}
+import java.util.Properties
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.{DataFrame, ExplainSuiteHelper, QueryTest}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, GlobalLimit,
Join, Sort}
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
+import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.util.Utils
+
+class JDBCV2JoinPushdownSuite extends QueryTest with SharedSparkSession with
ExplainSuiteHelper {
+ val tempDir = Utils.createTempDir()
+ val url =
s"jdbc:h2:${tempDir.getCanonicalPath};user=testUser;password=testPass"
+
+ override def sparkConf: SparkConf = super.sparkConf
+ .set("spark.sql.catalog.h2", classOf[JDBCTableCatalog].getName)
+ .set("spark.sql.catalog.h2.url", url)
+ .set("spark.sql.catalog.h2.driver", "org.h2.Driver")
+ .set("spark.sql.catalog.h2.pushDownAggregate", "true")
+ .set("spark.sql.catalog.h2.pushDownLimit", "true")
+ .set("spark.sql.catalog.h2.pushDownOffset", "true")
+ .set("spark.sql.catalog.h2.pushDownJoin", "true")
+
+ private def withConnection[T](f: Connection => T): T = {
+ val conn = DriverManager.getConnection(url, new Properties())
+ try {
+ f(conn)
+ } finally {
+ conn.close()
+ }
+ }
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ Utils.classForName("org.h2.Driver")
+ withConnection { conn =>
+ conn.prepareStatement("CREATE SCHEMA \"test\"").executeUpdate()
+ conn.prepareStatement(
+ "CREATE TABLE \"test\".\"people\" (name TEXT(32) NOT NULL, id
INTEGER NOT NULL)")
+ .executeUpdate()
+ conn.prepareStatement("INSERT INTO \"test\".\"people\" VALUES ('fred',
1)").executeUpdate()
+ conn.prepareStatement("INSERT INTO \"test\".\"people\" VALUES ('mary',
2)").executeUpdate()
+ conn.prepareStatement(
+ "CREATE TABLE \"test\".\"employee\" (dept INTEGER, name TEXT(32),
salary NUMERIC(20, 2)," +
+ " bonus DOUBLE, is_manager BOOLEAN)").executeUpdate()
+ conn.prepareStatement(
+ "INSERT INTO \"test\".\"employee\" VALUES (1, 'amy', 10000, 1000,
true)").executeUpdate()
+ conn.prepareStatement(
+ "INSERT INTO \"test\".\"employee\" VALUES (2, 'alex', 12000, 1200,
false)").executeUpdate()
+ conn.prepareStatement(
+ "INSERT INTO \"test\".\"employee\" VALUES (1, 'cathy', 9000, 1200,
false)").executeUpdate()
+ conn.prepareStatement(
+ "INSERT INTO \"test\".\"employee\" VALUES (2, 'david', 10000, 1300,
true)").executeUpdate()
+ conn.prepareStatement(
+ "INSERT INTO \"test\".\"employee\" VALUES (6, 'jen', 12000, 1200,
true)").executeUpdate()
+ conn.prepareStatement(
+ "CREATE TABLE \"test\".\"dept\" (\"dept id\" INTEGER NOT NULL,
\"dept.id\" INTEGER)")
+ .executeUpdate()
+ conn.prepareStatement("INSERT INTO \"test\".\"dept\" VALUES (1,
1)").executeUpdate()
+ conn.prepareStatement("INSERT INTO \"test\".\"dept\" VALUES (2,
1)").executeUpdate()
+
+ // scalastyle:off
+ conn.prepareStatement(
+ "CREATE TABLE \"test\".\"person\" (\"名\" INTEGER NOT
NULL)").executeUpdate()
+ // scalastyle:on
+ conn.prepareStatement("INSERT INTO \"test\".\"person\" VALUES
(1)").executeUpdate()
+ conn.prepareStatement("INSERT INTO \"test\".\"person\" VALUES
(2)").executeUpdate()
+ conn.prepareStatement(
+ """CREATE TABLE "test"."view1" ("|col1" INTEGER, "|col2"
INTEGER)""").executeUpdate()
+ conn.prepareStatement(
+ """CREATE TABLE "test"."view2" ("|col1" INTEGER, "|col3"
INTEGER)""").executeUpdate()
+ }
+ }
+
+ override def afterAll(): Unit = {
+ Utils.deleteRecursively(tempDir)
+ super.afterAll()
+ }
+
+ private def checkPushedInfo(df: DataFrame, expectedPlanFragment: String*):
Unit = {
+ withSQLConf(SQLConf.MAX_METADATA_STRING_LENGTH.key -> "1000") {
+ df.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ checkKeywordsExistsInExplain(df, expectedPlanFragment: _*)
+ }
+ }
+ }
+
+ // Conditionless joins are not supported in join pushdown
+ test("Test that 2-way join without condition should not have join pushed
down") {
+ val sqlQuery = "SELECT * FROM h2.test.employee a, h2.test.employee b"
+ val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key ->
"false") {
+ sql(sqlQuery).collect().toSeq
+ }
+
+ withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") {
+ val df = sql(sqlQuery)
+ val joinNodes = df.queryExecution.optimizedPlan.collect {
+ case j: Join => j
+ }
+
+ assert(joinNodes.nonEmpty)
+ checkAnswer(df, rows)
+ }
+ }
+
+ // Conditionless joins are not supported in join pushdown
+ test("Test that multi-way join without condition should not have join pushed
down") {
+ val sqlQuery = """
+ |SELECT * FROM
+ |h2.test.employee a,
+ |h2.test.employee b,
+ |h2.test.employee c
+ |""".stripMargin
+
+ val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key ->
"false") {
+ sql(sqlQuery).collect().toSeq
+ }
+
+ withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") {
+ val df = sql(sqlQuery)
+
+ val joinNodes = df.queryExecution.optimizedPlan.collect {
+ case j: Join => j
+ }
+
+ assert(joinNodes.nonEmpty)
+ checkAnswer(df, rows)
+ }
+ }
+
+ test("Test self join with condition") {
+ val sqlQuery = "SELECT * FROM h2.test.employee a JOIN h2.test.employee b
ON a.dept = b.dept + 1"
+
+ val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key ->
"false") {
+ sql(sqlQuery).collect().toSeq
+ }
+
+ withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") {
+ val df = sql(sqlQuery)
+
+ val joinNodes = df.queryExecution.optimizedPlan.collect {
+ case j: Join => j
+ }
+
+ assert(joinNodes.isEmpty)
+ checkPushedInfo(df, "PushedJoins: [h2.test.employee, h2.test.employee]")
+ checkAnswer(df, rows)
+ }
+ }
+
+ test("Test multi-way self join with conditions") {
+ val sqlQuery = """
+ |SELECT * FROM
+ |h2.test.employee a
+ |JOIN h2.test.employee b ON b.dept = a.dept + 1
+ |JOIN h2.test.employee c ON c.dept = b.dept - 1
+ |""".stripMargin
+
+ val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key ->
"false") {
+ sql(sqlQuery).collect().toSeq
+ }
+
+ assert(!rows.isEmpty)
+
+ withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") {
+ val df = sql(sqlQuery)
+ val joinNodes = df.queryExecution.optimizedPlan.collect {
+ case j: Join => j
+ }
+
+ assert(joinNodes.isEmpty)
+ checkPushedInfo(df, "PushedJoins: [h2.test.employee, h2.test.employee,
h2.test.employee]")
+ checkAnswer(df, rows)
+ }
+ }
+
+ test("Test self join with column pruning") {
+ val sqlQuery = """
+ |SELECT a.dept + 2, b.dept, b.salary FROM
+ |h2.test.employee a JOIN h2.test.employee b
+ |ON a.dept = b.dept + 1
+ |""".stripMargin
+
+ val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key ->
"false") {
+ sql(sqlQuery).collect().toSeq
+ }
+
+ withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") {
+ val df = sql(sqlQuery)
+
+ val joinNodes = df.queryExecution.optimizedPlan.collect {
+ case j: Join => j
+ }
+
+ assert(joinNodes.isEmpty)
+ checkPushedInfo(df, "PushedJoins: [h2.test.employee, h2.test.employee]")
+ checkAnswer(df, rows)
+ }
+ }
+
+ test("Test 2-way join with column pruning - different tables") {
+ val sqlQuery = """
+ |SELECT * FROM
+ |h2.test.employee a JOIN h2.test.people b
+ |ON a.dept = b.id
+ |""".stripMargin
+
+ val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key ->
"false") {
+ sql(sqlQuery).collect().toSeq
+ }
+
+ withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") {
+ val df = sql(sqlQuery)
+
+ val joinNodes = df.queryExecution.optimizedPlan.collect {
+ case j: Join => j
+ }
+
+ assert(joinNodes.isEmpty)
+ checkPushedInfo(df, "PushedJoins: [h2.test.employee, h2.test.people]")
+ checkPushedInfo(df,
+ "PushedFilters: [DEPT IS NOT NULL, ID IS NOT NULL, DEPT = ID]")
+ checkAnswer(df, rows)
+ }
+ }
+
+ test("Test multi-way self join with column pruning") {
+ val sqlQuery = """
+ |SELECT a.dept, b.*, c.dept, c.salary + a.salary
+ |FROM h2.test.employee a
+ |JOIN h2.test.employee b ON b.dept = a.dept + 1
+ |JOIN h2.test.employee c ON c.dept = b.dept - 1
+ |""".stripMargin
+
+ val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key ->
"false") {
+ sql(sqlQuery).collect().toSeq
+ }
+
+ withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") {
+ val df = sql(sqlQuery)
+
+ val joinNodes = df.queryExecution.optimizedPlan.collect {
+ case j: Join => j
+ }
+
+ assert(joinNodes.isEmpty)
+ checkPushedInfo(df, "PushedJoins: [h2.test.employee, h2.test.employee,
h2.test.employee]")
+ checkAnswer(df, rows)
+ }
+ }
+
+ test("Test aliases not supported in join pushdown") {
+ val sqlQuery = """
+ |SELECT a.dept, bc.*
+ |FROM h2.test.employee a
+ |JOIN (
+ | SELECT b.*, c.dept AS c_dept, c.salary AS c_salary
+ | FROM h2.test.employee b
+ | JOIN h2.test.employee c ON c.dept = b.dept - 1
+ |) bc ON bc.dept = a.dept + 1
+ |""".stripMargin
+
+ val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key ->
"false") {
+ sql(sqlQuery).collect().toSeq
+ }
+
+ withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") {
+ val df = sql(sqlQuery)
+ val joinNodes = df.queryExecution.optimizedPlan.collect {
+ case j: Join => j
+ }
+
+ assert(joinNodes.nonEmpty)
+ checkAnswer(df, rows)
+ }
+ }
+
+ test("Test join with dataframe with duplicated columns") {
+ val df1 = sql("SELECT dept FROM h2.test.employee")
+ val df2 = sql("SELECT dept, dept FROM h2.test.employee")
+
+ val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key ->
"false") {
+ df1.join(df2, "dept").collect().toSeq
+ }
+
+ withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") {
+ val joinDf = df1.join(df2, "dept")
+ val joinNodes = joinDf.queryExecution.optimizedPlan.collect {
+ case j: Join => j
+ }
+
+ assert(joinNodes.isEmpty)
+ checkPushedInfo(joinDf, "PushedJoins: [h2.test.employee,
h2.test.employee]")
+ checkAnswer(joinDf, rows)
+ }
+ }
+
+ test("Test aggregate on top of 2-way self join") {
+ val sqlQuery = """
+ |SELECT min(a.dept + b.dept), min(a.dept)
+ |FROM h2.test.employee a
+ |JOIN h2.test.employee b ON a.dept = b.dept + 1
+ |""".stripMargin
+
+ val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key ->
"false") {
+ sql(sqlQuery).collect().toSeq
+ }
+
+ withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") {
+ val df = sql(sqlQuery)
+ val joinNodes = df.queryExecution.optimizedPlan.collect {
+ case j: Join => j
+ }
+
+ val aggNodes = df.queryExecution.optimizedPlan.collect {
+ case a: Aggregate => a
+ }
+
+ assert(joinNodes.isEmpty)
+ assert(aggNodes.isEmpty)
+ checkPushedInfo(df, "PushedJoins: [h2.test.employee, h2.test.employee]")
+ checkAnswer(df, rows)
+ }
+ }
+
+ test("Test aggregate on top of multi-way self join") {
+ val sqlQuery = """
+ |SELECT min(a.dept + b.dept), min(a.dept), min(c.dept - 2)
+ |FROM h2.test.employee a
+ |JOIN h2.test.employee b ON b.dept = a.dept + 1
+ |JOIN h2.test.employee c ON c.dept = b.dept - 1
+ |""".stripMargin
+
+ val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key ->
"false") {
+ sql(sqlQuery).collect().toSeq
+ }
+
+ withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") {
+ val df = sql(sqlQuery)
+ val joinNodes = df.queryExecution.optimizedPlan.collect {
+ case j: Join => j
+ }
+
+ val aggNodes = df.queryExecution.optimizedPlan.collect {
+ case a: Aggregate => a
+ }
+
+ assert(joinNodes.isEmpty)
+ assert(aggNodes.isEmpty)
+ checkPushedInfo(df, "PushedJoins: [h2.test.employee, h2.test.employee,
h2.test.employee]")
+ checkAnswer(df, rows)
+ }
+ }
+
+ test("Test sort limit on top of join is pushed down") {
+ val sqlQuery = """
+ |SELECT min(a.dept + b.dept), a.dept, b.dept
+ |FROM h2.test.employee a
+ |JOIN h2.test.employee b ON b.dept = a.dept + 1
+ |GROUP BY a.dept, b.dept
+ |ORDER BY a.dept
+ |LIMIT 1
+ |""".stripMargin
+
+ val rows = withSQLConf(SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key ->
"false") {
+ sql(sqlQuery).collect().toSeq
+ }
+
+ withSQLConf(
+ SQLConf.DATA_SOURCE_V2_JOIN_PUSHDOWN.key -> "true") {
+ val df = sql(sqlQuery)
+ val joinNodes = df.queryExecution.optimizedPlan.collect {
+ case j: Join => j
+ }
+
+ val sortNodes = df.queryExecution.optimizedPlan.collect {
+ case s: Sort => s
+ }
+
+ val limitNodes = df.queryExecution.optimizedPlan.collect {
+ case l: GlobalLimit => l
+ }
+
+ assert(joinNodes.isEmpty)
+ assert(sortNodes.isEmpty)
+ assert(limitNodes.isEmpty)
+ checkPushedInfo(df, "PushedJoins: [h2.test.employee, h2.test.employee]")
+ checkAnswer(df, rows)
+ }
+ }
+}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index 3594f8d481a5..23761f684b45 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -142,6 +142,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession
with ExplainSuiteHel
.set("spark.sql.catalog.h2.pushDownAggregate", "true")
.set("spark.sql.catalog.h2.pushDownLimit", "true")
.set("spark.sql.catalog.h2.pushDownOffset", "true")
+ .set("spark.sql.catalog.h2.pushDownJoin", "true")
private def withConnection[T](f: Connection => T): T = {
val conn = DriverManager.getConnection(url, new Properties())
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]