This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new d2a47b9c208c [SPARK-55159][PYTHON] Consolidate pandas-to-Arrow
conversion utilities in serializers
d2a47b9c208c is described below
commit d2a47b9c208ccbbdd59310e3f8cba37cb5f65163
Author: Yicong-Huang <[email protected]>
AuthorDate: Thu Feb 19 09:55:44 2026 +0900
[SPARK-55159][PYTHON] Consolidate pandas-to-Arrow conversion utilities in
serializers
### What changes were proposed in this pull request?
Introduce `PandasToArrowConversion.convert()` in `conversion.py` to
centralize the pandas-to-Arrow conversion logic previously duplicated across
multiple serializers. Also extract `cast_arrow_array()` as a standalone utility
for Arrow array type casting.
Serializers (`ArrowStreamPandasSerializer`,
`ArrowStreamPandasUDFSerializer`, `ArrowStreamPandasUDTFSerializer`, etc.) now
delegate to these shared utilities instead of maintaining their own
`_create_array`, `_create_batch`, and `_create_struct_array` methods.
### Why are the changes needed?
Part of [SPARK-55159](https://issues.apache.org/jira/browse/SPARK-55159).
The same conversion logic was duplicated across 5+ serializer classes, making
it hard to maintain. This reduces ~450 lines of duplication.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
New unit tests in `test_conversion.py` for `PandasToArrowConversion`, plus
existing UDF/UDTF tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #54125 from
Yicong-Huang/SPARK-55159/refactor/consolidate-pandas-to-arrow.
Lead-authored-by: Yicong-Huang
<[email protected]>
Co-authored-by: Yicong Huang
<[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/conversion.py | 240 ++++++++++++
python/pyspark/sql/pandas/serializers.py | 585 +++++++---------------------
python/pyspark/sql/tests/test_conversion.py | 208 +++++++++-
3 files changed, 579 insertions(+), 454 deletions(-)
diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py
index a6a983c940e8..7b5fd9747a69 100644
--- a/python/pyspark/sql/conversion.py
+++ b/python/pyspark/sql/conversion.py
@@ -162,6 +162,246 @@ class ArrowBatchTransformer:
]
+# TODO: elevate to ArrowBatchTransformer and operate on full RecordBatch schema
+# instead of per-column coercion.
+def coerce_arrow_array(
+ arr: "pa.Array",
+ target_type: "pa.DataType",
+ *,
+ safecheck: bool = True,
+ arrow_cast: bool = True,
+) -> "pa.Array":
+ """
+ Coerce an Arrow Array to a target type, with optional type-mismatch
enforcement.
+
+ When ``arrow_cast`` is True (default), mismatched types are cast to the
+ target type. When False, a type mismatch raises an error instead.
+
+ Parameters
+ ----------
+ arr : pa.Array
+ Input Arrow array
+ target_type : pa.DataType
+ Target Arrow type
+ safecheck : bool
+ Whether to use safe casting (default True)
+ arrow_cast : bool
+ Whether to allow casting when types don't match (default True)
+
+ Returns
+ -------
+ pa.Array
+ """
+ from pyspark.errors import PySparkTypeError
+
+ if arr.type == target_type:
+ return arr
+
+ if not arrow_cast:
+ raise PySparkTypeError(
+ "Arrow UDFs require the return type to match the expected Arrow
type. "
+ f"Expected: {target_type}, but got: {arr.type}."
+ )
+
+ # when safe is True, the cast will fail if there's a overflow or other
+ # unsafe conversion.
+ # RecordBatch.cast(...) isn't used as minimum PyArrow version
+ # required for RecordBatch.cast(...) is v16.0
+ return arr.cast(target_type=target_type, safe=safecheck)
+
+
+class PandasToArrowConversion:
+ """
+ Conversion utilities from pandas data to Arrow.
+ """
+
+ @classmethod
+ def convert(
+ cls,
+ data: Union["pd.DataFrame", Sequence[Union["pd.Series",
"pd.DataFrame"]]],
+ schema: StructType,
+ *,
+ timezone: Optional[str] = None,
+ safecheck: bool = True,
+ arrow_cast: bool = False,
+ prefers_large_types: bool = False,
+ assign_cols_by_name: bool = False,
+ int_to_decimal_coercion_enabled: bool = False,
+ ignore_unexpected_complex_type_values: bool = False,
+ is_udtf: bool = False,
+ ) -> "pa.RecordBatch":
+ """
+ Convert a pandas DataFrame or list of Series/DataFrames to an Arrow
RecordBatch.
+
+ Parameters
+ ----------
+ data : pd.DataFrame or list of pd.Series/pd.DataFrame
+ Input data - either a single DataFrame, or a list of
Series/DataFrames
+ (one per schema field). A list of DataFrames is used when UDFs
return struct
+ types as DataFrames (e.g., applyInPandas with state).
+ schema : StructType
+ Spark schema defining the types for each column
+ timezone : str, optional
+ Timezone for timestamp conversion
+ safecheck : bool
+ Whether to use safe Arrow conversion (default True)
+ arrow_cast : bool
+ Whether to allow Arrow casting on type mismatch (default False)
+ prefers_large_types : bool
+ Whether to prefer large Arrow types (default False)
+ assign_cols_by_name : bool
+ Whether to reorder DataFrame columns by name to match schema
(default False)
+ int_to_decimal_coercion_enabled : bool
+ Whether to enable int to decimal coercion (default False)
+ ignore_unexpected_complex_type_values : bool
+ Whether to ignore unexpected complex type values in converter
(default False)
+ is_udtf : bool
+ Whether this conversion is for a UDTF. UDTFs use broader Arrow
exception
+ handling to allow more type coercions (e.g., struct field casting
via
+ ArrowTypeError), and convert errors to UDTF_ARROW_TYPE_CAST_ERROR.
+ # TODO(SPARK-55502): Unify UDTF and regular UDF conversion paths to
+ # eliminate the is_udtf flag.
+ Regular UDFs only catch ArrowInvalid to preserve legacy behavior
where
+ e.g. string→decimal must raise an error. (default False)
+
+ Returns
+ -------
+ pa.RecordBatch
+ """
+ import pyarrow as pa
+ import pandas as pd
+
+ from pyspark.errors import PySparkTypeError, PySparkValueError,
PySparkRuntimeError
+ from pyspark.sql.pandas.types import to_arrow_type,
_create_converter_from_pandas
+
+ # Handle empty schema (0 columns)
+ # Use dummy column + select([]) to preserve row count (PyArrow
limitation workaround)
+ if len(schema.fields) == 0:
+ num_rows = len(data[0]) if isinstance(data, list) and data else
len(data)
+ return pa.RecordBatch.from_pydict({"_": [None] *
num_rows}).select([])
+
+ # Handle empty DataFrame (0 columns) with non-empty schema
+ # This happens when user returns pd.DataFrame() for struct types
+ if isinstance(data, pd.DataFrame) and len(data.columns) == 0:
+ arrow_type = to_arrow_type(
+ schema, timezone=timezone,
prefers_large_types=prefers_large_types
+ )
+ return pa.RecordBatch.from_struct_array(pa.array([{}] * len(data),
arrow_type))
+
+ # Normalize input: reorder DataFrame columns by schema names if needed,
+ # then extract columns as a list for uniform iteration.
+ columns: List[Union["pd.Series", "pd.DataFrame"]]
+ if isinstance(data, pd.DataFrame):
+ if assign_cols_by_name and any(isinstance(c, str) for c in
data.columns):
+ data = data[schema.names]
+ columns = [data.iloc[:, i] for i in range(len(schema.fields))]
+ else:
+ columns = list(data)
+
+ def convert_column(
+ col: Union["pd.Series", "pd.DataFrame"], field: StructField
+ ) -> "pa.Array":
+ """Convert a single column (Series or DataFrame) to an Arrow Array.
+
+ Uses field.name for error messages instead of series.name to avoid
+ copying the Series via rename() — a ~20% overhead on the hot path.
+ """
+ if isinstance(col, pd.DataFrame):
+ assert isinstance(field.dataType, StructType)
+ nested_batch = cls.convert(
+ col,
+ field.dataType,
+ timezone=timezone,
+ safecheck=safecheck,
+ arrow_cast=arrow_cast,
+ prefers_large_types=prefers_large_types,
+ assign_cols_by_name=assign_cols_by_name,
+
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
+
ignore_unexpected_complex_type_values=ignore_unexpected_complex_type_values,
+ is_udtf=is_udtf,
+ )
+ # Wrap the nested RecordBatch as a single StructArray column
+ return
ArrowBatchTransformer.wrap_struct(nested_batch).column(0)
+
+ series = col
+ field_name = field.name
+ ret_type = field.dataType
+
+ if isinstance(series.dtype, pd.CategoricalDtype):
+ series = series.astype(series.dtype.categories.dtype)
+
+ arrow_type = to_arrow_type(
+ ret_type, timezone=timezone,
prefers_large_types=prefers_large_types
+ )
+ series = _create_converter_from_pandas(
+ ret_type,
+ timezone=timezone,
+ error_on_duplicated_field_names=False,
+
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
+
ignore_unexpected_complex_type_values=ignore_unexpected_complex_type_values,
+ )(series)
+
+ mask = None if hasattr(series.array, "__arrow_array__") else
series.isnull()
+
+ if is_udtf:
+ # UDTF path: broad ArrowException catch so that both
ArrowInvalid
+ # AND ArrowTypeError (e.g. dict→struct) trigger the cast
fallback.
+ try:
+ try:
+ return pa.Array.from_pandas(
+ series, mask=mask, type=arrow_type, safe=safecheck
+ )
+ except pa.lib.ArrowException: # broad: includes
ArrowTypeError
+ if arrow_cast:
+ return pa.Array.from_pandas(series,
mask=mask).cast(
+ target_type=arrow_type, safe=safecheck
+ )
+ raise
+ except pa.lib.ArrowException: # convert any Arrow error to
user-friendly message
+ raise PySparkRuntimeError(
+ errorClass="UDTF_ARROW_TYPE_CAST_ERROR",
+ messageParameters={
+ "col_name": field_name,
+ "col_type": str(series.dtype),
+ "arrow_type": str(arrow_type),
+ },
+ ) from None
+ else:
+ # UDF path: only ArrowInvalid triggers the cast fallback.
+ # ArrowTypeError (e.g. string→decimal) must NOT be silently
cast.
+ try:
+ try:
+ return pa.Array.from_pandas(
+ series, mask=mask, type=arrow_type, safe=safecheck
+ )
+ except pa.lib.ArrowInvalid: # narrow: skip ArrowTypeError
+ if arrow_cast:
+ return pa.Array.from_pandas(series,
mask=mask).cast(
+ target_type=arrow_type, safe=safecheck
+ )
+ raise
+ except TypeError as e: # includes pa.lib.ArrowTypeError
+ raise PySparkTypeError(
+ f"Exception thrown when converting pandas.Series
({series.dtype}) "
+ f"with name '{field_name}' to Arrow Array
({arrow_type})."
+ ) from e
+ except ValueError as e: # includes pa.lib.ArrowInvalid
+ error_msg = (
+ f"Exception thrown when converting pandas.Series
({series.dtype}) "
+ f"with name '{field_name}' to Arrow Array
({arrow_type})."
+ )
+ if safecheck:
+ error_msg += (
+ " It can be caused by overflows or other unsafe
conversions "
+ "warned by Arrow. Arrow safe type check can be
disabled by using "
+ "SQL config
`spark.sql.execution.pandas.convertToArrowArraySafely`."
+ )
+ raise PySparkValueError(error_msg) from e
+
+ arrays = [convert_column(col, field) for col, field in zip(columns,
schema.fields)]
+ return pa.RecordBatch.from_arrays(arrays, schema.names)
+
+
class LocalDataToArrowConversion:
"""
Conversion from local data (except pandas DataFrame and numpy ndarray) to
Arrow.
diff --git a/python/pyspark/sql/pandas/serializers.py
b/python/pyspark/sql/pandas/serializers.py
index 55cedd6f02dc..aac6df47a3b8 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -20,7 +20,7 @@ Serializers for PyArrow and pandas conversions. See
`pyspark.serializers` for mo
"""
from itertools import groupby
-from typing import TYPE_CHECKING, Any, Iterator, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple
import pyspark
from pyspark.errors import PySparkRuntimeError, PySparkTypeError,
PySparkValueError
@@ -36,12 +36,12 @@ from pyspark.sql.conversion import (
LocalDataToArrowConversion,
ArrowTableToRowsConversion,
ArrowBatchTransformer,
+ PandasToArrowConversion,
+ coerce_arrow_array,
)
from pyspark.sql.pandas.types import (
from_arrow_schema,
- is_variant,
to_arrow_type,
- _create_converter_from_pandas,
)
from pyspark.sql.types import (
DataType,
@@ -58,6 +58,19 @@ if TYPE_CHECKING:
import pyarrow as pa
+def _normalize_packed(packed):
+ """
+ Normalize UDF output to a uniform tuple-of-tuples form.
+
+ Iterator UDFs yield a single (series, spark_type) tuple directly,
+ while batched UDFs return a tuple of tuples ((s1, t1), (s2, t2), ...).
+ This function normalizes both forms to a tuple of tuples.
+ """
+ if len(packed) == 2 and isinstance(packed[1], DataType):
+ return (packed,)
+ return tuple(packed)
+
+
class SpecialLengths:
END_OF_DATA_SECTION = -1
PYTHON_EXCEPTION_THROWN = -2
@@ -272,29 +285,6 @@ class
ArrowStreamArrowUDTFSerializer(ArrowStreamUDTFSerializer):
for i in range(batch.num_columns)
]
- def _create_array(self, arr, arrow_type):
- import pyarrow as pa
-
- assert isinstance(arr, pa.Array)
- assert isinstance(arrow_type, pa.DataType)
- if arr.type == arrow_type:
- return arr
- else:
- try:
- # when safe is True, the cast will fail if there's a overflow
or other
- # unsafe conversion.
- # RecordBatch.cast(...) isn't used as minimum PyArrow version
- # required for RecordBatch.cast(...) is v16.0
- return arr.cast(target_type=arrow_type, safe=True)
- except (pa.ArrowInvalid, pa.ArrowTypeError):
- raise PySparkRuntimeError(
- errorClass="RESULT_COLUMNS_MISMATCH_FOR_ARROW_UDTF",
- messageParameters={
- "expected": str(arrow_type),
- "actual": str(arr.type),
- },
- )
-
def dump_stream(self, iterator, stream):
"""
Override to handle type coercion for ArrowUDTF outputs.
@@ -324,9 +314,22 @@ class
ArrowStreamArrowUDTFSerializer(ArrowStreamUDTFSerializer):
coerced_arrays = []
for i, field in enumerate(arrow_return_type):
- original_array = batch.column(i)
- coerced_array = self._create_array(original_array,
field.type)
- coerced_arrays.append(coerced_array)
+ try:
+ coerced_arrays.append(
+ coerce_arrow_array(
+ batch.column(i),
+ field.type,
+ safecheck=True,
+ )
+ )
+ except (pa.ArrowInvalid, pa.ArrowTypeError):
+ raise PySparkRuntimeError(
+
errorClass="RESULT_COLUMNS_MISMATCH_FOR_ARROW_UDTF",
+ messageParameters={
+ "expected": str(field.type),
+ "actual": str(batch.column(i).type),
+ },
+ )
coerced_batch = pa.RecordBatch.from_arrays(
coerced_arrays, names=expected_field_names
)
@@ -445,122 +448,32 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
self._input_type = input_type
self._arrow_cast = arrow_cast
- def _create_array(self, series, spark_type, *, arrow_cast=False,
prefers_large_types=False):
- """
- Create an Arrow Array from the given pandas.Series and Spark type.
-
- Parameters
- ----------
- series : pandas.Series
- A single series
- spark_type : DataType, optional
- The Spark return type. For UDF return types, this should always be
provided
- and should never be None. If None, pyarrow's inferred type will be
used
- (for backward compatibility).
- arrow_cast : bool, optional
- Whether to apply Arrow casting when the user-specified return type
mismatches the
- actual return values.
- prefers_large_types : bool, optional
- Whether to prefer large Arrow types (e.g., large_string instead of
string).
-
- Returns
- -------
- pyarrow.Array
+ def dump_stream(self, iterator, stream):
"""
- import pyarrow as pa
- import pandas as pd
-
- if isinstance(series.dtype, pd.CategoricalDtype):
- series = series.astype(series.dtypes.categories.dtype)
-
- # Derive arrow_type from spark_type
- arrow_type = (
- to_arrow_type(
- spark_type, timezone=self._timezone,
prefers_large_types=prefers_large_types
- )
- if spark_type is not None
- else None
- )
-
- if spark_type is not None:
- conv = _create_converter_from_pandas(
- spark_type,
+ Make ArrowRecordBatches from Pandas Series and serialize.
+ Each element in iterator is:
+ - For batched UDFs: tuple of (series, spark_type) tuples: ((s1, t1),
(s2, t2), ...)
+ - For iterator UDFs: single (series, spark_type) tuple directly
+ """
+
+ def create_batch(
+ series_tuples: Tuple[Tuple["pd.Series", DataType], ...],
+ ) -> "pa.RecordBatch":
+ series_data = [s for s, _ in series_tuples]
+ types = [t for _, t in series_tuples]
+ schema = StructType([StructField(f"_{i}", t) for i, t in
enumerate(types)])
+ return PandasToArrowConversion.convert(
+ series_data,
+ schema,
timezone=self._timezone,
- error_on_duplicated_field_names=False,
+ safecheck=self._safecheck,
+ prefers_large_types=self._prefers_large_types,
int_to_decimal_coercion_enabled=self._int_to_decimal_coercion_enabled,
)
- series = conv(series)
-
- if hasattr(series.array, "__arrow_array__"):
- mask = None
- else:
- mask = series.isnull()
- try:
- try:
- return pa.Array.from_pandas(
- series, mask=mask, type=arrow_type, safe=self._safecheck
- )
- except pa.lib.ArrowInvalid:
- if arrow_cast:
- return pa.Array.from_pandas(series, mask=mask).cast(
- target_type=arrow_type, safe=self._safecheck
- )
- else:
- raise
- except TypeError as e:
- error_msg = (
- "Exception thrown when converting pandas.Series (%s) "
- "with name '%s' to Arrow Array (%s)."
- )
- raise PySparkTypeError(error_msg % (series.dtype, series.name,
arrow_type)) from e
- except ValueError as e:
- error_msg = (
- "Exception thrown when converting pandas.Series (%s) "
- "with name '%s' to Arrow Array (%s)."
- )
- if self._safecheck:
- error_msg = error_msg + (
- " It can be caused by overflows or other "
- "unsafe conversions warned by Arrow. Arrow safe type check
"
- "can be disabled by using SQL config "
- "`spark.sql.execution.pandas.convertToArrowArraySafely`."
- )
- raise PySparkValueError(error_msg % (series.dtype, series.name,
arrow_type)) from e
-
- def _create_batch(self, series, *, prefers_large_types=False):
- """
- Create an Arrow record batch from the given iterable of (series,
spark_type) tuples.
-
- Parameters
- ----------
- series : iterable
- Iterable of (series, spark_type) tuples.
- prefers_large_types : bool, optional
- Whether to prefer large Arrow types (e.g., large_string instead of
string).
- Returns
- -------
- pyarrow.RecordBatch
- Arrow RecordBatch
- """
- import pyarrow as pa
-
- arrs = [
- self._create_array(s, spark_type,
prefers_large_types=prefers_large_types)
- for s, spark_type in series
- ]
- return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in
range(len(arrs))])
-
- def dump_stream(self, iterator, stream):
- """
- Make ArrowRecordBatches from Pandas Series and serialize.
- Each element in iterator is an iterable of (series, spark_type) tuples.
- """
- batches = (
- self._create_batch(series,
prefers_large_types=self._prefers_large_types)
- for series in iterator
+ super().dump_stream(
+ (create_batch(_normalize_packed(packed)) for packed in iterator),
stream
)
- super().dump_stream(batches, stream)
def load_stream(self, stream):
"""
@@ -599,6 +512,8 @@ class
ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
input_type: Optional[StructType] = None,
int_to_decimal_coercion_enabled: bool = False,
prefers_large_types: bool = False,
+ ignore_unexpected_complex_type_values: bool = False,
+ is_udtf: bool = False,
):
super().__init__(
timezone,
@@ -612,170 +527,51 @@ class
ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
arrow_cast,
)
self._assign_cols_by_name = assign_cols_by_name
+ self._ignore_unexpected_complex_type_values =
ignore_unexpected_complex_type_values
+ self._is_udtf = is_udtf
- def _create_struct_array(
- self,
- df: "pd.DataFrame",
- return_type: StructType,
- *,
- prefers_large_types: bool = False,
- ):
- """
- Create an Arrow StructArray from the given pandas.DataFrame and Spark
StructType.
-
- Parameters
- ----------
- df : pandas.DataFrame
- A pandas DataFrame
- return_type : StructType
- The Spark return type (StructType) to use
- prefers_large_types : bool, optional
- Whether to prefer large Arrow types (e.g., large_string instead of
string).
-
- Returns
- -------
- pyarrow.Array
- """
- import pyarrow as pa
-
- # Derive arrow_struct_type from return_type
- arrow_struct_type = to_arrow_type(
- return_type, timezone=self._timezone,
prefers_large_types=prefers_large_types
- )
-
- if len(df.columns) == 0:
- return pa.array([{}] * len(df), arrow_struct_type)
- # Assign result columns by schema name if user labeled with strings
- if self._assign_cols_by_name and any(isinstance(name, str) for name in
df.columns):
- struct_arrs = [
- self._create_array(
- df[spark_field.name],
- spark_field.dataType,
- arrow_cast=self._arrow_cast,
- prefers_large_types=prefers_large_types,
- )
- for spark_field in return_type
- ]
- # Assign result columns by position
- else:
- struct_arrs = [
- # the selected series has name '1', so we rename it to
spark_field.name
- # as the name is used by _create_array to provide a meaningful
error message
- self._create_array(
- df[df.columns[i]].rename(spark_field.name),
- spark_field.dataType,
- arrow_cast=self._arrow_cast,
- prefers_large_types=prefers_large_types,
- )
- for i, spark_field in enumerate(return_type)
- ]
-
- return pa.StructArray.from_arrays(struct_arrs,
fields=list(arrow_struct_type))
-
- def _create_batch(
- self, series, *, arrow_cast=False, prefers_large_types=False,
struct_in_pandas="dict"
- ):
+ def dump_stream(self, iterator, stream):
"""
- Create an Arrow record batch from the given pandas.Series,
pandas.DataFrame,
- or list of Series/DataFrame, with optional Spark type.
-
- Parameters
- ----------
- series : pandas.Series or pandas.DataFrame or list
- A single series or dataframe, list of series or dataframe,
- or list of (series or dataframe, spark_type) tuples.
- arrow_cast : bool, optional
- If True, use Arrow's cast method for type conversion.
- prefers_large_types : bool, optional
- Whether to prefer large Arrow types (e.g., large_string instead of
string).
- struct_in_pandas : str, optional
- How to represent struct types in pandas: "dict" or "row".
- Default is "dict".
+ Override because Pandas UDFs require a START_ARROW_STREAM before the
Arrow stream is sent.
+ This should be sent after creating the first record batch so in case
of an error, it can
+ be sent back to the JVM before the Arrow stream starts.
- Returns
- -------
- pyarrow.RecordBatch
- Arrow RecordBatch
+ Each element in iterator is:
+ - For batched UDFs: tuple of (series, spark_type) tuples: ((s1, t1),
(s2, t2), ...)
+ - For iterator UDFs: single (series, spark_type) tuple directly
"""
import pandas as pd
- import pyarrow as pa
-
- # Normalize input to list of (data, spark_type) tuples
- # Handle: single series, (series, type) tuple, or list of tuples
- if not isinstance(series, (list, tuple)) or (
- len(series) == 2 and isinstance(series[1], DataType)
- ):
- series = [series]
- # Ensure each element is a (data, spark_type) tuple
- series = [(s, None) if not isinstance(s, (list, tuple)) else s for s
in series]
-
- arrs = []
- for s, spark_type in series:
- # Convert spark_type to arrow_type for type checking (similar to
master branch)
- arrow_type = (
- to_arrow_type(
- spark_type, timezone=self._timezone,
prefers_large_types=prefers_large_types
- )
- if spark_type is not None
- else None
- )
- # Variants are represented in arrow as structs with additional
metadata (checked by
- # is_variant). If the data type is Variant, return a VariantVal
atomic type instead of
- # a dict of two binary values.
- if (
- struct_in_pandas == "dict"
- and arrow_type is not None
- and pa.types.is_struct(arrow_type)
- and not is_variant(arrow_type)
- ):
- # A pandas UDF should return pd.DataFrame when the return type
is a struct type.
- # If it returns a pd.Series, it should throw an error.
- if not isinstance(s, pd.DataFrame):
- raise PySparkValueError(
- "Invalid return type. Please make sure that the UDF
returns a "
- "pandas.DataFrame when the specified return type is
StructType."
- )
- arrs.append(
- self._create_struct_array(
- s, spark_type, prefers_large_types=prefers_large_types
- )
- )
- elif isinstance(s, pd.DataFrame):
- # If data is a DataFrame (e.g., from df_for_struct), use
_create_struct_array
- arrs.append(
- self._create_struct_array(
- s, spark_type, prefers_large_types=prefers_large_types
- )
- )
- else:
- arrs.append(
- self._create_array(
- s,
- spark_type,
- arrow_cast=arrow_cast,
- prefers_large_types=prefers_large_types,
- )
- )
+ def create_batch(
+ series_tuples: Tuple[Tuple["pd.Series", DataType], ...],
+ ) -> "pa.RecordBatch":
+ # When struct_in_pandas="dict", UDF must return DataFrame for
struct types
+ if self._struct_in_pandas == "dict":
+ for s, spark_type in series_tuples:
+ if isinstance(spark_type, StructType) and not
isinstance(s, pd.DataFrame):
+ raise PySparkValueError(
+ "Invalid return type. Please make sure that the
UDF returns a "
+ "pandas.DataFrame when the specified return type
is StructType."
+ )
- return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in
range(len(arrs))])
+ series_data = [s for s, _ in series_tuples]
+ types = [t for _, t in series_tuples]
+ schema = StructType([StructField(f"_{i}", t) for i, t in
enumerate(types)])
+ return PandasToArrowConversion.convert(
+ series_data,
+ schema,
+ timezone=self._timezone,
+ safecheck=self._safecheck,
+ arrow_cast=self._arrow_cast,
+ prefers_large_types=self._prefers_large_types,
+ assign_cols_by_name=self._assign_cols_by_name,
+
int_to_decimal_coercion_enabled=self._int_to_decimal_coercion_enabled,
+
ignore_unexpected_complex_type_values=self._ignore_unexpected_complex_type_values,
+ is_udtf=self._is_udtf,
+ )
- def dump_stream(self, iterator, stream):
- """
- Override because Pandas UDFs require a START_ARROW_STREAM before the
Arrow stream is sent.
- This should be sent after creating the first record batch so in case
of an error, it can
- be sent back to the JVM before the Arrow stream starts.
- """
batches = self._write_stream_start(
- (
- self._create_batch(
- series,
- arrow_cast=self._arrow_cast,
- prefers_large_types=self._prefers_large_types,
- struct_in_pandas=self._struct_in_pandas,
- )
- for series in iterator
- ),
+ (create_batch(_normalize_packed(packed)) for packed in iterator),
stream,
)
return ArrowStreamSerializer.dump_stream(self, batches, stream)
@@ -798,22 +594,6 @@ class ArrowStreamArrowUDFSerializer(ArrowStreamSerializer):
self._safecheck = safecheck
self._arrow_cast = arrow_cast
- def _create_array(self, arr, arrow_type, arrow_cast):
- import pyarrow as pa
-
- assert isinstance(arr, pa.Array)
- assert isinstance(arrow_type, pa.DataType)
-
- if arr.type == arrow_type:
- return arr
- elif arrow_cast:
- return arr.cast(target_type=arrow_type, safe=self._safecheck)
- else:
- raise PySparkTypeError(
- "Arrow UDFs require the return type to match the expected
Arrow type. "
- f"Expected: {arrow_type}, but got: {arr.type}."
- )
-
def dump_stream(self, iterator, stream):
"""
Override because Arrow UDFs require a START_ARROW_STREAM before the
Arrow stream is sent.
@@ -822,17 +602,25 @@ class
ArrowStreamArrowUDFSerializer(ArrowStreamSerializer):
"""
import pyarrow as pa
- def create_batches():
- for packed in iterator:
- if len(packed) == 2 and isinstance(packed[1], pa.DataType):
- # single array UDF in a projection
- arrs = [self._create_array(packed[0], packed[1],
self._arrow_cast)]
- else:
- # multiple array UDFs in a projection
- arrs = [self._create_array(t[0], t[1], self._arrow_cast)
for t in packed]
- yield pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in
range(len(arrs))])
+ def create_batch(
+ arr_tuples: List[Tuple["pa.Array", "pa.DataType"]],
+ ) -> "pa.RecordBatch":
+ arrs = [
+ coerce_arrow_array(
+ arr, arrow_type, safecheck=self._safecheck,
arrow_cast=self._arrow_cast
+ )
+ for arr, arrow_type in arr_tuples
+ ]
+ return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in
range(len(arrs))])
- batches = self._write_stream_start(create_batches(), stream)
+ def normalize(packed):
+ if len(packed) == 2 and isinstance(packed[1], pa.DataType):
+ return [packed]
+ return list(packed)
+
+ batches = self._write_stream_start(
+ (create_batch(normalize(packed)) for packed in iterator), stream
+ )
return ArrowStreamSerializer.dump_stream(self, batches, stream)
def __repr__(self):
@@ -912,8 +700,9 @@ class
ArrowBatchUDFSerializer(ArrowStreamArrowUDFSerializer):
Parameters
----------
iterator : iterator
- Iterator yielding tuples of (data, arrow_type, spark_type) for
single UDF
- or list of tuples for multiple UDFs in a projection
+ Iterator yielding tuple of (data, arrow_type, spark_type) tuples.
+ Single UDF: ((results, arrow_type, spark_type),)
+ Multiple UDFs: ((r1, t1, s1), (r2, t2, s2), ...)
stream : object
Output stream to write the Arrow record batches
@@ -976,134 +765,13 @@ class
ArrowStreamPandasUDTFSerializer(ArrowStreamPandasUDFSerializer):
input_type=input_type,
# Enable additional coercions for UDTF serialization
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
+ # UDTF-specific: ignore unexpected complex type values in converter
+ ignore_unexpected_complex_type_values=True,
+ # UDTF-specific: enables broader Arrow exception handling and
+ # converts errors to UDTF_ARROW_TYPE_CAST_ERROR
+ is_udtf=True,
)
- def _create_batch(
- self, series, *, arrow_cast=False, prefers_large_types=False,
struct_in_pandas="dict"
- ):
- """
- Create an Arrow record batch from the given iterable of (dataframe,
spark_type) tuples.
-
- Parameters
- ----------
- series : iterable
- Iterable of (dataframe, spark_type) tuples.
- arrow_cast : bool, optional
- Unused, kept for compatibility with parent class signature.
- prefers_large_types : bool, optional
- Whether to prefer large Arrow types (e.g., large_string instead of
string).
- struct_in_pandas : str, optional
- Unused, kept for compatibility with parent class signature.
-
- Returns
- -------
- pyarrow.RecordBatch
- Arrow RecordBatch
- """
- import pandas as pd
- import pyarrow as pa
-
- # Normalize input to list of (data, spark_type) tuples
- # Handle: single dataframe, (dataframe, type) tuple, or list of tuples
- if not isinstance(series, (list, tuple)) or (
- len(series) == 2 and isinstance(series[1], DataType)
- ):
- series = [series]
-
- arrs = []
- for s, spark_type in series:
- if not isinstance(s, pd.DataFrame):
- raise PySparkValueError(
- "Output of an arrow-optimized Python UDTFs expects "
- f"a pandas.DataFrame but got: {type(s)}"
- )
-
- arrs.append(
- self._create_struct_array(s, spark_type,
prefers_large_types=prefers_large_types)
- )
-
- return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in
range(len(arrs))])
-
- def _create_array(self, series, spark_type, *, arrow_cast=False,
prefers_large_types=False):
- """
- Override the `_create_array` method in the superclass to create an
Arrow Array
- from a given pandas.Series and Spark type. The difference here is that
we always
- use arrow cast when creating the arrow array. Also, the error messages
are specific
- to arrow-optimized Python UDTFs.
-
- Parameters
- ----------
- series : pandas.Series
- A single series
- spark_type : DataType, optional
- The Spark return type. For UDF return types, this should always be
provided
- and should never be None. If None, pyarrow's inferred type will be
used
- (for backward compatibility).
- arrow_cast : bool, optional
- Whether to apply Arrow casting when the user-specified return type
mismatches the
- actual return values.
- prefers_large_types : bool, optional
- Whether to prefer large Arrow types (e.g., large_string instead of
string).
-
- Returns
- -------
- pyarrow.Array
- """
- import pyarrow as pa
- import pandas as pd
-
- if isinstance(series.dtype, pd.CategoricalDtype):
- series = series.astype(series.dtypes.categories.dtype)
-
- # Derive arrow_type from spark_type
- arrow_type = (
- to_arrow_type(
- spark_type, timezone=self._timezone,
prefers_large_types=prefers_large_types
- )
- if spark_type is not None
- else None
- )
-
- if spark_type is not None:
- conv = _create_converter_from_pandas(
- spark_type,
- timezone=self._timezone,
- error_on_duplicated_field_names=False,
- ignore_unexpected_complex_type_values=True,
-
int_to_decimal_coercion_enabled=self._int_to_decimal_coercion_enabled,
- )
- series = conv(series)
-
- if hasattr(series.array, "__arrow_array__"):
- mask = None
- else:
- mask = series.isnull()
-
- try:
- try:
- return pa.Array.from_pandas(
- series, mask=mask, type=arrow_type, safe=self._safecheck
- )
- except pa.lib.ArrowException:
- if arrow_cast:
- return pa.Array.from_pandas(series, mask=mask).cast(
- target_type=arrow_type, safe=self._safecheck
- )
- else:
- raise
- except pa.lib.ArrowException:
- # Display the most user-friendly error messages instead of showing
- # arrow's error message. This also works better with Spark Connect
- # where the exception messages are by default truncated.
- raise PySparkRuntimeError(
- errorClass="UDTF_ARROW_TYPE_CAST_ERROR",
- messageParameters={
- "col_name": series.name,
- "col_type": str(series.dtype),
- "arrow_type": arrow_type,
- },
- ) from None
-
def __repr__(self):
return "ArrowStreamPandasUDTFSerializer"
@@ -1583,13 +1251,26 @@ class
ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer):
merged_pdf = pd.concat(pdfs, ignore_index=True)
merged_state_pdf = pd.concat(state_pdfs, ignore_index=True)
- return self._create_batch(
+ # Create batch from list of DataFrames, each wrapped as a
StructArray.
+ # Schema fields map to: _0=count, _1=output data, _2=state data
+ # (types defined in __init__: result_count_df_type, pdf_schema,
result_state_df_type)
+ data = [count_pdf, merged_pdf, merged_state_pdf]
+ schema = StructType(
[
- (count_pdf, self.result_count_df_type),
- (merged_pdf, pdf_schema),
- (merged_state_pdf, self.result_state_df_type),
+ StructField("_0", self.result_count_df_type),
+ StructField("_1", pdf_schema),
+ StructField("_2", self.result_state_df_type),
]
)
+ return PandasToArrowConversion.convert(
+ data,
+ schema,
+ timezone=self._timezone,
+ safecheck=self._safecheck,
+ arrow_cast=self._arrow_cast,
+ assign_cols_by_name=self._assign_cols_by_name,
+
int_to_decimal_coercion_enabled=self._int_to_decimal_coercion_enabled,
+ )
def serialize_batches():
"""
diff --git a/python/pyspark/sql/tests/test_conversion.py
b/python/pyspark/sql/tests/test_conversion.py
index 6b9d81bfa0f5..0560ac983250 100644
--- a/python/pyspark/sql/tests/test_conversion.py
+++ b/python/pyspark/sql/tests/test_conversion.py
@@ -18,30 +18,40 @@ import datetime
import unittest
from zoneinfo import ZoneInfo
-from pyspark.errors import PySparkValueError
+from pyspark.errors import PySparkRuntimeError, PySparkTypeError,
PySparkValueError
from pyspark.sql.conversion import (
ArrowArrayToPandasConversion,
ArrowTableToRowsConversion,
LocalDataToArrowConversion,
ArrowArrayConversion,
ArrowBatchTransformer,
+ PandasToArrowConversion,
)
from pyspark.sql.types import (
ArrayType,
BinaryType,
+ DecimalType,
+ DoubleType,
GeographyType,
GeometryType,
IntegerType,
+ LongType,
MapType,
NullType,
Row,
StringType,
StructField,
StructType,
+ TimestampType,
UserDefinedType,
)
from pyspark.testing.objects import ExamplePoint, ExamplePointUDT,
PythonOnlyPoint, PythonOnlyUDT
-from pyspark.testing.utils import have_pyarrow, pyarrow_requirement_message
+from pyspark.testing.utils import (
+ have_pandas,
+ have_pyarrow,
+ pandas_requirement_message,
+ pyarrow_requirement_message,
+)
class ScoreUDT(UserDefinedType):
@@ -145,6 +155,200 @@ class ArrowBatchTransformerTests(unittest.TestCase):
self.assertEqual(wrapped.num_columns, 1)
[email protected](not have_pyarrow, pyarrow_requirement_message)
[email protected](not have_pandas, pandas_requirement_message)
+class PandasToArrowConversionTests(unittest.TestCase):
+ def test_convert(self):
+ """Test basic DataFrame/Series to Arrow RecordBatch conversion."""
+ import pandas as pd
+ import pyarrow as pa
+
+ # Basic DataFrame conversion
+ df = pd.DataFrame({"a": [1, 2, 3], "b": [1.0, 2.0, 3.0]})
+ schema = StructType([StructField("a", IntegerType()), StructField("b",
DoubleType())])
+ result = PandasToArrowConversion.convert(df, schema)
+ self.assertIsInstance(result, pa.RecordBatch)
+ self.assertEqual(result.num_rows, 3)
+ self.assertEqual(result.num_columns, 2)
+ self.assertEqual(result.schema.names, ["a", "b"])
+
+ # List of Series input
+ series_list = [pd.Series([1, 2, 3]), pd.Series([1.0, 2.0, 3.0])]
+ result = PandasToArrowConversion.convert(series_list, schema)
+ self.assertEqual(result.num_rows, 3)
+
+ # With nulls
+ df = pd.DataFrame({"a": [1, None, 3], "b": [1.0, 2.0, None]})
+ result = PandasToArrowConversion.convert(df, schema)
+ self.assertEqual(result.column(0).to_pylist(), [1, None, 3])
+
+ # Empty DataFrame (0 rows)
+ df = pd.DataFrame({"a": pd.Series([], dtype=int), "b": pd.Series([],
dtype=float)})
+ result = PandasToArrowConversion.convert(df, schema)
+ self.assertEqual(result.num_rows, 0)
+
+ # Empty schema (0 columns)
+ # TODO(SPARK-55350): Pandas - > PyArrow should preserve row count with
0 columns. It is a bug.
+ result = PandasToArrowConversion.convert(df, StructType([]))
+ self.assertEqual(result.num_columns, 0)
+ self.assertEqual(result.num_rows, 0)
+
+ def test_convert_assign_cols_by_name(self):
+ """Test assign_cols_by_name reorders columns to match schema."""
+ import pandas as pd
+
+ # DataFrame columns in different order than schema
+ df = pd.DataFrame({"b": ["x", "y", "z"], "a": [1, 2, 3]})
+ schema = StructType([StructField("a", IntegerType()), StructField("b",
StringType())])
+
+ # With assign_cols_by_name=True - reorders columns to match schema
field names
+ result = PandasToArrowConversion.convert(df, schema,
assign_cols_by_name=True)
+ self.assertEqual(result.column(0).to_pylist(), [1, 2, 3]) # a
+ self.assertEqual(result.column(1).to_pylist(), ["x", "y", "z"]) # b
+
+ # Without assign_cols_by_name - uses positional order (b first, a
second)
+ df = pd.DataFrame({"b": [10, 20, 30], "a": [1.0, 2.0, 3.0]})
+ schema = StructType([StructField("x", IntegerType()), StructField("y",
DoubleType())])
+ result = PandasToArrowConversion.convert(df, schema,
assign_cols_by_name=False)
+ self.assertEqual(result.column(0).to_pylist(), [10, 20, 30]) #
positional: b -> x
+ self.assertEqual(result.column(1).to_pylist(), [1.0, 2.0, 3.0]) #
positional: a -> y
+
+ def test_convert_timezone(self):
+ """Test timezone handling for timestamp conversion."""
+ import pandas as pd
+
+ # Create DataFrame with timezone-naive timestamps
+ df = pd.DataFrame({"ts": pd.to_datetime(["2023-01-01 12:00:00",
"2023-01-02 12:00:00"])})
+ schema = StructType([StructField("ts", TimestampType())])
+
+ # Convert with timezone
+ result = PandasToArrowConversion.convert(df, schema, timezone="UTC")
+ self.assertEqual(result.num_rows, 2)
+ self.assertEqual(result.num_columns, 1)
+
+ def test_convert_arrow_cast(self):
+ """Test arrow_cast allows type coercion on mismatch."""
+ import pandas as pd
+
+ # DataFrame with int32, schema expects int64
+ df = pd.DataFrame({"a": pd.array([1, 2, 3], dtype="int32")})
+ schema = StructType([StructField("a", LongType())])
+
+ # With arrow_cast=True, should allow the conversion
+ result = PandasToArrowConversion.convert(df, schema, arrow_cast=True)
+ self.assertEqual(result.column(0).to_pylist(), [1, 2, 3])
+
+ def test_convert_decimal(self):
+ """Test int to decimal coercion."""
+ import pandas as pd
+ from decimal import Decimal
+
+ # DataFrame with integers, schema expects decimal
+ df = pd.DataFrame({"a": [1, 2, 3]})
+ schema = StructType([StructField("a", DecimalType(10, 2))])
+
+ # With int_to_decimal_coercion_enabled=True
+ result = PandasToArrowConversion.convert(df, schema,
int_to_decimal_coercion_enabled=True)
+ self.assertEqual(result.num_rows, 3)
+ # Values should be converted to decimal
+ values = result.column(0).to_pylist()
+ self.assertEqual(values, [Decimal("1.00"), Decimal("2.00"),
Decimal("3.00")])
+
+ def test_convert_struct(self):
+ """Test struct type conversion via nested DataFrame columns."""
+ import pandas as pd
+ import pyarrow as pa
+
+ schema = StructType(
+ [
+ StructField("id", IntegerType()),
+ StructField(
+ "info",
+ StructType([StructField("x", IntegerType()),
StructField("y", DoubleType())]),
+ ),
+ ]
+ )
+ # List input: second element is a DataFrame (struct column)
+ data = [pd.Series([1, 2]), pd.DataFrame({"x": [10, 20], "y": [1.1,
2.2]})]
+ result = PandasToArrowConversion.convert(data, schema)
+ self.assertEqual(result.num_rows, 2)
+ self.assertEqual(result.num_columns, 2)
+ # Struct column should be a StructArray
+ self.assertTrue(pa.types.is_struct(result.column(1).type))
+
+ # Empty DataFrame for struct type
+ data = [
+ pd.Series([], dtype=int),
+ pd.DataFrame({"x": pd.Series([], dtype=int), "y": pd.Series([],
dtype=float)}),
+ ]
+ result = PandasToArrowConversion.convert(data, schema)
+ self.assertEqual(result.num_rows, 0)
+
+ def test_convert_error_messages(self):
+ """Test error messages include series name from schema field."""
+ import pandas as pd
+
+ schema = StructType([StructField("age", IntegerType()),
StructField("name", StringType())])
+
+ # Type mismatch: string data for integer column
+ data = [pd.Series(["not_int", "bad"]), pd.Series(["a", "b"])]
+ with self.assertRaises((PySparkValueError, PySparkTypeError)) as ctx:
+ PandasToArrowConversion.convert(data, schema)
+ # Error message should reference the schema field name, not the
positional index
+ self.assertIn("age", str(ctx.exception))
+
+ def test_convert_is_udtf(self):
+ """Test is_udtf=True produces PySparkRuntimeError with
UDTF_ARROW_TYPE_CAST_ERROR."""
+ import pandas as pd
+
+ schema = StructType([StructField("val", DoubleType())])
+ data = [pd.Series(["not_a_number", "bad"])]
+
+ # ValueError path (string -> double)
+ with self.assertRaises(PySparkRuntimeError) as ctx:
+ PandasToArrowConversion.convert(data, schema, is_udtf=True)
+ self.assertIn("UDTF_ARROW_TYPE_CAST_ERROR", str(ctx.exception))
+
+ # TypeError path (int -> struct): ArrowTypeError inherits from
TypeError.
+ # ignore_unexpected_complex_type_values=True lets the bad value pass
through
+ # to Arrow, which raises ArrowTypeError (a TypeError subclass).
+ struct_schema = StructType(
+ [StructField("x", StructType([StructField("a", IntegerType())]))]
+ )
+ data = [pd.Series([0, 1])]
+ with self.assertRaises(PySparkRuntimeError) as ctx:
+ PandasToArrowConversion.convert(
+ data,
+ struct_schema,
+ is_udtf=True,
+ ignore_unexpected_complex_type_values=True,
+ )
+ self.assertIn("UDTF_ARROW_TYPE_CAST_ERROR", str(ctx.exception))
+
+ def test_convert_prefers_large_types(self):
+ """Test prefers_large_types produces large Arrow types."""
+ import pandas as pd
+ import pyarrow as pa
+
+ df = pd.DataFrame({"s": ["hello", "world"]})
+ schema = StructType([StructField("s", StringType())])
+
+ result = PandasToArrowConversion.convert(df, schema,
prefers_large_types=True)
+ self.assertEqual(result.column(0).type, pa.large_string())
+
+ result = PandasToArrowConversion.convert(df, schema,
prefers_large_types=False)
+ self.assertEqual(result.column(0).type, pa.string())
+
+ def test_convert_categorical(self):
+ """Test CategoricalDtype series is correctly converted."""
+ import pandas as pd
+
+ cat_series = pd.Series(pd.Categorical(["a", "b", "a", "c"]))
+ schema = StructType([StructField("cat", StringType())])
+ result = PandasToArrowConversion.convert([cat_series], schema)
+ self.assertEqual(result.column(0).to_pylist(), ["a", "b", "a", "c"])
+
+
@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
class ConversionTests(unittest.TestCase):
def test_conversion(self):
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]