Yicong-Huang commented on code in PR #54125:
URL: https://github.com/apache/spark/pull/54125#discussion_r2776493226


##########
python/pyspark/sql/conversion.py:
##########
@@ -159,6 +159,246 @@ def to_pandas(
         ]
 
 
+def cast_arrow_array(
+    arr: "pa.Array",
+    target_type: "pa.DataType",
+    *,
+    safe: bool = True,
+    allow_cast: bool = True,
+    error_class: Optional[str] = None,
+) -> "pa.Array":
+    """
+    Cast an Arrow Array to a target type.
+
+    Parameters
+    ----------
+    arr : pa.Array
+        Input Arrow array
+    target_type : pa.DataType
+        Target Arrow type
+    safe : bool
+        Whether to use safe casting (default True)
+    allow_cast : bool
+        Whether to allow casting when types don't match (default True)
+    error_class : str, optional
+        Custom error class for type mismatch errors
+
+    Returns
+    -------
+    pa.Array
+    """
+    import pyarrow as pa
+
+    from pyspark.errors import PySparkRuntimeError, PySparkTypeError
+
+    if arr.type == target_type:
+        return arr
+
+    if not allow_cast:
+        raise PySparkTypeError(
+            "Arrow UDFs require the return type to match the expected Arrow 
type. "
+            f"Expected: {target_type}, but got: {arr.type}."
+        )
+
+    try:
+        return arr.cast(target_type=target_type, safe=safe)
+    except (pa.ArrowInvalid, pa.ArrowTypeError):
+        if error_class:
+            raise PySparkRuntimeError(
+                errorClass=error_class,
+                messageParameters={
+                    "expected": str(target_type),
+                    "actual": str(arr.type),

Review Comment:
   that is unfortunately the current behavior on master. I would prefer to keep 
it the same for this PR, we can definitely improve it later.



##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -452,122 +434,50 @@ def arrow_to_pandas(
             ndarray_as_list=ndarray_as_list,
         )
 
-    def _create_array(self, series, spark_type, *, arrow_cast=False, 
prefers_large_types=False):
+    def dump_stream(self, iterator, stream):
         """
-        Create an Arrow Array from the given pandas.Series and Spark type.
-
-        Parameters
-        ----------
-        series : pandas.Series
-            A single series
-        spark_type : DataType, optional
-            The Spark return type. For UDF return types, this should always be 
provided
-            and should never be None. If None, pyarrow's inferred type will be 
used
-            (for backward compatibility).
-        arrow_cast : bool, optional
-            Whether to apply Arrow casting when the user-specified return type 
mismatches the
-            actual return values.
-        prefers_large_types : bool, optional
-            Whether to prefer large Arrow types (e.g., large_string instead of 
string).
-
-        Returns
-        -------
-        pyarrow.Array
+        Make ArrowRecordBatches from Pandas Series and serialize.
+        Each element in iterator is:
+        - For batched UDFs: tuple of (series, spark_type) tuples: ((s1, t1), 
(s2, t2), ...)
+        - For iterator UDFs: single (series, spark_type) tuple directly
         """
-        import pyarrow as pa
-        import pandas as pd
-
-        if isinstance(series.dtype, pd.CategoricalDtype):
-            series = series.astype(series.dtypes.categories.dtype)
 
-        # Derive arrow_type from spark_type
-        arrow_type = (
-            to_arrow_type(
-                spark_type, timezone=self._timezone, 
prefers_large_types=prefers_large_types
-            )
-            if spark_type is not None
-            else None
-        )
+        def create_batch(
+            packed: Union[
+                Tuple["pd.Series", DataType],
+                Tuple[Tuple["pd.Series", DataType], ...],
+            ],
+        ) -> "pa.RecordBatch":
+            """
+            Create batch from UDF output.
 
-        if spark_type is not None:
-            conv = _create_converter_from_pandas(
-                spark_type,
+            Parameters
+            ----------
+            packed : tuple
+                - For iterator UDFs: single (series, spark_type) tuple
+                - For batched UDFs: tuple of tuples ((s1, t1), (s2, t2), ...)
+            """
+            # Normalize: iterator UDFs yield (series, spark_type) directly,
+            # batched UDFs return tuple of tuples ((s1, t1), (s2, t2), ...)
+            if len(packed) == 2 and isinstance(packed[1], DataType):
+                # single UDF result: wrap in list
+                series_tuples: List[Tuple["pd.Series", DataType]] = [packed]
+            else:
+                # multiple UDF results: already iterable of tuples
+                series_tuples = list(packed)

Review Comment:
   the callsites of this method (i.e., eval types wrappers in worker.py) 
currently return `list`. I agree we can refactor this, but the change would be 
too large to include in this PR. we can gradually change it when we move this 
logic out of serializer?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to