allisonwang-db commented on code in PR #52317:
URL: https://github.com/apache/spark/pull/52317#discussion_r2353553450


##########
python/pyspark/sql/tests/arrow/test_arrow_udtf.py:
##########
@@ -730,6 +730,496 @@ def eval(self, x: "pa.Array", y: "pa.Array") -> 
Iterator["pa.Table"]:
         expected_df2 = self.spark.createDataFrame([(7, 3, 10)], "x int, y int, 
sum int")
         assertDataFrameEqual(sql_result_df2, expected_df2)
 
+    def test_arrow_udtf_with_partition_by(self):
+        @arrow_udtf(returnType="partition_key int, sum_value int")
+        class SumUDTF:
+            def eval(self, table_data: "pa.RecordBatch") -> 
Iterator["pa.Table"]:
+                table = pa.table(table_data)
+                partition_key = pc.unique(table["partition_key"]).to_pylist()
+                assert (
+                    len(partition_key) == 1
+                ), f"Expected exactly one partition key, got {partition_key}"
+                sum_value = pc.sum(table["value"]).as_py()
+                result_table = pa.table(
+                    {
+                        "partition_key": pa.array([partition_key[0]], 
type=pa.int32()),
+                        "sum_value": pa.array([sum_value], type=pa.int32()),
+                    }
+                )
+                yield result_table
+
+        test_data = [
+            (1, 10),
+            (2, 5),
+            (1, 20),
+            (2, 15),
+            (1, 30),
+            (3, 100),
+        ]
+        input_df = self.spark.createDataFrame(test_data, "partition_key int, 
value int")
+
+        self.spark.udtf.register("sum_udtf", SumUDTF)
+        input_df.createOrReplaceTempView("test_data")
+
+        result_df = self.spark.sql(
+            """
+            SELECT * FROM sum_udtf(TABLE(test_data) PARTITION BY partition_key)
+        """
+        )
+
+        expected_data = [
+            (1, 60),
+            (2, 20),
+            (3, 100),
+        ]
+        expected_df = self.spark.createDataFrame(expected_data, "partition_key 
int, sum_value int")
+        assertDataFrameEqual(result_df, expected_df)
+
+    def test_arrow_udtf_with_partition_by_and_terminate(self):
+        @arrow_udtf(returnType="partition_key int, count int, sum_value int")
+        class TerminateUDTF:
+            def __init__(self):
+                self._partition_key = None
+                self._count = 0
+                self._sum = 0
+
+            def eval(self, table_data: "pa.RecordBatch") -> 
Iterator["pa.Table"]:
+                import pyarrow.compute as pc
+
+                table = pa.table(table_data)
+                # Track partition key
+                partition_keys = pc.unique(table["partition_key"]).to_pylist()
+                assert len(partition_keys) == 1, f"Expected one partition key, 
got {partition_keys}"
+                self._partition_key = partition_keys[0]
+
+                # Accumulate stats but don't yield here
+                self._count += table.num_rows
+                self._sum += pc.sum(table["value"]).as_py()
+                # Return empty iterator - results come from terminate
+                return iter(())
+
+            def terminate(self) -> Iterator["pa.Table"]:
+                # Yield accumulated results for this partition
+                if self._partition_key is not None:
+                    result_table = pa.table(
+                        {
+                            "partition_key": pa.array([self._partition_key], 
type=pa.int32()),
+                            "count": pa.array([self._count], type=pa.int32()),
+                            "sum_value": pa.array([self._sum], 
type=pa.int32()),
+                        }
+                    )
+                    yield result_table
+
+        test_data = [
+            (3, 50),
+            (1, 10),
+            (2, 40),
+            (1, 20),
+            (2, 30),
+        ]
+        input_df = self.spark.createDataFrame(test_data, "partition_key int, 
value int")
+
+        self.spark.udtf.register("terminate_udtf", TerminateUDTF)
+        input_df.createOrReplaceTempView("test_data_terminate")
+
+        result_df = self.spark.sql(
+            """
+            SELECT * FROM terminate_udtf(TABLE(test_data_terminate) PARTITION 
BY partition_key)
+            ORDER BY partition_key
+            """
+        )
+
+        expected_data = [
+            (1, 2, 30),  # partition 1: 2 rows, sum = 30
+            (2, 2, 70),  # partition 2: 2 rows, sum = 70
+            (3, 1, 50),  # partition 3: 1 row, sum = 50
+        ]
+        expected_df = self.spark.createDataFrame(
+            expected_data, "partition_key int, count int, sum_value int"
+        )
+        assertDataFrameEqual(result_df, expected_df)
+
+    def test_arrow_udtf_with_partition_by_and_order_by(self):
+        @arrow_udtf(returnType="partition_key int, first_value int, last_value 
int")
+        class OrderByUDTF:
+            def __init__(self):
+                self._partition_key = None
+                self._first_value = None
+                self._last_value = None
+
+            def eval(self, table_data: "pa.RecordBatch") -> 
Iterator["pa.Table"]:
+                import pyarrow.compute as pc
+
+                table = pa.table(table_data)
+                partition_keys = pc.unique(table["partition_key"]).to_pylist()
+                assert len(partition_keys) == 1, f"Expected one partition key, 
got {partition_keys}"
+                self._partition_key = partition_keys[0]
+
+                # Track first and last values (should be ordered)
+                values = table["value"].to_pylist()
+                if values:
+                    if self._first_value is None:
+                        self._first_value = values[0]
+                    self._last_value = values[-1]
+
+                return iter(())
+
+            def terminate(self) -> Iterator["pa.Table"]:
+                if self._partition_key is not None:
+                    result_table = pa.table(
+                        {
+                            "partition_key": pa.array([self._partition_key], 
type=pa.int32()),
+                            "first_value": pa.array([self._first_value], 
type=pa.int32()),
+                            "last_value": pa.array([self._last_value], 
type=pa.int32()),
+                        }
+                    )
+                    yield result_table
+
+        test_data = [
+            (1, 30),
+            (1, 10),
+            (1, 20),
+            (2, 60),
+            (2, 40),
+            (2, 50),
+        ]
+        input_df = self.spark.createDataFrame(test_data, "partition_key int, 
value int")
+
+        self.spark.udtf.register("order_by_udtf", OrderByUDTF)
+        input_df.createOrReplaceTempView("test_data_order")
+
+        result_df = self.spark.sql(
+            """
+            SELECT * FROM order_by_udtf(
+                TABLE(test_data_order)
+                PARTITION BY partition_key
+                ORDER BY value
+            )
+            ORDER BY partition_key
+            """
+        )
+
+        expected_data = [
+            (1, 10, 30),  # partition 1: first=10 (min), last=30 (max) after 
ordering
+            (2, 40, 60),  # partition 2: first=40 (min), last=60 (max) after 
ordering
+        ]
+        expected_df = self.spark.createDataFrame(
+            expected_data, "partition_key int, first_value int, last_value int"
+        )
+        assertDataFrameEqual(result_df, expected_df)
+
+    def test_arrow_udtf_partition_column_removal(self):
+        @arrow_udtf(returnType="col1_sum int, col2_sum int")
+        class PartitionColumnTestUDTF:
+            def eval(self, table_data: "pa.RecordBatch") -> 
Iterator["pa.Table"]:
+                import pyarrow.compute as pc
+
+                table = pa.table(table_data)
+
+                # When partitioning by an expression like "col1 + col2",
+                # Catalyst adds the expression result as a new column at the 
beginning.
+                # The ArrowUDTFWithPartition._remove_partition_by_exprs method 
should
+                # remove this added column, leaving only the original table 
columns.
+                column_names = table.column_names
+
+                # Verify we only have the original columns, not the partition 
expression
+                assert "col1" in column_names, f"Expected col1 in columns: 
{column_names}"
+                assert "col2" in column_names, f"Expected col2 in columns: 
{column_names}"
+                # The partition expression column should have been removed
+                assert len(column_names) == 2, (
+                    f"Expected only col1 and col2 after partition column 
removal, "
+                    f"but got: {column_names}"
+                )
+
+                col1_sum = pc.sum(table["col1"]).as_py()
+                col2_sum = pc.sum(table["col2"]).as_py()
+
+                result_table = pa.table(
+                    {
+                        "col1_sum": pa.array([col1_sum], type=pa.int32()),
+                        "col2_sum": pa.array([col2_sum], type=pa.int32()),
+                    }
+                )
+                yield result_table
+
+        test_data = [
+            (1, 1),  # partition: 1+1=2
+            (1, 2),  # partition: 1+2=3
+            (2, 0),  # partition: 2+0=2
+            (2, 1),  # partition: 2+1=3
+        ]
+        input_df = self.spark.createDataFrame(test_data, "col1 int, col2 int")
+
+        self.spark.udtf.register("partition_column_test_udtf", 
PartitionColumnTestUDTF)
+        input_df.createOrReplaceTempView("test_partition_removal")
+
+        # Partition by col1 + col2 expression
+        result_df = self.spark.sql(
+            """
+            SELECT * FROM partition_column_test_udtf(
+                TABLE(test_partition_removal)
+                PARTITION BY col1 + col2
+            )
+            ORDER BY col1_sum, col2_sum
+            """
+        )
+
+        expected_data = [
+            (3, 1),  # partition 2: sum of col1s (1+2), sum of col2s (1+0)
+            (3, 3),  # partition 3: sum of col1s (1+2), sum of col2s (2+1)
+        ]
+        expected_df = self.spark.createDataFrame(expected_data, "col1_sum int, 
col2_sum int")
+        assertDataFrameEqual(result_df, expected_df)
+
+    def 
test_arrow_udtf_partition_by_single_partition_multiple_input_partitions(self):
+        @arrow_udtf(returnType="partition_key int, count bigint, sum_value 
bigint")
+        class SinglePartitionUDTF:
+            def __init__(self):
+                self._partition_key = None
+                self._count = 0
+                self._sum = 0
+
+            def eval(self, table_data: "pa.RecordBatch") -> 
Iterator["pa.Table"]:
+                import pyarrow.compute as pc
+
+                table = pa.table(table_data)
+
+                # All rows should have the same partition key (constant value 
1)
+                partition_keys = pc.unique(table["partition_key"]).to_pylist()
+                self._partition_key = partition_keys[0]
+                self._count += table.num_rows
+                self._sum += pc.sum(table["id"]).as_py()
+
+                return iter(())
+
+            def terminate(self) -> Iterator["pa.Table"]:
+                if self._partition_key is not None:
+                    result_table = pa.table(
+                        {
+                            "partition_key": pa.array([self._partition_key], 
type=pa.int32()),
+                            "count": pa.array([self._count], type=pa.int64()),
+                            "sum_value": pa.array([self._sum], 
type=pa.int64()),
+                        }
+                    )
+                    yield result_table
+
+        # Create DataFrame with 5 input partitions but all data will map to 
partition_key=1
+        # range(1, 10, 1, 5) creates ids from 1 to 9 with 5 partitions
+        input_df = self.spark.range(1, 10, 1, 5).selectExpr(
+            "1 as partition_key", "id"  # constant partition key
+        )
+
+        self.spark.udtf.register("single_partition_udtf", SinglePartitionUDTF)
+        input_df.createOrReplaceTempView("test_single_partition")
+
+        result_df = self.spark.sql(
+            """
+            SELECT * FROM single_partition_udtf(
+                TABLE(test_single_partition)
+                PARTITION BY partition_key
+            )
+            """
+        )
+
+        # All 9 rows (1 through 9) should be in a single partition with key=1
+        expected_data = [(1, 9, 45)]
+        expected_df = self.spark.createDataFrame(
+            expected_data, "partition_key int, count bigint, sum_value bigint"
+        )
+        assertDataFrameEqual(result_df, expected_df)
+
+    def test_arrow_udtf_with_partition_by_skip_rest_of_input(self):
+        from pyspark.sql.functions import SkipRestOfInputTableException
+
+        @arrow_udtf(returnType="partition_key int, rows_processed int, 
last_value int")
+        class SkipRestUDTF:
+            def __init__(self):
+                self._partition_key = None
+                self._rows_processed = 0
+                self._last_value = None
+
+            def eval(self, table_data: "pa.RecordBatch") -> 
Iterator["pa.Table"]:
+                import pyarrow.compute as pc
+
+                table = pa.table(table_data)
+                partition_keys = pc.unique(table["partition_key"]).to_pylist()
+                assert len(partition_keys) == 1, f"Expected one partition key, 
got {partition_keys}"
+                self._partition_key = partition_keys[0]
+
+                # Process rows one by one and stop after processing 2 rows per 
partition
+                values = table["value"].to_pylist()
+                for value in values:
+                    self._rows_processed += 1
+                    self._last_value = value
+
+                    # Skip rest of the partition after processing 2 rows
+                    if self._rows_processed >= 2:
+                        msg = f"Skipping partition {self._partition_key} "
+                        msg += f"after {self._rows_processed} rows"
+                        raise SkipRestOfInputTableException(msg)
+
+                return iter(())
+
+            def terminate(self) -> Iterator["pa.Table"]:
+                if self._partition_key is not None:
+                    result_table = pa.table(
+                        {
+                            "partition_key": pa.array([self._partition_key], 
type=pa.int32()),
+                            "rows_processed": pa.array([self._rows_processed], 
type=pa.int32()),
+                            "last_value": pa.array([self._last_value], 
type=pa.int32()),
+                        }
+                    )
+                    yield result_table
+
+        # Create test data with multiple partitions, each having more than 2 
rows
+        test_data = [
+            (1, 10),
+            (1, 20),
+            (1, 30),  # This should be skipped
+            (1, 40),  # This should be skipped
+            (2, 50),
+            (2, 60),
+            (2, 70),  # This should be skipped
+            (3, 80),
+            (3, 90),
+            (3, 100),  # This should be skipped
+            (3, 110),  # This should be skipped
+        ]
+        input_df = self.spark.createDataFrame(test_data, "partition_key int, 
value int")
+
+        self.spark.udtf.register("skip_rest_udtf", SkipRestUDTF)
+        input_df.createOrReplaceTempView("test_skip_rest")
+
+        result_df = self.spark.sql(
+            """
+            SELECT * FROM skip_rest_udtf(
+                TABLE(test_skip_rest)
+                PARTITION BY partition_key
+                ORDER BY value
+            )
+            ORDER BY partition_key
+            """
+        )
+
+        # Each partition should only process 2 rows before skipping the rest
+        expected_data = [
+            (1, 2, 20),  # Processed rows 10, 20, then skipped 30, 40
+            (2, 2, 60),  # Processed rows 50, 60, then skipped 70
+            (3, 2, 90),  # Processed rows 80, 90, then skipped 100, 110
+        ]
+        expected_df = self.spark.createDataFrame(
+            expected_data, "partition_key int, rows_processed int, last_value 
int"
+        )
+        assertDataFrameEqual(result_df, expected_df)
+
+    def test_arrow_udtf_with_partition_by_null_values(self):
+        @arrow_udtf(returnType="partition_key int, count int, non_null_sum 
int")
+        class NullPartitionUDTF:
+            def __init__(self):
+                self._partition_key = None
+                self._count = 0
+                self._non_null_sum = 0
+
+            def eval(self, table_data: "pa.RecordBatch") -> 
Iterator["pa.Table"]:
+                import pyarrow.compute as pc
+
+                table = pa.table(table_data)
+                # Handle null partition keys
+                partition_keys = table["partition_key"]
+                unique_keys = pc.unique(partition_keys).to_pylist()
+
+                # Should have exactly one unique value (either a value or None)
+                assert len(unique_keys) == 1, f"Expected one partition key, 
got {unique_keys}"
+                self._partition_key = unique_keys[0]
+
+                # Count rows and sum non-null values
+                self._count += table.num_rows
+                values = table["value"]
+                # Use PyArrow compute to handle nulls properly
+                non_null_values = pc.drop_null(values)
+                if len(non_null_values) > 0:
+                    self._non_null_sum += pc.sum(non_null_values).as_py()
+
+                return iter(())
+
+            def terminate(self) -> Iterator["pa.Table"]:
+                # Return results even for null partition keys
+                result_table = pa.table(
+                    {
+                        "partition_key": pa.array([self._partition_key], 
type=pa.int32()),
+                        "count": pa.array([self._count], type=pa.int32()),
+                        "non_null_sum": pa.array([self._non_null_sum], 
type=pa.int32()),
+                    }
+                )
+                yield result_table
+
+        # Test data with null partition keys and null values
+        test_data = [
+            (1, 10),
+            (1, None),  # null value in partition 1
+            (None, 20),  # null partition key
+            (None, 30),  # null partition key
+            (2, 40),
+            (2, None),  # null value in partition 2
+            (None, None),  # both null
+        ]
+        input_df = self.spark.createDataFrame(test_data, "partition_key int, 
value int")
+
+        self.spark.udtf.register("null_partition_udtf", NullPartitionUDTF)
+        input_df.createOrReplaceTempView("test_null_partitions")
+
+        result_df = self.spark.sql(
+            """
+            SELECT * FROM null_partition_udtf(
+                TABLE(test_null_partitions)
+                PARTITION BY partition_key
+                ORDER BY value
+            )
+            ORDER BY partition_key NULLS FIRST
+            """
+        )
+
+        # Expected: null partition gets grouped together, nulls in values are 
handled
+        expected_data = [
+            (None, 3, 50),  # null partition: 3 rows, sum of non-null values = 
20+30 = 50
+            (1, 2, 10),  # partition 1: 2 rows, sum of non-null values = 10
+            (2, 2, 40),  # partition 2: 2 rows, sum of non-null values = 40
+        ]
+        expected_df = self.spark.createDataFrame(
+            expected_data, "partition_key int, count int, non_null_sum int"
+        )
+        assertDataFrameEqual(result_df, expected_df)
+
+    def test_arrow_udtf_with_empty_table(self):
+        @arrow_udtf(returnType="result string")
+        class EmptyTableUDTF:
+            def eval(self, table_data: "pa.RecordBatch") -> 
Iterator["pa.Table"]:
+                import pyarrow as pa
+
+                # This should not be called for empty tables

Review Comment:
   Good point!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to