This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d53de53fff55 [SPARK-52989][SS] Add explicit close() API to State Store 
iterators
d53de53fff55 is described below

commit d53de53fff554d6d6eda7113dae91fbd75840ebb
Author: Dylan Wong <dylan.w...@databricks.com>
AuthorDate: Wed Aug 6 10:15:16 2025 +0900

    [SPARK-52989][SS] Add explicit close() API to State Store iterators
    
    ### What changes were proposed in this pull request?
    
    Add explicit ```close()``` API to State Store iterators. This PR changes 
the ```ReadStateStore``` trait's ```prefixScan``` and ```iterator``` methods to 
return ```StateStoreIterator[UnsafeRowPair]``` instead of 
```Iterator[UnsafeRowPair]```. This new type has the ```close()``` method.
    
    The ```exists()``` method of MapStateImpl is also changed to close the 
iterator explicitly when it is no longer needed.
    
    Additionally ```close()``` calls are added to in TimerStateImpl, 
MapStateImplWithTTL in their iterators that consume the state store iterators.
    
    ### Why are the changes needed?
    
    These changes expose the close() method on state store iterators. This 
allows users of the StateStoreIterator to explicitly close it and its 
underlying resources when it's no longer needed. This change prevents the issue 
of having to hold on to the iterators until all rows are consumed and close() 
is called, or until the task completion/failure listener calls close() on the 
iterators.
    
    ### Does this PR introduce _any_ user-facing change?
    
    ### How was this patch tested?
    
    Existing unit tests, tests for the wrapper ```StateStoreIterator``` class 
and new test to verify that ```close()``` closes the underlying RocksDB 
iterator.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #51701 from dylanwong250/SPARK-52989.
    
    Lead-authored-by: Dylan Wong <dylan.w...@databricks.com>
    Co-authored-by: dylanwong250 <dylanwong...@gmail.com>
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
---
 .../statevariables/MapStateImpl.scala              |  5 +-
 .../transformwithstate/timers/TimerStateImpl.scala |  4 +-
 .../ttl/MapStateImplWithTTL.scala                  |  4 +-
 .../state/HDFSBackedStateStoreProvider.scala       | 26 +++++---
 .../sql/execution/streaming/state/RocksDB.scala    |  6 +-
 .../state/RocksDBStateStoreProvider.scala          | 26 ++++++--
 .../sql/execution/streaming/state/StateStore.scala | 37 ++++++++---
 .../streaming/state/MemoryStateStore.scala         | 10 ++-
 .../RocksDBStateStoreCheckpointFormatV2Suite.scala |  6 +-
 .../streaming/state/RocksDBStateStoreSuite.scala   | 74 ++++++++++++++++++++++
 .../streaming/state/StateStoreSuite.scala          | 40 ++++++++++++
 11 files changed, 202 insertions(+), 36 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/statevariables/MapStateImpl.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/statevariables/MapStateImpl.scala
index 4e608a5d5dbb..b71d625b118e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/statevariables/MapStateImpl.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/statevariables/MapStateImpl.scala
@@ -56,7 +56,10 @@ class MapStateImpl[K, V](
 
   /** Whether state exists or not. */
   override def exists(): Boolean = {
-    store.prefixScan(stateTypesEncoder.encodeGroupingKey(), stateName).nonEmpty
+    val iter = store.prefixScan(stateTypesEncoder.encodeGroupingKey(), 
stateName)
+    val result = iter.nonEmpty
+    iter.close()
+    result
   }
 
   /** Get the state value if it exists */
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/timers/TimerStateImpl.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/timers/TimerStateImpl.scala
index 6f6a9997b3ba..27c109f9de09 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/timers/TimerStateImpl.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/timers/TimerStateImpl.scala
@@ -199,7 +199,9 @@ class TimerStateImpl(
         }
       }
 
-      override protected def close(): Unit = { }
+      override protected def close(): Unit = {
+        iter.close()
+      }
     }
   }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/MapStateImplWithTTL.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/MapStateImplWithTTL.scala
index 64581006555e..aa4446af6da7 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/MapStateImplWithTTL.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/ttl/MapStateImplWithTTL.scala
@@ -128,7 +128,9 @@ metrics: Map[String, SQLMetric])
         }
       }
 
-      override protected def close(): Unit = {}
+      override protected def close(): Unit = {
+        unsafeRowPairIterator.close()
+      }
     }
   }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
index c362ac916384..0ba4b1955c82 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
@@ -82,8 +82,9 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
 
     override def get(key: UnsafeRow, colFamilyName: String): UnsafeRow = 
map.get(key)
 
-    override def iterator(colFamilyName: String): Iterator[UnsafeRowPair] = {
-      map.iterator()
+    override def iterator(colFamilyName: String): 
StateStoreIterator[UnsafeRowPair] = {
+      val iter = map.iterator()
+      new StateStoreIterator(iter)
     }
 
     override def abort(): Unit = {}
@@ -94,9 +95,11 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
       s"HDFSReadStateStore[stateStoreId=$stateStoreId_, version=$version]"
     }
 
-    override def prefixScan(prefixKey: UnsafeRow, colFamilyName: String):
-      Iterator[UnsafeRowPair] = {
-      map.prefixScan(prefixKey)
+    override def prefixScan(
+        prefixKey: UnsafeRow,
+        colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
+      val iter = map.prefixScan(prefixKey)
+      new StateStoreIterator(iter)
     }
 
     override def valuesIterator(key: UnsafeRow, colFamilyName: String): 
Iterator[UnsafeRow] = {
@@ -214,15 +217,18 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
      * Get an iterator of all the store data.
      * This can be called only after committing all the updates made in the 
current thread.
      */
-    override def iterator(colFamilyName: String): Iterator[UnsafeRowPair] = {
+    override def iterator(colFamilyName: String): 
StateStoreIterator[UnsafeRowPair] = {
       assertUseOfDefaultColFamily(colFamilyName)
-      mapToUpdate.iterator()
+      val iter = mapToUpdate.iterator()
+      new StateStoreIterator(iter)
     }
 
-    override def prefixScan(prefixKey: UnsafeRow, colFamilyName: String):
-      Iterator[UnsafeRowPair] = {
+    override def prefixScan(
+        prefixKey: UnsafeRow,
+        colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
       assertUseOfDefaultColFamily(colFamilyName)
-      mapToUpdate.prefixScan(prefixKey)
+      val iter = mapToUpdate.prefixScan(prefixKey)
+      new StateStoreIterator(iter)
     }
 
     override def metrics: StateStoreMetrics = {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
index 4365d131d088..85e2d72ec163 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
@@ -964,7 +964,7 @@ class RocksDB(
   /**
    * Get an iterator of all committed and uncommitted key-value pairs.
    */
-  def iterator(): Iterator[ByteArrayPair] = {
+  def iterator(): NextIterator[ByteArrayPair] = {
     updateMemoryUsageIfNeeded()
     val iter = db.newIterator()
     logInfo(log"Getting iterator from version ${MDC(LogKeys.LOADED_VERSION, 
loadedVersion)}")
@@ -1001,7 +1001,7 @@ class RocksDB(
   /**
    * Get an iterator of all committed and uncommitted key-value pairs for the 
given column family.
    */
-  def iterator(cfName: String): Iterator[ByteArrayPair] = {
+  def iterator(cfName: String): NextIterator[ByteArrayPair] = {
     updateMemoryUsageIfNeeded()
     if (!useColumnFamilies) {
       iterator()
@@ -1051,7 +1051,7 @@ class RocksDB(
 
   def prefixScan(
       prefix: Array[Byte],
-      cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): 
Iterator[ByteArrayPair] = {
+      cfName: String = StateStore.DEFAULT_COL_FAMILY_NAME): 
NextIterator[ByteArrayPair] = {
     updateMemoryUsageIfNeeded()
     val iter = db.newIterator()
     val updatedPrefix = if (useColumnFamilies) {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
index 36480691a516..6bc3dd568af7 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
@@ -315,7 +315,7 @@ private[sql] class RocksDBStateStoreProvider
       rocksDB.remove(kvEncoder._1.encodeKey(key), colFamilyName)
     }
 
-    override def iterator(colFamilyName: String): Iterator[UnsafeRowPair] = {
+    override def iterator(colFamilyName: String): 
StateStoreIterator[UnsafeRowPair] = {
       validateAndTransitionState(UPDATE)
       // Note this verify function only verify on the colFamilyName being 
valid,
       // we are actually doing prefix when useColumnFamilies,
@@ -323,9 +323,10 @@ private[sql] class RocksDBStateStoreProvider
       verifyColFamilyOperations("iterator", colFamilyName)
       val kvEncoder = keyValueEncoderMap.get(colFamilyName)
       val rowPair = new UnsafeRowPair()
-
       if (useColumnFamilies) {
-        rocksDB.iterator(colFamilyName).map { kv =>
+        val rocksDbIter = rocksDB.iterator(colFamilyName)
+
+        val iter = rocksDbIter.map { kv =>
           rowPair.withRows(kvEncoder._1.decodeKey(kv.key),
             kvEncoder._2.decodeValue(kv.value))
           if (!isValidated && rowPair.value != null && !useColumnFamilies) {
@@ -335,8 +336,12 @@ private[sql] class RocksDBStateStoreProvider
           }
           rowPair
         }
+
+        new StateStoreIterator(iter, rocksDbIter.closeIfNeeded)
       } else {
-        rocksDB.iterator().map { kv =>
+        val rocksDbIter = rocksDB.iterator()
+
+        val iter = rocksDbIter.map { kv =>
           rowPair.withRows(kvEncoder._1.decodeKey(kv.key),
             kvEncoder._2.decodeValue(kv.value))
           if (!isValidated && rowPair.value != null && !useColumnFamilies) {
@@ -346,11 +351,14 @@ private[sql] class RocksDBStateStoreProvider
           }
           rowPair
         }
+
+        new StateStoreIterator(iter, rocksDbIter.closeIfNeeded)
       }
     }
 
-    override def prefixScan(prefixKey: UnsafeRow, colFamilyName: String):
-      Iterator[UnsafeRowPair] = {
+    override def prefixScan(
+        prefixKey: UnsafeRow,
+        colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
       validateAndTransitionState(UPDATE)
       verifyColFamilyOperations("prefixScan", colFamilyName)
 
@@ -360,11 +368,15 @@ private[sql] class RocksDBStateStoreProvider
 
       val rowPair = new UnsafeRowPair()
       val prefix = kvEncoder._1.encodePrefixKey(prefixKey)
-      rocksDB.prefixScan(prefix, colFamilyName).map { kv =>
+
+      val rocksDbIter = rocksDB.prefixScan(prefix, colFamilyName)
+      val iter = rocksDbIter.map { kv =>
         rowPair.withRows(kvEncoder._1.decodeKey(kv.key),
           kvEncoder._2.decodeValue(kv.value))
         rowPair
       }
+
+      new StateStoreIterator(iter, rocksDbIter.closeIfNeeded)
     }
 
     var checkpointInfo: Option[StateStoreCheckpointInfo] = None
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index af0e7069eeef..2f3c05b72388 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.execution.streaming.state
 
+import java.io.Closeable
 import java.util.UUID
 import java.util.concurrent.{ConcurrentLinkedQueue, ScheduledFuture, TimeUnit}
 import javax.annotation.concurrent.GuardedBy
@@ -44,6 +45,25 @@ import 
org.apache.spark.sql.execution.streaming.state.MaintenanceTaskType._
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.util.{NextIterator, ThreadUtils, Utils}
 
+/**
+ * Represents an iterator that provides additional functionalities for state 
store use cases.
+ *
+ * `close()` is useful for freeing underlying iterator resources when the 
iterator is no longer
+ * needed.
+ *
+ * The caller MUST call `close()` on the iterator if it was not fully 
consumed, and it is no
+ * longer needed.
+ */
+class StateStoreIterator[A](
+    val iter: Iterator[A],
+    val onClose: () => Unit = () => {}) extends Iterator[A] with Closeable {
+  override def hasNext: Boolean = iter.hasNext
+
+  override def next(): A = iter.next()
+
+  override def close(): Unit = onClose()
+}
+
 sealed trait StateStoreEncoding {
   override def toString: String = this match {
     case StateStoreEncoding.UnsafeRow => "unsaferow"
@@ -117,10 +137,11 @@ trait ReadStateStore {
    */
   def prefixScan(
       prefixKey: UnsafeRow,
-      colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): 
Iterator[UnsafeRowPair]
+      colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): 
StateStoreIterator[UnsafeRowPair]
 
   /** Return an iterator containing all the key-value pairs in the StateStore. 
*/
-  def iterator(colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): 
Iterator[UnsafeRowPair]
+  def iterator(
+      colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): 
StateStoreIterator[UnsafeRowPair]
 
   /**
    * Clean up the resource.
@@ -227,8 +248,8 @@ trait StateStore extends ReadStateStore {
    * performed after initialization of the iterator. Callers should perform 
all updates before
    * calling this method if all updates should be visible in the returned 
iterator.
    */
-  override def iterator(colFamilyName: String = 
StateStore.DEFAULT_COL_FAMILY_NAME):
-    Iterator[UnsafeRowPair]
+  override def iterator(colFamilyName: String = 
StateStore.DEFAULT_COL_FAMILY_NAME)
+    : StateStoreIterator[UnsafeRowPair]
 
   /** Current metrics of the state store */
   def metrics: StateStoreMetrics
@@ -260,16 +281,16 @@ class WrappedReadStateStore(store: StateStore) extends 
ReadStateStore {
     colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): UnsafeRow = 
store.get(key,
     colFamilyName)
 
-  override def iterator(colFamilyName: String = 
StateStore.DEFAULT_COL_FAMILY_NAME):
-    Iterator[UnsafeRowPair] = store.iterator(colFamilyName)
+  override def iterator(colFamilyName: String = 
StateStore.DEFAULT_COL_FAMILY_NAME)
+    : StateStoreIterator[UnsafeRowPair] = store.iterator(colFamilyName)
 
   override def abort(): Unit = store.abort()
 
   override def release(): Unit = store.release()
 
   override def prefixScan(prefixKey: UnsafeRow,
-    colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): 
Iterator[UnsafeRowPair] =
-    store.prefixScan(prefixKey, colFamilyName)
+      colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME)
+    : StateStoreIterator[UnsafeRowPair] = store.prefixScan(prefixKey, 
colFamilyName)
 
   override def valuesIterator(key: UnsafeRow, colFamilyName: String): 
Iterator[UnsafeRow] = {
     store.valuesIterator(key, colFamilyName)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala
index 5e74c3e1b1c1..931b00abc17c 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala
@@ -26,8 +26,10 @@ class MemoryStateStore extends StateStore() {
   import scala.jdk.CollectionConverters._
   private val map = new ConcurrentHashMap[UnsafeRow, UnsafeRow]
 
-  override def iterator(colFamilyName: String): Iterator[UnsafeRowPair] = {
-    map.entrySet.iterator.asScala.map { case e => new UnsafeRowPair(e.getKey, 
e.getValue) }
+  override def iterator(colFamilyName: String): 
StateStoreIterator[UnsafeRowPair] = {
+    val iter =
+      map.entrySet.iterator.asScala.map { case e => new 
UnsafeRowPair(e.getKey, e.getValue) }
+    new StateStoreIterator(iter)
   }
 
   override def createColFamilyIfAbsent(
@@ -66,7 +68,9 @@ class MemoryStateStore extends StateStore() {
 
   override def hasCommitted: Boolean = true
 
-  override def prefixScan(prefixKey: UnsafeRow, colFamilyName: String): 
Iterator[UnsafeRowPair] = {
+  override def prefixScan(
+      prefixKey: UnsafeRow,
+      colFamilyName: String): StateStoreIterator[UnsafeRowPair] = {
     throw new UnsupportedOperationException("Doesn't support prefix scan!")
   }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala
index ace8c4db6ff1..91117abf830e 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala
@@ -77,12 +77,14 @@ case class CkptIdCollectingStateStoreWrapper(innerStore: 
StateStore) extends Sta
 
   override def prefixScan(
       prefixKey: UnsafeRow,
-      colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): 
Iterator[UnsafeRowPair] = {
+      colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME)
+    : StateStoreIterator[UnsafeRowPair] = {
     innerStore.prefixScan(prefixKey, colFamilyName)
   }
 
   override def iterator(
-      colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): 
Iterator[UnsafeRowPair] = {
+      colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME)
+    : StateStoreIterator[UnsafeRowPair] = {
     innerStore.iterator(colFamilyName)
   }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
index 1fb87de63fd6..e1f48441c4db 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
@@ -1650,6 +1650,80 @@ class RocksDBStateStoreSuite extends 
StateStoreSuiteBase[RocksDBStateStoreProvid
     }
   }
 
+  testWithColumnFamiliesAndEncodingTypes(
+    "closing the iterator also closes the underlying rocksdb iterator",
+    TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled 
=>
+
+    // use the same schema as value schema for single col key schema
+    tryWithProviderResource(newStoreProvider(valueSchema,
+      RangeKeyScanStateEncoderSpec(valueSchema, Seq(0)), colFamiliesEnabled)) 
{ provider =>
+      val store = provider.getStore(0)
+      try {
+        val cfName = if (colFamiliesEnabled) "testColFamily" else "default"
+        if (colFamiliesEnabled) {
+          store.createColFamilyIfAbsent(cfName,
+            valueSchema, valueSchema,
+            RangeKeyScanStateEncoderSpec(valueSchema, Seq(0)))
+        }
+
+        val timerTimestamps = Seq(1, 2, 3, 22)
+        timerTimestamps.foreach { ts =>
+          val keyRow = dataToValueRow(ts)
+          val valueRow = dataToValueRow(1)
+          store.put(keyRow, valueRow, cfName)
+          assert(valueRowToData(store.get(keyRow, cfName)) === 1)
+        }
+
+        val iter1 = store.iterator(cfName)
+        for (i <- 1 to 4) {
+          assert(iter1.hasNext)
+          iter1.next()
+        }
+        // We were fully able to process the 4 elements
+        assert(!iter1.hasNext)
+
+        val iter2 = store.iterator(cfName)
+        for (i <- 1 to 2) {
+          assert(iter2.hasNext)
+          iter2.next()
+        }
+        // Close the iterator
+        iter2.close()
+        // After closing, this will call AbstractRocksIterator.isValid which 
should throw and
+        // exception since it no longer owns the underlying rocksdb iterator
+        val exception1 = intercept[AssertionError] {
+          iter2.next()
+        }
+        // Check that the exception is thrown from 
AbstractRocksIterator.isValid
+        
assert(exception1.getStackTrace()(0).getClassName.contains("AbstractRocksIterator"))
+        assert(exception1.getStackTrace()(0).getMethodName.contains("isValid"))
+
+        // also check for prefix scan
+        val prefix = dataToValueRow(2)
+        val iter3 = store.prefixScan(prefix, cfName)
+
+        iter3.next()
+        assert(!iter3.hasNext)
+
+        val iter4 = store.prefixScan(prefix, cfName)
+        // Immediately close the iterator without calling next
+        iter4.close()
+
+        // Since we closed the iterator, this will throw an exception when we 
try to call next
+        val exception2 = intercept[AssertionError] {
+          iter4.next()
+        }
+        // Check that the exception is thrown from 
AbstractRocksIterator.isValid
+        
assert(exception2.getStackTrace()(0).getClassName.contains("AbstractRocksIterator"))
+        assert(exception2.getStackTrace()(0).getMethodName.contains("isValid"))
+
+        store.commit()
+      } finally {
+        if (!store.hasCommitted) store.abort()
+      }
+    }
+  }
+
   test("validate rocksdb values iterator correctness") {
     withSQLConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1") {
       tryWithProviderResource(newStoreProvider(useColumnFamilies = true,
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
index a15462b4baa6..6e795f236c8b 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
@@ -2172,6 +2172,46 @@ abstract class StateStoreSuiteBase[ProviderClass <: 
StateStoreProvider]
     assert(combinedMetrics.customMetrics(customTimingMetric) == 400L)
   }
 
+  test("StateStoreIterator onClose method is called only when close() is 
called") {
+    // Test that the iterator functions as normal without closing
+    {
+      var closed = false
+
+      val iterator = new StateStoreIterator(Iterator(1, 2, 3, 4), () => {
+        closed = true
+      })
+
+      // next() should work as expected
+      for (i <- 1 to 4) {
+        assert(iterator.next() == i)
+      }
+
+      // close() is never called, so closed should remain false
+      assert(!closed)
+    }
+    // Test that the onClose method is called when close() is called
+    {
+      var closed = false
+
+      val iterator = new StateStoreIterator(Iterator(1, 2, 3, 4), () => {
+        closed = true
+      })
+
+      // next() should work as expected
+      assert(iterator.next() == 1)
+      assert(iterator.next() == 2)
+
+      // close() should call the onClose function which sets closed to true
+      assert(!closed)
+      iterator.close()
+      assert(closed)
+
+      // Calling close() again should not cause any issue
+      iterator.close()
+      assert(closed)
+    }
+  }
+
   test("SPARK-35659: StateStore.put cannot put null value") {
     tryWithProviderResource(newStoreProvider()) { provider =>
       // Verify state before starting a new set of updates


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to