This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new e363e9d1b7b [SPARK-42124][PYTHON][CONNECT] Scalar Inline Python UDF in
Spark Connect
e363e9d1b7b is described below
commit e363e9d1b7b2ff19c7ff39760521f83481f78c1c
Author: Xinrong Meng <[email protected]>
AuthorDate: Wed Jan 25 19:42:46 2023 +0900
[SPARK-42124][PYTHON][CONNECT] Scalar Inline Python UDF in Spark Connect
### What changes were proposed in this pull request?
Support scalar inline user-defined function of Python(a.k.a., unregistered
Python UDF) in Spark Connect.
Currently, the user-specified return type must be of
`pyspark.sql.types.DataType`.
There will be follow-up PRs on:
- Support Pandas UDF
[jira](https://issues.apache.org/jira/browse/SPARK-42125)
- Support user-specified return type in DDL-formatted strings
[jira](https://issues.apache.org/jira/browse/SPARK-42126)
### Why are the changes needed?
Feature parity with vanilla PySpark.
### Does this PR introduce _any_ user-facing change?
Yes. Unregistered Python UDF is supported now, as shown below:
```
>>> spark.range(2).withColumn('plus_one', udf(lambda x: x + 1)('id')).show()
+---+--------+
| id|plus_one|
+---+--------+
| 0| 1|
| 1| 2|
+---+--------+
>>> udf(LongType())
... def f(x):
... return x + 1
...
>>> spark.range(2).withColumn('plus_one', f('id')).show()
+---+--------+
| id|plus_one|
+---+--------+
| 0| 1|
| 1| 2|
+---+--------+
```
### How was this patch tested?
Unit tests.
Closes #39585 from xinrong-meng/connect_udf.
Authored-by: Xinrong Meng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../main/protobuf/spark/connect/expressions.proto | 24 +++
.../sql/connect/planner/SparkConnectPlanner.scala | 65 ++++++++
.../messages/ConnectProtoMessagesSuite.scala | 34 +++++
python/pyspark/sql/connect/_typing.py | 21 ++-
python/pyspark/sql/connect/expressions.py | 62 ++++++++
python/pyspark/sql/connect/functions.py | 30 +++-
.../pyspark/sql/connect/proto/expressions_pb2.py | 118 +++++++++------
.../pyspark/sql/connect/proto/expressions_pb2.pyi | 88 +++++++++++
python/pyspark/sql/connect/udf.py | 165 +++++++++++++++++++++
python/pyspark/sql/functions.py | 4 +
.../sql/tests/connect/test_connect_function.py | 47 +++++-
11 files changed, 604 insertions(+), 54 deletions(-)
diff --git
a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
index f7feae0e2f0..7ae0a6c5008 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -44,6 +44,7 @@ message Expression {
UnresolvedExtractValue unresolved_extract_value = 12;
UpdateFields update_fields = 13;
UnresolvedNamedLambdaVariable unresolved_named_lambda_variable = 14;
+ ScalarInlineUserDefinedFunction scalar_inline_user_defined_function = 15;
// This field is used to mark extensions to the protocol. When plugins
generate arbitrary
// relations they can add them here. During the planning the correct
resolution is done.
@@ -295,3 +296,26 @@ message Expression {
repeated string name_parts = 1;
}
}
+
+message ScalarInlineUserDefinedFunction {
+ // (Required) Name of the user-defined function.
+ string function_name = 1;
+ // (Required) Indicate if the user-defined function is deterministic.
+ bool deterministic = 2;
+ // (Optional) Function arguments. Empty arguments are allowed.
+ repeated Expression arguments = 3;
+ // (Required) Indicate the function type of the user-defined function.
+ oneof function {
+ PythonUDF python_udf = 4;
+ }
+}
+
+message PythonUDF {
+ // (Required) Output type of the Python UDF
+ string output_type = 1;
+ // (Required) EvalType of the Python UDF
+ int32 eval_type = 2;
+ // (Required) The encoded commands of the Python UDF
+ bytes command = 3;
+}
+
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index f95f065c5b3..dc921cee282 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -742,6 +742,8 @@ class SparkConnectPlanner(val session: SparkSession) {
transformWindowExpression(exp.getWindow)
case proto.Expression.ExprTypeCase.EXTENSION =>
transformExpressionPlugin(exp.getExtension)
+ case proto.Expression.ExprTypeCase.SCALAR_INLINE_USER_DEFINED_FUNCTION =>
+
transformScalarInlineUserDefinedFunction(exp.getScalarInlineUserDefinedFunction)
case _ =>
throw InvalidPlanInput(
s"Expression with ID: ${exp.getExprTypeCase.getNumber} is not
supported")
@@ -816,6 +818,65 @@ class SparkConnectPlanner(val session: SparkSession) {
}
}
+ /**
+ * Translates a user-defined function from proto to the Catalyst expression.
+ *
+ * @param fun
+ * Proto representation of the function call.
+ * @return
+ * Expression.
+ */
+ private def transformScalarInlineUserDefinedFunction(
+ fun: proto.ScalarInlineUserDefinedFunction): Expression = {
+ fun.getFunctionCase match {
+ case proto.ScalarInlineUserDefinedFunction.FunctionCase.PYTHON_UDF =>
+ transformPythonUDF(fun)
+ case _ =>
+ throw InvalidPlanInput(
+ s"Function with ID: ${fun.getFunctionCase.getNumber} is not
supported")
+ }
+ }
+
+ /**
+ * Translates a Python user-defined function from proto to the Catalyst
expression.
+ *
+ * @param fun
+ * Proto representation of the Python user-defined function.
+ * @return
+ * PythonUDF.
+ */
+ private def transformPythonUDF(fun: proto.ScalarInlineUserDefinedFunction):
PythonUDF = {
+ val udf = fun.getPythonUdf
+ PythonUDF(
+ name = fun.getFunctionName,
+ func = transformPythonFunction(udf),
+ dataType = DataType.parseTypeWithFallback(
+ schema = udf.getOutputType,
+ parser = DataType.fromDDL,
+ fallbackParser = DataType.fromJson) match {
+ case s: DataType => s
+ case other => throw InvalidPlanInput(s"Invalid return type $other")
+ },
+ children = fun.getArgumentsList.asScala.map(transformExpression).toSeq,
+ evalType = udf.getEvalType,
+ udfDeterministic = fun.getDeterministic)
+ }
+
+ private def transformPythonFunction(fun: proto.PythonUDF):
SimplePythonFunction = {
+ SimplePythonFunction(
+ command = fun.getCommand.toByteArray,
+ // Empty environment variables
+ envVars = Maps.newHashMap(),
+ // No imported Python libraries
+ pythonIncludes = Lists.newArrayList(),
+ pythonExec = pythonExec,
+ pythonVer = "3.9", // TODO(SPARK-40532) This needs to be an actual
Python version.
+ // Empty broadcast variables
+ broadcastVars = Lists.newArrayList(),
+ // Null accumulator
+ accumulator = null)
+ }
+
/**
* Translates a LambdaFunction from proto to the Catalyst expression.
*/
@@ -1351,11 +1412,15 @@ class SparkConnectPlanner(val session: SparkSession) {
private def handleCreateScalarFunction(cf: proto.CreateScalarFunction): Unit
= {
val function = SimplePythonFunction(
cf.getSerializedFunction.toByteArray,
+ // Empty environment variables
Maps.newHashMap(),
+ // No imported Python libraries
Lists.newArrayList(),
pythonExec,
"3.9", // TODO(SPARK-40532) This needs to be an actual Python version.
+ // Empty broadcast variables
Lists.newArrayList(),
+ // Null accumulator
null)
val udf = UserDefinedPythonFunction(
diff --git
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala
index 08f12aa6d08..3d8fae83428 100644
---
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala
+++
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala
@@ -16,6 +16,8 @@
*/
package org.apache.spark.sql.connect.messages
+import com.google.protobuf.ByteString
+
import org.apache.spark.SparkFunSuite
import org.apache.spark.connect.proto
@@ -48,4 +50,36 @@ class ConnectProtoMessagesSuite extends SparkFunSuite {
assert(extLit.getLiteral.hasInteger)
assert(extLit.getLiteral.getInteger == 32)
}
+
+ test("ScalarInlineUserDefinedFunction") {
+ val arguments = proto.Expression
+ .newBuilder()
+ .setUnresolvedAttribute(
+
proto.Expression.UnresolvedAttribute.newBuilder().setUnparsedIdentifier("id"))
+ .build()
+
+ val pythonUdf = proto.PythonUDF
+ .newBuilder()
+ .setEvalType(100)
+ .setOutputType("\"integer\"")
+ .setCommand(ByteString.copyFrom("command".getBytes()))
+ .build()
+
+ val scalarInlineUserDefinedFunctionExpr = proto.Expression
+ .newBuilder()
+ .setScalarInlineUserDefinedFunction(
+ proto.ScalarInlineUserDefinedFunction
+ .newBuilder()
+ .setFunctionName("f")
+ .setDeterministic(true)
+ .addArguments(arguments)
+ .setPythonUdf(pythonUdf))
+ .build()
+
+ val fun =
scalarInlineUserDefinedFunctionExpr.getScalarInlineUserDefinedFunction()
+ assert(fun.getFunctionName == "f")
+ assert(fun.getDeterministic == true)
+ assert(fun.getArgumentsCount == 1)
+ assert(fun.hasPythonUdf == true)
+ }
}
diff --git a/python/pyspark/sql/connect/_typing.py
b/python/pyspark/sql/connect/_typing.py
index 29a14384c82..66b08d898fe 100644
--- a/python/pyspark/sql/connect/_typing.py
+++ b/python/pyspark/sql/connect/_typing.py
@@ -22,11 +22,12 @@ if sys.version_info >= (3, 8):
else:
from typing_extensions import Protocol
-from typing import Union, Optional
+from typing import Any, Callable, Union, Optional
import datetime
import decimal
from pyspark.sql.connect.column import Column
+from pyspark.sql.connect.types import DataType
ColumnOrName = Union[Column, str]
@@ -41,6 +42,24 @@ DecimalLiteral = decimal.Decimal
DateTimeLiteral = Union[datetime.datetime, datetime.date]
+DataTypeOrString = Union[DataType, str]
+
+
+class UserDefinedFunctionLike(Protocol):
+ func: Callable[..., Any]
+ evalType: int
+ deterministic: bool
+
+ @property
+ def returnType(self) -> DataType:
+ ...
+
+ def __call__(self, *args: ColumnOrName) -> Column:
+ ...
+
+ def asNondeterministic(self) -> "UserDefinedFunctionLike":
+ ...
+
class UserDefinedFunctionCallable(Protocol):
def __call__(self, *_: ColumnOrName) -> Column:
diff --git a/python/pyspark/sql/connect/expressions.py
b/python/pyspark/sql/connect/expressions.py
index c8d361af2a5..0fa67a5f8d0 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -480,6 +480,68 @@ class UnresolvedFunction(Expression):
return f"{self._name}({', '.join([str(arg) for arg in
self._args])})"
+class PythonUDF:
+ """Represents a Python user-defined function."""
+
+ def __init__(
+ self,
+ output_type: str,
+ eval_type: int,
+ command: bytes,
+ ) -> None:
+ self._output_type = output_type
+ self._eval_type = eval_type
+ self._command = command
+
+ def to_plan(self, session: "SparkConnectClient") -> proto.PythonUDF:
+ expr = proto.PythonUDF()
+ expr.output_type = self._output_type
+ expr.eval_type = self._eval_type
+ expr.command = self._command
+ return expr
+
+ def __repr__(self) -> str:
+ return (
+ f"{self._output_type}, {self._eval_type}, "
+ f"{self._command}" # type: ignore[str-bytes-safe]
+ )
+
+
+class ScalarInlineUserDefinedFunction(Expression):
+ """Represents a scalar inline user-defined function of any programming
languages."""
+
+ def __init__(
+ self,
+ function_name: str,
+ deterministic: bool,
+ arguments: Sequence[Expression],
+ function: PythonUDF,
+ ):
+ self._function_name = function_name
+ self._deterministic = deterministic
+ self._arguments = arguments
+ self._function = function
+
+ def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
+ expr = proto.Expression()
+ expr.scalar_inline_user_defined_function.function_name =
self._function_name
+ expr.scalar_inline_user_defined_function.deterministic =
self._deterministic
+ if len(self._arguments) > 0:
+ expr.scalar_inline_user_defined_function.arguments.extend(
+ [arg.to_plan(session) for arg in self._arguments]
+ )
+ expr.scalar_inline_user_defined_function.python_udf.CopyFrom(
+ self._function.to_plan(session)
+ )
+ return expr
+
+ def __repr__(self) -> str:
+ return (
+ f"{self._function_name}({', '.join([str(arg) for arg in
self._arguments])}), "
+ f"{self._deterministic}, {self._function}"
+ )
+
+
class WithField(Expression):
def __init__(
self,
diff --git a/python/pyspark/sql/connect/functions.py
b/python/pyspark/sql/connect/functions.py
index 75f6ba1ff64..ee7b45622b3 100644
--- a/python/pyspark/sql/connect/functions.py
+++ b/python/pyspark/sql/connect/functions.py
@@ -17,6 +17,7 @@
import inspect
import warnings
+import functools
from typing import (
Any,
Dict,
@@ -49,11 +50,16 @@ from pyspark.sql.connect.expressions import (
LambdaFunction,
UnresolvedNamedLambdaVariable,
)
+from pyspark.sql.connect.udf import _create_udf
from pyspark.sql import functions as pysparkfuncs
-from pyspark.sql.types import _from_numpy_type, DataType, StructType, ArrayType
+from pyspark.sql.types import _from_numpy_type, DataType, StructType,
ArrayType, StringType
if TYPE_CHECKING:
- from pyspark.sql.connect._typing import ColumnOrName
+ from pyspark.sql.connect._typing import (
+ ColumnOrName,
+ DataTypeOrString,
+ UserDefinedFunctionLike,
+ )
from pyspark.sql.connect.dataframe import DataFrame
@@ -2401,8 +2407,24 @@ def unwrap_udt(col: "ColumnOrName") -> Column:
unwrap_udt.__doc__ = pysparkfuncs.unwrap_udt.__doc__
-def udf(*args: Any, **kwargs: Any) -> None:
- raise NotImplementedError("udf() is not implemented.")
+def udf(
+ f: Optional[Union[Callable[..., Any], "DataTypeOrString"]] = None,
+ returnType: "DataTypeOrString" = StringType(),
+) -> Union["UserDefinedFunctionLike", Callable[[Callable[..., Any]],
"UserDefinedFunctionLike"]]:
+ from pyspark.rdd import PythonEvalType
+
+ if f is None or isinstance(f, (str, DataType)):
+ # If DataType has been passed as a positional argument
+ # for decorator use it as a returnType
+ return_type = f or returnType
+ return functools.partial(
+ _create_udf, returnType=return_type,
evalType=PythonEvalType.SQL_BATCHED_UDF
+ )
+ else:
+ return _create_udf(f=f, returnType=returnType,
evalType=PythonEvalType.SQL_BATCHED_UDF)
+
+
+udf.__doc__ = pysparkfuncs.udf.__doc__
def pandas_udf(*args: Any, **kwargs: Any) -> None:
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py
b/python/pyspark/sql/connect/proto/expressions_pb2.py
index 87c16964102..0b2419fee35 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -34,7 +34,7 @@ from pyspark.sql.connect.proto import types_pb2 as
spark_dot_connect_dot_types__
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\x92$\n\nExpression\x12=\n\x07literal\x18\x01
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct
[...]
+
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\x92%\n\nExpression\x12=\n\x07literal\x18\x01
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct
[...]
)
@@ -61,6 +61,10 @@ _EXPRESSION_LAMBDAFUNCTION =
_EXPRESSION.nested_types_by_name["LambdaFunction"]
_EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE = _EXPRESSION.nested_types_by_name[
"UnresolvedNamedLambdaVariable"
]
+_SCALARINLINEUSERDEFINEDFUNCTION = DESCRIPTOR.message_types_by_name[
+ "ScalarInlineUserDefinedFunction"
+]
+_PYTHONUDF = DESCRIPTOR.message_types_by_name["PythonUDF"]
_EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE =
_EXPRESSION_WINDOW_WINDOWFRAME.enum_types_by_name[
"FrameType"
]
@@ -257,52 +261,78 @@ _sym_db.RegisterMessage(Expression.Alias)
_sym_db.RegisterMessage(Expression.LambdaFunction)
_sym_db.RegisterMessage(Expression.UnresolvedNamedLambdaVariable)
+ScalarInlineUserDefinedFunction = _reflection.GeneratedProtocolMessageType(
+ "ScalarInlineUserDefinedFunction",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _SCALARINLINEUSERDEFINEDFUNCTION,
+ "__module__": "spark.connect.expressions_pb2"
+ #
@@protoc_insertion_point(class_scope:spark.connect.ScalarInlineUserDefinedFunction)
+ },
+)
+_sym_db.RegisterMessage(ScalarInlineUserDefinedFunction)
+
+PythonUDF = _reflection.GeneratedProtocolMessageType(
+ "PythonUDF",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _PYTHONUDF,
+ "__module__": "spark.connect.expressions_pb2"
+ # @@protoc_insertion_point(class_scope:spark.connect.PythonUDF)
+ },
+)
+_sym_db.RegisterMessage(PythonUDF)
+
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
DESCRIPTOR._serialized_options =
b"\n\036org.apache.spark.connect.protoP\001"
_EXPRESSION._serialized_start = 105
- _EXPRESSION._serialized_end = 4731
- _EXPRESSION_WINDOW._serialized_start = 1347
- _EXPRESSION_WINDOW._serialized_end = 2130
- _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1637
- _EXPRESSION_WINDOW_WINDOWFRAME._serialized_end = 2130
- _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_start = 1904
- _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_end = 2049
- _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_start = 2051
- _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_end = 2130
- _EXPRESSION_SORTORDER._serialized_start = 2133
- _EXPRESSION_SORTORDER._serialized_end = 2558
- _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_start = 2363
- _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_end = 2471
- _EXPRESSION_SORTORDER_NULLORDERING._serialized_start = 2473
- _EXPRESSION_SORTORDER_NULLORDERING._serialized_end = 2558
- _EXPRESSION_CAST._serialized_start = 2561
- _EXPRESSION_CAST._serialized_end = 2706
- _EXPRESSION_LITERAL._serialized_start = 2709
- _EXPRESSION_LITERAL._serialized_end = 3585
- _EXPRESSION_LITERAL_DECIMAL._serialized_start = 3352
- _EXPRESSION_LITERAL_DECIMAL._serialized_end = 3469
- _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3471
- _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 3569
- _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 3587
- _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 3657
- _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 3660
- _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 3864
- _EXPRESSION_EXPRESSIONSTRING._serialized_start = 3866
- _EXPRESSION_EXPRESSIONSTRING._serialized_end = 3916
- _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 3918
- _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 4000
- _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 4002
- _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 4046
- _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 4049
- _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 4181
- _EXPRESSION_UPDATEFIELDS._serialized_start = 4184
- _EXPRESSION_UPDATEFIELDS._serialized_end = 4371
- _EXPRESSION_ALIAS._serialized_start = 4373
- _EXPRESSION_ALIAS._serialized_end = 4493
- _EXPRESSION_LAMBDAFUNCTION._serialized_start = 4496
- _EXPRESSION_LAMBDAFUNCTION._serialized_end = 4654
- _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 4656
- _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 4718
+ _EXPRESSION._serialized_end = 4859
+ _EXPRESSION_WINDOW._serialized_start = 1475
+ _EXPRESSION_WINDOW._serialized_end = 2258
+ _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1765
+ _EXPRESSION_WINDOW_WINDOWFRAME._serialized_end = 2258
+ _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_start = 2032
+ _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_end = 2177
+ _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_start = 2179
+ _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_end = 2258
+ _EXPRESSION_SORTORDER._serialized_start = 2261
+ _EXPRESSION_SORTORDER._serialized_end = 2686
+ _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_start = 2491
+ _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_end = 2599
+ _EXPRESSION_SORTORDER_NULLORDERING._serialized_start = 2601
+ _EXPRESSION_SORTORDER_NULLORDERING._serialized_end = 2686
+ _EXPRESSION_CAST._serialized_start = 2689
+ _EXPRESSION_CAST._serialized_end = 2834
+ _EXPRESSION_LITERAL._serialized_start = 2837
+ _EXPRESSION_LITERAL._serialized_end = 3713
+ _EXPRESSION_LITERAL_DECIMAL._serialized_start = 3480
+ _EXPRESSION_LITERAL_DECIMAL._serialized_end = 3597
+ _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3599
+ _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 3697
+ _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 3715
+ _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 3785
+ _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 3788
+ _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 3992
+ _EXPRESSION_EXPRESSIONSTRING._serialized_start = 3994
+ _EXPRESSION_EXPRESSIONSTRING._serialized_end = 4044
+ _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 4046
+ _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 4128
+ _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 4130
+ _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 4174
+ _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 4177
+ _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 4309
+ _EXPRESSION_UPDATEFIELDS._serialized_start = 4312
+ _EXPRESSION_UPDATEFIELDS._serialized_end = 4499
+ _EXPRESSION_ALIAS._serialized_start = 4501
+ _EXPRESSION_ALIAS._serialized_end = 4621
+ _EXPRESSION_LAMBDAFUNCTION._serialized_start = 4624
+ _EXPRESSION_LAMBDAFUNCTION._serialized_end = 4782
+ _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 4784
+ _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 4846
+ _SCALARINLINEUSERDEFINEDFUNCTION._serialized_start = 4862
+ _SCALARINLINEUSERDEFINEDFUNCTION._serialized_end = 5098
+ _PYTHONUDF._serialized_start = 5100
+ _PYTHONUDF._serialized_end = 5199
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index 45889c1518f..0191a0cdaf4 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -932,6 +932,7 @@ class Expression(google.protobuf.message.Message):
UNRESOLVED_EXTRACT_VALUE_FIELD_NUMBER: builtins.int
UPDATE_FIELDS_FIELD_NUMBER: builtins.int
UNRESOLVED_NAMED_LAMBDA_VARIABLE_FIELD_NUMBER: builtins.int
+ SCALAR_INLINE_USER_DEFINED_FUNCTION_FIELD_NUMBER: builtins.int
EXTENSION_FIELD_NUMBER: builtins.int
@property
def literal(self) -> global___Expression.Literal: ...
@@ -964,6 +965,8 @@ class Expression(google.protobuf.message.Message):
self,
) -> global___Expression.UnresolvedNamedLambdaVariable: ...
@property
+ def scalar_inline_user_defined_function(self) ->
global___ScalarInlineUserDefinedFunction: ...
+ @property
def extension(self) -> google.protobuf.any_pb2.Any:
"""This field is used to mark extensions to the protocol. When plugins
generate arbitrary
relations they can add them here. During the planning the correct
resolution is done.
@@ -986,6 +989,7 @@ class Expression(google.protobuf.message.Message):
update_fields: global___Expression.UpdateFields | None = ...,
unresolved_named_lambda_variable:
global___Expression.UnresolvedNamedLambdaVariable
| None = ...,
+ scalar_inline_user_defined_function:
global___ScalarInlineUserDefinedFunction | None = ...,
extension: google.protobuf.any_pb2.Any | None = ...,
) -> None: ...
def HasField(
@@ -1005,6 +1009,8 @@ class Expression(google.protobuf.message.Message):
b"lambda_function",
"literal",
b"literal",
+ "scalar_inline_user_defined_function",
+ b"scalar_inline_user_defined_function",
"sort_order",
b"sort_order",
"unresolved_attribute",
@@ -1042,6 +1048,8 @@ class Expression(google.protobuf.message.Message):
b"lambda_function",
"literal",
b"literal",
+ "scalar_inline_user_defined_function",
+ b"scalar_inline_user_defined_function",
"sort_order",
b"sort_order",
"unresolved_attribute",
@@ -1079,7 +1087,87 @@ class Expression(google.protobuf.message.Message):
"unresolved_extract_value",
"update_fields",
"unresolved_named_lambda_variable",
+ "scalar_inline_user_defined_function",
"extension",
] | None: ...
global___Expression = Expression
+
+class ScalarInlineUserDefinedFunction(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ FUNCTION_NAME_FIELD_NUMBER: builtins.int
+ DETERMINISTIC_FIELD_NUMBER: builtins.int
+ ARGUMENTS_FIELD_NUMBER: builtins.int
+ PYTHON_UDF_FIELD_NUMBER: builtins.int
+ function_name: builtins.str
+ """(Required) Name of the user-defined function."""
+ deterministic: builtins.bool
+ """(Required) Indicate if the user-defined function is deterministic."""
+ @property
+ def arguments(
+ self,
+ ) ->
google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Expression]:
+ """(Optional) Function arguments. Empty arguments are allowed."""
+ @property
+ def python_udf(self) -> global___PythonUDF: ...
+ def __init__(
+ self,
+ *,
+ function_name: builtins.str = ...,
+ deterministic: builtins.bool = ...,
+ arguments: collections.abc.Iterable[global___Expression] | None = ...,
+ python_udf: global___PythonUDF | None = ...,
+ ) -> None: ...
+ def HasField(
+ self,
+ field_name: typing_extensions.Literal["function", b"function",
"python_udf", b"python_udf"],
+ ) -> builtins.bool: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "arguments",
+ b"arguments",
+ "deterministic",
+ b"deterministic",
+ "function",
+ b"function",
+ "function_name",
+ b"function_name",
+ "python_udf",
+ b"python_udf",
+ ],
+ ) -> None: ...
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["function", b"function"]
+ ) -> typing_extensions.Literal["python_udf"] | None: ...
+
+global___ScalarInlineUserDefinedFunction = ScalarInlineUserDefinedFunction
+
+class PythonUDF(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ OUTPUT_TYPE_FIELD_NUMBER: builtins.int
+ EVAL_TYPE_FIELD_NUMBER: builtins.int
+ COMMAND_FIELD_NUMBER: builtins.int
+ output_type: builtins.str
+ """(Required) Output type of the Python UDF"""
+ eval_type: builtins.int
+ """(Required) EvalType of the Python UDF"""
+ command: builtins.bytes
+ """(Required) The encoded commands of the Python UDF"""
+ def __init__(
+ self,
+ *,
+ output_type: builtins.str = ...,
+ eval_type: builtins.int = ...,
+ command: builtins.bytes = ...,
+ ) -> None: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "command", b"command", "eval_type", b"eval_type", "output_type",
b"output_type"
+ ],
+ ) -> None: ...
+
+global___PythonUDF = PythonUDF
diff --git a/python/pyspark/sql/connect/udf.py
b/python/pyspark/sql/connect/udf.py
new file mode 100644
index 00000000000..4a465084838
--- /dev/null
+++ b/python/pyspark/sql/connect/udf.py
@@ -0,0 +1,165 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+"""
+User-defined function related classes and functions
+"""
+import functools
+from typing import Callable, Any, TYPE_CHECKING, Optional
+
+from pyspark.serializers import CloudPickleSerializer
+from pyspark.sql.connect.expressions import (
+ ColumnReference,
+ PythonUDF,
+ ScalarInlineUserDefinedFunction,
+)
+from pyspark.sql.connect.column import Column
+from pyspark.sql.types import DataType, StringType
+
+
+if TYPE_CHECKING:
+ from pyspark.sql.connect._typing import (
+ ColumnOrName,
+ DataTypeOrString,
+ UserDefinedFunctionLike,
+ )
+ from pyspark.sql.types import StringType
+
+
+def _create_udf(
+ f: Callable[..., Any],
+ returnType: "DataTypeOrString",
+ evalType: int,
+ name: Optional[str] = None,
+ deterministic: bool = True,
+) -> "UserDefinedFunctionLike":
+ # Set the name of the UserDefinedFunction object to be the name of
function f
+ udf_obj = UserDefinedFunction(
+ f, returnType=returnType, name=name, evalType=evalType,
deterministic=deterministic
+ )
+ return udf_obj._wrapped()
+
+
+class UserDefinedFunction:
+ """
+ User defined function in Python
+
+ Notes
+ -----
+ The constructor of this class is not supposed to be directly called.
+ Use :meth:`pyspark.sql.functions.udf` or
:meth:`pyspark.sql.functions.pandas_udf`
+ to create this instance.
+ """
+
+ def __init__(
+ self,
+ func: Callable[..., Any],
+ returnType: "DataTypeOrString" = StringType(),
+ name: Optional[str] = None,
+ evalType: int = 100,
+ deterministic: bool = True,
+ ):
+ if not callable(func):
+ raise TypeError(
+ "Invalid function: not a function or callable (__call__ is not
defined): "
+ "{0}".format(type(func))
+ )
+
+ if not isinstance(returnType, (DataType, str)):
+ raise TypeError(
+ "Invalid return type: returnType should be DataType or str "
+ "but is {}".format(returnType)
+ )
+
+ if not isinstance(evalType, int):
+ raise TypeError(
+ "Invalid evaluation type: evalType should be an int but is
{}".format(evalType)
+ )
+
+ self.func = func
+ self._returnType = returnType
+ self._name = name or (
+ func.__name__ if hasattr(func, "__name__") else
func.__class__.__name__
+ )
+ self.evalType = evalType
+ self.deterministic = deterministic
+
+ def __call__(self, *cols: "ColumnOrName") -> Column:
+ arg_cols = [
+ col if isinstance(col, Column) else Column(ColumnReference(col))
for col in cols
+ ]
+ arg_exprs = [col._expr for col in arg_cols]
+ data_type_str = (
+ self._returnType.json() if isinstance(self._returnType, DataType)
else self._returnType
+ )
+ py_udf = PythonUDF(
+ output_type=data_type_str,
+ eval_type=self.evalType,
+ command=CloudPickleSerializer().dumps((self.func,
self._returnType)),
+ )
+ return Column(
+ ScalarInlineUserDefinedFunction(
+ function_name=self._name,
+ deterministic=self.deterministic,
+ arguments=arg_exprs,
+ function=py_udf,
+ )
+ )
+
+ # This function is for improving the online help system in the interactive
interpreter.
+ # For example, the built-in help / pydoc.help. It wraps the UDF with the
docstring and
+ # argument annotation. (See: SPARK-19161)
+ def _wrapped(self) -> "UserDefinedFunctionLike":
+ """
+ Wrap this udf with a function and attach docstring from func
+ """
+
+ # It is possible for a callable instance without __name__ attribute
or/and
+ # __module__ attribute to be wrapped here. For example,
functools.partial. In this case,
+ # we should avoid wrapping the attributes from the wrapped function to
the wrapper
+ # function. So, we take out these attribute names from the default
names to set and
+ # then manually assign it after being wrapped.
+ assignments = tuple(
+ a for a in functools.WRAPPER_ASSIGNMENTS if a != "__name__" and a
!= "__module__"
+ )
+
+ @functools.wraps(self.func, assigned=assignments)
+ def wrapper(*args: "ColumnOrName") -> Column:
+ return self(*args)
+
+ wrapper.__name__ = self._name
+ wrapper.__module__ = (
+ self.func.__module__
+ if hasattr(self.func, "__module__")
+ else self.func.__class__.__module__
+ )
+
+ wrapper.func = self.func # type: ignore[attr-defined]
+ wrapper.returnType = self._returnType # type: ignore[attr-defined]
+ wrapper.evalType = self.evalType # type: ignore[attr-defined]
+ wrapper.deterministic = self.deterministic # type:
ignore[attr-defined]
+ wrapper.asNondeterministic = functools.wraps( # type:
ignore[attr-defined]
+ self.asNondeterministic
+ )(lambda: self.asNondeterministic()._wrapped())
+ wrapper._unwrapped = self # type: ignore[attr-defined]
+ return wrapper # type: ignore[return-value]
+
+ def asNondeterministic(self) -> "UserDefinedFunction":
+ """
+ Updates UserDefinedFunction to nondeterministic.
+ """
+ self.deterministic = False
+ return self
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 2c025db0f36..3426f2bdaf6 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -10043,6 +10043,7 @@ def udf(
...
+@try_remote_functions
def udf(
f: Optional[Union[Callable[..., Any], "DataTypeOrString"]] = None,
returnType: "DataTypeOrString" = StringType(),
@@ -10053,6 +10054,9 @@ def udf(
.. versionadded:: 1.3.0
+ .. versionchanged:: 3.4.0
+ Support Spark Connect.
+
Parameters
----------
f : function
diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py
b/python/pyspark/sql/tests/connect/test_connect_function.py
index e61aeced30c..b74b1a9ee69 100644
--- a/python/pyspark/sql/tests/connect/test_connect_function.py
+++ b/python/pyspark/sql/tests/connect/test_connect_function.py
@@ -19,7 +19,7 @@ import tempfile
from pyspark.errors import PySparkTypeError
from pyspark.sql import SparkSession
-from pyspark.sql.types import StructType, StructField, ArrayType, IntegerType
+from pyspark.sql.types import StringType, StructType, StructField, ArrayType,
IntegerType
from pyspark.testing.pandasutils import PandasOnSparkTestCase
from pyspark.testing.connectutils import should_test_connect,
connect_requirement_message
from pyspark.testing.utils import ReusedPySparkTestCase
@@ -2282,15 +2282,52 @@ class
SparkConnectFunctionTests(SparkConnectFuncTestCase):
).toPandas(),
)
+ def test_udf(self):
+ from pyspark.sql import functions as SF
+ from pyspark.sql.connect import functions as CF
+
+ query = """
+ SELECT a, b, c FROM VALUES
+ (1, 1.0, 'x'), (2, 2.0, 'y'), (3, 3.0, 'z')
+ AS tab(a, b, c)
+ """
+ # +---+---+---+
+ # | a| b| c|
+ # +---+---+---+
+ # | 1|1.0| x|
+ # | 2|2.0| y|
+ # | 3|3.0| z|
+ # +---+---+---+
+
+ cdf = self.connect.sql(query)
+ sdf = self.spark.sql(query)
+
+ # as a normal function
+ self.assert_eq(
+ cdf.withColumn("A", CF.udf(lambda x: x + 1)(cdf.a)).toPandas(),
+ sdf.withColumn("A", SF.udf(lambda x: x + 1)(sdf.a)).toPandas(),
+ )
+
+ # as a decorator
+ @CF.udf(StringType())
+ def cfun(x):
+ return x + "a"
+
+ @SF.udf(StringType())
+ def sfun(x):
+ return x + "a"
+
+ self.assert_eq(
+ cdf.withColumn("A", cfun(cdf.c)).toPandas(),
+ sdf.withColumn("A", sfun(sdf.c)).toPandas(),
+ )
+
def test_unsupported_functions(self):
# SPARK-41928: Disable unsupported functions.
from pyspark.sql.connect import functions as CF
- for f in (
- "udf",
- "pandas_udf",
- ):
+ for f in ("pandas_udf",):
with self.assertRaises(NotImplementedError):
getattr(CF, f)()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]