Github user tdas commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19327#discussion_r140914684
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
 ---
    @@ -309,19 +396,24 @@ class SymmetricHashJoinStateManager(
           stateStore.get(keyWithIndexRow(key, valueIndex))
         }
     
    -    /** Get all the values for key and all indices. */
    -    def getAll(key: UnsafeRow, numValues: Long): Iterator[UnsafeRow] = {
    +    /**
    +     * Get all values and indices for the provided key.
    +     * Should not return null.
    +     */
    +    def getAll(key: UnsafeRow, numValues: Long): 
Iterator[KeyWithIndexAndValue] = {
    +      val keyWithIndexAndValue = new KeyWithIndexAndValue()
           var index = 0
    -      new NextIterator[UnsafeRow] {
    -        override protected def getNext(): UnsafeRow = {
    +      new NextIterator[KeyWithIndexAndValue] {
    +        override protected def getNext(): KeyWithIndexAndValue = {
               if (index >= numValues) {
                 finished = true
                 null
               } else {
                 val keyWithIndex = keyWithIndexRow(key, index)
                 val value = stateStore.get(keyWithIndex)
                 index += 1
    -            value
    +            // return original index
    +            keyWithIndexAndValue.withNew(key, index - 1, value)
    --- End diff --
    
    super nit: better to make it 
    ```
    keyWithIndexAndValue.set(key, index, value)
    index +=1 
    keyWithIndexAndValue
    ```



---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to