maropu commented on a change in pull request #32210:
URL: https://github.com/apache/spark/pull/32210#discussion_r619912231



##########
File path: sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
##########
@@ -1394,4 +1394,32 @@ class JoinSuite extends QueryTest with 
SharedSparkSession with AdaptiveSparkPlan
       checkAnswer(fullJoinDF, Row(100))
     }
   }
+
+  test("SPARK-32634: Sort-based fallback for shuffled hash join") {
+    val df1 = spark.range(300).map(_.toString).select($"value".as("k1"))
+    val df2 = spark.range(100).map(_.toString).select($"value".as("k2"))
+
+    val smjDF = df1.join(df2.hint("SHUFFLE_MERGE"), $"k1" === $"k2")
+    assert(collect(smjDF.queryExecution.executedPlan) {
+      case _: SortMergeJoinExec => true }.size === 1)
+    val smjResult = smjDF.collect()
+
+    Seq(
+      // All tasks fall back
+      0,
+      // Some tasks fall back
+      10,
+      // No task falls back
+      1000
+    ).foreach(fallbackStartsAt =>
+      withSQLConf(SQLConf.SHUFFLEDHASHJOIN_FALLBACK_ENABLED.key -> "true",
+        "spark.sql.ShuffledHashJoin.testFallbackStartsAt" -> 
fallbackStartsAt.toString) {
+        val shjDF = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2")
+        assert(collect(shjDF.queryExecution.executedPlan) {
+          case _: ShuffledHashJoinExec => true }.size === 1)
+        // Same result between shuffled hash join and sort merge join
+        checkAnswer(shjDF, smjResult)

Review comment:
       Is this a test for non-codegen path?

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
##########
@@ -81,11 +83,22 @@ case class ShuffledHashJoinExec(
 
   protected override def doExecute(): RDD[InternalRow] = {
     val numOutputRows = longMetric("numOutputRows")
+    val spillThreshold = getSpillThreshold
+    val inMemoryThreshold = getInMemoryThreshold
+    val streamSortPlan = getStreamSortPlan
+    val buildSortPlan = getBuildSortPlan
+    val fallbackSMJPlan = SortMergeJoinExec(leftKeys, rightKeys, joinType, 
condition, left, right)
+
     streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, 
buildIter) =>
-      val hashed = buildHashedRelation(buildIter)
-      joinType match {
-        case FullOuter => fullOuterJoin(streamIter, hashed, numOutputRows)
-        case _ => join(streamIter, hashed, numOutputRows)
+      buildHashedRelation(buildIter) match {
+        case r: UnfinishedUnsafeHashedRelation =>

Review comment:
       How about adding a new SQL metric for #fallbacks then checking it in the 
test?

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
##########
@@ -475,18 +501,89 @@ private[joins] object UnsafeHashedRelation {
           key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
           row.getBaseObject, row.getBaseOffset, row.getSizeInBytes)
         if (!success) {
-          binaryMap.free()
-          throw 
QueryExecutionErrors.cannotAcquireMemoryToBuildUnsafeHashedRelationError()
+          if (allowsFallbackWithNoMemory) {
+            return new UnfinishedUnsafeHashedRelation(numFields, binaryMap, 
row)
+          } else {
+            // Clean up map and throw exception
+            binaryMap.free()
+            throw 
QueryExecutionErrors.cannotAcquireMemoryToBuildUnsafeHashedRelationError()
+          }
         }
       } else if (isNullAware) {
         return HashedRelationWithAllNullKeys
       }
+      i += 1
     }
 
     new UnsafeHashedRelation(key.size, numFields, binaryMap)
   }
 }
 
+/**
+ * An unfinished version of [[UnsafeHashedRelation]].
+ * This is intended to use in sort-based fallback of [[ShuffledHashJoinExec]],
+ * when there is no enough memory to build [[UnsafeHashedRelation]].
+ *
+ * @param numFields Number of fields in each row.
+ * @param binaryMap Backed [[BytesToBytesMap]] to hold keys and rows.
+ * @param pendingRow The row which cannot be added to `binaryMap` due to 
memory limit.
+ */
+private[joins] class UnfinishedUnsafeHashedRelation(

Review comment:
       Needs tests for this new class.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to