This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 1fbc7948e57 [SPARK-42891][CONNECT][PYTHON] Implement CoGrouped Map API
1fbc7948e57 is described below
commit 1fbc7948e57cbf05a46cb0c7fb2fad4ec25540e6
Author: Xinrong Meng <[email protected]>
AuthorDate: Thu Mar 23 12:38:30 2023 +0900
[SPARK-42891][CONNECT][PYTHON] Implement CoGrouped Map API
### What changes were proposed in this pull request?
Implement CoGrouped Map API: `applyInPandas`.
### Why are the changes needed?
Parity with vanilla PySpark.
### Does this PR introduce _any_ user-facing change?
Yes. CoGrouped Map API is supported as shown below.
```sh
>>> import pandas as pd
>>> df1 = spark.createDataFrame(
... [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0),
(20000102, 2, 4.0)], ("time", "id", "v1"))
>>>
>>> df2 = spark.createDataFrame(
... [(20000101, 1, "x"), (20000101, 2, "y")], ("time", "id", "v2"))
>>>
>>> def asof_join(l, r):
... return pd.merge_asof(l, r, on="time", by="id")
...
>>> df1.groupby("id").cogroup(df2.groupby("id")).applyInPandas(
... asof_join, schema="time int, id int, v1 double, v2 string"
... ).show()
+--------+---+---+---+
| time| id| v1| v2|
+--------+---+---+---+
|20000101| 1|1.0| x|
|20000102| 1|3.0| x|
|20000101| 2|2.0| y|
|20000102| 2|4.0| y|
+--------+---+---+---+
```
### How was this patch tested?
Parity unit tests.
Closes #40487 from xinrong-meng/cogroup_map.
Authored-by: Xinrong Meng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../main/protobuf/spark/connect/relations.proto | 18 ++
.../sql/connect/planner/SparkConnectPlanner.scala | 22 ++
dev/sparktestsupport/modules.py | 1 +
python/pyspark/sql/connect/_typing.py | 2 +
python/pyspark/sql/connect/group.py | 49 +++-
python/pyspark/sql/connect/plan.py | 40 ++++
python/pyspark/sql/connect/proto/relations_pb2.py | 246 +++++++++++----------
python/pyspark/sql/connect/proto/relations_pb2.pyi | 80 +++++++
python/pyspark/sql/pandas/group_ops.py | 9 +
.../sql/tests/connect/test_connect_basic.py | 5 +-
.../connect/test_parity_pandas_cogrouped_map.py | 82 +++++++
.../sql/tests/pandas/test_pandas_cogrouped_map.py | 6 +-
12 files changed, 437 insertions(+), 123 deletions(-)
diff --git
a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index aba965082ea..70aa399c424 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -64,6 +64,7 @@ message Relation {
CollectMetrics collect_metrics = 29;
Parse parse = 30;
GroupMap group_map = 31;
+ CoGroupMap co_group_map = 32;
// NA functions
NAFill fill_na = 90;
@@ -800,6 +801,23 @@ message GroupMap {
CommonInlineUserDefinedFunction func = 3;
}
+message CoGroupMap {
+ // (Required) One input relation for CoGroup Map API - applyInPandas.
+ Relation input = 1;
+
+ // Expressions for grouping keys of the first input relation.
+ repeated Expression input_grouping_expressions = 2;
+
+ // (Required) The other input relation.
+ Relation other = 3;
+
+ // Expressions for grouping keys of the other input relation.
+ repeated Expression other_grouping_expressions = 4;
+
+ // (Required) Input user-defined function.
+ CommonInlineUserDefinedFunction func = 5;
+}
+
// Collect arbitrary (named) metrics from a dataset.
message CollectMetrics {
// (Required) The input relation.
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 0faa2f74981..ba210d37231 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -119,6 +119,8 @@ class SparkConnectPlanner(val session: SparkSession) {
transformMapPartitions(rel.getMapPartitions)
case proto.Relation.RelTypeCase.GROUP_MAP =>
transformGroupMap(rel.getGroupMap)
+ case proto.Relation.RelTypeCase.CO_GROUP_MAP =>
+ transformCoGroupMap(rel.getCoGroupMap)
case proto.Relation.RelTypeCase.COLLECT_METRICS =>
transformCollectMetrics(rel.getCollectMetrics)
case proto.Relation.RelTypeCase.PARSE => transformParse(rel.getParse)
@@ -509,6 +511,26 @@ class SparkConnectPlanner(val session: SparkSession) {
.logicalPlan
}
+ private def transformCoGroupMap(rel: proto.CoGroupMap): LogicalPlan = {
+ val pythonUdf = transformPythonUDF(rel.getFunc)
+
+ val inputCols =
+ rel.getInputGroupingExpressionsList.asScala.toSeq.map(expr =>
+ Column(transformExpression(expr)))
+ val otherCols =
+ rel.getOtherGroupingExpressionsList.asScala.toSeq.map(expr =>
+ Column(transformExpression(expr)))
+
+ val input = Dataset
+ .ofRows(session, transformRelation(rel.getInput))
+ .groupBy(inputCols: _*)
+ val other = Dataset
+ .ofRows(session, transformRelation(rel.getOther))
+ .groupBy(otherCols: _*)
+
+ input.flatMapCoGroupsInPandas(other, pythonUdf).logicalPlan
+ }
+
private def transformWithColumnsRenamed(rel: proto.WithColumnsRenamed):
LogicalPlan = {
Dataset
.ofRows(session, transformRelation(rel.getInput))
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index f4efb412c87..c3c3b415a1f 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -768,6 +768,7 @@ pyspark_connect = Module(
"pyspark.sql.tests.connect.test_parity_pandas_map",
"pyspark.sql.tests.connect.test_parity_arrow_map",
"pyspark.sql.tests.connect.test_parity_pandas_grouped_map",
+ "pyspark.sql.tests.connect.test_parity_pandas_cogrouped_map",
# ml doctests
"pyspark.ml.connect.functions",
# ml unittests
diff --git a/python/pyspark/sql/connect/_typing.py
b/python/pyspark/sql/connect/_typing.py
index 63aae5d2487..6bdde2926cb 100644
--- a/python/pyspark/sql/connect/_typing.py
+++ b/python/pyspark/sql/connect/_typing.py
@@ -61,6 +61,8 @@ PandasGroupedMapFunction = Union[
GroupedMapPandasUserDefinedFunction =
NewType("GroupedMapPandasUserDefinedFunction", FunctionType)
+PandasCogroupedMapFunction = Callable[[DataFrameLike, DataFrameLike],
DataFrameLike]
+
class UserDefinedFunctionLike(Protocol):
func: Callable[..., Any]
diff --git a/python/pyspark/sql/connect/group.py
b/python/pyspark/sql/connect/group.py
index a75a50501bd..8377caac592 100644
--- a/python/pyspark/sql/connect/group.py
+++ b/python/pyspark/sql/connect/group.py
@@ -34,6 +34,7 @@ from typing import (
from pyspark.rdd import PythonEvalType
from pyspark.sql.group import GroupedData as PySparkGroupedData
+from pyspark.sql.pandas.group_ops import PandasCogroupedOps as
PySparkPandasCogroupedOps
from pyspark.sql.types import NumericType
import pyspark.sql.connect.plan as plan
@@ -45,6 +46,7 @@ if TYPE_CHECKING:
LiteralType,
PandasGroupedMapFunction,
GroupedMapPandasUserDefinedFunction,
+ PandasCogroupedMapFunction,
)
from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.types import StructType
@@ -263,13 +265,56 @@ class GroupedData:
def applyInPandasWithState(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("applyInPandasWithState() is not
implemented.")
- def cogroup(self, *args: Any, **kwargs: Any) -> None:
- raise NotImplementedError("cogroup() is not implemented.")
+ def cogroup(self, other: "GroupedData") -> "PandasCogroupedOps":
+ return PandasCogroupedOps(self, other)
+
+ cogroup.__doc__ = PySparkGroupedData.cogroup.__doc__
GroupedData.__doc__ = PySparkGroupedData.__doc__
+class PandasCogroupedOps:
+ def __init__(self, gd1: "GroupedData", gd2: "GroupedData"):
+ self._gd1 = gd1
+ self._gd2 = gd2
+
+ def applyInPandas(
+ self, func: "PandasCogroupedMapFunction", schema: Union["StructType",
str]
+ ) -> "DataFrame":
+ from pyspark.sql.connect.udf import UserDefinedFunction
+ from pyspark.sql.connect.dataframe import DataFrame
+
+ udf_obj = UserDefinedFunction(
+ func,
+ returnType=schema,
+ evalType=PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
+ )
+
+ all_cols = self._extract_cols(self._gd1) +
self._extract_cols(self._gd2)
+ return DataFrame.withPlan(
+ plan.CoGroupMap(
+ input=self._gd1._df._plan,
+ input_grouping_cols=self._gd1._grouping_cols,
+ other=self._gd2._df._plan,
+ other_grouping_cols=self._gd2._grouping_cols,
+ function=udf_obj,
+ cols=all_cols,
+ ),
+ session=self._gd1._df._session,
+ )
+
+ applyInPandas.__doc__ = PySparkPandasCogroupedOps.applyInPandas.__doc__
+
+ @staticmethod
+ def _extract_cols(gd: "GroupedData") -> List[Column]:
+ df = gd._df
+ return [df[col] for col in df.columns]
+
+
+PandasCogroupedOps.__doc__ = PySparkPandasCogroupedOps.__doc__
+
+
def _test() -> None:
import sys
import doctest
diff --git a/python/pyspark/sql/connect/plan.py
b/python/pyspark/sql/connect/plan.py
index dbfcfea7678..34eaf0d5bee 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -1950,6 +1950,46 @@ class GroupMap(LogicalPlan):
return plan
+class CoGroupMap(LogicalPlan):
+ """Logical plan object for a CoGroup Map API: applyInPandas."""
+
+ def __init__(
+ self,
+ input: Optional["LogicalPlan"],
+ input_grouping_cols: Sequence[Column],
+ other: Optional["LogicalPlan"],
+ other_grouping_cols: Sequence[Column],
+ function: "UserDefinedFunction",
+ cols: List[Column],
+ ):
+ assert isinstance(input_grouping_cols, list) and all(
+ isinstance(c, Column) for c in input_grouping_cols
+ )
+ assert isinstance(other_grouping_cols, list) and all(
+ isinstance(c, Column) for c in other_grouping_cols
+ )
+
+ super().__init__(input)
+ self._input_grouping_cols = input_grouping_cols
+ self._other_grouping_cols = other_grouping_cols
+ self._other = cast(LogicalPlan, other)
+ self._func = function._build_common_inline_user_defined_function(*cols)
+
+ def plan(self, session: "SparkConnectClient") -> proto.Relation:
+ assert self._child is not None
+ plan = self._create_proto_relation()
+ plan.co_group_map.input.CopyFrom(self._child.plan(session))
+ plan.co_group_map.input_grouping_expressions.extend(
+ [c.to_plan(session) for c in self._input_grouping_cols]
+ )
+ plan.co_group_map.other.CopyFrom(self._other.plan(session))
+ plan.co_group_map.other_grouping_expressions.extend(
+ [c.to_plan(session) for c in self._other_grouping_cols]
+ )
+ plan.co_group_map.func.CopyFrom(self._func.to_plan_udf(session))
+ return plan
+
+
class CachedRelation(LogicalPlan):
def __init__(self, plan: proto.Relation) -> None:
super(CachedRelation, self).__init__(None)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py
b/python/pyspark/sql/connect/proto/relations_pb2.py
index aa6d39cd4f0..60ec7081e09 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as
spark_dot_connect_dot_catal
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xf0\x13\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
+
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xaf\x14\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
)
@@ -93,6 +93,7 @@ _TOSCHEMA = DESCRIPTOR.message_types_by_name["ToSchema"]
_REPARTITIONBYEXPRESSION =
DESCRIPTOR.message_types_by_name["RepartitionByExpression"]
_MAPPARTITIONS = DESCRIPTOR.message_types_by_name["MapPartitions"]
_GROUPMAP = DESCRIPTOR.message_types_by_name["GroupMap"]
+_COGROUPMAP = DESCRIPTOR.message_types_by_name["CoGroupMap"]
_COLLECTMETRICS = DESCRIPTOR.message_types_by_name["CollectMetrics"]
_PARSE = DESCRIPTOR.message_types_by_name["Parse"]
_PARSE_OPTIONSENTRY = _PARSE.nested_types_by_name["OptionsEntry"]
@@ -652,6 +653,17 @@ GroupMap = _reflection.GeneratedProtocolMessageType(
)
_sym_db.RegisterMessage(GroupMap)
+CoGroupMap = _reflection.GeneratedProtocolMessageType(
+ "CoGroupMap",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _COGROUPMAP,
+ "__module__": "spark.connect.relations_pb2"
+ # @@protoc_insertion_point(class_scope:spark.connect.CoGroupMap)
+ },
+)
+_sym_db.RegisterMessage(CoGroupMap)
+
CollectMetrics = _reflection.GeneratedProtocolMessageType(
"CollectMetrics",
(_message.Message,),
@@ -697,119 +709,121 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_PARSE_OPTIONSENTRY._options = None
_PARSE_OPTIONSENTRY._serialized_options = b"8\001"
_RELATION._serialized_start = 165
- _RELATION._serialized_end = 2709
- _UNKNOWN._serialized_start = 2711
- _UNKNOWN._serialized_end = 2720
- _RELATIONCOMMON._serialized_start = 2722
- _RELATIONCOMMON._serialized_end = 2813
- _SQL._serialized_start = 2816
- _SQL._serialized_end = 2950
- _SQL_ARGSENTRY._serialized_start = 2895
- _SQL_ARGSENTRY._serialized_end = 2950
- _READ._serialized_start = 2953
- _READ._serialized_end = 3449
- _READ_NAMEDTABLE._serialized_start = 3095
- _READ_NAMEDTABLE._serialized_end = 3156
- _READ_DATASOURCE._serialized_start = 3159
- _READ_DATASOURCE._serialized_end = 3436
- _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3356
- _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3414
- _PROJECT._serialized_start = 3451
- _PROJECT._serialized_end = 3568
- _FILTER._serialized_start = 3570
- _FILTER._serialized_end = 3682
- _JOIN._serialized_start = 3685
- _JOIN._serialized_end = 4156
- _JOIN_JOINTYPE._serialized_start = 3948
- _JOIN_JOINTYPE._serialized_end = 4156
- _SETOPERATION._serialized_start = 4159
- _SETOPERATION._serialized_end = 4638
- _SETOPERATION_SETOPTYPE._serialized_start = 4475
- _SETOPERATION_SETOPTYPE._serialized_end = 4589
- _LIMIT._serialized_start = 4640
- _LIMIT._serialized_end = 4716
- _OFFSET._serialized_start = 4718
- _OFFSET._serialized_end = 4797
- _TAIL._serialized_start = 4799
- _TAIL._serialized_end = 4874
- _AGGREGATE._serialized_start = 4877
- _AGGREGATE._serialized_end = 5459
- _AGGREGATE_PIVOT._serialized_start = 5216
- _AGGREGATE_PIVOT._serialized_end = 5327
- _AGGREGATE_GROUPTYPE._serialized_start = 5330
- _AGGREGATE_GROUPTYPE._serialized_end = 5459
- _SORT._serialized_start = 5462
- _SORT._serialized_end = 5622
- _DROP._serialized_start = 5625
- _DROP._serialized_end = 5766
- _DEDUPLICATE._serialized_start = 5769
- _DEDUPLICATE._serialized_end = 5940
- _LOCALRELATION._serialized_start = 5942
- _LOCALRELATION._serialized_end = 6031
- _SAMPLE._serialized_start = 6034
- _SAMPLE._serialized_end = 6307
- _RANGE._serialized_start = 6310
- _RANGE._serialized_end = 6455
- _SUBQUERYALIAS._serialized_start = 6457
- _SUBQUERYALIAS._serialized_end = 6571
- _REPARTITION._serialized_start = 6574
- _REPARTITION._serialized_end = 6716
- _SHOWSTRING._serialized_start = 6719
- _SHOWSTRING._serialized_end = 6861
- _STATSUMMARY._serialized_start = 6863
- _STATSUMMARY._serialized_end = 6955
- _STATDESCRIBE._serialized_start = 6957
- _STATDESCRIBE._serialized_end = 7038
- _STATCROSSTAB._serialized_start = 7040
- _STATCROSSTAB._serialized_end = 7141
- _STATCOV._serialized_start = 7143
- _STATCOV._serialized_end = 7239
- _STATCORR._serialized_start = 7242
- _STATCORR._serialized_end = 7379
- _STATAPPROXQUANTILE._serialized_start = 7382
- _STATAPPROXQUANTILE._serialized_end = 7546
- _STATFREQITEMS._serialized_start = 7548
- _STATFREQITEMS._serialized_end = 7673
- _STATSAMPLEBY._serialized_start = 7676
- _STATSAMPLEBY._serialized_end = 7985
- _STATSAMPLEBY_FRACTION._serialized_start = 7877
- _STATSAMPLEBY_FRACTION._serialized_end = 7976
- _NAFILL._serialized_start = 7988
- _NAFILL._serialized_end = 8122
- _NADROP._serialized_start = 8125
- _NADROP._serialized_end = 8259
- _NAREPLACE._serialized_start = 8262
- _NAREPLACE._serialized_end = 8558
- _NAREPLACE_REPLACEMENT._serialized_start = 8417
- _NAREPLACE_REPLACEMENT._serialized_end = 8558
- _TODF._serialized_start = 8560
- _TODF._serialized_end = 8648
- _WITHCOLUMNSRENAMED._serialized_start = 8651
- _WITHCOLUMNSRENAMED._serialized_end = 8890
- _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 8823
- _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 8890
- _WITHCOLUMNS._serialized_start = 8892
- _WITHCOLUMNS._serialized_end = 9011
- _HINT._serialized_start = 9014
- _HINT._serialized_end = 9146
- _UNPIVOT._serialized_start = 9149
- _UNPIVOT._serialized_end = 9476
- _UNPIVOT_VALUES._serialized_start = 9406
- _UNPIVOT_VALUES._serialized_end = 9465
- _TOSCHEMA._serialized_start = 9478
- _TOSCHEMA._serialized_end = 9584
- _REPARTITIONBYEXPRESSION._serialized_start = 9587
- _REPARTITIONBYEXPRESSION._serialized_end = 9790
- _MAPPARTITIONS._serialized_start = 9793
- _MAPPARTITIONS._serialized_end = 9923
- _GROUPMAP._serialized_start = 9926
- _GROUPMAP._serialized_end = 10129
- _COLLECTMETRICS._serialized_start = 10132
- _COLLECTMETRICS._serialized_end = 10268
- _PARSE._serialized_start = 10271
- _PARSE._serialized_end = 10659
- _PARSE_OPTIONSENTRY._serialized_start = 3356
- _PARSE_OPTIONSENTRY._serialized_end = 3414
- _PARSE_PARSEFORMAT._serialized_start = 10560
- _PARSE_PARSEFORMAT._serialized_end = 10648
+ _RELATION._serialized_end = 2772
+ _UNKNOWN._serialized_start = 2774
+ _UNKNOWN._serialized_end = 2783
+ _RELATIONCOMMON._serialized_start = 2785
+ _RELATIONCOMMON._serialized_end = 2876
+ _SQL._serialized_start = 2879
+ _SQL._serialized_end = 3013
+ _SQL_ARGSENTRY._serialized_start = 2958
+ _SQL_ARGSENTRY._serialized_end = 3013
+ _READ._serialized_start = 3016
+ _READ._serialized_end = 3512
+ _READ_NAMEDTABLE._serialized_start = 3158
+ _READ_NAMEDTABLE._serialized_end = 3219
+ _READ_DATASOURCE._serialized_start = 3222
+ _READ_DATASOURCE._serialized_end = 3499
+ _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3419
+ _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3477
+ _PROJECT._serialized_start = 3514
+ _PROJECT._serialized_end = 3631
+ _FILTER._serialized_start = 3633
+ _FILTER._serialized_end = 3745
+ _JOIN._serialized_start = 3748
+ _JOIN._serialized_end = 4219
+ _JOIN_JOINTYPE._serialized_start = 4011
+ _JOIN_JOINTYPE._serialized_end = 4219
+ _SETOPERATION._serialized_start = 4222
+ _SETOPERATION._serialized_end = 4701
+ _SETOPERATION_SETOPTYPE._serialized_start = 4538
+ _SETOPERATION_SETOPTYPE._serialized_end = 4652
+ _LIMIT._serialized_start = 4703
+ _LIMIT._serialized_end = 4779
+ _OFFSET._serialized_start = 4781
+ _OFFSET._serialized_end = 4860
+ _TAIL._serialized_start = 4862
+ _TAIL._serialized_end = 4937
+ _AGGREGATE._serialized_start = 4940
+ _AGGREGATE._serialized_end = 5522
+ _AGGREGATE_PIVOT._serialized_start = 5279
+ _AGGREGATE_PIVOT._serialized_end = 5390
+ _AGGREGATE_GROUPTYPE._serialized_start = 5393
+ _AGGREGATE_GROUPTYPE._serialized_end = 5522
+ _SORT._serialized_start = 5525
+ _SORT._serialized_end = 5685
+ _DROP._serialized_start = 5688
+ _DROP._serialized_end = 5829
+ _DEDUPLICATE._serialized_start = 5832
+ _DEDUPLICATE._serialized_end = 6003
+ _LOCALRELATION._serialized_start = 6005
+ _LOCALRELATION._serialized_end = 6094
+ _SAMPLE._serialized_start = 6097
+ _SAMPLE._serialized_end = 6370
+ _RANGE._serialized_start = 6373
+ _RANGE._serialized_end = 6518
+ _SUBQUERYALIAS._serialized_start = 6520
+ _SUBQUERYALIAS._serialized_end = 6634
+ _REPARTITION._serialized_start = 6637
+ _REPARTITION._serialized_end = 6779
+ _SHOWSTRING._serialized_start = 6782
+ _SHOWSTRING._serialized_end = 6924
+ _STATSUMMARY._serialized_start = 6926
+ _STATSUMMARY._serialized_end = 7018
+ _STATDESCRIBE._serialized_start = 7020
+ _STATDESCRIBE._serialized_end = 7101
+ _STATCROSSTAB._serialized_start = 7103
+ _STATCROSSTAB._serialized_end = 7204
+ _STATCOV._serialized_start = 7206
+ _STATCOV._serialized_end = 7302
+ _STATCORR._serialized_start = 7305
+ _STATCORR._serialized_end = 7442
+ _STATAPPROXQUANTILE._serialized_start = 7445
+ _STATAPPROXQUANTILE._serialized_end = 7609
+ _STATFREQITEMS._serialized_start = 7611
+ _STATFREQITEMS._serialized_end = 7736
+ _STATSAMPLEBY._serialized_start = 7739
+ _STATSAMPLEBY._serialized_end = 8048
+ _STATSAMPLEBY_FRACTION._serialized_start = 7940
+ _STATSAMPLEBY_FRACTION._serialized_end = 8039
+ _NAFILL._serialized_start = 8051
+ _NAFILL._serialized_end = 8185
+ _NADROP._serialized_start = 8188
+ _NADROP._serialized_end = 8322
+ _NAREPLACE._serialized_start = 8325
+ _NAREPLACE._serialized_end = 8621
+ _NAREPLACE_REPLACEMENT._serialized_start = 8480
+ _NAREPLACE_REPLACEMENT._serialized_end = 8621
+ _TODF._serialized_start = 8623
+ _TODF._serialized_end = 8711
+ _WITHCOLUMNSRENAMED._serialized_start = 8714
+ _WITHCOLUMNSRENAMED._serialized_end = 8953
+ _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 8886
+ _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 8953
+ _WITHCOLUMNS._serialized_start = 8955
+ _WITHCOLUMNS._serialized_end = 9074
+ _HINT._serialized_start = 9077
+ _HINT._serialized_end = 9209
+ _UNPIVOT._serialized_start = 9212
+ _UNPIVOT._serialized_end = 9539
+ _UNPIVOT_VALUES._serialized_start = 9469
+ _UNPIVOT_VALUES._serialized_end = 9528
+ _TOSCHEMA._serialized_start = 9541
+ _TOSCHEMA._serialized_end = 9647
+ _REPARTITIONBYEXPRESSION._serialized_start = 9650
+ _REPARTITIONBYEXPRESSION._serialized_end = 9853
+ _MAPPARTITIONS._serialized_start = 9856
+ _MAPPARTITIONS._serialized_end = 9986
+ _GROUPMAP._serialized_start = 9989
+ _GROUPMAP._serialized_end = 10192
+ _COGROUPMAP._serialized_start = 10195
+ _COGROUPMAP._serialized_end = 10547
+ _COLLECTMETRICS._serialized_start = 10550
+ _COLLECTMETRICS._serialized_end = 10686
+ _PARSE._serialized_start = 10689
+ _PARSE._serialized_end = 11077
+ _PARSE_OPTIONSENTRY._serialized_start = 3419
+ _PARSE_OPTIONSENTRY._serialized_end = 3477
+ _PARSE_PARSEFORMAT._serialized_start = 10978
+ _PARSE_PARSEFORMAT._serialized_end = 11066
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 6ae4a323f6f..814b23ca26d 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -93,6 +93,7 @@ class Relation(google.protobuf.message.Message):
COLLECT_METRICS_FIELD_NUMBER: builtins.int
PARSE_FIELD_NUMBER: builtins.int
GROUP_MAP_FIELD_NUMBER: builtins.int
+ CO_GROUP_MAP_FIELD_NUMBER: builtins.int
FILL_NA_FIELD_NUMBER: builtins.int
DROP_NA_FIELD_NUMBER: builtins.int
REPLACE_FIELD_NUMBER: builtins.int
@@ -170,6 +171,8 @@ class Relation(google.protobuf.message.Message):
@property
def group_map(self) -> global___GroupMap: ...
@property
+ def co_group_map(self) -> global___CoGroupMap: ...
+ @property
def fill_na(self) -> global___NAFill:
"""NA functions"""
@property
@@ -237,6 +240,7 @@ class Relation(google.protobuf.message.Message):
collect_metrics: global___CollectMetrics | None = ...,
parse: global___Parse | None = ...,
group_map: global___GroupMap | None = ...,
+ co_group_map: global___CoGroupMap | None = ...,
fill_na: global___NAFill | None = ...,
drop_na: global___NADrop | None = ...,
replace: global___NAReplace | None = ...,
@@ -261,6 +265,8 @@ class Relation(google.protobuf.message.Message):
b"approx_quantile",
"catalog",
b"catalog",
+ "co_group_map",
+ b"co_group_map",
"collect_metrics",
b"collect_metrics",
"common",
@@ -358,6 +364,8 @@ class Relation(google.protobuf.message.Message):
b"approx_quantile",
"catalog",
b"catalog",
+ "co_group_map",
+ b"co_group_map",
"collect_metrics",
b"collect_metrics",
"common",
@@ -479,6 +487,7 @@ class Relation(google.protobuf.message.Message):
"collect_metrics",
"parse",
"group_map",
+ "co_group_map",
"fill_na",
"drop_na",
"replace",
@@ -2784,6 +2793,77 @@ class GroupMap(google.protobuf.message.Message):
global___GroupMap = GroupMap
+class CoGroupMap(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ INPUT_FIELD_NUMBER: builtins.int
+ INPUT_GROUPING_EXPRESSIONS_FIELD_NUMBER: builtins.int
+ OTHER_FIELD_NUMBER: builtins.int
+ OTHER_GROUPING_EXPRESSIONS_FIELD_NUMBER: builtins.int
+ FUNC_FIELD_NUMBER: builtins.int
+ @property
+ def input(self) -> global___Relation:
+ """(Required) One input relation for CoGroup Map API -
applyInPandas."""
+ @property
+ def input_grouping_expressions(
+ self,
+ ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+ pyspark.sql.connect.proto.expressions_pb2.Expression
+ ]:
+ """Expressions for grouping keys of the first input relation."""
+ @property
+ def other(self) -> global___Relation:
+ """(Required) The other input relation."""
+ @property
+ def other_grouping_expressions(
+ self,
+ ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+ pyspark.sql.connect.proto.expressions_pb2.Expression
+ ]:
+ """Expressions for grouping keys of the other input relation."""
+ @property
+ def func(self) ->
pyspark.sql.connect.proto.expressions_pb2.CommonInlineUserDefinedFunction:
+ """(Required) Input user-defined function."""
+ def __init__(
+ self,
+ *,
+ input: global___Relation | None = ...,
+ input_grouping_expressions: collections.abc.Iterable[
+ pyspark.sql.connect.proto.expressions_pb2.Expression
+ ]
+ | None = ...,
+ other: global___Relation | None = ...,
+ other_grouping_expressions: collections.abc.Iterable[
+ pyspark.sql.connect.proto.expressions_pb2.Expression
+ ]
+ | None = ...,
+ func:
pyspark.sql.connect.proto.expressions_pb2.CommonInlineUserDefinedFunction
+ | None = ...,
+ ) -> None: ...
+ def HasField(
+ self,
+ field_name: typing_extensions.Literal[
+ "func", b"func", "input", b"input", "other", b"other"
+ ],
+ ) -> builtins.bool: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "func",
+ b"func",
+ "input",
+ b"input",
+ "input_grouping_expressions",
+ b"input_grouping_expressions",
+ "other",
+ b"other",
+ "other_grouping_expressions",
+ b"other_grouping_expressions",
+ ],
+ ) -> None: ...
+
+global___CoGroupMap = CoGroupMap
+
class CollectMetrics(google.protobuf.message.Message):
"""Collect arbitrary (named) metrics from a dataset."""
diff --git a/python/pyspark/sql/pandas/group_ops.py
b/python/pyspark/sql/pandas/group_ops.py
index f03aa35bb83..12ceb56c79f 100644
--- a/python/pyspark/sql/pandas/group_ops.py
+++ b/python/pyspark/sql/pandas/group_ops.py
@@ -353,6 +353,9 @@ class PandasGroupedOpsMixin:
.. versionadded:: 3.0.0
+ .. versionchanged:: 3.4.0
+ Support Spark Connect.
+
See :class:`PandasCogroupedOps` for the operations that can be run.
"""
from pyspark.sql import GroupedData
@@ -369,6 +372,9 @@ class PandasCogroupedOps:
.. versionadded:: 3.0.0
+ .. versionchanged:: 3.4.0
+ Support Spark Connect.
+
Notes
-----
This API is experimental.
@@ -400,6 +406,9 @@ class PandasCogroupedOps:
.. versionadded:: 3.0.0
+ .. versionchanged:: 3.4.0
+ Support Spark Connect.
+
Parameters
----------
func : function
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index f911ca9ba78..682b3471a74 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -2841,10 +2841,7 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
def test_unsupported_group_functions(self):
# SPARK-41927: Disable unsupported functions.
cg = self.connect.read.table(self.tbl_name).groupBy("id")
- for f in (
- "applyInPandasWithState",
- "cogroup",
- ):
+ for f in ("applyInPandasWithState",):
with self.assertRaises(NotImplementedError):
getattr(cg, f)()
diff --git
a/python/pyspark/sql/tests/connect/test_parity_pandas_cogrouped_map.py
b/python/pyspark/sql/tests/connect/test_parity_pandas_cogrouped_map.py
new file mode 100644
index 00000000000..c03bc5f8219
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_parity_pandas_cogrouped_map.py
@@ -0,0 +1,82 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+from pyspark.sql.tests.pandas.test_pandas_cogrouped_map import
CogroupedApplyInPandasTestsMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+
+
+class CogroupedApplyInPandasTests(CogroupedApplyInPandasTestsMixin,
ReusedConnectTestCase):
+ @unittest.skip(
+ "Spark Connect does not support sc._jvm.org.apache.log4j but the test
depends on it."
+ )
+ def test_different_group_key_cardinality(self):
+ super().test_different_group_key_cardinality()
+
+ @unittest.skip(
+ "Spark Connect does not support sc._jvm.org.apache.log4j but the test
depends on it."
+ )
+ def test_apply_in_pandas_returning_incompatible_type(self):
+ super().test_apply_in_pandas_returning_incompatible_type()
+
+ @unittest.skip(
+ "Spark Connect does not support sc._jvm.org.apache.log4j but the test
depends on it."
+ )
+ def test_wrong_args(self):
+ super().test_wrong_args()
+
+ @unittest.skip(
+ "Spark Connect does not support sc._jvm.org.apache.log4j but the test
depends on it."
+ )
+ def test_apply_in_pandas_not_returning_pandas_dataframe(self):
+ super().test_apply_in_pandas_not_returning_pandas_dataframe()
+
+ @unittest.skip(
+ "Spark Connect does not support sc._jvm.org.apache.log4j but the test
depends on it."
+ )
+ def test_apply_in_pandas_returning_wrong_column_names(self):
+ super().test_apply_in_pandas_returning_wrong_column_names()
+
+ @unittest.skip(
+ "Spark Connect does not support sc._jvm.org.apache.log4j but the test
depends on it."
+ )
+ def test_apply_in_pandas_returning_no_column_names_and_wrong_amount(self):
+
super().test_apply_in_pandas_returning_no_column_names_and_wrong_amount()
+
+ @unittest.skip(
+ "Spark Connect does not support sc._jvm.org.apache.log4j but the test
depends on it."
+ )
+ def test_apply_in_pandas_returning_incompatible_type(self):
+ super().test_apply_in_pandas_returning_incompatible_type()
+
+ @unittest.skip(
+ "Spark Connect does not support sc._jvm.org.apache.log4j but the test
depends on it."
+ )
+ def test_wrong_return_type(self):
+ super().test_wrong_return_type()
+
+
+if __name__ == "__main__":
+ from pyspark.sql.tests.connect.test_parity_pandas_cogrouped_map import *
# noqa: F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
index 47ed12d2f46..1756fcc822a 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
@@ -43,7 +43,7 @@ if have_pyarrow:
not have_pandas or not have_pyarrow,
cast(str, pandas_requirement_message or pyarrow_requirement_message),
)
-class CogroupedApplyInPandasTests(ReusedSQLTestCase):
+class CogroupedApplyInPandasTestsMixin:
@property
def data1(self):
return (
@@ -538,6 +538,10 @@ class CogroupedApplyInPandasTests(ReusedSQLTestCase):
self.__test_merge(left, right, by, fn, output_schema)
+class CogroupedApplyInPandasTests(CogroupedApplyInPandasTestsMixin,
ReusedSQLTestCase):
+ pass
+
+
if __name__ == "__main__":
from pyspark.sql.tests.pandas.test_pandas_cogrouped_map import * # noqa:
F401
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]