c21 commented on a change in pull request #32210:
URL: https://github.com/apache/spark/pull/32210#discussion_r619976792
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
##########
@@ -81,11 +83,22 @@ case class ShuffledHashJoinExec(
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
+ val spillThreshold = getSpillThreshold
+ val inMemoryThreshold = getInMemoryThreshold
+ val streamSortPlan = getStreamSortPlan
+ val buildSortPlan = getBuildSortPlan
+ val fallbackSMJPlan = SortMergeJoinExec(leftKeys, rightKeys, joinType,
condition, left, right)
+
streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter,
buildIter) =>
- val hashed = buildHashedRelation(buildIter)
- joinType match {
- case FullOuter => fullOuterJoin(streamIter, hashed, numOutputRows)
- case _ => join(streamIter, hashed, numOutputRows)
+ buildHashedRelation(buildIter) match {
+ case r: UnfinishedUnsafeHashedRelation =>
+ joinWithSortFallback(streamIter, buildIter, r.destructiveValues(),
streamSortPlan,
Review comment:
@maropu - yes, more data is in
https://github.com/apache/spark/pull/32210#issuecomment-823503243 .
> This can cause high performance penalties if the fallback happens only in
a single task
Wondering why it is the case?
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
##########
@@ -81,11 +83,22 @@ case class ShuffledHashJoinExec(
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
+ val spillThreshold = getSpillThreshold
+ val inMemoryThreshold = getInMemoryThreshold
+ val streamSortPlan = getStreamSortPlan
+ val buildSortPlan = getBuildSortPlan
+ val fallbackSMJPlan = SortMergeJoinExec(leftKeys, rightKeys, joinType,
condition, left, right)
+
streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter,
buildIter) =>
- val hashed = buildHashedRelation(buildIter)
- joinType match {
- case FullOuter => fullOuterJoin(streamIter, hashed, numOutputRows)
- case _ => join(streamIter, hashed, numOutputRows)
+ buildHashedRelation(buildIter) match {
+ case r: UnfinishedUnsafeHashedRelation =>
+ joinWithSortFallback(streamIter, buildIter, r.destructiveValues(),
streamSortPlan,
Review comment:
For runtime, yes. The total query run-time is dominated by the last
finished task runtime. Just to point it out in case, without this change, this
would be task and query failure.
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
##########
@@ -81,11 +83,22 @@ case class ShuffledHashJoinExec(
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
+ val spillThreshold = getSpillThreshold
+ val inMemoryThreshold = getInMemoryThreshold
+ val streamSortPlan = getStreamSortPlan
+ val buildSortPlan = getBuildSortPlan
+ val fallbackSMJPlan = SortMergeJoinExec(leftKeys, rightKeys, joinType,
condition, left, right)
+
streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter,
buildIter) =>
- val hashed = buildHashedRelation(buildIter)
- joinType match {
- case FullOuter => fullOuterJoin(streamIter, hashed, numOutputRows)
- case _ => join(streamIter, hashed, numOutputRows)
+ buildHashedRelation(buildIter) match {
+ case r: UnfinishedUnsafeHashedRelation =>
+ joinWithSortFallback(streamIter, buildIter, r.destructiveValues(),
streamSortPlan,
Review comment:
> is there any other faster ballback logic than the current approach.
@maropu - yes, I am open and welcome any suggestion as well. Thanks for more
insights from @sigmod and you @maropu . I will take a deeper look of hybrid
join as in documentations you guys provided, and get back to you guys later.
##########
File path: sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
##########
@@ -1394,4 +1394,32 @@ class JoinSuite extends QueryTest with
SharedSparkSession with AdaptiveSparkPlan
checkAnswer(fullJoinDF, Row(100))
}
}
+
+ test("SPARK-32634: Sort-based fallback for shuffled hash join") {
+ val df1 = spark.range(300).map(_.toString).select($"value".as("k1"))
+ val df2 = spark.range(100).map(_.toString).select($"value".as("k2"))
+
+ val smjDF = df1.join(df2.hint("SHUFFLE_MERGE"), $"k1" === $"k2")
+ assert(collect(smjDF.queryExecution.executedPlan) {
+ case _: SortMergeJoinExec => true }.size === 1)
+ val smjResult = smjDF.collect()
+
+ Seq(
+ // All tasks fall back
+ 0,
+ // Some tasks fall back
+ 10,
+ // No task falls back
+ 1000
+ ).foreach(fallbackStartsAt =>
+ withSQLConf(SQLConf.SHUFFLEDHASHJOIN_FALLBACK_ENABLED.key -> "true",
+ "spark.sql.ShuffledHashJoin.testFallbackStartsAt" ->
fallbackStartsAt.toString) {
+ val shjDF = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2")
+ assert(collect(shjDF.queryExecution.executedPlan) {
+ case _: ShuffledHashJoinExec => true }.size === 1)
+ // Same result between shuffled hash join and sort merge join
+ checkAnswer(shjDF, smjResult)
Review comment:
Yes it is. I explicitly disable code-gen (in
`ShuffledHashJoinExec.supportCodegen`) when
`SQLConf.SHUFFLEDHASHJOIN_FALLBACK_ENABLED` is set to true in this PR.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]