maropu commented on a change in pull request #32210:
URL: https://github.com/apache/spark/pull/32210#discussion_r619912231



##########
File path: sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
##########
@@ -1394,4 +1394,32 @@ class JoinSuite extends QueryTest with 
SharedSparkSession with AdaptiveSparkPlan
       checkAnswer(fullJoinDF, Row(100))
     }
   }
+
+  test("SPARK-32634: Sort-based fallback for shuffled hash join") {
+    val df1 = spark.range(300).map(_.toString).select($"value".as("k1"))
+    val df2 = spark.range(100).map(_.toString).select($"value".as("k2"))
+
+    val smjDF = df1.join(df2.hint("SHUFFLE_MERGE"), $"k1" === $"k2")
+    assert(collect(smjDF.queryExecution.executedPlan) {
+      case _: SortMergeJoinExec => true }.size === 1)
+    val smjResult = smjDF.collect()
+
+    Seq(
+      // All tasks fall back
+      0,
+      // Some tasks fall back
+      10,
+      // No task falls back
+      1000
+    ).foreach(fallbackStartsAt =>
+      withSQLConf(SQLConf.SHUFFLEDHASHJOIN_FALLBACK_ENABLED.key -> "true",
+        "spark.sql.ShuffledHashJoin.testFallbackStartsAt" -> 
fallbackStartsAt.toString) {
+        val shjDF = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2")
+        assert(collect(shjDF.queryExecution.executedPlan) {
+          case _: ShuffledHashJoinExec => true }.size === 1)
+        // Same result between shuffled hash join and sort merge join
+        checkAnswer(shjDF, smjResult)

Review comment:
       Is this a test for non-codegen path?

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
##########
@@ -81,11 +83,22 @@ case class ShuffledHashJoinExec(
 
   protected override def doExecute(): RDD[InternalRow] = {
     val numOutputRows = longMetric("numOutputRows")
+    val spillThreshold = getSpillThreshold
+    val inMemoryThreshold = getInMemoryThreshold
+    val streamSortPlan = getStreamSortPlan
+    val buildSortPlan = getBuildSortPlan
+    val fallbackSMJPlan = SortMergeJoinExec(leftKeys, rightKeys, joinType, 
condition, left, right)
+
     streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, 
buildIter) =>
-      val hashed = buildHashedRelation(buildIter)
-      joinType match {
-        case FullOuter => fullOuterJoin(streamIter, hashed, numOutputRows)
-        case _ => join(streamIter, hashed, numOutputRows)
+      buildHashedRelation(buildIter) match {
+        case r: UnfinishedUnsafeHashedRelation =>

Review comment:
       How about adding a new SQL metric for #fallbacks then checking it in the 
test?

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
##########
@@ -475,18 +501,89 @@ private[joins] object UnsafeHashedRelation {
           key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
           row.getBaseObject, row.getBaseOffset, row.getSizeInBytes)
         if (!success) {
-          binaryMap.free()
-          throw 
QueryExecutionErrors.cannotAcquireMemoryToBuildUnsafeHashedRelationError()
+          if (allowsFallbackWithNoMemory) {
+            return new UnfinishedUnsafeHashedRelation(numFields, binaryMap, 
row)
+          } else {
+            // Clean up map and throw exception
+            binaryMap.free()
+            throw 
QueryExecutionErrors.cannotAcquireMemoryToBuildUnsafeHashedRelationError()
+          }
         }
       } else if (isNullAware) {
         return HashedRelationWithAllNullKeys
       }
+      i += 1
     }
 
     new UnsafeHashedRelation(key.size, numFields, binaryMap)
   }
 }
 
+/**
+ * An unfinished version of [[UnsafeHashedRelation]].
+ * This is intended to use in sort-based fallback of [[ShuffledHashJoinExec]],
+ * when there is no enough memory to build [[UnsafeHashedRelation]].
+ *
+ * @param numFields Number of fields in each row.
+ * @param binaryMap Backed [[BytesToBytesMap]] to hold keys and rows.
+ * @param pendingRow The row which cannot be added to `binaryMap` due to 
memory limit.
+ */
+private[joins] class UnfinishedUnsafeHashedRelation(

Review comment:
       Needs tests for this new class.

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
##########
@@ -81,11 +83,22 @@ case class ShuffledHashJoinExec(
 
   protected override def doExecute(): RDD[InternalRow] = {
     val numOutputRows = longMetric("numOutputRows")
+    val spillThreshold = getSpillThreshold
+    val inMemoryThreshold = getInMemoryThreshold
+    val streamSortPlan = getStreamSortPlan
+    val buildSortPlan = getBuildSortPlan
+    val fallbackSMJPlan = SortMergeJoinExec(leftKeys, rightKeys, joinType, 
condition, left, right)
+
     streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, 
buildIter) =>
-      val hashed = buildHashedRelation(buildIter)
-      joinType match {
-        case FullOuter => fullOuterJoin(streamIter, hashed, numOutputRows)
-        case _ => join(streamIter, hashed, numOutputRows)
+      buildHashedRelation(buildIter) match {
+        case r: UnfinishedUnsafeHashedRelation =>
+          joinWithSortFallback(streamIter, buildIter, r.destructiveValues(), 
streamSortPlan,

Review comment:
       @c21 This fallback logic has been already deployed in your production? I 
just want to know that this can work well for real workloads (This can cause 
high performance penalties if the fallback happens only in a single task).

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
##########
@@ -81,11 +83,22 @@ case class ShuffledHashJoinExec(
 
   protected override def doExecute(): RDD[InternalRow] = {
     val numOutputRows = longMetric("numOutputRows")
+    val spillThreshold = getSpillThreshold
+    val inMemoryThreshold = getInMemoryThreshold
+    val streamSortPlan = getStreamSortPlan
+    val buildSortPlan = getBuildSortPlan
+    val fallbackSMJPlan = SortMergeJoinExec(leftKeys, rightKeys, joinType, 
condition, left, right)
+
     streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, 
buildIter) =>
-      val hashed = buildHashedRelation(buildIter)
-      joinType match {
-        case FullOuter => fullOuterJoin(streamIter, hashed, numOutputRows)
-        case _ => join(streamIter, hashed, numOutputRows)
+      buildHashedRelation(buildIter) match {
+        case r: UnfinishedUnsafeHashedRelation =>
+          joinWithSortFallback(streamIter, buildIter, r.destructiveValues(), 
streamSortPlan,

Review comment:
       For example, we have three tasks for a shuffle hash join. If the two 
tasks have no fallback and the last one has the fallback, the running time of 
the fallback one (building time of a hash table and sorting time of left/right 
input rows) can be much longer than the other two non-fallback ones? Or, I 
misunderstood the current logic?

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
##########
@@ -81,11 +83,22 @@ case class ShuffledHashJoinExec(
 
   protected override def doExecute(): RDD[InternalRow] = {
     val numOutputRows = longMetric("numOutputRows")
+    val spillThreshold = getSpillThreshold
+    val inMemoryThreshold = getInMemoryThreshold
+    val streamSortPlan = getStreamSortPlan
+    val buildSortPlan = getBuildSortPlan
+    val fallbackSMJPlan = SortMergeJoinExec(leftKeys, rightKeys, joinType, 
condition, left, right)
+
     streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, 
buildIter) =>
-      val hashed = buildHashedRelation(buildIter)
-      joinType match {
-        case FullOuter => fullOuterJoin(streamIter, hashed, numOutputRows)
-        case _ => join(streamIter, hashed, numOutputRows)
+      buildHashedRelation(buildIter) match {
+        case r: UnfinishedUnsafeHashedRelation =>
+          joinWithSortFallback(streamIter, buildIter, r.destructiveValues(), 
streamSortPlan,

Review comment:
       > @maropu - yes, more data is in #32210 (comment) .
   
   Ah, thanks. I missed that comment.

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
##########
@@ -81,11 +83,22 @@ case class ShuffledHashJoinExec(
 
   protected override def doExecute(): RDD[InternalRow] = {
     val numOutputRows = longMetric("numOutputRows")
+    val spillThreshold = getSpillThreshold
+    val inMemoryThreshold = getInMemoryThreshold
+    val streamSortPlan = getStreamSortPlan
+    val buildSortPlan = getBuildSortPlan
+    val fallbackSMJPlan = SortMergeJoinExec(leftKeys, rightKeys, joinType, 
condition, left, right)
+
     streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, 
buildIter) =>
-      val hashed = buildHashedRelation(buildIter)
-      joinType match {
-        case FullOuter => fullOuterJoin(streamIter, hashed, numOutputRows)
-        case _ => join(streamIter, hashed, numOutputRows)
+      buildHashedRelation(buildIter) match {
+        case r: UnfinishedUnsafeHashedRelation =>
+          joinWithSortFallback(streamIter, buildIter, r.destructiveValues(), 
streamSortPlan,

Review comment:
       > @maropu - yes, more data is in #32210 (comment) .
   
   Ah, thanks. I missed that comment. It looks nice.

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
##########
@@ -81,11 +83,22 @@ case class ShuffledHashJoinExec(
 
   protected override def doExecute(): RDD[InternalRow] = {
     val numOutputRows = longMetric("numOutputRows")
+    val spillThreshold = getSpillThreshold
+    val inMemoryThreshold = getInMemoryThreshold
+    val streamSortPlan = getStreamSortPlan
+    val buildSortPlan = getBuildSortPlan
+    val fallbackSMJPlan = SortMergeJoinExec(leftKeys, rightKeys, joinType, 
condition, left, right)
+
     streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, 
buildIter) =>
-      val hashed = buildHashedRelation(buildIter)
-      joinType match {
-        case FullOuter => fullOuterJoin(streamIter, hashed, numOutputRows)
-        case _ => join(streamIter, hashed, numOutputRows)
+      buildHashedRelation(buildIter) match {
+        case r: UnfinishedUnsafeHashedRelation =>
+          joinWithSortFallback(streamIter, buildIter, r.destructiveValues(), 
streamSortPlan,

Review comment:
       > For runtime, yes. The total query run-time is dominated by the last 
finished task runtime. Just to point it out in case, without this change, this 
would be task and query failure.
   
   Yea, I basically agree to make the shuffle hash-join more robust (since a 
user possibly use inappropriate join hints in some cases). What I'm interested 
in is that; is there any other faster ballback logic than the current approach.

##########
File path: sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
##########
@@ -1394,4 +1394,32 @@ class JoinSuite extends QueryTest with 
SharedSparkSession with AdaptiveSparkPlan
       checkAnswer(fullJoinDF, Row(100))
     }
   }
+
+  test("SPARK-32634: Sort-based fallback for shuffled hash join") {
+    val df1 = spark.range(300).map(_.toString).select($"value".as("k1"))
+    val df2 = spark.range(100).map(_.toString).select($"value".as("k2"))
+
+    val smjDF = df1.join(df2.hint("SHUFFLE_MERGE"), $"k1" === $"k2")
+    assert(collect(smjDF.queryExecution.executedPlan) {
+      case _: SortMergeJoinExec => true }.size === 1)
+    val smjResult = smjDF.collect()
+
+    Seq(
+      // All tasks fall back
+      0,
+      // Some tasks fall back
+      10,
+      // No task falls back
+      1000
+    ).foreach(fallbackStartsAt =>
+      withSQLConf(SQLConf.SHUFFLEDHASHJOIN_FALLBACK_ENABLED.key -> "true",
+        "spark.sql.ShuffledHashJoin.testFallbackStartsAt" -> 
fallbackStartsAt.toString) {
+        val shjDF = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2")
+        assert(collect(shjDF.queryExecution.executedPlan) {
+          case _: ShuffledHashJoinExec => true }.size === 1)
+        // Same result between shuffled hash join and sort merge join
+        checkAnswer(shjDF, smjResult)

Review comment:
       Oh, I see.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to