This is an automated email from the ASF dual-hosted git repository. weichenxu123 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 2a1ac07132b [SPARK-42929] make mapInPandas / mapInArrow support "is_barrier" 2a1ac07132b is described below commit 2a1ac07132b7abc13e56b9a632b3dece7e4b60ea Author: Weichen Xu <weichen...@databricks.com> AuthorDate: Mon Mar 27 17:50:23 2023 +0800 [SPARK-42929] make mapInPandas / mapInArrow support "is_barrier" ### What changes were proposed in this pull request? make mapInPandas / mapInArrow support "is_barrier" ### Why are the changes needed? feature parity. ### Does this PR introduce _any_ user-facing change? Yes. ### How was this patch tested? Manually: `bin/pyspark --remote local`: ``` from pyspark.sql.functions import pandas_udf df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age")) def filter_func(iterator): for pdf in iterator: yield pdf[pdf.id == 1] df.mapInPandas(filter_func, df.schema, is_barrier=True).collect() def filter_func(iterator): for batch in iterator: pdf = batch.to_pandas() yield pyarrow.RecordBatch.from_pandas(pdf[pdf.id == 1]) df.mapInArrow(filter_func, df.schema, is_barrier=True).collect() ``` Closes #40559 from WeichenXu123/spark-connect-barrier-mode. Authored-by: Weichen Xu <weichen...@databricks.com> Signed-off-by: Weichen Xu <weichen...@databricks.com> --- .../main/protobuf/spark/connect/relations.proto | 3 +++ .../sql/connect/planner/SparkConnectPlanner.scala | 5 ++-- python/pyspark/sql/connect/dataframe.py | 21 +++++++++++---- python/pyspark/sql/connect/plan.py | 8 +++++- python/pyspark/sql/connect/proto/relations_pb2.py | 24 ++++++++--------- python/pyspark/sql/connect/proto/relations_pb2.pyi | 31 ++++++++++++++++++++-- 6 files changed, 70 insertions(+), 22 deletions(-) diff --git a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto index 976bd68e7fe..c965a6c8d32 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto @@ -794,6 +794,9 @@ message MapPartitions { // (Required) Input user-defined function. CommonInlineUserDefinedFunction func = 2; + + // (Optional) isBarrier. + optional bool is_barrier = 3; } message GroupMap { diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index e7911ccdf11..e7e88cab643 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -484,19 +484,20 @@ class SparkConnectPlanner(val session: SparkSession) { private def transformMapPartitions(rel: proto.MapPartitions): LogicalPlan = { val commonUdf = rel.getFunc val pythonUdf = transformPythonUDF(commonUdf) + val isBarrier = if (rel.hasIsBarrier) rel.getIsBarrier else false pythonUdf.evalType match { case PythonEvalType.SQL_MAP_PANDAS_ITER_UDF => logical.MapInPandas( pythonUdf, pythonUdf.dataType.asInstanceOf[StructType].toAttributes, transformRelation(rel.getInput), - false) + isBarrier) case PythonEvalType.SQL_MAP_ARROW_ITER_UDF => logical.PythonMapInArrow( pythonUdf, pythonUdf.dataType.asInstanceOf[StructType].toAttributes, transformRelation(rel.getInput), - false) + isBarrier) case _ => throw InvalidPlanInput(s"Function with EvalType: ${pythonUdf.evalType} is not supported") } diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 2dfc8e72193..10426c3c28d 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -1623,6 +1623,7 @@ class DataFrame: func: "PandasMapIterFunction", schema: Union[StructType, str], evalType: int, + is_barrier: bool, ) -> "DataFrame": from pyspark.sql.connect.udf import UserDefinedFunction @@ -1636,21 +1637,31 @@ class DataFrame: ) return DataFrame.withPlan( - plan.MapPartitions(child=self._plan, function=udf_obj, cols=self.columns), + plan.MapPartitions( + child=self._plan, function=udf_obj, cols=self.columns, is_barrier=is_barrier + ), session=self._session, ) def mapInPandas( - self, func: "PandasMapIterFunction", schema: Union[StructType, str] + self, + func: "PandasMapIterFunction", + schema: Union[StructType, str], + is_barrier: bool = False, ) -> "DataFrame": - return self._map_partitions(func, schema, PythonEvalType.SQL_MAP_PANDAS_ITER_UDF) + return self._map_partitions( + func, schema, PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, is_barrier + ) mapInPandas.__doc__ = PySparkDataFrame.mapInPandas.__doc__ def mapInArrow( - self, func: "ArrowMapIterFunction", schema: Union[StructType, str] + self, + func: "ArrowMapIterFunction", + schema: Union[StructType, str], + is_barrier: bool = False, ) -> "DataFrame": - return self._map_partitions(func, schema, PythonEvalType.SQL_MAP_ARROW_ITER_UDF) + return self._map_partitions(func, schema, PythonEvalType.SQL_MAP_ARROW_ITER_UDF, is_barrier) mapInArrow.__doc__ = PySparkDataFrame.mapInArrow.__doc__ diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 7988cc33009..12a5879db0f 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -1912,17 +1912,23 @@ class MapPartitions(LogicalPlan): """Logical plan object for a mapPartitions-equivalent API: mapInPandas, mapInArrow.""" def __init__( - self, child: Optional["LogicalPlan"], function: "UserDefinedFunction", cols: List[str] + self, + child: Optional["LogicalPlan"], + function: "UserDefinedFunction", + cols: List[str], + is_barrier: bool, ) -> None: super().__init__(child) self._func = function._build_common_inline_user_defined_function(*cols) + self._is_barrier = is_barrier def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() plan.map_partitions.input.CopyFrom(self._child.plan(session)) plan.map_partitions.func.CopyFrom(self._func.to_plan_udf(session)) + plan.map_partitions.is_barrier = self._is_barrier return plan diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index 81a0499666b..80e66fd4aae 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as spark_dot_connect_dot_catal DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xaf\x14\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...] + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xaf\x14\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...] ) @@ -828,17 +828,17 @@ if _descriptor._USE_C_DESCRIPTORS == False: _REPARTITIONBYEXPRESSION._serialized_start = 9782 _REPARTITIONBYEXPRESSION._serialized_end = 9985 _MAPPARTITIONS._serialized_start = 9988 - _MAPPARTITIONS._serialized_end = 10118 - _GROUPMAP._serialized_start = 10121 - _GROUPMAP._serialized_end = 10324 - _COGROUPMAP._serialized_start = 10327 - _COGROUPMAP._serialized_end = 10679 - _COLLECTMETRICS._serialized_start = 10682 - _COLLECTMETRICS._serialized_end = 10818 - _PARSE._serialized_start = 10821 - _PARSE._serialized_end = 11209 + _MAPPARTITIONS._serialized_end = 10169 + _GROUPMAP._serialized_start = 10172 + _GROUPMAP._serialized_end = 10375 + _COGROUPMAP._serialized_start = 10378 + _COGROUPMAP._serialized_end = 10730 + _COLLECTMETRICS._serialized_start = 10733 + _COLLECTMETRICS._serialized_end = 10869 + _PARSE._serialized_start = 10872 + _PARSE._serialized_end = 11260 _PARSE_OPTIONSENTRY._serialized_start = 3293 _PARSE_OPTIONSENTRY._serialized_end = 3351 - _PARSE_PARSEFORMAT._serialized_start = 11110 - _PARSE_PARSEFORMAT._serialized_end = 11198 + _PARSE_PARSEFORMAT._serialized_start = 11161 + _PARSE_PARSEFORMAT._serialized_end = 11249 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index f287a740346..c3cf733a995 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -2759,25 +2759,52 @@ class MapPartitions(google.protobuf.message.Message): INPUT_FIELD_NUMBER: builtins.int FUNC_FIELD_NUMBER: builtins.int + IS_BARRIER_FIELD_NUMBER: builtins.int @property def input(self) -> global___Relation: """(Required) Input relation for a mapPartitions-equivalent API: mapInPandas, mapInArrow.""" @property def func(self) -> pyspark.sql.connect.proto.expressions_pb2.CommonInlineUserDefinedFunction: """(Required) Input user-defined function.""" + is_barrier: builtins.bool + """(Optional) isBarrier.""" def __init__( self, *, input: global___Relation | None = ..., func: pyspark.sql.connect.proto.expressions_pb2.CommonInlineUserDefinedFunction | None = ..., + is_barrier: builtins.bool | None = ..., ) -> None: ... def HasField( - self, field_name: typing_extensions.Literal["func", b"func", "input", b"input"] + self, + field_name: typing_extensions.Literal[ + "_is_barrier", + b"_is_barrier", + "func", + b"func", + "input", + b"input", + "is_barrier", + b"is_barrier", + ], ) -> builtins.bool: ... def ClearField( - self, field_name: typing_extensions.Literal["func", b"func", "input", b"input"] + self, + field_name: typing_extensions.Literal[ + "_is_barrier", + b"_is_barrier", + "func", + b"func", + "input", + b"input", + "is_barrier", + b"is_barrier", + ], ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_is_barrier", b"_is_barrier"] + ) -> typing_extensions.Literal["is_barrier"] | None: ... global___MapPartitions = MapPartitions --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org