This is an automated email from the ASF dual-hosted git repository.

ashrigondekar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9a1c1ab4710a [SPARK-52008] Throwing an error if State Stores do not 
commit at the end of a batch when ForeachBatch is used
9a1c1ab4710a is described below

commit 9a1c1ab4710ab8efe35683197df5d3539f2106a7
Author: Eric Marnadi <eric.marn...@databricks.com>
AuthorDate: Wed Jul 30 11:31:28 2025 -0700

    [SPARK-52008] Throwing an error if State Stores do not commit at the end of 
a batch when ForeachBatch is used
    
    ### What changes were proposed in this pull request?
    
    This PR adds validation to ensure that all StateStore instances commit 
their changes at the end of each streaming batch. The implementation tracks 
expected StateStore commits through the StateStoreCoordinator and validates
    that all expected stores (across all operators and partitions) have 
committed before completing a batch, throwing a 
STATE_STORE_COMMIT_VALIDATION_FAILED error if any commits are missing.
    This is only a problem when ForeachBatch is used
    
    ### Why are the changes needed?
    
    To mitigate missing state store file issues when the whole df is not 
consumed.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Unit Tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #51706 from ericm-db/feb-df-consumption.
    
    Authored-by: Eric Marnadi <eric.marn...@databricks.com>
    Signed-off-by: Anish Shrigondekar <anish.shrigonde...@databricks.com>
---
 .../src/main/resources/error/error-conditions.json |  10 +
 .../pandas/test_pandas_transform_with_state.py     |  24 ++
 .../org/apache/spark/sql/internal/SQLConf.scala    |  13 +
 .../streaming/runtime/MicroBatchExecution.scala    |  61 ++-
 .../streaming/runtime/ProgressReporter.scala       |  12 +
 .../state/HDFSBackedStateStoreProvider.scala       |   6 +
 .../state/RocksDBStateStoreProvider.scala          |   6 +
 .../sql/execution/streaming/state/StateStore.scala |  23 ++
 .../execution/streaming/state/StateStoreConf.scala |   7 +
 .../streaming/state/StateStoreCoordinator.scala    | 141 +++++++
 .../streaming/state/StateStoreErrors.scala         |  23 ++
 .../streaming/sources/ForeachBatchSinkSuite.scala  | 423 +++++++++++++++++++++
 .../RocksDBStateStoreCheckpointFormatV2Suite.scala |  48 +--
 .../sql/streaming/TransformWithStateSuite.scala    |   9 +-
 14 files changed, 778 insertions(+), 28 deletions(-)

diff --git a/common/utils/src/main/resources/error/error-conditions.json 
b/common/utils/src/main/resources/error/error-conditions.json
index 50116bdc0a9c..febc92eba902 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -5155,6 +5155,16 @@
     ],
     "sqlState" : "42802"
   },
+  "STATE_STORE_COMMIT_VALIDATION_FAILED" : {
+    "message" : [
+      "State store commit validation failed for batch <batchId>.",
+      "Expected <expectedCommits> commits but got <actualCommits>.",
+      "Missing commits: <missingCommits>.",
+      "This error typically occurs when using operations like show() or 
limit() in foreachBatch that don't process all partitions, or if you are 
swallowing an exception and returning from the function early.",
+      "To fix: ensure your foreachBatch function processes the entire 
DataFrame."
+    ],
+    "sqlState" : "XXKST"
+  },
   "STATE_STORE_HANDLE_NOT_INITIALIZED" : {
     "message" : [
       "The handle has not been initialized for this StatefulProcessor.",
diff --git 
a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py 
b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
index 00e03c6da19b..f3b705a44bb0 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
@@ -195,6 +195,7 @@ class TransformWithStateTestsMixin:
 
     def test_transform_with_state_basic(self):
         def check_results(batch_df, batch_id):
+            batch_df.collect()
             if batch_id == 0:
                 assert set(batch_df.sort("id").collect()) == {
                     Row(id="0", countAsString="2"),
@@ -210,6 +211,7 @@ class TransformWithStateTestsMixin:
 
     def test_transform_with_state_non_exist_value_state(self):
         def check_results(batch_df, _):
+            batch_df.collect()
             assert set(batch_df.sort("id").collect()) == {
                 Row(id="0", countAsString="0"),
                 Row(id="1", countAsString="0"),
@@ -295,6 +297,7 @@ class TransformWithStateTestsMixin:
 
     def test_transform_with_state_list_state(self):
         def check_results(batch_df, _):
+            batch_df.collect()
             assert set(batch_df.sort("id").collect()) == {
                 Row(id="0", countAsString="2"),
                 Row(id="1", countAsString="2"),
@@ -306,6 +309,7 @@ class TransformWithStateTestsMixin:
 
     def test_transform_with_state_list_state_large_list(self):
         def check_results(batch_df, batch_id):
+            batch_df.collect()
             if batch_id == 0:
                 expected_prev_elements = ""
                 expected_updated_elements = ",".join(map(lambda x: str(x), 
range(90)))
@@ -380,6 +384,7 @@ class TransformWithStateTestsMixin:
     # test list state with ttl has the same behavior as list state when state 
doesn't expire.
     def test_transform_with_state_list_state_large_ttl(self):
         def check_results(batch_df, batch_id):
+            batch_df.collect()
             assert set(batch_df.sort("id").collect()) == {
                 Row(id="0", countAsString="2"),
                 Row(id="1", countAsString="2"),
@@ -391,6 +396,7 @@ class TransformWithStateTestsMixin:
 
     def test_transform_with_state_map_state(self):
         def check_results(batch_df, _):
+            batch_df.collect()
             assert set(batch_df.sort("id").collect()) == {
                 Row(id="0", countAsString="2"),
                 Row(id="1", countAsString="2"),
@@ -401,6 +407,7 @@ class TransformWithStateTestsMixin:
     # test map state with ttl has the same behavior as map state when state 
doesn't expire.
     def test_transform_with_state_map_state_large_ttl(self):
         def check_results(batch_df, batch_id):
+            batch_df.collect()
             assert set(batch_df.sort("id").collect()) == {
                 Row(id="0", countAsString="2"),
                 Row(id="1", countAsString="2"),
@@ -414,6 +421,7 @@ class TransformWithStateTestsMixin:
     # state doesn't expire.
     def test_value_state_ttl_basic(self):
         def check_results(batch_df, batch_id):
+            batch_df.collect()
             if batch_id == 0:
                 assert set(batch_df.sort("id").collect()) == {
                     Row(id="0", countAsString="2"),
@@ -433,6 +441,7 @@ class TransformWithStateTestsMixin:
     @unittest.skip("test is flaky and it is only a timing issue, skipping 
until we can resolve")
     def test_value_state_ttl_expiration(self):
         def check_results(batch_df, batch_id):
+            batch_df.collect()
             if batch_id == 0:
                 assertDataFrameEqual(
                     batch_df,
@@ -581,6 +590,8 @@ class TransformWithStateTestsMixin:
 
     def test_transform_with_state_proc_timer(self):
         def check_results(batch_df, batch_id):
+            batch_df.collect()
+
             # helper function to check expired timestamp is smaller than 
current processing time
             def check_timestamp(batch_df):
                 expired_df = (
@@ -696,6 +707,7 @@ class TransformWithStateTestsMixin:
 
     def test_transform_with_state_event_time(self):
         def check_results(batch_df, batch_id):
+            batch_df.collect()
             if batch_id == 0:
                 # watermark for late event = 0
                 # watermark for eviction = 0
@@ -727,6 +739,7 @@ class TransformWithStateTestsMixin:
 
     def test_transform_with_state_with_wmark_and_non_event_time(self):
         def check_results(batch_df, batch_id):
+            batch_df.collect()
             if batch_id == 0:
                 # watermark for late event = 0 and min event = 20
                 assert set(batch_df.sort("id").collect()) == {
@@ -824,6 +837,7 @@ class TransformWithStateTestsMixin:
 
     def test_transform_with_state_init_state(self):
         def check_results(batch_df, batch_id):
+            batch_df.collect()
             if batch_id == 0:
                 # for key 0, initial state was processed and it was only 
processed once;
                 # for key 1, it did not appear in the initial state df;
@@ -847,6 +861,7 @@ class TransformWithStateTestsMixin:
 
     def test_transform_with_state_init_state_with_extra_transformation(self):
         def check_results(batch_df, batch_id):
+            batch_df.collect()
             if batch_id == 0:
                 # for key 0, initial state was processed and it was only 
processed once;
                 # for key 1, it did not appear in the initial state df;
@@ -925,6 +940,7 @@ class TransformWithStateTestsMixin:
 
     def test_transform_with_state_non_contiguous_grouping_cols(self):
         def check_results(batch_df, batch_id):
+            batch_df.collect()
             assert set(batch_df.collect()) == {
                 Row(id1="0", id2="1", value=str(123 + 46)),
                 Row(id1="1", id2="2", value=str(146 + 346)),
@@ -936,6 +952,7 @@ class TransformWithStateTestsMixin:
 
     def 
test_transform_with_state_non_contiguous_grouping_cols_with_init_state(self):
         def check_results(batch_df, batch_id):
+            batch_df.collect()
             # initial state for key (0, 1) is processed
             assert set(batch_df.collect()) == {
                 Row(id1="0", id2="1", value=str(789 + 123 + 46)),
@@ -1018,6 +1035,7 @@ class TransformWithStateTestsMixin:
 
     def test_transform_with_state_chaining_ops(self):
         def check_results(batch_df, batch_id):
+            batch_df.collect()
             import datetime
 
             if batch_id == 0:
@@ -1053,6 +1071,7 @@ class TransformWithStateTestsMixin:
 
     def test_transform_with_state_init_state_with_timers(self):
         def check_results(batch_df, batch_id):
+            batch_df.collect()
             if batch_id == 0:
                 # timers are registered and handled in the first batch for
                 # rows in initial state; For key=0 and key=3 which contains
@@ -1177,6 +1196,7 @@ class TransformWithStateTestsMixin:
             expected_operator_name = "transformWithStateInPySparkExec"
 
         def check_results(batch_df, batch_id):
+            batch_df.collect()
             if batch_id == 0:
                 assert set(batch_df.sort("id").collect()) == {
                     Row(id="0", countAsString="2"),
@@ -1293,6 +1313,7 @@ class TransformWithStateTestsMixin:
         checkpoint_path = tempfile.mktemp()
 
         def check_results(batch_df, batch_id):
+            batch_df.collect()
             if batch_id == 0:
                 assert set(batch_df.sort("id").collect()) == {
                     Row(id="0", countAsString="2"),
@@ -1372,6 +1393,7 @@ class TransformWithStateTestsMixin:
         checkpoint_path = tempfile.mktemp()
 
         def check_results(batch_df, batch_id):
+            batch_df.collect()
             if batch_id == 0:
                 assert set(batch_df.sort("id").collect()) == {
                     Row(id="0", countAsString="2"),
@@ -1459,12 +1481,14 @@ class TransformWithStateTestsMixin:
 
     def test_transform_with_state_restart_with_multiple_rows_init_state(self):
         def check_results(batch_df, _):
+            batch_df.collect()
             assert set(batch_df.sort("id").collect()) == {
                 Row(id="0", countAsString="2"),
                 Row(id="1", countAsString="2"),
             }
 
         def check_results_for_new_query(batch_df, batch_id):
+            batch_df.collect()
             if batch_id == 0:
                 assert set(batch_df.sort("id").collect()) == {
                     Row(id="0", value=str(123 + 46)),
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index b2254eca5bf2..61bf7a9c46fd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -2691,6 +2691,17 @@ object SQLConf {
       .booleanConf
       .createWithDefault(false)
 
+  val STATE_STORE_COMMIT_VALIDATION_ENABLED =
+    buildConf("spark.sql.streaming.stateStore.commitValidation.enabled")
+      .doc("When true, Spark will validate that all StateStore instances have 
committed for " +
+        "stateful streaming queries using foreachBatch. This helps detect 
cases where " +
+        "user-defined functions in foreachBatch (e.g., show(), limit()) don't 
process all " +
+        "partitions, which can lead to incorrect results. The validation only 
applies to " +
+        "foreachBatch sinks without global aggregates or limits.")
+      .version("4.1.0")
+      .booleanConf
+      .createWithDefault(true)
+
   val CHECKPOINT_RENAMEDFILE_CHECK_ENABLED =
     buildConf("spark.sql.streaming.checkpoint.renamedFileCheck.enabled")
       .doc("When true, Spark will validate if renamed checkpoint file exists.")
@@ -6552,6 +6563,8 @@ class SQLConf extends Serializable with Logging with 
SqlApiConf {
 
   def stateStoreUnloadOnCommit: Boolean = getConf(STATE_STORE_UNLOAD_ON_COMMIT)
 
+  def stateStoreCommitValidationEnabled: Boolean = 
getConf(STATE_STORE_COMMIT_VALIDATION_ENABLED)
+
   def streamingMaintenanceInterval: Long = 
getConf(STREAMING_MAINTENANCE_INTERVAL)
 
   def stateStoreCompressionCodec: String = 
getConf(STATE_STORE_COMPRESSION_CODEC)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala
index 1dd70ad985cc..45712cf087c4 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala
@@ -19,11 +19,13 @@ package org.apache.spark.sql.execution.streaming
 
 import scala.collection.mutable.{Map => MutableMap}
 import scala.collection.mutable
+import scala.util.control.NonFatal
 
 import org.apache.spark.internal.{LogKeys, MDC}
+import org.apache.spark.internal.LogKeys.BATCH_ID
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
CurrentBatchTimestamp, CurrentDate, CurrentTimestamp, 
FileSourceMetadataAttribute, LocalTimestamp}
-import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, 
LogicalPlan, Project, StreamSourceAwareLogicalPlan}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, GlobalLimit, 
LeafNode, LocalRelation, LogicalPlan, Project, StreamSourceAwareLogicalPlan}
 import org.apache.spark.sql.catalyst.streaming.{StreamingRelationV2, 
WriteToStream}
 import org.apache.spark.sql.catalyst.trees.TreePattern.CURRENT_LIKE
 import org.apache.spark.sql.catalyst.util.truncatedString
@@ -35,7 +37,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
 import org.apache.spark.sql.execution.datasources.LogicalRelation
 import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, 
StreamingDataSourceV2Relation, StreamingDataSourceV2ScanRelation, 
StreamWriterCommitProgress, WriteToDataSourceV2Exec}
-import 
org.apache.spark.sql.execution.streaming.sources.{WriteToMicroBatchDataSource, 
WriteToMicroBatchDataSourceV1}
+import org.apache.spark.sql.execution.streaming.sources.{ForeachBatchSink, 
WriteToMicroBatchDataSource, WriteToMicroBatchDataSourceV1}
 import org.apache.spark.sql.execution.streaming.state.StateSchemaBroadcast
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.Trigger
@@ -303,6 +305,15 @@ class MicroBatchExecution(
   }
 
   private val watermarkPropagator = 
WatermarkPropagator(sparkSession.sessionState.conf)
+  private lazy val hasGlobalAggregateOrLimit = 
containsGlobalAggregateOrLimit(logicalPlan)
+
+  private def containsGlobalAggregateOrLimit(logicalPlan: LogicalPlan): 
Boolean = {
+    logicalPlan.collect {
+      case agg: Aggregate if agg.groupingExpressions.isEmpty => agg
+      case limit: GlobalLimit => limit
+    }.nonEmpty
+  }
+
 
   override def cleanup(): Unit = {
     super.cleanup()
@@ -862,6 +873,8 @@ class MicroBatchExecution(
         isTerminatingTrigger = trigger.isInstanceOf[AvailableNowTrigger.type])
       execCtx.executionPlan.executedPlan // Force the lazy generation of 
execution plan
     }
+    // Set up StateStore commit tracking before execution begins
+    setupStateStoreCommitTracking(execCtx)
 
     markMicroBatchExecutionStart(execCtx)
 
@@ -965,6 +978,50 @@ class MicroBatchExecution(
     }
   }
 
+
+  /**
+   * Set up tracking for StateStore commits before batch execution begins.
+   * This collects information about expected stateful operators and 
initializes
+   * commit tracking, but only for ForeachBatchSink without global aggregates 
or limits.
+   */
+  private def setupStateStoreCommitTracking(execCtx: 
MicroBatchExecutionContext): Unit = {
+    try {
+      // Collect stateful operators from the executed plan
+      val statefulOps = execCtx.executionPlan.executedPlan.collect {
+        case s: StateStoreWriter => s
+      }
+
+      if (statefulOps.nonEmpty &&
+        sparkSession.sessionState.conf.stateStoreCommitValidationEnabled) {
+
+        // Start tracking before execution begins
+        // We only validate commits for ForeachBatchSink because it's the only 
sink where
+        // user-defined functions can cause partial processing (e.g., using 
show() or limit()).
+        // We exclude queries with global aggregates or limits because they 
naturally don't
+        // process all partitions, making commit validation unnecessary and 
potentially noisy.
+        if (sink.isInstanceOf[ForeachBatchSink[_]] && 
!hasGlobalAggregateOrLimit) {
+          progressReporter.shouldValidateStateStoreCommit.set(true)
+          // Build expected stores map: operatorId -> (storeName -> 
numPartitions)
+          val expectedStores = statefulOps.map { op =>
+            val operatorId = op.getStateInfo.operatorId
+            val numPartitions = op.getStateInfo.numPartitions
+            val storeNames = op.stateStoreNames.map(_ -> numPartitions).toMap
+            operatorId -> storeNames
+          }.toMap
+          sparkSession.streams.stateStoreCoordinator
+            .startStateStoreCommitTrackingForBatch(runId, execCtx.batchId, 
expectedStores)
+        }
+        // TODO: Find out how to dynamically set the SQLConf at this point to 
disable
+        //  the commit tracking
+      }
+    } catch {
+      case NonFatal(e) =>
+        // Log but don't fail the query for tracking setup errors
+        logWarning(log"Error during StateStore commit tracking setup for batch 
" +
+          log"${MDC(BATCH_ID, execCtx.batchId)}", e)
+    }
+  }
+
   /**
    * Called after the microbatch has completed execution. It takes care of 
committing the offset
    * to commit log and other bookkeeping.
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/ProgressReporter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/ProgressReporter.scala
index dc04ba3331e7..8f07126a33bb 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/ProgressReporter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/ProgressReporter.scala
@@ -21,6 +21,7 @@ import java.time.Instant
 import java.time.ZoneId
 import java.time.format.DateTimeFormatter
 import java.util.{Optional, UUID}
+import java.util.concurrent.atomic.AtomicBoolean
 
 import scala.collection.mutable
 import scala.jdk.CollectionConverters._
@@ -56,6 +57,8 @@ class ProgressReporter(
   // The timestamp we report an event that has not executed anything
   var lastNoExecutionProgressEventTime = Long.MinValue
 
+  val shouldValidateStateStoreCommit = new AtomicBoolean(false)
+
   /** Holds the most recent query progress updates.  Accesses must lock on the 
queue itself. */
   private val progressBuffer = new mutable.Queue[StreamingQueryProgress]()
 
@@ -277,6 +280,15 @@ abstract class ProgressContext(
       currentTriggerStartOffsets != null && currentTriggerEndOffsets != null &&
         currentTriggerLatestOffsets != null
     )
+
+    // Only validate commits if enabled and the query has stateful operators
+    if (progressReporter.shouldValidateStateStoreCommit.get()) {
+      progressReporter.stateStoreCoordinator.validateStateStoreCommitForBatch(
+        lastExecution.runId,
+        lastExecution.currentBatchId
+      )
+    }
+
     currentTriggerEndTimestamp = triggerClock.getTimeMillis()
     val processingTimeMills = currentTriggerEndTimestamp - 
currentTriggerStartTimestamp
     assert(lastExecution != null, "executed batch should provide the 
information for execution.")
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
index 68e8555b314f..358d2d50e406 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
@@ -184,6 +184,12 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
         logInfo(log"Committed version ${MDC(LogKeys.COMMITTED_VERSION, 
newVersion)} " +
           log"for ${MDC(LogKeys.STATE_STORE_PROVIDER, this)} to file " +
           log"${MDC(LogKeys.FILE_NAME, finalDeltaFile)}")
+
+        // Report the commit to StateStoreCoordinator for tracking
+        if (storeConf.commitValidationEnabled) {
+          StateStore.reportCommitToCoordinator(newVersion, stateStoreId, 
hadoopConf)
+        }
+
         newVersion
       } catch {
         case e: Throwable =>
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
index bfcb2cdda296..a702d041de7e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
@@ -380,6 +380,12 @@ private[sql] class RocksDBStateStoreProvider
         validateAndTransitionState(COMMIT)
         logInfo(log"Committed ${MDC(VERSION_NUM, newVersion)} " +
           log"for ${MDC(STATE_STORE_ID, id)}")
+
+        // Report the commit to StateStoreCoordinator for tracking
+        if (storeConf.commitValidationEnabled) {
+          StateStore.reportCommitToCoordinator(newVersion, stateStoreId, 
hadoopConf)
+        }
+
         newVersion
       } catch {
         case e: Throwable =>
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index 903bc87f5a22..a9d3c75776e0 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -34,6 +34,7 @@ import org.json4s.jackson.JsonMethods.{compact, render}
 
 import org.apache.spark.{SparkContext, SparkEnv, SparkException, TaskContext}
 import org.apache.spark.internal.{Logging, LogKeys, MDC}
+import org.apache.spark.internal.LogKeys.{EXCEPTION, STATE_STORE_ID, 
VERSION_NUM}
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.sql.catalyst.util.UnsafeRowUtils
 import org.apache.spark.sql.errors.QueryExecutionErrors
@@ -913,6 +914,28 @@ object StateStore extends Logging {
   @GuardedBy("maintenanceThreadPoolLock")
   private val maintenancePartitions = new mutable.HashSet[StateStoreProviderId]
 
+  /** Reports to the coordinator that a StateStore has committed */
+  def reportCommitToCoordinator(
+      version: Long,
+      stateStoreId: StateStoreId,
+      hadoopConf: Configuration): Unit = {
+    try {
+      val runId = UUID.fromString(StateStoreProvider.getRunId(hadoopConf))
+      val providerId = StateStoreProviderId(stateStoreId, runId)
+      // The coordinator will handle whether tracking is active for this batch
+      // If tracking is not active, it will just reply without processing
+      StateStoreProvider.coordinatorRef.foreach(
+        _.reportStateStoreCommit(providerId, version, stateStoreId.storeName)
+      )
+      logDebug(log"Reported commit for store " +
+        log"${MDC(STATE_STORE_ID, stateStoreId)} at version ${MDC(VERSION_NUM, 
version)}")
+    } catch {
+      case NonFatal(e) =>
+        // Log but don't fail the commit if reporting fails
+        logWarning(log"Failed to report StateStore commit: ${MDC(EXCEPTION, 
e)}")
+    }
+  }
+
   /**
    * Runs the `task` periodically and bubbles any exceptions that it 
encounters.
    *
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
index b41d980b84fe..4026effcb088 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
@@ -71,6 +71,13 @@ class StateStoreConf(
   /** Whether validate the underlying format or not. */
   val formatValidationEnabled: Boolean = 
sqlConf.stateStoreFormatValidationEnabled
 
+  /**
+   * Whether to validate StateStore commits for ForeachBatch sinks to ensure 
all partitions
+   * are processed. This helps detect incomplete processing due to operations 
like show()
+   * or limit().
+   */
+  val commitValidationEnabled = sqlConf.stateStoreCommitValidationEnabled
+
   /**
    * Whether to validate the value side. This config is applied to both 
validators as below:
    *
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala
index 903f27fb2a22..f280553b9540 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala
@@ -53,6 +53,24 @@ private case class VerifyIfInstanceActive(storeId: 
StateStoreProviderId, executo
 private case class GetLocation(storeId: StateStoreProviderId)
   extends StateStoreCoordinatorMessage
 
+/** Report that a StateStore has committed for tracking purposes */
+private case class ReportStateStoreCommit(
+    storeId: StateStoreProviderId,
+    version: Long,
+    storeName: String = StateStoreId.DEFAULT_STORE_NAME)
+  extends StateStoreCoordinatorMessage
+
+/** Start tracking StateStore commits for a batch */
+private case class StartStateStoreCommitTrackingForBatch(
+    runId: UUID,
+    batchId: Long,
+    expectedStores: Map[Long, Map[String, Int]]) // operatorId -> (storeName 
-> numPartitions)
+  extends StateStoreCoordinatorMessage
+
+/** Validate that all expected StateStores have committed for a batch */
+private case class ValidateStateStoreCommitForBatch(runId: UUID, batchId: Long)
+  extends StateStoreCoordinatorMessage
+
 private case class DeactivateInstances(runId: UUID)
   extends StateStoreCoordinatorMessage
 
@@ -176,6 +194,29 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: 
RpcEndpointRef) {
       LogLaggingStateStores(queryRunId, latestVersion, isTerminatingTrigger))
   }
 
+
+  /** Start tracking StateStore commits for a batch */
+  private[sql] def startStateStoreCommitTrackingForBatch(
+      runId: UUID,
+      batchId: Long,
+      expectedStores: Map[Long, Map[String, Int]]): Unit = {
+    rpcEndpointRef.askSync[Unit](
+      StartStateStoreCommitTrackingForBatch(runId, batchId, expectedStores))
+  }
+
+  /** Report that a StateStore has committed */
+  private[sql] def reportStateStoreCommit(
+      storeId: StateStoreProviderId,
+      version: Long,
+      storeName: String = StateStoreId.DEFAULT_STORE_NAME): Unit = {
+    rpcEndpointRef.askSync[Unit](ReportStateStoreCommit(storeId, version, 
storeName))
+  }
+
+  /** Validate that all expected StateStores have committed for a batch */
+  private[sql] def validateStateStoreCommitForBatch(runId: UUID, batchId: 
Long): Unit = {
+    rpcEndpointRef.askSync[Unit](ValidateStateStoreCommitForBatch(runId, 
batchId))
+  }
+
   /**
    * Endpoint used for testing.
    * Get the latest snapshot version uploaded for a state store.
@@ -222,6 +263,10 @@ private class StateStoreCoordinator(
   // Default snapshot upload event to use when a provider has never uploaded a 
snapshot
   private val defaultSnapshotUploadEvent = SnapshotUploadEvent(0, 0)
 
+  // Tracking structure for StateStore commits per batch
+  // Key: (runId, batchId) -> Value: CommitTracker
+  private val batchCommitTrackers = new mutable.HashMap[(UUID, Long), 
BatchCommitTracker]
+
   // Stores the last timestamp in milliseconds for each queryRunId indicating 
when the
   // coordinator did a report on instances lagging behind on snapshot uploads.
   // The initial timestamp is defaulted to 0 milliseconds.
@@ -264,6 +309,10 @@ private class StateStoreCoordinator(
       val storeIdsToRemove =
         instances.keys.filter(_.queryRunId == runId).toSeq
       instances --= storeIdsToRemove
+
+      val runIdsToRemove = batchCommitTrackers.keys.filter(_._1 == runId)
+      batchCommitTrackers --= runIdsToRemove
+
       // Also remove these instances from snapshot upload event tracking
       stateStoreLatestUploadedSnapshot --= storeIdsToRemove
       // Remove the corresponding run id entries for report time and starting 
time
@@ -336,6 +385,49 @@ private class StateStoreCoordinator(
       }
       context.reply(true)
 
+    case StartStateStoreCommitTrackingForBatch(runId, batchId, expectedStores) 
=>
+      val key = (runId, batchId)
+      if (batchCommitTrackers.contains(key)) {
+        context.sendFailure(new IllegalStateException(
+          s"Batch commit tracker already exists for runId=$runId, 
batchId=$batchId"))
+      } else {
+        batchCommitTrackers.put(key, new BatchCommitTracker(runId, batchId, 
expectedStores))
+        logInfo(s"Started tracking commits for batch $batchId with " +
+          s"${expectedStores.values.map(_.values.sum).sum} expected stores")
+        context.reply()
+      }
+
+    case ReportStateStoreCommit(storeId, version, storeName) =>
+      // StateStore version = batchId + 1, so we need to adjust
+      val batchId = version - 1
+      val key = (storeId.queryRunId, batchId)
+      batchCommitTrackers.get(key) match {
+        case Some(tracker) =>
+          tracker.recordCommit(storeId, storeName)
+          context.reply()
+        case None =>
+          // In case no commit tracker for this batch was found
+          context.reply()
+      }
+
+    case ValidateStateStoreCommitForBatch(runId, batchId) =>
+      val key = (runId, batchId)
+      batchCommitTrackers.get(key) match {
+        case Some(tracker) =>
+          try {
+            tracker.validateAllCommitted()
+            batchCommitTrackers.remove(key) // Clean up after validation
+            context.reply()
+          } catch {
+            case e: StateStoreCommitValidationFailed =>
+              batchCommitTrackers.remove(key) // Clean up even on failure
+              context.sendFailure(e)
+          }
+        case None =>
+          context.sendFailure(new IllegalStateException(
+            s"No commit tracker found for runId=$runId, batchId=$batchId"))
+      }
+
     case GetLatestSnapshotVersionForTesting(providerId) =>
       val version = 
stateStoreLatestUploadedSnapshot.get(providerId).map(_.version)
       logDebug(s"Got latest snapshot version of the state store $providerId: 
$version")
@@ -402,6 +494,55 @@ private class StateStoreCoordinator(
   }
 }
 
+/**
+ * Tracks StateStore commits for a batch to ensure all expected stores commit
+ */
+private class BatchCommitTracker(
+    runId: UUID,
+    batchId: Long,
+    expectedStores: Map[Long, Map[String, Int]]) extends Logging {
+
+  // Track committed stores: (operatorId, partitionId, storeName) -> committed
+  private val committedStores = new mutable.HashSet[(Long, Int, String)]()
+
+  def recordCommit(storeId: StateStoreProviderId, storeName: String): Unit = {
+    val key = (storeId.storeId.operatorId, storeId.storeId.partitionId, 
storeName)
+    committedStores.add(key)
+    logDebug(s"Recorded commit for store $storeId with name $storeName for 
batch $batchId")
+  }
+
+  def validateAllCommitted(): Unit = {
+    val missingCommits = new mutable.ArrayBuffer[String]()
+
+    expectedStores.foreach { case (operatorId, storeMap) =>
+      storeMap.foreach { case (storeName, numPartitions) =>
+        for (partitionId <- 0 until numPartitions) {
+          val key = (operatorId, partitionId, storeName)
+          if (!committedStores.contains(key)) {
+            missingCommits += s"(operator=$operatorId, partition=$partitionId, 
store=$storeName)"
+          }
+        }
+      }
+    }
+
+    if (missingCommits.nonEmpty) {
+      val totalExpected = expectedStores.values.map(_.values.sum).sum
+      val errorMsg = s"Not all StateStores committed for batch $batchId. " +
+        s"Expected $totalExpected commits but got ${committedStores.size}. " +
+        s"Missing commits: ${missingCommits.mkString(", ")}"
+      logError(errorMsg)
+      throw StateStoreErrors.stateStoreCommitValidationFailed(
+        batchId,
+        totalExpected,
+        committedStores.size,
+        missingCommits.mkString(", ")
+      )
+    }
+
+    logInfo(s"All ${committedStores.size} StateStores successfully committed 
for batch $batchId")
+  }
+}
+
 case class SnapshotUploadEvent(
     version: Long,
     timestamp: Long
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
index a4a261342b87..455b06f8d9dc 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
@@ -216,6 +216,14 @@ object StateStoreErrors {
     new StateStoreInvalidConfigAfterRestart(configName, oldConfig, newConfig)
   }
 
+  def stateStoreCommitValidationFailed(
+      batchId: Long,
+      expectedCommits: Int,
+      actualCommits: Int,
+      missingCommits: String): StateStoreCommitValidationFailed = {
+    new StateStoreCommitValidationFailed(batchId, expectedCommits, 
actualCommits, missingCommits)
+  }
+
   def duplicateStateVariableDefined(stateName: String):
     StateStoreDuplicateStateVariableDefined = {
     new StateStoreDuplicateStateVariableDefined(stateName)
@@ -524,3 +532,18 @@ class StateStoreOperationOutOfOrder(errorMsg: String)
     errorClass = "STATE_STORE_OPERATION_OUT_OF_ORDER",
     messageParameters = Map("errorMsg" -> errorMsg)
   )
+
+class StateStoreCommitValidationFailed(
+    batchId: Long,
+    expectedCommits: Int,
+    actualCommits: Int,
+    missingCommits: String)
+  extends SparkRuntimeException(
+    errorClass = "STATE_STORE_COMMIT_VALIDATION_FAILED",
+    messageParameters = Map(
+      "batchId" -> batchId.toString,
+      "expectedCommits" -> expectedCommits.toString,
+      "actualCommits" -> actualCommits.toString,
+      "missingCommits" -> missingCommits
+    )
+  )
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala
index fc36235667ec..ec4aa8dacf00 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.execution.streaming.sources
 
+import java.io.File
+
 import scala.collection.mutable
 import scala.language.implicitConversions
 
@@ -25,7 +27,9 @@ import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.execution.SerializeFromObjectExec
 import org.apache.spark.sql.execution.streaming.MemoryStream
+import 
org.apache.spark.sql.execution.streaming.state.StateStoreCommitValidationFailed
 import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming._
 import org.apache.spark.util.ArrayImplicits._
 
@@ -255,6 +259,425 @@ class ForeachBatchSinkSuite extends StreamTest {
     query.awaitTermination()
   }
 
+  test("SPARK-52008: foreachBatch with show() should fail with appropriate 
error") {
+    // This test verifies that commit validation is enabled by default
+    withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "5") {
+      withTempDir { tempDir =>
+        val checkpointPath = tempDir.getCanonicalPath
+        // Create a simple streaming DataFrame
+        val streamingDF = spark.readStream
+          .format("rate")
+          .option("rowsPerSecond", 3)
+          .load()
+          .withColumn("pt_id", (rand() * 99 + 1).cast("int"))
+          .withColumn("event_time",
+            expr("timestampadd(SECOND, cast(rand() * 2 * 86400 - 86400 as 
int), timestamp)"))
+          .withColumn("in_map", (rand() * 2).cast("int") === 1)
+          .drop("value")
+
+        // Create a stateful streaming query
+        val windowedDF = streamingDF
+          .withWatermark("event_time", "1 day")
+          .groupBy("pt_id")
+          .agg(
+            max("event_time").as("latest_event_time"),
+            last("in_map").as("in_map")
+          )
+          .withColumn("output_time", current_timestamp())
+
+        // Define a foreachBatch function that uses show(), which only 
consumes some partitions
+        def problematicBatchProcessor(batchDF: DataFrame, batchId: Long): Unit 
= {
+          // show() only processes enough partitions to display the specified 
number of rows
+          // This doesn't consume all partitions, causing state files to be 
missing
+          batchDF.show(2) // Only shows 2 rows, not processing all partitions
+        }
+
+        // Start the streaming query
+        val queryEx = intercept[StreamingQueryException] {
+          val query = windowedDF.writeStream
+            .queryName("reproducer_test")
+            .option("checkpointLocation", checkpointPath)
+            .foreachBatch(problematicBatchProcessor _)
+            .outputMode("update")
+            .start()
+
+          // Wait for the exception to be thrown
+          query.awaitTermination()
+        }
+
+        // Verify we get the StateStore commit validation error
+        // The error is wrapped by RPC framework, so we need to check the 
cause chain
+        val rootCause = queryEx.getCause
+        assert(rootCause != null, "Expected a root cause for the 
StreamingQueryException")
+
+        // The RPC framework wraps our exception, so check the cause of the 
cause
+        val actualException = rootCause.getCause
+        assert(actualException != null, "Expected a cause for the RPC 
exception")
+        assert(actualException.isInstanceOf[StateStoreCommitValidationFailed],
+          s"Expected StateStoreCommitValidationFailed but got 
${actualException.getClass.getName}")
+
+        val errorMessage = actualException.getMessage
+        assert(errorMessage.contains("[STATE_STORE_COMMIT_VALIDATION_FAILED]"),
+          s"Expected STATE_STORE_COMMIT_VALIDATION_FAILED error, but got: 
$errorMessage")
+        assert(errorMessage.contains("State store commit validation failed"),
+          s"Expected state store commit validation message, but got: 
$errorMessage")
+        assert(errorMessage.contains("Missing commits"),
+          s"Expected missing commits message, but got: $errorMessage")
+
+        // Extract and validate the expected vs actual commit counts
+        val expectedPattern = "Expected (\\d+) commits but got (\\d+)".r
+        val missingPattern = "Missing commits: (.+)".r
+
+        expectedPattern.findFirstMatchIn(errorMessage) match {
+          case Some(m) =>
+            val expectedCommits = m.group(1).toInt
+            val actualCommits = m.group(2).toInt
+
+            // We should have fewer actual commits than expected due to show(2)
+            // not processing all partitions
+            assert(actualCommits < expectedCommits,
+              s"Expected fewer actual commits ($actualCommits)" +
+                s" than expected commits ($expectedCommits)")
+            assert(actualCommits >= 1,
+              s"Expected at least 1 actual commit from show(2), but got 
$actualCommits")
+            assert(expectedCommits == 5,
+              s"Expected more than 5 commits but got $expectedCommits")
+
+          case None =>
+            fail(s"Could not find expected/actual commit counts in error 
message: $errorMessage")
+        }
+
+        // Validate that missing commits are reported with proper structure
+        missingPattern.findFirstMatchIn(errorMessage) match {
+          case Some(m) =>
+            val missingCommits = m.group(1)
+            assert(missingCommits.nonEmpty,
+              s"Expected non-empty missing commits list, but got: 
'$missingCommits'")
+            // Should contain operator and partition information
+            assert(missingCommits.contains("operator=") && 
missingCommits.contains("partition="),
+              s"Expected missing commits to contain operator and" +
+                s" partition info, but got: '$missingCommits'")
+          case None =>
+            fail(s"Could not find missing commits in error message: 
$errorMessage")
+        }
+      }
+    }
+  }
+
+  test("StateStore commit validation should detect missing commits") {
+    withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "5") {
+      withTempDir { tempDir =>
+        val checkpointPath = tempDir.getCanonicalPath
+
+        // Create a streaming DataFrame with controlled partitioning
+        val streamingDF = spark.readStream
+          .format("rate")
+          .option("rowsPerSecond", 10)
+          .option("numPartitions", 4) // Ensure we have multiple partitions
+          .load()
+          .withColumn("key", col("value") % 100)
+
+        // Create a stateful operation that requires all partitions to process
+        val aggregatedDF = streamingDF
+          .groupBy("key")
+          .agg(count("*").as("count"))
+
+        // ForeachBatch function that only processes some data and then throws 
exception
+        // This should cause some StateStore commits to be missing
+        def problematicBatchProcessor(batchDF: DataFrame, batchId: Long): Unit 
= {
+          // Force evaluation of only a subset of partitions by using limit
+          // This should cause some partitions to not process their StateStores
+          batchDF.limit(5).collect() // Only process first 5 rows
+
+          // Simulate a failure that prevents remaining partitions from 
committing
+          if (batchId > 0) {
+            throw new RuntimeException("Simulated batch processing failure")
+          }
+        }
+
+        // This should fail with StateStore commit validation error
+        val queryEx = intercept[StreamingQueryException] {
+          val query = aggregatedDF.writeStream
+            .queryName("commit_validation_test")
+            .option("checkpointLocation", checkpointPath)
+            .foreachBatch(problematicBatchProcessor _)
+            .outputMode("complete")
+            .start()
+
+          query.awaitTermination()
+        }
+
+        // Should fail with either our new validation error or the simulated 
RuntimeException
+        // Check the cause chain since RPC wraps exceptions
+        val rootCause = queryEx.getCause
+        val actualException = if (rootCause != null) rootCause.getCause else 
null
+
+        val hasCommitValidationError = actualException != null && (
+          actualException.isInstanceOf[StateStoreCommitValidationFailed] ||
+          
actualException.getMessage.contains("[STATE_STORE_COMMIT_VALIDATION_FAILED]"))
+        val hasSimulatedError = queryEx.getMessage.contains("Simulated batch 
processing failure")
+
+        assert(hasCommitValidationError || hasSimulatedError,
+          s"Expected StateStore commit validation error or simulated error," +
+            s" but got: ${queryEx.getMessage}")
+      }
+    }
+  }
+
+  test("StateStore commit validation with AvailableNow trigger") {
+    withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "5") {
+      withTempDir { tempDir =>
+        val checkpointPath = tempDir.getCanonicalPath
+
+        // Create a temporary file with data
+        val inputPath = new File(tempDir, "input").getCanonicalPath
+        val inputData = spark.range(0, 100)
+          .selectExpr("id", "id % 10 as key")
+        inputData.write
+          .mode("overwrite")
+          .parquet(inputPath)
+
+        // Get the schema from the written data
+        val schema = inputData.schema
+
+        // Create a streaming DataFrame with AvailableNow trigger
+        val streamingDF = spark.readStream
+          .format("parquet")
+          .schema(schema) // Provide the schema explicitly
+          .load(inputPath)
+
+        // Stateful aggregation
+        val aggregatedDF = streamingDF
+          .groupBy("key")
+          .agg(count("*").as("count"))
+
+        // ForeachBatch that only processes partial data
+        def problematicBatchProcessor(batchDF: DataFrame, batchId: Long): Unit 
= {
+          // Only show first 2 rows, won't process all partitions
+          batchDF.show(2)
+        }
+
+        val queryEx = intercept[StreamingQueryException] {
+          val query = aggregatedDF.writeStream
+            .queryName("availablenow_commit_test")
+            .option("checkpointLocation", checkpointPath)
+            .foreachBatch(problematicBatchProcessor _)
+            .outputMode("complete")
+            .trigger(Trigger.AvailableNow())
+            .start()
+
+          query.awaitTermination()
+        }
+
+        // Check the cause chain since RPC wraps exceptions
+        val rootCause = queryEx.getCause
+        assert(rootCause != null, "Expected a root cause for the 
StreamingQueryException")
+        val actualException = rootCause.getCause
+        assert(actualException != null, "Expected a cause for the RPC 
exception")
+
+        assert(actualException.isInstanceOf[StateStoreCommitValidationFailed] 
||
+          
actualException.getMessage.contains("[STATE_STORE_COMMIT_VALIDATION_FAILED]"),
+          s"Expected STATE_STORE_COMMIT_VALIDATION_FAILED error," +
+            s" but got: ${actualException.getMessage}")
+      }
+    }
+  }
+
+  test("StateStore commit validation with swallowed exceptions in 
foreachBatch") {
+    withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "5") {
+      withTempDir { tempDir =>
+        val checkpointPath = tempDir.getCanonicalPath
+
+        // Create streaming DataFrame
+        val streamingDF = spark.readStream
+          .format("rate")
+          .option("rowsPerSecond", 10)
+          .option("numPartitions", 4)
+          .load()
+          .withColumn("key", col("value") % 10)
+
+        // Stateful aggregation
+        val aggregatedDF = streamingDF
+          .groupBy("key")
+          .agg(sum("value").as("sum"))
+
+        // ForeachBatch that swallows exceptions and returns early
+        def problematicBatchProcessor(batchDF: DataFrame, batchId: Long): Unit 
= {
+          try {
+            // Process only first few rows
+            val firstRows = batchDF.limit(2).collect()
+
+            // Simulate some processing that might fail
+            if (firstRows.length > 1) {
+              throw new RuntimeException("Processing failed!")
+            }
+          } catch {
+            case _: Exception =>
+              // Swallow the exception and return early
+              // This means remaining partitions won't be processed
+              return
+          }
+
+          // This code is never reached due to early return
+          batchDF.collect()
+        }
+
+        val queryEx = intercept[StreamingQueryException] {
+          val query = aggregatedDF.writeStream
+            .queryName("swallowed_exception_test")
+            .option("checkpointLocation", checkpointPath)
+            .foreachBatch(problematicBatchProcessor _)
+            .outputMode("update")
+            .start()
+
+          query.awaitTermination()
+        }
+
+        // Check the cause chain since RPC wraps exceptions
+        val rootCause = queryEx.getCause
+        assert(rootCause != null, "Expected a root cause for the 
StreamingQueryException")
+        val actualException = rootCause.getCause
+        assert(actualException != null, "Expected a cause for the RPC 
exception")
+
+        assert(actualException.isInstanceOf[StateStoreCommitValidationFailed] 
||
+          
actualException.getMessage.contains("[STATE_STORE_COMMIT_VALIDATION_FAILED]"),
+          s"Expected STATE_STORE_COMMIT_VALIDATION_FAILED error," +
+            s" but got: ${actualException.getMessage}")
+      }
+    }
+  }
+
+  test("StateStore commit validation with multiple swallowed exceptions") {
+    withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "5") {
+      withTempDir { tempDir =>
+        val checkpointPath = tempDir.getCanonicalPath
+
+        val streamingDF = spark.readStream
+          .format("rate")
+          .option("rowsPerSecond", 10)
+          .load()
+          .withColumn("key", col("value") % 5)
+
+        // Multiple aggregations to create multiple StateStores
+        val aggregatedDF = streamingDF
+          .groupBy("key")
+          .agg(
+            count("*").as("count"),
+            sum("value").as("sum"),
+            avg("value").as("avg")
+          )
+
+        var processedCount = 0
+
+        // ForeachBatch with multiple exception swallowing points
+        def problematicBatchProcessor(batchDF: DataFrame, batchId: Long): Unit 
= {
+          try {
+            // First processing attempt
+            batchDF.limit(1).collect()
+            processedCount += 1
+
+            try {
+              // Second processing attempt that fails
+              if (processedCount > 0) {
+                throw new IllegalStateException("Second processing failed")
+              }
+            } catch {
+              case _: IllegalStateException =>
+              // Swallow and continue
+            }
+
+            try {
+              // Third processing attempt
+              batchDF.limit(2).collect()
+              throw new RuntimeException("Third processing failed")
+            } catch {
+              case _: RuntimeException =>
+                // Swallow and return early
+                return
+            }
+
+            // Never reached - full processing
+            batchDF.collect()
+          } catch {
+            case _: Exception =>
+            // Outer catch that swallows everything
+          }
+        }
+
+        val queryEx = intercept[StreamingQueryException] {
+          val query = aggregatedDF.writeStream
+            .queryName("multiple_swallowed_exceptions_test")
+            .option("checkpointLocation", checkpointPath)
+            .foreachBatch(problematicBatchProcessor _)
+            .outputMode("complete")
+            .start()
+
+          query.awaitTermination()
+        }
+
+        // Check the cause chain since RPC wraps exceptions
+        val rootCause = queryEx.getCause
+        assert(rootCause != null, "Expected a root cause for the 
StreamingQueryException")
+        val actualException = rootCause.getCause
+        assert(actualException != null, "Expected a cause for the RPC 
exception")
+
+        assert(actualException.isInstanceOf[StateStoreCommitValidationFailed] 
||
+          
actualException.getMessage.contains("[STATE_STORE_COMMIT_VALIDATION_FAILED]"),
+          s"Expected STATE_STORE_COMMIT_VALIDATION_FAILED error," +
+            s" but got: ${actualException.getMessage}")
+      }
+    }
+  }
+
+  test("StateStore commit validation can be disabled via configuration") {
+    withSQLConf(
+      SQLConf.SHUFFLE_PARTITIONS.key -> "5",
+      SQLConf.STATE_STORE_COMMIT_VALIDATION_ENABLED.key -> "false") {
+      withTempDir { tempDir =>
+        val checkpointPath = tempDir.getCanonicalPath
+
+        val streamingDF = spark.readStream
+          .format("rate")
+          .option("rowsPerSecond", 3)
+          .load()
+          .withColumn("key", col("value") % 10)
+
+        val aggregatedDF = streamingDF
+          .groupBy("key")
+          .agg(count("*").as("count"))
+
+        // ForeachBatch that only processes partial data
+        // With validation disabled, this should not fail
+        def partialProcessor(batchDF: DataFrame, batchId: Long): Unit = {
+          // Only show first 2 rows, won't process all partitions
+          batchDF.show(2)
+        }
+
+        // This should complete successfully with validation disabled
+        val query = aggregatedDF.writeStream
+          .queryName("validation_disabled_test")
+          .option("checkpointLocation", checkpointPath)
+          .foreachBatch(partialProcessor _)
+          .outputMode("complete")
+          .trigger(Trigger.ProcessingTime("1 second"))
+          .start()
+
+        try {
+          // Wait for at least 2-3 batches to be processed
+          eventually(timeout(streamingTimeout)) {
+            assert(query.lastProgress != null, "Query should have made 
progress")
+            assert(query.lastProgress.batchId >= 2,
+              s"Query should have processed at least 3 batches, " +
+                s"but only processed ${query.lastProgress.batchId + 1}")
+          }
+        } finally {
+          query.stop()
+          query.awaitTermination()
+        }
+      }
+    }
+  }
+
   // ============== Helper classes and methods =================
 
   private class ForeachBatchTester[T: Encoder](memoryStream: 
MemoryStream[Int]) {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala
index 081383dd66c9..ace8c4db6ff1 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala
@@ -28,6 +28,7 @@ import org.apache.spark.{SparkContext, SparkException, 
TaskContext}
 import org.apache.spark.sql.{DataFrame, ForeachWriter}
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.sql.execution.streaming.{CommitLog, MemoryStream, 
StreamExecution}
+import 
org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorSuite.withCoordinatorRef
 import org.apache.spark.sql.execution.streaming.state.StateStoreTestsHelper
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
@@ -157,7 +158,6 @@ class CkptIdCollectingStateStoreProviderWrapper extends 
StateStoreProvider {
       hadoopConf: Configuration,
       useMultipleValuesPerKey: Boolean = false,
       stateSchemaProvider: Option[StateSchemaProvider] = None): Unit = {
-    hadoopConf.set(StreamExecution.RUN_ID_KEY, UUID.randomUUID().toString)
     innerProvider.init(
       stateStoreId,
       keySchema,
@@ -1203,28 +1203,30 @@ class RocksDBStateStoreCheckpointFormatV2Suite extends 
StreamTest
   test("checkpointFormatVersion2 racing commits don't return incorrect 
checkpointInfo") {
     val sqlConf = new SQLConf()
     sqlConf.setConf(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION, 2)
-
-    withTempDir { checkpointDir =>
-      val provider = new CkptIdCollectingStateStoreProviderWrapper()
-      provider.init(
-        StateStoreId(checkpointDir.toString, 0, 0),
-        StateStoreTestsHelper.keySchema,
-        StateStoreTestsHelper.valueSchema,
-        PrefixKeyScanStateEncoderSpec(StateStoreTestsHelper.keySchema, 1),
-        useColumnFamilies = false,
-        new StateStoreConf(sqlConf),
-        new Configuration
-      )
-
-      val store1 = provider.getStore(0)
-      val store1NewVersion = store1.commit()
-      val store2 = provider.getStore(1)
-      val store2NewVersion = store2.commit()
-      val store1CheckpointInfo = store1.getStateStoreCheckpointInfo()
-      val store2CheckpointInfo = store2.getStateStoreCheckpointInfo()
-
-      assert(store1CheckpointInfo.batchVersion == store1NewVersion)
-      assert(store2CheckpointInfo.batchVersion == store2NewVersion)
+    val sc = spark.sparkContext
+    withCoordinatorRef(sc) { _ =>
+      withTempDir { checkpointDir =>
+        val hadoopConf = new Configuration()
+        hadoopConf.set(StreamExecution.RUN_ID_KEY, UUID.randomUUID().toString)
+        val provider = new CkptIdCollectingStateStoreProviderWrapper()
+        provider.init(
+          StateStoreId(checkpointDir.toString, 0, 0),
+          StateStoreTestsHelper.keySchema,
+          StateStoreTestsHelper.valueSchema,
+          PrefixKeyScanStateEncoderSpec(StateStoreTestsHelper.keySchema, 1),
+          useColumnFamilies = false,
+          new StateStoreConf(sqlConf),
+          hadoopConf
+        )
+        val store1 = provider.getStore(0)
+        val store1NewVersion = store1.commit()
+        val store2 = provider.getStore(1)
+        val store2NewVersion = store2.commit()
+        val store1CheckpointInfo = store1.getStateStoreCheckpointInfo()
+        val store2CheckpointInfo = store2.getStateStoreCheckpointInfo()
+        assert(store1CheckpointInfo.batchVersion == store1NewVersion)
+        assert(store2CheckpointInfo.batchVersion == store2NewVersion)
+      }
     }
   }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
index 148c451f37af..ece3b8bf942b 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
@@ -1532,7 +1532,8 @@ abstract class TransformWithStateSuite extends 
StateStoreMetricsTest
 
         var index = 0
         val foreachBatchDf = df.writeStream
-          .foreachBatch((_: Dataset[(String, String)], _: Long) => {
+          .foreachBatch((ds: Dataset[(String, String)], _: Long) => {
+            ds.collect()
             index += 1
           })
           .trigger(Trigger.AvailableNow())
@@ -1559,7 +1560,8 @@ abstract class TransformWithStateSuite extends 
StateStoreMetricsTest
 
         def startTriggerAvailableNowQueryAndCheck(expectedIdx: Int): Unit = {
           val q = df.writeStream
-            .foreachBatch((_: Dataset[(String, String)], _: Long) => {
+            .foreachBatch((ds: Dataset[(String, String)], _: Long) => {
+              ds.collect()
               index += 1
             })
             .trigger(Trigger.AvailableNow)
@@ -2024,7 +2026,8 @@ abstract class TransformWithStateSuite extends 
StateStoreMetricsTest
         var index = 0
 
         val q = df.writeStream
-          .foreachBatch((_: Dataset[(String, String)], _: Long) => {
+          .foreachBatch((ds: Dataset[(String, String)], _: Long) => {
+            ds.collect()
             index += 1
           })
           .trigger(Trigger.AvailableNow)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to