ueshin commented on code in PR #52170:
URL: https://github.com/apache/spark/pull/52170#discussion_r2308852517


##########
python/pyspark/sql/tests/arrow/test_arrow_udtf.py:
##########
@@ -607,6 +607,244 @@ def eval(self, input_val: int):
         expected_df = self.spark.createDataFrame([(60, 180)], "computed_value 
int, multiplied int")
         assertDataFrameEqual(result_df, expected_df)
 
+    @unittest.skip("SPARK-53387: Support PARTIITON BY with Arrow UDTF")
+    def test_arrow_udtf_table_partition_by_single_column(self):
+        @arrow_udtf(returnType="partition_key string, total_value bigint")
+        class PartitionSumUDTF:
+            def eval(self, table_data: "pa.RecordBatch") -> 
Iterator["pa.Table"]:
+                table = pa.table(table_data)
+                
+                # Each partition will have records with the same category
+                if table.num_rows > 0:
+                    category = table.column("category")[0].as_py()
+                    total = pa.compute.sum(table.column("value")).as_py()
+                    
+                    result_table = pa.table({
+                        "partition_key": pa.array([category], 
type=pa.string()),
+                        "total_value": pa.array([total], type=pa.int64())
+                    })
+                    yield result_table
+
+        self.spark.udtf.register("partition_sum_udtf", PartitionSumUDTF)
+        
+        # Create test data with categories
+        test_data = [
+            ("A", 10), ("A", 20), ("B", 30), ("B", 40), ("C", 50)
+        ]
+        test_df = self.spark.createDataFrame(test_data, "category string, 
value int")
+        test_df.createOrReplaceTempView("partition_test_data")
+
+
+        result_df = self.spark.sql("""
+            SELECT * FROM partition_sum_udtf(
+                TABLE(partition_test_data) PARTITION BY category
+            ) ORDER BY partition_key
+        """)
+        
+        expected_df = self.spark.createDataFrame([
+            ("A", 30), ("B", 70), ("C", 50)
+        ], "partition_key string, total_value bigint")
+        assertDataFrameEqual(result_df, expected_df)
+
+    @unittest.skip("SPARK-53387: Support PARTIITON BY with Arrow UDTF")
+    def test_arrow_udtf_table_partition_by_multiple_columns(self):
+        @arrow_udtf(returnType="dept string, status string, count_employees 
bigint")
+        class DeptStatusCountUDTF:
+            def eval(self, table_data: "pa.RecordBatch") -> 
Iterator["pa.Table"]:
+                table = pa.table(table_data)
+                
+                if table.num_rows > 0:
+                    dept = table.column("department")[0].as_py()
+                    status = table.column("status")[0].as_py()
+                    count = table.num_rows
+                    
+                    result_table = pa.table({
+                        "dept": pa.array([dept], type=pa.string()),
+                        "status": pa.array([status], type=pa.string()),
+                        "count_employees": pa.array([count], type=pa.int64())
+                    })
+                    yield result_table
+
+        self.spark.udtf.register("dept_status_count_udtf", DeptStatusCountUDTF)
+        
+        test_data = [
+            ("IT", "active"), ("IT", "active"), ("IT", "inactive"),
+            ("HR", "active"), ("HR", "inactive"), ("Finance", "active")
+        ]
+        test_df = self.spark.createDataFrame(test_data, "department string, 
status string")
+        test_df.createOrReplaceTempView("employee_data")
+        
+        result_df = self.spark.sql("""
+            SELECT * FROM dept_status_count_udtf(
+                TABLE(SELECT * FROM employee_data) 
+                PARTITION BY department, status
+            ) ORDER BY dept, status
+        """)
+        
+        expected_df = self.spark.createDataFrame([
+            ("Finance", "active", 1), ("HR", "active", 1), ("HR", "inactive", 
1), ("IT", "active", 2), ("IT", "inactive", 1)
+        ], "dept string, status string, count_employees bigint")
+        assertDataFrameEqual(result_df, expected_df)
+
+    def test_arrow_udtf_with_scalar_first_table_second(self):
+        @arrow_udtf(returnType="filtered_id bigint")
+        class ScalarFirstTableSecondUDTF:
+            def eval(self, threshold: "pa.Array", table_data: 
"pa.RecordBatch") -> Iterator["pa.Table"]:
+                assert isinstance(
+                    threshold, pa.Array
+                ), f"Expected pa.Array for threshold, got {type(threshold)}"
+                assert isinstance(
+                    table_data, pa.RecordBatch
+                ), f"Expected pa.RecordBatch for table_data, got 
{type(table_data)}"
+
+                threshold_val = threshold[0].as_py()
+
+                # Convert record batch to table
+                table = pa.table(table_data)
+                id_column = table.column("id")
+                mask = pa.compute.greater(id_column, pa.scalar(threshold_val))
+                filtered_table = table.filter(mask)
+
+                if filtered_table.num_rows > 0:
+                    result_table = pa.table(
+                        {"filtered_id": filtered_table.column("id")}  # Keep 
original type
+                    )
+                    yield result_table
+
+        # Test with DataFrame API - scalar first, table second
+        input_df = self.spark.range(8)
+        result_df = ScalarFirstTableSecondUDTF(lit(4), input_df.asTable())
+        expected_df = self.spark.createDataFrame([(5,), (6,), (7,)], 
"filtered_id bigint")
+        assertDataFrameEqual(result_df, expected_df)
+
+        # Test SQL registration and usage
+        self.spark.udtf.register("test_scalar_first_table_second_udtf", 
ScalarFirstTableSecondUDTF)
+        sql_result_df = self.spark.sql(
+            "SELECT * FROM test_scalar_first_table_second_udtf(4, TABLE(SELECT 
id FROM range(0, 8)))"
+        )
+        assertDataFrameEqual(sql_result_df, expected_df)
+
+    def test_arrow_udtf_with_table_argument_in_middle(self):
+        """Test Arrow UDTF with table argument in the middle of multiple 
scalar arguments."""
+        @arrow_udtf(returnType="filtered_id bigint")
+        class TableInMiddleUDTF:
+            def eval(
+                self, 
+                min_threshold: "pa.Array", 
+                table_data: "pa.RecordBatch", 
+                max_threshold: "pa.Array"
+            ) -> Iterator["pa.Table"]:
+                assert isinstance(
+                    min_threshold, pa.Array
+                ), f"Expected pa.Array for min_threshold, got 
{type(min_threshold)}"
+                assert isinstance(
+                    table_data, pa.RecordBatch
+                ), f"Expected pa.RecordBatch for table_data, got 
{type(table_data)}"
+                assert isinstance(
+                    max_threshold, pa.Array
+                ), f"Expected pa.Array for max_threshold, got 
{type(max_threshold)}"
+
+                min_val = min_threshold[0].as_py()
+                max_val = max_threshold[0].as_py()
+
+                # Convert record batch to table
+                table = pa.table(table_data)
+                id_column = table.column("id")
+                
+                # Filter rows where min_val < id < max_val
+                mask = pa.compute.and_(
+                    pa.compute.greater(id_column, pa.scalar(min_val)),
+                    pa.compute.less(id_column, pa.scalar(max_val))
+                )
+                filtered_table = table.filter(mask)
+
+                if filtered_table.num_rows > 0:
+                    result_table = pa.table(
+                        {"filtered_id": filtered_table.column("id")}  # Keep 
original type
+                    )
+                    yield result_table
+
+        # Test with DataFrame API - scalar, table, scalar
+        input_df = self.spark.range(10)
+        result_df = TableInMiddleUDTF(lit(2), input_df.asTable(), lit(7))
+        expected_df = self.spark.createDataFrame([(3,), (4,), (5,), (6,)], 
"filtered_id bigint")
+        assertDataFrameEqual(result_df, expected_df)
+
+        # Test SQL registration and usage
+        self.spark.udtf.register("test_table_in_middle_udtf", 
TableInMiddleUDTF)
+        sql_result_df = self.spark.sql(
+            "SELECT * FROM test_table_in_middle_udtf(2, TABLE(SELECT id FROM 
range(0, 10)), 7)"
+        )
+        assertDataFrameEqual(sql_result_df, expected_df)
+
+    def test_arrow_udtf_with_named_arguments(self):
+        @arrow_udtf(returnType="result_id bigint, multiplier_used int")
+        class NamedArgsUDTF:
+            def eval(
+                self, 
+                table_data: "pa.RecordBatch", 
+                multiplier: "pa.Array"
+            ) -> Iterator["pa.Table"]:
+                assert isinstance(
+                    table_data, pa.RecordBatch
+                ), f"Expected pa.RecordBatch for table_data, got 
{type(table_data)}"
+                assert isinstance(
+                    multiplier, pa.Array
+                ), f"Expected pa.Array for multiplier, got {type(multiplier)}"
+
+                multiplier_val = multiplier[0].as_py()
+
+                # Convert record batch to table
+                table = pa.table(table_data)
+                id_column = table.column("id")
+                
+                # Multiply each id by the multiplier
+                multiplied_ids = pa.compute.multiply(id_column, 
pa.scalar(multiplier_val))
+                
+                result_table = pa.table({
+                    "result_id": multiplied_ids,
+                    "multiplier_used": pa.array([multiplier_val] * 
table.num_rows, type=pa.int32())
+                })
+                yield result_table
+
+        # Test with DataFrame API using named arguments
+        # TODO(SPARK-53426): Support named table argument with DataFrame API
+        # input_df = self.spark.range(3)  # [0, 1, 2]
+        # result_df = NamedArgsUDTF(table_data=input_df.asTable(), 
multiplier=lit(5))

Review Comment:
   Fix: #52171



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to