jingz-db commented on code in PR #45467:
URL: https://github.com/apache/spark/pull/45467#discussion_r1542128879
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala:
##########
@@ -271,57 +320,111 @@ case class TransformWithStateExec(
case _ =>
}
- if (isStreaming) {
- child.execute().mapPartitionsWithStateStore[InternalRow](
+ if (hasInitialState) {
+ val storeConf = new StateStoreConf(session.sqlContext.sessionState.conf)
+ val hadoopConfBroadcast = sparkContext.broadcast(
+ new
SerializableConfiguration(session.sqlContext.sessionState.newHadoopConf()))
+ child.execute().stateStoreAwareZipPartitions(
+ initialState.execute(),
getStateInfo,
- schemaForKeyRow,
- schemaForValueRow,
- NoPrefixKeyStateEncoderSpec(schemaForKeyRow),
- session.sqlContext.sessionState,
- Some(session.sqlContext.streams.stateStoreCoordinator),
- useColumnFamilies = true,
- useMultipleValuesPerKey = true
- ) {
- case (store: StateStore, singleIterator: Iterator[InternalRow]) =>
- processData(store, singleIterator)
+ storeNames = Seq(),
+ session.sqlContext.streams.stateStoreCoordinator) {
+ // The state store aware zip partitions will provide us with two
iterators,
+ // child data iterator and the initial state iterator per partition.
+ case (partitionId, childDataIterator, initStateIterator) =>
+ if (isStreaming) {
+ val stateStoreId = StateStoreId(stateInfo.get.checkpointLocation,
+ stateInfo.get.operatorId, partitionId)
+ val storeProviderId = StateStoreProviderId(stateStoreId,
stateInfo.get.queryRunId)
+ val store = StateStore.get(
+ storeProviderId = storeProviderId,
+ keySchema = schemaForKeyRow,
+ valueSchema = schemaForValueRow,
+ NoPrefixKeyStateEncoderSpec(schemaForKeyRow),
+ version = stateInfo.get.storeVersion,
+ useColumnFamilies = true,
+ storeConf = storeConf,
+ hadoopConf = hadoopConfBroadcast.value.value
+ )
+
+ processDataWithInitialState(store, childDataIterator,
initStateIterator)
+ } else {
+ val providerId = {
+ val tempDirPath = Utils.createTempDir().getAbsolutePath
+ new StateStoreProviderId(
+ StateStoreId(tempDirPath, 0, partitionId),
getStateInfo.queryRunId)
+ }
+ val sqlConf = new SQLConf()
+ sqlConf.setConfString(SQLConf.STATE_STORE_PROVIDER_CLASS.key,
+ classOf[RocksDBStateStoreProvider].getName)
+
+ // Create StateStoreProvider for this partition
+ val stateStoreProvider = StateStoreProvider.createAndInit(
+ providerId,
+ schemaForKeyRow,
+ schemaForValueRow,
+ NoPrefixKeyStateEncoderSpec(schemaForKeyRow),
+ useColumnFamilies = true,
+ storeConf = new StateStoreConf(sqlConf),
+ hadoopConf = hadoopConfBroadcast.value.value,
+ useMultipleValuesPerKey = true)
+ val store = stateStoreProvider.getStore(0)
+
+ processDataWithInitialState(store, childDataIterator,
initStateIterator)
Review Comment:
Good advice! Refactored duplicated codes into
`initNewStateStoreAndProcessData()`.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]