Github user tdas commented on a diff in the pull request:
https://github.com/apache/spark/pull/19327#discussion_r140904307
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
---
@@ -216,22 +232,70 @@ case class StreamingSymmetricHashJoinExec(
}
// Filter the joined rows based on the given condition.
- val outputFilterFunction =
- newPredicate(condition.getOrElse(Literal(true)), left.output ++
right.output).eval _
- val filteredOutputIter =
- (leftOutputIter ++ rightOutputIter).filter(outputFilterFunction).map
{ row =>
- numOutputRows += 1
- row
- }
+ val outputFilterFunction =
newPredicate(condition.getOrElse(Literal(true)), output).eval _
+
+ val filteredInnerOutputIter = (leftOutputIter ++
rightOutputIter).filter(outputFilterFunction)
+
+ val outputIter: Iterator[InternalRow] = joinType match {
+ case Inner =>
+ filteredInnerOutputIter
+ case LeftOuter =>
+ // We generate the outer join input by:
+ // * Getting an iterator over the rows that have aged out on the
left side. These rows are
+ // candidates for being null joined. Note that to avoid doing
two passes, this iterator
+ // removes the rows from the state manager as they're processed.
+ // * Checking whether the current row matches a key in the right
side state. If it doesn't,
+ // we know we can join with null, since there was never
(including this batch) a match
+ // within the watermark period. If it does, there must have been
a match at some point, so
+ // we know we can't join with null.
+ val nullRight = new
GenericInternalRow(right.output.map(_.withNullability(true)).length)
+ val removedRowIter = leftSideJoiner.removeOldState()
+ val outerOutputIter = removedRowIter
+ .filterNot(pair => rightSideJoiner.containsKey(pair.key))
+ .map(pair => joinedRow.withLeft(pair.value).withRight(nullRight))
+
+ filteredInnerOutputIter ++ outerOutputIter
+ case RightOuter =>
+ // See comments for left outer case.
+ val nullLeft = new
GenericInternalRow(left.output.map(_.withNullability(true)).length)
+ val removedRowIter = rightSideJoiner.removeOldState()
+ val outerOutputIter = removedRowIter
+ .filterNot(pair => leftSideJoiner.containsKey(pair.key))
+ .map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value))
+
+ filteredInnerOutputIter ++ outerOutputIter
+ case _ =>
+ throwBadJoinTypeException()
+ Iterator()
+ }
+
+ val outputIterWithMetrics = outputIter.map { row =>
+ numOutputRows += 1
+ row
+ }
// Function to remove old state after all the input has been consumed
and output generated
def onOutputCompletion = {
allUpdatesTimeMs += math.max(NANOSECONDS.toMillis(System.nanoTime -
updateStartTimeNs), 0)
- // Remove old state if needed
+ // TODO: how to get this for removals as part of outer join?
allRemovalsTimeMs += timeTakenMs {
- leftSideJoiner.removeOldState()
- rightSideJoiner.removeOldState()
+ // Iterator which must be consumed after output completion before
committing.
--- End diff --
nit: Rather than starting with how its implemented (i.e. referring to
iterator), first explain what does this code do semanticall (e.g. Remove all
the state rows that are not needed anymore).
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]