Github user tdas commented on a diff in the pull request: https://github.com/apache/spark/pull/19327#discussion_r140617927 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala --- @@ -216,22 +229,51 @@ case class StreamingSymmetricHashJoinExec( } // Filter the joined rows based on the given condition. - val outputFilterFunction = - newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output).eval _ - val filteredOutputIter = - (leftOutputIter ++ rightOutputIter).filter(outputFilterFunction).map { row => - numOutputRows += 1 - row - } + val outputFilterFunction = newPredicate(condition.getOrElse(Literal(true)), output).eval _ + + val filteredInnerOutputIter = (leftOutputIter ++ rightOutputIter).filter(outputFilterFunction) + + val outputIter: Iterator[InternalRow] = joinType match { + case Inner => + filteredInnerOutputIter + case LeftOuter => + val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length) + filteredInnerOutputIter ++ + leftSideJoiner + .removeOldState() + .filterNot { case (key, value) => rightSideJoiner.containsKey(key) } + .map { case (key, value) => joinedRow.withLeft(value).withRight(nullRight) } + case RightOuter => + val nullLeft = new GenericInternalRow(left.output.map(_.withNullability(true)).length) + filteredInnerOutputIter ++ + rightSideJoiner + .removeOldState() + .filterNot { case (key, value) => leftSideJoiner.containsKey(key) } + .map { case (key, value) => joinedRow.withLeft(nullLeft).withRight(value) } + case _ => throw badJoinTypeException + } + + val outputIterWithMetrics = outputIter.map { row => + numOutputRows += 1 + row + } + + // Iterator which must be consumed after output completion before committing. + val cleanupIter = joinType match { + case Inner => + leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState() + case LeftOuter => rightSideJoiner.removeOldState() + case RightOuter => leftSideJoiner.removeOldState() + case _ => throw badJoinTypeException + } // Function to remove old state after all the input has been consumed and output generated def onOutputCompletion = { allUpdatesTimeMs += math.max(NANOSECONDS.toMillis(System.nanoTime - updateStartTimeNs), 0) - // Remove old state if needed + // TODO: how to get this for removals as part of outer join? allRemovalsTimeMs += timeTakenMs { - leftSideJoiner.removeOldState() - rightSideJoiner.removeOldState() + cleanupIter.foreach(_ => ()) --- End diff -- dont use foreach. scala's foreach is pretty inefficient. use while loop.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org