allisonwang-db commented on code in PR #52170: URL: https://github.com/apache/spark/pull/52170#discussion_r2345567271
########## python/pyspark/sql/tests/arrow/test_arrow_udtf.py: ########## @@ -607,6 +607,244 @@ def eval(self, input_val: int): expected_df = self.spark.createDataFrame([(60, 180)], "computed_value int, multiplied int") assertDataFrameEqual(result_df, expected_df) + @unittest.skip("SPARK-53387: Support PARTIITON BY with Arrow UDTF") + def test_arrow_udtf_table_partition_by_single_column(self): + @arrow_udtf(returnType="partition_key string, total_value bigint") + class PartitionSumUDTF: + def eval(self, table_data: "pa.RecordBatch") -> Iterator["pa.Table"]: + table = pa.table(table_data) + + # Each partition will have records with the same category + if table.num_rows > 0: + category = table.column("category")[0].as_py() + total = pa.compute.sum(table.column("value")).as_py() + + result_table = pa.table({ + "partition_key": pa.array([category], type=pa.string()), + "total_value": pa.array([total], type=pa.int64()) + }) + yield result_table + + self.spark.udtf.register("partition_sum_udtf", PartitionSumUDTF) + + # Create test data with categories + test_data = [ + ("A", 10), ("A", 20), ("B", 30), ("B", 40), ("C", 50) + ] + test_df = self.spark.createDataFrame(test_data, "category string, value int") + test_df.createOrReplaceTempView("partition_test_data") + + + result_df = self.spark.sql(""" + SELECT * FROM partition_sum_udtf( + TABLE(partition_test_data) PARTITION BY category + ) ORDER BY partition_key + """) + + expected_df = self.spark.createDataFrame([ + ("A", 30), ("B", 70), ("C", 50) + ], "partition_key string, total_value bigint") + assertDataFrameEqual(result_df, expected_df) + + @unittest.skip("SPARK-53387: Support PARTIITON BY with Arrow UDTF") + def test_arrow_udtf_table_partition_by_multiple_columns(self): + @arrow_udtf(returnType="dept string, status string, count_employees bigint") + class DeptStatusCountUDTF: + def eval(self, table_data: "pa.RecordBatch") -> Iterator["pa.Table"]: + table = pa.table(table_data) + + if table.num_rows > 0: + dept = table.column("department")[0].as_py() + status = table.column("status")[0].as_py() + count = table.num_rows + + result_table = pa.table({ + "dept": pa.array([dept], type=pa.string()), + "status": pa.array([status], type=pa.string()), + "count_employees": pa.array([count], type=pa.int64()) + }) + yield result_table + + self.spark.udtf.register("dept_status_count_udtf", DeptStatusCountUDTF) + + test_data = [ + ("IT", "active"), ("IT", "active"), ("IT", "inactive"), + ("HR", "active"), ("HR", "inactive"), ("Finance", "active") + ] + test_df = self.spark.createDataFrame(test_data, "department string, status string") + test_df.createOrReplaceTempView("employee_data") + + result_df = self.spark.sql(""" + SELECT * FROM dept_status_count_udtf( + TABLE(SELECT * FROM employee_data) + PARTITION BY department, status + ) ORDER BY dept, status + """) + + expected_df = self.spark.createDataFrame([ + ("Finance", "active", 1), ("HR", "active", 1), ("HR", "inactive", 1), ("IT", "active", 2), ("IT", "inactive", 1) + ], "dept string, status string, count_employees bigint") + assertDataFrameEqual(result_df, expected_df) + + def test_arrow_udtf_with_scalar_first_table_second(self): + @arrow_udtf(returnType="filtered_id bigint") + class ScalarFirstTableSecondUDTF: + def eval(self, threshold: "pa.Array", table_data: "pa.RecordBatch") -> Iterator["pa.Table"]: + assert isinstance( + threshold, pa.Array + ), f"Expected pa.Array for threshold, got {type(threshold)}" + assert isinstance( + table_data, pa.RecordBatch + ), f"Expected pa.RecordBatch for table_data, got {type(table_data)}" + + threshold_val = threshold[0].as_py() + + # Convert record batch to table + table = pa.table(table_data) + id_column = table.column("id") + mask = pa.compute.greater(id_column, pa.scalar(threshold_val)) + filtered_table = table.filter(mask) + + if filtered_table.num_rows > 0: + result_table = pa.table( + {"filtered_id": filtered_table.column("id")} # Keep original type + ) + yield result_table + + # Test with DataFrame API - scalar first, table second + input_df = self.spark.range(8) + result_df = ScalarFirstTableSecondUDTF(lit(4), input_df.asTable()) + expected_df = self.spark.createDataFrame([(5,), (6,), (7,)], "filtered_id bigint") + assertDataFrameEqual(result_df, expected_df) + + # Test SQL registration and usage + self.spark.udtf.register("test_scalar_first_table_second_udtf", ScalarFirstTableSecondUDTF) + sql_result_df = self.spark.sql( + "SELECT * FROM test_scalar_first_table_second_udtf(4, TABLE(SELECT id FROM range(0, 8)))" + ) + assertDataFrameEqual(sql_result_df, expected_df) + + def test_arrow_udtf_with_table_argument_in_middle(self): + """Test Arrow UDTF with table argument in the middle of multiple scalar arguments.""" + @arrow_udtf(returnType="filtered_id bigint") + class TableInMiddleUDTF: + def eval( + self, + min_threshold: "pa.Array", + table_data: "pa.RecordBatch", + max_threshold: "pa.Array" + ) -> Iterator["pa.Table"]: + assert isinstance( + min_threshold, pa.Array + ), f"Expected pa.Array for min_threshold, got {type(min_threshold)}" + assert isinstance( + table_data, pa.RecordBatch + ), f"Expected pa.RecordBatch for table_data, got {type(table_data)}" + assert isinstance( + max_threshold, pa.Array + ), f"Expected pa.Array for max_threshold, got {type(max_threshold)}" + + min_val = min_threshold[0].as_py() + max_val = max_threshold[0].as_py() + + # Convert record batch to table + table = pa.table(table_data) + id_column = table.column("id") + + # Filter rows where min_val < id < max_val + mask = pa.compute.and_( + pa.compute.greater(id_column, pa.scalar(min_val)), + pa.compute.less(id_column, pa.scalar(max_val)) + ) + filtered_table = table.filter(mask) + + if filtered_table.num_rows > 0: + result_table = pa.table( + {"filtered_id": filtered_table.column("id")} # Keep original type + ) + yield result_table + + # Test with DataFrame API - scalar, table, scalar + input_df = self.spark.range(10) + result_df = TableInMiddleUDTF(lit(2), input_df.asTable(), lit(7)) + expected_df = self.spark.createDataFrame([(3,), (4,), (5,), (6,)], "filtered_id bigint") + assertDataFrameEqual(result_df, expected_df) + + # Test SQL registration and usage + self.spark.udtf.register("test_table_in_middle_udtf", TableInMiddleUDTF) + sql_result_df = self.spark.sql( + "SELECT * FROM test_table_in_middle_udtf(2, TABLE(SELECT id FROM range(0, 10)), 7)" + ) + assertDataFrameEqual(sql_result_df, expected_df) + + def test_arrow_udtf_with_named_arguments(self): + @arrow_udtf(returnType="result_id bigint, multiplier_used int") + class NamedArgsUDTF: + def eval( + self, + table_data: "pa.RecordBatch", + multiplier: "pa.Array" + ) -> Iterator["pa.Table"]: + assert isinstance( + table_data, pa.RecordBatch + ), f"Expected pa.RecordBatch for table_data, got {type(table_data)}" + assert isinstance( + multiplier, pa.Array + ), f"Expected pa.Array for multiplier, got {type(multiplier)}" + + multiplier_val = multiplier[0].as_py() + + # Convert record batch to table + table = pa.table(table_data) + id_column = table.column("id") + + # Multiply each id by the multiplier + multiplied_ids = pa.compute.multiply(id_column, pa.scalar(multiplier_val)) + + result_table = pa.table({ + "result_id": multiplied_ids, + "multiplier_used": pa.array([multiplier_val] * table.num_rows, type=pa.int32()) + }) + yield result_table + + # Test with DataFrame API using named arguments + # TODO(SPARK-53426): Support named table argument with DataFrame API + # input_df = self.spark.range(3) # [0, 1, 2] + # result_df = NamedArgsUDTF(table_data=input_df.asTable(), multiplier=lit(5)) Review Comment: Thanks the test can pass now! -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org