Yicong-Huang commented on code in PR #54125:
URL: https://github.com/apache/spark/pull/54125#discussion_r2800799777


##########
python/pyspark/sql/conversion.py:
##########
@@ -162,6 +162,239 @@ def to_pandas(
         ]
 
 
+# TODO: elevate to ArrowBatchTransformer and operate on full RecordBatch schema
+#       instead of per-column coercion.
+def coerce_arrow_array(
+    arr: "pa.Array",
+    target_type: "pa.DataType",
+    *,
+    safecheck: bool = True,
+    arrow_cast: bool = True,
+) -> "pa.Array":
+    """
+    Coerce an Arrow Array to a target type, with optional type-mismatch 
enforcement.
+
+    When ``arrow_cast`` is True (default), mismatched types are cast to the
+    target type.  When False, a type mismatch raises an error instead.
+
+    Parameters
+    ----------
+    arr : pa.Array
+        Input Arrow array
+    target_type : pa.DataType
+        Target Arrow type
+    safecheck : bool
+        Whether to use safe casting (default True)
+    arrow_cast : bool
+        Whether to allow casting when types don't match (default True)
+
+    Returns
+    -------
+    pa.Array
+    """
+    from pyspark.errors import PySparkTypeError
+
+    if arr.type == target_type:
+        return arr
+
+    if not arrow_cast:
+        raise PySparkTypeError(
+            "Arrow UDFs require the return type to match the expected Arrow 
type. "
+            f"Expected: {target_type}, but got: {arr.type}."
+        )
+
+    # when safe is True, the cast will fail if there's a overflow or other
+    # unsafe conversion.
+    # RecordBatch.cast(...) isn't used as minimum PyArrow version
+    # required for RecordBatch.cast(...) is v16.0
+    return arr.cast(target_type=target_type, safe=safecheck)
+
+
+class PandasToArrowConversion:
+    """
+    Conversion utilities from pandas data to Arrow.
+    """
+
+    @classmethod
+    def convert(
+        cls,
+        data: Union["pd.DataFrame", List[Union["pd.Series", "pd.DataFrame"]]],
+        schema: StructType,
+        *,
+        timezone: Optional[str] = None,
+        safecheck: bool = True,
+        arrow_cast: bool = False,
+        prefers_large_types: bool = False,
+        assign_cols_by_name: bool = False,
+        int_to_decimal_coercion_enabled: bool = False,
+        ignore_unexpected_complex_type_values: bool = False,
+        is_udtf: bool = False,
+    ) -> "pa.RecordBatch":
+        """
+        Convert a pandas DataFrame or list of Series/DataFrames to an Arrow 
RecordBatch.
+
+        Parameters
+        ----------
+        data : pd.DataFrame or list of pd.Series/pd.DataFrame
+            Input data - either a DataFrame or a list of Series/DataFrames.
+        schema : StructType
+            Spark schema defining the types for each column
+        timezone : str, optional
+            Timezone for timestamp conversion
+        safecheck : bool
+            Whether to use safe Arrow conversion (default True)
+        arrow_cast : bool
+            Whether to allow Arrow casting on type mismatch (default False)
+        prefers_large_types : bool
+            Whether to prefer large Arrow types (default False)
+        assign_cols_by_name : bool
+            Whether to reorder DataFrame columns by name to match schema 
(default False)
+        int_to_decimal_coercion_enabled : bool
+            Whether to enable int to decimal coercion (default False)
+        ignore_unexpected_complex_type_values : bool
+            Whether to ignore unexpected complex type values in converter 
(default False)
+        is_udtf : bool
+            Whether this conversion is for a UDTF. UDTFs use broader Arrow 
exception
+            handling to allow more type coercions (e.g., struct field casting 
via
+            ArrowTypeError), and convert errors to UDTF_ARROW_TYPE_CAST_ERROR.
+            Regular UDFs only catch ArrowInvalid to preserve legacy behavior 
where
+            e.g. string→decimal must raise an error. (default False)
+
+        Returns
+        -------
+        pa.RecordBatch
+        """
+        import pyarrow as pa
+        import pandas as pd
+
+        from pyspark.errors import PySparkTypeError, PySparkValueError, 
PySparkRuntimeError
+        from pyspark.sql.pandas.types import to_arrow_type, 
_create_converter_from_pandas
+
+        # Handle empty schema (0 columns)
+        # Use dummy column + select([]) to preserve row count (PyArrow 
limitation workaround)
+        if not schema.fields:
+            num_rows = len(data[0]) if isinstance(data, list) and data else 
len(data)
+            return pa.RecordBatch.from_pydict({"_": [None] * 
num_rows}).select([])
+
+        # Handle empty DataFrame (0 columns) with non-empty schema
+        # This happens when user returns pd.DataFrame() for struct types
+        if isinstance(data, pd.DataFrame) and len(data.columns) == 0:
+            arrow_type = to_arrow_type(
+                schema, timezone=timezone, 
prefers_large_types=prefers_large_types
+            )
+            return pa.RecordBatch.from_struct_array(pa.array([{}] * len(data), 
arrow_type))
+
+        # Normalize input: reorder DataFrame columns by schema names if needed,
+        # then extract columns as a list for uniform iteration.
+        if isinstance(data, list):
+            columns = data
+        else:
+            if assign_cols_by_name and any(isinstance(c, str) for c in 
data.columns):
+                data = data[schema.names]
+            columns = [data.iloc[:, i] for i in range(len(schema.fields))]
+
+        def series_to_array(series: "pd.Series", ret_type: DataType, 
field_name: str) -> "pa.Array":
+            """Convert a pandas Series to an Arrow Array (closure over 
conversion params).
+
+            Uses field_name for error messages instead of series.name to avoid
+            copying the Series via rename() — a ~20% overhead on the hot path.
+            """
+            if isinstance(series.dtype, pd.CategoricalDtype):
+                series = series.astype(series.dtype.categories.dtype)
+
+            arrow_type = to_arrow_type(
+                ret_type, timezone=timezone, 
prefers_large_types=prefers_large_types
+            )
+            series = _create_converter_from_pandas(
+                ret_type,
+                timezone=timezone,
+                error_on_duplicated_field_names=False,
+                
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
+                
ignore_unexpected_complex_type_values=ignore_unexpected_complex_type_values,
+            )(series)
+
+            mask = None if hasattr(series.array, "__arrow_array__") else 
series.isnull()
+
+            if is_udtf:

Review Comment:
   The `is_udtf` special handling is in the `from_pandas` stage (catching 
broader `ArrowException` instead of just `ArrowInvalid`), not in the `.cast()` 
stage that `coerce_arrow_array` handles. So it fits better in 
`series_to_array`. This flag will be eliminated via SPARK-55502 when we unify 
UDTF and regular UDF conversion paths.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to