gaogaotiantian commented on code in PR #54125:
URL: https://github.com/apache/spark/pull/54125#discussion_r2776439160


##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -452,122 +434,50 @@ def arrow_to_pandas(
             ndarray_as_list=ndarray_as_list,
         )
 
-    def _create_array(self, series, spark_type, *, arrow_cast=False, 
prefers_large_types=False):
+    def dump_stream(self, iterator, stream):
         """
-        Create an Arrow Array from the given pandas.Series and Spark type.
-
-        Parameters
-        ----------
-        series : pandas.Series
-            A single series
-        spark_type : DataType, optional
-            The Spark return type. For UDF return types, this should always be 
provided
-            and should never be None. If None, pyarrow's inferred type will be 
used
-            (for backward compatibility).
-        arrow_cast : bool, optional
-            Whether to apply Arrow casting when the user-specified return type 
mismatches the
-            actual return values.
-        prefers_large_types : bool, optional
-            Whether to prefer large Arrow types (e.g., large_string instead of 
string).
-
-        Returns
-        -------
-        pyarrow.Array
+        Make ArrowRecordBatches from Pandas Series and serialize.
+        Each element in iterator is:
+        - For batched UDFs: tuple of (series, spark_type) tuples: ((s1, t1), 
(s2, t2), ...)
+        - For iterator UDFs: single (series, spark_type) tuple directly
         """
-        import pyarrow as pa
-        import pandas as pd
-
-        if isinstance(series.dtype, pd.CategoricalDtype):
-            series = series.astype(series.dtypes.categories.dtype)
 
-        # Derive arrow_type from spark_type
-        arrow_type = (
-            to_arrow_type(
-                spark_type, timezone=self._timezone, 
prefers_large_types=prefers_large_types
-            )
-            if spark_type is not None
-            else None
-        )
+        def create_batch(
+            packed: Union[
+                Tuple["pd.Series", DataType],
+                Tuple[Tuple["pd.Series", DataType], ...],
+            ],
+        ) -> "pa.RecordBatch":
+            """
+            Create batch from UDF output.
 
-        if spark_type is not None:
-            conv = _create_converter_from_pandas(
-                spark_type,
+            Parameters
+            ----------
+            packed : tuple
+                - For iterator UDFs: single (series, spark_type) tuple
+                - For batched UDFs: tuple of tuples ((s1, t1), (s2, t2), ...)
+            """
+            # Normalize: iterator UDFs yield (series, spark_type) directly,
+            # batched UDFs return tuple of tuples ((s1, t1), (s2, t2), ...)
+            if len(packed) == 2 and isinstance(packed[1], DataType):
+                # single UDF result: wrap in list
+                series_tuples: List[Tuple["pd.Series", DataType]] = [packed]
+            else:
+                # multiple UDF results: already iterable of tuples
+                series_tuples = list(packed)

Review Comment:
   We have a lot of random conversions to `list` - why is it preferred? I think 
`tuple` should be used when possible (or keep it what it is if conversion is 
unnecessary). Immutable objects are always better - including the input data - 
we should at least take either.



##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -619,170 +531,71 @@ def __init__(
             arrow_cast,
         )
         self._assign_cols_by_name = assign_cols_by_name
+        self._ignore_unexpected_complex_type_values = 
ignore_unexpected_complex_type_values
+        self._is_udtf = is_udtf
 
-    def _create_struct_array(
-        self,
-        df: "pd.DataFrame",
-        return_type: StructType,
-        *,
-        prefers_large_types: bool = False,
-    ):
-        """
-        Create an Arrow StructArray from the given pandas.DataFrame and Spark 
StructType.
-
-        Parameters
-        ----------
-        df : pandas.DataFrame
-            A pandas DataFrame
-        return_type : StructType
-            The Spark return type (StructType) to use
-        prefers_large_types : bool, optional
-            Whether to prefer large Arrow types (e.g., large_string instead of 
string).
-
-        Returns
-        -------
-        pyarrow.Array
-        """
-        import pyarrow as pa
-
-        # Derive arrow_struct_type from return_type
-        arrow_struct_type = to_arrow_type(
-            return_type, timezone=self._timezone, 
prefers_large_types=prefers_large_types
-        )
-
-        if len(df.columns) == 0:
-            return pa.array([{}] * len(df), arrow_struct_type)
-        # Assign result columns by schema name if user labeled with strings
-        if self._assign_cols_by_name and any(isinstance(name, str) for name in 
df.columns):
-            struct_arrs = [
-                self._create_array(
-                    df[spark_field.name],
-                    spark_field.dataType,
-                    arrow_cast=self._arrow_cast,
-                    prefers_large_types=prefers_large_types,
-                )
-                for spark_field in return_type
-            ]
-        # Assign result columns by position
-        else:
-            struct_arrs = [
-                # the selected series has name '1', so we rename it to 
spark_field.name
-                # as the name is used by _create_array to provide a meaningful 
error message
-                self._create_array(
-                    df[df.columns[i]].rename(spark_field.name),
-                    spark_field.dataType,
-                    arrow_cast=self._arrow_cast,
-                    prefers_large_types=prefers_large_types,
-                )
-                for i, spark_field in enumerate(return_type)
-            ]
-
-        return pa.StructArray.from_arrays(struct_arrs, 
fields=list(arrow_struct_type))
-
-    def _create_batch(
-        self, series, *, arrow_cast=False, prefers_large_types=False, 
struct_in_pandas="dict"
-    ):
+    def dump_stream(self, iterator, stream):
         """
-        Create an Arrow record batch from the given pandas.Series, 
pandas.DataFrame,
-        or list of Series/DataFrame, with optional Spark type.
-
-        Parameters
-        ----------
-        series : pandas.Series or pandas.DataFrame or list
-            A single series or dataframe, list of series or dataframe,
-            or list of (series or dataframe, spark_type) tuples.
-        arrow_cast : bool, optional
-            If True, use Arrow's cast method for type conversion.
-        prefers_large_types : bool, optional
-            Whether to prefer large Arrow types (e.g., large_string instead of 
string).
-        struct_in_pandas : str, optional
-            How to represent struct types in pandas: "dict" or "row".
-            Default is "dict".
+        Override because Pandas UDFs require a START_ARROW_STREAM before the 
Arrow stream is sent.
+        This should be sent after creating the first record batch so in case 
of an error, it can
+        be sent back to the JVM before the Arrow stream starts.
 
-        Returns
-        -------
-        pyarrow.RecordBatch
-            Arrow RecordBatch
+        Each element in iterator is:
+        - For batched UDFs: tuple of (series, spark_type) tuples: ((s1, t1), 
(s2, t2), ...)
+        - For iterator UDFs: single (series, spark_type) tuple directly
         """
         import pandas as pd
-        import pyarrow as pa
 
-        # Normalize input to list of (data, spark_type) tuples
-        # Handle: single series, (series, type) tuple, or list of tuples
-        if not isinstance(series, (list, tuple)) or (
-            len(series) == 2 and isinstance(series[1], DataType)
-        ):
-            series = [series]
-        # Ensure each element is a (data, spark_type) tuple
-        series = [(s, None) if not isinstance(s, (list, tuple)) else s for s 
in series]
-
-        arrs = []
-        for s, spark_type in series:
-            # Convert spark_type to arrow_type for type checking (similar to 
master branch)
-            arrow_type = (
-                to_arrow_type(
-                    spark_type, timezone=self._timezone, 
prefers_large_types=prefers_large_types
-                )
-                if spark_type is not None
-                else None
-            )
+        def create_batch(
+            packed: Union[
+                Tuple[Union["pd.Series", "pd.DataFrame"], DataType],
+                Tuple[Tuple[Union["pd.Series", "pd.DataFrame"], DataType], 
...],
+            ],
+        ) -> "pa.RecordBatch":
+            """
+            Create batch from UDF output.
 
-            # Variants are represented in arrow as structs with additional 
metadata (checked by
-            # is_variant). If the data type is Variant, return a VariantVal 
atomic type instead of
-            # a dict of two binary values.
-            if (
-                struct_in_pandas == "dict"
-                and arrow_type is not None
-                and pa.types.is_struct(arrow_type)
-                and not is_variant(arrow_type)
-            ):
-                # A pandas UDF should return pd.DataFrame when the return type 
is a struct type.
-                # If it returns a pd.Series, it should throw an error.
-                if not isinstance(s, pd.DataFrame):
-                    raise PySparkValueError(
-                        "Invalid return type. Please make sure that the UDF 
returns a "
-                        "pandas.DataFrame when the specified return type is 
StructType."
-                    )
-                arrs.append(
-                    self._create_struct_array(
-                        s, spark_type, prefers_large_types=prefers_large_types
-                    )
-                )
-            elif isinstance(s, pd.DataFrame):
-                # If data is a DataFrame (e.g., from df_for_struct), use 
_create_struct_array
-                arrs.append(
-                    self._create_struct_array(
-                        s, spark_type, prefers_large_types=prefers_large_types
-                    )
-                )
+            Parameters
+            ----------
+            packed : tuple
+                - For iterator UDFs: single (series, spark_type) tuple
+                - For batched UDFs: tuple of tuples ((s1, t1), (s2, t2), ...)
+            """
+            # Normalize: iterator UDFs yield (series, spark_type) directly,
+            # batched UDFs return tuple of tuples ((s1, t1), (s2, t2), ...)
+            if len(packed) == 2 and isinstance(packed[1], DataType):
+                # single UDF result: wrap in list
+                series_tuples: List[Tuple[Union["pd.Series", "pd.DataFrame"], 
DataType]] = [packed]

Review Comment:
   BTW if we need to type hint a local variable, and the variable is assigned 
at more than one location, we normally do the type hint separately:
   
   ```python
   series_tuples: List[...]
   if xxx:
       series_tuples = ...
   else:
       series_tuples = ...
   ```



##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -619,170 +531,71 @@ def __init__(
             arrow_cast,
         )
         self._assign_cols_by_name = assign_cols_by_name
+        self._ignore_unexpected_complex_type_values = 
ignore_unexpected_complex_type_values
+        self._is_udtf = is_udtf
 
-    def _create_struct_array(
-        self,
-        df: "pd.DataFrame",
-        return_type: StructType,
-        *,
-        prefers_large_types: bool = False,
-    ):
-        """
-        Create an Arrow StructArray from the given pandas.DataFrame and Spark 
StructType.
-
-        Parameters
-        ----------
-        df : pandas.DataFrame
-            A pandas DataFrame
-        return_type : StructType
-            The Spark return type (StructType) to use
-        prefers_large_types : bool, optional
-            Whether to prefer large Arrow types (e.g., large_string instead of 
string).
-
-        Returns
-        -------
-        pyarrow.Array
-        """
-        import pyarrow as pa
-
-        # Derive arrow_struct_type from return_type
-        arrow_struct_type = to_arrow_type(
-            return_type, timezone=self._timezone, 
prefers_large_types=prefers_large_types
-        )
-
-        if len(df.columns) == 0:
-            return pa.array([{}] * len(df), arrow_struct_type)
-        # Assign result columns by schema name if user labeled with strings
-        if self._assign_cols_by_name and any(isinstance(name, str) for name in 
df.columns):
-            struct_arrs = [
-                self._create_array(
-                    df[spark_field.name],
-                    spark_field.dataType,
-                    arrow_cast=self._arrow_cast,
-                    prefers_large_types=prefers_large_types,
-                )
-                for spark_field in return_type
-            ]
-        # Assign result columns by position
-        else:
-            struct_arrs = [
-                # the selected series has name '1', so we rename it to 
spark_field.name
-                # as the name is used by _create_array to provide a meaningful 
error message
-                self._create_array(
-                    df[df.columns[i]].rename(spark_field.name),
-                    spark_field.dataType,
-                    arrow_cast=self._arrow_cast,
-                    prefers_large_types=prefers_large_types,
-                )
-                for i, spark_field in enumerate(return_type)
-            ]
-
-        return pa.StructArray.from_arrays(struct_arrs, 
fields=list(arrow_struct_type))
-
-    def _create_batch(
-        self, series, *, arrow_cast=False, prefers_large_types=False, 
struct_in_pandas="dict"
-    ):
+    def dump_stream(self, iterator, stream):
         """
-        Create an Arrow record batch from the given pandas.Series, 
pandas.DataFrame,
-        or list of Series/DataFrame, with optional Spark type.
-
-        Parameters
-        ----------
-        series : pandas.Series or pandas.DataFrame or list
-            A single series or dataframe, list of series or dataframe,
-            or list of (series or dataframe, spark_type) tuples.
-        arrow_cast : bool, optional
-            If True, use Arrow's cast method for type conversion.
-        prefers_large_types : bool, optional
-            Whether to prefer large Arrow types (e.g., large_string instead of 
string).
-        struct_in_pandas : str, optional
-            How to represent struct types in pandas: "dict" or "row".
-            Default is "dict".
+        Override because Pandas UDFs require a START_ARROW_STREAM before the 
Arrow stream is sent.
+        This should be sent after creating the first record batch so in case 
of an error, it can
+        be sent back to the JVM before the Arrow stream starts.
 
-        Returns
-        -------
-        pyarrow.RecordBatch
-            Arrow RecordBatch
+        Each element in iterator is:
+        - For batched UDFs: tuple of (series, spark_type) tuples: ((s1, t1), 
(s2, t2), ...)
+        - For iterator UDFs: single (series, spark_type) tuple directly
         """
         import pandas as pd
-        import pyarrow as pa
 
-        # Normalize input to list of (data, spark_type) tuples
-        # Handle: single series, (series, type) tuple, or list of tuples
-        if not isinstance(series, (list, tuple)) or (
-            len(series) == 2 and isinstance(series[1], DataType)
-        ):
-            series = [series]
-        # Ensure each element is a (data, spark_type) tuple
-        series = [(s, None) if not isinstance(s, (list, tuple)) else s for s 
in series]
-
-        arrs = []
-        for s, spark_type in series:
-            # Convert spark_type to arrow_type for type checking (similar to 
master branch)
-            arrow_type = (
-                to_arrow_type(
-                    spark_type, timezone=self._timezone, 
prefers_large_types=prefers_large_types
-                )
-                if spark_type is not None
-                else None
-            )
+        def create_batch(

Review Comment:
   Did you name all the to arrow batch functions `create_batch`? I think 
`batch` is probably a bit too ambiguous. Either `arrowbatch` or `recordbatch` 
or even `pabatch` (maybe `_` somewhere) would be more descriptive.



##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -452,122 +434,50 @@ def arrow_to_pandas(
             ndarray_as_list=ndarray_as_list,
         )
 
-    def _create_array(self, series, spark_type, *, arrow_cast=False, 
prefers_large_types=False):
+    def dump_stream(self, iterator, stream):
         """
-        Create an Arrow Array from the given pandas.Series and Spark type.
-
-        Parameters
-        ----------
-        series : pandas.Series
-            A single series
-        spark_type : DataType, optional
-            The Spark return type. For UDF return types, this should always be 
provided
-            and should never be None. If None, pyarrow's inferred type will be 
used
-            (for backward compatibility).
-        arrow_cast : bool, optional
-            Whether to apply Arrow casting when the user-specified return type 
mismatches the
-            actual return values.
-        prefers_large_types : bool, optional
-            Whether to prefer large Arrow types (e.g., large_string instead of 
string).
-
-        Returns
-        -------
-        pyarrow.Array
+        Make ArrowRecordBatches from Pandas Series and serialize.
+        Each element in iterator is:
+        - For batched UDFs: tuple of (series, spark_type) tuples: ((s1, t1), 
(s2, t2), ...)
+        - For iterator UDFs: single (series, spark_type) tuple directly
         """
-        import pyarrow as pa
-        import pandas as pd
-
-        if isinstance(series.dtype, pd.CategoricalDtype):
-            series = series.astype(series.dtypes.categories.dtype)
 
-        # Derive arrow_type from spark_type
-        arrow_type = (
-            to_arrow_type(
-                spark_type, timezone=self._timezone, 
prefers_large_types=prefers_large_types
-            )
-            if spark_type is not None
-            else None
-        )
+        def create_batch(
+            packed: Union[
+                Tuple["pd.Series", DataType],
+                Tuple[Tuple["pd.Series", DataType], ...],
+            ],
+        ) -> "pa.RecordBatch":
+            """
+            Create batch from UDF output.
 
-        if spark_type is not None:
-            conv = _create_converter_from_pandas(
-                spark_type,
+            Parameters
+            ----------
+            packed : tuple
+                - For iterator UDFs: single (series, spark_type) tuple
+                - For batched UDFs: tuple of tuples ((s1, t1), (s2, t2), ...)
+            """
+            # Normalize: iterator UDFs yield (series, spark_type) directly,
+            # batched UDFs return tuple of tuples ((s1, t1), (s2, t2), ...)
+            if len(packed) == 2 and isinstance(packed[1], DataType):
+                # single UDF result: wrap in list
+                series_tuples: List[Tuple["pd.Series", DataType]] = [packed]
+            else:
+                # multiple UDF results: already iterable of tuples
+                series_tuples = list(packed)
+
+            series_data, types = map(list, list(zip(*series_tuples)) or [(), 
()])

Review Comment:
   I think this would be simpler - took me 5 minutes to figure out what's 
happening on this line.



##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -829,17 +626,22 @@ def dump_stream(self, iterator, stream):
         """
         import pyarrow as pa
 
-        def create_batches():
-            for packed in iterator:
-                if len(packed) == 2 and isinstance(packed[1], pa.DataType):
-                    # single array UDF in a projection
-                    arrs = [self._create_array(packed[0], packed[1], 
self._arrow_cast)]
-                else:
-                    # multiple array UDFs in a projection
-                    arrs = [self._create_array(t[0], t[1], self._arrow_cast) 
for t in packed]
-                yield pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in 
range(len(arrs))])
+        def create_batch(
+            packed: Union[
+                Tuple["pa.Array", "pa.DataType"],
+                List[Tuple["pa.Array", "pa.DataType"]],
+            ],
+        ) -> "pa.RecordBatch":
+            # Normalize: single UDF (arr, type) -> [(arr, type)]
+            if len(packed) == 2 and isinstance(packed[1], pa.DataType):
+                packed = [packed]  # type: ignore[list-item]
+            arrs = [
+                cast_arrow_array(arr, arrow_type, safe=self._safecheck, 
allow_cast=self._arrow_cast)

Review Comment:
   As `cast_arrow_array` is something we defined, I think it would be better to 
have the same name parameters. `safecheck` is pretty good and we probably don't 
need to shrink it as `safe`. `allow_cast` vs `arrow_cast` is definitely a 
nightmare, especially for people who have trouble pronouncing `r` vs `l` :)



##########
python/pyspark/sql/conversion.py:
##########
@@ -159,6 +159,246 @@ def to_pandas(
         ]
 
 
+def cast_arrow_array(
+    arr: "pa.Array",
+    target_type: "pa.DataType",
+    *,
+    safe: bool = True,
+    allow_cast: bool = True,
+    error_class: Optional[str] = None,
+) -> "pa.Array":
+    """
+    Cast an Arrow Array to a target type.
+
+    Parameters
+    ----------
+    arr : pa.Array
+        Input Arrow array
+    target_type : pa.DataType
+        Target Arrow type
+    safe : bool
+        Whether to use safe casting (default True)
+    allow_cast : bool
+        Whether to allow casting when types don't match (default True)
+    error_class : str, optional
+        Custom error class for type mismatch errors
+
+    Returns
+    -------
+    pa.Array
+    """
+    import pyarrow as pa
+
+    from pyspark.errors import PySparkRuntimeError, PySparkTypeError
+
+    if arr.type == target_type:
+        return arr
+
+    if not allow_cast:
+        raise PySparkTypeError(
+            "Arrow UDFs require the return type to match the expected Arrow 
type. "
+            f"Expected: {target_type}, but got: {arr.type}."
+        )
+
+    try:
+        return arr.cast(target_type=target_type, safe=safe)
+    except (pa.ArrowInvalid, pa.ArrowTypeError):
+        if error_class:
+            raise PySparkRuntimeError(
+                errorClass=error_class,
+                messageParameters={
+                    "expected": str(target_type),
+                    "actual": str(arr.type),

Review Comment:
   There's a strong assumption here that the template has `expected` and 
`actual`, which feels a bit weird to me.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to