Github user JoshRosen commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5717#discussion_r32330289
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
 ---
    @@ -84,84 +129,164 @@ case class SortMergeJoin(
     
             override final def next(): Row = {
               if (hasNext) {
    -            // we are using the buffered right rows and run down left 
iterator
    -            val joinedRow = joinRow(leftElement, 
rightMatches(rightPosition))
    -            rightPosition += 1
    -            if (rightPosition >= rightMatches.size) {
    -              rightPosition = 0
    -              fetchLeft()
    -              if (leftElement == null || keyOrdering.compare(leftKey, 
matchKey) != 0) {
    -                stop = false
    -                rightMatches = null
    +            if (bufferedMatches == null || bufferedMatches.size == 0) {
    +              // we just found a row with no join match and we are here to 
produce a row
    +              // with this row with a standard null row from the other 
side.
    +              if (continueStreamed) {
    +                val joinedRow = smartJoinRow(streamedElement, 
bufferedNullRow.copy())
    +                fetchStreamed()
    +                joinedRow
    +              } else {
    +                val joinedRow = smartJoinRow(streamedNullRow.copy(), 
bufferedElement)
    +                fetchBuffered()
    +                joinedRow
    +              }
    +            } else {
    +              // we are using the buffered right rows and run down left 
iterator
    +              val joinedRow = smartJoinRow(streamedElement, 
bufferedMatches(bufferedPosition))
    +              bufferedPosition += 1
    +              if (bufferedPosition >= bufferedMatches.size) {
    +                bufferedPosition = 0
    +                if (joinType != FullOuter || secondStreamedElement == 
null) {
    +                  fetchStreamed()
    +                  if (streamedElement == null || 
keyOrdering.compare(streamedKey, matchKey) != 0) {
    +                    stop = false
    +                    bufferedMatches = null
    +                  }
    +                } else {
    +                  // in FullOuter join and the first time we finish the 
match buffer,
    +                  // we still want to generate all rows with streamed null 
row and buffered
    +                  // rows that match the join key but not the conditions.
    +                  streamedElement = secondStreamedElement
    +                  bufferedMatches = secondBufferedMatches
    +                  secondStreamedElement = null
    +                  secondBufferedMatches = null
    +                }
                   }
    +              joinedRow
                 }
    -            joinedRow
               } else {
                 // no more result
                 throw new NoSuchElementException
               }
             }
     
    -        private def fetchLeft() = {
    -          if (leftIter.hasNext) {
    -            leftElement = leftIter.next()
    -            leftKey = leftKeyGenerator(leftElement)
    +        private def smartJoinRow(streamedRow: Row, bufferedRow: Row): Row 
= joinType match {
    +          case RightOuter => joinRow(bufferedRow, streamedRow)
    +          case _ => joinRow(streamedRow, bufferedRow)
    +        }
    +
    +        private def fetchStreamed() = {
    +          if (streamedIter.hasNext) {
    +            streamedElement = streamedIter.next()
    +            streamedKey = streamedKeyGenerator(streamedElement)
               } else {
    -            leftElement = null
    +            streamedElement = null
               }
             }
     
    -        private def fetchRight() = {
    -          if (rightIter.hasNext) {
    -            rightElement = rightIter.next()
    -            rightKey = rightKeyGenerator(rightElement)
    +        private def fetchBuffered() = {
    +          if (bufferedIter.hasNext) {
    +            bufferedElement = bufferedIter.next()
    +            bufferedKey = bufferedKeyGenerator(bufferedElement)
               } else {
    -            rightElement = null
    +            bufferedElement = null
               }
             }
     
             private def initialize() = {
    -          fetchLeft()
    -          fetchRight()
    +          fetchStreamed()
    +          fetchBuffered()
             }
     
             /**
              * Searches the right iterator for the next rows that have matches 
in left side, and store
              * them in a buffer.
    +         * When this is not a Inner join, we will also return true when we 
get a row with no match
    +         * on the other side. This search will jump out every time from 
the same position until
    +         * `next()` is called.
              *
              * @return true if the search is successful, and false if the 
right iterator runs out of
              *         tuples.
              */
             private def nextMatchingPair(): Boolean = {
    --- End diff --
    
    Reading through the old version of the code, one tricky thing that stood 
out to me was the fact that if a consumer of this iterator calls `hasNext()` 
followed by `next()`, then we end up calling `nextMatchingPair()` two times in 
a row.  It might be nice to add a comment here to explain why this is safe / 
correct.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to