This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 8235f1d56bf2 [SPARK-46219][SQL] Unwrap cast in join predicates
8235f1d56bf2 is described below
commit 8235f1d56bf232bb713fe24ff6f2ffdaf49d2fcc
Author: Yuming Wang <[email protected]>
AuthorDate: Tue Dec 5 08:37:34 2023 -0800
[SPARK-46219][SQL] Unwrap cast in join predicates
### What changes were proposed in this pull request?
In a large data platform, it is very common to join different data types.
Similar to
[`reorderJoinPredicates`](https://github.com/apache/spark/blob/b03afa7bde5a050eb95284b275eae0aac2257f63/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala#L321-L338).
This PR adds a function in `EnsureRequirements` to unwrap cast in join
predicates to reduce shuffle if they are integral types.
The key idea here is that casting to either of these two types will not
affect the result of join for integral types join keys. For example: `a.intCol
= try_cast(b.bigIntCol AS int)`, if the value of `bigIntCol` exceeds the range
of int, the result of `try_cast(b.bigIntCol AS int)` is `null`, and the result
of `a.intCol = try_cast(b.bigIntCol AS int)` in the join condition is `false`.
The result is consistent with `cast(a.intCol AS bigint) = b.bigIntCol`.
### Why are the changes needed?
Reduce shuffle to improve query performance.
Case 1: Shuffle before join
```sql
CREATE TABLE t1(id int) USING parquet;
CREATE TABLE t2(id int) USING parquet;
CREATE TABLE t3(id bigint) USING parquet;
SET spark.sql.autoBroadcastJoinThreshold=-1;
explain SELECT * FROM t1 JOIN t2 ON t1.id = t2.id JOIN t3 ON t1.id = t3.id;
explain SELECT * FROM (SELECT *, row_number() OVER (PARTITION BY id ORDER
BY id) AS rn FROM t1) t JOIN t2 ON t.id = t2.id WHERE rn = 1;
```
The plan differences after this PR:
```diff
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
-+- SortMergeJoin [cast(id#10 as bigint)], [id#12L], Inner
- :- Sort [cast(id#10 as bigint) ASC NULLS FIRST], false, 0
- : +- Exchange hashpartitioning(cast(id#10 as bigint), 5),
ENSURE_REQUIREMENTS, [plan_id=54]
- : +- SortMergeJoin [id#10], [id#11], Inner
- : :- Sort [id#10 ASC NULLS FIRST], false, 0
- : : +- Exchange hashpartitioning(id#10, 5),
ENSURE_REQUIREMENTS, [plan_id=47]
- : : +- Filter isnotnull(id#10)
- : : +- FileScan parquet spark_catalog.default.t1[id#10]
- : +- Sort [id#11 ASC NULLS FIRST], false, 0
- : +- Exchange hashpartitioning(id#11, 5),
ENSURE_REQUIREMENTS, [plan_id=48]
- : +- Filter isnotnull(id#11)
- : +- FileScan parquet spark_catalog.default.t2[id#11]
- +- Sort [id#12L ASC NULLS FIRST], false, 0
- +- Exchange hashpartitioning(id#12L, 5), ENSURE_REQUIREMENTS,
[plan_id=55]
- +- Filter isnotnull(id#12L)
- +- FileScan parquet spark_catalog.default.t3[id#12L]
++- SortMergeJoin [id#20], [try_cast(id#22L as int)], Inner
+ :- SortMergeJoin [id#20], [id#21], Inner
+ : :- Sort [id#20 ASC NULLS FIRST], false, 0
+ : : +- Exchange hashpartitioning(id#20, 5), ENSURE_REQUIREMENTS,
[plan_id=50]
+ : : +- Filter isnotnull(id#20)
+ : : +- FileScan parquet spark_catalog.default.t1[id#20]
+ : +- Sort [id#21 ASC NULLS FIRST], false, 0
+ : +- Exchange hashpartitioning(id#21, 5), ENSURE_REQUIREMENTS,
[plan_id=51]
+ : +- Filter isnotnull(id#21)
+ : +- FileScan parquet spark_catalog.default.t2[id#21]
+ +- Sort [try_cast(id#22L as int) ASC NULLS FIRST], false, 0
+ +- Exchange hashpartitioning(try_cast(id#22L as int), 5),
ENSURE_REQUIREMENTS, [plan_id=58]
+ +- Filter isnotnull(id#22L)
+ +- FileScan parquet spark_catalog.default.t3[id#22L]
```
```diff
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
-+- SortMergeJoin [cast(id#22 as bigint)], [id#23L], Inner
- :- Sort [cast(id#22 as bigint) ASC NULLS FIRST], false, 0
- : +- Exchange hashpartitioning(cast(id#22 as bigint), 5),
ENSURE_REQUIREMENTS, [plan_id=62]
- : +- Filter (rn#20 = 1)
- : +- Window [row_number() windowspecdefinition(id#22, id#22 ASC
NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(),
currentrow$())) AS rn#20], [id#22], [id#22 ASC NULLS FIRST]
- : +- WindowGroupLimit [id#22], [id#22 ASC NULLS FIRST],
row_number(), 1, Final
- : +- Sort [id#22 ASC NULLS FIRST, id#22 ASC NULLS FIRST],
false, 0
- : +- Exchange hashpartitioning(id#22, 5),
ENSURE_REQUIREMENTS, [plan_id=55]
- : +- WindowGroupLimit [id#22], [id#22 ASC NULLS
FIRST], row_number(), 1, Partial
- : +- Sort [id#22 ASC NULLS FIRST, id#22 ASC NULLS
FIRST], false, 0
- : +- Filter isnotnull(id#22)
- : +- FileScan parquet
spark_catalog.default.t1[id#22]
- +- Sort [id#23L ASC NULLS FIRST], false, 0
- +- Exchange hashpartitioning(id#23L, 5), ENSURE_REQUIREMENTS,
[plan_id=63]
++- SortMergeJoin [id#22], [try_cast(id#23L as int)], Inner
+ :- Filter (rn#20 = 1)
+ : +- Window [row_number() windowspecdefinition(id#22, id#22 ASC NULLS
FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS
rn#20], [id#22], [id#22 ASC NULLS FIRST]
+ : +- WindowGroupLimit [id#22], [id#22 ASC NULLS FIRST],
row_number(), 1, Final
+ : +- Sort [id#22 ASC NULLS FIRST, id#22 ASC NULLS FIRST], false, 0
+ : +- Exchange hashpartitioning(id#22, 5),
ENSURE_REQUIREMENTS, [plan_id=55]
+ : +- WindowGroupLimit [id#22], [id#22 ASC NULLS FIRST],
row_number(), 1, Partial
+ : +- Sort [id#22 ASC NULLS FIRST, id#22 ASC NULLS
FIRST], false, 0
+ : +- Filter isnotnull(id#22)
+ : +- FileScan parquet
spark_catalog.default.t1[id#22]
+ +- Sort [try_cast(id#23L as int) ASC NULLS FIRST], false, 0
+ +- Exchange hashpartitioning(try_cast(id#23L as int), 5),
ENSURE_REQUIREMENTS, [plan_id=63]
+- Filter isnotnull(id#23L)
+- FileScan parquet spark_catalog.default.t2[id#23L]
```
Case 2: Bucket table
```sql
CREATE TABLE t1(id bigint) USING parquet CLUSTERED BY (id) INTO 200 buckets;
CREATE TABLE t2(id decimal(18, 0)) USING parquet CLUSTERED BY (id) INTO 200
buckets;
SET spark.sql.autoBroadcastJoinThreshold=-1;
explain SELECT * FROM t1 JOIN t2 ON t1.id = t2.id;
```
The plan differences after this PR:
```diff
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
-+- SortMergeJoin [cast(id#10L as decimal(20,0))], [cast(id#11 as
decimal(20,0))], Inner
- :- Sort [cast(id#10L as decimal(20,0)) ASC NULLS FIRST], false, 0
- : +- Exchange hashpartitioning(cast(id#10L as decimal(20,0)), 5),
ENSURE_REQUIREMENTS, [plan_id=38]
- : +- Filter isnotnull(id#10L)
- : +- FileScan parquet spark_catalog.default.t1[id#10L]
- +- Sort [cast(id#11 as decimal(20,0)) ASC NULLS FIRST], false, 0
- +- Exchange hashpartitioning(cast(id#11 as decimal(20,0)), 5),
ENSURE_REQUIREMENTS, [plan_id=42]
- +- Filter isnotnull(id#11)
- +- FileScan parquet spark_catalog.default.t2[id#11]
++- SortMergeJoin [id#20L], [try_cast(id#21 as bigint)], Inner
+ :- Sort [id#20L ASC NULLS FIRST], false, 0
+ : +- Filter isnotnull(id#20L)
+ : +- FileScan parquet spark_catalog.default.t1[id#20L] Bucketed:
true, SelectedBucketsCount: 200 out of 200
+ +- Sort [try_cast(id#21 as bigint) ASC NULLS FIRST], false, 0
+ +- Exchange hashpartitioning(try_cast(id#21 as bigint), 200),
ENSURE_REQUIREMENTS, [plan_id=42]
+ +- Filter isnotnull(id#21)
+ +- FileScan parquet spark_catalog.default.t2[id#21] Bucketed:
false (disabled by query planner)
```
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Unit test.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #44133 from wangyum/SPARK-46219.
Authored-by: Yuming Wang <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
.../org/apache/spark/sql/internal/SQLConf.scala | 10 ++
.../bucketing/CoalesceBucketsInJoin.scala | 22 +---
.../execution/exchange/EnsureRequirements.scala | 25 ++++-
...ractJoinWithUnwrappedCastInJoinPredicates.scala | 114 +++++++++++++++++++++
.../spark/sql/execution/joins/ShuffledJoin.scala | 14 ++-
.../apache/spark/sql/execution/PlannerSuite.scala | 74 +++++++++++++
.../spark/sql/sources/BucketedReadSuite.scala | 65 ++++++++++++
7 files changed, 301 insertions(+), 23 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 080928baf8a9..9918d583d49e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -564,6 +564,14 @@ object SQLConf {
.booleanConf
.createWithDefault(false)
+ val UNWRAP_CAST_IN_JOIN_CONDITION_ENABLED =
+ buildConf("spark.sql.unwrapCastInJoinCondition.enabled")
+ .doc("When true, unwrap the cast in the join condition to reduce shuffle
if they are " +
+ "integral types.")
+ .version("4.0.0")
+ .booleanConf
+ .createWithDefault(true)
+
val MAX_SINGLE_PARTITION_BYTES =
buildConf("spark.sql.maxSinglePartitionBytes")
.doc("The maximum number of bytes allowed for a single partition.
Otherwise, The planner " +
"will introduce shuffle to improve parallelism.")
@@ -5043,6 +5051,8 @@ class SQLConf extends Serializable with Logging with
SqlApiConf {
def preferSortMergeJoin: Boolean = getConf(PREFER_SORTMERGEJOIN)
+ def unwrapCastInJoinConditionEnabled: Boolean =
getConf(UNWRAP_CAST_IN_JOIN_CONDITION_ENABLED)
+
def enableRadixSort: Boolean = getConf(RADIX_SORT_ENABLED)
def isParquetSchemaMergingEnabled: Boolean =
getConf(PARQUET_SCHEMA_MERGING_ENABLED)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/bucketing/CoalesceBucketsInJoin.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/bucketing/CoalesceBucketsInJoin.scala
index d1464b4ac4ee..ab0eaa044dea 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/bucketing/CoalesceBucketsInJoin.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/bucketing/CoalesceBucketsInJoin.scala
@@ -20,9 +20,7 @@ package org.apache.spark.sql.execution.bucketing
import scala.annotation.tailrec
import org.apache.spark.sql.catalyst.catalog.BucketSpec
-import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
-import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning,
Partitioning, PartitioningCollection}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{FileSourceScanExec, FilterExec,
ProjectExec, SparkPlan}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec,
BroadcastNestedLoopJoinExec, ShuffledHashJoinExec, ShuffledJoin,
SortMergeJoinExec}
@@ -131,27 +129,11 @@ object ExtractJoinWithBuckets {
}
}
- /**
- * The join keys should match with expressions for output partitioning. Note
that
- * the ordering does not matter because it will be handled in
`EnsureRequirements`.
- */
- private def satisfiesOutputPartitioning(
- keys: Seq[Expression],
- partitioning: Partitioning): Boolean = {
- partitioning match {
- case HashPartitioning(exprs, _) if exprs.length == keys.length =>
- exprs.forall(e => keys.exists(_.semanticEquals(e)))
- case PartitioningCollection(partitionings) =>
- partitionings.exists(satisfiesOutputPartitioning(keys, _))
- case _ => false
- }
- }
-
private def isApplicable(j: ShuffledJoin): Boolean = {
hasScanOperation(j.left) &&
hasScanOperation(j.right) &&
- satisfiesOutputPartitioning(j.leftKeys, j.left.outputPartitioning) &&
- satisfiesOutputPartitioning(j.rightKeys, j.right.outputPartitioning)
+ j.satisfiesOutputPartitioning(j.leftKeys, j.left.outputPartitioning) &&
+ j.satisfiesOutputPartitioning(j.rightKeys, j.right.outputPartitioning)
}
private def isDivisible(numBuckets1: Int, numBuckets2: Int): Boolean = {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
index 2a7c1206bb41..38a8b5db2695 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -337,6 +337,28 @@ case class EnsureRequirements(
}
}
+ /**
+ * Unwrap the cast in join predicates to reduce shuffle.
+ */
+ private def unwrapCastInJoinPredicates(plan: SparkPlan): SparkPlan = {
+ if (conf.unwrapCastInJoinConditionEnabled) {
+ plan match {
+ case ExtractJoinWithUnwrappedCastInJoinPredicates(join, joinKeys) =>
+ val (leftKeys, rightKeys) = joinKeys.unzip
+ join match {
+ case j: SortMergeJoinExec =>
+ j.copy(leftKeys = leftKeys, rightKeys = rightKeys)
+ case j: ShuffledHashJoinExec =>
+ j.copy(leftKeys = leftKeys, rightKeys = rightKeys)
+ case other => other
+ }
+ case _ => plan
+ }
+ } else {
+ plan
+ }
+ }
+
/**
* Checks whether two children, `left` and `right`, of a join operator have
compatible
* `KeyGroupedPartitioning`, and can benefit from storage-partitioned join.
@@ -605,7 +627,8 @@ case class EnsureRequirements(
}
case operator: SparkPlan =>
- val reordered = reorderJoinPredicates(operator)
+ val unwrapped = unwrapCastInJoinPredicates(operator)
+ val reordered = reorderJoinPredicates(unwrapped)
val newChildren = ensureDistributionAndOrdering(
Some(reordered),
reordered.children,
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExtractJoinWithUnwrappedCastInJoinPredicates.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExtractJoinWithUnwrappedCastInJoinPredicates.scala
new file mode 100644
index 000000000000..5d46fac90985
--- /dev/null
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExtractJoinWithUnwrappedCastInJoinPredicates.scala
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.exchange
+
+import
org.apache.spark.sql.catalyst.analysis.AnsiTypeCoercion.findWiderTypeForTwo
+import org.apache.spark.sql.catalyst.expressions.{Cast, EvalMode, Expression}
+import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning,
PartitioningCollection}
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.joins.ShuffledJoin
+import org.apache.spark.sql.types.{DataType, DecimalType, IntegralType}
+
+/**
+ * An extractor that extracts `SortMergeJoinExec` and `ShuffledHashJoin`,
+ * where one sides can do bucketed read after unwrap cast in join keys.
+ */
+object ExtractJoinWithUnwrappedCastInJoinPredicates {
+ private def isIntegralType(dt: DataType): Boolean = dt match {
+ case _: IntegralType => true
+ case DecimalType.Fixed(_, 0) => true
+ case _ => false
+ }
+
+ private def unwrapCastInJoinKeys(joinKeys: Seq[Expression]): Seq[Expression]
= {
+ joinKeys.map {
+ case c: Cast if isIntegralType(c.child.dataType) => c.child
+ case e => e
+ }
+ }
+
+ // Casts the left or right side of join keys to the same data type.
+ private def coerceJoinKeyType(
+ unwrapLeftKeys: Seq[Expression],
+ unwrapRightKeys: Seq[Expression],
+ isAddCastToLeftSide: Boolean): Seq[(Expression, Expression)] = {
+ unwrapLeftKeys.zip(unwrapRightKeys).map {
+ case (l, r) if l.dataType != r.dataType =>
+ // Use TRY mode to avoid runtime exception in ANSI mode or data issue
in non-ANSI mode.
+ if (isAddCastToLeftSide) {
+ Cast(l, r.dataType, evalMode = EvalMode.TRY) -> r
+ } else {
+ l -> Cast(r, l.dataType, evalMode = EvalMode.TRY)
+ }
+ case (l, r) => l -> r
+ }
+ }
+
+ private def unwrapCastInJoinPredicates(j: ShuffledJoin):
Option[Seq[(Expression, Expression)]] = {
+ val leftKeys = unwrapCastInJoinKeys(j.leftKeys)
+ val rightKeys = unwrapCastInJoinKeys(j.rightKeys)
+ // Make sure cast to wider type.
+ // For example, we do not support: cast(longCol as int) = cast(decimalCol
as int).
+ val isCastToWiderType = leftKeys.zip(rightKeys).zipWithIndex.forall {
+ case ((e1, e2), i) =>
+ findWiderTypeForTwo(e1.dataType,
e2.dataType).contains(j.leftKeys(i).dataType)
+ }
+ if (isCastToWiderType) {
+ val leftSatisfies = j.satisfiesOutputPartitioning(leftKeys,
j.left.outputPartitioning)
+ val rightSatisfies = j.satisfiesOutputPartitioning(rightKeys,
j.right.outputPartitioning)
+ if (leftSatisfies && rightSatisfies) {
+ // If there is a bucketed read, their number of partitions may be
inconsistent.
+ // If the number of partitions on the left side is less than the
number of partitions
+ // on the right side, cast the left side keys to the data type of the
right side keys.
+ // Otherwise, cast the right side keys to the data type of the left
side keys.
+ Some(coerceJoinKeyType(leftKeys, rightKeys,
+ j.left.outputPartitioning.numPartitions <
j.right.outputPartitioning.numPartitions))
+ } else if (leftSatisfies) {
+ Some(coerceJoinKeyType(leftKeys, rightKeys, false))
+ } else if (rightSatisfies) {
+ Some(coerceJoinKeyType(leftKeys, rightKeys, true))
+ } else {
+ None
+ }
+ } else {
+ None
+ }
+ }
+
+ private def isTryToUnwrapCastInJoinPredicates(j: ShuffledJoin): Boolean = {
+ (j.leftKeys.exists(_.isInstanceOf[Cast]) ||
j.rightKeys.exists(_.isInstanceOf[Cast])) &&
+ !j.satisfiesOutputPartitioning(j.leftKeys, j.left.outputPartitioning) &&
+ !j.satisfiesOutputPartitioning(j.rightKeys, j.right.outputPartitioning)
&&
+ j.children.map(_.outputPartitioning).exists { _ match {
+ case _: PartitioningCollection => true
+ case _: HashPartitioning => true
+ case _ => false
+ }}
+ }
+
+ def unapply(plan: SparkPlan): Option[(ShuffledJoin, Seq[(Expression,
Expression)])] = {
+ plan match {
+ case j: ShuffledJoin if isTryToUnwrapCastInJoinPredicates(j) =>
+ unwrapCastInJoinPredicates(j) match {
+ case Some(joinKeys) => Some(j, joinKeys)
+ case _ => None
+ }
+ case _ => None
+ }
+ }
+}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala
index 7c4628c8576c..9591218b099b 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala
@@ -17,9 +17,9 @@
package org.apache.spark.sql.execution.joins
-import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter,
InnerLike, LeftExistence, LeftOuter, RightOuter}
-import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution,
Distribution, Partitioning, PartitioningCollection, UnknownPartitioning,
UnspecifiedDistribution}
+import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution,
Distribution, HashPartitioning, Partitioning, PartitioningCollection,
UnknownPartitioning, UnspecifiedDistribution}
/**
* Holds common logic for join operators by shuffling two child relations
@@ -56,6 +56,16 @@ trait ShuffledJoin extends JoinCodegenSupport {
s"ShuffledJoin should not take $x as the JoinType")
}
+ def satisfiesOutputPartitioning(keys: Seq[Expression], partitioning:
Partitioning): Boolean = {
+ partitioning match {
+ case HashPartitioning(exprs, _) if exprs.length == keys.length =>
+ exprs.forall(e => keys.exists(_.semanticEquals(e)))
+ case PartitioningCollection(partitionings) =>
+ partitionings.exists(satisfiesOutputPartitioning(keys, _))
+ case _ => false
+ }
+ }
+
override def output: Seq[Attribute] = {
joinType match {
case _: InnerLike =>
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index c5b1e68fb912..8565e06ba9fa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -1372,6 +1372,80 @@ class PlannerSuite extends SharedSparkSession with
AdaptiveSparkPlanHelper {
assert(numOutputPartitioning.size == 8)
}
}
+
+ test("SPARK-46219: Unwrap cast in join condition") {
+ val intExpr = Literal(1)
+ val longExpr = Literal(1L)
+ val smjExec = SortMergeJoinExec(
+ leftKeys = Cast(intExpr, LongType) :: Nil,
+ rightKeys = longExpr :: Nil,
+ joinType = Inner,
+ condition = None,
+ left = DummySparkPlan(outputPartitioning = HashPartitioning(intExpr::
Nil, 5)),
+ right = DummySparkPlan())
+
+ Seq(true, false).foreach { unwrapCast =>
+ withSQLConf(SQLConf.UNWRAP_CAST_IN_JOIN_CONDITION_ENABLED.key ->
unwrapCast.toString) {
+ val outputPlan = EnsureRequirements.apply(smjExec)
+ if (unwrapCast) {
+ outputPlan match {
+ case SortMergeJoinExec(leftKeys, rightKeys, _, _,
+ SortExec(_, _, _: DummySparkPlan, _),
+ SortExec(_, _, ShuffleExchangeExec(_, _: DummySparkPlan, _, _),
_), _) =>
+ assert(leftKeys === Seq(intExpr))
+ assert(rightKeys === Seq(Cast(longExpr, IntegerType, evalMode =
EvalMode.TRY)))
+ case _ => fail(outputPlan.toString)
+ }
+ } else {
+ outputPlan match {
+ case SortMergeJoinExec(leftKeys, rightKeys, _, _,
+ SortExec(_, _, ShuffleExchangeExec(_, _: DummySparkPlan, _, _),
_),
+ SortExec(_, _, ShuffleExchangeExec(_, _: DummySparkPlan, _, _),
_), _) =>
+ assert(leftKeys === smjExec.leftKeys)
+ assert(rightKeys === smjExec.rightKeys)
+ case _ => fail(outputPlan.toString)
+ }
+ }
+ }
+ }
+ }
+
+ test("SPARK-46219: Number of partitions may be inconsistent") {
+ val longExpr = Literal(1L)
+ val decimalExpr = Literal(Decimal(1L, 18, 0))
+ val smjExec = SortMergeJoinExec(
+ leftKeys = Cast(longExpr, DecimalType(20, 0)) :: Nil,
+ rightKeys = Cast(decimalExpr, DecimalType(20, 0)) :: Nil,
+ joinType = Inner,
+ condition = None,
+ left = DummySparkPlan(outputPartitioning = HashPartitioning(longExpr ::
Nil, 10)),
+ right = DummySparkPlan(outputPartitioning = HashPartitioning(decimalExpr
:: Nil, 5)))
+
+ Seq(true, false).foreach { unwrapCast =>
+ withSQLConf(SQLConf.UNWRAP_CAST_IN_JOIN_CONDITION_ENABLED.key ->
unwrapCast.toString) {
+ val outputPlan = EnsureRequirements.apply(smjExec)
+ if (unwrapCast) {
+ outputPlan match {
+ case SortMergeJoinExec(leftKeys, rightKeys, _, _,
+ SortExec(_, _, _: DummySparkPlan, _),
+ SortExec(_, _, ShuffleExchangeExec(_, _: DummySparkPlan, _, _),
_), _) =>
+ assert(leftKeys === Seq(longExpr))
+ assert(rightKeys === Seq(Cast(decimalExpr, LongType, evalMode =
EvalMode.TRY)))
+ case _ => fail(outputPlan.toString)
+ }
+ } else {
+ outputPlan match {
+ case SortMergeJoinExec(leftKeys, rightKeys, _, _,
+ SortExec(_, _, ShuffleExchangeExec(_, _: DummySparkPlan, _, _),
_),
+ SortExec(_, _, ShuffleExchangeExec(_, _: DummySparkPlan, _, _),
_), _) =>
+ assert(leftKeys === smjExec.leftKeys)
+ assert(rightKeys === smjExec.rightKeys)
+ case _ => fail(outputPlan.toString)
+ }
+ }
+ }
+ }
+ }
}
// Used for unit-testing EnsureRequirements
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
index 3573bafe482c..52a316e63a81 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
@@ -34,6 +34,7 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils}
+import org.apache.spark.sql.types.{DataType, DecimalType, IntegerType,
LongType}
import org.apache.spark.tags.SlowSQLTest
import org.apache.spark.util.collection.BitSet
@@ -1088,4 +1089,68 @@ abstract class BucketedReadSuite extends QueryTest with
SQLTestUtils with Adapti
}
}
}
+
+ test("SPARK-46219: Unwrap cast in join condition") {
+ def verify(
+ query: String,
+ expectedNumShuffles: Int,
+ numPartitions: Option[Int] = None,
+ partitioningKeyTypes: Option[Seq[DataType]] = None): Unit = {
+ Seq(true, false).foreach { ansiEnabled =>
+ Seq(true, false).foreach { aqeEnabled =>
+ withSQLConf(
+ SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString,
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqeEnabled.toString) {
+ val df = sql(query)
+ val plan = df.queryExecution.executedPlan
+ val shuffles = collect(plan) {
+ case s: ShuffleExchangeExec => s
+ }
+ assert(shuffles.size === expectedNumShuffles)
+ if (shuffles.size == 1) {
+ val outputPartitioning = shuffles.head.outputPartitioning
+ assert(outputPartitioning.numPartitions === numPartitions.get)
+ assert(outputPartitioning.asInstanceOf[HashPartitioning]
+ .expressions.map(_.dataType) === partitioningKeyTypes.get)
+
+ collect(plan) { case s: SortMergeJoinExec => s
}.flatMap(_.expressions).foreach {
+ case c: Cast => assert(c.evalMode === EvalMode.TRY) // The
EvalMode should be try.
+ case _ =>
+ }
+
+ checkAnswer(df, Row(1, 1) :: Nil)
+ }
+ }
+ }
+ }
+ }
+
+ withTable("t1", "t2", "t3", "t4") {
+ sql(
+ s"""
+ |CREATE TABLE t1 USING parquet CLUSTERED BY (i) INTO 8 buckets AS
+ |SELECT CAST(v AS int) AS i FROM values(1), (${Int.MaxValue}) AS
data(v)
+ |""".stripMargin)
+ sql(
+ s"""
+ |CREATE TABLE t2 USING parquet CLUSTERED BY (i) INTO 8 buckets AS
+ |SELECT CAST(v AS bigint) AS i FROM values(1), (${Long.MaxValue})
AS data(v)
+ |""".stripMargin)
+ sql(
+ s"""
+ |CREATE TABLE t3 USING parquet CLUSTERED BY (i) INTO 4 buckets AS
+ |SELECT CAST(v AS decimal(18, 0)) AS i FROM values(1), (${"9" *
18}) AS data(v)
+ |""".stripMargin)
+ spark.table("t2").write.saveAsTable("t4")
+
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0",
+ SQLConf.UNWRAP_CAST_IN_JOIN_CONDITION_ENABLED.key -> "true") {
+ verify("SELECT * FROM t2 JOIN t3 ON t2.i = t3.i", 1, Some(8),
Some(Seq(LongType)))
+ verify("SELECT * FROM t1 JOIN t4 ON t1.i = t4.i", 1, Some(8),
Some(Seq(IntegerType)))
+ verify("SELECT * FROM t3 JOIN t4 ON t3.i = t4.i", 1, Some(4),
Some(Seq(DecimalType(18, 0))))
+ // Do not unwrap cast if it is added by user.
+ verify("SELECT * FROM t2 JOIN t3 ON cast(t2.i as int) = cast(t3.i as
int)", 2)
+ }
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]