This is an automated email from the ASF dual-hosted git repository.
kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 4d72be3abdc4 [SPARK-47363][SS] Initial State without state reader
implementation for State API v2
4d72be3abdc4 is described below
commit 4d72be3abdc4c651da029bdbd24a574099d45e7c
Author: jingz-db <[email protected]>
AuthorDate: Thu Mar 28 14:50:46 2024 +0900
[SPARK-47363][SS] Initial State without state reader implementation for
State API v2
### What changes were proposed in this pull request?
This PR adds support for users to provide a Dataframe that can be used to
instantiate state for the query in the first batch for arbitrary state API v2.
Note that populating the initial state will only happen for the first batch
of the new streaming query. Trying to re-initialize state for the same grouping
key will result in an error.
### Why are the changes needed?
These changes are needed to support initial state. The changes are part of
the work around adding new stateful streaming operator for arbitrary state mgmt
that provides a bunch of new features listed in the SPIP JIRA here -
https://issues.apache.org/jira/browse/SPARK-45939
### Does this PR introduce _any_ user-facing change?
Yes.
This PR introduces a new function:
```
def transformWithState(
statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
timeoutMode: TimeoutMode,
outputMode: OutputMode,
initialState: KeyValueGroupedDataset[K, S]): Dataset[U]
```
### How was this patch tested?
Unit tests in `TransformWithStateWithInitialStateSuite`
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #45467 from jingz-db/initial-state-state-v2.
Lead-authored-by: jingz-db <[email protected]>
Co-authored-by: Jing Zhan <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
---
.../src/main/resources/error/error-classes.json | 6 +
docs/sql-error-conditions.md | 6 +
.../spark/sql/streaming/StatefulProcessor.scala | 19 ++
.../spark/sql/catalyst/plans/logical/object.scala | 55 +++-
.../apache/spark/sql/KeyValueGroupedDataset.scala | 38 ++-
.../spark/sql/execution/SparkStrategies.scala | 20 +-
.../execution/streaming/IncrementalExecution.scala | 4 +-
.../streaming/TransformWithStateExec.scala | 254 ++++++++++++++----
.../streaming/state/StateStoreErrors.scala | 10 +
.../sql/streaming/TransformWithMapStateSuite.scala | 5 +-
.../TransformWithStateInitialStateSuite.scala | 293 +++++++++++++++++++++
.../sql/streaming/TransformWithStateSuite.scala | 20 ++
12 files changed, 661 insertions(+), 69 deletions(-)
diff --git a/common/utils/src/main/resources/error/error-classes.json
b/common/utils/src/main/resources/error/error-classes.json
index 185e86853dfd..11c8204d2c93 100644
--- a/common/utils/src/main/resources/error/error-classes.json
+++ b/common/utils/src/main/resources/error/error-classes.json
@@ -3553,6 +3553,12 @@
],
"sqlState" : "42802"
},
+ "STATEFUL_PROCESSOR_CANNOT_REINITIALIZE_STATE_ON_KEY" : {
+ "message" : [
+ "Cannot re-initialize state on the same grouping key during initial
state handling for stateful processor. Invalid grouping key=<groupingKey>."
+ ],
+ "sqlState" : "42802"
+ },
"STATE_STORE_CANNOT_CREATE_COLUMN_FAMILY_WITH_RESERVED_CHARS" : {
"message" : [
"Failed to create column family with unsupported starting character and
name=<colFamilyName>."
diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md
index 838ca2fa33c9..85b9e85ac420 100644
--- a/docs/sql-error-conditions.md
+++ b/docs/sql-error-conditions.md
@@ -2162,6 +2162,12 @@ Failed to perform stateful processor
operation=`<operationType>` with invalid ha
Failed to perform stateful processor operation=`<operationType>` with invalid
timeoutMode=`<timeoutMode>`
+### STATEFUL_PROCESSOR_CANNOT_REINITIALIZE_STATE_ON_KEY
+
+[SQLSTATE:
42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
+
+Cannot re-initialize state on the same grouping key during initial state
handling for stateful processor. Invalid grouping key=`<groupingKey>`.
+
### STATE_STORE_CANNOT_CREATE_COLUMN_FAMILY_WITH_RESERVED_CHARS
[SQLSTATE:
42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala
b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala
index ad9b807ddf5a..1a61972f0ed0 100644
---
a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala
+++
b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala
@@ -91,3 +91,22 @@ private[sql] trait StatefulProcessor[K, I, O] extends
Serializable {
statefulProcessorHandle
}
}
+
+/**
+ * Stateful processor with support for specifying initial state.
+ * Accepts a user-defined type as initial state to be initialized in the first
batch.
+ * This can be used for starting a new streaming query with existing state
from a
+ * previous streaming query.
+ */
+@Experimental
+@Evolving
+trait StatefulProcessorWithInitialState[K, I, O, S] extends
StatefulProcessor[K, I, O] {
+
+ /**
+ * Function that will be invoked only in the first batch for users to
process initial states.
+ *
+ * @param key - grouping key
+ * @param initialState - A row in the initial state to be processed
+ */
+ def handleInitialState(key: K, initialState: S): Unit
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index cb8673d20ed3..b2c443a8cce0 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -588,7 +588,46 @@ object TransformWithState {
outputMode,
keyEncoder.asInstanceOf[ExpressionEncoder[Any]],
CatalystSerde.generateObjAttr[U],
- child
+ child,
+ hasInitialState = false,
+ // the following parameters will not be used in physical plan if
hasInitialState = false
+ initialStateGroupingAttrs = groupingAttributes,
+ initialStateDataAttrs = dataAttributes,
+ initialStateDeserializer =
+ UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes),
+ initialState = LocalRelation(encoderFor[K].schema) // empty data set
+ )
+ CatalystSerde.serialize[U](mapped)
+ }
+
+ // This apply() is to invoke TransformWithState object with hasInitialState
set to true
+ def apply[K: Encoder, V: Encoder, U: Encoder, S: Encoder](
+ groupingAttributes: Seq[Attribute],
+ dataAttributes: Seq[Attribute],
+ statefulProcessor: StatefulProcessor[K, V, U],
+ timeoutMode: TimeoutMode,
+ outputMode: OutputMode,
+ child: LogicalPlan,
+ initialStateGroupingAttrs: Seq[Attribute],
+ initialStateDataAttrs: Seq[Attribute],
+ initialState: LogicalPlan): LogicalPlan = {
+ val keyEncoder = encoderFor[K]
+ val mapped = new TransformWithState(
+ UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes),
+ UnresolvedDeserializer(encoderFor[V].deserializer, dataAttributes),
+ groupingAttributes,
+ dataAttributes,
+ statefulProcessor.asInstanceOf[StatefulProcessor[Any, Any, Any]],
+ timeoutMode,
+ outputMode,
+ keyEncoder.asInstanceOf[ExpressionEncoder[Any]],
+ CatalystSerde.generateObjAttr[U],
+ child,
+ hasInitialState = true,
+ initialStateGroupingAttrs,
+ initialStateDataAttrs,
+ UnresolvedDeserializer(encoderFor[S].deserializer,
initialStateDataAttrs),
+ initialState
)
CatalystSerde.serialize[U](mapped)
}
@@ -604,10 +643,18 @@ case class TransformWithState(
outputMode: OutputMode,
keyEncoder: ExpressionEncoder[Any],
outputObjAttr: Attribute,
- child: LogicalPlan) extends UnaryNode with ObjectProducer {
+ child: LogicalPlan,
+ hasInitialState: Boolean = false,
+ initialStateGroupingAttrs: Seq[Attribute],
+ initialStateDataAttrs: Seq[Attribute],
+ initialStateDeserializer: Expression,
+ initialState: LogicalPlan) extends BinaryNode with ObjectProducer {
- override protected def withNewChildInternal(newChild: LogicalPlan):
TransformWithState =
- copy(child = newChild)
+ override def left: LogicalPlan = child
+ override def right: LogicalPlan = initialState
+ override protected def withNewChildrenInternal(
+ newLeft: LogicalPlan, newRight: LogicalPlan): TransformWithState =
+ copy(child = newLeft, initialState = newRight)
}
/** Factory for constructing new `FlatMapGroupsInR` nodes. */
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 50ab2a41612b..95ad973aee51 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.expressions.ReduceAggregator
import org.apache.spark.sql.internal.TypedAggUtils
-import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout,
OutputMode, StatefulProcessor, TimeoutMode}
+import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout,
OutputMode, StatefulProcessor, StatefulProcessorWithInitialState, TimeoutMode}
/**
* A [[Dataset]] has been logically grouped by a user specified grouping key.
Users should not
@@ -676,6 +676,42 @@ class KeyValueGroupedDataset[K, V] private[sql](
)
}
+ /**
+ * (Scala-specific)
+ * Invokes methods defined in the stateful processor used in arbitrary state
API v2.
+ * Functions as the function above, but with additional initial state.
+ *
+ * @tparam U The type of the output objects. Must be encodable to Spark SQL
types.
+ * @tparam S The type of initial state objects. Must be encodable to Spark
SQL types.
+ * @param statefulProcessor Instance of statefulProcessor whose functions
will
+ * be invoked by the operator.
+ * @param timeoutMode The timeout mode of the stateful processor.
+ * @param outputMode The output mode of the stateful processor.
Defaults to APPEND mode.
+ * @param initialState User provided initial state that will be used to
initiate state for
+ * the query in the first batch.
+ *
+ */
+ private[sql] def transformWithState[U: Encoder, S: Encoder](
+ statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
+ timeoutMode: TimeoutMode,
+ outputMode: OutputMode,
+ initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = {
+ Dataset[U](
+ sparkSession,
+ TransformWithState[K, V, U, S](
+ groupingAttributes,
+ dataAttributes,
+ statefulProcessor,
+ timeoutMode,
+ outputMode,
+ child = logicalPlan,
+ initialState.groupingAttributes,
+ initialState.dataAttributes,
+ initialState.queryExecution.analyzed
+ )
+ )
+ }
+
/**
* (Scala-specific)
* Reduces the elements of each group of data using the specified binary
function.
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index f77d0fef4eb9..cc212d99f299 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -752,7 +752,9 @@ abstract class SparkStrategies extends
QueryPlanner[SparkPlan] {
case TransformWithState(
keyDeserializer, valueDeserializer, groupingAttributes,
dataAttributes, statefulProcessor, timeoutMode, outputMode,
- keyEncoder, outputAttr, child) =>
+ keyEncoder, outputAttr, child, hasInitialState,
+ initialStateGroupingAttrs, initialStateDataAttrs,
+ initialStateDeserializer, initialState) =>
val execPlan = TransformWithStateExec(
keyDeserializer,
valueDeserializer,
@@ -767,7 +769,13 @@ abstract class SparkStrategies extends
QueryPlanner[SparkPlan] {
batchTimestampMs = None,
eventTimeWatermarkForLateEvents = None,
eventTimeWatermarkForEviction = None,
- planLater(child))
+ planLater(child),
+ isStreaming = true,
+ hasInitialState,
+ initialStateGroupingAttrs,
+ initialStateDataAttrs,
+ initialStateDeserializer,
+ planLater(initialState))
execPlan :: Nil
case _ =>
Nil
@@ -918,10 +926,14 @@ abstract class SparkStrategies extends
QueryPlanner[SparkPlan] {
) :: Nil
case logical.TransformWithState(keyDeserializer, valueDeserializer,
groupingAttributes,
dataAttributes, statefulProcessor, timeoutMode, outputMode,
keyEncoder,
- outputObjAttr, child) =>
+ outputObjAttr, child, hasInitialState,
+ initialStateGroupingAttrs, initialStateDataAttrs,
+ initialStateDeserializer, initialState) =>
TransformWithStateExec.generateSparkPlanForBatchQueries(keyDeserializer,
valueDeserializer,
groupingAttributes, dataAttributes, statefulProcessor, timeoutMode,
outputMode,
- keyEncoder, outputObjAttr, planLater(child)) :: Nil
+ keyEncoder, outputObjAttr, planLater(child), hasInitialState,
+ initialStateGroupingAttrs, initialStateDataAttrs,
+ initialStateDeserializer, planLater(initialState)) :: Nil
case _: FlatMapGroupsInPandasWithState =>
// TODO(SPARK-40443): support applyInPandasWithState in batch query
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index 14007eb4b101..cfccfff3a138 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -268,11 +268,13 @@ class IncrementalExecution(
)
case t: TransformWithStateExec =>
+ val hasInitialState = (currentBatchId == 0L && t.hasInitialState)
t.copy(
stateInfo = Some(nextStatefulOperationStateInfo()),
batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs),
eventTimeWatermarkForLateEvents = None,
- eventTimeWatermarkForEviction = None
+ eventTimeWatermarkForEviction = None,
+ hasInitialState = hasInitialState
)
case m: FlatMapGroupsInPandasWithStateExec =>
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
index d3640ebd8850..36b957f9d430 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming
import java.util.UUID
import java.util.concurrent.TimeUnit.NANOSECONDS
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
@@ -26,9 +27,10 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending,
Attribute, Expressi
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical.Distribution
import org.apache.spark.sql.execution._
+import
org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.StateStoreAwareZipPartitionsHelper
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.streaming.{OutputMode, StatefulProcessor,
TimeoutMode}
+import org.apache.spark.sql.streaming.{OutputMode, StatefulProcessor,
StatefulProcessorWithInitialState, TimeoutMode}
import org.apache.spark.sql.types._
import org.apache.spark.util.{CompletionIterator, SerializableConfiguration,
Utils}
@@ -65,8 +67,13 @@ case class TransformWithStateExec(
eventTimeWatermarkForLateEvents: Option[Long],
eventTimeWatermarkForEviction: Option[Long],
child: SparkPlan,
- isStreaming: Boolean = true)
- extends UnaryExecNode with StateStoreWriter with WatermarkSupport with
ObjectProducerExec {
+ isStreaming: Boolean = true,
+ hasInitialState: Boolean = false,
+ initialStateGroupingAttrs: Seq[Attribute],
+ initialStateDataAttrs: Seq[Attribute],
+ initialStateDeserializer: Expression,
+ initialState: SparkPlan)
+ extends BinaryExecNode with StateStoreWriter with WatermarkSupport with
ObjectProducerExec {
override def shortName: String = "transformWithStateExec"
@@ -85,8 +92,13 @@ case class TransformWithStateExec(
}
}
- override protected def withNewChildInternal(
- newChild: SparkPlan): TransformWithStateExec = copy(child = newChild)
+ override def left: SparkPlan = child
+
+ override def right: SparkPlan = initialState
+
+ override protected def withNewChildrenInternal(
+ newLeft: SparkPlan, newRight: SparkPlan): TransformWithStateExec =
+ copy(child = newLeft, initialState = newRight)
override def keyExpressions: Seq[Attribute] = groupingAttributes
@@ -94,14 +106,25 @@ case class TransformWithStateExec(
protected val schemaForValueRow: StructType = new StructType().add("value",
BinaryType)
+ /**
+ * Distribute by grouping attributes - We need the underlying data and the
initial state data
+ * to have the same grouping so that the data are co-located on the same
task.
+ */
override def requiredChildDistribution: Seq[Distribution] = {
- StatefulOperatorPartitioning.getCompatibleDistribution(groupingAttributes,
- getStateInfo, conf) ::
- Nil
+ StatefulOperatorPartitioning.getCompatibleDistribution(
+ groupingAttributes, getStateInfo, conf) ::
+ StatefulOperatorPartitioning.getCompatibleDistribution(
+ initialStateGroupingAttrs, getStateInfo, conf) ::
+ Nil
}
+ /**
+ * We need the initial state to also use the ordering as the data so that we
can co-locate the
+ * keys from the underlying data and the initial state.
+ */
override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq(
- groupingAttributes.map(SortOrder(_, Ascending)))
+ groupingAttributes.map(SortOrder(_, Ascending)),
+ initialStateGroupingAttrs.map(SortOrder(_, Ascending)))
private def handleInputRows(keyRow: UnsafeRow, valueRowIter:
Iterator[InternalRow]):
Iterator[InternalRow] = {
@@ -127,6 +150,33 @@ case class TransformWithStateExec(
mappedIterator
}
+ private def processInitialStateRows(
+ keyRow: UnsafeRow,
+ initStateIter: Iterator[InternalRow]): Unit = {
+ val getKeyObj =
+ ObjectOperator.deserializeRowToObject(keyDeserializer,
groupingAttributes)
+
+ val getInitStateValueObj =
+ ObjectOperator.deserializeRowToObject(initialStateDeserializer,
initialStateDataAttrs)
+
+ val keyObj = getKeyObj(keyRow) // convert key to objects
+ ImplicitGroupingKeyTracker.setImplicitKey(keyObj)
+ val initStateObjIter = initStateIter.map(getInitStateValueObj.apply)
+
+ var seenInitStateOnKey = false
+ initStateObjIter.foreach { initState =>
+ // cannot re-initialize state on the same grouping key during initial
state handling
+ if (seenInitStateOnKey) {
+ throw StateStoreErrors.cannotReInitializeStateOnKey(keyObj.toString)
+ }
+ seenInitStateOnKey = true
+ statefulProcessor
+ .asInstanceOf[StatefulProcessorWithInitialState[Any, Any, Any, Any]]
+ .handleInitialState(keyObj, initState)
+ }
+ ImplicitGroupingKeyTracker.removeImplicitKey()
+ }
+
private def processNewData(dataIter: Iterator[InternalRow]):
Iterator[InternalRow] = {
val groupedIter = GroupedIterator(dataIter, groupingAttributes,
child.output)
groupedIter.flatMap { case (keyRow, valueRowIter) =>
@@ -263,58 +313,108 @@ case class TransformWithStateExec(
case _ =>
}
- if (isStreaming) {
- child.execute().mapPartitionsWithStateStore[InternalRow](
+ if (hasInitialState) {
+ val storeConf = new StateStoreConf(session.sqlContext.sessionState.conf)
+ val hadoopConfBroadcast = sparkContext.broadcast(
+ new
SerializableConfiguration(session.sqlContext.sessionState.newHadoopConf()))
+ child.execute().stateStoreAwareZipPartitions(
+ initialState.execute(),
getStateInfo,
- schemaForKeyRow,
- schemaForValueRow,
- NoPrefixKeyStateEncoderSpec(schemaForKeyRow),
- session.sqlContext.sessionState,
- Some(session.sqlContext.streams.stateStoreCoordinator),
- useColumnFamilies = true,
- useMultipleValuesPerKey = true
- ) {
- case (store: StateStore, singleIterator: Iterator[InternalRow]) =>
- processData(store, singleIterator)
+ storeNames = Seq(),
+ session.sqlContext.streams.stateStoreCoordinator) {
+ // The state store aware zip partitions will provide us with two
iterators,
+ // child data iterator and the initial state iterator per partition.
+ case (partitionId, childDataIterator, initStateIterator) =>
+ if (isStreaming) {
+ val stateStoreId = StateStoreId(stateInfo.get.checkpointLocation,
+ stateInfo.get.operatorId, partitionId)
+ val storeProviderId = StateStoreProviderId(stateStoreId,
stateInfo.get.queryRunId)
+ val store = StateStore.get(
+ storeProviderId = storeProviderId,
+ keySchema = schemaForKeyRow,
+ valueSchema = schemaForValueRow,
+ NoPrefixKeyStateEncoderSpec(schemaForKeyRow),
+ version = stateInfo.get.storeVersion,
+ useColumnFamilies = true,
+ storeConf = storeConf,
+ hadoopConf = hadoopConfBroadcast.value.value
+ )
+
+ processDataWithInitialState(store, childDataIterator,
initStateIterator)
+ } else {
+ initNewStateStoreAndProcessData(partitionId, hadoopConfBroadcast)
{ store =>
+ processDataWithInitialState(store, childDataIterator,
initStateIterator)
+ }
+ }
}
} else {
- // If the query is running in batch mode, we need to create a new
StateStore and instantiate
- // a temp directory on the executors in mapPartitionsWithIndex.
- val broadcastedHadoopConf =
- new SerializableConfiguration(session.sessionState.newHadoopConf())
- child.execute().mapPartitionsWithIndex[InternalRow](
- (i, iter) => {
- val providerId = {
- val tempDirPath = Utils.createTempDir().getAbsolutePath
- new StateStoreProviderId(
- StateStoreId(tempDirPath, 0, i), getStateInfo.queryRunId)
+ if (isStreaming) {
+ child.execute().mapPartitionsWithStateStore[InternalRow](
+ getStateInfo,
+ schemaForKeyRow,
+ schemaForValueRow,
+ NoPrefixKeyStateEncoderSpec(schemaForKeyRow),
+ session.sqlContext.sessionState,
+ Some(session.sqlContext.streams.stateStoreCoordinator),
+ useColumnFamilies = true
+ ) {
+ case (store: StateStore, singleIterator: Iterator[InternalRow]) =>
+ processData(store, singleIterator)
+ }
+ } else {
+ // If the query is running in batch mode, we need to create a new
StateStore and instantiate
+ // a temp directory on the executors in mapPartitionsWithIndex.
+ val hadoopConfBroadcast = sparkContext.broadcast(
+ new
SerializableConfiguration(session.sqlContext.sessionState.newHadoopConf()))
+ child.execute().mapPartitionsWithIndex[InternalRow](
+ (i: Int, iter: Iterator[InternalRow]) => {
+ initNewStateStoreAndProcessData(i, hadoopConfBroadcast) { store =>
+ processData(store, iter)
+ }
}
+ )
+ }
+ }
+ }
- val sqlConf = new SQLConf()
- sqlConf.setConfString(SQLConf.STATE_STORE_PROVIDER_CLASS.key,
- classOf[RocksDBStateStoreProvider].getName)
- val storeConf = new StateStoreConf(sqlConf)
-
- // Create StateStoreProvider for this partition
- val stateStoreProvider = StateStoreProvider.createAndInit(
- providerId,
- schemaForKeyRow,
- schemaForValueRow,
- NoPrefixKeyStateEncoderSpec(schemaForKeyRow),
- useColumnFamilies = true,
- storeConf = storeConf,
- hadoopConf = broadcastedHadoopConf.value,
- useMultipleValuesPerKey = true)
-
- val store = stateStoreProvider.getStore(0)
- val outputIterator = processData(store, iter)
- CompletionIterator[InternalRow,
Iterator[InternalRow]](outputIterator.iterator, {
- stateStoreProvider.close()
- statefulProcessor.close()
- })
- }
- )
+ /**
+ * Create a new StateStore for given partitionId and instantiate a temp
directory
+ * on the executors. Process data and close the stateStore provider
afterwards.
+ */
+ private def initNewStateStoreAndProcessData(
+ partitionId: Int,
+ hadoopConfBroadcast: Broadcast[SerializableConfiguration])
+ (f: StateStore => CompletionIterator[InternalRow, Iterator[InternalRow]]):
+ CompletionIterator[InternalRow, Iterator[InternalRow]] = {
+
+ val providerId = {
+ val tempDirPath = Utils.createTempDir().getAbsolutePath
+ new StateStoreProviderId(
+ StateStoreId(tempDirPath, 0, partitionId), getStateInfo.queryRunId)
}
+
+ val sqlConf = new SQLConf()
+ sqlConf.setConfString(SQLConf.STATE_STORE_PROVIDER_CLASS.key,
+ classOf[RocksDBStateStoreProvider].getName)
+ val storeConf = new StateStoreConf(sqlConf)
+
+ // Create StateStoreProvider for this partition
+ val stateStoreProvider = StateStoreProvider.createAndInit(
+ providerId,
+ schemaForKeyRow,
+ schemaForValueRow,
+ NoPrefixKeyStateEncoderSpec(schemaForKeyRow),
+ useColumnFamilies = true,
+ storeConf = storeConf,
+ hadoopConf = hadoopConfBroadcast.value.value,
+ useMultipleValuesPerKey = true)
+
+ val store = stateStoreProvider.getStore(0)
+ val outputIterator = f(store)
+ CompletionIterator[InternalRow,
Iterator[InternalRow]](outputIterator.iterator, {
+ stateStoreProvider.close()
+ statefulProcessor.close()
+ })
}
/**
@@ -333,8 +433,37 @@ case class TransformWithStateExec(
processorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED)
processDataWithPartition(singleIterator, store, processorHandle)
}
+
+ private def processDataWithInitialState(
+ store: StateStore,
+ childDataIterator: Iterator[InternalRow],
+ initStateIterator: Iterator[InternalRow]):
+ CompletionIterator[InternalRow, Iterator[InternalRow]] = {
+ val processorHandle = new StatefulProcessorHandleImpl(store,
getStateInfo.queryRunId,
+ keyEncoder, timeoutMode, isStreaming)
+ assert(processorHandle.getHandleState ==
StatefulProcessorHandleState.CREATED)
+ statefulProcessor.setHandle(processorHandle)
+ statefulProcessor.init(outputMode, timeoutMode)
+ processorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED)
+
+ // Check if is first batch
+ // Only process initial states for first batch
+ if (processorHandle.getQueryInfo().getBatchId == 0) {
+ // If the user provided initial state, we need to have the initial state
and the
+ // data in the same partition so that we can still have just one commit
at the end.
+ val groupedInitialStateIter = GroupedIterator(initStateIterator,
+ initialStateGroupingAttrs, initialState.output)
+ groupedInitialStateIter.foreach {
+ case (keyRow, valueRowIter) =>
+ processInitialStateRows(keyRow.asInstanceOf[UnsafeRow], valueRowIter)
+ }
+ }
+
+ processDataWithPartition(childDataIterator, store, processorHandle)
+ }
}
+// scalastyle:off
object TransformWithStateExec {
// Plan logical transformWithState for batch queries
@@ -348,7 +477,12 @@ object TransformWithStateExec {
outputMode: OutputMode,
keyEncoder: ExpressionEncoder[Any],
outputObjAttr: Attribute,
- child: SparkPlan): SparkPlan = {
+ child: SparkPlan,
+ hasInitialState: Boolean = false,
+ initialStateGroupingAttrs: Seq[Attribute],
+ initialStateDataAttrs: Seq[Attribute],
+ initialStateDeserializer: Expression,
+ initialState: SparkPlan): SparkPlan = {
val shufflePartitions =
child.session.sessionState.conf.numShufflePartitions
val statefulOperatorStateInfo = StatefulOperatorStateInfo(
checkpointLocation = "", // empty checkpointLocation will be populated
in doExecute
@@ -373,6 +507,12 @@ object TransformWithStateExec {
None,
None,
child,
- isStreaming = false)
+ isStreaming = false,
+ hasInitialState,
+ initialStateGroupingAttrs,
+ initialStateDataAttrs,
+ initialStateDeserializer,
+ initialState)
}
}
+// scalastyle:on
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
index a8d4c06bc83c..2f72cbb0b0fc 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
@@ -112,6 +112,11 @@ object StateStoreErrors {
handleState: String):
StatefulProcessorCannotPerformOperationWithInvalidHandleState = {
new
StatefulProcessorCannotPerformOperationWithInvalidHandleState(operationType,
handleState)
}
+
+ def cannotReInitializeStateOnKey(groupingKey: String):
+ StatefulProcessorCannotReInitializeState = {
+ new StatefulProcessorCannotReInitializeState(groupingKey)
+ }
}
class
StateStoreMultipleColumnFamiliesNotSupportedException(stateStoreProvider:
String)
@@ -157,6 +162,11 @@ class
StatefulProcessorCannotPerformOperationWithInvalidHandleState(
messageParameters = Map("operationType" -> operationType, "handleState" ->
handleState)
)
+class StatefulProcessorCannotReInitializeState(groupingKey: String)
+ extends SparkUnsupportedOperationException(
+ errorClass = "STATEFUL_PROCESSOR_CANNOT_REINITIALIZE_STATE_ON_KEY",
+ messageParameters = Map("groupingKey" -> groupingKey))
+
class StateStoreUnsupportedOperationOnMissingColumnFamily(
operationType: String,
colFamilyName: String) extends SparkUnsupportedOperationException(
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala
index d7c5ce3815b0..db8cb8b810af 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.streaming
import org.apache.spark.SparkIllegalArgumentException
import org.apache.spark.sql.Encoders
import org.apache.spark.sql.execution.streaming.MemoryStream
-import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider
+import
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled,
RocksDBStateStoreProvider}
import org.apache.spark.sql.internal.SQLConf
case class InputMapRow(key: String, action: String, value: (String, String))
@@ -82,7 +82,8 @@ class TestMapStateProcessor
* Class that adds integration tests for MapState types used in arbitrary
stateful
* operators such as transformWithState.
*/
-class TransformWithMapStateSuite extends StreamTest {
+class TransformWithMapStateSuite extends StreamTest
+ with AlsoTestWithChangelogCheckpointingEnabled {
import testImplicits._
private def testMapStateWithNullUserKey(inputMapRow: InputMapRow): Unit = {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala
new file mode 100644
index 000000000000..9f2e2c2d9f02
--- /dev/null
+++
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala
@@ -0,0 +1,293 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import org.apache.spark.SparkUnsupportedOperationException
+import org.apache.spark.sql.{Encoders, KeyValueGroupedDataset}
+import org.apache.spark.sql.execution.streaming.MemoryStream
+import
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled,
RocksDBStateStoreProvider}
+import org.apache.spark.sql.internal.SQLConf
+
+case class InitInputRow(key: String, action: String, value: Double)
+case class InputRowForInitialState(
+ key: String, value: Double, entries: List[Double], mapping: Map[Double,
Int])
+
+abstract class StatefulProcessorWithInitialStateTestClass[V]
+ extends StatefulProcessorWithInitialState[
+ String, InitInputRow, (String, String, Double), V] {
+ @transient var _valState: ValueState[Double] = _
+ @transient var _listState: ListState[Double] = _
+ @transient var _mapState: MapState[Double, Int] = _
+
+ override def init(outputMode: OutputMode, timeoutMode: TimeoutMode): Unit = {
+ _valState = getHandle.getValueState[Double]("testValueInit",
Encoders.scalaDouble)
+ _listState = getHandle.getListState[Double]("testListInit",
Encoders.scalaDouble)
+ _mapState = getHandle.getMapState[Double, Int](
+ "testMapInit", Encoders.scalaDouble, Encoders.scalaInt)
+ }
+
+ override def close(): Unit = {}
+
+ override def handleInputRows(
+ key: String,
+ inputRows: Iterator[InitInputRow],
+ timerValues: TimerValues,
+ expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String, Double)]
= {
+ var output = List[(String, String, Double)]()
+ for (row <- inputRows) {
+ if (row.action == "getOption") {
+ output = (key, row.action, _valState.getOption().getOrElse(-1.0)) ::
output
+ } else if (row.action == "update") {
+ _valState.update(row.value)
+ } else if (row.action == "remove") {
+ _valState.clear()
+ } else if (row.action == "getList") {
+ _listState.get().foreach { element =>
+ output = (key, row.action, element) :: output
+ }
+ } else if (row.action == "appendList") {
+ _listState.appendValue(row.value)
+ } else if (row.action == "clearList") {
+ _listState.clear()
+ } else if (row.action == "getCount") {
+ val count =
+ if (!_mapState.containsKey(row.value)) 0
+ else _mapState.getValue(row.value)
+ output = (key, row.action, count.toDouble) :: output
+ } else if (row.action == "incCount") {
+ val count =
+ if (!_mapState.containsKey(row.value)) 0
+ else _mapState.getValue(row.value)
+ _mapState.updateValue(row.value, count + 1)
+ } else if (row.action == "clearCount") {
+ _mapState.removeKey(row.value)
+ }
+ }
+ output.iterator
+ }
+}
+
+class AccumulateStatefulProcessorWithInitState
+ extends StatefulProcessorWithInitialStateTestClass[(String, Double)] {
+ override def handleInitialState(
+ key: String,
+ initialState: (String, Double)): Unit = {
+ _valState.update(initialState._2)
+ }
+
+ override def handleInputRows(
+ key: String,
+ inputRows: Iterator[InitInputRow],
+ timerValues: TimerValues,
+ expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String, Double)]
= {
+ var output = List[(String, String, Double)]()
+ for (row <- inputRows) {
+ if (row.action == "getOption") {
+ output = (key, row.action, _valState.getOption().getOrElse(0.0)) ::
output
+ } else if (row.action == "add") {
+ // Update state variable as accumulative sum
+ val accumulateSum = _valState.getOption().getOrElse(0.0) + row.value
+ _valState.update(accumulateSum)
+ } else if (row.action == "remove") {
+ _valState.clear()
+ }
+ }
+ output.iterator
+ }
+}
+
+class InitialStateInMemoryTestClass
+ extends StatefulProcessorWithInitialStateTestClass[InputRowForInitialState] {
+ override def handleInitialState(
+ key: String,
+ initialState: InputRowForInitialState): Unit = {
+ _valState.update(initialState.value)
+ _listState.appendList(initialState.entries.toArray)
+ val inMemoryMap = initialState.mapping
+ inMemoryMap.foreach { kvPair =>
+ _mapState.updateValue(kvPair._1, kvPair._2)
+ }
+ }
+}
+
+/**
+ * Class that adds tests for transformWithState stateful
+ * streaming operator with user-defined initial state
+ */
+class TransformWithStateInitialStateSuite extends StateStoreMetricsTest
+ with AlsoTestWithChangelogCheckpointingEnabled {
+
+ import testImplicits._
+
+ private def createInitialDfForTest: KeyValueGroupedDataset[String, (String,
Double)] = {
+ Seq(("init_1", 40.0), ("init_2", 100.0)).toDS()
+ .groupByKey(x => x._1)
+ .mapValues(x => x)
+ }
+
+
+ test("transformWithStateWithInitialState - correctness test, " +
+ "run with multiple state variables - in-memory type") {
+ withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+ classOf[RocksDBStateStoreProvider].getName) {
+
+ val inputData = MemoryStream[InitInputRow]
+ val kvDataSet = inputData.toDS()
+ .groupByKey(x => x.key)
+ val initStateDf =
+ Seq(InputRowForInitialState("init_1", 40.0, List(40.0), Map(40.0 ->
1)),
+ InputRowForInitialState("init_2", 100.0, List(100.0), Map(100.0 ->
1)))
+ .toDS().groupByKey(x => x.key).mapValues(x => x)
+ val query = kvDataSet.transformWithState(new
InitialStateInMemoryTestClass(),
+ TimeoutMode.NoTimeouts(), OutputMode.Append(), initStateDf)
+
+ testStream(query, OutputMode.Update())(
+ // non-exist key test
+ AddData(inputData, InitInputRow("k1", "update", 37.0)),
+ AddData(inputData, InitInputRow("k2", "update", 40.0)),
+ AddData(inputData, InitInputRow("non-exist", "getOption", -1.0)),
+ CheckNewAnswer(("non-exist", "getOption", -1.0)),
+ AddData(inputData, InitInputRow("k1", "appendList", 37.0)),
+ AddData(inputData, InitInputRow("k2", "appendList", 40.0)),
+ AddData(inputData, InitInputRow("non-exist", "getList", -1.0)),
+ CheckNewAnswer(),
+
+ AddData(inputData, InitInputRow("k1", "incCount", 37.0)),
+ AddData(inputData, InitInputRow("k2", "incCount", 40.0)),
+ AddData(inputData, InitInputRow("non-exist", "getCount", -1.0)),
+ CheckNewAnswer(("non-exist", "getCount", 0.0)),
+ AddData(inputData, InitInputRow("k2", "incCount", 40.0)),
+ AddData(inputData, InitInputRow("k2", "getCount", 40.0)),
+ CheckNewAnswer(("k2", "getCount", 2.0)),
+
+ // test every row in initial State is processed
+ AddData(inputData, InitInputRow("init_1", "getOption", -1.0)),
+ CheckNewAnswer(("init_1", "getOption", 40.0)),
+ AddData(inputData, InitInputRow("init_2", "getOption", -1.0)),
+ CheckNewAnswer(("init_2", "getOption", 100.0)),
+
+ AddData(inputData, InitInputRow("init_1", "getList", -1.0)),
+ CheckNewAnswer(("init_1", "getList", 40.0)),
+ AddData(inputData, InitInputRow("init_2", "getList", -1.0)),
+ CheckNewAnswer(("init_2", "getList", 100.0)),
+
+ AddData(inputData, InitInputRow("init_1", "getCount", 40.0)),
+ CheckNewAnswer(("init_1", "getCount", 1.0)),
+ AddData(inputData, InitInputRow("init_2", "getCount", 100.0)),
+ CheckNewAnswer(("init_2", "getCount", 1.0)),
+
+ // Update row with key in initial row will work
+ AddData(inputData, InitInputRow("init_1", "update", 50.0)),
+ AddData(inputData, InitInputRow("init_1", "getOption", -1.0)),
+ CheckNewAnswer(("init_1", "getOption", 50.0)),
+ AddData(inputData, InitInputRow("init_1", "remove", -1.0)),
+ AddData(inputData, InitInputRow("init_1", "getOption", -1.0)),
+ CheckNewAnswer(("init_1", "getOption", -1.0)),
+
+ AddData(inputData, InitInputRow("init_1", "appendList", 50.0)),
+ AddData(inputData, InitInputRow("init_1", "getList", -1.0)),
+ CheckNewAnswer(("init_1", "getList", 50.0), ("init_1", "getList",
40.0)),
+
+ AddData(inputData, InitInputRow("init_1", "incCount", 40.0)),
+ AddData(inputData, InitInputRow("init_1", "getCount", 40.0)),
+ CheckNewAnswer(("init_1", "getCount", 2.0)),
+
+ // test remove
+ AddData(inputData, InitInputRow("k1", "remove", -1.0)),
+ AddData(inputData, InitInputRow("k1", "getOption", -1.0)),
+ CheckNewAnswer(("k1", "getOption", -1.0)),
+
+ AddData(inputData, InitInputRow("init_1", "clearCount", -1.0)),
+ AddData(inputData, InitInputRow("init_1", "getCount", -1.0)),
+ CheckNewAnswer(("init_1", "getCount", 0.0)),
+
+ AddData(inputData, InitInputRow("init_1", "clearList", -1.0)),
+ AddData(inputData, InitInputRow("init_1", "getList", -1.0)),
+ CheckNewAnswer()
+ )
+ }
+ }
+
+ test("transformWithStateWithInitialState -" +
+ " correctness test, processInitialState should only run once") {
+ withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+ classOf[RocksDBStateStoreProvider].getName) {
+ val initStateDf = createInitialDfForTest
+ val inputData = MemoryStream[InitInputRow]
+ val query = inputData.toDS()
+ .groupByKey(x => x.key)
+ .transformWithState(new AccumulateStatefulProcessorWithInitState(),
+ TimeoutMode.NoTimeouts(), OutputMode.Append(), initStateDf
+ )
+ testStream(query, OutputMode.Update())(
+ AddData(inputData, InitInputRow("init_1", "add", 50.0)),
+ AddData(inputData, InitInputRow("init_2", "add", 60.0)),
+ AddData(inputData, InitInputRow("init_1", "add", 50.0)),
+ // If processInitialState was processed multiple times,
+ // following checks will fail
+ AddData(inputData,
+ InitInputRow("init_1", "getOption", -1.0), InitInputRow("init_2",
"getOption", -1.0)),
+ CheckNewAnswer(("init_2", "getOption", 160.0), ("init_1", "getOption",
140.0))
+ )
+ }
+ }
+
+ test("transformWithStateWithInitialState - batch should succeed") {
+ val inputData = Seq(InitInputRow("k1", "add", 37.0), InitInputRow("k1",
"getOption", -1.0))
+ val result = inputData.toDS()
+ .groupByKey(x => x.key)
+ .transformWithState(new AccumulateStatefulProcessorWithInitState(),
+ TimeoutMode.NoTimeouts(),
+ OutputMode.Append(),
+ createInitialDfForTest)
+
+ val df = result.toDF()
+ checkAnswer(df, Seq(("k1", "getOption", 37.0)).toDF())
+ }
+
+ test("transformWithStateWithInitialState - " +
+ "cannot re-initialize state during initial state handling") {
+ withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+ classOf[RocksDBStateStoreProvider].getName) {
+ val initDf = Seq(("init_1", 40.0), ("init_2", 100.0), ("init_1",
50.0)).toDS()
+ .groupByKey(x => x._1).mapValues(x => x)
+ val inputData = MemoryStream[InitInputRow]
+ val query = inputData.toDS()
+ .groupByKey(x => x.key)
+ .transformWithState(new AccumulateStatefulProcessorWithInitState(),
+ TimeoutMode.NoTimeouts(),
+ OutputMode.Append(),
+ initDf)
+
+ testStream(query, OutputMode.Update())(
+ AddData(inputData, InitInputRow("k1", "add", 50.0)),
+ Execute { q =>
+ val e = intercept[Exception] {
+ q.processAllAvailable()
+ }
+ checkError(
+ exception =
e.getCause.asInstanceOf[SparkUnsupportedOperationException],
+ errorClass = "STATEFUL_PROCESSOR_CANNOT_REINITIALIZE_STATE_ON_KEY",
+ sqlState = Some("42802"),
+ parameters = Map("groupingKey" -> "init_1")
+ )
+ }
+ )
+ }
+ }
+}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
index 24b0d59c45c5..24e68e3db9d8 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
@@ -769,4 +769,24 @@ class TransformWithStateValidationSuite extends
StateStoreMetricsTest {
}
)
}
+
+ test("transformWithStateWithInitialState - streaming with
hdfsStateStoreProvider should fail") {
+ val inputData = MemoryStream[InitInputRow]
+ val initDf = Seq(("init_1", 40.0), ("init_2", 100.0)).toDS()
+ .groupByKey(x => x._1)
+ .mapValues(x => x)
+ val result = inputData.toDS()
+ .groupByKey(x => x.key)
+ .transformWithState(new AccumulateStatefulProcessorWithInitState(),
+ TimeoutMode.NoTimeouts(), OutputMode.Append(), initDf
+ )
+ testStream(result, OutputMode.Update())(
+ AddData(inputData, InitInputRow("a", "add", -1.0)),
+ ExpectFailure[StateStoreMultipleColumnFamiliesNotSupportedException] {
+ (t: Throwable) => {
+ assert(t.getMessage.contains("not supported"))
+ }
+ }
+ )
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]