This is an automated email from the ASF dual-hosted git repository. kabhwan pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new d53de53fff55 [SPARK-52989][SS] Add explicit close() API to State Store iterators d53de53fff55 is described below commit d53de53fff554d6d6eda7113dae91fbd75840ebb Author: Dylan Wong <dylan.w...@databricks.com> AuthorDate: Wed Aug 6 10:15:16 2025 +0900 [SPARK-52989][SS] Add explicit close() API to State Store iterators ### What changes were proposed in this pull request? Add explicit ```close()``` API to State Store iterators. This PR changes the ```ReadStateStore``` trait's ```prefixScan``` and ```iterator``` methods to return ```StateStoreIterator[UnsafeRowPair]``` instead of ```Iterator[UnsafeRowPair]```. This new type has the ```close()``` method. The ```exists()``` method of MapStateImpl is also changed to close the iterator explicitly when it is no longer needed. Additionally ```close()``` calls are added to in TimerStateImpl, MapStateImplWithTTL in their iterators that consume the state store iterators. ### Why are the changes needed? These changes expose the close() method on state store iterators. This allows users of the StateStoreIterator to explicitly close it and its underlying resources when it's no longer needed. This change prevents the issue of having to hold on to the iterators until all rows are consumed and close() is called, or until the task completion/failure listener calls close() on the iterators. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? Existing unit tests, tests for the wrapper ```StateStoreIterator``` class and new test to verify that ```close()``` closes the underlying RocksDB iterator. ### Was this patch authored or co-authored using generative AI tooling? No Closes #51701 from dylanwong250/SPARK-52989. Lead-authored-by: Dylan Wong <dylan.w...@databricks.com> Co-authored-by: dylanwong250 <dylanwong...@gmail.com> Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> --- .../statevariables/MapStateImpl.scala | 5 +- .../transformwithstate/timers/TimerStateImpl.scala | 4 +- .../ttl/MapStateImplWithTTL.scala | 4 +- .../state/HDFSBackedStateStoreProvider.scala | 26 +++++--- .../sql/execution/streaming/state/RocksDB.scala | 6 +- .../state/RocksDBStateStoreProvider.scala | 26 ++++++-- .../sql/execution/streaming/state/StateStore.scala | 37 ++++++++--- .../streaming/state/MemoryStateStore.scala | 10 ++- .../RocksDBStateStoreCheckpointFormatV2Suite.scala | 6 +- .../streaming/state/RocksDBStateStoreSuite.scala | 74 ++++++++++++++++++++++ .../streaming/state/StateStoreSuite.scala | 40 ++++++++++++ 11 files changed, 202 insertions(+), 36 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/statevariables/MapStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/statevariables/MapStateImpl.scala index 4e608a5d5dbb..b71d625b118e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/statevariables/MapStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/statevariables/MapStateImpl.scala @@ -56,7 +56,10 @@ class MapStateImpl[K, V]( /** Whether state exists or not. */ override def exists(): Boolean = { - store.prefixScan(stateTypesEncoder.encodeGroupingKey(), stateName).nonEmpty + val iter = store.prefixScan(stateTypesEncoder.encodeGroupingKey(), stateName) + val result = iter.nonEmpty + iter.close() + result } /** Get the state value if it exists */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/timers/TimerStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/timers/TimerStateImpl.scala index 6f6a9997b3ba..27c109f9de09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/timers/TimerStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/timers/TimerStateImpl.scala @@ -199,7 +199,9 @@ class TimerStateImpl( } } - override protected def close(): Unit = { } + override protected def close(): Unit = { + iter.close() + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/MapStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/MapStateImplWithTTL.scala index 64581006555e..aa4446af6da7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/MapStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/MapStateImplWithTTL.scala @@ -128,7 +128,9 @@ metrics: Map[String, SQLMetric]) } } - override protected def close(): Unit = {} + override protected def close(): Unit = { + unsafeRowPairIterator.close() + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index c362ac916384..0ba4b1955c82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -82,8 +82,9 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with override def get(key: UnsafeRow, colFamilyName: String): UnsafeRow = map.get(key) - override def iterator(colFamilyName: String): Iterator[UnsafeRowPair] = { - map.iterator() + override def iterator(colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { + val iter = map.iterator() + new StateStoreIterator(iter) } override def abort(): Unit = {} @@ -94,9 +95,11 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with s"HDFSReadStateStore[stateStoreId=$stateStoreId_, version=$version]" } - override def prefixScan(prefixKey: UnsafeRow, colFamilyName: String): - Iterator[UnsafeRowPair] = { - map.prefixScan(prefixKey) + override def prefixScan( + prefixKey: UnsafeRow, + colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { + val iter = map.prefixScan(prefixKey) + new StateStoreIterator(iter) } override def valuesIterator(key: UnsafeRow, colFamilyName: String): Iterator[UnsafeRow] = { @@ -214,15 +217,18 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with * Get an iterator of all the store data. * This can be called only after committing all the updates made in the current thread. */ - override def iterator(colFamilyName: String): Iterator[UnsafeRowPair] = { + override def iterator(colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { assertUseOfDefaultColFamily(colFamilyName) - mapToUpdate.iterator() + val iter = mapToUpdate.iterator() + new StateStoreIterator(iter) } - override def prefixScan(prefixKey: UnsafeRow, colFamilyName: String): - Iterator[UnsafeRowPair] = { + override def prefixScan( + prefixKey: UnsafeRow, + colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { assertUseOfDefaultColFamily(colFamilyName) - mapToUpdate.prefixScan(prefixKey) + val iter = mapToUpdate.prefixScan(prefixKey) + new StateStoreIterator(iter) } override def metrics: StateStoreMetrics = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 4365d131d088..85e2d72ec163 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -964,7 +964,7 @@ class RocksDB( /** * Get an iterator of all committed and uncommitted key-value pairs. */ - def iterator(): Iterator[ByteArrayPair] = { + def iterator(): NextIterator[ByteArrayPair] = { updateMemoryUsageIfNeeded() val iter = db.newIterator() logInfo(log"Getting iterator from version ${MDC(LogKeys.LOADED_VERSION, loadedVersion)}") @@ -1001,7 +1001,7 @@ class RocksDB( /** * Get an iterator of all committed and uncommitted key-value pairs for the given column family. */ - def iterator(cfName: String): Iterator[ByteArrayPair] = { + def iterator(cfName: String): NextIterator[ByteArrayPair] = { updateMemoryUsageIfNeeded() if (!useColumnFamilies) { iterator() @@ -1051,7 +1051,7 @@ class RocksDB( def prefixScan( prefix: Array[Byte], - cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[ByteArrayPair] = { + cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): NextIterator[ByteArrayPair] = { updateMemoryUsageIfNeeded() val iter = db.newIterator() val updatedPrefix = if (useColumnFamilies) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 36480691a516..6bc3dd568af7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -315,7 +315,7 @@ private[sql] class RocksDBStateStoreProvider rocksDB.remove(kvEncoder._1.encodeKey(key), colFamilyName) } - override def iterator(colFamilyName: String): Iterator[UnsafeRowPair] = { + override def iterator(colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { validateAndTransitionState(UPDATE) // Note this verify function only verify on the colFamilyName being valid, // we are actually doing prefix when useColumnFamilies, @@ -323,9 +323,10 @@ private[sql] class RocksDBStateStoreProvider verifyColFamilyOperations("iterator", colFamilyName) val kvEncoder = keyValueEncoderMap.get(colFamilyName) val rowPair = new UnsafeRowPair() - if (useColumnFamilies) { - rocksDB.iterator(colFamilyName).map { kv => + val rocksDbIter = rocksDB.iterator(colFamilyName) + + val iter = rocksDbIter.map { kv => rowPair.withRows(kvEncoder._1.decodeKey(kv.key), kvEncoder._2.decodeValue(kv.value)) if (!isValidated && rowPair.value != null && !useColumnFamilies) { @@ -335,8 +336,12 @@ private[sql] class RocksDBStateStoreProvider } rowPair } + + new StateStoreIterator(iter, rocksDbIter.closeIfNeeded) } else { - rocksDB.iterator().map { kv => + val rocksDbIter = rocksDB.iterator() + + val iter = rocksDbIter.map { kv => rowPair.withRows(kvEncoder._1.decodeKey(kv.key), kvEncoder._2.decodeValue(kv.value)) if (!isValidated && rowPair.value != null && !useColumnFamilies) { @@ -346,11 +351,14 @@ private[sql] class RocksDBStateStoreProvider } rowPair } + + new StateStoreIterator(iter, rocksDbIter.closeIfNeeded) } } - override def prefixScan(prefixKey: UnsafeRow, colFamilyName: String): - Iterator[UnsafeRowPair] = { + override def prefixScan( + prefixKey: UnsafeRow, + colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { validateAndTransitionState(UPDATE) verifyColFamilyOperations("prefixScan", colFamilyName) @@ -360,11 +368,15 @@ private[sql] class RocksDBStateStoreProvider val rowPair = new UnsafeRowPair() val prefix = kvEncoder._1.encodePrefixKey(prefixKey) - rocksDB.prefixScan(prefix, colFamilyName).map { kv => + + val rocksDbIter = rocksDB.prefixScan(prefix, colFamilyName) + val iter = rocksDbIter.map { kv => rowPair.withRows(kvEncoder._1.decodeKey(kv.key), kvEncoder._2.decodeValue(kv.value)) rowPair } + + new StateStoreIterator(iter, rocksDbIter.closeIfNeeded) } var checkpointInfo: Option[StateStoreCheckpointInfo] = None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index af0e7069eeef..2f3c05b72388 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.streaming.state +import java.io.Closeable import java.util.UUID import java.util.concurrent.{ConcurrentLinkedQueue, ScheduledFuture, TimeUnit} import javax.annotation.concurrent.GuardedBy @@ -44,6 +45,25 @@ import org.apache.spark.sql.execution.streaming.state.MaintenanceTaskType._ import org.apache.spark.sql.types.StructType import org.apache.spark.util.{NextIterator, ThreadUtils, Utils} +/** + * Represents an iterator that provides additional functionalities for state store use cases. + * + * `close()` is useful for freeing underlying iterator resources when the iterator is no longer + * needed. + * + * The caller MUST call `close()` on the iterator if it was not fully consumed, and it is no + * longer needed. + */ +class StateStoreIterator[A]( + val iter: Iterator[A], + val onClose: () => Unit = () => {}) extends Iterator[A] with Closeable { + override def hasNext: Boolean = iter.hasNext + + override def next(): A = iter.next() + + override def close(): Unit = onClose() +} + sealed trait StateStoreEncoding { override def toString: String = this match { case StateStoreEncoding.UnsafeRow => "unsaferow" @@ -117,10 +137,11 @@ trait ReadStateStore { */ def prefixScan( prefixKey: UnsafeRow, - colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[UnsafeRowPair] + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): StateStoreIterator[UnsafeRowPair] /** Return an iterator containing all the key-value pairs in the StateStore. */ - def iterator(colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[UnsafeRowPair] + def iterator( + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): StateStoreIterator[UnsafeRowPair] /** * Clean up the resource. @@ -227,8 +248,8 @@ trait StateStore extends ReadStateStore { * performed after initialization of the iterator. Callers should perform all updates before * calling this method if all updates should be visible in the returned iterator. */ - override def iterator(colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): - Iterator[UnsafeRowPair] + override def iterator(colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME) + : StateStoreIterator[UnsafeRowPair] /** Current metrics of the state store */ def metrics: StateStoreMetrics @@ -260,16 +281,16 @@ class WrappedReadStateStore(store: StateStore) extends ReadStateStore { colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): UnsafeRow = store.get(key, colFamilyName) - override def iterator(colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): - Iterator[UnsafeRowPair] = store.iterator(colFamilyName) + override def iterator(colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME) + : StateStoreIterator[UnsafeRowPair] = store.iterator(colFamilyName) override def abort(): Unit = store.abort() override def release(): Unit = store.release() override def prefixScan(prefixKey: UnsafeRow, - colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[UnsafeRowPair] = - store.prefixScan(prefixKey, colFamilyName) + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME) + : StateStoreIterator[UnsafeRowPair] = store.prefixScan(prefixKey, colFamilyName) override def valuesIterator(key: UnsafeRow, colFamilyName: String): Iterator[UnsafeRow] = { store.valuesIterator(key, colFamilyName) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala index 5e74c3e1b1c1..931b00abc17c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala @@ -26,8 +26,10 @@ class MemoryStateStore extends StateStore() { import scala.jdk.CollectionConverters._ private val map = new ConcurrentHashMap[UnsafeRow, UnsafeRow] - override def iterator(colFamilyName: String): Iterator[UnsafeRowPair] = { - map.entrySet.iterator.asScala.map { case e => new UnsafeRowPair(e.getKey, e.getValue) } + override def iterator(colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { + val iter = + map.entrySet.iterator.asScala.map { case e => new UnsafeRowPair(e.getKey, e.getValue) } + new StateStoreIterator(iter) } override def createColFamilyIfAbsent( @@ -66,7 +68,9 @@ class MemoryStateStore extends StateStore() { override def hasCommitted: Boolean = true - override def prefixScan(prefixKey: UnsafeRow, colFamilyName: String): Iterator[UnsafeRowPair] = { + override def prefixScan( + prefixKey: UnsafeRow, + colFamilyName: String): StateStoreIterator[UnsafeRowPair] = { throw new UnsupportedOperationException("Doesn't support prefix scan!") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala index ace8c4db6ff1..91117abf830e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala @@ -77,12 +77,14 @@ case class CkptIdCollectingStateStoreWrapper(innerStore: StateStore) extends Sta override def prefixScan( prefixKey: UnsafeRow, - colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[UnsafeRowPair] = { + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME) + : StateStoreIterator[UnsafeRowPair] = { innerStore.prefixScan(prefixKey, colFamilyName) } override def iterator( - colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[UnsafeRowPair] = { + colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME) + : StateStoreIterator[UnsafeRowPair] = { innerStore.iterator(colFamilyName) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala index 1fb87de63fd6..e1f48441c4db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala @@ -1650,6 +1650,80 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid } } + testWithColumnFamiliesAndEncodingTypes( + "closing the iterator also closes the underlying rocksdb iterator", + TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled => + + // use the same schema as value schema for single col key schema + tryWithProviderResource(newStoreProvider(valueSchema, + RangeKeyScanStateEncoderSpec(valueSchema, Seq(0)), colFamiliesEnabled)) { provider => + val store = provider.getStore(0) + try { + val cfName = if (colFamiliesEnabled) "testColFamily" else "default" + if (colFamiliesEnabled) { + store.createColFamilyIfAbsent(cfName, + valueSchema, valueSchema, + RangeKeyScanStateEncoderSpec(valueSchema, Seq(0))) + } + + val timerTimestamps = Seq(1, 2, 3, 22) + timerTimestamps.foreach { ts => + val keyRow = dataToValueRow(ts) + val valueRow = dataToValueRow(1) + store.put(keyRow, valueRow, cfName) + assert(valueRowToData(store.get(keyRow, cfName)) === 1) + } + + val iter1 = store.iterator(cfName) + for (i <- 1 to 4) { + assert(iter1.hasNext) + iter1.next() + } + // We were fully able to process the 4 elements + assert(!iter1.hasNext) + + val iter2 = store.iterator(cfName) + for (i <- 1 to 2) { + assert(iter2.hasNext) + iter2.next() + } + // Close the iterator + iter2.close() + // After closing, this will call AbstractRocksIterator.isValid which should throw and + // exception since it no longer owns the underlying rocksdb iterator + val exception1 = intercept[AssertionError] { + iter2.next() + } + // Check that the exception is thrown from AbstractRocksIterator.isValid + assert(exception1.getStackTrace()(0).getClassName.contains("AbstractRocksIterator")) + assert(exception1.getStackTrace()(0).getMethodName.contains("isValid")) + + // also check for prefix scan + val prefix = dataToValueRow(2) + val iter3 = store.prefixScan(prefix, cfName) + + iter3.next() + assert(!iter3.hasNext) + + val iter4 = store.prefixScan(prefix, cfName) + // Immediately close the iterator without calling next + iter4.close() + + // Since we closed the iterator, this will throw an exception when we try to call next + val exception2 = intercept[AssertionError] { + iter4.next() + } + // Check that the exception is thrown from AbstractRocksIterator.isValid + assert(exception2.getStackTrace()(0).getClassName.contains("AbstractRocksIterator")) + assert(exception2.getStackTrace()(0).getMethodName.contains("isValid")) + + store.commit() + } finally { + if (!store.hasCommitted) store.abort() + } + } + } + test("validate rocksdb values iterator correctness") { withSQLConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1") { tryWithProviderResource(newStoreProvider(useColumnFamilies = true, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index a15462b4baa6..6e795f236c8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -2172,6 +2172,46 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] assert(combinedMetrics.customMetrics(customTimingMetric) == 400L) } + test("StateStoreIterator onClose method is called only when close() is called") { + // Test that the iterator functions as normal without closing + { + var closed = false + + val iterator = new StateStoreIterator(Iterator(1, 2, 3, 4), () => { + closed = true + }) + + // next() should work as expected + for (i <- 1 to 4) { + assert(iterator.next() == i) + } + + // close() is never called, so closed should remain false + assert(!closed) + } + // Test that the onClose method is called when close() is called + { + var closed = false + + val iterator = new StateStoreIterator(Iterator(1, 2, 3, 4), () => { + closed = true + }) + + // next() should work as expected + assert(iterator.next() == 1) + assert(iterator.next() == 2) + + // close() should call the onClose function which sets closed to true + assert(!closed) + iterator.close() + assert(closed) + + // Calling close() again should not cause any issue + iterator.close() + assert(closed) + } + } + test("SPARK-35659: StateStore.put cannot put null value") { tryWithProviderResource(newStoreProvider()) { provider => // Verify state before starting a new set of updates --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org