Repository: spark
Updated Branches:
  refs/heads/master 94439997d -> 7f16c6910


[SPARK-19122][SQL] Unnecessary shuffle+sort added if join predicates ordering 
differ from bucketing and sorting order

## What changes were proposed in this pull request?

Jira : https://issues.apache.org/jira/browse/SPARK-19122

`leftKeys` and `rightKeys` in `SortMergeJoinExec` are altered based on the 
ordering of join keys in the child's `outputPartitioning`. This is done 
everytime `requiredChildDistribution` is invoked during query planning.

## How was this patch tested?

- Added new test case
- Existing tests

Author: Tejas Patil <tej...@fb.com>

Closes #16985 from tejasapatil/SPARK-19122_join_order_shuffle.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7f16c691
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7f16c691
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7f16c691

Branch: refs/heads/master
Commit: 7f16c6910700ab90fe8e382b4a99022f67696317
Parents: 9443999
Author: Tejas Patil <tej...@fb.com>
Authored: Fri Aug 11 15:13:42 2017 -0700
Committer: gatorsmile <gatorsm...@gmail.com>
Committed: Fri Aug 11 15:13:42 2017 -0700

----------------------------------------------------------------------
 .../spark/sql/execution/QueryExecution.scala    |  2 +
 .../execution/joins/ReorderJoinPredicates.scala | 94 ++++++++++++++++++++
 .../spark/sql/sources/BucketedReadSuite.scala   | 59 ++++++++++++
 3 files changed, 155 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7f16c691/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index b56fbd4..4accf54 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.execution.command.{DescribeTableCommand, 
ExecutedCommandExec, ShowTablesCommand}
 import org.apache.spark.sql.execution.exchange.{EnsureRequirements, 
ReuseExchange}
+import org.apache.spark.sql.execution.joins.ReorderJoinPredicates
 import org.apache.spark.sql.types.{BinaryType, DateType, DecimalType, 
TimestampType, _}
 import org.apache.spark.util.Utils
 
@@ -103,6 +104,7 @@ class QueryExecution(val sparkSession: SparkSession, val 
logical: LogicalPlan) {
   protected def preparations: Seq[Rule[SparkPlan]] = Seq(
     python.ExtractPythonUDFs,
     PlanSubqueries(sparkSession),
+    new ReorderJoinPredicates,
     EnsureRequirements(sparkSession.sessionState.conf),
     CollapseCodegenStages(sparkSession.sessionState.conf),
     ReuseExchange(sparkSession.sessionState.conf),

http://git-wip-us.apache.org/repos/asf/spark/blob/7f16c691/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ReorderJoinPredicates.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ReorderJoinPredicates.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ReorderJoinPredicates.scala
new file mode 100644
index 0000000..534d8c5
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ReorderJoinPredicates.scala
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, 
Partitioning}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.SparkPlan
+
+/**
+ * When the physical operators are created for JOIN, the ordering of join keys 
is based on order
+ * in which the join keys appear in the user query. That might not match with 
the output
+ * partitioning of the join node's children (thus leading to extra sort / 
shuffle being
+ * introduced). This rule will change the ordering of the join keys to match 
with the
+ * partitioning of the join nodes' children.
+ */
+class ReorderJoinPredicates extends Rule[SparkPlan] {
+  private def reorderJoinKeys(
+      leftKeys: Seq[Expression],
+      rightKeys: Seq[Expression],
+      leftPartitioning: Partitioning,
+      rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = {
+
+    def reorder(
+        expectedOrderOfKeys: Seq[Expression],
+        currentOrderOfKeys: Seq[Expression]): (Seq[Expression], 
Seq[Expression]) = {
+      val leftKeysBuffer = ArrayBuffer[Expression]()
+      val rightKeysBuffer = ArrayBuffer[Expression]()
+
+      expectedOrderOfKeys.foreach(expression => {
+        val index = currentOrderOfKeys.indexWhere(e => 
e.semanticEquals(expression))
+        leftKeysBuffer.append(leftKeys(index))
+        rightKeysBuffer.append(rightKeys(index))
+      })
+      (leftKeysBuffer, rightKeysBuffer)
+    }
+
+    if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) 
{
+      leftPartitioning match {
+        case HashPartitioning(leftExpressions, _)
+          if leftExpressions.length == leftKeys.length &&
+            leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) 
=>
+          reorder(leftExpressions, leftKeys)
+
+        case _ => rightPartitioning match {
+          case HashPartitioning(rightExpressions, _)
+            if rightExpressions.length == rightKeys.length &&
+              rightKeys.forall(x => 
rightExpressions.exists(_.semanticEquals(x))) =>
+            reorder(rightExpressions, rightKeys)
+
+          case _ => (leftKeys, rightKeys)
+        }
+      }
+    } else {
+      (leftKeys, rightKeys)
+    }
+  }
+
+  def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
+    case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, 
condition, left, right) =>
+      val (reorderedLeftKeys, reorderedRightKeys) =
+        reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, 
right.outputPartitioning)
+      BroadcastHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, 
buildSide, condition,
+        left, right)
+
+    case ShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, 
condition, left, right) =>
+      val (reorderedLeftKeys, reorderedRightKeys) =
+        reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, 
right.outputPartitioning)
+      ShuffledHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, 
buildSide, condition,
+        left, right)
+
+    case SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, 
right) =>
+      val (reorderedLeftKeys, reorderedRightKeys) =
+        reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, 
right.outputPartitioning)
+      SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, 
condition, left, right)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/7f16c691/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
index ba0ca66..eb9e645 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
@@ -543,6 +543,65 @@ abstract class BucketedReadSuite extends QueryTest with 
SQLTestUtils {
     )
   }
 
+  test("SPARK-19122 Re-order join predicates if they match with the child's 
output partitioning") {
+    val bucketedTableTestSpec = BucketedTableTestSpec(
+      Some(BucketSpec(8, Seq("i", "j", "k"), Seq("i", "j", "k"))),
+      numPartitions = 1,
+      expectedShuffle = false,
+      expectedSort = false)
+
+    // If the set of join columns is equal to the set of bucketed + sort 
columns, then
+    // the order of join keys in the query should not matter and there should 
not be any shuffle
+    // and sort added in the query plan
+    Seq(
+      Seq("i", "j", "k"),
+      Seq("i", "k", "j"),
+      Seq("j", "k", "i"),
+      Seq("j", "i", "k"),
+      Seq("k", "j", "i"),
+      Seq("k", "i", "j")
+    ).foreach(joinKeys => {
+      testBucketing(
+        bucketedTableTestSpecLeft = bucketedTableTestSpec,
+        bucketedTableTestSpecRight = bucketedTableTestSpec,
+        joinCondition = joinCondition(joinKeys)
+      )
+    })
+  }
+
+  test("SPARK-19122 No re-ordering should happen if set of join columns != set 
of child's " +
+    "partitioning columns") {
+
+    // join predicates is a super set of child's partitioning columns
+    val bucketedTableTestSpec1 =
+      BucketedTableTestSpec(Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))), 
numPartitions = 1)
+    testBucketing(
+      bucketedTableTestSpecLeft = bucketedTableTestSpec1,
+      bucketedTableTestSpecRight = bucketedTableTestSpec1,
+      joinCondition = joinCondition(Seq("i", "j", "k"))
+    )
+
+    // child's partitioning columns is a super set of join predicates
+    val bucketedTableTestSpec2 =
+      BucketedTableTestSpec(Some(BucketSpec(8, Seq("i", "j", "k"), Seq("i", 
"j", "k"))),
+        numPartitions = 1)
+    testBucketing(
+      bucketedTableTestSpecLeft = bucketedTableTestSpec2,
+      bucketedTableTestSpecRight = bucketedTableTestSpec2,
+      joinCondition = joinCondition(Seq("i", "j"))
+    )
+
+    // set of child's partitioning columns != set join predicates (despite the 
lengths of the
+    // sets are same)
+    val bucketedTableTestSpec3 =
+      BucketedTableTestSpec(Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))), 
numPartitions = 1)
+    testBucketing(
+      bucketedTableTestSpecLeft = bucketedTableTestSpec3,
+      bucketedTableTestSpecRight = bucketedTableTestSpec3,
+      joinCondition = joinCondition(Seq("j", "k"))
+    )
+  }
+
   test("error if there exists any malformed bucket files") {
     withTable("bucketed_table") {
       df1.write.format("parquet").bucketBy(8, 
"i").saveAsTable("bucketed_table")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to