c21 commented on a change in pull request #32476: URL: https://github.com/apache/spark/pull/32476#discussion_r628976694
########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala ########## @@ -554,67 +604,118 @@ case class SortMergeJoinExec( override def doProduce(ctx: CodegenContext): String = { // Inline mutable state since not many join operations in a task - val leftInput = ctx.addMutableState("scala.collection.Iterator", "leftInput", + val streamedInput = ctx.addMutableState("scala.collection.Iterator", "streamedInput", v => s"$v = inputs[0];", forceInline = true) - val rightInput = ctx.addMutableState("scala.collection.Iterator", "rightInput", + val bufferedInput = ctx.addMutableState("scala.collection.Iterator", "bufferedInput", v => s"$v = inputs[1];", forceInline = true) - val (leftRow, matches) = genScanner(ctx) + val (streamedRow, matches) = genScanner(ctx) // Create variables for row from both sides. - val (leftVars, leftVarDecl) = createLeftVars(ctx, leftRow) - val rightRow = ctx.freshName("rightRow") - val rightVars = createRightVar(ctx, rightRow) + val (streamedVars, streamedVarDecl) = createStreamedVars(ctx, streamedRow) + val bufferedRow = ctx.freshName("bufferedRow") + val bufferedVars = genBuildSideVars(ctx, bufferedRow, bufferedPlan) val iterator = ctx.freshName("iterator") val numOutput = metricTerm(ctx, "numOutputRows") - val (beforeLoop, condCheck) = if (condition.isDefined) { + val resultVars = joinType match { + case _: InnerLike | LeftOuter => + streamedVars ++ bufferedVars + case RightOuter => + bufferedVars ++ streamedVars + case x => + throw new IllegalArgumentException( + s"SortMergeJoin.doProduce should not take $x as the JoinType") + } + + val (beforeLoop, conditionCheck) = if (condition.isDefined) { // Split the code of creating variables based on whether it's used by condition or not. val loaded = ctx.freshName("loaded") - val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars) - val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars) + val (streamedBefore, streamedAfter) = splitVarsByCondition(streamedOutput, streamedVars) + val (bufferedBefore, bufferedAfter) = splitVarsByCondition(bufferedOutput, bufferedVars) // Generate code for condition - ctx.currentVars = leftVars ++ rightVars + ctx.currentVars = resultVars val cond = BindReferences.bindReference(condition.get, output).genCode(ctx) // evaluate the columns those used by condition before loop - val before = s""" + val before = + s""" |boolean $loaded = false; - |$leftBefore + |$streamedBefore """.stripMargin - val checking = s""" - |$rightBefore - |${cond.code} - |if (${cond.isNull} || !${cond.value}) continue; - |if (!$loaded) { - | $loaded = true; - | $leftAfter - |} - |$rightAfter - """.stripMargin + val checking = + s""" + |$bufferedBefore + |if ($bufferedRow != null) { + | ${cond.code} + | if (${cond.isNull} || !${cond.value}) { + | continue; + | } + |} + |if (!$loaded) { + | $loaded = true; + | $streamedAfter + |} + |$bufferedAfter + """.stripMargin (before, checking) } else { - (evaluateVariables(leftVars), "") + (evaluateVariables(streamedVars), "") } val thisPlan = ctx.addReferenceObj("plan", this) val eagerCleanup = s"$thisPlan.cleanupResources();" - s""" - |while (findNextInnerJoinRows($leftInput, $rightInput)) { - | ${leftVarDecl.mkString("\n")} - | ${beforeLoop.trim} - | scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator(); - | while ($iterator.hasNext()) { - | InternalRow $rightRow = (InternalRow) $iterator.next(); - | ${condCheck.trim} - | $numOutput.add(1); - | ${consume(ctx, leftVars ++ rightVars)} - | } - | if (shouldStop()) return; - |} - |$eagerCleanup + lazy val innerJoin = + s""" + |while (findNextJoinRows($streamedInput, $bufferedInput)) { + | ${streamedVarDecl.mkString("\n")} + | ${beforeLoop.trim} + | scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator(); + | while ($iterator.hasNext()) { + | InternalRow $bufferedRow = (InternalRow) $iterator.next(); + | ${conditionCheck.trim} + | $numOutput.add(1); + | ${consume(ctx, resultVars)} + | } + | if (shouldStop()) return; + |} + |$eagerCleanup """.stripMargin + + lazy val outerJoin = { + val foundMatch = ctx.freshName("foundMatch") + val foundJoinRows = ctx.freshName("foundJoinRows") Review comment: My bad, forget to remove it during code iterations. Will remove. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org