HeartSaVioR commented on code in PR #45467:
URL: https://github.com/apache/spark/pull/45467#discussion_r1540460865
##########
sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala:
##########
@@ -676,6 +676,42 @@ class KeyValueGroupedDataset[K, V] private[sql](
)
}
+ /**
+ * (Scala-specific)
+ * Invokes methods defined in the stateful processor used in arbitrary state
API v2.
+ * Functions as the function above, but with additional initial state.
+ *
+ * @tparam U The type of the output objects. Must be encodable to Spark SQL
types.
+ * @tparam S The type of initial state objects. Must be encodable to Spark
SQL types.
+ * @param StatefulProcessorWithInitialState Instance of statefulProcessor
whose functions will
+ * be invoked by the operator.
+ * @param timeoutMode The timeout mode of the stateful processor.
+ * @param outputMode The output mode of the stateful processor.
Defaults to APPEND mode.
+ * @param initialState User provided initial state that will be used to
initiate state for
+ * the query in the first batch.
+ *
+ */
+ private[sql] def transformWithState[U: Encoder, S: Encoder](
+ statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
+ timeoutMode: TimeoutMode,
+ outputMode: OutputMode,
+ initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = {
+ Dataset[U](
+ sparkSession,
+ TransformWithState[K, V, U, S](
+ groupingAttributes,
+ dataAttributes,
+ statefulProcessor,
+ timeoutMode,
+ outputMode,
+ child = logicalPlan,
+ initialState.groupingAttributes,
+ initialState.dataAttributes,
+ initialState.queryExecution.logical
Review Comment:
Shall we follow the practice we did in flatMapGroupsWithState for safeness
sake?
`initialState.queryExecution.analyzed`
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala:
##########
@@ -268,11 +268,13 @@ class IncrementalExecution(
)
case t: TransformWithStateExec =>
+ val hasInitialState = (isFirstBatch && t.hasInitialState)
Review Comment:
I don't think we want to allow adding state in the middle of the query
lifecycle. Here `isFirstBatch` does not mean batch ID = 0 but mean this is the
first batch in this query run.
This should follow the above logic we did for FlatMapGroupsWithStateExec,
`currentBatchId == 0L`.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala:
##########
@@ -268,11 +268,13 @@ class IncrementalExecution(
)
case t: TransformWithStateExec =>
+ val hasInitialState = (isFirstBatch && t.hasInitialState)
Review Comment:
Please let me know if this is a different functionality than we had in
flatMapGroupsWithState.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala:
##########
@@ -271,57 +320,111 @@ case class TransformWithStateExec(
case _ =>
}
- if (isStreaming) {
- child.execute().mapPartitionsWithStateStore[InternalRow](
+ if (hasInitialState) {
+ val storeConf = new StateStoreConf(session.sqlContext.sessionState.conf)
+ val hadoopConfBroadcast = sparkContext.broadcast(
Review Comment:
Yeah there is a code comment. The practice seems to be that it's better to
use broadcast rather than task serialization as it could be huge.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala:
##########
@@ -271,57 +320,111 @@ case class TransformWithStateExec(
case _ =>
}
- if (isStreaming) {
- child.execute().mapPartitionsWithStateStore[InternalRow](
+ if (hasInitialState) {
+ val storeConf = new StateStoreConf(session.sqlContext.sessionState.conf)
+ val hadoopConfBroadcast = sparkContext.broadcast(
+ new
SerializableConfiguration(session.sqlContext.sessionState.newHadoopConf()))
+ child.execute().stateStoreAwareZipPartitions(
+ initialState.execute(),
getStateInfo,
- schemaForKeyRow,
- schemaForValueRow,
- NoPrefixKeyStateEncoderSpec(schemaForKeyRow),
- session.sqlContext.sessionState,
- Some(session.sqlContext.streams.stateStoreCoordinator),
- useColumnFamilies = true,
- useMultipleValuesPerKey = true
- ) {
- case (store: StateStore, singleIterator: Iterator[InternalRow]) =>
- processData(store, singleIterator)
+ storeNames = Seq(),
+ session.sqlContext.streams.stateStoreCoordinator) {
+ // The state store aware zip partitions will provide us with two
iterators,
+ // child data iterator and the initial state iterator per partition.
+ case (partitionId, childDataIterator, initStateIterator) =>
+ if (isStreaming) {
+ val stateStoreId = StateStoreId(stateInfo.get.checkpointLocation,
+ stateInfo.get.operatorId, partitionId)
+ val storeProviderId = StateStoreProviderId(stateStoreId,
stateInfo.get.queryRunId)
+ val store = StateStore.get(
+ storeProviderId = storeProviderId,
+ keySchema = schemaForKeyRow,
+ valueSchema = schemaForValueRow,
+ NoPrefixKeyStateEncoderSpec(schemaForKeyRow),
+ version = stateInfo.get.storeVersion,
+ useColumnFamilies = true,
+ storeConf = storeConf,
+ hadoopConf = hadoopConfBroadcast.value.value
+ )
+
+ processDataWithInitialState(store, childDataIterator,
initStateIterator)
+ } else {
+ val providerId = {
+ val tempDirPath = Utils.createTempDir().getAbsolutePath
+ new StateStoreProviderId(
+ StateStoreId(tempDirPath, 0, partitionId),
getStateInfo.queryRunId)
+ }
+ val sqlConf = new SQLConf()
+ sqlConf.setConfString(SQLConf.STATE_STORE_PROVIDER_CLASS.key,
+ classOf[RocksDBStateStoreProvider].getName)
+
+ // Create StateStoreProvider for this partition
+ val stateStoreProvider = StateStoreProvider.createAndInit(
+ providerId,
+ schemaForKeyRow,
+ schemaForValueRow,
+ NoPrefixKeyStateEncoderSpec(schemaForKeyRow),
+ useColumnFamilies = true,
+ storeConf = new StateStoreConf(sqlConf),
+ hadoopConf = hadoopConfBroadcast.value.value,
+ useMultipleValuesPerKey = true)
+ val store = stateStoreProvider.getStore(0)
+
+ processDataWithInitialState(store, childDataIterator,
initStateIterator)
Review Comment:
We close the state store and state store provider in batch codepath (see
below). Shall we do that here as well?
Also, this is a good representation that we have duplicated code. two batch
parts have similarity on spinning up state store provider and state store, and
also closing them. That could be extracted out.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala:
##########
@@ -271,57 +320,111 @@ case class TransformWithStateExec(
case _ =>
}
- if (isStreaming) {
- child.execute().mapPartitionsWithStateStore[InternalRow](
+ if (hasInitialState) {
+ val storeConf = new StateStoreConf(session.sqlContext.sessionState.conf)
+ val hadoopConfBroadcast = sparkContext.broadcast(
+ new
SerializableConfiguration(session.sqlContext.sessionState.newHadoopConf()))
+ child.execute().stateStoreAwareZipPartitions(
+ initialState.execute(),
getStateInfo,
- schemaForKeyRow,
- schemaForValueRow,
- NoPrefixKeyStateEncoderSpec(schemaForKeyRow),
- session.sqlContext.sessionState,
- Some(session.sqlContext.streams.stateStoreCoordinator),
- useColumnFamilies = true,
- useMultipleValuesPerKey = true
- ) {
- case (store: StateStore, singleIterator: Iterator[InternalRow]) =>
- processData(store, singleIterator)
+ storeNames = Seq(),
+ session.sqlContext.streams.stateStoreCoordinator) {
+ // The state store aware zip partitions will provide us with two
iterators,
+ // child data iterator and the initial state iterator per partition.
+ case (partitionId, childDataIterator, initStateIterator) =>
+ if (isStreaming) {
+ val stateStoreId = StateStoreId(stateInfo.get.checkpointLocation,
+ stateInfo.get.operatorId, partitionId)
+ val storeProviderId = StateStoreProviderId(stateStoreId,
stateInfo.get.queryRunId)
+ val store = StateStore.get(
+ storeProviderId = storeProviderId,
+ keySchema = schemaForKeyRow,
+ valueSchema = schemaForValueRow,
+ NoPrefixKeyStateEncoderSpec(schemaForKeyRow),
+ version = stateInfo.get.storeVersion,
+ useColumnFamilies = true,
+ storeConf = storeConf,
+ hadoopConf = hadoopConfBroadcast.value.value
+ )
+
+ processDataWithInitialState(store, childDataIterator,
initStateIterator)
+ } else {
+ val providerId = {
+ val tempDirPath = Utils.createTempDir().getAbsolutePath
+ new StateStoreProviderId(
+ StateStoreId(tempDirPath, 0, partitionId),
getStateInfo.queryRunId)
+ }
+ val sqlConf = new SQLConf()
+ sqlConf.setConfString(SQLConf.STATE_STORE_PROVIDER_CLASS.key,
+ classOf[RocksDBStateStoreProvider].getName)
+
+ // Create StateStoreProvider for this partition
+ val stateStoreProvider = StateStoreProvider.createAndInit(
+ providerId,
+ schemaForKeyRow,
+ schemaForValueRow,
+ NoPrefixKeyStateEncoderSpec(schemaForKeyRow),
+ useColumnFamilies = true,
+ storeConf = new StateStoreConf(sqlConf),
+ hadoopConf = hadoopConfBroadcast.value.value,
+ useMultipleValuesPerKey = true)
+ val store = stateStoreProvider.getStore(0)
+
+ processDataWithInitialState(store, childDataIterator,
initStateIterator)
+ }
}
} else {
- // If the query is running in batch mode, we need to create a new
StateStore and instantiate
- // a temp directory on the executors in mapPartitionsWithIndex.
- val broadcastedHadoopConf =
+ if (isStreaming) {
+ child.execute().mapPartitionsWithStateStore[InternalRow](
+ getStateInfo,
+ schemaForKeyRow,
+ schemaForValueRow,
+ NoPrefixKeyStateEncoderSpec(schemaForKeyRow),
+ session.sqlContext.sessionState,
+ Some(session.sqlContext.streams.stateStoreCoordinator),
+ useColumnFamilies = true
+ ) {
+ case (store: StateStore, singleIterator: Iterator[InternalRow]) =>
+ processData(store, singleIterator)
+ }
+ } else {
+ // If the query is running in batch mode, we need to create a new
StateStore and instantiate
Review Comment:
nit: apply the same practice while we are here? broadcast
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala:
##########
@@ -341,8 +444,37 @@ case class TransformWithStateExec(
processorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED)
processDataWithPartition(singleIterator, store, processorHandle)
}
+
+ private def processDataWithInitialState(
+ store: StateStore,
+ childDataIterator: Iterator[InternalRow],
+ initStateIterator: Iterator[InternalRow]):
+ CompletionIterator[InternalRow, Iterator[InternalRow]] = {
+ val processorHandle = new StatefulProcessorHandleImpl(store,
getStateInfo.queryRunId,
+ keyEncoder, timeoutMode, isStreaming)
+ assert(processorHandle.getHandleState ==
StatefulProcessorHandleState.CREATED)
+ statefulProcessor.setHandle(processorHandle)
+ statefulProcessor.init(outputMode, timeoutMode)
+ processorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED)
+
+ // Check if is first batch
+ // Only process initial states for first batch
+ if (processorHandle.getQueryInfo().getBatchId == 0) {
Review Comment:
OK I see we have multiple checks. Though still better to change the
condition in IncrementalExecution as reader can misunderstand that there are
inconsistency between flatMapGroupsWithState and transformWithState.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]