Github user rxin commented on a diff in the pull request:

    https://github.com/apache/spark/pull/12102#discussion_r58289799
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala ---
    @@ -113,166 +114,112 @@ trait HashJoin {
       protected def buildSideKeyGenerator: Projection =
         UnsafeProjection.create(rewriteKeyExpr(buildKeys), buildPlan.output)
     
    -  protected def streamSideKeyGenerator: Projection =
    +  protected def streamSideKeyGenerator: UnsafeProjection =
         UnsafeProjection.create(rewriteKeyExpr(streamedKeys), 
streamedPlan.output)
     
       @transient private[this] lazy val boundCondition = if 
(condition.isDefined) {
    -    newPredicate(condition.getOrElse(Literal(true)), left.output ++ 
right.output)
    +    newPredicate(condition.get, streamedPlan.output ++ buildPlan.output)
       } else {
         (r: InternalRow) => true
       }
     
    -  protected def createResultProjection: (InternalRow) => InternalRow =
    -    UnsafeProjection.create(self.schema)
    -
    -  protected def hashJoin(
    -      streamIter: Iterator[InternalRow],
    -      hashedRelation: HashedRelation,
    -      numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
    -    new Iterator[InternalRow] {
    -      private[this] var currentStreamedRow: InternalRow = _
    -      private[this] var currentHashMatches: Seq[InternalRow] = _
    -      private[this] var currentMatchPosition: Int = -1
    -
    -      // Mutable per row objects.
    -      private[this] val joinRow = new JoinedRow
    -      private[this] val resultProjection = createResultProjection
    -
    -      private[this] val joinKeys = streamSideKeyGenerator
    -
    -      override final def hasNext: Boolean = {
    -        while (true) {
    -          // check if it's end of current matches
    -          if (currentHashMatches != null && currentMatchPosition == 
currentHashMatches.length) {
    -            currentHashMatches = null
    -            currentMatchPosition = -1
    -          }
    -
    -          // find the next match
    -          while (currentHashMatches == null && streamIter.hasNext) {
    -            currentStreamedRow = streamIter.next()
    -            val key = joinKeys(currentStreamedRow)
    -            if (!key.anyNull) {
    -              currentHashMatches = hashedRelation.get(key)
    -              if (currentHashMatches != null) {
    -                currentMatchPosition = 0
    -              }
    -            }
    -          }
    -          if (currentHashMatches == null) {
    -            return false
    -          }
    -
    -          // found some matches
    -          buildSide match {
    -            case BuildRight => joinRow(currentStreamedRow, 
currentHashMatches(currentMatchPosition))
    -            case BuildLeft => 
joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow)
    -          }
    -          if (boundCondition(joinRow)) {
    -            return true
    -          } else {
    -            currentMatchPosition += 1
    -          }
    -        }
    -        false  // unreachable
    -      }
    -
    -      override final def next(): InternalRow = {
    -        // next() could be called without calling hasNext()
    -        if (hasNext) {
    -          currentMatchPosition += 1
    -          numOutputRows += 1
    -          resultProjection(joinRow)
    -        } else {
    -          throw new NoSuchElementException
    -        }
    -      }
    +  protected def createResultProjection: (InternalRow) => InternalRow = {
    +    if (joinType == LeftSemi) {
    +      UnsafeProjection.create(output, output)
    +    } else {
    +      // Always put the stream side on left to simplify implementation
    +      // both of left and right side could be null
    +      UnsafeProjection.create(
    +        output, (streamedPlan.output ++ 
buildPlan.output).map(_.withNullability(true)))
         }
       }
     
    -  @transient protected[this] lazy val EMPTY_LIST = 
CompactBuffer[InternalRow]()
    -
    -  @transient private[this] lazy val leftNullRow = new 
GenericInternalRow(left.output.length)
    -  @transient private[this] lazy val rightNullRow = new 
GenericInternalRow(right.output.length)
    -
    -  protected[this] def leftOuterIterator(
    -      key: InternalRow,
    -      joinedRow: JoinedRow,
    -      rightIter: Iterable[InternalRow],
    -      resultProjection: InternalRow => InternalRow,
    -      numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
    -    val ret: Iterable[InternalRow] = {
    -      if (!key.anyNull) {
    -        val temp = if (rightIter != null) {
    -          rightIter.collect {
    -            case r if boundCondition(joinedRow.withRight(r)) => {
    -              numOutputRows += 1
    -              resultProjection(joinedRow).copy()
    -            }
    -          }
    -        } else {
    -          List.empty
    -        }
    -        if (temp.isEmpty) {
    -          numOutputRows += 1
    -          resultProjection(joinedRow.withRight(rightNullRow)) :: Nil
    -        } else {
    -          temp
    -        }
    +  protected def innerJoin(
    +      streamIter: Iterator[InternalRow],
    +      hashedRelation: HashedRelation): Iterator[InternalRow] = {
    +    val joinRow = new JoinedRow
    +    val joinKeys = streamSideKeyGenerator
    +    streamIter.flatMap { srow =>
    +      joinRow.withLeft(srow)
    +      val matches = hashedRelation.get(joinKeys(srow))
    +      if (matches != null) {
    +        matches.map(joinRow.withRight(_)).filter(boundCondition)
           } else {
    -        numOutputRows += 1
    -        resultProjection(joinedRow.withRight(rightNullRow)) :: Nil
    +        Seq.empty
           }
         }
    -    ret.iterator
       }
     
    -  protected[this] def rightOuterIterator(
    -      key: InternalRow,
    -      leftIter: Iterable[InternalRow],
    +  @transient private[this] lazy val nullRow = new 
GenericInternalRow(buildPlan.output.length)
    +
    +  protected[this] def outerIterator(
           joinedRow: JoinedRow,
    -      resultProjection: InternalRow => InternalRow,
    -      numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
    -    val ret: Iterable[InternalRow] = {
    -      if (!key.anyNull) {
    -        val temp = if (leftIter != null) {
    -          leftIter.collect {
    -            case l if boundCondition(joinedRow.withLeft(l)) => {
    -              numOutputRows += 1
    -              resultProjection(joinedRow).copy()
    -            }
    +      buildIter: Iterator[InternalRow]): Iterator[InternalRow] = {
    +    new RowIterator {
    +      private var found = false
    +      override def advanceNext(): Boolean = {
    +        while (buildIter != null && buildIter.hasNext) {
    +          val nextBuildRow = buildIter.next()
    +          if (boundCondition(joinedRow.withRight(nextBuildRow))) {
    +            found = true
    +            return true
               }
    -        } else {
    -          List.empty
             }
    -        if (temp.isEmpty) {
    -          numOutputRows += 1
    -          resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil
    -        } else {
    -          temp
    +        if (!found) {
    +          joinedRow.withRight(nullRow)
    +          found = true
    +          return true
             }
    -      } else {
    -        numOutputRows += 1
    -        resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil
    +        false
           }
    -    }
    -    ret.iterator
    +      override def getRow: InternalRow = joinedRow
    +    }.toScala
       }
     
       protected def hashSemiJoin(
         streamIter: Iterator[InternalRow],
    -    hashedRelation: HashedRelation,
    -    numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
    +    hashedRelation: HashedRelation): Iterator[InternalRow] = {
         val joinKeys = streamSideKeyGenerator
         val joinedRow = new JoinedRow
         streamIter.filter { current =>
           val key = joinKeys(current)
    -      lazy val rowBuffer = hashedRelation.get(key)
    -      val r = !key.anyNull && rowBuffer != null && (condition.isEmpty || 
rowBuffer.exists {
    +      lazy val buildIter = hashedRelation.get(key)
    +      !key.anyNull && buildIter != null && (condition.isEmpty || 
buildIter.exists {
             (row: InternalRow) => boundCondition(joinedRow(current, row))
           })
    -      if (r) numOutputRows += 1
    -      r
    +    }
    +  }
    +
    +  protected def join(
    +      streamedIter: Iterator[InternalRow],
    +      hashed: HashedRelation,
    +      numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
    +    val joinedRow = new JoinedRow()
    +    val keyGenerator = streamSideKeyGenerator
    +
    +    val joinedIter = joinType match {
    +      case Inner =>
    +        innerJoin(streamedIter, hashed)
    +
    +      case LeftOuter | RightOuter =>
    +        streamedIter.flatMap { currentRow =>
    --- End diff --
    
    hm i see why you named it outerIterator. can we move this and the 
outerIterator together just into a function called outerJoin?


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to