This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d84b2d4565c5 [SPARK-50428][SS][PYTHON] Support 
TransformWithStateInPandas in batch queries
d84b2d4565c5 is described below

commit d84b2d4565c5e29c912de4e86d6960fff49ffbd2
Author: bogao007 <[email protected]>
AuthorDate: Thu Dec 12 12:04:35 2024 +0900

    [SPARK-50428][SS][PYTHON] Support TransformWithStateInPandas in batch 
queries
    
    ### What changes were proposed in this pull request?
    
    Support TransformWithStateInPandas in batch queries.
    
    ### Why are the changes needed?
    
    Bring parity as Scala. Scala batch support for TransformWithState is done 
in https://issues.apache.org/jira/browse/SPARK-46865.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes.
    
    ### How was this patch tested?
    
    Added new unit test cases.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #49113 from bogao007/tws-batch.
    
    Authored-by: bogao007 <[email protected]>
    Signed-off-by: Jungtaek Lim <[email protected]>
---
 .../pandas/test_pandas_transform_with_state.py     |  48 ++++++
 .../spark/sql/execution/SparkStrategies.scala      |  21 ++-
 .../python/TransformWithStateInPandasExec.scala    | 181 +++++++++++++++++----
 3 files changed, 207 insertions(+), 43 deletions(-)

diff --git 
a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py 
b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
index 60f2c9348db3..15089f2cb0d6 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
@@ -859,6 +859,54 @@ class TransformWithStateInPandasTestsMixin:
             self.test_transform_with_state_in_pandas_event_time()
             self.test_transform_with_state_in_pandas_proc_timer()
 
+    def test_transform_with_state_in_pandas_batch_query(self):
+        data = [("0", 123), ("0", 46), ("1", 146), ("1", 346)]
+        df = self.spark.createDataFrame(data, "id string, temperature int")
+
+        output_schema = StructType(
+            [
+                StructField("id", StringType(), True),
+                StructField("countAsString", StringType(), True),
+            ]
+        )
+        batch_result = df.groupBy("id").transformWithStateInPandas(
+            statefulProcessor=MapStateProcessor(),
+            outputStructType=output_schema,
+            outputMode="Update",
+            timeMode="None",
+        )
+        assert set(batch_result.sort("id").collect()) == {
+            Row(id="0", countAsString="2"),
+            Row(id="1", countAsString="2"),
+        }
+
+    def test_transform_with_state_in_pandas_batch_query_initial_state(self):
+        data = [("0", 123), ("0", 46), ("1", 146), ("1", 346)]
+        df = self.spark.createDataFrame(data, "id string, temperature int")
+
+        init_data = [("0", 789), ("3", 987)]
+        initial_state = self.spark.createDataFrame(init_data, "id string, 
initVal int").groupBy(
+            "id"
+        )
+
+        output_schema = StructType(
+            [
+                StructField("id", StringType(), True),
+                StructField("value", StringType(), True),
+            ]
+        )
+        batch_result = df.groupBy("id").transformWithStateInPandas(
+            statefulProcessor=SimpleStatefulProcessorWithInitialState(),
+            outputStructType=output_schema,
+            outputMode="Update",
+            timeMode="None",
+            initialState=initial_state,
+        )
+        assert set(batch_result.sort("id").collect()) == {
+            Row(id="0", value=str(789 + 123 + 46)),
+            Row(id="1", value=str(146 + 346)),
+        }
+
 
 class SimpleStatefulProcessorWithInitialState(StatefulProcessor):
     # this dict is the same as input initial state dataframe
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index e77c050fe887..36e25773f834 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -22,7 +22,7 @@ import java.util.Locale
 import org.apache.spark.{SparkException, SparkUnsupportedOperationException}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{execution, AnalysisException, Strategy}
-import org.apache.spark.sql.catalyst.{ExtendedAnalysisException, InternalRow}
+import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, 
BuildSide, JoinSelectionHelper, NormalizeFloatingNumbers}
@@ -794,8 +794,8 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
   object TransformWithStateInPandasStrategy extends Strategy {
     override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
       case t @ TransformWithStateInPandas(
-      func, _, outputAttrs, outputMode, timeMode, child,
-      hasInitialState, initialState, _, initialStateSchema) =>
+        func, _, outputAttrs, outputMode, timeMode, child,
+        hasInitialState, initialState, _, initialStateSchema) =>
         val execPlan = TransformWithStateInPandasExec(
           func, t.leftAttributes, outputAttrs, outputMode, timeMode,
           stateInfo = None,
@@ -803,6 +803,7 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
           eventTimeWatermarkForLateEvents = None,
           eventTimeWatermarkForEviction = None,
           planLater(child),
+          isStreaming = true,
           hasInitialState,
           planLater(initialState),
           t.rightAttributes,
@@ -967,18 +968,16 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
           keyEncoder, outputObjAttr, planLater(child), hasInitialState,
           initialStateGroupingAttrs, initialStateDataAttrs,
           initialStateDeserializer, planLater(initialState)) :: Nil
+      case t @ TransformWithStateInPandas(
+        func, _, outputAttrs, outputMode, timeMode, child,
+        hasInitialState, initialState, _, initialStateSchema) =>
+        TransformWithStateInPandasExec.generateSparkPlanForBatchQueries(func,
+          t.leftAttributes, outputAttrs, outputMode, timeMode, 
planLater(child), hasInitialState,
+          planLater(initialState), t.rightAttributes, initialStateSchema) :: 
Nil
 
       case _: FlatMapGroupsInPandasWithState =>
         // TODO(SPARK-40443): support applyInPandasWithState in batch query
         throw new SparkUnsupportedOperationException("_LEGACY_ERROR_TEMP_3176")
-      case t: TransformWithStateInPandas =>
-        // TODO(SPARK-50428): support TransformWithStateInPandas in batch query
-        throw new ExtendedAnalysisException(
-          new AnalysisException(
-            "_LEGACY_ERROR_TEMP_3102",
-            Map(
-              "msg" -> "TransformWithStateInPandas is not supported with batch 
DataFrames/Datasets")
-          ), plan = t)
       case logical.CoGroup(
           f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, lOrder, rOrder, 
oAttr, left, right) =>
         execution.CoGroupExec(
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala
index 617c20c3a782..f8e9f11f4d73 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala
@@ -16,12 +16,15 @@
  */
 package org.apache.spark.sql.execution.python
 
+import java.util.UUID
+
 import scala.concurrent.duration.NANOSECONDS
 
 import org.apache.hadoop.conf.Configuration
 
 import org.apache.spark.JobArtifactSet
 import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
+import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
@@ -34,10 +37,11 @@ import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.execution.python.PandasGroupUtils.{executePython, 
groupAndProject, resolveArgOffsets}
 import org.apache.spark.sql.execution.streaming.{StatefulOperatorCustomMetric, 
StatefulOperatorCustomSumMetric, StatefulOperatorPartitioning, 
StatefulOperatorStateInfo, StatefulProcessorHandleImpl, StateStoreWriter, 
WatermarkSupport}
 import 
org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.StateStoreAwareZipPartitionsHelper
-import 
org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, 
StateSchemaValidationResult, StateStore, StateStoreConf, StateStoreId, 
StateStoreOps, StateStoreProviderId}
+import 
org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, 
RocksDBStateStoreProvider, StateSchemaValidationResult, StateStore, 
StateStoreConf, StateStoreId, StateStoreOps, StateStoreProvider, 
StateStoreProviderId}
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.{OutputMode, TimeMode}
 import org.apache.spark.sql.types.{BinaryType, StructField, StructType}
-import org.apache.spark.util.{CompletionIterator, SerializableConfiguration}
+import org.apache.spark.util.{CompletionIterator, SerializableConfiguration, 
Utils}
 
 /**
  * Physical operator for executing
@@ -53,8 +57,11 @@ import org.apache.spark.util.{CompletionIterator, 
SerializableConfiguration}
  * @param eventTimeWatermarkForLateEvents event time watermark for filtering 
late events
  * @param eventTimeWatermarkForEviction event time watermark for state eviction
  * @param child the physical plan for the underlying data
+ * @param isStreaming defines whether the query is streaming or batch
+ * @param hasInitialState defines whether the query has initial state
  * @param initialState the physical plan for the input initial state
  * @param initialStateGroupingAttrs grouping attributes for initial state
+ * @param initialStateSchema schema for initial state
  */
 case class TransformWithStateInPandasExec(
     functionExpr: Expression,
@@ -67,6 +74,7 @@ case class TransformWithStateInPandasExec(
     eventTimeWatermarkForLateEvents: Option[Long],
     eventTimeWatermarkForEviction: Option[Long],
     child: SparkPlan,
+    isStreaming: Boolean = true,
     hasInitialState: Boolean,
     initialState: SparkPlan,
     initialStateGroupingAttrs: Seq[Attribute],
@@ -190,18 +198,32 @@ case class TransformWithStateInPandasExec(
     metrics
 
     if (!hasInitialState) {
-      child.execute().mapPartitionsWithStateStore[InternalRow](
-        getStateInfo,
-        schemaForKeyRow,
-        schemaForValueRow,
-        NoPrefixKeyStateEncoderSpec(schemaForKeyRow),
-        session.sqlContext.sessionState,
-        Some(session.sqlContext.streams.stateStoreCoordinator),
-        useColumnFamilies = true,
-        useMultipleValuesPerKey = true
-      ) {
-        case (store: StateStore, dataIterator: Iterator[InternalRow]) =>
-          processDataWithPartition(store, dataIterator)
+      if (isStreaming) {
+        child.execute().mapPartitionsWithStateStore[InternalRow](
+          getStateInfo,
+          schemaForKeyRow,
+          schemaForValueRow,
+          NoPrefixKeyStateEncoderSpec(schemaForKeyRow),
+          session.sqlContext.sessionState,
+          Some(session.sqlContext.streams.stateStoreCoordinator),
+          useColumnFamilies = true,
+          useMultipleValuesPerKey = true
+        ) {
+          case (store: StateStore, dataIterator: Iterator[InternalRow]) =>
+            processDataWithPartition(store, dataIterator)
+        }
+      } else {
+        // If the query is running in batch mode, we need to create a new 
StateStore and instantiate
+        // a temp directory on the executors in mapPartitionsWithIndex.
+        val hadoopConfBroadcast = sparkContext.broadcast(
+          new SerializableConfiguration(session.sessionState.newHadoopConf()))
+        child.execute().mapPartitionsWithIndex[InternalRow](
+          (partitionId: Int, dataIterator: Iterator[InternalRow]) => {
+            initNewStateStoreAndProcessData(partitionId, hadoopConfBroadcast) 
{ store =>
+              processDataWithPartition(store, dataIterator)
+            }
+          }
+        )
       }
     } else {
       val storeConf = new StateStoreConf(session.sqlContext.sessionState.conf)
@@ -216,25 +238,71 @@ case class TransformWithStateInPandasExec(
         // The state store aware zip partitions will provide us with two 
iterators,
         // child data iterator and the initial state iterator per partition.
         case (partitionId, childDataIterator, initStateIterator) =>
-          val stateStoreId = StateStoreId(stateInfo.get.checkpointLocation,
-            stateInfo.get.operatorId, partitionId)
-          val storeProviderId = StateStoreProviderId(stateStoreId, 
stateInfo.get.queryRunId)
-          val store = StateStore.get(
-            storeProviderId = storeProviderId,
-            keySchema = schemaForKeyRow,
-            valueSchema = schemaForValueRow,
-            NoPrefixKeyStateEncoderSpec(schemaForKeyRow),
-            version = stateInfo.get.storeVersion,
-            stateStoreCkptId = 
stateInfo.get.getStateStoreCkptId(partitionId).map(_.head),
-            useColumnFamilies = true,
-            storeConf = storeConf,
-            hadoopConf = hadoopConfBroadcast.value.value
-          )
-          processDataWithPartition(store, childDataIterator, initStateIterator)
+          if (isStreaming) {
+            val stateStoreId = StateStoreId(stateInfo.get.checkpointLocation,
+              stateInfo.get.operatorId, partitionId)
+            val storeProviderId = StateStoreProviderId(stateStoreId, 
stateInfo.get.queryRunId)
+            val store = StateStore.get(
+              storeProviderId = storeProviderId,
+              keySchema = schemaForKeyRow,
+              valueSchema = schemaForValueRow,
+              NoPrefixKeyStateEncoderSpec(schemaForKeyRow),
+              version = stateInfo.get.storeVersion,
+              stateStoreCkptId = 
stateInfo.get.getStateStoreCkptId(partitionId).map(_.head),
+              useColumnFamilies = true,
+              storeConf = storeConf,
+              hadoopConf = hadoopConfBroadcast.value.value
+            )
+            processDataWithPartition(store, childDataIterator, 
initStateIterator)
+          } else {
+            initNewStateStoreAndProcessData(partitionId, hadoopConfBroadcast) 
{ store =>
+              processDataWithPartition(store, childDataIterator, 
initStateIterator)
+            }
+          }
       }
     }
   }
 
+  /**
+   * Create a new StateStore for given partitionId and instantiate a temp 
directory
+   * on the executors. Process data and close the stateStore provider 
afterwards.
+   */
+  private def initNewStateStoreAndProcessData(
+      partitionId: Int,
+      hadoopConfBroadcast: Broadcast[SerializableConfiguration])
+    (f: StateStore => Iterator[InternalRow]): Iterator[InternalRow] = {
+
+    val providerId = {
+      val tempDirPath = Utils.createTempDir().getAbsolutePath
+      new StateStoreProviderId(
+        StateStoreId(tempDirPath, 0, partitionId), getStateInfo.queryRunId)
+    }
+
+    val sqlConf = new SQLConf()
+    sqlConf.setConfString(SQLConf.STATE_STORE_PROVIDER_CLASS.key,
+      classOf[RocksDBStateStoreProvider].getName)
+    val storeConf = new StateStoreConf(sqlConf)
+
+    // Create StateStoreProvider for this partition
+    val stateStoreProvider = StateStoreProvider.createAndInit(
+      providerId,
+      schemaForKeyRow,
+      schemaForValueRow,
+      NoPrefixKeyStateEncoderSpec(schemaForKeyRow),
+      useColumnFamilies = true,
+      storeConf = storeConf,
+      hadoopConf = hadoopConfBroadcast.value.value,
+      useMultipleValuesPerKey = true)
+
+    val store = stateStoreProvider.getStore(0, None)
+    val outputIterator = f(store)
+    CompletionIterator[InternalRow, 
Iterator[InternalRow]](outputIterator.iterator, {
+      stateStoreProvider.close()
+    }).map { row =>
+      row
+    }
+  }
+
   private def processDataWithPartition(
       store: StateStore,
       dataIterator: Iterator[InternalRow],
@@ -259,7 +327,7 @@ case class TransformWithStateInPandasExec(
     val data = groupAndProject(filteredIter, groupingAttributes, child.output, 
dedupAttributes)
 
     val processorHandle = new StatefulProcessorHandleImpl(store, 
getStateInfo.queryRunId,
-      groupingKeyExprEncoder, timeMode, isStreaming = true, batchTimestampMs, 
metrics)
+      groupingKeyExprEncoder, timeMode, isStreaming, batchTimestampMs, metrics)
 
     val outputIterator = if (!hasInitialState) {
       val runner = new TransformWithStateInPandasPythonRunner(
@@ -311,8 +379,12 @@ case class TransformWithStateInPandasExec(
       // by the upstream (consumer) operators in addition to the processing in 
this operator.
       allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - 
updatesStartTimeNs)
       commitTimeMs += timeTakenMs {
-        processorHandle.doTtlCleanup()
-        store.commit()
+        if (isStreaming) {
+          processorHandle.doTtlCleanup()
+          store.commit()
+        } else {
+          store.abort()
+        }
       }
       setStoreMetrics(store)
       setOperatorMetrics()
@@ -334,3 +406,48 @@ case class TransformWithStateInPandasExec(
 
   override def right: SparkPlan = initialState
 }
+
+// scalastyle:off argcount
+object TransformWithStateInPandasExec {
+
+  // Plan logical transformWithStateInPandas for batch queries
+  def generateSparkPlanForBatchQueries(
+      functionExpr: Expression,
+      groupingAttributes: Seq[Attribute],
+      output: Seq[Attribute],
+      outputMode: OutputMode,
+      timeMode: TimeMode,
+      child: SparkPlan,
+      hasInitialState: Boolean = false,
+      initialState: SparkPlan,
+      initialStateGroupingAttrs: Seq[Attribute],
+      initialStateSchema: StructType): SparkPlan = {
+    val shufflePartitions = 
child.session.sessionState.conf.numShufflePartitions
+    val statefulOperatorStateInfo = StatefulOperatorStateInfo(
+      checkpointLocation = "", // empty checkpointLocation will be populated 
in doExecute
+      queryRunId = UUID.randomUUID(),
+      operatorId = 0,
+      storeVersion = 0,
+      numPartitions = shufflePartitions,
+      stateStoreCkptIds = None
+    )
+
+    new TransformWithStateInPandasExec(
+      functionExpr,
+      groupingAttributes,
+      output,
+      outputMode,
+      timeMode,
+      Some(statefulOperatorStateInfo),
+      Some(System.currentTimeMillis),
+      None,
+      None,
+      child,
+      isStreaming = false,
+      hasInitialState,
+      initialState,
+      initialStateGroupingAttrs,
+      initialStateSchema)
+  }
+}
+// scalastyle:on argcount


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to