This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5432cef483c4 [SPARK-50152][SS] Support handleInitialState with state 
data source reader
5432cef483c4 is described below

commit 5432cef483c499e3548a30d608550bac9fce53ec
Author: jingz-db <[email protected]>
AuthorDate: Tue Nov 12 10:43:48 2024 +0900

    [SPARK-50152][SS] Support handleInitialState with state data source reader
    
    ### What changes were proposed in this pull request?
    
    This PR adds support for users to provide a Dataframe that can be used to 
instantiate state for the query in the first batch for arbitrary state API v2. 
More specifically, this dataframe is coming from state data source reader.
    
    Remove the restraints that initialState dataframe can only contains one 
value row for a grouping key. This is to enable the integration with state data 
source reader. In flattened state data source reader for composite type, we 
will have multiple value rows mapping to the same grouping key.
    
    For example, we can union dataframe created by state data source reader on 
a single state variable and union them together and get an output dataframe as 
initial state for a transformWithState operator like this:
    ```
    +-----------+-----+---------+----------+------------+
    |groupingKey|value|listValue|userMapKey|userMapValue|
    +-----------+-----+---------+----------+------------+
    |a          |3    |NULL     |NULL      |NULL        |
    |b          |2    |NULL     |NULL      |NULL        |
    |a          |NULL |1        |NULL      |NULL        |
    |a          |NULL |2        |NULL      |NULL        |
    |a          |NULL |3        |NULL      |NULL        |
    |b          |NULL |1        |NULL      |NULL        |
    |b          |NULL |2        |NULL      |NULL        |
    |a          |NULL |NULL     |a         |3           |
    |b          |NULL |NULL     |b         |2           |
    +-----------+-----+---------+----------+------------+
    ```
    ### Why are the changes needed?
    
    This change is for supporting initial state handling for integration with 
state data source reader.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No. The user API is the same as prior PR: 
https://github.com/apache/spark/pull/45467 for initial state support without 
state data source reader.
    
    ### How was this patch tested?
    
    Unit test cases added in `TransformWithStateWithInitialStateSuite`.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #48686 from jingz-db/initial-state-reader-integration.
    
    Lead-authored-by: jingz-db <[email protected]>
    Co-authored-by: Jungtaek Lim <[email protected]>
    Signed-off-by: Jungtaek Lim <[email protected]>
---
 .../src/main/resources/error/error-conditions.json |   6 -
 .../spark/sql/streaming/StatefulProcessor.scala    |   3 +
 .../streaming/TransformWithStateExec.scala         |   8 +-
 .../streaming/state/StateStoreErrors.scala         |  10 -
 .../TransformWithStateInitialStateSuite.scala      | 286 ++++++++++++++++++---
 5 files changed, 257 insertions(+), 56 deletions(-)

diff --git a/common/utils/src/main/resources/error/error-conditions.json 
b/common/utils/src/main/resources/error/error-conditions.json
index 7ef6feae0845..154fee2eefb7 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -4233,12 +4233,6 @@
     ],
     "sqlState" : "42802"
   },
-  "STATEFUL_PROCESSOR_CANNOT_REINITIALIZE_STATE_ON_KEY" : {
-    "message" : [
-      "Cannot re-initialize state on the same grouping key during initial 
state handling for stateful processor. Invalid grouping key=<groupingKey>."
-    ],
-    "sqlState" : "42802"
-  },
   "STATEFUL_PROCESSOR_DUPLICATE_STATE_VARIABLE_DEFINED" : {
     "message" : [
       "State variable with name <stateVarName> has already been defined in the 
StatefulProcessor."
diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala
index 719d1e572c20..55477b4dda0c 100644
--- 
a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala
+++ 
b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala
@@ -121,6 +121,9 @@ private[sql] abstract class 
StatefulProcessorWithInitialState[K, I, O, S]
 
   /**
    * Function that will be invoked only in the first batch for users to 
process initial states.
+   * The provided initial state can be arbitrary dataframe with the same 
grouping key schema with
+   * the input rows, e.g. dataframe from data source reader of existing 
streaming query
+   * checkpoint.
    *
    * @param key
    *   \- grouping key
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
index 4f7a10f88245..2b26d18019d1 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
@@ -271,13 +271,9 @@ case class TransformWithStateExec(
     ImplicitGroupingKeyTracker.setImplicitKey(keyObj)
     val initStateObjIter = initStateIter.map(getInitStateValueObj.apply)
 
-    var seenInitStateOnKey = false
     initStateObjIter.foreach { initState =>
-      // cannot re-initialize state on the same grouping key during initial 
state handling
-      if (seenInitStateOnKey) {
-        throw StateStoreErrors.cannotReInitializeStateOnKey(keyObj.toString)
-      }
-      seenInitStateOnKey = true
+      // allow multiple initial state rows on the same grouping key for 
integration
+      // with state data source reader with initial state
       statefulProcessor
         .asInstanceOf[StatefulProcessorWithInitialState[Any, Any, Any, Any]]
         .handleInitialState(keyObj, initState,
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
index e4b370e67b01..45ad7e14c52d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
@@ -123,11 +123,6 @@ object StateStoreErrors {
     new 
StatefulProcessorCannotPerformOperationWithInvalidHandleState(operationType, 
handleState)
   }
 
-  def cannotReInitializeStateOnKey(groupingKey: String):
-    StatefulProcessorCannotReInitializeState = {
-    new StatefulProcessorCannotReInitializeState(groupingKey)
-  }
-
   def cannotProvideTTLConfigForTimeMode(stateName: String, timeMode: String):
     StatefulProcessorCannotAssignTTLInTimeMode = {
     new StatefulProcessorCannotAssignTTLInTimeMode(stateName, timeMode)
@@ -272,11 +267,6 @@ class 
StatefulProcessorCannotPerformOperationWithInvalidHandleState(
     messageParameters = Map("operationType" -> operationType, "handleState" -> 
handleState)
   )
 
-class StatefulProcessorCannotReInitializeState(groupingKey: String)
-  extends SparkUnsupportedOperationException(
-  errorClass = "STATEFUL_PROCESSOR_CANNOT_REINITIALIZE_STATE_ON_KEY",
-  messageParameters = Map("groupingKey" -> groupingKey))
-
 class StateStoreUnsupportedOperationOnMissingColumnFamily(
     operationType: String,
     colFamilyName: String) extends SparkUnsupportedOperationException(
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala
index 35ac8a4687eb..360656a76f35 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala
@@ -17,11 +17,11 @@
 
 package org.apache.spark.sql.streaming
 
-import org.apache.spark.SparkUnsupportedOperationException
-import org.apache.spark.sql.{Dataset, Encoders, KeyValueGroupedDataset}
+import org.apache.spark.sql.{DataFrame, Dataset, Encoders, 
KeyValueGroupedDataset}
+import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions
 import org.apache.spark.sql.execution.streaming.MemoryStream
 import 
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled,
 RocksDBStateStoreProvider}
-import org.apache.spark.sql.functions.timestamp_seconds
+import org.apache.spark.sql.functions.{col, timestamp_seconds}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.util.StreamManualClock
 
@@ -29,6 +29,21 @@ case class InitInputRow(key: String, action: String, value: 
Double)
 case class InputRowForInitialState(
     key: String, value: Double, entries: List[Double], mapping: Map[Double, 
Int])
 
+case class UnionInitialStateRow(
+    groupingKey: String,
+    value: Option[Long],
+    listValue: Option[Long],
+    userMapKey: Option[String],
+    userMapValue: Option[Long]
+)
+
+case class UnionUnflattenInitialStateRow(
+    groupingKey: String,
+    value: Option[Long],
+    listValue: Option[Seq[Long]],
+    mapValue: Option[Map[String, Long]]
+)
+
 abstract class StatefulProcessorWithInitialStateTestClass[V]
     extends StatefulProcessorWithInitialState[
         String, InitInputRow, (String, String, Double), V] {
@@ -86,6 +101,86 @@ abstract class StatefulProcessorWithInitialStateTestClass[V]
   }
 }
 
+/**
+ * Class that updates all state variables with input rows. Act as a counter -
+ * keep the count in value state; keep all the occurrences in list state; and
+ * keep a map of key -> occurrence count in the map state.
+ */
+abstract class InitialStateWithStateDataSourceBase[V]
+  extends StatefulProcessorWithInitialState[
+    String, String, (String, String), V] {
+  @transient var _valState: ValueState[Long] = _
+  @transient var _listState: ListState[Long] = _
+  @transient var _mapState: MapState[String, Long] = _
+
+  override def init(
+      outputMode: OutputMode,
+      timeMode: TimeMode): Unit = {
+    _valState = getHandle.getValueState[Long]("testVal", Encoders.scalaLong, 
TTLConfig.NONE)
+    _listState = getHandle.getListState[Long]("testList", Encoders.scalaLong, 
TTLConfig.NONE)
+    _mapState = getHandle.getMapState[String, Long](
+      "testMap", Encoders.STRING, Encoders.scalaLong, TTLConfig.NONE)
+  }
+
+  override def handleInputRows(
+      key: String,
+      inputRows: Iterator[String],
+      timerValues: TimerValues): Iterator[(String, String)] = {
+    val curCountValue = if (_valState.exists()) {
+      _valState.get()
+    } else {
+      0L
+    }
+    var cnt = curCountValue
+    inputRows.foreach { row =>
+      cnt += 1
+      _listState.appendValue(cnt)
+      val mapCurVal = if (_mapState.containsKey(row)) {
+        _mapState.getValue(row)
+      } else {
+        0
+      }
+      _mapState.updateValue(row, mapCurVal + 1L)
+    }
+    _valState.update(cnt)
+    Iterator.single((key, cnt.toString))
+  }
+
+  override def close(): Unit = super.close()
+}
+
+class InitialStatefulProcessorWithStateDataSource
+  extends InitialStateWithStateDataSourceBase[UnionInitialStateRow] {
+  override def handleInitialState(
+      key: String, initialState: UnionInitialStateRow, timerValues: 
TimerValues): Unit = {
+    if (initialState.value.isDefined) {
+      _valState.update(initialState.value.get)
+    } else if (initialState.listValue.isDefined) {
+      _listState.appendValue(initialState.listValue.get)
+    } else if (initialState.userMapKey.isDefined) {
+      _mapState.updateValue(
+        initialState.userMapKey.get, initialState.userMapValue.get)
+    }
+  }
+}
+
+class InitialStatefulProcessorWithUnflattenStateDataSource
+  extends InitialStateWithStateDataSourceBase[UnionUnflattenInitialStateRow] {
+  override def handleInitialState(
+      key: String, initialState: UnionUnflattenInitialStateRow, timerValues: 
TimerValues): Unit = {
+    if (initialState.value.isDefined) {
+      _valState.update(initialState.value.get)
+    } else if (initialState.listValue.isDefined) {
+      _listState.appendList(
+        initialState.listValue.get.toArray)
+    } else if (initialState.mapValue.isDefined) {
+      initialState.mapValue.get.keys.foreach { key =>
+        _mapState.updateValue(key, initialState.mapValue.get.get(key).get)
+      }
+    }
+  }
+}
+
 class AccumulateStatefulProcessorWithInitState
     extends StatefulProcessorWithInitialStateTestClass[(String, Double)] {
   override def handleInitialState(
@@ -398,37 +493,6 @@ class TransformWithStateInitialStateSuite extends 
StateStoreMetricsTest
     checkAnswer(df, Seq(("k1", "getOption", 37.0)).toDF())
   }
 
-  test("transformWithStateWithInitialState - " +
-    "cannot re-initialize state during initial state handling") {
-    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
-      classOf[RocksDBStateStoreProvider].getName) {
-      val initDf = Seq(("init_1", 40.0), ("init_2", 100.0), ("init_1", 
50.0)).toDS()
-        .groupByKey(x => x._1).mapValues(x => x)
-      val inputData = MemoryStream[InitInputRow]
-      val query = inputData.toDS()
-        .groupByKey(x => x.key)
-        .transformWithState(new AccumulateStatefulProcessorWithInitState(),
-          TimeMode.None(),
-          OutputMode.Append(),
-          initDf)
-
-      testStream(query, OutputMode.Update())(
-        AddData(inputData, InitInputRow("k1", "add", 50.0)),
-        Execute { q =>
-          val e = intercept[Exception] {
-            q.processAllAvailable()
-          }
-          checkError(
-            exception = 
e.getCause.asInstanceOf[SparkUnsupportedOperationException],
-            condition = "STATEFUL_PROCESSOR_CANNOT_REINITIALIZE_STATE_ON_KEY",
-            sqlState = Some("42802"),
-            parameters = Map("groupingKey" -> "init_1")
-          )
-        }
-      )
-    }
-  }
-
   test("transformWithStateWithInitialState - streaming with processing time 
timer, " +
     "can emit expired initial state rows when grouping key is not received for 
new input rows") {
     withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
@@ -503,4 +567,158 @@ class TransformWithStateInitialStateSuite extends 
StateStoreMetricsTest
       )
     }
   }
+
+  Seq(true, false).foreach { flattenOption =>
+    Seq(("5", "2"), ("5", "8"), ("5", "5")).foreach { partitions =>
+      test("state data source reader dataframe as initial state " +
+        s"(flatten option=$flattenOption, shuffle partition for 1st 
stream=${partitions._1}, " +
+        s"shuffle partition for 1st stream=${partitions._2})") {
+        withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+          classOf[RocksDBStateStoreProvider].getName) {
+          withTempPaths(2) { checkpointDirs =>
+            SQLConf.get.setConfString(SQLConf.SHUFFLE_PARTITIONS.key, 
partitions._1)
+            val inputData = MemoryStream[String]
+            val result = inputData.toDS()
+              .groupByKey(x => x)
+              .transformWithState(new 
InitialStatefulProcessorWithStateDataSource(),
+                TimeMode.None(),
+                OutputMode.Update())
+
+            testStream(result, OutputMode.Update())(
+              StartStream(checkpointLocation = 
checkpointDirs(0).getCanonicalPath),
+              AddData(inputData, "a", "b"),
+              CheckNewAnswer(("a", "1"), ("b", "1")),
+              AddData(inputData, "a", "b", "a"),
+              CheckNewAnswer(("a", "3"), ("b", "2"))
+            )
+
+            // We are trying to mimic a use case where users load all state 
data rows
+            // from a previous tws query as initial state and start a new tws 
query.
+            // In this use case, users will need to create a single dataframe 
with
+            // all the state rows from different state variables with 
different schema.
+            // We can only read from one state variable from one state data 
source reader
+            // query, and they are of different schema. We will get one 
dataframe from each
+            // state variable, and we union them together into a single 
dataframe.
+            val valueDf = spark.read
+              .format("statestore")
+              .option(StateSourceOptions.PATH, 
checkpointDirs(0).getAbsolutePath)
+              .option(StateSourceOptions.STATE_VAR_NAME, "testVal")
+              .load()
+              .drop("partition_id")
+
+            val listDf = spark.read
+              .format("statestore")
+              .option(StateSourceOptions.PATH, 
checkpointDirs(0).getAbsolutePath)
+              .option(StateSourceOptions.STATE_VAR_NAME, "testList")
+              .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, 
flattenOption)
+              .load()
+              .drop("partition_id")
+
+            val mapDf = spark.read
+              .format("statestore")
+              .option(StateSourceOptions.PATH, 
checkpointDirs(0).getAbsolutePath)
+              .option(StateSourceOptions.STATE_VAR_NAME, "testMap")
+              .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, 
flattenOption)
+              .load()
+              .drop("partition_id")
+
+            // create a df where each row contains all value, list, map state 
rows
+            // fill the missing column with null.
+            SQLConf.get.setConfString(SQLConf.SHUFFLE_PARTITIONS.key, 
partitions._2)
+            val inputData2 = MemoryStream[String]
+            val query = startQueryWithDataSourceDataframeAsInitState(
+              flattenOption, valueDf, listDf, mapDf, inputData2)
+
+            testStream(query, OutputMode.Update())(
+              StartStream(checkpointLocation = 
checkpointDirs(1).getCanonicalPath),
+              // check initial state is updated for state vars
+              AddData(inputData2, "c"),
+              CheckNewAnswer(("c", "1")),
+              Execute { _ =>
+                val valueDf2 = spark.read
+                  .format("statestore")
+                  .option(StateSourceOptions.PATH, 
checkpointDirs(1).getAbsolutePath)
+                  .option(StateSourceOptions.STATE_VAR_NAME, "testVal")
+                  .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, 
flattenOption)
+                  .load()
+                  .drop("partition_id")
+                  .filter(col("key.value") =!= "c")
+
+                val listDf2 = spark.read
+                  .format("statestore")
+                  .option(StateSourceOptions.PATH, 
checkpointDirs(1).getAbsolutePath)
+                  .option(StateSourceOptions.STATE_VAR_NAME, "testList")
+                  .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, 
flattenOption)
+                  .load()
+                  .drop("partition_id")
+                  .filter(col("key.value") =!= "c")
+
+                val mapDf2 = spark.read
+                  .format("statestore")
+                  .option(StateSourceOptions.PATH, 
checkpointDirs(1).getAbsolutePath)
+                  .option(StateSourceOptions.STATE_VAR_NAME, "testMap")
+                  .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, 
flattenOption)
+                  .load()
+                  .drop("partition_id")
+                  .filter(col("key.value") =!= "c")
+
+                checkAnswer(valueDf, valueDf2)
+                checkAnswer(listDf, listDf2)
+                checkAnswer(mapDf, mapDf2)
+              }
+            )
+          }
+        }
+      }
+    }
+  }
+
+  private def startQueryWithDataSourceDataframeAsInitState(
+      flattenOption: Boolean,
+      valDf: DataFrame,
+      listDf: DataFrame,
+      mapDf: DataFrame,
+      inputData: MemoryStream[String]): DataFrame = {
+    if (flattenOption) {
+      // when we read the state rows with flattened option set to true, values 
of a composite
+      // state variable will be flattened into multiple rows where each row is 
a
+      // key -> single value pair
+      val valueDf = valDf.selectExpr("key.value AS groupingKey", "value.value 
AS value")
+      val flattenListDf = listDf
+        .selectExpr("key.value AS groupingKey", "list_element.value AS 
listValue")
+      val flattenMapDf = mapDf
+        .selectExpr(
+          "key.value AS groupingKey",
+          "user_map_key.value AS userMapKey",
+          "user_map_value.value AS userMapValue")
+      val df_joined =
+        valueDf.unionByName(flattenListDf, true)
+          .unionByName(flattenMapDf, true)
+      val kvDataSet = inputData.toDS().groupByKey(x => x)
+      val initDf = df_joined.as[UnionInitialStateRow].groupByKey(x => 
x.groupingKey)
+      (kvDataSet.transformWithState(
+        new InitialStatefulProcessorWithStateDataSource(),
+        TimeMode.None(), OutputMode.Append(), initDf).toDF())
+    } else {
+      // when we read the state rows with flattened option set to false, 
values of a composite
+      // state variable will be composed into a single row of list/map type
+      val valueDf = valDf.selectExpr("key.value AS groupingKey", "value.value 
AS value")
+      val unflattenListDf = listDf
+        .selectExpr("key.value AS groupingKey",
+          "list_value.value as listValue")
+      val unflattenMapDf = mapDf
+        .selectExpr(
+          "key.value AS groupingKey",
+          "map_from_entries(transform(map_entries(map_value), x -> " +
+            "struct(x.key.value, x.value.value))) as mapValue")
+      val df_joined =
+        valueDf.unionByName(unflattenListDf, true)
+          .unionByName(unflattenMapDf, true)
+      val kvDataSet = inputData.toDS().groupByKey(x => x)
+      val initDf = df_joined.as[UnionUnflattenInitialStateRow].groupByKey(x => 
x.groupingKey)
+      kvDataSet.transformWithState(
+        new InitialStatefulProcessorWithUnflattenStateDataSource(),
+        TimeMode.None(), OutputMode.Append(), initDf).toDF()
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to