dtenedor commented on code in PR #41948: URL: https://github.com/apache/spark/pull/41948#discussion_r1264230197
########## python/pyspark/sql/worker/analyze_udtf.py: ########## @@ -0,0 +1,140 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import inspect +import os +import sys +import traceback +from typing import Any, Dict, List, IO + +from pyspark.errors import PySparkRuntimeError, PySparkValueError +from pyspark.java_gateway import local_connect_and_auth +from pyspark.serializers import ( + read_bool, + read_int, + write_int, + write_with_length, + SpecialLengths, +) +from pyspark.sql.types import StructType, _parse_datatype_json_string +from pyspark.util import try_simplify_traceback +from pyspark.worker import check_python_version, read_command, pickleSer, utf8_deserializer + + +def read_udtf(infile: IO) -> type: + """Reads the Python UDTF and checks if its valid or not.""" + # Receive Python UDTF + handler = read_command(pickleSer, infile) + if not isinstance(handler, type): + raise PySparkRuntimeError( + f"Invalid UDTF handler type. Expected a class (type 'type'), but " + f"got an instance of {type(handler).__name__}." + ) + + if not hasattr(handler, "analyze") or not isinstance( + inspect.getattr_static(handler, "analyze"), staticmethod + ): + raise PySparkRuntimeError( + "Failed to execute the user defined table function because it has not " + "implemented the 'analyze' static method. " Review Comment: ```suggestion "implemented the 'analyze' static method or specified a fixed " "return type during registration time. " ``` ########## python/pyspark/sql/functions.py: ########## @@ -50,6 +50,7 @@ # Keep UserDefinedFunction import for backwards compatible import; moved in SPARK-22409 from pyspark.sql.udf import UserDefinedFunction, _create_py_udf # noqa: F401 from pyspark.sql.udtf import UserDefinedTableFunction, _create_py_udtf +from pyspark.sql.udtf import AnalyzeArgument, AnalyzeResult # noqa: F401 Review Comment: please sort alphabetically? Also, F401 is for "Module imported but unused", do we need to import the AnalyzeArgument and AnalyzeResult here, or can we just skip the import? Or do we need this here so users can import these symbols directly from this `functions.py` as well? ########## python/pyspark/sql/tests/test_udtf.py: ########## @@ -753,6 +773,376 @@ def terminate(self): self.assertIn("Evaluate the input row", cls.eval.__doc__) self.assertIn("Terminate the UDTF", cls.terminate.__doc__) + def test_simple_udtf_with_analyze(self): + class TestUDTF: + @staticmethod + def analyze() -> AnalyzeResult: + return AnalyzeResult(StructType().add("c1", StringType()).add("c2", StringType())) + + def eval(self): + yield "hello", "world" + + func = udtf(TestUDTF) + self.spark.udtf.register("test_udtf", func) + + self.assertEqual(func().collect(), [Row(c1="hello", c2="world")]) + self.assertEqual( + self.spark.sql("SELECT * FROM test_udtf()").collect(), [Row(c1="hello", c2="world")] + ) + + def test_udtf_with_analyze(self): + class TestUDTF: + @staticmethod + def analyze(a: AnalyzeArgument) -> AnalyzeResult: + assert isinstance(a, AnalyzeArgument) + assert isinstance(a.data_type, DataType) + assert a.value is not None + assert a.is_table is False + return AnalyzeResult(StructType().add("a", a.data_type)) + + def eval(self, a): + yield a, + + func = udtf(TestUDTF) + self.spark.udtf.register("test_udtf", func) + + for i, (df, expected_schema, expected_results) in enumerate( + [ + (func(lit(1)), StructType().add("a", IntegerType()), [Row(a=1)]), + # another data type + (func(lit("x")), StructType().add("a", StringType()), [Row(a="x")]), + # array type + ( + func(array(lit(1), lit(2), lit(3))), + StructType().add("a", ArrayType(IntegerType(), containsNull=False)), + [Row(a=[1, 2, 3])], + ), + # map type + ( + func(create_map(lit("x"), lit(1), lit("y"), lit(2))), + StructType().add( + "a", MapType(StringType(), IntegerType(), valueContainsNull=False) + ), + [Row(a={"x": 1, "y": 2})], + ), + # struct type + ( + func(named_struct(lit("x"), lit(1), lit("y"), lit(2))), + StructType().add( + "a", + StructType() + .add("x", IntegerType(), nullable=False) + .add("y", IntegerType(), nullable=False), + ), + [Row(a=Row(x=1, y=2))], + ), + # use SQL + ( + self.spark.sql("SELECT * from test_udtf(1)"), + StructType().add("a", IntegerType()), + [Row(a=1)], + ), + ] + ): + with self.subTest(query_no=i): + self.assertEqual(df.schema, expected_schema) + self.assertEqual(df.collect(), expected_results) + + def test_udtf_with_analyze_decorator(self): + @udtf + class TestUDTF: + @staticmethod + def analyze() -> AnalyzeResult: + return AnalyzeResult(StructType().add("c1", StringType()).add("c2", StringType())) + + def eval(self): + yield "hello", "world" + + self.spark.udtf.register("test_udtf", TestUDTF) + + self.assertEqual(TestUDTF().collect(), [Row(c1="hello", c2="world")]) + self.assertEqual( + self.spark.sql("SELECT * FROM test_udtf()").collect(), [Row(c1="hello", c2="world")] + ) + + def test_udtf_with_analyze_decorator_parens(self): + @udtf() + class TestUDTF: + @staticmethod + def analyze() -> AnalyzeResult: + return AnalyzeResult(StructType().add("c1", StringType()).add("c2", StringType())) + + def eval(self): + yield "hello", "world" + + self.spark.udtf.register("test_udtf", TestUDTF) + + self.assertEqual(TestUDTF().collect(), [Row(c1="hello", c2="world")]) + self.assertEqual( + self.spark.sql("SELECT * FROM test_udtf()").collect(), [Row(c1="hello", c2="world")] + ) + + def test_udtf_with_analyze_multiple_arguments(self): + class TestUDTF: + @staticmethod + def analyze(a: AnalyzeArgument, b: AnalyzeArgument) -> AnalyzeResult: + return AnalyzeResult(StructType().add("a", a.data_type).add("b", b.data_type)) + + def eval(self, a, b): + yield a, b + + func = udtf(TestUDTF) + self.spark.udtf.register("test_udtf", func) + + for i, (df, expected_schema, expected_results) in enumerate( + [ + ( + func(lit(1), lit("x")), + StructType().add("a", IntegerType()).add("b", StringType()), + [Row(a=1, b="x")], + ), + ( + self.spark.sql("SELECT * FROM test_udtf(1, 'x')"), + StructType().add("a", IntegerType()).add("b", StringType()), + [Row(a=1, b="x")], + ), + ] + ): + with self.subTest(query_no=i): + self.assertEqual(df.schema, expected_schema) + self.assertEqual(df.collect(), expected_results) + + def test_udtf_with_analyze_table_argument(self): + class TestUDTF: + @staticmethod + def analyze(a: AnalyzeArgument) -> AnalyzeResult: + assert isinstance(a, AnalyzeArgument) + assert isinstance(a.data_type, StructType) + assert a.value is None + assert a.is_table is True + return AnalyzeResult(StructType().add("a", a.data_type[0].dataType)) + + def eval(self, a: Row): + if a["id"] > 5: + yield a["id"], + + func = udtf(TestUDTF) + self.spark.udtf.register("test_udtf", func) + + df = self.spark.sql("SELECT * FROM test_udtf(TABLE (SELECT id FROM range(0, 8)))") + self.assertEqual(df.schema, StructType().add("a", LongType())) + self.assertEqual(df.collect(), [Row(a=6), Row(a=7)]) + + def test_udtf_with_analyze_table_argument_adding_columns(self): + class TestUDTF: + @staticmethod + def analyze(a: AnalyzeArgument) -> AnalyzeResult: + assert isinstance(a.data_type, StructType) + assert a.is_table is True + return AnalyzeResult(a.data_type.add("is_even", BooleanType())) + + def eval(self, a: Row): + yield a["id"], a["id"] % 2 == 0 + + func = udtf(TestUDTF) + self.spark.udtf.register("test_udtf", func) + + df = self.spark.sql("SELECT * FROM test_udtf(TABLE (SELECT id FROM range(0, 4)))") + self.assertEqual( + df.schema, + StructType().add("id", LongType(), nullable=False).add("is_even", BooleanType()), + ) + self.assertEqual( + df.collect(), + [ + Row(a=0, is_even=True), + Row(a=1, is_even=False), + Row(a=2, is_even=True), + Row(a=3, is_even=False), + ], + ) + + def test_udtf_with_analyze_table_argument_repeating_rows(self): + class TestUDTF: + @staticmethod + def analyze(n, row) -> AnalyzeResult: + if n.value is None or not isinstance(n.value, int) or (n.value < 1 or n.value > 10): + raise Exception("The first argument must be a scalar integer between 1 and 10") + + if row.is_table is False: + raise Exception("The second argument must be a table argument") + + assert isinstance(row.data_type, StructType) + return AnalyzeResult(row.data_type) + + def eval(self, n: int, row: Row): + for _ in range(n): + yield row + + func = udtf(TestUDTF) + self.spark.udtf.register("test_udtf", func) + + df1 = self.spark.sql("SELECT * FROM test_udtf(2, TABLE (SELECT id FROM range(0, 4)))") + self.assertEqual(df1.schema, StructType().add("id", LongType(), nullable=False)) + self.assertEqual( + df1.collect(), + [Row(a=0), Row(a=0), Row(a=1), Row(a=1), Row(a=2), Row(a=2), Row(a=3), Row(a=3)], + ) + + df2 = self.spark.sql("SELECT * FROM test_udtf(1 + 1, TABLE (SELECT id FROM range(0, 4)))") + self.assertEqual(df2.schema, StructType().add("id", LongType(), nullable=False)) + self.assertEqual( + df2.collect(), + [Row(a=0), Row(a=0), Row(a=1), Row(a=1), Row(a=2), Row(a=2), Row(a=3), Row(a=3)], + ) + + with self.assertRaisesRegex( + AnalysisException, "The first argument must be a scalar integer between 1 and 10" + ): + self.spark.sql( + "SELECT * FROM test_udtf(0, TABLE (SELECT id FROM range(0, 4)))" + ).collect() + + with self.sql_conf( + {"spark.sql.tvf.allowMultipleTableArguments.enabled": True} + ), self.assertRaisesRegex( + AnalysisException, "The first argument must be a scalar integer between 1 and 10" + ): + self.spark.sql( + """ + SELECT * FROM test_udtf( + TABLE (SELECT id FROM range(0, 1)), + TABLE (SELECT id FROM range(0, 4))) + """ + ).collect() + + with self.assertRaisesRegex( + AnalysisException, "The second argument must be a table argument" + ): + self.spark.sql("SELECT * FROM test_udtf(1, 'x')").collect() + + def test_udtf_with_both_return_type_and_analyze(self): + class TestUDTF: + @staticmethod + def analyze() -> AnalyzeResult: + return AnalyzeResult(StructType().add("c1", StringType()).add("c2", StringType())) + + def eval(self): + yield "hello", "world" + + with self.assertRaises(PySparkAttributeError) as e: + udtf(TestUDTF, returnType="c1: string, c2: string") + + self.check_error( + exception=e.exception, + error_class="INVALID_UDTF_BOTH_RETURN_TYPE_AND_ANALYZE_STATICMETHOD", + message_parameters={"name": "TestUDTF"}, + ) + + def test_udtf_with_neither_return_type_nor_analyze(self): + class TestUDTF: + def eval(self): + yield "hello", "world" + + with self.assertRaises(PySparkAttributeError) as e: + udtf(TestUDTF) + + self.check_error( + exception=e.exception, + error_class="INVALID_UDTF_RETURN_TYPE", + message_parameters={"name": "TestUDTF"}, + ) + + def test_udtf_with_analyze_non_staticmethod(self): + class TestUDTF: + def analyze(self) -> AnalyzeResult: + return AnalyzeResult(StructType().add("c1", StringType()).add("c2", StringType())) + + def eval(self): + yield "hello", "world" + + with self.assertRaises(PySparkAttributeError) as e: + udtf(TestUDTF) + + self.check_error( + exception=e.exception, + error_class="INVALID_UDTF_RETURN_TYPE", + message_parameters={"name": "TestUDTF"}, + ) + + def test_udtf_with_analyze_returning_non_struct(self): + class TestUDTF: + @staticmethod + def analyze(): + return StringType() + + def eval(self): + yield "hello", "world" + + func = udtf(TestUDTF) + + with self.assertRaisesRegex( + AnalysisException, + "Output of `analyze` static method of Python UDTFs expects " + "a pyspark.sql.udtf.AnalyzeResult but got: <class 'pyspark.sql.types.StringType'>", + ): + func().collect() + + def test_udtf_with_analyze_raising_an_exception(self): + class TestUDTF: + @staticmethod + def analyze() -> AnalyzeResult: + raise Exception("Failed to analyze.") + + def eval(self): + yield "hello", "world" + + func = udtf(TestUDTF) + + with self.assertRaisesRegex(AnalysisException, "Failed to analyze."): + func().collect() + + def test_udtf_with_analyze_null_literal(self): + class TestUDTF: + @staticmethod + def analyze(a: AnalyzeArgument) -> AnalyzeResult: + return AnalyzeResult(StructType().add("a", a.data_type)) + + def eval(self, a): + yield a, + + func = udtf(TestUDTF) + + df = func(lit(None)) + self.assertEqual(df.schema, StructType().add("a", NullType())) + self.assertEqual(df.collect(), [Row(a=None)]) + + def test_udtf_with_analyze_taking_wrong_number_of_arguments(self): Review Comment: we talked offline, it seems possible to make a test UDTF where the `analyze` method accepts non-keyword arguments: `def analyze(*args) -> AnalyzeResult:`. Can we please add a test case for this? We can add another test to exercise `def analyze(**kwargs) -> AnalyzeResult` as well? In the future, we could return some reasonable error message in this case like "Failed to evaluate table-valued function XXXXX because keyword arguments are not supported for the 'analyze' static method." ########## sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala: ########## @@ -91,3 +122,122 @@ case class UserDefinedPythonTableFunction( Dataset.ofRows(session, udtf) } } + +object UserDefinedPythonTableFunction { + + private[this] val workerModule = "pyspark.sql.worker.analyze_udtf" + + /** + * Runs the Python UDTF's `analyze` static method. + * + * When the Python UDTF is defined without a static return type, + * the analyze will call this while resolving table-valued functions. Review Comment: ```suggestion * the analyzer will call this while resolving table-valued functions. ``` ########## connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala: ########## @@ -2482,20 +2482,24 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { private def createPythonUserDefinedTableFunction( fun: proto.CommonInlineUserDefinedTableFunction): UserDefinedPythonTableFunction = { val udtf = fun.getPythonUdtf - // Currently return type is required for Python UDTFs. - // TODO(SPARK-44380): support `analyze` in Python UDTFs - assert(udtf.hasReturnType) - val returnType = transformDataType(udtf.getReturnType) - if (!returnType.isInstanceOf[StructType]) { - throw InvalidPlanInput( - "Invalid Python user-defined table function return type. " + - s"Expect a struct type, but got ${returnType.typeName}.") + val returnType = if (udtf.hasReturnType) { + val returnType = transformDataType(udtf.getReturnType) + + if (!returnType.isInstanceOf[StructType]) { Review Comment: optional: should we do ``` returnType match { case s: StructType => s case _ => throw InvalidPlanInput(...) } ``` ########## python/pyspark/sql/udtf.py: ########## @@ -103,6 +104,14 @@ class VectorizedUDTF: def __init__(self) -> None: self.func = cls() + if hasattr(cls, "analyze") and isinstance( + inspect.getattr_static(cls, "analyze"), staticmethod + ): + + @staticmethod Review Comment: 👍 ########## python/pyspark/sql/udtf.py: ########## @@ -103,6 +104,14 @@ class VectorizedUDTF: def __init__(self) -> None: self.func = cls() + if hasattr(cls, "analyze") and isinstance( + inspect.getattr_static(cls, "analyze"), staticmethod + ): + + @staticmethod Review Comment: 👍 ########## python/pyspark/errors/error_classes.py: ########## @@ -268,6 +268,11 @@ "Eval type for UDF must be <eval_type>." ] }, + "INVALID_UDTF_BOTH_RETURN_TYPE_AND_ANALYZE_STATICMETHOD" : { + "message" : [ + "The UDTF '<name>' is invalid. It has both its return type and the required 'analyze' static method. Please make it have one of either the return type or the 'analyze' static method in '<name>' and try again." Review Comment: 👍 this seems helpful to avoid confusion. ########## python/pyspark/sql/tests/test_udtf.py: ########## @@ -719,6 +730,394 @@ def terminate(self): self.assertIn("Evaluate the input row", cls.eval.__doc__) self.assertIn("Terminate the UDTF", cls.terminate.__doc__) + def test_simple_udtf_with_analyze(self): + class TestUDTF: + @staticmethod + def analyze() -> StructType: + return StructType().add("c1", StringType()).add("c2", StringType()) + + def eval(self): + yield "hello", "world" + + func = udtf(TestUDTF) + self.spark.udtf.register("test_udtf", func) + + self.assertEqual(func().collect(), [Row(c1="hello", c2="world")]) Review Comment: maybe assign these expected result rows to a variable in each case, so we don't have to repeat it again for the SQL case? ########## python/pyspark/sql/functions.py: ########## @@ -15564,6 +15566,36 @@ def udtf( | 1| 2| +---+---+ + UDTF can also have `analyze` static method instead of a static return type: + + The `analyze` static method should take arguments: + + - The number and order of arguments are the same as the UDTF inputs + - Each argument is a :class:`pyspark.sql.udtf.AnalyzeArgument`, containing: + - data_type: DataType + - value: Any: if the argument is foldable; otherwise None + - is_table: bool: True if the argument is TABLE + + and return a :class:`pyspark.sql.udtf.AnalyzeResult`, containing. Review Comment: optional: maybe also add the "import" statement we'd need to include at the top of the file to gain access to `@udtf`, `AnalyzeArgument`, and `AnalyzeResult` symbols? ########## python/pyspark/sql/connect/udtf.py: ########## @@ -33,8 +33,8 @@ ) from pyspark.sql.connect.types import UnparsedDataType from pyspark.sql.connect.utils import get_python_ver -from pyspark.sql.udtf import UDTFRegistration as PySparkUDTFRegistration -from pyspark.sql.udtf import _validate_udtf_handler +from pyspark.sql.udtf import UDTFRegistration as PySparkUDTFRegistration, _validate_udtf_handler +from pyspark.sql.udtf import AnalyzeArgument, AnalyzeResult # noqa: F401 Review Comment: F401 is for "Module imported but unused", do we need to import the `AnalyzeArgument` and `AnalyzeResult` here, or can we just skip the import? ########## python/pyspark/errors/error_classes.py: ########## @@ -268,6 +268,11 @@ "Eval type for UDF must be <eval_type>." ] }, + "INVALID_UDTF_BOTH_RETURN_TYPE_AND_ANALYZE_STATICMETHOD" : { + "message" : [ + "The UDTF '<name>' is invalid. It has both its return type and the required 'analyze' static method. Please make it have one of either the return type or the 'analyze' static method in '<name>' and try again." Review Comment: 👍 this seems helpful to avoid confusion. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
