This is an automated email from the ASF dual-hosted git repository. yao pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 4c9c41e0b7c4 Revert "[SPARK-51827][SS][CONNECT] Support Spark Connect on transformWithState in PySpark" 4c9c41e0b7c4 is described below commit 4c9c41e0b7c43618053408d34427ead2e05a2e23 Author: Kent Yao <y...@apache.org> AuthorDate: Sun Apr 27 17:51:06 2025 +0800 Revert "[SPARK-51827][SS][CONNECT] Support Spark Connect on transformWithState in PySpark" This reverts commit 81ede347e3e27c2c6adbb5f286e9ab701a290f84. --- python/pyspark/sql/connect/group.py | 52 --------------- python/pyspark/sql/connect/plan.py | 21 ++---- .../test_parity_pandas_transform_with_state.py | 32 +--------- .../sql/tests/test_connect_compatibility.py | 3 +- .../sql/connect/planner/SparkConnectPlanner.scala | 74 ++++++---------------- 5 files changed, 29 insertions(+), 153 deletions(-) diff --git a/python/pyspark/sql/connect/group.py b/python/pyspark/sql/connect/group.py index ed7f920be5f0..5a4888fda6db 100644 --- a/python/pyspark/sql/connect/group.py +++ b/python/pyspark/sql/connect/group.py @@ -414,58 +414,6 @@ class GroupedData: transformWithStateInPandas.__doc__ = PySparkGroupedData.transformWithStateInPandas.__doc__ - def transformWithState( - self, - statefulProcessor: StatefulProcessor, - outputStructType: Union[StructType, str], - outputMode: str, - timeMode: str, - initialState: Optional["GroupedData"] = None, - eventTimeColumnName: str = "", - ) -> "DataFrame": - from pyspark.sql.connect.udf import UserDefinedFunction - from pyspark.sql.connect.dataframe import DataFrame - from pyspark.sql.streaming.stateful_processor_util import ( - TransformWithStateInPySparkUdfUtils, - ) - - udf_util = TransformWithStateInPySparkUdfUtils(statefulProcessor, timeMode) - if initialState is None: - udf_obj = UserDefinedFunction( - udf_util.transformWithStateUDF, - returnType=outputStructType, - evalType=PythonEvalType.SQL_TRANSFORM_WITH_STATE_UDF, - ) - initial_state_plan = None - initial_state_grouping_cols = None - else: - self._df._check_same_session(initialState._df) - udf_obj = UserDefinedFunction( - udf_util.transformWithStateWithInitStateUDF, - returnType=outputStructType, - evalType=PythonEvalType.SQL_TRANSFORM_WITH_STATE_INIT_STATE_UDF, - ) - initial_state_plan = initialState._df._plan - initial_state_grouping_cols = initialState._grouping_cols - - return DataFrame( - plan.TransformWithStateInPySpark( - child=self._df._plan, - grouping_cols=self._grouping_cols, - function=udf_obj, - output_schema=outputStructType, - output_mode=outputMode, - time_mode=timeMode, - event_time_col_name=eventTimeColumnName, - cols=self._df.columns, - initial_state_plan=initial_state_plan, - initial_state_grouping_cols=initial_state_grouping_cols, - ), - session=self._df._session, - ) - - transformWithState.__doc__ = PySparkGroupedData.transformWithState.__doc__ - def applyInArrow( self, func: "ArrowGroupedMapFunction", schema: Union[StructType, str] ) -> "DataFrame": diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index c5b6f5430d6d..c4c7a6a63630 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -2546,8 +2546,8 @@ class ApplyInPandasWithState(LogicalPlan): return self._with_relations(plan, session) -class BaseTransformWithStateInPySpark(LogicalPlan): - """Base implementation of logical plan object for a TransformWithStateIn(PySpark/Pandas).""" +class TransformWithStateInPandas(LogicalPlan): + """Logical plan object for a TransformWithStateInPandas.""" def __init__( self, @@ -2600,7 +2600,7 @@ class BaseTransformWithStateInPySpark(LogicalPlan): [c.to_plan(session) for c in self._initial_state_grouping_cols] ) - # fill in transformWithStateInPySpark/Pandas related fields + # fill in transformWithStateInPandas related fields tws_info = proto.TransformWithStateInfo() tws_info.time_mode = self._time_mode tws_info.event_time_column_name = self._event_time_col_name @@ -2608,25 +2608,12 @@ class BaseTransformWithStateInPySpark(LogicalPlan): plan.group_map.transform_with_state_info.CopyFrom(tws_info) - # wrap transformWithStateInPySparkUdf in a function + # wrap transformWithStateInPandasUdf in a function plan.group_map.func.CopyFrom(self._function.to_plan_udf(session)) return self._with_relations(plan, session) -class TransformWithStateInPySpark(BaseTransformWithStateInPySpark): - """Logical plan object for a TransformWithStateInPySpark.""" - - pass - - -# Retaining this to avoid breaking backward compatibility. -class TransformWithStateInPandas(BaseTransformWithStateInPySpark): - """Logical plan object for a TransformWithStateInPandas.""" - - pass - - class PythonUDTF: """Represents a Python user-defined table function.""" diff --git a/python/pyspark/sql/tests/connect/pandas/test_parity_pandas_transform_with_state.py b/python/pyspark/sql/tests/connect/pandas/test_parity_pandas_transform_with_state.py index e772c2139326..26f2941d3d1f 100644 --- a/python/pyspark/sql/tests/connect/pandas/test_parity_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/connect/pandas/test_parity_pandas_transform_with_state.py @@ -18,7 +18,6 @@ import unittest from pyspark.sql.tests.pandas.test_pandas_transform_with_state import ( TransformWithStateInPandasTestsMixin, - TransformWithStateInPySparkTestsMixin, ) from pyspark import SparkConf from pyspark.testing.connectutils import ReusedConnectTestCase @@ -54,35 +53,8 @@ class TransformWithStateInPandasParityTests( pass -class TransformWithStateInPySparkParityTests( - TransformWithStateInPySparkTestsMixin, ReusedConnectTestCase -): - """ - Spark connect parity tests for TransformWithStateInPySpark. Run every test case in - `TransformWithStateInPySparkTestsMixin` in spark connect mode. - """ - - @classmethod - def conf(cls): - # Due to multiple inheritance from the same level, we need to explicitly setting configs in - # both TransformWithStateInPySparkTestsMixin and ReusedConnectTestCase here - cfg = SparkConf(loadDefaults=False) - for base in cls.__bases__: - if hasattr(base, "conf"): - parent_cfg = base.conf() - for k, v in parent_cfg.getAll(): - cfg.set(k, v) - - # Extra removing config for connect suites - if cfg._jconf is not None: - cfg._jconf.remove("spark.master") - - return cfg - - @unittest.skip("Flaky in spark connect on CI. Skip for now. See SPARK-51368 for details.") - def test_schema_evolution_scenarios(self): - pass - +# TODO(SPARK-51827): Need to copy the parity test when we implement transformWithState in +# Python Spark Connect if __name__ == "__main__": from pyspark.sql.tests.connect.pandas.test_parity_pandas_transform_with_state import * # noqa: F401,E501 diff --git a/python/pyspark/sql/tests/test_connect_compatibility.py b/python/pyspark/sql/tests/test_connect_compatibility.py index b2e0cc6229c4..7323dc9424de 100644 --- a/python/pyspark/sql/tests/test_connect_compatibility.py +++ b/python/pyspark/sql/tests/test_connect_compatibility.py @@ -395,7 +395,8 @@ class ConnectCompatibilityTestsMixin: """Test Grouping compatibility between classic and connect.""" expected_missing_connect_properties = set() expected_missing_classic_properties = set() - expected_missing_connect_methods = set() + # TODO(SPARK-51827): Add missing method `transformWithState` to the connect version + expected_missing_connect_methods = {"transformWithState"} expected_missing_classic_methods = set() self.check_compatibility( ClassicGroupedData, diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 849dd9532405..911d79ecdb12 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -661,11 +661,7 @@ class SparkConnectPlanner( case PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF | PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF => - transformTransformWithStateInPySpark(pythonUdf, group, rel, usePandas = true) - - case PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF | - PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF => - transformTransformWithStateInPySpark(pythonUdf, group, rel, usePandas = false) + transformTransformWithStateInPandas(pythonUdf, group, rel) case _ => throw InvalidPlanInput( @@ -1106,11 +1102,10 @@ class SparkConnectPlanner( .logicalPlan } - private def transformTransformWithStateInPySpark( + private def transformTransformWithStateInPandas( pythonUdf: PythonUDF, groupedDs: RelationalGroupedDataset, - rel: proto.GroupMap, - usePandas: Boolean): LogicalPlan = { + rel: proto.GroupMap): LogicalPlan = { val twsInfo = rel.getTransformWithStateInfo val outputSchema: StructType = { transformDataType(twsInfo.getOutputSchema) match { @@ -1136,52 +1131,25 @@ class SparkConnectPlanner( .builder(groupedDs.df.logicalPlan.output) .asInstanceOf[PythonUDF] - if (usePandas) { - groupedDs - .transformWithStateInPandas( - Column(resolvedPythonUDF), - outputSchema, - rel.getOutputMode, - twsInfo.getTimeMode, - initialStateDs, - twsInfo.getEventTimeColumnName) - .logicalPlan - } else { - // use Row - groupedDs - .transformWithStateInPySpark( - Column(resolvedPythonUDF), - outputSchema, - rel.getOutputMode, - twsInfo.getTimeMode, - initialStateDs, - twsInfo.getEventTimeColumnName) - .logicalPlan - } - + groupedDs + .transformWithStateInPandas( + Column(resolvedPythonUDF), + outputSchema, + rel.getOutputMode, + twsInfo.getTimeMode, + initialStateDs, + twsInfo.getEventTimeColumnName) + .logicalPlan } else { - if (usePandas) { - groupedDs - .transformWithStateInPandas( - Column(pythonUdf), - outputSchema, - rel.getOutputMode, - twsInfo.getTimeMode, - null, - twsInfo.getEventTimeColumnName) - .logicalPlan - } else { - // use Row - groupedDs - .transformWithStateInPySpark( - Column(pythonUdf), - outputSchema, - rel.getOutputMode, - twsInfo.getTimeMode, - null, - twsInfo.getEventTimeColumnName) - .logicalPlan - } + groupedDs + .transformWithStateInPandas( + Column(pythonUdf), + outputSchema, + rel.getOutputMode, + twsInfo.getTimeMode, + null, + twsInfo.getEventTimeColumnName) + .logicalPlan } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org