Github user tdas commented on a diff in the pull request:
https://github.com/apache/spark/pull/19327#discussion_r140914684
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
---
@@ -309,19 +396,24 @@ class SymmetricHashJoinStateManager(
stateStore.get(keyWithIndexRow(key, valueIndex))
}
- /** Get all the values for key and all indices. */
- def getAll(key: UnsafeRow, numValues: Long): Iterator[UnsafeRow] = {
+ /**
+ * Get all values and indices for the provided key.
+ * Should not return null.
+ */
+ def getAll(key: UnsafeRow, numValues: Long):
Iterator[KeyWithIndexAndValue] = {
+ val keyWithIndexAndValue = new KeyWithIndexAndValue()
var index = 0
- new NextIterator[UnsafeRow] {
- override protected def getNext(): UnsafeRow = {
+ new NextIterator[KeyWithIndexAndValue] {
+ override protected def getNext(): KeyWithIndexAndValue = {
if (index >= numValues) {
finished = true
null
} else {
val keyWithIndex = keyWithIndexRow(key, index)
val value = stateStore.get(keyWithIndex)
index += 1
- value
+ // return original index
+ keyWithIndexAndValue.withNew(key, index - 1, value)
--- End diff --
super nit: better to make it
```
keyWithIndexAndValue.set(key, index, value)
index +=1
keyWithIndexAndValue
```
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]