dtenedor commented on code in PR #43682:
URL: https://github.com/apache/spark/pull/43682#discussion_r1391771813


##########
python/pyspark/sql/tests/test_udtf.py:
##########
@@ -2467,6 +2468,41 @@ def terminate(self):
             [Row(count=20, buffer="abc")],
         )
 
+    def test_udtf_with_skip_rest_of_input_table_exception(self):
+        @udtf
+        class TestUDTF:
+            def __init__(self):
+                self._total = 0
+
+            @staticmethod
+            def analyze(_):
+                return AnalyzeResult(
+                    schema=StructType().add("total", IntegerType()), 
withSinglePartition=True
+                )
+
+            def eval(self, _: Row):
+                self._total += 1
+                if self._total >= 4:
+                    raise SkipRestOfInputTableException("Stop at self._total 
>= 4")
+
+            def terminate(self):
+                yield self._total,
+
+        self.spark.udtf.register("test_udtf", TestUDTF)
+
+        assertDataFrameEqual(
+            self.spark.sql(
+                """
+                WITH t AS (
+                  SELECT id FROM range(1, 21)
+                )
+                SELECT total
+                FROM test_udtf(TABLE(t))
+                """
+            ),
+            [Row(total=4)],
+        )

Review Comment:
   The `SkipRestOfInputTableException` stops scanning rows for just the current 
partition. I added a test case for this as well.



##########
python/pyspark/sql/udtf.py:
##########
@@ -118,6 +125,13 @@ class AnalyzeResult:
     orderBy: Sequence[OrderingColumn] = field(default_factory=tuple)
 
 
+# This represents an exception that the 'eval' method may raise to indicate 
that it is done
+# consuming rows from the current partition of the input table. Then the 
UDTF's 'terminate' method
+# runs (if any).
+class SkipRestOfInputTableException(Exception):
+    pass

Review Comment:
   Good point, I moved this to a pydoc string inside the class definition 
itself.



##########
python/pyspark/sql/tests/test_udtf.py:
##########
@@ -2467,6 +2468,41 @@ def terminate(self):
             [Row(count=20, buffer="abc")],
         )
 
+    def test_udtf_with_skip_rest_of_input_table_exception(self):
+        @udtf
+        class TestUDTF:
+            def __init__(self):
+                self._total = 0
+
+            @staticmethod
+            def analyze(_):
+                return AnalyzeResult(
+                    schema=StructType().add("total", IntegerType()), 
withSinglePartition=True
+                )

Review Comment:
   Good idea, done!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to