ueshin commented on code in PR #43682:
URL: https://github.com/apache/spark/pull/43682#discussion_r1389997538
##########
python/pyspark/sql/tests/test_udtf.py:
##########
@@ -2467,6 +2468,41 @@ def terminate(self):
[Row(count=20, buffer="abc")],
)
+ def test_udtf_with_skip_rest_of_input_table_exception(self):
+ @udtf
+ class TestUDTF:
+ def __init__(self):
+ self._total = 0
+
+ @staticmethod
+ def analyze(_):
+ return AnalyzeResult(
+ schema=StructType().add("total", IntegerType()),
withSinglePartition=True
+ )
+
+ def eval(self, _: Row):
+ self._total += 1
+ if self._total >= 4:
+ raise SkipRestOfInputTableException("Stop at self._total
>= 4")
+
+ def terminate(self):
+ yield self._total,
+
+ self.spark.udtf.register("test_udtf", TestUDTF)
+
+ assertDataFrameEqual(
+ self.spark.sql(
+ """
+ WITH t AS (
+ SELECT id FROM range(1, 21)
+ )
+ SELECT total
+ FROM test_udtf(TABLE(t))
+ """
+ ),
+ [Row(total=4)],
+ )
Review Comment:
What happens with partition by? Could you add a test for the case?
##########
python/pyspark/sql/tests/test_udtf.py:
##########
@@ -2467,6 +2468,41 @@ def terminate(self):
[Row(count=20, buffer="abc")],
)
+ def test_udtf_with_skip_rest_of_input_table_exception(self):
+ @udtf
+ class TestUDTF:
+ def __init__(self):
+ self._total = 0
+
+ @staticmethod
+ def analyze(_):
+ return AnalyzeResult(
+ schema=StructType().add("total", IntegerType()),
withSinglePartition=True
+ )
Review Comment:
nit: I guess we can use `@udf(returnType=...)` for the schema and `TABLE(t)
WITH SINGLE PARTITION` to simplify the test.
##########
python/pyspark/sql/udtf.py:
##########
@@ -118,6 +125,13 @@ class AnalyzeResult:
orderBy: Sequence[OrderingColumn] = field(default_factory=tuple)
+# This represents an exception that the 'eval' method may raise to indicate
that it is done
+# consuming rows from the current partition of the input table. Then the
UDTF's 'terminate' method
+# runs (if any).
+class SkipRestOfInputTableException(Exception):
+ pass
Review Comment:
The comment should be in the class definition?
```py
class SkipRestOfInputTableException(Exception):
# This represents ...
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]