Github user tdas commented on a diff in the pull request:
https://github.com/apache/spark/pull/19327#discussion_r140913768
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
---
@@ -87,70 +87,157 @@ class SymmetricHashJoinStateManager(
}
/**
- * Remove using a predicate on keys. See class docs for more context and
implement details.
+ * Remove using a predicate on keys.
+ *
+ * This produces an iterator over the (key, value) pairs satisfying
condition(key), where the
+ * underlying store is updated as a side-effect of producing next.
+ *
+ * This implies the iterator must be consumed fully without any other
operations on this manager
+ * or the underlying store being interleaved.
*/
- def removeByKeyCondition(condition: UnsafeRow => Boolean): Unit = {
- val allKeyToNumValues = keyToNumValues.iterator
-
- while (allKeyToNumValues.hasNext) {
- val keyToNumValue = allKeyToNumValues.next
- if (condition(keyToNumValue.key)) {
- keyToNumValues.remove(keyToNumValue.key)
- keyWithIndexToValue.removeAllValues(keyToNumValue.key,
keyToNumValue.numValue)
+ def removeByKeyCondition(condition: UnsafeRow => Boolean):
Iterator[UnsafeRowPair] = {
+ new NextIterator[UnsafeRowPair] {
+
+ private val allKeyToNumValues = keyToNumValues.iterator
+
+ private var currentKeyToNumValue: KeyAndNumValues = null
+ private var currentValues: Iterator[KeyWithIndexAndValue] = null
+
+ private def currentKey = currentKeyToNumValue.key
+
+ private val reusedPair = new UnsafeRowPair()
+
+ private def getAndRemoveValue() = {
+ val keyWithIndexAndValue = currentValues.next()
+ keyWithIndexToValue.remove(currentKey,
keyWithIndexAndValue.valueIndex)
+ reusedPair.withRows(currentKey, keyWithIndexAndValue.value)
+ }
+
+ override def getNext(): UnsafeRowPair = {
+ if (currentValues != null && currentValues.hasNext) {
+ return getAndRemoveValue()
+ } else {
+ while (allKeyToNumValues.hasNext) {
+ currentKeyToNumValue = allKeyToNumValues.next()
+ if (condition(currentKey)) {
+ currentValues = keyWithIndexToValue.getAll(
+ currentKey, currentKeyToNumValue.numValue)
+ keyToNumValues.remove(currentKey)
+
+ if (currentValues.hasNext) {
+ return getAndRemoveValue()
+ }
+ }
+ }
+ }
+
+ finished = true
+ null
}
+
+ override def close: Unit = {}
}
}
/**
- * Remove using a predicate on values. See class docs for more context
and implementation details.
+ * Remove using a predicate on values.
+ *
+ * At a high level, this produces an iterator over the (key, value)
pairs such that value
+ * satisfies the predicate, where producing an element removes the value
from the state store
+ * and producing all elements with a given key updates it accordingly.
+ *
+ * This implies the iterator must be consumed fully without any other
operations on this manager
+ * or the underlying store being interleaved.
*/
- def removeByValueCondition(condition: UnsafeRow => Boolean): Unit = {
- val allKeyToNumValues = keyToNumValues.iterator
+ def removeByValueCondition(condition: UnsafeRow => Boolean):
Iterator[UnsafeRowPair] = {
+ new NextIterator[UnsafeRowPair] {
- while (allKeyToNumValues.hasNext) {
- val keyToNumValue = allKeyToNumValues.next
- val key = keyToNumValue.key
+ // Reuse this object to avoid creation+GC overhead.
+ private val reusedPair = new UnsafeRowPair()
- var numValues: Long = keyToNumValue.numValue
- var index: Long = 0L
- var valueRemoved: Boolean = false
- var valueForIndex: UnsafeRow = null
+ private val allKeyToNumValues = keyToNumValues.iterator
- while (index < numValues) {
- if (valueForIndex == null) {
- valueForIndex = keyWithIndexToValue.get(key, index)
+ private var currentKey: UnsafeRow = null
+ private var numValues: Long = 0L
+ private var index: Long = 0L
+ private var valueRemoved: Boolean = false
+
+ // Push the data for the current key to the numValues store, and
reset the tracking variables
+ // to their empty state.
+ private def storeCurrentKey(): Unit = {
+ if (valueRemoved) {
+ if (numValues >= 1) {
+ keyToNumValues.put(currentKey, numValues)
+ } else {
+ keyToNumValues.remove(currentKey)
+ }
}
- if (condition(valueForIndex)) {
- if (numValues > 1) {
- val valueAtMaxIndex = keyWithIndexToValue.get(key, numValues -
1)
- keyWithIndexToValue.put(key, index, valueAtMaxIndex)
- keyWithIndexToValue.remove(key, numValues - 1)
- valueForIndex = valueAtMaxIndex
+
+ currentKey = null
+ numValues = 0
+ index = 0
+ valueRemoved = false
+ }
+
+ // Find the next value satisfying the condition, updating
`currentKey` and `numValues` if
+ // needed. Returns null when no value can be found.
+ private def findNextValueForIndex(): UnsafeRow = {
+ while (index < numValues || allKeyToNumValues.hasNext) {
--- End diff --
in fact to be even more obviously clear but verbose, you can do.
```
def hasMoreValuesForCurrentKey = currentKey != null && index < numValues
def hasMoreKeys = ... .hasNext
while (hasMoreValuesForCurrentKey || hasMoreValues) {
...
}
```
Then you dont event needs comments to explain it :)
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]