Github user JoshRosen commented on a diff in the pull request:
https://github.com/apache/spark/pull/7904#discussion_r36585292
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
---
@@ -56,117 +53,274 @@ case class SortMergeJoin(
@transient protected lazy val leftKeyGenerator = newProjection(leftKeys,
left.output)
@transient protected lazy val rightKeyGenerator =
newProjection(rightKeys, right.output)
+ protected[this] def isUnsafeMode: Boolean = {
+ (codegenEnabled && unsafeEnabled
+ && UnsafeProjection.canSupport(leftKeys)
+ && UnsafeProjection.canSupport(schema))
+ }
+
+ override def outputsUnsafeRows: Boolean = isUnsafeMode
+ override def canProcessUnsafeRows: Boolean = isUnsafeMode
+ override def canProcessSafeRows: Boolean = !isUnsafeMode
+
private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = {
// This must be ascending in order to agree with the `keyOrdering`
defined in `doExecute()`.
keys.map(SortOrder(_, Ascending))
}
protected override def doExecute(): RDD[InternalRow] = {
- val leftResults = left.execute().map(_.copy())
- val rightResults = right.execute().map(_.copy())
-
- leftResults.zipPartitions(rightResults) { (leftIter, rightIter) =>
- new Iterator[InternalRow] {
+ left.execute().zipPartitions(right.execute()) { (leftIter, rightIter)
=>
+ new RowIterator {
// An ordering that can be used to compare keys from both sides.
private[this] val keyOrdering =
newNaturalAscendingOrdering(leftKeys.map(_.dataType))
- // Mutable per row objects.
+ private[this] var currentLeftRow: InternalRow = _
+ private[this] var currentRightMatches: ArrayBuffer[InternalRow] = _
+ private[this] var currentMatchIdx: Int = -1
+ private[this] val smjScanner = new SortMergeJoinScanner(
+ leftKeyGenerator,
+ rightKeyGenerator,
+ keyOrdering,
+ RowIterator.fromScala(leftIter),
+ RowIterator.fromScala(rightIter)
+ )
private[this] val joinRow = new JoinedRow
- private[this] var leftElement: InternalRow = _
- private[this] var rightElement: InternalRow = _
- private[this] var leftKey: InternalRow = _
- private[this] var rightKey: InternalRow = _
- private[this] var rightMatches: CompactBuffer[InternalRow] = _
- private[this] var rightPosition: Int = -1
- private[this] var stop: Boolean = false
- private[this] var matchKey: InternalRow = _
-
- // initialize iterator
- initialize()
-
- override final def hasNext: Boolean = nextMatchingPair()
-
- override final def next(): InternalRow = {
- if (hasNext) {
- // we are using the buffered right rows and run down left
iterator
- val joinedRow = joinRow(leftElement,
rightMatches(rightPosition))
- rightPosition += 1
- if (rightPosition >= rightMatches.size) {
- rightPosition = 0
- fetchLeft()
- if (leftElement == null || keyOrdering.compare(leftKey,
matchKey) != 0) {
- stop = false
- rightMatches = null
- }
- }
- joinedRow
+ private[this] val resultProjection: (InternalRow) => InternalRow =
{
+ if (isUnsafeMode) {
+ UnsafeProjection.create(schema)
} else {
- // no more result
- throw new NoSuchElementException
+ identity[InternalRow]
}
}
- private def fetchLeft() = {
- if (leftIter.hasNext) {
- leftElement = leftIter.next()
- leftKey = leftKeyGenerator(leftElement)
- } else {
- leftElement = null
+ override def advanceNext(): Boolean = {
+ if (currentMatchIdx == -1 || currentMatchIdx ==
currentRightMatches.length) {
+ if (smjScanner.findNextInnerJoinRows()) {
+ currentRightMatches = smjScanner.getBufferedMatches
+ currentLeftRow = smjScanner.getStreamedRow
+ currentMatchIdx = 0
+ } else {
+ currentRightMatches = null
+ currentLeftRow = null
+ currentMatchIdx = -1
+ }
}
- }
-
- private def fetchRight() = {
- if (rightIter.hasNext) {
- rightElement = rightIter.next()
- rightKey = rightKeyGenerator(rightElement)
+ if (currentLeftRow != null) {
+ joinRow(currentLeftRow, currentRightMatches(currentMatchIdx))
+ currentMatchIdx += 1
+ true
} else {
- rightElement = null
+ false
}
}
- private def initialize() = {
- fetchLeft()
- fetchRight()
+ override def getRow: InternalRow = resultProjection(joinRow)
+ }.toScala
+ }
+ }
+}
+
+/**
+ * Helper class that is used to implement [[SortMergeJoin]] and
[[SortMergeOuterJoin]].
+ *
+ * The streamed input is the left side of a left outer join or the right
side of a right outer join.
+ *
+ * To perform an inner (outer) join, users of this class call
[[findNextInnerJoinRows()]]
+ * ([[findNextOuterJoinRows()]]), which returns `true` if a result has
been produced and `false`
+ * otherwise. If a result has been produced, then the caller may call
[[getStreamedRow]] to return
+ * the matching row from the streamed input and may call
[[getBufferedMatches]] to return the
+ * sequence of matching rows from the buffered input (in the case of an
outer join, this will return
+ * an empty sequence). For efficiency, both of these methods return
mutable objects which are
+ * re-used across calls to the `findNext*JoinRows()` methods.
+ *
+ * @param streamedKeyGenerator a projection that produces join keys from
the streamed input.
+ * @param bufferedKeyGenerator a projection that produces join keys from
the buffered input.
+ * @param keyOrdering an ordering which can be used to compare join keys.
+ * @param streamedIter an input whose rows will be streamed.
+ * @param bufferedIter an input whose rows will be buffered to construct
sequences of rows that
+ * have the same join key.
+ */
+private[joins] class SortMergeJoinScanner(
+ streamedKeyGenerator: Projection,
+ bufferedKeyGenerator: Projection,
+ keyOrdering: Ordering[InternalRow],
+ streamedIter: RowIterator,
+ bufferedIter: RowIterator) {
+ private[this] var streamedRow: InternalRow = _
+ private[this] var streamedRowKey: InternalRow = _
+ private[this] var bufferedRow: InternalRow = _
+ private[this] var bufferedRowKey: InternalRow = _
+ /**
+ * The join key for the rows buffered in `bufferedMatches`, or null if
`bufferedMatches` is empty
+ */
+ private[this] var matchJoinKey: InternalRow = _
+ /** Buffered rows from the buffered side of the join. This is empty if
there are no matches. */
+ private[this] val bufferedMatches: ArrayBuffer[InternalRow] = new
ArrayBuffer[InternalRow]
+
+ // Initialization (note: do _not_ want to advance streamed here).
+ advancedBuffered()
+
+ // --- Public methods
---------------------------------------------------------------------------
+
+ def getStreamedRow: InternalRow = streamedRow
+
+ def getBufferedMatches: ArrayBuffer[InternalRow] = bufferedMatches
+
+ /**
+ * Advances both input iterators, stopping when we have found rows with
matching join keys.
+ * @return true if matching rows have been found and false otherwise. If
this returns true, then
+ * [[getStreamedRow]] and [[getBufferedMatches]] can be called
to construct the join
+ * results.
+ */
+ final def findNextInnerJoinRows(): Boolean = {
+ while (advancedStreamed() && streamedRowKey.anyNull) {
+ // Advance the streamed side of the join until we find the next row
whose join key contains
+ // no nulls or we hit the end of the streamed iterator.
+ }
+ if (streamedRow == null) {
+ // We have consumed the entire streamed iterator, so there can be no
more matches.
+ matchJoinKey = null
+ bufferedMatches.clear()
+ false
+ } else if (matchJoinKey != null && keyOrdering.compare(streamedRowKey,
matchJoinKey) == 0) {
+ // The new streamed row has the same join key as the previous row,
so return the same matches.
+ true
+ } else if (bufferedRow == null) {
+ // The streamed row's join key does not match the current batch of
buffered rows and there are
+ // no more rows to read from the buffered iterator, so there can be
no more matches.
+ matchJoinKey = null
+ bufferedMatches.clear()
+ false
+ } else {
+ // Advance both the streamed and buffered iterators to find the next
pair of matching rows.
+ var comp = keyOrdering.compare(streamedRowKey, bufferedRowKey)
+ do {
+ if (streamedRowKey.anyNull) {
+ advancedStreamed()
+ } else if (bufferedRowKey.anyNull) {
+ advancedBuffered()
+ } else {
+ comp = keyOrdering.compare(streamedRowKey, bufferedRowKey)
+ if (comp > 0) advancedBuffered()
+ else if (comp < 0) advancedStreamed()
}
+ } while (streamedRow != null && bufferedRow != null && comp != 0)
+ if (streamedRow == null || bufferedRow == null) {
+ // We have either hit the end of one of the iterators, so there
can be no more matches.
+ matchJoinKey = null
+ bufferedMatches.clear()
+ false
+ } else {
+ // The streamed row's join key matches the current buffered row's
join, so walk through the
+ // buffered iterator to buffer the rest of the matching rows.
+ assert(comp == 0)
+ bufferMatchingRows()
+ true
+ }
+ }
+ }
- /**
- * Searches the right iterator for the next rows that have matches
in left side, and store
- * them in a buffer.
- *
- * @return true if the search is successful, and false if the
right iterator runs out of
- * tuples.
- */
- private def nextMatchingPair(): Boolean = {
- if (!stop && rightElement != null) {
- // run both side to get the first match pair
- while (!stop && leftElement != null && rightElement != null) {
- val comparing = keyOrdering.compare(leftKey, rightKey)
- // for inner join, we need to filter those null keys
- stop = comparing == 0 && !leftKey.anyNull
- if (comparing > 0 || rightKey.anyNull) {
- fetchRight()
- } else if (comparing < 0 || leftKey.anyNull) {
- fetchLeft()
- }
- }
- rightMatches = new CompactBuffer[InternalRow]()
- if (stop) {
- stop = false
- // iterate the right side to buffer all rows that matches
- // as the records should be ordered, exit when we meet the
first that not match
- while (!stop && rightElement != null) {
- rightMatches += rightElement
- fetchRight()
- stop = keyOrdering.compare(leftKey, rightKey) != 0
- }
- if (rightMatches.size > 0) {
- rightPosition = 0
- matchKey = leftKey
- }
+ /**
+ * Advances the streamed input iterator and buffers all rows from the
buffered input that
+ * have matching keys.
+ * @return true if the streamed iterator returned a row, false
otherwise. If this returns true,
+ * then [getStreamedRow and [[getBufferedMatches]] can be called
to produce the outer
+ * join results.
+ */
+ final def findNextOuterJoinRows(): Boolean = {
+ while (advancedStreamed() && streamedRowKey.anyNull) {
+ // Advance the streamed side of the join until we find the next row
whose join key contains
+ // no nulls or we hit the end of the streamed iterator.
+ }
+ if (streamedRow == null) {
+ // We have consumed the entire streamed iterator, so there can be no
more matches.
+ matchJoinKey = null
+ bufferedMatches.clear()
+ false
+ } else {
+ if (matchJoinKey != null && keyOrdering.compare(streamedRowKey,
matchJoinKey) == 0) {
+ // Matches the current group, so do nothing.
+ } else {
+ // The streamed row does not match the current group.
+ matchJoinKey = null
+ bufferedMatches.clear()
+ if (bufferedRow != null) {
+ // The buffered iterator could still contain matching rows, so
we'll need to walk through
+ // it until we either find matches or pass where they would be
found.
+ var comp =
+ if (bufferedRowKey.anyNull) 1 else
keyOrdering.compare(streamedRowKey, bufferedRowKey)
+ while (comp > 0 && advancedBuffered()) {
+ comp = if (bufferedRowKey.anyNull) {
+ 1
+ } else {
+ keyOrdering.compare(streamedRowKey, bufferedRowKey)
}
}
- rightMatches != null && rightMatches.size > 0
+ if (comp == 0) {
+ // We have found matches, so buffer them (this updates
matchJoinKey)
+ bufferMatchingRows()
+ } else {
+ // We have overshot the position where the row would be found,
hence no matches.
+ }
}
}
+ // If there is a streamed input with a non-null join key, then we
always return true
+ true
+ }
+ }
+
+ // --- Private methods
--------------------------------------------------------------------------
+
+ /**
+ * Advance the streamed iterator and compute the new row's join key.
+ * @return true if the streamed iterator returned a row and false
otherwise.
+ */
+ private def advancedStreamed(): Boolean = {
+ if (streamedIter.advanceNext()) {
+ streamedRow = streamedIter.getRow
+ streamedRowKey = streamedKeyGenerator(streamedRow)
+ true
+ } else {
+ streamedRow = null
+ streamedRowKey = null
+ false
}
}
+
+ /**
+ * Advance the buffered iterator and compute the new row's join key.
+ * @return true if the buffered iterator returned a row and false
otherwise.
+ */
+ private def advancedBuffered(): Boolean = {
+ if (bufferedIter.advanceNext()) {
--- End diff --
Yep, I agree; I'll move the null checking to here.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]