Github user tdas commented on a diff in the pull request:
https://github.com/apache/spark/pull/19327#discussion_r140608418
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
---
@@ -89,61 +89,124 @@ class SymmetricHashJoinStateManager(
/**
* Remove using a predicate on keys. See class docs for more context and
implement details.
*/
- def removeByKeyCondition(condition: UnsafeRow => Boolean): Unit = {
- val allKeyToNumValues = keyToNumValues.iterator
-
- while (allKeyToNumValues.hasNext) {
- val keyToNumValue = allKeyToNumValues.next
- if (condition(keyToNumValue.key)) {
- keyToNumValues.remove(keyToNumValue.key)
- keyWithIndexToValue.removeAllValues(keyToNumValue.key,
keyToNumValue.numValue)
+ def removeByKeyCondition(condition: UnsafeRow => Boolean):
Iterator[(UnsafeRow, UnsafeRow)] = {
+ new NextIterator[(UnsafeRow, UnsafeRow)] {
+
+ private val allKeyToNumValues = keyToNumValues.iterator
+
+ private var currentKeyToNumValue: Option[KeyAndNumValues] = None
+ private var currentValues: Option[Iterator[(UnsafeRow, Long)]] = None
+
+ private def currentKey = currentKeyToNumValue.get.key
+
+ private def getAndRemoveValue() = {
+ val (current, index) = currentValues.get.next()
--- End diff --
current -> currentValue
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]