Github user joseph-torres commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19327#discussion_r140968754
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
 ---
    @@ -87,70 +87,157 @@ class SymmetricHashJoinStateManager(
       }
     
       /**
    -   * Remove using a predicate on keys. See class docs for more context and 
implement details.
    +   * Remove using a predicate on keys.
    +   *
    +   * This produces an iterator over the (key, value) pairs satisfying 
condition(key), where the
    +   * underlying store is updated as a side-effect of producing next.
    +   *
    +   * This implies the iterator must be consumed fully without any other 
operations on this manager
    +   * or the underlying store being interleaved.
        */
    -  def removeByKeyCondition(condition: UnsafeRow => Boolean): Unit = {
    -    val allKeyToNumValues = keyToNumValues.iterator
    -
    -    while (allKeyToNumValues.hasNext) {
    -      val keyToNumValue = allKeyToNumValues.next
    -      if (condition(keyToNumValue.key)) {
    -        keyToNumValues.remove(keyToNumValue.key)
    -        keyWithIndexToValue.removeAllValues(keyToNumValue.key, 
keyToNumValue.numValue)
    +  def removeByKeyCondition(condition: UnsafeRow => Boolean): 
Iterator[UnsafeRowPair] = {
    +    new NextIterator[UnsafeRowPair] {
    +
    +      private val allKeyToNumValues = keyToNumValues.iterator
    +
    +      private var currentKeyToNumValue: KeyAndNumValues = null
    +      private var currentValues: Iterator[KeyWithIndexAndValue] = null
    +
    +      private def currentKey = currentKeyToNumValue.key
    +
    +      private val reusedPair = new UnsafeRowPair()
    +
    +      private def getAndRemoveValue() = {
    +        val keyWithIndexAndValue = currentValues.next()
    +        keyWithIndexToValue.remove(currentKey, 
keyWithIndexAndValue.valueIndex)
    +        reusedPair.withRows(currentKey, keyWithIndexAndValue.value)
    +      }
    +
    +      override def getNext(): UnsafeRowPair = {
    +        if (currentValues != null && currentValues.hasNext) {
    +          return getAndRemoveValue()
    +        } else {
    +          while (allKeyToNumValues.hasNext) {
    +            currentKeyToNumValue = allKeyToNumValues.next()
    +            if (condition(currentKey)) {
    +              currentValues = keyWithIndexToValue.getAll(
    +                currentKey, currentKeyToNumValue.numValue)
    +              keyToNumValues.remove(currentKey)
    +
    +              if (currentValues.hasNext) {
    +                return getAndRemoveValue()
    +              }
    +            }
    +          }
    +        }
    +
    +        finished = true
    +        null
           }
    +
    +      override def close: Unit = {}
         }
       }
     
       /**
    -   * Remove using a predicate on values. See class docs for more context 
and implementation details.
    +   * Remove using a predicate on values.
    +   *
    +   * At a high level, this produces an iterator over the (key, value) 
pairs such that value
    +   * satisfies the predicate, where producing an element removes the value 
from the state store
    +   * and producing all elements with a given key updates it accordingly.
    +   *
    +   * This implies the iterator must be consumed fully without any other 
operations on this manager
    +   * or the underlying store being interleaved.
        */
    -  def removeByValueCondition(condition: UnsafeRow => Boolean): Unit = {
    -    val allKeyToNumValues = keyToNumValues.iterator
    +  def removeByValueCondition(condition: UnsafeRow => Boolean): 
Iterator[UnsafeRowPair] = {
    +    new NextIterator[UnsafeRowPair] {
     
    -    while (allKeyToNumValues.hasNext) {
    -      val keyToNumValue = allKeyToNumValues.next
    -      val key = keyToNumValue.key
    +      // Reuse this object to avoid creation+GC overhead.
    +      private val reusedPair = new UnsafeRowPair()
     
    -      var numValues: Long = keyToNumValue.numValue
    -      var index: Long = 0L
    -      var valueRemoved: Boolean = false
    -      var valueForIndex: UnsafeRow = null
    +      private val allKeyToNumValues = keyToNumValues.iterator
     
    -      while (index < numValues) {
    -        if (valueForIndex == null) {
    -          valueForIndex = keyWithIndexToValue.get(key, index)
    +      private var currentKey: UnsafeRow = null
    +      private var numValues: Long = 0L
    +      private var index: Long = 0L
    +      private var valueRemoved: Boolean = false
    +
    +      // Push the data for the current key to the numValues store, and 
reset the tracking variables
    +      // to their empty state.
    +      private def storeCurrentKey(): Unit = {
    +        if (valueRemoved) {
    +          if (numValues >= 1) {
    +            keyToNumValues.put(currentKey, numValues)
    +          } else {
    +            keyToNumValues.remove(currentKey)
    +          }
             }
    -        if (condition(valueForIndex)) {
    -          if (numValues > 1) {
    -            val valueAtMaxIndex = keyWithIndexToValue.get(key, numValues - 
1)
    -            keyWithIndexToValue.put(key, index, valueAtMaxIndex)
    -            keyWithIndexToValue.remove(key, numValues - 1)
    -            valueForIndex = valueAtMaxIndex
    +
    +        currentKey = null
    +        numValues = 0
    +        index = 0
    +        valueRemoved = false
    +      }
    +
    +      // Find the next value satisfying the condition, updating 
`currentKey` and `numValues` if
    +      // needed. Returns null when no value can be found.
    +      private def findNextValueForIndex(): UnsafeRow = {
    +        while (index < numValues || allKeyToNumValues.hasNext) {
    +          // Note that index < numValues can only be true if we have a 
currentKey, since numValues
    +          // is only initialized to 0 or the current key's numValues.
    +          if (index < numValues) {
    +            // First search the values for the current key.
    +            val current = keyWithIndexToValue.get(currentKey, index)
    +            if (condition(current)) {
    +              return current
    +            } else {
    +              index += 1
    +            }
    +          } else if (allKeyToNumValues.hasNext) {
    +            // If we can't find a value for the current key, cleanup and 
start looking at the next.
    +            // This will also happen the first time the iterator is called.
    +            storeCurrentKey()
    +
    +            val currentKeyToNumValue = allKeyToNumValues.next()
    +            currentKey = currentKeyToNumValue.key
    +            numValues = currentKeyToNumValue.numValue
    --- End diff --
    
    updateNumValueForCurrentKey handles that resetting. I can pull it out to 
another function if you think it's best, but I'm unsure about the advantages, 
since updateNumValueForCurrentKey is already pretty short.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to