allisonwang-db commented on code in PR #41948:
URL: https://github.com/apache/spark/pull/41948#discussion_r1267359161


##########
core/src/main/scala/org/apache/spark/SparkEnv.scala:
##########
@@ -73,7 +73,10 @@ class SparkEnv (
     val conf: SparkConf) extends Logging {
 
   @volatile private[spark] var isStopped = false
-  private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), 
PythonWorkerFactory]()
+
+  private case class PythonWorkersKey(
+      pythonExec: String, workerModule: String, daemonModule: String, envVars: 
Map[String, String])

Review Comment:
   Nice! This is so much more readable. Nit: can we add a comment to explain 
the parameters here (workerModule and daemonModule) briefly?



##########
common/utils/src/main/resources/error/error-classes.json:
##########
@@ -2446,6 +2446,11 @@
     ],
     "sqlState" : "42P01"
   },
+  "TABLE_VALUED_FUNCTION_FAILED_TO_ANALYZE_IN_PYTHON" : {
+    "message" : [
+      "The user defined Python table function failed to analyze in Python, 
<msg>"

Review Comment:
   ```suggestion
         "Failed to analyze the Python user defined table function: <msg>"
   ```
   Is it possible to add the function name?



##########
python/pyspark/sql/tests/test_udtf.py:
##########
@@ -748,6 +769,442 @@ def terminate(self):
         self.assertIn("Evaluate the input row", cls.eval.__doc__)
         self.assertIn("Terminate the UDTF", cls.terminate.__doc__)
 
+    def test_simple_udtf_with_analyze(self):
+        class TestUDTF:
+            @staticmethod
+            def analyze() -> AnalyzeResult:
+                return AnalyzeResult(StructType().add("c1", 
StringType()).add("c2", StringType()))
+
+            def eval(self):
+                yield "hello", "world"
+
+        func = udtf(TestUDTF)
+        self.spark.udtf.register("test_udtf", func)
+
+        expected = [Row(c1="hello", c2="world")]
+        self.assertEqual(func().collect(), expected)
+        self.assertEqual(self.spark.sql("SELECT * FROM 
test_udtf()").collect(), expected)
+
+    def test_udtf_with_analyze(self):
+        class TestUDTF:
+            @staticmethod
+            def analyze(a: AnalyzeArgument) -> AnalyzeResult:
+                assert isinstance(a, AnalyzeArgument)
+                assert isinstance(a.data_type, DataType)
+                assert a.value is not None
+                assert a.is_table is False
+                return AnalyzeResult(StructType().add("a", a.data_type))
+
+            def eval(self, a):
+                yield a,
+
+        func = udtf(TestUDTF)
+        self.spark.udtf.register("test_udtf", func)
+
+        for i, (df, expected_schema, expected_results) in enumerate(
+            [
+                (func(lit(1)), StructType().add("a", IntegerType()), 
[Row(a=1)]),
+                # another data type
+                (func(lit("x")), StructType().add("a", StringType()), 
[Row(a="x")]),
+                # array type
+                (
+                    func(array(lit(1), lit(2), lit(3))),
+                    StructType().add("a", ArrayType(IntegerType(), 
containsNull=False)),
+                    [Row(a=[1, 2, 3])],
+                ),
+                # map type
+                (
+                    func(create_map(lit("x"), lit(1), lit("y"), lit(2))),
+                    StructType().add(
+                        "a", MapType(StringType(), IntegerType(), 
valueContainsNull=False)
+                    ),
+                    [Row(a={"x": 1, "y": 2})],
+                ),
+                # struct type
+                (
+                    func(named_struct(lit("x"), lit(1), lit("y"), lit(2))),
+                    StructType().add(
+                        "a",
+                        StructType()
+                        .add("x", IntegerType(), nullable=False)
+                        .add("y", IntegerType(), nullable=False),
+                    ),
+                    [Row(a=Row(x=1, y=2))],
+                ),
+                # use SQL
+                (
+                    self.spark.sql("SELECT * from test_udtf(1)"),
+                    StructType().add("a", IntegerType()),
+                    [Row(a=1)],
+                ),
+            ]
+        ):
+            with self.subTest(query_no=i):
+                self.assertEqual(df.schema, expected_schema)

Review Comment:
   can we use the new test utility methods here? `assertSchemaEqual` and 
`assertDataFrameEqual`



##########
python/pyspark/sql/tests/test_udtf.py:
##########
@@ -748,6 +769,442 @@ def terminate(self):
         self.assertIn("Evaluate the input row", cls.eval.__doc__)
         self.assertIn("Terminate the UDTF", cls.terminate.__doc__)
 
+    def test_simple_udtf_with_analyze(self):
+        class TestUDTF:
+            @staticmethod
+            def analyze() -> AnalyzeResult:
+                return AnalyzeResult(StructType().add("c1", 
StringType()).add("c2", StringType()))
+
+            def eval(self):
+                yield "hello", "world"
+
+        func = udtf(TestUDTF)
+        self.spark.udtf.register("test_udtf", func)
+
+        expected = [Row(c1="hello", c2="world")]
+        self.assertEqual(func().collect(), expected)
+        self.assertEqual(self.spark.sql("SELECT * FROM 
test_udtf()").collect(), expected)
+
+    def test_udtf_with_analyze(self):
+        class TestUDTF:
+            @staticmethod
+            def analyze(a: AnalyzeArgument) -> AnalyzeResult:
+                assert isinstance(a, AnalyzeArgument)
+                assert isinstance(a.data_type, DataType)
+                assert a.value is not None
+                assert a.is_table is False
+                return AnalyzeResult(StructType().add("a", a.data_type))
+
+            def eval(self, a):
+                yield a,
+
+        func = udtf(TestUDTF)
+        self.spark.udtf.register("test_udtf", func)
+
+        for i, (df, expected_schema, expected_results) in enumerate(
+            [
+                (func(lit(1)), StructType().add("a", IntegerType()), 
[Row(a=1)]),
+                # another data type
+                (func(lit("x")), StructType().add("a", StringType()), 
[Row(a="x")]),
+                # array type
+                (
+                    func(array(lit(1), lit(2), lit(3))),
+                    StructType().add("a", ArrayType(IntegerType(), 
containsNull=False)),
+                    [Row(a=[1, 2, 3])],
+                ),
+                # map type
+                (
+                    func(create_map(lit("x"), lit(1), lit("y"), lit(2))),
+                    StructType().add(
+                        "a", MapType(StringType(), IntegerType(), 
valueContainsNull=False)
+                    ),
+                    [Row(a={"x": 1, "y": 2})],
+                ),
+                # struct type
+                (
+                    func(named_struct(lit("x"), lit(1), lit("y"), lit(2))),
+                    StructType().add(
+                        "a",
+                        StructType()
+                        .add("x", IntegerType(), nullable=False)
+                        .add("y", IntegerType(), nullable=False),
+                    ),
+                    [Row(a=Row(x=1, y=2))],
+                ),
+                # use SQL
+                (
+                    self.spark.sql("SELECT * from test_udtf(1)"),
+                    StructType().add("a", IntegerType()),
+                    [Row(a=1)],
+                ),
+            ]
+        ):
+            with self.subTest(query_no=i):
+                self.assertEqual(df.schema, expected_schema)
+                self.assertEqual(df.collect(), expected_results)
+
+    def test_udtf_with_analyze_decorator(self):
+        @udtf
+        class TestUDTF:
+            @staticmethod
+            def analyze() -> AnalyzeResult:
+                return AnalyzeResult(StructType().add("c1", 
StringType()).add("c2", StringType()))
+
+            def eval(self):
+                yield "hello", "world"
+
+        self.spark.udtf.register("test_udtf", TestUDTF)
+
+        expected = [Row(c1="hello", c2="world")]
+        self.assertEqual(TestUDTF().collect(), expected)
+        self.assertEqual(self.spark.sql("SELECT * FROM 
test_udtf()").collect(), expected)
+
+    def test_udtf_with_analyze_decorator_parens(self):
+        @udtf()
+        class TestUDTF:
+            @staticmethod
+            def analyze() -> AnalyzeResult:
+                return AnalyzeResult(StructType().add("c1", 
StringType()).add("c2", StringType()))
+
+            def eval(self):
+                yield "hello", "world"
+
+        self.spark.udtf.register("test_udtf", TestUDTF)
+
+        expected = [Row(c1="hello", c2="world")]
+        self.assertEqual(TestUDTF().collect(), expected)
+        self.assertEqual(self.spark.sql("SELECT * FROM 
test_udtf()").collect(), expected)
+
+    def test_udtf_with_analyze_multiple_arguments(self):
+        class TestUDTF:
+            @staticmethod
+            def analyze(a: AnalyzeArgument, b: AnalyzeArgument) -> 
AnalyzeResult:
+                return AnalyzeResult(StructType().add("a", 
a.data_type).add("b", b.data_type))
+
+            def eval(self, a, b):
+                yield a, b
+
+        func = udtf(TestUDTF)
+        self.spark.udtf.register("test_udtf", func)
+
+        for i, (df, expected_schema, expected_results) in enumerate(
+            [
+                (
+                    func(lit(1), lit("x")),
+                    StructType().add("a", IntegerType()).add("b", 
StringType()),
+                    [Row(a=1, b="x")],
+                ),
+                (
+                    self.spark.sql("SELECT * FROM test_udtf(1, 'x')"),
+                    StructType().add("a", IntegerType()).add("b", 
StringType()),
+                    [Row(a=1, b="x")],
+                ),
+            ]
+        ):
+            with self.subTest(query_no=i):
+                self.assertEqual(df.schema, expected_schema)
+                self.assertEqual(df.collect(), expected_results)
+
+    def test_udtf_with_analyze_arbitary_number_arguments(self):
+        class TestUDTF:
+            @staticmethod
+            def analyze(*args: AnalyzeArgument) -> AnalyzeResult:
+                return AnalyzeResult(
+                    StructType([StructField(f"col{i}", a.data_type) for i, a 
in enumerate(args)])
+                )
+
+            def eval(self, *args):
+                yield args
+
+        func = udtf(TestUDTF)
+        self.spark.udtf.register("test_udtf", func)
+
+        for i, (df, expected_schema, expected_results) in enumerate(
+            [
+                (
+                    func(lit(1)),
+                    StructType().add("col0", IntegerType()),
+                    [Row(a=1)],
+                ),
+                (
+                    self.spark.sql("SELECT * FROM test_udtf(1, 'x')"),
+                    StructType().add("col0", IntegerType()).add("col1", 
StringType()),
+                    [Row(a=1, b="x")],
+                ),
+                # TODO(SPARK-44479): Support Python UDTFs with empty schema
+                # (func(), StructType(), [Row()]),
+            ]
+        ):
+            with self.subTest(query_no=i):
+                self.assertEqual(df.schema, expected_schema)
+                self.assertEqual(df.collect(), expected_results)
+
+    def test_udtf_with_analyze_table_argument(self):
+        class TestUDTF:
+            @staticmethod
+            def analyze(a: AnalyzeArgument) -> AnalyzeResult:
+                assert isinstance(a, AnalyzeArgument)
+                assert isinstance(a.data_type, StructType)
+                assert a.value is None
+                assert a.is_table is True
+                return AnalyzeResult(StructType().add("a", 
a.data_type[0].dataType))
+
+            def eval(self, a: Row):
+                if a["id"] > 5:
+                    yield a["id"],
+
+        func = udtf(TestUDTF)
+        self.spark.udtf.register("test_udtf", func)
+
+        df = self.spark.sql("SELECT * FROM test_udtf(TABLE (SELECT id FROM 
range(0, 8)))")
+        self.assertEqual(df.schema, StructType().add("a", LongType()))
+        self.assertEqual(df.collect(), [Row(a=6), Row(a=7)])
+
+    def test_udtf_with_analyze_table_argument_adding_columns(self):
+        class TestUDTF:
+            @staticmethod
+            def analyze(a: AnalyzeArgument) -> AnalyzeResult:
+                assert isinstance(a.data_type, StructType)
+                assert a.is_table is True
+                return AnalyzeResult(a.data_type.add("is_even", BooleanType()))
+
+            def eval(self, a: Row):
+                yield a["id"], a["id"] % 2 == 0
+
+        func = udtf(TestUDTF)
+        self.spark.udtf.register("test_udtf", func)
+
+        df = self.spark.sql("SELECT * FROM test_udtf(TABLE (SELECT id FROM 
range(0, 4)))")
+        self.assertEqual(
+            df.schema,
+            StructType().add("id", LongType(), nullable=False).add("is_even", 
BooleanType()),
+        )
+        self.assertEqual(
+            df.collect(),
+            [
+                Row(a=0, is_even=True),
+                Row(a=1, is_even=False),
+                Row(a=2, is_even=True),
+                Row(a=3, is_even=False),
+            ],
+        )
+
+    def test_udtf_with_analyze_table_argument_repeating_rows(self):
+        class TestUDTF:
+            @staticmethod
+            def analyze(n, row) -> AnalyzeResult:
+                if n.value is None or not isinstance(n.value, int) or (n.value 
< 1 or n.value > 10):
+                    raise Exception("The first argument must be a scalar 
integer between 1 and 10")
+
+                if row.is_table is False:
+                    raise Exception("The second argument must be a table 
argument")
+
+                assert isinstance(row.data_type, StructType)
+                return AnalyzeResult(row.data_type)
+
+            def eval(self, n: int, row: Row):
+                for _ in range(n):
+                    yield row
+
+        func = udtf(TestUDTF)
+        self.spark.udtf.register("test_udtf", func)
+
+        expected_schema = StructType().add("id", LongType(), nullable=False)
+        expected_results = [
+            Row(a=0),
+            Row(a=0),
+            Row(a=1),
+            Row(a=1),
+            Row(a=2),
+            Row(a=2),
+            Row(a=3),
+            Row(a=3),
+        ]
+        for i, df in enumerate(
+            [
+                self.spark.sql("SELECT * FROM test_udtf(2, TABLE (SELECT id 
FROM range(0, 4)))"),
+                self.spark.sql(
+                    "SELECT * FROM test_udtf(1 + 1, TABLE (SELECT id FROM 
range(0, 4)))"

Review Comment:
   Since we are accessing `.value` for the analyze argument, what if the input 
value here is not a foldable one? for example `SELECT * FROM values (0, 1) 
t(c1, c2), lateral test_udtf(c1, ...)`. 



##########
python/pyspark/sql/udtf.py:
##########
@@ -27,20 +29,32 @@
 from pyspark.rdd import PythonEvalType
 from pyspark.sql.column import _to_java_column, _to_seq
 from pyspark.sql.pandas.utils import require_minimum_pandas_version, 
require_minimum_pyarrow_version
-from pyspark.sql.types import StructType, _parse_datatype_string
+from pyspark.sql.types import DataType, StructType, _parse_datatype_string
 from pyspark.sql.udf import _wrap_function
 
 if TYPE_CHECKING:
     from pyspark.sql._typing import ColumnOrName
     from pyspark.sql.dataframe import DataFrame
     from pyspark.sql.session import SparkSession
 
-__all__ = ["UDTFRegistration"]
+__all__ = ["AnalyzeArgument", "AnalyzeResult", "UDTFRegistration"]
+
+
+@dataclass(frozen=True)
+class AnalyzeArgument:

Review Comment:
   nit: let's add some comments for this class and variables.
   



##########
python/pyspark/sql/tests/test_udtf.py:
##########
@@ -748,6 +769,442 @@ def terminate(self):
         self.assertIn("Evaluate the input row", cls.eval.__doc__)
         self.assertIn("Terminate the UDTF", cls.terminate.__doc__)
 
+    def test_simple_udtf_with_analyze(self):
+        class TestUDTF:
+            @staticmethod
+            def analyze() -> AnalyzeResult:
+                return AnalyzeResult(StructType().add("c1", 
StringType()).add("c2", StringType()))
+
+            def eval(self):
+                yield "hello", "world"
+
+        func = udtf(TestUDTF)
+        self.spark.udtf.register("test_udtf", func)
+
+        expected = [Row(c1="hello", c2="world")]
+        self.assertEqual(func().collect(), expected)
+        self.assertEqual(self.spark.sql("SELECT * FROM 
test_udtf()").collect(), expected)
+
+    def test_udtf_with_analyze(self):
+        class TestUDTF:
+            @staticmethod
+            def analyze(a: AnalyzeArgument) -> AnalyzeResult:
+                assert isinstance(a, AnalyzeArgument)
+                assert isinstance(a.data_type, DataType)
+                assert a.value is not None
+                assert a.is_table is False
+                return AnalyzeResult(StructType().add("a", a.data_type))
+
+            def eval(self, a):
+                yield a,
+
+        func = udtf(TestUDTF)
+        self.spark.udtf.register("test_udtf", func)
+
+        for i, (df, expected_schema, expected_results) in enumerate(
+            [
+                (func(lit(1)), StructType().add("a", IntegerType()), 
[Row(a=1)]),
+                # another data type
+                (func(lit("x")), StructType().add("a", StringType()), 
[Row(a="x")]),
+                # array type
+                (
+                    func(array(lit(1), lit(2), lit(3))),
+                    StructType().add("a", ArrayType(IntegerType(), 
containsNull=False)),
+                    [Row(a=[1, 2, 3])],
+                ),
+                # map type
+                (
+                    func(create_map(lit("x"), lit(1), lit("y"), lit(2))),
+                    StructType().add(
+                        "a", MapType(StringType(), IntegerType(), 
valueContainsNull=False)
+                    ),
+                    [Row(a={"x": 1, "y": 2})],
+                ),
+                # struct type
+                (
+                    func(named_struct(lit("x"), lit(1), lit("y"), lit(2))),
+                    StructType().add(
+                        "a",
+                        StructType()
+                        .add("x", IntegerType(), nullable=False)
+                        .add("y", IntegerType(), nullable=False),
+                    ),
+                    [Row(a=Row(x=1, y=2))],
+                ),
+                # use SQL
+                (
+                    self.spark.sql("SELECT * from test_udtf(1)"),
+                    StructType().add("a", IntegerType()),
+                    [Row(a=1)],
+                ),
+            ]
+        ):
+            with self.subTest(query_no=i):
+                self.assertEqual(df.schema, expected_schema)
+                self.assertEqual(df.collect(), expected_results)
+
+    def test_udtf_with_analyze_decorator(self):
+        @udtf
+        class TestUDTF:
+            @staticmethod
+            def analyze() -> AnalyzeResult:
+                return AnalyzeResult(StructType().add("c1", 
StringType()).add("c2", StringType()))
+
+            def eval(self):
+                yield "hello", "world"
+
+        self.spark.udtf.register("test_udtf", TestUDTF)
+
+        expected = [Row(c1="hello", c2="world")]
+        self.assertEqual(TestUDTF().collect(), expected)
+        self.assertEqual(self.spark.sql("SELECT * FROM 
test_udtf()").collect(), expected)
+
+    def test_udtf_with_analyze_decorator_parens(self):
+        @udtf()
+        class TestUDTF:
+            @staticmethod
+            def analyze() -> AnalyzeResult:
+                return AnalyzeResult(StructType().add("c1", 
StringType()).add("c2", StringType()))
+
+            def eval(self):
+                yield "hello", "world"
+
+        self.spark.udtf.register("test_udtf", TestUDTF)
+
+        expected = [Row(c1="hello", c2="world")]
+        self.assertEqual(TestUDTF().collect(), expected)
+        self.assertEqual(self.spark.sql("SELECT * FROM 
test_udtf()").collect(), expected)
+
+    def test_udtf_with_analyze_multiple_arguments(self):
+        class TestUDTF:
+            @staticmethod
+            def analyze(a: AnalyzeArgument, b: AnalyzeArgument) -> 
AnalyzeResult:
+                return AnalyzeResult(StructType().add("a", 
a.data_type).add("b", b.data_type))
+
+            def eval(self, a, b):
+                yield a, b
+
+        func = udtf(TestUDTF)
+        self.spark.udtf.register("test_udtf", func)
+
+        for i, (df, expected_schema, expected_results) in enumerate(
+            [
+                (
+                    func(lit(1), lit("x")),
+                    StructType().add("a", IntegerType()).add("b", 
StringType()),
+                    [Row(a=1, b="x")],
+                ),
+                (
+                    self.spark.sql("SELECT * FROM test_udtf(1, 'x')"),
+                    StructType().add("a", IntegerType()).add("b", 
StringType()),
+                    [Row(a=1, b="x")],
+                ),
+            ]
+        ):
+            with self.subTest(query_no=i):
+                self.assertEqual(df.schema, expected_schema)
+                self.assertEqual(df.collect(), expected_results)
+
+    def test_udtf_with_analyze_arbitary_number_arguments(self):
+        class TestUDTF:
+            @staticmethod
+            def analyze(*args: AnalyzeArgument) -> AnalyzeResult:
+                return AnalyzeResult(
+                    StructType([StructField(f"col{i}", a.data_type) for i, a 
in enumerate(args)])
+                )
+
+            def eval(self, *args):
+                yield args
+
+        func = udtf(TestUDTF)
+        self.spark.udtf.register("test_udtf", func)
+
+        for i, (df, expected_schema, expected_results) in enumerate(
+            [
+                (
+                    func(lit(1)),
+                    StructType().add("col0", IntegerType()),
+                    [Row(a=1)],
+                ),
+                (
+                    self.spark.sql("SELECT * FROM test_udtf(1, 'x')"),
+                    StructType().add("col0", IntegerType()).add("col1", 
StringType()),
+                    [Row(a=1, b="x")],
+                ),
+                # TODO(SPARK-44479): Support Python UDTFs with empty schema
+                # (func(), StructType(), [Row()]),
+            ]
+        ):
+            with self.subTest(query_no=i):
+                self.assertEqual(df.schema, expected_schema)
+                self.assertEqual(df.collect(), expected_results)
+
+    def test_udtf_with_analyze_table_argument(self):
+        class TestUDTF:
+            @staticmethod
+            def analyze(a: AnalyzeArgument) -> AnalyzeResult:
+                assert isinstance(a, AnalyzeArgument)
+                assert isinstance(a.data_type, StructType)
+                assert a.value is None
+                assert a.is_table is True
+                return AnalyzeResult(StructType().add("a", 
a.data_type[0].dataType))
+
+            def eval(self, a: Row):
+                if a["id"] > 5:
+                    yield a["id"],
+
+        func = udtf(TestUDTF)
+        self.spark.udtf.register("test_udtf", func)
+
+        df = self.spark.sql("SELECT * FROM test_udtf(TABLE (SELECT id FROM 
range(0, 8)))")
+        self.assertEqual(df.schema, StructType().add("a", LongType()))
+        self.assertEqual(df.collect(), [Row(a=6), Row(a=7)])
+
+    def test_udtf_with_analyze_table_argument_adding_columns(self):
+        class TestUDTF:
+            @staticmethod
+            def analyze(a: AnalyzeArgument) -> AnalyzeResult:
+                assert isinstance(a.data_type, StructType)
+                assert a.is_table is True
+                return AnalyzeResult(a.data_type.add("is_even", BooleanType()))
+
+            def eval(self, a: Row):
+                yield a["id"], a["id"] % 2 == 0
+
+        func = udtf(TestUDTF)
+        self.spark.udtf.register("test_udtf", func)
+
+        df = self.spark.sql("SELECT * FROM test_udtf(TABLE (SELECT id FROM 
range(0, 4)))")
+        self.assertEqual(
+            df.schema,
+            StructType().add("id", LongType(), nullable=False).add("is_even", 
BooleanType()),
+        )
+        self.assertEqual(
+            df.collect(),
+            [
+                Row(a=0, is_even=True),
+                Row(a=1, is_even=False),
+                Row(a=2, is_even=True),
+                Row(a=3, is_even=False),
+            ],
+        )
+
+    def test_udtf_with_analyze_table_argument_repeating_rows(self):
+        class TestUDTF:
+            @staticmethod
+            def analyze(n, row) -> AnalyzeResult:
+                if n.value is None or not isinstance(n.value, int) or (n.value 
< 1 or n.value > 10):
+                    raise Exception("The first argument must be a scalar 
integer between 1 and 10")
+
+                if row.is_table is False:
+                    raise Exception("The second argument must be a table 
argument")
+
+                assert isinstance(row.data_type, StructType)
+                return AnalyzeResult(row.data_type)
+
+            def eval(self, n: int, row: Row):
+                for _ in range(n):
+                    yield row
+
+        func = udtf(TestUDTF)
+        self.spark.udtf.register("test_udtf", func)
+
+        expected_schema = StructType().add("id", LongType(), nullable=False)
+        expected_results = [
+            Row(a=0),
+            Row(a=0),
+            Row(a=1),
+            Row(a=1),
+            Row(a=2),
+            Row(a=2),
+            Row(a=3),
+            Row(a=3),
+        ]
+        for i, df in enumerate(
+            [
+                self.spark.sql("SELECT * FROM test_udtf(2, TABLE (SELECT id 
FROM range(0, 4)))"),
+                self.spark.sql(
+                    "SELECT * FROM test_udtf(1 + 1, TABLE (SELECT id FROM 
range(0, 4)))"
+                ),
+            ]
+        ):
+            with self.subTest(query_no=i):
+                self.assertEqual(df.schema, expected_schema)
+                self.assertEqual(df.collect(), expected_results)
+
+        with self.assertRaisesRegex(
+            AnalysisException, "The first argument must be a scalar integer 
between 1 and 10"
+        ):
+            self.spark.sql(
+                "SELECT * FROM test_udtf(0, TABLE (SELECT id FROM range(0, 
4)))"
+            ).collect()
+
+        with self.sql_conf(
+            {"spark.sql.tvf.allowMultipleTableArguments.enabled": True}
+        ), self.assertRaisesRegex(
+            AnalysisException, "The first argument must be a scalar integer 
between 1 and 10"
+        ):
+            self.spark.sql(
+                """
+                SELECT * FROM test_udtf(
+                  TABLE (SELECT id FROM range(0, 1)),
+                  TABLE (SELECT id FROM range(0, 4)))
+                """
+            ).collect()
+
+        with self.assertRaisesRegex(
+            AnalysisException, "The second argument must be a table argument"
+        ):
+            self.spark.sql("SELECT * FROM test_udtf(1, 'x')").collect()
+
+    def test_udtf_with_both_return_type_and_analyze(self):
+        class TestUDTF:
+            @staticmethod
+            def analyze() -> AnalyzeResult:
+                return AnalyzeResult(StructType().add("c1", 
StringType()).add("c2", StringType()))
+
+            def eval(self):
+                yield "hello", "world"
+
+        with self.assertRaises(PySparkAttributeError) as e:
+            udtf(TestUDTF, returnType="c1: string, c2: string")
+
+        self.check_error(
+            exception=e.exception,
+            
error_class="INVALID_UDTF_BOTH_RETURN_TYPE_AND_ANALYZE_STATICMETHOD",
+            message_parameters={"name": "TestUDTF"},
+        )
+
+    def test_udtf_with_neither_return_type_nor_analyze(self):
+        class TestUDTF:
+            def eval(self):
+                yield "hello", "world"
+
+        with self.assertRaises(PySparkAttributeError) as e:
+            udtf(TestUDTF)
+
+        self.check_error(
+            exception=e.exception,
+            error_class="INVALID_UDTF_RETURN_TYPE",
+            message_parameters={"name": "TestUDTF"},
+        )
+
+    def test_udtf_with_analyze_non_staticmethod(self):
+        class TestUDTF:
+            def analyze(self) -> AnalyzeResult:
+                return AnalyzeResult(StructType().add("c1", 
StringType()).add("c2", StringType()))
+
+            def eval(self):
+                yield "hello", "world"
+
+        with self.assertRaises(PySparkAttributeError) as e:
+            udtf(TestUDTF)
+
+        self.check_error(
+            exception=e.exception,
+            error_class="INVALID_UDTF_RETURN_TYPE",

Review Comment:
   Instead of throwing this `INVALID_UDTF_RETURN_TYPE` error, shall we check if 
the UDTF class has an `analyze` method and throw an error if it is not static? 
Or we could update the error message for INVALID_UDTF_RETURN_TYPE to suggest 
making `analyze` a staticmethod.



##########
python/pyspark/sql/worker/analyze_udtf.py:
##########
@@ -0,0 +1,142 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import inspect
+import os
+import sys
+import traceback
+from typing import List, IO
+
+from pyspark.errors import PySparkRuntimeError, PySparkValueError
+from pyspark.java_gateway import local_connect_and_auth
+from pyspark.serializers import (
+    read_bool,
+    read_int,
+    write_int,
+    write_with_length,
+    SpecialLengths,
+)
+from pyspark.sql.types import _parse_datatype_json_string
+from pyspark.sql.udtf import AnalyzeArgument, AnalyzeResult
+from pyspark.util import try_simplify_traceback
+from pyspark.worker import check_python_version, read_command, pickleSer, 
utf8_deserializer
+
+
+def read_udtf(infile: IO) -> type:
+    """Reads the Python UDTF and checks if its valid or not."""
+    # Receive Python UDTF
+    handler = read_command(pickleSer, infile)
+    if not isinstance(handler, type):
+        raise PySparkRuntimeError(
+            f"Invalid UDTF handler type. Expected a class (type 'type'), but "
+            f"got an instance of {type(handler).__name__}."
+        )
+
+    if not hasattr(handler, "analyze") or not isinstance(
+        inspect.getattr_static(handler, "analyze"), staticmethod
+    ):
+        raise PySparkRuntimeError(
+            "Failed to execute the user defined table function because it has 
not "
+            "implemented the 'analyze' static method or specified a fixed "
+            "return type during registration time. "
+            "Please add the 'analyze' static method or specify the return 
type, "
+            "and try the query again."
+        )
+    return handler
+
+
+def read_arguments(infile: IO) -> List[AnalyzeArgument]:
+    """Reads the arguments for `analyze` static method."""
+    # Receive arguments
+    num_args = read_int(infile)
+    args: List[AnalyzeArgument] = []
+    for _ in range(num_args):
+        dt = _parse_datatype_json_string(utf8_deserializer.loads(infile))
+        if read_bool(infile):  # is foldable
+            value = pickleSer._read_with_length(infile)
+            if dt.needConversion():
+                value = dt.fromInternal(value)
+        else:
+            value = None

Review Comment:
   Let's add this as a comment in the AnalyzeArgument data class: The value 
field will be None when the input value is not foldable.



##########
python/pyspark/sql/udtf.py:
##########
@@ -27,20 +29,32 @@
 from pyspark.rdd import PythonEvalType
 from pyspark.sql.column import _to_java_column, _to_seq
 from pyspark.sql.pandas.utils import require_minimum_pandas_version, 
require_minimum_pyarrow_version
-from pyspark.sql.types import StructType, _parse_datatype_string
+from pyspark.sql.types import DataType, StructType, _parse_datatype_string
 from pyspark.sql.udf import _wrap_function
 
 if TYPE_CHECKING:
     from pyspark.sql._typing import ColumnOrName
     from pyspark.sql.dataframe import DataFrame
     from pyspark.sql.session import SparkSession
 
-__all__ = ["UDTFRegistration"]
+__all__ = ["AnalyzeArgument", "AnalyzeResult", "UDTFRegistration"]
+
+
+@dataclass(frozen=True)
+class AnalyzeArgument:
+    data_type: DataType
+    value: Optional[Any]
+    is_table: bool
+
+
+@dataclass(frozen=True)
+class AnalyzeResult:

Review Comment:
   ditto



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala:
##########
@@ -176,6 +177,28 @@ case class PythonUDTF(
     copy(children = newChildren)
 }
 
+/**
+ * A place holder of a Polymorphic Python table-valued function.

Review Comment:
   ```suggestion
    * A placeholder of a polymorphic Python table-valued function.
   ```



##########
python/pyspark/sql/udtf.py:
##########
@@ -103,6 +104,14 @@ class VectorizedUDTF:
         def __init__(self) -> None:
             self.func = cls()
 
+        if hasattr(cls, "analyze") and isinstance(
+            inspect.getattr_static(cls, "analyze"), staticmethod
+        ):
+
+            @staticmethod

Review Comment:
   :+1:



##########
python/pyspark/sql/udtf.py:
##########
@@ -153,6 +175,19 @@ def _validate_udtf_handler(cls: Any) -> None:
             error_class="INVALID_UDTF_NO_EVAL", message_parameters={"name": 
cls.__name__}
         )
 
+    has_analyze_staticmethod = hasattr(cls, "analyze") and isinstance(
+        inspect.getattr_static(cls, "analyze"), staticmethod

Review Comment:
   I wonder if we should throw a warning message if the UDTF handler class 
contains `analyze,` but it is not a static method.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala:
##########
@@ -63,18 +77,51 @@ case class UserDefinedPythonFunction(
 case class UserDefinedPythonTableFunction(
     name: String,
     func: PythonFunction,
-    returnType: StructType,
+    returnType: Option[StructType],
     pythonEvalType: Int,
     udfDeterministic: Boolean) {
 
-  def builder(e: Seq[Expression]): LogicalPlan = {
-    val udtf = PythonUDTF(
-      name = name,
-      func = func,
-      elementSchema = returnType,
-      children = e,
-      evalType = pythonEvalType,
-      udfDeterministic = udfDeterministic)
+  def this(
+      name: String,
+      func: PythonFunction,
+      returnType: StructType,
+      pythonEvalType: Int,
+      udfDeterministic: Boolean) = {
+    this(name, func, Some(returnType), pythonEvalType, udfDeterministic)
+  }
+
+  def this(
+      name: String,
+      func: PythonFunction,
+      pythonEvalType: Int,
+      udfDeterministic: Boolean) = {
+    this(name, func, None, pythonEvalType, udfDeterministic)
+  }
+
+  def builder(exprs: Seq[Expression]): LogicalPlan = {
+    val udtf = returnType match {
+      case Some(rt) =>
+        PythonUDTF(
+          name = name,
+          func = func,
+          elementSchema = rt,
+          children = exprs,
+          evalType = pythonEvalType,
+          udfDeterministic = udfDeterministic)
+      case _ =>
+        UnresolvedPolymorphicPythonUDTF(
+          name = name,
+          func = func,
+          children = exprs,
+          evalType = pythonEvalType,
+          udfDeterministic = udfDeterministic,
+          resolveElementSchema = 
UserDefinedPythonTableFunction.analyzeInPython(
+            exprs.map {
+              case _: FunctionTableSubqueryArgumentExpression => true
+              case NamedArgumentExpression(_, _: 
FunctionTableSubqueryArgumentExpression) => true
+              case _ => false

Review Comment:
   Can we add a comment to explain this code logic here?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to