Github user tdas commented on a diff in the pull request:
https://github.com/apache/spark/pull/19327#discussion_r140614146
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
---
@@ -89,61 +89,124 @@ class SymmetricHashJoinStateManager(
/**
* Remove using a predicate on keys. See class docs for more context and
implement details.
*/
- def removeByKeyCondition(condition: UnsafeRow => Boolean): Unit = {
- val allKeyToNumValues = keyToNumValues.iterator
-
- while (allKeyToNumValues.hasNext) {
- val keyToNumValue = allKeyToNumValues.next
- if (condition(keyToNumValue.key)) {
- keyToNumValues.remove(keyToNumValue.key)
- keyWithIndexToValue.removeAllValues(keyToNumValue.key,
keyToNumValue.numValue)
+ def removeByKeyCondition(condition: UnsafeRow => Boolean):
Iterator[(UnsafeRow, UnsafeRow)] = {
+ new NextIterator[(UnsafeRow, UnsafeRow)] {
+
+ private val allKeyToNumValues = keyToNumValues.iterator
+
+ private var currentKeyToNumValue: Option[KeyAndNumValues] = None
+ private var currentValues: Option[Iterator[(UnsafeRow, Long)]] = None
+
+ private def currentKey = currentKeyToNumValue.get.key
+
+ private def getAndRemoveValue() = {
+ val (current, index) = currentValues.get.next()
+ keyWithIndexToValue.remove(currentKey, index)
+ (currentKey, current)
+ }
+
+ override def getNext(): (UnsafeRow, UnsafeRow) = {
+ if (currentValues.nonEmpty && currentValues.get.hasNext) {
+ return getAndRemoveValue()
+ } else {
+ while (allKeyToNumValues.hasNext) {
+ currentKeyToNumValue = Some(allKeyToNumValues.next())
+ if (condition(currentKey)) {
+ currentValues = Some(keyWithIndexToValue.getAllWithIndex(
+ currentKey, currentKeyToNumValue.get.numValue))
+ keyToNumValues.remove(currentKey)
+
+ if (currentValues.nonEmpty && currentValues.get.hasNext) {
+ return getAndRemoveValue()
+ }
+ }
+ }
+ }
+
+ finished = true
+ null
}
+
+ override def close: Unit = {}
}
}
/**
* Remove using a predicate on values. See class docs for more context
and implementation details.
*/
- def removeByValueCondition(condition: UnsafeRow => Boolean): Unit = {
- val allKeyToNumValues = keyToNumValues.iterator
+ def removeByValueCondition(condition: UnsafeRow => Boolean):
Iterator[(UnsafeRow, UnsafeRow)] = {
+ new NextIterator[(UnsafeRow, UnsafeRow)] {
+
+ private val allKeyToNumValues = keyToNumValues.iterator
- while (allKeyToNumValues.hasNext) {
- val keyToNumValue = allKeyToNumValues.next
- val key = keyToNumValue.key
+ private var currentKeyToNumValue: Option[KeyAndNumValues] = None
- var numValues: Long = keyToNumValue.numValue
+ private def currentKey = currentKeyToNumValue.get.key
+
+ var numValues: Long = 0L
var index: Long = 0L
var valueRemoved: Boolean = false
var valueForIndex: UnsafeRow = null
- while (index < numValues) {
- if (valueForIndex == null) {
- valueForIndex = keyWithIndexToValue.get(key, index)
+ private def cleanupCurrentKey(): Unit = {
+ if (valueRemoved) {
+ if (numValues >= 1) {
+ keyToNumValues.put(currentKey, numValues)
+ } else {
+ keyToNumValues.remove(currentKey)
+ }
+ }
+
+ numValues = 0
+ index = 0
+ valueRemoved = false
+ valueForIndex = null
+ }
+
+ override def getNext(): (UnsafeRow, UnsafeRow) = {
+ // TODO: there has to be a better way to express this but I don't
know what it is
+ while (valueForIndex == null && (index < numValues ||
allKeyToNumValues.hasNext)) {
--- End diff --
maybe rename `valueForIndex -> currentValue` and `index ->
nextValueIndex `
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]