allisonwang-db commented on code in PR #52140:
URL: https://github.com/apache/spark/pull/52140#discussion_r2323951651


##########
python/pyspark/errors/error-conditions.json:
##########
@@ -967,6 +967,11 @@
       "Column names of the returned pyarrow.Table do not match specified 
schema.<missing><extra>"
     ]
   },
+  "RESULT_COLUMNS_MISMATCH_FOR_ARROW_UDTF": {
+    "message": [
+      "Column names of the returned pyarrow.Table do not match specified 
schema. Expected: <expected> Actual: <actual>"

Review Comment:
   This is not necessarily pyarrow.Table (it can be columnar batch). How about 
let's just say, Column names of the returned table do not match ...



##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -227,6 +227,63 @@ def load_stream(self, stream):
                     result_batches.append(batch.column(i))
             yield result_batches
 
+    def _create_array(self, arr, arrow_type):
+        import pyarrow as pa
+
+        assert isinstance(arr, pa.Array)
+        assert isinstance(arrow_type, pa.DataType)
+        if arr.type == arrow_type:
+            return arr
+        else:
+            try:
+                # when safe is True, the cast will fail if there's a overflow 
or other
+                # unsafe conversion
+                return arr.cast(target_type=arrow_type, safe=True)
+            except (pa.ArrowInvalid, pa.ArrowTypeError):
+                raise PySparkTypeError(
+                    "Arrow UDTFs require the return type to match the expected 
Arrow type. "
+                    f"Expected: {arrow_type}, but got: {arr.type}."
+                )
+
+    def dump_stream(self, iterator, stream):
+        """
+        Override to handle type coercion for ArrowUDTF outputs.
+        ArrowUDTF returns iterator of (pa.RecordBatch, arrow_return_type) 
tuples.
+        """
+        import pyarrow as pa
+
+        def apply_type_coercion():
+            for batch, arrow_return_type in iterator:
+                assert isinstance(
+                    arrow_return_type, pa.StructType
+                ), f"Expected pa.StructType, got {type(arrow_return_type)}"
+
+                # Handle empty struct case specially
+                if batch.num_columns == 0:
+                    coerced_batch = batch  # skip type coercion
+                else:
+                    expected_field_names = arrow_return_type.names
+                    actual_field_names = batch.schema.names
+
+                    if expected_field_names != actual_field_names:
+                        raise PySparkTypeError(
+                            "Target schema's field names are not matching the 
record batch's "
+                            "field names. "
+                            f"Expected: {expected_field_names}, but got: 
{actual_field_names}."
+                        )
+
+                    coerced_arrays = []
+                    for i, field in enumerate(arrow_return_type):
+                        original_array = batch.column(i)
+                        coerced_array = self._create_array(original_array, 
field.type)
+                        coerced_arrays.append(coerced_array)

Review Comment:
   Got it. Can we add a comment here mentioning why we don't use record 
batch.cast an d what's the minimum pyarrow version to support it?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to