cloud-fan commented on code in PR #41884:
URL: https://github.com/apache/spark/pull/41884#discussion_r1277106792


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala:
##########
@@ -128,249 +122,27 @@ case class SortMergeJoinExec(
     val spillSize = longMetric("spillSize")
     val spillThreshold = getSpillThreshold
     val inMemoryThreshold = getInMemoryThreshold
-    left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
-      val boundCondition: (InternalRow) => Boolean = {
-        condition.map { cond =>
-          Predicate.create(cond, left.output ++ right.output).eval _
-        }.getOrElse {
-          (r: InternalRow) => true
-        }
-      }
-
-      // An ordering that can be used to compare keys from both sides.
-      val keyOrdering = 
RowOrdering.createNaturalAscendingOrdering(leftKeys.map(_.dataType))
-      val resultProj: InternalRow => InternalRow = 
UnsafeProjection.create(output, output)
-
-      joinType match {
-        case _: InnerLike =>
-          new RowIterator {
-            private[this] var currentLeftRow: InternalRow = _
-            private[this] var currentRightMatches: 
ExternalAppendOnlyUnsafeRowArray = _
-            private[this] var rightMatchesIterator: Iterator[UnsafeRow] = null
-            private[this] val smjScanner = new SortMergeJoinScanner(
-              createLeftKeyGenerator(),
-              createRightKeyGenerator(),
-              keyOrdering,
-              RowIterator.fromScala(leftIter),
-              RowIterator.fromScala(rightIter),
-              inMemoryThreshold,
-              spillThreshold,
-              spillSize,
-              cleanupResources
-            )
-            private[this] val joinRow = new JoinedRow
-
-            if (smjScanner.findNextInnerJoinRows()) {
-              currentRightMatches = smjScanner.getBufferedMatches
-              currentLeftRow = smjScanner.getStreamedRow
-              rightMatchesIterator = currentRightMatches.generateIterator()
-            }
-
-            override def advanceNext(): Boolean = {
-              while (rightMatchesIterator != null) {
-                if (!rightMatchesIterator.hasNext) {
-                  if (smjScanner.findNextInnerJoinRows()) {
-                    currentRightMatches = smjScanner.getBufferedMatches
-                    currentLeftRow = smjScanner.getStreamedRow
-                    rightMatchesIterator = 
currentRightMatches.generateIterator()
-                  } else {
-                    currentRightMatches = null
-                    currentLeftRow = null
-                    rightMatchesIterator = null
-                    return false
-                  }
-                }
-                joinRow(currentLeftRow, rightMatchesIterator.next())
-                if (boundCondition(joinRow)) {
-                  numOutputRows += 1
-                  return true
-                }
-              }
-              false
-            }
-
-            override def getRow: InternalRow = resultProj(joinRow)
-          }.toScala
-
-        case LeftOuter =>
-          val smjScanner = new SortMergeJoinScanner(
-            streamedKeyGenerator = createLeftKeyGenerator(),
-            bufferedKeyGenerator = createRightKeyGenerator(),
-            keyOrdering,
-            streamedIter = RowIterator.fromScala(leftIter),
-            bufferedIter = RowIterator.fromScala(rightIter),
-            inMemoryThreshold,
-            spillThreshold,
-            spillSize,
-            cleanupResources
-          )
-          val rightNullRow = new GenericInternalRow(right.output.length)
-          new LeftOuterIterator(
-            smjScanner, rightNullRow, boundCondition, resultProj, 
numOutputRows).toScala
-
-        case RightOuter =>
-          val smjScanner = new SortMergeJoinScanner(
-            streamedKeyGenerator = createRightKeyGenerator(),
-            bufferedKeyGenerator = createLeftKeyGenerator(),
-            keyOrdering,
-            streamedIter = RowIterator.fromScala(rightIter),
-            bufferedIter = RowIterator.fromScala(leftIter),
-            inMemoryThreshold,
-            spillThreshold,
-            spillSize,
-            cleanupResources
-          )
-          val leftNullRow = new GenericInternalRow(left.output.length)
-          new RightOuterIterator(
-            smjScanner, leftNullRow, boundCondition, resultProj, 
numOutputRows).toScala
-
-        case FullOuter =>
-          val leftNullRow = new GenericInternalRow(left.output.length)
-          val rightNullRow = new GenericInternalRow(right.output.length)
-          val smjScanner = new SortMergeFullOuterJoinScanner(
-            leftKeyGenerator = createLeftKeyGenerator(),
-            rightKeyGenerator = createRightKeyGenerator(),
-            keyOrdering,
-            leftIter = RowIterator.fromScala(leftIter),
-            rightIter = RowIterator.fromScala(rightIter),
-            boundCondition,
-            leftNullRow,
-            rightNullRow)
-
-          new FullOuterIterator(
-            smjScanner,
-            resultProj,
-            numOutputRows).toScala
-
-        case LeftSemi =>
-          new RowIterator {
-            private[this] var currentLeftRow: InternalRow = _
-            private[this] val smjScanner = new SortMergeJoinScanner(
-              createLeftKeyGenerator(),
-              createRightKeyGenerator(),
-              keyOrdering,
-              RowIterator.fromScala(leftIter),
-              RowIterator.fromScala(rightIter),
-              inMemoryThreshold,
-              spillThreshold,
-              spillSize,
-              cleanupResources,
-              onlyBufferFirstMatchedRow
-            )
-            private[this] val joinRow = new JoinedRow
-
-            override def advanceNext(): Boolean = {
-              while (smjScanner.findNextInnerJoinRows()) {
-                val currentRightMatches = smjScanner.getBufferedMatches
-                currentLeftRow = smjScanner.getStreamedRow
-                if (currentRightMatches != null && currentRightMatches.length 
> 0) {
-                  val rightMatchesIterator = 
currentRightMatches.generateIterator()
-                  while (rightMatchesIterator.hasNext) {
-                    joinRow(currentLeftRow, rightMatchesIterator.next())
-                    if (boundCondition(joinRow)) {
-                      numOutputRows += 1
-                      return true
-                    }
-                  }
-                }
-              }
-              false
-            }
-
-            override def getRow: InternalRow = currentLeftRow
-          }.toScala
-
-        case LeftAnti =>
-          new RowIterator {
-            private[this] var currentLeftRow: InternalRow = _
-            private[this] val smjScanner = new SortMergeJoinScanner(
-              createLeftKeyGenerator(),
-              createRightKeyGenerator(),
-              keyOrdering,
-              RowIterator.fromScala(leftIter),
-              RowIterator.fromScala(rightIter),
-              inMemoryThreshold,
-              spillThreshold,
-              spillSize,
-              cleanupResources,
-              onlyBufferFirstMatchedRow
-            )
-            private[this] val joinRow = new JoinedRow
-
-            override def advanceNext(): Boolean = {
-              while (smjScanner.findNextOuterJoinRows()) {
-                currentLeftRow = smjScanner.getStreamedRow
-                val currentRightMatches = smjScanner.getBufferedMatches
-                if (currentRightMatches == null || currentRightMatches.length 
== 0) {
-                  numOutputRows += 1
-                  return true
-                }
-                var found = false
-                val rightMatchesIterator = 
currentRightMatches.generateIterator()
-                while (!found && rightMatchesIterator.hasNext) {
-                  joinRow(currentLeftRow, rightMatchesIterator.next())
-                  if (boundCondition(joinRow)) {
-                    found = true
-                  }
-                }
-                if (!found) {
-                  numOutputRows += 1
-                  return true
-                }
-              }
-              false
-            }
-
-            override def getRow: InternalRow = currentLeftRow
-          }.toScala
-
-        case j: ExistenceJoin =>
-          new RowIterator {
-            private[this] var currentLeftRow: InternalRow = _
-            private[this] val result: InternalRow = new 
GenericInternalRow(Array[Any](null))
-            private[this] val smjScanner = new SortMergeJoinScanner(
-              createLeftKeyGenerator(),
-              createRightKeyGenerator(),
-              keyOrdering,
-              RowIterator.fromScala(leftIter),
-              RowIterator.fromScala(rightIter),
-              inMemoryThreshold,
-              spillThreshold,
-              spillSize,
-              cleanupResources,
-              onlyBufferFirstMatchedRow
-            )
-            private[this] val joinRow = new JoinedRow
-
-            override def advanceNext(): Boolean = {
-              while (smjScanner.findNextOuterJoinRows()) {
-                currentLeftRow = smjScanner.getStreamedRow
-                val currentRightMatches = smjScanner.getBufferedMatches
-                var found = false
-                if (currentRightMatches != null && currentRightMatches.length 
> 0) {
-                  val rightMatchesIterator = 
currentRightMatches.generateIterator()
-                  while (!found && rightMatchesIterator.hasNext) {
-                    joinRow(currentLeftRow, rightMatchesIterator.next())
-                    if (boundCondition(joinRow)) {
-                      found = true
-                    }
-                  }
-                }
-                result.setBoolean(0, found)
-                numOutputRows += 1
-                return true
-              }
-              false
-            }
-
-            override def getRow: InternalRow = 
resultProj(joinRow(currentLeftRow, result))
-          }.toScala
-
-        case x =>
-          throw new IllegalArgumentException(
-            s"SortMergeJoin should not take $x as the JoinType")
+    val evaluatorFactory = new SortMergeJoinEvaluatorFactory(
+      leftKeys,
+      rightKeys,
+      joinType,
+      condition,
+      left,
+      right,
+      output,
+      inMemoryThreshold,
+      spillThreshold,
+      numOutputRows,
+      spillSize,
+      onlyBufferFirstMatchedRow
+    )
+    if (conf.usePartitionEvaluator) {
+      left.execute().zipPartitionsWithEvaluator(right.execute(), 
evaluatorFactory)
+    } else {
+      left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
+        val evaluator = evaluatorFactory.createEvaluator()
+        evaluator.eval(0, leftIter, rightIter)

Review Comment:
   nvm, it's different from index. Maybe we should just leave a note here about 
why we always use `0`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to