Yicong-Huang commented on code in PR #52317: URL: https://github.com/apache/spark/pull/52317#discussion_r2345635600
########## python/pyspark/sql/tests/arrow/test_arrow_udtf.py: ########## @@ -730,6 +730,496 @@ def eval(self, x: "pa.Array", y: "pa.Array") -> Iterator["pa.Table"]: expected_df2 = self.spark.createDataFrame([(7, 3, 10)], "x int, y int, sum int") assertDataFrameEqual(sql_result_df2, expected_df2) + def test_arrow_udtf_with_partition_by(self): + @arrow_udtf(returnType="partition_key int, sum_value int") + class SumUDTF: + def eval(self, table_data: "pa.RecordBatch") -> Iterator["pa.Table"]: + table = pa.table(table_data) + partition_key = pc.unique(table["partition_key"]).to_pylist() + assert ( + len(partition_key) == 1 + ), f"Expected exactly one partition key, got {partition_key}" + sum_value = pc.sum(table["value"]).as_py() + result_table = pa.table( + { + "partition_key": pa.array([partition_key[0]], type=pa.int32()), + "sum_value": pa.array([sum_value], type=pa.int32()), + } + ) + yield result_table + + test_data = [ + (1, 10), + (2, 5), + (1, 20), + (2, 15), + (1, 30), + (3, 100), + ] + input_df = self.spark.createDataFrame(test_data, "partition_key int, value int") + + self.spark.udtf.register("sum_udtf", SumUDTF) + input_df.createOrReplaceTempView("test_data") + + result_df = self.spark.sql( + """ + SELECT * FROM sum_udtf(TABLE(test_data) PARTITION BY partition_key) + """ + ) + + expected_data = [ + (1, 60), + (2, 20), + (3, 100), + ] + expected_df = self.spark.createDataFrame(expected_data, "partition_key int, sum_value int") + assertDataFrameEqual(result_df, expected_df) + + def test_arrow_udtf_with_partition_by_and_terminate(self): + @arrow_udtf(returnType="partition_key int, count int, sum_value int") + class TerminateUDTF: + def __init__(self): + self._partition_key = None + self._count = 0 + self._sum = 0 + + def eval(self, table_data: "pa.RecordBatch") -> Iterator["pa.Table"]: + import pyarrow.compute as pc + + table = pa.table(table_data) + # Track partition key + partition_keys = pc.unique(table["partition_key"]).to_pylist() + assert len(partition_keys) == 1, f"Expected one partition key, got {partition_keys}" + self._partition_key = partition_keys[0] + + # Accumulate stats but don't yield here + self._count += table.num_rows + self._sum += pc.sum(table["value"]).as_py() + # Return empty iterator - results come from terminate + return iter(()) + + def terminate(self) -> Iterator["pa.Table"]: + # Yield accumulated results for this partition + if self._partition_key is not None: + result_table = pa.table( + { + "partition_key": pa.array([self._partition_key], type=pa.int32()), + "count": pa.array([self._count], type=pa.int32()), + "sum_value": pa.array([self._sum], type=pa.int32()), + } + ) + yield result_table + + test_data = [ + (3, 50), + (1, 10), + (2, 40), + (1, 20), + (2, 30), + ] + input_df = self.spark.createDataFrame(test_data, "partition_key int, value int") + + self.spark.udtf.register("terminate_udtf", TerminateUDTF) + input_df.createOrReplaceTempView("test_data_terminate") + + result_df = self.spark.sql( + """ + SELECT * FROM terminate_udtf(TABLE(test_data_terminate) PARTITION BY partition_key) + ORDER BY partition_key + """ + ) + + expected_data = [ + (1, 2, 30), # partition 1: 2 rows, sum = 30 + (2, 2, 70), # partition 2: 2 rows, sum = 70 + (3, 1, 50), # partition 3: 1 row, sum = 50 + ] + expected_df = self.spark.createDataFrame( + expected_data, "partition_key int, count int, sum_value int" + ) + assertDataFrameEqual(result_df, expected_df) + + def test_arrow_udtf_with_partition_by_and_order_by(self): + @arrow_udtf(returnType="partition_key int, first_value int, last_value int") + class OrderByUDTF: + def __init__(self): + self._partition_key = None + self._first_value = None + self._last_value = None + + def eval(self, table_data: "pa.RecordBatch") -> Iterator["pa.Table"]: + import pyarrow.compute as pc + + table = pa.table(table_data) + partition_keys = pc.unique(table["partition_key"]).to_pylist() + assert len(partition_keys) == 1, f"Expected one partition key, got {partition_keys}" + self._partition_key = partition_keys[0] + + # Track first and last values (should be ordered) + values = table["value"].to_pylist() + if values: + if self._first_value is None: + self._first_value = values[0] + self._last_value = values[-1] + + return iter(()) + + def terminate(self) -> Iterator["pa.Table"]: + if self._partition_key is not None: + result_table = pa.table( + { + "partition_key": pa.array([self._partition_key], type=pa.int32()), + "first_value": pa.array([self._first_value], type=pa.int32()), + "last_value": pa.array([self._last_value], type=pa.int32()), + } + ) + yield result_table + + test_data = [ + (1, 30), + (1, 10), + (1, 20), + (2, 60), + (2, 40), + (2, 50), + ] + input_df = self.spark.createDataFrame(test_data, "partition_key int, value int") + + self.spark.udtf.register("order_by_udtf", OrderByUDTF) + input_df.createOrReplaceTempView("test_data_order") + + result_df = self.spark.sql( + """ + SELECT * FROM order_by_udtf( + TABLE(test_data_order) + PARTITION BY partition_key + ORDER BY value + ) + ORDER BY partition_key + """ + ) + + expected_data = [ + (1, 10, 30), # partition 1: first=10 (min), last=30 (max) after ordering + (2, 40, 60), # partition 2: first=40 (min), last=60 (max) after ordering + ] + expected_df = self.spark.createDataFrame( + expected_data, "partition_key int, first_value int, last_value int" + ) + assertDataFrameEqual(result_df, expected_df) + + def test_arrow_udtf_partition_column_removal(self): + @arrow_udtf(returnType="col1_sum int, col2_sum int") + class PartitionColumnTestUDTF: + def eval(self, table_data: "pa.RecordBatch") -> Iterator["pa.Table"]: + import pyarrow.compute as pc + + table = pa.table(table_data) + + # When partitioning by an expression like "col1 + col2", + # Catalyst adds the expression result as a new column at the beginning. + # The ArrowUDTFWithPartition._remove_partition_by_exprs method should + # remove this added column, leaving only the original table columns. + column_names = table.column_names + + # Verify we only have the original columns, not the partition expression + assert "col1" in column_names, f"Expected col1 in columns: {column_names}" + assert "col2" in column_names, f"Expected col2 in columns: {column_names}" + # The partition expression column should have been removed + assert len(column_names) == 2, ( + f"Expected only col1 and col2 after partition column removal, " + f"but got: {column_names}" + ) + + col1_sum = pc.sum(table["col1"]).as_py() + col2_sum = pc.sum(table["col2"]).as_py() + + result_table = pa.table( + { + "col1_sum": pa.array([col1_sum], type=pa.int32()), + "col2_sum": pa.array([col2_sum], type=pa.int32()), + } + ) + yield result_table Review Comment: For each batch, it is conducting an aggregation and return a sum. Then if the input has multiple batches, for each partition, it would result in multiple output tuples, one for each batch, right? Is this intentional? ########## python/pyspark/worker.py: ########## @@ -1514,10 +1514,288 @@ def _remove_partition_by_exprs(self, arg: Any) -> Any: else: return arg + class ArrowUDTFWithPartition: + """ + Implements logic for an Arrow UDTF (SQL_ARROW_UDTF) that accepts a TABLE argument + with one or more PARTITION BY expressions. + + Arrow UDTFs receive data as PyArrow RecordBatch objects instead of individual Row + objects. + + Example table: + CREATE TABLE t (c1 INT, c2 INT) USING delta; + + Example queries: + SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2); + partition_child_indexes: 0, 1. + + SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2 + 4); + partition_child_indexes: 0, 2 (adds a projection for "c2 + 4"). + """ + + def __init__(self, create_udtf: Callable, partition_child_indexes: list): + """ + Create a new instance that wraps the provided Arrow UDTF with partitioning + logic. + + Parameters + ---------- + create_udtf: function + Function that creates a new instance of the Arrow UDTF to invoke. + partition_child_indexes: list + Zero-based indexes of input-table columns that contain projected + partitioning expressions. + """ + self._create_udtf: Callable = create_udtf + self._udtf = create_udtf() + self._partition_child_indexes: list = partition_child_indexes + # Track last partition key from previous batch + self._last_partition_key: Optional[Tuple[Any, ...]] = None + self._eval_raised_skip_rest_of_input_table: bool = False + + def eval(self, *args, **kwargs) -> Iterator: + """Handle partitioning logic for Arrow UDTFs that receive RecordBatch objects.""" + import pyarrow as pa + + # Get the original batch with partition columns + original_batch = self._get_table_arg(list(args) + list(kwargs.values())) + if not isinstance(original_batch, pa.RecordBatch): + # Arrow UDTFs with PARTITION BY must have a TABLE argument that + # results in a PyArrow RecordBatch + raise PySparkRuntimeError( + errorClass="INVALID_ARROW_UDTF_TABLE_ARGUMENT", + messageParameters={ + "actual_type": str(type(original_batch)) + if original_batch is not None + else "None" + }, + ) + + # Remove partition columns to get the filtered arguments + filtered_args = [self._remove_partition_by_exprs(arg) for arg in args] + filtered_kwargs = { + key: self._remove_partition_by_exprs(value) for (key, value) in kwargs.items() + } + + # Get the filtered RecordBatch (without partition columns) + filtered_batch = self._get_table_arg(filtered_args + list(filtered_kwargs.values())) + + # Process the RecordBatch by partitions + yield from self._process_arrow_batch_by_partitions( + original_batch, filtered_batch, filtered_args, filtered_kwargs + ) + + def _process_arrow_batch_by_partitions( + self, original_batch, filtered_batch, filtered_args, filtered_kwargs + ) -> Iterator: + """Process an Arrow RecordBatch by splitting it into partitions. + + Since Catalyst guarantees that rows with the same partition key are contiguous, + we can use efficient boundary detection instead of group_by. Review Comment: Consider add a check? ########## python/pyspark/worker.py: ########## @@ -1514,10 +1514,288 @@ def _remove_partition_by_exprs(self, arg: Any) -> Any: else: return arg + class ArrowUDTFWithPartition: + """ + Implements logic for an Arrow UDTF (SQL_ARROW_UDTF) that accepts a TABLE argument + with one or more PARTITION BY expressions. + + Arrow UDTFs receive data as PyArrow RecordBatch objects instead of individual Row + objects. + + Example table: + CREATE TABLE t (c1 INT, c2 INT) USING delta; + + Example queries: + SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2); + partition_child_indexes: 0, 1. + + SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2 + 4); + partition_child_indexes: 0, 2 (adds a projection for "c2 + 4"). + """ + + def __init__(self, create_udtf: Callable, partition_child_indexes: list): + """ + Create a new instance that wraps the provided Arrow UDTF with partitioning + logic. + + Parameters + ---------- + create_udtf: function + Function that creates a new instance of the Arrow UDTF to invoke. + partition_child_indexes: list + Zero-based indexes of input-table columns that contain projected + partitioning expressions. + """ + self._create_udtf: Callable = create_udtf + self._udtf = create_udtf() + self._partition_child_indexes: list = partition_child_indexes + # Track last partition key from previous batch + self._last_partition_key: Optional[Tuple[Any, ...]] = None + self._eval_raised_skip_rest_of_input_table: bool = False + + def eval(self, *args, **kwargs) -> Iterator: + """Handle partitioning logic for Arrow UDTFs that receive RecordBatch objects.""" + import pyarrow as pa + + # Get the original batch with partition columns + original_batch = self._get_table_arg(list(args) + list(kwargs.values())) + if not isinstance(original_batch, pa.RecordBatch): + # Arrow UDTFs with PARTITION BY must have a TABLE argument that + # results in a PyArrow RecordBatch + raise PySparkRuntimeError( + errorClass="INVALID_ARROW_UDTF_TABLE_ARGUMENT", + messageParameters={ + "actual_type": str(type(original_batch)) + if original_batch is not None + else "None" Review Comment: maybe also add the expected type in the messageParameters? ########## python/pyspark/worker.py: ########## @@ -1514,10 +1514,288 @@ def _remove_partition_by_exprs(self, arg: Any) -> Any: else: return arg + class ArrowUDTFWithPartition: + """ + Implements logic for an Arrow UDTF (SQL_ARROW_UDTF) that accepts a TABLE argument + with one or more PARTITION BY expressions. + + Arrow UDTFs receive data as PyArrow RecordBatch objects instead of individual Row + objects. + + Example table: + CREATE TABLE t (c1 INT, c2 INT) USING delta; + + Example queries: + SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2); + partition_child_indexes: 0, 1. + + SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2 + 4); + partition_child_indexes: 0, 2 (adds a projection for "c2 + 4"). + """ + + def __init__(self, create_udtf: Callable, partition_child_indexes: list): + """ + Create a new instance that wraps the provided Arrow UDTF with partitioning + logic. + + Parameters + ---------- + create_udtf: function + Function that creates a new instance of the Arrow UDTF to invoke. + partition_child_indexes: list + Zero-based indexes of input-table columns that contain projected + partitioning expressions. + """ + self._create_udtf: Callable = create_udtf + self._udtf = create_udtf() + self._partition_child_indexes: list = partition_child_indexes + # Track last partition key from previous batch + self._last_partition_key: Optional[Tuple[Any, ...]] = None + self._eval_raised_skip_rest_of_input_table: bool = False + + def eval(self, *args, **kwargs) -> Iterator: + """Handle partitioning logic for Arrow UDTFs that receive RecordBatch objects.""" + import pyarrow as pa + + # Get the original batch with partition columns + original_batch = self._get_table_arg(list(args) + list(kwargs.values())) Review Comment: I checked the implementation of `self._get_table_arg` but this way to filter the batch sounds not error-proofed. how do we ensure there exist exactly one batch in the args? ########## python/pyspark/worker.py: ########## @@ -1514,10 +1514,288 @@ def _remove_partition_by_exprs(self, arg: Any) -> Any: else: return arg + class ArrowUDTFWithPartition: + """ + Implements logic for an Arrow UDTF (SQL_ARROW_UDTF) that accepts a TABLE argument + with one or more PARTITION BY expressions. + + Arrow UDTFs receive data as PyArrow RecordBatch objects instead of individual Row + objects. + + Example table: + CREATE TABLE t (c1 INT, c2 INT) USING delta; + + Example queries: + SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2); + partition_child_indexes: 0, 1. + + SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2 + 4); + partition_child_indexes: 0, 2 (adds a projection for "c2 + 4"). Review Comment: It is not straightforward to me that why the index of the second child becomes 2. is it because the pre-projection partition has index 1? ########## python/pyspark/worker.py: ########## @@ -1514,10 +1514,288 @@ def _remove_partition_by_exprs(self, arg: Any) -> Any: else: return arg + class ArrowUDTFWithPartition: + """ + Implements logic for an Arrow UDTF (SQL_ARROW_UDTF) that accepts a TABLE argument + with one or more PARTITION BY expressions. + + Arrow UDTFs receive data as PyArrow RecordBatch objects instead of individual Row + objects. + + Example table: + CREATE TABLE t (c1 INT, c2 INT) USING delta; + + Example queries: + SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2); + partition_child_indexes: 0, 1. + + SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2 + 4); + partition_child_indexes: 0, 2 (adds a projection for "c2 + 4"). + """ + + def __init__(self, create_udtf: Callable, partition_child_indexes: list): + """ + Create a new instance that wraps the provided Arrow UDTF with partitioning + logic. + + Parameters + ---------- + create_udtf: function + Function that creates a new instance of the Arrow UDTF to invoke. + partition_child_indexes: list + Zero-based indexes of input-table columns that contain projected + partitioning expressions. + """ + self._create_udtf: Callable = create_udtf + self._udtf = create_udtf() + self._partition_child_indexes: list = partition_child_indexes + # Track last partition key from previous batch + self._last_partition_key: Optional[Tuple[Any, ...]] = None + self._eval_raised_skip_rest_of_input_table: bool = False + + def eval(self, *args, **kwargs) -> Iterator: + """Handle partitioning logic for Arrow UDTFs that receive RecordBatch objects.""" + import pyarrow as pa + + # Get the original batch with partition columns + original_batch = self._get_table_arg(list(args) + list(kwargs.values())) + if not isinstance(original_batch, pa.RecordBatch): + # Arrow UDTFs with PARTITION BY must have a TABLE argument that + # results in a PyArrow RecordBatch + raise PySparkRuntimeError( + errorClass="INVALID_ARROW_UDTF_TABLE_ARGUMENT", + messageParameters={ + "actual_type": str(type(original_batch)) + if original_batch is not None + else "None" + }, + ) + + # Remove partition columns to get the filtered arguments + filtered_args = [self._remove_partition_by_exprs(arg) for arg in args] + filtered_kwargs = { + key: self._remove_partition_by_exprs(value) for (key, value) in kwargs.items() + } + + # Get the filtered RecordBatch (without partition columns) + filtered_batch = self._get_table_arg(filtered_args + list(filtered_kwargs.values())) + + # Process the RecordBatch by partitions + yield from self._process_arrow_batch_by_partitions( + original_batch, filtered_batch, filtered_args, filtered_kwargs + ) + + def _process_arrow_batch_by_partitions( + self, original_batch, filtered_batch, filtered_args, filtered_kwargs + ) -> Iterator: + """Process an Arrow RecordBatch by splitting it into partitions. + + Since Catalyst guarantees that rows with the same partition key are contiguous, + we can use efficient boundary detection instead of group_by. + + Handles two scenarios: + 1. Multiple partitions within a single RecordBatch (using boundary detection) + 2. Same partition key continuing from previous RecordBatch (tracking state) + """ + import pyarrow as pa + + if self._partition_child_indexes: + # Detect partition boundaries. + boundaries = self._detect_partition_boundaries(original_batch) + + # Process each contiguous partition + for i in range(len(boundaries) - 1): + start_idx = boundaries[i] + end_idx = boundaries[i + 1] + + # Get the partition key for this segment + partition_key = tuple( + original_batch.column(idx)[start_idx].as_py() + for idx in self._partition_child_indexes + ) + + # Check if this is a continuation of the previous batch's partition + is_new_partition = ( + self._last_partition_key is not None + and partition_key != self._last_partition_key + ) + + if is_new_partition: + # Previous partition ended, call terminate + if hasattr(self._udtf, "terminate"): + terminate_result = self._udtf.terminate() + if terminate_result is not None: + for table in terminate_result: + yield table + # Create new UDTF instance for new partition + self._udtf = self._create_udtf() + self._eval_raised_skip_rest_of_input_table = False + + # Slice the filtered batch for this partition + partition_batch = filtered_batch.slice(start_idx, end_idx - start_idx) + + # Update the last partition key + self._last_partition_key = partition_key + + # Update filtered args to use the partition batch + partition_filtered_args = [] + for arg in filtered_args: + if isinstance(arg, pa.RecordBatch): + partition_filtered_args.append(partition_batch) + else: + partition_filtered_args.append(arg) + + partition_filtered_kwargs = {} + for key, value in filtered_kwargs.items(): + if isinstance(value, pa.RecordBatch): + partition_filtered_kwargs[key] = partition_batch + else: + partition_filtered_kwargs[key] = value Review Comment: seems again we have a strong assumption that there exist exactly one RecordBath in all args. I don't know if we can make this assumption more explicit and robust. ########## python/pyspark/worker.py: ########## @@ -1514,10 +1514,288 @@ def _remove_partition_by_exprs(self, arg: Any) -> Any: else: return arg + class ArrowUDTFWithPartition: + """ + Implements logic for an Arrow UDTF (SQL_ARROW_UDTF) that accepts a TABLE argument + with one or more PARTITION BY expressions. + + Arrow UDTFs receive data as PyArrow RecordBatch objects instead of individual Row + objects. + + Example table: + CREATE TABLE t (c1 INT, c2 INT) USING delta; + + Example queries: + SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2); + partition_child_indexes: 0, 1. + + SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2 + 4); + partition_child_indexes: 0, 2 (adds a projection for "c2 + 4"). + """ + + def __init__(self, create_udtf: Callable, partition_child_indexes: list): + """ + Create a new instance that wraps the provided Arrow UDTF with partitioning + logic. + + Parameters + ---------- + create_udtf: function + Function that creates a new instance of the Arrow UDTF to invoke. + partition_child_indexes: list + Zero-based indexes of input-table columns that contain projected + partitioning expressions. + """ + self._create_udtf: Callable = create_udtf + self._udtf = create_udtf() + self._partition_child_indexes: list = partition_child_indexes + # Track last partition key from previous batch + self._last_partition_key: Optional[Tuple[Any, ...]] = None + self._eval_raised_skip_rest_of_input_table: bool = False + + def eval(self, *args, **kwargs) -> Iterator: + """Handle partitioning logic for Arrow UDTFs that receive RecordBatch objects.""" + import pyarrow as pa + + # Get the original batch with partition columns + original_batch = self._get_table_arg(list(args) + list(kwargs.values())) + if not isinstance(original_batch, pa.RecordBatch): + # Arrow UDTFs with PARTITION BY must have a TABLE argument that + # results in a PyArrow RecordBatch + raise PySparkRuntimeError( + errorClass="INVALID_ARROW_UDTF_TABLE_ARGUMENT", + messageParameters={ + "actual_type": str(type(original_batch)) + if original_batch is not None + else "None" + }, + ) + + # Remove partition columns to get the filtered arguments + filtered_args = [self._remove_partition_by_exprs(arg) for arg in args] + filtered_kwargs = { + key: self._remove_partition_by_exprs(value) for (key, value) in kwargs.items() + } + + # Get the filtered RecordBatch (without partition columns) + filtered_batch = self._get_table_arg(filtered_args + list(filtered_kwargs.values())) Review Comment: do we guarantee that there will be exact one filtered batch? ########## python/pyspark/worker.py: ########## @@ -1514,10 +1514,288 @@ def _remove_partition_by_exprs(self, arg: Any) -> Any: else: return arg + class ArrowUDTFWithPartition: + """ + Implements logic for an Arrow UDTF (SQL_ARROW_UDTF) that accepts a TABLE argument + with one or more PARTITION BY expressions. + + Arrow UDTFs receive data as PyArrow RecordBatch objects instead of individual Row + objects. + + Example table: + CREATE TABLE t (c1 INT, c2 INT) USING delta; + + Example queries: + SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2); + partition_child_indexes: 0, 1. + + SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2 + 4); + partition_child_indexes: 0, 2 (adds a projection for "c2 + 4"). + """ + + def __init__(self, create_udtf: Callable, partition_child_indexes: list): + """ + Create a new instance that wraps the provided Arrow UDTF with partitioning + logic. + + Parameters + ---------- + create_udtf: function + Function that creates a new instance of the Arrow UDTF to invoke. + partition_child_indexes: list + Zero-based indexes of input-table columns that contain projected + partitioning expressions. + """ + self._create_udtf: Callable = create_udtf + self._udtf = create_udtf() + self._partition_child_indexes: list = partition_child_indexes + # Track last partition key from previous batch + self._last_partition_key: Optional[Tuple[Any, ...]] = None + self._eval_raised_skip_rest_of_input_table: bool = False + + def eval(self, *args, **kwargs) -> Iterator: + """Handle partitioning logic for Arrow UDTFs that receive RecordBatch objects.""" + import pyarrow as pa + + # Get the original batch with partition columns + original_batch = self._get_table_arg(list(args) + list(kwargs.values())) + if not isinstance(original_batch, pa.RecordBatch): + # Arrow UDTFs with PARTITION BY must have a TABLE argument that + # results in a PyArrow RecordBatch + raise PySparkRuntimeError( + errorClass="INVALID_ARROW_UDTF_TABLE_ARGUMENT", + messageParameters={ + "actual_type": str(type(original_batch)) + if original_batch is not None + else "None" + }, + ) + + # Remove partition columns to get the filtered arguments + filtered_args = [self._remove_partition_by_exprs(arg) for arg in args] + filtered_kwargs = { + key: self._remove_partition_by_exprs(value) for (key, value) in kwargs.items() + } + + # Get the filtered RecordBatch (without partition columns) + filtered_batch = self._get_table_arg(filtered_args + list(filtered_kwargs.values())) + + # Process the RecordBatch by partitions + yield from self._process_arrow_batch_by_partitions( + original_batch, filtered_batch, filtered_args, filtered_kwargs + ) + + def _process_arrow_batch_by_partitions( + self, original_batch, filtered_batch, filtered_args, filtered_kwargs + ) -> Iterator: + """Process an Arrow RecordBatch by splitting it into partitions. + + Since Catalyst guarantees that rows with the same partition key are contiguous, + we can use efficient boundary detection instead of group_by. + + Handles two scenarios: + 1. Multiple partitions within a single RecordBatch (using boundary detection) + 2. Same partition key continuing from previous RecordBatch (tracking state) + """ + import pyarrow as pa + + if self._partition_child_indexes: + # Detect partition boundaries. + boundaries = self._detect_partition_boundaries(original_batch) + + # Process each contiguous partition + for i in range(len(boundaries) - 1): + start_idx = boundaries[i] + end_idx = boundaries[i + 1] + + # Get the partition key for this segment + partition_key = tuple( + original_batch.column(idx)[start_idx].as_py() + for idx in self._partition_child_indexes + ) + + # Check if this is a continuation of the previous batch's partition + is_new_partition = ( + self._last_partition_key is not None + and partition_key != self._last_partition_key + ) + + if is_new_partition: + # Previous partition ended, call terminate + if hasattr(self._udtf, "terminate"): Review Comment: Have we considered making a NO_OP base terminate() and call it anyway? ########## python/pyspark/worker.py: ########## @@ -1514,10 +1514,288 @@ def _remove_partition_by_exprs(self, arg: Any) -> Any: else: return arg + class ArrowUDTFWithPartition: + """ + Implements logic for an Arrow UDTF (SQL_ARROW_UDTF) that accepts a TABLE argument + with one or more PARTITION BY expressions. + + Arrow UDTFs receive data as PyArrow RecordBatch objects instead of individual Row + objects. + + Example table: + CREATE TABLE t (c1 INT, c2 INT) USING delta; + + Example queries: + SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2); + partition_child_indexes: 0, 1. + + SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2 + 4); + partition_child_indexes: 0, 2 (adds a projection for "c2 + 4"). + """ + + def __init__(self, create_udtf: Callable, partition_child_indexes: list): + """ + Create a new instance that wraps the provided Arrow UDTF with partitioning + logic. + + Parameters + ---------- + create_udtf: function + Function that creates a new instance of the Arrow UDTF to invoke. + partition_child_indexes: list + Zero-based indexes of input-table columns that contain projected + partitioning expressions. + """ + self._create_udtf: Callable = create_udtf + self._udtf = create_udtf() + self._partition_child_indexes: list = partition_child_indexes + # Track last partition key from previous batch + self._last_partition_key: Optional[Tuple[Any, ...]] = None + self._eval_raised_skip_rest_of_input_table: bool = False + + def eval(self, *args, **kwargs) -> Iterator: + """Handle partitioning logic for Arrow UDTFs that receive RecordBatch objects.""" + import pyarrow as pa + + # Get the original batch with partition columns + original_batch = self._get_table_arg(list(args) + list(kwargs.values())) + if not isinstance(original_batch, pa.RecordBatch): + # Arrow UDTFs with PARTITION BY must have a TABLE argument that + # results in a PyArrow RecordBatch + raise PySparkRuntimeError( + errorClass="INVALID_ARROW_UDTF_TABLE_ARGUMENT", + messageParameters={ + "actual_type": str(type(original_batch)) + if original_batch is not None + else "None" + }, + ) + + # Remove partition columns to get the filtered arguments + filtered_args = [self._remove_partition_by_exprs(arg) for arg in args] + filtered_kwargs = { + key: self._remove_partition_by_exprs(value) for (key, value) in kwargs.items() + } + + # Get the filtered RecordBatch (without partition columns) + filtered_batch = self._get_table_arg(filtered_args + list(filtered_kwargs.values())) + + # Process the RecordBatch by partitions + yield from self._process_arrow_batch_by_partitions( + original_batch, filtered_batch, filtered_args, filtered_kwargs + ) + + def _process_arrow_batch_by_partitions( + self, original_batch, filtered_batch, filtered_args, filtered_kwargs + ) -> Iterator: + """Process an Arrow RecordBatch by splitting it into partitions. + + Since Catalyst guarantees that rows with the same partition key are contiguous, + we can use efficient boundary detection instead of group_by. + + Handles two scenarios: + 1. Multiple partitions within a single RecordBatch (using boundary detection) + 2. Same partition key continuing from previous RecordBatch (tracking state) + """ + import pyarrow as pa + + if self._partition_child_indexes: + # Detect partition boundaries. + boundaries = self._detect_partition_boundaries(original_batch) + + # Process each contiguous partition + for i in range(len(boundaries) - 1): + start_idx = boundaries[i] + end_idx = boundaries[i + 1] + + # Get the partition key for this segment + partition_key = tuple( + original_batch.column(idx)[start_idx].as_py() + for idx in self._partition_child_indexes + ) + + # Check if this is a continuation of the previous batch's partition + is_new_partition = ( + self._last_partition_key is not None + and partition_key != self._last_partition_key + ) + + if is_new_partition: + # Previous partition ended, call terminate + if hasattr(self._udtf, "terminate"): + terminate_result = self._udtf.terminate() + if terminate_result is not None: + for table in terminate_result: + yield table + # Create new UDTF instance for new partition + self._udtf = self._create_udtf() + self._eval_raised_skip_rest_of_input_table = False + + # Slice the filtered batch for this partition + partition_batch = filtered_batch.slice(start_idx, end_idx - start_idx) + + # Update the last partition key + self._last_partition_key = partition_key + + # Update filtered args to use the partition batch + partition_filtered_args = [] + for arg in filtered_args: + if isinstance(arg, pa.RecordBatch): + partition_filtered_args.append(partition_batch) + else: + partition_filtered_args.append(arg) + + partition_filtered_kwargs = {} + for key, value in filtered_kwargs.items(): + if isinstance(value, pa.RecordBatch): + partition_filtered_kwargs[key] = partition_batch + else: + partition_filtered_kwargs[key] = value + + # Call the UDTF with this partition's data + if not self._eval_raised_skip_rest_of_input_table: + try: + result = self._udtf.eval( + *partition_filtered_args, **partition_filtered_kwargs + ) + if result is not None: + for table in result: + yield table + except SkipRestOfInputTableException: + # Skip remaining rows in this partition + self._eval_raised_skip_rest_of_input_table = True + + # Don't terminate here - let the next batch or final terminate handle it + else: + # No partitions, process the entire batch as one group + try: + result = self._udtf.eval(*filtered_args, **filtered_kwargs) + if result is not None: + # result is an iterator of PyArrow Tables (for Arrow UDTFs) + for table in result: + yield table + except SkipRestOfInputTableException: + pass Review Comment: add a log entry for this case? ########## python/pyspark/worker.py: ########## @@ -1514,10 +1514,288 @@ def _remove_partition_by_exprs(self, arg: Any) -> Any: else: return arg + class ArrowUDTFWithPartition: + """ + Implements logic for an Arrow UDTF (SQL_ARROW_UDTF) that accepts a TABLE argument + with one or more PARTITION BY expressions. + + Arrow UDTFs receive data as PyArrow RecordBatch objects instead of individual Row + objects. + + Example table: + CREATE TABLE t (c1 INT, c2 INT) USING delta; + + Example queries: + SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2); + partition_child_indexes: 0, 1. + + SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2 + 4); + partition_child_indexes: 0, 2 (adds a projection for "c2 + 4"). + """ + + def __init__(self, create_udtf: Callable, partition_child_indexes: list): + """ + Create a new instance that wraps the provided Arrow UDTF with partitioning + logic. + + Parameters + ---------- + create_udtf: function + Function that creates a new instance of the Arrow UDTF to invoke. + partition_child_indexes: list + Zero-based indexes of input-table columns that contain projected + partitioning expressions. + """ + self._create_udtf: Callable = create_udtf + self._udtf = create_udtf() + self._partition_child_indexes: list = partition_child_indexes + # Track last partition key from previous batch + self._last_partition_key: Optional[Tuple[Any, ...]] = None + self._eval_raised_skip_rest_of_input_table: bool = False + + def eval(self, *args, **kwargs) -> Iterator: + """Handle partitioning logic for Arrow UDTFs that receive RecordBatch objects.""" + import pyarrow as pa + + # Get the original batch with partition columns + original_batch = self._get_table_arg(list(args) + list(kwargs.values())) + if not isinstance(original_batch, pa.RecordBatch): + # Arrow UDTFs with PARTITION BY must have a TABLE argument that + # results in a PyArrow RecordBatch + raise PySparkRuntimeError( + errorClass="INVALID_ARROW_UDTF_TABLE_ARGUMENT", + messageParameters={ + "actual_type": str(type(original_batch)) + if original_batch is not None + else "None" + }, + ) + + # Remove partition columns to get the filtered arguments + filtered_args = [self._remove_partition_by_exprs(arg) for arg in args] + filtered_kwargs = { + key: self._remove_partition_by_exprs(value) for (key, value) in kwargs.items() + } + + # Get the filtered RecordBatch (without partition columns) + filtered_batch = self._get_table_arg(filtered_args + list(filtered_kwargs.values())) + + # Process the RecordBatch by partitions + yield from self._process_arrow_batch_by_partitions( + original_batch, filtered_batch, filtered_args, filtered_kwargs + ) + + def _process_arrow_batch_by_partitions( + self, original_batch, filtered_batch, filtered_args, filtered_kwargs + ) -> Iterator: + """Process an Arrow RecordBatch by splitting it into partitions. + + Since Catalyst guarantees that rows with the same partition key are contiguous, + we can use efficient boundary detection instead of group_by. + + Handles two scenarios: + 1. Multiple partitions within a single RecordBatch (using boundary detection) + 2. Same partition key continuing from previous RecordBatch (tracking state) + """ + import pyarrow as pa + + if self._partition_child_indexes: + # Detect partition boundaries. + boundaries = self._detect_partition_boundaries(original_batch) + + # Process each contiguous partition + for i in range(len(boundaries) - 1): + start_idx = boundaries[i] + end_idx = boundaries[i + 1] + + # Get the partition key for this segment + partition_key = tuple( + original_batch.column(idx)[start_idx].as_py() + for idx in self._partition_child_indexes + ) + + # Check if this is a continuation of the previous batch's partition + is_new_partition = ( + self._last_partition_key is not None + and partition_key != self._last_partition_key + ) + + if is_new_partition: + # Previous partition ended, call terminate + if hasattr(self._udtf, "terminate"): + terminate_result = self._udtf.terminate() + if terminate_result is not None: + for table in terminate_result: + yield table Review Comment: yield from? ########## python/pyspark/worker.py: ########## @@ -1514,10 +1514,288 @@ def _remove_partition_by_exprs(self, arg: Any) -> Any: else: return arg + class ArrowUDTFWithPartition: + """ + Implements logic for an Arrow UDTF (SQL_ARROW_UDTF) that accepts a TABLE argument + with one or more PARTITION BY expressions. + + Arrow UDTFs receive data as PyArrow RecordBatch objects instead of individual Row + objects. + + Example table: + CREATE TABLE t (c1 INT, c2 INT) USING delta; + + Example queries: + SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2); + partition_child_indexes: 0, 1. + + SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2 + 4); + partition_child_indexes: 0, 2 (adds a projection for "c2 + 4"). + """ + + def __init__(self, create_udtf: Callable, partition_child_indexes: list): + """ + Create a new instance that wraps the provided Arrow UDTF with partitioning + logic. + + Parameters + ---------- + create_udtf: function + Function that creates a new instance of the Arrow UDTF to invoke. + partition_child_indexes: list + Zero-based indexes of input-table columns that contain projected + partitioning expressions. + """ + self._create_udtf: Callable = create_udtf + self._udtf = create_udtf() + self._partition_child_indexes: list = partition_child_indexes + # Track last partition key from previous batch + self._last_partition_key: Optional[Tuple[Any, ...]] = None + self._eval_raised_skip_rest_of_input_table: bool = False + + def eval(self, *args, **kwargs) -> Iterator: + """Handle partitioning logic for Arrow UDTFs that receive RecordBatch objects.""" + import pyarrow as pa + + # Get the original batch with partition columns + original_batch = self._get_table_arg(list(args) + list(kwargs.values())) + if not isinstance(original_batch, pa.RecordBatch): + # Arrow UDTFs with PARTITION BY must have a TABLE argument that + # results in a PyArrow RecordBatch + raise PySparkRuntimeError( + errorClass="INVALID_ARROW_UDTF_TABLE_ARGUMENT", + messageParameters={ + "actual_type": str(type(original_batch)) + if original_batch is not None + else "None" + }, + ) + + # Remove partition columns to get the filtered arguments + filtered_args = [self._remove_partition_by_exprs(arg) for arg in args] + filtered_kwargs = { + key: self._remove_partition_by_exprs(value) for (key, value) in kwargs.items() + } + + # Get the filtered RecordBatch (without partition columns) + filtered_batch = self._get_table_arg(filtered_args + list(filtered_kwargs.values())) + + # Process the RecordBatch by partitions + yield from self._process_arrow_batch_by_partitions( + original_batch, filtered_batch, filtered_args, filtered_kwargs + ) + + def _process_arrow_batch_by_partitions( + self, original_batch, filtered_batch, filtered_args, filtered_kwargs + ) -> Iterator: + """Process an Arrow RecordBatch by splitting it into partitions. + + Since Catalyst guarantees that rows with the same partition key are contiguous, + we can use efficient boundary detection instead of group_by. + + Handles two scenarios: + 1. Multiple partitions within a single RecordBatch (using boundary detection) + 2. Same partition key continuing from previous RecordBatch (tracking state) + """ + import pyarrow as pa + + if self._partition_child_indexes: + # Detect partition boundaries. + boundaries = self._detect_partition_boundaries(original_batch) + + # Process each contiguous partition + for i in range(len(boundaries) - 1): + start_idx = boundaries[i] + end_idx = boundaries[i + 1] + + # Get the partition key for this segment + partition_key = tuple( + original_batch.column(idx)[start_idx].as_py() + for idx in self._partition_child_indexes + ) + + # Check if this is a continuation of the previous batch's partition + is_new_partition = ( + self._last_partition_key is not None + and partition_key != self._last_partition_key + ) + + if is_new_partition: + # Previous partition ended, call terminate + if hasattr(self._udtf, "terminate"): + terminate_result = self._udtf.terminate() + if terminate_result is not None: + for table in terminate_result: + yield table + # Create new UDTF instance for new partition + self._udtf = self._create_udtf() + self._eval_raised_skip_rest_of_input_table = False + + # Slice the filtered batch for this partition + partition_batch = filtered_batch.slice(start_idx, end_idx - start_idx) + + # Update the last partition key + self._last_partition_key = partition_key + + # Update filtered args to use the partition batch + partition_filtered_args = [] + for arg in filtered_args: + if isinstance(arg, pa.RecordBatch): + partition_filtered_args.append(partition_batch) + else: + partition_filtered_args.append(arg) + + partition_filtered_kwargs = {} + for key, value in filtered_kwargs.items(): + if isinstance(value, pa.RecordBatch): + partition_filtered_kwargs[key] = partition_batch + else: + partition_filtered_kwargs[key] = value + + # Call the UDTF with this partition's data + if not self._eval_raised_skip_rest_of_input_table: + try: + result = self._udtf.eval( + *partition_filtered_args, **partition_filtered_kwargs + ) + if result is not None: + for table in result: + yield table + except SkipRestOfInputTableException: + # Skip remaining rows in this partition + self._eval_raised_skip_rest_of_input_table = True + + # Don't terminate here - let the next batch or final terminate handle it + else: + # No partitions, process the entire batch as one group + try: + result = self._udtf.eval(*filtered_args, **filtered_kwargs) + if result is not None: + # result is an iterator of PyArrow Tables (for Arrow UDTFs) + for table in result: + yield table + except SkipRestOfInputTableException: + pass + + def terminate(self) -> Iterator: + if hasattr(self._udtf, "terminate"): + return self._udtf.terminate() + return iter(()) + + def cleanup(self) -> None: + if hasattr(self._udtf, "cleanup"): + self._udtf.cleanup() + + def _get_table_arg(self, inputs: list): + """Get the table argument (RecordBatch) from the inputs list. + + For Arrow UDTFs with TABLE arguments, we can guarantee the table argument + will be a pa.RecordBatch, not a Row. + """ + import pyarrow as pa + + # Find the RecordBatch in the arguments + for arg in inputs: + if isinstance(arg, pa.RecordBatch): + return arg + + # This shouldn't happen for Arrow UDTFs with TABLE arguments + return None Review Comment: this is quite weak. also there is no handling of such a case. ########## python/pyspark/worker.py: ########## @@ -1514,10 +1514,288 @@ def _remove_partition_by_exprs(self, arg: Any) -> Any: else: return arg + class ArrowUDTFWithPartition: + """ + Implements logic for an Arrow UDTF (SQL_ARROW_UDTF) that accepts a TABLE argument + with one or more PARTITION BY expressions. + + Arrow UDTFs receive data as PyArrow RecordBatch objects instead of individual Row + objects. + + Example table: + CREATE TABLE t (c1 INT, c2 INT) USING delta; + + Example queries: + SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2); + partition_child_indexes: 0, 1. + + SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2 + 4); + partition_child_indexes: 0, 2 (adds a projection for "c2 + 4"). + """ + + def __init__(self, create_udtf: Callable, partition_child_indexes: list): + """ + Create a new instance that wraps the provided Arrow UDTF with partitioning + logic. + + Parameters + ---------- + create_udtf: function + Function that creates a new instance of the Arrow UDTF to invoke. + partition_child_indexes: list + Zero-based indexes of input-table columns that contain projected + partitioning expressions. + """ + self._create_udtf: Callable = create_udtf + self._udtf = create_udtf() + self._partition_child_indexes: list = partition_child_indexes + # Track last partition key from previous batch + self._last_partition_key: Optional[Tuple[Any, ...]] = None + self._eval_raised_skip_rest_of_input_table: bool = False + + def eval(self, *args, **kwargs) -> Iterator: + """Handle partitioning logic for Arrow UDTFs that receive RecordBatch objects.""" + import pyarrow as pa + + # Get the original batch with partition columns + original_batch = self._get_table_arg(list(args) + list(kwargs.values())) + if not isinstance(original_batch, pa.RecordBatch): + # Arrow UDTFs with PARTITION BY must have a TABLE argument that + # results in a PyArrow RecordBatch + raise PySparkRuntimeError( + errorClass="INVALID_ARROW_UDTF_TABLE_ARGUMENT", + messageParameters={ + "actual_type": str(type(original_batch)) + if original_batch is not None + else "None" + }, + ) + + # Remove partition columns to get the filtered arguments + filtered_args = [self._remove_partition_by_exprs(arg) for arg in args] + filtered_kwargs = { + key: self._remove_partition_by_exprs(value) for (key, value) in kwargs.items() + } + + # Get the filtered RecordBatch (without partition columns) + filtered_batch = self._get_table_arg(filtered_args + list(filtered_kwargs.values())) + + # Process the RecordBatch by partitions + yield from self._process_arrow_batch_by_partitions( + original_batch, filtered_batch, filtered_args, filtered_kwargs + ) + + def _process_arrow_batch_by_partitions( + self, original_batch, filtered_batch, filtered_args, filtered_kwargs + ) -> Iterator: + """Process an Arrow RecordBatch by splitting it into partitions. + + Since Catalyst guarantees that rows with the same partition key are contiguous, + we can use efficient boundary detection instead of group_by. + + Handles two scenarios: + 1. Multiple partitions within a single RecordBatch (using boundary detection) + 2. Same partition key continuing from previous RecordBatch (tracking state) + """ + import pyarrow as pa + + if self._partition_child_indexes: + # Detect partition boundaries. + boundaries = self._detect_partition_boundaries(original_batch) + + # Process each contiguous partition + for i in range(len(boundaries) - 1): + start_idx = boundaries[i] + end_idx = boundaries[i + 1] + + # Get the partition key for this segment + partition_key = tuple( + original_batch.column(idx)[start_idx].as_py() + for idx in self._partition_child_indexes + ) + + # Check if this is a continuation of the previous batch's partition + is_new_partition = ( + self._last_partition_key is not None + and partition_key != self._last_partition_key + ) + + if is_new_partition: + # Previous partition ended, call terminate + if hasattr(self._udtf, "terminate"): + terminate_result = self._udtf.terminate() + if terminate_result is not None: + for table in terminate_result: + yield table + # Create new UDTF instance for new partition + self._udtf = self._create_udtf() + self._eval_raised_skip_rest_of_input_table = False + + # Slice the filtered batch for this partition + partition_batch = filtered_batch.slice(start_idx, end_idx - start_idx) + + # Update the last partition key + self._last_partition_key = partition_key + + # Update filtered args to use the partition batch + partition_filtered_args = [] + for arg in filtered_args: + if isinstance(arg, pa.RecordBatch): + partition_filtered_args.append(partition_batch) + else: + partition_filtered_args.append(arg) + + partition_filtered_kwargs = {} + for key, value in filtered_kwargs.items(): + if isinstance(value, pa.RecordBatch): + partition_filtered_kwargs[key] = partition_batch + else: + partition_filtered_kwargs[key] = value + + # Call the UDTF with this partition's data + if not self._eval_raised_skip_rest_of_input_table: + try: + result = self._udtf.eval( + *partition_filtered_args, **partition_filtered_kwargs + ) + if result is not None: + for table in result: + yield table + except SkipRestOfInputTableException: + # Skip remaining rows in this partition + self._eval_raised_skip_rest_of_input_table = True + + # Don't terminate here - let the next batch or final terminate handle it + else: + # No partitions, process the entire batch as one group + try: + result = self._udtf.eval(*filtered_args, **filtered_kwargs) + if result is not None: + # result is an iterator of PyArrow Tables (for Arrow UDTFs) + for table in result: + yield table + except SkipRestOfInputTableException: + pass + + def terminate(self) -> Iterator: + if hasattr(self._udtf, "terminate"): + return self._udtf.terminate() + return iter(()) + + def cleanup(self) -> None: + if hasattr(self._udtf, "cleanup"): + self._udtf.cleanup() + + def _get_table_arg(self, inputs: list): + """Get the table argument (RecordBatch) from the inputs list. + + For Arrow UDTFs with TABLE arguments, we can guarantee the table argument + will be a pa.RecordBatch, not a Row. + """ + import pyarrow as pa + + # Find the RecordBatch in the arguments + for arg in inputs: + if isinstance(arg, pa.RecordBatch): + return arg + + # This shouldn't happen for Arrow UDTFs with TABLE arguments + return None + + def _detect_partition_boundaries(self, batch) -> list: + """ + Efficiently detect partition boundaries in a batch with contiguous partitions. + + Since Catalyst ensures rows with the same partition key are contiguous, + we only need to find where partition values change. Review Comment: Is the value here "row index"? because seems you also assume: 1. Discrete values; 2. ASEC order. Just want to check if this is the correct understanding! ########## python/pyspark/sql/tests/arrow/test_arrow_udtf.py: ########## @@ -730,6 +730,496 @@ def eval(self, x: "pa.Array", y: "pa.Array") -> Iterator["pa.Table"]: expected_df2 = self.spark.createDataFrame([(7, 3, 10)], "x int, y int, sum int") assertDataFrameEqual(sql_result_df2, expected_df2) + def test_arrow_udtf_with_partition_by(self): + @arrow_udtf(returnType="partition_key int, sum_value int") + class SumUDTF: + def eval(self, table_data: "pa.RecordBatch") -> Iterator["pa.Table"]: + table = pa.table(table_data) + partition_key = pc.unique(table["partition_key"]).to_pylist() + assert ( + len(partition_key) == 1 + ), f"Expected exactly one partition key, got {partition_key}" + sum_value = pc.sum(table["value"]).as_py() + result_table = pa.table( + { + "partition_key": pa.array([partition_key[0]], type=pa.int32()), + "sum_value": pa.array([sum_value], type=pa.int32()), + } + ) + yield result_table + + test_data = [ + (1, 10), + (2, 5), + (1, 20), + (2, 15), + (1, 30), + (3, 100), + ] + input_df = self.spark.createDataFrame(test_data, "partition_key int, value int") + + self.spark.udtf.register("sum_udtf", SumUDTF) + input_df.createOrReplaceTempView("test_data") + + result_df = self.spark.sql( + """ + SELECT * FROM sum_udtf(TABLE(test_data) PARTITION BY partition_key) + """ + ) + + expected_data = [ + (1, 60), + (2, 20), + (3, 100), + ] + expected_df = self.spark.createDataFrame(expected_data, "partition_key int, sum_value int") + assertDataFrameEqual(result_df, expected_df) + + def test_arrow_udtf_with_partition_by_and_terminate(self): + @arrow_udtf(returnType="partition_key int, count int, sum_value int") + class TerminateUDTF: + def __init__(self): + self._partition_key = None + self._count = 0 + self._sum = 0 + + def eval(self, table_data: "pa.RecordBatch") -> Iterator["pa.Table"]: + import pyarrow.compute as pc + + table = pa.table(table_data) + # Track partition key + partition_keys = pc.unique(table["partition_key"]).to_pylist() + assert len(partition_keys) == 1, f"Expected one partition key, got {partition_keys}" + self._partition_key = partition_keys[0] + + # Accumulate stats but don't yield here + self._count += table.num_rows + self._sum += pc.sum(table["value"]).as_py() + # Return empty iterator - results come from terminate + return iter(()) + + def terminate(self) -> Iterator["pa.Table"]: + # Yield accumulated results for this partition + if self._partition_key is not None: + result_table = pa.table( + { + "partition_key": pa.array([self._partition_key], type=pa.int32()), + "count": pa.array([self._count], type=pa.int32()), + "sum_value": pa.array([self._sum], type=pa.int32()), + } + ) + yield result_table + + test_data = [ + (3, 50), + (1, 10), + (2, 40), + (1, 20), + (2, 30), + ] + input_df = self.spark.createDataFrame(test_data, "partition_key int, value int") + + self.spark.udtf.register("terminate_udtf", TerminateUDTF) + input_df.createOrReplaceTempView("test_data_terminate") + + result_df = self.spark.sql( + """ + SELECT * FROM terminate_udtf(TABLE(test_data_terminate) PARTITION BY partition_key) + ORDER BY partition_key + """ + ) + + expected_data = [ + (1, 2, 30), # partition 1: 2 rows, sum = 30 + (2, 2, 70), # partition 2: 2 rows, sum = 70 + (3, 1, 50), # partition 3: 1 row, sum = 50 + ] + expected_df = self.spark.createDataFrame( + expected_data, "partition_key int, count int, sum_value int" + ) + assertDataFrameEqual(result_df, expected_df) + + def test_arrow_udtf_with_partition_by_and_order_by(self): + @arrow_udtf(returnType="partition_key int, first_value int, last_value int") + class OrderByUDTF: + def __init__(self): + self._partition_key = None + self._first_value = None + self._last_value = None + + def eval(self, table_data: "pa.RecordBatch") -> Iterator["pa.Table"]: + import pyarrow.compute as pc + + table = pa.table(table_data) + partition_keys = pc.unique(table["partition_key"]).to_pylist() + assert len(partition_keys) == 1, f"Expected one partition key, got {partition_keys}" + self._partition_key = partition_keys[0] + + # Track first and last values (should be ordered) + values = table["value"].to_pylist() + if values: + if self._first_value is None: + self._first_value = values[0] + self._last_value = values[-1] + + return iter(()) + + def terminate(self) -> Iterator["pa.Table"]: + if self._partition_key is not None: + result_table = pa.table( + { + "partition_key": pa.array([self._partition_key], type=pa.int32()), + "first_value": pa.array([self._first_value], type=pa.int32()), + "last_value": pa.array([self._last_value], type=pa.int32()), + } + ) + yield result_table + + test_data = [ + (1, 30), + (1, 10), + (1, 20), + (2, 60), + (2, 40), + (2, 50), + ] + input_df = self.spark.createDataFrame(test_data, "partition_key int, value int") + + self.spark.udtf.register("order_by_udtf", OrderByUDTF) + input_df.createOrReplaceTempView("test_data_order") + + result_df = self.spark.sql( + """ + SELECT * FROM order_by_udtf( + TABLE(test_data_order) + PARTITION BY partition_key + ORDER BY value + ) + ORDER BY partition_key + """ + ) + + expected_data = [ + (1, 10, 30), # partition 1: first=10 (min), last=30 (max) after ordering + (2, 40, 60), # partition 2: first=40 (min), last=60 (max) after ordering + ] + expected_df = self.spark.createDataFrame( + expected_data, "partition_key int, first_value int, last_value int" + ) + assertDataFrameEqual(result_df, expected_df) + + def test_arrow_udtf_partition_column_removal(self): + @arrow_udtf(returnType="col1_sum int, col2_sum int") + class PartitionColumnTestUDTF: + def eval(self, table_data: "pa.RecordBatch") -> Iterator["pa.Table"]: + import pyarrow.compute as pc + + table = pa.table(table_data) + + # When partitioning by an expression like "col1 + col2", + # Catalyst adds the expression result as a new column at the beginning. + # The ArrowUDTFWithPartition._remove_partition_by_exprs method should + # remove this added column, leaving only the original table columns. + column_names = table.column_names + + # Verify we only have the original columns, not the partition expression + assert "col1" in column_names, f"Expected col1 in columns: {column_names}" + assert "col2" in column_names, f"Expected col2 in columns: {column_names}" + # The partition expression column should have been removed + assert len(column_names) == 2, ( + f"Expected only col1 and col2 after partition column removal, " + f"but got: {column_names}" + ) + + col1_sum = pc.sum(table["col1"]).as_py() + col2_sum = pc.sum(table["col2"]).as_py() + + result_table = pa.table( + { + "col1_sum": pa.array([col1_sum], type=pa.int32()), + "col2_sum": pa.array([col2_sum], type=pa.int32()), + } + ) + yield result_table + + test_data = [ + (1, 1), # partition: 1+1=2 + (1, 2), # partition: 1+2=3 + (2, 0), # partition: 2+0=2 + (2, 1), # partition: 2+1=3 + ] + input_df = self.spark.createDataFrame(test_data, "col1 int, col2 int") + + self.spark.udtf.register("partition_column_test_udtf", PartitionColumnTestUDTF) + input_df.createOrReplaceTempView("test_partition_removal") + + # Partition by col1 + col2 expression + result_df = self.spark.sql( + """ + SELECT * FROM partition_column_test_udtf( + TABLE(test_partition_removal) + PARTITION BY col1 + col2 + ) + ORDER BY col1_sum, col2_sum + """ + ) + + expected_data = [ + (3, 1), # partition 2: sum of col1s (1+2), sum of col2s (1+0) + (3, 3), # partition 3: sum of col1s (1+2), sum of col2s (2+1) + ] + expected_df = self.spark.createDataFrame(expected_data, "col1_sum int, col2_sum int") + assertDataFrameEqual(result_df, expected_df) + + def test_arrow_udtf_partition_by_single_partition_multiple_input_partitions(self): + @arrow_udtf(returnType="partition_key int, count bigint, sum_value bigint") + class SinglePartitionUDTF: + def __init__(self): + self._partition_key = None + self._count = 0 + self._sum = 0 + + def eval(self, table_data: "pa.RecordBatch") -> Iterator["pa.Table"]: + import pyarrow.compute as pc + + table = pa.table(table_data) + + # All rows should have the same partition key (constant value 1) + partition_keys = pc.unique(table["partition_key"]).to_pylist() + self._partition_key = partition_keys[0] + self._count += table.num_rows + self._sum += pc.sum(table["id"]).as_py() + + return iter(()) + + def terminate(self) -> Iterator["pa.Table"]: + if self._partition_key is not None: + result_table = pa.table( + { + "partition_key": pa.array([self._partition_key], type=pa.int32()), + "count": pa.array([self._count], type=pa.int64()), + "sum_value": pa.array([self._sum], type=pa.int64()), + } + ) + yield result_table + + # Create DataFrame with 5 input partitions but all data will map to partition_key=1 + # range(1, 10, 1, 5) creates ids from 1 to 9 with 5 partitions + input_df = self.spark.range(1, 10, 1, 5).selectExpr( + "1 as partition_key", "id" # constant partition key + ) + + self.spark.udtf.register("single_partition_udtf", SinglePartitionUDTF) + input_df.createOrReplaceTempView("test_single_partition") + + result_df = self.spark.sql( + """ + SELECT * FROM single_partition_udtf( + TABLE(test_single_partition) + PARTITION BY partition_key + ) + """ + ) + + # All 9 rows (1 through 9) should be in a single partition with key=1 + expected_data = [(1, 9, 45)] + expected_df = self.spark.createDataFrame( + expected_data, "partition_key int, count bigint, sum_value bigint" + ) + assertDataFrameEqual(result_df, expected_df) + + def test_arrow_udtf_with_partition_by_skip_rest_of_input(self): + from pyspark.sql.functions import SkipRestOfInputTableException + + @arrow_udtf(returnType="partition_key int, rows_processed int, last_value int") + class SkipRestUDTF: + def __init__(self): + self._partition_key = None + self._rows_processed = 0 + self._last_value = None + + def eval(self, table_data: "pa.RecordBatch") -> Iterator["pa.Table"]: + import pyarrow.compute as pc + + table = pa.table(table_data) + partition_keys = pc.unique(table["partition_key"]).to_pylist() + assert len(partition_keys) == 1, f"Expected one partition key, got {partition_keys}" + self._partition_key = partition_keys[0] + + # Process rows one by one and stop after processing 2 rows per partition + values = table["value"].to_pylist() + for value in values: + self._rows_processed += 1 + self._last_value = value + + # Skip rest of the partition after processing 2 rows + if self._rows_processed >= 2: + msg = f"Skipping partition {self._partition_key} " + msg += f"after {self._rows_processed} rows" + raise SkipRestOfInputTableException(msg) Review Comment: I feel this pattern is not very ideal for users to write. I wonder if we can utilize the `return` keyword inside an iterator for this purpose. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org