This is an automated email from the ASF dual-hosted git repository. ashrigondekar pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 9a1c1ab4710a [SPARK-52008] Throwing an error if State Stores do not commit at the end of a batch when ForeachBatch is used 9a1c1ab4710a is described below commit 9a1c1ab4710ab8efe35683197df5d3539f2106a7 Author: Eric Marnadi <eric.marn...@databricks.com> AuthorDate: Wed Jul 30 11:31:28 2025 -0700 [SPARK-52008] Throwing an error if State Stores do not commit at the end of a batch when ForeachBatch is used ### What changes were proposed in this pull request? This PR adds validation to ensure that all StateStore instances commit their changes at the end of each streaming batch. The implementation tracks expected StateStore commits through the StateStoreCoordinator and validates that all expected stores (across all operators and partitions) have committed before completing a batch, throwing a STATE_STORE_COMMIT_VALIDATION_FAILED error if any commits are missing. This is only a problem when ForeachBatch is used ### Why are the changes needed? To mitigate missing state store file issues when the whole df is not consumed. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit Tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #51706 from ericm-db/feb-df-consumption. Authored-by: Eric Marnadi <eric.marn...@databricks.com> Signed-off-by: Anish Shrigondekar <anish.shrigonde...@databricks.com> --- .../src/main/resources/error/error-conditions.json | 10 + .../pandas/test_pandas_transform_with_state.py | 24 ++ .../org/apache/spark/sql/internal/SQLConf.scala | 13 + .../streaming/runtime/MicroBatchExecution.scala | 61 ++- .../streaming/runtime/ProgressReporter.scala | 12 + .../state/HDFSBackedStateStoreProvider.scala | 6 + .../state/RocksDBStateStoreProvider.scala | 6 + .../sql/execution/streaming/state/StateStore.scala | 23 ++ .../execution/streaming/state/StateStoreConf.scala | 7 + .../streaming/state/StateStoreCoordinator.scala | 141 +++++++ .../streaming/state/StateStoreErrors.scala | 23 ++ .../streaming/sources/ForeachBatchSinkSuite.scala | 423 +++++++++++++++++++++ .../RocksDBStateStoreCheckpointFormatV2Suite.scala | 48 +-- .../sql/streaming/TransformWithStateSuite.scala | 9 +- 14 files changed, 778 insertions(+), 28 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 50116bdc0a9c..febc92eba902 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -5155,6 +5155,16 @@ ], "sqlState" : "42802" }, + "STATE_STORE_COMMIT_VALIDATION_FAILED" : { + "message" : [ + "State store commit validation failed for batch <batchId>.", + "Expected <expectedCommits> commits but got <actualCommits>.", + "Missing commits: <missingCommits>.", + "This error typically occurs when using operations like show() or limit() in foreachBatch that don't process all partitions, or if you are swallowing an exception and returning from the function early.", + "To fix: ensure your foreachBatch function processes the entire DataFrame." + ], + "sqlState" : "XXKST" + }, "STATE_STORE_HANDLE_NOT_INITIALIZED" : { "message" : [ "The handle has not been initialized for this StatefulProcessor.", diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py index 00e03c6da19b..f3b705a44bb0 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py @@ -195,6 +195,7 @@ class TransformWithStateTestsMixin: def test_transform_with_state_basic(self): def check_results(batch_df, batch_id): + batch_df.collect() if batch_id == 0: assert set(batch_df.sort("id").collect()) == { Row(id="0", countAsString="2"), @@ -210,6 +211,7 @@ class TransformWithStateTestsMixin: def test_transform_with_state_non_exist_value_state(self): def check_results(batch_df, _): + batch_df.collect() assert set(batch_df.sort("id").collect()) == { Row(id="0", countAsString="0"), Row(id="1", countAsString="0"), @@ -295,6 +297,7 @@ class TransformWithStateTestsMixin: def test_transform_with_state_list_state(self): def check_results(batch_df, _): + batch_df.collect() assert set(batch_df.sort("id").collect()) == { Row(id="0", countAsString="2"), Row(id="1", countAsString="2"), @@ -306,6 +309,7 @@ class TransformWithStateTestsMixin: def test_transform_with_state_list_state_large_list(self): def check_results(batch_df, batch_id): + batch_df.collect() if batch_id == 0: expected_prev_elements = "" expected_updated_elements = ",".join(map(lambda x: str(x), range(90))) @@ -380,6 +384,7 @@ class TransformWithStateTestsMixin: # test list state with ttl has the same behavior as list state when state doesn't expire. def test_transform_with_state_list_state_large_ttl(self): def check_results(batch_df, batch_id): + batch_df.collect() assert set(batch_df.sort("id").collect()) == { Row(id="0", countAsString="2"), Row(id="1", countAsString="2"), @@ -391,6 +396,7 @@ class TransformWithStateTestsMixin: def test_transform_with_state_map_state(self): def check_results(batch_df, _): + batch_df.collect() assert set(batch_df.sort("id").collect()) == { Row(id="0", countAsString="2"), Row(id="1", countAsString="2"), @@ -401,6 +407,7 @@ class TransformWithStateTestsMixin: # test map state with ttl has the same behavior as map state when state doesn't expire. def test_transform_with_state_map_state_large_ttl(self): def check_results(batch_df, batch_id): + batch_df.collect() assert set(batch_df.sort("id").collect()) == { Row(id="0", countAsString="2"), Row(id="1", countAsString="2"), @@ -414,6 +421,7 @@ class TransformWithStateTestsMixin: # state doesn't expire. def test_value_state_ttl_basic(self): def check_results(batch_df, batch_id): + batch_df.collect() if batch_id == 0: assert set(batch_df.sort("id").collect()) == { Row(id="0", countAsString="2"), @@ -433,6 +441,7 @@ class TransformWithStateTestsMixin: @unittest.skip("test is flaky and it is only a timing issue, skipping until we can resolve") def test_value_state_ttl_expiration(self): def check_results(batch_df, batch_id): + batch_df.collect() if batch_id == 0: assertDataFrameEqual( batch_df, @@ -581,6 +590,8 @@ class TransformWithStateTestsMixin: def test_transform_with_state_proc_timer(self): def check_results(batch_df, batch_id): + batch_df.collect() + # helper function to check expired timestamp is smaller than current processing time def check_timestamp(batch_df): expired_df = ( @@ -696,6 +707,7 @@ class TransformWithStateTestsMixin: def test_transform_with_state_event_time(self): def check_results(batch_df, batch_id): + batch_df.collect() if batch_id == 0: # watermark for late event = 0 # watermark for eviction = 0 @@ -727,6 +739,7 @@ class TransformWithStateTestsMixin: def test_transform_with_state_with_wmark_and_non_event_time(self): def check_results(batch_df, batch_id): + batch_df.collect() if batch_id == 0: # watermark for late event = 0 and min event = 20 assert set(batch_df.sort("id").collect()) == { @@ -824,6 +837,7 @@ class TransformWithStateTestsMixin: def test_transform_with_state_init_state(self): def check_results(batch_df, batch_id): + batch_df.collect() if batch_id == 0: # for key 0, initial state was processed and it was only processed once; # for key 1, it did not appear in the initial state df; @@ -847,6 +861,7 @@ class TransformWithStateTestsMixin: def test_transform_with_state_init_state_with_extra_transformation(self): def check_results(batch_df, batch_id): + batch_df.collect() if batch_id == 0: # for key 0, initial state was processed and it was only processed once; # for key 1, it did not appear in the initial state df; @@ -925,6 +940,7 @@ class TransformWithStateTestsMixin: def test_transform_with_state_non_contiguous_grouping_cols(self): def check_results(batch_df, batch_id): + batch_df.collect() assert set(batch_df.collect()) == { Row(id1="0", id2="1", value=str(123 + 46)), Row(id1="1", id2="2", value=str(146 + 346)), @@ -936,6 +952,7 @@ class TransformWithStateTestsMixin: def test_transform_with_state_non_contiguous_grouping_cols_with_init_state(self): def check_results(batch_df, batch_id): + batch_df.collect() # initial state for key (0, 1) is processed assert set(batch_df.collect()) == { Row(id1="0", id2="1", value=str(789 + 123 + 46)), @@ -1018,6 +1035,7 @@ class TransformWithStateTestsMixin: def test_transform_with_state_chaining_ops(self): def check_results(batch_df, batch_id): + batch_df.collect() import datetime if batch_id == 0: @@ -1053,6 +1071,7 @@ class TransformWithStateTestsMixin: def test_transform_with_state_init_state_with_timers(self): def check_results(batch_df, batch_id): + batch_df.collect() if batch_id == 0: # timers are registered and handled in the first batch for # rows in initial state; For key=0 and key=3 which contains @@ -1177,6 +1196,7 @@ class TransformWithStateTestsMixin: expected_operator_name = "transformWithStateInPySparkExec" def check_results(batch_df, batch_id): + batch_df.collect() if batch_id == 0: assert set(batch_df.sort("id").collect()) == { Row(id="0", countAsString="2"), @@ -1293,6 +1313,7 @@ class TransformWithStateTestsMixin: checkpoint_path = tempfile.mktemp() def check_results(batch_df, batch_id): + batch_df.collect() if batch_id == 0: assert set(batch_df.sort("id").collect()) == { Row(id="0", countAsString="2"), @@ -1372,6 +1393,7 @@ class TransformWithStateTestsMixin: checkpoint_path = tempfile.mktemp() def check_results(batch_df, batch_id): + batch_df.collect() if batch_id == 0: assert set(batch_df.sort("id").collect()) == { Row(id="0", countAsString="2"), @@ -1459,12 +1481,14 @@ class TransformWithStateTestsMixin: def test_transform_with_state_restart_with_multiple_rows_init_state(self): def check_results(batch_df, _): + batch_df.collect() assert set(batch_df.sort("id").collect()) == { Row(id="0", countAsString="2"), Row(id="1", countAsString="2"), } def check_results_for_new_query(batch_df, batch_id): + batch_df.collect() if batch_id == 0: assert set(batch_df.sort("id").collect()) == { Row(id="0", value=str(123 + 46)), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b2254eca5bf2..61bf7a9c46fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2691,6 +2691,17 @@ object SQLConf { .booleanConf .createWithDefault(false) + val STATE_STORE_COMMIT_VALIDATION_ENABLED = + buildConf("spark.sql.streaming.stateStore.commitValidation.enabled") + .doc("When true, Spark will validate that all StateStore instances have committed for " + + "stateful streaming queries using foreachBatch. This helps detect cases where " + + "user-defined functions in foreachBatch (e.g., show(), limit()) don't process all " + + "partitions, which can lead to incorrect results. The validation only applies to " + + "foreachBatch sinks without global aggregates or limits.") + .version("4.1.0") + .booleanConf + .createWithDefault(true) + val CHECKPOINT_RENAMEDFILE_CHECK_ENABLED = buildConf("spark.sql.streaming.checkpoint.renamedFileCheck.enabled") .doc("When true, Spark will validate if renamed checkpoint file exists.") @@ -6552,6 +6563,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def stateStoreUnloadOnCommit: Boolean = getConf(STATE_STORE_UNLOAD_ON_COMMIT) + def stateStoreCommitValidationEnabled: Boolean = getConf(STATE_STORE_COMMIT_VALIDATION_ENABLED) + def streamingMaintenanceInterval: Long = getConf(STREAMING_MAINTENANCE_INTERVAL) def stateStoreCompressionCodec: String = getConf(STATE_STORE_COMPRESSION_CODEC) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala index 1dd70ad985cc..45712cf087c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala @@ -19,11 +19,13 @@ package org.apache.spark.sql.execution.streaming import scala.collection.mutable.{Map => MutableMap} import scala.collection.mutable +import scala.util.control.NonFatal import org.apache.spark.internal.{LogKeys, MDC} +import org.apache.spark.internal.LogKeys.BATCH_ID import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp, FileSourceMetadataAttribute, LocalTimestamp} -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, LogicalPlan, Project, StreamSourceAwareLogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, GlobalLimit, LeafNode, LocalRelation, LogicalPlan, Project, StreamSourceAwareLogicalPlan} import org.apache.spark.sql.catalyst.streaming.{StreamingRelationV2, WriteToStream} import org.apache.spark.sql.catalyst.trees.TreePattern.CURRENT_LIKE import org.apache.spark.sql.catalyst.util.truncatedString @@ -35,7 +37,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, StreamingDataSourceV2Relation, StreamingDataSourceV2ScanRelation, StreamWriterCommitProgress, WriteToDataSourceV2Exec} -import org.apache.spark.sql.execution.streaming.sources.{WriteToMicroBatchDataSource, WriteToMicroBatchDataSourceV1} +import org.apache.spark.sql.execution.streaming.sources.{ForeachBatchSink, WriteToMicroBatchDataSource, WriteToMicroBatchDataSourceV1} import org.apache.spark.sql.execution.streaming.state.StateSchemaBroadcast import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.Trigger @@ -303,6 +305,15 @@ class MicroBatchExecution( } private val watermarkPropagator = WatermarkPropagator(sparkSession.sessionState.conf) + private lazy val hasGlobalAggregateOrLimit = containsGlobalAggregateOrLimit(logicalPlan) + + private def containsGlobalAggregateOrLimit(logicalPlan: LogicalPlan): Boolean = { + logicalPlan.collect { + case agg: Aggregate if agg.groupingExpressions.isEmpty => agg + case limit: GlobalLimit => limit + }.nonEmpty + } + override def cleanup(): Unit = { super.cleanup() @@ -862,6 +873,8 @@ class MicroBatchExecution( isTerminatingTrigger = trigger.isInstanceOf[AvailableNowTrigger.type]) execCtx.executionPlan.executedPlan // Force the lazy generation of execution plan } + // Set up StateStore commit tracking before execution begins + setupStateStoreCommitTracking(execCtx) markMicroBatchExecutionStart(execCtx) @@ -965,6 +978,50 @@ class MicroBatchExecution( } } + + /** + * Set up tracking for StateStore commits before batch execution begins. + * This collects information about expected stateful operators and initializes + * commit tracking, but only for ForeachBatchSink without global aggregates or limits. + */ + private def setupStateStoreCommitTracking(execCtx: MicroBatchExecutionContext): Unit = { + try { + // Collect stateful operators from the executed plan + val statefulOps = execCtx.executionPlan.executedPlan.collect { + case s: StateStoreWriter => s + } + + if (statefulOps.nonEmpty && + sparkSession.sessionState.conf.stateStoreCommitValidationEnabled) { + + // Start tracking before execution begins + // We only validate commits for ForeachBatchSink because it's the only sink where + // user-defined functions can cause partial processing (e.g., using show() or limit()). + // We exclude queries with global aggregates or limits because they naturally don't + // process all partitions, making commit validation unnecessary and potentially noisy. + if (sink.isInstanceOf[ForeachBatchSink[_]] && !hasGlobalAggregateOrLimit) { + progressReporter.shouldValidateStateStoreCommit.set(true) + // Build expected stores map: operatorId -> (storeName -> numPartitions) + val expectedStores = statefulOps.map { op => + val operatorId = op.getStateInfo.operatorId + val numPartitions = op.getStateInfo.numPartitions + val storeNames = op.stateStoreNames.map(_ -> numPartitions).toMap + operatorId -> storeNames + }.toMap + sparkSession.streams.stateStoreCoordinator + .startStateStoreCommitTrackingForBatch(runId, execCtx.batchId, expectedStores) + } + // TODO: Find out how to dynamically set the SQLConf at this point to disable + // the commit tracking + } + } catch { + case NonFatal(e) => + // Log but don't fail the query for tracking setup errors + logWarning(log"Error during StateStore commit tracking setup for batch " + + log"${MDC(BATCH_ID, execCtx.batchId)}", e) + } + } + /** * Called after the microbatch has completed execution. It takes care of committing the offset * to commit log and other bookkeeping. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/ProgressReporter.scala index dc04ba3331e7..8f07126a33bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/ProgressReporter.scala @@ -21,6 +21,7 @@ import java.time.Instant import java.time.ZoneId import java.time.format.DateTimeFormatter import java.util.{Optional, UUID} +import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable import scala.jdk.CollectionConverters._ @@ -56,6 +57,8 @@ class ProgressReporter( // The timestamp we report an event that has not executed anything var lastNoExecutionProgressEventTime = Long.MinValue + val shouldValidateStateStoreCommit = new AtomicBoolean(false) + /** Holds the most recent query progress updates. Accesses must lock on the queue itself. */ private val progressBuffer = new mutable.Queue[StreamingQueryProgress]() @@ -277,6 +280,15 @@ abstract class ProgressContext( currentTriggerStartOffsets != null && currentTriggerEndOffsets != null && currentTriggerLatestOffsets != null ) + + // Only validate commits if enabled and the query has stateful operators + if (progressReporter.shouldValidateStateStoreCommit.get()) { + progressReporter.stateStoreCoordinator.validateStateStoreCommitForBatch( + lastExecution.runId, + lastExecution.currentBatchId + ) + } + currentTriggerEndTimestamp = triggerClock.getTimeMillis() val processingTimeMills = currentTriggerEndTimestamp - currentTriggerStartTimestamp assert(lastExecution != null, "executed batch should provide the information for execution.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 68e8555b314f..358d2d50e406 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -184,6 +184,12 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with logInfo(log"Committed version ${MDC(LogKeys.COMMITTED_VERSION, newVersion)} " + log"for ${MDC(LogKeys.STATE_STORE_PROVIDER, this)} to file " + log"${MDC(LogKeys.FILE_NAME, finalDeltaFile)}") + + // Report the commit to StateStoreCoordinator for tracking + if (storeConf.commitValidationEnabled) { + StateStore.reportCommitToCoordinator(newVersion, stateStoreId, hadoopConf) + } + newVersion } catch { case e: Throwable => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index bfcb2cdda296..a702d041de7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -380,6 +380,12 @@ private[sql] class RocksDBStateStoreProvider validateAndTransitionState(COMMIT) logInfo(log"Committed ${MDC(VERSION_NUM, newVersion)} " + log"for ${MDC(STATE_STORE_ID, id)}") + + // Report the commit to StateStoreCoordinator for tracking + if (storeConf.commitValidationEnabled) { + StateStore.reportCommitToCoordinator(newVersion, stateStoreId, hadoopConf) + } + newVersion } catch { case e: Throwable => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 903bc87f5a22..a9d3c75776e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -34,6 +34,7 @@ import org.json4s.jackson.JsonMethods.{compact, render} import org.apache.spark.{SparkContext, SparkEnv, SparkException, TaskContext} import org.apache.spark.internal.{Logging, LogKeys, MDC} +import org.apache.spark.internal.LogKeys.{EXCEPTION, STATE_STORE_ID, VERSION_NUM} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.util.UnsafeRowUtils import org.apache.spark.sql.errors.QueryExecutionErrors @@ -913,6 +914,28 @@ object StateStore extends Logging { @GuardedBy("maintenanceThreadPoolLock") private val maintenancePartitions = new mutable.HashSet[StateStoreProviderId] + /** Reports to the coordinator that a StateStore has committed */ + def reportCommitToCoordinator( + version: Long, + stateStoreId: StateStoreId, + hadoopConf: Configuration): Unit = { + try { + val runId = UUID.fromString(StateStoreProvider.getRunId(hadoopConf)) + val providerId = StateStoreProviderId(stateStoreId, runId) + // The coordinator will handle whether tracking is active for this batch + // If tracking is not active, it will just reply without processing + StateStoreProvider.coordinatorRef.foreach( + _.reportStateStoreCommit(providerId, version, stateStoreId.storeName) + ) + logDebug(log"Reported commit for store " + + log"${MDC(STATE_STORE_ID, stateStoreId)} at version ${MDC(VERSION_NUM, version)}") + } catch { + case NonFatal(e) => + // Log but don't fail the commit if reporting fails + logWarning(log"Failed to report StateStore commit: ${MDC(EXCEPTION, e)}") + } + } + /** * Runs the `task` periodically and bubbles any exceptions that it encounters. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index b41d980b84fe..4026effcb088 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -71,6 +71,13 @@ class StateStoreConf( /** Whether validate the underlying format or not. */ val formatValidationEnabled: Boolean = sqlConf.stateStoreFormatValidationEnabled + /** + * Whether to validate StateStore commits for ForeachBatch sinks to ensure all partitions + * are processed. This helps detect incomplete processing due to operations like show() + * or limit(). + */ + val commitValidationEnabled = sqlConf.stateStoreCommitValidationEnabled + /** * Whether to validate the value side. This config is applied to both validators as below: * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index 903f27fb2a22..f280553b9540 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -53,6 +53,24 @@ private case class VerifyIfInstanceActive(storeId: StateStoreProviderId, executo private case class GetLocation(storeId: StateStoreProviderId) extends StateStoreCoordinatorMessage +/** Report that a StateStore has committed for tracking purposes */ +private case class ReportStateStoreCommit( + storeId: StateStoreProviderId, + version: Long, + storeName: String = StateStoreId.DEFAULT_STORE_NAME) + extends StateStoreCoordinatorMessage + +/** Start tracking StateStore commits for a batch */ +private case class StartStateStoreCommitTrackingForBatch( + runId: UUID, + batchId: Long, + expectedStores: Map[Long, Map[String, Int]]) // operatorId -> (storeName -> numPartitions) + extends StateStoreCoordinatorMessage + +/** Validate that all expected StateStores have committed for a batch */ +private case class ValidateStateStoreCommitForBatch(runId: UUID, batchId: Long) + extends StateStoreCoordinatorMessage + private case class DeactivateInstances(runId: UUID) extends StateStoreCoordinatorMessage @@ -176,6 +194,29 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { LogLaggingStateStores(queryRunId, latestVersion, isTerminatingTrigger)) } + + /** Start tracking StateStore commits for a batch */ + private[sql] def startStateStoreCommitTrackingForBatch( + runId: UUID, + batchId: Long, + expectedStores: Map[Long, Map[String, Int]]): Unit = { + rpcEndpointRef.askSync[Unit]( + StartStateStoreCommitTrackingForBatch(runId, batchId, expectedStores)) + } + + /** Report that a StateStore has committed */ + private[sql] def reportStateStoreCommit( + storeId: StateStoreProviderId, + version: Long, + storeName: String = StateStoreId.DEFAULT_STORE_NAME): Unit = { + rpcEndpointRef.askSync[Unit](ReportStateStoreCommit(storeId, version, storeName)) + } + + /** Validate that all expected StateStores have committed for a batch */ + private[sql] def validateStateStoreCommitForBatch(runId: UUID, batchId: Long): Unit = { + rpcEndpointRef.askSync[Unit](ValidateStateStoreCommitForBatch(runId, batchId)) + } + /** * Endpoint used for testing. * Get the latest snapshot version uploaded for a state store. @@ -222,6 +263,10 @@ private class StateStoreCoordinator( // Default snapshot upload event to use when a provider has never uploaded a snapshot private val defaultSnapshotUploadEvent = SnapshotUploadEvent(0, 0) + // Tracking structure for StateStore commits per batch + // Key: (runId, batchId) -> Value: CommitTracker + private val batchCommitTrackers = new mutable.HashMap[(UUID, Long), BatchCommitTracker] + // Stores the last timestamp in milliseconds for each queryRunId indicating when the // coordinator did a report on instances lagging behind on snapshot uploads. // The initial timestamp is defaulted to 0 milliseconds. @@ -264,6 +309,10 @@ private class StateStoreCoordinator( val storeIdsToRemove = instances.keys.filter(_.queryRunId == runId).toSeq instances --= storeIdsToRemove + + val runIdsToRemove = batchCommitTrackers.keys.filter(_._1 == runId) + batchCommitTrackers --= runIdsToRemove + // Also remove these instances from snapshot upload event tracking stateStoreLatestUploadedSnapshot --= storeIdsToRemove // Remove the corresponding run id entries for report time and starting time @@ -336,6 +385,49 @@ private class StateStoreCoordinator( } context.reply(true) + case StartStateStoreCommitTrackingForBatch(runId, batchId, expectedStores) => + val key = (runId, batchId) + if (batchCommitTrackers.contains(key)) { + context.sendFailure(new IllegalStateException( + s"Batch commit tracker already exists for runId=$runId, batchId=$batchId")) + } else { + batchCommitTrackers.put(key, new BatchCommitTracker(runId, batchId, expectedStores)) + logInfo(s"Started tracking commits for batch $batchId with " + + s"${expectedStores.values.map(_.values.sum).sum} expected stores") + context.reply() + } + + case ReportStateStoreCommit(storeId, version, storeName) => + // StateStore version = batchId + 1, so we need to adjust + val batchId = version - 1 + val key = (storeId.queryRunId, batchId) + batchCommitTrackers.get(key) match { + case Some(tracker) => + tracker.recordCommit(storeId, storeName) + context.reply() + case None => + // In case no commit tracker for this batch was found + context.reply() + } + + case ValidateStateStoreCommitForBatch(runId, batchId) => + val key = (runId, batchId) + batchCommitTrackers.get(key) match { + case Some(tracker) => + try { + tracker.validateAllCommitted() + batchCommitTrackers.remove(key) // Clean up after validation + context.reply() + } catch { + case e: StateStoreCommitValidationFailed => + batchCommitTrackers.remove(key) // Clean up even on failure + context.sendFailure(e) + } + case None => + context.sendFailure(new IllegalStateException( + s"No commit tracker found for runId=$runId, batchId=$batchId")) + } + case GetLatestSnapshotVersionForTesting(providerId) => val version = stateStoreLatestUploadedSnapshot.get(providerId).map(_.version) logDebug(s"Got latest snapshot version of the state store $providerId: $version") @@ -402,6 +494,55 @@ private class StateStoreCoordinator( } } +/** + * Tracks StateStore commits for a batch to ensure all expected stores commit + */ +private class BatchCommitTracker( + runId: UUID, + batchId: Long, + expectedStores: Map[Long, Map[String, Int]]) extends Logging { + + // Track committed stores: (operatorId, partitionId, storeName) -> committed + private val committedStores = new mutable.HashSet[(Long, Int, String)]() + + def recordCommit(storeId: StateStoreProviderId, storeName: String): Unit = { + val key = (storeId.storeId.operatorId, storeId.storeId.partitionId, storeName) + committedStores.add(key) + logDebug(s"Recorded commit for store $storeId with name $storeName for batch $batchId") + } + + def validateAllCommitted(): Unit = { + val missingCommits = new mutable.ArrayBuffer[String]() + + expectedStores.foreach { case (operatorId, storeMap) => + storeMap.foreach { case (storeName, numPartitions) => + for (partitionId <- 0 until numPartitions) { + val key = (operatorId, partitionId, storeName) + if (!committedStores.contains(key)) { + missingCommits += s"(operator=$operatorId, partition=$partitionId, store=$storeName)" + } + } + } + } + + if (missingCommits.nonEmpty) { + val totalExpected = expectedStores.values.map(_.values.sum).sum + val errorMsg = s"Not all StateStores committed for batch $batchId. " + + s"Expected $totalExpected commits but got ${committedStores.size}. " + + s"Missing commits: ${missingCommits.mkString(", ")}" + logError(errorMsg) + throw StateStoreErrors.stateStoreCommitValidationFailed( + batchId, + totalExpected, + committedStores.size, + missingCommits.mkString(", ") + ) + } + + logInfo(s"All ${committedStores.size} StateStores successfully committed for batch $batchId") + } +} + case class SnapshotUploadEvent( version: Long, timestamp: Long diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala index a4a261342b87..455b06f8d9dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala @@ -216,6 +216,14 @@ object StateStoreErrors { new StateStoreInvalidConfigAfterRestart(configName, oldConfig, newConfig) } + def stateStoreCommitValidationFailed( + batchId: Long, + expectedCommits: Int, + actualCommits: Int, + missingCommits: String): StateStoreCommitValidationFailed = { + new StateStoreCommitValidationFailed(batchId, expectedCommits, actualCommits, missingCommits) + } + def duplicateStateVariableDefined(stateName: String): StateStoreDuplicateStateVariableDefined = { new StateStoreDuplicateStateVariableDefined(stateName) @@ -524,3 +532,18 @@ class StateStoreOperationOutOfOrder(errorMsg: String) errorClass = "STATE_STORE_OPERATION_OUT_OF_ORDER", messageParameters = Map("errorMsg" -> errorMsg) ) + +class StateStoreCommitValidationFailed( + batchId: Long, + expectedCommits: Int, + actualCommits: Int, + missingCommits: String) + extends SparkRuntimeException( + errorClass = "STATE_STORE_COMMIT_VALIDATION_FAILED", + messageParameters = Map( + "batchId" -> batchId.toString, + "expectedCommits" -> expectedCommits.toString, + "actualCommits" -> actualCommits.toString, + "missingCommits" -> missingCommits + ) + ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala index fc36235667ec..ec4aa8dacf00 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming.sources +import java.io.File + import scala.collection.mutable import scala.language.implicitConversions @@ -25,7 +27,9 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.SerializeFromObjectExec import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.execution.streaming.state.StateStoreCommitValidationFailed import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming._ import org.apache.spark.util.ArrayImplicits._ @@ -255,6 +259,425 @@ class ForeachBatchSinkSuite extends StreamTest { query.awaitTermination() } + test("SPARK-52008: foreachBatch with show() should fail with appropriate error") { + // This test verifies that commit validation is enabled by default + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "5") { + withTempDir { tempDir => + val checkpointPath = tempDir.getCanonicalPath + // Create a simple streaming DataFrame + val streamingDF = spark.readStream + .format("rate") + .option("rowsPerSecond", 3) + .load() + .withColumn("pt_id", (rand() * 99 + 1).cast("int")) + .withColumn("event_time", + expr("timestampadd(SECOND, cast(rand() * 2 * 86400 - 86400 as int), timestamp)")) + .withColumn("in_map", (rand() * 2).cast("int") === 1) + .drop("value") + + // Create a stateful streaming query + val windowedDF = streamingDF + .withWatermark("event_time", "1 day") + .groupBy("pt_id") + .agg( + max("event_time").as("latest_event_time"), + last("in_map").as("in_map") + ) + .withColumn("output_time", current_timestamp()) + + // Define a foreachBatch function that uses show(), which only consumes some partitions + def problematicBatchProcessor(batchDF: DataFrame, batchId: Long): Unit = { + // show() only processes enough partitions to display the specified number of rows + // This doesn't consume all partitions, causing state files to be missing + batchDF.show(2) // Only shows 2 rows, not processing all partitions + } + + // Start the streaming query + val queryEx = intercept[StreamingQueryException] { + val query = windowedDF.writeStream + .queryName("reproducer_test") + .option("checkpointLocation", checkpointPath) + .foreachBatch(problematicBatchProcessor _) + .outputMode("update") + .start() + + // Wait for the exception to be thrown + query.awaitTermination() + } + + // Verify we get the StateStore commit validation error + // The error is wrapped by RPC framework, so we need to check the cause chain + val rootCause = queryEx.getCause + assert(rootCause != null, "Expected a root cause for the StreamingQueryException") + + // The RPC framework wraps our exception, so check the cause of the cause + val actualException = rootCause.getCause + assert(actualException != null, "Expected a cause for the RPC exception") + assert(actualException.isInstanceOf[StateStoreCommitValidationFailed], + s"Expected StateStoreCommitValidationFailed but got ${actualException.getClass.getName}") + + val errorMessage = actualException.getMessage + assert(errorMessage.contains("[STATE_STORE_COMMIT_VALIDATION_FAILED]"), + s"Expected STATE_STORE_COMMIT_VALIDATION_FAILED error, but got: $errorMessage") + assert(errorMessage.contains("State store commit validation failed"), + s"Expected state store commit validation message, but got: $errorMessage") + assert(errorMessage.contains("Missing commits"), + s"Expected missing commits message, but got: $errorMessage") + + // Extract and validate the expected vs actual commit counts + val expectedPattern = "Expected (\\d+) commits but got (\\d+)".r + val missingPattern = "Missing commits: (.+)".r + + expectedPattern.findFirstMatchIn(errorMessage) match { + case Some(m) => + val expectedCommits = m.group(1).toInt + val actualCommits = m.group(2).toInt + + // We should have fewer actual commits than expected due to show(2) + // not processing all partitions + assert(actualCommits < expectedCommits, + s"Expected fewer actual commits ($actualCommits)" + + s" than expected commits ($expectedCommits)") + assert(actualCommits >= 1, + s"Expected at least 1 actual commit from show(2), but got $actualCommits") + assert(expectedCommits == 5, + s"Expected more than 5 commits but got $expectedCommits") + + case None => + fail(s"Could not find expected/actual commit counts in error message: $errorMessage") + } + + // Validate that missing commits are reported with proper structure + missingPattern.findFirstMatchIn(errorMessage) match { + case Some(m) => + val missingCommits = m.group(1) + assert(missingCommits.nonEmpty, + s"Expected non-empty missing commits list, but got: '$missingCommits'") + // Should contain operator and partition information + assert(missingCommits.contains("operator=") && missingCommits.contains("partition="), + s"Expected missing commits to contain operator and" + + s" partition info, but got: '$missingCommits'") + case None => + fail(s"Could not find missing commits in error message: $errorMessage") + } + } + } + } + + test("StateStore commit validation should detect missing commits") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "5") { + withTempDir { tempDir => + val checkpointPath = tempDir.getCanonicalPath + + // Create a streaming DataFrame with controlled partitioning + val streamingDF = spark.readStream + .format("rate") + .option("rowsPerSecond", 10) + .option("numPartitions", 4) // Ensure we have multiple partitions + .load() + .withColumn("key", col("value") % 100) + + // Create a stateful operation that requires all partitions to process + val aggregatedDF = streamingDF + .groupBy("key") + .agg(count("*").as("count")) + + // ForeachBatch function that only processes some data and then throws exception + // This should cause some StateStore commits to be missing + def problematicBatchProcessor(batchDF: DataFrame, batchId: Long): Unit = { + // Force evaluation of only a subset of partitions by using limit + // This should cause some partitions to not process their StateStores + batchDF.limit(5).collect() // Only process first 5 rows + + // Simulate a failure that prevents remaining partitions from committing + if (batchId > 0) { + throw new RuntimeException("Simulated batch processing failure") + } + } + + // This should fail with StateStore commit validation error + val queryEx = intercept[StreamingQueryException] { + val query = aggregatedDF.writeStream + .queryName("commit_validation_test") + .option("checkpointLocation", checkpointPath) + .foreachBatch(problematicBatchProcessor _) + .outputMode("complete") + .start() + + query.awaitTermination() + } + + // Should fail with either our new validation error or the simulated RuntimeException + // Check the cause chain since RPC wraps exceptions + val rootCause = queryEx.getCause + val actualException = if (rootCause != null) rootCause.getCause else null + + val hasCommitValidationError = actualException != null && ( + actualException.isInstanceOf[StateStoreCommitValidationFailed] || + actualException.getMessage.contains("[STATE_STORE_COMMIT_VALIDATION_FAILED]")) + val hasSimulatedError = queryEx.getMessage.contains("Simulated batch processing failure") + + assert(hasCommitValidationError || hasSimulatedError, + s"Expected StateStore commit validation error or simulated error," + + s" but got: ${queryEx.getMessage}") + } + } + } + + test("StateStore commit validation with AvailableNow trigger") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "5") { + withTempDir { tempDir => + val checkpointPath = tempDir.getCanonicalPath + + // Create a temporary file with data + val inputPath = new File(tempDir, "input").getCanonicalPath + val inputData = spark.range(0, 100) + .selectExpr("id", "id % 10 as key") + inputData.write + .mode("overwrite") + .parquet(inputPath) + + // Get the schema from the written data + val schema = inputData.schema + + // Create a streaming DataFrame with AvailableNow trigger + val streamingDF = spark.readStream + .format("parquet") + .schema(schema) // Provide the schema explicitly + .load(inputPath) + + // Stateful aggregation + val aggregatedDF = streamingDF + .groupBy("key") + .agg(count("*").as("count")) + + // ForeachBatch that only processes partial data + def problematicBatchProcessor(batchDF: DataFrame, batchId: Long): Unit = { + // Only show first 2 rows, won't process all partitions + batchDF.show(2) + } + + val queryEx = intercept[StreamingQueryException] { + val query = aggregatedDF.writeStream + .queryName("availablenow_commit_test") + .option("checkpointLocation", checkpointPath) + .foreachBatch(problematicBatchProcessor _) + .outputMode("complete") + .trigger(Trigger.AvailableNow()) + .start() + + query.awaitTermination() + } + + // Check the cause chain since RPC wraps exceptions + val rootCause = queryEx.getCause + assert(rootCause != null, "Expected a root cause for the StreamingQueryException") + val actualException = rootCause.getCause + assert(actualException != null, "Expected a cause for the RPC exception") + + assert(actualException.isInstanceOf[StateStoreCommitValidationFailed] || + actualException.getMessage.contains("[STATE_STORE_COMMIT_VALIDATION_FAILED]"), + s"Expected STATE_STORE_COMMIT_VALIDATION_FAILED error," + + s" but got: ${actualException.getMessage}") + } + } + } + + test("StateStore commit validation with swallowed exceptions in foreachBatch") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "5") { + withTempDir { tempDir => + val checkpointPath = tempDir.getCanonicalPath + + // Create streaming DataFrame + val streamingDF = spark.readStream + .format("rate") + .option("rowsPerSecond", 10) + .option("numPartitions", 4) + .load() + .withColumn("key", col("value") % 10) + + // Stateful aggregation + val aggregatedDF = streamingDF + .groupBy("key") + .agg(sum("value").as("sum")) + + // ForeachBatch that swallows exceptions and returns early + def problematicBatchProcessor(batchDF: DataFrame, batchId: Long): Unit = { + try { + // Process only first few rows + val firstRows = batchDF.limit(2).collect() + + // Simulate some processing that might fail + if (firstRows.length > 1) { + throw new RuntimeException("Processing failed!") + } + } catch { + case _: Exception => + // Swallow the exception and return early + // This means remaining partitions won't be processed + return + } + + // This code is never reached due to early return + batchDF.collect() + } + + val queryEx = intercept[StreamingQueryException] { + val query = aggregatedDF.writeStream + .queryName("swallowed_exception_test") + .option("checkpointLocation", checkpointPath) + .foreachBatch(problematicBatchProcessor _) + .outputMode("update") + .start() + + query.awaitTermination() + } + + // Check the cause chain since RPC wraps exceptions + val rootCause = queryEx.getCause + assert(rootCause != null, "Expected a root cause for the StreamingQueryException") + val actualException = rootCause.getCause + assert(actualException != null, "Expected a cause for the RPC exception") + + assert(actualException.isInstanceOf[StateStoreCommitValidationFailed] || + actualException.getMessage.contains("[STATE_STORE_COMMIT_VALIDATION_FAILED]"), + s"Expected STATE_STORE_COMMIT_VALIDATION_FAILED error," + + s" but got: ${actualException.getMessage}") + } + } + } + + test("StateStore commit validation with multiple swallowed exceptions") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "5") { + withTempDir { tempDir => + val checkpointPath = tempDir.getCanonicalPath + + val streamingDF = spark.readStream + .format("rate") + .option("rowsPerSecond", 10) + .load() + .withColumn("key", col("value") % 5) + + // Multiple aggregations to create multiple StateStores + val aggregatedDF = streamingDF + .groupBy("key") + .agg( + count("*").as("count"), + sum("value").as("sum"), + avg("value").as("avg") + ) + + var processedCount = 0 + + // ForeachBatch with multiple exception swallowing points + def problematicBatchProcessor(batchDF: DataFrame, batchId: Long): Unit = { + try { + // First processing attempt + batchDF.limit(1).collect() + processedCount += 1 + + try { + // Second processing attempt that fails + if (processedCount > 0) { + throw new IllegalStateException("Second processing failed") + } + } catch { + case _: IllegalStateException => + // Swallow and continue + } + + try { + // Third processing attempt + batchDF.limit(2).collect() + throw new RuntimeException("Third processing failed") + } catch { + case _: RuntimeException => + // Swallow and return early + return + } + + // Never reached - full processing + batchDF.collect() + } catch { + case _: Exception => + // Outer catch that swallows everything + } + } + + val queryEx = intercept[StreamingQueryException] { + val query = aggregatedDF.writeStream + .queryName("multiple_swallowed_exceptions_test") + .option("checkpointLocation", checkpointPath) + .foreachBatch(problematicBatchProcessor _) + .outputMode("complete") + .start() + + query.awaitTermination() + } + + // Check the cause chain since RPC wraps exceptions + val rootCause = queryEx.getCause + assert(rootCause != null, "Expected a root cause for the StreamingQueryException") + val actualException = rootCause.getCause + assert(actualException != null, "Expected a cause for the RPC exception") + + assert(actualException.isInstanceOf[StateStoreCommitValidationFailed] || + actualException.getMessage.contains("[STATE_STORE_COMMIT_VALIDATION_FAILED]"), + s"Expected STATE_STORE_COMMIT_VALIDATION_FAILED error," + + s" but got: ${actualException.getMessage}") + } + } + } + + test("StateStore commit validation can be disabled via configuration") { + withSQLConf( + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + SQLConf.STATE_STORE_COMMIT_VALIDATION_ENABLED.key -> "false") { + withTempDir { tempDir => + val checkpointPath = tempDir.getCanonicalPath + + val streamingDF = spark.readStream + .format("rate") + .option("rowsPerSecond", 3) + .load() + .withColumn("key", col("value") % 10) + + val aggregatedDF = streamingDF + .groupBy("key") + .agg(count("*").as("count")) + + // ForeachBatch that only processes partial data + // With validation disabled, this should not fail + def partialProcessor(batchDF: DataFrame, batchId: Long): Unit = { + // Only show first 2 rows, won't process all partitions + batchDF.show(2) + } + + // This should complete successfully with validation disabled + val query = aggregatedDF.writeStream + .queryName("validation_disabled_test") + .option("checkpointLocation", checkpointPath) + .foreachBatch(partialProcessor _) + .outputMode("complete") + .trigger(Trigger.ProcessingTime("1 second")) + .start() + + try { + // Wait for at least 2-3 batches to be processed + eventually(timeout(streamingTimeout)) { + assert(query.lastProgress != null, "Query should have made progress") + assert(query.lastProgress.batchId >= 2, + s"Query should have processed at least 3 batches, " + + s"but only processed ${query.lastProgress.batchId + 1}") + } + } finally { + query.stop() + query.awaitTermination() + } + } + } + } + // ============== Helper classes and methods ================= private class ForeachBatchTester[T: Encoder](memoryStream: MemoryStream[Int]) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala index 081383dd66c9..ace8c4db6ff1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala @@ -28,6 +28,7 @@ import org.apache.spark.{SparkContext, SparkException, TaskContext} import org.apache.spark.sql.{DataFrame, ForeachWriter} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.streaming.{CommitLog, MemoryStream, StreamExecution} +import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorSuite.withCoordinatorRef import org.apache.spark.sql.execution.streaming.state.StateStoreTestsHelper import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -157,7 +158,6 @@ class CkptIdCollectingStateStoreProviderWrapper extends StateStoreProvider { hadoopConf: Configuration, useMultipleValuesPerKey: Boolean = false, stateSchemaProvider: Option[StateSchemaProvider] = None): Unit = { - hadoopConf.set(StreamExecution.RUN_ID_KEY, UUID.randomUUID().toString) innerProvider.init( stateStoreId, keySchema, @@ -1203,28 +1203,30 @@ class RocksDBStateStoreCheckpointFormatV2Suite extends StreamTest test("checkpointFormatVersion2 racing commits don't return incorrect checkpointInfo") { val sqlConf = new SQLConf() sqlConf.setConf(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION, 2) - - withTempDir { checkpointDir => - val provider = new CkptIdCollectingStateStoreProviderWrapper() - provider.init( - StateStoreId(checkpointDir.toString, 0, 0), - StateStoreTestsHelper.keySchema, - StateStoreTestsHelper.valueSchema, - PrefixKeyScanStateEncoderSpec(StateStoreTestsHelper.keySchema, 1), - useColumnFamilies = false, - new StateStoreConf(sqlConf), - new Configuration - ) - - val store1 = provider.getStore(0) - val store1NewVersion = store1.commit() - val store2 = provider.getStore(1) - val store2NewVersion = store2.commit() - val store1CheckpointInfo = store1.getStateStoreCheckpointInfo() - val store2CheckpointInfo = store2.getStateStoreCheckpointInfo() - - assert(store1CheckpointInfo.batchVersion == store1NewVersion) - assert(store2CheckpointInfo.batchVersion == store2NewVersion) + val sc = spark.sparkContext + withCoordinatorRef(sc) { _ => + withTempDir { checkpointDir => + val hadoopConf = new Configuration() + hadoopConf.set(StreamExecution.RUN_ID_KEY, UUID.randomUUID().toString) + val provider = new CkptIdCollectingStateStoreProviderWrapper() + provider.init( + StateStoreId(checkpointDir.toString, 0, 0), + StateStoreTestsHelper.keySchema, + StateStoreTestsHelper.valueSchema, + PrefixKeyScanStateEncoderSpec(StateStoreTestsHelper.keySchema, 1), + useColumnFamilies = false, + new StateStoreConf(sqlConf), + hadoopConf + ) + val store1 = provider.getStore(0) + val store1NewVersion = store1.commit() + val store2 = provider.getStore(1) + val store2NewVersion = store2.commit() + val store1CheckpointInfo = store1.getStateStoreCheckpointInfo() + val store2CheckpointInfo = store2.getStateStoreCheckpointInfo() + assert(store1CheckpointInfo.batchVersion == store1NewVersion) + assert(store2CheckpointInfo.batchVersion == store2NewVersion) + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 148c451f37af..ece3b8bf942b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -1532,7 +1532,8 @@ abstract class TransformWithStateSuite extends StateStoreMetricsTest var index = 0 val foreachBatchDf = df.writeStream - .foreachBatch((_: Dataset[(String, String)], _: Long) => { + .foreachBatch((ds: Dataset[(String, String)], _: Long) => { + ds.collect() index += 1 }) .trigger(Trigger.AvailableNow()) @@ -1559,7 +1560,8 @@ abstract class TransformWithStateSuite extends StateStoreMetricsTest def startTriggerAvailableNowQueryAndCheck(expectedIdx: Int): Unit = { val q = df.writeStream - .foreachBatch((_: Dataset[(String, String)], _: Long) => { + .foreachBatch((ds: Dataset[(String, String)], _: Long) => { + ds.collect() index += 1 }) .trigger(Trigger.AvailableNow) @@ -2024,7 +2026,8 @@ abstract class TransformWithStateSuite extends StateStoreMetricsTest var index = 0 val q = df.writeStream - .foreachBatch((_: Dataset[(String, String)], _: Long) => { + .foreachBatch((ds: Dataset[(String, String)], _: Long) => { + ds.collect() index += 1 }) .trigger(Trigger.AvailableNow) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org