cloud-fan commented on a change in pull request #34444:
URL: https://github.com/apache/spark/pull/34444#discussion_r740045221



##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
##########
@@ -332,6 +327,266 @@ case class ShuffledHashJoinExec(
     HashedRelationInfo(relationTerm, keyIsUnique = false, isEmpty = false)
   }
 
+  override def doProduce(ctx: CodegenContext): String = {
+    // Specialize `doProduce` code for full outer join, because full outer 
join needs to
+    // iterate streamed and build side separately.
+    if (joinType != FullOuter) {
+      return super.doProduce(ctx)
+    }
+
+    val HashedRelationInfo(relationTerm, _, _) = prepareRelation(ctx)
+
+    // Inline mutable state since not many join operations in a task
+    val keyIsUnique = ctx.addMutableState("boolean", "keyIsUnique",
+      v => s"$v = $relationTerm.keyIsUnique();", forceInline = true)
+    val streamedInput = ctx.addMutableState("scala.collection.Iterator", 
"streamedInput",
+      v => s"$v = inputs[0];", forceInline = true)
+    val buildInput = ctx.addMutableState("scala.collection.Iterator", 
"buildInput",
+      v => s"$v = $relationTerm.valuesWithKeyIndex();", forceInline = true)
+    val streamedRow = ctx.addMutableState("InternalRow", "streamedRow", 
forceInline = true)
+    val buildRow = ctx.addMutableState("InternalRow", "buildRow", forceInline 
= true)
+
+    // Generate variables and related code from streamed side
+    val streamedVars = genOneSideJoinVars(ctx, streamedRow, streamedPlan, 
setDefaultValue = false)
+    val streamedKeyVariables = evaluateRequiredVariables(streamedOutput, 
streamedVars,
+      
AttributeSet.fromAttributeSets(HashJoin.rewriteKeyExpr(streamedKeys).map(_.references)))
+    ctx.currentVars = streamedVars
+    val streamedKeyExprCode = GenerateUnsafeProjection.createCode(ctx, 
streamedBoundKeys)
+    val streamedKeyEv =
+      s"""
+         |$streamedKeyVariables
+         |${streamedKeyExprCode.code}
+       """.stripMargin
+    val streamedKeyAnyNull = s"${streamedKeyExprCode.value}.anyNull()"
+
+    // Generate code for join condition
+    val (_, conditionCheck, _) =
+      getJoinCondition(ctx, streamedVars, streamedPlan, buildPlan, 
Some(buildRow))
+
+    // Generate code for result output in separate function, as we need to 
output result from
+    // multiple places in join code.
+    val streamedResultVars = genOneSideJoinVars(
+      ctx, streamedRow, streamedPlan, setDefaultValue = true)
+    val buildResultVars = genOneSideJoinVars(
+      ctx, buildRow, buildPlan, setDefaultValue = true)
+    val resultVars = buildSide match {
+      case BuildLeft => buildResultVars ++ streamedResultVars
+      case BuildRight => streamedResultVars ++ buildResultVars
+    }
+    val consumeFullOuterJoinRow = ctx.freshName("consumeFullOuterJoinRow")
+    ctx.addNewFunction(consumeFullOuterJoinRow,
+      s"""
+         |private void $consumeFullOuterJoinRow() {
+         |  ${metricTerm(ctx, "numOutputRows")}.add(1);
+         |  ${consume(ctx, resultVars)}
+         |}
+       """.stripMargin)
+    val stopCheck = "if (shouldStop()) return;"

Review comment:
       This is so short. I think we can just inline it, instead of creating a 
variable and passing it to the methods.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to