This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push:
new f1377a856e8 [SPARK-44131][SQL][PYTHON][CONNECT][FOLLOWUP] Support
qualified function name for call_function
f1377a856e8 is described below
commit f1377a856e85977aafe3bf13cce1da7b4d4ed195
Author: Jiaan Geng <[email protected]>
AuthorDate: Tue Jul 25 08:54:00 2023 +0800
[SPARK-44131][SQL][PYTHON][CONNECT][FOLLOWUP] Support qualified function
name for call_function
### What changes were proposed in this pull request?
https://github.com/apache/spark/pull/41687 added `call_function` and
deprecate `call_udf` for Scala API.
Some times, the function name can be qualified, we should let users use it
to invoke persistent functions as well.
### Why are the changes needed?
Support qualified function name for `call_function`.
### Does this PR introduce _any_ user-facing change?
'No'.
New feature.
### How was this patch tested?
New test cases.
Closes #41932 from beliefer/SPARK-44131_followup.
Authored-by: Jiaan Geng <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
(cherry picked from commit d97a4e214c7e11bcc9b7d6e126bf06e214a29988)
Signed-off-by: Wenchen Fan <[email protected]>
---
.../scala/org/apache/spark/sql/functions.scala | 10 +-
.../spark/sql/application/ReplE2ESuite.scala | 10 ++
.../main/protobuf/spark/connect/expressions.proto | 9 ++
.../queries/function_call_function.json | 2 +-
.../queries/function_call_function.proto.bin | Bin 174 -> 175 bytes
.../sql/connect/planner/SparkConnectPlanner.scala | 19 ++++
python/pyspark/sql/connect/expressions.py | 24 +++++
python/pyspark/sql/connect/functions.py | 6 +-
.../pyspark/sql/connect/proto/expressions_pb2.py | 118 +++++++++++----------
.../pyspark/sql/connect/proto/expressions_pb2.pyi | 36 +++++++
python/pyspark/sql/functions.py | 23 +++-
.../scala/org/apache/spark/sql/functions.scala | 22 ++--
.../apache/spark/sql/DataFrameFunctionsSuite.scala | 20 ++++
.../spark/sql/hive/execution/HiveUDFSuite.scala | 15 ++-
14 files changed, 238 insertions(+), 76 deletions(-)
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
index 17d1cdca350..eac3f652320 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
@@ -7923,15 +7923,19 @@ object functions {
def call_udf(udfName: String, cols: Column*): Column =
call_function(udfName, cols: _*)
/**
- * Call a builtin or temp function.
+ * Call a SQL function.
*
* @param funcName
- * function name
+ * function name that follows the SQL identifier syntax (can be quoted,
can be qualified)
* @param cols
* the expression parameters of function
* @since 3.5.0
*/
@scala.annotation.varargs
- def call_function(funcName: String, cols: Column*): Column =
Column.fn(funcName, cols: _*)
+ def call_function(funcName: String, cols: Column*): Column = Column {
builder =>
+ builder.getCallFunctionBuilder
+ .setFunctionName(funcName)
+ .addAllArguments(cols.map(_.expr).asJava)
+ }
}
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
index 800ce43a60d..ad2ca383e4f 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
@@ -239,4 +239,14 @@ class ReplE2ESuite extends RemoteSparkSession with
BeforeAndAfterEach {
val output = runCommandsInShell(input)
assertContains("Array[org.apache.spark.sql.Row] = Array([id1,1], [id2,16],
[id3,25])", output)
}
+
+ test("call_function") {
+ val input = """
+ |val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value")
+ |spark.udf.register("simpleUDF", (v: Int) => v * v)
+ |df.select($"id", call_function("simpleUDF", $"value")).collect()
+ """.stripMargin
+ val output = runCommandsInShell(input)
+ assertContains("Array[org.apache.spark.sql.Row] = Array([id1,1], [id2,16],
[id3,25])", output)
+ }
}
diff --git
a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
index 37a8778865d..557b9db9123 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -46,6 +46,7 @@ message Expression {
UpdateFields update_fields = 13;
UnresolvedNamedLambdaVariable unresolved_named_lambda_variable = 14;
CommonInlineUserDefinedFunction common_inline_user_defined_function = 15;
+ CallFunction call_function = 16;
// This field is used to mark extensions to the protocol. When plugins
generate arbitrary
// relations they can add them here. During the planning the correct
resolution is done.
@@ -371,3 +372,11 @@ message JavaUDF {
// (Required) Indicate if the Java user-defined function is an aggregate
function
bool aggregate = 3;
}
+
+message CallFunction {
+ // (Required) Unparsed name of the SQL function.
+ string function_name = 1;
+
+ // (Optional) Function arguments. Empty arguments are allowed.
+ repeated Expression arguments = 2;
+}
diff --git
a/connector/connect/common/src/test/resources/query-tests/queries/function_call_function.json
b/connector/connect/common/src/test/resources/query-tests/queries/function_call_function.json
index f7fe5beba2c..6db0a614682 100644
---
a/connector/connect/common/src/test/resources/query-tests/queries/function_call_function.json
+++
b/connector/connect/common/src/test/resources/query-tests/queries/function_call_function.json
@@ -12,7 +12,7 @@
}
},
"expressions": [{
- "unresolvedFunction": {
+ "callFunction": {
"functionName": "lower",
"arguments": [{
"unresolvedAttribute": {
diff --git
a/connector/connect/common/src/test/resources/query-tests/queries/function_call_function.proto.bin
b/connector/connect/common/src/test/resources/query-tests/queries/function_call_function.proto.bin
index 7c736d93f77..ef985e42131 100644
Binary files
a/connector/connect/common/src/test/resources/query-tests/queries/function_call_function.proto.bin
and
b/connector/connect/common/src/test/resources/query-tests/queries/function_call_function.proto.bin
differ
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 92a9524f67a..36037cce7eb 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -1380,6 +1380,8 @@ class SparkConnectPlanner(val sessionHolder:
SessionHolder) extends Logging {
transformExpressionPlugin(exp.getExtension)
case proto.Expression.ExprTypeCase.COMMON_INLINE_USER_DEFINED_FUNCTION =>
transformCommonInlineUserDefinedFunction(exp.getCommonInlineUserDefinedFunction)
+ case proto.Expression.ExprTypeCase.CALL_FUNCTION =>
+ transformCallFunction(exp.getCallFunction)
case _ =>
throw InvalidPlanInput(
s"Expression with ID: ${exp.getExprTypeCase.getNumber} is not
supported")
@@ -1484,6 +1486,23 @@ class SparkConnectPlanner(val sessionHolder:
SessionHolder) extends Logging {
}
}
+ /**
+ * Translates a SQL function from proto to the Catalyst expression.
+ *
+ * @param fun
+ * Proto representation of the function call.
+ * @return
+ * Expression.
+ */
+ private def transformCallFunction(fun: proto.CallFunction): Expression = {
+ val funcName = fun.getFunctionName
+ val nameParts =
session.sessionState.sqlParser.parseMultipartIdentifier(funcName)
+ UnresolvedFunction(
+ nameParts,
+ fun.getArgumentsList.asScala.map(transformExpression).toSeq,
+ false)
+ }
+
private def unpackUdf(fun: proto.CommonInlineUserDefinedFunction): UdfPacket
= {
Utils.deserialize[UdfPacket](
fun.getScalarScalaUdf.getPayload.toByteArray,
diff --git a/python/pyspark/sql/connect/expressions.py
b/python/pyspark/sql/connect/expressions.py
index e1b648c7bb8..44e6e174f70 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -1027,3 +1027,27 @@ class DistributedSequenceID(Expression):
def __repr__(self) -> str:
return "DistributedSequenceID()"
+
+
+class CallFunction(Expression):
+ def __init__(self, name: str, args: Sequence["Expression"]):
+ super().__init__()
+
+ assert isinstance(name, str)
+ self._name = name
+
+ assert isinstance(args, list) and all(isinstance(arg, Expression) for
arg in args)
+ self._args = args
+
+ def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
+ expr = proto.Expression()
+ expr.call_function.function_name = self._name
+ if len(self._args) > 0:
+ expr.call_function.arguments.extend([arg.to_plan(session) for arg
in self._args])
+ return expr
+
+ def __repr__(self) -> str:
+ if len(self._args) > 0:
+ return f"CallFunction('{self._name}', {', '.join([str(arg) for arg
in self._args])})"
+ else:
+ return f"CallFunction('{self._name}')"
diff --git a/python/pyspark/sql/connect/functions.py
b/python/pyspark/sql/connect/functions.py
index a1c0516ee0d..a92f89c0f6c 100644
--- a/python/pyspark/sql/connect/functions.py
+++ b/python/pyspark/sql/connect/functions.py
@@ -51,6 +51,7 @@ from pyspark.sql.connect.expressions import (
SQLExpression,
LambdaFunction,
UnresolvedNamedLambdaVariable,
+ CallFunction,
)
from pyspark.sql.connect.udf import _create_py_udf
from pyspark.sql.connect.udtf import _create_py_udtf
@@ -3909,8 +3910,9 @@ def udtf(
udtf.__doc__ = pysparkfuncs.udtf.__doc__
-def call_function(udfName: str, *cols: "ColumnOrName") -> Column:
- return _invoke_function(udfName, *[_to_col(c) for c in cols])
+def call_function(funcName: str, *cols: "ColumnOrName") -> Column:
+ expressions = [_to_col(c)._expr for c in cols]
+ return Column(CallFunction(funcName, expressions))
call_function.__doc__ = pysparkfuncs.call_function.__doc__
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py
b/python/pyspark/sql/connect/proto/expressions_pb2.py
index 7a68d831a99..51d1a5d48a1 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -33,7 +33,7 @@ from pyspark.sql.connect.proto import types_pb2 as
spark_dot_connect_dot_types__
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\x95+\n\nExpression\x12=\n\x07literal\x18\x01
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct
[...]
+
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xd9+\n\nExpression\x12=\n\x07literal\x18\x01
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct
[...]
)
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -45,61 +45,63 @@ if _descriptor._USE_C_DESCRIPTORS == False:
b"\n\036org.apache.spark.connect.protoP\001Z\022internal/generated"
)
_EXPRESSION._serialized_start = 105
- _EXPRESSION._serialized_end = 5630
- _EXPRESSION_WINDOW._serialized_start = 1475
- _EXPRESSION_WINDOW._serialized_end = 2258
- _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1765
- _EXPRESSION_WINDOW_WINDOWFRAME._serialized_end = 2258
- _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_start = 2032
- _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_end = 2177
- _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_start = 2179
- _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_end = 2258
- _EXPRESSION_SORTORDER._serialized_start = 2261
- _EXPRESSION_SORTORDER._serialized_end = 2686
- _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_start = 2491
- _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_end = 2599
- _EXPRESSION_SORTORDER_NULLORDERING._serialized_start = 2601
- _EXPRESSION_SORTORDER_NULLORDERING._serialized_end = 2686
- _EXPRESSION_CAST._serialized_start = 2689
- _EXPRESSION_CAST._serialized_end = 2834
- _EXPRESSION_LITERAL._serialized_start = 2837
- _EXPRESSION_LITERAL._serialized_end = 4400
- _EXPRESSION_LITERAL_DECIMAL._serialized_start = 3672
- _EXPRESSION_LITERAL_DECIMAL._serialized_end = 3789
- _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3791
- _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 3889
- _EXPRESSION_LITERAL_ARRAY._serialized_start = 3892
- _EXPRESSION_LITERAL_ARRAY._serialized_end = 4022
- _EXPRESSION_LITERAL_MAP._serialized_start = 4025
- _EXPRESSION_LITERAL_MAP._serialized_end = 4252
- _EXPRESSION_LITERAL_STRUCT._serialized_start = 4255
- _EXPRESSION_LITERAL_STRUCT._serialized_end = 4384
- _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 4402
- _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 4514
- _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 4517
- _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 4721
- _EXPRESSION_EXPRESSIONSTRING._serialized_start = 4723
- _EXPRESSION_EXPRESSIONSTRING._serialized_end = 4773
- _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 4775
- _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 4857
- _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 4859
- _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 4945
- _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 4948
- _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 5080
- _EXPRESSION_UPDATEFIELDS._serialized_start = 5083
- _EXPRESSION_UPDATEFIELDS._serialized_end = 5270
- _EXPRESSION_ALIAS._serialized_start = 5272
- _EXPRESSION_ALIAS._serialized_end = 5392
- _EXPRESSION_LAMBDAFUNCTION._serialized_start = 5395
- _EXPRESSION_LAMBDAFUNCTION._serialized_end = 5553
- _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 5555
- _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 5617
- _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 5633
- _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 5997
- _PYTHONUDF._serialized_start = 6000
- _PYTHONUDF._serialized_end = 6155
- _SCALARSCALAUDF._serialized_start = 6158
- _SCALARSCALAUDF._serialized_end = 6342
- _JAVAUDF._serialized_start = 6345
- _JAVAUDF._serialized_end = 6494
+ _EXPRESSION._serialized_end = 5698
+ _EXPRESSION_WINDOW._serialized_start = 1543
+ _EXPRESSION_WINDOW._serialized_end = 2326
+ _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1833
+ _EXPRESSION_WINDOW_WINDOWFRAME._serialized_end = 2326
+ _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_start = 2100
+ _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_end = 2245
+ _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_start = 2247
+ _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_end = 2326
+ _EXPRESSION_SORTORDER._serialized_start = 2329
+ _EXPRESSION_SORTORDER._serialized_end = 2754
+ _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_start = 2559
+ _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_end = 2667
+ _EXPRESSION_SORTORDER_NULLORDERING._serialized_start = 2669
+ _EXPRESSION_SORTORDER_NULLORDERING._serialized_end = 2754
+ _EXPRESSION_CAST._serialized_start = 2757
+ _EXPRESSION_CAST._serialized_end = 2902
+ _EXPRESSION_LITERAL._serialized_start = 2905
+ _EXPRESSION_LITERAL._serialized_end = 4468
+ _EXPRESSION_LITERAL_DECIMAL._serialized_start = 3740
+ _EXPRESSION_LITERAL_DECIMAL._serialized_end = 3857
+ _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3859
+ _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 3957
+ _EXPRESSION_LITERAL_ARRAY._serialized_start = 3960
+ _EXPRESSION_LITERAL_ARRAY._serialized_end = 4090
+ _EXPRESSION_LITERAL_MAP._serialized_start = 4093
+ _EXPRESSION_LITERAL_MAP._serialized_end = 4320
+ _EXPRESSION_LITERAL_STRUCT._serialized_start = 4323
+ _EXPRESSION_LITERAL_STRUCT._serialized_end = 4452
+ _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 4470
+ _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 4582
+ _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 4585
+ _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 4789
+ _EXPRESSION_EXPRESSIONSTRING._serialized_start = 4791
+ _EXPRESSION_EXPRESSIONSTRING._serialized_end = 4841
+ _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 4843
+ _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 4925
+ _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 4927
+ _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 5013
+ _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 5016
+ _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 5148
+ _EXPRESSION_UPDATEFIELDS._serialized_start = 5151
+ _EXPRESSION_UPDATEFIELDS._serialized_end = 5338
+ _EXPRESSION_ALIAS._serialized_start = 5340
+ _EXPRESSION_ALIAS._serialized_end = 5460
+ _EXPRESSION_LAMBDAFUNCTION._serialized_start = 5463
+ _EXPRESSION_LAMBDAFUNCTION._serialized_end = 5621
+ _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 5623
+ _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 5685
+ _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 5701
+ _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 6065
+ _PYTHONUDF._serialized_start = 6068
+ _PYTHONUDF._serialized_end = 6223
+ _SCALARSCALAUDF._serialized_start = 6226
+ _SCALARSCALAUDF._serialized_end = 6410
+ _JAVAUDF._serialized_start = 6413
+ _JAVAUDF._serialized_end = 6562
+ _CALLFUNCTION._serialized_start = 6564
+ _CALLFUNCTION._serialized_end = 6672
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index bef87203b55..b9b16ce35e3 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -1101,6 +1101,7 @@ class Expression(google.protobuf.message.Message):
UPDATE_FIELDS_FIELD_NUMBER: builtins.int
UNRESOLVED_NAMED_LAMBDA_VARIABLE_FIELD_NUMBER: builtins.int
COMMON_INLINE_USER_DEFINED_FUNCTION_FIELD_NUMBER: builtins.int
+ CALL_FUNCTION_FIELD_NUMBER: builtins.int
EXTENSION_FIELD_NUMBER: builtins.int
@property
def literal(self) -> global___Expression.Literal: ...
@@ -1135,6 +1136,8 @@ class Expression(google.protobuf.message.Message):
@property
def common_inline_user_defined_function(self) ->
global___CommonInlineUserDefinedFunction: ...
@property
+ def call_function(self) -> global___CallFunction: ...
+ @property
def extension(self) -> google.protobuf.any_pb2.Any:
"""This field is used to mark extensions to the protocol. When plugins
generate arbitrary
relations they can add them here. During the planning the correct
resolution is done.
@@ -1158,6 +1161,7 @@ class Expression(google.protobuf.message.Message):
unresolved_named_lambda_variable:
global___Expression.UnresolvedNamedLambdaVariable
| None = ...,
common_inline_user_defined_function:
global___CommonInlineUserDefinedFunction | None = ...,
+ call_function: global___CallFunction | None = ...,
extension: google.protobuf.any_pb2.Any | None = ...,
) -> None: ...
def HasField(
@@ -1165,6 +1169,8 @@ class Expression(google.protobuf.message.Message):
field_name: typing_extensions.Literal[
"alias",
b"alias",
+ "call_function",
+ b"call_function",
"cast",
b"cast",
"common_inline_user_defined_function",
@@ -1204,6 +1210,8 @@ class Expression(google.protobuf.message.Message):
field_name: typing_extensions.Literal[
"alias",
b"alias",
+ "call_function",
+ b"call_function",
"cast",
b"cast",
"common_inline_user_defined_function",
@@ -1256,6 +1264,7 @@ class Expression(google.protobuf.message.Message):
"update_fields",
"unresolved_named_lambda_variable",
"common_inline_user_defined_function",
+ "call_function",
"extension",
] | None: ...
@@ -1469,3 +1478,30 @@ class JavaUDF(google.protobuf.message.Message):
) -> typing_extensions.Literal["output_type"] | None: ...
global___JavaUDF = JavaUDF
+
+class CallFunction(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ FUNCTION_NAME_FIELD_NUMBER: builtins.int
+ ARGUMENTS_FIELD_NUMBER: builtins.int
+ function_name: builtins.str
+ """(Required) Unparsed name of the SQL function."""
+ @property
+ def arguments(
+ self,
+ ) ->
google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Expression]:
+ """(Optional) Function arguments. Empty arguments are allowed."""
+ def __init__(
+ self,
+ *,
+ function_name: builtins.str = ...,
+ arguments: collections.abc.Iterable[global___Expression] | None = ...,
+ ) -> None: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "arguments", b"arguments", "function_name", b"function_name"
+ ],
+ ) -> None: ...
+
+global___CallFunction = CallFunction
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index f566fcee0e3..b45e1daa0fd 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -14395,16 +14395,16 @@ def call_udf(udfName: str, *cols: "ColumnOrName") ->
Column:
@try_remote_functions
-def call_function(udfName: str, *cols: "ColumnOrName") -> Column:
+def call_function(funcName: str, *cols: "ColumnOrName") -> Column:
"""
- Call a builtin or temp function.
+ Call a SQL function.
.. versionadded:: 3.5.0
Parameters
----------
- udfName : str
- name of the function
+ funcName : str
+ function name that follows the SQL identifier syntax (can be quoted,
can be qualified)
cols : :class:`~pyspark.sql.Column` or str
column names or :class:`~pyspark.sql.Column`\\s to be used in the
function
@@ -14442,9 +14442,22 @@ def call_function(udfName: str, *cols: "ColumnOrName")
-> Column:
+-------+
| 2.0|
+-------+
+ >>> _ = spark.sql("CREATE FUNCTION custom_avg AS
'test.org.apache.spark.sql.MyDoubleAvg'")
+ >>> df.select(call_function("custom_avg", col("id"))).show()
+ +------------------------------------+
+ |spark_catalog.default.custom_avg(id)|
+ +------------------------------------+
+ | 102.0|
+ +------------------------------------+
+ >>> df.select(call_function("spark_catalog.default.custom_avg",
col("id"))).show()
+ +------------------------------------+
+ |spark_catalog.default.custom_avg(id)|
+ +------------------------------------+
+ | 102.0|
+ +------------------------------------+
"""
sc = get_active_spark_context()
- return _invoke_function("call_function", udfName, _to_seq(sc, cols,
_to_java_column))
+ return _invoke_function("call_function", funcName, _to_seq(sc, cols,
_to_java_column))
@try_remote_functions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 2a8cfd250c9..ca5e4422ca9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -8338,7 +8338,7 @@ object functions {
@scala.annotation.varargs
@deprecated("Use call_udf")
def callUDF(udfName: String, cols: Column*): Column =
- call_function(udfName, cols: _*)
+ call_function(Seq(udfName), cols: _*)
/**
* Call an user-defined function.
@@ -8357,18 +8357,28 @@ object functions {
*/
@scala.annotation.varargs
def call_udf(udfName: String, cols: Column*): Column =
- call_function(udfName, cols: _*)
+ call_function(Seq(udfName), cols: _*)
/**
- * Call a builtin or temp function.
+ * Call a SQL function.
*
- * @param funcName function name
+ * @param funcName function name that follows the SQL identifier syntax
+ * (can be quoted, can be qualified)
* @param cols the expression parameters of function
* @since 3.5.0
*/
@scala.annotation.varargs
- def call_function(funcName: String, cols: Column*): Column =
- withExpr { UnresolvedFunction(funcName, cols.map(_.expr), false) }
+ def call_function(funcName: String, cols: Column*): Column = {
+ val parser =
SparkSession.getActiveSession.map(_.sessionState.sqlParser).getOrElse {
+ new SparkSqlParser()
+ }
+ val nameParts = parser.parseMultipartIdentifier(funcName)
+ call_function(nameParts, cols: _*)
+ }
+
+ private def call_function(nameParts: Seq[String], cols: Column*): Column =
withExpr {
+ UnresolvedFunction(nameParts, cols.map(_.expr), false)
+ }
/**
* Unwrap UDT data type column into its underlying type.
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 9781a8e3ff4..c7dcb575ff0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -5918,6 +5918,26 @@ class DataFrameFunctionsSuite extends QueryTest with
SharedSparkSession {
test("call_function") {
checkAnswer(testData2.select(call_function("avg", $"a")),
testData2.selectExpr("avg(a)"))
+
+ withUserDefinedFunction("custom_func" -> true, "custom_sum" -> false) {
+ spark.udf.register("custom_func", (i: Int) => { i + 2 })
+ checkAnswer(
+ testData2.select(call_function("custom_func", $"a")),
+ Seq(Row(3), Row(3), Row(4), Row(4), Row(5), Row(5)))
+ spark.udf.register("default.custom_func", (i: Int) => { i + 2 })
+ checkAnswer(
+ testData2.select(call_function("`default.custom_func`", $"a")),
+ Seq(Row(3), Row(3), Row(4), Row(4), Row(5), Row(5)))
+
+ sql("CREATE FUNCTION custom_sum AS
'test.org.apache.spark.sql.MyDoubleSum'")
+ checkAnswer(
+ testData2.select(
+ call_function("custom_sum", $"a"),
+ call_function("default.custom_sum", $"a"),
+ call_function("spark_catalog.default.custom_sum", $"a")),
+ Row(12.0, 12.0, 12.0))
+ }
+
}
}
diff --git
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
index ef430f4b6a2..d12ebae0f5f 100644
---
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
+++
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
@@ -37,7 +37,7 @@ import org.apache.spark.{SparkException, SparkFiles,
TestUtils}
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.catalyst.plans.logical.Project
import org.apache.spark.sql.execution.WholeStageCodegenExec
-import org.apache.spark.sql.functions.max
+import org.apache.spark.sql.functions.{call_function, max}
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
@@ -552,6 +552,19 @@ class HiveUDFSuite extends QueryTest with
TestHiveSingleton with SQLTestUtils {
}
}
+ test("Invoke a persist hive function with call_function") {
+ val testData = spark.range(5).repartition(1)
+ withUserDefinedFunction("custom_avg" -> false) {
+ sql(s"CREATE FUNCTION custom_avg AS
'${classOf[GenericUDAFAverage].getName}'")
+ checkAnswer(
+ testData.select(
+ call_function("custom_avg", $"id"),
+ call_function("default.custom_avg", $"id"),
+ call_function("spark_catalog.default.custom_avg", $"id")),
+ Row(2.0, 2.0, 2.0))
+ }
+ }
+
test("Temp function has dots in the names") {
withUserDefinedFunction("test_avg" -> false, "`default.test_avg`" -> true)
{
sql(s"CREATE FUNCTION test_avg AS
'${classOf[GenericUDAFAverage].getName}'")
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]