This is an automated email from the ASF dual-hosted git repository.
kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new d84b2d4565c5 [SPARK-50428][SS][PYTHON] Support
TransformWithStateInPandas in batch queries
d84b2d4565c5 is described below
commit d84b2d4565c5e29c912de4e86d6960fff49ffbd2
Author: bogao007 <[email protected]>
AuthorDate: Thu Dec 12 12:04:35 2024 +0900
[SPARK-50428][SS][PYTHON] Support TransformWithStateInPandas in batch
queries
### What changes were proposed in this pull request?
Support TransformWithStateInPandas in batch queries.
### Why are the changes needed?
Bring parity as Scala. Scala batch support for TransformWithState is done
in https://issues.apache.org/jira/browse/SPARK-46865.
### Does this PR introduce _any_ user-facing change?
Yes.
### How was this patch tested?
Added new unit test cases.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #49113 from bogao007/tws-batch.
Authored-by: bogao007 <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
---
.../pandas/test_pandas_transform_with_state.py | 48 ++++++
.../spark/sql/execution/SparkStrategies.scala | 21 ++-
.../python/TransformWithStateInPandasExec.scala | 181 +++++++++++++++++----
3 files changed, 207 insertions(+), 43 deletions(-)
diff --git
a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
index 60f2c9348db3..15089f2cb0d6 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
@@ -859,6 +859,54 @@ class TransformWithStateInPandasTestsMixin:
self.test_transform_with_state_in_pandas_event_time()
self.test_transform_with_state_in_pandas_proc_timer()
+ def test_transform_with_state_in_pandas_batch_query(self):
+ data = [("0", 123), ("0", 46), ("1", 146), ("1", 346)]
+ df = self.spark.createDataFrame(data, "id string, temperature int")
+
+ output_schema = StructType(
+ [
+ StructField("id", StringType(), True),
+ StructField("countAsString", StringType(), True),
+ ]
+ )
+ batch_result = df.groupBy("id").transformWithStateInPandas(
+ statefulProcessor=MapStateProcessor(),
+ outputStructType=output_schema,
+ outputMode="Update",
+ timeMode="None",
+ )
+ assert set(batch_result.sort("id").collect()) == {
+ Row(id="0", countAsString="2"),
+ Row(id="1", countAsString="2"),
+ }
+
+ def test_transform_with_state_in_pandas_batch_query_initial_state(self):
+ data = [("0", 123), ("0", 46), ("1", 146), ("1", 346)]
+ df = self.spark.createDataFrame(data, "id string, temperature int")
+
+ init_data = [("0", 789), ("3", 987)]
+ initial_state = self.spark.createDataFrame(init_data, "id string,
initVal int").groupBy(
+ "id"
+ )
+
+ output_schema = StructType(
+ [
+ StructField("id", StringType(), True),
+ StructField("value", StringType(), True),
+ ]
+ )
+ batch_result = df.groupBy("id").transformWithStateInPandas(
+ statefulProcessor=SimpleStatefulProcessorWithInitialState(),
+ outputStructType=output_schema,
+ outputMode="Update",
+ timeMode="None",
+ initialState=initial_state,
+ )
+ assert set(batch_result.sort("id").collect()) == {
+ Row(id="0", value=str(789 + 123 + 46)),
+ Row(id="1", value=str(146 + 346)),
+ }
+
class SimpleStatefulProcessorWithInitialState(StatefulProcessor):
# this dict is the same as input initial state dataframe
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index e77c050fe887..36e25773f834 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -22,7 +22,7 @@ import java.util.Locale
import org.apache.spark.{SparkException, SparkUnsupportedOperationException}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{execution, AnalysisException, Strategy}
-import org.apache.spark.sql.catalyst.{ExtendedAnalysisException, InternalRow}
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight,
BuildSide, JoinSelectionHelper, NormalizeFloatingNumbers}
@@ -794,8 +794,8 @@ abstract class SparkStrategies extends
QueryPlanner[SparkPlan] {
object TransformWithStateInPandasStrategy extends Strategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case t @ TransformWithStateInPandas(
- func, _, outputAttrs, outputMode, timeMode, child,
- hasInitialState, initialState, _, initialStateSchema) =>
+ func, _, outputAttrs, outputMode, timeMode, child,
+ hasInitialState, initialState, _, initialStateSchema) =>
val execPlan = TransformWithStateInPandasExec(
func, t.leftAttributes, outputAttrs, outputMode, timeMode,
stateInfo = None,
@@ -803,6 +803,7 @@ abstract class SparkStrategies extends
QueryPlanner[SparkPlan] {
eventTimeWatermarkForLateEvents = None,
eventTimeWatermarkForEviction = None,
planLater(child),
+ isStreaming = true,
hasInitialState,
planLater(initialState),
t.rightAttributes,
@@ -967,18 +968,16 @@ abstract class SparkStrategies extends
QueryPlanner[SparkPlan] {
keyEncoder, outputObjAttr, planLater(child), hasInitialState,
initialStateGroupingAttrs, initialStateDataAttrs,
initialStateDeserializer, planLater(initialState)) :: Nil
+ case t @ TransformWithStateInPandas(
+ func, _, outputAttrs, outputMode, timeMode, child,
+ hasInitialState, initialState, _, initialStateSchema) =>
+ TransformWithStateInPandasExec.generateSparkPlanForBatchQueries(func,
+ t.leftAttributes, outputAttrs, outputMode, timeMode,
planLater(child), hasInitialState,
+ planLater(initialState), t.rightAttributes, initialStateSchema) ::
Nil
case _: FlatMapGroupsInPandasWithState =>
// TODO(SPARK-40443): support applyInPandasWithState in batch query
throw new SparkUnsupportedOperationException("_LEGACY_ERROR_TEMP_3176")
- case t: TransformWithStateInPandas =>
- // TODO(SPARK-50428): support TransformWithStateInPandas in batch query
- throw new ExtendedAnalysisException(
- new AnalysisException(
- "_LEGACY_ERROR_TEMP_3102",
- Map(
- "msg" -> "TransformWithStateInPandas is not supported with batch
DataFrames/Datasets")
- ), plan = t)
case logical.CoGroup(
f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, lOrder, rOrder,
oAttr, left, right) =>
execution.CoGroupExec(
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala
index 617c20c3a782..f8e9f11f4d73 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala
@@ -16,12 +16,15 @@
*/
package org.apache.spark.sql.execution.python
+import java.util.UUID
+
import scala.concurrent.duration.NANOSECONDS
import org.apache.hadoop.conf.Configuration
import org.apache.spark.JobArtifactSet
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
@@ -34,10 +37,11 @@ import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.python.PandasGroupUtils.{executePython,
groupAndProject, resolveArgOffsets}
import org.apache.spark.sql.execution.streaming.{StatefulOperatorCustomMetric,
StatefulOperatorCustomSumMetric, StatefulOperatorPartitioning,
StatefulOperatorStateInfo, StatefulProcessorHandleImpl, StateStoreWriter,
WatermarkSupport}
import
org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.StateStoreAwareZipPartitionsHelper
-import
org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec,
StateSchemaValidationResult, StateStore, StateStoreConf, StateStoreId,
StateStoreOps, StateStoreProviderId}
+import
org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec,
RocksDBStateStoreProvider, StateSchemaValidationResult, StateStore,
StateStoreConf, StateStoreId, StateStoreOps, StateStoreProvider,
StateStoreProviderId}
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.{OutputMode, TimeMode}
import org.apache.spark.sql.types.{BinaryType, StructField, StructType}
-import org.apache.spark.util.{CompletionIterator, SerializableConfiguration}
+import org.apache.spark.util.{CompletionIterator, SerializableConfiguration,
Utils}
/**
* Physical operator for executing
@@ -53,8 +57,11 @@ import org.apache.spark.util.{CompletionIterator,
SerializableConfiguration}
* @param eventTimeWatermarkForLateEvents event time watermark for filtering
late events
* @param eventTimeWatermarkForEviction event time watermark for state eviction
* @param child the physical plan for the underlying data
+ * @param isStreaming defines whether the query is streaming or batch
+ * @param hasInitialState defines whether the query has initial state
* @param initialState the physical plan for the input initial state
* @param initialStateGroupingAttrs grouping attributes for initial state
+ * @param initialStateSchema schema for initial state
*/
case class TransformWithStateInPandasExec(
functionExpr: Expression,
@@ -67,6 +74,7 @@ case class TransformWithStateInPandasExec(
eventTimeWatermarkForLateEvents: Option[Long],
eventTimeWatermarkForEviction: Option[Long],
child: SparkPlan,
+ isStreaming: Boolean = true,
hasInitialState: Boolean,
initialState: SparkPlan,
initialStateGroupingAttrs: Seq[Attribute],
@@ -190,18 +198,32 @@ case class TransformWithStateInPandasExec(
metrics
if (!hasInitialState) {
- child.execute().mapPartitionsWithStateStore[InternalRow](
- getStateInfo,
- schemaForKeyRow,
- schemaForValueRow,
- NoPrefixKeyStateEncoderSpec(schemaForKeyRow),
- session.sqlContext.sessionState,
- Some(session.sqlContext.streams.stateStoreCoordinator),
- useColumnFamilies = true,
- useMultipleValuesPerKey = true
- ) {
- case (store: StateStore, dataIterator: Iterator[InternalRow]) =>
- processDataWithPartition(store, dataIterator)
+ if (isStreaming) {
+ child.execute().mapPartitionsWithStateStore[InternalRow](
+ getStateInfo,
+ schemaForKeyRow,
+ schemaForValueRow,
+ NoPrefixKeyStateEncoderSpec(schemaForKeyRow),
+ session.sqlContext.sessionState,
+ Some(session.sqlContext.streams.stateStoreCoordinator),
+ useColumnFamilies = true,
+ useMultipleValuesPerKey = true
+ ) {
+ case (store: StateStore, dataIterator: Iterator[InternalRow]) =>
+ processDataWithPartition(store, dataIterator)
+ }
+ } else {
+ // If the query is running in batch mode, we need to create a new
StateStore and instantiate
+ // a temp directory on the executors in mapPartitionsWithIndex.
+ val hadoopConfBroadcast = sparkContext.broadcast(
+ new SerializableConfiguration(session.sessionState.newHadoopConf()))
+ child.execute().mapPartitionsWithIndex[InternalRow](
+ (partitionId: Int, dataIterator: Iterator[InternalRow]) => {
+ initNewStateStoreAndProcessData(partitionId, hadoopConfBroadcast)
{ store =>
+ processDataWithPartition(store, dataIterator)
+ }
+ }
+ )
}
} else {
val storeConf = new StateStoreConf(session.sqlContext.sessionState.conf)
@@ -216,25 +238,71 @@ case class TransformWithStateInPandasExec(
// The state store aware zip partitions will provide us with two
iterators,
// child data iterator and the initial state iterator per partition.
case (partitionId, childDataIterator, initStateIterator) =>
- val stateStoreId = StateStoreId(stateInfo.get.checkpointLocation,
- stateInfo.get.operatorId, partitionId)
- val storeProviderId = StateStoreProviderId(stateStoreId,
stateInfo.get.queryRunId)
- val store = StateStore.get(
- storeProviderId = storeProviderId,
- keySchema = schemaForKeyRow,
- valueSchema = schemaForValueRow,
- NoPrefixKeyStateEncoderSpec(schemaForKeyRow),
- version = stateInfo.get.storeVersion,
- stateStoreCkptId =
stateInfo.get.getStateStoreCkptId(partitionId).map(_.head),
- useColumnFamilies = true,
- storeConf = storeConf,
- hadoopConf = hadoopConfBroadcast.value.value
- )
- processDataWithPartition(store, childDataIterator, initStateIterator)
+ if (isStreaming) {
+ val stateStoreId = StateStoreId(stateInfo.get.checkpointLocation,
+ stateInfo.get.operatorId, partitionId)
+ val storeProviderId = StateStoreProviderId(stateStoreId,
stateInfo.get.queryRunId)
+ val store = StateStore.get(
+ storeProviderId = storeProviderId,
+ keySchema = schemaForKeyRow,
+ valueSchema = schemaForValueRow,
+ NoPrefixKeyStateEncoderSpec(schemaForKeyRow),
+ version = stateInfo.get.storeVersion,
+ stateStoreCkptId =
stateInfo.get.getStateStoreCkptId(partitionId).map(_.head),
+ useColumnFamilies = true,
+ storeConf = storeConf,
+ hadoopConf = hadoopConfBroadcast.value.value
+ )
+ processDataWithPartition(store, childDataIterator,
initStateIterator)
+ } else {
+ initNewStateStoreAndProcessData(partitionId, hadoopConfBroadcast)
{ store =>
+ processDataWithPartition(store, childDataIterator,
initStateIterator)
+ }
+ }
}
}
}
+ /**
+ * Create a new StateStore for given partitionId and instantiate a temp
directory
+ * on the executors. Process data and close the stateStore provider
afterwards.
+ */
+ private def initNewStateStoreAndProcessData(
+ partitionId: Int,
+ hadoopConfBroadcast: Broadcast[SerializableConfiguration])
+ (f: StateStore => Iterator[InternalRow]): Iterator[InternalRow] = {
+
+ val providerId = {
+ val tempDirPath = Utils.createTempDir().getAbsolutePath
+ new StateStoreProviderId(
+ StateStoreId(tempDirPath, 0, partitionId), getStateInfo.queryRunId)
+ }
+
+ val sqlConf = new SQLConf()
+ sqlConf.setConfString(SQLConf.STATE_STORE_PROVIDER_CLASS.key,
+ classOf[RocksDBStateStoreProvider].getName)
+ val storeConf = new StateStoreConf(sqlConf)
+
+ // Create StateStoreProvider for this partition
+ val stateStoreProvider = StateStoreProvider.createAndInit(
+ providerId,
+ schemaForKeyRow,
+ schemaForValueRow,
+ NoPrefixKeyStateEncoderSpec(schemaForKeyRow),
+ useColumnFamilies = true,
+ storeConf = storeConf,
+ hadoopConf = hadoopConfBroadcast.value.value,
+ useMultipleValuesPerKey = true)
+
+ val store = stateStoreProvider.getStore(0, None)
+ val outputIterator = f(store)
+ CompletionIterator[InternalRow,
Iterator[InternalRow]](outputIterator.iterator, {
+ stateStoreProvider.close()
+ }).map { row =>
+ row
+ }
+ }
+
private def processDataWithPartition(
store: StateStore,
dataIterator: Iterator[InternalRow],
@@ -259,7 +327,7 @@ case class TransformWithStateInPandasExec(
val data = groupAndProject(filteredIter, groupingAttributes, child.output,
dedupAttributes)
val processorHandle = new StatefulProcessorHandleImpl(store,
getStateInfo.queryRunId,
- groupingKeyExprEncoder, timeMode, isStreaming = true, batchTimestampMs,
metrics)
+ groupingKeyExprEncoder, timeMode, isStreaming, batchTimestampMs, metrics)
val outputIterator = if (!hasInitialState) {
val runner = new TransformWithStateInPandasPythonRunner(
@@ -311,8 +379,12 @@ case class TransformWithStateInPandasExec(
// by the upstream (consumer) operators in addition to the processing in
this operator.
allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime -
updatesStartTimeNs)
commitTimeMs += timeTakenMs {
- processorHandle.doTtlCleanup()
- store.commit()
+ if (isStreaming) {
+ processorHandle.doTtlCleanup()
+ store.commit()
+ } else {
+ store.abort()
+ }
}
setStoreMetrics(store)
setOperatorMetrics()
@@ -334,3 +406,48 @@ case class TransformWithStateInPandasExec(
override def right: SparkPlan = initialState
}
+
+// scalastyle:off argcount
+object TransformWithStateInPandasExec {
+
+ // Plan logical transformWithStateInPandas for batch queries
+ def generateSparkPlanForBatchQueries(
+ functionExpr: Expression,
+ groupingAttributes: Seq[Attribute],
+ output: Seq[Attribute],
+ outputMode: OutputMode,
+ timeMode: TimeMode,
+ child: SparkPlan,
+ hasInitialState: Boolean = false,
+ initialState: SparkPlan,
+ initialStateGroupingAttrs: Seq[Attribute],
+ initialStateSchema: StructType): SparkPlan = {
+ val shufflePartitions =
child.session.sessionState.conf.numShufflePartitions
+ val statefulOperatorStateInfo = StatefulOperatorStateInfo(
+ checkpointLocation = "", // empty checkpointLocation will be populated
in doExecute
+ queryRunId = UUID.randomUUID(),
+ operatorId = 0,
+ storeVersion = 0,
+ numPartitions = shufflePartitions,
+ stateStoreCkptIds = None
+ )
+
+ new TransformWithStateInPandasExec(
+ functionExpr,
+ groupingAttributes,
+ output,
+ outputMode,
+ timeMode,
+ Some(statefulOperatorStateInfo),
+ Some(System.currentTimeMillis),
+ None,
+ None,
+ child,
+ isStreaming = false,
+ hasInitialState,
+ initialState,
+ initialStateGroupingAttrs,
+ initialStateSchema)
+ }
+}
+// scalastyle:on argcount
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]