Repository: spark Updated Branches: refs/heads/master 874350905 -> 682eb4f2e
[SPARK-22042][SQL] ReorderJoinPredicates can break when child's partitioning is not decided ## What changes were proposed in this pull request? See jira description for the bug : https://issues.apache.org/jira/browse/SPARK-22042 Fix done in this PR is: In `EnsureRequirements`, apply `ReorderJoinPredicates` over the input tree before doing its core logic. Since the tree is transformed bottom-up, we can assure that the children are resolved before doing `ReorderJoinPredicates`. Theoretically this will guarantee to cover all such cases while keeping the code simple. My small grudge is for cosmetic reasons. This PR will look weird given that we don't call rules from other rules (not to my knowledge). I could have moved all the logic for `ReorderJoinPredicates` into `EnsureRequirements` but that will make it a but crowded. I am happy to discuss if there are better options. ## How was this patch tested? Added a new test case Author: Tejas Patil <tej...@fb.com> Closes #19257 from tejasapatil/SPARK-22042_ReorderJoinPredicates. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/682eb4f2 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/682eb4f2 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/682eb4f2 Branch: refs/heads/master Commit: 682eb4f2ea152ce1043fbe689ea95318926b91b0 Parents: 8743509 Author: Tejas Patil <tej...@fb.com> Authored: Tue Dec 12 23:30:06 2017 -0800 Committer: gatorsmile <gatorsm...@gmail.com> Committed: Tue Dec 12 23:30:06 2017 -0800 ---------------------------------------------------------------------- .../spark/sql/execution/QueryExecution.scala | 2 - .../execution/exchange/EnsureRequirements.scala | 76 +++++++++++++++- .../execution/joins/ReorderJoinPredicates.scala | 94 -------------------- .../spark/sql/sources/BucketedReadSuite.scala | 31 +++++++ 4 files changed, 106 insertions(+), 97 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/682eb4f2/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index f404621..946475a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec, ShowTablesCommand} import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange} -import org.apache.spark.sql.execution.joins.ReorderJoinPredicates import org.apache.spark.sql.types.{BinaryType, DateType, DecimalType, TimestampType, _} import org.apache.spark.util.Utils @@ -104,7 +103,6 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { protected def preparations: Seq[Rule[SparkPlan]] = Seq( python.ExtractPythonUDFs, PlanSubqueries(sparkSession), - new ReorderJoinPredicates, EnsureRequirements(sparkSession.sessionState.conf), CollapseCodegenStages(sparkSession.sessionState.conf), ReuseExchange(sparkSession.sessionState.conf), http://git-wip-us.apache.org/repos/asf/spark/blob/682eb4f2/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 4e2ca37..82f0b9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -17,10 +17,14 @@ package org.apache.spark.sql.execution.exchange +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, + SortMergeJoinExec} import org.apache.spark.sql.internal.SQLConf /** @@ -248,6 +252,75 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { operator.withNewChildren(children) } + /** + * When the physical operators are created for JOIN, the ordering of join keys is based on order + * in which the join keys appear in the user query. That might not match with the output + * partitioning of the join node's children (thus leading to extra sort / shuffle being + * introduced). This rule will change the ordering of the join keys to match with the + * partitioning of the join nodes' children. + */ + def reorderJoinPredicates(plan: SparkPlan): SparkPlan = { + def reorderJoinKeys( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + leftPartitioning: Partitioning, + rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = { + + def reorder(expectedOrderOfKeys: Seq[Expression], + currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { + val leftKeysBuffer = ArrayBuffer[Expression]() + val rightKeysBuffer = ArrayBuffer[Expression]() + + expectedOrderOfKeys.foreach(expression => { + val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression)) + leftKeysBuffer.append(leftKeys(index)) + rightKeysBuffer.append(rightKeys(index)) + }) + (leftKeysBuffer, rightKeysBuffer) + } + + if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) { + leftPartitioning match { + case HashPartitioning(leftExpressions, _) + if leftExpressions.length == leftKeys.length && + leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) => + reorder(leftExpressions, leftKeys) + + case _ => rightPartitioning match { + case HashPartitioning(rightExpressions, _) + if rightExpressions.length == rightKeys.length && + rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) => + reorder(rightExpressions, rightKeys) + + case _ => (leftKeys, rightKeys) + } + } + } else { + (leftKeys, rightKeys) + } + } + + plan.transformUp { + case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, + right) => + val (reorderedLeftKeys, reorderedRightKeys) = + reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) + BroadcastHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition, + left, right) + + case ShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) => + val (reorderedLeftKeys, reorderedRightKeys) = + reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) + ShuffledHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition, + left, right) + + case SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right) => + val (reorderedLeftKeys, reorderedRightKeys) = + reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) + SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, left, right) + } + } + def apply(plan: SparkPlan): SparkPlan = plan.transformUp { case operator @ ShuffleExchangeExec(partitioning, child, _) => child.children match { @@ -255,6 +328,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { if (childPartitioning.guarantees(partitioning)) child else operator case _ => operator } - case operator: SparkPlan => ensureDistributionAndOrdering(operator) + case operator: SparkPlan => + ensureDistributionAndOrdering(reorderJoinPredicates(operator)) } } http://git-wip-us.apache.org/repos/asf/spark/blob/682eb4f2/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ReorderJoinPredicates.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ReorderJoinPredicates.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ReorderJoinPredicates.scala deleted file mode 100644 index 534d8c5..0000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ReorderJoinPredicates.scala +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.SparkPlan - -/** - * When the physical operators are created for JOIN, the ordering of join keys is based on order - * in which the join keys appear in the user query. That might not match with the output - * partitioning of the join node's children (thus leading to extra sort / shuffle being - * introduced). This rule will change the ordering of the join keys to match with the - * partitioning of the join nodes' children. - */ -class ReorderJoinPredicates extends Rule[SparkPlan] { - private def reorderJoinKeys( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - leftPartitioning: Partitioning, - rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = { - - def reorder( - expectedOrderOfKeys: Seq[Expression], - currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { - val leftKeysBuffer = ArrayBuffer[Expression]() - val rightKeysBuffer = ArrayBuffer[Expression]() - - expectedOrderOfKeys.foreach(expression => { - val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression)) - leftKeysBuffer.append(leftKeys(index)) - rightKeysBuffer.append(rightKeys(index)) - }) - (leftKeysBuffer, rightKeysBuffer) - } - - if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) { - leftPartitioning match { - case HashPartitioning(leftExpressions, _) - if leftExpressions.length == leftKeys.length && - leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) => - reorder(leftExpressions, leftKeys) - - case _ => rightPartitioning match { - case HashPartitioning(rightExpressions, _) - if rightExpressions.length == rightKeys.length && - rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) => - reorder(rightExpressions, rightKeys) - - case _ => (leftKeys, rightKeys) - } - } - } else { - (leftKeys, rightKeys) - } - } - - def apply(plan: SparkPlan): SparkPlan = plan.transformUp { - case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) => - val (reorderedLeftKeys, reorderedRightKeys) = - reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) - BroadcastHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition, - left, right) - - case ShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) => - val (reorderedLeftKeys, reorderedRightKeys) = - reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) - ShuffledHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition, - left, right) - - case SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right) => - val (reorderedLeftKeys, reorderedRightKeys) = - reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) - SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, left, right) - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/682eb4f2/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index ab18905..9025859 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -602,6 +602,37 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { ) } + test("SPARK-22042 ReorderJoinPredicates can break when child's partitioning is not decided") { + withTable("bucketed_table", "table1", "table2") { + df.write.format("parquet").saveAsTable("table1") + df.write.format("parquet").saveAsTable("table2") + df.write.format("parquet").bucketBy(8, "j", "k").saveAsTable("bucketed_table") + + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { + checkAnswer( + sql(""" + |SELECT ab.i, ab.j, ab.k, c.i, c.j, c.k + |FROM ( + | SELECT a.i, a.j, a.k + | FROM bucketed_table a + | JOIN table1 b + | ON a.i = b.i + |) ab + |JOIN table2 c + |ON ab.i = c.i + |""".stripMargin), + sql(""" + |SELECT a.i, a.j, a.k, c.i, c.j, c.k + |FROM bucketed_table a + |JOIN table1 b + |ON a.i = b.i + |JOIN table2 c + |ON a.i = c.i + |""".stripMargin)) + } + } + } + test("error if there exists any malformed bucket files") { withTable("bucketed_table") { df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table") --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org