This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new fee47c77e5d [SPARK-42340][CONNECT][PYTHON] Implement Grouped Map API fee47c77e5d is described below commit fee47c77e5d31bce592bf6e2bd33c2dabfc57bd3 Author: Xinrong Meng <xinr...@apache.org> AuthorDate: Mon Mar 20 20:04:00 2023 +0900 [SPARK-42340][CONNECT][PYTHON] Implement Grouped Map API ### What changes were proposed in this pull request? Implement Grouped Map API:`GroupedData.applyInPandas` and `GroupedData.apply`. ### Why are the changes needed? Parity with vanilla PySpark. ### Does this PR introduce _any_ user-facing change? Yes. `GroupedData.applyInPandas` and `GroupedData.apply` are supported now, as shown below. ```sh >>> df = spark.createDataFrame([(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],("id", "v")) >>> def normalize(pdf): ... v = pdf.v ... return pdf.assign(v=(v - v.mean()) / v.std()) ... >>> df.groupby("id").applyInPandas(normalize, schema="id long, v double").show() +---+-------------------+ | id| v| +---+-------------------+ | 1|-0.7071067811865475| | 1| 0.7071067811865475| | 2|-0.8320502943378437| | 2|-0.2773500981126146| | 2| 1.1094003924504583| +---+-------------------+ ``` ```sh >>> pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) ... def normalize(pdf): ... v = pdf.v ... return pdf.assign(v=(v - v.mean()) / v.std()) ... >>> df.groupby("id").apply(normalize).show() /Users/xinrong.meng/spark/python/pyspark/sql/connect/group.py:228: UserWarning: It is preferred to use 'applyInPandas' over this API. This API will be deprecated in the future releases. See SPARK-28264 for more details. warnings.warn( +---+-------------------+ | id| v| +---+-------------------+ | 1|-0.7071067811865475| | 1| 0.7071067811865475| | 2|-0.8320502943378437| | 2|-0.2773500981126146| | 2| 1.1094003924504583| +---+-------------------+ ``` ### How was this patch tested? (Parity) Unit tests. Closes #40405 from xinrong-meng/group_map. Authored-by: Xinrong Meng <xinr...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../main/protobuf/spark/connect/relations.proto | 12 + .../sql/connect/planner/SparkConnectPlanner.scala | 14 ++ dev/sparktestsupport/modules.py | 2 +- python/pyspark/sql/connect/_typing.py | 10 +- python/pyspark/sql/connect/group.py | 61 +++++- python/pyspark/sql/connect/plan.py | 27 +++ python/pyspark/sql/connect/proto/relations_pb2.py | 242 +++++++++++---------- python/pyspark/sql/connect/proto/relations_pb2.pyi | 51 +++++ python/pyspark/sql/pandas/group_ops.py | 6 + .../sql/tests/connect/test_connect_basic.py | 2 - .../connect/test_parity_pandas_grouped_map.py | 102 +++++++++ .../sql/tests/connect/test_parity_pandas_udf.py | 5 - .../sql/tests/pandas/test_pandas_grouped_map.py | 16 +- 13 files changed, 416 insertions(+), 134 deletions(-) diff --git a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto index 69451e7b76e..aba965082ea 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto @@ -63,6 +63,7 @@ message Relation { MapPartitions map_partitions = 28; CollectMetrics collect_metrics = 29; Parse parse = 30; + GroupMap group_map = 31; // NA functions NAFill fill_na = 90; @@ -788,6 +789,17 @@ message MapPartitions { CommonInlineUserDefinedFunction func = 2; } +message GroupMap { + // (Required) Input relation for Group Map API: apply, applyInPandas. + Relation input = 1; + + // (Required) Expressions for grouping keys. + repeated Expression grouping_expressions = 2; + + // (Required) Input user-defined function. + CommonInlineUserDefinedFunction func = 3; +} + // Collect arbitrary (named) metrics from a dataset. message CollectMetrics { // (Required) The input relation. diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index b023adac98a..c8fdaa6641a 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -117,6 +117,8 @@ class SparkConnectPlanner(val session: SparkSession) { transformRepartitionByExpression(rel.getRepartitionByExpression) case proto.Relation.RelTypeCase.MAP_PARTITIONS => transformMapPartitions(rel.getMapPartitions) + case proto.Relation.RelTypeCase.GROUP_MAP => + transformGroupMap(rel.getGroupMap) case proto.Relation.RelTypeCase.COLLECT_METRICS => transformCollectMetrics(rel.getCollectMetrics) case proto.Relation.RelTypeCase.PARSE => transformParse(rel.getParse) @@ -495,6 +497,18 @@ class SparkConnectPlanner(val session: SparkSession) { } } + private def transformGroupMap(rel: proto.GroupMap): LogicalPlan = { + val pythonUdf = transformPythonUDF(rel.getFunc) + val cols = + rel.getGroupingExpressionsList.asScala.toSeq.map(expr => Column(transformExpression(expr))) + + Dataset + .ofRows(session, transformRelation(rel.getInput)) + .groupBy(cols: _*) + .flatMapGroupsInPandas(pythonUdf) + .logicalPlan + } + private def transformWithColumnsRenamed(rel: proto.WithColumnsRenamed): LogicalPlan = { Dataset .ofRows(session, transformRelation(rel.getInput)) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 5379f883815..c31a9362cd7 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -513,7 +513,6 @@ pyspark_sql = Module( ], ) - pyspark_resource = Module( name="pyspark-resource", dependencies=[pyspark_core], @@ -779,6 +778,7 @@ pyspark_connect = Module( "pyspark.sql.tests.connect.test_parity_pandas_udf", "pyspark.sql.tests.connect.test_parity_pandas_map", "pyspark.sql.tests.connect.test_parity_arrow_map", + "pyspark.sql.tests.connect.test_parity_pandas_grouped_map", # ml doctests "pyspark.ml.connect.functions", # ml unittests diff --git a/python/pyspark/sql/connect/_typing.py b/python/pyspark/sql/connect/_typing.py index 6df3f15d87d..63aae5d2487 100644 --- a/python/pyspark/sql/connect/_typing.py +++ b/python/pyspark/sql/connect/_typing.py @@ -22,7 +22,8 @@ if sys.version_info >= (3, 8): else: from typing_extensions import Protocol -from typing import Any, Callable, Iterable, Union, Optional +from types import FunctionType +from typing import Any, Callable, Iterable, Union, Optional, NewType import datetime import decimal @@ -53,6 +54,13 @@ PandasMapIterFunction = Callable[[Iterable[DataFrameLike]], Iterable[DataFrameLi ArrowMapIterFunction = Callable[[Iterable[pyarrow.RecordBatch]], Iterable[pyarrow.RecordBatch]] +PandasGroupedMapFunction = Union[ + Callable[[DataFrameLike], DataFrameLike], + Callable[[Any, DataFrameLike], DataFrameLike], +] + +GroupedMapPandasUserDefinedFunction = NewType("GroupedMapPandasUserDefinedFunction", FunctionType) + class UserDefinedFunctionLike(Protocol): func: Callable[..., Any] diff --git a/python/pyspark/sql/connect/group.py b/python/pyspark/sql/connect/group.py index e699ce7105a..a75a50501bd 100644 --- a/python/pyspark/sql/connect/group.py +++ b/python/pyspark/sql/connect/group.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import warnings + from pyspark.sql.connect.utils import check_dependencies check_dependencies(__name__) @@ -30,6 +32,7 @@ from typing import ( cast, ) +from pyspark.rdd import PythonEvalType from pyspark.sql.group import GroupedData as PySparkGroupedData from pyspark.sql.types import NumericType @@ -38,8 +41,13 @@ from pyspark.sql.connect.column import Column from pyspark.sql.connect.functions import _invoke_function, col, lit if TYPE_CHECKING: - from pyspark.sql.connect._typing import LiteralType + from pyspark.sql.connect._typing import ( + LiteralType, + PandasGroupedMapFunction, + GroupedMapPandasUserDefinedFunction, + ) from pyspark.sql.connect.dataframe import DataFrame + from pyspark.sql.types import StructType class GroupedData: @@ -203,11 +211,54 @@ class GroupedData: pivot.__doc__ = PySparkGroupedData.pivot.__doc__ - def apply(self, *args: Any, **kwargs: Any) -> None: - raise NotImplementedError("apply() is not implemented.") + def apply(self, udf: "GroupedMapPandasUserDefinedFunction") -> "DataFrame": + # Columns are special because hasattr always return True + if ( + isinstance(udf, Column) + or not hasattr(udf, "func") + or ( + udf.evalType # type: ignore[attr-defined] + != PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF + ) + ): + raise ValueError( + "Invalid udf: the udf argument must be a pandas_udf of type " "GROUPED_MAP." + ) + + warnings.warn( + "It is preferred to use 'applyInPandas' over this " + "API. This API will be deprecated in the future releases. See SPARK-28264 for " + "more details.", + UserWarning, + ) + + return self.applyInPandas(udf.func, schema=udf.returnType) # type: ignore[attr-defined] + + apply.__doc__ = PySparkGroupedData.apply.__doc__ + + def applyInPandas( + self, func: "PandasGroupedMapFunction", schema: Union["StructType", str] + ) -> "DataFrame": + from pyspark.sql.connect.udf import UserDefinedFunction + from pyspark.sql.connect.dataframe import DataFrame + + udf_obj = UserDefinedFunction( + func, + returnType=schema, + evalType=PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + ) + + return DataFrame.withPlan( + plan.GroupMap( + child=self._df._plan, + grouping_cols=self._grouping_cols, + function=udf_obj, + cols=self._df.columns, + ), + session=self._df._session, + ) - def applyInPandas(self, *args: Any, **kwargs: Any) -> None: - raise NotImplementedError("applyInPandas() is not implemented.") + applyInPandas.__doc__ = PySparkGroupedData.applyInPandas.__doc__ def applyInPandasWithState(self, *args: Any, **kwargs: Any) -> None: raise NotImplementedError("applyInPandasWithState() is not implemented.") diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 9807c9722a6..dbfcfea7678 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -1923,6 +1923,33 @@ class MapPartitions(LogicalPlan): return plan +class GroupMap(LogicalPlan): + """Logical plan object for a Group Map API: apply, applyInPandas.""" + + def __init__( + self, + child: Optional["LogicalPlan"], + grouping_cols: Sequence[Column], + function: "UserDefinedFunction", + cols: List[str], + ): + assert isinstance(grouping_cols, list) and all(isinstance(c, Column) for c in grouping_cols) + + super().__init__(child) + self._grouping_cols = grouping_cols + self._func = function._build_common_inline_user_defined_function(*cols) + + def plan(self, session: "SparkConnectClient") -> proto.Relation: + assert self._child is not None + plan = self._create_proto_relation() + plan.group_map.input.CopyFrom(self._child.plan(session)) + plan.group_map.grouping_expressions.extend( + [c.to_plan(session) for c in self._grouping_cols] + ) + plan.group_map.func.CopyFrom(self._func.to_plan_udf(session)) + return plan + + class CachedRelation(LogicalPlan): def __init__(self, plan: proto.Relation) -> None: super(CachedRelation, self).__init__(None) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index 521a10f214c..aa6d39cd4f0 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as spark_dot_connect_dot_catal DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xb8\x13\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...] + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xf0\x13\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...] ) @@ -92,6 +92,7 @@ _UNPIVOT_VALUES = _UNPIVOT.nested_types_by_name["Values"] _TOSCHEMA = DESCRIPTOR.message_types_by_name["ToSchema"] _REPARTITIONBYEXPRESSION = DESCRIPTOR.message_types_by_name["RepartitionByExpression"] _MAPPARTITIONS = DESCRIPTOR.message_types_by_name["MapPartitions"] +_GROUPMAP = DESCRIPTOR.message_types_by_name["GroupMap"] _COLLECTMETRICS = DESCRIPTOR.message_types_by_name["CollectMetrics"] _PARSE = DESCRIPTOR.message_types_by_name["Parse"] _PARSE_OPTIONSENTRY = _PARSE.nested_types_by_name["OptionsEntry"] @@ -640,6 +641,17 @@ MapPartitions = _reflection.GeneratedProtocolMessageType( ) _sym_db.RegisterMessage(MapPartitions) +GroupMap = _reflection.GeneratedProtocolMessageType( + "GroupMap", + (_message.Message,), + { + "DESCRIPTOR": _GROUPMAP, + "__module__": "spark.connect.relations_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.GroupMap) + }, +) +_sym_db.RegisterMessage(GroupMap) + CollectMetrics = _reflection.GeneratedProtocolMessageType( "CollectMetrics", (_message.Message,), @@ -685,117 +697,119 @@ if _descriptor._USE_C_DESCRIPTORS == False: _PARSE_OPTIONSENTRY._options = None _PARSE_OPTIONSENTRY._serialized_options = b"8\001" _RELATION._serialized_start = 165 - _RELATION._serialized_end = 2653 - _UNKNOWN._serialized_start = 2655 - _UNKNOWN._serialized_end = 2664 - _RELATIONCOMMON._serialized_start = 2666 - _RELATIONCOMMON._serialized_end = 2757 - _SQL._serialized_start = 2760 - _SQL._serialized_end = 2894 - _SQL_ARGSENTRY._serialized_start = 2839 - _SQL_ARGSENTRY._serialized_end = 2894 - _READ._serialized_start = 2897 - _READ._serialized_end = 3393 - _READ_NAMEDTABLE._serialized_start = 3039 - _READ_NAMEDTABLE._serialized_end = 3100 - _READ_DATASOURCE._serialized_start = 3103 - _READ_DATASOURCE._serialized_end = 3380 - _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3300 - _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3358 - _PROJECT._serialized_start = 3395 - _PROJECT._serialized_end = 3512 - _FILTER._serialized_start = 3514 - _FILTER._serialized_end = 3626 - _JOIN._serialized_start = 3629 - _JOIN._serialized_end = 4100 - _JOIN_JOINTYPE._serialized_start = 3892 - _JOIN_JOINTYPE._serialized_end = 4100 - _SETOPERATION._serialized_start = 4103 - _SETOPERATION._serialized_end = 4582 - _SETOPERATION_SETOPTYPE._serialized_start = 4419 - _SETOPERATION_SETOPTYPE._serialized_end = 4533 - _LIMIT._serialized_start = 4584 - _LIMIT._serialized_end = 4660 - _OFFSET._serialized_start = 4662 - _OFFSET._serialized_end = 4741 - _TAIL._serialized_start = 4743 - _TAIL._serialized_end = 4818 - _AGGREGATE._serialized_start = 4821 - _AGGREGATE._serialized_end = 5403 - _AGGREGATE_PIVOT._serialized_start = 5160 - _AGGREGATE_PIVOT._serialized_end = 5271 - _AGGREGATE_GROUPTYPE._serialized_start = 5274 - _AGGREGATE_GROUPTYPE._serialized_end = 5403 - _SORT._serialized_start = 5406 - _SORT._serialized_end = 5566 - _DROP._serialized_start = 5569 - _DROP._serialized_end = 5710 - _DEDUPLICATE._serialized_start = 5713 - _DEDUPLICATE._serialized_end = 5884 - _LOCALRELATION._serialized_start = 5886 - _LOCALRELATION._serialized_end = 5975 - _SAMPLE._serialized_start = 5978 - _SAMPLE._serialized_end = 6251 - _RANGE._serialized_start = 6254 - _RANGE._serialized_end = 6399 - _SUBQUERYALIAS._serialized_start = 6401 - _SUBQUERYALIAS._serialized_end = 6515 - _REPARTITION._serialized_start = 6518 - _REPARTITION._serialized_end = 6660 - _SHOWSTRING._serialized_start = 6663 - _SHOWSTRING._serialized_end = 6805 - _STATSUMMARY._serialized_start = 6807 - _STATSUMMARY._serialized_end = 6899 - _STATDESCRIBE._serialized_start = 6901 - _STATDESCRIBE._serialized_end = 6982 - _STATCROSSTAB._serialized_start = 6984 - _STATCROSSTAB._serialized_end = 7085 - _STATCOV._serialized_start = 7087 - _STATCOV._serialized_end = 7183 - _STATCORR._serialized_start = 7186 - _STATCORR._serialized_end = 7323 - _STATAPPROXQUANTILE._serialized_start = 7326 - _STATAPPROXQUANTILE._serialized_end = 7490 - _STATFREQITEMS._serialized_start = 7492 - _STATFREQITEMS._serialized_end = 7617 - _STATSAMPLEBY._serialized_start = 7620 - _STATSAMPLEBY._serialized_end = 7929 - _STATSAMPLEBY_FRACTION._serialized_start = 7821 - _STATSAMPLEBY_FRACTION._serialized_end = 7920 - _NAFILL._serialized_start = 7932 - _NAFILL._serialized_end = 8066 - _NADROP._serialized_start = 8069 - _NADROP._serialized_end = 8203 - _NAREPLACE._serialized_start = 8206 - _NAREPLACE._serialized_end = 8502 - _NAREPLACE_REPLACEMENT._serialized_start = 8361 - _NAREPLACE_REPLACEMENT._serialized_end = 8502 - _TODF._serialized_start = 8504 - _TODF._serialized_end = 8592 - _WITHCOLUMNSRENAMED._serialized_start = 8595 - _WITHCOLUMNSRENAMED._serialized_end = 8834 - _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 8767 - _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 8834 - _WITHCOLUMNS._serialized_start = 8836 - _WITHCOLUMNS._serialized_end = 8955 - _HINT._serialized_start = 8958 - _HINT._serialized_end = 9090 - _UNPIVOT._serialized_start = 9093 - _UNPIVOT._serialized_end = 9420 - _UNPIVOT_VALUES._serialized_start = 9350 - _UNPIVOT_VALUES._serialized_end = 9409 - _TOSCHEMA._serialized_start = 9422 - _TOSCHEMA._serialized_end = 9528 - _REPARTITIONBYEXPRESSION._serialized_start = 9531 - _REPARTITIONBYEXPRESSION._serialized_end = 9734 - _MAPPARTITIONS._serialized_start = 9737 - _MAPPARTITIONS._serialized_end = 9867 - _COLLECTMETRICS._serialized_start = 9870 - _COLLECTMETRICS._serialized_end = 10006 - _PARSE._serialized_start = 10009 - _PARSE._serialized_end = 10397 - _PARSE_OPTIONSENTRY._serialized_start = 3300 - _PARSE_OPTIONSENTRY._serialized_end = 3358 - _PARSE_PARSEFORMAT._serialized_start = 10298 - _PARSE_PARSEFORMAT._serialized_end = 10386 + _RELATION._serialized_end = 2709 + _UNKNOWN._serialized_start = 2711 + _UNKNOWN._serialized_end = 2720 + _RELATIONCOMMON._serialized_start = 2722 + _RELATIONCOMMON._serialized_end = 2813 + _SQL._serialized_start = 2816 + _SQL._serialized_end = 2950 + _SQL_ARGSENTRY._serialized_start = 2895 + _SQL_ARGSENTRY._serialized_end = 2950 + _READ._serialized_start = 2953 + _READ._serialized_end = 3449 + _READ_NAMEDTABLE._serialized_start = 3095 + _READ_NAMEDTABLE._serialized_end = 3156 + _READ_DATASOURCE._serialized_start = 3159 + _READ_DATASOURCE._serialized_end = 3436 + _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3356 + _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3414 + _PROJECT._serialized_start = 3451 + _PROJECT._serialized_end = 3568 + _FILTER._serialized_start = 3570 + _FILTER._serialized_end = 3682 + _JOIN._serialized_start = 3685 + _JOIN._serialized_end = 4156 + _JOIN_JOINTYPE._serialized_start = 3948 + _JOIN_JOINTYPE._serialized_end = 4156 + _SETOPERATION._serialized_start = 4159 + _SETOPERATION._serialized_end = 4638 + _SETOPERATION_SETOPTYPE._serialized_start = 4475 + _SETOPERATION_SETOPTYPE._serialized_end = 4589 + _LIMIT._serialized_start = 4640 + _LIMIT._serialized_end = 4716 + _OFFSET._serialized_start = 4718 + _OFFSET._serialized_end = 4797 + _TAIL._serialized_start = 4799 + _TAIL._serialized_end = 4874 + _AGGREGATE._serialized_start = 4877 + _AGGREGATE._serialized_end = 5459 + _AGGREGATE_PIVOT._serialized_start = 5216 + _AGGREGATE_PIVOT._serialized_end = 5327 + _AGGREGATE_GROUPTYPE._serialized_start = 5330 + _AGGREGATE_GROUPTYPE._serialized_end = 5459 + _SORT._serialized_start = 5462 + _SORT._serialized_end = 5622 + _DROP._serialized_start = 5625 + _DROP._serialized_end = 5766 + _DEDUPLICATE._serialized_start = 5769 + _DEDUPLICATE._serialized_end = 5940 + _LOCALRELATION._serialized_start = 5942 + _LOCALRELATION._serialized_end = 6031 + _SAMPLE._serialized_start = 6034 + _SAMPLE._serialized_end = 6307 + _RANGE._serialized_start = 6310 + _RANGE._serialized_end = 6455 + _SUBQUERYALIAS._serialized_start = 6457 + _SUBQUERYALIAS._serialized_end = 6571 + _REPARTITION._serialized_start = 6574 + _REPARTITION._serialized_end = 6716 + _SHOWSTRING._serialized_start = 6719 + _SHOWSTRING._serialized_end = 6861 + _STATSUMMARY._serialized_start = 6863 + _STATSUMMARY._serialized_end = 6955 + _STATDESCRIBE._serialized_start = 6957 + _STATDESCRIBE._serialized_end = 7038 + _STATCROSSTAB._serialized_start = 7040 + _STATCROSSTAB._serialized_end = 7141 + _STATCOV._serialized_start = 7143 + _STATCOV._serialized_end = 7239 + _STATCORR._serialized_start = 7242 + _STATCORR._serialized_end = 7379 + _STATAPPROXQUANTILE._serialized_start = 7382 + _STATAPPROXQUANTILE._serialized_end = 7546 + _STATFREQITEMS._serialized_start = 7548 + _STATFREQITEMS._serialized_end = 7673 + _STATSAMPLEBY._serialized_start = 7676 + _STATSAMPLEBY._serialized_end = 7985 + _STATSAMPLEBY_FRACTION._serialized_start = 7877 + _STATSAMPLEBY_FRACTION._serialized_end = 7976 + _NAFILL._serialized_start = 7988 + _NAFILL._serialized_end = 8122 + _NADROP._serialized_start = 8125 + _NADROP._serialized_end = 8259 + _NAREPLACE._serialized_start = 8262 + _NAREPLACE._serialized_end = 8558 + _NAREPLACE_REPLACEMENT._serialized_start = 8417 + _NAREPLACE_REPLACEMENT._serialized_end = 8558 + _TODF._serialized_start = 8560 + _TODF._serialized_end = 8648 + _WITHCOLUMNSRENAMED._serialized_start = 8651 + _WITHCOLUMNSRENAMED._serialized_end = 8890 + _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 8823 + _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 8890 + _WITHCOLUMNS._serialized_start = 8892 + _WITHCOLUMNS._serialized_end = 9011 + _HINT._serialized_start = 9014 + _HINT._serialized_end = 9146 + _UNPIVOT._serialized_start = 9149 + _UNPIVOT._serialized_end = 9476 + _UNPIVOT_VALUES._serialized_start = 9406 + _UNPIVOT_VALUES._serialized_end = 9465 + _TOSCHEMA._serialized_start = 9478 + _TOSCHEMA._serialized_end = 9584 + _REPARTITIONBYEXPRESSION._serialized_start = 9587 + _REPARTITIONBYEXPRESSION._serialized_end = 9790 + _MAPPARTITIONS._serialized_start = 9793 + _MAPPARTITIONS._serialized_end = 9923 + _GROUPMAP._serialized_start = 9926 + _GROUPMAP._serialized_end = 10129 + _COLLECTMETRICS._serialized_start = 10132 + _COLLECTMETRICS._serialized_end = 10268 + _PARSE._serialized_start = 10271 + _PARSE._serialized_end = 10659 + _PARSE_OPTIONSENTRY._serialized_start = 3356 + _PARSE_OPTIONSENTRY._serialized_end = 3414 + _PARSE_PARSEFORMAT._serialized_start = 10560 + _PARSE_PARSEFORMAT._serialized_end = 10648 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index ab1561996ef..6ae4a323f6f 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -92,6 +92,7 @@ class Relation(google.protobuf.message.Message): MAP_PARTITIONS_FIELD_NUMBER: builtins.int COLLECT_METRICS_FIELD_NUMBER: builtins.int PARSE_FIELD_NUMBER: builtins.int + GROUP_MAP_FIELD_NUMBER: builtins.int FILL_NA_FIELD_NUMBER: builtins.int DROP_NA_FIELD_NUMBER: builtins.int REPLACE_FIELD_NUMBER: builtins.int @@ -167,6 +168,8 @@ class Relation(google.protobuf.message.Message): @property def parse(self) -> global___Parse: ... @property + def group_map(self) -> global___GroupMap: ... + @property def fill_na(self) -> global___NAFill: """NA functions""" @property @@ -233,6 +236,7 @@ class Relation(google.protobuf.message.Message): map_partitions: global___MapPartitions | None = ..., collect_metrics: global___CollectMetrics | None = ..., parse: global___Parse | None = ..., + group_map: global___GroupMap | None = ..., fill_na: global___NAFill | None = ..., drop_na: global___NADrop | None = ..., replace: global___NAReplace | None = ..., @@ -283,6 +287,8 @@ class Relation(google.protobuf.message.Message): b"filter", "freq_items", b"freq_items", + "group_map", + b"group_map", "hint", b"hint", "join", @@ -378,6 +384,8 @@ class Relation(google.protobuf.message.Message): b"filter", "freq_items", b"freq_items", + "group_map", + b"group_map", "hint", b"hint", "join", @@ -470,6 +478,7 @@ class Relation(google.protobuf.message.Message): "map_partitions", "collect_metrics", "parse", + "group_map", "fill_na", "drop_na", "replace", @@ -2733,6 +2742,48 @@ class MapPartitions(google.protobuf.message.Message): global___MapPartitions = MapPartitions +class GroupMap(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + INPUT_FIELD_NUMBER: builtins.int + GROUPING_EXPRESSIONS_FIELD_NUMBER: builtins.int + FUNC_FIELD_NUMBER: builtins.int + @property + def input(self) -> global___Relation: + """(Required) Input relation for Group Map API: apply, applyInPandas.""" + @property + def grouping_expressions( + self, + ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ + pyspark.sql.connect.proto.expressions_pb2.Expression + ]: + """(Required) Expressions for grouping keys.""" + @property + def func(self) -> pyspark.sql.connect.proto.expressions_pb2.CommonInlineUserDefinedFunction: + """(Required) Input user-defined function.""" + def __init__( + self, + *, + input: global___Relation | None = ..., + grouping_expressions: collections.abc.Iterable[ + pyspark.sql.connect.proto.expressions_pb2.Expression + ] + | None = ..., + func: pyspark.sql.connect.proto.expressions_pb2.CommonInlineUserDefinedFunction + | None = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["func", b"func", "input", b"input"] + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "func", b"func", "grouping_expressions", b"grouping_expressions", "input", b"input" + ], + ) -> None: ... + +global___GroupMap = GroupMap + class CollectMetrics(google.protobuf.message.Message): """Collect arbitrary (named) metrics from a dataset.""" diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index bca96eaf205..f03aa35bb83 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -48,6 +48,9 @@ class PandasGroupedOpsMixin: .. versionadded:: 2.3.0 + .. versionchanged:: 3.4.0 + Support Spark Connect. + Parameters ---------- udf : :func:`pyspark.sql.functions.pandas_udf` @@ -128,6 +131,9 @@ class PandasGroupedOpsMixin: .. versionadded:: 3.0.0 + .. versionchanged:: 3.4.0 + Support Spark Connect. + Parameters ---------- func : function diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index a8e161a42a6..491865ad9c9 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -2845,8 +2845,6 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): # SPARK-41927: Disable unsupported functions. cg = self.connect.read.table(self.tbl_name).groupBy("id") for f in ( - "apply", - "applyInPandas", "applyInPandasWithState", "cogroup", ): diff --git a/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map.py b/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map.py new file mode 100644 index 00000000000..1736e395723 --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map.py @@ -0,0 +1,102 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +from pyspark.sql.tests.pandas.test_pandas_grouped_map import GroupedApplyInPandasTestsMixin +from pyspark.testing.connectutils import ReusedConnectTestCase + + +class GroupedApplyInPandasTests(GroupedApplyInPandasTestsMixin, ReusedConnectTestCase): + # TODO(SPARK-42822): Fix ambiguous reference for case-insensitive grouping column + @unittest.skip("Fails in Spark Connect, should enable.") + def test_case_insensitive_grouping_column(self): + super().test_case_insensitive_grouping_column() + + # TODO(SPARK-42857): Support CreateDataFrame from Decimal128 + @unittest.skip("Fails in Spark Connect, should enable.") + def test_supported_types(self): + super().test_supported_types() + + @unittest.skip( + "Spark Connect does not support sc._jvm.org.apache.log4j but the test depends on it." + ) + def test_wrong_return_type(self): + super().test_wrong_return_type() + + @unittest.skip( + "Spark Connect does not support sc._jvm.org.apache.log4j but the test depends on it." + ) + def test_wrong_args(self): + super().test_wrong_args() + + @unittest.skip( + "Spark Connect does not support sc._jvm.org.apache.log4j but the test depends on it." + ) + def test_unsupported_types(self): + super().test_unsupported_types() + + @unittest.skip( + "Spark Connect does not support sc._jvm.org.apache.log4j but the test depends on it." + ) + def test_register_grouped_map_udf(self): + super().test_register_grouped_map_udf() + + @unittest.skip( + "Spark Connect does not support sc._jvm.org.apache.log4j but the test depends on it." + ) + def test_column_order(self): + super().test_column_order() + + @unittest.skip( + "Spark Connect does not support sc._jvm.org.apache.log4j but the test depends on it." + ) + def test_apply_in_pandas_returning_wrong_column_names(self): + super().test_apply_in_pandas_returning_wrong_column_names() + + @unittest.skip( + "Spark Connect does not support sc._jvm.org.apache.log4j but the test depends on it." + ) + def test_apply_in_pandas_returning_no_column_names_and_wrong_amount(self): + super().test_apply_in_pandas_returning_no_column_names_and_wrong_amount() + + @unittest.skip( + "Spark Connect does not support sc._jvm.org.apache.log4j but the test depends on it." + ) + def test_apply_in_pandas_returning_incompatible_type(self): + super().test_apply_in_pandas_returning_incompatible_type() + + @unittest.skip( + "Spark Connect does not support sc._jvm.org.apache.log4j but the test depends on it." + ) + def test_apply_in_pandas_not_returning_pandas_dataframe(self): + super().test_apply_in_pandas_not_returning_pandas_dataframe() + + @unittest.skip("Spark Connect doesn't support RDD but the test depends on it.") + def test_grouped_with_empty_partition(self): + super().test_grouped_with_empty_partition() + + +if __name__ == "__main__": + from pyspark.sql.tests.connect.test_parity_pandas_grouped_map import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/connect/test_parity_pandas_udf.py b/python/pyspark/sql/tests/connect/test_parity_pandas_udf.py index 571ee74287e..d2eab7fa4f3 100644 --- a/python/pyspark/sql/tests/connect/test_parity_pandas_udf.py +++ b/python/pyspark/sql/tests/connect/test_parity_pandas_udf.py @@ -66,11 +66,6 @@ class PandasUDFParityTests(PandasUDFTestsMixin, ReusedConnectTestCase): self.assertEqual(udf.returnType, UnparsedDataType("v double")) self.assertEqual(udf.evalType, PandasUDFType.GROUPED_MAP) - # TODO(SPARK-42340): implement GroupedData.applyInPandas - @unittest.skip("Fails in Spark Connect, should enable.") - def test_stopiteration_in_grouped_map(self): - super().test_stopiteration_in_grouped_map() - if __name__ == "__main__": import unittest diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py index 88e68b04303..36bdae02944 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py @@ -73,7 +73,7 @@ if have_pyarrow: not have_pandas or not have_pyarrow, cast(str, pandas_requirement_message or pyarrow_requirement_message), ) -class GroupedApplyInPandasTests(ReusedSQLTestCase): +class GroupedApplyInPandasTestsMixin: @property def data(self): return ( @@ -289,17 +289,17 @@ class GroupedApplyInPandasTests(ReusedSQLTestCase): return pd.DataFrame([key + (pdf.v.mean(),)]) def test_apply_in_pandas_returning_column_names(self): - self._test_apply_in_pandas(GroupedApplyInPandasTests.stats_with_column_names) + self._test_apply_in_pandas(GroupedApplyInPandasTestsMixin.stats_with_column_names) def test_apply_in_pandas_returning_no_column_names(self): - self._test_apply_in_pandas(GroupedApplyInPandasTests.stats_with_no_column_names) + self._test_apply_in_pandas(GroupedApplyInPandasTestsMixin.stats_with_no_column_names) def test_apply_in_pandas_returning_column_names_sometimes(self): def stats(key, pdf): if key[0] % 2: - return GroupedApplyInPandasTests.stats_with_column_names(key, pdf) + return GroupedApplyInPandasTestsMixin.stats_with_column_names(key, pdf) else: - return GroupedApplyInPandasTests.stats_with_no_column_names(key, pdf) + return GroupedApplyInPandasTestsMixin.stats_with_no_column_names(key, pdf) self._test_apply_in_pandas(stats) @@ -782,7 +782,7 @@ class GroupedApplyInPandasTests(ReusedSQLTestCase): def stats(key, pdf): if key[0] % 2 == 0: - return GroupedApplyInPandasTests.stats_with_no_column_names(key, pdf) + return GroupedApplyInPandasTestsMixin.stats_with_no_column_names(key, pdf) return empty_df result = ( @@ -805,6 +805,10 @@ class GroupedApplyInPandasTests(ReusedSQLTestCase): self._test_apply_in_pandas_returning_empty_dataframe(empty_df) +class GroupedApplyInPandasTests(GroupedApplyInPandasTestsMixin, ReusedSQLTestCase): + pass + + if __name__ == "__main__": from pyspark.sql.tests.pandas.test_pandas_grouped_map import * # noqa: F401 --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org