This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new f1377a856e8 [SPARK-44131][SQL][PYTHON][CONNECT][FOLLOWUP] Support 
qualified function name for call_function
f1377a856e8 is described below

commit f1377a856e85977aafe3bf13cce1da7b4d4ed195
Author: Jiaan Geng <belie...@163.com>
AuthorDate: Tue Jul 25 08:54:00 2023 +0800

    [SPARK-44131][SQL][PYTHON][CONNECT][FOLLOWUP] Support qualified function 
name for call_function
    
    ### What changes were proposed in this pull request?
    https://github.com/apache/spark/pull/41687 added `call_function` and 
deprecate `call_udf` for Scala API.
    
    Some times, the function name can be qualified, we should let users use it 
to invoke persistent functions as well.
    
    ### Why are the changes needed?
    Support qualified function name for `call_function`.
    
    ### Does this PR introduce _any_ user-facing change?
    'No'.
    New feature.
    
    ### How was this patch tested?
    New test cases.
    
    Closes #41932 from beliefer/SPARK-44131_followup.
    
    Authored-by: Jiaan Geng <belie...@163.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
    (cherry picked from commit d97a4e214c7e11bcc9b7d6e126bf06e214a29988)
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../scala/org/apache/spark/sql/functions.scala     |  10 +-
 .../spark/sql/application/ReplE2ESuite.scala       |  10 ++
 .../main/protobuf/spark/connect/expressions.proto  |   9 ++
 .../queries/function_call_function.json            |   2 +-
 .../queries/function_call_function.proto.bin       | Bin 174 -> 175 bytes
 .../sql/connect/planner/SparkConnectPlanner.scala  |  19 ++++
 python/pyspark/sql/connect/expressions.py          |  24 +++++
 python/pyspark/sql/connect/functions.py            |   6 +-
 .../pyspark/sql/connect/proto/expressions_pb2.py   | 118 +++++++++++----------
 .../pyspark/sql/connect/proto/expressions_pb2.pyi  |  36 +++++++
 python/pyspark/sql/functions.py                    |  23 +++-
 .../scala/org/apache/spark/sql/functions.scala     |  22 ++--
 .../apache/spark/sql/DataFrameFunctionsSuite.scala |  20 ++++
 .../spark/sql/hive/execution/HiveUDFSuite.scala    |  15 ++-
 14 files changed, 238 insertions(+), 76 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
index 17d1cdca350..eac3f652320 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
@@ -7923,15 +7923,19 @@ object functions {
   def call_udf(udfName: String, cols: Column*): Column = 
call_function(udfName, cols: _*)
 
   /**
-   * Call a builtin or temp function.
+   * Call a SQL function.
    *
    * @param funcName
-   *   function name
+   *   function name that follows the SQL identifier syntax (can be quoted, 
can be qualified)
    * @param cols
    *   the expression parameters of function
    * @since 3.5.0
    */
   @scala.annotation.varargs
-  def call_function(funcName: String, cols: Column*): Column = 
Column.fn(funcName, cols: _*)
+  def call_function(funcName: String, cols: Column*): Column = Column { 
builder =>
+    builder.getCallFunctionBuilder
+      .setFunctionName(funcName)
+      .addAllArguments(cols.map(_.expr).asJava)
+  }
 
 }
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
index 800ce43a60d..ad2ca383e4f 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
@@ -239,4 +239,14 @@ class ReplE2ESuite extends RemoteSparkSession with 
BeforeAndAfterEach {
     val output = runCommandsInShell(input)
     assertContains("Array[org.apache.spark.sql.Row] = Array([id1,1], [id2,16], 
[id3,25])", output)
   }
+
+  test("call_function") {
+    val input = """
+        |val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value")
+        |spark.udf.register("simpleUDF", (v: Int) => v * v)
+        |df.select($"id", call_function("simpleUDF", $"value")).collect()
+      """.stripMargin
+    val output = runCommandsInShell(input)
+    assertContains("Array[org.apache.spark.sql.Row] = Array([id1,1], [id2,16], 
[id3,25])", output)
+  }
 }
diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
index 37a8778865d..557b9db9123 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -46,6 +46,7 @@ message Expression {
     UpdateFields update_fields = 13;
     UnresolvedNamedLambdaVariable unresolved_named_lambda_variable = 14;
     CommonInlineUserDefinedFunction common_inline_user_defined_function = 15;
+    CallFunction call_function = 16;
 
     // This field is used to mark extensions to the protocol. When plugins 
generate arbitrary
     // relations they can add them here. During the planning the correct 
resolution is done.
@@ -371,3 +372,11 @@ message JavaUDF {
   // (Required) Indicate if the Java user-defined function is an aggregate 
function
   bool aggregate = 3;
 }
+
+message CallFunction {
+  // (Required) Unparsed name of the SQL function.
+  string function_name = 1;
+
+  // (Optional) Function arguments. Empty arguments are allowed.
+  repeated Expression arguments = 2;
+}
diff --git 
a/connector/connect/common/src/test/resources/query-tests/queries/function_call_function.json
 
b/connector/connect/common/src/test/resources/query-tests/queries/function_call_function.json
index f7fe5beba2c..6db0a614682 100644
--- 
a/connector/connect/common/src/test/resources/query-tests/queries/function_call_function.json
+++ 
b/connector/connect/common/src/test/resources/query-tests/queries/function_call_function.json
@@ -12,7 +12,7 @@
       }
     },
     "expressions": [{
-      "unresolvedFunction": {
+      "callFunction": {
         "functionName": "lower",
         "arguments": [{
           "unresolvedAttribute": {
diff --git 
a/connector/connect/common/src/test/resources/query-tests/queries/function_call_function.proto.bin
 
b/connector/connect/common/src/test/resources/query-tests/queries/function_call_function.proto.bin
index 7c736d93f77..ef985e42131 100644
Binary files 
a/connector/connect/common/src/test/resources/query-tests/queries/function_call_function.proto.bin
 and 
b/connector/connect/common/src/test/resources/query-tests/queries/function_call_function.proto.bin
 differ
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 92a9524f67a..36037cce7eb 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -1380,6 +1380,8 @@ class SparkConnectPlanner(val sessionHolder: 
SessionHolder) extends Logging {
         transformExpressionPlugin(exp.getExtension)
       case proto.Expression.ExprTypeCase.COMMON_INLINE_USER_DEFINED_FUNCTION =>
         
transformCommonInlineUserDefinedFunction(exp.getCommonInlineUserDefinedFunction)
+      case proto.Expression.ExprTypeCase.CALL_FUNCTION =>
+        transformCallFunction(exp.getCallFunction)
       case _ =>
         throw InvalidPlanInput(
           s"Expression with ID: ${exp.getExprTypeCase.getNumber} is not 
supported")
@@ -1484,6 +1486,23 @@ class SparkConnectPlanner(val sessionHolder: 
SessionHolder) extends Logging {
     }
   }
 
+  /**
+   * Translates a SQL function from proto to the Catalyst expression.
+   *
+   * @param fun
+   *   Proto representation of the function call.
+   * @return
+   *   Expression.
+   */
+  private def transformCallFunction(fun: proto.CallFunction): Expression = {
+    val funcName = fun.getFunctionName
+    val nameParts = 
session.sessionState.sqlParser.parseMultipartIdentifier(funcName)
+    UnresolvedFunction(
+      nameParts,
+      fun.getArgumentsList.asScala.map(transformExpression).toSeq,
+      false)
+  }
+
   private def unpackUdf(fun: proto.CommonInlineUserDefinedFunction): UdfPacket 
= {
     Utils.deserialize[UdfPacket](
       fun.getScalarScalaUdf.getPayload.toByteArray,
diff --git a/python/pyspark/sql/connect/expressions.py 
b/python/pyspark/sql/connect/expressions.py
index e1b648c7bb8..44e6e174f70 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -1027,3 +1027,27 @@ class DistributedSequenceID(Expression):
 
     def __repr__(self) -> str:
         return "DistributedSequenceID()"
+
+
+class CallFunction(Expression):
+    def __init__(self, name: str, args: Sequence["Expression"]):
+        super().__init__()
+
+        assert isinstance(name, str)
+        self._name = name
+
+        assert isinstance(args, list) and all(isinstance(arg, Expression) for 
arg in args)
+        self._args = args
+
+    def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
+        expr = proto.Expression()
+        expr.call_function.function_name = self._name
+        if len(self._args) > 0:
+            expr.call_function.arguments.extend([arg.to_plan(session) for arg 
in self._args])
+        return expr
+
+    def __repr__(self) -> str:
+        if len(self._args) > 0:
+            return f"CallFunction('{self._name}', {', '.join([str(arg) for arg 
in self._args])})"
+        else:
+            return f"CallFunction('{self._name}')"
diff --git a/python/pyspark/sql/connect/functions.py 
b/python/pyspark/sql/connect/functions.py
index a1c0516ee0d..a92f89c0f6c 100644
--- a/python/pyspark/sql/connect/functions.py
+++ b/python/pyspark/sql/connect/functions.py
@@ -51,6 +51,7 @@ from pyspark.sql.connect.expressions import (
     SQLExpression,
     LambdaFunction,
     UnresolvedNamedLambdaVariable,
+    CallFunction,
 )
 from pyspark.sql.connect.udf import _create_py_udf
 from pyspark.sql.connect.udtf import _create_py_udtf
@@ -3909,8 +3910,9 @@ def udtf(
 udtf.__doc__ = pysparkfuncs.udtf.__doc__
 
 
-def call_function(udfName: str, *cols: "ColumnOrName") -> Column:
-    return _invoke_function(udfName, *[_to_col(c) for c in cols])
+def call_function(funcName: str, *cols: "ColumnOrName") -> Column:
+    expressions = [_to_col(c)._expr for c in cols]
+    return Column(CallFunction(funcName, expressions))
 
 
 call_function.__doc__ = pysparkfuncs.call_function.__doc__
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py 
b/python/pyspark/sql/connect/proto/expressions_pb2.py
index 7a68d831a99..51d1a5d48a1 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -33,7 +33,7 @@ from pyspark.sql.connect.proto import types_pb2 as 
spark_dot_connect_dot_types__
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\x95+\n\nExpression\x12=\n\x07literal\x18\x01
 
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
 
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
 
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct
 [...]
+    
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xd9+\n\nExpression\x12=\n\x07literal\x18\x01
 
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
 
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
 
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct
 [...]
 )
 
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -45,61 +45,63 @@ if _descriptor._USE_C_DESCRIPTORS == False:
         b"\n\036org.apache.spark.connect.protoP\001Z\022internal/generated"
     )
     _EXPRESSION._serialized_start = 105
-    _EXPRESSION._serialized_end = 5630
-    _EXPRESSION_WINDOW._serialized_start = 1475
-    _EXPRESSION_WINDOW._serialized_end = 2258
-    _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1765
-    _EXPRESSION_WINDOW_WINDOWFRAME._serialized_end = 2258
-    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_start = 2032
-    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_end = 2177
-    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_start = 2179
-    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_end = 2258
-    _EXPRESSION_SORTORDER._serialized_start = 2261
-    _EXPRESSION_SORTORDER._serialized_end = 2686
-    _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_start = 2491
-    _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_end = 2599
-    _EXPRESSION_SORTORDER_NULLORDERING._serialized_start = 2601
-    _EXPRESSION_SORTORDER_NULLORDERING._serialized_end = 2686
-    _EXPRESSION_CAST._serialized_start = 2689
-    _EXPRESSION_CAST._serialized_end = 2834
-    _EXPRESSION_LITERAL._serialized_start = 2837
-    _EXPRESSION_LITERAL._serialized_end = 4400
-    _EXPRESSION_LITERAL_DECIMAL._serialized_start = 3672
-    _EXPRESSION_LITERAL_DECIMAL._serialized_end = 3789
-    _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3791
-    _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 3889
-    _EXPRESSION_LITERAL_ARRAY._serialized_start = 3892
-    _EXPRESSION_LITERAL_ARRAY._serialized_end = 4022
-    _EXPRESSION_LITERAL_MAP._serialized_start = 4025
-    _EXPRESSION_LITERAL_MAP._serialized_end = 4252
-    _EXPRESSION_LITERAL_STRUCT._serialized_start = 4255
-    _EXPRESSION_LITERAL_STRUCT._serialized_end = 4384
-    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 4402
-    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 4514
-    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 4517
-    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 4721
-    _EXPRESSION_EXPRESSIONSTRING._serialized_start = 4723
-    _EXPRESSION_EXPRESSIONSTRING._serialized_end = 4773
-    _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 4775
-    _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 4857
-    _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 4859
-    _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 4945
-    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 4948
-    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 5080
-    _EXPRESSION_UPDATEFIELDS._serialized_start = 5083
-    _EXPRESSION_UPDATEFIELDS._serialized_end = 5270
-    _EXPRESSION_ALIAS._serialized_start = 5272
-    _EXPRESSION_ALIAS._serialized_end = 5392
-    _EXPRESSION_LAMBDAFUNCTION._serialized_start = 5395
-    _EXPRESSION_LAMBDAFUNCTION._serialized_end = 5553
-    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 5555
-    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 5617
-    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 5633
-    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 5997
-    _PYTHONUDF._serialized_start = 6000
-    _PYTHONUDF._serialized_end = 6155
-    _SCALARSCALAUDF._serialized_start = 6158
-    _SCALARSCALAUDF._serialized_end = 6342
-    _JAVAUDF._serialized_start = 6345
-    _JAVAUDF._serialized_end = 6494
+    _EXPRESSION._serialized_end = 5698
+    _EXPRESSION_WINDOW._serialized_start = 1543
+    _EXPRESSION_WINDOW._serialized_end = 2326
+    _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1833
+    _EXPRESSION_WINDOW_WINDOWFRAME._serialized_end = 2326
+    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_start = 2100
+    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_end = 2245
+    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_start = 2247
+    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_end = 2326
+    _EXPRESSION_SORTORDER._serialized_start = 2329
+    _EXPRESSION_SORTORDER._serialized_end = 2754
+    _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_start = 2559
+    _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_end = 2667
+    _EXPRESSION_SORTORDER_NULLORDERING._serialized_start = 2669
+    _EXPRESSION_SORTORDER_NULLORDERING._serialized_end = 2754
+    _EXPRESSION_CAST._serialized_start = 2757
+    _EXPRESSION_CAST._serialized_end = 2902
+    _EXPRESSION_LITERAL._serialized_start = 2905
+    _EXPRESSION_LITERAL._serialized_end = 4468
+    _EXPRESSION_LITERAL_DECIMAL._serialized_start = 3740
+    _EXPRESSION_LITERAL_DECIMAL._serialized_end = 3857
+    _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3859
+    _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 3957
+    _EXPRESSION_LITERAL_ARRAY._serialized_start = 3960
+    _EXPRESSION_LITERAL_ARRAY._serialized_end = 4090
+    _EXPRESSION_LITERAL_MAP._serialized_start = 4093
+    _EXPRESSION_LITERAL_MAP._serialized_end = 4320
+    _EXPRESSION_LITERAL_STRUCT._serialized_start = 4323
+    _EXPRESSION_LITERAL_STRUCT._serialized_end = 4452
+    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 4470
+    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 4582
+    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 4585
+    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 4789
+    _EXPRESSION_EXPRESSIONSTRING._serialized_start = 4791
+    _EXPRESSION_EXPRESSIONSTRING._serialized_end = 4841
+    _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 4843
+    _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 4925
+    _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 4927
+    _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 5013
+    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 5016
+    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 5148
+    _EXPRESSION_UPDATEFIELDS._serialized_start = 5151
+    _EXPRESSION_UPDATEFIELDS._serialized_end = 5338
+    _EXPRESSION_ALIAS._serialized_start = 5340
+    _EXPRESSION_ALIAS._serialized_end = 5460
+    _EXPRESSION_LAMBDAFUNCTION._serialized_start = 5463
+    _EXPRESSION_LAMBDAFUNCTION._serialized_end = 5621
+    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 5623
+    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 5685
+    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 5701
+    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 6065
+    _PYTHONUDF._serialized_start = 6068
+    _PYTHONUDF._serialized_end = 6223
+    _SCALARSCALAUDF._serialized_start = 6226
+    _SCALARSCALAUDF._serialized_end = 6410
+    _JAVAUDF._serialized_start = 6413
+    _JAVAUDF._serialized_end = 6562
+    _CALLFUNCTION._serialized_start = 6564
+    _CALLFUNCTION._serialized_end = 6672
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi 
b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index bef87203b55..b9b16ce35e3 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -1101,6 +1101,7 @@ class Expression(google.protobuf.message.Message):
     UPDATE_FIELDS_FIELD_NUMBER: builtins.int
     UNRESOLVED_NAMED_LAMBDA_VARIABLE_FIELD_NUMBER: builtins.int
     COMMON_INLINE_USER_DEFINED_FUNCTION_FIELD_NUMBER: builtins.int
+    CALL_FUNCTION_FIELD_NUMBER: builtins.int
     EXTENSION_FIELD_NUMBER: builtins.int
     @property
     def literal(self) -> global___Expression.Literal: ...
@@ -1135,6 +1136,8 @@ class Expression(google.protobuf.message.Message):
     @property
     def common_inline_user_defined_function(self) -> 
global___CommonInlineUserDefinedFunction: ...
     @property
+    def call_function(self) -> global___CallFunction: ...
+    @property
     def extension(self) -> google.protobuf.any_pb2.Any:
         """This field is used to mark extensions to the protocol. When plugins 
generate arbitrary
         relations they can add them here. During the planning the correct 
resolution is done.
@@ -1158,6 +1161,7 @@ class Expression(google.protobuf.message.Message):
         unresolved_named_lambda_variable: 
global___Expression.UnresolvedNamedLambdaVariable
         | None = ...,
         common_inline_user_defined_function: 
global___CommonInlineUserDefinedFunction | None = ...,
+        call_function: global___CallFunction | None = ...,
         extension: google.protobuf.any_pb2.Any | None = ...,
     ) -> None: ...
     def HasField(
@@ -1165,6 +1169,8 @@ class Expression(google.protobuf.message.Message):
         field_name: typing_extensions.Literal[
             "alias",
             b"alias",
+            "call_function",
+            b"call_function",
             "cast",
             b"cast",
             "common_inline_user_defined_function",
@@ -1204,6 +1210,8 @@ class Expression(google.protobuf.message.Message):
         field_name: typing_extensions.Literal[
             "alias",
             b"alias",
+            "call_function",
+            b"call_function",
             "cast",
             b"cast",
             "common_inline_user_defined_function",
@@ -1256,6 +1264,7 @@ class Expression(google.protobuf.message.Message):
         "update_fields",
         "unresolved_named_lambda_variable",
         "common_inline_user_defined_function",
+        "call_function",
         "extension",
     ] | None: ...
 
@@ -1469,3 +1478,30 @@ class JavaUDF(google.protobuf.message.Message):
     ) -> typing_extensions.Literal["output_type"] | None: ...
 
 global___JavaUDF = JavaUDF
+
+class CallFunction(google.protobuf.message.Message):
+    DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+    FUNCTION_NAME_FIELD_NUMBER: builtins.int
+    ARGUMENTS_FIELD_NUMBER: builtins.int
+    function_name: builtins.str
+    """(Required) Unparsed name of the SQL function."""
+    @property
+    def arguments(
+        self,
+    ) -> 
google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Expression]:
+        """(Optional) Function arguments. Empty arguments are allowed."""
+    def __init__(
+        self,
+        *,
+        function_name: builtins.str = ...,
+        arguments: collections.abc.Iterable[global___Expression] | None = ...,
+    ) -> None: ...
+    def ClearField(
+        self,
+        field_name: typing_extensions.Literal[
+            "arguments", b"arguments", "function_name", b"function_name"
+        ],
+    ) -> None: ...
+
+global___CallFunction = CallFunction
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index f566fcee0e3..b45e1daa0fd 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -14395,16 +14395,16 @@ def call_udf(udfName: str, *cols: "ColumnOrName") -> 
Column:
 
 
 @try_remote_functions
-def call_function(udfName: str, *cols: "ColumnOrName") -> Column:
+def call_function(funcName: str, *cols: "ColumnOrName") -> Column:
     """
-    Call a builtin or temp function.
+    Call a SQL function.
 
     .. versionadded:: 3.5.0
 
     Parameters
     ----------
-    udfName : str
-        name of the function
+    funcName : str
+        function name that follows the SQL identifier syntax (can be quoted, 
can be qualified)
     cols : :class:`~pyspark.sql.Column` or str
         column names or :class:`~pyspark.sql.Column`\\s to be used in the 
function
 
@@ -14442,9 +14442,22 @@ def call_function(udfName: str, *cols: "ColumnOrName") 
-> Column:
     +-------+
     |    2.0|
     +-------+
+    >>> _ = spark.sql("CREATE FUNCTION custom_avg AS 
'test.org.apache.spark.sql.MyDoubleAvg'")
+    >>> df.select(call_function("custom_avg", col("id"))).show()
+    +------------------------------------+
+    |spark_catalog.default.custom_avg(id)|
+    +------------------------------------+
+    |                               102.0|
+    +------------------------------------+
+    >>> df.select(call_function("spark_catalog.default.custom_avg", 
col("id"))).show()
+    +------------------------------------+
+    |spark_catalog.default.custom_avg(id)|
+    +------------------------------------+
+    |                               102.0|
+    +------------------------------------+
     """
     sc = get_active_spark_context()
-    return _invoke_function("call_function", udfName, _to_seq(sc, cols, 
_to_java_column))
+    return _invoke_function("call_function", funcName, _to_seq(sc, cols, 
_to_java_column))
 
 
 @try_remote_functions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 2a8cfd250c9..ca5e4422ca9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -8338,7 +8338,7 @@ object functions {
   @scala.annotation.varargs
   @deprecated("Use call_udf")
   def callUDF(udfName: String, cols: Column*): Column =
-    call_function(udfName, cols: _*)
+    call_function(Seq(udfName), cols: _*)
 
   /**
    * Call an user-defined function.
@@ -8357,18 +8357,28 @@ object functions {
    */
   @scala.annotation.varargs
   def call_udf(udfName: String, cols: Column*): Column =
-    call_function(udfName, cols: _*)
+    call_function(Seq(udfName), cols: _*)
 
   /**
-   * Call a builtin or temp function.
+   * Call a SQL function.
    *
-   * @param funcName function name
+   * @param funcName function name that follows the SQL identifier syntax
+   *                 (can be quoted, can be qualified)
    * @param cols the expression parameters of function
    * @since 3.5.0
    */
   @scala.annotation.varargs
-  def call_function(funcName: String, cols: Column*): Column =
-    withExpr { UnresolvedFunction(funcName, cols.map(_.expr), false) }
+  def call_function(funcName: String, cols: Column*): Column = {
+    val parser = 
SparkSession.getActiveSession.map(_.sessionState.sqlParser).getOrElse {
+      new SparkSqlParser()
+    }
+    val nameParts = parser.parseMultipartIdentifier(funcName)
+    call_function(nameParts, cols: _*)
+  }
+
+  private def call_function(nameParts: Seq[String], cols: Column*): Column = 
withExpr {
+    UnresolvedFunction(nameParts, cols.map(_.expr), false)
+  }
 
   /**
    * Unwrap UDT data type column into its underlying type.
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 9781a8e3ff4..c7dcb575ff0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -5918,6 +5918,26 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSparkSession {
 
   test("call_function") {
     checkAnswer(testData2.select(call_function("avg", $"a")), 
testData2.selectExpr("avg(a)"))
+
+    withUserDefinedFunction("custom_func" -> true, "custom_sum" -> false) {
+      spark.udf.register("custom_func", (i: Int) => { i + 2 })
+      checkAnswer(
+        testData2.select(call_function("custom_func", $"a")),
+        Seq(Row(3), Row(3), Row(4), Row(4), Row(5), Row(5)))
+      spark.udf.register("default.custom_func", (i: Int) => { i + 2 })
+      checkAnswer(
+        testData2.select(call_function("`default.custom_func`", $"a")),
+        Seq(Row(3), Row(3), Row(4), Row(4), Row(5), Row(5)))
+
+      sql("CREATE FUNCTION custom_sum AS 
'test.org.apache.spark.sql.MyDoubleSum'")
+      checkAnswer(
+        testData2.select(
+          call_function("custom_sum", $"a"),
+          call_function("default.custom_sum", $"a"),
+          call_function("spark_catalog.default.custom_sum", $"a")),
+        Row(12.0, 12.0, 12.0))
+    }
+
   }
 }
 
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
index ef430f4b6a2..d12ebae0f5f 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
@@ -37,7 +37,7 @@ import org.apache.spark.{SparkException, SparkFiles, 
TestUtils}
 import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
 import org.apache.spark.sql.catalyst.plans.logical.Project
 import org.apache.spark.sql.execution.WholeStageCodegenExec
-import org.apache.spark.sql.functions.max
+import org.apache.spark.sql.functions.{call_function, max}
 import org.apache.spark.sql.hive.test.TestHiveSingleton
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SQLTestUtils
@@ -552,6 +552,19 @@ class HiveUDFSuite extends QueryTest with 
TestHiveSingleton with SQLTestUtils {
     }
   }
 
+  test("Invoke a persist hive function with call_function") {
+    val testData = spark.range(5).repartition(1)
+    withUserDefinedFunction("custom_avg" -> false) {
+      sql(s"CREATE FUNCTION custom_avg AS 
'${classOf[GenericUDAFAverage].getName}'")
+      checkAnswer(
+        testData.select(
+          call_function("custom_avg", $"id"),
+          call_function("default.custom_avg", $"id"),
+          call_function("spark_catalog.default.custom_avg", $"id")),
+        Row(2.0, 2.0, 2.0))
+    }
+  }
+
   test("Temp function has dots in the names") {
     withUserDefinedFunction("test_avg" -> false, "`default.test_avg`" -> true) 
{
       sql(s"CREATE FUNCTION test_avg AS 
'${classOf[GenericUDAFAverage].getName}'")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to