ueshin commented on code in PR #52317: URL: https://github.com/apache/spark/pull/52317#discussion_r2356436933
########## python/pyspark/worker.py: ########## @@ -1514,10 +1514,293 @@ def _remove_partition_by_exprs(self, arg: Any) -> Any: else: return arg + class ArrowUDTFWithPartition: + """ + Implements logic for an Arrow UDTF (SQL_ARROW_UDTF) that accepts a TABLE argument + with one or more PARTITION BY expressions. + + Arrow UDTFs receive data as PyArrow RecordBatch objects instead of individual Row + objects. + + Example table: + CREATE TABLE t (c1 INT, c2 INT) USING delta; + + Example queries: + SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2); + partition_child_indexes: 0, 1. + + SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2 + 4); + partition_child_indexes: 0, 2 (adds a projection for "c2 + 4"). + """ + + def __init__(self, create_udtf: Callable, partition_child_indexes: list): + """ + Create a new instance that wraps the provided Arrow UDTF with partitioning + logic. + + Parameters + ---------- + create_udtf: function + Function that creates a new instance of the Arrow UDTF to invoke. + partition_child_indexes: list + Zero-based indexes of input-table columns that contain projected + partitioning expressions. + """ + self._create_udtf: Callable = create_udtf + self._udtf = create_udtf() + self._partition_child_indexes: list = partition_child_indexes + # Track last partition key from previous batch + self._last_partition_key: Optional[Tuple[Any, ...]] = None + self._eval_raised_skip_rest_of_input_table: bool = False + + def eval(self, *args, **kwargs) -> Iterator: + """Handle partitioning logic for Arrow UDTFs that receive RecordBatch objects.""" + import pyarrow as pa + + # Get the original batch with partition columns + original_batch = self._get_table_arg(list(args) + list(kwargs.values())) + if not isinstance(original_batch, pa.RecordBatch): + # Arrow UDTFs with PARTITION BY must have a TABLE argument that + # results in a PyArrow RecordBatch + raise PySparkRuntimeError( + errorClass="INVALID_ARROW_UDTF_TABLE_ARGUMENT", + messageParameters={ + "actual_type": str(type(original_batch)) + if original_batch is not None + else "None" + }, + ) + + # Remove partition columns to get the filtered arguments + filtered_args = [self._remove_partition_by_exprs(arg) for arg in args] + filtered_kwargs = { + key: self._remove_partition_by_exprs(value) for (key, value) in kwargs.items() + } + + # Get the filtered RecordBatch (without partition columns) + filtered_batch = self._get_table_arg(filtered_args + list(filtered_kwargs.values())) + + # Process the RecordBatch by partitions + yield from self._process_arrow_batch_by_partitions( + original_batch, filtered_batch, filtered_args, filtered_kwargs + ) + + def _process_arrow_batch_by_partitions( + self, original_batch, filtered_batch, filtered_args, filtered_kwargs + ) -> Iterator: + """Process an Arrow RecordBatch by splitting it into partitions. + + Since Catalyst guarantees that rows with the same partition key are contiguous, + we can use efficient boundary detection instead of group_by. + + Handles two scenarios: + 1. Multiple partitions within a single RecordBatch (using boundary detection) + 2. Same partition key continuing from previous RecordBatch (tracking state) + """ + import pyarrow as pa + + if self._partition_child_indexes: Review Comment: Do we need this check? This class seems to be used only when `len(partition_child_indexes) > 0`? ########## python/pyspark/worker.py: ########## @@ -1514,10 +1514,293 @@ def _remove_partition_by_exprs(self, arg: Any) -> Any: else: return arg + class ArrowUDTFWithPartition: + """ + Implements logic for an Arrow UDTF (SQL_ARROW_UDTF) that accepts a TABLE argument + with one or more PARTITION BY expressions. + + Arrow UDTFs receive data as PyArrow RecordBatch objects instead of individual Row + objects. + + Example table: + CREATE TABLE t (c1 INT, c2 INT) USING delta; + + Example queries: + SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2); + partition_child_indexes: 0, 1. + + SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2 + 4); + partition_child_indexes: 0, 2 (adds a projection for "c2 + 4"). + """ + + def __init__(self, create_udtf: Callable, partition_child_indexes: list): + """ + Create a new instance that wraps the provided Arrow UDTF with partitioning + logic. + + Parameters + ---------- + create_udtf: function + Function that creates a new instance of the Arrow UDTF to invoke. + partition_child_indexes: list + Zero-based indexes of input-table columns that contain projected + partitioning expressions. + """ + self._create_udtf: Callable = create_udtf + self._udtf = create_udtf() + self._partition_child_indexes: list = partition_child_indexes + # Track last partition key from previous batch + self._last_partition_key: Optional[Tuple[Any, ...]] = None + self._eval_raised_skip_rest_of_input_table: bool = False + + def eval(self, *args, **kwargs) -> Iterator: + """Handle partitioning logic for Arrow UDTFs that receive RecordBatch objects.""" + import pyarrow as pa + + # Get the original batch with partition columns + original_batch = self._get_table_arg(list(args) + list(kwargs.values())) + if not isinstance(original_batch, pa.RecordBatch): + # Arrow UDTFs with PARTITION BY must have a TABLE argument that + # results in a PyArrow RecordBatch + raise PySparkRuntimeError( + errorClass="INVALID_ARROW_UDTF_TABLE_ARGUMENT", + messageParameters={ + "actual_type": str(type(original_batch)) + if original_batch is not None + else "None" + }, + ) + + # Remove partition columns to get the filtered arguments + filtered_args = [self._remove_partition_by_exprs(arg) for arg in args] + filtered_kwargs = { + key: self._remove_partition_by_exprs(value) for (key, value) in kwargs.items() + } + + # Get the filtered RecordBatch (without partition columns) + filtered_batch = self._get_table_arg(filtered_args + list(filtered_kwargs.values())) + + # Process the RecordBatch by partitions + yield from self._process_arrow_batch_by_partitions( + original_batch, filtered_batch, filtered_args, filtered_kwargs + ) + + def _process_arrow_batch_by_partitions( + self, original_batch, filtered_batch, filtered_args, filtered_kwargs + ) -> Iterator: + """Process an Arrow RecordBatch by splitting it into partitions. + + Since Catalyst guarantees that rows with the same partition key are contiguous, + we can use efficient boundary detection instead of group_by. + + Handles two scenarios: + 1. Multiple partitions within a single RecordBatch (using boundary detection) + 2. Same partition key continuing from previous RecordBatch (tracking state) + """ + import pyarrow as pa + + if self._partition_child_indexes: + # Detect partition boundaries. + boundaries = self._detect_partition_boundaries(original_batch) + + # Process each contiguous partition + for i in range(len(boundaries) - 1): + start_idx = boundaries[i] + end_idx = boundaries[i + 1] + + # Get the partition key for this segment + partition_key = tuple( + original_batch.column(idx)[start_idx].as_py() + for idx in self._partition_child_indexes + ) + + # Check if this is a continuation of the previous batch's partition + is_new_partition = ( + self._last_partition_key is not None + and partition_key != self._last_partition_key + ) Review Comment: This check is only necessary for the first boundary? The following boundaries are always for new partitions. Maybe revisit this later as this is rather kind of optimization. ########## python/pyspark/worker.py: ########## @@ -1514,10 +1514,293 @@ def _remove_partition_by_exprs(self, arg: Any) -> Any: else: return arg + class ArrowUDTFWithPartition: + """ + Implements logic for an Arrow UDTF (SQL_ARROW_UDTF) that accepts a TABLE argument + with one or more PARTITION BY expressions. + + Arrow UDTFs receive data as PyArrow RecordBatch objects instead of individual Row + objects. + + Example table: + CREATE TABLE t (c1 INT, c2 INT) USING delta; + + Example queries: + SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2); + partition_child_indexes: 0, 1. + + SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2 + 4); + partition_child_indexes: 0, 2 (adds a projection for "c2 + 4"). + """ + + def __init__(self, create_udtf: Callable, partition_child_indexes: list): + """ + Create a new instance that wraps the provided Arrow UDTF with partitioning + logic. + + Parameters + ---------- + create_udtf: function + Function that creates a new instance of the Arrow UDTF to invoke. + partition_child_indexes: list + Zero-based indexes of input-table columns that contain projected + partitioning expressions. + """ + self._create_udtf: Callable = create_udtf + self._udtf = create_udtf() + self._partition_child_indexes: list = partition_child_indexes + # Track last partition key from previous batch + self._last_partition_key: Optional[Tuple[Any, ...]] = None + self._eval_raised_skip_rest_of_input_table: bool = False + + def eval(self, *args, **kwargs) -> Iterator: + """Handle partitioning logic for Arrow UDTFs that receive RecordBatch objects.""" + import pyarrow as pa + + # Get the original batch with partition columns + original_batch = self._get_table_arg(list(args) + list(kwargs.values())) + if not isinstance(original_batch, pa.RecordBatch): + # Arrow UDTFs with PARTITION BY must have a TABLE argument that + # results in a PyArrow RecordBatch + raise PySparkRuntimeError( + errorClass="INVALID_ARROW_UDTF_TABLE_ARGUMENT", + messageParameters={ + "actual_type": str(type(original_batch)) + if original_batch is not None + else "None" + }, + ) + + # Remove partition columns to get the filtered arguments + filtered_args = [self._remove_partition_by_exprs(arg) for arg in args] + filtered_kwargs = { + key: self._remove_partition_by_exprs(value) for (key, value) in kwargs.items() + } + + # Get the filtered RecordBatch (without partition columns) + filtered_batch = self._get_table_arg(filtered_args + list(filtered_kwargs.values())) + + # Process the RecordBatch by partitions + yield from self._process_arrow_batch_by_partitions( + original_batch, filtered_batch, filtered_args, filtered_kwargs + ) + + def _process_arrow_batch_by_partitions( + self, original_batch, filtered_batch, filtered_args, filtered_kwargs + ) -> Iterator: + """Process an Arrow RecordBatch by splitting it into partitions. + + Since Catalyst guarantees that rows with the same partition key are contiguous, + we can use efficient boundary detection instead of group_by. + + Handles two scenarios: + 1. Multiple partitions within a single RecordBatch (using boundary detection) + 2. Same partition key continuing from previous RecordBatch (tracking state) + """ + import pyarrow as pa + + if self._partition_child_indexes: + # Detect partition boundaries. + boundaries = self._detect_partition_boundaries(original_batch) + + # Process each contiguous partition + for i in range(len(boundaries) - 1): + start_idx = boundaries[i] + end_idx = boundaries[i + 1] + + # Get the partition key for this segment + partition_key = tuple( + original_batch.column(idx)[start_idx].as_py() + for idx in self._partition_child_indexes + ) + + # Check if this is a continuation of the previous batch's partition + is_new_partition = ( + self._last_partition_key is not None + and partition_key != self._last_partition_key + ) + + if is_new_partition: + # Previous partition ended, call terminate + if hasattr(self._udtf, "terminate"): + terminate_result = self._udtf.terminate() + if terminate_result is not None: + yield from terminate_result + # Create new UDTF instance for new partition + self._udtf = self._create_udtf() + self._eval_raised_skip_rest_of_input_table = False + + # Slice the filtered batch for this partition + partition_batch = filtered_batch.slice(start_idx, end_idx - start_idx) + + # Update the last partition key + self._last_partition_key = partition_key + + # Update filtered args to use the partition batch + partition_filtered_args = [] + for arg in filtered_args: + if isinstance(arg, pa.RecordBatch): + partition_filtered_args.append(partition_batch) + else: + partition_filtered_args.append(arg) + + partition_filtered_kwargs = {} + for key, value in filtered_kwargs.items(): + if isinstance(value, pa.RecordBatch): + partition_filtered_kwargs[key] = partition_batch + else: + partition_filtered_kwargs[key] = value + + # Call the UDTF with this partition's data + if not self._eval_raised_skip_rest_of_input_table: + try: + result = self._udtf.eval( + *partition_filtered_args, **partition_filtered_kwargs + ) + if result is not None: + yield from result + except SkipRestOfInputTableException: + # Skip remaining rows in this partition + self._eval_raised_skip_rest_of_input_table = True + + # Don't terminate here - let the next batch or final terminate handle it + else: + # No partitions, process the entire batch as one group + try: + result = self._udtf.eval(*filtered_args, **filtered_kwargs) + if result is not None: + # result is an iterator of PyArrow Tables (for Arrow UDTFs) + yield from result + except SkipRestOfInputTableException: + pass + + def terminate(self) -> Iterator: + if hasattr(self._udtf, "terminate"): + return self._udtf.terminate() + return iter(()) + + def cleanup(self) -> None: + if hasattr(self._udtf, "cleanup"): + self._udtf.cleanup() + + def _get_table_arg(self, inputs: list): + """Get the table argument (RecordBatch) from the inputs list. + + For Arrow UDTFs with TABLE arguments, we can guarantee the table argument + will be a pa.RecordBatch, not a Row. + """ + import pyarrow as pa + + # Find all RecordBatch arguments + batches = [arg for arg in inputs if isinstance(arg, pa.RecordBatch)] + + if len(batches) == 0: + # No RecordBatch found - this shouldn't happen for Arrow UDTFs with TABLE arguments + return None + elif len(batches) == 1: + return batches[0] + else: + # Multiple RecordBatch arguments found - this is unexpected + raise RuntimeError( + f"Expected exactly one pa.RecordBatch argument for TABLE parameter, " + f"but found {len(batches)}. Received types: " + f"{[type(arg).__name__ for arg in inputs]}" + ) + + def _detect_partition_boundaries(self, batch) -> list: + """ + Efficiently detect partition boundaries in a batch with contiguous partitions. + + Since Catalyst ensures rows with the same partition key are contiguous, + we only need to find where partition values change. + + Returns: + List of indices where each partition starts, plus the total row count. + For example: [0, 3, 8, 10] means partitions are rows [0:3), [3:8), [8:10) + """ + boundaries = [0] # First partition starts at index 0 + + if batch.num_rows <= 1 or not self._partition_child_indexes: Review Comment: ditto for `self._partition_child_indexes` check? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org