This is an automated email from the ASF dual-hosted git repository.
ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new ce6b5f3751e [SPARK-44918][SQL][PYTHON] Support named arguments in
scalar Python/Pandas UDFs
ce6b5f3751e is described below
commit ce6b5f3751e9ea5d1cb4b63c8e14235914817766
Author: Takuya UESHIN <[email protected]>
AuthorDate: Thu Aug 24 10:55:54 2023 -0700
[SPARK-44918][SQL][PYTHON] Support named arguments in scalar Python/Pandas
UDFs
### What changes were proposed in this pull request?
Supports named arguments in scalar Python/Pandas UDF.
For example:
```py
>>> udf("int")
... def test_udf(a, b):
... return a + 10 * b
...
>>> spark.udf.register("test_udf", test_udf)
>>> spark.range(2).select(test_udf(b=col("id") * 10, a=col("id"))).show()
+---------------------------------+
|test_udf(b => (id * 10), a => id)|
+---------------------------------+
| 0|
| 101|
+---------------------------------+
>>> spark.sql("SELECT test_udf(b => id * 10, a => id) FROM range(2)").show()
+---------------------------------+
|test_udf(b => (id * 10), a => id)|
+---------------------------------+
| 0|
| 101|
+---------------------------------+
```
or:
```py
>>> pandas_udf("int")
... def test_udf(a, b):
... return a + 10 * b
...
>>> spark.udf.register("test_udf", test_udf)
>>> spark.range(2).select(test_udf(b=col("id") * 10, a=col("id"))).show()
+---------------------------------+
|test_udf(b => (id * 10), a => id)|
+---------------------------------+
| 0|
| 101|
+---------------------------------+
>>> spark.sql("SELECT test_udf(b => id * 10, a => id) FROM range(2)").show()
+---------------------------------+
|test_udf(b => (id * 10), a => id)|
+---------------------------------+
| 0|
| 101|
+---------------------------------+
```
### Why are the changes needed?
Now that named arguments support was added
(https://github.com/apache/spark/pull/41796,
https://github.com/apache/spark/pull/42020).
Scalar Python/Pandas UDFs can support it.
### Does this PR introduce _any_ user-facing change?
Yes, named arguments will be available for scalar Python/Pandas UDFs.
### How was this patch tested?
Added related tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #42617 from ueshin/issues/SPARK-44918/kwargs.
Authored-by: Takuya UESHIN <[email protected]>
Signed-off-by: Takuya UESHIN <[email protected]>
---
python/pyspark/sql/connect/udf.py | 24 +++---
python/pyspark/sql/functions.py | 17 ++++
python/pyspark/sql/pandas/functions.py | 17 ++++
.../sql/tests/pandas/test_pandas_udf_scalar.py | 94 ++++++++++++++++++++-
python/pyspark/sql/tests/test_udf.py | 98 +++++++++++++++++++++-
python/pyspark/sql/tests/test_udtf.py | 4 +-
python/pyspark/sql/udf.py | 24 ++++--
python/pyspark/worker.py | 89 ++++++++++++++------
.../ApplyInPandasWithStatePythonRunner.scala | 4 +
.../sql/execution/python/ArrowEvalPythonExec.scala | 9 +-
.../execution/python/ArrowEvalPythonUDTFExec.scala | 2 +-
.../sql/execution/python/ArrowPythonRunner.scala | 51 ++++++++++-
.../execution/python/ArrowPythonUDTFRunner.scala | 11 +--
.../sql/execution/python/BatchEvalPythonExec.scala | 7 +-
.../execution/python/BatchEvalPythonUDTFExec.scala | 20 ++---
.../python/EvalPythonEvaluatorFactory.scala | 25 ++++--
.../sql/execution/python/EvalPythonExec.scala | 10 +++
.../sql/execution/python/EvalPythonUDTFExec.scala | 17 +---
.../sql/execution/python/PythonArrowInput.scala | 11 +--
.../sql/execution/python/PythonUDFRunner.scala | 85 +++++++++++++------
.../python/UserDefinedPythonFunction.scala | 15 +++-
21 files changed, 497 insertions(+), 137 deletions(-)
diff --git a/python/pyspark/sql/connect/udf.py
b/python/pyspark/sql/connect/udf.py
index 2636777e5f6..90cea26e56f 100644
--- a/python/pyspark/sql/connect/udf.py
+++ b/python/pyspark/sql/connect/udf.py
@@ -25,13 +25,15 @@ import sys
import functools
import warnings
from inspect import getfullargspec
-from typing import cast, Callable, Any, TYPE_CHECKING, Optional, Union
+from typing import cast, Callable, Any, List, TYPE_CHECKING, Optional, Union
from pyspark.rdd import PythonEvalType
from pyspark.sql.connect.expressions import (
ColumnReference,
- PythonUDF,
CommonInlineUserDefinedFunction,
+ Expression,
+ NamedArgumentExpression,
+ PythonUDF,
)
from pyspark.sql.connect.column import Column
from pyspark.sql.connect.types import UnparsedDataType
@@ -155,12 +157,14 @@ class UserDefinedFunction:
self.deterministic = deterministic
def _build_common_inline_user_defined_function(
- self, *cols: "ColumnOrName"
+ self, *args: "ColumnOrName", **kwargs: "ColumnOrName"
) -> CommonInlineUserDefinedFunction:
- arg_cols = [
- col if isinstance(col, Column) else Column(ColumnReference(col))
for col in cols
+ def to_expr(col: "ColumnOrName") -> Expression:
+ return col._expr if isinstance(col, Column) else
ColumnReference(col)
+
+ arg_exprs: List[Expression] = [to_expr(arg) for arg in args] + [
+ NamedArgumentExpression(key, to_expr(value)) for key, value in
kwargs.items()
]
- arg_exprs = [col._expr for col in arg_cols]
py_udf = PythonUDF(
output_type=self.returnType,
@@ -175,8 +179,8 @@ class UserDefinedFunction:
arguments=arg_exprs,
)
- def __call__(self, *cols: "ColumnOrName") -> Column:
- return Column(self._build_common_inline_user_defined_function(*cols))
+ def __call__(self, *args: "ColumnOrName", **kwargs: "ColumnOrName") ->
Column:
+ return Column(self._build_common_inline_user_defined_function(*args,
**kwargs))
# This function is for improving the online help system in the interactive
interpreter.
# For example, the built-in help / pydoc.help. It wraps the UDF with the
docstring and
@@ -196,8 +200,8 @@ class UserDefinedFunction:
)
@functools.wraps(self.func, assigned=assignments)
- def wrapper(*args: "ColumnOrName") -> Column:
- return self(*args)
+ def wrapper(*args: "ColumnOrName", **kwargs: "ColumnOrName") -> Column:
+ return self(*args, **kwargs)
wrapper.__name__ = self._name
wrapper.__module__ = (
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index b5e64c7a039..e580d2aba12 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -15422,6 +15422,9 @@ def udf(
.. versionchanged:: 3.4.0
Supports Spark Connect.
+ .. versionchanged:: 4.0.0
+ Supports keyword-arguments.
+
Parameters
----------
f : function
@@ -15455,6 +15458,20 @@ def udf(
| 8| JOHN DOE| 22|
+----------+--------------+------------+
+ UDF can use keyword arguments:
+
+ >>> @udf(returnType=IntegerType())
+ ... def calc(a, b):
+ ... return a + 10 * b
+ ...
+ >>> spark.range(2).select(calc(b=col("id") * 10, a=col("id"))).show()
+ +-----------------------------+
+ |calc(b => (id * 10), a => id)|
+ +-----------------------------+
+ | 0|
+ | 101|
+ +-----------------------------+
+
Notes
-----
The user-defined functions are considered deterministic by default. Due to
diff --git a/python/pyspark/sql/pandas/functions.py
b/python/pyspark/sql/pandas/functions.py
index b7d381f04c7..ad9fdac9706 100644
--- a/python/pyspark/sql/pandas/functions.py
+++ b/python/pyspark/sql/pandas/functions.py
@@ -56,6 +56,9 @@ def pandas_udf(f=None, returnType=None, functionType=None):
.. versionchanged:: 3.4.0
Supports Spark Connect.
+ .. versionchanged:: 4.0.0
+ Supports keyword-arguments in SCALAR type.
+
Parameters
----------
f : function, optional
@@ -153,6 +156,20 @@ def pandas_udf(f=None, returnType=None, functionType=None):
| [John, Doe]|
+------------------+
+ This type of Pandas UDF can use keyword arguments:
+
+ >>> @pandas_udf(returnType=IntegerType())
+ ... def calc(a: pd.Series, b: pd.Series) -> pd.Series:
+ ... return a + 10 * b
+ ...
+ >>> spark.range(2).select(calc(b=col("id") * 10, a=col("id"))).show()
+ +-----------------------------+
+ |calc(b => (id * 10), a => id)|
+ +-----------------------------+
+ | 0|
+ | 101|
+ +-----------------------------+
+
.. note:: The length of the input is not that of the whole input
column, but is the
length of an internal batch used for each call to the function.
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
index 7a80547b3fc..8cb397ab95d 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
@@ -50,7 +50,7 @@ from pyspark.sql.types import (
BinaryType,
YearMonthIntervalType,
)
-from pyspark.errors import AnalysisException
+from pyspark.errors import AnalysisException, PythonException
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
test_compiled,
@@ -1467,6 +1467,98 @@ class ScalarPandasUDFTestsMixin:
finally:
shutil.rmtree(path)
+ def test_named_arguments(self):
+ @pandas_udf("int")
+ def test_udf(a, b):
+ return a + 10 * b
+
+ self.spark.udf.register("test_udf", test_udf)
+
+ for i, df in enumerate(
+ [
+ self.spark.range(2).select(test_udf(col("id"), b=col("id") *
10)),
+ self.spark.range(2).select(test_udf(a=col("id"), b=col("id") *
10)),
+ self.spark.range(2).select(test_udf(b=col("id") * 10,
a=col("id"))),
+ self.spark.sql("SELECT test_udf(id, b => id * 10) FROM
range(2)"),
+ self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM
range(2)"),
+ self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM
range(2)"),
+ ]
+ ):
+ with self.subTest(query_no=i):
+ assertDataFrameEqual(df, [Row(0), Row(101)])
+
+ def test_named_arguments_negative(self):
+ @pandas_udf("int")
+ def test_udf(a, b):
+ return a + b
+
+ self.spark.udf.register("test_udf", test_udf)
+
+ with self.assertRaisesRegex(
+ AnalysisException,
+
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
+ ):
+ self.spark.sql("SELECT test_udf(a => id, a => id * 10) FROM
range(2)").show()
+
+ with self.assertRaisesRegex(AnalysisException,
"UNEXPECTED_POSITIONAL_ARGUMENT"):
+ self.spark.sql("SELECT test_udf(a => id, id * 10) FROM
range(2)").show()
+
+ with self.assertRaisesRegex(
+ PythonException, r"test_udf\(\) got an unexpected keyword argument
'c'"
+ ):
+ self.spark.sql("SELECT test_udf(c => 'x') FROM range(2)").show()
+
+ def test_kwargs(self):
+ @pandas_udf("int")
+ def test_udf(a, **kwargs):
+ return a + 10 * kwargs["b"]
+
+ self.spark.udf.register("test_udf", test_udf)
+
+ for i, df in enumerate(
+ [
+ self.spark.range(2).select(test_udf(a=col("id"), b=col("id") *
10)),
+ self.spark.range(2).select(test_udf(b=col("id") * 10,
a=col("id"))),
+ self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM
range(2)"),
+ self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM
range(2)"),
+ ]
+ ):
+ with self.subTest(query_no=i):
+ assertDataFrameEqual(df, [Row(0), Row(101)])
+
+ def test_named_arguments_and_defaults(self):
+ @pandas_udf("int")
+ def test_udf(a, b=0):
+ return a + 10 * b
+
+ self.spark.udf.register("test_udf", test_udf)
+
+ # without "b"
+ for i, df in enumerate(
+ [
+ self.spark.range(2).select(test_udf(col("id"))),
+ self.spark.range(2).select(test_udf(a=col("id"))),
+ self.spark.sql("SELECT test_udf(id) FROM range(2)"),
+ self.spark.sql("SELECT test_udf(a => id) FROM range(2)"),
+ ]
+ ):
+ with self.subTest(with_b=False, query_no=i):
+ assertDataFrameEqual(df, [Row(0), Row(1)])
+
+ # with "b"
+ for i, df in enumerate(
+ [
+ self.spark.range(2).select(test_udf(col("id"), b=col("id") *
10)),
+ self.spark.range(2).select(test_udf(a=col("id"), b=col("id") *
10)),
+ self.spark.range(2).select(test_udf(b=col("id") * 10,
a=col("id"))),
+ self.spark.sql("SELECT test_udf(id, b => id * 10) FROM
range(2)"),
+ self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM
range(2)"),
+ self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM
range(2)"),
+ ]
+ ):
+ with self.subTest(with_b=True, query_no=i):
+ assertDataFrameEqual(df, [Row(0), Row(101)])
+
class ScalarPandasUDFTests(ScalarPandasUDFTestsMixin, ReusedSQLTestCase):
@classmethod
diff --git a/python/pyspark/sql/tests/test_udf.py
b/python/pyspark/sql/tests/test_udf.py
index 239ff27813b..f72bf288230 100644
--- a/python/pyspark/sql/tests/test_udf.py
+++ b/python/pyspark/sql/tests/test_udf.py
@@ -24,7 +24,7 @@ import datetime
from pyspark import SparkContext, SQLContext
from pyspark.sql import SparkSession, Column, Row
-from pyspark.sql.functions import udf, assert_true, lit, rand
+from pyspark.sql.functions import col, udf, assert_true, lit, rand
from pyspark.sql.udf import UserDefinedFunction
from pyspark.sql.types import (
StringType,
@@ -38,9 +38,9 @@ from pyspark.sql.types import (
TimestampNTZType,
DayTimeIntervalType,
)
-from pyspark.errors import AnalysisException, PySparkTypeError
+from pyspark.errors import AnalysisException, PythonException, PySparkTypeError
from pyspark.testing.sqlutils import ReusedSQLTestCase, test_compiled,
test_not_compiled_message
-from pyspark.testing.utils import QuietTest
+from pyspark.testing.utils import QuietTest, assertDataFrameEqual
class BaseUDFTestsMixin(object):
@@ -898,6 +898,98 @@ class BaseUDFTestsMixin(object):
self.assertEquals(row[1], {"a": "b"})
self.assertEquals(row[2], Row(col1=1, col2=2))
+ def test_named_arguments(self):
+ @udf("int")
+ def test_udf(a, b):
+ return a + 10 * b
+
+ self.spark.udf.register("test_udf", test_udf)
+
+ for i, df in enumerate(
+ [
+ self.spark.range(2).select(test_udf(col("id"), b=col("id") *
10)),
+ self.spark.range(2).select(test_udf(a=col("id"), b=col("id") *
10)),
+ self.spark.range(2).select(test_udf(b=col("id") * 10,
a=col("id"))),
+ self.spark.sql("SELECT test_udf(id, b => id * 10) FROM
range(2)"),
+ self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM
range(2)"),
+ self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM
range(2)"),
+ ]
+ ):
+ with self.subTest(query_no=i):
+ assertDataFrameEqual(df, [Row(0), Row(101)])
+
+ def test_named_arguments_negative(self):
+ @udf("int")
+ def test_udf(a, b):
+ return a + b
+
+ self.spark.udf.register("test_udf", test_udf)
+
+ with self.assertRaisesRegex(
+ AnalysisException,
+
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
+ ):
+ self.spark.sql("SELECT test_udf(a => id, a => id * 10) FROM
range(2)").show()
+
+ with self.assertRaisesRegex(AnalysisException,
"UNEXPECTED_POSITIONAL_ARGUMENT"):
+ self.spark.sql("SELECT test_udf(a => id, id * 10) FROM
range(2)").show()
+
+ with self.assertRaisesRegex(
+ PythonException, r"test_udf\(\) got an unexpected keyword argument
'c'"
+ ):
+ self.spark.sql("SELECT test_udf(c => 'x') FROM range(2)").show()
+
+ def test_kwargs(self):
+ @udf("int")
+ def test_udf(**kwargs):
+ return kwargs["a"] + 10 * kwargs["b"]
+
+ self.spark.udf.register("test_udf", test_udf)
+
+ for i, df in enumerate(
+ [
+ self.spark.range(2).select(test_udf(a=col("id"), b=col("id") *
10)),
+ self.spark.range(2).select(test_udf(b=col("id") * 10,
a=col("id"))),
+ self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM
range(2)"),
+ self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM
range(2)"),
+ ]
+ ):
+ with self.subTest(query_no=i):
+ assertDataFrameEqual(df, [Row(0), Row(101)])
+
+ def test_named_arguments_and_defaults(self):
+ @udf("int")
+ def test_udf(a, b=0):
+ return a + 10 * b
+
+ self.spark.udf.register("test_udf", test_udf)
+
+ # without "b"
+ for i, df in enumerate(
+ [
+ self.spark.range(2).select(test_udf(col("id"))),
+ self.spark.range(2).select(test_udf(a=col("id"))),
+ self.spark.sql("SELECT test_udf(id) FROM range(2)"),
+ self.spark.sql("SELECT test_udf(a => id) FROM range(2)"),
+ ]
+ ):
+ with self.subTest(with_b=False, query_no=i):
+ assertDataFrameEqual(df, [Row(0), Row(1)])
+
+ # with "b"
+ for i, df in enumerate(
+ [
+ self.spark.range(2).select(test_udf(col("id"), b=col("id") *
10)),
+ self.spark.range(2).select(test_udf(a=col("id"), b=col("id") *
10)),
+ self.spark.range(2).select(test_udf(b=col("id") * 10,
a=col("id"))),
+ self.spark.sql("SELECT test_udf(id, b => id * 10) FROM
range(2)"),
+ self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM
range(2)"),
+ self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM
range(2)"),
+ ]
+ ):
+ with self.subTest(with_b=True, query_no=i):
+ assertDataFrameEqual(df, [Row(0), Row(101)])
+
class UDFTests(BaseUDFTestsMixin, ReusedSQLTestCase):
@classmethod
diff --git a/python/pyspark/sql/tests/test_udtf.py
b/python/pyspark/sql/tests/test_udtf.py
index 9a80f8fd73c..63743de5e03 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -1947,7 +1947,7 @@ class BaseUDTFTestsMixin:
TestUDTF(a=lit(10)),
]
):
- with self.subTest(query_no=i):
+ with self.subTest(with_b=False, query_no=i):
assertDataFrameEqual(df, [Row(a=10, b=100)])
# with "b"
@@ -1961,7 +1961,7 @@ class BaseUDTFTestsMixin:
TestUDTF(b=lit("z"), a=lit(10)),
]
):
- with self.subTest(query_no=i):
+ with self.subTest(with_b=True, query_no=i):
assertDataFrameEqual(df, [Row(a=10, b="z")])
def test_udtf_with_table_argument_and_partition_by(self):
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index 7d7784dd522..029293ab70f 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -30,7 +30,7 @@ from py4j.java_gateway import JavaObject
from pyspark import SparkContext
from pyspark.profiler import Profiler
from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType
-from pyspark.sql.column import Column, _to_java_column, _to_java_expr, _to_seq
+from pyspark.sql.column import Column, _to_java_expr, _to_seq
from pyspark.sql.types import (
DataType,
StringType,
@@ -336,8 +336,17 @@ class UserDefinedFunction:
)
return judf
- def __call__(self, *cols: "ColumnOrName") -> Column:
+ def __call__(self, *args: "ColumnOrName", **kwargs: "ColumnOrName") ->
Column:
sc = get_active_spark_context()
+
+ assert sc._jvm is not None
+ jexprs = [_to_java_expr(arg) for arg in args] + [
+
sc._jvm.org.apache.spark.sql.catalyst.expressions.NamedArgumentExpression(
+ key, _to_java_expr(value)
+ )
+ for key, value in kwargs.items()
+ ]
+
profiler: Optional[Profiler] = None
memory_profiler: Optional[Profiler] = None
if sc.profiler_collector:
@@ -376,7 +385,7 @@ class UserDefinedFunction:
func.__signature__ = inspect.signature(f) # type:
ignore[attr-defined]
judf = self._create_judf(func)
- jUDFExpr = judf.builder(_to_seq(sc, cols, _to_java_expr))
+ jUDFExpr = judf.builder(_to_seq(sc, jexprs))
jPythonUDF = judf.fromUDFExpr(jUDFExpr)
id = jUDFExpr.resultId().id()
sc.profiler_collector.add_profiler(id, profiler)
@@ -394,13 +403,14 @@ class UserDefinedFunction:
func.__signature__ = inspect.signature(f) # type:
ignore[attr-defined]
judf = self._create_judf(func)
- jUDFExpr = judf.builder(_to_seq(sc, cols, _to_java_expr))
+ jUDFExpr = judf.builder(_to_seq(sc, jexprs))
jPythonUDF = judf.fromUDFExpr(jUDFExpr)
id = jUDFExpr.resultId().id()
sc.profiler_collector.add_profiler(id, memory_profiler)
else:
judf = self._judf
- jPythonUDF = judf.apply(_to_seq(sc, cols, _to_java_column))
+ jUDFExpr = judf.builder(_to_seq(sc, jexprs))
+ jPythonUDF = judf.fromUDFExpr(jUDFExpr)
return Column(jPythonUDF)
# This function is for improving the online help system in the interactive
interpreter.
@@ -421,8 +431,8 @@ class UserDefinedFunction:
)
@functools.wraps(self.func, assigned=assignments)
- def wrapper(*args: "ColumnOrName") -> Column:
- return self(*args)
+ def wrapper(*args: "ColumnOrName", **kwargs: "ColumnOrName") -> Column:
+ return self(*args, **kwargs)
wrapper.__name__ = self._name
wrapper.__module__ = (
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 0dd281ea91f..19c8c9c897b 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -84,9 +84,9 @@ def chain(f, g):
def wrap_udf(f, return_type):
if return_type.needConversion():
toInternal = return_type.toInternal
- return lambda *a: toInternal(f(*a))
+ return lambda *a, **kw: toInternal(f(*a, **kw))
else:
- return lambda *a: f(*a)
+ return lambda *a, **kw: f(*a, **kw)
def wrap_scalar_pandas_udf(f, return_type):
@@ -115,8 +115,10 @@ def wrap_scalar_pandas_udf(f, return_type):
)
return result
- return lambda *a: (
- verify_result_length(verify_result_type(f(*a)), len(a[0])),
+ return lambda *a, **kw: (
+ verify_result_length(
+ verify_result_type(f(*a, **kw)), len((list(a) +
list(kw.values()))[0])
+ ),
arrow_return_type,
)
@@ -137,8 +139,17 @@ def wrap_arrow_batch_udf(f, return_type):
elif type(return_type) == BinaryType:
result_func = lambda r: bytes(r) if r is not None else r # noqa: E731
- def evaluate(*args: pd.Series) -> pd.Series:
- return pd.Series(result_func(f(*a)) for a in zip(*args))
+ def evaluate(*args: pd.Series, **kwargs: pd.Series) -> pd.Series:
+ keys = list(kwargs.keys())
+ len_args = len(args)
+ return pd.Series(
+ [
+ result_func(
+ f(*row[:len_args], **{key: row[len_args + i] for i, key in
enumerate(keys)})
+ )
+ for row in zip(*args, *[kwargs[key] for key in keys])
+ ]
+ )
def verify_result_type(result):
if not hasattr(result, "__len__"):
@@ -163,8 +174,10 @@ def wrap_arrow_batch_udf(f, return_type):
)
return result
- return lambda *a: (
- verify_result_length(verify_result_type(evaluate(*a)), len(a[0])),
+ return lambda *a, **kw: (
+ verify_result_length(
+ verify_result_type(evaluate(*a, **kw)), len((list(a) +
list(kw.values()))[0])
+ ),
arrow_return_type,
)
@@ -517,7 +530,27 @@ def wrap_bounded_window_agg_pandas_udf(f, return_type):
def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index):
num_arg = read_int(infile)
- arg_offsets = [read_int(infile) for i in range(num_arg)]
+
+ if eval_type in (
+ PythonEvalType.SQL_BATCHED_UDF,
+ PythonEvalType.SQL_ARROW_BATCHED_UDF,
+ PythonEvalType.SQL_SCALAR_PANDAS_UDF,
+ # The below doesn't support named argument, but shares the same
protocol.
+ PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
+ ):
+ args_offsets = []
+ kwargs_offsets = {}
+ for _ in range(num_arg):
+ offset = read_int(infile)
+ if read_bool(infile):
+ name = utf8_deserializer.loads(infile)
+ kwargs_offsets[name] = offset
+ else:
+ args_offsets.append(offset)
+ else:
+ args_offsets = [read_int(infile) for i in range(num_arg)]
+ kwargs_offsets = {}
+
chained_func = None
for i in range(read_int(infile)):
f, return_type = read_command(pickleSer, infile)
@@ -535,31 +568,32 @@ def read_single_udf(pickleSer, infile, eval_type,
runner_conf, udf_index):
# the last returnType will be the return type of UDF
if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
- return arg_offsets, wrap_scalar_pandas_udf(func, return_type)
+ udf = wrap_scalar_pandas_udf(func, return_type)
elif eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF:
- return arg_offsets, wrap_arrow_batch_udf(func, return_type)
+ udf = wrap_arrow_batch_udf(func, return_type)
elif eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF:
- return arg_offsets, wrap_pandas_batch_iter_udf(func, return_type)
+ udf = wrap_pandas_batch_iter_udf(func, return_type)
elif eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF:
- return arg_offsets, wrap_pandas_batch_iter_udf(func, return_type)
+ udf = wrap_pandas_batch_iter_udf(func, return_type)
elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF:
- return arg_offsets, wrap_arrow_batch_iter_udf(func, return_type)
+ udf = wrap_arrow_batch_iter_udf(func, return_type)
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
argspec = getfullargspec(chained_func) # signature was lost when
wrapping it
- return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type,
argspec, runner_conf)
+ udf = wrap_grouped_map_pandas_udf(func, return_type, argspec,
runner_conf)
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
- return arg_offsets, wrap_grouped_map_pandas_udf_with_state(func,
return_type)
+ udf = wrap_grouped_map_pandas_udf_with_state(func, return_type)
elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
argspec = getfullargspec(chained_func) # signature was lost when
wrapping it
- return arg_offsets, wrap_cogrouped_map_pandas_udf(func, return_type,
argspec, runner_conf)
+ udf = wrap_cogrouped_map_pandas_udf(func, return_type, argspec,
runner_conf)
elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
- return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type)
+ udf = wrap_grouped_agg_pandas_udf(func, return_type)
elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF:
- return arg_offsets, wrap_window_agg_pandas_udf(func, return_type,
runner_conf, udf_index)
+ udf = wrap_window_agg_pandas_udf(func, return_type, runner_conf,
udf_index)
elif eval_type == PythonEvalType.SQL_BATCHED_UDF:
- return arg_offsets, wrap_udf(func, return_type)
+ udf = wrap_udf(func, return_type)
else:
raise ValueError("Unknown eval type: {}".format(eval_type))
+ return args_offsets, kwargs_offsets, udf
# Used by SQL_GROUPED_MAP_PANDAS_UDF and SQL_SCALAR_PANDAS_UDF and
SQL_ARROW_BATCHED_UDF when
@@ -984,7 +1018,9 @@ def read_udfs(pickleSer, infile, eval_type):
if is_map_arrow_iter:
assert num_udfs == 1, "One MAP_ARROW_ITER UDF expected here."
- arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type,
runner_conf, udf_index=0)
+ arg_offsets, _, udf = read_single_udf(
+ pickleSer, infile, eval_type, runner_conf, udf_index=0
+ )
def func(_, iterator):
num_input_rows = 0
@@ -1074,7 +1110,7 @@ def read_udfs(pickleSer, infile, eval_type):
# See FlatMapGroupsInPandasExec for how arg_offsets are used to
# distinguish between grouping attributes and data attributes
- arg_offsets, f = read_single_udf(pickleSer, infile, eval_type,
runner_conf, udf_index=0)
+ arg_offsets, _, f = read_single_udf(pickleSer, infile, eval_type,
runner_conf, udf_index=0)
parsed_offsets = extract_key_value_indexes(arg_offsets)
# Create function like this:
@@ -1091,7 +1127,7 @@ def read_udfs(pickleSer, infile, eval_type):
# See FlatMapGroupsInPandas(WithState)Exec for how arg_offsets are
used to
# distinguish between grouping attributes and data attributes
- arg_offsets, f = read_single_udf(pickleSer, infile, eval_type,
runner_conf, udf_index=0)
+ arg_offsets, _, f = read_single_udf(pickleSer, infile, eval_type,
runner_conf, udf_index=0)
parsed_offsets = extract_key_value_indexes(arg_offsets)
def mapper(a):
@@ -1125,7 +1161,7 @@ def read_udfs(pickleSer, infile, eval_type):
# We assume there is only one UDF here because cogrouped map doesn't
# support combining multiple UDFs.
assert num_udfs == 1
- arg_offsets, f = read_single_udf(pickleSer, infile, eval_type,
runner_conf, udf_index=0)
+ arg_offsets, _, f = read_single_udf(pickleSer, infile, eval_type,
runner_conf, udf_index=0)
parsed_offsets = extract_key_value_indexes(arg_offsets)
@@ -1142,7 +1178,10 @@ def read_udfs(pickleSer, infile, eval_type):
udfs.append(read_single_udf(pickleSer, infile, eval_type,
runner_conf, udf_index=i))
def mapper(a):
- result = tuple(f(*[a[o] for o in arg_offsets]) for (arg_offsets,
f) in udfs)
+ result = tuple(
+ f(*[a[o] for o in args_offsets], **{k: a[o] for k, o in
kwargs_offsets.items()})
+ for args_offsets, kwargs_offsets, f in udfs
+ )
# In the special case of a single UDF this will return a single
result rather
# than a tuple of results; this is the format that the JVM side
expects.
if len(result) == 1:
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
index a60d0beeeed..9fde1814079 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
@@ -105,6 +105,10 @@ class ApplyInPandasWithStatePythonRunner(
private val stateRowDeserializer = stateEncoder.createDeserializer()
+ override protected def writeUDF(dataOut: DataOutputStream): Unit = {
+ PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets)
+ }
+
/**
* This method sends out the additional metadata before sending out actual
data.
*
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
index 7db43a34a88..bd91da3bc0f 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata
import org.apache.spark.sql.types.StructType
/**
@@ -94,11 +95,11 @@ class ArrowEvalPythonEvaluatorFactory(
pythonRunnerConf: Map[String, String],
pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String])
- extends EvalPythonEvaluatorFactory(childOutput, udfs, output) {
+ extends EvalPythonEvaluatorFactory(childOutput, udfs, output) {
override def evaluate(
funcs: Seq[ChainedPythonFunctions],
- argOffsets: Array[Array[Int]],
+ argMetas: Array[Array[ArgumentMetadata]],
iter: Iterator[InternalRow],
schema: StructType,
context: TaskContext): Iterator[InternalRow] = {
@@ -108,10 +109,10 @@ class ArrowEvalPythonEvaluatorFactory(
// DO NOT use iter.grouped(). See BatchIterator.
val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else
Iterator(iter)
- val columnarBatchIter = new ArrowPythonRunner(
+ val columnarBatchIter = new ArrowPythonWithNamedArgumentRunner(
funcs,
evalType,
- argOffsets,
+ argMetas,
schema,
sessionLocalTimeZone,
largeVarTypes,
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
index 8ebd8a3a106..9d5bac0c600 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
@@ -23,7 +23,7 @@ import org.apache.spark.{JobArtifactSet, TaskContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
-import
org.apache.spark.sql.execution.python.EvalPythonUDTFExec.ArgumentMetadata
+import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
index 0f26d8f21f8..251e682c9e1 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
@@ -17,17 +17,17 @@
package org.apache.spark.sql.execution.python
+import java.io.DataOutputStream
+
import org.apache.spark.api.python._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
-/**
- * Similar to `PythonUDFRunner`, but exchange data with Python worker via
Arrow stream.
- */
-class ArrowPythonRunner(
+abstract class BaseArrowPythonRunner(
funcs: Seq[ChainedPythonFunctions],
evalType: Int,
argOffsets: Array[Array[Int]],
@@ -61,6 +61,49 @@ class ArrowPythonRunner(
s"Please change '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'.")
}
+/**
+ * Similar to `PythonUDFRunner`, but exchange data with Python worker via
Arrow stream.
+ */
+class ArrowPythonRunner(
+ funcs: Seq[ChainedPythonFunctions],
+ evalType: Int,
+ argOffsets: Array[Array[Int]],
+ _schema: StructType,
+ _timeZoneId: String,
+ largeVarTypes: Boolean,
+ workerConf: Map[String, String],
+ pythonMetrics: Map[String, SQLMetric],
+ jobArtifactUUID: Option[String])
+ extends BaseArrowPythonRunner(
+ funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes,
workerConf,
+ pythonMetrics, jobArtifactUUID) {
+
+ override protected def writeUDF(dataOut: DataOutputStream): Unit =
+ PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets)
+}
+
+/**
+ * Similar to `PythonUDFWithNamedArgumentsRunner`, but exchange data with
Python worker
+ * via Arrow stream.
+ */
+class ArrowPythonWithNamedArgumentRunner(
+ funcs: Seq[ChainedPythonFunctions],
+ evalType: Int,
+ argMetas: Array[Array[ArgumentMetadata]],
+ _schema: StructType,
+ _timeZoneId: String,
+ largeVarTypes: Boolean,
+ workerConf: Map[String, String],
+ pythonMetrics: Map[String, SQLMetric],
+ jobArtifactUUID: Option[String])
+ extends BaseArrowPythonRunner(
+ funcs, evalType, argMetas.map(_.map(_.offset)), _schema, _timeZoneId,
largeVarTypes, workerConf,
+ pythonMetrics, jobArtifactUUID) {
+
+ override protected def writeUDF(dataOut: DataOutputStream): Unit =
+ PythonUDFRunner.writeUDFs(dataOut, funcs, argMetas)
+}
+
object ArrowPythonRunner {
/** Return Map with conf settings to be used in ArrowPythonRunner */
def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] = {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
index c0fa8b58bee..690947b4129 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
@@ -23,7 +23,7 @@ import org.apache.spark.api.python._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.PythonUDTF
import org.apache.spark.sql.execution.metric.SQLMetric
-import
org.apache.spark.sql.execution.python.EvalPythonUDTFExec.ArgumentMetadata
+import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
@@ -42,15 +42,12 @@ class ArrowPythonUDTFRunner(
val pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String])
extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](
- Seq(ChainedPythonFunctions(Seq(udtf.func))),
- evalType, Array(argMetas.map(_.offset)), jobArtifactUUID)
+ Seq(ChainedPythonFunctions(Seq(udtf.func))), evalType,
Array(argMetas.map(_.offset)),
+ jobArtifactUUID)
with BasicPythonArrowInput
with BasicPythonArrowOutput {
- override protected def writeUDF(
- dataOut: DataOutputStream,
- funcs: Seq[ChainedPythonFunctions],
- argOffsets: Array[Array[Int]]): Unit = {
+ override protected def writeUDF(dataOut: DataOutputStream): Unit = {
PythonUDTFRunner.writeUDTF(dataOut, udtf, argMetas)
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
index 1de8f55d84b..a0e7789b281 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata
import org.apache.spark.sql.types.{StructField, StructType}
/**
@@ -60,7 +61,7 @@ class BatchEvalPythonEvaluatorFactory(
override def evaluate(
funcs: Seq[ChainedPythonFunctions],
- argOffsets: Array[Array[Int]],
+ argMetas: Array[Array[ArgumentMetadata]],
iter: Iterator[InternalRow],
schema: StructType,
context: TaskContext): Iterator[InternalRow] = {
@@ -71,8 +72,8 @@ class BatchEvalPythonEvaluatorFactory(
// Output iterator for results from Python.
val outputIterator =
- new PythonUDFRunner(
- funcs, PythonEvalType.SQL_BATCHED_UDF, argOffsets, pythonMetrics,
jobArtifactUUID)
+ new PythonUDFWithNamedArgumentsRunner(
+ funcs, PythonEvalType.SQL_BATCHED_UDF, argMetas, pythonMetrics,
jobArtifactUUID)
.compute(inputIterator, context.partitionId(), context)
val unpickle = new Unpickler
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
index 46cc0f2ab50..342e0723194 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
@@ -23,14 +23,14 @@ import scala.collection.JavaConverters._
import net.razorvine.pickle.Unpickler
-import org.apache.spark.{JobArtifactSet, SparkEnv, TaskContext}
-import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType,
PythonWorker, PythonWorkerUtils}
+import org.apache.spark.{JobArtifactSet, TaskContext}
+import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType,
PythonWorkerUtils}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.metric.SQLMetric
-import
org.apache.spark.sql.execution.python.EvalPythonUDTFExec.ArgumentMetadata
+import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata
import org.apache.spark.sql.types.StructType
/**
@@ -101,18 +101,8 @@ class PythonUDTFRunner(
Seq(ChainedPythonFunctions(Seq(udtf.func))),
PythonEvalType.SQL_TABLE_UDF, Array(argMetas.map(_.offset)),
pythonMetrics, jobArtifactUUID) {
- protected override def newWriter(
- env: SparkEnv,
- worker: PythonWorker,
- inputIterator: Iterator[Array[Byte]],
- partitionIndex: Int,
- context: TaskContext): Writer = {
- new PythonUDFWriter(env, worker, inputIterator, partitionIndex, context) {
-
- protected override def writeCommand(dataOut: DataOutputStream): Unit = {
- PythonUDTFRunner.writeUDTF(dataOut, udtf, argMetas)
- }
- }
+ override protected def writeUDF(dataOut: DataOutputStream): Unit = {
+ PythonUDTFRunner.writeUDTF(dataOut, udtf, argMetas)
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonEvaluatorFactory.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonEvaluatorFactory.scala
index 373e17c0aa3..d5142f58eab 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonEvaluatorFactory.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonEvaluatorFactory.scala
@@ -25,6 +25,7 @@ import org.apache.spark.{PartitionEvaluator,
PartitionEvaluatorFactory, SparkEnv
import org.apache.spark.api.python.ChainedPythonFunctions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata
import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.util.Utils
@@ -32,11 +33,11 @@ abstract class EvalPythonEvaluatorFactory(
childOutput: Seq[Attribute],
udfs: Seq[PythonUDF],
output: Seq[Attribute])
- extends PartitionEvaluatorFactory[InternalRow, InternalRow] {
+ extends PartitionEvaluatorFactory[InternalRow, InternalRow] {
protected def evaluate(
funcs: Seq[ChainedPythonFunctions],
- argOffsets: Array[Array[Int]],
+ argMetas: Array[Array[ArgumentMetadata]],
iter: Iterator[InternalRow],
schema: StructType,
context: TaskContext): Iterator[InternalRow]
@@ -78,14 +79,20 @@ abstract class EvalPythonEvaluatorFactory(
// flatten all the arguments
val allInputs = new ArrayBuffer[Expression]
val dataTypes = new ArrayBuffer[DataType]
- val argOffsets = inputs.map { input =>
+ val argMetas = inputs.map { input =>
input.map { e =>
- if (allInputs.exists(_.semanticEquals(e))) {
- allInputs.indexWhere(_.semanticEquals(e))
+ val (key, value) = e match {
+ case NamedArgumentExpression(key, value) =>
+ (Some(key), value)
+ case _ =>
+ (None, e)
+ }
+ if (allInputs.exists(_.semanticEquals(value))) {
+ ArgumentMetadata(allInputs.indexWhere(_.semanticEquals(value)),
key)
} else {
- allInputs += e
- dataTypes += e.dataType
- allInputs.length - 1
+ allInputs += value
+ dataTypes += value.dataType
+ ArgumentMetadata(allInputs.length - 1, key)
}
}.toArray
}.toArray
@@ -102,7 +109,7 @@ abstract class EvalPythonEvaluatorFactory(
}
val outputRowIterator =
- evaluate(pyFuncs, argOffsets, projectedRowIter, schema, context)
+ evaluate(pyFuncs, argMetas, projectedRowIter, schema, context)
val joined = new JoinedRow
val resultProj = UnsafeProjection.create(output, output)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala
index 1c8b0f2228f..af6769cfbb9 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala
@@ -22,6 +22,16 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.UnaryExecNode
+object EvalPythonExec {
+ /**
+ * Metadata for arguments of Python UDTF.
+ *
+ * @param offset the offset of the argument
+ * @param name the name of the argument if it's a `NamedArgumentExpression`
+ */
+ case class ArgumentMetadata(offset: Int, name: Option[String])
+}
+
/**
* A physical plan that evaluates a [[PythonUDF]], one partition of tuples at
a time.
*
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala
index 410209e0ada..41a99693443 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala
@@ -21,25 +21,15 @@ import java.io.File
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.{ContextAwareIterator, SparkEnv, TaskContext}
+import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.UnaryExecNode
-import
org.apache.spark.sql.execution.python.EvalPythonUDTFExec.ArgumentMetadata
+import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata
import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.util.Utils
-object EvalPythonUDTFExec {
- /**
- * Metadata for arguments of Python UDTF.
- *
- * @param offset the offset of the argument
- * @param name the name of the argument if it's a `NamedArgumentExpression`
- */
- case class ArgumentMetadata(offset: Int, name: Option[String])
-}
-
/**
* A physical plan that evaluates a [[PythonUDTF]], one partition of tuples at
a time.
* This is similar to [[EvalPythonExec]].
@@ -66,7 +56,6 @@ trait EvalPythonUDTFExec extends UnaryExecNode {
inputRDD.mapPartitions { iter =>
val context = TaskContext.get()
- val contextAwareIterator = new ContextAwareIterator(context, iter)
// The queue used to buffer input rows so we can drain it to
// combine input with output from Python.
@@ -104,7 +93,7 @@ trait EvalPythonUDTFExec extends UnaryExecNode {
// Also keep track of the number rows added to the queue.
// This is needed to process extra output rows from the `terminate()`
call of the UDTF.
var count = 0L
- val projectedRowIter = contextAwareIterator.map { inputRow =>
+ val projectedRowIter = iter.map { inputRow =>
queue.add(inputRow.asInstanceOf[UnsafeRow])
count += 1
projection(inputRow)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
index 00ee3a17563..1e075cab922 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
@@ -22,7 +22,7 @@ import org.apache.arrow.vector.VectorSchemaRoot
import org.apache.arrow.vector.ipc.ArrowStreamWriter
import org.apache.spark.{SparkEnv, TaskContext}
-import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions,
PythonRDD, PythonWorker}
+import org.apache.spark.api.python.{BasePythonRunner, PythonRDD, PythonWorker}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.arrow
import org.apache.spark.sql.execution.arrow.ArrowWriter
@@ -54,11 +54,7 @@ private[python] trait PythonArrowInput[IN] { self:
BasePythonRunner[IN, _] =>
dataOut: DataOutputStream,
inputIterator: Iterator[IN]): Boolean
- protected def writeUDF(
- dataOut: DataOutputStream,
- funcs: Seq[ChainedPythonFunctions],
- argOffsets: Array[Array[Int]]): Unit =
- PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets)
+ protected def writeUDF(dataOut: DataOutputStream): Unit
protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = {
// Write config for the worker as a number of key -> value pairs of strings
@@ -94,9 +90,10 @@ protected def close(): Unit = {
partitionIndex: Int,
context: TaskContext): Writer = {
new Writer(env, worker, inputIterator, partitionIndex, context) {
+
protected override def writeCommand(dataOut: DataOutputStream): Unit = {
handleMetadataBeforeExec(dataOut)
- writeUDF(dataOut, funcs, argOffsets)
+ writeUDF(dataOut)
}
override def writeNextInputToStream(dataOut: DataOutputStream): Boolean
= {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
index bc27ee6919d..b99517f544d 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
@@ -23,6 +23,7 @@ import java.util.concurrent.atomic.AtomicBoolean
import org.apache.spark._
import org.apache.spark.api.python._
import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata
import org.apache.spark.sql.internal.SQLConf
/**
@@ -43,24 +44,31 @@ abstract class BasePythonUDFRunner(
override val simplifiedTraceback: Boolean =
SQLConf.get.pysparkSimplifiedTraceback
- abstract class PythonUDFWriter(
+ protected def writeUDF(dataOut: DataOutputStream): Unit
+
+ protected override def newWriter(
env: SparkEnv,
worker: PythonWorker,
inputIterator: Iterator[Array[Byte]],
partitionIndex: Int,
- context: TaskContext)
- extends Writer(env, worker, inputIterator, partitionIndex, context) {
-
- override def writeNextInputToStream(dataOut: DataOutputStream): Boolean = {
- val startData = dataOut.size()
- val wroteData = PythonRDD.writeNextElementToStream(inputIterator,
dataOut)
- if (!wroteData) {
- // Reached the end of input.
- dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
+ context: TaskContext): Writer = {
+ new Writer(env, worker, inputIterator, partitionIndex, context) {
+
+ protected override def writeCommand(dataOut: DataOutputStream): Unit = {
+ writeUDF(dataOut)
+ }
+
+ override def writeNextInputToStream(dataOut: DataOutputStream): Boolean
= {
+ val startData = dataOut.size()
+ val wroteData = PythonRDD.writeNextElementToStream(inputIterator,
dataOut)
+ if (!wroteData) {
+ // Reached the end of input.
+ dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
+ }
+ val deltaData = dataOut.size() - startData
+ pythonMetrics("pythonDataSent") += deltaData
+ wroteData
}
- val deltaData = dataOut.size() - startData
- pythonMetrics("pythonDataSent") += deltaData
- wroteData
}
}
@@ -111,19 +119,22 @@ class PythonUDFRunner(
jobArtifactUUID: Option[String])
extends BasePythonUDFRunner(funcs, evalType, argOffsets, pythonMetrics,
jobArtifactUUID) {
- protected override def newWriter(
- env: SparkEnv,
- worker: PythonWorker,
- inputIterator: Iterator[Array[Byte]],
- partitionIndex: Int,
- context: TaskContext): Writer = {
- new PythonUDFWriter(env, worker, inputIterator, partitionIndex, context) {
+ override protected def writeUDF(dataOut: DataOutputStream): Unit = {
+ PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets)
+ }
+}
- protected override def writeCommand(dataOut: DataOutputStream): Unit = {
- PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets)
- }
+class PythonUDFWithNamedArgumentsRunner(
+ funcs: Seq[ChainedPythonFunctions],
+ evalType: Int,
+ argMetas: Array[Array[ArgumentMetadata]],
+ pythonMetrics: Map[String, SQLMetric],
+ jobArtifactUUID: Option[String])
+ extends BasePythonUDFRunner(
+ funcs, evalType, argMetas.map(_.map(_.offset)), pythonMetrics,
jobArtifactUUID) {
- }
+ override protected def writeUDF(dataOut: DataOutputStream): Unit = {
+ PythonUDFRunner.writeUDFs(dataOut, funcs, argMetas)
}
}
@@ -146,4 +157,30 @@ object PythonUDFRunner {
}
}
}
+
+ def writeUDFs(
+ dataOut: DataOutputStream,
+ funcs: Seq[ChainedPythonFunctions],
+ argMetas: Array[Array[ArgumentMetadata]]): Unit = {
+ dataOut.writeInt(funcs.length)
+ funcs.zip(argMetas).foreach { case (chained, metas) =>
+ dataOut.writeInt(metas.length)
+ metas.foreach {
+ case ArgumentMetadata(offset, name) =>
+ dataOut.writeInt(offset)
+ name match {
+ case Some(name) =>
+ dataOut.writeBoolean(true)
+ PythonWorkerUtils.writeUTF(name, dataOut)
+ case _ =>
+ dataOut.writeBoolean(false)
+ }
+ }
+ dataOut.writeInt(chained.funcs.length)
+ chained.funcs.foreach { f =>
+ dataOut.writeInt(f.command.length)
+ dataOut.write(f.command.toArray)
+ }
+ }
+ }
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
index 38d521c16d5..f576637aa25 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
@@ -50,6 +50,19 @@ case class UserDefinedPythonFunction(
udfDeterministic: Boolean) {
def builder(e: Seq[Expression]): Expression = {
+ if (pythonEvalType == PythonEvalType.SQL_BATCHED_UDF
+ || pythonEvalType ==PythonEvalType.SQL_ARROW_BATCHED_UDF
+ || pythonEvalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF) {
+ /*
+ * Check if the named arguments:
+ * - don't have duplicated names
+ * - don't contain positional arguments after named arguments
+ */
+ NamedParametersSupport.splitAndCheckNamedArguments(e, name)
+ } else if (e.exists(_.isInstanceOf[NamedArgumentExpression])) {
+ throw QueryCompilationErrors.namedArgumentsNotSupported(name)
+ }
+
if (pythonEvalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF) {
PythonUDAF(name, func, dataType, e, udfDeterministic)
} else {
@@ -104,7 +117,7 @@ case class UserDefinedPythonTableFunction(
/*
* Check if the named arguments:
* - don't have duplicated names
- * - don't contain positional arguments
+ * - don't contain positional arguments after named arguments
*/
NamedParametersSupport.splitAndCheckNamedArguments(exprs, name)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]