allisonwang-db commented on code in PR #52317: URL: https://github.com/apache/spark/pull/52317#discussion_r2356668054
########## python/pyspark/worker.py: ########## @@ -1514,10 +1514,293 @@ def _remove_partition_by_exprs(self, arg: Any) -> Any: else: return arg + class ArrowUDTFWithPartition: + """ + Implements logic for an Arrow UDTF (SQL_ARROW_UDTF) that accepts a TABLE argument + with one or more PARTITION BY expressions. + + Arrow UDTFs receive data as PyArrow RecordBatch objects instead of individual Row + objects. + + Example table: + CREATE TABLE t (c1 INT, c2 INT) USING delta; + + Example queries: + SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2); + partition_child_indexes: 0, 1. + + SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2 + 4); + partition_child_indexes: 0, 2 (adds a projection for "c2 + 4"). + """ + + def __init__(self, create_udtf: Callable, partition_child_indexes: list): + """ + Create a new instance that wraps the provided Arrow UDTF with partitioning + logic. + + Parameters + ---------- + create_udtf: function + Function that creates a new instance of the Arrow UDTF to invoke. + partition_child_indexes: list + Zero-based indexes of input-table columns that contain projected + partitioning expressions. + """ + self._create_udtf: Callable = create_udtf + self._udtf = create_udtf() + self._partition_child_indexes: list = partition_child_indexes + # Track last partition key from previous batch + self._last_partition_key: Optional[Tuple[Any, ...]] = None + self._eval_raised_skip_rest_of_input_table: bool = False + + def eval(self, *args, **kwargs) -> Iterator: + """Handle partitioning logic for Arrow UDTFs that receive RecordBatch objects.""" + import pyarrow as pa + + # Get the original batch with partition columns + original_batch = self._get_table_arg(list(args) + list(kwargs.values())) + if not isinstance(original_batch, pa.RecordBatch): + # Arrow UDTFs with PARTITION BY must have a TABLE argument that + # results in a PyArrow RecordBatch + raise PySparkRuntimeError( + errorClass="INVALID_ARROW_UDTF_TABLE_ARGUMENT", + messageParameters={ + "actual_type": str(type(original_batch)) + if original_batch is not None + else "None" + }, + ) + + # Remove partition columns to get the filtered arguments + filtered_args = [self._remove_partition_by_exprs(arg) for arg in args] + filtered_kwargs = { + key: self._remove_partition_by_exprs(value) for (key, value) in kwargs.items() + } + + # Get the filtered RecordBatch (without partition columns) + filtered_batch = self._get_table_arg(filtered_args + list(filtered_kwargs.values())) + + # Process the RecordBatch by partitions + yield from self._process_arrow_batch_by_partitions( + original_batch, filtered_batch, filtered_args, filtered_kwargs + ) + + def _process_arrow_batch_by_partitions( + self, original_batch, filtered_batch, filtered_args, filtered_kwargs + ) -> Iterator: + """Process an Arrow RecordBatch by splitting it into partitions. + + Since Catalyst guarantees that rows with the same partition key are contiguous, + we can use efficient boundary detection instead of group_by. + + Handles two scenarios: + 1. Multiple partitions within a single RecordBatch (using boundary detection) + 2. Same partition key continuing from previous RecordBatch (tracking state) + """ + import pyarrow as pa + + if self._partition_child_indexes: Review Comment: Make sense. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org