This is an automated email from the ASF dual-hosted git repository.

ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ce6b5f3751e [SPARK-44918][SQL][PYTHON] Support named arguments in 
scalar Python/Pandas UDFs
ce6b5f3751e is described below

commit ce6b5f3751e9ea5d1cb4b63c8e14235914817766
Author: Takuya UESHIN <ues...@databricks.com>
AuthorDate: Thu Aug 24 10:55:54 2023 -0700

    [SPARK-44918][SQL][PYTHON] Support named arguments in scalar Python/Pandas 
UDFs
    
    ### What changes were proposed in this pull request?
    
    Supports named arguments in scalar Python/Pandas UDF.
    
    For example:
    
    ```py
    >>> udf("int")
    ... def test_udf(a, b):
    ...     return a + 10 * b
    ...
    >>> spark.udf.register("test_udf", test_udf)
    
    >>> spark.range(2).select(test_udf(b=col("id") * 10, a=col("id"))).show()
    +---------------------------------+
    |test_udf(b => (id * 10), a => id)|
    +---------------------------------+
    |                                0|
    |                              101|
    +---------------------------------+
    
    >>> spark.sql("SELECT test_udf(b => id * 10, a => id) FROM range(2)").show()
    +---------------------------------+
    |test_udf(b => (id * 10), a => id)|
    +---------------------------------+
    |                                0|
    |                              101|
    +---------------------------------+
    ```
    
    or:
    
    ```py
    >>> pandas_udf("int")
    ... def test_udf(a, b):
    ...     return a + 10 * b
    ...
    >>> spark.udf.register("test_udf", test_udf)
    
    >>> spark.range(2).select(test_udf(b=col("id") * 10, a=col("id"))).show()
    +---------------------------------+
    |test_udf(b => (id * 10), a => id)|
    +---------------------------------+
    |                                0|
    |                              101|
    +---------------------------------+
    
    >>> spark.sql("SELECT test_udf(b => id * 10, a => id) FROM range(2)").show()
    +---------------------------------+
    |test_udf(b => (id * 10), a => id)|
    +---------------------------------+
    |                                0|
    |                              101|
    +---------------------------------+
    ```
    
    ### Why are the changes needed?
    
    Now that named arguments support was added 
(https://github.com/apache/spark/pull/41796, 
https://github.com/apache/spark/pull/42020).
    
    Scalar Python/Pandas UDFs can support it.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, named arguments will be available for scalar Python/Pandas UDFs.
    
    ### How was this patch tested?
    
    Added related tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #42617 from ueshin/issues/SPARK-44918/kwargs.
    
    Authored-by: Takuya UESHIN <ues...@databricks.com>
    Signed-off-by: Takuya UESHIN <ues...@databricks.com>
---
 python/pyspark/sql/connect/udf.py                  | 24 +++---
 python/pyspark/sql/functions.py                    | 17 ++++
 python/pyspark/sql/pandas/functions.py             | 17 ++++
 .../sql/tests/pandas/test_pandas_udf_scalar.py     | 94 ++++++++++++++++++++-
 python/pyspark/sql/tests/test_udf.py               | 98 +++++++++++++++++++++-
 python/pyspark/sql/tests/test_udtf.py              |  4 +-
 python/pyspark/sql/udf.py                          | 24 ++++--
 python/pyspark/worker.py                           | 89 ++++++++++++++------
 .../ApplyInPandasWithStatePythonRunner.scala       |  4 +
 .../sql/execution/python/ArrowEvalPythonExec.scala |  9 +-
 .../execution/python/ArrowEvalPythonUDTFExec.scala |  2 +-
 .../sql/execution/python/ArrowPythonRunner.scala   | 51 ++++++++++-
 .../execution/python/ArrowPythonUDTFRunner.scala   | 11 +--
 .../sql/execution/python/BatchEvalPythonExec.scala |  7 +-
 .../execution/python/BatchEvalPythonUDTFExec.scala | 20 ++---
 .../python/EvalPythonEvaluatorFactory.scala        | 25 ++++--
 .../sql/execution/python/EvalPythonExec.scala      | 10 +++
 .../sql/execution/python/EvalPythonUDTFExec.scala  | 17 +---
 .../sql/execution/python/PythonArrowInput.scala    | 11 +--
 .../sql/execution/python/PythonUDFRunner.scala     | 85 +++++++++++++------
 .../python/UserDefinedPythonFunction.scala         | 15 +++-
 21 files changed, 497 insertions(+), 137 deletions(-)

diff --git a/python/pyspark/sql/connect/udf.py 
b/python/pyspark/sql/connect/udf.py
index 2636777e5f6..90cea26e56f 100644
--- a/python/pyspark/sql/connect/udf.py
+++ b/python/pyspark/sql/connect/udf.py
@@ -25,13 +25,15 @@ import sys
 import functools
 import warnings
 from inspect import getfullargspec
-from typing import cast, Callable, Any, TYPE_CHECKING, Optional, Union
+from typing import cast, Callable, Any, List, TYPE_CHECKING, Optional, Union
 
 from pyspark.rdd import PythonEvalType
 from pyspark.sql.connect.expressions import (
     ColumnReference,
-    PythonUDF,
     CommonInlineUserDefinedFunction,
+    Expression,
+    NamedArgumentExpression,
+    PythonUDF,
 )
 from pyspark.sql.connect.column import Column
 from pyspark.sql.connect.types import UnparsedDataType
@@ -155,12 +157,14 @@ class UserDefinedFunction:
         self.deterministic = deterministic
 
     def _build_common_inline_user_defined_function(
-        self, *cols: "ColumnOrName"
+        self, *args: "ColumnOrName", **kwargs: "ColumnOrName"
     ) -> CommonInlineUserDefinedFunction:
-        arg_cols = [
-            col if isinstance(col, Column) else Column(ColumnReference(col)) 
for col in cols
+        def to_expr(col: "ColumnOrName") -> Expression:
+            return col._expr if isinstance(col, Column) else 
ColumnReference(col)
+
+        arg_exprs: List[Expression] = [to_expr(arg) for arg in args] + [
+            NamedArgumentExpression(key, to_expr(value)) for key, value in 
kwargs.items()
         ]
-        arg_exprs = [col._expr for col in arg_cols]
 
         py_udf = PythonUDF(
             output_type=self.returnType,
@@ -175,8 +179,8 @@ class UserDefinedFunction:
             arguments=arg_exprs,
         )
 
-    def __call__(self, *cols: "ColumnOrName") -> Column:
-        return Column(self._build_common_inline_user_defined_function(*cols))
+    def __call__(self, *args: "ColumnOrName", **kwargs: "ColumnOrName") -> 
Column:
+        return Column(self._build_common_inline_user_defined_function(*args, 
**kwargs))
 
     # This function is for improving the online help system in the interactive 
interpreter.
     # For example, the built-in help / pydoc.help. It wraps the UDF with the 
docstring and
@@ -196,8 +200,8 @@ class UserDefinedFunction:
         )
 
         @functools.wraps(self.func, assigned=assignments)
-        def wrapper(*args: "ColumnOrName") -> Column:
-            return self(*args)
+        def wrapper(*args: "ColumnOrName", **kwargs: "ColumnOrName") -> Column:
+            return self(*args, **kwargs)
 
         wrapper.__name__ = self._name
         wrapper.__module__ = (
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index b5e64c7a039..e580d2aba12 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -15422,6 +15422,9 @@ def udf(
     .. versionchanged:: 3.4.0
         Supports Spark Connect.
 
+    .. versionchanged:: 4.0.0
+        Supports keyword-arguments.
+
     Parameters
     ----------
     f : function
@@ -15455,6 +15458,20 @@ def udf(
     |         8|      JOHN DOE|          22|
     +----------+--------------+------------+
 
+    UDF can use keyword arguments:
+
+    >>> @udf(returnType=IntegerType())
+    ... def calc(a, b):
+    ...     return a + 10 * b
+    ...
+    >>> spark.range(2).select(calc(b=col("id") * 10, a=col("id"))).show()
+    +-----------------------------+
+    |calc(b => (id * 10), a => id)|
+    +-----------------------------+
+    |                            0|
+    |                          101|
+    +-----------------------------+
+
     Notes
     -----
     The user-defined functions are considered deterministic by default. Due to
diff --git a/python/pyspark/sql/pandas/functions.py 
b/python/pyspark/sql/pandas/functions.py
index b7d381f04c7..ad9fdac9706 100644
--- a/python/pyspark/sql/pandas/functions.py
+++ b/python/pyspark/sql/pandas/functions.py
@@ -56,6 +56,9 @@ def pandas_udf(f=None, returnType=None, functionType=None):
     .. versionchanged:: 3.4.0
         Supports Spark Connect.
 
+    .. versionchanged:: 4.0.0
+        Supports keyword-arguments in SCALAR type.
+
     Parameters
     ----------
     f : function, optional
@@ -153,6 +156,20 @@ def pandas_udf(f=None, returnType=None, functionType=None):
         |       [John, Doe]|
         +------------------+
 
+        This type of Pandas UDF can use keyword arguments:
+
+        >>> @pandas_udf(returnType=IntegerType())
+        ... def calc(a: pd.Series, b: pd.Series) -> pd.Series:
+        ...     return a + 10 * b
+        ...
+        >>> spark.range(2).select(calc(b=col("id") * 10, a=col("id"))).show()
+        +-----------------------------+
+        |calc(b => (id * 10), a => id)|
+        +-----------------------------+
+        |                            0|
+        |                          101|
+        +-----------------------------+
+
         .. note:: The length of the input is not that of the whole input 
column, but is the
             length of an internal batch used for each call to the function.
 
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py 
b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
index 7a80547b3fc..8cb397ab95d 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
@@ -50,7 +50,7 @@ from pyspark.sql.types import (
     BinaryType,
     YearMonthIntervalType,
 )
-from pyspark.errors import AnalysisException
+from pyspark.errors import AnalysisException, PythonException
 from pyspark.testing.sqlutils import (
     ReusedSQLTestCase,
     test_compiled,
@@ -1467,6 +1467,98 @@ class ScalarPandasUDFTestsMixin:
         finally:
             shutil.rmtree(path)
 
+    def test_named_arguments(self):
+        @pandas_udf("int")
+        def test_udf(a, b):
+            return a + 10 * b
+
+        self.spark.udf.register("test_udf", test_udf)
+
+        for i, df in enumerate(
+            [
+                self.spark.range(2).select(test_udf(col("id"), b=col("id") * 
10)),
+                self.spark.range(2).select(test_udf(a=col("id"), b=col("id") * 
10)),
+                self.spark.range(2).select(test_udf(b=col("id") * 10, 
a=col("id"))),
+                self.spark.sql("SELECT test_udf(id, b => id * 10) FROM 
range(2)"),
+                self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM 
range(2)"),
+                self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM 
range(2)"),
+            ]
+        ):
+            with self.subTest(query_no=i):
+                assertDataFrameEqual(df, [Row(0), Row(101)])
+
+    def test_named_arguments_negative(self):
+        @pandas_udf("int")
+        def test_udf(a, b):
+            return a + b
+
+        self.spark.udf.register("test_udf", test_udf)
+
+        with self.assertRaisesRegex(
+            AnalysisException,
+            
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
+        ):
+            self.spark.sql("SELECT test_udf(a => id, a => id * 10) FROM 
range(2)").show()
+
+        with self.assertRaisesRegex(AnalysisException, 
"UNEXPECTED_POSITIONAL_ARGUMENT"):
+            self.spark.sql("SELECT test_udf(a => id, id * 10) FROM 
range(2)").show()
+
+        with self.assertRaisesRegex(
+            PythonException, r"test_udf\(\) got an unexpected keyword argument 
'c'"
+        ):
+            self.spark.sql("SELECT test_udf(c => 'x') FROM range(2)").show()
+
+    def test_kwargs(self):
+        @pandas_udf("int")
+        def test_udf(a, **kwargs):
+            return a + 10 * kwargs["b"]
+
+        self.spark.udf.register("test_udf", test_udf)
+
+        for i, df in enumerate(
+            [
+                self.spark.range(2).select(test_udf(a=col("id"), b=col("id") * 
10)),
+                self.spark.range(2).select(test_udf(b=col("id") * 10, 
a=col("id"))),
+                self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM 
range(2)"),
+                self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM 
range(2)"),
+            ]
+        ):
+            with self.subTest(query_no=i):
+                assertDataFrameEqual(df, [Row(0), Row(101)])
+
+    def test_named_arguments_and_defaults(self):
+        @pandas_udf("int")
+        def test_udf(a, b=0):
+            return a + 10 * b
+
+        self.spark.udf.register("test_udf", test_udf)
+
+        # without "b"
+        for i, df in enumerate(
+            [
+                self.spark.range(2).select(test_udf(col("id"))),
+                self.spark.range(2).select(test_udf(a=col("id"))),
+                self.spark.sql("SELECT test_udf(id) FROM range(2)"),
+                self.spark.sql("SELECT test_udf(a => id) FROM range(2)"),
+            ]
+        ):
+            with self.subTest(with_b=False, query_no=i):
+                assertDataFrameEqual(df, [Row(0), Row(1)])
+
+        # with "b"
+        for i, df in enumerate(
+            [
+                self.spark.range(2).select(test_udf(col("id"), b=col("id") * 
10)),
+                self.spark.range(2).select(test_udf(a=col("id"), b=col("id") * 
10)),
+                self.spark.range(2).select(test_udf(b=col("id") * 10, 
a=col("id"))),
+                self.spark.sql("SELECT test_udf(id, b => id * 10) FROM 
range(2)"),
+                self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM 
range(2)"),
+                self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM 
range(2)"),
+            ]
+        ):
+            with self.subTest(with_b=True, query_no=i):
+                assertDataFrameEqual(df, [Row(0), Row(101)])
+
 
 class ScalarPandasUDFTests(ScalarPandasUDFTestsMixin, ReusedSQLTestCase):
     @classmethod
diff --git a/python/pyspark/sql/tests/test_udf.py 
b/python/pyspark/sql/tests/test_udf.py
index 239ff27813b..f72bf288230 100644
--- a/python/pyspark/sql/tests/test_udf.py
+++ b/python/pyspark/sql/tests/test_udf.py
@@ -24,7 +24,7 @@ import datetime
 
 from pyspark import SparkContext, SQLContext
 from pyspark.sql import SparkSession, Column, Row
-from pyspark.sql.functions import udf, assert_true, lit, rand
+from pyspark.sql.functions import col, udf, assert_true, lit, rand
 from pyspark.sql.udf import UserDefinedFunction
 from pyspark.sql.types import (
     StringType,
@@ -38,9 +38,9 @@ from pyspark.sql.types import (
     TimestampNTZType,
     DayTimeIntervalType,
 )
-from pyspark.errors import AnalysisException, PySparkTypeError
+from pyspark.errors import AnalysisException, PythonException, PySparkTypeError
 from pyspark.testing.sqlutils import ReusedSQLTestCase, test_compiled, 
test_not_compiled_message
-from pyspark.testing.utils import QuietTest
+from pyspark.testing.utils import QuietTest, assertDataFrameEqual
 
 
 class BaseUDFTestsMixin(object):
@@ -898,6 +898,98 @@ class BaseUDFTestsMixin(object):
         self.assertEquals(row[1], {"a": "b"})
         self.assertEquals(row[2], Row(col1=1, col2=2))
 
+    def test_named_arguments(self):
+        @udf("int")
+        def test_udf(a, b):
+            return a + 10 * b
+
+        self.spark.udf.register("test_udf", test_udf)
+
+        for i, df in enumerate(
+            [
+                self.spark.range(2).select(test_udf(col("id"), b=col("id") * 
10)),
+                self.spark.range(2).select(test_udf(a=col("id"), b=col("id") * 
10)),
+                self.spark.range(2).select(test_udf(b=col("id") * 10, 
a=col("id"))),
+                self.spark.sql("SELECT test_udf(id, b => id * 10) FROM 
range(2)"),
+                self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM 
range(2)"),
+                self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM 
range(2)"),
+            ]
+        ):
+            with self.subTest(query_no=i):
+                assertDataFrameEqual(df, [Row(0), Row(101)])
+
+    def test_named_arguments_negative(self):
+        @udf("int")
+        def test_udf(a, b):
+            return a + b
+
+        self.spark.udf.register("test_udf", test_udf)
+
+        with self.assertRaisesRegex(
+            AnalysisException,
+            
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
+        ):
+            self.spark.sql("SELECT test_udf(a => id, a => id * 10) FROM 
range(2)").show()
+
+        with self.assertRaisesRegex(AnalysisException, 
"UNEXPECTED_POSITIONAL_ARGUMENT"):
+            self.spark.sql("SELECT test_udf(a => id, id * 10) FROM 
range(2)").show()
+
+        with self.assertRaisesRegex(
+            PythonException, r"test_udf\(\) got an unexpected keyword argument 
'c'"
+        ):
+            self.spark.sql("SELECT test_udf(c => 'x') FROM range(2)").show()
+
+    def test_kwargs(self):
+        @udf("int")
+        def test_udf(**kwargs):
+            return kwargs["a"] + 10 * kwargs["b"]
+
+        self.spark.udf.register("test_udf", test_udf)
+
+        for i, df in enumerate(
+            [
+                self.spark.range(2).select(test_udf(a=col("id"), b=col("id") * 
10)),
+                self.spark.range(2).select(test_udf(b=col("id") * 10, 
a=col("id"))),
+                self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM 
range(2)"),
+                self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM 
range(2)"),
+            ]
+        ):
+            with self.subTest(query_no=i):
+                assertDataFrameEqual(df, [Row(0), Row(101)])
+
+    def test_named_arguments_and_defaults(self):
+        @udf("int")
+        def test_udf(a, b=0):
+            return a + 10 * b
+
+        self.spark.udf.register("test_udf", test_udf)
+
+        # without "b"
+        for i, df in enumerate(
+            [
+                self.spark.range(2).select(test_udf(col("id"))),
+                self.spark.range(2).select(test_udf(a=col("id"))),
+                self.spark.sql("SELECT test_udf(id) FROM range(2)"),
+                self.spark.sql("SELECT test_udf(a => id) FROM range(2)"),
+            ]
+        ):
+            with self.subTest(with_b=False, query_no=i):
+                assertDataFrameEqual(df, [Row(0), Row(1)])
+
+        # with "b"
+        for i, df in enumerate(
+            [
+                self.spark.range(2).select(test_udf(col("id"), b=col("id") * 
10)),
+                self.spark.range(2).select(test_udf(a=col("id"), b=col("id") * 
10)),
+                self.spark.range(2).select(test_udf(b=col("id") * 10, 
a=col("id"))),
+                self.spark.sql("SELECT test_udf(id, b => id * 10) FROM 
range(2)"),
+                self.spark.sql("SELECT test_udf(a => id, b => id * 10) FROM 
range(2)"),
+                self.spark.sql("SELECT test_udf(b => id * 10, a => id) FROM 
range(2)"),
+            ]
+        ):
+            with self.subTest(with_b=True, query_no=i):
+                assertDataFrameEqual(df, [Row(0), Row(101)])
+
 
 class UDFTests(BaseUDFTestsMixin, ReusedSQLTestCase):
     @classmethod
diff --git a/python/pyspark/sql/tests/test_udtf.py 
b/python/pyspark/sql/tests/test_udtf.py
index 9a80f8fd73c..63743de5e03 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -1947,7 +1947,7 @@ class BaseUDTFTestsMixin:
                 TestUDTF(a=lit(10)),
             ]
         ):
-            with self.subTest(query_no=i):
+            with self.subTest(with_b=False, query_no=i):
                 assertDataFrameEqual(df, [Row(a=10, b=100)])
 
         # with "b"
@@ -1961,7 +1961,7 @@ class BaseUDTFTestsMixin:
                 TestUDTF(b=lit("z"), a=lit(10)),
             ]
         ):
-            with self.subTest(query_no=i):
+            with self.subTest(with_b=True, query_no=i):
                 assertDataFrameEqual(df, [Row(a=10, b="z")])
 
     def test_udtf_with_table_argument_and_partition_by(self):
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index 7d7784dd522..029293ab70f 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -30,7 +30,7 @@ from py4j.java_gateway import JavaObject
 from pyspark import SparkContext
 from pyspark.profiler import Profiler
 from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType
-from pyspark.sql.column import Column, _to_java_column, _to_java_expr, _to_seq
+from pyspark.sql.column import Column, _to_java_expr, _to_seq
 from pyspark.sql.types import (
     DataType,
     StringType,
@@ -336,8 +336,17 @@ class UserDefinedFunction:
         )
         return judf
 
-    def __call__(self, *cols: "ColumnOrName") -> Column:
+    def __call__(self, *args: "ColumnOrName", **kwargs: "ColumnOrName") -> 
Column:
         sc = get_active_spark_context()
+
+        assert sc._jvm is not None
+        jexprs = [_to_java_expr(arg) for arg in args] + [
+            
sc._jvm.org.apache.spark.sql.catalyst.expressions.NamedArgumentExpression(
+                key, _to_java_expr(value)
+            )
+            for key, value in kwargs.items()
+        ]
+
         profiler: Optional[Profiler] = None
         memory_profiler: Optional[Profiler] = None
         if sc.profiler_collector:
@@ -376,7 +385,7 @@ class UserDefinedFunction:
 
                 func.__signature__ = inspect.signature(f)  # type: 
ignore[attr-defined]
                 judf = self._create_judf(func)
-                jUDFExpr = judf.builder(_to_seq(sc, cols, _to_java_expr))
+                jUDFExpr = judf.builder(_to_seq(sc, jexprs))
                 jPythonUDF = judf.fromUDFExpr(jUDFExpr)
                 id = jUDFExpr.resultId().id()
                 sc.profiler_collector.add_profiler(id, profiler)
@@ -394,13 +403,14 @@ class UserDefinedFunction:
 
                 func.__signature__ = inspect.signature(f)  # type: 
ignore[attr-defined]
                 judf = self._create_judf(func)
-                jUDFExpr = judf.builder(_to_seq(sc, cols, _to_java_expr))
+                jUDFExpr = judf.builder(_to_seq(sc, jexprs))
                 jPythonUDF = judf.fromUDFExpr(jUDFExpr)
                 id = jUDFExpr.resultId().id()
                 sc.profiler_collector.add_profiler(id, memory_profiler)
         else:
             judf = self._judf
-            jPythonUDF = judf.apply(_to_seq(sc, cols, _to_java_column))
+            jUDFExpr = judf.builder(_to_seq(sc, jexprs))
+            jPythonUDF = judf.fromUDFExpr(jUDFExpr)
         return Column(jPythonUDF)
 
     # This function is for improving the online help system in the interactive 
interpreter.
@@ -421,8 +431,8 @@ class UserDefinedFunction:
         )
 
         @functools.wraps(self.func, assigned=assignments)
-        def wrapper(*args: "ColumnOrName") -> Column:
-            return self(*args)
+        def wrapper(*args: "ColumnOrName", **kwargs: "ColumnOrName") -> Column:
+            return self(*args, **kwargs)
 
         wrapper.__name__ = self._name
         wrapper.__module__ = (
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 0dd281ea91f..19c8c9c897b 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -84,9 +84,9 @@ def chain(f, g):
 def wrap_udf(f, return_type):
     if return_type.needConversion():
         toInternal = return_type.toInternal
-        return lambda *a: toInternal(f(*a))
+        return lambda *a, **kw: toInternal(f(*a, **kw))
     else:
-        return lambda *a: f(*a)
+        return lambda *a, **kw: f(*a, **kw)
 
 
 def wrap_scalar_pandas_udf(f, return_type):
@@ -115,8 +115,10 @@ def wrap_scalar_pandas_udf(f, return_type):
             )
         return result
 
-    return lambda *a: (
-        verify_result_length(verify_result_type(f(*a)), len(a[0])),
+    return lambda *a, **kw: (
+        verify_result_length(
+            verify_result_type(f(*a, **kw)), len((list(a) + 
list(kw.values()))[0])
+        ),
         arrow_return_type,
     )
 
@@ -137,8 +139,17 @@ def wrap_arrow_batch_udf(f, return_type):
     elif type(return_type) == BinaryType:
         result_func = lambda r: bytes(r) if r is not None else r  # noqa: E731
 
-    def evaluate(*args: pd.Series) -> pd.Series:
-        return pd.Series(result_func(f(*a)) for a in zip(*args))
+    def evaluate(*args: pd.Series, **kwargs: pd.Series) -> pd.Series:
+        keys = list(kwargs.keys())
+        len_args = len(args)
+        return pd.Series(
+            [
+                result_func(
+                    f(*row[:len_args], **{key: row[len_args + i] for i, key in 
enumerate(keys)})
+                )
+                for row in zip(*args, *[kwargs[key] for key in keys])
+            ]
+        )
 
     def verify_result_type(result):
         if not hasattr(result, "__len__"):
@@ -163,8 +174,10 @@ def wrap_arrow_batch_udf(f, return_type):
             )
         return result
 
-    return lambda *a: (
-        verify_result_length(verify_result_type(evaluate(*a)), len(a[0])),
+    return lambda *a, **kw: (
+        verify_result_length(
+            verify_result_type(evaluate(*a, **kw)), len((list(a) + 
list(kw.values()))[0])
+        ),
         arrow_return_type,
     )
 
@@ -517,7 +530,27 @@ def wrap_bounded_window_agg_pandas_udf(f, return_type):
 
 def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index):
     num_arg = read_int(infile)
-    arg_offsets = [read_int(infile) for i in range(num_arg)]
+
+    if eval_type in (
+        PythonEvalType.SQL_BATCHED_UDF,
+        PythonEvalType.SQL_ARROW_BATCHED_UDF,
+        PythonEvalType.SQL_SCALAR_PANDAS_UDF,
+        # The below doesn't support named argument, but shares the same 
protocol.
+        PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
+    ):
+        args_offsets = []
+        kwargs_offsets = {}
+        for _ in range(num_arg):
+            offset = read_int(infile)
+            if read_bool(infile):
+                name = utf8_deserializer.loads(infile)
+                kwargs_offsets[name] = offset
+            else:
+                args_offsets.append(offset)
+    else:
+        args_offsets = [read_int(infile) for i in range(num_arg)]
+        kwargs_offsets = {}
+
     chained_func = None
     for i in range(read_int(infile)):
         f, return_type = read_command(pickleSer, infile)
@@ -535,31 +568,32 @@ def read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index):
 
     # the last returnType will be the return type of UDF
     if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
-        return arg_offsets, wrap_scalar_pandas_udf(func, return_type)
+        udf = wrap_scalar_pandas_udf(func, return_type)
     elif eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF:
-        return arg_offsets, wrap_arrow_batch_udf(func, return_type)
+        udf = wrap_arrow_batch_udf(func, return_type)
     elif eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF:
-        return arg_offsets, wrap_pandas_batch_iter_udf(func, return_type)
+        udf = wrap_pandas_batch_iter_udf(func, return_type)
     elif eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF:
-        return arg_offsets, wrap_pandas_batch_iter_udf(func, return_type)
+        udf = wrap_pandas_batch_iter_udf(func, return_type)
     elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF:
-        return arg_offsets, wrap_arrow_batch_iter_udf(func, return_type)
+        udf = wrap_arrow_batch_iter_udf(func, return_type)
     elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
         argspec = getfullargspec(chained_func)  # signature was lost when 
wrapping it
-        return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, 
argspec, runner_conf)
+        udf = wrap_grouped_map_pandas_udf(func, return_type, argspec, 
runner_conf)
     elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
-        return arg_offsets, wrap_grouped_map_pandas_udf_with_state(func, 
return_type)
+        udf = wrap_grouped_map_pandas_udf_with_state(func, return_type)
     elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
         argspec = getfullargspec(chained_func)  # signature was lost when 
wrapping it
-        return arg_offsets, wrap_cogrouped_map_pandas_udf(func, return_type, 
argspec, runner_conf)
+        udf = wrap_cogrouped_map_pandas_udf(func, return_type, argspec, 
runner_conf)
     elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
-        return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type)
+        udf = wrap_grouped_agg_pandas_udf(func, return_type)
     elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF:
-        return arg_offsets, wrap_window_agg_pandas_udf(func, return_type, 
runner_conf, udf_index)
+        udf = wrap_window_agg_pandas_udf(func, return_type, runner_conf, 
udf_index)
     elif eval_type == PythonEvalType.SQL_BATCHED_UDF:
-        return arg_offsets, wrap_udf(func, return_type)
+        udf = wrap_udf(func, return_type)
     else:
         raise ValueError("Unknown eval type: {}".format(eval_type))
+    return args_offsets, kwargs_offsets, udf
 
 
 # Used by SQL_GROUPED_MAP_PANDAS_UDF and SQL_SCALAR_PANDAS_UDF and 
SQL_ARROW_BATCHED_UDF when
@@ -984,7 +1018,9 @@ def read_udfs(pickleSer, infile, eval_type):
         if is_map_arrow_iter:
             assert num_udfs == 1, "One MAP_ARROW_ITER UDF expected here."
 
-        arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index=0)
+        arg_offsets, _, udf = read_single_udf(
+            pickleSer, infile, eval_type, runner_conf, udf_index=0
+        )
 
         def func(_, iterator):
             num_input_rows = 0
@@ -1074,7 +1110,7 @@ def read_udfs(pickleSer, infile, eval_type):
 
         # See FlatMapGroupsInPandasExec for how arg_offsets are used to
         # distinguish between grouping attributes and data attributes
-        arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index=0)
+        arg_offsets, _, f = read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index=0)
         parsed_offsets = extract_key_value_indexes(arg_offsets)
 
         # Create function like this:
@@ -1091,7 +1127,7 @@ def read_udfs(pickleSer, infile, eval_type):
 
         # See FlatMapGroupsInPandas(WithState)Exec for how arg_offsets are 
used to
         # distinguish between grouping attributes and data attributes
-        arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index=0)
+        arg_offsets, _, f = read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index=0)
         parsed_offsets = extract_key_value_indexes(arg_offsets)
 
         def mapper(a):
@@ -1125,7 +1161,7 @@ def read_udfs(pickleSer, infile, eval_type):
         # We assume there is only one UDF here because cogrouped map doesn't
         # support combining multiple UDFs.
         assert num_udfs == 1
-        arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index=0)
+        arg_offsets, _, f = read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index=0)
 
         parsed_offsets = extract_key_value_indexes(arg_offsets)
 
@@ -1142,7 +1178,10 @@ def read_udfs(pickleSer, infile, eval_type):
             udfs.append(read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index=i))
 
         def mapper(a):
-            result = tuple(f(*[a[o] for o in arg_offsets]) for (arg_offsets, 
f) in udfs)
+            result = tuple(
+                f(*[a[o] for o in args_offsets], **{k: a[o] for k, o in 
kwargs_offsets.items()})
+                for args_offsets, kwargs_offsets, f in udfs
+            )
             # In the special case of a single UDF this will return a single 
result rather
             # than a tuple of results; this is the format that the JVM side 
expects.
             if len(result) == 1:
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
index a60d0beeeed..9fde1814079 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
@@ -105,6 +105,10 @@ class ApplyInPandasWithStatePythonRunner(
 
   private val stateRowDeserializer = stateEncoder.createDeserializer()
 
+  override protected def writeUDF(dataOut: DataOutputStream): Unit = {
+    PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets)
+  }
+
   /**
    * This method sends out the additional metadata before sending out actual 
data.
    *
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
index 7db43a34a88..bd91da3bc0f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata
 import org.apache.spark.sql.types.StructType
 
 /**
@@ -94,11 +95,11 @@ class ArrowEvalPythonEvaluatorFactory(
     pythonRunnerConf: Map[String, String],
     pythonMetrics: Map[String, SQLMetric],
     jobArtifactUUID: Option[String])
-    extends EvalPythonEvaluatorFactory(childOutput, udfs, output) {
+  extends EvalPythonEvaluatorFactory(childOutput, udfs, output) {
 
   override def evaluate(
       funcs: Seq[ChainedPythonFunctions],
-      argOffsets: Array[Array[Int]],
+      argMetas: Array[Array[ArgumentMetadata]],
       iter: Iterator[InternalRow],
       schema: StructType,
       context: TaskContext): Iterator[InternalRow] = {
@@ -108,10 +109,10 @@ class ArrowEvalPythonEvaluatorFactory(
     // DO NOT use iter.grouped(). See BatchIterator.
     val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else 
Iterator(iter)
 
-    val columnarBatchIter = new ArrowPythonRunner(
+    val columnarBatchIter = new ArrowPythonWithNamedArgumentRunner(
       funcs,
       evalType,
-      argOffsets,
+      argMetas,
       schema,
       sessionLocalTimeZone,
       largeVarTypes,
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
index 8ebd8a3a106..9d5bac0c600 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
@@ -23,7 +23,7 @@ import org.apache.spark.{JobArtifactSet, TaskContext}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.execution.SparkPlan
-import 
org.apache.spark.sql.execution.python.EvalPythonUDTFExec.ArgumentMetadata
+import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
index 0f26d8f21f8..251e682c9e1 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
@@ -17,17 +17,17 @@
 
 package org.apache.spark.sql.execution.python
 
+import java.io.DataOutputStream
+
 import org.apache.spark.api.python._
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 import org.apache.spark.sql.vectorized.ColumnarBatch
 
-/**
- * Similar to `PythonUDFRunner`, but exchange data with Python worker via 
Arrow stream.
- */
-class ArrowPythonRunner(
+abstract class BaseArrowPythonRunner(
     funcs: Seq[ChainedPythonFunctions],
     evalType: Int,
     argOffsets: Array[Array[Int]],
@@ -61,6 +61,49 @@ class ArrowPythonRunner(
       s"Please change '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'.")
 }
 
+/**
+ * Similar to `PythonUDFRunner`, but exchange data with Python worker via 
Arrow stream.
+ */
+class ArrowPythonRunner(
+    funcs: Seq[ChainedPythonFunctions],
+    evalType: Int,
+    argOffsets: Array[Array[Int]],
+    _schema: StructType,
+    _timeZoneId: String,
+    largeVarTypes: Boolean,
+    workerConf: Map[String, String],
+    pythonMetrics: Map[String, SQLMetric],
+    jobArtifactUUID: Option[String])
+  extends BaseArrowPythonRunner(
+    funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes, 
workerConf,
+    pythonMetrics, jobArtifactUUID) {
+
+  override protected def writeUDF(dataOut: DataOutputStream): Unit =
+    PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets)
+}
+
+/**
+ * Similar to `PythonUDFWithNamedArgumentsRunner`, but exchange data with 
Python worker
+ * via Arrow stream.
+ */
+class ArrowPythonWithNamedArgumentRunner(
+    funcs: Seq[ChainedPythonFunctions],
+    evalType: Int,
+    argMetas: Array[Array[ArgumentMetadata]],
+    _schema: StructType,
+    _timeZoneId: String,
+    largeVarTypes: Boolean,
+    workerConf: Map[String, String],
+    pythonMetrics: Map[String, SQLMetric],
+    jobArtifactUUID: Option[String])
+  extends BaseArrowPythonRunner(
+    funcs, evalType, argMetas.map(_.map(_.offset)), _schema, _timeZoneId, 
largeVarTypes, workerConf,
+    pythonMetrics, jobArtifactUUID) {
+
+  override protected def writeUDF(dataOut: DataOutputStream): Unit =
+    PythonUDFRunner.writeUDFs(dataOut, funcs, argMetas)
+}
+
 object ArrowPythonRunner {
   /** Return Map with conf settings to be used in ArrowPythonRunner */
   def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] = {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
index c0fa8b58bee..690947b4129 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
@@ -23,7 +23,7 @@ import org.apache.spark.api.python._
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.PythonUDTF
 import org.apache.spark.sql.execution.metric.SQLMetric
-import 
org.apache.spark.sql.execution.python.EvalPythonUDTFExec.ArgumentMetadata
+import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 import org.apache.spark.sql.vectorized.ColumnarBatch
@@ -42,15 +42,12 @@ class ArrowPythonUDTFRunner(
     val pythonMetrics: Map[String, SQLMetric],
     jobArtifactUUID: Option[String])
   extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](
-      Seq(ChainedPythonFunctions(Seq(udtf.func))),
-      evalType, Array(argMetas.map(_.offset)), jobArtifactUUID)
+      Seq(ChainedPythonFunctions(Seq(udtf.func))), evalType, 
Array(argMetas.map(_.offset)),
+      jobArtifactUUID)
   with BasicPythonArrowInput
   with BasicPythonArrowOutput {
 
-  override protected def writeUDF(
-      dataOut: DataOutputStream,
-      funcs: Seq[ChainedPythonFunctions],
-      argOffsets: Array[Array[Int]]): Unit = {
+  override protected def writeUDF(dataOut: DataOutputStream): Unit = {
     PythonUDTFRunner.writeUDTF(dataOut, udtf, argMetas)
   }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
index 1de8f55d84b..a0e7789b281 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata
 import org.apache.spark.sql.types.{StructField, StructType}
 
 /**
@@ -60,7 +61,7 @@ class BatchEvalPythonEvaluatorFactory(
 
   override def evaluate(
       funcs: Seq[ChainedPythonFunctions],
-      argOffsets: Array[Array[Int]],
+      argMetas: Array[Array[ArgumentMetadata]],
       iter: Iterator[InternalRow],
       schema: StructType,
       context: TaskContext): Iterator[InternalRow] = {
@@ -71,8 +72,8 @@ class BatchEvalPythonEvaluatorFactory(
 
     // Output iterator for results from Python.
     val outputIterator =
-      new PythonUDFRunner(
-        funcs, PythonEvalType.SQL_BATCHED_UDF, argOffsets, pythonMetrics, 
jobArtifactUUID)
+      new PythonUDFWithNamedArgumentsRunner(
+        funcs, PythonEvalType.SQL_BATCHED_UDF, argMetas, pythonMetrics, 
jobArtifactUUID)
       .compute(inputIterator, context.partitionId(), context)
 
     val unpickle = new Unpickler
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
index 46cc0f2ab50..342e0723194 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
@@ -23,14 +23,14 @@ import scala.collection.JavaConverters._
 
 import net.razorvine.pickle.Unpickler
 
-import org.apache.spark.{JobArtifactSet, SparkEnv, TaskContext}
-import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, 
PythonWorker, PythonWorkerUtils}
+import org.apache.spark.{JobArtifactSet, TaskContext}
+import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, 
PythonWorkerUtils}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.util.GenericArrayData
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.execution.metric.SQLMetric
-import 
org.apache.spark.sql.execution.python.EvalPythonUDTFExec.ArgumentMetadata
+import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata
 import org.apache.spark.sql.types.StructType
 
 /**
@@ -101,18 +101,8 @@ class PythonUDTFRunner(
     Seq(ChainedPythonFunctions(Seq(udtf.func))),
     PythonEvalType.SQL_TABLE_UDF, Array(argMetas.map(_.offset)), 
pythonMetrics, jobArtifactUUID) {
 
-  protected override def newWriter(
-      env: SparkEnv,
-      worker: PythonWorker,
-      inputIterator: Iterator[Array[Byte]],
-      partitionIndex: Int,
-      context: TaskContext): Writer = {
-    new PythonUDFWriter(env, worker, inputIterator, partitionIndex, context) {
-
-      protected override def writeCommand(dataOut: DataOutputStream): Unit = {
-        PythonUDTFRunner.writeUDTF(dataOut, udtf, argMetas)
-      }
-    }
+  override protected def writeUDF(dataOut: DataOutputStream): Unit = {
+    PythonUDTFRunner.writeUDTF(dataOut, udtf, argMetas)
   }
 }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonEvaluatorFactory.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonEvaluatorFactory.scala
index 373e17c0aa3..d5142f58eab 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonEvaluatorFactory.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonEvaluatorFactory.scala
@@ -25,6 +25,7 @@ import org.apache.spark.{PartitionEvaluator, 
PartitionEvaluatorFactory, SparkEnv
 import org.apache.spark.api.python.ChainedPythonFunctions
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata
 import org.apache.spark.sql.types.{DataType, StructField, StructType}
 import org.apache.spark.util.Utils
 
@@ -32,11 +33,11 @@ abstract class EvalPythonEvaluatorFactory(
     childOutput: Seq[Attribute],
     udfs: Seq[PythonUDF],
     output: Seq[Attribute])
-    extends PartitionEvaluatorFactory[InternalRow, InternalRow] {
+  extends PartitionEvaluatorFactory[InternalRow, InternalRow] {
 
   protected def evaluate(
       funcs: Seq[ChainedPythonFunctions],
-      argOffsets: Array[Array[Int]],
+      argMetas: Array[Array[ArgumentMetadata]],
       iter: Iterator[InternalRow],
       schema: StructType,
       context: TaskContext): Iterator[InternalRow]
@@ -78,14 +79,20 @@ abstract class EvalPythonEvaluatorFactory(
       // flatten all the arguments
       val allInputs = new ArrayBuffer[Expression]
       val dataTypes = new ArrayBuffer[DataType]
-      val argOffsets = inputs.map { input =>
+      val argMetas = inputs.map { input =>
         input.map { e =>
-          if (allInputs.exists(_.semanticEquals(e))) {
-            allInputs.indexWhere(_.semanticEquals(e))
+          val (key, value) = e match {
+            case NamedArgumentExpression(key, value) =>
+              (Some(key), value)
+            case _ =>
+              (None, e)
+          }
+          if (allInputs.exists(_.semanticEquals(value))) {
+            ArgumentMetadata(allInputs.indexWhere(_.semanticEquals(value)), 
key)
           } else {
-            allInputs += e
-            dataTypes += e.dataType
-            allInputs.length - 1
+            allInputs += value
+            dataTypes += value.dataType
+            ArgumentMetadata(allInputs.length - 1, key)
           }
         }.toArray
       }.toArray
@@ -102,7 +109,7 @@ abstract class EvalPythonEvaluatorFactory(
       }
 
       val outputRowIterator =
-        evaluate(pyFuncs, argOffsets, projectedRowIter, schema, context)
+        evaluate(pyFuncs, argMetas, projectedRowIter, schema, context)
 
       val joined = new JoinedRow
       val resultProj = UnsafeProjection.create(output, output)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala
index 1c8b0f2228f..af6769cfbb9 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala
@@ -22,6 +22,16 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.execution.UnaryExecNode
 
+object EvalPythonExec {
+  /**
+   * Metadata for arguments of Python UDTF.
+   *
+   * @param offset the offset of the argument
+   * @param name the name of the argument if it's a `NamedArgumentExpression`
+   */
+  case class ArgumentMetadata(offset: Int, name: Option[String])
+}
+
 /**
  * A physical plan that evaluates a [[PythonUDF]], one partition of tuples at 
a time.
  *
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala
index 410209e0ada..41a99693443 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala
@@ -21,25 +21,15 @@ import java.io.File
 
 import scala.collection.mutable.ArrayBuffer
 
-import org.apache.spark.{ContextAwareIterator, SparkEnv, TaskContext}
+import org.apache.spark.{SparkEnv, TaskContext}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.execution.UnaryExecNode
-import 
org.apache.spark.sql.execution.python.EvalPythonUDTFExec.ArgumentMetadata
+import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata
 import org.apache.spark.sql.types.{DataType, StructField, StructType}
 import org.apache.spark.util.Utils
 
-object EvalPythonUDTFExec {
-  /**
-   * Metadata for arguments of Python UDTF.
-   *
-   * @param offset the offset of the argument
-   * @param name the name of the argument if it's a `NamedArgumentExpression`
-   */
-  case class ArgumentMetadata(offset: Int, name: Option[String])
-}
-
 /**
  * A physical plan that evaluates a [[PythonUDTF]], one partition of tuples at 
a time.
  * This is similar to [[EvalPythonExec]].
@@ -66,7 +56,6 @@ trait EvalPythonUDTFExec extends UnaryExecNode {
 
     inputRDD.mapPartitions { iter =>
       val context = TaskContext.get()
-      val contextAwareIterator = new ContextAwareIterator(context, iter)
 
       // The queue used to buffer input rows so we can drain it to
       // combine input with output from Python.
@@ -104,7 +93,7 @@ trait EvalPythonUDTFExec extends UnaryExecNode {
       // Also keep track of the number rows added to the queue.
       // This is needed to process extra output rows from the `terminate()` 
call of the UDTF.
       var count = 0L
-      val projectedRowIter = contextAwareIterator.map { inputRow =>
+      val projectedRowIter = iter.map { inputRow =>
         queue.add(inputRow.asInstanceOf[UnsafeRow])
         count += 1
         projection(inputRow)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
index 00ee3a17563..1e075cab922 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
@@ -22,7 +22,7 @@ import org.apache.arrow.vector.VectorSchemaRoot
 import org.apache.arrow.vector.ipc.ArrowStreamWriter
 
 import org.apache.spark.{SparkEnv, TaskContext}
-import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, 
PythonRDD, PythonWorker}
+import org.apache.spark.api.python.{BasePythonRunner, PythonRDD, PythonWorker}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.execution.arrow
 import org.apache.spark.sql.execution.arrow.ArrowWriter
@@ -54,11 +54,7 @@ private[python] trait PythonArrowInput[IN] { self: 
BasePythonRunner[IN, _] =>
       dataOut: DataOutputStream,
       inputIterator: Iterator[IN]): Boolean
 
-  protected def writeUDF(
-      dataOut: DataOutputStream,
-      funcs: Seq[ChainedPythonFunctions],
-      argOffsets: Array[Array[Int]]): Unit =
-    PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets)
+  protected def writeUDF(dataOut: DataOutputStream): Unit
 
   protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = {
     // Write config for the worker as a number of key -> value pairs of strings
@@ -94,9 +90,10 @@ protected def close(): Unit = {
       partitionIndex: Int,
       context: TaskContext): Writer = {
     new Writer(env, worker, inputIterator, partitionIndex, context) {
+
       protected override def writeCommand(dataOut: DataOutputStream): Unit = {
         handleMetadataBeforeExec(dataOut)
-        writeUDF(dataOut, funcs, argOffsets)
+        writeUDF(dataOut)
       }
 
       override def writeNextInputToStream(dataOut: DataOutputStream): Boolean 
= {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
index bc27ee6919d..b99517f544d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
@@ -23,6 +23,7 @@ import java.util.concurrent.atomic.AtomicBoolean
 import org.apache.spark._
 import org.apache.spark.api.python._
 import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata
 import org.apache.spark.sql.internal.SQLConf
 
 /**
@@ -43,24 +44,31 @@ abstract class BasePythonUDFRunner(
 
   override val simplifiedTraceback: Boolean = 
SQLConf.get.pysparkSimplifiedTraceback
 
-  abstract class PythonUDFWriter(
+  protected def writeUDF(dataOut: DataOutputStream): Unit
+
+  protected override def newWriter(
       env: SparkEnv,
       worker: PythonWorker,
       inputIterator: Iterator[Array[Byte]],
       partitionIndex: Int,
-      context: TaskContext)
-    extends Writer(env, worker, inputIterator, partitionIndex, context) {
-
-    override def writeNextInputToStream(dataOut: DataOutputStream): Boolean = {
-      val startData = dataOut.size()
-      val wroteData = PythonRDD.writeNextElementToStream(inputIterator, 
dataOut)
-      if (!wroteData) {
-        // Reached the end of input.
-        dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
+      context: TaskContext): Writer = {
+    new Writer(env, worker, inputIterator, partitionIndex, context) {
+
+      protected override def writeCommand(dataOut: DataOutputStream): Unit = {
+        writeUDF(dataOut)
+      }
+
+      override def writeNextInputToStream(dataOut: DataOutputStream): Boolean 
= {
+        val startData = dataOut.size()
+        val wroteData = PythonRDD.writeNextElementToStream(inputIterator, 
dataOut)
+        if (!wroteData) {
+          // Reached the end of input.
+          dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
+        }
+        val deltaData = dataOut.size() - startData
+        pythonMetrics("pythonDataSent") += deltaData
+        wroteData
       }
-      val deltaData = dataOut.size() - startData
-      pythonMetrics("pythonDataSent") += deltaData
-      wroteData
     }
   }
 
@@ -111,19 +119,22 @@ class PythonUDFRunner(
     jobArtifactUUID: Option[String])
   extends BasePythonUDFRunner(funcs, evalType, argOffsets, pythonMetrics, 
jobArtifactUUID) {
 
-  protected override def newWriter(
-      env: SparkEnv,
-      worker: PythonWorker,
-      inputIterator: Iterator[Array[Byte]],
-      partitionIndex: Int,
-      context: TaskContext): Writer = {
-    new PythonUDFWriter(env, worker, inputIterator, partitionIndex, context) {
+  override protected def writeUDF(dataOut: DataOutputStream): Unit = {
+    PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets)
+  }
+}
 
-      protected override def writeCommand(dataOut: DataOutputStream): Unit = {
-        PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets)
-      }
+class PythonUDFWithNamedArgumentsRunner(
+    funcs: Seq[ChainedPythonFunctions],
+    evalType: Int,
+    argMetas: Array[Array[ArgumentMetadata]],
+    pythonMetrics: Map[String, SQLMetric],
+    jobArtifactUUID: Option[String])
+  extends BasePythonUDFRunner(
+    funcs, evalType, argMetas.map(_.map(_.offset)), pythonMetrics, 
jobArtifactUUID) {
 
-    }
+  override protected def writeUDF(dataOut: DataOutputStream): Unit = {
+    PythonUDFRunner.writeUDFs(dataOut, funcs, argMetas)
   }
 }
 
@@ -146,4 +157,30 @@ object PythonUDFRunner {
       }
     }
   }
+
+  def writeUDFs(
+      dataOut: DataOutputStream,
+      funcs: Seq[ChainedPythonFunctions],
+      argMetas: Array[Array[ArgumentMetadata]]): Unit = {
+    dataOut.writeInt(funcs.length)
+    funcs.zip(argMetas).foreach { case (chained, metas) =>
+      dataOut.writeInt(metas.length)
+      metas.foreach {
+        case ArgumentMetadata(offset, name) =>
+          dataOut.writeInt(offset)
+          name match {
+            case Some(name) =>
+              dataOut.writeBoolean(true)
+              PythonWorkerUtils.writeUTF(name, dataOut)
+            case _ =>
+              dataOut.writeBoolean(false)
+          }
+      }
+      dataOut.writeInt(chained.funcs.length)
+      chained.funcs.foreach { f =>
+        dataOut.writeInt(f.command.length)
+        dataOut.write(f.command.toArray)
+      }
+    }
+  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
index 38d521c16d5..f576637aa25 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
@@ -50,6 +50,19 @@ case class UserDefinedPythonFunction(
     udfDeterministic: Boolean) {
 
   def builder(e: Seq[Expression]): Expression = {
+    if (pythonEvalType == PythonEvalType.SQL_BATCHED_UDF
+        || pythonEvalType ==PythonEvalType.SQL_ARROW_BATCHED_UDF
+        || pythonEvalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF) {
+      /*
+       * Check if the named arguments:
+       * - don't have duplicated names
+       * - don't contain positional arguments after named arguments
+       */
+      NamedParametersSupport.splitAndCheckNamedArguments(e, name)
+    } else if (e.exists(_.isInstanceOf[NamedArgumentExpression])) {
+      throw QueryCompilationErrors.namedArgumentsNotSupported(name)
+    }
+
     if (pythonEvalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF) {
       PythonUDAF(name, func, dataType, e, udfDeterministic)
     } else {
@@ -104,7 +117,7 @@ case class UserDefinedPythonTableFunction(
     /*
      * Check if the named arguments:
      * - don't have duplicated names
-     * - don't contain positional arguments
+     * - don't contain positional arguments after named arguments
      */
     NamedParametersSupport.splitAndCheckNamedArguments(exprs, name)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to