Github user JoshRosen commented on a diff in the pull request:
https://github.com/apache/spark/pull/7904#discussion_r36582977
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
---
@@ -56,117 +53,274 @@ case class SortMergeJoin(
@transient protected lazy val leftKeyGenerator = newProjection(leftKeys,
left.output)
@transient protected lazy val rightKeyGenerator =
newProjection(rightKeys, right.output)
+ protected[this] def isUnsafeMode: Boolean = {
+ (codegenEnabled && unsafeEnabled
+ && UnsafeProjection.canSupport(leftKeys)
+ && UnsafeProjection.canSupport(schema))
+ }
+
+ override def outputsUnsafeRows: Boolean = isUnsafeMode
+ override def canProcessUnsafeRows: Boolean = isUnsafeMode
+ override def canProcessSafeRows: Boolean = !isUnsafeMode
+
private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = {
// This must be ascending in order to agree with the `keyOrdering`
defined in `doExecute()`.
keys.map(SortOrder(_, Ascending))
}
protected override def doExecute(): RDD[InternalRow] = {
- val leftResults = left.execute().map(_.copy())
- val rightResults = right.execute().map(_.copy())
-
- leftResults.zipPartitions(rightResults) { (leftIter, rightIter) =>
- new Iterator[InternalRow] {
+ left.execute().zipPartitions(right.execute()) { (leftIter, rightIter)
=>
+ new RowIterator {
// An ordering that can be used to compare keys from both sides.
private[this] val keyOrdering =
newNaturalAscendingOrdering(leftKeys.map(_.dataType))
- // Mutable per row objects.
+ private[this] var currentLeftRow: InternalRow = _
+ private[this] var currentRightMatches: ArrayBuffer[InternalRow] = _
+ private[this] var currentMatchIdx: Int = -1
+ private[this] val smjScanner = new SortMergeJoinScanner(
+ leftKeyGenerator,
+ rightKeyGenerator,
+ keyOrdering,
+ RowIterator.fromScala(leftIter),
+ RowIterator.fromScala(rightIter)
+ )
private[this] val joinRow = new JoinedRow
- private[this] var leftElement: InternalRow = _
- private[this] var rightElement: InternalRow = _
- private[this] var leftKey: InternalRow = _
- private[this] var rightKey: InternalRow = _
- private[this] var rightMatches: CompactBuffer[InternalRow] = _
- private[this] var rightPosition: Int = -1
- private[this] var stop: Boolean = false
- private[this] var matchKey: InternalRow = _
-
- // initialize iterator
- initialize()
-
- override final def hasNext: Boolean = nextMatchingPair()
-
- override final def next(): InternalRow = {
- if (hasNext) {
- // we are using the buffered right rows and run down left
iterator
- val joinedRow = joinRow(leftElement,
rightMatches(rightPosition))
- rightPosition += 1
- if (rightPosition >= rightMatches.size) {
- rightPosition = 0
- fetchLeft()
- if (leftElement == null || keyOrdering.compare(leftKey,
matchKey) != 0) {
- stop = false
- rightMatches = null
- }
- }
- joinedRow
+ private[this] val resultProjection: (InternalRow) => InternalRow =
{
+ if (isUnsafeMode) {
+ UnsafeProjection.create(schema)
} else {
- // no more result
- throw new NoSuchElementException
+ identity[InternalRow]
}
}
- private def fetchLeft() = {
- if (leftIter.hasNext) {
- leftElement = leftIter.next()
- leftKey = leftKeyGenerator(leftElement)
- } else {
- leftElement = null
+ override def advanceNext(): Boolean = {
+ if (currentMatchIdx == -1 || currentMatchIdx ==
currentRightMatches.length) {
+ if (smjScanner.findNextInnerJoinRows()) {
+ currentRightMatches = smjScanner.getBufferedMatches
+ currentLeftRow = smjScanner.getStreamedRow
+ currentMatchIdx = 0
+ } else {
+ currentRightMatches = null
+ currentLeftRow = null
+ currentMatchIdx = -1
+ }
}
- }
-
- private def fetchRight() = {
- if (rightIter.hasNext) {
- rightElement = rightIter.next()
- rightKey = rightKeyGenerator(rightElement)
+ if (currentLeftRow != null) {
+ joinRow(currentLeftRow, currentRightMatches(currentMatchIdx))
+ currentMatchIdx += 1
+ true
} else {
- rightElement = null
+ false
}
}
- private def initialize() = {
- fetchLeft()
- fetchRight()
+ override def getRow: InternalRow = resultProjection(joinRow)
+ }.toScala
+ }
+ }
+}
+
+/**
+ * Helper class that is used to implement [[SortMergeJoin]] and
[[SortMergeOuterJoin]].
+ *
+ * The streamed input is the left side of a left outer join or the right
side of a right outer join.
+ *
+ * To perform an inner (outer) join, users of this class call
[[findNextInnerJoinRows()]]
+ * ([[findNextOuterJoinRows()]]), which returns `true` if a result has
been produced and `false`
+ * otherwise. If a result has been produced, then the caller may call
[[getStreamedRow]] to return
+ * the matching row from the streamed input and may call
[[getBufferedMatches]] to return the
+ * sequence of matching rows from the buffered input (in the case of an
outer join, this will return
+ * an empty sequence). For efficiency, both of these methods return
mutable objects which are
+ * re-used across calls to the `findNext*JoinRows()` methods.
+ *
+ * @param streamedKeyGenerator a projection that produces join keys from
the streamed input.
+ * @param bufferedKeyGenerator a projection that produces join keys from
the buffered input.
+ * @param keyOrdering an ordering which can be used to compare join keys.
+ * @param streamedIter an input whose rows will be streamed.
+ * @param bufferedIter an input whose rows will be buffered to construct
sequences of rows that
+ * have the same join key.
+ */
+private[joins] class SortMergeJoinScanner(
+ streamedKeyGenerator: Projection,
+ bufferedKeyGenerator: Projection,
+ keyOrdering: Ordering[InternalRow],
+ streamedIter: RowIterator,
+ bufferedIter: RowIterator) {
+ private[this] var streamedRow: InternalRow = _
+ private[this] var streamedRowKey: InternalRow = _
+ private[this] var bufferedRow: InternalRow = _
+ private[this] var bufferedRowKey: InternalRow = _
+ /**
+ * The join key for the rows buffered in `bufferedMatches`, or null if
`bufferedMatches` is empty
+ */
+ private[this] var matchJoinKey: InternalRow = _
+ /** Buffered rows from the buffered side of the join. This is empty if
there are no matches. */
+ private[this] val bufferedMatches: ArrayBuffer[InternalRow] = new
ArrayBuffer[InternalRow]
+
+ // Initialization (note: do _not_ want to advance streamed here).
+ advancedBuffered()
+
+ // --- Public methods
---------------------------------------------------------------------------
+
+ def getStreamedRow: InternalRow = streamedRow
+
+ def getBufferedMatches: ArrayBuffer[InternalRow] = bufferedMatches
+
+ /**
+ * Advances both input iterators, stopping when we have found rows with
matching join keys.
+ * @return true if matching rows have been found and false otherwise. If
this returns true, then
+ * [[getStreamedRow]] and [[getBufferedMatches]] can be called
to construct the join
+ * results.
+ */
+ final def findNextInnerJoinRows(): Boolean = {
+ while (advancedStreamed() && streamedRowKey.anyNull) {
+ // Advance the streamed side of the join until we find the next row
whose join key contains
+ // no nulls or we hit the end of the streamed iterator.
+ }
+ if (streamedRow == null) {
+ // We have consumed the entire streamed iterator, so there can be no
more matches.
+ matchJoinKey = null
+ bufferedMatches.clear()
+ false
+ } else if (matchJoinKey != null && keyOrdering.compare(streamedRowKey,
matchJoinKey) == 0) {
+ // The new streamed row has the same join key as the previous row,
so return the same matches.
+ true
+ } else if (bufferedRow == null) {
+ // The streamed row's join key does not match the current batch of
buffered rows and there are
+ // no more rows to read from the buffered iterator, so there can be
no more matches.
+ matchJoinKey = null
+ bufferedMatches.clear()
+ false
+ } else {
+ // Advance both the streamed and buffered iterators to find the next
pair of matching rows.
+ var comp = keyOrdering.compare(streamedRowKey, bufferedRowKey)
--- End diff --
It's used to break out of the loop.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]