Yicong-Huang commented on code in PR #52317: URL: https://github.com/apache/spark/pull/52317#discussion_r2345904286
########## python/pyspark/sql/tests/arrow/test_arrow_udtf.py: ########## @@ -730,6 +730,496 @@ def eval(self, x: "pa.Array", y: "pa.Array") -> Iterator["pa.Table"]: expected_df2 = self.spark.createDataFrame([(7, 3, 10)], "x int, y int, sum int") assertDataFrameEqual(sql_result_df2, expected_df2) + def test_arrow_udtf_with_partition_by(self): + @arrow_udtf(returnType="partition_key int, sum_value int") + class SumUDTF: + def eval(self, table_data: "pa.RecordBatch") -> Iterator["pa.Table"]: + table = pa.table(table_data) + partition_key = pc.unique(table["partition_key"]).to_pylist() + assert ( + len(partition_key) == 1 + ), f"Expected exactly one partition key, got {partition_key}" + sum_value = pc.sum(table["value"]).as_py() + result_table = pa.table( + { + "partition_key": pa.array([partition_key[0]], type=pa.int32()), + "sum_value": pa.array([sum_value], type=pa.int32()), + } + ) + yield result_table + + test_data = [ + (1, 10), + (2, 5), + (1, 20), + (2, 15), + (1, 30), + (3, 100), + ] + input_df = self.spark.createDataFrame(test_data, "partition_key int, value int") + + self.spark.udtf.register("sum_udtf", SumUDTF) + input_df.createOrReplaceTempView("test_data") + + result_df = self.spark.sql( + """ + SELECT * FROM sum_udtf(TABLE(test_data) PARTITION BY partition_key) + """ + ) + + expected_data = [ + (1, 60), + (2, 20), + (3, 100), + ] + expected_df = self.spark.createDataFrame(expected_data, "partition_key int, sum_value int") + assertDataFrameEqual(result_df, expected_df) + + def test_arrow_udtf_with_partition_by_and_terminate(self): + @arrow_udtf(returnType="partition_key int, count int, sum_value int") + class TerminateUDTF: + def __init__(self): + self._partition_key = None + self._count = 0 + self._sum = 0 + + def eval(self, table_data: "pa.RecordBatch") -> Iterator["pa.Table"]: + import pyarrow.compute as pc + + table = pa.table(table_data) + # Track partition key + partition_keys = pc.unique(table["partition_key"]).to_pylist() + assert len(partition_keys) == 1, f"Expected one partition key, got {partition_keys}" + self._partition_key = partition_keys[0] + + # Accumulate stats but don't yield here + self._count += table.num_rows + self._sum += pc.sum(table["value"]).as_py() + # Return empty iterator - results come from terminate + return iter(()) + + def terminate(self) -> Iterator["pa.Table"]: + # Yield accumulated results for this partition + if self._partition_key is not None: + result_table = pa.table( + { + "partition_key": pa.array([self._partition_key], type=pa.int32()), + "count": pa.array([self._count], type=pa.int32()), + "sum_value": pa.array([self._sum], type=pa.int32()), + } + ) + yield result_table + + test_data = [ + (3, 50), + (1, 10), + (2, 40), + (1, 20), + (2, 30), + ] + input_df = self.spark.createDataFrame(test_data, "partition_key int, value int") + + self.spark.udtf.register("terminate_udtf", TerminateUDTF) + input_df.createOrReplaceTempView("test_data_terminate") + + result_df = self.spark.sql( + """ + SELECT * FROM terminate_udtf(TABLE(test_data_terminate) PARTITION BY partition_key) + ORDER BY partition_key + """ + ) + + expected_data = [ + (1, 2, 30), # partition 1: 2 rows, sum = 30 + (2, 2, 70), # partition 2: 2 rows, sum = 70 + (3, 1, 50), # partition 3: 1 row, sum = 50 + ] + expected_df = self.spark.createDataFrame( + expected_data, "partition_key int, count int, sum_value int" + ) + assertDataFrameEqual(result_df, expected_df) + + def test_arrow_udtf_with_partition_by_and_order_by(self): + @arrow_udtf(returnType="partition_key int, first_value int, last_value int") + class OrderByUDTF: + def __init__(self): + self._partition_key = None + self._first_value = None + self._last_value = None + + def eval(self, table_data: "pa.RecordBatch") -> Iterator["pa.Table"]: + import pyarrow.compute as pc + + table = pa.table(table_data) + partition_keys = pc.unique(table["partition_key"]).to_pylist() + assert len(partition_keys) == 1, f"Expected one partition key, got {partition_keys}" + self._partition_key = partition_keys[0] + + # Track first and last values (should be ordered) + values = table["value"].to_pylist() + if values: + if self._first_value is None: + self._first_value = values[0] + self._last_value = values[-1] + + return iter(()) + + def terminate(self) -> Iterator["pa.Table"]: + if self._partition_key is not None: + result_table = pa.table( + { + "partition_key": pa.array([self._partition_key], type=pa.int32()), + "first_value": pa.array([self._first_value], type=pa.int32()), + "last_value": pa.array([self._last_value], type=pa.int32()), + } + ) + yield result_table + + test_data = [ + (1, 30), + (1, 10), + (1, 20), + (2, 60), + (2, 40), + (2, 50), + ] + input_df = self.spark.createDataFrame(test_data, "partition_key int, value int") + + self.spark.udtf.register("order_by_udtf", OrderByUDTF) + input_df.createOrReplaceTempView("test_data_order") + + result_df = self.spark.sql( + """ + SELECT * FROM order_by_udtf( + TABLE(test_data_order) + PARTITION BY partition_key + ORDER BY value + ) + ORDER BY partition_key + """ + ) + + expected_data = [ + (1, 10, 30), # partition 1: first=10 (min), last=30 (max) after ordering + (2, 40, 60), # partition 2: first=40 (min), last=60 (max) after ordering + ] + expected_df = self.spark.createDataFrame( + expected_data, "partition_key int, first_value int, last_value int" + ) + assertDataFrameEqual(result_df, expected_df) + + def test_arrow_udtf_partition_column_removal(self): + @arrow_udtf(returnType="col1_sum int, col2_sum int") + class PartitionColumnTestUDTF: + def eval(self, table_data: "pa.RecordBatch") -> Iterator["pa.Table"]: + import pyarrow.compute as pc + + table = pa.table(table_data) + + # When partitioning by an expression like "col1 + col2", + # Catalyst adds the expression result as a new column at the beginning. + # The ArrowUDTFWithPartition._remove_partition_by_exprs method should + # remove this added column, leaving only the original table columns. + column_names = table.column_names + + # Verify we only have the original columns, not the partition expression + assert "col1" in column_names, f"Expected col1 in columns: {column_names}" + assert "col2" in column_names, f"Expected col2 in columns: {column_names}" + # The partition expression column should have been removed + assert len(column_names) == 2, ( + f"Expected only col1 and col2 after partition column removal, " + f"but got: {column_names}" + ) + + col1_sum = pc.sum(table["col1"]).as_py() + col2_sum = pc.sum(table["col2"]).as_py() + + result_table = pa.table( + { + "col1_sum": pa.array([col1_sum], type=pa.int32()), + "col2_sum": pa.array([col2_sum], type=pa.int32()), + } + ) + yield result_table Review Comment: I think https://github.com/apache/spark/pull/52317#discussion_r2345574269 might be the same issue -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org