Github user tdas commented on a diff in the pull request: https://github.com/apache/spark/pull/19327#discussion_r140614816 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala --- @@ -89,61 +89,124 @@ class SymmetricHashJoinStateManager( /** * Remove using a predicate on keys. See class docs for more context and implement details. */ - def removeByKeyCondition(condition: UnsafeRow => Boolean): Unit = { - val allKeyToNumValues = keyToNumValues.iterator - - while (allKeyToNumValues.hasNext) { - val keyToNumValue = allKeyToNumValues.next - if (condition(keyToNumValue.key)) { - keyToNumValues.remove(keyToNumValue.key) - keyWithIndexToValue.removeAllValues(keyToNumValue.key, keyToNumValue.numValue) + def removeByKeyCondition(condition: UnsafeRow => Boolean): Iterator[(UnsafeRow, UnsafeRow)] = { + new NextIterator[(UnsafeRow, UnsafeRow)] { + + private val allKeyToNumValues = keyToNumValues.iterator + + private var currentKeyToNumValue: Option[KeyAndNumValues] = None + private var currentValues: Option[Iterator[(UnsafeRow, Long)]] = None + + private def currentKey = currentKeyToNumValue.get.key + + private def getAndRemoveValue() = { + val (current, index) = currentValues.get.next() + keyWithIndexToValue.remove(currentKey, index) + (currentKey, current) + } + + override def getNext(): (UnsafeRow, UnsafeRow) = { + if (currentValues.nonEmpty && currentValues.get.hasNext) { + return getAndRemoveValue() + } else { + while (allKeyToNumValues.hasNext) { + currentKeyToNumValue = Some(allKeyToNumValues.next()) + if (condition(currentKey)) { + currentValues = Some(keyWithIndexToValue.getAllWithIndex( + currentKey, currentKeyToNumValue.get.numValue)) + keyToNumValues.remove(currentKey) + + if (currentValues.nonEmpty && currentValues.get.hasNext) { + return getAndRemoveValue() + } + } + } + } + + finished = true + null } + + override def close: Unit = {} } } /** * Remove using a predicate on values. See class docs for more context and implementation details. */ - def removeByValueCondition(condition: UnsafeRow => Boolean): Unit = { - val allKeyToNumValues = keyToNumValues.iterator + def removeByValueCondition(condition: UnsafeRow => Boolean): Iterator[(UnsafeRow, UnsafeRow)] = { + new NextIterator[(UnsafeRow, UnsafeRow)] { + + private val allKeyToNumValues = keyToNumValues.iterator - while (allKeyToNumValues.hasNext) { - val keyToNumValue = allKeyToNumValues.next - val key = keyToNumValue.key + private var currentKeyToNumValue: Option[KeyAndNumValues] = None - var numValues: Long = keyToNumValue.numValue + private def currentKey = currentKeyToNumValue.get.key + + var numValues: Long = 0L var index: Long = 0L var valueRemoved: Boolean = false var valueForIndex: UnsafeRow = null - while (index < numValues) { - if (valueForIndex == null) { - valueForIndex = keyWithIndexToValue.get(key, index) + private def cleanupCurrentKey(): Unit = { + if (valueRemoved) { + if (numValues >= 1) { + keyToNumValues.put(currentKey, numValues) + } else { + keyToNumValues.remove(currentKey) + } + } + + numValues = 0 + index = 0 + valueRemoved = false + valueForIndex = null + } + + override def getNext(): (UnsafeRow, UnsafeRow) = { + // TODO: there has to be a better way to express this but I don't know what it is + while (valueForIndex == null && (index < numValues || allKeyToNumValues.hasNext)) { + if (index < numValues) { + val current = keyWithIndexToValue.get(currentKey, index) + if (condition(current)) { + valueForIndex = current + } else { + index += 1 + } + } else { + cleanupCurrentKey() + + currentKeyToNumValue = Some(allKeyToNumValues.next()) + numValues = currentKeyToNumValue.get.numValue --- End diff -- numValues -> currentKeyNumValues
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org