Yicong-Huang commented on code in PR #54125:
URL: https://github.com/apache/spark/pull/54125#discussion_r2800832794
##########
python/pyspark/sql/conversion.py:
##########
@@ -162,6 +162,239 @@ def to_pandas(
]
+# TODO: elevate to ArrowBatchTransformer and operate on full RecordBatch schema
+# instead of per-column coercion.
+def coerce_arrow_array(
+ arr: "pa.Array",
+ target_type: "pa.DataType",
+ *,
+ safecheck: bool = True,
+ arrow_cast: bool = True,
+) -> "pa.Array":
+ """
+ Coerce an Arrow Array to a target type, with optional type-mismatch
enforcement.
+
+ When ``arrow_cast`` is True (default), mismatched types are cast to the
+ target type. When False, a type mismatch raises an error instead.
+
+ Parameters
+ ----------
+ arr : pa.Array
+ Input Arrow array
+ target_type : pa.DataType
+ Target Arrow type
+ safecheck : bool
+ Whether to use safe casting (default True)
+ arrow_cast : bool
+ Whether to allow casting when types don't match (default True)
+
+ Returns
+ -------
+ pa.Array
+ """
+ from pyspark.errors import PySparkTypeError
+
+ if arr.type == target_type:
+ return arr
+
+ if not arrow_cast:
+ raise PySparkTypeError(
+ "Arrow UDFs require the return type to match the expected Arrow
type. "
+ f"Expected: {target_type}, but got: {arr.type}."
+ )
+
+ # when safe is True, the cast will fail if there's a overflow or other
+ # unsafe conversion.
+ # RecordBatch.cast(...) isn't used as minimum PyArrow version
+ # required for RecordBatch.cast(...) is v16.0
+ return arr.cast(target_type=target_type, safe=safecheck)
+
+
+class PandasToArrowConversion:
+ """
+ Conversion utilities from pandas data to Arrow.
+ """
+
+ @classmethod
+ def convert(
+ cls,
+ data: Union["pd.DataFrame", List[Union["pd.Series", "pd.DataFrame"]]],
+ schema: StructType,
+ *,
+ timezone: Optional[str] = None,
+ safecheck: bool = True,
+ arrow_cast: bool = False,
+ prefers_large_types: bool = False,
+ assign_cols_by_name: bool = False,
+ int_to_decimal_coercion_enabled: bool = False,
+ ignore_unexpected_complex_type_values: bool = False,
+ is_udtf: bool = False,
+ ) -> "pa.RecordBatch":
+ """
+ Convert a pandas DataFrame or list of Series/DataFrames to an Arrow
RecordBatch.
+
+ Parameters
+ ----------
+ data : pd.DataFrame or list of pd.Series/pd.DataFrame
+ Input data - either a DataFrame or a list of Series/DataFrames.
+ schema : StructType
+ Spark schema defining the types for each column
+ timezone : str, optional
+ Timezone for timestamp conversion
+ safecheck : bool
+ Whether to use safe Arrow conversion (default True)
+ arrow_cast : bool
+ Whether to allow Arrow casting on type mismatch (default False)
+ prefers_large_types : bool
+ Whether to prefer large Arrow types (default False)
+ assign_cols_by_name : bool
+ Whether to reorder DataFrame columns by name to match schema
(default False)
+ int_to_decimal_coercion_enabled : bool
+ Whether to enable int to decimal coercion (default False)
+ ignore_unexpected_complex_type_values : bool
+ Whether to ignore unexpected complex type values in converter
(default False)
+ is_udtf : bool
Review Comment:
Added TODO with SPARK-55502.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]