Github user tdas commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19327#discussion_r140904307
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
 ---
    @@ -216,22 +232,70 @@ case class StreamingSymmetricHashJoinExec(
         }
     
         // Filter the joined rows based on the given condition.
    -    val outputFilterFunction =
    -      newPredicate(condition.getOrElse(Literal(true)), left.output ++ 
right.output).eval _
    -    val filteredOutputIter =
    -      (leftOutputIter ++ rightOutputIter).filter(outputFilterFunction).map 
{ row =>
    -        numOutputRows += 1
    -        row
    -      }
    +    val outputFilterFunction = 
newPredicate(condition.getOrElse(Literal(true)), output).eval _
    +
    +    val filteredInnerOutputIter = (leftOutputIter ++ 
rightOutputIter).filter(outputFilterFunction)
    +
    +    val outputIter: Iterator[InternalRow] = joinType match {
    +      case Inner =>
    +        filteredInnerOutputIter
    +      case LeftOuter =>
    +        // We generate the outer join input by:
    +        // * Getting an iterator over the rows that have aged out on the 
left side. These rows are
    +        //   candidates for being null joined. Note that to avoid doing 
two passes, this iterator
    +        //   removes the rows from the state manager as they're processed.
    +        // * Checking whether the current row matches a key in the right 
side state. If it doesn't,
    +        //   we know we can join with null, since there was never 
(including this batch) a match
    +        //   within the watermark period. If it does, there must have been 
a match at some point, so
    +        //   we know we can't join with null.
    +        val nullRight = new 
GenericInternalRow(right.output.map(_.withNullability(true)).length)
    +        val removedRowIter = leftSideJoiner.removeOldState()
    +        val outerOutputIter = removedRowIter
    +          .filterNot(pair => rightSideJoiner.containsKey(pair.key))
    +          .map(pair => joinedRow.withLeft(pair.value).withRight(nullRight))
    +
    +        filteredInnerOutputIter ++ outerOutputIter
    +      case RightOuter =>
    +        // See comments for left outer case.
    +        val nullLeft = new 
GenericInternalRow(left.output.map(_.withNullability(true)).length)
    +        val removedRowIter = rightSideJoiner.removeOldState()
    +        val outerOutputIter = removedRowIter
    +          .filterNot(pair => leftSideJoiner.containsKey(pair.key))
    +          .map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value))
    +
    +        filteredInnerOutputIter ++ outerOutputIter
    +      case _ =>
    +        throwBadJoinTypeException()
    +        Iterator()
    +    }
    +
    +    val outputIterWithMetrics = outputIter.map { row =>
    +      numOutputRows += 1
    +      row
    +    }
     
         // Function to remove old state after all the input has been consumed 
and output generated
         def onOutputCompletion = {
           allUpdatesTimeMs += math.max(NANOSECONDS.toMillis(System.nanoTime - 
updateStartTimeNs), 0)
     
    -      // Remove old state if needed
    +      // TODO: how to get this for removals as part of outer join?
           allRemovalsTimeMs += timeTakenMs {
    -        leftSideJoiner.removeOldState()
    -        rightSideJoiner.removeOldState()
    +        // Iterator which must be consumed after output completion before 
committing.
    --- End diff --
    
    nit: Rather than starting with how its implemented (i.e. referring to 
iterator), first explain what does this code do semanticall (e.g. Remove all 
the state rows that are not needed anymore).


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to