Github user tdas commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19327#discussion_r140617927
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
 ---
    @@ -216,22 +229,51 @@ case class StreamingSymmetricHashJoinExec(
         }
     
         // Filter the joined rows based on the given condition.
    -    val outputFilterFunction =
    -      newPredicate(condition.getOrElse(Literal(true)), left.output ++ 
right.output).eval _
    -    val filteredOutputIter =
    -      (leftOutputIter ++ rightOutputIter).filter(outputFilterFunction).map 
{ row =>
    -        numOutputRows += 1
    -        row
    -      }
    +    val outputFilterFunction = 
newPredicate(condition.getOrElse(Literal(true)), output).eval _
    +
    +    val filteredInnerOutputIter = (leftOutputIter ++ 
rightOutputIter).filter(outputFilterFunction)
    +
    +    val outputIter: Iterator[InternalRow] = joinType match {
    +      case Inner =>
    +        filteredInnerOutputIter
    +      case LeftOuter =>
    +        val nullRight = new 
GenericInternalRow(right.output.map(_.withNullability(true)).length)
    +        filteredInnerOutputIter ++
    +          leftSideJoiner
    +            .removeOldState()
    +            .filterNot { case (key, value) => 
rightSideJoiner.containsKey(key) }
    +            .map { case (key, value) => 
joinedRow.withLeft(value).withRight(nullRight) }
    +      case RightOuter =>
    +        val nullLeft = new 
GenericInternalRow(left.output.map(_.withNullability(true)).length)
    +        filteredInnerOutputIter ++
    +          rightSideJoiner
    +            .removeOldState()
    +            .filterNot { case (key, value) => 
leftSideJoiner.containsKey(key) }
    +            .map { case (key, value) => 
joinedRow.withLeft(nullLeft).withRight(value) }
    +      case _ => throw badJoinTypeException
    +    }
    +
    +    val outputIterWithMetrics = outputIter.map { row =>
    +      numOutputRows += 1
    +      row
    +    }
    +
    +    // Iterator which must be consumed after output completion before 
committing.
    +    val cleanupIter = joinType match {
    +      case Inner =>
    +        leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState()
    +      case LeftOuter => rightSideJoiner.removeOldState()
    +      case RightOuter => leftSideJoiner.removeOldState()
    +      case _ => throw badJoinTypeException
    +    }
     
         // Function to remove old state after all the input has been consumed 
and output generated
         def onOutputCompletion = {
           allUpdatesTimeMs += math.max(NANOSECONDS.toMillis(System.nanoTime - 
updateStartTimeNs), 0)
     
    -      // Remove old state if needed
    +      // TODO: how to get this for removals as part of outer join?
           allRemovalsTimeMs += timeTakenMs {
    -        leftSideJoiner.removeOldState()
    -        rightSideJoiner.removeOldState()
    +        cleanupIter.foreach(_ => ())
    --- End diff --
    
    dont use foreach. scala's foreach is pretty inefficient. use while loop.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to