This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new fee47c77e5d [SPARK-42340][CONNECT][PYTHON] Implement Grouped Map API
fee47c77e5d is described below

commit fee47c77e5d31bce592bf6e2bd33c2dabfc57bd3
Author: Xinrong Meng <xinr...@apache.org>
AuthorDate: Mon Mar 20 20:04:00 2023 +0900

    [SPARK-42340][CONNECT][PYTHON] Implement Grouped Map API
    
    ### What changes were proposed in this pull request?
    Implement Grouped Map API:`GroupedData.applyInPandas` and 
`GroupedData.apply`.
    
    ### Why are the changes needed?
    Parity with vanilla PySpark.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. `GroupedData.applyInPandas` and `GroupedData.apply` are supported now, 
as shown below.
    ```sh
    >>> df = spark.createDataFrame([(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 
10.0)],("id", "v"))
    >>> def normalize(pdf):
    ...     v = pdf.v
    ...     return pdf.assign(v=(v - v.mean()) / v.std())
    ...
    >>> df.groupby("id").applyInPandas(normalize, schema="id long, v 
double").show()
    
    +---+-------------------+
    | id|                  v|
    +---+-------------------+
    |  1|-0.7071067811865475|
    |  1| 0.7071067811865475|
    |  2|-0.8320502943378437|
    |  2|-0.2773500981126146|
    |  2| 1.1094003924504583|
    +---+-------------------+
    ```
    
    ```sh
    >>> pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP)
    ... def normalize(pdf):
    ...     v = pdf.v
    ...     return pdf.assign(v=(v - v.mean()) / v.std())
    ...
    >>> df.groupby("id").apply(normalize).show()
    /Users/xinrong.meng/spark/python/pyspark/sql/connect/group.py:228: 
UserWarning: It is preferred to use 'applyInPandas' over this API. This API 
will be deprecated in the future releases. See SPARK-28264 for more details.
      warnings.warn(
    +---+-------------------+
    | id|                  v|
    +---+-------------------+
    |  1|-0.7071067811865475|
    |  1| 0.7071067811865475|
    |  2|-0.8320502943378437|
    |  2|-0.2773500981126146|
    |  2| 1.1094003924504583|
    +---+-------------------+
    ```
    
    ### How was this patch tested?
    (Parity) Unit tests.
    
    Closes #40405 from xinrong-meng/group_map.
    
    Authored-by: Xinrong Meng <xinr...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../main/protobuf/spark/connect/relations.proto    |  12 +
 .../sql/connect/planner/SparkConnectPlanner.scala  |  14 ++
 dev/sparktestsupport/modules.py                    |   2 +-
 python/pyspark/sql/connect/_typing.py              |  10 +-
 python/pyspark/sql/connect/group.py                |  61 +++++-
 python/pyspark/sql/connect/plan.py                 |  27 +++
 python/pyspark/sql/connect/proto/relations_pb2.py  | 242 +++++++++++----------
 python/pyspark/sql/connect/proto/relations_pb2.pyi |  51 +++++
 python/pyspark/sql/pandas/group_ops.py             |   6 +
 .../sql/tests/connect/test_connect_basic.py        |   2 -
 .../connect/test_parity_pandas_grouped_map.py      | 102 +++++++++
 .../sql/tests/connect/test_parity_pandas_udf.py    |   5 -
 .../sql/tests/pandas/test_pandas_grouped_map.py    |  16 +-
 13 files changed, 416 insertions(+), 134 deletions(-)

diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index 69451e7b76e..aba965082ea 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -63,6 +63,7 @@ message Relation {
     MapPartitions map_partitions = 28;
     CollectMetrics collect_metrics = 29;
     Parse parse = 30;
+    GroupMap group_map = 31;
 
     // NA functions
     NAFill fill_na = 90;
@@ -788,6 +789,17 @@ message MapPartitions {
   CommonInlineUserDefinedFunction func = 2;
 }
 
+message GroupMap {
+  // (Required) Input relation for Group Map API: apply, applyInPandas.
+  Relation input = 1;
+
+  // (Required) Expressions for grouping keys.
+  repeated Expression grouping_expressions = 2;
+
+  // (Required) Input user-defined function.
+  CommonInlineUserDefinedFunction func = 3;
+}
+
 // Collect arbitrary (named) metrics from a dataset.
 message CollectMetrics {
   // (Required) The input relation.
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index b023adac98a..c8fdaa6641a 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -117,6 +117,8 @@ class SparkConnectPlanner(val session: SparkSession) {
         transformRepartitionByExpression(rel.getRepartitionByExpression)
       case proto.Relation.RelTypeCase.MAP_PARTITIONS =>
         transformMapPartitions(rel.getMapPartitions)
+      case proto.Relation.RelTypeCase.GROUP_MAP =>
+        transformGroupMap(rel.getGroupMap)
       case proto.Relation.RelTypeCase.COLLECT_METRICS =>
         transformCollectMetrics(rel.getCollectMetrics)
       case proto.Relation.RelTypeCase.PARSE => transformParse(rel.getParse)
@@ -495,6 +497,18 @@ class SparkConnectPlanner(val session: SparkSession) {
     }
   }
 
+  private def transformGroupMap(rel: proto.GroupMap): LogicalPlan = {
+    val pythonUdf = transformPythonUDF(rel.getFunc)
+    val cols =
+      rel.getGroupingExpressionsList.asScala.toSeq.map(expr => 
Column(transformExpression(expr)))
+
+    Dataset
+      .ofRows(session, transformRelation(rel.getInput))
+      .groupBy(cols: _*)
+      .flatMapGroupsInPandas(pythonUdf)
+      .logicalPlan
+  }
+
   private def transformWithColumnsRenamed(rel: proto.WithColumnsRenamed): 
LogicalPlan = {
     Dataset
       .ofRows(session, transformRelation(rel.getInput))
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 5379f883815..c31a9362cd7 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -513,7 +513,6 @@ pyspark_sql = Module(
     ],
 )
 
-
 pyspark_resource = Module(
     name="pyspark-resource",
     dependencies=[pyspark_core],
@@ -779,6 +778,7 @@ pyspark_connect = Module(
         "pyspark.sql.tests.connect.test_parity_pandas_udf",
         "pyspark.sql.tests.connect.test_parity_pandas_map",
         "pyspark.sql.tests.connect.test_parity_arrow_map",
+        "pyspark.sql.tests.connect.test_parity_pandas_grouped_map",
         # ml doctests
         "pyspark.ml.connect.functions",
         # ml unittests
diff --git a/python/pyspark/sql/connect/_typing.py 
b/python/pyspark/sql/connect/_typing.py
index 6df3f15d87d..63aae5d2487 100644
--- a/python/pyspark/sql/connect/_typing.py
+++ b/python/pyspark/sql/connect/_typing.py
@@ -22,7 +22,8 @@ if sys.version_info >= (3, 8):
 else:
     from typing_extensions import Protocol
 
-from typing import Any, Callable, Iterable, Union, Optional
+from types import FunctionType
+from typing import Any, Callable, Iterable, Union, Optional, NewType
 import datetime
 import decimal
 
@@ -53,6 +54,13 @@ PandasMapIterFunction = Callable[[Iterable[DataFrameLike]], 
Iterable[DataFrameLi
 
 ArrowMapIterFunction = Callable[[Iterable[pyarrow.RecordBatch]], 
Iterable[pyarrow.RecordBatch]]
 
+PandasGroupedMapFunction = Union[
+    Callable[[DataFrameLike], DataFrameLike],
+    Callable[[Any, DataFrameLike], DataFrameLike],
+]
+
+GroupedMapPandasUserDefinedFunction = 
NewType("GroupedMapPandasUserDefinedFunction", FunctionType)
+
 
 class UserDefinedFunctionLike(Protocol):
     func: Callable[..., Any]
diff --git a/python/pyspark/sql/connect/group.py 
b/python/pyspark/sql/connect/group.py
index e699ce7105a..a75a50501bd 100644
--- a/python/pyspark/sql/connect/group.py
+++ b/python/pyspark/sql/connect/group.py
@@ -14,6 +14,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
+import warnings
+
 from pyspark.sql.connect.utils import check_dependencies
 
 check_dependencies(__name__)
@@ -30,6 +32,7 @@ from typing import (
     cast,
 )
 
+from pyspark.rdd import PythonEvalType
 from pyspark.sql.group import GroupedData as PySparkGroupedData
 from pyspark.sql.types import NumericType
 
@@ -38,8 +41,13 @@ from pyspark.sql.connect.column import Column
 from pyspark.sql.connect.functions import _invoke_function, col, lit
 
 if TYPE_CHECKING:
-    from pyspark.sql.connect._typing import LiteralType
+    from pyspark.sql.connect._typing import (
+        LiteralType,
+        PandasGroupedMapFunction,
+        GroupedMapPandasUserDefinedFunction,
+    )
     from pyspark.sql.connect.dataframe import DataFrame
+    from pyspark.sql.types import StructType
 
 
 class GroupedData:
@@ -203,11 +211,54 @@ class GroupedData:
 
     pivot.__doc__ = PySparkGroupedData.pivot.__doc__
 
-    def apply(self, *args: Any, **kwargs: Any) -> None:
-        raise NotImplementedError("apply() is not implemented.")
+    def apply(self, udf: "GroupedMapPandasUserDefinedFunction") -> "DataFrame":
+        # Columns are special because hasattr always return True
+        if (
+            isinstance(udf, Column)
+            or not hasattr(udf, "func")
+            or (
+                udf.evalType  # type: ignore[attr-defined]
+                != PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
+            )
+        ):
+            raise ValueError(
+                "Invalid udf: the udf argument must be a pandas_udf of type " 
"GROUPED_MAP."
+            )
+
+        warnings.warn(
+            "It is preferred to use 'applyInPandas' over this "
+            "API. This API will be deprecated in the future releases. See 
SPARK-28264 for "
+            "more details.",
+            UserWarning,
+        )
+
+        return self.applyInPandas(udf.func, schema=udf.returnType)  # type: 
ignore[attr-defined]
+
+    apply.__doc__ = PySparkGroupedData.apply.__doc__
+
+    def applyInPandas(
+        self, func: "PandasGroupedMapFunction", schema: Union["StructType", 
str]
+    ) -> "DataFrame":
+        from pyspark.sql.connect.udf import UserDefinedFunction
+        from pyspark.sql.connect.dataframe import DataFrame
+
+        udf_obj = UserDefinedFunction(
+            func,
+            returnType=schema,
+            evalType=PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
+        )
+
+        return DataFrame.withPlan(
+            plan.GroupMap(
+                child=self._df._plan,
+                grouping_cols=self._grouping_cols,
+                function=udf_obj,
+                cols=self._df.columns,
+            ),
+            session=self._df._session,
+        )
 
-    def applyInPandas(self, *args: Any, **kwargs: Any) -> None:
-        raise NotImplementedError("applyInPandas() is not implemented.")
+    applyInPandas.__doc__ = PySparkGroupedData.applyInPandas.__doc__
 
     def applyInPandasWithState(self, *args: Any, **kwargs: Any) -> None:
         raise NotImplementedError("applyInPandasWithState() is not 
implemented.")
diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index 9807c9722a6..dbfcfea7678 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -1923,6 +1923,33 @@ class MapPartitions(LogicalPlan):
         return plan
 
 
+class GroupMap(LogicalPlan):
+    """Logical plan object for a Group Map API: apply, applyInPandas."""
+
+    def __init__(
+        self,
+        child: Optional["LogicalPlan"],
+        grouping_cols: Sequence[Column],
+        function: "UserDefinedFunction",
+        cols: List[str],
+    ):
+        assert isinstance(grouping_cols, list) and all(isinstance(c, Column) 
for c in grouping_cols)
+
+        super().__init__(child)
+        self._grouping_cols = grouping_cols
+        self._func = function._build_common_inline_user_defined_function(*cols)
+
+    def plan(self, session: "SparkConnectClient") -> proto.Relation:
+        assert self._child is not None
+        plan = self._create_proto_relation()
+        plan.group_map.input.CopyFrom(self._child.plan(session))
+        plan.group_map.grouping_expressions.extend(
+            [c.to_plan(session) for c in self._grouping_cols]
+        )
+        plan.group_map.func.CopyFrom(self._func.to_plan_udf(session))
+        return plan
+
+
 class CachedRelation(LogicalPlan):
     def __init__(self, plan: proto.Relation) -> None:
         super(CachedRelation, self).__init__(None)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py 
b/python/pyspark/sql/connect/proto/relations_pb2.py
index 521a10f214c..aa6d39cd4f0 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as 
spark_dot_connect_dot_catal
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xb8\x13\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
+    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xf0\x13\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
 )
 
 
@@ -92,6 +92,7 @@ _UNPIVOT_VALUES = _UNPIVOT.nested_types_by_name["Values"]
 _TOSCHEMA = DESCRIPTOR.message_types_by_name["ToSchema"]
 _REPARTITIONBYEXPRESSION = 
DESCRIPTOR.message_types_by_name["RepartitionByExpression"]
 _MAPPARTITIONS = DESCRIPTOR.message_types_by_name["MapPartitions"]
+_GROUPMAP = DESCRIPTOR.message_types_by_name["GroupMap"]
 _COLLECTMETRICS = DESCRIPTOR.message_types_by_name["CollectMetrics"]
 _PARSE = DESCRIPTOR.message_types_by_name["Parse"]
 _PARSE_OPTIONSENTRY = _PARSE.nested_types_by_name["OptionsEntry"]
@@ -640,6 +641,17 @@ MapPartitions = _reflection.GeneratedProtocolMessageType(
 )
 _sym_db.RegisterMessage(MapPartitions)
 
+GroupMap = _reflection.GeneratedProtocolMessageType(
+    "GroupMap",
+    (_message.Message,),
+    {
+        "DESCRIPTOR": _GROUPMAP,
+        "__module__": "spark.connect.relations_pb2"
+        # @@protoc_insertion_point(class_scope:spark.connect.GroupMap)
+    },
+)
+_sym_db.RegisterMessage(GroupMap)
+
 CollectMetrics = _reflection.GeneratedProtocolMessageType(
     "CollectMetrics",
     (_message.Message,),
@@ -685,117 +697,119 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _PARSE_OPTIONSENTRY._options = None
     _PARSE_OPTIONSENTRY._serialized_options = b"8\001"
     _RELATION._serialized_start = 165
-    _RELATION._serialized_end = 2653
-    _UNKNOWN._serialized_start = 2655
-    _UNKNOWN._serialized_end = 2664
-    _RELATIONCOMMON._serialized_start = 2666
-    _RELATIONCOMMON._serialized_end = 2757
-    _SQL._serialized_start = 2760
-    _SQL._serialized_end = 2894
-    _SQL_ARGSENTRY._serialized_start = 2839
-    _SQL_ARGSENTRY._serialized_end = 2894
-    _READ._serialized_start = 2897
-    _READ._serialized_end = 3393
-    _READ_NAMEDTABLE._serialized_start = 3039
-    _READ_NAMEDTABLE._serialized_end = 3100
-    _READ_DATASOURCE._serialized_start = 3103
-    _READ_DATASOURCE._serialized_end = 3380
-    _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3300
-    _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3358
-    _PROJECT._serialized_start = 3395
-    _PROJECT._serialized_end = 3512
-    _FILTER._serialized_start = 3514
-    _FILTER._serialized_end = 3626
-    _JOIN._serialized_start = 3629
-    _JOIN._serialized_end = 4100
-    _JOIN_JOINTYPE._serialized_start = 3892
-    _JOIN_JOINTYPE._serialized_end = 4100
-    _SETOPERATION._serialized_start = 4103
-    _SETOPERATION._serialized_end = 4582
-    _SETOPERATION_SETOPTYPE._serialized_start = 4419
-    _SETOPERATION_SETOPTYPE._serialized_end = 4533
-    _LIMIT._serialized_start = 4584
-    _LIMIT._serialized_end = 4660
-    _OFFSET._serialized_start = 4662
-    _OFFSET._serialized_end = 4741
-    _TAIL._serialized_start = 4743
-    _TAIL._serialized_end = 4818
-    _AGGREGATE._serialized_start = 4821
-    _AGGREGATE._serialized_end = 5403
-    _AGGREGATE_PIVOT._serialized_start = 5160
-    _AGGREGATE_PIVOT._serialized_end = 5271
-    _AGGREGATE_GROUPTYPE._serialized_start = 5274
-    _AGGREGATE_GROUPTYPE._serialized_end = 5403
-    _SORT._serialized_start = 5406
-    _SORT._serialized_end = 5566
-    _DROP._serialized_start = 5569
-    _DROP._serialized_end = 5710
-    _DEDUPLICATE._serialized_start = 5713
-    _DEDUPLICATE._serialized_end = 5884
-    _LOCALRELATION._serialized_start = 5886
-    _LOCALRELATION._serialized_end = 5975
-    _SAMPLE._serialized_start = 5978
-    _SAMPLE._serialized_end = 6251
-    _RANGE._serialized_start = 6254
-    _RANGE._serialized_end = 6399
-    _SUBQUERYALIAS._serialized_start = 6401
-    _SUBQUERYALIAS._serialized_end = 6515
-    _REPARTITION._serialized_start = 6518
-    _REPARTITION._serialized_end = 6660
-    _SHOWSTRING._serialized_start = 6663
-    _SHOWSTRING._serialized_end = 6805
-    _STATSUMMARY._serialized_start = 6807
-    _STATSUMMARY._serialized_end = 6899
-    _STATDESCRIBE._serialized_start = 6901
-    _STATDESCRIBE._serialized_end = 6982
-    _STATCROSSTAB._serialized_start = 6984
-    _STATCROSSTAB._serialized_end = 7085
-    _STATCOV._serialized_start = 7087
-    _STATCOV._serialized_end = 7183
-    _STATCORR._serialized_start = 7186
-    _STATCORR._serialized_end = 7323
-    _STATAPPROXQUANTILE._serialized_start = 7326
-    _STATAPPROXQUANTILE._serialized_end = 7490
-    _STATFREQITEMS._serialized_start = 7492
-    _STATFREQITEMS._serialized_end = 7617
-    _STATSAMPLEBY._serialized_start = 7620
-    _STATSAMPLEBY._serialized_end = 7929
-    _STATSAMPLEBY_FRACTION._serialized_start = 7821
-    _STATSAMPLEBY_FRACTION._serialized_end = 7920
-    _NAFILL._serialized_start = 7932
-    _NAFILL._serialized_end = 8066
-    _NADROP._serialized_start = 8069
-    _NADROP._serialized_end = 8203
-    _NAREPLACE._serialized_start = 8206
-    _NAREPLACE._serialized_end = 8502
-    _NAREPLACE_REPLACEMENT._serialized_start = 8361
-    _NAREPLACE_REPLACEMENT._serialized_end = 8502
-    _TODF._serialized_start = 8504
-    _TODF._serialized_end = 8592
-    _WITHCOLUMNSRENAMED._serialized_start = 8595
-    _WITHCOLUMNSRENAMED._serialized_end = 8834
-    _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 8767
-    _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 8834
-    _WITHCOLUMNS._serialized_start = 8836
-    _WITHCOLUMNS._serialized_end = 8955
-    _HINT._serialized_start = 8958
-    _HINT._serialized_end = 9090
-    _UNPIVOT._serialized_start = 9093
-    _UNPIVOT._serialized_end = 9420
-    _UNPIVOT_VALUES._serialized_start = 9350
-    _UNPIVOT_VALUES._serialized_end = 9409
-    _TOSCHEMA._serialized_start = 9422
-    _TOSCHEMA._serialized_end = 9528
-    _REPARTITIONBYEXPRESSION._serialized_start = 9531
-    _REPARTITIONBYEXPRESSION._serialized_end = 9734
-    _MAPPARTITIONS._serialized_start = 9737
-    _MAPPARTITIONS._serialized_end = 9867
-    _COLLECTMETRICS._serialized_start = 9870
-    _COLLECTMETRICS._serialized_end = 10006
-    _PARSE._serialized_start = 10009
-    _PARSE._serialized_end = 10397
-    _PARSE_OPTIONSENTRY._serialized_start = 3300
-    _PARSE_OPTIONSENTRY._serialized_end = 3358
-    _PARSE_PARSEFORMAT._serialized_start = 10298
-    _PARSE_PARSEFORMAT._serialized_end = 10386
+    _RELATION._serialized_end = 2709
+    _UNKNOWN._serialized_start = 2711
+    _UNKNOWN._serialized_end = 2720
+    _RELATIONCOMMON._serialized_start = 2722
+    _RELATIONCOMMON._serialized_end = 2813
+    _SQL._serialized_start = 2816
+    _SQL._serialized_end = 2950
+    _SQL_ARGSENTRY._serialized_start = 2895
+    _SQL_ARGSENTRY._serialized_end = 2950
+    _READ._serialized_start = 2953
+    _READ._serialized_end = 3449
+    _READ_NAMEDTABLE._serialized_start = 3095
+    _READ_NAMEDTABLE._serialized_end = 3156
+    _READ_DATASOURCE._serialized_start = 3159
+    _READ_DATASOURCE._serialized_end = 3436
+    _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3356
+    _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3414
+    _PROJECT._serialized_start = 3451
+    _PROJECT._serialized_end = 3568
+    _FILTER._serialized_start = 3570
+    _FILTER._serialized_end = 3682
+    _JOIN._serialized_start = 3685
+    _JOIN._serialized_end = 4156
+    _JOIN_JOINTYPE._serialized_start = 3948
+    _JOIN_JOINTYPE._serialized_end = 4156
+    _SETOPERATION._serialized_start = 4159
+    _SETOPERATION._serialized_end = 4638
+    _SETOPERATION_SETOPTYPE._serialized_start = 4475
+    _SETOPERATION_SETOPTYPE._serialized_end = 4589
+    _LIMIT._serialized_start = 4640
+    _LIMIT._serialized_end = 4716
+    _OFFSET._serialized_start = 4718
+    _OFFSET._serialized_end = 4797
+    _TAIL._serialized_start = 4799
+    _TAIL._serialized_end = 4874
+    _AGGREGATE._serialized_start = 4877
+    _AGGREGATE._serialized_end = 5459
+    _AGGREGATE_PIVOT._serialized_start = 5216
+    _AGGREGATE_PIVOT._serialized_end = 5327
+    _AGGREGATE_GROUPTYPE._serialized_start = 5330
+    _AGGREGATE_GROUPTYPE._serialized_end = 5459
+    _SORT._serialized_start = 5462
+    _SORT._serialized_end = 5622
+    _DROP._serialized_start = 5625
+    _DROP._serialized_end = 5766
+    _DEDUPLICATE._serialized_start = 5769
+    _DEDUPLICATE._serialized_end = 5940
+    _LOCALRELATION._serialized_start = 5942
+    _LOCALRELATION._serialized_end = 6031
+    _SAMPLE._serialized_start = 6034
+    _SAMPLE._serialized_end = 6307
+    _RANGE._serialized_start = 6310
+    _RANGE._serialized_end = 6455
+    _SUBQUERYALIAS._serialized_start = 6457
+    _SUBQUERYALIAS._serialized_end = 6571
+    _REPARTITION._serialized_start = 6574
+    _REPARTITION._serialized_end = 6716
+    _SHOWSTRING._serialized_start = 6719
+    _SHOWSTRING._serialized_end = 6861
+    _STATSUMMARY._serialized_start = 6863
+    _STATSUMMARY._serialized_end = 6955
+    _STATDESCRIBE._serialized_start = 6957
+    _STATDESCRIBE._serialized_end = 7038
+    _STATCROSSTAB._serialized_start = 7040
+    _STATCROSSTAB._serialized_end = 7141
+    _STATCOV._serialized_start = 7143
+    _STATCOV._serialized_end = 7239
+    _STATCORR._serialized_start = 7242
+    _STATCORR._serialized_end = 7379
+    _STATAPPROXQUANTILE._serialized_start = 7382
+    _STATAPPROXQUANTILE._serialized_end = 7546
+    _STATFREQITEMS._serialized_start = 7548
+    _STATFREQITEMS._serialized_end = 7673
+    _STATSAMPLEBY._serialized_start = 7676
+    _STATSAMPLEBY._serialized_end = 7985
+    _STATSAMPLEBY_FRACTION._serialized_start = 7877
+    _STATSAMPLEBY_FRACTION._serialized_end = 7976
+    _NAFILL._serialized_start = 7988
+    _NAFILL._serialized_end = 8122
+    _NADROP._serialized_start = 8125
+    _NADROP._serialized_end = 8259
+    _NAREPLACE._serialized_start = 8262
+    _NAREPLACE._serialized_end = 8558
+    _NAREPLACE_REPLACEMENT._serialized_start = 8417
+    _NAREPLACE_REPLACEMENT._serialized_end = 8558
+    _TODF._serialized_start = 8560
+    _TODF._serialized_end = 8648
+    _WITHCOLUMNSRENAMED._serialized_start = 8651
+    _WITHCOLUMNSRENAMED._serialized_end = 8890
+    _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 8823
+    _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 8890
+    _WITHCOLUMNS._serialized_start = 8892
+    _WITHCOLUMNS._serialized_end = 9011
+    _HINT._serialized_start = 9014
+    _HINT._serialized_end = 9146
+    _UNPIVOT._serialized_start = 9149
+    _UNPIVOT._serialized_end = 9476
+    _UNPIVOT_VALUES._serialized_start = 9406
+    _UNPIVOT_VALUES._serialized_end = 9465
+    _TOSCHEMA._serialized_start = 9478
+    _TOSCHEMA._serialized_end = 9584
+    _REPARTITIONBYEXPRESSION._serialized_start = 9587
+    _REPARTITIONBYEXPRESSION._serialized_end = 9790
+    _MAPPARTITIONS._serialized_start = 9793
+    _MAPPARTITIONS._serialized_end = 9923
+    _GROUPMAP._serialized_start = 9926
+    _GROUPMAP._serialized_end = 10129
+    _COLLECTMETRICS._serialized_start = 10132
+    _COLLECTMETRICS._serialized_end = 10268
+    _PARSE._serialized_start = 10271
+    _PARSE._serialized_end = 10659
+    _PARSE_OPTIONSENTRY._serialized_start = 3356
+    _PARSE_OPTIONSENTRY._serialized_end = 3414
+    _PARSE_PARSEFORMAT._serialized_start = 10560
+    _PARSE_PARSEFORMAT._serialized_end = 10648
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi 
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index ab1561996ef..6ae4a323f6f 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -92,6 +92,7 @@ class Relation(google.protobuf.message.Message):
     MAP_PARTITIONS_FIELD_NUMBER: builtins.int
     COLLECT_METRICS_FIELD_NUMBER: builtins.int
     PARSE_FIELD_NUMBER: builtins.int
+    GROUP_MAP_FIELD_NUMBER: builtins.int
     FILL_NA_FIELD_NUMBER: builtins.int
     DROP_NA_FIELD_NUMBER: builtins.int
     REPLACE_FIELD_NUMBER: builtins.int
@@ -167,6 +168,8 @@ class Relation(google.protobuf.message.Message):
     @property
     def parse(self) -> global___Parse: ...
     @property
+    def group_map(self) -> global___GroupMap: ...
+    @property
     def fill_na(self) -> global___NAFill:
         """NA functions"""
     @property
@@ -233,6 +236,7 @@ class Relation(google.protobuf.message.Message):
         map_partitions: global___MapPartitions | None = ...,
         collect_metrics: global___CollectMetrics | None = ...,
         parse: global___Parse | None = ...,
+        group_map: global___GroupMap | None = ...,
         fill_na: global___NAFill | None = ...,
         drop_na: global___NADrop | None = ...,
         replace: global___NAReplace | None = ...,
@@ -283,6 +287,8 @@ class Relation(google.protobuf.message.Message):
             b"filter",
             "freq_items",
             b"freq_items",
+            "group_map",
+            b"group_map",
             "hint",
             b"hint",
             "join",
@@ -378,6 +384,8 @@ class Relation(google.protobuf.message.Message):
             b"filter",
             "freq_items",
             b"freq_items",
+            "group_map",
+            b"group_map",
             "hint",
             b"hint",
             "join",
@@ -470,6 +478,7 @@ class Relation(google.protobuf.message.Message):
         "map_partitions",
         "collect_metrics",
         "parse",
+        "group_map",
         "fill_na",
         "drop_na",
         "replace",
@@ -2733,6 +2742,48 @@ class MapPartitions(google.protobuf.message.Message):
 
 global___MapPartitions = MapPartitions
 
+class GroupMap(google.protobuf.message.Message):
+    DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+    INPUT_FIELD_NUMBER: builtins.int
+    GROUPING_EXPRESSIONS_FIELD_NUMBER: builtins.int
+    FUNC_FIELD_NUMBER: builtins.int
+    @property
+    def input(self) -> global___Relation:
+        """(Required) Input relation for Group Map API: apply, 
applyInPandas."""
+    @property
+    def grouping_expressions(
+        self,
+    ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+        pyspark.sql.connect.proto.expressions_pb2.Expression
+    ]:
+        """(Required) Expressions for grouping keys."""
+    @property
+    def func(self) -> 
pyspark.sql.connect.proto.expressions_pb2.CommonInlineUserDefinedFunction:
+        """(Required) Input user-defined function."""
+    def __init__(
+        self,
+        *,
+        input: global___Relation | None = ...,
+        grouping_expressions: collections.abc.Iterable[
+            pyspark.sql.connect.proto.expressions_pb2.Expression
+        ]
+        | None = ...,
+        func: 
pyspark.sql.connect.proto.expressions_pb2.CommonInlineUserDefinedFunction
+        | None = ...,
+    ) -> None: ...
+    def HasField(
+        self, field_name: typing_extensions.Literal["func", b"func", "input", 
b"input"]
+    ) -> builtins.bool: ...
+    def ClearField(
+        self,
+        field_name: typing_extensions.Literal[
+            "func", b"func", "grouping_expressions", b"grouping_expressions", 
"input", b"input"
+        ],
+    ) -> None: ...
+
+global___GroupMap = GroupMap
+
 class CollectMetrics(google.protobuf.message.Message):
     """Collect arbitrary (named) metrics from a dataset."""
 
diff --git a/python/pyspark/sql/pandas/group_ops.py 
b/python/pyspark/sql/pandas/group_ops.py
index bca96eaf205..f03aa35bb83 100644
--- a/python/pyspark/sql/pandas/group_ops.py
+++ b/python/pyspark/sql/pandas/group_ops.py
@@ -48,6 +48,9 @@ class PandasGroupedOpsMixin:
 
         .. versionadded:: 2.3.0
 
+        .. versionchanged:: 3.4.0
+            Support Spark Connect.
+
         Parameters
         ----------
         udf : :func:`pyspark.sql.functions.pandas_udf`
@@ -128,6 +131,9 @@ class PandasGroupedOpsMixin:
 
         .. versionadded:: 3.0.0
 
+        .. versionchanged:: 3.4.0
+            Support Spark Connect.
+
         Parameters
         ----------
         func : function
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index a8e161a42a6..491865ad9c9 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -2845,8 +2845,6 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
         # SPARK-41927: Disable unsupported functions.
         cg = self.connect.read.table(self.tbl_name).groupBy("id")
         for f in (
-            "apply",
-            "applyInPandas",
             "applyInPandasWithState",
             "cogroup",
         ):
diff --git a/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map.py 
b/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map.py
new file mode 100644
index 00000000000..1736e395723
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map.py
@@ -0,0 +1,102 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+from pyspark.sql.tests.pandas.test_pandas_grouped_map import 
GroupedApplyInPandasTestsMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+
+
+class GroupedApplyInPandasTests(GroupedApplyInPandasTestsMixin, 
ReusedConnectTestCase):
+    # TODO(SPARK-42822): Fix ambiguous reference for case-insensitive grouping 
column
+    @unittest.skip("Fails in Spark Connect, should enable.")
+    def test_case_insensitive_grouping_column(self):
+        super().test_case_insensitive_grouping_column()
+
+    # TODO(SPARK-42857): Support CreateDataFrame from Decimal128
+    @unittest.skip("Fails in Spark Connect, should enable.")
+    def test_supported_types(self):
+        super().test_supported_types()
+
+    @unittest.skip(
+        "Spark Connect does not support sc._jvm.org.apache.log4j but the test 
depends on it."
+    )
+    def test_wrong_return_type(self):
+        super().test_wrong_return_type()
+
+    @unittest.skip(
+        "Spark Connect does not support sc._jvm.org.apache.log4j but the test 
depends on it."
+    )
+    def test_wrong_args(self):
+        super().test_wrong_args()
+
+    @unittest.skip(
+        "Spark Connect does not support sc._jvm.org.apache.log4j but the test 
depends on it."
+    )
+    def test_unsupported_types(self):
+        super().test_unsupported_types()
+
+    @unittest.skip(
+        "Spark Connect does not support sc._jvm.org.apache.log4j but the test 
depends on it."
+    )
+    def test_register_grouped_map_udf(self):
+        super().test_register_grouped_map_udf()
+
+    @unittest.skip(
+        "Spark Connect does not support sc._jvm.org.apache.log4j but the test 
depends on it."
+    )
+    def test_column_order(self):
+        super().test_column_order()
+
+    @unittest.skip(
+        "Spark Connect does not support sc._jvm.org.apache.log4j but the test 
depends on it."
+    )
+    def test_apply_in_pandas_returning_wrong_column_names(self):
+        super().test_apply_in_pandas_returning_wrong_column_names()
+
+    @unittest.skip(
+        "Spark Connect does not support sc._jvm.org.apache.log4j but the test 
depends on it."
+    )
+    def test_apply_in_pandas_returning_no_column_names_and_wrong_amount(self):
+        
super().test_apply_in_pandas_returning_no_column_names_and_wrong_amount()
+
+    @unittest.skip(
+        "Spark Connect does not support sc._jvm.org.apache.log4j but the test 
depends on it."
+    )
+    def test_apply_in_pandas_returning_incompatible_type(self):
+        super().test_apply_in_pandas_returning_incompatible_type()
+
+    @unittest.skip(
+        "Spark Connect does not support sc._jvm.org.apache.log4j but the test 
depends on it."
+    )
+    def test_apply_in_pandas_not_returning_pandas_dataframe(self):
+        super().test_apply_in_pandas_not_returning_pandas_dataframe()
+
+    @unittest.skip("Spark Connect doesn't support RDD but the test depends on 
it.")
+    def test_grouped_with_empty_partition(self):
+        super().test_grouped_with_empty_partition()
+
+
+if __name__ == "__main__":
+    from pyspark.sql.tests.connect.test_parity_pandas_grouped_map import *  # 
noqa: F401
+
+    try:
+        import xmlrunner
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/connect/test_parity_pandas_udf.py 
b/python/pyspark/sql/tests/connect/test_parity_pandas_udf.py
index 571ee74287e..d2eab7fa4f3 100644
--- a/python/pyspark/sql/tests/connect/test_parity_pandas_udf.py
+++ b/python/pyspark/sql/tests/connect/test_parity_pandas_udf.py
@@ -66,11 +66,6 @@ class PandasUDFParityTests(PandasUDFTestsMixin, 
ReusedConnectTestCase):
         self.assertEqual(udf.returnType, UnparsedDataType("v double"))
         self.assertEqual(udf.evalType, PandasUDFType.GROUPED_MAP)
 
-    # TODO(SPARK-42340): implement GroupedData.applyInPandas
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_stopiteration_in_grouped_map(self):
-        super().test_stopiteration_in_grouped_map()
-
 
 if __name__ == "__main__":
     import unittest
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py 
b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
index 88e68b04303..36bdae02944 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
@@ -73,7 +73,7 @@ if have_pyarrow:
     not have_pandas or not have_pyarrow,
     cast(str, pandas_requirement_message or pyarrow_requirement_message),
 )
-class GroupedApplyInPandasTests(ReusedSQLTestCase):
+class GroupedApplyInPandasTestsMixin:
     @property
     def data(self):
         return (
@@ -289,17 +289,17 @@ class GroupedApplyInPandasTests(ReusedSQLTestCase):
         return pd.DataFrame([key + (pdf.v.mean(),)])
 
     def test_apply_in_pandas_returning_column_names(self):
-        
self._test_apply_in_pandas(GroupedApplyInPandasTests.stats_with_column_names)
+        
self._test_apply_in_pandas(GroupedApplyInPandasTestsMixin.stats_with_column_names)
 
     def test_apply_in_pandas_returning_no_column_names(self):
-        
self._test_apply_in_pandas(GroupedApplyInPandasTests.stats_with_no_column_names)
+        
self._test_apply_in_pandas(GroupedApplyInPandasTestsMixin.stats_with_no_column_names)
 
     def test_apply_in_pandas_returning_column_names_sometimes(self):
         def stats(key, pdf):
             if key[0] % 2:
-                return GroupedApplyInPandasTests.stats_with_column_names(key, 
pdf)
+                return 
GroupedApplyInPandasTestsMixin.stats_with_column_names(key, pdf)
             else:
-                return 
GroupedApplyInPandasTests.stats_with_no_column_names(key, pdf)
+                return 
GroupedApplyInPandasTestsMixin.stats_with_no_column_names(key, pdf)
 
         self._test_apply_in_pandas(stats)
 
@@ -782,7 +782,7 @@ class GroupedApplyInPandasTests(ReusedSQLTestCase):
 
         def stats(key, pdf):
             if key[0] % 2 == 0:
-                return 
GroupedApplyInPandasTests.stats_with_no_column_names(key, pdf)
+                return 
GroupedApplyInPandasTestsMixin.stats_with_no_column_names(key, pdf)
             return empty_df
 
         result = (
@@ -805,6 +805,10 @@ class GroupedApplyInPandasTests(ReusedSQLTestCase):
                 self._test_apply_in_pandas_returning_empty_dataframe(empty_df)
 
 
+class GroupedApplyInPandasTests(GroupedApplyInPandasTestsMixin, 
ReusedSQLTestCase):
+    pass
+
+
 if __name__ == "__main__":
     from pyspark.sql.tests.pandas.test_pandas_grouped_map import *  # noqa: 
F401
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to