allisonwang-db commented on code in PR #52140: URL: https://github.com/apache/spark/pull/52140#discussion_r2323951651
########## python/pyspark/errors/error-conditions.json: ########## @@ -967,6 +967,11 @@ "Column names of the returned pyarrow.Table do not match specified schema.<missing><extra>" ] }, + "RESULT_COLUMNS_MISMATCH_FOR_ARROW_UDTF": { + "message": [ + "Column names of the returned pyarrow.Table do not match specified schema. Expected: <expected> Actual: <actual>" Review Comment: This is not necessarily pyarrow.Table (it can be columnar batch). How about let's just say, Column names of the returned table do not match ... ########## python/pyspark/sql/pandas/serializers.py: ########## @@ -227,6 +227,63 @@ def load_stream(self, stream): result_batches.append(batch.column(i)) yield result_batches + def _create_array(self, arr, arrow_type): + import pyarrow as pa + + assert isinstance(arr, pa.Array) + assert isinstance(arrow_type, pa.DataType) + if arr.type == arrow_type: + return arr + else: + try: + # when safe is True, the cast will fail if there's a overflow or other + # unsafe conversion + return arr.cast(target_type=arrow_type, safe=True) + except (pa.ArrowInvalid, pa.ArrowTypeError): + raise PySparkTypeError( + "Arrow UDTFs require the return type to match the expected Arrow type. " + f"Expected: {arrow_type}, but got: {arr.type}." + ) + + def dump_stream(self, iterator, stream): + """ + Override to handle type coercion for ArrowUDTF outputs. + ArrowUDTF returns iterator of (pa.RecordBatch, arrow_return_type) tuples. + """ + import pyarrow as pa + + def apply_type_coercion(): + for batch, arrow_return_type in iterator: + assert isinstance( + arrow_return_type, pa.StructType + ), f"Expected pa.StructType, got {type(arrow_return_type)}" + + # Handle empty struct case specially + if batch.num_columns == 0: + coerced_batch = batch # skip type coercion + else: + expected_field_names = arrow_return_type.names + actual_field_names = batch.schema.names + + if expected_field_names != actual_field_names: + raise PySparkTypeError( + "Target schema's field names are not matching the record batch's " + "field names. " + f"Expected: {expected_field_names}, but got: {actual_field_names}." + ) + + coerced_arrays = [] + for i, field in enumerate(arrow_return_type): + original_array = batch.column(i) + coerced_array = self._create_array(original_array, field.type) + coerced_arrays.append(coerced_array) Review Comment: Got it. Can we add a comment here mentioning why we don't use record batch.cast an d what's the minimum pyarrow version to support it? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org