Github user yhuai commented on a diff in the pull request:

    https://github.com/apache/spark/pull/7904#discussion_r36697131
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
 ---
    @@ -56,117 +53,265 @@ case class SortMergeJoin(
       @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, 
left.output)
       @transient protected lazy val rightKeyGenerator = 
newProjection(rightKeys, right.output)
     
    +  protected[this] def isUnsafeMode: Boolean = {
    +    (codegenEnabled && unsafeEnabled
    +      && UnsafeProjection.canSupport(leftKeys)
    +      && UnsafeProjection.canSupport(schema))
    +  }
    +
    +  override def outputsUnsafeRows: Boolean = isUnsafeMode
    +  override def canProcessUnsafeRows: Boolean = isUnsafeMode
    +  override def canProcessSafeRows: Boolean = !isUnsafeMode
    +
       private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = {
         // This must be ascending in order to agree with the `keyOrdering` 
defined in `doExecute()`.
         keys.map(SortOrder(_, Ascending))
       }
     
       protected override def doExecute(): RDD[InternalRow] = {
    -    val leftResults = left.execute().map(_.copy())
    -    val rightResults = right.execute().map(_.copy())
    -
    -    leftResults.zipPartitions(rightResults) { (leftIter, rightIter) =>
    -      new Iterator[InternalRow] {
    +    left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) 
=>
    +      new RowIterator {
             // An ordering that can be used to compare keys from both sides.
             private[this] val keyOrdering = 
newNaturalAscendingOrdering(leftKeys.map(_.dataType))
    -        // Mutable per row objects.
    +        private[this] var currentLeftRow: InternalRow = _
    +        private[this] var currentRightMatches: ArrayBuffer[InternalRow] = _
    +        private[this] var currentMatchIdx: Int = -1
    +        private[this] val smjScanner = new SortMergeJoinScanner(
    +          leftKeyGenerator,
    +          rightKeyGenerator,
    +          keyOrdering,
    +          RowIterator.fromScala(leftIter),
    +          RowIterator.fromScala(rightIter)
    +        )
             private[this] val joinRow = new JoinedRow
    -        private[this] var leftElement: InternalRow = _
    -        private[this] var rightElement: InternalRow = _
    -        private[this] var leftKey: InternalRow = _
    -        private[this] var rightKey: InternalRow = _
    -        private[this] var rightMatches: CompactBuffer[InternalRow] = _
    -        private[this] var rightPosition: Int = -1
    -        private[this] var stop: Boolean = false
    -        private[this] var matchKey: InternalRow = _
    -
    -        // initialize iterator
    -        initialize()
    -
    -        override final def hasNext: Boolean = nextMatchingPair()
    -
    -        override final def next(): InternalRow = {
    -          if (hasNext) {
    -            // we are using the buffered right rows and run down left 
iterator
    -            val joinedRow = joinRow(leftElement, 
rightMatches(rightPosition))
    -            rightPosition += 1
    -            if (rightPosition >= rightMatches.size) {
    -              rightPosition = 0
    -              fetchLeft()
    -              if (leftElement == null || keyOrdering.compare(leftKey, 
matchKey) != 0) {
    -                stop = false
    -                rightMatches = null
    -              }
    -            }
    -            joinedRow
    +        private[this] val resultProjection: (InternalRow) => InternalRow = 
{
    +          if (isUnsafeMode) {
    +            UnsafeProjection.create(schema)
               } else {
    -            // no more result
    -            throw new NoSuchElementException
    +            identity[InternalRow]
               }
             }
     
    -        private def fetchLeft() = {
    -          if (leftIter.hasNext) {
    -            leftElement = leftIter.next()
    -            leftKey = leftKeyGenerator(leftElement)
    -          } else {
    -            leftElement = null
    +        override def advanceNext(): Boolean = {
    +          if (currentMatchIdx == -1 || currentMatchIdx == 
currentRightMatches.length) {
    +            if (smjScanner.findNextInnerJoinRows()) {
    +              currentRightMatches = smjScanner.getBufferedMatches
    +              currentLeftRow = smjScanner.getStreamedRow
    +              currentMatchIdx = 0
    +            } else {
    +              currentRightMatches = null
    +              currentLeftRow = null
    +              currentMatchIdx = -1
    +            }
               }
    -        }
    -
    -        private def fetchRight() = {
    -          if (rightIter.hasNext) {
    -            rightElement = rightIter.next()
    -            rightKey = rightKeyGenerator(rightElement)
    +          if (currentLeftRow != null) {
    +            joinRow(currentLeftRow, currentRightMatches(currentMatchIdx))
    +            currentMatchIdx += 1
    +            true
               } else {
    -            rightElement = null
    +            false
               }
             }
     
    -        private def initialize() = {
    -          fetchLeft()
    -          fetchRight()
    +        override def getRow: InternalRow = resultProjection(joinRow)
    +      }.toScala
    +    }
    +  }
    +}
    +
    +/**
    + * Helper class that is used to implement [[SortMergeJoin]] and 
[[SortMergeOuterJoin]].
    + *
    + * The streamed input is the left side of a left outer join or the right 
side of a right outer join.
    + *
    + * To perform an inner (outer) join, users of this class call 
[[findNextInnerJoinRows()]]
    + * ([[findNextOuterJoinRows()]]), which returns `true` if a result has 
been produced and `false`
    + * otherwise. If a result has been produced, then the caller may call 
[[getStreamedRow]] to return
    + * the matching row from the streamed input and may call 
[[getBufferedMatches]] to return the
    + * sequence of matching rows from the buffered input (in the case of an 
outer join, this will return
    + * an empty sequence). For efficiency, both of these methods return 
mutable objects which are
    --- End diff --
    
    Why `getBufferedMatches` returns an empty sequence when we have an outer 
join?


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to