This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 04816474bfc [SPARK-43261][PYTHON] Migrate `TypeError` from Spark SQL
types into error class
04816474bfc is described below
commit 04816474bfcc05c7d90f7b7e8d35184d95c78cbd
Author: itholic <[email protected]>
AuthorDate: Thu Apr 27 16:55:52 2023 +0800
[SPARK-43261][PYTHON] Migrate `TypeError` from Spark SQL types into error
class
### What changes were proposed in this pull request?
This PR proposes to migrate `TypeError` from Spark SQL types into error
class.
### Why are the changes needed?
To improve PySpark error
### Does this PR introduce _any_ user-facing change?
No API change, only error improvement.
### How was this patch tested?
The existing CI should pass
Closes #40926 from itholic/error_sql_types.
Authored-by: itholic <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/errors/error_classes.py | 35 +++++++++++++
python/pyspark/sql/tests/test_dataframe.py | 2 +-
python/pyspark/sql/tests/test_functions.py | 12 +++--
python/pyspark/sql/tests/test_types.py | 61 +++++++++++++++++++----
python/pyspark/sql/types.py | 80 +++++++++++++++++++++++-------
5 files changed, 158 insertions(+), 32 deletions(-)
diff --git a/python/pyspark/errors/error_classes.py
b/python/pyspark/errors/error_classes.py
index 34efd471707..f35971c4a94 100644
--- a/python/pyspark/errors/error_classes.py
+++ b/python/pyspark/errors/error_classes.py
@@ -44,6 +44,11 @@ ERROR_CLASSES_JSON = """
"Not supported to call `<func_name>` before initialize <object>."
]
},
+ "CANNOT_ACCEPT_OBJECT_IN_TYPE": {
+ "message": [
+ "`<data_type>` can not accept object `<obj_name>` in type `<obj_type>`."
+ ]
+ },
"CANNOT_ACCESS_TO_DUNDER": {
"message": [
"Dunder(double underscore) attribute is for internal use only."
@@ -69,11 +74,31 @@ ERROR_CLASSES_JSON = """
"Cannot convert column into bool: please use '&' for 'and', '|' for
'or', '~' for 'not' when building DataFrame boolean expressions."
]
},
+ "CANNOT_CONVERT_TYPE": {
+ "message": [
+ "Cannot convert <from_type> into <to_type>."
+ ]
+ },
"CANNOT_INFER_ARRAY_TYPE": {
"message": [
"Can not infer Array Type from an list with None as the first element."
]
},
+ "CANNOT_INFER_SCHEMA_FOR_TYPE": {
+ "message": [
+ "Can not infer schema for type: `<data_type>`."
+ ]
+ },
+ "CANNOT_INFER_TYPE_FOR_FIELD": {
+ "message": [
+ "Unable to infer the type of the field `<field_name>`."
+ ]
+ },
+ "CANNOT_MERGE_TYPE": {
+ "message": [
+ "Can not merge type `<data_type1>` and `<data_type2>`."
+ ]
+ },
"CANNOT_OPEN_SOCKET": {
"message": [
"Can not open socket: <errors>."
@@ -155,6 +180,11 @@ ERROR_CLASSES_JSON = """
"Timeout timestamp (<timestamp>) cannot be earlier than the current
watermark (<watermark>)."
]
},
+ "INVALID_TYPENAME_CALL" : {
+ "message" : [
+ "StructField does not have typeName. Use typeName on its type explicitly
instead."
+ ]
+ },
"INVALID_UDF_EVAL_TYPE" : {
"message" : [
"Eval type for UDF must be <eval_type>."
@@ -335,6 +365,11 @@ ERROR_CLASSES_JSON = """
"Argument `<arg_name>` should be an int, got <arg_type>."
]
},
+ "NOT_INT_OR_SLICE_OR_STR" : {
+ "message" : [
+ "Argument `<arg_name>` should be an int, slice or str, got <arg_type>."
+ ]
+ },
"NOT_IN_BARRIER_STAGE" : {
"message" : [
"It is not in a barrier stage."
diff --git a/python/pyspark/sql/tests/test_dataframe.py
b/python/pyspark/sql/tests/test_dataframe.py
index 96b31dfee7b..27e12568b28 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -1010,7 +1010,7 @@ class DataFrameTestsMixin:
# field types mismatch will cause exception at runtime.
self.assertRaisesRegex(
Exception,
- "FloatType\\(\\) can not accept",
+ "CANNOT_ACCEPT_OBJECT_IN_TYPE",
lambda: rdd.toDF("key: float, value: string").collect(),
)
diff --git a/python/pyspark/sql/tests/test_functions.py
b/python/pyspark/sql/tests/test_functions.py
index 38de87b0e72..9067de34633 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -1136,11 +1136,17 @@ class FunctionsTestsMixin:
expected_spark_dtypes,
self.spark.range(1).select(F.lit(arr).alias("b")).dtypes
)
arr = np.array([1, 2]).astype(np.uint)
- with self.assertRaisesRegex(
- TypeError, "The type of array scalar '%s' is not supported" %
arr.dtype
- ):
+ with self.assertRaises(PySparkTypeError) as pe:
self.spark.range(1).select(F.lit(arr).alias("b"))
+ self.check_error(
+ exception=pe.exception,
+ error_class="UNSUPPORTED_NUMPY_ARRAY_SCALAR",
+ message_parameters={
+ "dtype": "uint64",
+ },
+ )
+
def test_binary_math_function(self):
funcs, expected = zip(
*[(F.atan2, 0.13664), (F.hypot, 8.07527), (F.pow, 2.14359),
(F.pmod, 1.1)]
diff --git a/python/pyspark/sql/tests/test_types.py
b/python/pyspark/sql/tests/test_types.py
index cd1ae1f2964..49952c2c135 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -26,7 +26,7 @@ import unittest
from pyspark.sql import Row
from pyspark.sql import functions as F
-from pyspark.errors import AnalysisException
+from pyspark.errors import AnalysisException, PySparkTypeError
from pyspark.sql.types import (
ByteType,
ShortType,
@@ -66,6 +66,7 @@ from pyspark.testing.sqlutils import (
PythonOnlyPoint,
MyObject,
)
+from pyspark.testing.utils import PySparkErrorTestUtils
class TypesTestsMixin:
@@ -906,8 +907,13 @@ class TypesTestsMixin:
self.assertEqual(
_merge_type(ArrayType(LongType()), ArrayType(LongType())),
ArrayType(LongType())
)
- with self.assertRaisesRegex(TypeError, "element in array"):
+ with self.assertRaises(PySparkTypeError) as pe:
_merge_type(ArrayType(LongType()), ArrayType(DoubleType()))
+ self.check_error(
+ exception=pe.exception,
+ error_class="CANNOT_MERGE_TYPE",
+ message_parameters={"data_type1": "LongType", "data_type2":
"DoubleType"},
+ )
self.assertEqual(
_merge_type(MapType(StringType(), LongType()),
MapType(StringType(), LongType())),
@@ -919,8 +925,13 @@ class TypesTestsMixin:
MapType(StringType(), LongType()),
)
- with self.assertRaisesRegex(TypeError, "value of map"):
+ with self.assertRaises(PySparkTypeError) as pe:
_merge_type(MapType(StringType(), LongType()),
MapType(StringType(), DoubleType()))
+ self.check_error(
+ exception=pe.exception,
+ error_class="CANNOT_MERGE_TYPE",
+ message_parameters={"data_type1": "LongType", "data_type2":
"DoubleType"},
+ )
self.assertEqual(
_merge_type(
@@ -929,11 +940,16 @@ class TypesTestsMixin:
),
StructType([StructField("f1", LongType()), StructField("f2",
StringType())]),
)
- with self.assertRaisesRegex(TypeError, "field f1"):
+ with self.assertRaises(PySparkTypeError) as pe:
_merge_type(
StructType([StructField("f1", LongType()), StructField("f2",
StringType())]),
StructType([StructField("f1", DoubleType()), StructField("f2",
StringType())]),
)
+ self.check_error(
+ exception=pe.exception,
+ error_class="CANNOT_MERGE_TYPE",
+ message_parameters={"data_type1": "LongType", "data_type2":
"DoubleType"},
+ )
self.assertEqual(
_merge_type(
@@ -961,7 +977,7 @@ class TypesTestsMixin:
),
StructType([StructField("f1", ArrayType(LongType())),
StructField("f2", StringType())]),
)
- with self.assertRaisesRegex(TypeError, "element in array field f1"):
+ with self.assertRaises(PySparkTypeError) as pe:
_merge_type(
StructType(
[StructField("f1", ArrayType(LongType())),
StructField("f2", StringType())]
@@ -970,6 +986,11 @@ class TypesTestsMixin:
[StructField("f1", ArrayType(DoubleType())),
StructField("f2", StringType())]
),
)
+ self.check_error(
+ exception=pe.exception,
+ error_class="CANNOT_MERGE_TYPE",
+ message_parameters={"data_type1": "LongType", "data_type2":
"DoubleType"},
+ )
self.assertEqual(
_merge_type(
@@ -993,7 +1014,7 @@ class TypesTestsMixin:
]
),
)
- with self.assertRaisesRegex(TypeError, "value of map field f1"):
+ with self.assertRaises(PySparkTypeError) as pe:
_merge_type(
StructType(
[
@@ -1008,6 +1029,11 @@ class TypesTestsMixin:
]
),
)
+ self.check_error(
+ exception=pe.exception,
+ error_class="CANNOT_MERGE_TYPE",
+ message_parameters={"data_type1": "LongType", "data_type2":
"DoubleType"},
+ )
self.assertEqual(
_merge_type(
@@ -1110,10 +1136,16 @@ class TypesTestsMixin:
unsupported_types = all_types - set(supported_types)
# test unsupported types
for t in unsupported_types:
- with self.assertRaisesRegex(TypeError, "infer the type of the
field myarray"):
+ with self.assertRaises(PySparkTypeError) as pe:
a = array.array(t)
self.spark.createDataFrame([Row(myarray=a)]).collect()
+ self.check_error(
+ exception=pe.exception,
+ error_class="CANNOT_INFER_TYPE_FOR_FIELD",
+ message_parameters={"field_name": "myarray"},
+ )
+
def test_repr(self):
instances = [
NullType(),
@@ -1304,7 +1336,7 @@ class DataTypeTests(unittest.TestCase):
self.assertRaises(ValueError, lambda: row_class(1, 2, 3))
-class DataTypeVerificationTests(unittest.TestCase):
+class DataTypeVerificationTests(unittest.TestCase, PySparkErrorTestUtils):
def test_verify_type_exception_msg(self):
self.assertRaisesRegex(
ValueError,
@@ -1313,8 +1345,17 @@ class DataTypeVerificationTests(unittest.TestCase):
)
schema = StructType([StructField("a", StructType([StructField("b",
IntegerType())]))])
- self.assertRaisesRegex(
- TypeError, "field b in field a", lambda:
_make_type_verifier(schema)([["data"]])
+ with self.assertRaises(PySparkTypeError) as pe:
+ _make_type_verifier(schema)([["data"]])
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="CANNOT_ACCEPT_OBJECT_IN_TYPE",
+ message_parameters={
+ "data_type": "IntegerType()",
+ "obj_name": "data",
+ "obj_type": "str",
+ },
)
def test_verify_type_ok_nullable(self):
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 721be76e8ba..5876d55e426 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -50,6 +50,7 @@ from py4j.java_gateway import GatewayClient, JavaClass,
JavaGateway, JavaObject,
from pyspark.serializers import CloudPickleSerializer
from pyspark.sql.utils import has_numpy, get_active_spark_context
+from pyspark.errors import PySparkTypeError
if has_numpy:
import numpy as np
@@ -718,8 +719,9 @@ class StructField(DataType):
return self.dataType.fromInternal(obj)
def typeName(self) -> str: # type: ignore[override]
- raise TypeError(
- "StructField does not have typeName. " "Use typeName on its type
explicitly instead."
+ raise PySparkTypeError(
+ error_class="INVALID_TYPENAME_CALL",
+ message_parameters={},
)
@@ -898,7 +900,10 @@ class StructType(DataType):
elif isinstance(key, slice):
return StructType(self.fields[key])
else:
- raise TypeError("StructType keys should be strings, integers or
slices")
+ raise PySparkTypeError(
+ error_class="NOT_INT_OR_SLICE_OR_STR",
+ message_parameters={"arg_name": "key", "arg_type":
type(key).__name__},
+ )
def simpleString(self) -> str:
return "struct<%s>" % (",".join(f.simpleString() for f in self))
@@ -1584,7 +1589,10 @@ def _infer_type(
if obj.typecode in _array_type_mappings:
return ArrayType(_array_type_mappings[obj.typecode](), False)
else:
- raise TypeError("not supported type: array(%s)" % obj.typecode)
+ raise PySparkTypeError(
+ error_class="UNSUPPORTED_DATA_TYPE",
+ message_parameters={"data_type": f"array({obj.typecode})"},
+ )
else:
try:
return _infer_schema(
@@ -1593,7 +1601,10 @@ def _infer_type(
infer_array_from_first_element=infer_array_from_first_element,
)
except TypeError:
- raise TypeError("not supported type: %s" % type(obj))
+ raise PySparkTypeError(
+ error_class="UNSUPPORTED_DATA_TYPE",
+ message_parameters={"data_type": type(obj).__name__},
+ )
def _infer_schema(
@@ -1624,7 +1635,10 @@ def _infer_schema(
items = sorted(row.__dict__.items())
else:
- raise TypeError("Can not infer schema for type: %s" % type(row))
+ raise PySparkTypeError(
+ error_class="CANNOT_INFER_SCHEMA_FOR_TYPE",
+ message_parameters={"data_type": type(row).__name__},
+ )
fields = []
for k, v in items:
@@ -1641,8 +1655,11 @@ def _infer_schema(
True,
)
)
- except TypeError as e:
- raise TypeError("Unable to infer the type of the field
{}.".format(k)) from e
+ except TypeError:
+ raise PySparkTypeError(
+ error_class="CANNOT_INFER_TYPE_FOR_FIELD",
+ message_parameters={"field_name": k},
+ )
return StructType(fields)
@@ -1713,7 +1730,10 @@ def _merge_type(
return a
elif type(a) is not type(b):
# TODO: type cast (such as int -> long)
- raise TypeError(new_msg("Can not merge type %s and %s" % (type(a),
type(b))))
+ raise PySparkTypeError(
+ error_class="CANNOT_MERGE_TYPE",
+ message_parameters={"data_type1": type(a).__name__, "data_type2":
type(b).__name__},
+ )
# same type
if isinstance(a, StructType):
@@ -1801,7 +1821,10 @@ def _create_converter(dataType: DataType) -> Callable:
elif hasattr(obj, "__dict__"): # object
d = obj.__dict__
else:
- raise TypeError("Unexpected obj type: %s" % type(obj))
+ raise PySparkTypeError(
+ error_class="UNSUPPORTED_DATA_TYPE",
+ message_parameters={"data_type": type(obj).__name__},
+ )
if convert_fields:
return tuple([conv(d.get(name)) for name, conv in zip(names,
converters)])
@@ -1860,7 +1883,7 @@ def _make_type_verifier(
>>> _make_type_verifier(ArrayType(StringType()))(set()) # doctest:
+IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
- TypeError:...
+ pyspark.errors.exceptions.base.PySparkTypeError:...
>>> _make_type_verifier(MapType(StringType(), IntegerType()))({})
>>> _make_type_verifier(StructType([]))(())
>>> _make_type_verifier(StructType([]))([])
@@ -1883,7 +1906,9 @@ def _make_type_verifier(
Traceback (most recent call last):
...
ValueError:...
- >>> _make_type_verifier(MapType(StringType(), IntegerType()))({None: 1})
+ >>> _make_type_verifier( # doctest: +IGNORE_EXCEPTION_DETAIL
+ ... MapType(StringType(), IntegerType())
+ ... )({None: 1})
Traceback (most recent call last):
...
ValueError:...
@@ -1929,8 +1954,13 @@ def _make_type_verifier(
def verify_acceptable_types(obj: Any) -> None:
# subclass of them can not be fromInternal in JVM
if type(obj) not in _acceptable_types[_type]:
- raise TypeError(
- new_msg("%s can not accept object %r in type %s" % (dataType,
obj, type(obj)))
+ raise PySparkTypeError(
+ error_class="CANNOT_ACCEPT_OBJECT_IN_TYPE",
+ message_parameters={
+ "data_type": str(dataType),
+ "obj_name": str(obj),
+ "obj_type": type(obj).__name__,
+ },
)
if isinstance(dataType, (StringType, CharType, VarcharType)):
@@ -2043,8 +2073,13 @@ def _make_type_verifier(
for f, verifier in verifiers:
verifier(d.get(f))
else:
- raise TypeError(
- new_msg("StructType can not accept object %r in type %s" %
(obj, type(obj)))
+ raise PySparkTypeError(
+ error_class="CANNOT_ACCEPT_OBJECT_IN_TYPE",
+ message_parameters={
+ "data_type": "StructType",
+ "obj_name": str(obj),
+ "obj_type": type(obj).__name__,
+ },
)
verify_value = verify_struct
@@ -2183,7 +2218,13 @@ class Row(tuple):
True
"""
if not hasattr(self, "__fields__"):
- raise TypeError("Cannot convert a Row class into dict")
+ raise PySparkTypeError(
+ error_class="CANNOT_CONVERT_TYPE",
+ message_parameters={
+ "from_type": "Row",
+ "to_type": "dict",
+ },
+ )
if recursive:
@@ -2368,7 +2409,10 @@ class NumpyArrayConverter:
else:
jtpe = self._from_numpy_type_to_java_type(obj.dtype, gateway)
if jtpe is None:
- raise TypeError("The type of array scalar '%s' is not
supported" % (obj.dtype))
+ raise PySparkTypeError(
+ error_class="UNSUPPORTED_NUMPY_ARRAY_SCALAR",
+ message_parameters={"dtype": str(obj.dtype)},
+ )
jarr = gateway.new_array(jtpe, len(obj))
for i in range(len(plist)):
jarr[i] = plist[i]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]