HeartSaVioR commented on code in PR #48005:
URL: https://github.com/apache/spark/pull/48005#discussion_r1827111574
##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -409,6 +410,9 @@ def transformWithStateInPandas(
The output mode of the stateful processor.
timeMode : str
The time mode semantics of the stateful processor for timers and
TTL.
+ initialState : :class:`pyspark.sql.GroupedData`
+ Optional. The grouped dataframe on given grouping key as initial
states used for initialization
Review Comment:
nit: Now the method doc for Scala version and PySpark version are diverged,
not only for the type (which is expected) but also the description itself.
For example, here is the explanation of `initialState` in Scala API:
> User provided initial state that will be used to initiate state for the
query in the first batch.
Probably better to revisit both API doc at some point and sync between twos.
Before doing that, I think the part `on given grouping key` is redundant,
and makes confusion. We should have checked the compatibility of the grouping
key between two groups (current Dataset, and Dataset for initialState), right?
If then we could just remove it.
##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -551,25 +550,103 @@ def transformWithStateUDF(
# TODO(SPARK-49603) set the handle state in the lazily initialized
iterator
result = itertools.chain(*result_iter_list)
+ return result
+
+ def transformWithStateUDF(
+ statefulProcessorApiClient: StatefulProcessorApiClient,
+ key: Any,
+ inputRows: Iterator["PandasDataFrameLike"],
+ ) -> Iterator["PandasDataFrameLike"]:
+ handle = StatefulProcessorHandle(statefulProcessorApiClient)
+
+ if statefulProcessorApiClient.handle_state ==
StatefulProcessorHandleState.CREATED:
+ statefulProcessor.init(handle)
+ statefulProcessorApiClient.set_handle_state(
+ StatefulProcessorHandleState.INITIALIZED
+ )
+
+ result = handle_data_with_timers(statefulProcessorApiClient, key,
inputRows)
+ return result
+
+ def transformWithStateWithInitStateUDF(
+ statefulProcessorApiClient: StatefulProcessorApiClient,
+ key: Any,
+ inputRows: Iterator["PandasDataFrameLike"],
+ initialStates: Iterator["PandasDataFrameLike"] = None,
+ ) -> Iterator["PandasDataFrameLike"]:
+ """
+ UDF for TWS operator with non-empty initial states. Possible input
combinations
+ of inputRows and initialStates iterator:
+ - Both `inputRows` and `initialStates` are non-empty: for the
given key, both input rows
+ and initial states contains the grouping key, both input rows
and initial states contains data.
+ - `InitialStates` is non-empty, while `initialStates` is empty.
For the given key, only
+ initial states contains the grouping key and data, and it is
first batch.
+ - `initialStates` is empty, while `inputRows` is not empty. For
the given grouping key, only inputRows
+ contains the grouping key and data, and it is first batch.
+ - `initialStates` is None, while `inputRows` is not empty. This is
not first batch. `initialStates`
+ is initialized to the positional value as None.
+ """
+ handle = StatefulProcessorHandle(statefulProcessorApiClient)
+
+ if statefulProcessorApiClient.handle_state ==
StatefulProcessorHandleState.CREATED:
+ statefulProcessor.init(handle)
+ statefulProcessorApiClient.set_handle_state(
+ StatefulProcessorHandleState.INITIALIZED
+ )
+
+ # only process initial state if first batch
+ is_first_batch = statefulProcessorApiClient.is_first_batch()
+ if is_first_batch and initialStates is not None:
Review Comment:
I'd expect caller to handle this; providing initialStates for non-first
batch is already adding unnecessary overhead and ideally caller should provide
None for non-first batch. I'm OK to double check here for safety purpose, but
maybe I'd do opposite, assert that (!is_first_batch and initialStates is None)
is True.
##########
python/pyspark/sql/streaming/stateful_processor_api_client.py:
##########
@@ -338,6 +339,27 @@ def get_map_state(
# TODO(SPARK-49233): Classify user facing errors.
raise PySparkRuntimeError(f"Error initializing map state: "
f"{response_message[1]}")
+ def is_first_batch(self) -> bool:
+ import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
+
+ is_first_batch = stateMessage.IsFirstBatch()
+ request = stateMessage.UtilsCallCommand(isFirstBatch=is_first_batch)
+ stateful_processor_call =
stateMessage.StatefulProcessorCall(utilsCall=request)
+ message =
stateMessage.StateRequest(statefulProcessorCall=stateful_processor_call)
+
+ self._send_proto_message(message.SerializeToString())
+ response_message = self._receive_proto_message()
+ status = response_message[0]
+ if status == 0:
+ return True
+ elif status == 1:
+ return False
+ else:
Review Comment:
For all other calls it's mostly status == 1.
##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -551,25 +550,103 @@ def transformWithStateUDF(
# TODO(SPARK-49603) set the handle state in the lazily initialized
iterator
result = itertools.chain(*result_iter_list)
+ return result
+
+ def transformWithStateUDF(
+ statefulProcessorApiClient: StatefulProcessorApiClient,
+ key: Any,
+ inputRows: Iterator["PandasDataFrameLike"],
+ ) -> Iterator["PandasDataFrameLike"]:
+ handle = StatefulProcessorHandle(statefulProcessorApiClient)
+
+ if statefulProcessorApiClient.handle_state ==
StatefulProcessorHandleState.CREATED:
+ statefulProcessor.init(handle)
+ statefulProcessorApiClient.set_handle_state(
+ StatefulProcessorHandleState.INITIALIZED
+ )
+
+ result = handle_data_with_timers(statefulProcessorApiClient, key,
inputRows)
+ return result
+
+ def transformWithStateWithInitStateUDF(
+ statefulProcessorApiClient: StatefulProcessorApiClient,
+ key: Any,
+ inputRows: Iterator["PandasDataFrameLike"],
+ initialStates: Iterator["PandasDataFrameLike"] = None,
+ ) -> Iterator["PandasDataFrameLike"]:
+ """
+ UDF for TWS operator with non-empty initial states. Possible input
combinations
+ of inputRows and initialStates iterator:
+ - Both `inputRows` and `initialStates` are non-empty: for the
given key, both input rows
+ and initial states contains the grouping key, both input rows
and initial states contains data.
+ - `InitialStates` is non-empty, while `initialStates` is empty.
For the given key, only
+ initial states contains the grouping key and data, and it is
first batch.
+ - `initialStates` is empty, while `inputRows` is not empty. For
the given grouping key, only inputRows
+ contains the grouping key and data, and it is first batch.
+ - `initialStates` is None, while `inputRows` is not empty. This is
not first batch. `initialStates`
+ is initialized to the positional value as None.
+ """
+ handle = StatefulProcessorHandle(statefulProcessorApiClient)
+
+ if statefulProcessorApiClient.handle_state ==
StatefulProcessorHandleState.CREATED:
+ statefulProcessor.init(handle)
+ statefulProcessorApiClient.set_handle_state(
+ StatefulProcessorHandleState.INITIALIZED
+ )
+
+ # only process initial state if first batch
+ is_first_batch = statefulProcessorApiClient.is_first_batch()
+ if is_first_batch and initialStates is not None:
+ for cur_initial_state in initialStates:
+ statefulProcessorApiClient.set_implicit_key(key)
+ # TODO(SPARK-50194) integration with new timer API &
initial state timer register
+ statefulProcessor.handleInitialState(key,
cur_initial_state)
+
+ # if we don't have input rows for the given key but only have
initial state
+ # for the grouping key, the inputRows iterator could be empty
+ input_rows_empty = False
+ try:
+ first = next(inputRows)
+ except StopIteration:
+ input_rows_empty = True
+ else:
+ inputRows = itertools.chain([first], inputRows)
+
+ if not input_rows_empty:
Review Comment:
If you don't have a test covering this scenario, please add it as well.
##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -551,25 +550,103 @@ def transformWithStateUDF(
# TODO(SPARK-49603) set the handle state in the lazily initialized
iterator
result = itertools.chain(*result_iter_list)
+ return result
+
+ def transformWithStateUDF(
+ statefulProcessorApiClient: StatefulProcessorApiClient,
+ key: Any,
+ inputRows: Iterator["PandasDataFrameLike"],
+ ) -> Iterator["PandasDataFrameLike"]:
+ handle = StatefulProcessorHandle(statefulProcessorApiClient)
+
+ if statefulProcessorApiClient.handle_state ==
StatefulProcessorHandleState.CREATED:
+ statefulProcessor.init(handle)
+ statefulProcessorApiClient.set_handle_state(
+ StatefulProcessorHandleState.INITIALIZED
+ )
+
+ result = handle_data_with_timers(statefulProcessorApiClient, key,
inputRows)
+ return result
+
+ def transformWithStateWithInitStateUDF(
+ statefulProcessorApiClient: StatefulProcessorApiClient,
+ key: Any,
+ inputRows: Iterator["PandasDataFrameLike"],
+ initialStates: Iterator["PandasDataFrameLike"] = None,
+ ) -> Iterator["PandasDataFrameLike"]:
+ """
+ UDF for TWS operator with non-empty initial states. Possible input
combinations
+ of inputRows and initialStates iterator:
+ - Both `inputRows` and `initialStates` are non-empty: for the
given key, both input rows
+ and initial states contains the grouping key, both input rows
and initial states contains data.
+ - `InitialStates` is non-empty, while `initialStates` is empty.
For the given key, only
+ initial states contains the grouping key and data, and it is
first batch.
+ - `initialStates` is empty, while `inputRows` is not empty. For
the given grouping key, only inputRows
+ contains the grouping key and data, and it is first batch.
+ - `initialStates` is None, while `inputRows` is not empty. This is
not first batch. `initialStates`
Review Comment:
This represents the difference between an empty Dataset (or iterator) and
None, right? Just to make clear.
##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -551,25 +550,103 @@ def transformWithStateUDF(
# TODO(SPARK-49603) set the handle state in the lazily initialized
iterator
result = itertools.chain(*result_iter_list)
+ return result
+
+ def transformWithStateUDF(
+ statefulProcessorApiClient: StatefulProcessorApiClient,
+ key: Any,
+ inputRows: Iterator["PandasDataFrameLike"],
+ ) -> Iterator["PandasDataFrameLike"]:
+ handle = StatefulProcessorHandle(statefulProcessorApiClient)
+
+ if statefulProcessorApiClient.handle_state ==
StatefulProcessorHandleState.CREATED:
+ statefulProcessor.init(handle)
+ statefulProcessorApiClient.set_handle_state(
+ StatefulProcessorHandleState.INITIALIZED
+ )
+
+ result = handle_data_with_timers(statefulProcessorApiClient, key,
inputRows)
+ return result
+
+ def transformWithStateWithInitStateUDF(
+ statefulProcessorApiClient: StatefulProcessorApiClient,
+ key: Any,
+ inputRows: Iterator["PandasDataFrameLike"],
+ initialStates: Iterator["PandasDataFrameLike"] = None,
+ ) -> Iterator["PandasDataFrameLike"]:
+ """
+ UDF for TWS operator with non-empty initial states. Possible input
combinations
+ of inputRows and initialStates iterator:
+ - Both `inputRows` and `initialStates` are non-empty: for the
given key, both input rows
Review Comment:
nit: `both input rows and initial states contains the grouping key` sound to
be redundant since we call out `for the given key`. inputRows and initialStates
are expected to be flatten Dataset (not grouped one), right? Their grouping key
is the given key.
##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -551,25 +550,103 @@ def transformWithStateUDF(
# TODO(SPARK-49603) set the handle state in the lazily initialized
iterator
result = itertools.chain(*result_iter_list)
+ return result
+
+ def transformWithStateUDF(
+ statefulProcessorApiClient: StatefulProcessorApiClient,
+ key: Any,
+ inputRows: Iterator["PandasDataFrameLike"],
+ ) -> Iterator["PandasDataFrameLike"]:
+ handle = StatefulProcessorHandle(statefulProcessorApiClient)
+
+ if statefulProcessorApiClient.handle_state ==
StatefulProcessorHandleState.CREATED:
+ statefulProcessor.init(handle)
+ statefulProcessorApiClient.set_handle_state(
+ StatefulProcessorHandleState.INITIALIZED
+ )
+
+ result = handle_data_with_timers(statefulProcessorApiClient, key,
inputRows)
+ return result
+
+ def transformWithStateWithInitStateUDF(
+ statefulProcessorApiClient: StatefulProcessorApiClient,
+ key: Any,
+ inputRows: Iterator["PandasDataFrameLike"],
+ initialStates: Iterator["PandasDataFrameLike"] = None,
+ ) -> Iterator["PandasDataFrameLike"]:
+ """
+ UDF for TWS operator with non-empty initial states. Possible input
combinations
+ of inputRows and initialStates iterator:
+ - Both `inputRows` and `initialStates` are non-empty: for the
given key, both input rows
+ and initial states contains the grouping key, both input rows
and initial states contains data.
+ - `InitialStates` is non-empty, while `initialStates` is empty.
For the given key, only
+ initial states contains the grouping key and data, and it is
first batch.
+ - `initialStates` is empty, while `inputRows` is not empty. For
the given grouping key, only inputRows
+ contains the grouping key and data, and it is first batch.
+ - `initialStates` is None, while `inputRows` is not empty. This is
not first batch. `initialStates`
+ is initialized to the positional value as None.
+ """
+ handle = StatefulProcessorHandle(statefulProcessorApiClient)
+
+ if statefulProcessorApiClient.handle_state ==
StatefulProcessorHandleState.CREATED:
+ statefulProcessor.init(handle)
+ statefulProcessorApiClient.set_handle_state(
+ StatefulProcessorHandleState.INITIALIZED
+ )
+
+ # only process initial state if first batch
+ is_first_batch = statefulProcessorApiClient.is_first_batch()
+ if is_first_batch and initialStates is not None:
+ for cur_initial_state in initialStates:
+ statefulProcessorApiClient.set_implicit_key(key)
+ # TODO(SPARK-50194) integration with new timer API &
initial state timer register
+ statefulProcessor.handleInitialState(key,
cur_initial_state)
+
+ # if we don't have input rows for the given key but only have
initial state
+ # for the grouping key, the inputRows iterator could be empty
+ input_rows_empty = False
+ try:
+ first = next(inputRows)
+ except StopIteration:
+ input_rows_empty = True
+ else:
+ inputRows = itertools.chain([first], inputRows)
+
+ if not input_rows_empty:
Review Comment:
Wait, isn't there a case where inputRows iterator is empty but timer is
expected to be expired?
My understanding is that if you pass transformWithStateWithInitStateUDF as
udf, it will be used for all batches, right? Then how this could handle the
case of grouping key for batch N where there is no data for grouping key but
timer to expire?
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala:
##########
@@ -168,26 +168,53 @@ case class FlatMapGroupsInPandasWithState(
* methods for new rows in each trigger and the user's state/state variables
will be stored
* persistently across invocations.
* @param functionExpr function called on each group
- * @param groupingAttributes used to group the data
+ * @param groupingAttributesLen length of the seq of grouping attributes for
input dataframe
* @param outputAttrs used to define the output rows
* @param outputMode defines the output mode for the statefulProcessor
* @param timeMode the time mode semantics of the stateful processor for
timers and TTL.
* @param child logical plan of the underlying data
+ * @param initialState logical plan of initial state
+ * @param initGroupingAttrsLen length of the seq of grouping attributes for
initial state dataframe
*/
case class TransformWithStateInPandas(
functionExpr: Expression,
- groupingAttributes: Seq[Attribute],
+ groupingAttributesLen: Int,
outputAttrs: Seq[Attribute],
outputMode: OutputMode,
timeMode: TimeMode,
- child: LogicalPlan) extends UnaryNode {
+ child: LogicalPlan,
+ hasInitialState: Boolean,
+ initialState: LogicalPlan,
+ initGroupingAttrsLen: Int,
+ initialStateSchema: StructType) extends BinaryNode {
+ override def left: LogicalPlan = child
+
+ override def right: LogicalPlan = initialState
override def output: Seq[Attribute] = outputAttrs
override def producedAttributes: AttributeSet = AttributeSet(outputAttrs)
- override protected def withNewChildInternal(
- newChild: LogicalPlan): TransformWithStateInPandas = copy(child =
newChild)
+ override lazy val references: AttributeSet =
+ AttributeSet(leftAttributes ++ rightAttributes ++ functionExpr.references)
-- producedAttributes
+
+ override protected def withNewChildrenInternal(
+ newLeft: LogicalPlan, newRight: LogicalPlan): TransformWithStateInPandas
=
+ copy(child = newLeft, initialState = newRight)
+
+ // We call the following attributes from `SparkStrategies` because attribute
were
+ // resolved when we got there; if we directly pass Seq of attributes before
it
+ // is fully analyzed, we will have conflicting attributes when resolving
(initial state
+ // and child have the same names on grouping attributes)
+ def leftAttributes: Seq[Attribute] = left.output.take(groupingAttributesLen)
Review Comment:
You can also assert this node to be resolved when this method is called. e.g.
```
assert(resolved, "this method is expected to be called after resolution")
```
Also we could make the code comment as method doc so that it's easily seen
based on IDE.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala:
##########
@@ -788,15 +788,20 @@ abstract class SparkStrategies extends
QueryPlanner[SparkPlan] {
*/
object TransformWithStateInPandasStrategy extends Strategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case TransformWithStateInPandas(
- func, groupingAttributes, outputAttrs, outputMode, timeMode, child) =>
+ case t@TransformWithStateInPandas(
Review Comment:
nit: spaces around `@`
##########
python/pyspark/sql/streaming/stateful_processor.py:
##########
@@ -426,3 +426,10 @@ def close(self) -> None:
operations.
"""
...
+
+ def handleInitialState(self, key: Any, initialState:
"PandasDataFrameLike") -> None:
+ """
+ Optional to implement. Will act as no-op if not defined or no initial
state input. Function
Review Comment:
nit: Same, it's diverged with method doc for Scala API.
> Function that will be invoked only in the first batch for users to process
initial states.
We can either revisit as a whole, or do the sync now at least for code
change in this PR, while we are here.
##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -551,25 +550,103 @@ def transformWithStateUDF(
# TODO(SPARK-49603) set the handle state in the lazily initialized
iterator
result = itertools.chain(*result_iter_list)
+ return result
+
+ def transformWithStateUDF(
+ statefulProcessorApiClient: StatefulProcessorApiClient,
+ key: Any,
+ inputRows: Iterator["PandasDataFrameLike"],
+ ) -> Iterator["PandasDataFrameLike"]:
+ handle = StatefulProcessorHandle(statefulProcessorApiClient)
+
+ if statefulProcessorApiClient.handle_state ==
StatefulProcessorHandleState.CREATED:
+ statefulProcessor.init(handle)
+ statefulProcessorApiClient.set_handle_state(
+ StatefulProcessorHandleState.INITIALIZED
+ )
+
+ result = handle_data_with_timers(statefulProcessorApiClient, key,
inputRows)
+ return result
+
+ def transformWithStateWithInitStateUDF(
+ statefulProcessorApiClient: StatefulProcessorApiClient,
+ key: Any,
+ inputRows: Iterator["PandasDataFrameLike"],
+ initialStates: Iterator["PandasDataFrameLike"] = None,
+ ) -> Iterator["PandasDataFrameLike"]:
+ """
+ UDF for TWS operator with non-empty initial states. Possible input
combinations
+ of inputRows and initialStates iterator:
+ - Both `inputRows` and `initialStates` are non-empty: for the
given key, both input rows
+ and initial states contains the grouping key, both input rows
and initial states contains data.
+ - `InitialStates` is non-empty, while `initialStates` is empty.
For the given key, only
Review Comment:
nit: `InitialStates` is non-empty, while `initialStates` is empty.
you may want to change either one.
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala:
##########
@@ -168,26 +168,53 @@ case class FlatMapGroupsInPandasWithState(
* methods for new rows in each trigger and the user's state/state variables
will be stored
* persistently across invocations.
* @param functionExpr function called on each group
- * @param groupingAttributes used to group the data
+ * @param groupingAttributesLen length of the seq of grouping attributes for
input dataframe
Review Comment:
Shall we describe the prerequisite of the child (left) and initialState
(right), grouping key attributes should be duplicated at the first place along
with output attributes? Better to understand how this works and provide the
params from caller site.
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala:
##########
@@ -168,26 +168,53 @@ case class FlatMapGroupsInPandasWithState(
* methods for new rows in each trigger and the user's state/state variables
will be stored
* persistently across invocations.
* @param functionExpr function called on each group
- * @param groupingAttributes used to group the data
+ * @param groupingAttributesLen length of the seq of grouping attributes for
input dataframe
* @param outputAttrs used to define the output rows
* @param outputMode defines the output mode for the statefulProcessor
* @param timeMode the time mode semantics of the stateful processor for
timers and TTL.
* @param child logical plan of the underlying data
+ * @param initialState logical plan of initial state
+ * @param initGroupingAttrsLen length of the seq of grouping attributes for
initial state dataframe
*/
case class TransformWithStateInPandas(
functionExpr: Expression,
- groupingAttributes: Seq[Attribute],
+ groupingAttributesLen: Int,
outputAttrs: Seq[Attribute],
outputMode: OutputMode,
timeMode: TimeMode,
- child: LogicalPlan) extends UnaryNode {
+ child: LogicalPlan,
+ hasInitialState: Boolean,
+ initialState: LogicalPlan,
+ initGroupingAttrsLen: Int,
+ initialStateSchema: StructType) extends BinaryNode {
+ override def left: LogicalPlan = child
+
+ override def right: LogicalPlan = initialState
override def output: Seq[Attribute] = outputAttrs
override def producedAttributes: AttributeSet = AttributeSet(outputAttrs)
- override protected def withNewChildInternal(
- newChild: LogicalPlan): TransformWithStateInPandas = copy(child =
newChild)
+ override lazy val references: AttributeSet =
+ AttributeSet(leftAttributes ++ rightAttributes ++ functionExpr.references)
-- producedAttributes
+
+ override protected def withNewChildrenInternal(
+ newLeft: LogicalPlan, newRight: LogicalPlan): TransformWithStateInPandas
=
+ copy(child = newLeft, initialState = newRight)
+
+ // We call the following attributes from `SparkStrategies` because attribute
were
+ // resolved when we got there; if we directly pass Seq of attributes before
it
+ // is fully analyzed, we will have conflicting attributes when resolving
(initial state
+ // and child have the same names on grouping attributes)
+ def leftAttributes: Seq[Attribute] = left.output.take(groupingAttributesLen)
+
+ def rightAttributes: Seq[Attribute] = if (hasInitialState) {
Review Comment:
same here, leave a method doc with requirement.
##########
sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala:
##########
@@ -474,24 +474,61 @@ class RelationalGroupedDataset protected[sql](
func: Column,
outputStructType: StructType,
outputModeStr: String,
- timeModeStr: String): DataFrame = {
- val groupingNamedExpressions = groupingExprs.map {
- case ne: NamedExpression => ne
- case other => Alias(other, other.toString)()
+ timeModeStr: String,
+ initialState: RelationalGroupedDataset): DataFrame = {
+ def exprToAttr(expr: Seq[Expression]): Seq[Attribute] = {
+ expr.map {
+ case ne: NamedExpression => ne
+ case other => Alias(other, other.toString)()
+ }.map(_.toAttribute)
}
- val groupingAttrs = groupingNamedExpressions.map(_.toAttribute)
+
+ val groupingAttrs = exprToAttr(groupingExprs)
val outputAttrs = toAttributes(outputStructType)
val outputMode = InternalOutputModes(outputModeStr)
val timeMode = TimeModes(timeModeStr)
- val plan = TransformWithStateInPandas(
- func.expr,
- groupingAttrs,
- outputAttrs,
- outputMode,
- timeMode,
- child = df.logicalPlan
- )
+ val plan: LogicalPlan = if (initialState == null) {
+ TransformWithStateInPandas(
Review Comment:
Are the grouping attributes guaranteed to be placed for the first place? And
also are they duplicated along with output attributes?
I suspect here we have to do the projection at least for child. Please leave
a code comment explaining if it's not.
If we needed a projection and no test was failing without this, please add a
test which picks non-contiguous columns as grouping key, for both non-initial
state and initial state.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -348,6 +351,23 @@ class TransformWithStateInPandasStateServer(
}
}
+ private[sql] def handleStatefulProcessorUtilRequest(message:
UtilsCallCommand): Unit = {
+ message.getMethodCase match {
+ case UtilsCallCommand.MethodCase.ISFIRSTBATCH =>
+ if (!hasInitialState) {
+ // In physical planning, hasInitialState will always be flipped
Review Comment:
What's the reason we have to deduce the info of isFirstBatch based on
hasInitialState? It isn't exactly same, right? hasInitialState can be false
even for the first batch. How we will deal with it if isFirstBatch is used
without the context of initial state?
Please avoid coupling something which introduces dependency and assumption,
and try not to couple with how "current" caller will use it. If you have to
rely on assumption (e.g. can't add more info to carry over in the plan), please
find the dependency which can give the consistent result. e.g. state operator
info has version which you can get batchId from version - 1, and isFirstBatch
is batchId = 0.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala:
##########
@@ -30,15 +30,16 @@ import org.apache.spark.api.python.{BasePythonRunner,
ChainedPythonFunctions, Py
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.metric.SQLMetric
-import
org.apache.spark.sql.execution.python.TransformWithStateInPandasPythonRunner.{InType,
OutType}
+import
org.apache.spark.sql.execution.python.TransformWithStateInPandasPythonRunner._
Review Comment:
nit: we generally avoid using wildcard on import unless it's 10s of.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala:
##########
@@ -788,15 +788,20 @@ abstract class SparkStrategies extends
QueryPlanner[SparkPlan] {
*/
object TransformWithStateInPandasStrategy extends Strategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case TransformWithStateInPandas(
- func, groupingAttributes, outputAttrs, outputMode, timeMode, child) =>
+ case t@TransformWithStateInPandas(
+ func, _, outputAttrs, outputMode, timeMode, child,
+ hasInitialState, i, _, initialStateSchema) =>
Review Comment:
nit: let's ensure the param name to be understandable, we don't shorten
others but i for initialState; why not just use the full name.
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala:
##########
@@ -561,6 +565,13 @@ class TransformWithStateInPandasStateServerSuite extends
SparkFunSuite with Befo
verify(arrowStreamWriter).finalizeCurrentArrowBatch()
}
+ test("stateful processor - is first batch") {
Review Comment:
Once my comment in above is addressed, please have a negative test as well.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -348,6 +351,23 @@ class TransformWithStateInPandasStateServer(
}
}
+ private[sql] def handleStatefulProcessorUtilRequest(message:
UtilsCallCommand): Unit = {
+ message.getMethodCase match {
+ case UtilsCallCommand.MethodCase.ISFIRSTBATCH =>
+ if (!hasInitialState) {
+ // In physical planning, hasInitialState will always be flipped
+ // if it is not first batch
+ sendResponse(1)
Review Comment:
We have been using the code 1 as "error", not "true"/"false". We use 2 in
false of the boolean result. Let's be consistent.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala:
##########
@@ -51,15 +52,137 @@ class TransformWithStateInPandasPythonRunner(
override val pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
groupingKeySchema: StructType,
- batchTimestampMs: Option[Long] = None,
- eventTimeWatermarkForEviction: Option[Long] = None)
- extends BasePythonRunner[InType, OutType](funcs.map(_._1), evalType,
argOffsets, jobArtifactUUID)
- with PythonArrowInput[InType]
- with BasicPythonArrowOutput
- with Logging {
+ batchTimestampMs: Option[Long],
+ eventTimeWatermarkForEviction: Option[Long],
+ hasInitialState: Boolean)
+ extends TransformWithStateInPandasPythonBaseRunner[InType](
+ funcs, evalType, argOffsets, _schema, processorHandle, _timeZoneId,
+ initialWorkerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema,
+ batchTimestampMs, eventTimeWatermarkForEviction, hasInitialState)
+ with PythonArrowInput[InType] {
- private val sqlConf = SQLConf.get
- private val arrowMaxRecordsPerBatch = sqlConf.arrowMaxRecordsPerBatch
+ private var pandasWriter: BaseStreamingArrowWriter = _
+
+ override protected def writeNextInputToArrowStream(
+ root: VectorSchemaRoot,
+ writer: ArrowStreamWriter,
+ dataOut: DataOutputStream,
+ inputIterator: Iterator[InType]): Boolean = {
+ if (pandasWriter == null) {
+ pandasWriter = new BaseStreamingArrowWriter(root, writer,
arrowMaxRecordsPerBatch)
+ }
+
+ if (inputIterator.hasNext) {
+ val startData = dataOut.size()
+ val next = inputIterator.next()
+ val dataIter = next._2
+
+ while (dataIter.hasNext) {
+ val dataRow = dataIter.next()
+ pandasWriter.writeRow(dataRow)
+ }
+ pandasWriter.finalizeCurrentArrowBatch()
+ val deltaData = dataOut.size() - startData
+ pythonMetrics("pythonDataSent") += deltaData
+ true
+ } else {
+ super[PythonArrowInput].close()
+ false
+ }
+ }
+}
+
+/**
+ * Python runner with initial state in TransformWithStateInPandas.
+ * Write input data as one InternalRow(inputRow, initialState) in each row in
arrow batch.
+ */
+class TransformWithStateInPandasPythonInitialStateRunner(
+ funcs: Seq[(ChainedPythonFunctions, Long)],
+ evalType: Int,
+ argOffsets: Array[Array[Int]],
+ dataSchema: StructType,
+ initStateSchema: StructType,
+ processorHandle: StatefulProcessorHandleImpl,
+ _timeZoneId: String,
+ initialWorkerConf: Map[String, String],
+ override val pythonMetrics: Map[String, SQLMetric],
+ jobArtifactUUID: Option[String],
+ groupingKeySchema: StructType,
+ batchTimestampMs: Option[Long],
+ eventTimeWatermarkForEviction: Option[Long],
+ hasInitialState: Boolean)
+ extends TransformWithStateInPandasPythonBaseRunner[GroupedInType](
+ funcs, evalType, argOffsets, dataSchema, processorHandle, _timeZoneId,
+ initialWorkerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema,
+ batchTimestampMs, eventTimeWatermarkForEviction, hasInitialState)
+ with PythonArrowInput[GroupedInType] {
Review Comment:
ditto
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala:
##########
@@ -61,6 +62,8 @@ class TransformWithStateInPandasStateServerSuite extends
SparkFunSuite with Befo
var arrowStreamWriter: BaseStreamingArrowWriter = _
var batchTimestampMs: Option[Long] = _
var eventTimeWatermarkForEviction: Option[Long] = _
+ var initialStateSchema: StructType = StructType(Seq())
Review Comment:
nit: twos are not used
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala:
##########
@@ -51,15 +52,137 @@ class TransformWithStateInPandasPythonRunner(
override val pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
groupingKeySchema: StructType,
- batchTimestampMs: Option[Long] = None,
- eventTimeWatermarkForEviction: Option[Long] = None)
- extends BasePythonRunner[InType, OutType](funcs.map(_._1), evalType,
argOffsets, jobArtifactUUID)
- with PythonArrowInput[InType]
- with BasicPythonArrowOutput
- with Logging {
+ batchTimestampMs: Option[Long],
+ eventTimeWatermarkForEviction: Option[Long],
+ hasInitialState: Boolean)
+ extends TransformWithStateInPandasPythonBaseRunner[InType](
+ funcs, evalType, argOffsets, _schema, processorHandle, _timeZoneId,
+ initialWorkerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema,
+ batchTimestampMs, eventTimeWatermarkForEviction, hasInitialState)
+ with PythonArrowInput[InType] {
Review Comment:
nit: shift left 2 spaces
##########
python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py:
##########
@@ -536,6 +536,108 @@ def check_results(batch_df, batch_id):
EventTimeStatefulProcessor(), check_results
)
+ def _test_transform_with_state_init_state_in_pandas(self,
stateful_processor, check_results):
+ input_path = tempfile.mkdtemp()
+ self._prepare_test_resource1(input_path)
Review Comment:
I see you are covering both cases in this test, which is great!
* grouping key in input, but not in initial state (1)
* grouping key in initial state, but not in input (3)
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala:
##########
@@ -51,15 +52,137 @@ class TransformWithStateInPandasPythonRunner(
override val pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
groupingKeySchema: StructType,
- batchTimestampMs: Option[Long] = None,
- eventTimeWatermarkForEviction: Option[Long] = None)
- extends BasePythonRunner[InType, OutType](funcs.map(_._1), evalType,
argOffsets, jobArtifactUUID)
- with PythonArrowInput[InType]
- with BasicPythonArrowOutput
- with Logging {
+ batchTimestampMs: Option[Long],
+ eventTimeWatermarkForEviction: Option[Long],
+ hasInitialState: Boolean)
+ extends TransformWithStateInPandasPythonBaseRunner[InType](
+ funcs, evalType, argOffsets, _schema, processorHandle, _timeZoneId,
+ initialWorkerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema,
+ batchTimestampMs, eventTimeWatermarkForEviction, hasInitialState)
+ with PythonArrowInput[InType] {
- private val sqlConf = SQLConf.get
- private val arrowMaxRecordsPerBatch = sqlConf.arrowMaxRecordsPerBatch
+ private var pandasWriter: BaseStreamingArrowWriter = _
+
+ override protected def writeNextInputToArrowStream(
+ root: VectorSchemaRoot,
+ writer: ArrowStreamWriter,
+ dataOut: DataOutputStream,
+ inputIterator: Iterator[InType]): Boolean = {
+ if (pandasWriter == null) {
+ pandasWriter = new BaseStreamingArrowWriter(root, writer,
arrowMaxRecordsPerBatch)
+ }
+
+ if (inputIterator.hasNext) {
+ val startData = dataOut.size()
+ val next = inputIterator.next()
+ val dataIter = next._2
+
+ while (dataIter.hasNext) {
+ val dataRow = dataIter.next()
+ pandasWriter.writeRow(dataRow)
+ }
+ pandasWriter.finalizeCurrentArrowBatch()
+ val deltaData = dataOut.size() - startData
+ pythonMetrics("pythonDataSent") += deltaData
+ true
+ } else {
+ super[PythonArrowInput].close()
+ false
+ }
+ }
+}
+
+/**
+ * Python runner with initial state in TransformWithStateInPandas.
+ * Write input data as one InternalRow(inputRow, initialState) in each row in
arrow batch.
+ */
+class TransformWithStateInPandasPythonInitialStateRunner(
+ funcs: Seq[(ChainedPythonFunctions, Long)],
+ evalType: Int,
+ argOffsets: Array[Array[Int]],
+ dataSchema: StructType,
+ initStateSchema: StructType,
+ processorHandle: StatefulProcessorHandleImpl,
+ _timeZoneId: String,
+ initialWorkerConf: Map[String, String],
+ override val pythonMetrics: Map[String, SQLMetric],
+ jobArtifactUUID: Option[String],
+ groupingKeySchema: StructType,
+ batchTimestampMs: Option[Long],
+ eventTimeWatermarkForEviction: Option[Long],
+ hasInitialState: Boolean)
+ extends TransformWithStateInPandasPythonBaseRunner[GroupedInType](
+ funcs, evalType, argOffsets, dataSchema, processorHandle, _timeZoneId,
+ initialWorkerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema,
+ batchTimestampMs, eventTimeWatermarkForEviction, hasInitialState)
+ with PythonArrowInput[GroupedInType] {
+
+ override protected lazy val schema: StructType = new StructType()
+ .add("inputData", dataSchema)
+ .add("initState", initStateSchema)
+
+ private var pandasWriter: BaseStreamingArrowWriter = _
+
+ override protected def writeNextInputToArrowStream(
+ root: VectorSchemaRoot,
+ writer: ArrowStreamWriter,
+ dataOut: DataOutputStream,
+ inputIterator:
+ Iterator[GroupedInType]): Boolean = {
+ if (pandasWriter == null) {
+ pandasWriter = new BaseStreamingArrowWriter(root, writer,
arrowMaxRecordsPerBatch)
+ }
+
+ if (inputIterator.hasNext) {
+ val startData = dataOut.size()
+ // a new grouping key with data & init state iter
+ val next = inputIterator.next()
+ val dataIter = next._2
+ val initIter = next._3
+
+ while (dataIter.hasNext || initIter.hasNext) {
+ val dataRow =
+ if (dataIter.hasNext) dataIter.next()
+ else InternalRow.empty
+ val initRow =
+ if (initIter.hasNext) initIter.next()
+ else InternalRow.empty
+ pandasWriter.writeRow(InternalRow(dataRow, initRow))
+ }
+ pandasWriter.finalizeCurrentArrowBatch()
+ val deltaData = dataOut.size() - startData
+ pythonMetrics("pythonDataSent") += deltaData
+ true
+ } else {
+ super[PythonArrowInput].close()
+ false
+ }
+ }
+}
+
+/**
+ * Base Python runner implementation for TransformWithStateInPandas.
+ */
+abstract class TransformWithStateInPandasPythonBaseRunner[I](
+ funcs: Seq[(ChainedPythonFunctions, Long)],
+ evalType: Int,
+ argOffsets: Array[Array[Int]],
+ _schema: StructType,
+ processorHandle: StatefulProcessorHandleImpl,
+ _timeZoneId: String,
+ initialWorkerConf: Map[String, String],
+ override val pythonMetrics: Map[String, SQLMetric],
+ jobArtifactUUID: Option[String],
+ groupingKeySchema: StructType,
+ batchTimestampMs: Option[Long],
+ eventTimeWatermarkForEviction: Option[Long],
+ hasInitialState: Boolean)
+ extends BasePythonRunner[I, ColumnarBatch](funcs.map(_._1), evalType,
argOffsets, jobArtifactUUID)
+ with PythonArrowInput[I]
Review Comment:
ditto for all `with` lines
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala:
##########
@@ -51,15 +52,137 @@ class TransformWithStateInPandasPythonRunner(
override val pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
groupingKeySchema: StructType,
- batchTimestampMs: Option[Long] = None,
- eventTimeWatermarkForEviction: Option[Long] = None)
- extends BasePythonRunner[InType, OutType](funcs.map(_._1), evalType,
argOffsets, jobArtifactUUID)
- with PythonArrowInput[InType]
- with BasicPythonArrowOutput
- with Logging {
+ batchTimestampMs: Option[Long],
+ eventTimeWatermarkForEviction: Option[Long],
+ hasInitialState: Boolean)
+ extends TransformWithStateInPandasPythonBaseRunner[InType](
+ funcs, evalType, argOffsets, _schema, processorHandle, _timeZoneId,
+ initialWorkerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema,
+ batchTimestampMs, eventTimeWatermarkForEviction, hasInitialState)
+ with PythonArrowInput[InType] {
- private val sqlConf = SQLConf.get
- private val arrowMaxRecordsPerBatch = sqlConf.arrowMaxRecordsPerBatch
+ private var pandasWriter: BaseStreamingArrowWriter = _
+
+ override protected def writeNextInputToArrowStream(
+ root: VectorSchemaRoot,
+ writer: ArrowStreamWriter,
+ dataOut: DataOutputStream,
+ inputIterator: Iterator[InType]): Boolean = {
+ if (pandasWriter == null) {
+ pandasWriter = new BaseStreamingArrowWriter(root, writer,
arrowMaxRecordsPerBatch)
+ }
+
+ if (inputIterator.hasNext) {
+ val startData = dataOut.size()
+ val next = inputIterator.next()
+ val dataIter = next._2
+
+ while (dataIter.hasNext) {
+ val dataRow = dataIter.next()
+ pandasWriter.writeRow(dataRow)
+ }
+ pandasWriter.finalizeCurrentArrowBatch()
+ val deltaData = dataOut.size() - startData
+ pythonMetrics("pythonDataSent") += deltaData
+ true
+ } else {
+ super[PythonArrowInput].close()
+ false
+ }
+ }
+}
+
+/**
+ * Python runner with initial state in TransformWithStateInPandas.
+ * Write input data as one InternalRow(inputRow, initialState) in each row in
arrow batch.
+ */
+class TransformWithStateInPandasPythonInitialStateRunner(
+ funcs: Seq[(ChainedPythonFunctions, Long)],
+ evalType: Int,
+ argOffsets: Array[Array[Int]],
+ dataSchema: StructType,
+ initStateSchema: StructType,
+ processorHandle: StatefulProcessorHandleImpl,
+ _timeZoneId: String,
+ initialWorkerConf: Map[String, String],
+ override val pythonMetrics: Map[String, SQLMetric],
+ jobArtifactUUID: Option[String],
+ groupingKeySchema: StructType,
+ batchTimestampMs: Option[Long],
+ eventTimeWatermarkForEviction: Option[Long],
+ hasInitialState: Boolean)
+ extends TransformWithStateInPandasPythonBaseRunner[GroupedInType](
+ funcs, evalType, argOffsets, dataSchema, processorHandle, _timeZoneId,
+ initialWorkerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema,
+ batchTimestampMs, eventTimeWatermarkForEviction, hasInitialState)
+ with PythonArrowInput[GroupedInType] {
+
+ override protected lazy val schema: StructType = new StructType()
+ .add("inputData", dataSchema)
+ .add("initState", initStateSchema)
+
+ private var pandasWriter: BaseStreamingArrowWriter = _
+
+ override protected def writeNextInputToArrowStream(
+ root: VectorSchemaRoot,
+ writer: ArrowStreamWriter,
+ dataOut: DataOutputStream,
+ inputIterator:
+ Iterator[GroupedInType]): Boolean = {
Review Comment:
nit: shifting one line above (any reason it's placed to the next line?)
##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -1190,3 +1190,70 @@ def dump_stream(self, iterator, stream):
"""
result = [(b, t) for x in iterator for y, t in x for b in y]
super().dump_stream(result, stream)
+
+
+class
TransformWithStateInPandasInitStateSerializer(TransformWithStateInPandasSerializer):
+ """
+ Serializer used by Python worker to evaluate UDF for
+
:meth:`pyspark.sql.GroupedData.transformWithStateInPandasInitStateSerializer`.
+ Parameters
+ ----------
+ Same as input parameters in TransformWithStateInPandasSerializer.
+ """
+
+ def __init__(self, timezone, safecheck, assign_cols_by_name,
arrow_max_records_per_batch):
+ super(TransformWithStateInPandasInitStateSerializer, self).__init__(
+ timezone, safecheck, assign_cols_by_name,
arrow_max_records_per_batch
+ )
+ self.init_key_offsets = None
+
+ def load_stream(self, stream):
+ import pyarrow as pa
+
+ def generate_data_batches(batches):
+ """
+ Deserialize ArrowRecordBatches and return a generator of
pandas.Series list.
+ The deserialization logic assumes that Arrow RecordBatches contain
the data with the
+ ordering that data chunks for same grouping key will appear
sequentially.
+ See `TransformWithStateInPandasPythonBaseRunner` for arrow batch
schema sent from JVM.
+ This function flatten the columns of input rows and initial state
rows and feed them into
+ the data generator.
+ """
+
+ def flatten_columns(cur_batch, col_name):
+ state_column =
cur_batch.column(cur_batch.schema.get_field_index(col_name))
+ state_field_names = [
+ state_column.type[i].name for i in
range(state_column.type.num_fields)
+ ]
+ state_field_arrays = [
+ state_column.field(i) for i in
range(state_column.type.num_fields)
+ ]
+ table_from_fields = pa.Table.from_arrays(
+ state_field_arrays, names=state_field_names
+ )
+ return table_from_fields
+
+ for batch in batches:
Review Comment:
Maybe better to have a brief comment about how the batch has constructed or
some characteristic, or even where to read the code to understand the data
structure. Personally I read this code before reading the part of building
batch, and have to make an assumption that a batch must only have data from a
single grouping key, otherwise it won't work.
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala:
##########
@@ -168,26 +168,53 @@ case class FlatMapGroupsInPandasWithState(
* methods for new rows in each trigger and the user's state/state variables
will be stored
* persistently across invocations.
* @param functionExpr function called on each group
- * @param groupingAttributes used to group the data
+ * @param groupingAttributesLen length of the seq of grouping attributes for
input dataframe
* @param outputAttrs used to define the output rows
* @param outputMode defines the output mode for the statefulProcessor
* @param timeMode the time mode semantics of the stateful processor for
timers and TTL.
* @param child logical plan of the underlying data
+ * @param initialState logical plan of initial state
+ * @param initGroupingAttrsLen length of the seq of grouping attributes for
initial state dataframe
*/
case class TransformWithStateInPandas(
functionExpr: Expression,
- groupingAttributes: Seq[Attribute],
+ groupingAttributesLen: Int,
outputAttrs: Seq[Attribute],
outputMode: OutputMode,
timeMode: TimeMode,
- child: LogicalPlan) extends UnaryNode {
+ child: LogicalPlan,
+ hasInitialState: Boolean,
+ initialState: LogicalPlan,
+ initGroupingAttrsLen: Int,
+ initialStateSchema: StructType) extends BinaryNode {
+ override def left: LogicalPlan = child
+
+ override def right: LogicalPlan = initialState
override def output: Seq[Attribute] = outputAttrs
override def producedAttributes: AttributeSet = AttributeSet(outputAttrs)
- override protected def withNewChildInternal(
- newChild: LogicalPlan): TransformWithStateInPandas = copy(child =
newChild)
+ override lazy val references: AttributeSet =
+ AttributeSet(leftAttributes ++ rightAttributes ++ functionExpr.references)
-- producedAttributes
+
+ override protected def withNewChildrenInternal(
+ newLeft: LogicalPlan, newRight: LogicalPlan): TransformWithStateInPandas
=
+ copy(child = newLeft, initialState = newRight)
+
+ // We call the following attributes from `SparkStrategies` because attribute
were
+ // resolved when we got there; if we directly pass Seq of attributes before
it
+ // is fully analyzed, we will have conflicting attributes when resolving
(initial state
+ // and child have the same names on grouping attributes)
+ def leftAttributes: Seq[Attribute] = left.output.take(groupingAttributesLen)
Review Comment:
Also, for the code comment, we don't need to explain where this method is
called in here - we will have to update the code comment whenever there is a
new caller. Let's just describe the requirement from this method's point of
view. It is needed to be called "after" this node is resolved.
##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -551,25 +550,103 @@ def transformWithStateUDF(
# TODO(SPARK-49603) set the handle state in the lazily initialized
iterator
result = itertools.chain(*result_iter_list)
+ return result
+
+ def transformWithStateUDF(
+ statefulProcessorApiClient: StatefulProcessorApiClient,
+ key: Any,
+ inputRows: Iterator["PandasDataFrameLike"],
+ ) -> Iterator["PandasDataFrameLike"]:
+ handle = StatefulProcessorHandle(statefulProcessorApiClient)
+
+ if statefulProcessorApiClient.handle_state ==
StatefulProcessorHandleState.CREATED:
+ statefulProcessor.init(handle)
+ statefulProcessorApiClient.set_handle_state(
+ StatefulProcessorHandleState.INITIALIZED
+ )
+
+ result = handle_data_with_timers(statefulProcessorApiClient, key,
inputRows)
+ return result
+
+ def transformWithStateWithInitStateUDF(
+ statefulProcessorApiClient: StatefulProcessorApiClient,
+ key: Any,
+ inputRows: Iterator["PandasDataFrameLike"],
+ initialStates: Iterator["PandasDataFrameLike"] = None,
+ ) -> Iterator["PandasDataFrameLike"]:
+ """
+ UDF for TWS operator with non-empty initial states. Possible input
combinations
+ of inputRows and initialStates iterator:
+ - Both `inputRows` and `initialStates` are non-empty: for the
given key, both input rows
Review Comment:
ditto for all others
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala:
##########
@@ -51,15 +52,137 @@ class TransformWithStateInPandasPythonRunner(
override val pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
groupingKeySchema: StructType,
- batchTimestampMs: Option[Long] = None,
- eventTimeWatermarkForEviction: Option[Long] = None)
- extends BasePythonRunner[InType, OutType](funcs.map(_._1), evalType,
argOffsets, jobArtifactUUID)
- with PythonArrowInput[InType]
- with BasicPythonArrowOutput
- with Logging {
+ batchTimestampMs: Option[Long],
+ eventTimeWatermarkForEviction: Option[Long],
+ hasInitialState: Boolean)
+ extends TransformWithStateInPandasPythonBaseRunner[InType](
+ funcs, evalType, argOffsets, _schema, processorHandle, _timeZoneId,
+ initialWorkerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema,
+ batchTimestampMs, eventTimeWatermarkForEviction, hasInitialState)
+ with PythonArrowInput[InType] {
- private val sqlConf = SQLConf.get
- private val arrowMaxRecordsPerBatch = sqlConf.arrowMaxRecordsPerBatch
+ private var pandasWriter: BaseStreamingArrowWriter = _
+
+ override protected def writeNextInputToArrowStream(
+ root: VectorSchemaRoot,
+ writer: ArrowStreamWriter,
+ dataOut: DataOutputStream,
+ inputIterator: Iterator[InType]): Boolean = {
+ if (pandasWriter == null) {
+ pandasWriter = new BaseStreamingArrowWriter(root, writer,
arrowMaxRecordsPerBatch)
+ }
+
+ if (inputIterator.hasNext) {
+ val startData = dataOut.size()
+ val next = inputIterator.next()
+ val dataIter = next._2
+
+ while (dataIter.hasNext) {
+ val dataRow = dataIter.next()
+ pandasWriter.writeRow(dataRow)
+ }
+ pandasWriter.finalizeCurrentArrowBatch()
+ val deltaData = dataOut.size() - startData
+ pythonMetrics("pythonDataSent") += deltaData
+ true
+ } else {
+ super[PythonArrowInput].close()
+ false
+ }
+ }
+}
+
+/**
+ * Python runner with initial state in TransformWithStateInPandas.
+ * Write input data as one InternalRow(inputRow, initialState) in each row in
arrow batch.
+ */
+class TransformWithStateInPandasPythonInitialStateRunner(
+ funcs: Seq[(ChainedPythonFunctions, Long)],
+ evalType: Int,
+ argOffsets: Array[Array[Int]],
+ dataSchema: StructType,
+ initStateSchema: StructType,
+ processorHandle: StatefulProcessorHandleImpl,
+ _timeZoneId: String,
+ initialWorkerConf: Map[String, String],
+ override val pythonMetrics: Map[String, SQLMetric],
+ jobArtifactUUID: Option[String],
+ groupingKeySchema: StructType,
+ batchTimestampMs: Option[Long],
+ eventTimeWatermarkForEviction: Option[Long],
+ hasInitialState: Boolean)
+ extends TransformWithStateInPandasPythonBaseRunner[GroupedInType](
+ funcs, evalType, argOffsets, dataSchema, processorHandle, _timeZoneId,
+ initialWorkerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema,
+ batchTimestampMs, eventTimeWatermarkForEviction, hasInitialState)
+ with PythonArrowInput[GroupedInType] {
+
+ override protected lazy val schema: StructType = new StructType()
+ .add("inputData", dataSchema)
+ .add("initState", initStateSchema)
+
+ private var pandasWriter: BaseStreamingArrowWriter = _
+
+ override protected def writeNextInputToArrowStream(
+ root: VectorSchemaRoot,
+ writer: ArrowStreamWriter,
+ dataOut: DataOutputStream,
+ inputIterator:
+ Iterator[GroupedInType]): Boolean = {
Review Comment:
If the combined line exceeds 100 chars, `: Boolean = {` should only be in
this line, with 2 spaces shifted left from parameters.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]