This is an automated email from the ASF dual-hosted git repository.

yao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4c9c41e0b7c4 Revert "[SPARK-51827][SS][CONNECT] Support Spark Connect 
on transformWithState in PySpark"
4c9c41e0b7c4 is described below

commit 4c9c41e0b7c43618053408d34427ead2e05a2e23
Author: Kent Yao <y...@apache.org>
AuthorDate: Sun Apr 27 17:51:06 2025 +0800

    Revert "[SPARK-51827][SS][CONNECT] Support Spark Connect on 
transformWithState in PySpark"
    
    This reverts commit 81ede347e3e27c2c6adbb5f286e9ab701a290f84.
---
 python/pyspark/sql/connect/group.py                | 52 ---------------
 python/pyspark/sql/connect/plan.py                 | 21 ++----
 .../test_parity_pandas_transform_with_state.py     | 32 +---------
 .../sql/tests/test_connect_compatibility.py        |  3 +-
 .../sql/connect/planner/SparkConnectPlanner.scala  | 74 ++++++----------------
 5 files changed, 29 insertions(+), 153 deletions(-)

diff --git a/python/pyspark/sql/connect/group.py 
b/python/pyspark/sql/connect/group.py
index ed7f920be5f0..5a4888fda6db 100644
--- a/python/pyspark/sql/connect/group.py
+++ b/python/pyspark/sql/connect/group.py
@@ -414,58 +414,6 @@ class GroupedData:
 
     transformWithStateInPandas.__doc__ = 
PySparkGroupedData.transformWithStateInPandas.__doc__
 
-    def transformWithState(
-        self,
-        statefulProcessor: StatefulProcessor,
-        outputStructType: Union[StructType, str],
-        outputMode: str,
-        timeMode: str,
-        initialState: Optional["GroupedData"] = None,
-        eventTimeColumnName: str = "",
-    ) -> "DataFrame":
-        from pyspark.sql.connect.udf import UserDefinedFunction
-        from pyspark.sql.connect.dataframe import DataFrame
-        from pyspark.sql.streaming.stateful_processor_util import (
-            TransformWithStateInPySparkUdfUtils,
-        )
-
-        udf_util = TransformWithStateInPySparkUdfUtils(statefulProcessor, 
timeMode)
-        if initialState is None:
-            udf_obj = UserDefinedFunction(
-                udf_util.transformWithStateUDF,
-                returnType=outputStructType,
-                evalType=PythonEvalType.SQL_TRANSFORM_WITH_STATE_UDF,
-            )
-            initial_state_plan = None
-            initial_state_grouping_cols = None
-        else:
-            self._df._check_same_session(initialState._df)
-            udf_obj = UserDefinedFunction(
-                udf_util.transformWithStateWithInitStateUDF,
-                returnType=outputStructType,
-                
evalType=PythonEvalType.SQL_TRANSFORM_WITH_STATE_INIT_STATE_UDF,
-            )
-            initial_state_plan = initialState._df._plan
-            initial_state_grouping_cols = initialState._grouping_cols
-
-        return DataFrame(
-            plan.TransformWithStateInPySpark(
-                child=self._df._plan,
-                grouping_cols=self._grouping_cols,
-                function=udf_obj,
-                output_schema=outputStructType,
-                output_mode=outputMode,
-                time_mode=timeMode,
-                event_time_col_name=eventTimeColumnName,
-                cols=self._df.columns,
-                initial_state_plan=initial_state_plan,
-                initial_state_grouping_cols=initial_state_grouping_cols,
-            ),
-            session=self._df._session,
-        )
-
-    transformWithState.__doc__ = PySparkGroupedData.transformWithState.__doc__
-
     def applyInArrow(
         self, func: "ArrowGroupedMapFunction", schema: Union[StructType, str]
     ) -> "DataFrame":
diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index c5b6f5430d6d..c4c7a6a63630 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -2546,8 +2546,8 @@ class ApplyInPandasWithState(LogicalPlan):
         return self._with_relations(plan, session)
 
 
-class BaseTransformWithStateInPySpark(LogicalPlan):
-    """Base implementation of logical plan object for a 
TransformWithStateIn(PySpark/Pandas)."""
+class TransformWithStateInPandas(LogicalPlan):
+    """Logical plan object for a TransformWithStateInPandas."""
 
     def __init__(
         self,
@@ -2600,7 +2600,7 @@ class BaseTransformWithStateInPySpark(LogicalPlan):
                 [c.to_plan(session) for c in self._initial_state_grouping_cols]
             )
 
-        # fill in transformWithStateInPySpark/Pandas related fields
+        # fill in transformWithStateInPandas related fields
         tws_info = proto.TransformWithStateInfo()
         tws_info.time_mode = self._time_mode
         tws_info.event_time_column_name = self._event_time_col_name
@@ -2608,25 +2608,12 @@ class BaseTransformWithStateInPySpark(LogicalPlan):
 
         plan.group_map.transform_with_state_info.CopyFrom(tws_info)
 
-        # wrap transformWithStateInPySparkUdf in a function
+        # wrap transformWithStateInPandasUdf in a function
         plan.group_map.func.CopyFrom(self._function.to_plan_udf(session))
 
         return self._with_relations(plan, session)
 
 
-class TransformWithStateInPySpark(BaseTransformWithStateInPySpark):
-    """Logical plan object for a TransformWithStateInPySpark."""
-
-    pass
-
-
-# Retaining this to avoid breaking backward compatibility.
-class TransformWithStateInPandas(BaseTransformWithStateInPySpark):
-    """Logical plan object for a TransformWithStateInPandas."""
-
-    pass
-
-
 class PythonUDTF:
     """Represents a Python user-defined table function."""
 
diff --git 
a/python/pyspark/sql/tests/connect/pandas/test_parity_pandas_transform_with_state.py
 
b/python/pyspark/sql/tests/connect/pandas/test_parity_pandas_transform_with_state.py
index e772c2139326..26f2941d3d1f 100644
--- 
a/python/pyspark/sql/tests/connect/pandas/test_parity_pandas_transform_with_state.py
+++ 
b/python/pyspark/sql/tests/connect/pandas/test_parity_pandas_transform_with_state.py
@@ -18,7 +18,6 @@ import unittest
 
 from pyspark.sql.tests.pandas.test_pandas_transform_with_state import (
     TransformWithStateInPandasTestsMixin,
-    TransformWithStateInPySparkTestsMixin,
 )
 from pyspark import SparkConf
 from pyspark.testing.connectutils import ReusedConnectTestCase
@@ -54,35 +53,8 @@ class TransformWithStateInPandasParityTests(
         pass
 
 
-class TransformWithStateInPySparkParityTests(
-    TransformWithStateInPySparkTestsMixin, ReusedConnectTestCase
-):
-    """
-    Spark connect parity tests for TransformWithStateInPySpark. Run every test 
case in
-     `TransformWithStateInPySparkTestsMixin` in spark connect mode.
-    """
-
-    @classmethod
-    def conf(cls):
-        # Due to multiple inheritance from the same level, we need to 
explicitly setting configs in
-        # both TransformWithStateInPySparkTestsMixin and ReusedConnectTestCase 
here
-        cfg = SparkConf(loadDefaults=False)
-        for base in cls.__bases__:
-            if hasattr(base, "conf"):
-                parent_cfg = base.conf()
-                for k, v in parent_cfg.getAll():
-                    cfg.set(k, v)
-
-        # Extra removing config for connect suites
-        if cfg._jconf is not None:
-            cfg._jconf.remove("spark.master")
-
-        return cfg
-
-    @unittest.skip("Flaky in spark connect on CI. Skip for now. See 
SPARK-51368 for details.")
-    def test_schema_evolution_scenarios(self):
-        pass
-
+# TODO(SPARK-51827): Need to copy the parity test when we implement 
transformWithState in
+#  Python Spark Connect
 
 if __name__ == "__main__":
     from 
pyspark.sql.tests.connect.pandas.test_parity_pandas_transform_with_state import 
*  # noqa: F401,E501
diff --git a/python/pyspark/sql/tests/test_connect_compatibility.py 
b/python/pyspark/sql/tests/test_connect_compatibility.py
index b2e0cc6229c4..7323dc9424de 100644
--- a/python/pyspark/sql/tests/test_connect_compatibility.py
+++ b/python/pyspark/sql/tests/test_connect_compatibility.py
@@ -395,7 +395,8 @@ class ConnectCompatibilityTestsMixin:
         """Test Grouping compatibility between classic and connect."""
         expected_missing_connect_properties = set()
         expected_missing_classic_properties = set()
-        expected_missing_connect_methods = set()
+        # TODO(SPARK-51827): Add missing method `transformWithState` to the 
connect version
+        expected_missing_connect_methods = {"transformWithState"}
         expected_missing_classic_methods = set()
         self.check_compatibility(
             ClassicGroupedData,
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 849dd9532405..911d79ecdb12 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -661,11 +661,7 @@ class SparkConnectPlanner(
 
           case PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF |
               PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF =>
-            transformTransformWithStateInPySpark(pythonUdf, group, rel, 
usePandas = true)
-
-          case PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF |
-              
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF =>
-            transformTransformWithStateInPySpark(pythonUdf, group, rel, 
usePandas = false)
+            transformTransformWithStateInPandas(pythonUdf, group, rel)
 
           case _ =>
             throw InvalidPlanInput(
@@ -1106,11 +1102,10 @@ class SparkConnectPlanner(
       .logicalPlan
   }
 
-  private def transformTransformWithStateInPySpark(
+  private def transformTransformWithStateInPandas(
       pythonUdf: PythonUDF,
       groupedDs: RelationalGroupedDataset,
-      rel: proto.GroupMap,
-      usePandas: Boolean): LogicalPlan = {
+      rel: proto.GroupMap): LogicalPlan = {
     val twsInfo = rel.getTransformWithStateInfo
     val outputSchema: StructType = {
       transformDataType(twsInfo.getOutputSchema) match {
@@ -1136,52 +1131,25 @@ class SparkConnectPlanner(
         .builder(groupedDs.df.logicalPlan.output)
         .asInstanceOf[PythonUDF]
 
-      if (usePandas) {
-        groupedDs
-          .transformWithStateInPandas(
-            Column(resolvedPythonUDF),
-            outputSchema,
-            rel.getOutputMode,
-            twsInfo.getTimeMode,
-            initialStateDs,
-            twsInfo.getEventTimeColumnName)
-          .logicalPlan
-      } else {
-        // use Row
-        groupedDs
-          .transformWithStateInPySpark(
-            Column(resolvedPythonUDF),
-            outputSchema,
-            rel.getOutputMode,
-            twsInfo.getTimeMode,
-            initialStateDs,
-            twsInfo.getEventTimeColumnName)
-          .logicalPlan
-      }
-
+      groupedDs
+        .transformWithStateInPandas(
+          Column(resolvedPythonUDF),
+          outputSchema,
+          rel.getOutputMode,
+          twsInfo.getTimeMode,
+          initialStateDs,
+          twsInfo.getEventTimeColumnName)
+        .logicalPlan
     } else {
-      if (usePandas) {
-        groupedDs
-          .transformWithStateInPandas(
-            Column(pythonUdf),
-            outputSchema,
-            rel.getOutputMode,
-            twsInfo.getTimeMode,
-            null,
-            twsInfo.getEventTimeColumnName)
-          .logicalPlan
-      } else {
-        // use Row
-        groupedDs
-          .transformWithStateInPySpark(
-            Column(pythonUdf),
-            outputSchema,
-            rel.getOutputMode,
-            twsInfo.getTimeMode,
-            null,
-            twsInfo.getEventTimeColumnName)
-          .logicalPlan
-      }
+      groupedDs
+        .transformWithStateInPandas(
+          Column(pythonUdf),
+          outputSchema,
+          rel.getOutputMode,
+          twsInfo.getTimeMode,
+          null,
+          twsInfo.getEventTimeColumnName)
+        .logicalPlan
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to