Github user tdas commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19327#discussion_r140614563
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
 ---
    @@ -89,61 +89,124 @@ class SymmetricHashJoinStateManager(
       /**
        * Remove using a predicate on keys. See class docs for more context and 
implement details.
        */
    -  def removeByKeyCondition(condition: UnsafeRow => Boolean): Unit = {
    -    val allKeyToNumValues = keyToNumValues.iterator
    -
    -    while (allKeyToNumValues.hasNext) {
    -      val keyToNumValue = allKeyToNumValues.next
    -      if (condition(keyToNumValue.key)) {
    -        keyToNumValues.remove(keyToNumValue.key)
    -        keyWithIndexToValue.removeAllValues(keyToNumValue.key, 
keyToNumValue.numValue)
    +  def removeByKeyCondition(condition: UnsafeRow => Boolean): 
Iterator[(UnsafeRow, UnsafeRow)] = {
    +    new NextIterator[(UnsafeRow, UnsafeRow)] {
    +
    +      private val allKeyToNumValues = keyToNumValues.iterator
    +
    +      private var currentKeyToNumValue: Option[KeyAndNumValues] = None
    +      private var currentValues: Option[Iterator[(UnsafeRow, Long)]] = None
    +
    +      private def currentKey = currentKeyToNumValue.get.key
    +
    +      private def getAndRemoveValue() = {
    +        val (current, index) = currentValues.get.next()
    +        keyWithIndexToValue.remove(currentKey, index)
    +        (currentKey, current)
    +      }
    +
    +      override def getNext(): (UnsafeRow, UnsafeRow) = {
    +        if (currentValues.nonEmpty && currentValues.get.hasNext) {
    +          return getAndRemoveValue()
    +        } else {
    +          while (allKeyToNumValues.hasNext) {
    +            currentKeyToNumValue = Some(allKeyToNumValues.next())
    +            if (condition(currentKey)) {
    +              currentValues = Some(keyWithIndexToValue.getAllWithIndex(
    +                currentKey, currentKeyToNumValue.get.numValue))
    +              keyToNumValues.remove(currentKey)
    +
    +              if (currentValues.nonEmpty && currentValues.get.hasNext) {
    +                return getAndRemoveValue()
    +              }
    +            }
    +          }
    +        }
    +
    +        finished = true
    +        null
           }
    +
    +      override def close: Unit = {}
         }
       }
     
       /**
        * Remove using a predicate on values. See class docs for more context 
and implementation details.
        */
    -  def removeByValueCondition(condition: UnsafeRow => Boolean): Unit = {
    -    val allKeyToNumValues = keyToNumValues.iterator
    +  def removeByValueCondition(condition: UnsafeRow => Boolean): 
Iterator[(UnsafeRow, UnsafeRow)] = {
    +    new NextIterator[(UnsafeRow, UnsafeRow)] {
    +
    +      private val allKeyToNumValues = keyToNumValues.iterator
     
    -    while (allKeyToNumValues.hasNext) {
    -      val keyToNumValue = allKeyToNumValues.next
    -      val key = keyToNumValue.key
    +      private var currentKeyToNumValue: Option[KeyAndNumValues] = None
     
    -      var numValues: Long = keyToNumValue.numValue
    +      private def currentKey = currentKeyToNumValue.get.key
    +
    +      var numValues: Long = 0L
           var index: Long = 0L
           var valueRemoved: Boolean = false
           var valueForIndex: UnsafeRow = null
     
    -      while (index < numValues) {
    -        if (valueForIndex == null) {
    -          valueForIndex = keyWithIndexToValue.get(key, index)
    +      private def cleanupCurrentKey(): Unit = {
    +        if (valueRemoved) {
    +          if (numValues >= 1) {
    +            keyToNumValues.put(currentKey, numValues)
    +          } else {
    +            keyToNumValues.remove(currentKey)
    +          }
    +        }
    +
    +        numValues = 0
    +        index = 0
    +        valueRemoved = false
    +        valueForIndex = null
    +      }
    +
    +      override def getNext(): (UnsafeRow, UnsafeRow) = {
    +        // TODO: there has to be a better way to express this but I don't 
know what it is
    +        while (valueForIndex == null && (index < numValues || 
allKeyToNumValues.hasNext)) {
    +          if (index < numValues) {
    +            val current = keyWithIndexToValue.get(currentKey, index)
    --- End diff --
    
    current -> nextValue, to avoid confusion with valueWithIndex (assuming that 
gets renamed the currentValue)


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to