ueshin commented on code in PR #52317: URL: https://github.com/apache/spark/pull/52317#discussion_r2345583649
########## python/pyspark/worker.py: ########## @@ -1514,10 +1514,288 @@ def _remove_partition_by_exprs(self, arg: Any) -> Any: else: return arg + class ArrowUDTFWithPartition: + """ + Implements logic for an Arrow UDTF (SQL_ARROW_UDTF) that accepts a TABLE argument + with one or more PARTITION BY expressions. + + Arrow UDTFs receive data as PyArrow RecordBatch objects instead of individual Row + objects. + + Example table: + CREATE TABLE t (c1 INT, c2 INT) USING delta; + + Example queries: + SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2); + partition_child_indexes: 0, 1. + + SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2 + 4); + partition_child_indexes: 0, 2 (adds a projection for "c2 + 4"). + """ + + def __init__(self, create_udtf: Callable, partition_child_indexes: list): + """ + Create a new instance that wraps the provided Arrow UDTF with partitioning + logic. + + Parameters + ---------- + create_udtf: function + Function that creates a new instance of the Arrow UDTF to invoke. + partition_child_indexes: list + Zero-based indexes of input-table columns that contain projected + partitioning expressions. + """ + self._create_udtf: Callable = create_udtf + self._udtf = create_udtf() + self._partition_child_indexes: list = partition_child_indexes + # Track last partition key from previous batch + self._last_partition_key: Optional[Tuple[Any, ...]] = None + self._eval_raised_skip_rest_of_input_table: bool = False + + def eval(self, *args, **kwargs) -> Iterator: + """Handle partitioning logic for Arrow UDTFs that receive RecordBatch objects.""" + import pyarrow as pa + + # Get the original batch with partition columns + original_batch = self._get_table_arg(list(args) + list(kwargs.values())) + if not isinstance(original_batch, pa.RecordBatch): + # Arrow UDTFs with PARTITION BY must have a TABLE argument that + # results in a PyArrow RecordBatch + raise PySparkRuntimeError( + errorClass="INVALID_ARROW_UDTF_TABLE_ARGUMENT", + messageParameters={ + "actual_type": str(type(original_batch)) + if original_batch is not None + else "None" + }, + ) + + # Remove partition columns to get the filtered arguments + filtered_args = [self._remove_partition_by_exprs(arg) for arg in args] + filtered_kwargs = { + key: self._remove_partition_by_exprs(value) for (key, value) in kwargs.items() + } + + # Get the filtered RecordBatch (without partition columns) + filtered_batch = self._get_table_arg(filtered_args + list(filtered_kwargs.values())) + + # Process the RecordBatch by partitions + yield from self._process_arrow_batch_by_partitions( + original_batch, filtered_batch, filtered_args, filtered_kwargs + ) + + def _process_arrow_batch_by_partitions( + self, original_batch, filtered_batch, filtered_args, filtered_kwargs + ) -> Iterator: + """Process an Arrow RecordBatch by splitting it into partitions. + + Since Catalyst guarantees that rows with the same partition key are contiguous, + we can use efficient boundary detection instead of group_by. + + Handles two scenarios: + 1. Multiple partitions within a single RecordBatch (using boundary detection) + 2. Same partition key continuing from previous RecordBatch (tracking state) + """ + import pyarrow as pa + + if self._partition_child_indexes: + # Detect partition boundaries. + boundaries = self._detect_partition_boundaries(original_batch) + + # Process each contiguous partition + for i in range(len(boundaries) - 1): + start_idx = boundaries[i] + end_idx = boundaries[i + 1] + + # Get the partition key for this segment + partition_key = tuple( + original_batch.column(idx)[start_idx].as_py() + for idx in self._partition_child_indexes + ) + + # Check if this is a continuation of the previous batch's partition + is_new_partition = ( + self._last_partition_key is not None + and partition_key != self._last_partition_key + ) + + if is_new_partition: + # Previous partition ended, call terminate + if hasattr(self._udtf, "terminate"): + terminate_result = self._udtf.terminate() + if terminate_result is not None: + for table in terminate_result: + yield table + # Create new UDTF instance for new partition + self._udtf = self._create_udtf() + self._eval_raised_skip_rest_of_input_table = False + + # Slice the filtered batch for this partition + partition_batch = filtered_batch.slice(start_idx, end_idx - start_idx) + + # Update the last partition key + self._last_partition_key = partition_key + + # Update filtered args to use the partition batch + partition_filtered_args = [] + for arg in filtered_args: + if isinstance(arg, pa.RecordBatch): + partition_filtered_args.append(partition_batch) + else: + partition_filtered_args.append(arg) + + partition_filtered_kwargs = {} + for key, value in filtered_kwargs.items(): + if isinstance(value, pa.RecordBatch): + partition_filtered_kwargs[key] = partition_batch + else: + partition_filtered_kwargs[key] = value + + # Call the UDTF with this partition's data + if not self._eval_raised_skip_rest_of_input_table: + try: + result = self._udtf.eval( + *partition_filtered_args, **partition_filtered_kwargs + ) + if result is not None: + for table in result: + yield table Review Comment: nit ```suggestion yield from result ``` ########## python/pyspark/sql/tests/arrow/test_arrow_udtf.py: ########## @@ -730,6 +730,496 @@ def eval(self, x: "pa.Array", y: "pa.Array") -> Iterator["pa.Table"]: expected_df2 = self.spark.createDataFrame([(7, 3, 10)], "x int, y int, sum int") assertDataFrameEqual(sql_result_df2, expected_df2) + def test_arrow_udtf_with_partition_by(self): + @arrow_udtf(returnType="partition_key int, sum_value int") + class SumUDTF: + def eval(self, table_data: "pa.RecordBatch") -> Iterator["pa.Table"]: + table = pa.table(table_data) + partition_key = pc.unique(table["partition_key"]).to_pylist() + assert ( + len(partition_key) == 1 + ), f"Expected exactly one partition key, got {partition_key}" + sum_value = pc.sum(table["value"]).as_py() + result_table = pa.table( + { + "partition_key": pa.array([partition_key[0]], type=pa.int32()), + "sum_value": pa.array([sum_value], type=pa.int32()), + } + ) + yield result_table + + test_data = [ + (1, 10), + (2, 5), + (1, 20), + (2, 15), + (1, 30), + (3, 100), + ] + input_df = self.spark.createDataFrame(test_data, "partition_key int, value int") + + self.spark.udtf.register("sum_udtf", SumUDTF) + input_df.createOrReplaceTempView("test_data") + + result_df = self.spark.sql( + """ + SELECT * FROM sum_udtf(TABLE(test_data) PARTITION BY partition_key) + """ + ) + + expected_data = [ + (1, 60), + (2, 20), + (3, 100), + ] + expected_df = self.spark.createDataFrame(expected_data, "partition_key int, sum_value int") + assertDataFrameEqual(result_df, expected_df) + + def test_arrow_udtf_with_partition_by_and_terminate(self): + @arrow_udtf(returnType="partition_key int, count int, sum_value int") + class TerminateUDTF: + def __init__(self): + self._partition_key = None + self._count = 0 + self._sum = 0 + + def eval(self, table_data: "pa.RecordBatch") -> Iterator["pa.Table"]: + import pyarrow.compute as pc + + table = pa.table(table_data) + # Track partition key + partition_keys = pc.unique(table["partition_key"]).to_pylist() + assert len(partition_keys) == 1, f"Expected one partition key, got {partition_keys}" + self._partition_key = partition_keys[0] + + # Accumulate stats but don't yield here + self._count += table.num_rows + self._sum += pc.sum(table["value"]).as_py() + # Return empty iterator - results come from terminate + return iter(()) + + def terminate(self) -> Iterator["pa.Table"]: + # Yield accumulated results for this partition + if self._partition_key is not None: + result_table = pa.table( + { + "partition_key": pa.array([self._partition_key], type=pa.int32()), + "count": pa.array([self._count], type=pa.int32()), + "sum_value": pa.array([self._sum], type=pa.int32()), + } + ) + yield result_table + + test_data = [ + (3, 50), + (1, 10), + (2, 40), + (1, 20), + (2, 30), + ] + input_df = self.spark.createDataFrame(test_data, "partition_key int, value int") + + self.spark.udtf.register("terminate_udtf", TerminateUDTF) + input_df.createOrReplaceTempView("test_data_terminate") + + result_df = self.spark.sql( + """ + SELECT * FROM terminate_udtf(TABLE(test_data_terminate) PARTITION BY partition_key) + ORDER BY partition_key + """ + ) + + expected_data = [ + (1, 2, 30), # partition 1: 2 rows, sum = 30 + (2, 2, 70), # partition 2: 2 rows, sum = 70 + (3, 1, 50), # partition 3: 1 row, sum = 50 + ] + expected_df = self.spark.createDataFrame( + expected_data, "partition_key int, count int, sum_value int" + ) + assertDataFrameEqual(result_df, expected_df) + + def test_arrow_udtf_with_partition_by_and_order_by(self): + @arrow_udtf(returnType="partition_key int, first_value int, last_value int") + class OrderByUDTF: + def __init__(self): + self._partition_key = None + self._first_value = None + self._last_value = None + + def eval(self, table_data: "pa.RecordBatch") -> Iterator["pa.Table"]: + import pyarrow.compute as pc + + table = pa.table(table_data) + partition_keys = pc.unique(table["partition_key"]).to_pylist() + assert len(partition_keys) == 1, f"Expected one partition key, got {partition_keys}" + self._partition_key = partition_keys[0] + + # Track first and last values (should be ordered) + values = table["value"].to_pylist() + if values: + if self._first_value is None: + self._first_value = values[0] + self._last_value = values[-1] + + return iter(()) + + def terminate(self) -> Iterator["pa.Table"]: + if self._partition_key is not None: + result_table = pa.table( + { + "partition_key": pa.array([self._partition_key], type=pa.int32()), + "first_value": pa.array([self._first_value], type=pa.int32()), + "last_value": pa.array([self._last_value], type=pa.int32()), + } + ) + yield result_table + + test_data = [ + (1, 30), + (1, 10), + (1, 20), + (2, 60), + (2, 40), + (2, 50), + ] + input_df = self.spark.createDataFrame(test_data, "partition_key int, value int") + + self.spark.udtf.register("order_by_udtf", OrderByUDTF) + input_df.createOrReplaceTempView("test_data_order") + + result_df = self.spark.sql( + """ + SELECT * FROM order_by_udtf( + TABLE(test_data_order) + PARTITION BY partition_key + ORDER BY value + ) + ORDER BY partition_key + """ + ) + + expected_data = [ + (1, 10, 30), # partition 1: first=10 (min), last=30 (max) after ordering + (2, 40, 60), # partition 2: first=40 (min), last=60 (max) after ordering + ] + expected_df = self.spark.createDataFrame( + expected_data, "partition_key int, first_value int, last_value int" + ) + assertDataFrameEqual(result_df, expected_df) + + def test_arrow_udtf_partition_column_removal(self): + @arrow_udtf(returnType="col1_sum int, col2_sum int") + class PartitionColumnTestUDTF: + def eval(self, table_data: "pa.RecordBatch") -> Iterator["pa.Table"]: + import pyarrow.compute as pc + + table = pa.table(table_data) + + # When partitioning by an expression like "col1 + col2", + # Catalyst adds the expression result as a new column at the beginning. + # The ArrowUDTFWithPartition._remove_partition_by_exprs method should + # remove this added column, leaving only the original table columns. + column_names = table.column_names + + # Verify we only have the original columns, not the partition expression + assert "col1" in column_names, f"Expected col1 in columns: {column_names}" + assert "col2" in column_names, f"Expected col2 in columns: {column_names}" + # The partition expression column should have been removed + assert len(column_names) == 2, ( + f"Expected only col1 and col2 after partition column removal, " + f"but got: {column_names}" + ) + + col1_sum = pc.sum(table["col1"]).as_py() + col2_sum = pc.sum(table["col2"]).as_py() + + result_table = pa.table( + { + "col1_sum": pa.array([col1_sum], type=pa.int32()), + "col2_sum": pa.array([col2_sum], type=pa.int32()), + } + ) + yield result_table + + test_data = [ + (1, 1), # partition: 1+1=2 + (1, 2), # partition: 1+2=3 + (2, 0), # partition: 2+0=2 + (2, 1), # partition: 2+1=3 + ] + input_df = self.spark.createDataFrame(test_data, "col1 int, col2 int") + + self.spark.udtf.register("partition_column_test_udtf", PartitionColumnTestUDTF) + input_df.createOrReplaceTempView("test_partition_removal") + + # Partition by col1 + col2 expression + result_df = self.spark.sql( + """ + SELECT * FROM partition_column_test_udtf( + TABLE(test_partition_removal) + PARTITION BY col1 + col2 + ) + ORDER BY col1_sum, col2_sum + """ + ) + + expected_data = [ + (3, 1), # partition 2: sum of col1s (1+2), sum of col2s (1+0) + (3, 3), # partition 3: sum of col1s (1+2), sum of col2s (2+1) + ] + expected_df = self.spark.createDataFrame(expected_data, "col1_sum int, col2_sum int") + assertDataFrameEqual(result_df, expected_df) + + def test_arrow_udtf_partition_by_single_partition_multiple_input_partitions(self): + @arrow_udtf(returnType="partition_key int, count bigint, sum_value bigint") + class SinglePartitionUDTF: + def __init__(self): + self._partition_key = None + self._count = 0 + self._sum = 0 + + def eval(self, table_data: "pa.RecordBatch") -> Iterator["pa.Table"]: + import pyarrow.compute as pc + + table = pa.table(table_data) + + # All rows should have the same partition key (constant value 1) + partition_keys = pc.unique(table["partition_key"]).to_pylist() + self._partition_key = partition_keys[0] + self._count += table.num_rows + self._sum += pc.sum(table["id"]).as_py() + + return iter(()) + + def terminate(self) -> Iterator["pa.Table"]: + if self._partition_key is not None: + result_table = pa.table( + { + "partition_key": pa.array([self._partition_key], type=pa.int32()), + "count": pa.array([self._count], type=pa.int64()), + "sum_value": pa.array([self._sum], type=pa.int64()), + } + ) + yield result_table + + # Create DataFrame with 5 input partitions but all data will map to partition_key=1 + # range(1, 10, 1, 5) creates ids from 1 to 9 with 5 partitions + input_df = self.spark.range(1, 10, 1, 5).selectExpr( + "1 as partition_key", "id" # constant partition key + ) + + self.spark.udtf.register("single_partition_udtf", SinglePartitionUDTF) + input_df.createOrReplaceTempView("test_single_partition") + + result_df = self.spark.sql( + """ + SELECT * FROM single_partition_udtf( + TABLE(test_single_partition) + PARTITION BY partition_key + ) + """ + ) + + # All 9 rows (1 through 9) should be in a single partition with key=1 + expected_data = [(1, 9, 45)] + expected_df = self.spark.createDataFrame( + expected_data, "partition_key int, count bigint, sum_value bigint" + ) + assertDataFrameEqual(result_df, expected_df) + + def test_arrow_udtf_with_partition_by_skip_rest_of_input(self): + from pyspark.sql.functions import SkipRestOfInputTableException + + @arrow_udtf(returnType="partition_key int, rows_processed int, last_value int") + class SkipRestUDTF: + def __init__(self): + self._partition_key = None + self._rows_processed = 0 + self._last_value = None + + def eval(self, table_data: "pa.RecordBatch") -> Iterator["pa.Table"]: + import pyarrow.compute as pc + + table = pa.table(table_data) + partition_keys = pc.unique(table["partition_key"]).to_pylist() + assert len(partition_keys) == 1, f"Expected one partition key, got {partition_keys}" + self._partition_key = partition_keys[0] + + # Process rows one by one and stop after processing 2 rows per partition + values = table["value"].to_pylist() + for value in values: + self._rows_processed += 1 + self._last_value = value + + # Skip rest of the partition after processing 2 rows + if self._rows_processed >= 2: + msg = f"Skipping partition {self._partition_key} " + msg += f"after {self._rows_processed} rows" + raise SkipRestOfInputTableException(msg) + + return iter(()) + + def terminate(self) -> Iterator["pa.Table"]: + if self._partition_key is not None: + result_table = pa.table( + { + "partition_key": pa.array([self._partition_key], type=pa.int32()), + "rows_processed": pa.array([self._rows_processed], type=pa.int32()), + "last_value": pa.array([self._last_value], type=pa.int32()), + } + ) + yield result_table + + # Create test data with multiple partitions, each having more than 2 rows + test_data = [ + (1, 10), + (1, 20), + (1, 30), # This should be skipped + (1, 40), # This should be skipped + (2, 50), + (2, 60), + (2, 70), # This should be skipped + (3, 80), + (3, 90), + (3, 100), # This should be skipped + (3, 110), # This should be skipped + ] + input_df = self.spark.createDataFrame(test_data, "partition_key int, value int") + + self.spark.udtf.register("skip_rest_udtf", SkipRestUDTF) + input_df.createOrReplaceTempView("test_skip_rest") + + result_df = self.spark.sql( + """ + SELECT * FROM skip_rest_udtf( + TABLE(test_skip_rest) + PARTITION BY partition_key + ORDER BY value + ) + ORDER BY partition_key + """ + ) + + # Each partition should only process 2 rows before skipping the rest + expected_data = [ + (1, 2, 20), # Processed rows 10, 20, then skipped 30, 40 + (2, 2, 60), # Processed rows 50, 60, then skipped 70 + (3, 2, 90), # Processed rows 80, 90, then skipped 100, 110 + ] + expected_df = self.spark.createDataFrame( + expected_data, "partition_key int, rows_processed int, last_value int" + ) + assertDataFrameEqual(result_df, expected_df) + + def test_arrow_udtf_with_partition_by_null_values(self): + @arrow_udtf(returnType="partition_key int, count int, non_null_sum int") + class NullPartitionUDTF: + def __init__(self): + self._partition_key = None + self._count = 0 + self._non_null_sum = 0 + + def eval(self, table_data: "pa.RecordBatch") -> Iterator["pa.Table"]: + import pyarrow.compute as pc + + table = pa.table(table_data) + # Handle null partition keys + partition_keys = table["partition_key"] + unique_keys = pc.unique(partition_keys).to_pylist() + + # Should have exactly one unique value (either a value or None) + assert len(unique_keys) == 1, f"Expected one partition key, got {unique_keys}" + self._partition_key = unique_keys[0] + + # Count rows and sum non-null values + self._count += table.num_rows + values = table["value"] + # Use PyArrow compute to handle nulls properly + non_null_values = pc.drop_null(values) + if len(non_null_values) > 0: + self._non_null_sum += pc.sum(non_null_values).as_py() + + return iter(()) + + def terminate(self) -> Iterator["pa.Table"]: + # Return results even for null partition keys + result_table = pa.table( + { + "partition_key": pa.array([self._partition_key], type=pa.int32()), + "count": pa.array([self._count], type=pa.int32()), + "non_null_sum": pa.array([self._non_null_sum], type=pa.int32()), + } + ) + yield result_table + + # Test data with null partition keys and null values + test_data = [ + (1, 10), + (1, None), # null value in partition 1 + (None, 20), # null partition key + (None, 30), # null partition key + (2, 40), + (2, None), # null value in partition 2 + (None, None), # both null + ] + input_df = self.spark.createDataFrame(test_data, "partition_key int, value int") + + self.spark.udtf.register("null_partition_udtf", NullPartitionUDTF) + input_df.createOrReplaceTempView("test_null_partitions") + + result_df = self.spark.sql( + """ + SELECT * FROM null_partition_udtf( + TABLE(test_null_partitions) + PARTITION BY partition_key + ORDER BY value + ) + ORDER BY partition_key NULLS FIRST + """ + ) + + # Expected: null partition gets grouped together, nulls in values are handled + expected_data = [ + (None, 3, 50), # null partition: 3 rows, sum of non-null values = 20+30 = 50 + (1, 2, 10), # partition 1: 2 rows, sum of non-null values = 10 + (2, 2, 40), # partition 2: 2 rows, sum of non-null values = 40 + ] + expected_df = self.spark.createDataFrame( + expected_data, "partition_key int, count int, non_null_sum int" + ) + assertDataFrameEqual(result_df, expected_df) + + def test_arrow_udtf_with_empty_table(self): + @arrow_udtf(returnType="result string") + class EmptyTableUDTF: + def eval(self, table_data: "pa.RecordBatch") -> Iterator["pa.Table"]: + import pyarrow as pa + + # This should not be called for empty tables Review Comment: How about raising an error to check it won't fail if this should not be called? ########## python/pyspark/worker.py: ########## @@ -1514,10 +1514,288 @@ def _remove_partition_by_exprs(self, arg: Any) -> Any: else: return arg + class ArrowUDTFWithPartition: + """ + Implements logic for an Arrow UDTF (SQL_ARROW_UDTF) that accepts a TABLE argument + with one or more PARTITION BY expressions. + + Arrow UDTFs receive data as PyArrow RecordBatch objects instead of individual Row + objects. + + Example table: + CREATE TABLE t (c1 INT, c2 INT) USING delta; + + Example queries: + SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2); + partition_child_indexes: 0, 1. + + SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2 + 4); + partition_child_indexes: 0, 2 (adds a projection for "c2 + 4"). + """ + + def __init__(self, create_udtf: Callable, partition_child_indexes: list): + """ + Create a new instance that wraps the provided Arrow UDTF with partitioning + logic. + + Parameters + ---------- + create_udtf: function + Function that creates a new instance of the Arrow UDTF to invoke. + partition_child_indexes: list + Zero-based indexes of input-table columns that contain projected + partitioning expressions. + """ + self._create_udtf: Callable = create_udtf + self._udtf = create_udtf() + self._partition_child_indexes: list = partition_child_indexes + # Track last partition key from previous batch + self._last_partition_key: Optional[Tuple[Any, ...]] = None + self._eval_raised_skip_rest_of_input_table: bool = False + + def eval(self, *args, **kwargs) -> Iterator: + """Handle partitioning logic for Arrow UDTFs that receive RecordBatch objects.""" + import pyarrow as pa + + # Get the original batch with partition columns + original_batch = self._get_table_arg(list(args) + list(kwargs.values())) + if not isinstance(original_batch, pa.RecordBatch): + # Arrow UDTFs with PARTITION BY must have a TABLE argument that + # results in a PyArrow RecordBatch + raise PySparkRuntimeError( + errorClass="INVALID_ARROW_UDTF_TABLE_ARGUMENT", + messageParameters={ + "actual_type": str(type(original_batch)) + if original_batch is not None + else "None" + }, + ) + + # Remove partition columns to get the filtered arguments + filtered_args = [self._remove_partition_by_exprs(arg) for arg in args] + filtered_kwargs = { + key: self._remove_partition_by_exprs(value) for (key, value) in kwargs.items() + } + + # Get the filtered RecordBatch (without partition columns) + filtered_batch = self._get_table_arg(filtered_args + list(filtered_kwargs.values())) + + # Process the RecordBatch by partitions + yield from self._process_arrow_batch_by_partitions( + original_batch, filtered_batch, filtered_args, filtered_kwargs + ) + + def _process_arrow_batch_by_partitions( + self, original_batch, filtered_batch, filtered_args, filtered_kwargs + ) -> Iterator: + """Process an Arrow RecordBatch by splitting it into partitions. + + Since Catalyst guarantees that rows with the same partition key are contiguous, + we can use efficient boundary detection instead of group_by. + + Handles two scenarios: + 1. Multiple partitions within a single RecordBatch (using boundary detection) + 2. Same partition key continuing from previous RecordBatch (tracking state) + """ + import pyarrow as pa + + if self._partition_child_indexes: + # Detect partition boundaries. + boundaries = self._detect_partition_boundaries(original_batch) + + # Process each contiguous partition + for i in range(len(boundaries) - 1): + start_idx = boundaries[i] + end_idx = boundaries[i + 1] + + # Get the partition key for this segment + partition_key = tuple( + original_batch.column(idx)[start_idx].as_py() + for idx in self._partition_child_indexes + ) + + # Check if this is a continuation of the previous batch's partition + is_new_partition = ( + self._last_partition_key is not None + and partition_key != self._last_partition_key + ) + + if is_new_partition: + # Previous partition ended, call terminate + if hasattr(self._udtf, "terminate"): + terminate_result = self._udtf.terminate() + if terminate_result is not None: + for table in terminate_result: + yield table + # Create new UDTF instance for new partition + self._udtf = self._create_udtf() + self._eval_raised_skip_rest_of_input_table = False + + # Slice the filtered batch for this partition + partition_batch = filtered_batch.slice(start_idx, end_idx - start_idx) + + # Update the last partition key + self._last_partition_key = partition_key + + # Update filtered args to use the partition batch + partition_filtered_args = [] + for arg in filtered_args: + if isinstance(arg, pa.RecordBatch): + partition_filtered_args.append(partition_batch) + else: + partition_filtered_args.append(arg) + + partition_filtered_kwargs = {} + for key, value in filtered_kwargs.items(): + if isinstance(value, pa.RecordBatch): + partition_filtered_kwargs[key] = partition_batch + else: + partition_filtered_kwargs[key] = value + + # Call the UDTF with this partition's data + if not self._eval_raised_skip_rest_of_input_table: + try: + result = self._udtf.eval( + *partition_filtered_args, **partition_filtered_kwargs + ) + if result is not None: + for table in result: + yield table + except SkipRestOfInputTableException: + # Skip remaining rows in this partition + self._eval_raised_skip_rest_of_input_table = True + + # Don't terminate here - let the next batch or final terminate handle it + else: + # No partitions, process the entire batch as one group + try: + result = self._udtf.eval(*filtered_args, **filtered_kwargs) + if result is not None: + # result is an iterator of PyArrow Tables (for Arrow UDTFs) + for table in result: + yield table Review Comment: ditto. ```suggestion yield from result ``` ########## python/pyspark/sql/tests/arrow/test_arrow_udtf.py: ########## @@ -730,6 +730,496 @@ def eval(self, x: "pa.Array", y: "pa.Array") -> Iterator["pa.Table"]: expected_df2 = self.spark.createDataFrame([(7, 3, 10)], "x int, y int, sum int") assertDataFrameEqual(sql_result_df2, expected_df2) + def test_arrow_udtf_with_partition_by(self): + @arrow_udtf(returnType="partition_key int, sum_value int") + class SumUDTF: + def eval(self, table_data: "pa.RecordBatch") -> Iterator["pa.Table"]: + table = pa.table(table_data) + partition_key = pc.unique(table["partition_key"]).to_pylist() + assert ( + len(partition_key) == 1 + ), f"Expected exactly one partition key, got {partition_key}" + sum_value = pc.sum(table["value"]).as_py() + result_table = pa.table( + { + "partition_key": pa.array([partition_key[0]], type=pa.int32()), + "sum_value": pa.array([sum_value], type=pa.int32()), + } + ) + yield result_table + + test_data = [ + (1, 10), + (2, 5), + (1, 20), + (2, 15), + (1, 30), + (3, 100), + ] + input_df = self.spark.createDataFrame(test_data, "partition_key int, value int") + + self.spark.udtf.register("sum_udtf", SumUDTF) + input_df.createOrReplaceTempView("test_data") + + result_df = self.spark.sql( + """ + SELECT * FROM sum_udtf(TABLE(test_data) PARTITION BY partition_key) + """ + ) + + expected_data = [ + (1, 60), + (2, 20), + (3, 100), + ] Review Comment: This test may be potentially flaky? IIUC, `eval` can be called multiple times per partition, so this implementation can yield multiple tables per parittion. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org