This is an automated email from the ASF dual-hosted git repository. ueshin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new ce6b5f3751e [SPARK-44918][SQL][PYTHON] Support named arguments in scalar Python/Pandas UDFs ce6b5f3751e is described below commit ce6b5f3751e9ea5d1cb4b63c8e14235914817766 Author: Takuya UESHIN <ues...@databricks.com> AuthorDate: Thu Aug 24 10:55:54 2023 -0700 [SPARK-44918][SQL][PYTHON] Support named arguments in scalar Python/Pandas UDFs ### What changes were proposed in this pull request? Supports named arguments in scalar Python/Pandas UDF. For example: ```py >>> udf("int") ... def test_udf(a, b): ... return a + 10 * b ... >>> spark.udf.register("test_udf", test_udf) >>> spark.range(2).select(test_udf(b=col("id") * 10, a=col("id"))).show() +---------------------------------+ |test_udf(b => (id * 10), a => id)| +---------------------------------+ | 0| | 101| +---------------------------------+ >>> spark.sql("SELECT test_udf(b => id * 10, a => id) FROM range(2)").show() +---------------------------------+ |test_udf(b => (id * 10), a => id)| +---------------------------------+ | 0| | 101| +---------------------------------+ ``` or: ```py >>> pandas_udf("int") ... def test_udf(a, b): ... return a + 10 * b ... >>> spark.udf.register("test_udf", test_udf) >>> spark.range(2).select(test_udf(b=col("id") * 10, a=col("id"))).show() +---------------------------------+ |test_udf(b => (id * 10), a => id)| +---------------------------------+ | 0| | 101| +---------------------------------+ >>> spark.sql("SELECT test_udf(b => id * 10, a => id) FROM range(2)").show() +---------------------------------+ |test_udf(b => (id * 10), a => id)| +---------------------------------+ | 0| | 101| +---------------------------------+ ``` ### Why are the changes needed? Now that named arguments support was added (https://github.com/apache/spark/pull/41796, https://github.com/apache/spark/pull/42020). Scalar Python/Pandas UDFs can support it. ### Does this PR introduce _any_ user-facing change? Yes, named arguments will be available for scalar Python/Pandas UDFs. ### How was this patch tested? Added related tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #42617 from ueshin/issues/SPARK-44918/kwargs. Authored-by: Takuya UESHIN <ues...@databricks.com> Signed-off-by: Takuya UESHIN <ues...@databricks.com> --- python/pyspark/sql/connect/udf.py | 24 +++--- python/pyspark/sql/functions.py | 17 ++++ python/pyspark/sql/pandas/functions.py | 17 ++++ .../sql/tests/pandas/test_pandas_udf_scalar.py | 94 ++++++++++++++++++++- python/pyspark/sql/tests/test_udf.py | 98 +++++++++++++++++++++- python/pyspark/sql/tests/test_udtf.py | 4 +- python/pyspark/sql/udf.py | 24 ++++-- python/pyspark/worker.py | 89 ++++++++++++++------ .../ApplyInPandasWithStatePythonRunner.scala | 4 + .../sql/execution/python/ArrowEvalPythonExec.scala | 9 +- .../execution/python/ArrowEvalPythonUDTFExec.scala | 2 +- .../sql/execution/python/ArrowPythonRunner.scala | 51 ++++++++++- .../execution/python/ArrowPythonUDTFRunner.scala | 11 +-- .../sql/execution/python/BatchEvalPythonExec.scala | 7 +- .../execution/python/BatchEvalPythonUDTFExec.scala | 20 ++--- .../python/EvalPythonEvaluatorFactory.scala | 25 ++++-- .../sql/execution/python/EvalPythonExec.scala | 10 +++ .../sql/execution/python/EvalPythonUDTFExec.scala | 17 +--- .../sql/execution/python/PythonArrowInput.scala | 11 +-- .../sql/execution/python/PythonUDFRunner.scala | 85 +++++++++++++------ .../python/UserDefinedPythonFunction.scala | 15 +++- 21 files changed, 497 insertions(+), 137 deletions(-) diff --git a/python/pyspark/sql/connect/udf.py b/python/pyspark/sql/connect/udf.py index 2636777e5f6..90cea26e56f 100644 --- a/python/pyspark/sql/connect/udf.py +++ b/python/pyspark/sql/connect/udf.py @@ -25,13 +25,15 @@ import sys import functools import warnings from inspect import getfullargspec -from typing import cast, Callable, Any, TYPE_CHECKING, Optional, Union +from typing import cast, Callable, Any, List, TYPE_CHECKING, Optional, Union from pyspark.rdd import PythonEvalType from pyspark.sql.connect.expressions import ( ColumnReference, - PythonUDF, CommonInlineUserDefinedFunction, + Expression, + NamedArgumentExpression, + PythonUDF, ) from pyspark.sql.connect.column import Column from pyspark.sql.connect.types import UnparsedDataType @@ -155,12 +157,14 @@ class UserDefinedFunction: self.deterministic = deterministic def _build_common_inline_user_defined_function( - self, *cols: "ColumnOrName" + self, *args: "ColumnOrName", **kwargs: "ColumnOrName" ) -> CommonInlineUserDefinedFunction: - arg_cols = [ - col if isinstance(col, Column) else Column(ColumnReference(col)) for col in cols + def to_expr(col: "ColumnOrName") -> Expression: + return col._expr if isinstance(col, Column) else ColumnReference(col) + + arg_exprs: List[Expression] = [to_expr(arg) for arg in args] + [ + NamedArgumentExpression(key, to_expr(value)) for key, value in kwargs.items() ] - arg_exprs = [col._expr for col in arg_cols] py_udf = PythonUDF( output_type=self.returnType, @@ -175,8 +179,8 @@ class UserDefinedFunction: arguments=arg_exprs, ) - def __call__(self, *cols: "ColumnOrName") -> Column: - return Column(self._build_common_inline_user_defined_function(*cols)) + def __call__(self, *args: "ColumnOrName", **kwargs: "ColumnOrName") -> Column: + return Column(self._build_common_inline_user_defined_function(*args, **kwargs)) # This function is for improving the online help system in the interactive interpreter. # For example, the built-in help / pydoc.help. It wraps the UDF with the docstring and @@ -196,8 +200,8 @@ class UserDefinedFunction: ) @functools.wraps(self.func, assigned=assignments) - def wrapper(*args: "ColumnOrName") -> Column: - return self(*args) + def wrapper(*args: "ColumnOrName", **kwargs: "ColumnOrName") -> Column: + return self(*args, **kwargs) wrapper.__name__ = self._name wrapper.__module__ = ( diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b5e64c7a039..e580d2aba12 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -15422,6 +15422,9 @@ def udf( .. versionchanged:: 3.4.0 Supports Spark Connect. + .. versionchanged:: 4.0.0 + Supports keyword-arguments. + Parameters ---------- f : function @@ -15455,6 +15458,20 @@ def udf( | 8| JOHN DOE| 22| +----------+--------------+------------+ + UDF can use keyword arguments: + + >>> @udf(returnType=IntegerType()) + ... def calc(a, b): + ... return a + 10 * b + ... + >>> spark.range(2).select(calc(b=col("id") * 10, a=col("id"))).show() + +-----------------------------+ + |calc(b => (id * 10), a => id)| + +-----------------------------+ + | 0| + | 101| + +-----------------------------+ + Notes ----- The user-defined functions are considered deterministic by default. Due to diff --git a/python/pyspark/sql/pandas/functions.py b/python/pyspark/sql/pandas/functions.py index b7d381f04c7..ad9fdac9706 100644 --- a/python/pyspark/sql/pandas/functions.py +++ b/python/pyspark/sql/pandas/functions.py @@ -56,6 +56,9 @@ def pandas_udf(f=None, returnType=None, functionType=None): .. versionchanged:: 3.4.0 Supports Spark Connect. + .. versionchanged:: 4.0.0 + Supports keyword-arguments in SCALAR type. + Parameters ---------- f : function, optional @@ -153,6 +156,20 @@ def pandas_udf(f=None, returnType=None, functionType=None): | [John, Doe]| +------------------+ + This type of Pandas UDF can use keyword arguments: + + >>> @pandas_udf(returnType=IntegerType()) + ... def calc(a: pd.Series, b: pd.Series) -> pd.Series: + ... return a + 10 * b + ... + >>> spark.range(2).select(calc(b=col("id") * 10, a=col("id"))).show() + +-----------------------------+ + |calc(b => (id * 10), a => id)| + +-----------------------------+ + | 0| + | 101| + +-----------------------------+ + .. note:: The length of the input is not that of the whole input column, but is the length of an internal batch used for each call to the function. diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py index 7a80547b3fc..8cb397ab95d 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py @@ -50,7 +50,7 @@ from pyspark.sql.types import ( BinaryType, YearMonthIntervalType, ) -from pyspark.errors import AnalysisException +from pyspark.errors import AnalysisException, PythonException from pyspark.testing.sqlutils import ( ReusedSQLTestCase, test_compiled, @@ -1467,6 +1467,98 @@ class ScalarPandasUDFTestsMixin: finally: shutil.rmtree(path) + def test_named_arguments(self): + @pandas_udf("int") + def test_udf(a, b): + return a + 10 * b + + self.spark.udf.register("test_udf", test_udf) + + for i, df in enumerate( + [ + self.spark.range(2).select(test_udf(col("id"), b=col("id") * 10)), + self.spark.range(2).select(test_udf(a=col("id"), b=col("id") * 10)), + self.spark.range(2).select(test_udf(b=col("id") * 10, a=col("id"))), + self.spark.sql("SELECT test_udf(id, b => id * 10) FROM range(2)"), + self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM range(2)"), + self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM range(2)"), + ] + ): + with self.subTest(query_no=i): + assertDataFrameEqual(df, [Row(0), Row(101)]) + + def test_named_arguments_negative(self): + @pandas_udf("int") + def test_udf(a, b): + return a + b + + self.spark.udf.register("test_udf", test_udf) + + with self.assertRaisesRegex( + AnalysisException, + "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", + ): + self.spark.sql("SELECT test_udf(a => id, a => id * 10) FROM range(2)").show() + + with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"): + self.spark.sql("SELECT test_udf(a => id, id * 10) FROM range(2)").show() + + with self.assertRaisesRegex( + PythonException, r"test_udf\(\) got an unexpected keyword argument 'c'" + ): + self.spark.sql("SELECT test_udf(c => 'x') FROM range(2)").show() + + def test_kwargs(self): + @pandas_udf("int") + def test_udf(a, **kwargs): + return a + 10 * kwargs["b"] + + self.spark.udf.register("test_udf", test_udf) + + for i, df in enumerate( + [ + self.spark.range(2).select(test_udf(a=col("id"), b=col("id") * 10)), + self.spark.range(2).select(test_udf(b=col("id") * 10, a=col("id"))), + self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM range(2)"), + self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM range(2)"), + ] + ): + with self.subTest(query_no=i): + assertDataFrameEqual(df, [Row(0), Row(101)]) + + def test_named_arguments_and_defaults(self): + @pandas_udf("int") + def test_udf(a, b=0): + return a + 10 * b + + self.spark.udf.register("test_udf", test_udf) + + # without "b" + for i, df in enumerate( + [ + self.spark.range(2).select(test_udf(col("id"))), + self.spark.range(2).select(test_udf(a=col("id"))), + self.spark.sql("SELECT test_udf(id) FROM range(2)"), + self.spark.sql("SELECT test_udf(a => id) FROM range(2)"), + ] + ): + with self.subTest(with_b=False, query_no=i): + assertDataFrameEqual(df, [Row(0), Row(1)]) + + # with "b" + for i, df in enumerate( + [ + self.spark.range(2).select(test_udf(col("id"), b=col("id") * 10)), + self.spark.range(2).select(test_udf(a=col("id"), b=col("id") * 10)), + self.spark.range(2).select(test_udf(b=col("id") * 10, a=col("id"))), + self.spark.sql("SELECT test_udf(id, b => id * 10) FROM range(2)"), + self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM range(2)"), + self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM range(2)"), + ] + ): + with self.subTest(with_b=True, query_no=i): + assertDataFrameEqual(df, [Row(0), Row(101)]) + class ScalarPandasUDFTests(ScalarPandasUDFTestsMixin, ReusedSQLTestCase): @classmethod diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py index 239ff27813b..f72bf288230 100644 --- a/python/pyspark/sql/tests/test_udf.py +++ b/python/pyspark/sql/tests/test_udf.py @@ -24,7 +24,7 @@ import datetime from pyspark import SparkContext, SQLContext from pyspark.sql import SparkSession, Column, Row -from pyspark.sql.functions import udf, assert_true, lit, rand +from pyspark.sql.functions import col, udf, assert_true, lit, rand from pyspark.sql.udf import UserDefinedFunction from pyspark.sql.types import ( StringType, @@ -38,9 +38,9 @@ from pyspark.sql.types import ( TimestampNTZType, DayTimeIntervalType, ) -from pyspark.errors import AnalysisException, PySparkTypeError +from pyspark.errors import AnalysisException, PythonException, PySparkTypeError from pyspark.testing.sqlutils import ReusedSQLTestCase, test_compiled, test_not_compiled_message -from pyspark.testing.utils import QuietTest +from pyspark.testing.utils import QuietTest, assertDataFrameEqual class BaseUDFTestsMixin(object): @@ -898,6 +898,98 @@ class BaseUDFTestsMixin(object): self.assertEquals(row[1], {"a": "b"}) self.assertEquals(row[2], Row(col1=1, col2=2)) + def test_named_arguments(self): + @udf("int") + def test_udf(a, b): + return a + 10 * b + + self.spark.udf.register("test_udf", test_udf) + + for i, df in enumerate( + [ + self.spark.range(2).select(test_udf(col("id"), b=col("id") * 10)), + self.spark.range(2).select(test_udf(a=col("id"), b=col("id") * 10)), + self.spark.range(2).select(test_udf(b=col("id") * 10, a=col("id"))), + self.spark.sql("SELECT test_udf(id, b => id * 10) FROM range(2)"), + self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM range(2)"), + self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM range(2)"), + ] + ): + with self.subTest(query_no=i): + assertDataFrameEqual(df, [Row(0), Row(101)]) + + def test_named_arguments_negative(self): + @udf("int") + def test_udf(a, b): + return a + b + + self.spark.udf.register("test_udf", test_udf) + + with self.assertRaisesRegex( + AnalysisException, + "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", + ): + self.spark.sql("SELECT test_udf(a => id, a => id * 10) FROM range(2)").show() + + with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"): + self.spark.sql("SELECT test_udf(a => id, id * 10) FROM range(2)").show() + + with self.assertRaisesRegex( + PythonException, r"test_udf\(\) got an unexpected keyword argument 'c'" + ): + self.spark.sql("SELECT test_udf(c => 'x') FROM range(2)").show() + + def test_kwargs(self): + @udf("int") + def test_udf(**kwargs): + return kwargs["a"] + 10 * kwargs["b"] + + self.spark.udf.register("test_udf", test_udf) + + for i, df in enumerate( + [ + self.spark.range(2).select(test_udf(a=col("id"), b=col("id") * 10)), + self.spark.range(2).select(test_udf(b=col("id") * 10, a=col("id"))), + self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM range(2)"), + self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM range(2)"), + ] + ): + with self.subTest(query_no=i): + assertDataFrameEqual(df, [Row(0), Row(101)]) + + def test_named_arguments_and_defaults(self): + @udf("int") + def test_udf(a, b=0): + return a + 10 * b + + self.spark.udf.register("test_udf", test_udf) + + # without "b" + for i, df in enumerate( + [ + self.spark.range(2).select(test_udf(col("id"))), + self.spark.range(2).select(test_udf(a=col("id"))), + self.spark.sql("SELECT test_udf(id) FROM range(2)"), + self.spark.sql("SELECT test_udf(a => id) FROM range(2)"), + ] + ): + with self.subTest(with_b=False, query_no=i): + assertDataFrameEqual(df, [Row(0), Row(1)]) + + # with "b" + for i, df in enumerate( + [ + self.spark.range(2).select(test_udf(col("id"), b=col("id") * 10)), + self.spark.range(2).select(test_udf(a=col("id"), b=col("id") * 10)), + self.spark.range(2).select(test_udf(b=col("id") * 10, a=col("id"))), + self.spark.sql("SELECT test_udf(id, b => id * 10) FROM range(2)"), + self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM range(2)"), + self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM range(2)"), + ] + ): + with self.subTest(with_b=True, query_no=i): + assertDataFrameEqual(df, [Row(0), Row(101)]) + class UDFTests(BaseUDFTestsMixin, ReusedSQLTestCase): @classmethod diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index 9a80f8fd73c..63743de5e03 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -1947,7 +1947,7 @@ class BaseUDTFTestsMixin: TestUDTF(a=lit(10)), ] ): - with self.subTest(query_no=i): + with self.subTest(with_b=False, query_no=i): assertDataFrameEqual(df, [Row(a=10, b=100)]) # with "b" @@ -1961,7 +1961,7 @@ class BaseUDTFTestsMixin: TestUDTF(b=lit("z"), a=lit(10)), ] ): - with self.subTest(query_no=i): + with self.subTest(with_b=True, query_no=i): assertDataFrameEqual(df, [Row(a=10, b="z")]) def test_udtf_with_table_argument_and_partition_by(self): diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 7d7784dd522..029293ab70f 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -30,7 +30,7 @@ from py4j.java_gateway import JavaObject from pyspark import SparkContext from pyspark.profiler import Profiler from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType -from pyspark.sql.column import Column, _to_java_column, _to_java_expr, _to_seq +from pyspark.sql.column import Column, _to_java_expr, _to_seq from pyspark.sql.types import ( DataType, StringType, @@ -336,8 +336,17 @@ class UserDefinedFunction: ) return judf - def __call__(self, *cols: "ColumnOrName") -> Column: + def __call__(self, *args: "ColumnOrName", **kwargs: "ColumnOrName") -> Column: sc = get_active_spark_context() + + assert sc._jvm is not None + jexprs = [_to_java_expr(arg) for arg in args] + [ + sc._jvm.org.apache.spark.sql.catalyst.expressions.NamedArgumentExpression( + key, _to_java_expr(value) + ) + for key, value in kwargs.items() + ] + profiler: Optional[Profiler] = None memory_profiler: Optional[Profiler] = None if sc.profiler_collector: @@ -376,7 +385,7 @@ class UserDefinedFunction: func.__signature__ = inspect.signature(f) # type: ignore[attr-defined] judf = self._create_judf(func) - jUDFExpr = judf.builder(_to_seq(sc, cols, _to_java_expr)) + jUDFExpr = judf.builder(_to_seq(sc, jexprs)) jPythonUDF = judf.fromUDFExpr(jUDFExpr) id = jUDFExpr.resultId().id() sc.profiler_collector.add_profiler(id, profiler) @@ -394,13 +403,14 @@ class UserDefinedFunction: func.__signature__ = inspect.signature(f) # type: ignore[attr-defined] judf = self._create_judf(func) - jUDFExpr = judf.builder(_to_seq(sc, cols, _to_java_expr)) + jUDFExpr = judf.builder(_to_seq(sc, jexprs)) jPythonUDF = judf.fromUDFExpr(jUDFExpr) id = jUDFExpr.resultId().id() sc.profiler_collector.add_profiler(id, memory_profiler) else: judf = self._judf - jPythonUDF = judf.apply(_to_seq(sc, cols, _to_java_column)) + jUDFExpr = judf.builder(_to_seq(sc, jexprs)) + jPythonUDF = judf.fromUDFExpr(jUDFExpr) return Column(jPythonUDF) # This function is for improving the online help system in the interactive interpreter. @@ -421,8 +431,8 @@ class UserDefinedFunction: ) @functools.wraps(self.func, assigned=assignments) - def wrapper(*args: "ColumnOrName") -> Column: - return self(*args) + def wrapper(*args: "ColumnOrName", **kwargs: "ColumnOrName") -> Column: + return self(*args, **kwargs) wrapper.__name__ = self._name wrapper.__module__ = ( diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 0dd281ea91f..19c8c9c897b 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -84,9 +84,9 @@ def chain(f, g): def wrap_udf(f, return_type): if return_type.needConversion(): toInternal = return_type.toInternal - return lambda *a: toInternal(f(*a)) + return lambda *a, **kw: toInternal(f(*a, **kw)) else: - return lambda *a: f(*a) + return lambda *a, **kw: f(*a, **kw) def wrap_scalar_pandas_udf(f, return_type): @@ -115,8 +115,10 @@ def wrap_scalar_pandas_udf(f, return_type): ) return result - return lambda *a: ( - verify_result_length(verify_result_type(f(*a)), len(a[0])), + return lambda *a, **kw: ( + verify_result_length( + verify_result_type(f(*a, **kw)), len((list(a) + list(kw.values()))[0]) + ), arrow_return_type, ) @@ -137,8 +139,17 @@ def wrap_arrow_batch_udf(f, return_type): elif type(return_type) == BinaryType: result_func = lambda r: bytes(r) if r is not None else r # noqa: E731 - def evaluate(*args: pd.Series) -> pd.Series: - return pd.Series(result_func(f(*a)) for a in zip(*args)) + def evaluate(*args: pd.Series, **kwargs: pd.Series) -> pd.Series: + keys = list(kwargs.keys()) + len_args = len(args) + return pd.Series( + [ + result_func( + f(*row[:len_args], **{key: row[len_args + i] for i, key in enumerate(keys)}) + ) + for row in zip(*args, *[kwargs[key] for key in keys]) + ] + ) def verify_result_type(result): if not hasattr(result, "__len__"): @@ -163,8 +174,10 @@ def wrap_arrow_batch_udf(f, return_type): ) return result - return lambda *a: ( - verify_result_length(verify_result_type(evaluate(*a)), len(a[0])), + return lambda *a, **kw: ( + verify_result_length( + verify_result_type(evaluate(*a, **kw)), len((list(a) + list(kw.values()))[0]) + ), arrow_return_type, ) @@ -517,7 +530,27 @@ def wrap_bounded_window_agg_pandas_udf(f, return_type): def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): num_arg = read_int(infile) - arg_offsets = [read_int(infile) for i in range(num_arg)] + + if eval_type in ( + PythonEvalType.SQL_BATCHED_UDF, + PythonEvalType.SQL_ARROW_BATCHED_UDF, + PythonEvalType.SQL_SCALAR_PANDAS_UDF, + # The below doesn't support named argument, but shares the same protocol. + PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, + ): + args_offsets = [] + kwargs_offsets = {} + for _ in range(num_arg): + offset = read_int(infile) + if read_bool(infile): + name = utf8_deserializer.loads(infile) + kwargs_offsets[name] = offset + else: + args_offsets.append(offset) + else: + args_offsets = [read_int(infile) for i in range(num_arg)] + kwargs_offsets = {} + chained_func = None for i in range(read_int(infile)): f, return_type = read_command(pickleSer, infile) @@ -535,31 +568,32 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): # the last returnType will be the return type of UDF if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF: - return arg_offsets, wrap_scalar_pandas_udf(func, return_type) + udf = wrap_scalar_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF: - return arg_offsets, wrap_arrow_batch_udf(func, return_type) + udf = wrap_arrow_batch_udf(func, return_type) elif eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF: - return arg_offsets, wrap_pandas_batch_iter_udf(func, return_type) + udf = wrap_pandas_batch_iter_udf(func, return_type) elif eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF: - return arg_offsets, wrap_pandas_batch_iter_udf(func, return_type) + udf = wrap_pandas_batch_iter_udf(func, return_type) elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF: - return arg_offsets, wrap_arrow_batch_iter_udf(func, return_type) + udf = wrap_arrow_batch_iter_udf(func, return_type) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: argspec = getfullargspec(chained_func) # signature was lost when wrapping it - return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec, runner_conf) + udf = wrap_grouped_map_pandas_udf(func, return_type, argspec, runner_conf) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: - return arg_offsets, wrap_grouped_map_pandas_udf_with_state(func, return_type) + udf = wrap_grouped_map_pandas_udf_with_state(func, return_type) elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: argspec = getfullargspec(chained_func) # signature was lost when wrapping it - return arg_offsets, wrap_cogrouped_map_pandas_udf(func, return_type, argspec, runner_conf) + udf = wrap_cogrouped_map_pandas_udf(func, return_type, argspec, runner_conf) elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: - return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type) + udf = wrap_grouped_agg_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF: - return arg_offsets, wrap_window_agg_pandas_udf(func, return_type, runner_conf, udf_index) + udf = wrap_window_agg_pandas_udf(func, return_type, runner_conf, udf_index) elif eval_type == PythonEvalType.SQL_BATCHED_UDF: - return arg_offsets, wrap_udf(func, return_type) + udf = wrap_udf(func, return_type) else: raise ValueError("Unknown eval type: {}".format(eval_type)) + return args_offsets, kwargs_offsets, udf # Used by SQL_GROUPED_MAP_PANDAS_UDF and SQL_SCALAR_PANDAS_UDF and SQL_ARROW_BATCHED_UDF when @@ -984,7 +1018,9 @@ def read_udfs(pickleSer, infile, eval_type): if is_map_arrow_iter: assert num_udfs == 1, "One MAP_ARROW_ITER UDF expected here." - arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0) + arg_offsets, _, udf = read_single_udf( + pickleSer, infile, eval_type, runner_conf, udf_index=0 + ) def func(_, iterator): num_input_rows = 0 @@ -1074,7 +1110,7 @@ def read_udfs(pickleSer, infile, eval_type): # See FlatMapGroupsInPandasExec for how arg_offsets are used to # distinguish between grouping attributes and data attributes - arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0) + arg_offsets, _, f = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0) parsed_offsets = extract_key_value_indexes(arg_offsets) # Create function like this: @@ -1091,7 +1127,7 @@ def read_udfs(pickleSer, infile, eval_type): # See FlatMapGroupsInPandas(WithState)Exec for how arg_offsets are used to # distinguish between grouping attributes and data attributes - arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0) + arg_offsets, _, f = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0) parsed_offsets = extract_key_value_indexes(arg_offsets) def mapper(a): @@ -1125,7 +1161,7 @@ def read_udfs(pickleSer, infile, eval_type): # We assume there is only one UDF here because cogrouped map doesn't # support combining multiple UDFs. assert num_udfs == 1 - arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0) + arg_offsets, _, f = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0) parsed_offsets = extract_key_value_indexes(arg_offsets) @@ -1142,7 +1178,10 @@ def read_udfs(pickleSer, infile, eval_type): udfs.append(read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=i)) def mapper(a): - result = tuple(f(*[a[o] for o in arg_offsets]) for (arg_offsets, f) in udfs) + result = tuple( + f(*[a[o] for o in args_offsets], **{k: a[o] for k, o in kwargs_offsets.items()}) + for args_offsets, kwargs_offsets, f in udfs + ) # In the special case of a single UDF this will return a single result rather # than a tuple of results; this is the format that the JVM side expects. if len(result) == 1: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala index a60d0beeeed..9fde1814079 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala @@ -105,6 +105,10 @@ class ApplyInPandasWithStatePythonRunner( private val stateRowDeserializer = stateEncoder.createDeserializer() + override protected def writeUDF(dataOut: DataOutputStream): Unit = { + PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) + } + /** * This method sends out the additional metadata before sending out actual data. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index 7db43a34a88..bd91da3bc0f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata import org.apache.spark.sql.types.StructType /** @@ -94,11 +95,11 @@ class ArrowEvalPythonEvaluatorFactory( pythonRunnerConf: Map[String, String], pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String]) - extends EvalPythonEvaluatorFactory(childOutput, udfs, output) { + extends EvalPythonEvaluatorFactory(childOutput, udfs, output) { override def evaluate( funcs: Seq[ChainedPythonFunctions], - argOffsets: Array[Array[Int]], + argMetas: Array[Array[ArgumentMetadata]], iter: Iterator[InternalRow], schema: StructType, context: TaskContext): Iterator[InternalRow] = { @@ -108,10 +109,10 @@ class ArrowEvalPythonEvaluatorFactory( // DO NOT use iter.grouped(). See BatchIterator. val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else Iterator(iter) - val columnarBatchIter = new ArrowPythonRunner( + val columnarBatchIter = new ArrowPythonWithNamedArgumentRunner( funcs, evalType, - argOffsets, + argMetas, schema, sessionLocalTimeZone, largeVarTypes, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala index 8ebd8a3a106..9d5bac0c600 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala @@ -23,7 +23,7 @@ import org.apache.spark.{JobArtifactSet, TaskContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.python.EvalPythonUDTFExec.ArgumentMetadata +import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 0f26d8f21f8..251e682c9e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -17,17 +17,17 @@ package org.apache.spark.sql.execution.python +import java.io.DataOutputStream + import org.apache.spark.api.python._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch -/** - * Similar to `PythonUDFRunner`, but exchange data with Python worker via Arrow stream. - */ -class ArrowPythonRunner( +abstract class BaseArrowPythonRunner( funcs: Seq[ChainedPythonFunctions], evalType: Int, argOffsets: Array[Array[Int]], @@ -61,6 +61,49 @@ class ArrowPythonRunner( s"Please change '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'.") } +/** + * Similar to `PythonUDFRunner`, but exchange data with Python worker via Arrow stream. + */ +class ArrowPythonRunner( + funcs: Seq[ChainedPythonFunctions], + evalType: Int, + argOffsets: Array[Array[Int]], + _schema: StructType, + _timeZoneId: String, + largeVarTypes: Boolean, + workerConf: Map[String, String], + pythonMetrics: Map[String, SQLMetric], + jobArtifactUUID: Option[String]) + extends BaseArrowPythonRunner( + funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes, workerConf, + pythonMetrics, jobArtifactUUID) { + + override protected def writeUDF(dataOut: DataOutputStream): Unit = + PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) +} + +/** + * Similar to `PythonUDFWithNamedArgumentsRunner`, but exchange data with Python worker + * via Arrow stream. + */ +class ArrowPythonWithNamedArgumentRunner( + funcs: Seq[ChainedPythonFunctions], + evalType: Int, + argMetas: Array[Array[ArgumentMetadata]], + _schema: StructType, + _timeZoneId: String, + largeVarTypes: Boolean, + workerConf: Map[String, String], + pythonMetrics: Map[String, SQLMetric], + jobArtifactUUID: Option[String]) + extends BaseArrowPythonRunner( + funcs, evalType, argMetas.map(_.map(_.offset)), _schema, _timeZoneId, largeVarTypes, workerConf, + pythonMetrics, jobArtifactUUID) { + + override protected def writeUDF(dataOut: DataOutputStream): Unit = + PythonUDFRunner.writeUDFs(dataOut, funcs, argMetas) +} + object ArrowPythonRunner { /** Return Map with conf settings to be used in ArrowPythonRunner */ def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala index c0fa8b58bee..690947b4129 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala @@ -23,7 +23,7 @@ import org.apache.spark.api.python._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.PythonUDTF import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.sql.execution.python.EvalPythonUDTFExec.ArgumentMetadata +import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch @@ -42,15 +42,12 @@ class ArrowPythonUDTFRunner( val pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String]) extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch]( - Seq(ChainedPythonFunctions(Seq(udtf.func))), - evalType, Array(argMetas.map(_.offset)), jobArtifactUUID) + Seq(ChainedPythonFunctions(Seq(udtf.func))), evalType, Array(argMetas.map(_.offset)), + jobArtifactUUID) with BasicPythonArrowInput with BasicPythonArrowOutput { - override protected def writeUDF( - dataOut: DataOutputStream, - funcs: Seq[ChainedPythonFunctions], - argOffsets: Array[Array[Int]]): Unit = { + override protected def writeUDF(dataOut: DataOutputStream): Unit = { PythonUDTFRunner.writeUDTF(dataOut, udtf, argMetas) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index 1de8f55d84b..a0e7789b281 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata import org.apache.spark.sql.types.{StructField, StructType} /** @@ -60,7 +61,7 @@ class BatchEvalPythonEvaluatorFactory( override def evaluate( funcs: Seq[ChainedPythonFunctions], - argOffsets: Array[Array[Int]], + argMetas: Array[Array[ArgumentMetadata]], iter: Iterator[InternalRow], schema: StructType, context: TaskContext): Iterator[InternalRow] = { @@ -71,8 +72,8 @@ class BatchEvalPythonEvaluatorFactory( // Output iterator for results from Python. val outputIterator = - new PythonUDFRunner( - funcs, PythonEvalType.SQL_BATCHED_UDF, argOffsets, pythonMetrics, jobArtifactUUID) + new PythonUDFWithNamedArgumentsRunner( + funcs, PythonEvalType.SQL_BATCHED_UDF, argMetas, pythonMetrics, jobArtifactUUID) .compute(inputIterator, context.partitionId(), context) val unpickle = new Unpickler diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala index 46cc0f2ab50..342e0723194 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala @@ -23,14 +23,14 @@ import scala.collection.JavaConverters._ import net.razorvine.pickle.Unpickler -import org.apache.spark.{JobArtifactSet, SparkEnv, TaskContext} -import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonWorker, PythonWorkerUtils} +import org.apache.spark.{JobArtifactSet, TaskContext} +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonWorkerUtils} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.sql.execution.python.EvalPythonUDTFExec.ArgumentMetadata +import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata import org.apache.spark.sql.types.StructType /** @@ -101,18 +101,8 @@ class PythonUDTFRunner( Seq(ChainedPythonFunctions(Seq(udtf.func))), PythonEvalType.SQL_TABLE_UDF, Array(argMetas.map(_.offset)), pythonMetrics, jobArtifactUUID) { - protected override def newWriter( - env: SparkEnv, - worker: PythonWorker, - inputIterator: Iterator[Array[Byte]], - partitionIndex: Int, - context: TaskContext): Writer = { - new PythonUDFWriter(env, worker, inputIterator, partitionIndex, context) { - - protected override def writeCommand(dataOut: DataOutputStream): Unit = { - PythonUDTFRunner.writeUDTF(dataOut, udtf, argMetas) - } - } + override protected def writeUDF(dataOut: DataOutputStream): Unit = { + PythonUDTFRunner.writeUDTF(dataOut, udtf, argMetas) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonEvaluatorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonEvaluatorFactory.scala index 373e17c0aa3..d5142f58eab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonEvaluatorFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonEvaluatorFactory.scala @@ -25,6 +25,7 @@ import org.apache.spark.{PartitionEvaluator, PartitionEvaluatorFactory, SparkEnv import org.apache.spark.api.python.ChainedPythonFunctions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.util.Utils @@ -32,11 +33,11 @@ abstract class EvalPythonEvaluatorFactory( childOutput: Seq[Attribute], udfs: Seq[PythonUDF], output: Seq[Attribute]) - extends PartitionEvaluatorFactory[InternalRow, InternalRow] { + extends PartitionEvaluatorFactory[InternalRow, InternalRow] { protected def evaluate( funcs: Seq[ChainedPythonFunctions], - argOffsets: Array[Array[Int]], + argMetas: Array[Array[ArgumentMetadata]], iter: Iterator[InternalRow], schema: StructType, context: TaskContext): Iterator[InternalRow] @@ -78,14 +79,20 @@ abstract class EvalPythonEvaluatorFactory( // flatten all the arguments val allInputs = new ArrayBuffer[Expression] val dataTypes = new ArrayBuffer[DataType] - val argOffsets = inputs.map { input => + val argMetas = inputs.map { input => input.map { e => - if (allInputs.exists(_.semanticEquals(e))) { - allInputs.indexWhere(_.semanticEquals(e)) + val (key, value) = e match { + case NamedArgumentExpression(key, value) => + (Some(key), value) + case _ => + (None, e) + } + if (allInputs.exists(_.semanticEquals(value))) { + ArgumentMetadata(allInputs.indexWhere(_.semanticEquals(value)), key) } else { - allInputs += e - dataTypes += e.dataType - allInputs.length - 1 + allInputs += value + dataTypes += value.dataType + ArgumentMetadata(allInputs.length - 1, key) } }.toArray }.toArray @@ -102,7 +109,7 @@ abstract class EvalPythonEvaluatorFactory( } val outputRowIterator = - evaluate(pyFuncs, argOffsets, projectedRowIter, schema, context) + evaluate(pyFuncs, argMetas, projectedRowIter, schema, context) val joined = new JoinedRow val resultProj = UnsafeProjection.create(output, output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala index 1c8b0f2228f..af6769cfbb9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala @@ -22,6 +22,16 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.UnaryExecNode +object EvalPythonExec { + /** + * Metadata for arguments of Python UDTF. + * + * @param offset the offset of the argument + * @param name the name of the argument if it's a `NamedArgumentExpression` + */ + case class ArgumentMetadata(offset: Int, name: Option[String]) +} + /** * A physical plan that evaluates a [[PythonUDF]], one partition of tuples at a time. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala index 410209e0ada..41a99693443 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala @@ -21,25 +21,15 @@ import java.io.File import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{ContextAwareIterator, SparkEnv, TaskContext} +import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.UnaryExecNode -import org.apache.spark.sql.execution.python.EvalPythonUDTFExec.ArgumentMetadata +import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.util.Utils -object EvalPythonUDTFExec { - /** - * Metadata for arguments of Python UDTF. - * - * @param offset the offset of the argument - * @param name the name of the argument if it's a `NamedArgumentExpression` - */ - case class ArgumentMetadata(offset: Int, name: Option[String]) -} - /** * A physical plan that evaluates a [[PythonUDTF]], one partition of tuples at a time. * This is similar to [[EvalPythonExec]]. @@ -66,7 +56,6 @@ trait EvalPythonUDTFExec extends UnaryExecNode { inputRDD.mapPartitions { iter => val context = TaskContext.get() - val contextAwareIterator = new ContextAwareIterator(context, iter) // The queue used to buffer input rows so we can drain it to // combine input with output from Python. @@ -104,7 +93,7 @@ trait EvalPythonUDTFExec extends UnaryExecNode { // Also keep track of the number rows added to the queue. // This is needed to process extra output rows from the `terminate()` call of the UDTF. var count = 0L - val projectedRowIter = contextAwareIterator.map { inputRow => + val projectedRowIter = iter.map { inputRow => queue.add(inputRow.asInstanceOf[UnsafeRow]) count += 1 projection(inputRow) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala index 00ee3a17563..1e075cab922 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala @@ -22,7 +22,7 @@ import org.apache.arrow.vector.VectorSchemaRoot import org.apache.arrow.vector.ipc.ArrowStreamWriter import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonRDD, PythonWorker} +import org.apache.spark.api.python.{BasePythonRunner, PythonRDD, PythonWorker} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.arrow import org.apache.spark.sql.execution.arrow.ArrowWriter @@ -54,11 +54,7 @@ private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] => dataOut: DataOutputStream, inputIterator: Iterator[IN]): Boolean - protected def writeUDF( - dataOut: DataOutputStream, - funcs: Seq[ChainedPythonFunctions], - argOffsets: Array[Array[Int]]): Unit = - PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) + protected def writeUDF(dataOut: DataOutputStream): Unit protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = { // Write config for the worker as a number of key -> value pairs of strings @@ -94,9 +90,10 @@ protected def close(): Unit = { partitionIndex: Int, context: TaskContext): Writer = { new Writer(env, worker, inputIterator, partitionIndex, context) { + protected override def writeCommand(dataOut: DataOutputStream): Unit = { handleMetadataBeforeExec(dataOut) - writeUDF(dataOut, funcs, argOffsets) + writeUDF(dataOut) } override def writeNextInputToStream(dataOut: DataOutputStream): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala index bc27ee6919d..b99517f544d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala @@ -23,6 +23,7 @@ import java.util.concurrent.atomic.AtomicBoolean import org.apache.spark._ import org.apache.spark.api.python._ import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata import org.apache.spark.sql.internal.SQLConf /** @@ -43,24 +44,31 @@ abstract class BasePythonUDFRunner( override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback - abstract class PythonUDFWriter( + protected def writeUDF(dataOut: DataOutputStream): Unit + + protected override def newWriter( env: SparkEnv, worker: PythonWorker, inputIterator: Iterator[Array[Byte]], partitionIndex: Int, - context: TaskContext) - extends Writer(env, worker, inputIterator, partitionIndex, context) { - - override def writeNextInputToStream(dataOut: DataOutputStream): Boolean = { - val startData = dataOut.size() - val wroteData = PythonRDD.writeNextElementToStream(inputIterator, dataOut) - if (!wroteData) { - // Reached the end of input. - dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) + context: TaskContext): Writer = { + new Writer(env, worker, inputIterator, partitionIndex, context) { + + protected override def writeCommand(dataOut: DataOutputStream): Unit = { + writeUDF(dataOut) + } + + override def writeNextInputToStream(dataOut: DataOutputStream): Boolean = { + val startData = dataOut.size() + val wroteData = PythonRDD.writeNextElementToStream(inputIterator, dataOut) + if (!wroteData) { + // Reached the end of input. + dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) + } + val deltaData = dataOut.size() - startData + pythonMetrics("pythonDataSent") += deltaData + wroteData } - val deltaData = dataOut.size() - startData - pythonMetrics("pythonDataSent") += deltaData - wroteData } } @@ -111,19 +119,22 @@ class PythonUDFRunner( jobArtifactUUID: Option[String]) extends BasePythonUDFRunner(funcs, evalType, argOffsets, pythonMetrics, jobArtifactUUID) { - protected override def newWriter( - env: SparkEnv, - worker: PythonWorker, - inputIterator: Iterator[Array[Byte]], - partitionIndex: Int, - context: TaskContext): Writer = { - new PythonUDFWriter(env, worker, inputIterator, partitionIndex, context) { + override protected def writeUDF(dataOut: DataOutputStream): Unit = { + PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) + } +} - protected override def writeCommand(dataOut: DataOutputStream): Unit = { - PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) - } +class PythonUDFWithNamedArgumentsRunner( + funcs: Seq[ChainedPythonFunctions], + evalType: Int, + argMetas: Array[Array[ArgumentMetadata]], + pythonMetrics: Map[String, SQLMetric], + jobArtifactUUID: Option[String]) + extends BasePythonUDFRunner( + funcs, evalType, argMetas.map(_.map(_.offset)), pythonMetrics, jobArtifactUUID) { - } + override protected def writeUDF(dataOut: DataOutputStream): Unit = { + PythonUDFRunner.writeUDFs(dataOut, funcs, argMetas) } } @@ -146,4 +157,30 @@ object PythonUDFRunner { } } } + + def writeUDFs( + dataOut: DataOutputStream, + funcs: Seq[ChainedPythonFunctions], + argMetas: Array[Array[ArgumentMetadata]]): Unit = { + dataOut.writeInt(funcs.length) + funcs.zip(argMetas).foreach { case (chained, metas) => + dataOut.writeInt(metas.length) + metas.foreach { + case ArgumentMetadata(offset, name) => + dataOut.writeInt(offset) + name match { + case Some(name) => + dataOut.writeBoolean(true) + PythonWorkerUtils.writeUTF(name, dataOut) + case _ => + dataOut.writeBoolean(false) + } + } + dataOut.writeInt(chained.funcs.length) + chained.funcs.foreach { f => + dataOut.writeInt(f.command.length) + dataOut.write(f.command.toArray) + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index 38d521c16d5..f576637aa25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -50,6 +50,19 @@ case class UserDefinedPythonFunction( udfDeterministic: Boolean) { def builder(e: Seq[Expression]): Expression = { + if (pythonEvalType == PythonEvalType.SQL_BATCHED_UDF + || pythonEvalType ==PythonEvalType.SQL_ARROW_BATCHED_UDF + || pythonEvalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF) { + /* + * Check if the named arguments: + * - don't have duplicated names + * - don't contain positional arguments after named arguments + */ + NamedParametersSupport.splitAndCheckNamedArguments(e, name) + } else if (e.exists(_.isInstanceOf[NamedArgumentExpression])) { + throw QueryCompilationErrors.namedArgumentsNotSupported(name) + } + if (pythonEvalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF) { PythonUDAF(name, func, dataType, e, udfDeterministic) } else { @@ -104,7 +117,7 @@ case class UserDefinedPythonTableFunction( /* * Check if the named arguments: * - don't have duplicated names - * - don't contain positional arguments + * - don't contain positional arguments after named arguments */ NamedParametersSupport.splitAndCheckNamedArguments(exprs, name) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org