ueshin commented on code in PR #52140: URL: https://github.com/apache/spark/pull/52140#discussion_r2317309319
########## python/pyspark/sql/pandas/serializers.py: ########## @@ -227,6 +227,60 @@ def load_stream(self, stream): result_batches.append(batch.column(i)) yield result_batches + def dump_stream(self, iterator, stream): + """ + Override to handle type coercion for ArrowUDTF outputs. + ArrowUDTF returns iterator of (pa.RecordBatch, arrow_return_type) tuples. + """ + import pyarrow as pa + + def apply_type_coercion(): + for batch, arrow_return_type in iterator: + assert isinstance( + arrow_return_type, pa.StructType + ), f"Expected pa.StructType, got {type(arrow_return_type)}" + + # Handle empty struct case specially + if batch.num_columns == 0: + # When batch has no column, it should still create + # an empty batch with the number of rows set. + struct = pa.array([{}] * batch.num_rows) + coerced_batch = pa.RecordBatch.from_arrays([struct], ["_0"]) + else: + target_schema = pa.schema(arrow_return_type) + try: + coerced_batch = batch.cast(target_schema) + except (pa.ArrowInvalid, pa.ArrowTypeError) as e: + from pyspark.errors import PySparkTypeError + + # Extract specific type mismatch information + expected_type = None + actual_type = None + + # Find the first mismatched field for better error message + if len(target_schema) == len(batch.schema): + for expected_field, actual_field in zip(target_schema, batch.schema): + if expected_field.type != actual_field.type: + expected_type = expected_field.type + actual_type = actual_field.type + break + + if expected_type and actual_type: + error_msg = f"Expected: {expected_type}, but got: {actual_type} in field '{expected_field.name}'." + else: + error_msg = f"Expected: {target_schema}, but got: {batch.schema}." + + raise PySparkTypeError( + "Arrow UDTFs require the return type to match the expected Arrow type." + + error_msg + ) from e + + yield coerced_batch, arrow_return_type + + return super(ArrowStreamArrowUDTFSerializer, self).dump_stream( Review Comment: nit: `super().dump_stream ...` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org