This is an automated email from the ASF dual-hosted git repository. xinrong pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 0db63df2b28 [SPARK-42125][CONNECT][PYTHON] Pandas UDF in Spark Connect 0db63df2b28 is described below commit 0db63df2b2829f1358fb711cd657a22b7838ece2 Author: Xinrong Meng <xinr...@apache.org> AuthorDate: Tue Jan 31 09:12:20 2023 +0800 [SPARK-42125][CONNECT][PYTHON] Pandas UDF in Spark Connect ### What changes were proposed in this pull request? Support Pandas UDF in Spark Connect. Since Pandas UDF and scalar inline Python UDF share the same proto message, `ScalarInlineUserDefinedFunction` is renamed to `CommonUserDefinedFunction`. ### Why are the changes needed? To reach parity with the vanilla PySpark. ### Does this PR introduce _any_ user-facing change? Yes. Pandas UDF is supported in Spark Connect, as shown below. ```py >>> from pyspark.sql.functions import pandas_udf >>> import pandas as pd >>> pandas_udf("double") ... def mean_udf(v: pd.Series) -> float: ... return v.mean() ... >>> df = spark.createDataFrame([(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")) >>> type(df) <class 'pyspark.sql.connect.dataframe.DataFrame'> >>> df.groupby("id").agg(mean_udf("v")).show() +---+-----------+ | id|mean_udf(v)| +---+-----------+ | 1| 1.5| | 2| 6.0| +---+-----------+ >>> ``` ### How was this patch tested? Existing tests. Closes #39753 from xinrong-meng/connect_pd_udf. Authored-by: Xinrong Meng <xinr...@apache.org> Signed-off-by: Xinrong Meng <xinr...@apache.org> --- .../main/protobuf/spark/connect/expressions.proto | 4 ++-- .../sql/connect/planner/SparkConnectPlanner.scala | 12 ++++++------ .../connect/messages/ConnectProtoMessagesSuite.scala | 10 +++++----- python/pyspark/sql/connect/expressions.py | 13 +++++++------ python/pyspark/sql/connect/proto/expressions_pb2.py | 20 ++++++++++---------- python/pyspark/sql/connect/proto/expressions_pb2.pyi | 20 ++++++++++---------- python/pyspark/sql/connect/udf.py | 4 ++-- python/pyspark/sql/pandas/functions.py | 11 ++++++++++- 8 files changed, 52 insertions(+), 42 deletions(-) diff --git a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto index 7ae0a6c5008..5b27d4593db 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto @@ -44,7 +44,7 @@ message Expression { UnresolvedExtractValue unresolved_extract_value = 12; UpdateFields update_fields = 13; UnresolvedNamedLambdaVariable unresolved_named_lambda_variable = 14; - ScalarInlineUserDefinedFunction scalar_inline_user_defined_function = 15; + CommonInlineUserDefinedFunction common_inline_user_defined_function = 15; // This field is used to mark extensions to the protocol. When plugins generate arbitrary // relations they can add them here. During the planning the correct resolution is done. @@ -297,7 +297,7 @@ message Expression { } } -message ScalarInlineUserDefinedFunction { +message CommonInlineUserDefinedFunction { // (Required) Name of the user-defined function. string function_name = 1; // (Required) Indicate if the user-defined function is deterministic. diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index dc921cee282..9b5c4b93f62 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -742,8 +742,8 @@ class SparkConnectPlanner(val session: SparkSession) { transformWindowExpression(exp.getWindow) case proto.Expression.ExprTypeCase.EXTENSION => transformExpressionPlugin(exp.getExtension) - case proto.Expression.ExprTypeCase.SCALAR_INLINE_USER_DEFINED_FUNCTION => - transformScalarInlineUserDefinedFunction(exp.getScalarInlineUserDefinedFunction) + case proto.Expression.ExprTypeCase.COMMON_INLINE_USER_DEFINED_FUNCTION => + transformCommonInlineUserDefinedFunction(exp.getCommonInlineUserDefinedFunction) case _ => throw InvalidPlanInput( s"Expression with ID: ${exp.getExprTypeCase.getNumber} is not supported") @@ -826,10 +826,10 @@ class SparkConnectPlanner(val session: SparkSession) { * @return * Expression. */ - private def transformScalarInlineUserDefinedFunction( - fun: proto.ScalarInlineUserDefinedFunction): Expression = { + private def transformCommonInlineUserDefinedFunction( + fun: proto.CommonInlineUserDefinedFunction): Expression = { fun.getFunctionCase match { - case proto.ScalarInlineUserDefinedFunction.FunctionCase.PYTHON_UDF => + case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF => transformPythonUDF(fun) case _ => throw InvalidPlanInput( @@ -845,7 +845,7 @@ class SparkConnectPlanner(val session: SparkSession) { * @return * PythonUDF. */ - private def transformPythonUDF(fun: proto.ScalarInlineUserDefinedFunction): PythonUDF = { + private def transformPythonUDF(fun: proto.CommonInlineUserDefinedFunction): PythonUDF = { val udf = fun.getPythonUdf PythonUDF( name = fun.getFunctionName, diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala index 3d8fae83428..240f6573c7d 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala @@ -51,7 +51,7 @@ class ConnectProtoMessagesSuite extends SparkFunSuite { assert(extLit.getLiteral.getInteger == 32) } - test("ScalarInlineUserDefinedFunction") { + test("CommonInlineUserDefinedFunction") { val arguments = proto.Expression .newBuilder() .setUnresolvedAttribute( @@ -65,10 +65,10 @@ class ConnectProtoMessagesSuite extends SparkFunSuite { .setCommand(ByteString.copyFrom("command".getBytes())) .build() - val scalarInlineUserDefinedFunctionExpr = proto.Expression + val commonInlineUserDefinedFunctionExpr = proto.Expression .newBuilder() - .setScalarInlineUserDefinedFunction( - proto.ScalarInlineUserDefinedFunction + .setCommonInlineUserDefinedFunction( + proto.CommonInlineUserDefinedFunction .newBuilder() .setFunctionName("f") .setDeterministic(true) @@ -76,7 +76,7 @@ class ConnectProtoMessagesSuite extends SparkFunSuite { .setPythonUdf(pythonUdf)) .build() - val fun = scalarInlineUserDefinedFunctionExpr.getScalarInlineUserDefinedFunction() + val fun = commonInlineUserDefinedFunctionExpr.getCommonInlineUserDefinedFunction() assert(fun.getFunctionName == "f") assert(fun.getDeterministic == true) assert(fun.getArgumentsCount == 1) diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index 0fa67a5f8d0..04d70beeddd 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -507,8 +507,9 @@ class PythonUDF: ) -class ScalarInlineUserDefinedFunction(Expression): - """Represents a scalar inline user-defined function of any programming languages.""" +class CommonInlineUserDefinedFunction(Expression): + """Represents a user-defined function with an inlined defined function body of any programming + languages.""" def __init__( self, @@ -524,13 +525,13 @@ class ScalarInlineUserDefinedFunction(Expression): def to_plan(self, session: "SparkConnectClient") -> "proto.Expression": expr = proto.Expression() - expr.scalar_inline_user_defined_function.function_name = self._function_name - expr.scalar_inline_user_defined_function.deterministic = self._deterministic + expr.common_inline_user_defined_function.function_name = self._function_name + expr.common_inline_user_defined_function.deterministic = self._deterministic if len(self._arguments) > 0: - expr.scalar_inline_user_defined_function.arguments.extend( + expr.common_inline_user_defined_function.arguments.extend( [arg.to_plan(session) for arg in self._arguments] ) - expr.scalar_inline_user_defined_function.python_udf.CopyFrom( + expr.common_inline_user_defined_function.python_udf.CopyFrom( self._function.to_plan(session) ) return expr diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py index 0b2419fee35..f320eee54e0 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.py +++ b/python/pyspark/sql/connect/proto/expressions_pb2.py @@ -34,7 +34,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\x92%\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct [...] + b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\x92%\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct [...] ) @@ -61,8 +61,8 @@ _EXPRESSION_LAMBDAFUNCTION = _EXPRESSION.nested_types_by_name["LambdaFunction"] _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE = _EXPRESSION.nested_types_by_name[ "UnresolvedNamedLambdaVariable" ] -_SCALARINLINEUSERDEFINEDFUNCTION = DESCRIPTOR.message_types_by_name[ - "ScalarInlineUserDefinedFunction" +_COMMONINLINEUSERDEFINEDFUNCTION = DESCRIPTOR.message_types_by_name[ + "CommonInlineUserDefinedFunction" ] _PYTHONUDF = DESCRIPTOR.message_types_by_name["PythonUDF"] _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE = _EXPRESSION_WINDOW_WINDOWFRAME.enum_types_by_name[ @@ -261,16 +261,16 @@ _sym_db.RegisterMessage(Expression.Alias) _sym_db.RegisterMessage(Expression.LambdaFunction) _sym_db.RegisterMessage(Expression.UnresolvedNamedLambdaVariable) -ScalarInlineUserDefinedFunction = _reflection.GeneratedProtocolMessageType( - "ScalarInlineUserDefinedFunction", +CommonInlineUserDefinedFunction = _reflection.GeneratedProtocolMessageType( + "CommonInlineUserDefinedFunction", (_message.Message,), { - "DESCRIPTOR": _SCALARINLINEUSERDEFINEDFUNCTION, + "DESCRIPTOR": _COMMONINLINEUSERDEFINEDFUNCTION, "__module__": "spark.connect.expressions_pb2" - # @@protoc_insertion_point(class_scope:spark.connect.ScalarInlineUserDefinedFunction) + # @@protoc_insertion_point(class_scope:spark.connect.CommonInlineUserDefinedFunction) }, ) -_sym_db.RegisterMessage(ScalarInlineUserDefinedFunction) +_sym_db.RegisterMessage(CommonInlineUserDefinedFunction) PythonUDF = _reflection.GeneratedProtocolMessageType( "PythonUDF", @@ -331,8 +331,8 @@ if _descriptor._USE_C_DESCRIPTORS == False: _EXPRESSION_LAMBDAFUNCTION._serialized_end = 4782 _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 4784 _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 4846 - _SCALARINLINEUSERDEFINEDFUNCTION._serialized_start = 4862 - _SCALARINLINEUSERDEFINEDFUNCTION._serialized_end = 5098 + _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 4862 + _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 5098 _PYTHONUDF._serialized_start = 5100 _PYTHONUDF._serialized_end = 5199 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi index 0191a0cdaf4..d8b0485017c 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi +++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi @@ -932,7 +932,7 @@ class Expression(google.protobuf.message.Message): UNRESOLVED_EXTRACT_VALUE_FIELD_NUMBER: builtins.int UPDATE_FIELDS_FIELD_NUMBER: builtins.int UNRESOLVED_NAMED_LAMBDA_VARIABLE_FIELD_NUMBER: builtins.int - SCALAR_INLINE_USER_DEFINED_FUNCTION_FIELD_NUMBER: builtins.int + COMMON_INLINE_USER_DEFINED_FUNCTION_FIELD_NUMBER: builtins.int EXTENSION_FIELD_NUMBER: builtins.int @property def literal(self) -> global___Expression.Literal: ... @@ -965,7 +965,7 @@ class Expression(google.protobuf.message.Message): self, ) -> global___Expression.UnresolvedNamedLambdaVariable: ... @property - def scalar_inline_user_defined_function(self) -> global___ScalarInlineUserDefinedFunction: ... + def common_inline_user_defined_function(self) -> global___CommonInlineUserDefinedFunction: ... @property def extension(self) -> google.protobuf.any_pb2.Any: """This field is used to mark extensions to the protocol. When plugins generate arbitrary @@ -989,7 +989,7 @@ class Expression(google.protobuf.message.Message): update_fields: global___Expression.UpdateFields | None = ..., unresolved_named_lambda_variable: global___Expression.UnresolvedNamedLambdaVariable | None = ..., - scalar_inline_user_defined_function: global___ScalarInlineUserDefinedFunction | None = ..., + common_inline_user_defined_function: global___CommonInlineUserDefinedFunction | None = ..., extension: google.protobuf.any_pb2.Any | None = ..., ) -> None: ... def HasField( @@ -999,6 +999,8 @@ class Expression(google.protobuf.message.Message): b"alias", "cast", b"cast", + "common_inline_user_defined_function", + b"common_inline_user_defined_function", "expr_type", b"expr_type", "expression_string", @@ -1009,8 +1011,6 @@ class Expression(google.protobuf.message.Message): b"lambda_function", "literal", b"literal", - "scalar_inline_user_defined_function", - b"scalar_inline_user_defined_function", "sort_order", b"sort_order", "unresolved_attribute", @@ -1038,6 +1038,8 @@ class Expression(google.protobuf.message.Message): b"alias", "cast", b"cast", + "common_inline_user_defined_function", + b"common_inline_user_defined_function", "expr_type", b"expr_type", "expression_string", @@ -1048,8 +1050,6 @@ class Expression(google.protobuf.message.Message): b"lambda_function", "literal", b"literal", - "scalar_inline_user_defined_function", - b"scalar_inline_user_defined_function", "sort_order", b"sort_order", "unresolved_attribute", @@ -1087,13 +1087,13 @@ class Expression(google.protobuf.message.Message): "unresolved_extract_value", "update_fields", "unresolved_named_lambda_variable", - "scalar_inline_user_defined_function", + "common_inline_user_defined_function", "extension", ] | None: ... global___Expression = Expression -class ScalarInlineUserDefinedFunction(google.protobuf.message.Message): +class CommonInlineUserDefinedFunction(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor FUNCTION_NAME_FIELD_NUMBER: builtins.int @@ -1142,7 +1142,7 @@ class ScalarInlineUserDefinedFunction(google.protobuf.message.Message): self, oneof_group: typing_extensions.Literal["function", b"function"] ) -> typing_extensions.Literal["python_udf"] | None: ... -global___ScalarInlineUserDefinedFunction = ScalarInlineUserDefinedFunction +global___CommonInlineUserDefinedFunction = CommonInlineUserDefinedFunction class PythonUDF(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor diff --git a/python/pyspark/sql/connect/udf.py b/python/pyspark/sql/connect/udf.py index d0eb2fdfe6c..79346ee07ef 100644 --- a/python/pyspark/sql/connect/udf.py +++ b/python/pyspark/sql/connect/udf.py @@ -24,7 +24,7 @@ from pyspark.serializers import CloudPickleSerializer from pyspark.sql.connect.expressions import ( ColumnReference, PythonUDF, - ScalarInlineUserDefinedFunction, + CommonInlineUserDefinedFunction, ) from pyspark.sql.connect.column import Column from pyspark.sql.types import DataType, StringType @@ -129,7 +129,7 @@ class UserDefinedFunction: command=CloudPickleSerializer().dumps((self.func, self._returnType)), ) return Column( - ScalarInlineUserDefinedFunction( + CommonInlineUserDefinedFunction( function_name=self._name, deterministic=self.deterministic, arguments=arg_exprs, diff --git a/python/pyspark/sql/pandas/functions.py b/python/pyspark/sql/pandas/functions.py index d0f81e2f633..b982480cf3e 100644 --- a/python/pyspark/sql/pandas/functions.py +++ b/python/pyspark/sql/pandas/functions.py @@ -25,6 +25,7 @@ from pyspark.sql.pandas.typehints import infer_eval_type from pyspark.sql.pandas.utils import require_minimum_pandas_version, require_minimum_pyarrow_version from pyspark.sql.types import DataType from pyspark.sql.udf import _create_udf +from pyspark.sql.utils import is_remote class PandasUDFType: @@ -51,6 +52,9 @@ def pandas_udf(f=None, returnType=None, functionType=None): .. versionadded:: 2.3.0 + .. versionchanged:: 3.4.0 + Support Spark Connect. + Parameters ---------- f : function, optional @@ -449,4 +453,9 @@ def _create_pandas_udf(f, returnType, evalType): "or three arguments (key, left, right)." ) - return _create_udf(f, returnType, evalType) + if is_remote(): + from pyspark.sql.connect.udf import _create_udf as _create_connect_udf + + return _create_connect_udf(f, returnType, evalType) + else: + return _create_udf(f, returnType, evalType) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org