This is an automated email from the ASF dual-hosted git repository.
ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 69cf80d25f0e [SPARK-45402][SQL][PYTHON] Add UDTF API for 'eval' and
'terminate' methods to consume previous 'analyze' result
69cf80d25f0e is described below
commit 69cf80d25f0e4ed46ec38a63e063471988c31732
Author: Daniel Tenedorio <[email protected]>
AuthorDate: Wed Oct 11 18:52:06 2023 -0700
[SPARK-45402][SQL][PYTHON] Add UDTF API for 'eval' and 'terminate' methods
to consume previous 'analyze' result
### What changes were proposed in this pull request?
This PR adds a Python UDTF API for the `eval` and `terminate` methods to
consume the previous `analyze` result.
This also works for subclasses of the `AnalyzeResult` class, allowing the
UDTF to return custom state from `analyze` to be consumed later.
For example, we can now define a UDTF that perform complex initialization
in the `analyze` method and then returns the result of that in the `terminate`
method:
```
def MyUDTF(self):
dataclass
class AnalyzeResultWithBuffer(AnalyzeResult):
buffer: str
udtf
class TestUDTF:
def __init__(self, analyze_result):
self._total = 0
self._buffer = do_complex_initialization(analyze_result.buffer)
staticmethod
def analyze(argument, _):
return AnalyzeResultWithBuffer(
schema=StructType()
.add("total", IntegerType())
.add("buffer", StringType()),
with_single_partition=True,
buffer=argument.value,
)
def eval(self, argument, row: Row):
self._total += 1
def terminate(self):
yield self._total, self._buffer
self.spark.udtf.register("my_ddtf", MyUDTF)
```
Then the results might look like:
```
sql(
"""
WITH t AS (
SELECT id FROM range(1, 21)
)
SELECT total, buffer
FROM test_udtf("abc", TABLE(t))
"""
).collect()
> 20, "complex_initialization_result"
```
### Why are the changes needed?
In this way, the UDTF can perform potentially expensive initialization
logic in the `analyze` method just once and result the result of such
initialization rather than repeating the initialization in `eval`.
### Does this PR introduce _any_ user-facing change?
Yes, see above.
### How was this patch tested?
This PR adds new unit test coverage.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #43204 from dtenedor/prepare-string.
Authored-by: Daniel Tenedorio <[email protected]>
Signed-off-by: Takuya UESHIN <[email protected]>
---
python/docs/source/user_guide/sql/python_udtf.rst | 124 ++++++++++++++++++++-
python/pyspark/sql/tests/test_udtf.py | 53 +++++++++
python/pyspark/sql/udtf.py | 5 +-
python/pyspark/sql/worker/analyze_udtf.py | 2 +
python/pyspark/worker.py | 34 +++++-
.../spark/sql/catalyst/analysis/Analyzer.scala | 5 +-
.../spark/sql/catalyst/expressions/PythonUDF.scala | 20 +++-
.../execution/python/BatchEvalPythonUDTFExec.scala | 8 ++
.../python/UserDefinedPythonFunction.scala | 7 +-
.../sql-tests/analyzer-results/udtf/udtf.sql.out | 26 +++--
.../test/resources/sql-tests/inputs/udtf/udtf.sql | 9 +-
.../resources/sql-tests/results/udtf/udtf.sql.out | 28 +++--
.../apache/spark/sql/IntegratedUDFTestUtils.scala | 64 ++++++++++-
.../sql/execution/python/PythonUDTFSuite.scala | 42 +++++--
14 files changed, 374 insertions(+), 53 deletions(-)
diff --git a/python/docs/source/user_guide/sql/python_udtf.rst
b/python/docs/source/user_guide/sql/python_udtf.rst
index 74d8eb889861..fb42644dc702 100644
--- a/python/docs/source/user_guide/sql/python_udtf.rst
+++ b/python/docs/source/user_guide/sql/python_udtf.rst
@@ -50,10 +50,108 @@ To implement a Python UDTF, you first need to define a
class implementing the me
Notes
-----
- - This method does not accept any extra arguments. Only the default
- constructor is supported.
- You cannot create or reference the Spark session within the
UDTF. Any
attempt to do so will result in a serialization error.
+ - If the below `analyze` method is implemented, it is also
possible to define this
+ method as: `__init__(self, analyze_result: AnalyzeResult)`. In
this case, the result
+ of the `analyze` method is passed into all future instantiations
of this UDTF class.
+ In this way, the UDTF may inspect the schema and metadata of the
output table as
+ needed during execution of other methods in this class. Note
that it is possible to
+ create a subclass of the `AnalyzeResult` class if desired for
purposes of passing
+ custom information generated just once during UDTF analysis to
other method calls;
+ this can be especially useful if this initialization is
expensive.
+ """
+ ...
+
+ def analyze(self, *args: Any) -> AnalyzeResult:
+ """
+ Computes the output schema of a particular call to this function
in response to the
+ arguments provided.
+
+ This method is optional and only needed if the registration of the
UDTF did not provide
+ a static output schema to be use for all calls to the function. In
this context,
+ `output schema` refers to the ordered list of the names and types
of the columns in the
+ function's result table.
+
+ This method accepts zero or more parameters mapping 1:1 with the
arguments provided to
+ the particular UDTF call under consideration. Each parameter is an
instance of the
+ `AnalyzeArgument` class, which contains fields including the
provided argument's data
+ type and value (in the case of literal scalar arguments only). For
table arguments, the
+ `is_table` field is set to true and the `data_type` field is a
StructType representing
+ the table's column types:
+
+ data_type: DataType
+ value: Optional[Any]
+ is_table: bool
+
+ This method returns an instance of the `AnalyzeResult` class which
includes the result
+ table's schema as a StructType. If the UDTF accepts an input table
argument, then the
+ `AnalyzeResult` can also include a requested way to partition the
rows of the input
+ table across several UDTF calls. If `with_single_partition` is set
to True, the query
+ planner will arrange a repartitioning operation from the previous
execution stage such
+ that all rows of the input table are consumed by the `eval` method
from exactly one
+ instance of the UDTF class. On the other hand, if the
`partition_by` list is non-empty,
+ the query planner will arrange a repartitioning such that all rows
with each unique
+ combination of values of the partitioning columns are consumed by
a separate unique
+ instance of the UDTF class. If `order_by` is non-empty, this
specifies the requested
+ ordering of rows within each partition.
+
+ schema: StructType
+ with_single_partition: bool = False
+ partition_by: Sequence[PartitioningColumn] =
field(default_factory=tuple)
+ order_by: Sequence[OrderingColumn] =
field(default_factory=tuple)
+
+ Examples
+ --------
+ analyze implementation that returns one output column for each
word in the input string
+ argument.
+
+ >>> def analyze(self, text: str) -> AnalyzeResult:
+ ... schema = StructType()
+ ... for index, word in enumerate(text.split(" ")):
+ ... schema = schema.add(f"word_{index}")
+ ... return AnalyzeResult(schema=schema)
+
+ Same as above, but using *args to accept the arguments.
+
+ >>> def analyze(self, *args) -> AnalyzeResult:
+ ... assert len(args) == 1, "This function accepts one argument
only"
+ ... assert args[0].data_type == StringType(), "Only string
arguments are supported"
+ ... text = args[0]
+ ... schema = StructType()
+ ... for index, word in enumerate(text.split(" ")):
+ ... schema = schema.add(f"word_{index}")
+ ... return AnalyzeResult(schema=schema)
+
+ Same as above, but using **kwargs to accept the arguments.
+
+ >>> def analyze(self, **kwargs) -> AnalyzeResult:
+ ... assert len(kwargs) == 1, "This function accepts one
argument only"
+ ... assert "text" in kwargs, "An argument named 'text' is
required"
+ ... assert kwargs["text"].data_type == StringType(), "Only
strings are supported"
+ ... text = args["text"]
+ ... schema = StructType()
+ ... for index, word in enumerate(text.split(" ")):
+ ... schema = schema.add(f"word_{index}")
+ ... return AnalyzeResult(schema=schema)
+
+ analyze implementation that returns a constant output schema, but
add custom information
+ in the result metadata to be consumed by future __init__ method
calls:
+
+ >>> def analyze(self, text: str) -> AnalyzeResult:
+ ... @dataclass
+ ... class AnalyzeResultWithOtherMetadata(AnalyzeResult):
+ ... num_words: int
+ ... num_articles: int
+ ... words = text.split(" ")
+ ... return AnalyzeResultWithOtherMetadata(
+ ... schema=StructType()
+ ... .add("word", StringType())
+ ... .add('total", IntegerType()),
+ ... num_words=len(words),
+ ... num_articles=len((
+ ... word for word in words
+ ... if word == 'a' or word == 'an' or word == 'the')))
"""
...
@@ -89,7 +187,9 @@ To implement a Python UDTF, you first need to define a class
implementing the me
-----
- The result of the function must be a tuple representing a single
row
in the UDTF result table.
- - UDTFs currently do not accept keyword arguments during the
function call.
+ - It is also possible for UDTFs to accept the exact arguments
expected, along with
+ their types.
+ - UDTFs can instead accept keyword arguments during the function
call if needed.
Examples
--------
@@ -103,6 +203,24 @@ To implement a Python UDTF, you first need to define a
class implementing the me
>>> def eval(self, x: int, y: int):
... yield (x + y, x - y)
... yield (y + x, y - x)
+
+ Same as above, but using *args to accept the arguments:
+
+ >>> def eval(self, *args):
+ ... assert len(args) == 2, "This function accepts two integer
arguments only"
+ ... x = args[0]
+ ... y = args[1]
+ ... yield (x + y, x - y)
+ ... yield (y + x, y - x)
+
+ Same as above, but using **kwargs to accept the arguments:
+
+ >>> def eval(self, **kwargs):
+ ... assert len(kwargs) == 2, "This function accepts two
integer arguments only"
+ ... x = kwargs["x"]
+ ... y = kwargs["y"]
+ ... yield (x + y, x - y)
+ ... yield (y + x, y - x)
"""
...
diff --git a/python/pyspark/sql/tests/test_udtf.py
b/python/pyspark/sql/tests/test_udtf.py
index 9c821f4bde9c..98676bd7be49 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -18,6 +18,7 @@ import os
import shutil
import tempfile
import unittest
+from dataclasses import dataclass
from typing import Iterator
from py4j.protocol import Py4JJavaError
@@ -2365,6 +2366,58 @@ class BaseUDTFTestsMixin:
+ [Row(partition_col=42, count=3, total=3, last=None)],
)
+ def test_udtf_with_prepare_string_from_analyze(self):
+ @dataclass
+ class AnalyzeResultWithBuffer(AnalyzeResult):
+ buffer: str = ""
+
+ @udtf
+ class TestUDTF:
+ def __init__(self, analyze_result=None):
+ self._total = 0
+ if analyze_result is not None:
+ self._buffer = analyze_result.buffer
+ else:
+ self._buffer = ""
+
+ @staticmethod
+ def analyze(argument, _):
+ if (
+ argument.value is None
+ or argument.is_table
+ or not isinstance(argument.value, str)
+ or len(argument.value) == 0
+ ):
+ raise Exception("The first argument must be non-empty
string")
+ assert argument.data_type == StringType()
+ assert not argument.is_table
+ return AnalyzeResultWithBuffer(
+ schema=StructType().add("total",
IntegerType()).add("buffer", StringType()),
+ with_single_partition=True,
+ buffer=argument.value,
+ )
+
+ def eval(self, argument, row: Row):
+ self._total += 1
+
+ def terminate(self):
+ yield self._total, self._buffer
+
+ self.spark.udtf.register("test_udtf", TestUDTF)
+
+ assertDataFrameEqual(
+ self.spark.sql(
+ """
+ WITH t AS (
+ SELECT id FROM range(1, 21)
+ )
+ SELECT total, buffer
+ FROM test_udtf("abc", TABLE(t))
+ """
+ ).collect(),
+ [Row(count=20, buffer="abc")],
+ )
+
class UDTFTests(BaseUDTFTestsMixin, ReusedSQLTestCase):
@classmethod
diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py
index ba4bac2ffdfa..26ce68111db8 100644
--- a/python/pyspark/sql/udtf.py
+++ b/python/pyspark/sql/udtf.py
@@ -85,7 +85,10 @@ class OrderingColumn:
overrideNullsFirst: Optional[bool] = None
-@dataclass(frozen=True)
+# Note: this class is a "dataclass" for purposes of convenience, but it is not
marked "frozen"
+# because the intention is that users may create subclasses of it for purposes
of returning custom
+# information from the "analyze" method.
+@dataclass
class AnalyzeResult:
"""
The return of Python UDTF's analyze static method.
diff --git a/python/pyspark/sql/worker/analyze_udtf.py
b/python/pyspark/sql/worker/analyze_udtf.py
index 6fb3ca995e5d..a6aa381eb14a 100644
--- a/python/pyspark/sql/worker/analyze_udtf.py
+++ b/python/pyspark/sql/worker/analyze_udtf.py
@@ -126,6 +126,8 @@ def main(infile: IO, outfile: IO) -> None:
# Return the analyzed schema.
write_with_length(result.schema.json().encode("utf-8"), outfile)
+ # Return the pickled 'AnalyzeResult' class instance.
+ pickleSer._write_with_length(result, outfile)
# Return whether the "with single partition" property is requested.
write_int(1 if result.with_single_partition else 0, outfile)
# Return the list of partitioning columns, if any.
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 3d08f6c4baea..df7dd1bc2f73 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -20,6 +20,7 @@ Worker that receives input from Piped RDD.
"""
import os
import sys
+import dataclasses
import time
from inspect import getfullargspec
import json
@@ -666,7 +667,7 @@ def read_udtf(pickleSer, infile, eval_type):
# Each row is a group so do not batch but send one by one.
ser = BatchedSerializer(CPickleSerializer(), 1)
- # See `PythonUDTFRunner.PythonUDFWriterThread.writeCommand'
+ # See 'PythonUDTFRunner.PythonUDFWriterThread.writeCommand'
num_arg = read_int(infile)
args_offsets = []
kwargs_offsets = {}
@@ -679,6 +680,14 @@ def read_udtf(pickleSer, infile, eval_type):
args_offsets.append(offset)
num_partition_child_indexes = read_int(infile)
partition_child_indexes = [read_int(infile) for i in
range(num_partition_child_indexes)]
+ has_pickled_analyze_result = read_bool(infile)
+ if has_pickled_analyze_result:
+ pickled_analyze_result = pickleSer._read_with_length(infile)
+ else:
+ pickled_analyze_result = None
+ # Initially we assume that the UDTF __init__ method accepts the pickled
AnalyzeResult,
+ # although we may set this to false later if we find otherwise.
+ udtf_init_method_accepts_analyze_result = True
handler = read_command(pickleSer, infile)
if not isinstance(handler, type):
raise PySparkRuntimeError(
@@ -692,6 +701,29 @@ def read_udtf(pickleSer, infile, eval_type):
f"The return type of a UDTF must be a struct type, but got
{type(return_type)}."
)
+ # Update the handler that creates a new UDTF instance to first try calling
the UDTF constructor
+ # with one argument containing the previous AnalyzeResult. If that fails,
then try a constructor
+ # with no arguments. In this way each UDTF class instance can decide if it
wants to inspect the
+ # AnalyzeResult.
+ if has_pickled_analyze_result:
+ prev_handler = handler
+
+ def construct_udtf():
+ nonlocal udtf_init_method_accepts_analyze_result
+ if not udtf_init_method_accepts_analyze_result:
+ return prev_handler()
+ else:
+ try:
+ # Here we pass the AnalyzeResult to the UDTF's __init__
method.
+ return
prev_handler(dataclasses.replace(pickled_analyze_result))
+ except TypeError:
+ # This means that the UDTF handler does not accept an
AnalyzeResult object in
+ # its __init__ method.
+ udtf_init_method_accepts_analyze_result = False
+ return prev_handler()
+
+ handler = construct_udtf
+
class UDTFWithPartitions:
"""
This implements the logic of a UDTF that accepts an input TABLE
argument with one or more
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index cc0bfd3fc31b..18a0aec8fc61 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -2229,8 +2229,9 @@ class Analyzer(override val catalogManager:
CatalogManager) extends RuleExecutor
analyzeResult.applyToTableArgument(u.name, t)
case c => c
}
- PythonUDTF(u.name, u.func, analyzeResult.schema, newChildren,
- u.evalType, u.udfDeterministic, u.resultId)
+ PythonUDTF(
+ u.name, u.func, analyzeResult.schema,
Some(analyzeResult.pickledAnalyzeResult),
+ newChildren, u.evalType, u.udfDeterministic, u.resultId)
}
}
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
index 539505543a40..f886b50e8a23 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
@@ -159,6 +159,10 @@ abstract class UnevaluableGenerator extends Generator {
* @param name name of the Python UDTF being called
* @param func string contents of the Python code in the UDTF, along with
other environment state
* @param elementSchema result schema of the function call
+ * @param pickledAnalyzeResult if the UDTF defined an 'analyze' method, this
contains the pickled
+ * 'AnalyzeResult' instance from that method,
which contains all
+ * metadata returned including the result schema
of the function call as
+ * well as optional other information
* @param children input arguments to the UDTF call; for scalar arguments
these are the expressions
* themeselves, and for TABLE arguments, these are instances of
* [[FunctionTableSubqueryArgumentExpression]]
@@ -167,15 +171,15 @@ abstract class UnevaluableGenerator extends Generator {
* @param udfDeterministic true if this function is deterministic wherein it
returns the same result
* rows for every call with the same input arguments
* @param resultId unique expression ID for this function invocation
- * @param pythonUDTFPartitionColumnIndexes holds the indexes of the TABLE
argument to the Python
- * UDTF call, if applicable
- * @param analyzeResult holds the result of the polymorphic Python UDTF
'analze' method, if the UDTF
- * defined one
+ * @param pythonUDTFPartitionColumnIndexes holds the zero-based indexes of the
projected results of
+ * all PARTITION BY expressions within
the TABLE argument of
+ * the Python UDTF call, if applicable
*/
case class PythonUDTF(
name: String,
func: PythonFunction,
elementSchema: StructType,
+ pickledAnalyzeResult: Option[Array[Byte]],
children: Seq[Expression],
evalType: Int,
udfDeterministic: Boolean,
@@ -224,6 +228,7 @@ case class UnresolvedPolymorphicPythonUDTF(
/**
* Represents the result of invoking the polymorphic 'analyze' method on a
Python user-defined table
* function. This returns the table function's output schema in addition to
other optional metadata.
+ *
* @param schema result schema of this particular function call in response to
the particular
* arguments provided, including the types of any provided
scalar arguments (and
* their values, in the case of literals) as well as the names
and types of columns of
@@ -241,12 +246,17 @@ case class UnresolvedPolymorphicPythonUDTF(
* @param orderByExpressions if non-empty, this contains the list of ordering
items that the
* 'analyze' method explicitly indicated that the
UDTF call should consume
* the input table rows by
+ * @param pickledAnalyzeResult this is the pickled 'AnalyzeResult' instance
from the UDTF, which
+ * contains all metadata returned by the Python
UDTF 'analyze' method
+ * including the result schema of the function
call as well as optional
+ * other information
*/
case class PythonUDTFAnalyzeResult(
schema: StructType,
withSinglePartition: Boolean,
partitionByExpressions: Seq[Expression],
- orderByExpressions: Seq[SortOrder]) {
+ orderByExpressions: Seq[SortOrder],
+ pickledAnalyzeResult: Array[Byte]) {
/**
* Applies the requested properties from this analysis result to the target
TABLE argument
* expression of a UDTF call, throwing an error if any properties of the
UDTF call are
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
index a70d16dc7e89..40993f96e7a0 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
@@ -112,6 +112,7 @@ object PythonUDTFRunner {
dataOut: DataOutputStream,
udtf: PythonUDTF,
argMetas: Array[ArgumentMetadata]): Unit = {
+ // Write the argument types of the UDTF.
dataOut.writeInt(argMetas.length)
argMetas.foreach {
case ArgumentMetadata(offset, name) =>
@@ -124,6 +125,8 @@ object PythonUDTFRunner {
dataOut.writeBoolean(false)
}
}
+ // Write the zero-based indexes of the projected results of all PARTITION
BY expressions within
+ // the TABLE argument of the Python UDTF call, if applicable.
udtf.pythonUDTFPartitionColumnIndexes match {
case Some(partitionColumnIndexes) =>
dataOut.writeInt(partitionColumnIndexes.partitionChildIndexes.length)
@@ -132,7 +135,12 @@ object PythonUDTFRunner {
case None =>
dataOut.writeInt(0)
}
+ // Write the pickled AnalyzeResult buffer from the UDTF "analyze" method,
if any.
+ dataOut.writeBoolean(udtf.pickledAnalyzeResult.nonEmpty)
+ udtf.pickledAnalyzeResult.foreach(PythonWorkerUtils.writeBytes(_, dataOut))
+ // Write the contents of the Python script itself.
PythonWorkerUtils.writePythonFunction(udtf.func, dataOut)
+ // Write the result schema of the UDTF call.
PythonWorkerUtils.writeUTF(udtf.elementSchema.json, dataOut)
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
index b03942cdf43c..d8d3cc9b7fc4 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
@@ -129,6 +129,7 @@ case class UserDefinedPythonTableFunction(
name = name,
func = func,
elementSchema = rt,
+ pickledAnalyzeResult = None,
children = exprs,
evalType = pythonEvalType,
udfDeterministic = udfDeterministic)
@@ -283,6 +284,9 @@ object UserDefinedPythonTableFunction {
val schema = DataType.fromJson(
PythonWorkerUtils.readUTF(length, dataIn)).asInstanceOf[StructType]
+ // Receive the pickled AnalyzeResult buffer, if any.
+ val pickledAnalyzeResult: Array[Byte] =
PythonWorkerUtils.readBytes(dataIn)
+
// Receive whether the "with single partition" property is requested.
val withSinglePartition = dataIn.readInt() == 1
// Receive the list of requested partitioning columns, if any.
@@ -324,7 +328,8 @@ object UserDefinedPythonTableFunction {
schema = schema,
withSinglePartition = withSinglePartition,
partitionByExpressions = partitionByColumns.toSeq,
- orderByExpressions = orderBy.toSeq)
+ orderByExpressions = orderBy.toSeq,
+ pickledAnalyzeResult = pickledAnalyzeResult)
} catch {
case eof: EOFException =>
throw new SparkException("Python worker exited unexpectedly
(crashed)", eof)
diff --git
a/sql/core/src/test/resources/sql-tests/analyzer-results/udtf/udtf.sql.out
b/sql/core/src/test/resources/sql-tests/analyzer-results/udtf/udtf.sql.out
index f7b2bada26ec..1b923442207e 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/udtf/udtf.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/udtf/udtf.sql.out
@@ -123,13 +123,19 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
-- !query
-SELECT * FROM UDTFWithSinglePartition(TABLE(t2))
+SELECT * FROM UDTFWithSinglePartition(0, TABLE(t2))
-- !query analysis
[Analyzer test output redacted due to nondeterminism]
-- !query
-SELECT * FROM UDTFWithSinglePartition(TABLE(t2) WITH SINGLE PARTITION)
+SELECT * FROM UDTFWithSinglePartition(1, TABLE(t2))
+-- !query analysis
+[Analyzer test output redacted due to nondeterminism]
+
+
+-- !query
+SELECT * FROM UDTFWithSinglePartition(0, TABLE(t2) WITH SINGLE PARTITION)
-- !query analysis
org.apache.spark.sql.AnalysisException
{
@@ -144,14 +150,14 @@ org.apache.spark.sql.AnalysisException
"objectType" : "",
"objectName" : "",
"startIndex" : 15,
- "stopIndex" : 70,
- "fragment" : "UDTFWithSinglePartition(TABLE(t2) WITH SINGLE PARTITION)"
+ "stopIndex" : 73,
+ "fragment" : "UDTFWithSinglePartition(0, TABLE(t2) WITH SINGLE PARTITION)"
} ]
}
-- !query
-SELECT * FROM UDTFWithSinglePartition(TABLE(t2) PARTITION BY partition_col)
+SELECT * FROM UDTFWithSinglePartition(0, TABLE(t2) PARTITION BY partition_col)
-- !query analysis
org.apache.spark.sql.AnalysisException
{
@@ -166,8 +172,8 @@ org.apache.spark.sql.AnalysisException
"objectType" : "",
"objectName" : "",
"startIndex" : 15,
- "stopIndex" : 75,
- "fragment" : "UDTFWithSinglePartition(TABLE(t2) PARTITION BY
partition_col)"
+ "stopIndex" : 78,
+ "fragment" : "UDTFWithSinglePartition(0, TABLE(t2) PARTITION BY
partition_col)"
} ]
}
@@ -176,7 +182,7 @@ org.apache.spark.sql.AnalysisException
SELECT * FROM
VALUES (0), (1) AS t(col)
JOIN LATERAL
- UDTFWithSinglePartition(TABLE(t2) PARTITION BY partition_col)
+ UDTFWithSinglePartition(0, TABLE(t2) PARTITION BY partition_col)
-- !query analysis
org.apache.spark.sql.AnalysisException
{
@@ -191,8 +197,8 @@ org.apache.spark.sql.AnalysisException
"objectType" : "",
"objectName" : "",
"startIndex" : 66,
- "stopIndex" : 126,
- "fragment" : "UDTFWithSinglePartition(TABLE(t2) PARTITION BY
partition_col)"
+ "stopIndex" : 129,
+ "fragment" : "UDTFWithSinglePartition(0, TABLE(t2) PARTITION BY
partition_col)"
} ]
}
diff --git a/sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql
b/sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql
index 6d49177c4f6a..6d34b91e2f16 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql
@@ -47,13 +47,14 @@ SELECT * FROM
-- order_by=[
-- OrderingColumn("input"),
-- OrderingColumn("partition_col")])
-SELECT * FROM UDTFWithSinglePartition(TABLE(t2));
-SELECT * FROM UDTFWithSinglePartition(TABLE(t2) WITH SINGLE PARTITION);
-SELECT * FROM UDTFWithSinglePartition(TABLE(t2) PARTITION BY partition_col);
+SELECT * FROM UDTFWithSinglePartition(0, TABLE(t2));
+SELECT * FROM UDTFWithSinglePartition(1, TABLE(t2));
+SELECT * FROM UDTFWithSinglePartition(0, TABLE(t2) WITH SINGLE PARTITION);
+SELECT * FROM UDTFWithSinglePartition(0, TABLE(t2) PARTITION BY partition_col);
SELECT * FROM
VALUES (0), (1) AS t(col)
JOIN LATERAL
- UDTFWithSinglePartition(TABLE(t2) PARTITION BY partition_col);
+ UDTFWithSinglePartition(0, TABLE(t2) PARTITION BY partition_col);
-- As a reminder, the UDTFPartitionByOrderBy function returns this analyze
result:
-- AnalyzeResult(
-- schema=StructType()
diff --git a/sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out
b/sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out
index a93aac945015..11295c43d8cb 100644
--- a/sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out
@@ -161,7 +161,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
-- !query
-SELECT * FROM UDTFWithSinglePartition(TABLE(t2))
+SELECT * FROM UDTFWithSinglePartition(0, TABLE(t2))
-- !query schema
struct<count:int,total:int,last:int>
-- !query output
@@ -169,7 +169,15 @@ struct<count:int,total:int,last:int>
-- !query
-SELECT * FROM UDTFWithSinglePartition(TABLE(t2) WITH SINGLE PARTITION)
+SELECT * FROM UDTFWithSinglePartition(1, TABLE(t2))
+-- !query schema
+struct<count:int,total:int,last:int>
+-- !query output
+3 6 3
+
+
+-- !query
+SELECT * FROM UDTFWithSinglePartition(0, TABLE(t2) WITH SINGLE PARTITION)
-- !query schema
struct<>
-- !query output
@@ -186,14 +194,14 @@ org.apache.spark.sql.AnalysisException
"objectType" : "",
"objectName" : "",
"startIndex" : 15,
- "stopIndex" : 70,
- "fragment" : "UDTFWithSinglePartition(TABLE(t2) WITH SINGLE PARTITION)"
+ "stopIndex" : 73,
+ "fragment" : "UDTFWithSinglePartition(0, TABLE(t2) WITH SINGLE PARTITION)"
} ]
}
-- !query
-SELECT * FROM UDTFWithSinglePartition(TABLE(t2) PARTITION BY partition_col)
+SELECT * FROM UDTFWithSinglePartition(0, TABLE(t2) PARTITION BY partition_col)
-- !query schema
struct<>
-- !query output
@@ -210,8 +218,8 @@ org.apache.spark.sql.AnalysisException
"objectType" : "",
"objectName" : "",
"startIndex" : 15,
- "stopIndex" : 75,
- "fragment" : "UDTFWithSinglePartition(TABLE(t2) PARTITION BY
partition_col)"
+ "stopIndex" : 78,
+ "fragment" : "UDTFWithSinglePartition(0, TABLE(t2) PARTITION BY
partition_col)"
} ]
}
@@ -220,7 +228,7 @@ org.apache.spark.sql.AnalysisException
SELECT * FROM
VALUES (0), (1) AS t(col)
JOIN LATERAL
- UDTFWithSinglePartition(TABLE(t2) PARTITION BY partition_col)
+ UDTFWithSinglePartition(0, TABLE(t2) PARTITION BY partition_col)
-- !query schema
struct<>
-- !query output
@@ -237,8 +245,8 @@ org.apache.spark.sql.AnalysisException
"objectType" : "",
"objectName" : "",
"startIndex" : 66,
- "stopIndex" : 126,
- "fragment" : "UDTFWithSinglePartition(TABLE(t2) PARTITION BY
partition_col)"
+ "stopIndex" : 129,
+ "fragment" : "UDTFWithSinglePartition(0, TABLE(t2) PARTITION BY
partition_col)"
} ]
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
index ef4606b70cae..3c30c414f81f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
@@ -524,8 +524,15 @@ object IntegratedUDFTestUtils extends SQLHelper {
val name: String = "UDTFWithSinglePartition"
val pythonScript: String =
s"""
+ |import json
+ |from dataclasses import dataclass
|from pyspark.sql.functions import AnalyzeResult, OrderingColumn,
PartitioningColumn
|from pyspark.sql.types import IntegerType, Row, StructType
+ |
+ |@dataclass
+ |class AnalyzeResultWithBuffer(AnalyzeResult):
+ | buffer: str = ""
+ |
|class $name:
| def __init__(self):
| self._count = 0
@@ -533,8 +540,14 @@ object IntegratedUDFTestUtils extends SQLHelper {
| self._last = None
|
| @staticmethod
- | def analyze(self):
- | return AnalyzeResult(
+ | def analyze(initial_count, input_table):
+ | buffer = ""
+ | if initial_count.value is not None:
+ | assert(not initial_count.is_table)
+ | assert(initial_count.data_type == IntegerType())
+ | count = initial_count.value
+ | buffer = json.dumps({"initial_count": count})
+ | return AnalyzeResultWithBuffer(
| schema=StructType()
| .add("count", IntegerType())
| .add("total", IntegerType())
@@ -542,9 +555,10 @@ object IntegratedUDFTestUtils extends SQLHelper {
| with_single_partition=True,
| order_by=[
| OrderingColumn("input"),
- | OrderingColumn("partition_col")])
+ | OrderingColumn("partition_col")],
+ | buffer=buffer)
|
- | def eval(self, row: Row):
+ | def eval(self, initial_count, row):
| self._count += 1
| self._last = row["input"]
| self._sum += row["input"]
@@ -693,6 +707,48 @@ object IntegratedUDFTestUtils extends SQLHelper {
"without a corresponding partitioning table requirement"
}
+ object TestPythonUDTFForwardStateFromAnalyze extends TestUDTF {
+ val name: String = "TestPythonUDTFForwardStateFromAnalyze"
+ val pythonScript: String =
+ s"""
+ |from dataclasses import dataclass
+ |from pyspark.sql.functions import AnalyzeResult
+ |from pyspark.sql.types import StringType, StructType
+ |
+ |@dataclass
+ |class AnalyzeResultWithBuffer(AnalyzeResult):
+ | buffer: str = ""
+ |
+ |class $name:
+ | def __init__(self, analyze_result):
+ | self._analyze_result = analyze_result
+ |
+ | @staticmethod
+ | def analyze(argument):
+ | assert(argument.data_type == StringType())
+ | return AnalyzeResultWithBuffer(
+ | schema=StructType()
+ | .add("result", StringType()),
+ | buffer=argument.value)
+ |
+ | def eval(self, argument):
+ | pass
+ |
+ | def terminate(self):
+ | yield self._analyze_result.buffer,
+ |""".stripMargin
+
+ val udtf: UserDefinedPythonTableFunction =
createUserDefinedPythonTableFunction(
+ name = name,
+ pythonScript = pythonScript,
+ returnType = None)
+
+ def apply(session: SparkSession, exprs: Column*): DataFrame =
+ udtf.apply(session, exprs: _*)
+
+ val prettyName: String = "Python UDTF whose 'analyze' method sets state
and reads it later"
+ }
+
/**
* A Scalar Pandas UDF that takes one column, casts into string, executes the
* Python native function, and casts back to the type of input column.
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala
index cdc3ef9e4178..efab685236de 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala
@@ -48,15 +48,15 @@ class PythonUDTFSuite extends QueryTest with
SharedSparkSession {
private val pythonUDTFCountSumLast: UserDefinedPythonTableFunction =
createUserDefinedPythonTableFunction(
- "UDTFCountSumLast", TestPythonUDTFCountSumLast.pythonScript, None)
+ TestPythonUDTFCountSumLast.name,
TestPythonUDTFCountSumLast.pythonScript, None)
private val pythonUDTFWithSinglePartition: UserDefinedPythonTableFunction =
createUserDefinedPythonTableFunction(
- "UDTFWithSinglePartition",
TestPythonUDTFWithSinglePartition.pythonScript, None)
+ TestPythonUDTFWithSinglePartition.name,
TestPythonUDTFWithSinglePartition.pythonScript, None)
private val pythonUDTFPartitionByOrderBy: UserDefinedPythonTableFunction =
createUserDefinedPythonTableFunction(
- "UDTFPartitionByOrderBy", TestPythonUDTFPartitionBy.pythonScript, None)
+ TestPythonUDTFPartitionBy.name, TestPythonUDTFPartitionBy.pythonScript,
None)
private val arrowPythonUDTF: UserDefinedPythonTableFunction =
createUserDefinedPythonTableFunction(
@@ -65,6 +65,11 @@ class PythonUDTFSuite extends QueryTest with
SharedSparkSession {
Some(returnType),
evalType = PythonEvalType.SQL_ARROW_TABLE_UDF)
+ private val pythonUDTFForwardStateFromAnalyze:
UserDefinedPythonTableFunction =
+ createUserDefinedPythonTableFunction(
+ TestPythonUDTFForwardStateFromAnalyze.name,
+ TestPythonUDTFForwardStateFromAnalyze.pythonScript, None)
+
test("Simple PythonUDTF") {
assume(shouldTestPythonUDFs)
val df = pythonUDTF(spark, lit(1), lit(2))
@@ -200,14 +205,14 @@ class PythonUDTFSuite extends QueryTest with
SharedSparkSession {
stop = 29))
}
- spark.udtf.registerPython("UDTFCountSumLast", pythonUDTFCountSumLast)
+ spark.udtf.registerPython(TestPythonUDTFCountSumLast.name,
pythonUDTFCountSumLast)
var plan = sql(
- """
+ s"""
|WITH t AS (
| VALUES (0, 1), (1, 2), (1, 3) t(partition_col, input)
|)
|SELECT count, total, last
- |FROM UDTFCountSumLast(TABLE(t) WITH SINGLE PARTITION)
+ |FROM ${TestPythonUDTFCountSumLast.name}(TABLE(t) WITH SINGLE
PARTITION)
|ORDER BY 1, 2
|""".stripMargin).queryExecution.analyzed
plan.collectFirst { case r: Repartition => r } match {
@@ -216,16 +221,16 @@ class PythonUDTFSuite extends QueryTest with
SharedSparkSession {
failure(plan)
}
- spark.udtf.registerPython("UDTFWithSinglePartition",
pythonUDTFWithSinglePartition)
+ spark.udtf.registerPython(TestPythonUDTFWithSinglePartition.name,
pythonUDTFWithSinglePartition)
plan = sql(
- """
+ s"""
|WITH t AS (
| SELECT id AS partition_col, 1 AS input FROM range(1, 21)
| UNION ALL
| SELECT id AS partition_col, 2 AS input FROM range(1, 21)
|)
|SELECT count, total, last
- |FROM UDTFWithSinglePartition(TABLE(t))
+ |FROM ${TestPythonUDTFWithSinglePartition.name}(0, TABLE(t))
|ORDER BY 1, 2
|""".stripMargin).queryExecution.analyzed
plan.collectFirst { case r: Repartition => r } match {
@@ -234,16 +239,16 @@ class PythonUDTFSuite extends QueryTest with
SharedSparkSession {
failure(plan)
}
- spark.udtf.registerPython("UDTFPartitionByOrderBy",
pythonUDTFPartitionByOrderBy)
+ spark.udtf.registerPython(TestPythonUDTFPartitionBy.name,
pythonUDTFPartitionByOrderBy)
plan = sql(
- """
+ s"""
|WITH t AS (
| SELECT id AS partition_col, 1 AS input FROM range(1, 21)
| UNION ALL
| SELECT id AS partition_col, 2 AS input FROM range(1, 21)
|)
|SELECT partition_col, count, total, last
- |FROM UDTFPartitionByOrderBy(TABLE(t))
+ |FROM ${TestPythonUDTFPartitionBy.name}(TABLE(t))
|ORDER BY 1, 2
|""".stripMargin).queryExecution.analyzed
plan.collectFirst { case r: RepartitionByExpression => r } match {
@@ -345,4 +350,17 @@ class PythonUDTFSuite extends QueryTest with
SharedSparkSession {
Literal("abc"))) ==
Seq(2, 3))
}
+
+ test("SPARK-45402: Add UDTF API for 'analyze' to return a buffer to consume
on class creation") {
+ spark.udtf.registerPython(
+ TestPythonUDTFForwardStateFromAnalyze.name,
+ pythonUDTFForwardStateFromAnalyze)
+ withTable("t") {
+ sql("create table t(col array<int>) using parquet")
+ val query = s"select * from
${TestPythonUDTFForwardStateFromAnalyze.name}('abc')"
+ checkAnswer(
+ sql(query),
+ Row("abc"))
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]