This is an automated email from the ASF dual-hosted git repository.
kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new becfb94e1c71 [SPARK-46865][SS] Add Batch Support for
TransformWithState Operator
becfb94e1c71 is described below
commit becfb94e1c713d10dac83300d096be490a912fd2
Author: Eric Marnadi <[email protected]>
AuthorDate: Thu Feb 8 12:15:20 2024 +0900
[SPARK-46865][SS] Add Batch Support for TransformWithState Operator
### What changes were proposed in this pull request?
We are allowing batch queries to use and define the `TransformWithState`
operator, which was initially introduced for streaming.
### Why are the changes needed?
This is needed to keep up the parity between streaming and batch APIs,
since we want everything supported in streaming to be supported in batch, as
well.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Added unit tests that use the TransformWithState operator with a batch
query.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #44884 from ericm-db/tws-batch.
Lead-authored-by: Eric Marnadi <[email protected]>
Co-authored-by: ericm-db <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
---
.../analysis/UnsupportedOperationChecker.scala | 3 -
.../spark/sql/execution/SparkStrategies.scala | 9 +-
.../execution/streaming/IncrementalExecution.scala | 2 +-
.../streaming/StatefulProcessorHandleImpl.scala | 25 ++--
.../streaming/TransformWithStateExec.scala | 138 ++++++++++++++++++---
.../sql/streaming/TransformWithStateSuite.scala | 29 ++---
6 files changed, 151 insertions(+), 55 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
index 15a856b273ed..d57464fcefc0 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
@@ -43,9 +43,6 @@ object UnsupportedOperationChecker extends Logging {
throwError("dropDuplicatesWithinWatermark is not supported with batch
" +
"DataFrames/DataSets")(d)
- case t: TransformWithState =>
- throwError("transformWithState is not supported with batch
DataFrames/Datasets")(t)
-
case _ =>
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index f5c2f17f8826..65347fc9d237 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -723,7 +723,7 @@ abstract class SparkStrategies extends
QueryPlanner[SparkPlan] {
* Strategy to convert [[TransformWithState]] logical operator to physical
operator
* in streaming plans.
*/
- object TransformWithStateStrategy extends Strategy {
+ object StreamingTransformWithStateStrategy extends Strategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case TransformWithState(
keyDeserializer, valueDeserializer, groupingAttributes,
@@ -892,6 +892,13 @@ abstract class SparkStrategies extends
QueryPlanner[SparkPlan] {
initialStateGroupAttrs, data, initialStateDataAttrs, output, timeout,
hasInitialState, planLater(initialState), planLater(child)
) :: Nil
+ case logical.TransformWithState(keyDeserializer, valueDeserializer,
groupingAttributes,
+ dataAttributes, statefulProcessor, timeoutMode, outputMode,
keyEncoder,
+ outputObjAttr, child) =>
+
TransformWithStateExec.generateSparkPlanForBatchQueries(keyDeserializer,
valueDeserializer,
+ groupingAttributes, dataAttributes, statefulProcessor, timeoutMode,
outputMode,
+ keyEncoder, outputObjAttr, planLater(child)) :: Nil
+
case _: FlatMapGroupsInPandasWithState =>
// TODO(SPARK-40443): support applyInPandasWithState in batch query
throw new SparkUnsupportedOperationException("_LEGACY_ERROR_TEMP_3176")
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index 08d41b840d04..4469d52618e8 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -73,7 +73,7 @@ class IncrementalExecution(
StreamingRelationStrategy ::
StreamingDeduplicationStrategy ::
StreamingGlobalLimitStrategy(outputMode) ::
- TransformWithStateStrategy :: Nil
+ StreamingTransformWithStateStrategy :: Nil
}
private lazy val hadoopConf = sparkSession.sessionState.newHadoopConf()
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
index d06938ffeafb..fed18fc7e458 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
@@ -69,29 +69,28 @@ class QueryInfoImpl(
* @param store - instance of state store
* @param runId - unique id for the current run
* @param keyEncoder - encoder for the key
+ * @param isStreaming - defines whether the query is streaming or batch
*/
class StatefulProcessorHandleImpl(
store: StateStore,
runId: UUID,
- keyEncoder: ExpressionEncoder[Any])
+ keyEncoder: ExpressionEncoder[Any],
+ isStreaming: Boolean = true)
extends StatefulProcessorHandle with Logging {
import StatefulProcessorHandleState._
+ private val BATCH_QUERY_ID = "00000000-0000-0000-0000-000000000000"
private def buildQueryInfo(): QueryInfo = {
- val taskCtxOpt = Option(TaskContext.get())
- // Task context is not available in tests, so we generate a random query
id and batch id here
- val queryId = if (taskCtxOpt.isDefined) {
- taskCtxOpt.get.getLocalProperty(StreamExecution.QUERY_ID_KEY)
- } else {
- assert(Utils.isTesting, "Failed to find query id in task context")
- UUID.randomUUID().toString
- }
- val batchId = if (taskCtxOpt.isDefined) {
- taskCtxOpt.get.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY).toLong
+ val taskCtxOpt = Option(TaskContext.get())
+ val (queryId, batchId) = if (!isStreaming) {
+ (BATCH_QUERY_ID, 0L)
+ } else if (taskCtxOpt.isDefined) {
+ (taskCtxOpt.get.getLocalProperty(StreamExecution.QUERY_ID_KEY),
+
taskCtxOpt.get.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY).toLong)
} else {
- assert(Utils.isTesting, "Failed to find batch id in task context")
- 0
+ assert(Utils.isTesting, "Failed to find query id/batch Id in task
context")
+ (UUID.randomUUID().toString, 0L)
}
new QueryInfoImpl(UUID.fromString(queryId), runId, batchId)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
index 82e827685b47..818bef5f34a2 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
@@ -16,6 +16,7 @@
*/
package org.apache.spark.sql.execution.streaming
+import java.util.UUID
import java.util.concurrent.TimeUnit.NANOSECONDS
import org.apache.spark.rdd.RDD
@@ -25,9 +26,10 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending,
Attribute, Expressi
import org.apache.spark.sql.catalyst.plans.physical.Distribution
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.streaming.state._
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.{OutputMode, StatefulProcessor,
TimeoutMode}
import org.apache.spark.sql.types._
-import org.apache.spark.util.CompletionIterator
+import org.apache.spark.util.{CompletionIterator, SerializableConfiguration,
Utils}
/**
* Physical operator for executing `TransformWithState`
@@ -44,6 +46,7 @@ import org.apache.spark.util.CompletionIterator
* @param batchTimestampMs processing timestamp of the current batch.
* @param eventTimeWatermarkForLateEvents event time watermark for filtering
late events
* @param eventTimeWatermarkForEviction event time watermark for state eviction
+ * @param isStreaming defines whether the query is streaming or batch
* @param child the physical plan for the underlying data
*/
case class TransformWithStateExec(
@@ -60,7 +63,8 @@ case class TransformWithStateExec(
batchTimestampMs: Option[Long],
eventTimeWatermarkForLateEvents: Option[Long],
eventTimeWatermarkForEviction: Option[Long],
- child: SparkPlan)
+ child: SparkPlan,
+ isStreaming: Boolean = true)
extends UnaryExecNode with StateStoreWriter with WatermarkSupport with
ObjectProducerExec {
override def shortName: String = "transformWithStateExec"
@@ -143,7 +147,11 @@ case class TransformWithStateExec(
// by the upstream (consumer) operators in addition to the processing in
this operator.
allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime -
updatesStartTimeNs)
commitTimeMs += timeTakenMs {
- store.commit()
+ if (isStreaming) {
+ store.commit()
+ } else {
+ store.abort()
+ }
}
setStoreMetrics(store)
setOperatorMetrics()
@@ -155,23 +163,113 @@ case class TransformWithStateExec(
override protected def doExecute(): RDD[InternalRow] = {
metrics // force lazy init at driver
- child.execute().mapPartitionsWithStateStore[InternalRow](
- getStateInfo,
- schemaForKeyRow,
- schemaForValueRow,
- numColsPrefixKey = 0,
- session.sqlContext.sessionState,
- Some(session.sqlContext.streams.stateStoreCoordinator),
- useColumnFamilies = true
- ) {
- case (store: StateStore, singleIterator: Iterator[InternalRow]) =>
- val processorHandle = new StatefulProcessorHandleImpl(store,
getStateInfo.queryRunId,
- keyEncoder)
- assert(processorHandle.getHandleState ==
StatefulProcessorHandleState.CREATED)
- statefulProcessor.init(processorHandle, outputMode)
-
processorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED)
- val result = processDataWithPartition(singleIterator, store,
processorHandle)
- result
+ if (isStreaming) {
+ child.execute().mapPartitionsWithStateStore[InternalRow](
+ getStateInfo,
+ schemaForKeyRow,
+ schemaForValueRow,
+ numColsPrefixKey = 0,
+ session.sqlContext.sessionState,
+ Some(session.sqlContext.streams.stateStoreCoordinator),
+ useColumnFamilies = true
+ ) {
+ case (store: StateStore, singleIterator: Iterator[InternalRow]) =>
+ processData(store, singleIterator)
+ }
+ } else {
+ // If the query is running in batch mode, we need to create a new
StateStore and instantiate
+ // a temp directory on the executors in mapPartitionsWithIndex.
+ val broadcastedHadoopConf =
+ new SerializableConfiguration(session.sessionState.newHadoopConf())
+ child.execute().mapPartitionsWithIndex[InternalRow](
+ (i, iter) => {
+ val providerId = {
+ val tempDirPath = Utils.createTempDir().getAbsolutePath
+ new StateStoreProviderId(
+ StateStoreId(tempDirPath, 0, i), getStateInfo.queryRunId)
+ }
+
+ val sqlConf = new SQLConf()
+ sqlConf.setConfString(SQLConf.STATE_STORE_PROVIDER_CLASS.key,
+ classOf[RocksDBStateStoreProvider].getName)
+ val storeConf = new StateStoreConf(sqlConf)
+
+ // Create StateStoreProvider for this partition
+ val stateStoreProvider = StateStoreProvider.createAndInit(
+ providerId,
+ schemaForKeyRow,
+ schemaForValueRow,
+ numColsPrefixKey = 0,
+ useColumnFamilies = true,
+ storeConf = storeConf,
+ hadoopConf = broadcastedHadoopConf.value)
+
+ val store = stateStoreProvider.getStore(0)
+ val outputIterator = processData(store, iter)
+ CompletionIterator[InternalRow,
Iterator[InternalRow]](outputIterator.iterator, {
+ stateStoreProvider.close()
+ statefulProcessor.close()
+ })
+ }
+ )
}
}
+
+ /**
+ * Process the data in the partition using the state store and the stateful
processor.
+ * @param store The state store to use
+ * @param singleIterator The iterator of rows to process
+ * @return An iterator of rows that are the result of processing the input
rows
+ */
+ private def processData(store: StateStore, singleIterator:
Iterator[InternalRow]):
+ CompletionIterator[InternalRow, Iterator[InternalRow]] = {
+ val processorHandle = new StatefulProcessorHandleImpl(
+ store, getStateInfo.queryRunId, keyEncoder, isStreaming)
+ assert(processorHandle.getHandleState ==
StatefulProcessorHandleState.CREATED)
+ statefulProcessor.init(processorHandle, outputMode)
+ processorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED)
+ processDataWithPartition(singleIterator, store, processorHandle)
+ }
+}
+
+object TransformWithStateExec {
+
+ // Plan logical transformWithState for batch queries
+ def generateSparkPlanForBatchQueries(
+ keyDeserializer: Expression,
+ valueDeserializer: Expression,
+ groupingAttributes: Seq[Attribute],
+ dataAttributes: Seq[Attribute],
+ statefulProcessor: StatefulProcessor[Any, Any, Any],
+ timeoutMode: TimeoutMode,
+ outputMode: OutputMode,
+ keyEncoder: ExpressionEncoder[Any],
+ outputObjAttr: Attribute,
+ child: SparkPlan): SparkPlan = {
+ val shufflePartitions =
child.session.sessionState.conf.numShufflePartitions
+ val statefulOperatorStateInfo = StatefulOperatorStateInfo(
+ checkpointLocation = "", // empty checkpointLocation will be populated
in doExecute
+ queryRunId = UUID.randomUUID(),
+ operatorId = 0,
+ storeVersion = 0,
+ numPartitions = shufflePartitions
+ )
+
+ new TransformWithStateExec(
+ keyDeserializer,
+ valueDeserializer,
+ groupingAttributes,
+ dataAttributes,
+ statefulProcessor,
+ timeoutMode,
+ outputMode,
+ keyEncoder,
+ outputObjAttr,
+ Some(statefulOperatorStateInfo),
+ Some(System.currentTimeMillis),
+ None,
+ None,
+ child,
+ isStreaming = false)
+ }
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
index 569e6852315c..7b448ac93419 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.streaming
import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{AnalysisException, SaveMode}
import org.apache.spark.sql.execution.streaming._
import
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled,
RocksDBStateStoreProvider,
StateStoreMultipleColumnFamiliesNotSupportedException}
import org.apache.spark.sql.internal.SQLConf
@@ -196,6 +195,18 @@ class TransformWithStateSuite extends StateStoreMetricsTest
}
}
+ test("transformWithState - batch should succeed") {
+ val inputData = Seq("a", "b")
+ val result = inputData.toDS()
+ .groupByKey(x => x)
+ .transformWithState(new RunningCountStatefulProcessor(),
+ TimeoutMode.NoTimeouts(),
+ OutputMode.Append())
+
+ val df = result.toDF()
+ checkAnswer(df, Seq(("a", "1"), ("b", "1")).toDF())
+ }
+
test("transformWithState - test deleteIfExists operator") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName,
@@ -333,22 +344,6 @@ class TransformWithStateSuite extends StateStoreMetricsTest
class TransformWithStateValidationSuite extends StateStoreMetricsTest {
import testImplicits._
- test("transformWithState - batch should fail") {
- val ex = intercept[Exception] {
- val df = Seq("a", "a", "b").toDS()
- .groupByKey(x => x)
- .transformWithState(new RunningCountStatefulProcessor,
- TimeoutMode.NoTimeouts(),
- OutputMode.Append())
- .write
- .format("noop")
- .mode(SaveMode.Append)
- .save()
- }
- assert(ex.isInstanceOf[AnalysisException])
- assert(ex.getMessage.contains("not supported"))
- }
-
test("transformWithState - streaming with hdfsStateStoreProvider should
fail") {
val inputData = MemoryStream[String]
val result = inputData.toDS()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]