This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 603dc509821 [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState
in PySpark
603dc509821 is described below
commit 603dc5098217d9580f611873165d25392f41cdfe
Author: Jungtaek Lim <[email protected]>
AuthorDate: Thu Sep 22 12:35:07 2022 +0900
[SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark
### What changes were proposed in this pull request?
This PR proposes to introduce the new API `applyInPandasWithState` in
PySpark, which provides the functionality to perform arbitrary stateful
processing in Structured Streaming.
This will be a pair API with applyInPandas - applyInPandas in PySpark
covers the use case of flatMapGroups in Scala/Java API, applyInPandasWithState
in PySpark covers the use case of flatMapGroupsWithState in Scala/Java API.
The signature of API follows:
```
# call this function after groupBy
def applyInPandasWithState(
self,
func: "PandasGroupedMapFunctionWithState",
outputStructType: Union[StructType, str],
stateStructType: Union[StructType, str],
outputMode: str,
timeoutConf: str,
) -> DataFrame
```
and the signature of user function follows:
```
def func(
key: Tuple,
pdf_iter: Iterator[pandas.DataFrame],
state: GroupStateImpl
) -> Iterator[pandas.DataFrame]
```
(Please refer the code diff for function doc of new function.)
Major design choices which differ from existing APIs:
1. The new API is untyped, while flatMapGroupsWithState in typed API.
This is based on the nature of Python language - it's really duck typing
and type definition is just a hint. We don't have the implementation of typed
API for PySpark DataFrame.
This leads us to design the API to be untyped, meaning, all types for
(input, state, output) should be Row-compatible. While we don't require end
users to deal with `Row` directly, the model they will use for state and output
must be convertible to Row with default encoder. If they want the python type
for state which is not compatible with Row (e.g. custom class), they need to
pickle and use BinaryType to store it.
This requires end users to specify the type of state and output via Spark
SQL schema in the method.
Note that this helps to ensure compatibility for state data across Spark
versions, as long as the encoders for 1) python type -> python Row and 2)
python Row -> UnsafeRow are not changed. We won't change the underlying data
layout for UnsafeRow, as it will break all of existing stateful query.
2. The new API will produce Pandas DataFrame to user function, while
flatMapGroupsWithState produces iterator of rows.
We decided to follow the user experience applyInPandas provides for both
consistency and performance (Arrow batching, vectorization, etc). This leads us
to design the user function to leverage pandas DataFrame rather than iterator
of rows. While this leads inconsistency of the UX from the Scala/Java API, we
don't think this will come up as a problem since Pandas is considered as
de-facto standard for Python data scientists.
3. The new API will produce iterator of Pandas DataFrame to user function
and also require to return iterator of Pandas DataFrame to address scalability.
There is known limitation of applyInPandas, scalability. It basically
requires data in a specific group to be fit into memory. During the design
phase of new API, we decided to address the scalability rather than inheriting
the limitation.
To address the scalability, we tweak the user function to receive an
iterator (generator) of Pandas DataFrame instead of a single Pandas DataFrame,
and also return an iterator (generator) of Pandas DataFrame. We think it does
not hurt the UX too much, as for-each and yield would be enough to deal with
the requirement of dealing with iterator.
Implementation perspective, we split the data in a specific group to
multiple chunks, which each chunk is stored and sent as "an" Arrow RecordBatch,
and then finally materialized to "a" pandas DataFrame. This way, as long as end
users don't materialize lots of pandas DataFrames from the iterator at the same
time, only one chunk will be materialized into memory which is scalable.
Similar logic applies to the output of user function, hence scalable as well.
4. The new API also bin-packs the data with multiple groups into "an" Arrow
RecordBatch.
Given the API is mainly used for streaming workload, it could be high
likely that the volume of data in a specific group may not be huge enough to
leverage the benefit of Arrow columnar batching, which would hurt the
performance. To address this, we also do the opposite thing what we do for
scalability, bin-pack. That said, an Arrow RecordBatch can contain data for
multiple groups, as well as a part of data for specific group. This address
both aspects of concerns together, scalabilit [...]
Note that we are not implementing all of features Scala/Java API provide
from the initial phase. e.g. Support for batch query and support for initial
state will be left as TODO.
### Why are the changes needed?
PySpark users don't have a way to perform arbitrary stateful processing in
Structured Streaming and being forced to use either Java or Scala which is
unacceptable for users in many cases. This PR enables PySpark users to deal
with it without moving to Java/Scala world.
### Does this PR introduce _any_ user-facing change?
Yes. We are exposing new public API in PySpark which performs arbitrary
stateful processing.
### How was this patch tested?
N/A. We will make sure test suites are constructed via E2E manner under
[SPARK-40431](https://issues.apache.org/jira/browse/SPARK-40431) - #37894
Closes #37893 from
HeartSaVioR/SPARK-40434-on-top-of-SPARK-40433-SPARK-40432.
Lead-authored-by: Jungtaek Lim <[email protected]>
Co-authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../org/apache/spark/api/python/PythonRunner.scala | 2 +
python/pyspark/rdd.py | 2 +
python/pyspark/sql/pandas/_typing/__init__.pyi | 6 +
python/pyspark/sql/pandas/functions.py | 2 +
python/pyspark/sql/pandas/group_ops.py | 125 +++++++-
python/pyspark/sql/pandas/serializers.py | 355 ++++++++++++++++++++-
python/pyspark/sql/streaming/state.py | 55 +++-
python/pyspark/sql/udf.py | 9 +-
python/pyspark/worker.py | 143 +++++++++
.../analysis/UnsupportedOperationChecker.scala | 62 ++++
.../plans/logical/pythonLogicalOperators.scala | 34 ++
.../spark/sql/RelationalGroupedDataset.scala | 45 +++
.../spark/sql/execution/SparkStrategies.scala | 23 ++
.../spark/sql/execution/arrow/ArrowWriter.scala | 16 +-
.../ApplyInPandasWithStatePythonRunner.scala | 223 +++++++++++++
.../python/ApplyInPandasWithStateWriter.scala | 276 ++++++++++++++++
.../python/FlatMapCoGroupsInPandasExec.scala | 4 +-
.../python/FlatMapGroupsInPandasExec.scala | 2 +-
.../FlatMapGroupsInPandasWithStateExec.scala | 214 +++++++++++++
.../sql/execution/python/PandasGroupUtils.scala | 7 +-
.../sql/execution/python/PythonArrowInput.scala | 1 -
.../execution/streaming/IncrementalExecution.scala | 9 +
.../sql/execution/streaming/state/package.scala | 2 +-
23 files changed, 1599 insertions(+), 18 deletions(-)
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index 5a13674e8bf..7b31fa93c32 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -53,6 +53,7 @@ private[spark] object PythonEvalType {
val SQL_MAP_PANDAS_ITER_UDF = 205
val SQL_COGROUPED_MAP_PANDAS_UDF = 206
val SQL_MAP_ARROW_ITER_UDF = 207
+ val SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE = 208
def toString(pythonEvalType: Int): String = pythonEvalType match {
case NON_UDF => "NON_UDF"
@@ -65,6 +66,7 @@ private[spark] object PythonEvalType {
case SQL_MAP_PANDAS_ITER_UDF => "SQL_MAP_PANDAS_ITER_UDF"
case SQL_COGROUPED_MAP_PANDAS_UDF => "SQL_COGROUPED_MAP_PANDAS_UDF"
case SQL_MAP_ARROW_ITER_UDF => "SQL_MAP_ARROW_ITER_UDF"
+ case SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE =>
"SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE"
}
}
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 7ef0014ae75..5f4f4d494e1 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -105,6 +105,7 @@ if TYPE_CHECKING:
PandasMapIterUDFType,
PandasCogroupedMapUDFType,
ArrowMapIterUDFType,
+ PandasGroupedMapUDFWithStateType,
)
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.types import AtomicType, StructType
@@ -147,6 +148,7 @@ class PythonEvalType:
SQL_MAP_PANDAS_ITER_UDF: "PandasMapIterUDFType" = 205
SQL_COGROUPED_MAP_PANDAS_UDF: "PandasCogroupedMapUDFType" = 206
SQL_MAP_ARROW_ITER_UDF: "ArrowMapIterUDFType" = 207
+ SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: "PandasGroupedMapUDFWithStateType"
= 208
def portable_hash(x: Hashable) -> int:
diff --git a/python/pyspark/sql/pandas/_typing/__init__.pyi
b/python/pyspark/sql/pandas/_typing/__init__.pyi
index 27ac64a7238..acca8c00f2a 100644
--- a/python/pyspark/sql/pandas/_typing/__init__.pyi
+++ b/python/pyspark/sql/pandas/_typing/__init__.pyi
@@ -30,6 +30,7 @@ from typing_extensions import Protocol, Literal
from types import FunctionType
from pyspark.sql._typing import LiteralType
+from pyspark.sql.streaming.state import GroupState
from pandas.core.frame import DataFrame as PandasDataFrame
from pandas.core.series import Series as PandasSeries
from numpy import ndarray as NDArray
@@ -51,6 +52,7 @@ PandasScalarIterUDFType = Literal[204]
PandasMapIterUDFType = Literal[205]
PandasCogroupedMapUDFType = Literal[206]
ArrowMapIterUDFType = Literal[207]
+PandasGroupedMapUDFWithStateType = Literal[208]
class PandasVariadicScalarToScalarFunction(Protocol):
def __call__(self, *_: DataFrameOrSeriesLike_) -> DataFrameOrSeriesLike_:
...
@@ -256,6 +258,10 @@ PandasGroupedMapFunction = Union[
Callable[[Any, DataFrameLike], DataFrameLike],
]
+PandasGroupedMapFunctionWithState = Callable[
+ [Any, Iterable[DataFrameLike], GroupState], Iterable[DataFrameLike]
+]
+
class PandasVariadicGroupedAggFunction(Protocol):
def __call__(self, *_: SeriesLike) -> LiteralType: ...
diff --git a/python/pyspark/sql/pandas/functions.py
b/python/pyspark/sql/pandas/functions.py
index 94fabdbb295..d0f81e2f633 100644
--- a/python/pyspark/sql/pandas/functions.py
+++ b/python/pyspark/sql/pandas/functions.py
@@ -369,6 +369,7 @@ def pandas_udf(f=None, returnType=None, functionType=None):
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
+ PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
None,
]: # None means it should infer the type from type hints.
@@ -402,6 +403,7 @@ def _create_pandas_udf(f, returnType, evalType):
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
+ PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
]:
# In case of 'SQL_GROUPED_MAP_PANDAS_UDF', deprecation warning is
being triggered
# at `apply` instead.
diff --git a/python/pyspark/sql/pandas/group_ops.py
b/python/pyspark/sql/pandas/group_ops.py
index 6178433573e..0945c0078a2 100644
--- a/python/pyspark/sql/pandas/group_ops.py
+++ b/python/pyspark/sql/pandas/group_ops.py
@@ -15,18 +15,20 @@
# limitations under the License.
#
import sys
-from typing import List, Union, TYPE_CHECKING
+from typing import List, Union, TYPE_CHECKING, cast
import warnings
from pyspark.rdd import PythonEvalType
from pyspark.sql.column import Column
from pyspark.sql.dataframe import DataFrame
-from pyspark.sql.types import StructType
+from pyspark.sql.streaming.state import GroupStateTimeout
+from pyspark.sql.types import StructType, _parse_datatype_string
if TYPE_CHECKING:
from pyspark.sql.pandas._typing import (
GroupedMapPandasUserDefinedFunction,
PandasGroupedMapFunction,
+ PandasGroupedMapFunctionWithState,
PandasCogroupedMapFunction,
)
from pyspark.sql.group import GroupedData
@@ -216,6 +218,125 @@ class PandasGroupedOpsMixin:
jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
return DataFrame(jdf, self.session)
+ def applyInPandasWithState(
+ self,
+ func: "PandasGroupedMapFunctionWithState",
+ outputStructType: Union[StructType, str],
+ stateStructType: Union[StructType, str],
+ outputMode: str,
+ timeoutConf: str,
+ ) -> DataFrame:
+ """
+ Applies the given function to each group of data, while maintaining a
user-defined
+ per-group state. The result Dataset will represent the flattened
record returned by the
+ function.
+
+ For a streaming Dataset, the function will be invoked first for all
input groups and then
+ for all timed out states where the input data is set to be empty.
Updates to each group's
+ state will be saved across invocations.
+
+ The function should take parameters (key,
Iterator[`pandas.DataFrame`], state) and
+ return another Iterator[`pandas.DataFrame`]. The grouping key(s) will
be passed as a tuple
+ of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The
state will be passed as
+ :class:`pyspark.sql.streaming.state.GroupState`.
+
+ For each group, all columns are passed together as `pandas.DataFrame`
to the user-function,
+ and the returned `pandas.DataFrame` across all invocations are
combined as a
+ :class:`DataFrame`. Note that the user function should not make a
guess of the number of
+ elements in the iterator. To process all data, the user function needs
to iterate all
+ elements and process them. On the other hand, the user function is not
strictly required to
+ iterate through all elements in the iterator if it intends to read a
part of data.
+
+ The `outputStructType` should be a :class:`StructType` describing the
schema of all
+ elements in the returned value, `pandas.DataFrame`. The column labels
of all elements in
+ returned `pandas.DataFrame` must either match the field names in the
defined schema if
+ specified as strings, or match the field data types by position if not
strings,
+ e.g. integer indices.
+
+ The `stateStructType` should be :class:`StructType` describing the
schema of the
+ user-defined state. The value of the state will be presented as a
tuple, as well as the
+ update should be performed with the tuple. The corresponding Python
types for
+ :class:DataType are supported. Please refer to the page
+ https://spark.apache.org/docs/latest/sql-ref-datatypes.html (python
tab).
+
+ The size of each DataFrame in both the input and output can be
arbitrary. The number of
+ DataFrames in both the input and output can also be arbitrary.
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ func : function
+ a Python native function to be called on every group. It should
take parameters
+ (key, Iterator[`pandas.DataFrame`], state) and return
Iterator[`pandas.DataFrame`].
+ Note that the type of the key is tuple and the type of the state is
+ :class:`pyspark.sql.streaming.state.GroupState`.
+ outputStructType : :class:`pyspark.sql.types.DataType` or str
+ the type of the output records. The value can be either a
+ :class:`pyspark.sql.types.DataType` object or a DDL-formatted type
string.
+ stateStructType : :class:`pyspark.sql.types.DataType` or str
+ the type of the user-defined state. The value can be either a
+ :class:`pyspark.sql.types.DataType` object or a DDL-formatted type
string.
+ outputMode : str
+ the output mode of the function.
+ timeoutConf : str
+ timeout configuration for groups that do not receive data for a
while. valid values
+ are defined in
:class:`pyspark.sql.streaming.state.GroupStateTimeout`.
+
+ Examples
+ --------
+ >>> import pandas as pd # doctest: +SKIP
+ >>> from pyspark.sql.streaming.state import GroupStateTimeout
+ >>> def count_fn(key, pdf_iter, state):
+ ... assert isinstance(state, GroupStateImpl)
+ ... total_len = 0
+ ... for pdf in pdf_iter:
+ ... total_len += len(pdf)
+ ... state.update((total_len,))
+ ... yield pd.DataFrame({"id": [key[0]], "countAsString":
[str(total_len)]})
+ >>> df.groupby("id").applyInPandasWithState(
+ ... count_fn, outputStructType="id long, countAsString string",
+ ... stateStructType="len long", outputMode="Update",
+ ... timeoutConf=GroupStateTimeout.NoTimeout) # doctest: +SKIP
+
+ Notes
+ -----
+ This function requires a full shuffle.
+
+ This API is experimental.
+ """
+
+ from pyspark.sql import GroupedData
+ from pyspark.sql.functions import pandas_udf
+
+ assert isinstance(self, GroupedData)
+ assert timeoutConf in [
+ GroupStateTimeout.NoTimeout,
+ GroupStateTimeout.ProcessingTimeTimeout,
+ GroupStateTimeout.EventTimeTimeout,
+ ]
+
+ if isinstance(outputStructType, str):
+ outputStructType = cast(StructType,
_parse_datatype_string(outputStructType))
+ if isinstance(stateStructType, str):
+ stateStructType = cast(StructType,
_parse_datatype_string(stateStructType))
+
+ udf = pandas_udf(
+ func, # type: ignore[call-overload]
+ returnType=outputStructType,
+ functionType=PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
+ )
+ df = self._df
+ udf_column = udf(*[df[col] for col in df.columns])
+ jdf = self._jgd.applyInPandasWithState(
+ udf_column._jc.expr(),
+ self.session._jsparkSession.parseDataType(outputStructType.json()),
+ self.session._jsparkSession.parseDataType(stateStructType.json()),
+ outputMode,
+ timeoutConf,
+ )
+ return DataFrame(jdf, self.session)
+
def cogroup(self, other: "GroupedData") -> "PandasCogroupedOps":
"""
Cogroups this group with another group so that we can run cogrouped
operations.
diff --git a/python/pyspark/sql/pandas/serializers.py
b/python/pyspark/sql/pandas/serializers.py
index 992e82b403a..ca249c75ea5 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -19,7 +19,9 @@
Serializers for PyArrow and pandas conversions. See `pyspark.serializers` for
more details.
"""
-from pyspark.serializers import Serializer, read_int, write_int,
UTF8Deserializer
+from pyspark.serializers import Serializer, read_int, write_int,
UTF8Deserializer, CPickleSerializer
+from pyspark.sql.pandas.types import to_arrow_type
+from pyspark.sql.types import StringType, StructType, BinaryType, StructField,
LongType
class SpecialLengths:
@@ -371,3 +373,354 @@ class
CogroupUDFSerializer(ArrowStreamPandasUDFSerializer):
raise ValueError(
"Invalid number of pandas.DataFrames in group
{0}".format(dataframes_in_group)
)
+
+
+class ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer):
+ """
+ Serializer used by Python worker to evaluate UDF for
applyInPandasWithState.
+
+ Parameters
+ ----------
+ timezone : str
+ A timezone to respect when handling timestamp values
+ safecheck : bool
+ If True, conversion from Arrow to Pandas checks for overflow/truncation
+ assign_cols_by_name : bool
+ If True, then Pandas DataFrames will get columns by name
+ state_object_schema : StructType
+ The type of state object represented as Spark SQL type
+ arrow_max_records_per_batch : int
+ Limit of the number of records that can be written to a single
ArrowRecordBatch in memory.
+ """
+
+ def __init__(
+ self,
+ timezone,
+ safecheck,
+ assign_cols_by_name,
+ state_object_schema,
+ arrow_max_records_per_batch,
+ ):
+ super(ApplyInPandasWithStateSerializer, self).__init__(
+ timezone, safecheck, assign_cols_by_name
+ )
+ self.pickleSer = CPickleSerializer()
+ self.utf8_deserializer = UTF8Deserializer()
+ self.state_object_schema = state_object_schema
+
+ self.result_state_df_type = StructType(
+ [
+ StructField("properties", StringType()),
+ StructField("keyRowAsUnsafe", BinaryType()),
+ StructField("object", BinaryType()),
+ StructField("oldTimeoutTimestamp", LongType()),
+ ]
+ )
+
+ self.result_state_pdf_arrow_type =
to_arrow_type(self.result_state_df_type)
+ self.arrow_max_records_per_batch = arrow_max_records_per_batch
+
+ def load_stream(self, stream):
+ """
+ Read ArrowRecordBatches from stream, deserialize them to populate a
list of pair
+ (data chunk, state), and convert the data into a list of pandas.Series.
+
+ Please refer the doc of inner function `gen_data_and_state` for more
details how
+ this function works in overall.
+
+ In addition, this function further groups the return of
`gen_data_and_state` by the state
+ instance (same semantic as grouping by grouping key) and produces an
iterator of data
+ chunks for each group, so that the caller can lazily materialize the
data chunk.
+ """
+
+ import pyarrow as pa
+ import json
+ from itertools import groupby
+ from pyspark.sql.streaming.state import GroupState
+
+ def construct_state(state_info_col):
+ """
+ Construct state instance from the value of state information
column.
+ """
+
+ state_info_col_properties = state_info_col["properties"]
+ state_info_col_key_row = state_info_col["keyRowAsUnsafe"]
+ state_info_col_object = state_info_col["object"]
+
+ state_properties = json.loads(state_info_col_properties)
+ if state_info_col_object:
+ state_object = self.pickleSer.loads(state_info_col_object)
+ else:
+ state_object = None
+ state_properties["optionalValue"] = state_object
+
+ return GroupState(
+ keyAsUnsafe=state_info_col_key_row,
+ valueSchema=self.state_object_schema,
+ **state_properties,
+ )
+
+ def gen_data_and_state(batches):
+ """
+ Deserialize ArrowRecordBatches and return a generator of
+ `(a list of pandas.Series, state)`.
+
+ The logic on deserialization is following:
+
+ 1. Read the entire data part from Arrow RecordBatch.
+ 2. Read the entire state information part from Arrow RecordBatch.
+ 3. Loop through each state information:
+ 3.A. Extract the data out from entire data via the information
of data range.
+ 3.B. Construct a new state instance if the state information is
the first occurrence
+ for the current grouping key.
+ 3.C. Leverage the existing state instance if it is already
available for the current
+ grouping key. (Meaning it's not the first occurrence.)
+ 3.D. Remove the cache of state instance if the state
information denotes the data is
+ the last chunk for current grouping key.
+
+ This deserialization logic assumes that Arrow RecordBatches
contain the data with the
+ ordering that data chunks for same grouping key will appear
sequentially.
+
+ This function must avoid materializing multiple Arrow
RecordBatches into memory at the
+ same time. And data chunks from the same grouping key should
appear sequentially, to
+ further group them based on state instance (same state instance
will be produced for
+ same grouping key).
+ """
+
+ state_for_current_group = None
+
+ for batch in batches:
+ batch_schema = batch.schema
+ data_schema = pa.schema([batch_schema[i] for i in range(0,
len(batch_schema) - 1)])
+ state_schema = pa.schema(
+ [
+ batch_schema[-1],
+ ]
+ )
+
+ batch_columns = batch.columns
+ data_columns = batch_columns[0:-1]
+ state_column = batch_columns[-1]
+
+ data_batch = pa.RecordBatch.from_arrays(data_columns,
schema=data_schema)
+ state_batch = pa.RecordBatch.from_arrays(
+ [
+ state_column,
+ ],
+ schema=state_schema,
+ )
+
+ state_arrow =
pa.Table.from_batches([state_batch]).itercolumns()
+ state_pandas = [self.arrow_to_pandas(c) for c in
state_arrow][0]
+
+ for state_idx in range(0, len(state_pandas)):
+ state_info_col = state_pandas.iloc[state_idx]
+
+ if not state_info_col:
+ # no more data with grouping key + state
+ break
+
+ data_start_offset = state_info_col["startOffset"]
+ num_data_rows = state_info_col["numRows"]
+ is_last_chunk = state_info_col["isLastChunk"]
+
+ if state_for_current_group:
+ # use the state, we already have state for same group
and there should be
+ # some data in same group being processed earlier
+ state = state_for_current_group
+ else:
+ # there is no state being stored for same group,
construct one
+ state = construct_state(state_info_col)
+
+ if is_last_chunk:
+ # discard the state being cached for same group
+ state_for_current_group = None
+ elif not state_for_current_group:
+ # there's no cached state but expected to have
additional data in same group
+ # cache the current state
+ state_for_current_group = state
+
+ data_batch_for_group = data_batch.slice(data_start_offset,
num_data_rows)
+ data_arrow =
pa.Table.from_batches([data_batch_for_group]).itercolumns()
+
+ data_pandas = [self.arrow_to_pandas(c) for c in data_arrow]
+
+ # state info
+ yield (
+ data_pandas,
+ state,
+ )
+
+ _batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
+
+ data_state_generator = gen_data_and_state(_batches)
+
+ # state will be same object for same grouping key
+ for _state, _data in groupby(data_state_generator, key=lambda x: x[1]):
+ yield (
+ _data,
+ _state,
+ )
+
+ def dump_stream(self, iterator, stream):
+ """
+ Read through an iterator of (iterator of pandas DataFrame, state),
serialize them to Arrow
+ RecordBatches, and write batches to stream.
+ """
+
+ import pandas as pd
+ import pyarrow as pa
+
+ def construct_state_pdf(state):
+ """
+ Construct a pandas DataFrame from the state instance.
+ """
+
+ state_properties = state.json().encode("utf-8")
+ state_key_row_as_binary = state._keyAsUnsafe
+ if state.exists:
+ state_object =
self.pickleSer.dumps(state._value_schema.toInternal(state._value))
+ else:
+ state_object = None
+ state_old_timeout_timestamp = state.oldTimeoutTimestamp
+
+ state_dict = {
+ "properties": [
+ state_properties,
+ ],
+ "keyRowAsUnsafe": [
+ state_key_row_as_binary,
+ ],
+ "object": [
+ state_object,
+ ],
+ "oldTimeoutTimestamp": [
+ state_old_timeout_timestamp,
+ ],
+ }
+
+ return pd.DataFrame.from_dict(state_dict)
+
+ def construct_record_batch(pdfs, pdf_data_cnt, pdf_schema, state_pdfs,
state_data_cnt):
+ """
+ Construct a new Arrow RecordBatch based on output pandas
DataFrames and states. Each
+ one matches to the single struct field for Arrow schema, hence the
return value of
+ Arrow RecordBatch will have schema with two fields, in `data`,
`state` order.
+ (Readers are expected to access the field via position rather than
the name. We do
+ not guarantee the name of the field.)
+
+ Note that Arrow RecordBatch requires all columns to have all same
number of rows,
+ hence this function inserts empty data for state/data with less
elements to compensate.
+ """
+
+ max_data_cnt = max(pdf_data_cnt, state_data_cnt)
+
+ empty_row_cnt_in_data = max_data_cnt - pdf_data_cnt
+ empty_row_cnt_in_state = max_data_cnt - state_data_cnt
+
+ empty_rows_pdf = pd.DataFrame(
+ dict.fromkeys(pa.schema(pdf_schema).names),
+ index=[x for x in range(0, empty_row_cnt_in_data)],
+ )
+ empty_rows_state = pd.DataFrame(
+ columns=["properties", "keyRowAsUnsafe", "object",
"oldTimeoutTimestamp"],
+ index=[x for x in range(0, empty_row_cnt_in_state)],
+ )
+
+ pdfs.append(empty_rows_pdf)
+ state_pdfs.append(empty_rows_state)
+
+ merged_pdf = pd.concat(pdfs, ignore_index=True)
+ merged_state_pdf = pd.concat(state_pdfs, ignore_index=True)
+
+ return self._create_batch(
+ [(merged_pdf, pdf_schema), (merged_state_pdf,
self.result_state_pdf_arrow_type)]
+ )
+
+ def serialize_batches():
+ """
+ Read through an iterator of (iterator of pandas DataFrame, state),
and serialize them
+ to Arrow RecordBatches.
+
+ This function does batching on constructing the Arrow RecordBatch;
a batch will be
+ serialized to the Arrow RecordBatch when the total number of
records exceeds the
+ configured threshold.
+ """
+ # a set of variables for the state of current batch which will be
converted to Arrow
+ # RecordBatch.
+ pdfs = []
+ state_pdfs = []
+ pdf_data_cnt = 0
+ state_data_cnt = 0
+
+ return_schema = None
+
+ for data in iterator:
+ # data represents the result of each call of user function
+ packaged_result = data[0]
+
+ # There are two results from the call of user function:
+ # 1) iterator of pandas DataFrame (output)
+ # 2) updated state instance
+ pdf_iter = packaged_result[0][0]
+ state = packaged_result[0][1]
+
+ # This is static and won't change across batches.
+ return_schema = packaged_result[1]
+
+ for pdf in pdf_iter:
+ # We ignore empty pandas DataFrame.
+ if len(pdf) > 0:
+ pdf_data_cnt += len(pdf)
+ pdfs.append(pdf)
+
+ # If the total number of records in current batch
exceeds the configured
+ # threshold, time to construct the Arrow RecordBatch
from the batch.
+ if pdf_data_cnt > self.arrow_max_records_per_batch:
+ batch = construct_record_batch(
+ pdfs, pdf_data_cnt, return_schema, state_pdfs,
state_data_cnt
+ )
+
+ # Reset the variables to start with new batch for
further data.
+ pdfs = []
+ state_pdfs = []
+ pdf_data_cnt = 0
+ state_data_cnt = 0
+
+ yield batch
+
+ # This has to be performed 'after' evaluating all elements in
iterator, so that
+ # the user function has been completed and the state is
guaranteed to be updated.
+ state_pdf = construct_state_pdf(state)
+
+ state_pdfs.append(state_pdf)
+ state_data_cnt += 1
+
+ # processed all output, but current batch may not be flushed yet.
+ if pdf_data_cnt > 0 or state_data_cnt > 0:
+ batch = construct_record_batch(
+ pdfs, pdf_data_cnt, return_schema, state_pdfs,
state_data_cnt
+ )
+
+ yield batch
+
+ def init_stream_yield_batches(batches):
+ """
+ This function helps to ensure the requirement for Pandas UDFs -
Pandas UDFs require a
+ START_ARROW_STREAM before the Arrow stream is sent.
+
+ START_ARROW_STREAM should be sent after creating the first record
batch so in case of
+ an error, it can be sent back to the JVM before the Arrow stream
starts.
+ """
+ should_write_start_length = True
+
+ for batch in batches:
+ if should_write_start_length:
+ write_int(SpecialLengths.START_ARROW_STREAM, stream)
+ should_write_start_length = False
+
+ yield batch
+
+ batches_to_write = init_stream_yield_batches(serialize_batches())
+
+ return ArrowStreamSerializer.dump_stream(self, batches_to_write,
stream)
diff --git a/python/pyspark/sql/streaming/state.py
b/python/pyspark/sql/streaming/state.py
index 842eff32233..66b225e1b10 100644
--- a/python/pyspark/sql/streaming/state.py
+++ b/python/pyspark/sql/streaming/state.py
@@ -20,16 +20,24 @@ from typing import Tuple, Optional
from pyspark.sql.types import DateType, Row, StructType
-__all__ = ["GroupStateImpl", "GroupStateTimeout"]
+__all__ = ["GroupState", "GroupStateTimeout"]
class GroupStateTimeout:
+ """
+ Represents the type of timeouts possible for the Dataset operations
applyInPandasWithState.
+ """
+
NoTimeout: str = "NoTimeout"
ProcessingTimeTimeout: str = "ProcessingTimeTimeout"
EventTimeTimeout: str = "EventTimeTimeout"
-class GroupStateImpl:
+class GroupState:
+ """
+ Wrapper class for interacting with per-group state data in
`applyInPandasWithState`.
+ """
+
NO_TIMESTAMP: int = -1
def __init__(
@@ -76,10 +84,16 @@ class GroupStateImpl:
@property
def exists(self) -> bool:
+ """
+ Whether state exists or not.
+ """
return self._defined
@property
def get(self) -> Tuple:
+ """
+ Get the state value if it exists, or throw ValueError.
+ """
if self.exists:
return tuple(self._value)
else:
@@ -87,6 +101,9 @@ class GroupStateImpl:
@property
def getOption(self) -> Optional[Tuple]:
+ """
+ Get the state value if it exists, or return None.
+ """
if self.exists:
return tuple(self._value)
else:
@@ -94,6 +111,10 @@ class GroupStateImpl:
@property
def hasTimedOut(self) -> bool:
+ """
+ Whether the function has been called because the key has timed out.
+ This can return true only when timeouts are enabled.
+ """
return self._has_timed_out
# NOTE: this function is only available to PySpark implementation due to
underlying
@@ -103,6 +124,9 @@ class GroupStateImpl:
return self._old_timeout_timestamp
def update(self, newValue: Tuple) -> None:
+ """
+ Update the value of the state. The value of the state cannot be null.
+ """
if newValue is None:
raise ValueError("'None' is not a valid state value")
@@ -112,11 +136,18 @@ class GroupStateImpl:
self._removed = False
def remove(self) -> None:
+ """
+ Remove this state.
+ """
self._defined = False
self._updated = False
self._removed = True
def setTimeoutDuration(self, durationMs: int) -> None:
+ """
+ Set the timeout duration in ms for this key.
+ Processing time timeout must be enabled.
+ """
if isinstance(durationMs, str):
# TODO(SPARK-40437): Support string representation of durationMs.
raise ValueError("durationMs should be int but get :%s" %
type(durationMs))
@@ -133,6 +164,11 @@ class GroupStateImpl:
# TODO(SPARK-40438): Implement additionalDuration parameter.
def setTimeoutTimestamp(self, timestampMs: int) -> None:
+ """
+ Set the timeout timestamp for this key as milliseconds in epoch time.
+ This timestamp cannot be older than the current watermark.
+ Event time timeout must be enabled.
+ """
if self._timeout_conf != GroupStateTimeout.EventTimeTimeout:
raise RuntimeError(
"Cannot set timeout duration without enabling processing time
timeout in "
@@ -146,7 +182,7 @@ class GroupStateImpl:
raise ValueError("Timeout timestamp must be positive")
if (
- self._event_time_watermark_ms != GroupStateImpl.NO_TIMESTAMP
+ self._event_time_watermark_ms != GroupState.NO_TIMESTAMP
and timestampMs < self._event_time_watermark_ms
):
raise ValueError(
@@ -157,6 +193,10 @@ class GroupStateImpl:
self._timeout_timestamp = timestampMs
def getCurrentWatermarkMs(self) -> int:
+ """
+ Get the current event time watermark as milliseconds in epoch time.
+ In a streaming query, this can be called only when watermark is set.
+ """
if not self._watermark_present:
raise RuntimeError(
"Cannot get event time watermark timestamp without setting
watermark before "
@@ -165,6 +205,11 @@ class GroupStateImpl:
return self._event_time_watermark_ms
def getCurrentProcessingTimeMs(self) -> int:
+ """
+ Get the current processing time as milliseconds in epoch time.
+ In a streaming query, this will return a constant value throughout the
duration of a
+ trigger, even if the trigger is re-executed.
+ """
return self._batch_processing_time_ms
def __str__(self) -> str:
@@ -174,6 +219,10 @@ class GroupStateImpl:
return "GroupState(<undefined>)"
def json(self) -> str:
+ """
+ Convert the internal values of instance into JSON. This is used to
send out the update
+ from Python worker to executor.
+ """
return json.dumps(
{
# Constructor
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index 6a01e399d04..da9a245bb71 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -144,20 +144,23 @@ class UserDefinedFunction:
"Invalid return type with scalar Pandas UDFs: %s is "
"not supported" % str(self._returnType_placeholder)
)
- elif self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
+ elif (
+ self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
+ or self.evalType ==
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE
+ ):
if isinstance(self._returnType_placeholder, StructType):
try:
to_arrow_type(self._returnType_placeholder)
except TypeError:
raise NotImplementedError(
"Invalid return type with grouped map Pandas UDFs or "
- "at groupby.applyInPandas: %s is not supported"
+ "at groupby.applyInPandas(WithState): %s is not
supported"
% str(self._returnType_placeholder)
)
else:
raise TypeError(
"Invalid return type for grouped map Pandas "
- "UDFs or at groupby.applyInPandas: return type must be a "
+ "UDFs or at groupby.applyInPandas(WithState): return type
must be a "
"StructType."
)
elif (
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index c486b7bed1d..c1c3669701f 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -23,6 +23,7 @@ import sys
import time
from inspect import currentframe, getframeinfo, getfullargspec
import importlib
+import json
# 'resource' is a Unix specific module.
has_resource_module = True
@@ -57,6 +58,7 @@ from pyspark.sql.pandas.serializers import (
ArrowStreamPandasUDFSerializer,
CogroupUDFSerializer,
ArrowStreamUDFSerializer,
+ ApplyInPandasWithStateSerializer,
)
from pyspark.sql.pandas.types import to_arrow_type
from pyspark.sql.types import StructType
@@ -207,6 +209,90 @@ def wrap_grouped_map_pandas_udf(f, return_type, argspec):
return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))]
+def wrap_grouped_map_pandas_udf_with_state(f, return_type):
+ """
+ Provides a new lambda instance wrapping user function of
applyInPandasWithState.
+
+ The lambda instance receives (key series, iterator of value series, state)
and performs
+ some conversion to be adapted with the signature of user function.
+
+ See the function doc of inner function `wrapped` for more details on what
adapter does.
+ See the function doc of `mapper` function for
+ `eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE` for
more details on
+ the input parameters of lambda function.
+
+ Along with the returned iterator, the lambda instance will also produce
the return_type as
+ converted to the arrow schema.
+ """
+
+ def wrapped(key_series, value_series_gen, state):
+ """
+ Provide an adapter of the user function performing below:
+
+ - Extract the first value of all columns in key series and produce as
a tuple.
+ - If the state has timed out, call the user function with empty pandas
DataFrame.
+ - If not, construct a new generator which converts each element of
value series to
+ pandas DataFrame (lazy evaluation), and call the user function with
the generator
+ - Verify each element of returned iterator to check the schema of
pandas DataFrame.
+ """
+ import pandas as pd
+
+ key = tuple(s[0] for s in key_series)
+
+ if state.hasTimedOut:
+ # Timeout processing pass empty iterator. Here we return an empty
DataFrame instead.
+ values = [
+ pd.DataFrame(columns=pd.concat(next(value_series_gen),
axis=1).columns),
+ ]
+ else:
+ values = (pd.concat(x, axis=1) for x in value_series_gen)
+
+ result_iter = f(key, values, state)
+
+ def verify_element(result):
+ if not isinstance(result, pd.DataFrame):
+ raise TypeError(
+ "The type of element in return iterator of the
user-defined function "
+ "should be pandas.DataFrame, but is
{}".format(type(result))
+ )
+ # the number of columns of result have to match the return type
+ # but it is fine for result to have no columns at all if it is
empty
+ if not (
+ len(result.columns) == len(return_type)
+ or (len(result.columns) == 0 and result.empty)
+ ):
+ raise RuntimeError(
+ "Number of columns of the element (pandas.DataFrame) in
return iterator "
+ "doesn't match specified schema. "
+ "Expected: {} Actual: {}".format(len(return_type),
len(result.columns))
+ )
+
+ return result
+
+ if isinstance(result_iter, pd.DataFrame):
+ raise TypeError(
+ "Return type of the user-defined function should be "
+ "iterable of pandas.DataFrame, but is
{}".format(type(result_iter))
+ )
+
+ try:
+ iter(result_iter)
+ except TypeError:
+ raise TypeError(
+ "Return type of the user-defined function should be "
+ "iterable, but is {}".format(type(result_iter))
+ )
+
+ result_iter_with_validation = (verify_element(x) for x in result_iter)
+
+ return (
+ result_iter_with_validation,
+ state,
+ )
+
+ return lambda k, v, s: [(wrapped(k, v, s), to_arrow_type(return_type))]
+
+
def wrap_grouped_agg_pandas_udf(f, return_type):
arrow_return_type = to_arrow_type(return_type)
@@ -311,6 +397,8 @@ def read_single_udf(pickleSer, infile, eval_type,
runner_conf, udf_index):
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
argspec = getfullargspec(chained_func) # signature was lost when
wrapping it
return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type,
argspec)
+ elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
+ return arg_offsets, wrap_grouped_map_pandas_udf_with_state(func,
return_type)
elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
argspec = getfullargspec(chained_func) # signature was lost when
wrapping it
return arg_offsets, wrap_cogrouped_map_pandas_udf(func, return_type,
argspec)
@@ -336,6 +424,7 @@ def read_udfs(pickleSer, infile, eval_type):
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF,
+ PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
):
# Load conf used for pandas_udf evaluation
@@ -345,6 +434,10 @@ def read_udfs(pickleSer, infile, eval_type):
v = utf8_deserializer.loads(infile)
runner_conf[k] = v
+ state_object_schema = None
+ if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
+ state_object_schema =
StructType.fromJson(json.loads(utf8_deserializer.loads(infile)))
+
# NOTE: if timezone is set here, that implies respectSessionTimeZone
is True
timezone = runner_conf.get("spark.sql.session.timeZone", None)
safecheck = (
@@ -361,6 +454,19 @@ def read_udfs(pickleSer, infile, eval_type):
if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
ser = CogroupUDFSerializer(timezone, safecheck,
assign_cols_by_name)
+ elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
+ arrow_max_records_per_batch = runner_conf.get(
+ "spark.sql.execution.arrow.maxRecordsPerBatch", 10000
+ )
+ arrow_max_records_per_batch = int(arrow_max_records_per_batch)
+
+ ser = ApplyInPandasWithStateSerializer(
+ timezone,
+ safecheck,
+ assign_cols_by_name,
+ state_object_schema,
+ arrow_max_records_per_batch,
+ )
elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF:
ser = ArrowStreamUDFSerializer()
else:
@@ -486,6 +592,43 @@ def read_udfs(pickleSer, infile, eval_type):
vals = [a[o] for o in parsed_offsets[0][1]]
return f(keys, vals)
+ elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
+ # We assume there is only one UDF here because grouped map doesn't
+ # support combining multiple UDFs.
+ assert num_udfs == 1
+
+ # See FlatMapGroupsInPandas(WithState)Exec for how arg_offsets are
used to
+ # distinguish between grouping attributes and data attributes
+ arg_offsets, f = read_single_udf(pickleSer, infile, eval_type,
runner_conf, udf_index=0)
+ parsed_offsets = extract_key_value_indexes(arg_offsets)
+
+ def mapper(a):
+ """
+ The function receives (iterator of data, state) and performs
extraction of key and
+ value from the data, with retaining lazy evaluation.
+
+ See `load_stream` in `ApplyInPandasWithStateSerializer` for more
details on the input
+ and see `wrap_grouped_map_pandas_udf_with_state` for more details
on how output will
+ be used.
+ """
+ from itertools import tee
+
+ state = a[1]
+ data_gen = (x[0] for x in a[0])
+
+ # We know there should be at least one item in the
iterator/generator.
+ # We want to peek the first element to construct the key, hence
applying
+ # tee to construct the key while we retain another
iterator/generator
+ # for values.
+ keys_gen, values_gen = tee(data_gen)
+ keys_elem = next(keys_gen)
+ keys = [keys_elem[o] for o in parsed_offsets[0][0]]
+
+ # This must be generator comprehension - do not materialize.
+ vals = ([x[o] for o in parsed_offsets[0][1]] for x in values_gen)
+
+ return f(keys, vals, state)
+
elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
# We assume there is only one UDF here because cogrouped map doesn't
# support combining multiple UDFs.
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
index c11ce7d3b90..84795203fd1 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
@@ -64,6 +64,7 @@ object UnsupportedOperationChecker extends Logging {
case s: Aggregate if s.isStreaming => true
case _ @ Join(left, right, _, _, _) if left.isStreaming &&
right.isStreaming => true
case f: FlatMapGroupsWithState if f.isStreaming => true
+ case f: FlatMapGroupsInPandasWithState if f.isStreaming => true
case d: Deduplicate if d.isStreaming => true
case _ => false
}
@@ -142,6 +143,17 @@ object UnsupportedOperationChecker extends Logging {
" or the output mode is not append on a streaming
DataFrames/Datasets")(plan)
}
+ val applyInPandasWithStates = plan.collect {
+ case f: FlatMapGroupsInPandasWithState if f.isStreaming => f
+ }
+
+ // Disallow multiple `applyInPandasWithState`s.
+ if (applyInPandasWithStates.size > 1) {
+ throwError(
+ "Multiple applyInPandasWithStates are not supported on a streaming " +
+ "DataFrames/Datasets")(plan)
+ }
+
// Disallow multiple streaming aggregations
val aggregates = collectStreamingAggregates(plan)
@@ -311,6 +323,56 @@ object UnsupportedOperationChecker extends Logging {
}
}
+ // applyInPandasWithState
+ case m: FlatMapGroupsInPandasWithState if m.isStreaming =>
+ // Check compatibility with output modes and aggregations in query
+ val aggsInQuery = collectStreamingAggregates(plan)
+
+ if (aggsInQuery.isEmpty) {
+ // applyInPandasWithState without aggregation: operation's output
mode must
+ // match query output mode
+ m.outputMode match {
+ case InternalOutputModes.Update if outputMode !=
InternalOutputModes.Update =>
+ throwError(
+ "applyInPandasWithState in update mode is not supported with
" +
+ s"$outputMode output mode on a streaming
DataFrame/Dataset")
+
+ case InternalOutputModes.Append if outputMode !=
InternalOutputModes.Append =>
+ throwError(
+ "applyInPandasWithState in append mode is not supported with
" +
+ s"$outputMode output mode on a streaming
DataFrame/Dataset")
+
+ case _ =>
+ }
+ } else {
+ // applyInPandasWithState with aggregation: update operation mode
not allowed, and
+ // *groupsWithState after aggregation not allowed
+ if (m.outputMode == InternalOutputModes.Update) {
+ throwError(
+ "applyInPandasWithState in update mode is not supported with "
+
+ "aggregation on a streaming DataFrame/Dataset")
+ } else if (collectStreamingAggregates(m).nonEmpty) {
+ throwError(
+ "applyInPandasWithState in append mode is not supported after
" +
+ "aggregation on a streaming DataFrame/Dataset")
+ }
+ }
+
+ // Check compatibility with timeout configs
+ if (m.timeout == EventTimeTimeout) {
+ // With event time timeout, watermark must be defined.
+ val watermarkAttributes = m.child.output.collect {
+ case a: Attribute if
a.metadata.contains(EventTimeWatermark.delayKey) => a
+ }
+ if (watermarkAttributes.isEmpty) {
+ throwError(
+ "Watermark must be specified in the query using " +
+ "'[Dataset/DataFrame].withWatermark()' for using event-time
timeout in a " +
+ "applyInPandasWithState. Event-time timeout not supported
without " +
+ "watermark.")(plan)
+ }
+ }
+
case d: Deduplicate if collectStreamingAggregates(d).nonEmpty =>
throwError("dropDuplicates is not supported after aggregation on a "
+
"streaming DataFrame/Dataset")
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
index c2f74b35083..e97ff7808f1 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet,
Expression, PythonUDF}
import org.apache.spark.sql.catalyst.util.truncatedString
+import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
+import org.apache.spark.sql.types.StructType
/**
* FlatMap groups using a udf: pandas.Dataframe -> pandas.DataFrame.
@@ -98,6 +100,38 @@ case class FlatMapCoGroupsInPandas(
copy(left = newLeft, right = newRight)
}
+/**
+ * Similar with [[FlatMapGroupsWithState]]. Applies func to each unique group
+ * in `child`, based on the evaluation of `groupingAttributes`,
+ * while using state data.
+ * `functionExpr` is invoked with an pandas DataFrame representation and the
+ * grouping key (tuple).
+ *
+ * @param functionExpr function called on each group
+ * @param groupingAttributes used to group the data
+ * @param outputAttrs used to define the output rows
+ * @param stateType used to serialize/deserialize state before calling
`functionExpr`
+ * @param outputMode the output mode of `func`
+ * @param timeout used to timeout groups that have not received data in a while
+ * @param child logical plan of the underlying data
+ */
+case class FlatMapGroupsInPandasWithState(
+ functionExpr: Expression,
+ groupingAttributes: Seq[Attribute],
+ outputAttrs: Seq[Attribute],
+ stateType: StructType,
+ outputMode: OutputMode,
+ timeout: GroupStateTimeout,
+ child: LogicalPlan) extends UnaryNode {
+
+ override def output: Seq[Attribute] = outputAttrs
+
+ override def producedAttributes: AttributeSet = AttributeSet(outputAttrs)
+
+ override protected def withNewChildInternal(
+ newChild: LogicalPlan): FlatMapGroupsInPandasWithState = copy(child =
newChild)
+}
+
trait BaseEvalPython extends UnaryNode {
def udfs: Seq[PythonUDF]
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 989ee325218..0429fd27a41 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -30,9 +30,11 @@ import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.catalyst.util.toPrettySQL
import org.apache.spark.sql.errors.{QueryCompilationErrors,
QueryExecutionErrors}
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
+import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.{NumericType, StructType}
/**
@@ -620,6 +622,49 @@ class RelationalGroupedDataset protected[sql](
Dataset.ofRows(df.sparkSession, plan)
}
+ /**
+ * Applies a grouped vectorized python user-defined function to each group
of data.
+ * The user-defined function defines a transformation: iterator of
`pandas.DataFrame` ->
+ * iterator of `pandas.DataFrame`.
+ * For each group, all elements in the group are passed as an iterator of
`pandas.DataFrame`
+ * along with corresponding state, and the results for all groups are
combined into a new
+ * [[DataFrame]].
+ *
+ * This function does not support partial aggregation, and requires
shuffling all the data in
+ * the [[DataFrame]].
+ *
+ * This function uses Apache Arrow as serialization format between Java
executors and Python
+ * workers.
+ */
+ private[sql] def applyInPandasWithState(
+ func: PythonUDF,
+ outputStructType: StructType,
+ stateStructType: StructType,
+ outputModeStr: String,
+ timeoutConfStr: String): DataFrame = {
+ val timeoutConf = org.apache.spark.sql.execution.streaming
+ .GroupStateImpl.groupStateTimeoutFromString(timeoutConfStr)
+ val outputMode = InternalOutputModes(outputModeStr)
+ if (outputMode != OutputMode.Append && outputMode != OutputMode.Update) {
+ throw new IllegalArgumentException("The output mode of function should
be append or update")
+ }
+ val groupingNamedExpressions = groupingExprs.map {
+ case ne: NamedExpression => ne
+ case other => Alias(other, other.toString)()
+ }
+ val groupingAttrs = groupingNamedExpressions.map(_.toAttribute)
+ val outputAttrs = outputStructType.toAttributes
+ val plan = FlatMapGroupsInPandasWithState(
+ func,
+ groupingAttrs,
+ outputAttrs,
+ stateStructType,
+ outputMode,
+ timeoutConf,
+ child = df.logicalPlan)
+ Dataset.ofRows(df.sparkSession, plan)
+ }
+
override def toString: String = {
val builder = new StringBuilder
builder.append("RelationalGroupedDataset: [grouping expressions: [")
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 6104104c7be..c64a123e3a7 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -684,6 +684,25 @@ abstract class SparkStrategies extends
QueryPlanner[SparkPlan] {
}
}
+ /**
+ * Strategy to convert [[FlatMapGroupsInPandasWithState]] logical operator
to physical operator
+ * in streaming plans. Conversion for batch plans is handled by
[[BasicOperators]].
+ */
+ object FlatMapGroupsInPandasWithStateStrategy extends Strategy {
+ override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case FlatMapGroupsInPandasWithState(
+ func, groupAttr, outputAttr, stateType, outputMode, timeout, child) =>
+ val stateVersion =
conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION)
+ val execPlan = python.FlatMapGroupsInPandasWithStateExec(
+ func, groupAttr, outputAttr, stateType, None, stateVersion,
outputMode, timeout,
+ batchTimestampMs = None, eventTimeWatermark = None, planLater(child)
+ )
+ execPlan :: Nil
+ case _ =>
+ Nil
+ }
+ }
+
/**
* Strategy to convert EvalPython logical operator to physical operator.
*/
@@ -793,6 +812,10 @@ abstract class SparkStrategies extends
QueryPlanner[SparkPlan] {
initialStateGroupAttrs, data, initialStateDataAttrs, output, timeout,
hasInitialState, planLater(initialState), planLater(child)
) :: Nil
+ case _: FlatMapGroupsInPandasWithState =>
+ // TODO(SPARK-40443): support applyInPandasWithState in batch query
+ throw new UnsupportedOperationException(
+ "applyInPandasWithState is unsupported in batch query. Use
applyInPandas instead.")
case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr,
oAttr, left, right) =>
execution.CoGroupExec(
f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr,
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
index 7abca5f0e33..2988c0fb518 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
@@ -44,7 +44,7 @@ object ArrowWriter {
new ArrowWriter(root, children.toArray)
}
- private def createFieldWriter(vector: ValueVector): ArrowFieldWriter = {
+ private[sql] def createFieldWriter(vector: ValueVector): ArrowFieldWriter = {
val field = vector.getField()
(ArrowUtils.fromArrowField(field), vector) match {
case (BooleanType, vector: BitVector) => new BooleanWriter(vector)
@@ -98,6 +98,16 @@ class ArrowWriter(val root: VectorSchemaRoot, fields:
Array[ArrowFieldWriter]) {
count += 1
}
+ def sizeInBytes(): Int = {
+ var i = 0
+ var bytes = 0
+ while (i < fields.size) {
+ bytes += fields(i).getSizeInBytes()
+ i += 1
+ }
+ bytes
+ }
+
def finish(): Unit = {
root.setRowCount(count)
fields.foreach(_.finish())
@@ -132,6 +142,10 @@ private[arrow] abstract class ArrowFieldWriter {
count += 1
}
+ def getSizeInBytes(): Int = {
+ valueVector.getBufferSizeFor(count)
+ }
+
def finish(): Unit = {
valueVector.setValueCount(count)
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
new file mode 100644
index 00000000000..bd8c72029dc
--- /dev/null
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
@@ -0,0 +1,223 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import java.io._
+
+import scala.collection.JavaConverters._
+
+import org.apache.arrow.vector.VectorSchemaRoot
+import org.apache.arrow.vector.ipc.ArrowStreamWriter
+import org.json4s._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.api.python._
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.api.python.PythonSQLUtils
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import
org.apache.spark.sql.execution.python.ApplyInPandasWithStatePythonRunner.{InType,
OutType, OutTypeForState, STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER}
+import
org.apache.spark.sql.execution.python.ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA
+import org.apache.spark.sql.execution.streaming.GroupStateImpl
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
+
+
+/**
+ * A variant implementation of [[ArrowPythonRunner]] to serve the operation
+ * applyInPandasWithState.
+ *
+ * Unlike normal ArrowPythonRunner which both input and output (executor <->
python worker)
+ * are InternalRow, applyInPandasWithState has side data (state information)
in both input
+ * and output along with data, which requires different struct on Arrow
RecordBatch.
+ */
+class ApplyInPandasWithStatePythonRunner(
+ funcs: Seq[ChainedPythonFunctions],
+ evalType: Int,
+ argOffsets: Array[Array[Int]],
+ inputSchema: StructType,
+ override protected val timeZoneId: String,
+ initialWorkerConf: Map[String, String],
+ stateEncoder: ExpressionEncoder[Row],
+ keySchema: StructType,
+ outputSchema: StructType,
+ stateValueSchema: StructType)
+ extends BasePythonRunner[InType, OutType](funcs, evalType, argOffsets)
+ with PythonArrowInput[InType]
+ with PythonArrowOutput[OutType] {
+
+ private val sqlConf = SQLConf.get
+
+ override protected val schema: StructType = inputSchema.add("__state",
STATE_METADATA_SCHEMA)
+
+ override val simplifiedTraceback: Boolean =
sqlConf.pysparkSimplifiedTraceback
+
+ override val bufferSize: Int = {
+ val configuredSize = sqlConf.pandasUDFBufferSize
+ if (configuredSize < 4) {
+ logWarning("Pandas execution requires more than 4 bytes. Please
configure bigger value " +
+ s"for the configuration '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'. " +
+ "Force using the value '4'.")
+ 4
+ } else {
+ configuredSize
+ }
+ }
+
+ private val arrowMaxRecordsPerBatch = sqlConf.arrowMaxRecordsPerBatch
+
+ // applyInPandasWithState has its own mechanism to construct the Arrow
RecordBatch instance.
+ // Configurations are both applied to executor and Python worker, set them
to the worker conf
+ // to let Python worker read the config properly.
+ override protected val workerConf: Map[String, String] = initialWorkerConf +
+ (SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key ->
arrowMaxRecordsPerBatch.toString)
+
+ private val stateRowDeserializer = stateEncoder.createDeserializer()
+
+ /**
+ * This method sends out the additional metadata before sending out actual
data.
+ *
+ * Specifically, this class overrides this method to also write the schema
for state value.
+ */
+ override protected def handleMetadataBeforeExec(stream: DataOutputStream):
Unit = {
+ super.handleMetadataBeforeExec(stream)
+ // Also write the schema for state value
+ PythonRDD.writeUTF(stateValueSchema.json, stream)
+ }
+
+ /**
+ * Read the (key, state, values) from input iterator and construct Arrow
RecordBatches, and
+ * write constructed RecordBatches to the writer.
+ *
+ * See [[ApplyInPandasWithStateWriter]] for more details.
+ */
+ protected def writeIteratorToArrowStream(
+ root: VectorSchemaRoot,
+ writer: ArrowStreamWriter,
+ dataOut: DataOutputStream,
+ inputIterator: Iterator[InType]): Unit = {
+ val w = new ApplyInPandasWithStateWriter(root, writer,
arrowMaxRecordsPerBatch)
+
+ while (inputIterator.hasNext) {
+ val (keyRow, groupState, dataIter) = inputIterator.next()
+ assert(dataIter.hasNext, "should have at least one data row!")
+ w.startNewGroup(keyRow, groupState)
+
+ while (dataIter.hasNext) {
+ val dataRow = dataIter.next()
+ w.writeRow(dataRow)
+ }
+
+ w.finalizeGroup()
+ }
+
+ w.finalizeData()
+ }
+
+ /**
+ * Deserialize ColumnarBatch received from the Python worker to produce the
output. Schema info
+ * for given ColumnarBatch is also provided as well.
+ */
+ protected def deserializeColumnarBatch(batch: ColumnarBatch, schema:
StructType): OutType = {
+ // This should at least have one row for state. Also, we ensure that all
columns across
+ // data and state metadata have same number of rows, which is required by
Arrow record
+ // batch.
+ assert(batch.numRows() > 0)
+ assert(schema.length == 2)
+
+ def getColumnarBatchForStructTypeColumn(
+ batch: ColumnarBatch,
+ ordinal: Int,
+ expectedType: StructType): ColumnarBatch = {
+ // UDF returns a StructType column in ColumnarBatch, select the
children here
+ val structVector = batch.column(ordinal).asInstanceOf[ArrowColumnVector]
+ val dataType = schema(ordinal).dataType.asInstanceOf[StructType]
+ assert(dataType.sameType(expectedType),
+ s"Schema equality check failure! type from Arrow: $dataType, expected
type: $expectedType")
+
+ val outputVectors = dataType.indices.map(structVector.getChild)
+ val flattenedBatch = new ColumnarBatch(outputVectors.toArray)
+ flattenedBatch.setNumRows(batch.numRows())
+
+ flattenedBatch
+ }
+
+ def constructIterForData(batch: ColumnarBatch): Iterator[InternalRow] = {
+ val dataBatch = getColumnarBatchForStructTypeColumn(batch, 0,
outputSchema)
+ dataBatch.rowIterator.asScala.flatMap { row =>
+ if (row.isNullAt(0)) {
+ // The entire row in record batch seems to be for state metadata.
+ None
+ } else {
+ Some(row)
+ }
+ }
+ }
+
+ def constructIterForState(batch: ColumnarBatch): Iterator[OutTypeForState]
= {
+ val stateMetadataBatch = getColumnarBatchForStructTypeColumn(batch, 1,
+ STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER)
+
+ stateMetadataBatch.rowIterator().asScala.flatMap { row =>
+ implicit val formats = org.json4s.DefaultFormats
+
+ if (row.isNullAt(0)) {
+ // The entire row in record batch seems to be for data.
+ None
+ } else {
+ // NOTE: See
ApplyInPandasWithStatePythonRunner.STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER
+ // for the schema.
+ val propertiesAsJson = parse(row.getUTF8String(0).toString)
+ val keyRowAsUnsafeAsBinary = row.getBinary(1)
+ val keyRowAsUnsafe = new UnsafeRow(keySchema.fields.length)
+ keyRowAsUnsafe.pointTo(keyRowAsUnsafeAsBinary,
keyRowAsUnsafeAsBinary.length)
+ val maybeObjectRow = if (row.isNullAt(2)) {
+ None
+ } else {
+ val pickledStateValue = row.getBinary(2)
+ Some(PythonSQLUtils.toJVMRow(pickledStateValue, stateValueSchema,
+ stateRowDeserializer))
+ }
+ val oldTimeoutTimestamp = row.getLong(3)
+
+ Some((keyRowAsUnsafe, GroupStateImpl.fromJson(maybeObjectRow,
propertiesAsJson),
+ oldTimeoutTimestamp))
+ }
+ }
+ }
+
+ (constructIterForState(batch), constructIterForData(batch))
+ }
+}
+
+object ApplyInPandasWithStatePythonRunner {
+ type InType = (UnsafeRow, GroupStateImpl[Row], Iterator[InternalRow])
+ type OutTypeForState = (UnsafeRow, GroupStateImpl[Row], Long)
+ type OutType = (Iterator[OutTypeForState], Iterator[InternalRow])
+
+ val STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER: StructType = StructType(
+ Array(
+ StructField("properties", StringType),
+ StructField("keyRowAsUnsafe", BinaryType),
+ StructField("object", BinaryType),
+ StructField("oldTimeoutTimestamp", LongType)
+ )
+ )
+}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala
new file mode 100644
index 00000000000..60a228ddd73
--- /dev/null
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala
@@ -0,0 +1,276 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import scala.collection.JavaConverters._
+
+import org.apache.arrow.vector.{FieldVector, VectorSchemaRoot}
+import org.apache.arrow.vector.ipc.ArrowStreamWriter
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.api.python.PythonSQLUtils
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow,
UnsafeRow}
+import org.apache.spark.sql.execution.arrow.ArrowWriter
+import org.apache.spark.sql.execution.arrow.ArrowWriter.createFieldWriter
+import org.apache.spark.sql.execution.streaming.GroupStateImpl
+import org.apache.spark.sql.types.{BinaryType, BooleanType, IntegerType,
StringType, StructField, StructType}
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * This class abstracts the complexity on constructing Arrow RecordBatches for
data and state with
+ * bin-packing and chunking. The caller only need to call the proper public
methods of this class
+ * `startNewGroup`, `writeRow`, `finalizeGroup`, `finalizeData` and this class
will write the data
+ * and state into Arrow RecordBatches with performing bin-pack and chunk
internally.
+ *
+ * This class requires that the parameter `root` has been initialized with the
Arrow schema like
+ * below:
+ * - data fields
+ * - state field
+ * - nested schema (Refer ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA)
+ *
+ * Please refer the code comment in the implementation to see how the writes
of data and state
+ * against Arrow RecordBatch work with consideration of bin-packing and
chunking.
+ */
+class ApplyInPandasWithStateWriter(
+ root: VectorSchemaRoot,
+ writer: ArrowStreamWriter,
+ arrowMaxRecordsPerBatch: Int) {
+
+ import ApplyInPandasWithStateWriter._
+
+ // Unlike applyInPandas (and other PySpark operators),
applyInPandasWithState requires to produce
+ // the additional data `state`, along with the input data.
+ //
+ // ArrowStreamWriter supports only single VectorSchemaRoot, which means all
Arrow RecordBatches
+ // being sent out from ArrowStreamWriter should have same schema. That said,
we have to construct
+ // "an" Arrow schema to contain both data and state, and also construct
ArrowBatches to contain
+ // both data and state.
+ //
+ // To achieve this, we extend the schema for input data to have a column for
state at the end.
+ // But also, we logically group the columns by family (data vs state) and
initialize writer
+ // separately, since it's lot more easier and probably performant to write
the row directly
+ // rather than projecting the row to match up with the overall schema.
+ //
+ // Although Arrow RecordBatch enables to write the data as columnar, we
figure out it gives
+ // strange outputs if we don't ensure that all columns have the same number
of values. Since
+ // there are at least one data for a grouping key (we ensure this for the
case of handling timed
+ // out state as well) whereas there is only one state for a grouping key, we
have to fill up the
+ // empty rows in state side to ensure both have the same number of rows.
+ private val arrowWriterForData = createArrowWriter(
+ root.getFieldVectors.asScala.toSeq.dropRight(1))
+ private val arrowWriterForState = createArrowWriter(
+ root.getFieldVectors.asScala.toSeq.takeRight(1))
+
+ // - Bin-packing
+ //
+ // We apply bin-packing the data from multiple groups into one Arrow
RecordBatch to
+ // gain the performance. In many cases, the amount of data per grouping key
is quite
+ // small, which does not seem to maximize the benefits of using Arrow.
+ //
+ // We have to split the record batch down to each group in Python worker to
convert the
+ // data for group to Pandas, but hopefully, Arrow RecordBatch provides the
way to split
+ // the range of data and give a view, say, "zero-copy". To help splitting
the range for
+ // data, we provide the "start offset" and the "number of data" in the state
metadata.
+ //
+ // We don't bin-pack all groups into a single record batch - we have a limit
on the number
+ // of rows in the current Arrow RecordBatch to stop adding next group.
+ //
+ // - Chunking
+ //
+ // We also chunk the data from single group into multiple Arrow RecordBatch
to ensure
+ // scalability. Note that we don't know the volume (number of rows, overall
size) of data for
+ // specific group key before we read the entire data. The easiest approach
to address both
+ // bin-pack and chunk is to check the number of rows in the current Arrow
RecordBatch for each
+ // write of row.
+ //
+ // - Data and State
+ //
+ // Since we apply bin-packing and chunking, there should be the way to
distinguish each chunk
+ // from the entire data part of Arrow RecordBatch. We leverage the state
metadata to also
+ // contain the "metadata" of data part to distinguish the chunk from the
entire data.
+ // As a result, state metadata has a 1-1 relationship with "chunk", instead
of "grouping key".
+ //
+ // - Consideration
+ //
+ // Since the number of rows in Arrow RecordBatch does not represent the
actual size (bytes),
+ // the limit should be set very conservatively. Using a small number of
limit does not introduce
+ // correctness issues.
+
+ // variables for tracking current grouping key and state
+ private var currentGroupKeyRow: UnsafeRow = _
+ private var currentGroupState: GroupStateImpl[Row] = _
+
+ // variables for tracking the status of current batch
+ private var totalNumRowsForBatch = 0
+ private var totalNumStatesForBatch = 0
+
+ // variables for tracking the status of current chunk
+ private var startOffsetForCurrentChunk = 0
+ private var numRowsForCurrentChunk = 0
+
+
+ /**
+ * Indicates writer to start with new grouping key.
+ *
+ * @param keyRow The grouping key row for current group.
+ * @param groupState The instance of GroupStateImpl for current group.
+ */
+ def startNewGroup(keyRow: UnsafeRow, groupState: GroupStateImpl[Row]): Unit
= {
+ currentGroupKeyRow = keyRow
+ currentGroupState = groupState
+ }
+
+ /**
+ * Indicates writer to write a row in the current group.
+ *
+ * @param dataRow The row to write in the current group.
+ */
+ def writeRow(dataRow: InternalRow): Unit = {
+ // If it exceeds the condition of batch (number of records) and there is
more data for the
+ // same group, finalize and construct a new batch.
+
+ if (totalNumRowsForBatch >= arrowMaxRecordsPerBatch) {
+ finalizeCurrentChunk(isLastChunkForGroup = false)
+ finalizeCurrentArrowBatch()
+ }
+
+ arrowWriterForData.write(dataRow)
+
+ numRowsForCurrentChunk += 1
+ totalNumRowsForBatch += 1
+ }
+
+ /**
+ * Indicates writer that current group has finalized and there will be no
further row bound to
+ * the current group.
+ */
+ def finalizeGroup(): Unit = {
+ finalizeCurrentChunk(isLastChunkForGroup = true)
+
+ // If it exceeds the condition of batch (number of records) once the all
data is received for
+ // same group, finalize and construct a new batch.
+ if (totalNumRowsForBatch >= arrowMaxRecordsPerBatch) {
+ finalizeCurrentArrowBatch()
+ }
+ }
+
+ /**
+ * Indicates writer that all groups have been processed.
+ */
+ def finalizeData(): Unit = {
+ if (totalNumRowsForBatch > 0) {
+ // We still have some rows in the current record batch. Need to finalize
them as well.
+ finalizeCurrentArrowBatch()
+ }
+ }
+
+ private def createArrowWriter(fieldVectors: Seq[FieldVector]): ArrowWriter =
{
+ val children = fieldVectors.map { vector =>
+ vector.allocateNew()
+ createFieldWriter(vector)
+ }
+
+ new ArrowWriter(root, children.toArray)
+ }
+
+ private def buildStateInfoRow(
+ keyRow: UnsafeRow,
+ groupState: GroupStateImpl[Row],
+ startOffset: Int,
+ numRows: Int,
+ isLastChunk: Boolean): InternalRow = {
+ // NOTE: see ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA
+ val stateUnderlyingRow = new GenericInternalRow(
+ Array[Any](
+ UTF8String.fromString(groupState.json()),
+ keyRow.getBytes,
+ groupState.getOption.map(PythonSQLUtils.toPyRow).orNull,
+ startOffset,
+ numRows,
+ isLastChunk
+ )
+ )
+ new GenericInternalRow(Array[Any](stateUnderlyingRow))
+ }
+
+ private def finalizeCurrentChunk(isLastChunkForGroup: Boolean): Unit = {
+ val stateInfoRow = buildStateInfoRow(currentGroupKeyRow, currentGroupState,
+ startOffsetForCurrentChunk, numRowsForCurrentChunk, isLastChunkForGroup)
+ arrowWriterForState.write(stateInfoRow)
+ totalNumStatesForBatch += 1
+
+ // The start offset for next chunk would be same as the total number of
rows for batch,
+ // unless the next chunk starts with new batch.
+ startOffsetForCurrentChunk = totalNumRowsForBatch
+ numRowsForCurrentChunk = 0
+ }
+
+ private def finalizeCurrentArrowBatch(): Unit = {
+ val remainingEmptyStateRows = totalNumRowsForBatch - totalNumStatesForBatch
+ (0 until remainingEmptyStateRows).foreach { _ =>
+ arrowWriterForState.write(EMPTY_STATE_METADATA_ROW)
+ }
+
+ arrowWriterForState.finish()
+ arrowWriterForData.finish()
+ writer.writeBatch()
+ arrowWriterForState.reset()
+ arrowWriterForData.reset()
+
+ startOffsetForCurrentChunk = 0
+ numRowsForCurrentChunk = 0
+ totalNumRowsForBatch = 0
+ totalNumStatesForBatch = 0
+ }
+}
+
+object ApplyInPandasWithStateWriter {
+ // This schema contains both state metadata and the metadata of the chunk.
Refer the code comment
+ // of "Data and State" for more details.
+ val STATE_METADATA_SCHEMA: StructType = StructType(
+ Array(
+ /*
+ Metadata of the state
+ */
+
+ // properties of state instance (excluding state value) in json format
+ StructField("properties", StringType),
+ // key row as UnsafeRow, Python worker won't touch this value but send
the value back to
+ // executor when sending an update of state
+ StructField("keyRowAsUnsafe", BinaryType),
+ // state value
+ StructField("object", BinaryType),
+
+ /*
+ Metadata of the chunk
+ */
+
+ // start offset of the data chunk from entire data
+ StructField("startOffset", IntegerType),
+ // the number of rows for the data chunk
+ StructField("numRows", IntegerType),
+ // whether the current data chunk is the last one for current grouping
key or not
+ StructField("isLastChunk", BooleanType)
+ )
+ )
+
+ // To avoid initializing a new row for empty state metadata row.
+ val EMPTY_STATE_METADATA_ROW = new GenericInternalRow(
+ Array[Any](null, null, null, null, null, null))
+}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala
index e830ea6b546..b39787b12a4 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala
@@ -78,8 +78,8 @@ case class FlatMapCoGroupsInPandasExec(
override protected def doExecute(): RDD[InternalRow] = {
- val (leftDedup, leftArgOffsets) = resolveArgOffsets(left, leftGroup)
- val (rightDedup, rightArgOffsets) = resolveArgOffsets(right, rightGroup)
+ val (leftDedup, leftArgOffsets) = resolveArgOffsets(left.output, leftGroup)
+ val (rightDedup, rightArgOffsets) = resolveArgOffsets(right.output,
rightGroup)
// Map cogrouped rows to ArrowPythonRunner results, Only execute if
partition is not empty
left.execute().zipPartitions(right.execute()) { (leftData, rightData) =>
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
index 3a3a6022f99..f0e815e966e 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
@@ -75,7 +75,7 @@ case class FlatMapGroupsInPandasExec(
override protected def doExecute(): RDD[InternalRow] = {
val inputRDD = child.execute()
- val (dedupAttributes, argOffsets) = resolveArgOffsets(child,
groupingAttributes)
+ val (dedupAttributes, argOffsets) = resolveArgOffsets(child.output,
groupingAttributes)
// Map grouped rows to ArrowPythonRunner results, Only execute if
partition is not empty
inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala
new file mode 100644
index 00000000000..159f805f734
--- /dev/null
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.python
+
+import org.apache.spark.TaskContext
+import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout,
ProcessingTimeTimeout}
+import org.apache.spark.sql.catalyst.plans.physical.Distribution
+import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan,
UnaryExecNode}
+import org.apache.spark.sql.execution.python.PandasGroupUtils.resolveArgOffsets
+import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP
+import
org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper.StateData
+import org.apache.spark.sql.execution.streaming.state.StateStore
+import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.ArrowUtils
+import org.apache.spark.util.CompletionIterator
+
+/**
+ * Physical operator for executing
+ *
[[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandasWithState]]
+ *
+ * @param functionExpr function called on each group
+ * @param groupingAttributes used to group the data
+ * @param outAttributes used to define the output rows
+ * @param stateType used to serialize/deserialize state before calling
`functionExpr`
+ * @param stateInfo `StatefulOperatorStateInfo` to identify the state store
for a given operator.
+ * @param stateFormatVersion the version of state format.
+ * @param outputMode the output mode of `functionExpr`
+ * @param timeoutConf used to timeout groups that have not received data in a
while
+ * @param batchTimestampMs processing timestamp of the current batch.
+ * @param eventTimeWatermark event time watermark for the current batch
+ * @param child logical plan of the underlying data
+ */
+case class FlatMapGroupsInPandasWithStateExec(
+ functionExpr: Expression,
+ groupingAttributes: Seq[Attribute],
+ outAttributes: Seq[Attribute],
+ stateType: StructType,
+ stateInfo: Option[StatefulOperatorStateInfo],
+ stateFormatVersion: Int,
+ outputMode: OutputMode,
+ timeoutConf: GroupStateTimeout,
+ batchTimestampMs: Option[Long],
+ eventTimeWatermark: Option[Long],
+ child: SparkPlan) extends UnaryExecNode with
FlatMapGroupsWithStateExecBase {
+
+ // TODO(SPARK-40444): Add the support of initial state.
+ override protected val initialStateDeserializer: Expression = null
+ override protected val initialStateGroupAttrs: Seq[Attribute] = null
+ override protected val initialStateDataAttrs: Seq[Attribute] = null
+ override protected val initialState: SparkPlan = null
+ override protected val hasInitialState: Boolean = false
+
+ override protected val stateEncoder: ExpressionEncoder[Any] =
+ RowEncoder(stateType).resolveAndBind().asInstanceOf[ExpressionEncoder[Any]]
+
+ override def output: Seq[Attribute] = outAttributes
+
+ private val sessionLocalTimeZone = conf.sessionLocalTimeZone
+ private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
+
+ private val pythonFunction = functionExpr.asInstanceOf[PythonUDF].func
+ private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction)))
+ private lazy val (dedupAttributes, argOffsets) = resolveArgOffsets(
+ groupingAttributes ++ child.output, groupingAttributes)
+ private lazy val unsafeProj = UnsafeProjection.create(dedupAttributes,
child.output)
+
+ override def requiredChildDistribution: Seq[Distribution] =
+ StatefulOperatorPartitioning.getCompatibleDistribution(
+ groupingAttributes, getStateInfo, conf) :: Nil
+
+ override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq(
+ groupingAttributes.map(SortOrder(_, Ascending)))
+
+ override def shortName: String = "applyInPandasWithState"
+
+ override protected def withNewChildInternal(
+ newChild: SparkPlan): FlatMapGroupsInPandasWithStateExec = copy(child =
newChild)
+
+ override def createInputProcessor(
+ store: StateStore): InputProcessor = new InputProcessor(store:
StateStore) {
+
+ override def processNewData(dataIter: Iterator[InternalRow]):
Iterator[InternalRow] = {
+ val groupedIter = GroupedIterator(dataIter, groupingAttributes,
child.output)
+ val processIter = groupedIter.map { case (keyRow, valueRowIter) =>
+ val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow]
+ val stateData = stateManager.getState(store, keyUnsafeRow)
+ (keyUnsafeRow, stateData, valueRowIter.map(unsafeProj))
+ }
+
+ process(processIter, hasTimedOut = false)
+ }
+
+ override def processNewDataWithInitialState(
+ childDataIter: Iterator[InternalRow],
+ initStateIter: Iterator[InternalRow]): Iterator[InternalRow] = {
+ throw new UnsupportedOperationException("Should not reach here!")
+ }
+
+ override def processTimedOutState(): Iterator[InternalRow] = {
+ if (isTimeoutEnabled) {
+ val timeoutThreshold = timeoutConf match {
+ case ProcessingTimeTimeout => batchTimestampMs.get
+ case EventTimeTimeout => eventTimeWatermark.get
+ case _ =>
+ throw new IllegalStateException(
+ s"Cannot filter timed out keys for $timeoutConf")
+ }
+ val timingOutPairs = stateManager.getAllState(store).filter { state =>
+ state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp <
timeoutThreshold
+ }
+
+ val processIter = timingOutPairs.map { stateData =>
+ val joinedKeyRow = unsafeProj(
+ new JoinedRow(
+ stateData.keyRow,
+ new GenericInternalRow(Array.fill(dedupAttributes.length)(null:
Any))))
+
+ (stateData.keyRow, stateData, Iterator.single(joinedKeyRow))
+ }
+
+ process(processIter, hasTimedOut = true)
+ } else Iterator.empty
+ }
+
+ private def process(
+ iter: Iterator[(UnsafeRow, StateData, Iterator[InternalRow])],
+ hasTimedOut: Boolean): Iterator[InternalRow] = {
+ val runner = new ApplyInPandasWithStatePythonRunner(
+ chainedFunc,
+ PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
+ Array(argOffsets),
+ StructType.fromAttributes(dedupAttributes),
+ sessionLocalTimeZone,
+ pythonRunnerConf,
+ stateEncoder.asInstanceOf[ExpressionEncoder[Row]],
+ groupingAttributes.toStructType,
+ outAttributes.toStructType,
+ stateType)
+
+ val context = TaskContext.get()
+
+ val processIter = iter.map { case (keyRow, stateData, valueIter) =>
+ val groupedState = GroupStateImpl.createForStreaming(
+ Option(stateData.stateObj).map { r => assert(r.isInstanceOf[Row]); r
},
+ batchTimestampMs.getOrElse(NO_TIMESTAMP),
+ eventTimeWatermark.getOrElse(NO_TIMESTAMP),
+ timeoutConf,
+ hasTimedOut = hasTimedOut,
+ watermarkPresent).asInstanceOf[GroupStateImpl[Row]]
+ (keyRow, groupedState, valueIter)
+ }
+ runner.compute(processIter, context.partitionId(), context).flatMap {
+ case (stateIter, outputIter) =>
+ // When the iterator is consumed, then write changes to state.
+ // state does not affect each others, hence when to update does not
affect to the result.
+ def onIteratorCompletion: Unit = {
+ stateIter.foreach { case (keyRow, newGroupState,
oldTimeoutTimestamp) =>
+ if (newGroupState.isRemoved &&
!newGroupState.getTimeoutTimestampMs.isPresent()) {
+ stateManager.removeState(store, keyRow)
+ numRemovedStateRows += 1
+ } else {
+ val currentTimeoutTimestamp =
newGroupState.getTimeoutTimestampMs
+ .orElse(NO_TIMESTAMP)
+ val hasTimeoutChanged = currentTimeoutTimestamp !=
oldTimeoutTimestamp
+ val shouldWriteState = newGroupState.isUpdated ||
newGroupState.isRemoved ||
+ hasTimeoutChanged
+
+ if (shouldWriteState) {
+ val updatedStateObj = if (newGroupState.exists)
newGroupState.get else null
+ stateManager.putState(store, keyRow, updatedStateObj,
+ currentTimeoutTimestamp)
+ numUpdatedStateRows += 1
+ }
+ }
+ }
+ }
+
+ CompletionIterator[InternalRow, Iterator[InternalRow]](
+ outputIter, onIteratorCompletion).map { row =>
+ numOutputRows += 1
+ row
+ }
+ }
+ }
+
+ override protected def callFunctionAndUpdateState(
+ stateData: StateData,
+ valueRowIter: Iterator[InternalRow],
+ hasTimedOut: Boolean): Iterator[InternalRow] = {
+ throw new UnsupportedOperationException("Should not reach here!")
+ }
+ }
+}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala
index 2da0000dad4..07887666406 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala
@@ -24,7 +24,7 @@ import org.apache.spark.TaskContext
import org.apache.spark.api.python.BasePythonRunner
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection}
-import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan}
+import org.apache.spark.sql.execution.GroupedIterator
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
/**
@@ -88,9 +88,10 @@ private[python] object PandasGroupUtils {
* argOffsets[argOffsets[0]+2 .. ] is the arg offsets for data attributes
*/
def resolveArgOffsets(
- child: SparkPlan, groupingAttributes: Seq[Attribute]): (Seq[Attribute],
Array[Int]) = {
+ attributes: Seq[Attribute],
+ groupingAttributes: Seq[Attribute]): (Seq[Attribute], Array[Int]) = {
- val dataAttributes = child.output.drop(groupingAttributes.length)
+ val dataAttributes = attributes.drop(groupingAttributes.length)
val groupingIndicesInData = groupingAttributes.map { attribute =>
dataAttributes.indexWhere(attribute.semanticEquals)
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
index 6168d0f867a..bf66791183e 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
@@ -76,7 +76,6 @@ private[python] trait PythonArrowInput[IN] { self:
BasePythonRunner[IN, _] =>
val root = VectorSchemaRoot.create(arrowSchema, allocator)
Utils.tryWithSafeFinally {
- val arrowWriter = ArrowWriter.create(root)
val writer = new ArrowStreamWriter(root, null, dataOut)
writer.start()
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index 3f369ac5e97..f386282a0b3 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.execution.{LocalLimitExec, QueryExecution,
SparkPlan, SparkPlanner, UnaryExecNode}
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec,
MergingSessionsExec, ObjectHashAggregateExec, SortAggregateExec,
UpdatingSessionsExec}
import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
+import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasWithStateExec
import
org.apache.spark.sql.execution.streaming.sources.WriteToMicroBatchDataSourceV1
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.OutputMode
@@ -62,6 +63,7 @@ class IncrementalExecution(
StreamingJoinStrategy ::
StatefulAggregationStrategy ::
FlatMapGroupsWithStateStrategy ::
+ FlatMapGroupsInPandasWithStateStrategy ::
StreamingRelationStrategy ::
StreamingDeduplicationStrategy ::
StreamingGlobalLimitStrategy(outputMode) :: Nil
@@ -210,6 +212,13 @@ class IncrementalExecution(
hasInitialState = hasInitialState
)
+ case m: FlatMapGroupsInPandasWithStateExec =>
+ m.copy(
+ stateInfo = Some(nextStatefulOperationStateInfo),
+ batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs),
+ eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs)
+ )
+
case j: StreamingSymmetricHashJoinExec =>
j.copy(
stateInfo = Some(nextStatefulOperationStateInfo),
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
index 01ff72bac7b..022fd1239ce 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
@@ -49,7 +49,7 @@ package object state {
}
/** Map each partition of an RDD along with data in a [[StateStore]]. */
- private[streaming] def mapPartitionsWithStateStore[U: ClassTag](
+ def mapPartitionsWithStateStore[U: ClassTag](
stateInfo: StatefulOperatorStateInfo,
keySchema: StructType,
valueSchema: StructType,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]