dtenedor commented on code in PR #41316:
URL: https://github.com/apache/spark/pull/41316#discussion_r1221943311


##########
python/pyspark/sql/functions.py:
##########
@@ -10403,6 +10405,84 @@ def udf(
         return _create_py_udf(f=f, returnType=returnType, useArrow=useArrow)
 
 
+def udtf(
+    cls: Optional[Type] = None,
+    *,
+    returnType: Union[StructType, str],
+) -> Union[UserDefinedTableFunction, functools.partial]:
+    """Creates a user defined table function (UDTF).
+
+    .. versionadded:: 3.5.0
+
+    Parameters
+    ----------
+    cls : class
+        the Python user-defined table function handler class.
+    returnType : :class:`pyspark.sql.types.StructType` or str
+        the return type of the user-defined table function. The value can be 
either a
+        :class:`pyspark.sql.types.StructType` object or a DDL-formatted struct 
type string.
+
+    Examples
+    --------
+    Implement the UDTF class.
+    >>> class TestUDTF:
+    ...     def eval(self, *args: Any):
+    ...         yield "hello", "world"
+
+    Create the UDTF
+    >>> from pyspark.sql.functions import udtf
+    >>> test_udtf = udtf(TestUDTF, returnType="c1: string, c2: string")
+
+    Create the UDTF using the decorator
+    >>> @udtf(returnType="c1: int, c2: int")
+    ... class PlusOne:
+    ...     def eval(self, x: int):
+    ...         yield x, x + 1
+
+    Invoke the UDTF
+    >>> test_udtf().show()
+    +-----+-----+
+    |   c1|   c2|
+    +-----+-----+
+    |hello|world|
+    +-----+-----+
+
+    Invoke the UDTF with parameters
+    >>> from pyspark.sql.functions import lit
+    >>> PlusOne(lit(1)).show()
+    +---+---+
+    | c1| c2|
+    +---+---+
+    |  1|  2|
+    +---+---+
+
+    Notes
+    -----
+    User-defined table functions are considered deterministic by default.
+    Use `asNondeterministic()` to mark a function as non-deterministic. E.g.:
+
+    >>> import random
+    >>> class RandomUDTF:
+    ...     def eval(self, a: int):
+    ...         yield a * int(random.random() * 100),
+    >>> random_udtf = udtf(RandomUDTF, returnType="r: 
int").asNondeterministic()
+

Review Comment:
   ```suggestion
   
       Use "yield" to produce one row for the UDTF result relation, as many 
times
       as needed. In the event of a lateral join, each such result row will be
       associated with the most recent input row consumed from the "eval" 
method.
       Or, use "return" to produce multiple rows for the UDTF result relation at
       once and immediately end execution of the current "eval" method call.
   
   ```



##########
python/pyspark/sql/tests/test_udtf.py:
##########
@@ -0,0 +1,366 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from typing import Iterator
+
+from py4j.protocol import Py4JJavaError
+
+from pyspark.errors import PythonException, AnalysisException
+from pyspark.sql.functions import lit, udtf
+from pyspark.sql.types import Row
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+
+
+class UDTFTestsMixin(ReusedSQLTestCase):
+    def test_simple_udtf(self):
+        class TestUDTF:
+            def eval(self):
+                yield "hello", "world"
+
+        func = udtf(TestUDTF, returnType="c1: string, c2: string")
+        rows = func().collect()
+        self.assertEqual(rows, [Row(c1="hello", c2="world")])
+
+    def test_udtf_yield_single_row_col(self):
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a,
+
+        func = udtf(TestUDTF, returnType="a: int")
+        rows = func(lit(1)).collect()
+        self.assertEqual(rows, [Row(a=1)])
+
+    def test_udtf_yield_multi_cols(self):
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a, a + 1
+
+        func = udtf(TestUDTF, returnType="a: int, b: int")
+        rows = func(lit(1)).collect()
+        self.assertEqual(rows, [Row(a=1, b=2)])
+
+    def test_udtf_yield_multi_rows(self):
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a,
+                yield a + 1,
+
+        func = udtf(TestUDTF, returnType="a: int")
+        rows = func(lit(1)).collect()
+        self.assertEqual(rows, [Row(a=1), Row(a=2)])
+
+    def test_udtf_yield_multi_row_col(self):
+        class TestUDTF:
+            def eval(self, a: int, b: int):
+                yield a, b, a + b
+                yield a, b, a - b
+                yield a, b, b - a
+
+        func = udtf(TestUDTF, returnType="a: int, b: int, c: int")
+        rows = func(lit(1), lit(2)).collect()
+        self.assertEqual(rows, [Row(a=1, b=2, c=3), Row(a=1, b=2, c=-1), 
Row(a=1, b=2, c=1)])
+
+    def test_udtf_decorator(self):
+        @udtf(returnType="a: int, b: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a, a + 1
+
+        rows = TestUDTF(lit(1)).collect()
+        self.assertEqual(rows, [Row(a=1, b=2)])
+
+    def test_udtf_registration(self):
+        class TestUDTF:
+            def eval(self, a: int, b: int):
+                yield a, b, a + b
+                yield a, b, a - b
+                yield a, b, b - a
+
+        func = udtf(TestUDTF, returnType="a: int, b: int, c: int")
+        self.spark.udtf.register("testUDTF", func)
+        df = self.spark.sql("SELECT * FROM testUDTF(1, 2)")
+        self.assertEqual(
+            df.collect(), [Row(a=1, b=2, c=3), Row(a=1, b=2, c=-1), Row(a=1, 
b=2, c=1)]
+        )
+
+    def test_udtf_with_lateral_join(self):
+        class TestUDTF:
+            def eval(self, a: int, b: int) -> Iterator:
+                yield a, b, a + b
+                yield a, b, a - b
+
+        func = udtf(TestUDTF, returnType="a: int, b: int, c: int")
+        self.spark.udtf.register("testUDTF", func)
+        df = self.spark.sql(
+            "SELECT f.* FROM values (0, 1), (1, 2) t(a, b), LATERAL 
testUDTF(a, b) f"
+        )
+        expected = self.spark.createDataFrame(
+            [(0, 1, 1), (0, 1, -1), (1, 2, 3), (1, 2, -1)], schema=["a", "b", 
"c"]
+        )
+        self.assertEqual(df.collect(), expected.collect())
+
+    def test_udtf_eval_with_return_stmt(self):
+        class TestUDTF:
+            def eval(self, a: int, b: int):
+                return [(a, a + 1), (b, b + 1)]
+
+        func = udtf(TestUDTF, returnType="a: int, b: int")
+        rows = func(lit(1), lit(2)).collect()
+        self.assertEqual(rows, [Row(a=1, b=2), Row(a=2, b=3)])
+
+    def test_udtf_eval_returning_non_tuple(self):
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a
+
+        func = udtf(TestUDTF, returnType="a: int")
+        with self.assertRaisesRegex(PythonException, "Unexpected tuple 1 with 
StructType"):

Review Comment:
   This error message is rather confusing, and this might happen frequently for 
UDTFs that return just one column. Is there a way to check the type of the 
value from the "yield" or "return" call and internally add it to a tuple of 
exactly one element in this case? (Feel free to add a Jira and leave a TODO 
comment for this if needed.)



##########
python/pyspark/sql/tests/test_udtf.py:
##########
@@ -0,0 +1,366 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from typing import Iterator
+
+from py4j.protocol import Py4JJavaError
+
+from pyspark.errors import PythonException, AnalysisException
+from pyspark.sql.functions import lit, udtf
+from pyspark.sql.types import Row
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+
+
+class UDTFTestsMixin(ReusedSQLTestCase):
+    def test_simple_udtf(self):
+        class TestUDTF:
+            def eval(self):
+                yield "hello", "world"
+
+        func = udtf(TestUDTF, returnType="c1: string, c2: string")
+        rows = func().collect()
+        self.assertEqual(rows, [Row(c1="hello", c2="world")])
+
+    def test_udtf_yield_single_row_col(self):
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a,
+
+        func = udtf(TestUDTF, returnType="a: int")
+        rows = func(lit(1)).collect()
+        self.assertEqual(rows, [Row(a=1)])
+
+    def test_udtf_yield_multi_cols(self):
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a, a + 1
+
+        func = udtf(TestUDTF, returnType="a: int, b: int")
+        rows = func(lit(1)).collect()
+        self.assertEqual(rows, [Row(a=1, b=2)])
+
+    def test_udtf_yield_multi_rows(self):
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a,
+                yield a + 1,
+
+        func = udtf(TestUDTF, returnType="a: int")
+        rows = func(lit(1)).collect()
+        self.assertEqual(rows, [Row(a=1), Row(a=2)])
+
+    def test_udtf_yield_multi_row_col(self):
+        class TestUDTF:
+            def eval(self, a: int, b: int):
+                yield a, b, a + b
+                yield a, b, a - b
+                yield a, b, b - a
+
+        func = udtf(TestUDTF, returnType="a: int, b: int, c: int")
+        rows = func(lit(1), lit(2)).collect()
+        self.assertEqual(rows, [Row(a=1, b=2, c=3), Row(a=1, b=2, c=-1), 
Row(a=1, b=2, c=1)])
+
+    def test_udtf_decorator(self):
+        @udtf(returnType="a: int, b: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a, a + 1
+
+        rows = TestUDTF(lit(1)).collect()
+        self.assertEqual(rows, [Row(a=1, b=2)])
+
+    def test_udtf_registration(self):
+        class TestUDTF:
+            def eval(self, a: int, b: int):
+                yield a, b, a + b
+                yield a, b, a - b
+                yield a, b, b - a
+
+        func = udtf(TestUDTF, returnType="a: int, b: int, c: int")
+        self.spark.udtf.register("testUDTF", func)
+        df = self.spark.sql("SELECT * FROM testUDTF(1, 2)")
+        self.assertEqual(
+            df.collect(), [Row(a=1, b=2, c=3), Row(a=1, b=2, c=-1), Row(a=1, 
b=2, c=1)]
+        )
+
+    def test_udtf_with_lateral_join(self):
+        class TestUDTF:
+            def eval(self, a: int, b: int) -> Iterator:
+                yield a, b, a + b
+                yield a, b, a - b
+
+        func = udtf(TestUDTF, returnType="a: int, b: int, c: int")
+        self.spark.udtf.register("testUDTF", func)
+        df = self.spark.sql(
+            "SELECT f.* FROM values (0, 1), (1, 2) t(a, b), LATERAL 
testUDTF(a, b) f"
+        )
+        expected = self.spark.createDataFrame(
+            [(0, 1, 1), (0, 1, -1), (1, 2, 3), (1, 2, -1)], schema=["a", "b", 
"c"]
+        )
+        self.assertEqual(df.collect(), expected.collect())
+
+    def test_udtf_eval_with_return_stmt(self):
+        class TestUDTF:
+            def eval(self, a: int, b: int):
+                return [(a, a + 1), (b, b + 1)]
+
+        func = udtf(TestUDTF, returnType="a: int, b: int")
+        rows = func(lit(1), lit(2)).collect()
+        self.assertEqual(rows, [Row(a=1, b=2), Row(a=2, b=3)])
+
+    def test_udtf_eval_returning_non_tuple(self):
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a
+
+        func = udtf(TestUDTF, returnType="a: int")
+        with self.assertRaisesRegex(PythonException, "Unexpected tuple 1 with 
StructType"):
+            func(lit(1)).collect()
+
+    def test_udtf_eval_returning_non_generator(self):
+        class TestUDTF:
+            def eval(self, a: int):
+                return (a,)
+
+        func = udtf(TestUDTF, returnType="a: int")
+        with self.assertRaisesRegex(PythonException, "Unexpected tuple 1 with 
StructType"):
+            func(lit(1)).collect()
+
+    def test_udtf_eval_with_no_return(self):
+        @udtf(returnType="a: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                ...
+
+        # TODO(SPARK-43967): Support Python UDTFs with empty return values
+        with self.assertRaisesRegex(
+            PythonException, "TypeError: 'NoneType' object is not iterable"
+        ):
+            TestUDTF(lit(1)).collect()
+
+        @udtf(returnType="a: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                return
+
+        with self.assertRaisesRegex(
+            PythonException, "TypeError: 'NoneType' object is not iterable"
+        ):
+            TestUDTF(lit(1)).collect()
+
+    def test_udtf_with_conditional_return(self):
+        class TestUDTF:
+            def eval(self, a: int):
+                if a > 5:
+                    yield a,
+
+        func = udtf(TestUDTF, returnType="a: int")
+        self.spark.udtf.register("test_udtf", func)
+        self.assertEqual(
+            self.spark.sql("SELECT * FROM range(0, 8) JOIN LATERAL 
test_udtf(id)").collect(),
+            [Row(id=6, a=6), Row(id=7, a=7)],
+        )
+
+    def test_udtf_with_empty_yield(self):
+        @udtf(returnType="a: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield
+
+        with self.assertRaisesRegex(Py4JJavaError, 
"java.lang.NullPointerException"):
+            TestUDTF(lit(1)).collect()
+
+    def test_udtf_with_none_output(self):
+        @udtf(returnType="a: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a,
+                yield None,
+
+        self.assertEqual(TestUDTF(lit(1)).collect(), [Row(a=1), Row(a=None)])
+        df = self.spark.createDataFrame([(0, 1), (1, 2)], schema=["a", "b"])
+        self.assertEqual(TestUDTF(lit(1)).join(df, "a", "inner").collect(), 
[Row(a=1, b=2)])
+        self.assertEqual(
+            TestUDTF(lit(1)).join(df, "a", "left").collect(), [Row(a=None, 
b=None), Row(a=1, b=2)]
+        )
+
+    def test_udtf_with_none_input(self):
+        @udtf(returnType="a: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a,
+
+        self.assertEqual(TestUDTF(lit(None)).collect(), [Row(a=None)])
+        self.spark.udtf.register("testUDTF", TestUDTF)
+        df = self.spark.sql("SELECT * FROM testUDTF(null)")
+        self.assertEqual(df.collect(), [Row(a=None)])
+
+    def test_udtf_with_wrong_num_output(self):
+        err_msg = (
+            "java.lang.IllegalStateException: Input row doesn't have expected 
number of "

Review Comment:
   Can we have a test like this, but the number of output values from the 
"yield" or "return" doesn't match the output schema?



##########
python/pyspark/sql/tests/test_udtf.py:
##########
@@ -0,0 +1,366 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from typing import Iterator
+
+from py4j.protocol import Py4JJavaError
+
+from pyspark.errors import PythonException, AnalysisException
+from pyspark.sql.functions import lit, udtf
+from pyspark.sql.types import Row
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+
+
+class UDTFTestsMixin(ReusedSQLTestCase):
+    def test_simple_udtf(self):
+        class TestUDTF:
+            def eval(self):
+                yield "hello", "world"
+
+        func = udtf(TestUDTF, returnType="c1: string, c2: string")
+        rows = func().collect()
+        self.assertEqual(rows, [Row(c1="hello", c2="world")])
+
+    def test_udtf_yield_single_row_col(self):
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a,
+
+        func = udtf(TestUDTF, returnType="a: int")
+        rows = func(lit(1)).collect()
+        self.assertEqual(rows, [Row(a=1)])
+
+    def test_udtf_yield_multi_cols(self):
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a, a + 1
+
+        func = udtf(TestUDTF, returnType="a: int, b: int")
+        rows = func(lit(1)).collect()
+        self.assertEqual(rows, [Row(a=1, b=2)])
+
+    def test_udtf_yield_multi_rows(self):
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a,
+                yield a + 1,
+
+        func = udtf(TestUDTF, returnType="a: int")
+        rows = func(lit(1)).collect()
+        self.assertEqual(rows, [Row(a=1), Row(a=2)])
+
+    def test_udtf_yield_multi_row_col(self):
+        class TestUDTF:
+            def eval(self, a: int, b: int):
+                yield a, b, a + b
+                yield a, b, a - b
+                yield a, b, b - a
+
+        func = udtf(TestUDTF, returnType="a: int, b: int, c: int")
+        rows = func(lit(1), lit(2)).collect()
+        self.assertEqual(rows, [Row(a=1, b=2, c=3), Row(a=1, b=2, c=-1), 
Row(a=1, b=2, c=1)])
+
+    def test_udtf_decorator(self):
+        @udtf(returnType="a: int, b: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a, a + 1
+
+        rows = TestUDTF(lit(1)).collect()
+        self.assertEqual(rows, [Row(a=1, b=2)])
+
+    def test_udtf_registration(self):
+        class TestUDTF:
+            def eval(self, a: int, b: int):
+                yield a, b, a + b
+                yield a, b, a - b
+                yield a, b, b - a
+
+        func = udtf(TestUDTF, returnType="a: int, b: int, c: int")
+        self.spark.udtf.register("testUDTF", func)
+        df = self.spark.sql("SELECT * FROM testUDTF(1, 2)")
+        self.assertEqual(
+            df.collect(), [Row(a=1, b=2, c=3), Row(a=1, b=2, c=-1), Row(a=1, 
b=2, c=1)]
+        )
+
+    def test_udtf_with_lateral_join(self):
+        class TestUDTF:
+            def eval(self, a: int, b: int) -> Iterator:
+                yield a, b, a + b
+                yield a, b, a - b
+
+        func = udtf(TestUDTF, returnType="a: int, b: int, c: int")
+        self.spark.udtf.register("testUDTF", func)
+        df = self.spark.sql(
+            "SELECT f.* FROM values (0, 1), (1, 2) t(a, b), LATERAL 
testUDTF(a, b) f"
+        )
+        expected = self.spark.createDataFrame(
+            [(0, 1, 1), (0, 1, -1), (1, 2, 3), (1, 2, -1)], schema=["a", "b", 
"c"]
+        )
+        self.assertEqual(df.collect(), expected.collect())
+
+    def test_udtf_eval_with_return_stmt(self):
+        class TestUDTF:
+            def eval(self, a: int, b: int):
+                return [(a, a + 1), (b, b + 1)]
+
+        func = udtf(TestUDTF, returnType="a: int, b: int")
+        rows = func(lit(1), lit(2)).collect()
+        self.assertEqual(rows, [Row(a=1, b=2), Row(a=2, b=3)])
+
+    def test_udtf_eval_returning_non_tuple(self):
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a
+
+        func = udtf(TestUDTF, returnType="a: int")
+        with self.assertRaisesRegex(PythonException, "Unexpected tuple 1 with 
StructType"):
+            func(lit(1)).collect()
+
+    def test_udtf_eval_returning_non_generator(self):
+        class TestUDTF:
+            def eval(self, a: int):
+                return (a,)
+
+        func = udtf(TestUDTF, returnType="a: int")
+        with self.assertRaisesRegex(PythonException, "Unexpected tuple 1 with 
StructType"):
+            func(lit(1)).collect()
+
+    def test_udtf_eval_with_no_return(self):
+        @udtf(returnType="a: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                ...
+
+        # TODO(SPARK-43967): Support Python UDTFs with empty return values
+        with self.assertRaisesRegex(
+            PythonException, "TypeError: 'NoneType' object is not iterable"
+        ):
+            TestUDTF(lit(1)).collect()
+
+        @udtf(returnType="a: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                return
+
+        with self.assertRaisesRegex(
+            PythonException, "TypeError: 'NoneType' object is not iterable"
+        ):
+            TestUDTF(lit(1)).collect()
+
+    def test_udtf_with_conditional_return(self):
+        class TestUDTF:
+            def eval(self, a: int):
+                if a > 5:
+                    yield a,
+
+        func = udtf(TestUDTF, returnType="a: int")
+        self.spark.udtf.register("test_udtf", func)
+        self.assertEqual(
+            self.spark.sql("SELECT * FROM range(0, 8) JOIN LATERAL 
test_udtf(id)").collect(),
+            [Row(id=6, a=6), Row(id=7, a=7)],
+        )
+
+    def test_udtf_with_empty_yield(self):
+        @udtf(returnType="a: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield
+
+        with self.assertRaisesRegex(Py4JJavaError, 
"java.lang.NullPointerException"):
+            TestUDTF(lit(1)).collect()
+
+    def test_udtf_with_none_output(self):
+        @udtf(returnType="a: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a,
+                yield None,
+
+        self.assertEqual(TestUDTF(lit(1)).collect(), [Row(a=1), Row(a=None)])
+        df = self.spark.createDataFrame([(0, 1), (1, 2)], schema=["a", "b"])
+        self.assertEqual(TestUDTF(lit(1)).join(df, "a", "inner").collect(), 
[Row(a=1, b=2)])
+        self.assertEqual(
+            TestUDTF(lit(1)).join(df, "a", "left").collect(), [Row(a=None, 
b=None), Row(a=1, b=2)]
+        )
+
+    def test_udtf_with_none_input(self):
+        @udtf(returnType="a: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a,
+
+        self.assertEqual(TestUDTF(lit(None)).collect(), [Row(a=None)])
+        self.spark.udtf.register("testUDTF", TestUDTF)
+        df = self.spark.sql("SELECT * FROM testUDTF(null)")
+        self.assertEqual(df.collect(), [Row(a=None)])
+
+    def test_udtf_with_wrong_num_output(self):
+        err_msg = (
+            "java.lang.IllegalStateException: Input row doesn't have expected 
number of "
+            + "values required by the schema."
+        )
+
+        @udtf(returnType="a: int, b: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a,
+
+        with self.assertRaisesRegex(Py4JJavaError, err_msg):
+            TestUDTF(lit(1)).collect()
+
+        @udtf(returnType="a: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a, a + 1
+
+        with self.assertRaisesRegex(Py4JJavaError, err_msg):
+            TestUDTF(lit(1)).collect()
+
+    def test_udtf_init(self):
+        @udtf(returnType="a: int, b: int, c: string")
+        class TestUDTF:
+            def __init__(self):
+                self.key = "test"
+
+            def eval(self, a: int):
+                yield a, a + 1, self.key
+
+        rows = TestUDTF(lit(1)).collect()
+        self.assertEqual(rows, [Row(a=1, b=2, c="test")])
+
+    def test_udtf_terminate(self):
+        @udtf(returnType="key: string, value: float")
+        class TestUDTF:
+            def __init__(self):
+                self._count = 0
+                self._sum = 0
+
+            def eval(self, x: int):
+                if x > 0:
+                    self._count += 1
+                    self._sum += x
+                    yield "input", float(x)
+
+            def terminate(self):
+                yield "count", float(self._count)
+                yield "avg", self._sum / self._count
+
+        self.assertEqual(
+            TestUDTF(lit(1)).collect(),
+            [Row(key="input", value=1), Row(key="count", value=1.0), 
Row(key="avg", value=1.0)],
+        )
+
+        with self.assertRaisesRegex(PythonException, "division by zero"):
+            TestUDTF(lit(0)).collect()
+
+        self.spark.udtf.register("test_udtf", TestUDTF)
+        df = self.spark.sql(
+            "SELECT id, key, value FROM range(0, 10, 1, 2), "
+            "LATERAL test_udtf(id) WHERE key != 'input'"
+        )
+        self.assertEqual(
+            df.collect(),
+            [
+                Row(id=0, key="count", value=4.0),
+                Row(id=0, key="avg", value=2.5),
+                Row(id=0, key="count", value=5.0),
+                Row(id=0, key="avg", value=7.0),
+            ],
+        )
+
+    def test_terminate_with_exceptions(self):
+        @udtf(returnType="a: int, b: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a, a + 1
+
+            def terminate(self):
+                raise ValueError("terminate error")
+
+        with self.assertRaisesRegex(
+            PythonException, "Failed to terminate the user defined table 
function: terminate error"

Review Comment:
   can we include the UDTF name in this message? And maybe instead of "Failed 
to terminate", we can say that the UDTF failed in the "terminate" method? E.g.
   
   ```
   User defined table function TestUDTF encountered an error in the "terminate"
   method: terminate error
   ```



##########
python/pyspark/sql/tests/test_udtf.py:
##########
@@ -0,0 +1,366 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from typing import Iterator
+
+from py4j.protocol import Py4JJavaError
+
+from pyspark.errors import PythonException, AnalysisException
+from pyspark.sql.functions import lit, udtf
+from pyspark.sql.types import Row
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+
+
+class UDTFTestsMixin(ReusedSQLTestCase):
+    def test_simple_udtf(self):
+        class TestUDTF:
+            def eval(self):
+                yield "hello", "world"
+
+        func = udtf(TestUDTF, returnType="c1: string, c2: string")
+        rows = func().collect()
+        self.assertEqual(rows, [Row(c1="hello", c2="world")])
+
+    def test_udtf_yield_single_row_col(self):
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a,
+
+        func = udtf(TestUDTF, returnType="a: int")
+        rows = func(lit(1)).collect()
+        self.assertEqual(rows, [Row(a=1)])
+
+    def test_udtf_yield_multi_cols(self):
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a, a + 1
+
+        func = udtf(TestUDTF, returnType="a: int, b: int")
+        rows = func(lit(1)).collect()
+        self.assertEqual(rows, [Row(a=1, b=2)])
+
+    def test_udtf_yield_multi_rows(self):
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a,
+                yield a + 1,
+
+        func = udtf(TestUDTF, returnType="a: int")
+        rows = func(lit(1)).collect()
+        self.assertEqual(rows, [Row(a=1), Row(a=2)])
+
+    def test_udtf_yield_multi_row_col(self):
+        class TestUDTF:
+            def eval(self, a: int, b: int):
+                yield a, b, a + b
+                yield a, b, a - b
+                yield a, b, b - a
+
+        func = udtf(TestUDTF, returnType="a: int, b: int, c: int")
+        rows = func(lit(1), lit(2)).collect()
+        self.assertEqual(rows, [Row(a=1, b=2, c=3), Row(a=1, b=2, c=-1), 
Row(a=1, b=2, c=1)])
+
+    def test_udtf_decorator(self):
+        @udtf(returnType="a: int, b: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a, a + 1
+
+        rows = TestUDTF(lit(1)).collect()
+        self.assertEqual(rows, [Row(a=1, b=2)])
+
+    def test_udtf_registration(self):
+        class TestUDTF:
+            def eval(self, a: int, b: int):
+                yield a, b, a + b
+                yield a, b, a - b
+                yield a, b, b - a
+
+        func = udtf(TestUDTF, returnType="a: int, b: int, c: int")
+        self.spark.udtf.register("testUDTF", func)
+        df = self.spark.sql("SELECT * FROM testUDTF(1, 2)")
+        self.assertEqual(
+            df.collect(), [Row(a=1, b=2, c=3), Row(a=1, b=2, c=-1), Row(a=1, 
b=2, c=1)]
+        )
+
+    def test_udtf_with_lateral_join(self):
+        class TestUDTF:
+            def eval(self, a: int, b: int) -> Iterator:
+                yield a, b, a + b
+                yield a, b, a - b
+
+        func = udtf(TestUDTF, returnType="a: int, b: int, c: int")
+        self.spark.udtf.register("testUDTF", func)
+        df = self.spark.sql(
+            "SELECT f.* FROM values (0, 1), (1, 2) t(a, b), LATERAL 
testUDTF(a, b) f"
+        )
+        expected = self.spark.createDataFrame(
+            [(0, 1, 1), (0, 1, -1), (1, 2, 3), (1, 2, -1)], schema=["a", "b", 
"c"]
+        )
+        self.assertEqual(df.collect(), expected.collect())
+
+    def test_udtf_eval_with_return_stmt(self):
+        class TestUDTF:
+            def eval(self, a: int, b: int):
+                return [(a, a + 1), (b, b + 1)]
+
+        func = udtf(TestUDTF, returnType="a: int, b: int")
+        rows = func(lit(1), lit(2)).collect()
+        self.assertEqual(rows, [Row(a=1, b=2), Row(a=2, b=3)])
+
+    def test_udtf_eval_returning_non_tuple(self):
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a
+
+        func = udtf(TestUDTF, returnType="a: int")
+        with self.assertRaisesRegex(PythonException, "Unexpected tuple 1 with 
StructType"):
+            func(lit(1)).collect()
+
+    def test_udtf_eval_returning_non_generator(self):
+        class TestUDTF:
+            def eval(self, a: int):
+                return (a,)
+
+        func = udtf(TestUDTF, returnType="a: int")
+        with self.assertRaisesRegex(PythonException, "Unexpected tuple 1 with 
StructType"):
+            func(lit(1)).collect()
+
+    def test_udtf_eval_with_no_return(self):
+        @udtf(returnType="a: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                ...
+
+        # TODO(SPARK-43967): Support Python UDTFs with empty return values
+        with self.assertRaisesRegex(
+            PythonException, "TypeError: 'NoneType' object is not iterable"
+        ):
+            TestUDTF(lit(1)).collect()
+
+        @udtf(returnType="a: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                return
+
+        with self.assertRaisesRegex(
+            PythonException, "TypeError: 'NoneType' object is not iterable"
+        ):
+            TestUDTF(lit(1)).collect()
+
+    def test_udtf_with_conditional_return(self):
+        class TestUDTF:
+            def eval(self, a: int):
+                if a > 5:
+                    yield a,
+
+        func = udtf(TestUDTF, returnType="a: int")
+        self.spark.udtf.register("test_udtf", func)
+        self.assertEqual(
+            self.spark.sql("SELECT * FROM range(0, 8) JOIN LATERAL 
test_udtf(id)").collect(),
+            [Row(id=6, a=6), Row(id=7, a=7)],
+        )
+
+    def test_udtf_with_empty_yield(self):
+        @udtf(returnType="a: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield
+
+        with self.assertRaisesRegex(Py4JJavaError, 
"java.lang.NullPointerException"):

Review Comment:
   Same, can we catch this and throw a better error message? (Feel free to 
leave a TODO and file a Jira to fix it next.)



##########
python/pyspark/sql/tests/test_udtf.py:
##########
@@ -0,0 +1,366 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from typing import Iterator
+
+from py4j.protocol import Py4JJavaError
+
+from pyspark.errors import PythonException, AnalysisException
+from pyspark.sql.functions import lit, udtf
+from pyspark.sql.types import Row
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+
+
+class UDTFTestsMixin(ReusedSQLTestCase):
+    def test_simple_udtf(self):
+        class TestUDTF:
+            def eval(self):
+                yield "hello", "world"
+
+        func = udtf(TestUDTF, returnType="c1: string, c2: string")
+        rows = func().collect()
+        self.assertEqual(rows, [Row(c1="hello", c2="world")])
+
+    def test_udtf_yield_single_row_col(self):
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a,
+
+        func = udtf(TestUDTF, returnType="a: int")
+        rows = func(lit(1)).collect()
+        self.assertEqual(rows, [Row(a=1)])
+
+    def test_udtf_yield_multi_cols(self):
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a, a + 1
+
+        func = udtf(TestUDTF, returnType="a: int, b: int")
+        rows = func(lit(1)).collect()
+        self.assertEqual(rows, [Row(a=1, b=2)])
+
+    def test_udtf_yield_multi_rows(self):
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a,
+                yield a + 1,
+
+        func = udtf(TestUDTF, returnType="a: int")
+        rows = func(lit(1)).collect()
+        self.assertEqual(rows, [Row(a=1), Row(a=2)])
+
+    def test_udtf_yield_multi_row_col(self):
+        class TestUDTF:
+            def eval(self, a: int, b: int):
+                yield a, b, a + b
+                yield a, b, a - b
+                yield a, b, b - a
+
+        func = udtf(TestUDTF, returnType="a: int, b: int, c: int")
+        rows = func(lit(1), lit(2)).collect()
+        self.assertEqual(rows, [Row(a=1, b=2, c=3), Row(a=1, b=2, c=-1), 
Row(a=1, b=2, c=1)])
+
+    def test_udtf_decorator(self):
+        @udtf(returnType="a: int, b: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a, a + 1
+
+        rows = TestUDTF(lit(1)).collect()
+        self.assertEqual(rows, [Row(a=1, b=2)])
+
+    def test_udtf_registration(self):
+        class TestUDTF:
+            def eval(self, a: int, b: int):
+                yield a, b, a + b
+                yield a, b, a - b
+                yield a, b, b - a
+
+        func = udtf(TestUDTF, returnType="a: int, b: int, c: int")
+        self.spark.udtf.register("testUDTF", func)
+        df = self.spark.sql("SELECT * FROM testUDTF(1, 2)")
+        self.assertEqual(
+            df.collect(), [Row(a=1, b=2, c=3), Row(a=1, b=2, c=-1), Row(a=1, 
b=2, c=1)]
+        )
+
+    def test_udtf_with_lateral_join(self):
+        class TestUDTF:
+            def eval(self, a: int, b: int) -> Iterator:
+                yield a, b, a + b
+                yield a, b, a - b
+
+        func = udtf(TestUDTF, returnType="a: int, b: int, c: int")
+        self.spark.udtf.register("testUDTF", func)
+        df = self.spark.sql(
+            "SELECT f.* FROM values (0, 1), (1, 2) t(a, b), LATERAL 
testUDTF(a, b) f"
+        )
+        expected = self.spark.createDataFrame(
+            [(0, 1, 1), (0, 1, -1), (1, 2, 3), (1, 2, -1)], schema=["a", "b", 
"c"]
+        )
+        self.assertEqual(df.collect(), expected.collect())
+
+    def test_udtf_eval_with_return_stmt(self):
+        class TestUDTF:
+            def eval(self, a: int, b: int):
+                return [(a, a + 1), (b, b + 1)]
+
+        func = udtf(TestUDTF, returnType="a: int, b: int")
+        rows = func(lit(1), lit(2)).collect()
+        self.assertEqual(rows, [Row(a=1, b=2), Row(a=2, b=3)])
+
+    def test_udtf_eval_returning_non_tuple(self):
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a
+
+        func = udtf(TestUDTF, returnType="a: int")
+        with self.assertRaisesRegex(PythonException, "Unexpected tuple 1 with 
StructType"):
+            func(lit(1)).collect()
+
+    def test_udtf_eval_returning_non_generator(self):
+        class TestUDTF:
+            def eval(self, a: int):
+                return (a,)
+
+        func = udtf(TestUDTF, returnType="a: int")
+        with self.assertRaisesRegex(PythonException, "Unexpected tuple 1 with 
StructType"):
+            func(lit(1)).collect()
+
+    def test_udtf_eval_with_no_return(self):
+        @udtf(returnType="a: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                ...
+
+        # TODO(SPARK-43967): Support Python UDTFs with empty return values
+        with self.assertRaisesRegex(
+            PythonException, "TypeError: 'NoneType' object is not iterable"
+        ):
+            TestUDTF(lit(1)).collect()
+
+        @udtf(returnType="a: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                return
+
+        with self.assertRaisesRegex(
+            PythonException, "TypeError: 'NoneType' object is not iterable"
+        ):
+            TestUDTF(lit(1)).collect()
+
+    def test_udtf_with_conditional_return(self):
+        class TestUDTF:
+            def eval(self, a: int):
+                if a > 5:
+                    yield a,
+
+        func = udtf(TestUDTF, returnType="a: int")
+        self.spark.udtf.register("test_udtf", func)
+        self.assertEqual(
+            self.spark.sql("SELECT * FROM range(0, 8) JOIN LATERAL 
test_udtf(id)").collect(),
+            [Row(id=6, a=6), Row(id=7, a=7)],
+        )
+
+    def test_udtf_with_empty_yield(self):
+        @udtf(returnType="a: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield
+
+        with self.assertRaisesRegex(Py4JJavaError, 
"java.lang.NullPointerException"):
+            TestUDTF(lit(1)).collect()
+
+    def test_udtf_with_none_output(self):
+        @udtf(returnType="a: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a,
+                yield None,
+
+        self.assertEqual(TestUDTF(lit(1)).collect(), [Row(a=1), Row(a=None)])
+        df = self.spark.createDataFrame([(0, 1), (1, 2)], schema=["a", "b"])
+        self.assertEqual(TestUDTF(lit(1)).join(df, "a", "inner").collect(), 
[Row(a=1, b=2)])
+        self.assertEqual(
+            TestUDTF(lit(1)).join(df, "a", "left").collect(), [Row(a=None, 
b=None), Row(a=1, b=2)]
+        )
+
+    def test_udtf_with_none_input(self):
+        @udtf(returnType="a: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a,
+
+        self.assertEqual(TestUDTF(lit(None)).collect(), [Row(a=None)])
+        self.spark.udtf.register("testUDTF", TestUDTF)
+        df = self.spark.sql("SELECT * FROM testUDTF(null)")
+        self.assertEqual(df.collect(), [Row(a=None)])
+
+    def test_udtf_with_wrong_num_output(self):
+        err_msg = (
+            "java.lang.IllegalStateException: Input row doesn't have expected 
number of "
+            + "values required by the schema."
+        )
+
+        @udtf(returnType="a: int, b: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a,
+
+        with self.assertRaisesRegex(Py4JJavaError, err_msg):
+            TestUDTF(lit(1)).collect()
+
+        @udtf(returnType="a: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a, a + 1
+
+        with self.assertRaisesRegex(Py4JJavaError, err_msg):
+            TestUDTF(lit(1)).collect()
+
+    def test_udtf_init(self):
+        @udtf(returnType="a: int, b: int, c: string")
+        class TestUDTF:
+            def __init__(self):
+                self.key = "test"
+
+            def eval(self, a: int):
+                yield a, a + 1, self.key
+
+        rows = TestUDTF(lit(1)).collect()
+        self.assertEqual(rows, [Row(a=1, b=2, c="test")])
+
+    def test_udtf_terminate(self):
+        @udtf(returnType="key: string, value: float")
+        class TestUDTF:
+            def __init__(self):
+                self._count = 0
+                self._sum = 0
+
+            def eval(self, x: int):
+                if x > 0:
+                    self._count += 1
+                    self._sum += x
+                    yield "input", float(x)
+
+            def terminate(self):
+                yield "count", float(self._count)
+                yield "avg", self._sum / self._count
+
+        self.assertEqual(
+            TestUDTF(lit(1)).collect(),
+            [Row(key="input", value=1), Row(key="count", value=1.0), 
Row(key="avg", value=1.0)],
+        )
+
+        with self.assertRaisesRegex(PythonException, "division by zero"):
+            TestUDTF(lit(0)).collect()
+
+        self.spark.udtf.register("test_udtf", TestUDTF)
+        df = self.spark.sql(
+            "SELECT id, key, value FROM range(0, 10, 1, 2), "
+            "LATERAL test_udtf(id) WHERE key != 'input'"
+        )
+        self.assertEqual(
+            df.collect(),
+            [
+                Row(id=0, key="count", value=4.0),
+                Row(id=0, key="avg", value=2.5),
+                Row(id=0, key="count", value=5.0),
+                Row(id=0, key="avg", value=7.0),
+            ],
+        )
+
+    def test_terminate_with_exceptions(self):
+        @udtf(returnType="a: int, b: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a, a + 1
+
+            def terminate(self):
+                raise ValueError("terminate error")
+
+        with self.assertRaisesRegex(
+            PythonException, "Failed to terminate the user defined table 
function: terminate error"
+        ):
+            TestUDTF(lit(1)).collect()
+
+    def test_nondeterministic_udtf(self):
+        import random
+
+        class RandomUDTF:
+            def eval(self, a: int):
+                yield a * int(random.random() * 100),
+
+        random_udtf = udtf(RandomUDTF, returnType="x: 
int").asNondeterministic()
+        # TODO(SPARK-43966): support non-deterministic UDTFs
+        with self.assertRaisesRegex(AnalysisException, "nondeterministic 
expressions"):
+            random_udtf(lit(1)).collect()
+
+    def test_udtf_with_nondeterministic_input(self):
+        from pyspark.sql.functions import rand
+
+        @udtf(returnType="x: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a + 1,
+
+        # TODO(SPARK-43966): support non-deterministic UDTFs
+        with self.assertRaisesRegex(AnalysisException, "nondeterministic 
expressions"):
+            TestUDTF(rand(0) * 100).collect()
+
+    def test_udtf_no_eval(self):
+        @udtf(returnType="a: int, b: int")
+        class TestUDTF:
+            def run(self, a: int):
+                yield a, a + 1
+
+        with self.assertRaisesRegex(PythonException, "Python UDTF must 
implement the eval method"):

Review Comment:
   Same, can we include the UDTF name in the error message, e.g.
   
   ```
   Failed to execute the user defined table function TestUDTF because
   it has not yet implemented the "eval" method; please add the "eval"
   method and try the query again
   ```



##########
python/pyspark/worker.py:
##########
@@ -456,6 +456,73 @@ def assign_cols_by_name(runner_conf):
     )
 
 
+# Read and process a serialized user-defined table function (UDTF) from a 
socket.
+# It expects the UDTF to be in a specific format and performs various checks to
+# ensure the UDTF is valid. This function also prepares a mapper function for 
applying
+# the UDTF logic to input rows.
+def read_udtf(pickleSer, infile, eval_type):
+    num_udtfs = read_int(infile)
+    if num_udtfs != 1:
+        raise RuntimeError("Got more than 1 UDTF")

Review Comment:
   can we include the names of the UDTFs in the error message (possibly 
truncating if there are too many to prevent the error message from getting too 
long)? Same for L472 below, and L477, L483, L487, L518.



##########
python/pyspark/sql/udtf.py:
##########
@@ -0,0 +1,227 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+"""
+User-defined table function related classes and functions
+"""
+import sys
+from typing import Type, TYPE_CHECKING, Optional, Union
+
+from py4j.java_gateway import JavaObject
+
+from pyspark.sql.column import _to_java_column, _to_seq
+from pyspark.sql.types import StructType, _parse_datatype_string
+from pyspark.sql.udf import _wrap_function
+
+if TYPE_CHECKING:
+    from pyspark.sql._typing import ColumnOrName
+    from pyspark.sql.dataframe import DataFrame
+    from pyspark.sql.session import SparkSession
+
+__all__ = ["UDTFRegistration"]
+
+
+def _create_udtf(
+    cls: Type,
+    returnType: Union[StructType, str],
+    name: Optional[str] = None,
+    deterministic: bool = True,
+) -> "UserDefinedTableFunction":
+    """Create a Python UDTF."""
+    udtf_obj = UserDefinedTableFunction(
+        cls, returnType=returnType, name=name, deterministic=deterministic
+    )
+    return udtf_obj
+
+
+class UserDefinedTableFunction:
+    """
+    User-defined table function in Python
+
+    .. versionadded:: 3.5.0
+
+    Notes
+    -----
+    The constructor of this class is not supposed to be directly called.
+    Use :meth:`pyspark.sql.functions.udtf` to create this instance.
+
+    This API is evolving.
+    """
+
+    def __init__(
+        self,
+        func: Type,
+        returnType: Union[StructType, str],
+        name: Optional[str] = None,
+        deterministic: bool = True,
+    ):
+        if not isinstance(func, type):
+            raise TypeError(
+                f"Invalid user-defined table function: the function handler "
+                f"must be a class, but got {type(func)}."
+            )
+
+        # TODO(SPARK-43968): add more compile time checks for UDTFs
+
+        self.func = func
+        self._returnType = returnType
+        self._returnType_placeholder: Optional[StructType] = None
+        self._inputTypes_placeholder = None
+        self._judtf_placeholder = None
+        self._name = name or func.__name__
+        self.deterministic = deterministic
+
+    @property
+    def returnType(self) -> StructType:
+        # `_parse_datatype_string` accesses to JVM for parsing a DDL formatted 
string.
+        # This makes sure this is called after SparkContext is initialized.
+        if self._returnType_placeholder is None:
+            if isinstance(self._returnType, StructType):
+                self._returnType_placeholder = self._returnType
+            else:
+                assert isinstance(self._returnType, str)
+                parsed = _parse_datatype_string(self._returnType)
+                if not isinstance(parsed, StructType):
+                    raise TypeError(
+                        f"Invalid function return type string: 
{self._returnType}. "
+                        f"The return type of a UDTF must be a struct type."

Review Comment:
   Same, can we include the UDTF name in the error message?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to