Github user tdas commented on a diff in the pull request:
https://github.com/apache/spark/pull/19327#discussion_r140905683
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
---
@@ -216,22 +232,70 @@ case class StreamingSymmetricHashJoinExec(
}
// Filter the joined rows based on the given condition.
- val outputFilterFunction =
- newPredicate(condition.getOrElse(Literal(true)), left.output ++
right.output).eval _
- val filteredOutputIter =
- (leftOutputIter ++ rightOutputIter).filter(outputFilterFunction).map
{ row =>
- numOutputRows += 1
- row
- }
+ val outputFilterFunction =
newPredicate(condition.getOrElse(Literal(true)), output).eval _
+
+ val filteredInnerOutputIter = (leftOutputIter ++
rightOutputIter).filter(outputFilterFunction)
+
+ val outputIter: Iterator[InternalRow] = joinType match {
+ case Inner =>
+ filteredInnerOutputIter
+ case LeftOuter =>
+ // We generate the outer join input by:
+ // * Getting an iterator over the rows that have aged out on the
left side. These rows are
+ // candidates for being null joined. Note that to avoid doing
two passes, this iterator
+ // removes the rows from the state manager as they're processed.
+ // * Checking whether the current row matches a key in the right
side state. If it doesn't,
+ // we know we can join with null, since there was never
(including this batch) a match
+ // within the watermark period. If it does, there must have been
a match at some point, so
+ // we know we can't join with null.
+ val nullRight = new
GenericInternalRow(right.output.map(_.withNullability(true)).length)
+ val removedRowIter = leftSideJoiner.removeOldState()
+ val outerOutputIter = removedRowIter
+ .filterNot(pair => rightSideJoiner.containsKey(pair.key))
+ .map(pair => joinedRow.withLeft(pair.value).withRight(nullRight))
+
+ filteredInnerOutputIter ++ outerOutputIter
+ case RightOuter =>
+ // See comments for left outer case.
+ val nullLeft = new
GenericInternalRow(left.output.map(_.withNullability(true)).length)
+ val removedRowIter = rightSideJoiner.removeOldState()
+ val outerOutputIter = removedRowIter
+ .filterNot(pair => leftSideJoiner.containsKey(pair.key))
+ .map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value))
+
+ filteredInnerOutputIter ++ outerOutputIter
+ case _ =>
+ throwBadJoinTypeException()
+ Iterator()
+ }
+
+ val outputIterWithMetrics = outputIter.map { row =>
+ numOutputRows += 1
+ row
+ }
// Function to remove old state after all the input has been consumed
and output generated
def onOutputCompletion = {
allUpdatesTimeMs += math.max(NANOSECONDS.toMillis(System.nanoTime -
updateStartTimeNs), 0)
- // Remove old state if needed
+ // TODO: how to get this for removals as part of outer join?
allRemovalsTimeMs += timeTakenMs {
- leftSideJoiner.removeOldState()
- rightSideJoiner.removeOldState()
+ // Iterator which must be consumed after output completion before
committing.
+ // For outer joins, we've removed old state from the appropriate
side inline while we
+ // produced the null rows. So we need to finish cleaning the other
side. For inner joins
--- End diff --
"appropriate side inline" does not make sense to me. Something like this
would be better
"For inner joins, we have to remove unnecessary state rows from both sides
if possible. For outer joins, we have already removed unnecessary state rows
from the outer side (e.g., left side for left outer join) while generating the
outer "null" outputs. Now, we have to remove unnecessary state rows from the
other side (e.g., right side for the left outer join) if possible. In all
cases, nothing needs to be outputted, hence the removal needs to be done
greedily by immediately consuming the returned iterator."
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]