imback82 commented on a change in pull request #29074:
URL: https://github.com/apache/spark/pull/29074#discussion_r466776005
##########
File path:
sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
##########
@@ -994,6 +994,88 @@ class PlannerSuite extends SharedSparkSession with
AdaptiveSparkPlanHelper {
}
}
}
+
+ test("EnsureRequirements.reorder should fallback to the right side
HashPartitioning") {
+ val plan1 = DummySparkPlan(
+ outputPartitioning = HashPartitioning(exprA :: exprB :: exprC :: Nil, 5))
+ val plan2 = DummySparkPlan(
+ outputPartitioning = HashPartitioning(exprB :: exprC :: Nil, 5))
+ // The left keys cannot be reordered to match the left partitioning, and
it should
+ // fall back to reorder the right side.
+ val smjExec = SortMergeJoinExec(
+ exprA :: exprB :: Nil, exprC :: exprB :: Nil, Inner, None, plan1, plan2)
+ val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(smjExec)
+ outputPlan match {
+ case SortMergeJoinExec(leftKeys, rightKeys, _, _,
+ SortExec(_, _,
+ ShuffleExchangeExec(HashPartitioning(leftPartitioningExpressions,
_), _, _), _),
+ SortExec(_, _,
+ DummySparkPlan(_, _, HashPartitioning(rightPartitioningExpressions,
_), _, _), _), _) =>
+ assert(leftKeys !== smjExec.leftKeys)
+ assert(rightKeys !== smjExec.rightKeys)
+ assert(leftKeys === leftPartitioningExpressions)
+ assert(rightKeys === rightPartitioningExpressions)
+ case _ => fail(outputPlan.toString)
+ }
+ }
+
+ test("EnsureRequirements.reorder should handle PartitioningCollection") {
+ // PartitioningCollection on the left side of join.
+ val plan1 = DummySparkPlan(
+ outputPartitioning = PartitioningCollection(Seq(
+ HashPartitioning(exprA :: exprB :: Nil, 5),
+ HashPartitioning(exprA :: Nil, 5))))
+ val plan2 = DummySparkPlan()
+ val smjExec1 = SortMergeJoinExec(
+ exprB :: exprA :: Nil, exprA :: exprB :: Nil, Inner, None, plan1, plan2)
+ val outputPlan =
EnsureRequirements(spark.sessionState.conf).apply(smjExec1)
+ outputPlan match {
+ case SortMergeJoinExec(leftKeys, rightKeys, _, _,
+ SortExec(_, _,
+ DummySparkPlan(_, _, PartitioningCollection(leftPartitionings), _,
_), _),
+ SortExec(_, _,
+ ShuffleExchangeExec(HashPartitioning(rightPartitioningExpressions,
_), _, _), _), _) =>
+ assert(leftKeys !== smjExec1.leftKeys)
+ assert(rightKeys !== smjExec1.rightKeys)
+ assert(leftKeys ===
leftPartitionings(0).asInstanceOf[HashPartitioning].expressions)
+ assert(rightKeys === rightPartitioningExpressions)
+ case _ => fail(outputPlan.toString)
+ }
+
+ // PartitioningCollection on the right side of join.
+ val smjExec2 = SortMergeJoinExec(
+ exprA :: exprB :: Nil, exprB :: exprA :: Nil, Inner, None, plan2, plan1)
+ val outputPlan2 =
EnsureRequirements(spark.sessionState.conf).apply(smjExec2)
+ outputPlan2 match {
+ case SortMergeJoinExec(leftKeys, rightKeys, _, _,
+ SortExec(_, _,
+ ShuffleExchangeExec(HashPartitioning(leftPartitioningExpressions,
_), _, _), _),
+ SortExec(_, _,
+ DummySparkPlan(_, _, PartitioningCollection(rightPartitionings), _,
_), _), _) =>
+ assert(leftKeys !== smjExec2.leftKeys)
+ assert(rightKeys !== smjExec2.rightKeys)
+ assert(leftKeys === leftPartitioningExpressions)
+ assert(rightKeys ===
rightPartitionings(0).asInstanceOf[HashPartitioning].expressions)
+ case _ => fail(outputPlan2.toString)
+ }
+
+ // Both sides are PartitioningCollection and falls back to the right side.
+ val smjExec3 = SortMergeJoinExec(
+ exprA :: exprC :: Nil, exprB :: exprA :: Nil, Inner, None, plan1, plan1)
+ val outputPlan3 =
EnsureRequirements(spark.sessionState.conf).apply(smjExec2)
Review comment:
Thanks for the catch!
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]