shujingyang-db commented on code in PR #52140:
URL: https://github.com/apache/spark/pull/52140#discussion_r2317239489


##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -227,6 +246,58 @@ def load_stream(self, stream):
                     result_batches.append(batch.column(i))
             yield result_batches
 
+    def dump_stream(self, iterator, stream):
+        """
+        Override to handle type coercion for ArrowUDTF outputs.
+        ArrowUDTF returns iterator of (pa.RecordBatch, arrow_return_type) 
tuples.
+        """
+        import pyarrow as pa
+        from pyspark.serializers import write_int, SpecialLengths
+
+        def wrap_and_init_stream():
+            should_write_start_length = True
+            for packed in iterator:
+                batch, arrow_return_type = packed
+                assert isinstance(
+                    arrow_return_type, pa.StructType
+                ), f"Expected pa.StructType, got {type(arrow_return_type)}"
+
+                # Handle empty struct case specially
+                if batch.num_columns == 0:
+                    # When batch has no column, it should still create
+                    # an empty batch with the number of rows set.
+                    struct = pa.array([{}] * batch.num_rows)
+                    coerced_batch = pa.RecordBatch.from_arrays([struct], 
["_0"])
+                else:
+                    # Apply type coercion to each column if needed
+                    coerced_arrays = []
+                    for i, field in enumerate(arrow_return_type):
+                        if i < batch.num_columns:
+                            original_array = batch.column(i)
+                            coerced_array = self._create_array(original_array, 
field.type)
+                            coerced_arrays.append(coerced_array)
+                        else:
+                            raise PySparkRuntimeError(
+                                errorClass="UDTF_RETURN_SCHEMA_MISMATCH",
+                                messageParameters={
+                                    "expected": str(len(arrow_return_type)),
+                                    "actual": str(batch.num_columns),
+                                    "func": "ArrowUDTF",
+                                },
+                            )
+
+                    struct = pa.StructArray.from_arrays(coerced_arrays, 
fields=arrow_return_type)
+                    coerced_batch = pa.RecordBatch.from_arrays([struct], 
["_0"])
+
+                # Write the first record batch with initialization
+                if should_write_start_length:
+                    write_int(SpecialLengths.START_ARROW_STREAM, stream)
+                    should_write_start_length = False
+
+                yield coerced_batch

Review Comment:
   make sense! I changed it to `RecordBatch.cast`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to