This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 062ac987063 [SPARK-43296][CONNECT][PYTHON] Migrate Spark Connect
session errors into error class
062ac987063 is described below
commit 062ac987063f8814c8d92925ddc6d2c72df2d208
Author: itholic <[email protected]>
AuthorDate: Tue May 9 10:48:34 2023 +0800
[SPARK-43296][CONNECT][PYTHON] Migrate Spark Connect session errors into
error class
### What changes were proposed in this pull request?
This PR proposes to migrate Spark Connect session errors into error class
### Why are the changes needed?
To improve PySpark error usability.
### Does this PR introduce _any_ user-facing change?
No API changes.
### How was this patch tested?
The existing CI should pass.
Closes #40964 from itholic/error_connect_session.
Authored-by: itholic <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/errors/error_classes.py | 30 +++++++++
python/pyspark/sql/connect/session.py | 77 +++++++++++++++-------
python/pyspark/sql/pandas/conversion.py | 6 +-
.../sql/tests/connect/test_connect_basic.py | 44 ++++++++-----
python/pyspark/sql/tests/test_arrow.py | 8 ++-
5 files changed, 123 insertions(+), 42 deletions(-)
diff --git a/python/pyspark/errors/error_classes.py
b/python/pyspark/errors/error_classes.py
index 6af8d5bc6ff..c7b00e0736d 100644
--- a/python/pyspark/errors/error_classes.py
+++ b/python/pyspark/errors/error_classes.py
@@ -39,6 +39,11 @@ ERROR_CLASSES_JSON = """
"Attribute `<attr_name>` is not supported."
]
},
+ "AXIS_LENGTH_MISMATCH" : {
+ "message" : [
+ "Length mismatch: Expected axis has <expected_length> element, new
values have <actual_length> elements."
+ ]
+ },
"BROADCAST_VARIABLE_NOT_LOADED": {
"message": [
"Broadcast variable `<variable>` not loaded."
@@ -94,6 +99,11 @@ ERROR_CLASSES_JSON = """
"Can not infer Array Type from an list with None as the first element."
]
},
+ "CANNOT_INFER_EMPTY_SCHEMA": {
+ "message": [
+ "Can not infer schema from empty dataset."
+ ]
+ },
"CANNOT_INFER_SCHEMA_FOR_TYPE": {
"message": [
"Can not infer schema for type: `<data_type>`."
@@ -195,6 +205,11 @@ ERROR_CLASSES_JSON = """
"All items in `<arg_name>` should be in <allowed_types>, got
<item_type>."
]
},
+ "INVALID_NDARRAY_DIMENSION": {
+ "message": [
+ "NumPy array input should be of <dimensions> dimensions."
+ ]
+ },
"INVALID_PANDAS_UDF" : {
"message" : [
"Invalid function: <detail>"
@@ -215,6 +230,11 @@ ERROR_CLASSES_JSON = """
"Timeout timestamp (<timestamp>) cannot be earlier than the current
watermark (<watermark>)."
]
},
+ "INVALID_TYPE" : {
+ "message" : [
+ "Argument `<arg_name>` should not be a <data_type>."
+ ]
+ },
"INVALID_TYPENAME_CALL" : {
"message" : [
"StructField does not have typeName. Use typeName on its type explicitly
instead."
@@ -556,6 +576,11 @@ ERROR_CLASSES_JSON = """
"Result vector from pandas_udf was not the required length: expected
<expected>, got <actual>."
]
},
+ "SESSION_OR_CONTEXT_EXISTS" : {
+ "message" : [
+ "There should not be an existing Spark Session or Spark Context."
+ ]
+ },
"SLICE_WITH_STEP" : {
"message" : [
"Slice with step is not supported."
@@ -611,6 +636,11 @@ ERROR_CLASSES_JSON = """
"Unsupported DataType `<data_type>`."
]
},
+ "UNSUPPORTED_DATA_TYPE_FOR_ARROW" : {
+ "message" : [
+ "Single data type <data_type> is not supported with Arrow."
+ ]
+ },
"UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION" : {
"message" : [
"<data_type> is not supported in conversion to Arrow."
diff --git a/python/pyspark/sql/connect/session.py
b/python/pyspark/sql/connect/session.py
index 4f8fa419119..c23b6c5d11a 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -68,7 +68,13 @@ from pyspark.sql.types import (
TimestampType,
)
from pyspark.sql.utils import to_str
-from pyspark.errors import PySparkAttributeError, PySparkNotImplementedError
+from pyspark.errors import (
+ PySparkAttributeError,
+ PySparkNotImplementedError,
+ PySparkRuntimeError,
+ PySparkValueError,
+ PySparkTypeError,
+)
if TYPE_CHECKING:
from pyspark.sql.connect._typing import OptionalPrimitiveType
@@ -153,8 +159,7 @@ class SparkSession:
def enableHiveSupport(self) -> "SparkSession.Builder":
raise PySparkNotImplementedError(
- error_class="NOT_IMPLEMENTED",
- message_parameters={"feature": "enableHiveSupport"},
+ error_class="NOT_IMPLEMENTED", message_parameters={"feature":
"enableHiveSupport"}
)
def getOrCreate(self) -> "SparkSession":
@@ -233,7 +238,10 @@ class SparkSession:
Infer schema from list of Row, dict, or tuple.
"""
if not data:
- raise ValueError("can not infer schema from empty dataset")
+ raise PySparkValueError(
+ error_class="CANNOT_INFER_EMPTY_SCHEMA",
+ message_parameters={},
+ )
(
infer_dict_as_struct,
@@ -265,7 +273,10 @@ class SparkSession:
) -> "DataFrame":
assert data is not None
if isinstance(data, DataFrame):
- raise TypeError("data is already a DataFrame")
+ raise PySparkTypeError(
+ error_class="INVALID_TYPE",
+ message_parameters={"arg_name": "data", "data_type":
"DataFrame"},
+ )
_schema: Optional[Union[AtomicType, StructType]] = None
_cols: Optional[List[str]] = None
@@ -289,12 +300,18 @@ class SparkSession:
_num_cols = len(_cols)
if isinstance(data, np.ndarray) and data.ndim not in [1, 2]:
- raise ValueError("NumPy array input should be of 1 or 2
dimensions.")
+ raise PySparkValueError(
+ error_class="INVALID_NDARRAY_DIMENSION",
+ message_parameters={"dimensions": "1 or 2"},
+ )
elif isinstance(data, Sized) and len(data) == 0:
if _schema is not None:
return DataFrame.withPlan(LocalRelation(table=None,
schema=_schema.json()), self)
else:
- raise ValueError("can not infer schema from empty dataset")
+ raise PySparkValueError(
+ error_class="CANNOT_INFER_EMPTY_SCHEMA",
+ message_parameters={},
+ )
_table: Optional[pa.Table] = None
@@ -317,7 +334,10 @@ class SparkSession:
arrow_types = [field.type for field in arrow_schema]
_cols = [str(x) if not isinstance(x, str) else x for x in
schema.fieldNames()]
elif isinstance(schema, DataType):
- raise ValueError("Single data type %s is not supported with
Arrow" % str(schema))
+ raise PySparkTypeError(
+ error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW",
+ message_parameters={"data_type": str(schema)},
+ )
else:
# Any timestamps must be coerced to be compatible with Spark
arrow_types = [
@@ -354,17 +374,23 @@ class SparkSession:
if data.ndim == 1:
if 1 != len(_cols):
- raise ValueError(
- f"Length mismatch: Expected axis has {len(_cols)}
element, "
- "new values have 1 elements"
+ raise PySparkValueError(
+ error_class="AXIS_LENGTH_MISMATCH",
+ message_parameters={
+ "expected_length": str(len(_cols)),
+ "actual_length": "1",
+ },
)
_table = pa.Table.from_arrays([pa.array(data)], _cols)
else:
if data.shape[1] != len(_cols):
- raise ValueError(
- f"Length mismatch: Expected axis has {len(_cols)}
elements, "
- f"new values have {data.shape[1]} elements"
+ raise PySparkValueError(
+ error_class="AXIS_LENGTH_MISMATCH",
+ message_parameters={
+ "expected_length": str(len(_cols)),
+ "actual_length": str(data.shape[1]),
+ },
)
_table = pa.Table.from_arrays(
@@ -416,9 +442,12 @@ class SparkSession:
# TODO: Beside the validation on number of columns, we should also
check
# whether the Arrow Schema is compatible with the user provided Schema.
if _num_cols is not None and _num_cols != _table.shape[1]:
- raise ValueError(
- f"Length mismatch: Expected axis has {_num_cols} elements, "
- f"new values have {_table.shape[1]} elements"
+ raise PySparkValueError(
+ error_class="AXIS_LENGTH_MISMATCH",
+ message_parameters={
+ "expected_length": str(_num_cols),
+ "actual_length": str(_table.shape[1]),
+ },
)
if _schema is not None:
@@ -517,14 +546,12 @@ class SparkSession:
@classmethod
def getActiveSession(cls) -> Any:
raise PySparkNotImplementedError(
- error_class="NOT_IMPLEMENTED",
- message_parameters={"feature": "getActiveSession()"},
+ error_class="NOT_IMPLEMENTED", message_parameters={"feature":
"getActiveSession()"}
)
def newSession(self) -> Any:
raise PySparkNotImplementedError(
- error_class="NOT_IMPLEMENTED",
- message_parameters={"feature": "newSession()"},
+ error_class="NOT_IMPLEMENTED", message_parameters={"feature":
"newSession()"}
)
@property
@@ -534,8 +561,7 @@ class SparkSession:
@property
def sparkContext(self) -> Any:
raise PySparkNotImplementedError(
- error_class="NOT_IMPLEMENTED",
- message_parameters={"feature": "sparkContext()"},
+ error_class="NOT_IMPLEMENTED", message_parameters={"feature":
"sparkContext()"}
)
@property
@@ -705,7 +731,10 @@ class SparkSession:
if origin_remote is not None:
os.environ["SPARK_REMOTE"] = origin_remote
else:
- raise RuntimeError("There should not be an existing Spark Session
or Spark Context.")
+ raise PySparkRuntimeError(
+ error_class="SESSION_OR_CONTEXT_EXISTS",
+ message_parameters={},
+ )
@property
def session_id(self) -> str:
diff --git a/python/pyspark/sql/pandas/conversion.py
b/python/pyspark/sql/pandas/conversion.py
index 0c29dcceed0..a4503661cad 100644
--- a/python/pyspark/sql/pandas/conversion.py
+++ b/python/pyspark/sql/pandas/conversion.py
@@ -31,6 +31,7 @@ from pyspark.sql.pandas.serializers import
ArrowCollectSerializer
from pyspark.sql.types import TimestampType, StructType, DataType
from pyspark.sql.utils import is_timestamp_ntz_preferred
from pyspark.traceback_utils import SCCallSiteSync
+from pyspark.errors import PySparkTypeError
if TYPE_CHECKING:
import numpy as np
@@ -488,7 +489,10 @@ class SparkConversionMixin:
if isinstance(schema, StructType):
arrow_types = [to_arrow_type(f.dataType) for f in schema.fields]
elif isinstance(schema, DataType):
- raise ValueError("Single data type %s is not supported with Arrow"
% str(schema))
+ raise PySparkTypeError(
+ error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW",
+ message_parameters={"data_type": str(schema)},
+ )
else:
# Any timestamps must be coerced to be compatible with Spark
arrow_types = [
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 45dbe182f12..b0bc2cba78e 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -552,21 +552,27 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
self.assertEqual(sdf.schema, cdf.schema)
self.assert_eq(sdf.toPandas(), cdf.toPandas())
- with self.assertRaisesRegex(
- ValueError,
- "Length mismatch: Expected axis has 5 elements, new values have 4
elements",
- ):
+ with self.assertRaises(PySparkValueError) as pe:
self.connect.createDataFrame(data, ["a", "b", "c", "d", "e"])
+ self.check_error(
+ exception=pe.exception,
+ error_class="AXIS_LENGTH_MISMATCH",
+ message_parameters={"expected_length": "5", "actual_length": "4"},
+ )
+
with self.assertRaises(ParseException):
self.connect.createDataFrame(data, "col1 magic_type, col2 int,
col3 int, col4 int")
- with self.assertRaisesRegex(
- ValueError,
- "Length mismatch: Expected axis has 3 elements, new values have 4
elements",
- ):
+ with self.assertRaises(PySparkValueError) as pe:
self.connect.createDataFrame(data, "col1 int, col2 int, col3 int")
+ self.check_error(
+ exception=pe.exception,
+ error_class="AXIS_LENGTH_MISMATCH",
+ message_parameters={"expected_length": "3", "actual_length": "4"},
+ )
+
# test 1 dim ndarray
data = np.array([1.0, 2.0, np.nan, 3.0, 4.0, float("NaN"), 5.0])
self.assertEqual(data.ndim, 1)
@@ -599,12 +605,15 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
self.assertEqual(sdf.schema, cdf.schema)
self.assert_eq(sdf.toPandas(), cdf.toPandas())
- with self.assertRaisesRegex(
- ValueError,
- "Length mismatch: Expected axis has 5 elements, new values have 4
elements",
- ):
+ with self.assertRaises(PySparkValueError) as pe:
self.connect.createDataFrame(data, ["a", "b", "c", "d", "e"])
+ self.check_error(
+ exception=pe.exception,
+ error_class="AXIS_LENGTH_MISMATCH",
+ message_parameters={"expected_length": "5", "actual_length": "4"},
+ )
+
with self.assertRaises(ParseException):
self.connect.createDataFrame(data, "col1 magic_type, col2 int,
col3 int, col4 int")
@@ -765,12 +774,15 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
self.assert_eq(cdf.toPandas(), sdf.toPandas())
# check error
- with self.assertRaisesRegex(
- ValueError,
- "can not infer schema from empty dataset",
- ):
+ with self.assertRaises(PySparkValueError) as pe:
self.connect.createDataFrame(data=[])
+ self.check_error(
+ exception=pe.exception,
+ error_class="CANNOT_INFER_EMPTY_SCHEMA",
+ message_parameters={},
+ )
+
def test_create_dataframe_from_arrays(self):
# SPARK-42021: createDataFrame support array.array
data1 = [Row(a=1, b=array.array("i", [1, 2, 3]), c=array.array("d",
[4, 5, 6]))]
diff --git a/python/pyspark/sql/tests/test_arrow.py
b/python/pyspark/sql/tests/test_arrow.py
index 52e13782199..91fc6969185 100644
--- a/python/pyspark/sql/tests/test_arrow.py
+++ b/python/pyspark/sql/tests/test_arrow.py
@@ -542,9 +542,15 @@ class ArrowTestsMixin:
def check_createDataFrame_with_single_data_type(self):
for schema in ["int", IntegerType()]:
with self.subTest(schema=schema):
- with self.assertRaisesRegex(ValueError, ".*IntegerType.*not
supported.*"):
+ with self.assertRaises(PySparkTypeError) as pe:
self.spark.createDataFrame(pd.DataFrame({"a": [1]}),
schema=schema).collect()
+ self.check_error(
+ exception=pe.exception,
+ error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW",
+ message_parameters={"data_type": "IntegerType()"},
+ )
+
def test_createDataFrame_does_not_modify_input(self):
# Some series get converted for Spark to consume, this makes sure
input is unchanged
pdf = self.create_pandas_data_frame()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]