maropu commented on a change in pull request #32476: URL: https://github.com/apache/spark/pull/32476#discussion_r628960873
########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala ########## @@ -418,115 +443,140 @@ case class SortMergeJoinExec( // Inline mutable state since not many join operations in a task val matches = ctx.addMutableState(clsName, "matches", v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold);", forceInline = true) - // Copy the left keys as class members so they could be used in next function call. - val matchedKeyVars = copyKeys(ctx, leftKeyVars) + // Copy the streamed keys as class members so they could be used in next function call. + val matchedKeyVars = copyKeys(ctx, streamedKeyVars) + + // Handle the case when streamed rows has any NULL keys. + val handleStreamedAnyNull = joinType match { + case _: InnerLike => + // Skip streamed row. + s""" + |$streamedRow = null; + |continue; + """.stripMargin + case LeftOuter | RightOuter => + // Eagerly return streamed row. + s""" + |if (!$matches.isEmpty()) { + | $matches.clear(); + |} + |return false; + """.stripMargin + case x => + throw new IllegalArgumentException( + s"SortMergeJoin.genScanner should not take $x as the JoinType") + } - ctx.addNewFunction("findNextInnerJoinRows", + // Handle the case when streamed keys less than buffered keys. + val handleStreamedLessThanBuffered = joinType match { + case _: InnerLike => + // Skip streamed row. + s"$streamedRow = null;" + case LeftOuter | RightOuter => + // Eagerly return with streamed row. + "return false;" + case x => + throw new IllegalArgumentException( + s"SortMergeJoin.genScanner should not take $x as the JoinType") + } + + ctx.addNewFunction("findNextJoinRows", s""" - |private boolean findNextInnerJoinRows( - | scala.collection.Iterator leftIter, - | scala.collection.Iterator rightIter) { - | $leftRow = null; + |private boolean findNextJoinRows( Review comment: In the outer case, a return value is not used? ########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala ########## @@ -554,67 +604,118 @@ case class SortMergeJoinExec( override def doProduce(ctx: CodegenContext): String = { // Inline mutable state since not many join operations in a task - val leftInput = ctx.addMutableState("scala.collection.Iterator", "leftInput", + val streamedInput = ctx.addMutableState("scala.collection.Iterator", "streamedInput", v => s"$v = inputs[0];", forceInline = true) - val rightInput = ctx.addMutableState("scala.collection.Iterator", "rightInput", + val bufferedInput = ctx.addMutableState("scala.collection.Iterator", "bufferedInput", v => s"$v = inputs[1];", forceInline = true) - val (leftRow, matches) = genScanner(ctx) + val (streamedRow, matches) = genScanner(ctx) // Create variables for row from both sides. - val (leftVars, leftVarDecl) = createLeftVars(ctx, leftRow) - val rightRow = ctx.freshName("rightRow") - val rightVars = createRightVar(ctx, rightRow) + val (streamedVars, streamedVarDecl) = createStreamedVars(ctx, streamedRow) + val bufferedRow = ctx.freshName("bufferedRow") + val bufferedVars = genBuildSideVars(ctx, bufferedRow, bufferedPlan) val iterator = ctx.freshName("iterator") val numOutput = metricTerm(ctx, "numOutputRows") - val (beforeLoop, condCheck) = if (condition.isDefined) { + val resultVars = joinType match { + case _: InnerLike | LeftOuter => + streamedVars ++ bufferedVars + case RightOuter => + bufferedVars ++ streamedVars + case x => + throw new IllegalArgumentException( + s"SortMergeJoin.doProduce should not take $x as the JoinType") + } + + val (beforeLoop, conditionCheck) = if (condition.isDefined) { // Split the code of creating variables based on whether it's used by condition or not. val loaded = ctx.freshName("loaded") - val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars) - val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars) + val (streamedBefore, streamedAfter) = splitVarsByCondition(streamedOutput, streamedVars) + val (bufferedBefore, bufferedAfter) = splitVarsByCondition(bufferedOutput, bufferedVars) // Generate code for condition - ctx.currentVars = leftVars ++ rightVars + ctx.currentVars = resultVars val cond = BindReferences.bindReference(condition.get, output).genCode(ctx) // evaluate the columns those used by condition before loop - val before = s""" + val before = + s""" |boolean $loaded = false; - |$leftBefore + |$streamedBefore """.stripMargin - val checking = s""" - |$rightBefore - |${cond.code} - |if (${cond.isNull} || !${cond.value}) continue; - |if (!$loaded) { - | $loaded = true; - | $leftAfter - |} - |$rightAfter - """.stripMargin + val checking = + s""" + |$bufferedBefore + |if ($bufferedRow != null) { + | ${cond.code} + | if (${cond.isNull} || !${cond.value}) { + | continue; + | } + |} + |if (!$loaded) { + | $loaded = true; + | $streamedAfter + |} + |$bufferedAfter + """.stripMargin (before, checking) } else { - (evaluateVariables(leftVars), "") + (evaluateVariables(streamedVars), "") } val thisPlan = ctx.addReferenceObj("plan", this) val eagerCleanup = s"$thisPlan.cleanupResources();" - s""" - |while (findNextInnerJoinRows($leftInput, $rightInput)) { - | ${leftVarDecl.mkString("\n")} - | ${beforeLoop.trim} - | scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator(); - | while ($iterator.hasNext()) { - | InternalRow $rightRow = (InternalRow) $iterator.next(); - | ${condCheck.trim} - | $numOutput.add(1); - | ${consume(ctx, leftVars ++ rightVars)} - | } - | if (shouldStop()) return; - |} - |$eagerCleanup + lazy val innerJoin = + s""" + |while (findNextJoinRows($streamedInput, $bufferedInput)) { + | ${streamedVarDecl.mkString("\n")} + | ${beforeLoop.trim} + | scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator(); + | while ($iterator.hasNext()) { + | InternalRow $bufferedRow = (InternalRow) $iterator.next(); + | ${conditionCheck.trim} + | $numOutput.add(1); + | ${consume(ctx, resultVars)} + | } + | if (shouldStop()) return; + |} + |$eagerCleanup """.stripMargin + + lazy val outerJoin = { + val foundMatch = ctx.freshName("foundMatch") + val foundJoinRows = ctx.freshName("foundJoinRows") Review comment: `foundJoinRows` not used? ########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala ########## @@ -418,115 +443,140 @@ case class SortMergeJoinExec( // Inline mutable state since not many join operations in a task val matches = ctx.addMutableState(clsName, "matches", v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold);", forceInline = true) - // Copy the left keys as class members so they could be used in next function call. - val matchedKeyVars = copyKeys(ctx, leftKeyVars) + // Copy the streamed keys as class members so they could be used in next function call. + val matchedKeyVars = copyKeys(ctx, streamedKeyVars) + + // Handle the case when streamed rows has any NULL keys. + val handleStreamedAnyNull = joinType match { + case _: InnerLike => + // Skip streamed row. + s""" + |$streamedRow = null; + |continue; + """.stripMargin + case LeftOuter | RightOuter => + // Eagerly return streamed row. + s""" + |if (!$matches.isEmpty()) { + | $matches.clear(); + |} + |return false; + """.stripMargin + case x => + throw new IllegalArgumentException( + s"SortMergeJoin.genScanner should not take $x as the JoinType") + } - ctx.addNewFunction("findNextInnerJoinRows", + // Handle the case when streamed keys less than buffered keys. + val handleStreamedLessThanBuffered = joinType match { + case _: InnerLike => + // Skip streamed row. + s"$streamedRow = null;" + case LeftOuter | RightOuter => + // Eagerly return with streamed row. + "return false;" + case x => + throw new IllegalArgumentException( + s"SortMergeJoin.genScanner should not take $x as the JoinType") + } + + ctx.addNewFunction("findNextJoinRows", s""" - |private boolean findNextInnerJoinRows( - | scala.collection.Iterator leftIter, - | scala.collection.Iterator rightIter) { - | $leftRow = null; + |private boolean findNextJoinRows( Review comment: It looks reusing the inner-case code makes the outer-case code inefficient. For example, if there are too many matched duplicate rows in the buffered side, it seems we don't need to put all the rows in `matches`, right? ########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala ########## @@ -418,115 +443,140 @@ case class SortMergeJoinExec( // Inline mutable state since not many join operations in a task val matches = ctx.addMutableState(clsName, "matches", v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold);", forceInline = true) - // Copy the left keys as class members so they could be used in next function call. - val matchedKeyVars = copyKeys(ctx, leftKeyVars) + // Copy the streamed keys as class members so they could be used in next function call. + val matchedKeyVars = copyKeys(ctx, streamedKeyVars) + + // Handle the case when streamed rows has any NULL keys. + val handleStreamedAnyNull = joinType match { + case _: InnerLike => + // Skip streamed row. + s""" + |$streamedRow = null; + |continue; + """.stripMargin + case LeftOuter | RightOuter => + // Eagerly return streamed row. + s""" + |if (!$matches.isEmpty()) { + | $matches.clear(); + |} + |return false; + """.stripMargin + case x => + throw new IllegalArgumentException( + s"SortMergeJoin.genScanner should not take $x as the JoinType") + } - ctx.addNewFunction("findNextInnerJoinRows", + // Handle the case when streamed keys less than buffered keys. + val handleStreamedLessThanBuffered = joinType match { + case _: InnerLike => + // Skip streamed row. + s"$streamedRow = null;" + case LeftOuter | RightOuter => + // Eagerly return with streamed row. + "return false;" + case x => + throw new IllegalArgumentException( + s"SortMergeJoin.genScanner should not take $x as the JoinType") + } + + ctx.addNewFunction("findNextJoinRows", s""" - |private boolean findNextInnerJoinRows( - | scala.collection.Iterator leftIter, - | scala.collection.Iterator rightIter) { - | $leftRow = null; + |private boolean findNextJoinRows( + | scala.collection.Iterator streamedIter, + | scala.collection.Iterator bufferedIter) { + | $streamedRow = null; | int comp = 0; - | while ($leftRow == null) { - | if (!leftIter.hasNext()) return false; - | $leftRow = (InternalRow) leftIter.next(); - | ${leftKeyVars.map(_.code).mkString("\n")} - | if ($leftAnyNull) { - | $leftRow = null; - | continue; + | while ($streamedRow == null) { + | if (!streamedIter.hasNext()) return false; + | $streamedRow = (InternalRow) streamedIter.next(); + | ${streamedKeyVars.map(_.code).mkString("\n")} + | if ($streamedAnyNull) { + | $handleStreamedAnyNull | } | if (!$matches.isEmpty()) { - | ${genComparison(ctx, leftKeyVars, matchedKeyVars)} + | ${genComparison(ctx, streamedKeyVars, matchedKeyVars)} | if (comp == 0) { | return true; | } | $matches.clear(); | } | | do { - | if ($rightRow == null) { - | if (!rightIter.hasNext()) { + | if ($bufferedRow == null) { + | if (!bufferedIter.hasNext()) { | ${matchedKeyVars.map(_.code).mkString("\n")} | return !$matches.isEmpty(); | } - | $rightRow = (InternalRow) rightIter.next(); - | ${rightKeyTmpVars.map(_.code).mkString("\n")} - | if ($rightAnyNull) { - | $rightRow = null; + | $bufferedRow = (InternalRow) bufferedIter.next(); + | ${bufferedKeyTmpVars.map(_.code).mkString("\n")} + | if ($bufferedAnyNull) { + | $bufferedRow = null; | continue; | } - | ${rightKeyVars.map(_.code).mkString("\n")} + | ${bufferedKeyVars.map(_.code).mkString("\n")} | } - | ${genComparison(ctx, leftKeyVars, rightKeyVars)} + | ${genComparison(ctx, streamedKeyVars, bufferedKeyVars)} | if (comp > 0) { - | $rightRow = null; + | $bufferedRow = null; | } else if (comp < 0) { | if (!$matches.isEmpty()) { | ${matchedKeyVars.map(_.code).mkString("\n")} | return true; + | } else { + | $handleStreamedLessThanBuffered | } - | $leftRow = null; | } else { - | $matches.add((UnsafeRow) $rightRow); - | $rightRow = null; + | $matches.add((UnsafeRow) $bufferedRow); + | $bufferedRow = null; | } - | } while ($leftRow != null); + | } while ($streamedRow != null); | } | return false; // unreachable Review comment: (This is not related to this PR though) In this case, could we throw an illegal state exception? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org