This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1fbc7948e57 [SPARK-42891][CONNECT][PYTHON] Implement CoGrouped Map API
1fbc7948e57 is described below

commit 1fbc7948e57cbf05a46cb0c7fb2fad4ec25540e6
Author: Xinrong Meng <[email protected]>
AuthorDate: Thu Mar 23 12:38:30 2023 +0900

    [SPARK-42891][CONNECT][PYTHON] Implement CoGrouped Map API
    
    ### What changes were proposed in this pull request?
    Implement CoGrouped Map API: `applyInPandas`.
    
    ### Why are the changes needed?
    Parity with vanilla PySpark.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. CoGrouped Map API is supported as shown below.
    
    ```sh
    >>> import pandas as pd
    >>> df1 = spark.createDataFrame(
    ...   [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), 
(20000102, 2, 4.0)], ("time", "id", "v1"))
    >>>
    >>> df2 = spark.createDataFrame(
    ...   [(20000101, 1, "x"), (20000101, 2, "y")], ("time", "id", "v2"))
    >>>
    >>> def asof_join(l, r):
    ...   return pd.merge_asof(l, r, on="time", by="id")
    ...
    >>> df1.groupby("id").cogroup(df2.groupby("id")).applyInPandas(
    ...   asof_join, schema="time int, id int, v1 double, v2 string"
    ... ).show()
    +--------+---+---+---+
    |    time| id| v1| v2|
    +--------+---+---+---+
    |20000101|  1|1.0|  x|
    |20000102|  1|3.0|  x|
    |20000101|  2|2.0|  y|
    |20000102|  2|4.0|  y|
    +--------+---+---+---+
    ```
    
    ### How was this patch tested?
    Parity unit tests.
    
    Closes #40487 from xinrong-meng/cogroup_map.
    
    Authored-by: Xinrong Meng <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../main/protobuf/spark/connect/relations.proto    |  18 ++
 .../sql/connect/planner/SparkConnectPlanner.scala  |  22 ++
 dev/sparktestsupport/modules.py                    |   1 +
 python/pyspark/sql/connect/_typing.py              |   2 +
 python/pyspark/sql/connect/group.py                |  49 +++-
 python/pyspark/sql/connect/plan.py                 |  40 ++++
 python/pyspark/sql/connect/proto/relations_pb2.py  | 246 +++++++++++----------
 python/pyspark/sql/connect/proto/relations_pb2.pyi |  80 +++++++
 python/pyspark/sql/pandas/group_ops.py             |   9 +
 .../sql/tests/connect/test_connect_basic.py        |   5 +-
 .../connect/test_parity_pandas_cogrouped_map.py    |  82 +++++++
 .../sql/tests/pandas/test_pandas_cogrouped_map.py  |   6 +-
 12 files changed, 437 insertions(+), 123 deletions(-)

diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index aba965082ea..70aa399c424 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -64,6 +64,7 @@ message Relation {
     CollectMetrics collect_metrics = 29;
     Parse parse = 30;
     GroupMap group_map = 31;
+    CoGroupMap co_group_map = 32;
 
     // NA functions
     NAFill fill_na = 90;
@@ -800,6 +801,23 @@ message GroupMap {
   CommonInlineUserDefinedFunction func = 3;
 }
 
+message CoGroupMap {
+  // (Required) One input relation for CoGroup Map API - applyInPandas.
+  Relation input = 1;
+
+  // Expressions for grouping keys of the first input relation.
+  repeated Expression input_grouping_expressions = 2;
+
+  // (Required) The other input relation.
+  Relation other = 3;
+
+  // Expressions for grouping keys of the other input relation.
+  repeated Expression other_grouping_expressions = 4;
+
+  // (Required) Input user-defined function.
+  CommonInlineUserDefinedFunction func = 5;
+}
+
 // Collect arbitrary (named) metrics from a dataset.
 message CollectMetrics {
   // (Required) The input relation.
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 0faa2f74981..ba210d37231 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -119,6 +119,8 @@ class SparkConnectPlanner(val session: SparkSession) {
         transformMapPartitions(rel.getMapPartitions)
       case proto.Relation.RelTypeCase.GROUP_MAP =>
         transformGroupMap(rel.getGroupMap)
+      case proto.Relation.RelTypeCase.CO_GROUP_MAP =>
+        transformCoGroupMap(rel.getCoGroupMap)
       case proto.Relation.RelTypeCase.COLLECT_METRICS =>
         transformCollectMetrics(rel.getCollectMetrics)
       case proto.Relation.RelTypeCase.PARSE => transformParse(rel.getParse)
@@ -509,6 +511,26 @@ class SparkConnectPlanner(val session: SparkSession) {
       .logicalPlan
   }
 
+  private def transformCoGroupMap(rel: proto.CoGroupMap): LogicalPlan = {
+    val pythonUdf = transformPythonUDF(rel.getFunc)
+
+    val inputCols =
+      rel.getInputGroupingExpressionsList.asScala.toSeq.map(expr =>
+        Column(transformExpression(expr)))
+    val otherCols =
+      rel.getOtherGroupingExpressionsList.asScala.toSeq.map(expr =>
+        Column(transformExpression(expr)))
+
+    val input = Dataset
+      .ofRows(session, transformRelation(rel.getInput))
+      .groupBy(inputCols: _*)
+    val other = Dataset
+      .ofRows(session, transformRelation(rel.getOther))
+      .groupBy(otherCols: _*)
+
+    input.flatMapCoGroupsInPandas(other, pythonUdf).logicalPlan
+  }
+
   private def transformWithColumnsRenamed(rel: proto.WithColumnsRenamed): 
LogicalPlan = {
     Dataset
       .ofRows(session, transformRelation(rel.getInput))
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index f4efb412c87..c3c3b415a1f 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -768,6 +768,7 @@ pyspark_connect = Module(
         "pyspark.sql.tests.connect.test_parity_pandas_map",
         "pyspark.sql.tests.connect.test_parity_arrow_map",
         "pyspark.sql.tests.connect.test_parity_pandas_grouped_map",
+        "pyspark.sql.tests.connect.test_parity_pandas_cogrouped_map",
         # ml doctests
         "pyspark.ml.connect.functions",
         # ml unittests
diff --git a/python/pyspark/sql/connect/_typing.py 
b/python/pyspark/sql/connect/_typing.py
index 63aae5d2487..6bdde2926cb 100644
--- a/python/pyspark/sql/connect/_typing.py
+++ b/python/pyspark/sql/connect/_typing.py
@@ -61,6 +61,8 @@ PandasGroupedMapFunction = Union[
 
 GroupedMapPandasUserDefinedFunction = 
NewType("GroupedMapPandasUserDefinedFunction", FunctionType)
 
+PandasCogroupedMapFunction = Callable[[DataFrameLike, DataFrameLike], 
DataFrameLike]
+
 
 class UserDefinedFunctionLike(Protocol):
     func: Callable[..., Any]
diff --git a/python/pyspark/sql/connect/group.py 
b/python/pyspark/sql/connect/group.py
index a75a50501bd..8377caac592 100644
--- a/python/pyspark/sql/connect/group.py
+++ b/python/pyspark/sql/connect/group.py
@@ -34,6 +34,7 @@ from typing import (
 
 from pyspark.rdd import PythonEvalType
 from pyspark.sql.group import GroupedData as PySparkGroupedData
+from pyspark.sql.pandas.group_ops import PandasCogroupedOps as 
PySparkPandasCogroupedOps
 from pyspark.sql.types import NumericType
 
 import pyspark.sql.connect.plan as plan
@@ -45,6 +46,7 @@ if TYPE_CHECKING:
         LiteralType,
         PandasGroupedMapFunction,
         GroupedMapPandasUserDefinedFunction,
+        PandasCogroupedMapFunction,
     )
     from pyspark.sql.connect.dataframe import DataFrame
     from pyspark.sql.types import StructType
@@ -263,13 +265,56 @@ class GroupedData:
     def applyInPandasWithState(self, *args: Any, **kwargs: Any) -> None:
         raise NotImplementedError("applyInPandasWithState() is not 
implemented.")
 
-    def cogroup(self, *args: Any, **kwargs: Any) -> None:
-        raise NotImplementedError("cogroup() is not implemented.")
+    def cogroup(self, other: "GroupedData") -> "PandasCogroupedOps":
+        return PandasCogroupedOps(self, other)
+
+    cogroup.__doc__ = PySparkGroupedData.cogroup.__doc__
 
 
 GroupedData.__doc__ = PySparkGroupedData.__doc__
 
 
+class PandasCogroupedOps:
+    def __init__(self, gd1: "GroupedData", gd2: "GroupedData"):
+        self._gd1 = gd1
+        self._gd2 = gd2
+
+    def applyInPandas(
+        self, func: "PandasCogroupedMapFunction", schema: Union["StructType", 
str]
+    ) -> "DataFrame":
+        from pyspark.sql.connect.udf import UserDefinedFunction
+        from pyspark.sql.connect.dataframe import DataFrame
+
+        udf_obj = UserDefinedFunction(
+            func,
+            returnType=schema,
+            evalType=PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
+        )
+
+        all_cols = self._extract_cols(self._gd1) + 
self._extract_cols(self._gd2)
+        return DataFrame.withPlan(
+            plan.CoGroupMap(
+                input=self._gd1._df._plan,
+                input_grouping_cols=self._gd1._grouping_cols,
+                other=self._gd2._df._plan,
+                other_grouping_cols=self._gd2._grouping_cols,
+                function=udf_obj,
+                cols=all_cols,
+            ),
+            session=self._gd1._df._session,
+        )
+
+    applyInPandas.__doc__ = PySparkPandasCogroupedOps.applyInPandas.__doc__
+
+    @staticmethod
+    def _extract_cols(gd: "GroupedData") -> List[Column]:
+        df = gd._df
+        return [df[col] for col in df.columns]
+
+
+PandasCogroupedOps.__doc__ = PySparkPandasCogroupedOps.__doc__
+
+
 def _test() -> None:
     import sys
     import doctest
diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index dbfcfea7678..34eaf0d5bee 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -1950,6 +1950,46 @@ class GroupMap(LogicalPlan):
         return plan
 
 
+class CoGroupMap(LogicalPlan):
+    """Logical plan object for a CoGroup Map API: applyInPandas."""
+
+    def __init__(
+        self,
+        input: Optional["LogicalPlan"],
+        input_grouping_cols: Sequence[Column],
+        other: Optional["LogicalPlan"],
+        other_grouping_cols: Sequence[Column],
+        function: "UserDefinedFunction",
+        cols: List[Column],
+    ):
+        assert isinstance(input_grouping_cols, list) and all(
+            isinstance(c, Column) for c in input_grouping_cols
+        )
+        assert isinstance(other_grouping_cols, list) and all(
+            isinstance(c, Column) for c in other_grouping_cols
+        )
+
+        super().__init__(input)
+        self._input_grouping_cols = input_grouping_cols
+        self._other_grouping_cols = other_grouping_cols
+        self._other = cast(LogicalPlan, other)
+        self._func = function._build_common_inline_user_defined_function(*cols)
+
+    def plan(self, session: "SparkConnectClient") -> proto.Relation:
+        assert self._child is not None
+        plan = self._create_proto_relation()
+        plan.co_group_map.input.CopyFrom(self._child.plan(session))
+        plan.co_group_map.input_grouping_expressions.extend(
+            [c.to_plan(session) for c in self._input_grouping_cols]
+        )
+        plan.co_group_map.other.CopyFrom(self._other.plan(session))
+        plan.co_group_map.other_grouping_expressions.extend(
+            [c.to_plan(session) for c in self._other_grouping_cols]
+        )
+        plan.co_group_map.func.CopyFrom(self._func.to_plan_udf(session))
+        return plan
+
+
 class CachedRelation(LogicalPlan):
     def __init__(self, plan: proto.Relation) -> None:
         super(CachedRelation, self).__init__(None)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py 
b/python/pyspark/sql/connect/proto/relations_pb2.py
index aa6d39cd4f0..60ec7081e09 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as 
spark_dot_connect_dot_catal
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xf0\x13\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
+    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xaf\x14\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
 )
 
 
@@ -93,6 +93,7 @@ _TOSCHEMA = DESCRIPTOR.message_types_by_name["ToSchema"]
 _REPARTITIONBYEXPRESSION = 
DESCRIPTOR.message_types_by_name["RepartitionByExpression"]
 _MAPPARTITIONS = DESCRIPTOR.message_types_by_name["MapPartitions"]
 _GROUPMAP = DESCRIPTOR.message_types_by_name["GroupMap"]
+_COGROUPMAP = DESCRIPTOR.message_types_by_name["CoGroupMap"]
 _COLLECTMETRICS = DESCRIPTOR.message_types_by_name["CollectMetrics"]
 _PARSE = DESCRIPTOR.message_types_by_name["Parse"]
 _PARSE_OPTIONSENTRY = _PARSE.nested_types_by_name["OptionsEntry"]
@@ -652,6 +653,17 @@ GroupMap = _reflection.GeneratedProtocolMessageType(
 )
 _sym_db.RegisterMessage(GroupMap)
 
+CoGroupMap = _reflection.GeneratedProtocolMessageType(
+    "CoGroupMap",
+    (_message.Message,),
+    {
+        "DESCRIPTOR": _COGROUPMAP,
+        "__module__": "spark.connect.relations_pb2"
+        # @@protoc_insertion_point(class_scope:spark.connect.CoGroupMap)
+    },
+)
+_sym_db.RegisterMessage(CoGroupMap)
+
 CollectMetrics = _reflection.GeneratedProtocolMessageType(
     "CollectMetrics",
     (_message.Message,),
@@ -697,119 +709,121 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _PARSE_OPTIONSENTRY._options = None
     _PARSE_OPTIONSENTRY._serialized_options = b"8\001"
     _RELATION._serialized_start = 165
-    _RELATION._serialized_end = 2709
-    _UNKNOWN._serialized_start = 2711
-    _UNKNOWN._serialized_end = 2720
-    _RELATIONCOMMON._serialized_start = 2722
-    _RELATIONCOMMON._serialized_end = 2813
-    _SQL._serialized_start = 2816
-    _SQL._serialized_end = 2950
-    _SQL_ARGSENTRY._serialized_start = 2895
-    _SQL_ARGSENTRY._serialized_end = 2950
-    _READ._serialized_start = 2953
-    _READ._serialized_end = 3449
-    _READ_NAMEDTABLE._serialized_start = 3095
-    _READ_NAMEDTABLE._serialized_end = 3156
-    _READ_DATASOURCE._serialized_start = 3159
-    _READ_DATASOURCE._serialized_end = 3436
-    _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3356
-    _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3414
-    _PROJECT._serialized_start = 3451
-    _PROJECT._serialized_end = 3568
-    _FILTER._serialized_start = 3570
-    _FILTER._serialized_end = 3682
-    _JOIN._serialized_start = 3685
-    _JOIN._serialized_end = 4156
-    _JOIN_JOINTYPE._serialized_start = 3948
-    _JOIN_JOINTYPE._serialized_end = 4156
-    _SETOPERATION._serialized_start = 4159
-    _SETOPERATION._serialized_end = 4638
-    _SETOPERATION_SETOPTYPE._serialized_start = 4475
-    _SETOPERATION_SETOPTYPE._serialized_end = 4589
-    _LIMIT._serialized_start = 4640
-    _LIMIT._serialized_end = 4716
-    _OFFSET._serialized_start = 4718
-    _OFFSET._serialized_end = 4797
-    _TAIL._serialized_start = 4799
-    _TAIL._serialized_end = 4874
-    _AGGREGATE._serialized_start = 4877
-    _AGGREGATE._serialized_end = 5459
-    _AGGREGATE_PIVOT._serialized_start = 5216
-    _AGGREGATE_PIVOT._serialized_end = 5327
-    _AGGREGATE_GROUPTYPE._serialized_start = 5330
-    _AGGREGATE_GROUPTYPE._serialized_end = 5459
-    _SORT._serialized_start = 5462
-    _SORT._serialized_end = 5622
-    _DROP._serialized_start = 5625
-    _DROP._serialized_end = 5766
-    _DEDUPLICATE._serialized_start = 5769
-    _DEDUPLICATE._serialized_end = 5940
-    _LOCALRELATION._serialized_start = 5942
-    _LOCALRELATION._serialized_end = 6031
-    _SAMPLE._serialized_start = 6034
-    _SAMPLE._serialized_end = 6307
-    _RANGE._serialized_start = 6310
-    _RANGE._serialized_end = 6455
-    _SUBQUERYALIAS._serialized_start = 6457
-    _SUBQUERYALIAS._serialized_end = 6571
-    _REPARTITION._serialized_start = 6574
-    _REPARTITION._serialized_end = 6716
-    _SHOWSTRING._serialized_start = 6719
-    _SHOWSTRING._serialized_end = 6861
-    _STATSUMMARY._serialized_start = 6863
-    _STATSUMMARY._serialized_end = 6955
-    _STATDESCRIBE._serialized_start = 6957
-    _STATDESCRIBE._serialized_end = 7038
-    _STATCROSSTAB._serialized_start = 7040
-    _STATCROSSTAB._serialized_end = 7141
-    _STATCOV._serialized_start = 7143
-    _STATCOV._serialized_end = 7239
-    _STATCORR._serialized_start = 7242
-    _STATCORR._serialized_end = 7379
-    _STATAPPROXQUANTILE._serialized_start = 7382
-    _STATAPPROXQUANTILE._serialized_end = 7546
-    _STATFREQITEMS._serialized_start = 7548
-    _STATFREQITEMS._serialized_end = 7673
-    _STATSAMPLEBY._serialized_start = 7676
-    _STATSAMPLEBY._serialized_end = 7985
-    _STATSAMPLEBY_FRACTION._serialized_start = 7877
-    _STATSAMPLEBY_FRACTION._serialized_end = 7976
-    _NAFILL._serialized_start = 7988
-    _NAFILL._serialized_end = 8122
-    _NADROP._serialized_start = 8125
-    _NADROP._serialized_end = 8259
-    _NAREPLACE._serialized_start = 8262
-    _NAREPLACE._serialized_end = 8558
-    _NAREPLACE_REPLACEMENT._serialized_start = 8417
-    _NAREPLACE_REPLACEMENT._serialized_end = 8558
-    _TODF._serialized_start = 8560
-    _TODF._serialized_end = 8648
-    _WITHCOLUMNSRENAMED._serialized_start = 8651
-    _WITHCOLUMNSRENAMED._serialized_end = 8890
-    _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 8823
-    _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 8890
-    _WITHCOLUMNS._serialized_start = 8892
-    _WITHCOLUMNS._serialized_end = 9011
-    _HINT._serialized_start = 9014
-    _HINT._serialized_end = 9146
-    _UNPIVOT._serialized_start = 9149
-    _UNPIVOT._serialized_end = 9476
-    _UNPIVOT_VALUES._serialized_start = 9406
-    _UNPIVOT_VALUES._serialized_end = 9465
-    _TOSCHEMA._serialized_start = 9478
-    _TOSCHEMA._serialized_end = 9584
-    _REPARTITIONBYEXPRESSION._serialized_start = 9587
-    _REPARTITIONBYEXPRESSION._serialized_end = 9790
-    _MAPPARTITIONS._serialized_start = 9793
-    _MAPPARTITIONS._serialized_end = 9923
-    _GROUPMAP._serialized_start = 9926
-    _GROUPMAP._serialized_end = 10129
-    _COLLECTMETRICS._serialized_start = 10132
-    _COLLECTMETRICS._serialized_end = 10268
-    _PARSE._serialized_start = 10271
-    _PARSE._serialized_end = 10659
-    _PARSE_OPTIONSENTRY._serialized_start = 3356
-    _PARSE_OPTIONSENTRY._serialized_end = 3414
-    _PARSE_PARSEFORMAT._serialized_start = 10560
-    _PARSE_PARSEFORMAT._serialized_end = 10648
+    _RELATION._serialized_end = 2772
+    _UNKNOWN._serialized_start = 2774
+    _UNKNOWN._serialized_end = 2783
+    _RELATIONCOMMON._serialized_start = 2785
+    _RELATIONCOMMON._serialized_end = 2876
+    _SQL._serialized_start = 2879
+    _SQL._serialized_end = 3013
+    _SQL_ARGSENTRY._serialized_start = 2958
+    _SQL_ARGSENTRY._serialized_end = 3013
+    _READ._serialized_start = 3016
+    _READ._serialized_end = 3512
+    _READ_NAMEDTABLE._serialized_start = 3158
+    _READ_NAMEDTABLE._serialized_end = 3219
+    _READ_DATASOURCE._serialized_start = 3222
+    _READ_DATASOURCE._serialized_end = 3499
+    _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3419
+    _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3477
+    _PROJECT._serialized_start = 3514
+    _PROJECT._serialized_end = 3631
+    _FILTER._serialized_start = 3633
+    _FILTER._serialized_end = 3745
+    _JOIN._serialized_start = 3748
+    _JOIN._serialized_end = 4219
+    _JOIN_JOINTYPE._serialized_start = 4011
+    _JOIN_JOINTYPE._serialized_end = 4219
+    _SETOPERATION._serialized_start = 4222
+    _SETOPERATION._serialized_end = 4701
+    _SETOPERATION_SETOPTYPE._serialized_start = 4538
+    _SETOPERATION_SETOPTYPE._serialized_end = 4652
+    _LIMIT._serialized_start = 4703
+    _LIMIT._serialized_end = 4779
+    _OFFSET._serialized_start = 4781
+    _OFFSET._serialized_end = 4860
+    _TAIL._serialized_start = 4862
+    _TAIL._serialized_end = 4937
+    _AGGREGATE._serialized_start = 4940
+    _AGGREGATE._serialized_end = 5522
+    _AGGREGATE_PIVOT._serialized_start = 5279
+    _AGGREGATE_PIVOT._serialized_end = 5390
+    _AGGREGATE_GROUPTYPE._serialized_start = 5393
+    _AGGREGATE_GROUPTYPE._serialized_end = 5522
+    _SORT._serialized_start = 5525
+    _SORT._serialized_end = 5685
+    _DROP._serialized_start = 5688
+    _DROP._serialized_end = 5829
+    _DEDUPLICATE._serialized_start = 5832
+    _DEDUPLICATE._serialized_end = 6003
+    _LOCALRELATION._serialized_start = 6005
+    _LOCALRELATION._serialized_end = 6094
+    _SAMPLE._serialized_start = 6097
+    _SAMPLE._serialized_end = 6370
+    _RANGE._serialized_start = 6373
+    _RANGE._serialized_end = 6518
+    _SUBQUERYALIAS._serialized_start = 6520
+    _SUBQUERYALIAS._serialized_end = 6634
+    _REPARTITION._serialized_start = 6637
+    _REPARTITION._serialized_end = 6779
+    _SHOWSTRING._serialized_start = 6782
+    _SHOWSTRING._serialized_end = 6924
+    _STATSUMMARY._serialized_start = 6926
+    _STATSUMMARY._serialized_end = 7018
+    _STATDESCRIBE._serialized_start = 7020
+    _STATDESCRIBE._serialized_end = 7101
+    _STATCROSSTAB._serialized_start = 7103
+    _STATCROSSTAB._serialized_end = 7204
+    _STATCOV._serialized_start = 7206
+    _STATCOV._serialized_end = 7302
+    _STATCORR._serialized_start = 7305
+    _STATCORR._serialized_end = 7442
+    _STATAPPROXQUANTILE._serialized_start = 7445
+    _STATAPPROXQUANTILE._serialized_end = 7609
+    _STATFREQITEMS._serialized_start = 7611
+    _STATFREQITEMS._serialized_end = 7736
+    _STATSAMPLEBY._serialized_start = 7739
+    _STATSAMPLEBY._serialized_end = 8048
+    _STATSAMPLEBY_FRACTION._serialized_start = 7940
+    _STATSAMPLEBY_FRACTION._serialized_end = 8039
+    _NAFILL._serialized_start = 8051
+    _NAFILL._serialized_end = 8185
+    _NADROP._serialized_start = 8188
+    _NADROP._serialized_end = 8322
+    _NAREPLACE._serialized_start = 8325
+    _NAREPLACE._serialized_end = 8621
+    _NAREPLACE_REPLACEMENT._serialized_start = 8480
+    _NAREPLACE_REPLACEMENT._serialized_end = 8621
+    _TODF._serialized_start = 8623
+    _TODF._serialized_end = 8711
+    _WITHCOLUMNSRENAMED._serialized_start = 8714
+    _WITHCOLUMNSRENAMED._serialized_end = 8953
+    _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 8886
+    _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 8953
+    _WITHCOLUMNS._serialized_start = 8955
+    _WITHCOLUMNS._serialized_end = 9074
+    _HINT._serialized_start = 9077
+    _HINT._serialized_end = 9209
+    _UNPIVOT._serialized_start = 9212
+    _UNPIVOT._serialized_end = 9539
+    _UNPIVOT_VALUES._serialized_start = 9469
+    _UNPIVOT_VALUES._serialized_end = 9528
+    _TOSCHEMA._serialized_start = 9541
+    _TOSCHEMA._serialized_end = 9647
+    _REPARTITIONBYEXPRESSION._serialized_start = 9650
+    _REPARTITIONBYEXPRESSION._serialized_end = 9853
+    _MAPPARTITIONS._serialized_start = 9856
+    _MAPPARTITIONS._serialized_end = 9986
+    _GROUPMAP._serialized_start = 9989
+    _GROUPMAP._serialized_end = 10192
+    _COGROUPMAP._serialized_start = 10195
+    _COGROUPMAP._serialized_end = 10547
+    _COLLECTMETRICS._serialized_start = 10550
+    _COLLECTMETRICS._serialized_end = 10686
+    _PARSE._serialized_start = 10689
+    _PARSE._serialized_end = 11077
+    _PARSE_OPTIONSENTRY._serialized_start = 3419
+    _PARSE_OPTIONSENTRY._serialized_end = 3477
+    _PARSE_PARSEFORMAT._serialized_start = 10978
+    _PARSE_PARSEFORMAT._serialized_end = 11066
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi 
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 6ae4a323f6f..814b23ca26d 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -93,6 +93,7 @@ class Relation(google.protobuf.message.Message):
     COLLECT_METRICS_FIELD_NUMBER: builtins.int
     PARSE_FIELD_NUMBER: builtins.int
     GROUP_MAP_FIELD_NUMBER: builtins.int
+    CO_GROUP_MAP_FIELD_NUMBER: builtins.int
     FILL_NA_FIELD_NUMBER: builtins.int
     DROP_NA_FIELD_NUMBER: builtins.int
     REPLACE_FIELD_NUMBER: builtins.int
@@ -170,6 +171,8 @@ class Relation(google.protobuf.message.Message):
     @property
     def group_map(self) -> global___GroupMap: ...
     @property
+    def co_group_map(self) -> global___CoGroupMap: ...
+    @property
     def fill_na(self) -> global___NAFill:
         """NA functions"""
     @property
@@ -237,6 +240,7 @@ class Relation(google.protobuf.message.Message):
         collect_metrics: global___CollectMetrics | None = ...,
         parse: global___Parse | None = ...,
         group_map: global___GroupMap | None = ...,
+        co_group_map: global___CoGroupMap | None = ...,
         fill_na: global___NAFill | None = ...,
         drop_na: global___NADrop | None = ...,
         replace: global___NAReplace | None = ...,
@@ -261,6 +265,8 @@ class Relation(google.protobuf.message.Message):
             b"approx_quantile",
             "catalog",
             b"catalog",
+            "co_group_map",
+            b"co_group_map",
             "collect_metrics",
             b"collect_metrics",
             "common",
@@ -358,6 +364,8 @@ class Relation(google.protobuf.message.Message):
             b"approx_quantile",
             "catalog",
             b"catalog",
+            "co_group_map",
+            b"co_group_map",
             "collect_metrics",
             b"collect_metrics",
             "common",
@@ -479,6 +487,7 @@ class Relation(google.protobuf.message.Message):
         "collect_metrics",
         "parse",
         "group_map",
+        "co_group_map",
         "fill_na",
         "drop_na",
         "replace",
@@ -2784,6 +2793,77 @@ class GroupMap(google.protobuf.message.Message):
 
 global___GroupMap = GroupMap
 
+class CoGroupMap(google.protobuf.message.Message):
+    DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+    INPUT_FIELD_NUMBER: builtins.int
+    INPUT_GROUPING_EXPRESSIONS_FIELD_NUMBER: builtins.int
+    OTHER_FIELD_NUMBER: builtins.int
+    OTHER_GROUPING_EXPRESSIONS_FIELD_NUMBER: builtins.int
+    FUNC_FIELD_NUMBER: builtins.int
+    @property
+    def input(self) -> global___Relation:
+        """(Required) One input relation for CoGroup Map API - 
applyInPandas."""
+    @property
+    def input_grouping_expressions(
+        self,
+    ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+        pyspark.sql.connect.proto.expressions_pb2.Expression
+    ]:
+        """Expressions for grouping keys of the first input relation."""
+    @property
+    def other(self) -> global___Relation:
+        """(Required) The other input relation."""
+    @property
+    def other_grouping_expressions(
+        self,
+    ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+        pyspark.sql.connect.proto.expressions_pb2.Expression
+    ]:
+        """Expressions for grouping keys of the other input relation."""
+    @property
+    def func(self) -> 
pyspark.sql.connect.proto.expressions_pb2.CommonInlineUserDefinedFunction:
+        """(Required) Input user-defined function."""
+    def __init__(
+        self,
+        *,
+        input: global___Relation | None = ...,
+        input_grouping_expressions: collections.abc.Iterable[
+            pyspark.sql.connect.proto.expressions_pb2.Expression
+        ]
+        | None = ...,
+        other: global___Relation | None = ...,
+        other_grouping_expressions: collections.abc.Iterable[
+            pyspark.sql.connect.proto.expressions_pb2.Expression
+        ]
+        | None = ...,
+        func: 
pyspark.sql.connect.proto.expressions_pb2.CommonInlineUserDefinedFunction
+        | None = ...,
+    ) -> None: ...
+    def HasField(
+        self,
+        field_name: typing_extensions.Literal[
+            "func", b"func", "input", b"input", "other", b"other"
+        ],
+    ) -> builtins.bool: ...
+    def ClearField(
+        self,
+        field_name: typing_extensions.Literal[
+            "func",
+            b"func",
+            "input",
+            b"input",
+            "input_grouping_expressions",
+            b"input_grouping_expressions",
+            "other",
+            b"other",
+            "other_grouping_expressions",
+            b"other_grouping_expressions",
+        ],
+    ) -> None: ...
+
+global___CoGroupMap = CoGroupMap
+
 class CollectMetrics(google.protobuf.message.Message):
     """Collect arbitrary (named) metrics from a dataset."""
 
diff --git a/python/pyspark/sql/pandas/group_ops.py 
b/python/pyspark/sql/pandas/group_ops.py
index f03aa35bb83..12ceb56c79f 100644
--- a/python/pyspark/sql/pandas/group_ops.py
+++ b/python/pyspark/sql/pandas/group_ops.py
@@ -353,6 +353,9 @@ class PandasGroupedOpsMixin:
 
         .. versionadded:: 3.0.0
 
+        .. versionchanged:: 3.4.0
+            Support Spark Connect.
+
         See :class:`PandasCogroupedOps` for the operations that can be run.
         """
         from pyspark.sql import GroupedData
@@ -369,6 +372,9 @@ class PandasCogroupedOps:
 
     .. versionadded:: 3.0.0
 
+    .. versionchanged:: 3.4.0
+        Support Spark Connect.
+
     Notes
     -----
     This API is experimental.
@@ -400,6 +406,9 @@ class PandasCogroupedOps:
 
         .. versionadded:: 3.0.0
 
+        .. versionchanged:: 3.4.0
+            Support Spark Connect.
+
         Parameters
         ----------
         func : function
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index f911ca9ba78..682b3471a74 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -2841,10 +2841,7 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
     def test_unsupported_group_functions(self):
         # SPARK-41927: Disable unsupported functions.
         cg = self.connect.read.table(self.tbl_name).groupBy("id")
-        for f in (
-            "applyInPandasWithState",
-            "cogroup",
-        ):
+        for f in ("applyInPandasWithState",):
             with self.assertRaises(NotImplementedError):
                 getattr(cg, f)()
 
diff --git 
a/python/pyspark/sql/tests/connect/test_parity_pandas_cogrouped_map.py 
b/python/pyspark/sql/tests/connect/test_parity_pandas_cogrouped_map.py
new file mode 100644
index 00000000000..c03bc5f8219
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_parity_pandas_cogrouped_map.py
@@ -0,0 +1,82 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+from pyspark.sql.tests.pandas.test_pandas_cogrouped_map import 
CogroupedApplyInPandasTestsMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+
+
+class CogroupedApplyInPandasTests(CogroupedApplyInPandasTestsMixin, 
ReusedConnectTestCase):
+    @unittest.skip(
+        "Spark Connect does not support sc._jvm.org.apache.log4j but the test 
depends on it."
+    )
+    def test_different_group_key_cardinality(self):
+        super().test_different_group_key_cardinality()
+
+    @unittest.skip(
+        "Spark Connect does not support sc._jvm.org.apache.log4j but the test 
depends on it."
+    )
+    def test_apply_in_pandas_returning_incompatible_type(self):
+        super().test_apply_in_pandas_returning_incompatible_type()
+
+    @unittest.skip(
+        "Spark Connect does not support sc._jvm.org.apache.log4j but the test 
depends on it."
+    )
+    def test_wrong_args(self):
+        super().test_wrong_args()
+
+    @unittest.skip(
+        "Spark Connect does not support sc._jvm.org.apache.log4j but the test 
depends on it."
+    )
+    def test_apply_in_pandas_not_returning_pandas_dataframe(self):
+        super().test_apply_in_pandas_not_returning_pandas_dataframe()
+
+    @unittest.skip(
+        "Spark Connect does not support sc._jvm.org.apache.log4j but the test 
depends on it."
+    )
+    def test_apply_in_pandas_returning_wrong_column_names(self):
+        super().test_apply_in_pandas_returning_wrong_column_names()
+
+    @unittest.skip(
+        "Spark Connect does not support sc._jvm.org.apache.log4j but the test 
depends on it."
+    )
+    def test_apply_in_pandas_returning_no_column_names_and_wrong_amount(self):
+        
super().test_apply_in_pandas_returning_no_column_names_and_wrong_amount()
+
+    @unittest.skip(
+        "Spark Connect does not support sc._jvm.org.apache.log4j but the test 
depends on it."
+    )
+    def test_apply_in_pandas_returning_incompatible_type(self):
+        super().test_apply_in_pandas_returning_incompatible_type()
+
+    @unittest.skip(
+        "Spark Connect does not support sc._jvm.org.apache.log4j but the test 
depends on it."
+    )
+    def test_wrong_return_type(self):
+        super().test_wrong_return_type()
+
+
+if __name__ == "__main__":
+    from pyspark.sql.tests.connect.test_parity_pandas_cogrouped_map import *  
# noqa: F401
+
+    try:
+        import xmlrunner
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py 
b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
index 47ed12d2f46..1756fcc822a 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
@@ -43,7 +43,7 @@ if have_pyarrow:
     not have_pandas or not have_pyarrow,
     cast(str, pandas_requirement_message or pyarrow_requirement_message),
 )
-class CogroupedApplyInPandasTests(ReusedSQLTestCase):
+class CogroupedApplyInPandasTestsMixin:
     @property
     def data1(self):
         return (
@@ -538,6 +538,10 @@ class CogroupedApplyInPandasTests(ReusedSQLTestCase):
                 self.__test_merge(left, right, by, fn, output_schema)
 
 
+class CogroupedApplyInPandasTests(CogroupedApplyInPandasTestsMixin, 
ReusedSQLTestCase):
+    pass
+
+
 if __name__ == "__main__":
     from pyspark.sql.tests.pandas.test_pandas_cogrouped_map import *  # noqa: 
F401
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to