This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 15111a05e92 [SPARK-43260][PYTHON] Migrate the Spark SQL pandas arrow type errors into error class 15111a05e92 is described below commit 15111a05e925c0f25949908a9407c3b71e332e5f Author: itholic <haejoon....@databricks.com> AuthorDate: Tue Apr 25 14:22:13 2023 +0800 [SPARK-43260][PYTHON] Migrate the Spark SQL pandas arrow type errors into error class ### What changes were proposed in this pull request? This PR proposes to migrate the Spark SQL pandas arrow type errors into error class. ### Why are the changes needed? Leveraging the PySpark error framework. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? The existing CI should pass Closes #40924 from itholic/error_pandas_types. Authored-by: itholic <haejoon....@databricks.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/errors/error_classes.py | 10 ++++++ python/pyspark/sql/pandas/types.py | 61 ++++++++++++++++++++++++++-------- python/pyspark/sql/tests/test_arrow.py | 17 ++++++++-- 3 files changed, 73 insertions(+), 15 deletions(-) diff --git a/python/pyspark/errors/error_classes.py b/python/pyspark/errors/error_classes.py index e3742441fe4..f6fd0f24a35 100644 --- a/python/pyspark/errors/error_classes.py +++ b/python/pyspark/errors/error_classes.py @@ -354,6 +354,16 @@ ERROR_CLASSES_JSON = """ "Unsupported DataType `<data_type>`." ] }, + "UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION" : { + "message" : [ + "<data_type> is not supported in conversion to Arrow." + ] + }, + "UNSUPPORTED_DATA_TYPE_FOR_ARROW_VERSION" : { + "message" : [ + "<data_type> is only supported with pyarrow 2.0.0 and above." + ] + }, "UNSUPPORTED_LITERAL" : { "message" : [ "Unsupported Literal '<literal>'." diff --git a/python/pyspark/sql/pandas/types.py b/python/pyspark/sql/pandas/types.py index 67efdae2b87..70d50ca6e95 100644 --- a/python/pyspark/sql/pandas/types.py +++ b/python/pyspark/sql/pandas/types.py @@ -44,6 +44,7 @@ from pyspark.sql.types import ( NullType, DataType, ) +from pyspark.errors import PySparkTypeError if TYPE_CHECKING: import pyarrow as pa @@ -87,27 +88,43 @@ def to_arrow_type(dt: DataType) -> "pa.DataType": arrow_type = pa.duration("us") elif type(dt) == ArrayType: if type(dt.elementType) == TimestampType: - raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) + raise PySparkTypeError( + error_class="UNSUPPORTED_DATA_TYPE", + message_parameters={"data_type": str(dt)}, + ) elif type(dt.elementType) == StructType: if LooseVersion(pa.__version__) < LooseVersion("2.0.0"): - raise TypeError( - "Array of StructType is only supported with pyarrow 2.0.0 and above" + raise PySparkTypeError( + error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW_VERSION", + message_parameters={"data_type": "Array of StructType"}, ) if any(type(field.dataType) == StructType for field in dt.elementType): - raise TypeError("Nested StructType not supported in conversion to Arrow") + raise PySparkTypeError( + error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION", + message_parameters={"data_type": "Nested StructType"}, + ) arrow_type = pa.list_(to_arrow_type(dt.elementType)) elif type(dt) == MapType: if LooseVersion(pa.__version__) < LooseVersion("2.0.0"): - raise TypeError("MapType is only supported with pyarrow 2.0.0 and above") + raise PySparkTypeError( + error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW_VERSION", + message_parameters={"data_type": "MapType"}, + ) if type(dt.keyType) in [StructType, TimestampType] or type(dt.valueType) in [ StructType, TimestampType, ]: - raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) + raise PySparkTypeError( + error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION", + message_parameters={"data_type": str(dt)}, + ) arrow_type = pa.map_(to_arrow_type(dt.keyType), to_arrow_type(dt.valueType)) elif type(dt) == StructType: if any(type(field.dataType) == StructType for field in dt): - raise TypeError("Nested StructType not supported in conversion to Arrow") + raise PySparkTypeError( + error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION", + message_parameters={"data_type": "Nested StructType"}, + ) fields = [ pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable) for field in dt @@ -116,7 +133,10 @@ def to_arrow_type(dt: DataType) -> "pa.DataType": elif type(dt) == NullType: arrow_type = pa.null() else: - raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) + raise PySparkTypeError( + error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION", + message_parameters={"data_type": str(dt)}, + ) return arrow_type @@ -168,17 +188,29 @@ def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) -> Da spark_type = DayTimeIntervalType() elif types.is_list(at): if types.is_timestamp(at.value_type): - raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) + raise PySparkTypeError( + error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION", + message_parameters={"data_type": str(at)}, + ) spark_type = ArrayType(from_arrow_type(at.value_type)) elif types.is_map(at): if LooseVersion(pa.__version__) < LooseVersion("2.0.0"): - raise TypeError("MapType is only supported with pyarrow 2.0.0 and above") + raise PySparkTypeError( + error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW_VERSION", + message_parameters={"data_type": "MapType"}, + ) if types.is_timestamp(at.key_type) or types.is_timestamp(at.item_type): - raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) + raise PySparkTypeError( + error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION", + message_parameters={"data_type": str(at)}, + ) spark_type = MapType(from_arrow_type(at.key_type), from_arrow_type(at.item_type)) elif types.is_struct(at): if any(types.is_struct(field.type) for field in at): - raise TypeError("Nested StructType not supported in conversion from Arrow: " + str(at)) + raise PySparkTypeError( + error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION", + message_parameters={"data_type": "Nested StructType"}, + ) return StructType( [ StructField(field.name, from_arrow_type(field.type), nullable=field.nullable) @@ -190,7 +222,10 @@ def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) -> Da elif types.is_null(at): spark_type = NullType() else: - raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) + raise PySparkTypeError( + error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION", + message_parameters={"data_type": str(at)}, + ) return spark_type diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index cf28d32c903..518e17d57b6 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -55,6 +55,7 @@ from pyspark.testing.sqlutils import ( pyarrow_requirement_message, ) from pyspark.testing.utils import QuietTest +from pyspark.errors import PySparkTypeError if have_pandas: import pandas as pd @@ -215,9 +216,15 @@ class ArrowTestsMixin: df = self.spark.createDataFrame([(None,)], schema=schema) with QuietTest(self.sc): with self.warnings_lock: - with self.assertRaisesRegex(Exception, "Unsupported type"): + with self.assertRaises(PySparkTypeError) as pe: df.toPandas() + self.check_error( + exception=pe.exception, + error_class="UNSUPPORTED_DATA_TYPE", + message_parameters={"data_type": "ArrayType(TimestampType(), True)"}, + ) + def test_toPandas_empty_df_arrow_enabled(self): for arrow_enabled in [True, False]: with self.subTest(arrow_enabled=arrow_enabled): @@ -748,12 +755,18 @@ class ArrowTestsMixin: def test_createDataFrame_fallback_disabled(self): with QuietTest(self.sc): - with self.assertRaisesRegex(TypeError, "Unsupported type"): + with self.assertRaises(PySparkTypeError) as pe: self.spark.createDataFrame( pd.DataFrame({"a": [[datetime.datetime(2015, 11, 1, 0, 30)]]}), "a: array<timestamp>", ) + self.check_error( + exception=pe.exception, + error_class="UNSUPPORTED_DATA_TYPE", + message_parameters={"data_type": "ArrayType(TimestampType(), True)"}, + ) + # Regression test for SPARK-23314 def test_timestamp_dst(self): # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org