This is an automated email from the ASF dual-hosted git repository.

xinrong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 0db63df2b28 [SPARK-42125][CONNECT][PYTHON] Pandas UDF in Spark Connect
0db63df2b28 is described below

commit 0db63df2b2829f1358fb711cd657a22b7838ece2
Author: Xinrong Meng <xinr...@apache.org>
AuthorDate: Tue Jan 31 09:12:20 2023 +0800

    [SPARK-42125][CONNECT][PYTHON] Pandas UDF in Spark Connect
    
    ### What changes were proposed in this pull request?
    Support Pandas UDF in Spark Connect.
    
    Since Pandas UDF and scalar inline Python UDF share the same proto message, 
`ScalarInlineUserDefinedFunction` is renamed to `CommonUserDefinedFunction`.
    
    ### Why are the changes needed?
    To reach parity with the vanilla PySpark.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. Pandas UDF is supported in Spark Connect, as shown below.
    
    ```py
    >>> from pyspark.sql.functions import pandas_udf
    >>> import pandas as pd
    >>> pandas_udf("double")
    ... def mean_udf(v: pd.Series) -> float:
    ...     return v.mean()
    ...
    >>> df = spark.createDataFrame([(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 
10.0)], ("id", "v"))
    >>> type(df)
    <class 'pyspark.sql.connect.dataframe.DataFrame'>
    >>> df.groupby("id").agg(mean_udf("v")).show()
    +---+-----------+
    | id|mean_udf(v)|
    +---+-----------+
    |  1|        1.5|
    |  2|        6.0|
    +---+-----------+
    >>>
    ```
    
    ### How was this patch tested?
    Existing tests.
    
    Closes #39753 from xinrong-meng/connect_pd_udf.
    
    Authored-by: Xinrong Meng <xinr...@apache.org>
    Signed-off-by: Xinrong Meng <xinr...@apache.org>
---
 .../main/protobuf/spark/connect/expressions.proto    |  4 ++--
 .../sql/connect/planner/SparkConnectPlanner.scala    | 12 ++++++------
 .../connect/messages/ConnectProtoMessagesSuite.scala | 10 +++++-----
 python/pyspark/sql/connect/expressions.py            | 13 +++++++------
 python/pyspark/sql/connect/proto/expressions_pb2.py  | 20 ++++++++++----------
 python/pyspark/sql/connect/proto/expressions_pb2.pyi | 20 ++++++++++----------
 python/pyspark/sql/connect/udf.py                    |  4 ++--
 python/pyspark/sql/pandas/functions.py               | 11 ++++++++++-
 8 files changed, 52 insertions(+), 42 deletions(-)

diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
index 7ae0a6c5008..5b27d4593db 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -44,7 +44,7 @@ message Expression {
     UnresolvedExtractValue unresolved_extract_value = 12;
     UpdateFields update_fields = 13;
     UnresolvedNamedLambdaVariable unresolved_named_lambda_variable = 14;
-    ScalarInlineUserDefinedFunction scalar_inline_user_defined_function = 15;
+    CommonInlineUserDefinedFunction common_inline_user_defined_function = 15;
 
     // This field is used to mark extensions to the protocol. When plugins 
generate arbitrary
     // relations they can add them here. During the planning the correct 
resolution is done.
@@ -297,7 +297,7 @@ message Expression {
   }
 }
 
-message ScalarInlineUserDefinedFunction {
+message CommonInlineUserDefinedFunction {
   // (Required) Name of the user-defined function.
   string function_name = 1;
   // (Required) Indicate if the user-defined function is deterministic.
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index dc921cee282..9b5c4b93f62 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -742,8 +742,8 @@ class SparkConnectPlanner(val session: SparkSession) {
         transformWindowExpression(exp.getWindow)
       case proto.Expression.ExprTypeCase.EXTENSION =>
         transformExpressionPlugin(exp.getExtension)
-      case proto.Expression.ExprTypeCase.SCALAR_INLINE_USER_DEFINED_FUNCTION =>
-        
transformScalarInlineUserDefinedFunction(exp.getScalarInlineUserDefinedFunction)
+      case proto.Expression.ExprTypeCase.COMMON_INLINE_USER_DEFINED_FUNCTION =>
+        
transformCommonInlineUserDefinedFunction(exp.getCommonInlineUserDefinedFunction)
       case _ =>
         throw InvalidPlanInput(
           s"Expression with ID: ${exp.getExprTypeCase.getNumber} is not 
supported")
@@ -826,10 +826,10 @@ class SparkConnectPlanner(val session: SparkSession) {
    * @return
    *   Expression.
    */
-  private def transformScalarInlineUserDefinedFunction(
-      fun: proto.ScalarInlineUserDefinedFunction): Expression = {
+  private def transformCommonInlineUserDefinedFunction(
+      fun: proto.CommonInlineUserDefinedFunction): Expression = {
     fun.getFunctionCase match {
-      case proto.ScalarInlineUserDefinedFunction.FunctionCase.PYTHON_UDF =>
+      case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF =>
         transformPythonUDF(fun)
       case _ =>
         throw InvalidPlanInput(
@@ -845,7 +845,7 @@ class SparkConnectPlanner(val session: SparkSession) {
    * @return
    *   PythonUDF.
    */
-  private def transformPythonUDF(fun: proto.ScalarInlineUserDefinedFunction): 
PythonUDF = {
+  private def transformPythonUDF(fun: proto.CommonInlineUserDefinedFunction): 
PythonUDF = {
     val udf = fun.getPythonUdf
     PythonUDF(
       name = fun.getFunctionName,
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala
index 3d8fae83428..240f6573c7d 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala
@@ -51,7 +51,7 @@ class ConnectProtoMessagesSuite extends SparkFunSuite {
     assert(extLit.getLiteral.getInteger == 32)
   }
 
-  test("ScalarInlineUserDefinedFunction") {
+  test("CommonInlineUserDefinedFunction") {
     val arguments = proto.Expression
       .newBuilder()
       .setUnresolvedAttribute(
@@ -65,10 +65,10 @@ class ConnectProtoMessagesSuite extends SparkFunSuite {
       .setCommand(ByteString.copyFrom("command".getBytes()))
       .build()
 
-    val scalarInlineUserDefinedFunctionExpr = proto.Expression
+    val commonInlineUserDefinedFunctionExpr = proto.Expression
       .newBuilder()
-      .setScalarInlineUserDefinedFunction(
-        proto.ScalarInlineUserDefinedFunction
+      .setCommonInlineUserDefinedFunction(
+        proto.CommonInlineUserDefinedFunction
           .newBuilder()
           .setFunctionName("f")
           .setDeterministic(true)
@@ -76,7 +76,7 @@ class ConnectProtoMessagesSuite extends SparkFunSuite {
           .setPythonUdf(pythonUdf))
       .build()
 
-    val fun = 
scalarInlineUserDefinedFunctionExpr.getScalarInlineUserDefinedFunction()
+    val fun = 
commonInlineUserDefinedFunctionExpr.getCommonInlineUserDefinedFunction()
     assert(fun.getFunctionName == "f")
     assert(fun.getDeterministic == true)
     assert(fun.getArgumentsCount == 1)
diff --git a/python/pyspark/sql/connect/expressions.py 
b/python/pyspark/sql/connect/expressions.py
index 0fa67a5f8d0..04d70beeddd 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -507,8 +507,9 @@ class PythonUDF:
         )
 
 
-class ScalarInlineUserDefinedFunction(Expression):
-    """Represents a scalar inline user-defined function of any programming 
languages."""
+class CommonInlineUserDefinedFunction(Expression):
+    """Represents a user-defined function with an inlined defined function 
body of any programming
+    languages."""
 
     def __init__(
         self,
@@ -524,13 +525,13 @@ class ScalarInlineUserDefinedFunction(Expression):
 
     def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
         expr = proto.Expression()
-        expr.scalar_inline_user_defined_function.function_name = 
self._function_name
-        expr.scalar_inline_user_defined_function.deterministic = 
self._deterministic
+        expr.common_inline_user_defined_function.function_name = 
self._function_name
+        expr.common_inline_user_defined_function.deterministic = 
self._deterministic
         if len(self._arguments) > 0:
-            expr.scalar_inline_user_defined_function.arguments.extend(
+            expr.common_inline_user_defined_function.arguments.extend(
                 [arg.to_plan(session) for arg in self._arguments]
             )
-        expr.scalar_inline_user_defined_function.python_udf.CopyFrom(
+        expr.common_inline_user_defined_function.python_udf.CopyFrom(
             self._function.to_plan(session)
         )
         return expr
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py 
b/python/pyspark/sql/connect/proto/expressions_pb2.py
index 0b2419fee35..f320eee54e0 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -34,7 +34,7 @@ from pyspark.sql.connect.proto import types_pb2 as 
spark_dot_connect_dot_types__
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\x92%\n\nExpression\x12=\n\x07literal\x18\x01
 
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
 
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
 
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct
 [...]
+    
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\x92%\n\nExpression\x12=\n\x07literal\x18\x01
 
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
 
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
 
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct
 [...]
 )
 
 
@@ -61,8 +61,8 @@ _EXPRESSION_LAMBDAFUNCTION = 
_EXPRESSION.nested_types_by_name["LambdaFunction"]
 _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE = _EXPRESSION.nested_types_by_name[
     "UnresolvedNamedLambdaVariable"
 ]
-_SCALARINLINEUSERDEFINEDFUNCTION = DESCRIPTOR.message_types_by_name[
-    "ScalarInlineUserDefinedFunction"
+_COMMONINLINEUSERDEFINEDFUNCTION = DESCRIPTOR.message_types_by_name[
+    "CommonInlineUserDefinedFunction"
 ]
 _PYTHONUDF = DESCRIPTOR.message_types_by_name["PythonUDF"]
 _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE = 
_EXPRESSION_WINDOW_WINDOWFRAME.enum_types_by_name[
@@ -261,16 +261,16 @@ _sym_db.RegisterMessage(Expression.Alias)
 _sym_db.RegisterMessage(Expression.LambdaFunction)
 _sym_db.RegisterMessage(Expression.UnresolvedNamedLambdaVariable)
 
-ScalarInlineUserDefinedFunction = _reflection.GeneratedProtocolMessageType(
-    "ScalarInlineUserDefinedFunction",
+CommonInlineUserDefinedFunction = _reflection.GeneratedProtocolMessageType(
+    "CommonInlineUserDefinedFunction",
     (_message.Message,),
     {
-        "DESCRIPTOR": _SCALARINLINEUSERDEFINEDFUNCTION,
+        "DESCRIPTOR": _COMMONINLINEUSERDEFINEDFUNCTION,
         "__module__": "spark.connect.expressions_pb2"
-        # 
@@protoc_insertion_point(class_scope:spark.connect.ScalarInlineUserDefinedFunction)
+        # 
@@protoc_insertion_point(class_scope:spark.connect.CommonInlineUserDefinedFunction)
     },
 )
-_sym_db.RegisterMessage(ScalarInlineUserDefinedFunction)
+_sym_db.RegisterMessage(CommonInlineUserDefinedFunction)
 
 PythonUDF = _reflection.GeneratedProtocolMessageType(
     "PythonUDF",
@@ -331,8 +331,8 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _EXPRESSION_LAMBDAFUNCTION._serialized_end = 4782
     _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 4784
     _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 4846
-    _SCALARINLINEUSERDEFINEDFUNCTION._serialized_start = 4862
-    _SCALARINLINEUSERDEFINEDFUNCTION._serialized_end = 5098
+    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 4862
+    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 5098
     _PYTHONUDF._serialized_start = 5100
     _PYTHONUDF._serialized_end = 5199
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi 
b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index 0191a0cdaf4..d8b0485017c 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -932,7 +932,7 @@ class Expression(google.protobuf.message.Message):
     UNRESOLVED_EXTRACT_VALUE_FIELD_NUMBER: builtins.int
     UPDATE_FIELDS_FIELD_NUMBER: builtins.int
     UNRESOLVED_NAMED_LAMBDA_VARIABLE_FIELD_NUMBER: builtins.int
-    SCALAR_INLINE_USER_DEFINED_FUNCTION_FIELD_NUMBER: builtins.int
+    COMMON_INLINE_USER_DEFINED_FUNCTION_FIELD_NUMBER: builtins.int
     EXTENSION_FIELD_NUMBER: builtins.int
     @property
     def literal(self) -> global___Expression.Literal: ...
@@ -965,7 +965,7 @@ class Expression(google.protobuf.message.Message):
         self,
     ) -> global___Expression.UnresolvedNamedLambdaVariable: ...
     @property
-    def scalar_inline_user_defined_function(self) -> 
global___ScalarInlineUserDefinedFunction: ...
+    def common_inline_user_defined_function(self) -> 
global___CommonInlineUserDefinedFunction: ...
     @property
     def extension(self) -> google.protobuf.any_pb2.Any:
         """This field is used to mark extensions to the protocol. When plugins 
generate arbitrary
@@ -989,7 +989,7 @@ class Expression(google.protobuf.message.Message):
         update_fields: global___Expression.UpdateFields | None = ...,
         unresolved_named_lambda_variable: 
global___Expression.UnresolvedNamedLambdaVariable
         | None = ...,
-        scalar_inline_user_defined_function: 
global___ScalarInlineUserDefinedFunction | None = ...,
+        common_inline_user_defined_function: 
global___CommonInlineUserDefinedFunction | None = ...,
         extension: google.protobuf.any_pb2.Any | None = ...,
     ) -> None: ...
     def HasField(
@@ -999,6 +999,8 @@ class Expression(google.protobuf.message.Message):
             b"alias",
             "cast",
             b"cast",
+            "common_inline_user_defined_function",
+            b"common_inline_user_defined_function",
             "expr_type",
             b"expr_type",
             "expression_string",
@@ -1009,8 +1011,6 @@ class Expression(google.protobuf.message.Message):
             b"lambda_function",
             "literal",
             b"literal",
-            "scalar_inline_user_defined_function",
-            b"scalar_inline_user_defined_function",
             "sort_order",
             b"sort_order",
             "unresolved_attribute",
@@ -1038,6 +1038,8 @@ class Expression(google.protobuf.message.Message):
             b"alias",
             "cast",
             b"cast",
+            "common_inline_user_defined_function",
+            b"common_inline_user_defined_function",
             "expr_type",
             b"expr_type",
             "expression_string",
@@ -1048,8 +1050,6 @@ class Expression(google.protobuf.message.Message):
             b"lambda_function",
             "literal",
             b"literal",
-            "scalar_inline_user_defined_function",
-            b"scalar_inline_user_defined_function",
             "sort_order",
             b"sort_order",
             "unresolved_attribute",
@@ -1087,13 +1087,13 @@ class Expression(google.protobuf.message.Message):
         "unresolved_extract_value",
         "update_fields",
         "unresolved_named_lambda_variable",
-        "scalar_inline_user_defined_function",
+        "common_inline_user_defined_function",
         "extension",
     ] | None: ...
 
 global___Expression = Expression
 
-class ScalarInlineUserDefinedFunction(google.protobuf.message.Message):
+class CommonInlineUserDefinedFunction(google.protobuf.message.Message):
     DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
     FUNCTION_NAME_FIELD_NUMBER: builtins.int
@@ -1142,7 +1142,7 @@ class 
ScalarInlineUserDefinedFunction(google.protobuf.message.Message):
         self, oneof_group: typing_extensions.Literal["function", b"function"]
     ) -> typing_extensions.Literal["python_udf"] | None: ...
 
-global___ScalarInlineUserDefinedFunction = ScalarInlineUserDefinedFunction
+global___CommonInlineUserDefinedFunction = CommonInlineUserDefinedFunction
 
 class PythonUDF(google.protobuf.message.Message):
     DESCRIPTOR: google.protobuf.descriptor.Descriptor
diff --git a/python/pyspark/sql/connect/udf.py 
b/python/pyspark/sql/connect/udf.py
index d0eb2fdfe6c..79346ee07ef 100644
--- a/python/pyspark/sql/connect/udf.py
+++ b/python/pyspark/sql/connect/udf.py
@@ -24,7 +24,7 @@ from pyspark.serializers import CloudPickleSerializer
 from pyspark.sql.connect.expressions import (
     ColumnReference,
     PythonUDF,
-    ScalarInlineUserDefinedFunction,
+    CommonInlineUserDefinedFunction,
 )
 from pyspark.sql.connect.column import Column
 from pyspark.sql.types import DataType, StringType
@@ -129,7 +129,7 @@ class UserDefinedFunction:
             command=CloudPickleSerializer().dumps((self.func, 
self._returnType)),
         )
         return Column(
-            ScalarInlineUserDefinedFunction(
+            CommonInlineUserDefinedFunction(
                 function_name=self._name,
                 deterministic=self.deterministic,
                 arguments=arg_exprs,
diff --git a/python/pyspark/sql/pandas/functions.py 
b/python/pyspark/sql/pandas/functions.py
index d0f81e2f633..b982480cf3e 100644
--- a/python/pyspark/sql/pandas/functions.py
+++ b/python/pyspark/sql/pandas/functions.py
@@ -25,6 +25,7 @@ from pyspark.sql.pandas.typehints import infer_eval_type
 from pyspark.sql.pandas.utils import require_minimum_pandas_version, 
require_minimum_pyarrow_version
 from pyspark.sql.types import DataType
 from pyspark.sql.udf import _create_udf
+from pyspark.sql.utils import is_remote
 
 
 class PandasUDFType:
@@ -51,6 +52,9 @@ def pandas_udf(f=None, returnType=None, functionType=None):
 
     .. versionadded:: 2.3.0
 
+    .. versionchanged:: 3.4.0
+        Support Spark Connect.
+
     Parameters
     ----------
     f : function, optional
@@ -449,4 +453,9 @@ def _create_pandas_udf(f, returnType, evalType):
             "or three arguments (key, left, right)."
         )
 
-    return _create_udf(f, returnType, evalType)
+    if is_remote():
+        from pyspark.sql.connect.udf import _create_udf as _create_connect_udf
+
+        return _create_connect_udf(f, returnType, evalType)
+    else:
+        return _create_udf(f, returnType, evalType)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to