This is an automated email from the ASF dual-hosted git repository. ueshin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 69cf80d25f0e [SPARK-45402][SQL][PYTHON] Add UDTF API for 'eval' and 'terminate' methods to consume previous 'analyze' result 69cf80d25f0e is described below commit 69cf80d25f0e4ed46ec38a63e063471988c31732 Author: Daniel Tenedorio <daniel.tenedo...@databricks.com> AuthorDate: Wed Oct 11 18:52:06 2023 -0700 [SPARK-45402][SQL][PYTHON] Add UDTF API for 'eval' and 'terminate' methods to consume previous 'analyze' result ### What changes were proposed in this pull request? This PR adds a Python UDTF API for the `eval` and `terminate` methods to consume the previous `analyze` result. This also works for subclasses of the `AnalyzeResult` class, allowing the UDTF to return custom state from `analyze` to be consumed later. For example, we can now define a UDTF that perform complex initialization in the `analyze` method and then returns the result of that in the `terminate` method: ``` def MyUDTF(self): dataclass class AnalyzeResultWithBuffer(AnalyzeResult): buffer: str udtf class TestUDTF: def __init__(self, analyze_result): self._total = 0 self._buffer = do_complex_initialization(analyze_result.buffer) staticmethod def analyze(argument, _): return AnalyzeResultWithBuffer( schema=StructType() .add("total", IntegerType()) .add("buffer", StringType()), with_single_partition=True, buffer=argument.value, ) def eval(self, argument, row: Row): self._total += 1 def terminate(self): yield self._total, self._buffer self.spark.udtf.register("my_ddtf", MyUDTF) ``` Then the results might look like: ``` sql( """ WITH t AS ( SELECT id FROM range(1, 21) ) SELECT total, buffer FROM test_udtf("abc", TABLE(t)) """ ).collect() > 20, "complex_initialization_result" ``` ### Why are the changes needed? In this way, the UDTF can perform potentially expensive initialization logic in the `analyze` method just once and result the result of such initialization rather than repeating the initialization in `eval`. ### Does this PR introduce _any_ user-facing change? Yes, see above. ### How was this patch tested? This PR adds new unit test coverage. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43204 from dtenedor/prepare-string. Authored-by: Daniel Tenedorio <daniel.tenedo...@databricks.com> Signed-off-by: Takuya UESHIN <ues...@databricks.com> --- python/docs/source/user_guide/sql/python_udtf.rst | 124 ++++++++++++++++++++- python/pyspark/sql/tests/test_udtf.py | 53 +++++++++ python/pyspark/sql/udtf.py | 5 +- python/pyspark/sql/worker/analyze_udtf.py | 2 + python/pyspark/worker.py | 34 +++++- .../spark/sql/catalyst/analysis/Analyzer.scala | 5 +- .../spark/sql/catalyst/expressions/PythonUDF.scala | 20 +++- .../execution/python/BatchEvalPythonUDTFExec.scala | 8 ++ .../python/UserDefinedPythonFunction.scala | 7 +- .../sql-tests/analyzer-results/udtf/udtf.sql.out | 26 +++-- .../test/resources/sql-tests/inputs/udtf/udtf.sql | 9 +- .../resources/sql-tests/results/udtf/udtf.sql.out | 28 +++-- .../apache/spark/sql/IntegratedUDFTestUtils.scala | 64 ++++++++++- .../sql/execution/python/PythonUDTFSuite.scala | 42 +++++-- 14 files changed, 374 insertions(+), 53 deletions(-) diff --git a/python/docs/source/user_guide/sql/python_udtf.rst b/python/docs/source/user_guide/sql/python_udtf.rst index 74d8eb889861..fb42644dc702 100644 --- a/python/docs/source/user_guide/sql/python_udtf.rst +++ b/python/docs/source/user_guide/sql/python_udtf.rst @@ -50,10 +50,108 @@ To implement a Python UDTF, you first need to define a class implementing the me Notes ----- - - This method does not accept any extra arguments. Only the default - constructor is supported. - You cannot create or reference the Spark session within the UDTF. Any attempt to do so will result in a serialization error. + - If the below `analyze` method is implemented, it is also possible to define this + method as: `__init__(self, analyze_result: AnalyzeResult)`. In this case, the result + of the `analyze` method is passed into all future instantiations of this UDTF class. + In this way, the UDTF may inspect the schema and metadata of the output table as + needed during execution of other methods in this class. Note that it is possible to + create a subclass of the `AnalyzeResult` class if desired for purposes of passing + custom information generated just once during UDTF analysis to other method calls; + this can be especially useful if this initialization is expensive. + """ + ... + + def analyze(self, *args: Any) -> AnalyzeResult: + """ + Computes the output schema of a particular call to this function in response to the + arguments provided. + + This method is optional and only needed if the registration of the UDTF did not provide + a static output schema to be use for all calls to the function. In this context, + `output schema` refers to the ordered list of the names and types of the columns in the + function's result table. + + This method accepts zero or more parameters mapping 1:1 with the arguments provided to + the particular UDTF call under consideration. Each parameter is an instance of the + `AnalyzeArgument` class, which contains fields including the provided argument's data + type and value (in the case of literal scalar arguments only). For table arguments, the + `is_table` field is set to true and the `data_type` field is a StructType representing + the table's column types: + + data_type: DataType + value: Optional[Any] + is_table: bool + + This method returns an instance of the `AnalyzeResult` class which includes the result + table's schema as a StructType. If the UDTF accepts an input table argument, then the + `AnalyzeResult` can also include a requested way to partition the rows of the input + table across several UDTF calls. If `with_single_partition` is set to True, the query + planner will arrange a repartitioning operation from the previous execution stage such + that all rows of the input table are consumed by the `eval` method from exactly one + instance of the UDTF class. On the other hand, if the `partition_by` list is non-empty, + the query planner will arrange a repartitioning such that all rows with each unique + combination of values of the partitioning columns are consumed by a separate unique + instance of the UDTF class. If `order_by` is non-empty, this specifies the requested + ordering of rows within each partition. + + schema: StructType + with_single_partition: bool = False + partition_by: Sequence[PartitioningColumn] = field(default_factory=tuple) + order_by: Sequence[OrderingColumn] = field(default_factory=tuple) + + Examples + -------- + analyze implementation that returns one output column for each word in the input string + argument. + + >>> def analyze(self, text: str) -> AnalyzeResult: + ... schema = StructType() + ... for index, word in enumerate(text.split(" ")): + ... schema = schema.add(f"word_{index}") + ... return AnalyzeResult(schema=schema) + + Same as above, but using *args to accept the arguments. + + >>> def analyze(self, *args) -> AnalyzeResult: + ... assert len(args) == 1, "This function accepts one argument only" + ... assert args[0].data_type == StringType(), "Only string arguments are supported" + ... text = args[0] + ... schema = StructType() + ... for index, word in enumerate(text.split(" ")): + ... schema = schema.add(f"word_{index}") + ... return AnalyzeResult(schema=schema) + + Same as above, but using **kwargs to accept the arguments. + + >>> def analyze(self, **kwargs) -> AnalyzeResult: + ... assert len(kwargs) == 1, "This function accepts one argument only" + ... assert "text" in kwargs, "An argument named 'text' is required" + ... assert kwargs["text"].data_type == StringType(), "Only strings are supported" + ... text = args["text"] + ... schema = StructType() + ... for index, word in enumerate(text.split(" ")): + ... schema = schema.add(f"word_{index}") + ... return AnalyzeResult(schema=schema) + + analyze implementation that returns a constant output schema, but add custom information + in the result metadata to be consumed by future __init__ method calls: + + >>> def analyze(self, text: str) -> AnalyzeResult: + ... @dataclass + ... class AnalyzeResultWithOtherMetadata(AnalyzeResult): + ... num_words: int + ... num_articles: int + ... words = text.split(" ") + ... return AnalyzeResultWithOtherMetadata( + ... schema=StructType() + ... .add("word", StringType()) + ... .add('total", IntegerType()), + ... num_words=len(words), + ... num_articles=len(( + ... word for word in words + ... if word == 'a' or word == 'an' or word == 'the'))) """ ... @@ -89,7 +187,9 @@ To implement a Python UDTF, you first need to define a class implementing the me ----- - The result of the function must be a tuple representing a single row in the UDTF result table. - - UDTFs currently do not accept keyword arguments during the function call. + - It is also possible for UDTFs to accept the exact arguments expected, along with + their types. + - UDTFs can instead accept keyword arguments during the function call if needed. Examples -------- @@ -103,6 +203,24 @@ To implement a Python UDTF, you first need to define a class implementing the me >>> def eval(self, x: int, y: int): ... yield (x + y, x - y) ... yield (y + x, y - x) + + Same as above, but using *args to accept the arguments: + + >>> def eval(self, *args): + ... assert len(args) == 2, "This function accepts two integer arguments only" + ... x = args[0] + ... y = args[1] + ... yield (x + y, x - y) + ... yield (y + x, y - x) + + Same as above, but using **kwargs to accept the arguments: + + >>> def eval(self, **kwargs): + ... assert len(kwargs) == 2, "This function accepts two integer arguments only" + ... x = kwargs["x"] + ... y = kwargs["y"] + ... yield (x + y, x - y) + ... yield (y + x, y - x) """ ... diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index 9c821f4bde9c..98676bd7be49 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -18,6 +18,7 @@ import os import shutil import tempfile import unittest +from dataclasses import dataclass from typing import Iterator from py4j.protocol import Py4JJavaError @@ -2365,6 +2366,58 @@ class BaseUDTFTestsMixin: + [Row(partition_col=42, count=3, total=3, last=None)], ) + def test_udtf_with_prepare_string_from_analyze(self): + @dataclass + class AnalyzeResultWithBuffer(AnalyzeResult): + buffer: str = "" + + @udtf + class TestUDTF: + def __init__(self, analyze_result=None): + self._total = 0 + if analyze_result is not None: + self._buffer = analyze_result.buffer + else: + self._buffer = "" + + @staticmethod + def analyze(argument, _): + if ( + argument.value is None + or argument.is_table + or not isinstance(argument.value, str) + or len(argument.value) == 0 + ): + raise Exception("The first argument must be non-empty string") + assert argument.data_type == StringType() + assert not argument.is_table + return AnalyzeResultWithBuffer( + schema=StructType().add("total", IntegerType()).add("buffer", StringType()), + with_single_partition=True, + buffer=argument.value, + ) + + def eval(self, argument, row: Row): + self._total += 1 + + def terminate(self): + yield self._total, self._buffer + + self.spark.udtf.register("test_udtf", TestUDTF) + + assertDataFrameEqual( + self.spark.sql( + """ + WITH t AS ( + SELECT id FROM range(1, 21) + ) + SELECT total, buffer + FROM test_udtf("abc", TABLE(t)) + """ + ).collect(), + [Row(count=20, buffer="abc")], + ) + class UDTFTests(BaseUDTFTestsMixin, ReusedSQLTestCase): @classmethod diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py index ba4bac2ffdfa..26ce68111db8 100644 --- a/python/pyspark/sql/udtf.py +++ b/python/pyspark/sql/udtf.py @@ -85,7 +85,10 @@ class OrderingColumn: overrideNullsFirst: Optional[bool] = None -@dataclass(frozen=True) +# Note: this class is a "dataclass" for purposes of convenience, but it is not marked "frozen" +# because the intention is that users may create subclasses of it for purposes of returning custom +# information from the "analyze" method. +@dataclass class AnalyzeResult: """ The return of Python UDTF's analyze static method. diff --git a/python/pyspark/sql/worker/analyze_udtf.py b/python/pyspark/sql/worker/analyze_udtf.py index 6fb3ca995e5d..a6aa381eb14a 100644 --- a/python/pyspark/sql/worker/analyze_udtf.py +++ b/python/pyspark/sql/worker/analyze_udtf.py @@ -126,6 +126,8 @@ def main(infile: IO, outfile: IO) -> None: # Return the analyzed schema. write_with_length(result.schema.json().encode("utf-8"), outfile) + # Return the pickled 'AnalyzeResult' class instance. + pickleSer._write_with_length(result, outfile) # Return whether the "with single partition" property is requested. write_int(1 if result.with_single_partition else 0, outfile) # Return the list of partitioning columns, if any. diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 3d08f6c4baea..df7dd1bc2f73 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -20,6 +20,7 @@ Worker that receives input from Piped RDD. """ import os import sys +import dataclasses import time from inspect import getfullargspec import json @@ -666,7 +667,7 @@ def read_udtf(pickleSer, infile, eval_type): # Each row is a group so do not batch but send one by one. ser = BatchedSerializer(CPickleSerializer(), 1) - # See `PythonUDTFRunner.PythonUDFWriterThread.writeCommand' + # See 'PythonUDTFRunner.PythonUDFWriterThread.writeCommand' num_arg = read_int(infile) args_offsets = [] kwargs_offsets = {} @@ -679,6 +680,14 @@ def read_udtf(pickleSer, infile, eval_type): args_offsets.append(offset) num_partition_child_indexes = read_int(infile) partition_child_indexes = [read_int(infile) for i in range(num_partition_child_indexes)] + has_pickled_analyze_result = read_bool(infile) + if has_pickled_analyze_result: + pickled_analyze_result = pickleSer._read_with_length(infile) + else: + pickled_analyze_result = None + # Initially we assume that the UDTF __init__ method accepts the pickled AnalyzeResult, + # although we may set this to false later if we find otherwise. + udtf_init_method_accepts_analyze_result = True handler = read_command(pickleSer, infile) if not isinstance(handler, type): raise PySparkRuntimeError( @@ -692,6 +701,29 @@ def read_udtf(pickleSer, infile, eval_type): f"The return type of a UDTF must be a struct type, but got {type(return_type)}." ) + # Update the handler that creates a new UDTF instance to first try calling the UDTF constructor + # with one argument containing the previous AnalyzeResult. If that fails, then try a constructor + # with no arguments. In this way each UDTF class instance can decide if it wants to inspect the + # AnalyzeResult. + if has_pickled_analyze_result: + prev_handler = handler + + def construct_udtf(): + nonlocal udtf_init_method_accepts_analyze_result + if not udtf_init_method_accepts_analyze_result: + return prev_handler() + else: + try: + # Here we pass the AnalyzeResult to the UDTF's __init__ method. + return prev_handler(dataclasses.replace(pickled_analyze_result)) + except TypeError: + # This means that the UDTF handler does not accept an AnalyzeResult object in + # its __init__ method. + udtf_init_method_accepts_analyze_result = False + return prev_handler() + + handler = construct_udtf + class UDTFWithPartitions: """ This implements the logic of a UDTF that accepts an input TABLE argument with one or more diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index cc0bfd3fc31b..18a0aec8fc61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2229,8 +2229,9 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor analyzeResult.applyToTableArgument(u.name, t) case c => c } - PythonUDTF(u.name, u.func, analyzeResult.schema, newChildren, - u.evalType, u.udfDeterministic, u.resultId) + PythonUDTF( + u.name, u.func, analyzeResult.schema, Some(analyzeResult.pickledAnalyzeResult), + newChildren, u.evalType, u.udfDeterministic, u.resultId) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala index 539505543a40..f886b50e8a23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala @@ -159,6 +159,10 @@ abstract class UnevaluableGenerator extends Generator { * @param name name of the Python UDTF being called * @param func string contents of the Python code in the UDTF, along with other environment state * @param elementSchema result schema of the function call + * @param pickledAnalyzeResult if the UDTF defined an 'analyze' method, this contains the pickled + * 'AnalyzeResult' instance from that method, which contains all + * metadata returned including the result schema of the function call as + * well as optional other information * @param children input arguments to the UDTF call; for scalar arguments these are the expressions * themeselves, and for TABLE arguments, these are instances of * [[FunctionTableSubqueryArgumentExpression]] @@ -167,15 +171,15 @@ abstract class UnevaluableGenerator extends Generator { * @param udfDeterministic true if this function is deterministic wherein it returns the same result * rows for every call with the same input arguments * @param resultId unique expression ID for this function invocation - * @param pythonUDTFPartitionColumnIndexes holds the indexes of the TABLE argument to the Python - * UDTF call, if applicable - * @param analyzeResult holds the result of the polymorphic Python UDTF 'analze' method, if the UDTF - * defined one + * @param pythonUDTFPartitionColumnIndexes holds the zero-based indexes of the projected results of + * all PARTITION BY expressions within the TABLE argument of + * the Python UDTF call, if applicable */ case class PythonUDTF( name: String, func: PythonFunction, elementSchema: StructType, + pickledAnalyzeResult: Option[Array[Byte]], children: Seq[Expression], evalType: Int, udfDeterministic: Boolean, @@ -224,6 +228,7 @@ case class UnresolvedPolymorphicPythonUDTF( /** * Represents the result of invoking the polymorphic 'analyze' method on a Python user-defined table * function. This returns the table function's output schema in addition to other optional metadata. + * * @param schema result schema of this particular function call in response to the particular * arguments provided, including the types of any provided scalar arguments (and * their values, in the case of literals) as well as the names and types of columns of @@ -241,12 +246,17 @@ case class UnresolvedPolymorphicPythonUDTF( * @param orderByExpressions if non-empty, this contains the list of ordering items that the * 'analyze' method explicitly indicated that the UDTF call should consume * the input table rows by + * @param pickledAnalyzeResult this is the pickled 'AnalyzeResult' instance from the UDTF, which + * contains all metadata returned by the Python UDTF 'analyze' method + * including the result schema of the function call as well as optional + * other information */ case class PythonUDTFAnalyzeResult( schema: StructType, withSinglePartition: Boolean, partitionByExpressions: Seq[Expression], - orderByExpressions: Seq[SortOrder]) { + orderByExpressions: Seq[SortOrder], + pickledAnalyzeResult: Array[Byte]) { /** * Applies the requested properties from this analysis result to the target TABLE argument * expression of a UDTF call, throwing an error if any properties of the UDTF call are diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala index a70d16dc7e89..40993f96e7a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala @@ -112,6 +112,7 @@ object PythonUDTFRunner { dataOut: DataOutputStream, udtf: PythonUDTF, argMetas: Array[ArgumentMetadata]): Unit = { + // Write the argument types of the UDTF. dataOut.writeInt(argMetas.length) argMetas.foreach { case ArgumentMetadata(offset, name) => @@ -124,6 +125,8 @@ object PythonUDTFRunner { dataOut.writeBoolean(false) } } + // Write the zero-based indexes of the projected results of all PARTITION BY expressions within + // the TABLE argument of the Python UDTF call, if applicable. udtf.pythonUDTFPartitionColumnIndexes match { case Some(partitionColumnIndexes) => dataOut.writeInt(partitionColumnIndexes.partitionChildIndexes.length) @@ -132,7 +135,12 @@ object PythonUDTFRunner { case None => dataOut.writeInt(0) } + // Write the pickled AnalyzeResult buffer from the UDTF "analyze" method, if any. + dataOut.writeBoolean(udtf.pickledAnalyzeResult.nonEmpty) + udtf.pickledAnalyzeResult.foreach(PythonWorkerUtils.writeBytes(_, dataOut)) + // Write the contents of the Python script itself. PythonWorkerUtils.writePythonFunction(udtf.func, dataOut) + // Write the result schema of the UDTF call. PythonWorkerUtils.writeUTF(udtf.elementSchema.json, dataOut) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index b03942cdf43c..d8d3cc9b7fc4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -129,6 +129,7 @@ case class UserDefinedPythonTableFunction( name = name, func = func, elementSchema = rt, + pickledAnalyzeResult = None, children = exprs, evalType = pythonEvalType, udfDeterministic = udfDeterministic) @@ -283,6 +284,9 @@ object UserDefinedPythonTableFunction { val schema = DataType.fromJson( PythonWorkerUtils.readUTF(length, dataIn)).asInstanceOf[StructType] + // Receive the pickled AnalyzeResult buffer, if any. + val pickledAnalyzeResult: Array[Byte] = PythonWorkerUtils.readBytes(dataIn) + // Receive whether the "with single partition" property is requested. val withSinglePartition = dataIn.readInt() == 1 // Receive the list of requested partitioning columns, if any. @@ -324,7 +328,8 @@ object UserDefinedPythonTableFunction { schema = schema, withSinglePartition = withSinglePartition, partitionByExpressions = partitionByColumns.toSeq, - orderByExpressions = orderBy.toSeq) + orderByExpressions = orderBy.toSeq, + pickledAnalyzeResult = pickledAnalyzeResult) } catch { case eof: EOFException => throw new SparkException("Python worker exited unexpectedly (crashed)", eof) diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/udtf/udtf.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/udtf/udtf.sql.out index f7b2bada26ec..1b923442207e 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/udtf/udtf.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/udtf/udtf.sql.out @@ -123,13 +123,19 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException -- !query -SELECT * FROM UDTFWithSinglePartition(TABLE(t2)) +SELECT * FROM UDTFWithSinglePartition(0, TABLE(t2)) -- !query analysis [Analyzer test output redacted due to nondeterminism] -- !query -SELECT * FROM UDTFWithSinglePartition(TABLE(t2) WITH SINGLE PARTITION) +SELECT * FROM UDTFWithSinglePartition(1, TABLE(t2)) +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +SELECT * FROM UDTFWithSinglePartition(0, TABLE(t2) WITH SINGLE PARTITION) -- !query analysis org.apache.spark.sql.AnalysisException { @@ -144,14 +150,14 @@ org.apache.spark.sql.AnalysisException "objectType" : "", "objectName" : "", "startIndex" : 15, - "stopIndex" : 70, - "fragment" : "UDTFWithSinglePartition(TABLE(t2) WITH SINGLE PARTITION)" + "stopIndex" : 73, + "fragment" : "UDTFWithSinglePartition(0, TABLE(t2) WITH SINGLE PARTITION)" } ] } -- !query -SELECT * FROM UDTFWithSinglePartition(TABLE(t2) PARTITION BY partition_col) +SELECT * FROM UDTFWithSinglePartition(0, TABLE(t2) PARTITION BY partition_col) -- !query analysis org.apache.spark.sql.AnalysisException { @@ -166,8 +172,8 @@ org.apache.spark.sql.AnalysisException "objectType" : "", "objectName" : "", "startIndex" : 15, - "stopIndex" : 75, - "fragment" : "UDTFWithSinglePartition(TABLE(t2) PARTITION BY partition_col)" + "stopIndex" : 78, + "fragment" : "UDTFWithSinglePartition(0, TABLE(t2) PARTITION BY partition_col)" } ] } @@ -176,7 +182,7 @@ org.apache.spark.sql.AnalysisException SELECT * FROM VALUES (0), (1) AS t(col) JOIN LATERAL - UDTFWithSinglePartition(TABLE(t2) PARTITION BY partition_col) + UDTFWithSinglePartition(0, TABLE(t2) PARTITION BY partition_col) -- !query analysis org.apache.spark.sql.AnalysisException { @@ -191,8 +197,8 @@ org.apache.spark.sql.AnalysisException "objectType" : "", "objectName" : "", "startIndex" : 66, - "stopIndex" : 126, - "fragment" : "UDTFWithSinglePartition(TABLE(t2) PARTITION BY partition_col)" + "stopIndex" : 129, + "fragment" : "UDTFWithSinglePartition(0, TABLE(t2) PARTITION BY partition_col)" } ] } diff --git a/sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql b/sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql index 6d49177c4f6a..6d34b91e2f16 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql @@ -47,13 +47,14 @@ SELECT * FROM -- order_by=[ -- OrderingColumn("input"), -- OrderingColumn("partition_col")]) -SELECT * FROM UDTFWithSinglePartition(TABLE(t2)); -SELECT * FROM UDTFWithSinglePartition(TABLE(t2) WITH SINGLE PARTITION); -SELECT * FROM UDTFWithSinglePartition(TABLE(t2) PARTITION BY partition_col); +SELECT * FROM UDTFWithSinglePartition(0, TABLE(t2)); +SELECT * FROM UDTFWithSinglePartition(1, TABLE(t2)); +SELECT * FROM UDTFWithSinglePartition(0, TABLE(t2) WITH SINGLE PARTITION); +SELECT * FROM UDTFWithSinglePartition(0, TABLE(t2) PARTITION BY partition_col); SELECT * FROM VALUES (0), (1) AS t(col) JOIN LATERAL - UDTFWithSinglePartition(TABLE(t2) PARTITION BY partition_col); + UDTFWithSinglePartition(0, TABLE(t2) PARTITION BY partition_col); -- As a reminder, the UDTFPartitionByOrderBy function returns this analyze result: -- AnalyzeResult( -- schema=StructType() diff --git a/sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out b/sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out index a93aac945015..11295c43d8cb 100644 --- a/sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out @@ -161,7 +161,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException -- !query -SELECT * FROM UDTFWithSinglePartition(TABLE(t2)) +SELECT * FROM UDTFWithSinglePartition(0, TABLE(t2)) -- !query schema struct<count:int,total:int,last:int> -- !query output @@ -169,7 +169,15 @@ struct<count:int,total:int,last:int> -- !query -SELECT * FROM UDTFWithSinglePartition(TABLE(t2) WITH SINGLE PARTITION) +SELECT * FROM UDTFWithSinglePartition(1, TABLE(t2)) +-- !query schema +struct<count:int,total:int,last:int> +-- !query output +3 6 3 + + +-- !query +SELECT * FROM UDTFWithSinglePartition(0, TABLE(t2) WITH SINGLE PARTITION) -- !query schema struct<> -- !query output @@ -186,14 +194,14 @@ org.apache.spark.sql.AnalysisException "objectType" : "", "objectName" : "", "startIndex" : 15, - "stopIndex" : 70, - "fragment" : "UDTFWithSinglePartition(TABLE(t2) WITH SINGLE PARTITION)" + "stopIndex" : 73, + "fragment" : "UDTFWithSinglePartition(0, TABLE(t2) WITH SINGLE PARTITION)" } ] } -- !query -SELECT * FROM UDTFWithSinglePartition(TABLE(t2) PARTITION BY partition_col) +SELECT * FROM UDTFWithSinglePartition(0, TABLE(t2) PARTITION BY partition_col) -- !query schema struct<> -- !query output @@ -210,8 +218,8 @@ org.apache.spark.sql.AnalysisException "objectType" : "", "objectName" : "", "startIndex" : 15, - "stopIndex" : 75, - "fragment" : "UDTFWithSinglePartition(TABLE(t2) PARTITION BY partition_col)" + "stopIndex" : 78, + "fragment" : "UDTFWithSinglePartition(0, TABLE(t2) PARTITION BY partition_col)" } ] } @@ -220,7 +228,7 @@ org.apache.spark.sql.AnalysisException SELECT * FROM VALUES (0), (1) AS t(col) JOIN LATERAL - UDTFWithSinglePartition(TABLE(t2) PARTITION BY partition_col) + UDTFWithSinglePartition(0, TABLE(t2) PARTITION BY partition_col) -- !query schema struct<> -- !query output @@ -237,8 +245,8 @@ org.apache.spark.sql.AnalysisException "objectType" : "", "objectName" : "", "startIndex" : 66, - "stopIndex" : 126, - "fragment" : "UDTFWithSinglePartition(TABLE(t2) PARTITION BY partition_col)" + "stopIndex" : 129, + "fragment" : "UDTFWithSinglePartition(0, TABLE(t2) PARTITION BY partition_col)" } ] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala index ef4606b70cae..3c30c414f81f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala @@ -524,8 +524,15 @@ object IntegratedUDFTestUtils extends SQLHelper { val name: String = "UDTFWithSinglePartition" val pythonScript: String = s""" + |import json + |from dataclasses import dataclass |from pyspark.sql.functions import AnalyzeResult, OrderingColumn, PartitioningColumn |from pyspark.sql.types import IntegerType, Row, StructType + | + |@dataclass + |class AnalyzeResultWithBuffer(AnalyzeResult): + | buffer: str = "" + | |class $name: | def __init__(self): | self._count = 0 @@ -533,8 +540,14 @@ object IntegratedUDFTestUtils extends SQLHelper { | self._last = None | | @staticmethod - | def analyze(self): - | return AnalyzeResult( + | def analyze(initial_count, input_table): + | buffer = "" + | if initial_count.value is not None: + | assert(not initial_count.is_table) + | assert(initial_count.data_type == IntegerType()) + | count = initial_count.value + | buffer = json.dumps({"initial_count": count}) + | return AnalyzeResultWithBuffer( | schema=StructType() | .add("count", IntegerType()) | .add("total", IntegerType()) @@ -542,9 +555,10 @@ object IntegratedUDFTestUtils extends SQLHelper { | with_single_partition=True, | order_by=[ | OrderingColumn("input"), - | OrderingColumn("partition_col")]) + | OrderingColumn("partition_col")], + | buffer=buffer) | - | def eval(self, row: Row): + | def eval(self, initial_count, row): | self._count += 1 | self._last = row["input"] | self._sum += row["input"] @@ -693,6 +707,48 @@ object IntegratedUDFTestUtils extends SQLHelper { "without a corresponding partitioning table requirement" } + object TestPythonUDTFForwardStateFromAnalyze extends TestUDTF { + val name: String = "TestPythonUDTFForwardStateFromAnalyze" + val pythonScript: String = + s""" + |from dataclasses import dataclass + |from pyspark.sql.functions import AnalyzeResult + |from pyspark.sql.types import StringType, StructType + | + |@dataclass + |class AnalyzeResultWithBuffer(AnalyzeResult): + | buffer: str = "" + | + |class $name: + | def __init__(self, analyze_result): + | self._analyze_result = analyze_result + | + | @staticmethod + | def analyze(argument): + | assert(argument.data_type == StringType()) + | return AnalyzeResultWithBuffer( + | schema=StructType() + | .add("result", StringType()), + | buffer=argument.value) + | + | def eval(self, argument): + | pass + | + | def terminate(self): + | yield self._analyze_result.buffer, + |""".stripMargin + + val udtf: UserDefinedPythonTableFunction = createUserDefinedPythonTableFunction( + name = name, + pythonScript = pythonScript, + returnType = None) + + def apply(session: SparkSession, exprs: Column*): DataFrame = + udtf.apply(session, exprs: _*) + + val prettyName: String = "Python UDTF whose 'analyze' method sets state and reads it later" + } + /** * A Scalar Pandas UDF that takes one column, casts into string, executes the * Python native function, and casts back to the type of input column. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala index cdc3ef9e4178..efab685236de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala @@ -48,15 +48,15 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession { private val pythonUDTFCountSumLast: UserDefinedPythonTableFunction = createUserDefinedPythonTableFunction( - "UDTFCountSumLast", TestPythonUDTFCountSumLast.pythonScript, None) + TestPythonUDTFCountSumLast.name, TestPythonUDTFCountSumLast.pythonScript, None) private val pythonUDTFWithSinglePartition: UserDefinedPythonTableFunction = createUserDefinedPythonTableFunction( - "UDTFWithSinglePartition", TestPythonUDTFWithSinglePartition.pythonScript, None) + TestPythonUDTFWithSinglePartition.name, TestPythonUDTFWithSinglePartition.pythonScript, None) private val pythonUDTFPartitionByOrderBy: UserDefinedPythonTableFunction = createUserDefinedPythonTableFunction( - "UDTFPartitionByOrderBy", TestPythonUDTFPartitionBy.pythonScript, None) + TestPythonUDTFPartitionBy.name, TestPythonUDTFPartitionBy.pythonScript, None) private val arrowPythonUDTF: UserDefinedPythonTableFunction = createUserDefinedPythonTableFunction( @@ -65,6 +65,11 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession { Some(returnType), evalType = PythonEvalType.SQL_ARROW_TABLE_UDF) + private val pythonUDTFForwardStateFromAnalyze: UserDefinedPythonTableFunction = + createUserDefinedPythonTableFunction( + TestPythonUDTFForwardStateFromAnalyze.name, + TestPythonUDTFForwardStateFromAnalyze.pythonScript, None) + test("Simple PythonUDTF") { assume(shouldTestPythonUDFs) val df = pythonUDTF(spark, lit(1), lit(2)) @@ -200,14 +205,14 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession { stop = 29)) } - spark.udtf.registerPython("UDTFCountSumLast", pythonUDTFCountSumLast) + spark.udtf.registerPython(TestPythonUDTFCountSumLast.name, pythonUDTFCountSumLast) var plan = sql( - """ + s""" |WITH t AS ( | VALUES (0, 1), (1, 2), (1, 3) t(partition_col, input) |) |SELECT count, total, last - |FROM UDTFCountSumLast(TABLE(t) WITH SINGLE PARTITION) + |FROM ${TestPythonUDTFCountSumLast.name}(TABLE(t) WITH SINGLE PARTITION) |ORDER BY 1, 2 |""".stripMargin).queryExecution.analyzed plan.collectFirst { case r: Repartition => r } match { @@ -216,16 +221,16 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession { failure(plan) } - spark.udtf.registerPython("UDTFWithSinglePartition", pythonUDTFWithSinglePartition) + spark.udtf.registerPython(TestPythonUDTFWithSinglePartition.name, pythonUDTFWithSinglePartition) plan = sql( - """ + s""" |WITH t AS ( | SELECT id AS partition_col, 1 AS input FROM range(1, 21) | UNION ALL | SELECT id AS partition_col, 2 AS input FROM range(1, 21) |) |SELECT count, total, last - |FROM UDTFWithSinglePartition(TABLE(t)) + |FROM ${TestPythonUDTFWithSinglePartition.name}(0, TABLE(t)) |ORDER BY 1, 2 |""".stripMargin).queryExecution.analyzed plan.collectFirst { case r: Repartition => r } match { @@ -234,16 +239,16 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession { failure(plan) } - spark.udtf.registerPython("UDTFPartitionByOrderBy", pythonUDTFPartitionByOrderBy) + spark.udtf.registerPython(TestPythonUDTFPartitionBy.name, pythonUDTFPartitionByOrderBy) plan = sql( - """ + s""" |WITH t AS ( | SELECT id AS partition_col, 1 AS input FROM range(1, 21) | UNION ALL | SELECT id AS partition_col, 2 AS input FROM range(1, 21) |) |SELECT partition_col, count, total, last - |FROM UDTFPartitionByOrderBy(TABLE(t)) + |FROM ${TestPythonUDTFPartitionBy.name}(TABLE(t)) |ORDER BY 1, 2 |""".stripMargin).queryExecution.analyzed plan.collectFirst { case r: RepartitionByExpression => r } match { @@ -345,4 +350,17 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession { Literal("abc"))) == Seq(2, 3)) } + + test("SPARK-45402: Add UDTF API for 'analyze' to return a buffer to consume on class creation") { + spark.udtf.registerPython( + TestPythonUDTFForwardStateFromAnalyze.name, + pythonUDTFForwardStateFromAnalyze) + withTable("t") { + sql("create table t(col array<int>) using parquet") + val query = s"select * from ${TestPythonUDTFForwardStateFromAnalyze.name}('abc')" + checkAnswer( + sql(query), + Row("abc")) + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org