Github user tdas commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19327#discussion_r140910588
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
 ---
    @@ -87,70 +87,157 @@ class SymmetricHashJoinStateManager(
       }
     
       /**
    -   * Remove using a predicate on keys. See class docs for more context and 
implement details.
    +   * Remove using a predicate on keys.
    +   *
    +   * This produces an iterator over the (key, value) pairs satisfying 
condition(key), where the
    +   * underlying store is updated as a side-effect of producing next.
    +   *
    +   * This implies the iterator must be consumed fully without any other 
operations on this manager
    +   * or the underlying store being interleaved.
        */
    -  def removeByKeyCondition(condition: UnsafeRow => Boolean): Unit = {
    -    val allKeyToNumValues = keyToNumValues.iterator
    -
    -    while (allKeyToNumValues.hasNext) {
    -      val keyToNumValue = allKeyToNumValues.next
    -      if (condition(keyToNumValue.key)) {
    -        keyToNumValues.remove(keyToNumValue.key)
    -        keyWithIndexToValue.removeAllValues(keyToNumValue.key, 
keyToNumValue.numValue)
    +  def removeByKeyCondition(condition: UnsafeRow => Boolean): 
Iterator[UnsafeRowPair] = {
    +    new NextIterator[UnsafeRowPair] {
    +
    +      private val allKeyToNumValues = keyToNumValues.iterator
    +
    +      private var currentKeyToNumValue: KeyAndNumValues = null
    +      private var currentValues: Iterator[KeyWithIndexAndValue] = null
    +
    +      private def currentKey = currentKeyToNumValue.key
    +
    +      private val reusedPair = new UnsafeRowPair()
    +
    +      private def getAndRemoveValue() = {
    +        val keyWithIndexAndValue = currentValues.next()
    +        keyWithIndexToValue.remove(currentKey, 
keyWithIndexAndValue.valueIndex)
    +        reusedPair.withRows(currentKey, keyWithIndexAndValue.value)
    +      }
    +
    +      override def getNext(): UnsafeRowPair = {
    +        if (currentValues != null && currentValues.hasNext) {
    +          return getAndRemoveValue()
    +        } else {
    +          while (allKeyToNumValues.hasNext) {
    --- End diff --
    
    Add comments on what this while look is doing.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to