Yicong-Huang commented on code in PR #52317:
URL: https://github.com/apache/spark/pull/52317#discussion_r2345904286


##########
python/pyspark/sql/tests/arrow/test_arrow_udtf.py:
##########
@@ -730,6 +730,496 @@ def eval(self, x: "pa.Array", y: "pa.Array") -> 
Iterator["pa.Table"]:
         expected_df2 = self.spark.createDataFrame([(7, 3, 10)], "x int, y int, 
sum int")
         assertDataFrameEqual(sql_result_df2, expected_df2)
 
+    def test_arrow_udtf_with_partition_by(self):
+        @arrow_udtf(returnType="partition_key int, sum_value int")
+        class SumUDTF:
+            def eval(self, table_data: "pa.RecordBatch") -> 
Iterator["pa.Table"]:
+                table = pa.table(table_data)
+                partition_key = pc.unique(table["partition_key"]).to_pylist()
+                assert (
+                    len(partition_key) == 1
+                ), f"Expected exactly one partition key, got {partition_key}"
+                sum_value = pc.sum(table["value"]).as_py()
+                result_table = pa.table(
+                    {
+                        "partition_key": pa.array([partition_key[0]], 
type=pa.int32()),
+                        "sum_value": pa.array([sum_value], type=pa.int32()),
+                    }
+                )
+                yield result_table
+
+        test_data = [
+            (1, 10),
+            (2, 5),
+            (1, 20),
+            (2, 15),
+            (1, 30),
+            (3, 100),
+        ]
+        input_df = self.spark.createDataFrame(test_data, "partition_key int, 
value int")
+
+        self.spark.udtf.register("sum_udtf", SumUDTF)
+        input_df.createOrReplaceTempView("test_data")
+
+        result_df = self.spark.sql(
+            """
+            SELECT * FROM sum_udtf(TABLE(test_data) PARTITION BY partition_key)
+        """
+        )
+
+        expected_data = [
+            (1, 60),
+            (2, 20),
+            (3, 100),
+        ]
+        expected_df = self.spark.createDataFrame(expected_data, "partition_key 
int, sum_value int")
+        assertDataFrameEqual(result_df, expected_df)
+
+    def test_arrow_udtf_with_partition_by_and_terminate(self):
+        @arrow_udtf(returnType="partition_key int, count int, sum_value int")
+        class TerminateUDTF:
+            def __init__(self):
+                self._partition_key = None
+                self._count = 0
+                self._sum = 0
+
+            def eval(self, table_data: "pa.RecordBatch") -> 
Iterator["pa.Table"]:
+                import pyarrow.compute as pc
+
+                table = pa.table(table_data)
+                # Track partition key
+                partition_keys = pc.unique(table["partition_key"]).to_pylist()
+                assert len(partition_keys) == 1, f"Expected one partition key, 
got {partition_keys}"
+                self._partition_key = partition_keys[0]
+
+                # Accumulate stats but don't yield here
+                self._count += table.num_rows
+                self._sum += pc.sum(table["value"]).as_py()
+                # Return empty iterator - results come from terminate
+                return iter(())
+
+            def terminate(self) -> Iterator["pa.Table"]:
+                # Yield accumulated results for this partition
+                if self._partition_key is not None:
+                    result_table = pa.table(
+                        {
+                            "partition_key": pa.array([self._partition_key], 
type=pa.int32()),
+                            "count": pa.array([self._count], type=pa.int32()),
+                            "sum_value": pa.array([self._sum], 
type=pa.int32()),
+                        }
+                    )
+                    yield result_table
+
+        test_data = [
+            (3, 50),
+            (1, 10),
+            (2, 40),
+            (1, 20),
+            (2, 30),
+        ]
+        input_df = self.spark.createDataFrame(test_data, "partition_key int, 
value int")
+
+        self.spark.udtf.register("terminate_udtf", TerminateUDTF)
+        input_df.createOrReplaceTempView("test_data_terminate")
+
+        result_df = self.spark.sql(
+            """
+            SELECT * FROM terminate_udtf(TABLE(test_data_terminate) PARTITION 
BY partition_key)
+            ORDER BY partition_key
+            """
+        )
+
+        expected_data = [
+            (1, 2, 30),  # partition 1: 2 rows, sum = 30
+            (2, 2, 70),  # partition 2: 2 rows, sum = 70
+            (3, 1, 50),  # partition 3: 1 row, sum = 50
+        ]
+        expected_df = self.spark.createDataFrame(
+            expected_data, "partition_key int, count int, sum_value int"
+        )
+        assertDataFrameEqual(result_df, expected_df)
+
+    def test_arrow_udtf_with_partition_by_and_order_by(self):
+        @arrow_udtf(returnType="partition_key int, first_value int, last_value 
int")
+        class OrderByUDTF:
+            def __init__(self):
+                self._partition_key = None
+                self._first_value = None
+                self._last_value = None
+
+            def eval(self, table_data: "pa.RecordBatch") -> 
Iterator["pa.Table"]:
+                import pyarrow.compute as pc
+
+                table = pa.table(table_data)
+                partition_keys = pc.unique(table["partition_key"]).to_pylist()
+                assert len(partition_keys) == 1, f"Expected one partition key, 
got {partition_keys}"
+                self._partition_key = partition_keys[0]
+
+                # Track first and last values (should be ordered)
+                values = table["value"].to_pylist()
+                if values:
+                    if self._first_value is None:
+                        self._first_value = values[0]
+                    self._last_value = values[-1]
+
+                return iter(())
+
+            def terminate(self) -> Iterator["pa.Table"]:
+                if self._partition_key is not None:
+                    result_table = pa.table(
+                        {
+                            "partition_key": pa.array([self._partition_key], 
type=pa.int32()),
+                            "first_value": pa.array([self._first_value], 
type=pa.int32()),
+                            "last_value": pa.array([self._last_value], 
type=pa.int32()),
+                        }
+                    )
+                    yield result_table
+
+        test_data = [
+            (1, 30),
+            (1, 10),
+            (1, 20),
+            (2, 60),
+            (2, 40),
+            (2, 50),
+        ]
+        input_df = self.spark.createDataFrame(test_data, "partition_key int, 
value int")
+
+        self.spark.udtf.register("order_by_udtf", OrderByUDTF)
+        input_df.createOrReplaceTempView("test_data_order")
+
+        result_df = self.spark.sql(
+            """
+            SELECT * FROM order_by_udtf(
+                TABLE(test_data_order)
+                PARTITION BY partition_key
+                ORDER BY value
+            )
+            ORDER BY partition_key
+            """
+        )
+
+        expected_data = [
+            (1, 10, 30),  # partition 1: first=10 (min), last=30 (max) after 
ordering
+            (2, 40, 60),  # partition 2: first=40 (min), last=60 (max) after 
ordering
+        ]
+        expected_df = self.spark.createDataFrame(
+            expected_data, "partition_key int, first_value int, last_value int"
+        )
+        assertDataFrameEqual(result_df, expected_df)
+
+    def test_arrow_udtf_partition_column_removal(self):
+        @arrow_udtf(returnType="col1_sum int, col2_sum int")
+        class PartitionColumnTestUDTF:
+            def eval(self, table_data: "pa.RecordBatch") -> 
Iterator["pa.Table"]:
+                import pyarrow.compute as pc
+
+                table = pa.table(table_data)
+
+                # When partitioning by an expression like "col1 + col2",
+                # Catalyst adds the expression result as a new column at the 
beginning.
+                # The ArrowUDTFWithPartition._remove_partition_by_exprs method 
should
+                # remove this added column, leaving only the original table 
columns.
+                column_names = table.column_names
+
+                # Verify we only have the original columns, not the partition 
expression
+                assert "col1" in column_names, f"Expected col1 in columns: 
{column_names}"
+                assert "col2" in column_names, f"Expected col2 in columns: 
{column_names}"
+                # The partition expression column should have been removed
+                assert len(column_names) == 2, (
+                    f"Expected only col1 and col2 after partition column 
removal, "
+                    f"but got: {column_names}"
+                )
+
+                col1_sum = pc.sum(table["col1"]).as_py()
+                col2_sum = pc.sum(table["col2"]).as_py()
+
+                result_table = pa.table(
+                    {
+                        "col1_sum": pa.array([col1_sum], type=pa.int32()),
+                        "col2_sum": pa.array([col2_sum], type=pa.int32()),
+                    }
+                )
+                yield result_table

Review Comment:
   I think https://github.com/apache/spark/pull/52317#discussion_r2345574269 
might be the same issue



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to