This is an automated email from the ASF dual-hosted git repository.

ashrigondekar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 51494a7f1216 [SPARK-51596][SS] Fix concurrent StateStoreProvider 
maintenance and closing
51494a7f1216 is described below

commit 51494a7f1216e56e0b88b00c42ec3f47f4343a45
Author: Eric Marnadi <eric.marn...@databricks.com>
AuthorDate: Fri Jul 18 15:18:38 2025 -0700

    [SPARK-51596][SS] Fix concurrent StateStoreProvider maintenance and closing
    
    ### What changes were proposed in this pull request?
    Moves the unload operation away from task thread into the maintenance 
thread. To ensure unloading still occurs ASAP (rather than potentially waiting 
for the maintenance interval) as was introduced by 
https://issues.apache.org/jira/browse/SPARK-33827, we immediately trigger a 
maintenance thread to do the unload.
    
    This gives us an extra benefit that unloading other providers doesn't block 
the task thread. To capitalize on this, unload() should not hold the 
loadedProviders lock the entire time (which will block other task threads), but 
instead release it once it has deleted the unloading providers from the map and 
close the providers without the lock held.
    
    ### Why are the changes needed?
    Currently, both the task thread and maintenance thread can call unload() on 
a provider. This leads to a race condition where the maintenance could be 
conducting maintenance while the task thread is closing the provider, leading 
to unexpected behavior.
    
    ### Does this PR introduce any user-facing change?
    No
    
    ### How was this patch tested?
    Added unit test
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #51565 from ericm-db/maint-changes.
    
    Authored-by: Eric Marnadi <eric.marn...@databricks.com>
    Signed-off-by: Anish Shrigondekar <anish.shrigonde...@databricks.com>
---
 .../org/apache/spark/sql/internal/SQLConf.scala    |  10 +
 .../sql/execution/streaming/state/StateStore.scala | 260 +++++++++++---
 .../execution/streaming/state/StateStoreConf.scala |   2 +
 .../streaming/state/StateStoreSuite.scala          | 388 ++++++++++++++++++++-
 4 files changed, 607 insertions(+), 53 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index e076f4ede051..d46edef90c2d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -2436,6 +2436,13 @@ object SQLConf {
       .timeConf(TimeUnit.SECONDS)
       .createWithDefault(300L)
 
+  val STATE_STORE_MAINTENANCE_PROCESSING_TIMEOUT =
+    buildConf("spark.sql.streaming.stateStore.maintenanceProcessingTimeout")
+      .internal()
+      .doc("Timeout in seconds to wait for maintenance to process this 
partition.")
+      .timeConf(TimeUnit.SECONDS)
+      .createWithDefault(30L)
+
   val STATE_SCHEMA_CHECK_ENABLED =
     buildConf("spark.sql.streaming.stateStore.stateSchemaCheck")
       .doc("When true, Spark will validate the state schema against schema on 
existing state and " +
@@ -6343,6 +6350,9 @@ class SQLConf extends Serializable with Logging with 
SqlApiConf {
 
   def stateStoreMaintenanceShutdownTimeout: Long = 
getConf(STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT)
 
+  def stateStoreMaintenanceProcessingTimeout: Long =
+    getConf(STATE_STORE_MAINTENANCE_PROCESSING_TIMEOUT)
+
   def stateStoreMinDeltasForSnapshot: Int = 
getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT)
 
   def stateStoreFormatValidationEnabled: Boolean = 
getConf(STATE_STORE_FORMAT_VALIDATION_ENABLED)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index 6e4befc25e6e..f4b36f2a1acb 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -18,10 +18,11 @@
 package org.apache.spark.sql.execution.streaming.state
 
 import java.util.UUID
-import java.util.concurrent.{ScheduledFuture, TimeUnit}
+import java.util.concurrent.{ConcurrentLinkedQueue, ScheduledFuture, TimeUnit}
 import javax.annotation.concurrent.GuardedBy
 
 import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
 import scala.util.control.NonFatal
 
 import org.apache.hadoop.conf.Configuration
@@ -31,13 +32,14 @@ import org.json4s.JsonAST.JValue
 import org.json4s.JsonDSL._
 import org.json4s.jackson.JsonMethods.{compact, render}
 
-import org.apache.spark.{SparkContext, SparkEnv, SparkException}
+import org.apache.spark.{SparkContext, SparkEnv, SparkException, TaskContext}
 import org.apache.spark.internal.{Logging, LogKeys, MDC}
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.sql.catalyst.util.UnsafeRowUtils
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
 import org.apache.spark.sql.execution.streaming.{StatefulOperatorStateInfo, 
StreamExecution}
+import org.apache.spark.sql.execution.streaming.state.MaintenanceTaskType._
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.util.{NextIterator, ThreadUtils, Utils}
 
@@ -53,6 +55,14 @@ object StateStoreEncoding {
   case object Avro extends StateStoreEncoding
 }
 
+sealed trait MaintenanceTaskType
+
+object MaintenanceTaskType {
+  case object FromUnloadedProvidersQueue extends MaintenanceTaskType
+  case object FromTaskThread extends MaintenanceTaskType
+  case object FromLoadedProviders extends MaintenanceTaskType
+}
+
 /**
  * Base trait for a versioned key-value store which provides read operations. 
Each instance of a
  * `ReadStateStore` represents a specific version of state data, and such 
instances are created
@@ -554,7 +564,11 @@ trait StateStoreProvider {
    */
   def stateStoreId: StateStoreId
 
-  /** Called when the provider instance is unloaded from the executor */
+  /**
+   * Called when the provider instance is unloaded from the executor
+   * WARNING: IF PROVIDER FROM [[StateStore.loadedProviders]],
+   * CLOSE MUST ONLY BE CALLED FROM MAINTENANCE THREAD!
+   */
   def close(): Unit
 
   /**
@@ -843,6 +857,9 @@ object StateStore extends Logging {
 
   private val maintenanceThreadPoolLock = new Object
 
+  private val unloadedProvidersToClose =
+    new ConcurrentLinkedQueue[(StateStoreProviderId, StateStoreProvider)]
+
   // This set is to keep track of the partitions that are queued
   // for maintenance or currently have maintenance running on them
   // to prevent the same partition from being processed concurrently.
@@ -1012,7 +1029,21 @@ object StateStore extends Logging {
       if (!storeConf.unloadOnCommit) {
         val otherProviderIds = loadedProviders.keys.filter(_ != 
storeProviderId).toSeq
         val providerIdsToUnload = reportActiveStoreInstance(storeProviderId, 
otherProviderIds)
-        providerIdsToUnload.foreach(unload(_))
+        val taskContextIdLogLine = Option(TaskContext.get()).map { tc =>
+          log"taskId=${MDC(LogKeys.TASK_ID, tc.taskAttemptId())}"
+        }.getOrElse(log"")
+        providerIdsToUnload.foreach(id => {
+          loadedProviders.remove(id).foreach( provider => {
+            // Trigger maintenance thread to immediately do maintenance on and 
close the provider.
+            // Doing maintenance first allows us to do maintenance for a 
constantly-moving state
+            // store.
+            logInfo(log"Submitted maintenance from task thread to close " +
+              log"provider=${MDC(LogKeys.STATE_STORE_PROVIDER_ID, id)}." + 
taskContextIdLogLine +
+              log"Removed provider from loadedProviders")
+            submitMaintenanceWorkForProvider(
+              id, provider, storeConf, MaintenanceTaskType.FromTaskThread)
+          })
+        })
       }
 
       provider
@@ -1029,14 +1060,30 @@ object StateStore extends Logging {
     }
   }
 
-  /** Unload a state store provider */
-  def unload(storeProviderId: StateStoreProviderId): Unit = 
loadedProviders.synchronized {
-    loadedProviders.remove(storeProviderId).foreach(_.close())
+  /**
+   * Unload a state store provider.
+   * If alreadyRemovedFromLoadedProviders is None, provider will be
+   * removed from loadedProviders and closed.
+   * If alreadyRemovedFromLoadedProviders is Some, provider will be closed
+   * using passed in provider.
+   * WARNING: CAN ONLY BE CALLED FROM MAINTENANCE THREAD!
+   */
+  def removeFromLoadedProvidersAndClose(
+      storeProviderId: StateStoreProviderId,
+      alreadyRemovedProvider: Option[StateStoreProvider] = None): Unit = {
+    val providerToClose = alreadyRemovedProvider.orElse {
+      loadedProviders.synchronized {
+        loadedProviders.remove(storeProviderId)
+      }
+    }
+    providerToClose.foreach { provider =>
+      provider.close()
+    }
   }
 
   /** Unload all state store providers: unit test purpose */
   private[sql] def unloadAll(): Unit = loadedProviders.synchronized {
-    loadedProviders.keySet.foreach { key => unload(key) }
+    loadedProviders.keySet.foreach { key => 
removeFromLoadedProvidersAndClose(key) }
     loadedProviders.clear()
   }
 
@@ -1075,7 +1122,7 @@ object StateStore extends Logging {
 
   /** Unload and stop all state store providers */
   def stop(): Unit = loadedProviders.synchronized {
-    loadedProviders.keySet.foreach { key => unload(key) }
+    loadedProviders.keySet.foreach { key => 
removeFromLoadedProvidersAndClose(key) }
     loadedProviders.clear()
     _coordRef = null
     stopMaintenanceTask()
@@ -1090,7 +1137,7 @@ object StateStore extends Logging {
       if (SparkEnv.get != null && !isMaintenanceRunning && 
!storeConf.unloadOnCommit) {
         maintenanceTask = new MaintenanceTask(
           storeConf.maintenanceInterval,
-          task = { doMaintenance() }
+          task = { doMaintenance(storeConf) }
         )
         maintenanceThreadPool = new 
MaintenanceThreadPool(numMaintenanceThreads,
           maintenanceShutdownTimeout)
@@ -1099,6 +1146,27 @@ object StateStore extends Logging {
     }
   }
 
+  // Wait until this partition can be processed
+  private def awaitProcessThisPartition(
+      id: StateStoreProviderId,
+      timeoutMs: Long): Boolean = maintenanceThreadPoolLock synchronized  {
+    val startTime = System.currentTimeMillis()
+    val endTime = startTime + timeoutMs
+
+    // If immediate processing fails, wait with timeout
+    var canProcessThisPartition = processThisPartition(id)
+    while (!canProcessThisPartition && System.currentTimeMillis() < endTime) {
+      maintenanceThreadPoolLock.wait(timeoutMs)
+      canProcessThisPartition = processThisPartition(id)
+    }
+    val elapsedTime = System.currentTimeMillis() - startTime
+    logInfo(log"Waited for ${MDC(LogKeys.TOTAL_TIME, elapsedTime)} ms to be 
able to process " +
+      log"maintenance for partition ${MDC(LogKeys.STATE_STORE_PROVIDER_ID, 
id)}")
+    canProcessThisPartition
+  }
+
+  private def doMaintenance(): Unit = doMaintenance(StateStoreConf.empty)
+
   private def processThisPartition(id: StateStoreProviderId): Boolean = {
     maintenanceThreadPoolLock.synchronized {
       if (!maintenancePartitions.contains(id)) {
@@ -1114,56 +1182,42 @@ object StateStore extends Logging {
    * Execute background maintenance task in all the loaded store providers if 
they are still
    * the active instances according to the coordinator.
    */
-  private def doMaintenance(): Unit = {
+  private def doMaintenance(storeConf: StateStoreConf): Unit = {
     logDebug("Doing maintenance")
     if (SparkEnv.get == null) {
       throw new IllegalStateException("SparkEnv not active, cannot do 
maintenance on StateStores")
     }
+
+    // Providers that couldn't be processed now and need to be added back to 
the queue
+    val providersToRequeue = new ArrayBuffer[(StateStoreProviderId, 
StateStoreProvider)]()
+
+    // unloadedProvidersToClose are StateStoreProviders that have been removed 
from
+    // loadedProviders, and can now be processed for maintenance. This queue 
contains
+    // providers for which we weren't able to process for maintenance on the 
previous iteration
+    while (!unloadedProvidersToClose.isEmpty) {
+      val (providerId, provider) = unloadedProvidersToClose.poll()
+
+      if (processThisPartition(providerId)) {
+        submitMaintenanceWorkForProvider(
+          providerId, provider, storeConf, 
MaintenanceTaskType.FromUnloadedProvidersQueue)
+      } else {
+        providersToRequeue += ((providerId, provider))
+      }
+    }
+
+    if (providersToRequeue.nonEmpty) {
+      logInfo(log"Had to requeue ${MDC(LogKeys.SIZE, providersToRequeue.size)} 
providers " +
+        log"for maintenance in doMaintenance")
+    }
+
+    providersToRequeue.foreach(unloadedProvidersToClose.offer)
+
     loadedProviders.synchronized {
       loadedProviders.toSeq
     }.foreach { case (id, provider) =>
       if (processThisPartition(id)) {
-        maintenanceThreadPool.execute(() => {
-          val startTime = System.currentTimeMillis()
-          try {
-            provider.doMaintenance()
-            if (!verifyIfStoreInstanceActive(id)) {
-              unload(id)
-              logInfo(log"Unloaded ${MDC(LogKeys.STATE_STORE_PROVIDER, 
provider)}")
-            }
-          } catch {
-            case NonFatal(e) =>
-              logWarning(log"Error managing 
${MDC(LogKeys.STATE_STORE_PROVIDER, provider)}, " +
-                log"unloading state store provider", e)
-              // When we get a non-fatal exception, we just unload the 
provider.
-              //
-              // By not bubbling the exception to the maintenance task thread 
or the query execution
-              // thread, it's possible for a maintenance thread pool task to 
continue failing on
-              // the same partition. Additionally, if there is some global 
issue that will cause
-              // all maintenance thread pool tasks to fail, then bubbling the 
exception and
-              // stopping the pool is faster than waiting for all tasks to see 
the same exception.
-              //
-              // However, we assume that repeated failures on the same 
partition and global issues
-              // are rare. The benefit to unloading just the partition with an 
exception is that
-              // transient issues on a given provider do not affect any other 
providers; so, in
-              // most cases, this should be a more performant solution.
-              unload(id)
-          } finally {
-            val duration = System.currentTimeMillis() - startTime
-            val logMsg =
-              log"Finished maintenance task for " +
-                log"provider=${MDC(LogKeys.STATE_STORE_PROVIDER_ID, id)}" +
-                log" in elapsed_time=${MDC(LogKeys.TIME_UNITS, duration)}\n"
-            if (duration > 5000) {
-              logInfo(logMsg)
-            } else {
-              logDebug(logMsg)
-            }
-            maintenanceThreadPoolLock.synchronized {
-              maintenancePartitions.remove(id)
-            }
-          }
-        })
+        submitMaintenanceWorkForProvider(
+          id, provider, storeConf, MaintenanceTaskType.FromLoadedProviders)
       } else {
         logInfo(log"Not processing partition ${MDC(LogKeys.PARTITION_ID, id)} 
" +
           log"for maintenance because it is currently " +
@@ -1172,6 +1226,108 @@ object StateStore extends Logging {
     }
   }
 
+  /**
+   * Submits maintenance work for a provider to the maintenance thread pool.
+   *
+   * @param id The StateStore provider ID to perform maintenance on
+   * @param provider The StateStore provider instance
+   */
+  private def submitMaintenanceWorkForProvider(
+      id: StateStoreProviderId,
+      provider: StateStoreProvider,
+      storeConf: StateStoreConf,
+      source: MaintenanceTaskType = FromLoadedProviders): Unit = {
+    maintenanceThreadPool.execute(() => {
+      val startTime = System.currentTimeMillis()
+      // Determine if we can process this partition based on the source
+      val canProcessThisPartition = source match {
+        case FromTaskThread =>
+          // Provider from task thread needs to wait for lock
+          // We potentially need to wait for ongoing maintenance to finish 
processing
+          // this partition
+          val timeoutMs = storeConf.stateStoreMaintenanceProcessingTimeout * 
1000
+          val ableToProcessNow = awaitProcessThisPartition(id, timeoutMs)
+          if (!ableToProcessNow) {
+            // Add to queue for later processing if we can't process now
+            // This will be resubmitted for maintenance later by the 
background maintenance task
+            unloadedProvidersToClose.add((id, provider))
+          }
+          ableToProcessNow
+
+        case FromUnloadedProvidersQueue =>
+          // Provider from queue can be processed immediately
+          // (we've already removed it from loadedProviders)
+          true
+
+        case FromLoadedProviders =>
+          // Provider from loadedProviders can be processed immediately
+          // as it's in maintenancePartitions
+          true
+      }
+
+      if (canProcessThisPartition) {
+        val awaitingPartitionDuration = System.currentTimeMillis() - startTime
+        try {
+          provider.doMaintenance()
+          // Handle unloading based on source
+          source match {
+            case FromTaskThread | FromUnloadedProvidersQueue =>
+              // Provider already removed from loadedProviders, just close it
+              removeFromLoadedProvidersAndClose(id, Some(provider))
+
+            case FromLoadedProviders =>
+              // Check if provider should be unloaded
+              if (!verifyIfStoreInstanceActive(id)) {
+                removeFromLoadedProvidersAndClose(id)
+              }
+          }
+          logInfo(log"Unloaded ${MDC(LogKeys.STATE_STORE_PROVIDER_ID, id)}")
+        } catch {
+          case NonFatal(e) =>
+            logWarning(log"Error doing maintenance on provider:" +
+              log" ${MDC(LogKeys.STATE_STORE_PROVIDER_ID, id)}. " +
+              log"Could not unload state store provider", e)
+            // When we get a non-fatal exception, we just unload the provider.
+            //
+            // By not bubbling the exception to the maintenance task thread or 
the query execution
+            // thread, it's possible for a maintenance thread pool task to 
continue failing on
+            // the same partition. Additionally, if there is some global issue 
that will cause
+            // all maintenance thread pool tasks to fail, then bubbling the 
exception and
+            // stopping the pool is faster than waiting for all tasks to see 
the same exception.
+            //
+            // However, we assume that repeated failures on the same partition 
and global issues
+            // are rare. The benefit to unloading just the partition with an 
exception is that
+            // transient issues on a given provider do not affect any other 
providers; so, in
+            // most cases, this should be a more performant solution.
+            source match {
+              case FromTaskThread | FromUnloadedProvidersQueue =>
+                removeFromLoadedProvidersAndClose(id, Some(provider))
+
+              case FromLoadedProviders =>
+                removeFromLoadedProvidersAndClose(id)
+            }
+        } finally {
+          val duration = System.currentTimeMillis() - startTime
+          val logMsg =
+            log"Finished maintenance task for " +
+              log"provider=${MDC(LogKeys.STATE_STORE_PROVIDER_ID, id)}" +
+              log" in elapsed_time=${MDC(LogKeys.TIME_UNITS, duration)}" +
+              log" and awaiting_partition_time=" +
+              log"${MDC(LogKeys.TIME_UNITS, awaitingPartitionDuration)}\n"
+          if (duration > 5000) {
+            logInfo(logMsg)
+          } else {
+            logDebug(logMsg)
+          }
+          maintenanceThreadPoolLock.synchronized {
+            maintenancePartitions.remove(id)
+            maintenanceThreadPoolLock.notifyAll()
+          }
+        }
+      }
+    })
+  }
+
   private def reportActiveStoreInstance(
       storeProviderId: StateStoreProviderId,
       otherProviderIds: Seq[StateStoreProviderId]): Seq[StateStoreProviderId] 
= {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
index cbf8227ac08c..b41d980b84fe 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
@@ -40,6 +40,8 @@ class StateStoreConf(
    */
   val stateStoreMaintenanceShutdownTimeout: Long = 
sqlConf.stateStoreMaintenanceShutdownTimeout
 
+  val stateStoreMaintenanceProcessingTimeout: Long = 
sqlConf.stateStoreMaintenanceProcessingTimeout
+
   /**
    * Minimum number of delta files in a chain after which HDFSBackedStateStore 
will
    * consider generating a snapshot.
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
index 0891c0702aea..eb8979a90c2d 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
@@ -21,6 +21,7 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, 
File, IOException,
 import java.net.URI
 import java.util
 import java.util.UUID
+import java.util.concurrent.{ConcurrentLinkedQueue, CountDownLatch, TimeUnit}
 import java.util.concurrent.atomic.AtomicBoolean
 
 import scala.collection.mutable
@@ -38,6 +39,7 @@ import org.scalatest.time.SpanSugar._
 
 import org.apache.spark._
 import org.apache.spark.LocalSparkContext._
+import org.apache.spark.internal.Logging
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, 
UnsafeProjection, UnsafeRow}
 import org.apache.spark.sql.catalyst.util.quietly
@@ -50,6 +52,134 @@ import org.apache.spark.tags.ExtendedSQLTest
 import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.Utils
 
+/**
+ * A test StateStoreProvider implementation that controls maintenance execution
+ * timing using a CountDownLatch to simulate concurrent maintenance scenarios.
+ *
+ * This provider is used to test the scenario where a task thread attempts to
+ * unload a provider via maintenance while it's already being processed by a
+ * maintenance thread. This tests the awaitProcessThisPartition functionality
+ * that ensures proper synchronization in StateStore's maintenance thread pool.
+ */
+class SignalingStateStoreProvider extends StateStoreProvider with Logging {
+  import SignalingStateStoreProvider._
+  private var id: StateStoreId = null
+
+  override def init(
+      stateStoreId: StateStoreId,
+      keySchema: StructType,
+      valueSchema: StructType,
+      keyStateEncoderSpec: KeyStateEncoderSpec,
+      useColumnFamilies: Boolean,
+      storeConfs: StateStoreConf,
+      hadoopConf: Configuration,
+      useMultipleValuesPerKey: Boolean = false,
+      stateSchemaProvider: Option[StateSchemaProvider] = None): Unit = {
+    id = stateStoreId
+  }
+
+  override def stateStoreId: StateStoreId = id
+
+  /**
+   * Records which thread called close() to verify that only maintenance 
threads close providers
+   */
+  override def close(): Unit = {
+    closeThreadName = Thread.currentThread.getName
+  }
+
+  /**
+   * This test implementation doesn't need to provide an actual store
+   */
+  override def getStore(
+      version: Long,
+      uniqueId: Option[String]): StateStore = null
+
+  /**
+   * Simulates a maintenance operation that blocks until a signal is received.
+   * This allows testing the scenario where a provider is already under 
maintenance
+   * when a task thread tries to trigger another maintenance operation on it.
+   */
+  override def doMaintenance(): Unit = {
+    maintenanceStarted = true
+    logInfo(s"Maintenance started on thread: 
${Thread.currentThread().getName}")
+
+    // Block until the test signals to continue
+    continueSignal.await()
+
+    logInfo(s"Maintenance continuing after signal on thread: 
${Thread.currentThread().getName}")
+  }
+}
+
+/**
+ * Companion object that tracks state and provides synchronization primitives
+ * for testing concurrent maintenance scenarios
+ */
+object SignalingStateStoreProvider extends Logging {
+  // For tracking state across threads
+  var maintenanceStarted: Boolean = false
+  var taskSubmittedMaintenance: Boolean = false
+  var closeThreadName: String = ""
+
+  // Added for queue testing
+  var providerWasQueued: Boolean = false
+
+  // For coordination between threads
+  var continueSignal = new CountDownLatch(1)
+  val maintenanceStartedLatch = new CountDownLatch(1)
+  val taskAttemptCompletedLatch = new CountDownLatch(1)
+
+  /**
+   * Resets all test state between test runs
+   */
+  def reset(): Unit = {
+    maintenanceStarted = false
+    taskSubmittedMaintenance = false
+    closeThreadName = ""
+
+    // Reset the latch to ensure maintenance will block again
+    try {
+      continueSignal = new CountDownLatch(1)
+    } catch {
+      case e: Exception =>
+        logError(s"Error resetting latch: ${e.getMessage}")
+    }
+  }
+}
+
+class FakeStateStoreProviderTracksCloseThread extends StateStoreProvider {
+  import FakeStateStoreProviderTracksCloseThread._
+  private var id: StateStoreId = null
+
+  override def init(
+      stateStoreId: StateStoreId,
+      keySchema: StructType,
+      valueSchema: StructType,
+      keyStateEncoderSpec: KeyStateEncoderSpec,
+      useColumnFamilies: Boolean,
+      storeConfs: StateStoreConf,
+      hadoopConf: Configuration,
+      useMultipleValuesPerKey: Boolean = false,
+      stateSchemaProvider: Option[StateSchemaProvider] = None): Unit = {
+    id = stateStoreId
+  }
+
+  override def stateStoreId: StateStoreId = id
+
+  override def close(): Unit = {
+    closeThreadNames = Thread.currentThread.getName :: closeThreadNames
+  }
+
+  override def getStore(
+      version: Long,
+      uniqueId: Option[String]): StateStore = null
+
+  override def doMaintenance(): Unit = {}
+}
+
+private object FakeStateStoreProviderTracksCloseThread {
+  var closeThreadNames: List[String] = Nil
+}
+
 // MaintenanceErrorOnCertainPartitionsProvider is a test-only provider that 
throws an
 // exception during maintenance for partitions 0 and 1 (these are arbitrary 
choices). It is
 // used to test that an exception in a single provider's maintenance does not 
affect other
@@ -138,6 +268,262 @@ class StateStoreSuite extends 
StateStoreSuiteBase[HDFSBackedStateStoreProvider]
     require(!StateStore.isMaintenanceRunning)
   }
 
+  test("SPARK-51596: submitMaintenanceWorkForProvider from task thread adds" +
+    " to queue when timeout occurs") {
+    // Reset tracking variables for a clean test
+    SignalingStateStoreProvider.reset()
+
+    val sqlConf = getDefaultSQLConf(
+      SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get,
+      SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY.defaultValue.get
+    )
+
+    // Critical: Set a very short timeout to ensure awaitProcessThisPartition 
fails quickly
+    sqlConf.setConf(SQLConf.STATE_STORE_MAINTENANCE_PROCESSING_TIMEOUT, 1L) // 
1 second
+
+    // Maintenance interval large enough that we control timing manually
+    sqlConf.setConf(SQLConf.STREAMING_MAINTENANCE_INTERVAL, 30000L)
+    sqlConf.setConf(SQLConf.NUM_STATE_STORE_MAINTENANCE_THREADS, 4)
+
+    // Use our test provider
+    sqlConf.setConf(
+      SQLConf.STATE_STORE_PROVIDER_CLASS,
+      classOf[SignalingStateStoreProvider].getName
+    )
+
+    val conf = new SparkConf().setMaster("local").setAppName("test")
+
+    withSpark(SparkContext.getOrCreate(conf)) { sc =>
+      withCoordinatorRef(sc) { _ =>
+        val rootLocation = 
s"${Utils.createTempDir().getAbsolutePath}/spark-51596-timeout-queue"
+        val providerId = StateStoreProviderId(StateStoreId(rootLocation, 0, 
0), UUID.randomUUID)
+
+        // Load the provider to start the maintenance system
+        StateStore.get(
+          providerId,
+          keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema),
+          0, None, None, useColumnFamilies = false,
+          new StateStoreConf(sqlConf), new Configuration()
+        )
+
+        // Access the queue via reflection for verification
+        val queueField = PrivateMethod[ConcurrentLinkedQueue[
+          (StateStoreProviderId, StateStoreProvider)]](
+          Symbol("unloadedProvidersToClose"))
+        val queue = StateStore invokePrivate queueField()
+        assert(queue.isEmpty, "Queue should start empty")
+
+        // Manually trigger maintenance which will block
+        val maintenanceMethod = PrivateMethod[Unit](Symbol("doMaintenance"))
+        StateStore invokePrivate maintenanceMethod()
+
+        // Wait for maintenance to start
+        eventually(timeout(5.seconds)) {
+          assert(SignalingStateStoreProvider.maintenanceStarted)
+          assert(StateStore.isLoaded(providerId))
+        }
+
+        // Now get access to the provider to simulate a task thread
+        val loadedProvidersField = PrivateMethod[
+          mutable.HashMap[StateStoreProviderId, StateStoreProvider]](
+          Symbol("loadedProviders"))
+        val loadedProviders = StateStore invokePrivate loadedProvidersField()
+        val provider = loadedProviders.synchronized { 
loadedProviders.get(providerId).get }
+        val maintenancePartitionsField = PrivateMethod[
+          mutable.HashSet[StateStoreProviderId]](
+          Symbol("maintenancePartitions"))
+        val maintenancePartitions = StateStore invokePrivate 
maintenancePartitionsField()
+
+        // Create a task thread that will attempt to submit maintenance
+        val taskThread = new Thread(() => {
+          try {
+            // Call submitMaintenanceWorkForProvider directly since that's 
what we're testing
+            val submitMaintenanceMethod = PrivateMethod[Unit](
+              Symbol("submitMaintenanceWorkForProvider"))
+            StateStore invokePrivate submitMaintenanceMethod(
+              providerId, provider, new StateStoreConf(sqlConf),
+              MaintenanceTaskType.FromTaskThread)
+
+            SignalingStateStoreProvider.taskSubmittedMaintenance = true
+            SignalingStateStoreProvider.taskAttemptCompletedLatch.countDown()
+          } catch {
+            case e: Exception =>
+              logError(s"Error in task thread: ${e.getMessage}", e)
+          }
+        })
+
+        // Start the task thread - it should timeout and add provider to queue
+        taskThread.start()
+
+        // Wait for task attempt to complete
+        assert(SignalingStateStoreProvider
+          .taskAttemptCompletedLatch.await(10, TimeUnit.SECONDS),
+          "Task thread didn't complete")
+
+        // Critical verification: After timeout, the provider should be in the 
queue
+        eventually(timeout(5.seconds)) {
+          assert(queue.size() == 1, "Provider should be queued after timeout")
+        }
+        val (queuedId, _) = queue.peek()
+        assert(queuedId == providerId, "Queued provider has wrong ID")
+
+        // Now allow the first maintenance to complete
+        SignalingStateStoreProvider.continueSignal.countDown()
+
+        eventually(timeout(5.seconds)) {
+          assert(maintenancePartitions.isEmpty,
+            "Maintenance partitions should be removed from")
+        }
+        // Manually trigger another maintenance to process the queue
+        StateStore invokePrivate maintenanceMethod()
+
+        // Verify the queue eventually gets processed
+        eventually(timeout(5.seconds)) {
+          assert(queue.isEmpty, "Queue should be emptied after maintenance")
+        }
+      }
+    }
+  }
+
+  test("SPARK-51596: queued maintenance tasks get processed when lock is 
available") {
+    // Reset tracking variables for a clean test
+    SignalingStateStoreProvider.reset()
+
+    val sqlConf = getDefaultSQLConf(
+      SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get,
+      SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY.defaultValue.get
+    )
+    // Use a maintenance interval large enough that we control timing 
explicitly
+    sqlConf.setConf(SQLConf.STREAMING_MAINTENANCE_INTERVAL, 30000L)
+    // Set our special provider class that lets us control maintenance timing
+    sqlConf.setConf(
+      SQLConf.STATE_STORE_PROVIDER_CLASS,
+      classOf[SignalingStateStoreProvider].getName
+    )
+
+    val conf = new SparkConf().setMaster("local").setAppName("test")
+
+    withSpark(SparkContext.getOrCreate(conf)) { sc =>
+      withCoordinatorRef(sc) { coordinatorRef =>
+        val rootLocation = 
s"${Utils.createTempDir().getAbsolutePath}/spark-51596-queue"
+
+        // Create two providers that we'll use for the test
+        val provider1Id =
+          StateStoreProviderId(StateStoreId(rootLocation, 0, 0), 
UUID.randomUUID)
+        val provider2Id =
+          StateStoreProviderId(StateStoreId(rootLocation, 0, 1), 
UUID.randomUUID)
+
+        // Get the first provider to load it
+        StateStore.get(
+          provider1Id,
+          keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema),
+          0, None, None, useColumnFamilies = false,
+          new StateStoreConf(sqlConf), new Configuration()
+        )
+
+        // Manually trigger maintenance for provider1, which will block in 
doMaintenance()
+        val maintenanceMethod = PrivateMethod[Unit](Symbol("doMaintenance"))
+        StateStore invokePrivate maintenanceMethod()
+
+        // Wait for maintenance to start before continuing
+        eventually(timeout(5.seconds)) {
+          assert(SignalingStateStoreProvider.maintenanceStarted)
+          assert(StateStore.isLoaded(provider1Id))
+        }
+
+        // Now make the first provider "stale" by reporting it active on 
another executor
+        coordinatorRef.reportActiveInstance(provider1Id, "otherhost", 
"otherexec", Seq.empty)
+
+        // Get provider2 which will cause a maintenance task for provider1 to 
be queued
+        // (since provider1 is already under maintenance and can't be 
processed immediately)
+        StateStore.get(
+          provider2Id,
+          keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema),
+          0, None, None, useColumnFamilies = false,
+          new StateStoreConf(sqlConf), new Configuration()
+        )
+
+        // Mark that task submitted maintenance
+        SignalingStateStoreProvider.taskSubmittedMaintenance = true
+
+        // Unblock the first maintenance operation
+        SignalingStateStoreProvider.continueSignal.countDown()
+
+        // Verify that provider1 is eventually unloaded by the maintenance 
thread
+        // after the first maintenance completes and the queued maintenance 
runs
+        eventually(timeout(5.seconds)) {
+          // Provider1 should be unloaded
+          assert(!StateStore.isLoaded(provider1Id))
+          // Provider2 should still be loaded
+          assert(StateStore.isLoaded(provider2Id))
+          // Close should have been called on a maintenance thread
+          
assert(SignalingStateStoreProvider.closeThreadName.contains("maintenance"))
+        }
+
+        // Get the partitionsForMaintenance field to check the queue is empty
+        val partitionsField = PrivateMethod[
+          
ConcurrentLinkedQueue[StateStoreProviderId]](Symbol("unloadedProvidersToClose"))
+        val queue = StateStore invokePrivate partitionsField()
+        assert(queue.isEmpty, "Maintenance queue should be empty after 
processing queued tasks")
+      }
+    }
+  }
+
+  test("SPARK-51596: unloading only occurs on maintenance thread but occurs 
promptly") {
+    // Reset closeThreadNames
+    FakeStateStoreProviderTracksCloseThread.closeThreadNames = Nil
+
+    val sqlConf = getDefaultSQLConf(
+      SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get,
+      SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY.defaultValue.get
+    )
+    // Make maintenance interval very large (30s) so that task thread runs 
before maintenance.
+    sqlConf.setConf(SQLConf.STREAMING_MAINTENANCE_INTERVAL, 30000L)
+    // Use the `FakeStateStoreProviderTracksCloseThread` to run the test
+    sqlConf.setConf(
+      SQLConf.STATE_STORE_PROVIDER_CLASS,
+      classOf[FakeStateStoreProviderTracksCloseThread].getName
+    )
+
+    val conf = new SparkConf().setMaster("local").setAppName("test")
+
+    withSpark(SparkContext.getOrCreate(conf)) { sc =>
+      withCoordinatorRef(sc) { coordinatorRef =>
+        val rootLocation = 
s"${Utils.createTempDir().getAbsolutePath}/spark-51596"
+        val providerId =
+          StateStoreProviderId(StateStoreId(rootLocation, 0, 0), 
UUID.randomUUID)
+        val providerId2 =
+          StateStoreProviderId(StateStoreId(rootLocation, 0, 1), 
UUID.randomUUID)
+
+        // Create provider to start the maintenance task + pool
+        StateStore.get(
+          providerId,
+          keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema),
+          0, None, None, useColumnFamilies = false, new 
StateStoreConf(sqlConf), new Configuration()
+        )
+
+        // Report instance active on another executor
+        coordinatorRef.reportActiveInstance(providerId, "otherhost", 
"otherexec", Seq.empty)
+
+        // Load another provider to trigger task unload
+        StateStore.get(
+          providerId2,
+          keySchema, valueSchema, NoPrefixKeyStateEncoderSpec(keySchema),
+          0, None, None, useColumnFamilies = false, new 
StateStoreConf(sqlConf), new Configuration()
+        )
+
+        // Wait for close to occur. Timeout is less than maintenance interval,
+        // so should only close by task triggering.
+        eventually(timeout(5.seconds)) {
+          assert(FakeStateStoreProviderTracksCloseThread.closeThreadNames.size 
== 1)
+          FakeStateStoreProviderTracksCloseThread.closeThreadNames.foreach { 
name =>
+            assert(name.contains("state-store-maintenance-thread"))}
+        }
+      }
+    }
+  }
+
+
   test("retaining only two latest versions when 
MAX_BATCHES_TO_RETAIN_IN_MEMORY set to 2") {
     tryWithProviderResource(
       newStoreProvider(minDeltasForSnapshot = 10, numOfVersToRetainInMemory = 
2)) { provider =>
@@ -1611,7 +1997,7 @@ abstract class StateStoreSuiteBase[ProviderClass <: 
StateStoreProvider]
           assert(rowPairsToDataSet(store0reloaded.iterator()) === Set.empty)
 
           // Verify that you can remove the store and still reload and use it
-          StateStore.unload(storeId)
+          StateStore.removeFromLoadedProvidersAndClose(storeId)
           assert(!StateStore.isLoaded(storeId))
 
           val store1reloaded = StateStore.get(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to