This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3eca6e4d354 [SPARK-43306][PYTHON] Migrate `ValueError` from Spark SQL 
types into error class
3eca6e4d354 is described below

commit 3eca6e4d3547010f521af029c265e64ae79c3a82
Author: itholic <[email protected]>
AuthorDate: Fri May 5 10:38:39 2023 +0800

    [SPARK-43306][PYTHON] Migrate `ValueError` from Spark SQL types into error 
class
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to migrate all `ValueError` from Spark SQL types into 
error class.
    
    ### Why are the changes needed?
    
    To improve PySpark error
    
    ### Does this PR introduce _any_ user-facing change?
    
    No API changes, only error improvement.
    
    ### How was this patch tested?
    
    This existing CI should pass.
    
    Closes #40975 from itholic/verror_sql_types.
    
    Authored-by: itholic <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/pyspark/errors/error_classes.py             |  24 +++-
 python/pyspark/sql/connect/expressions.py          |   6 +-
 .../sql/tests/connect/test_connect_column.py       |   4 +-
 .../sql/tests/connect/test_connect_function.py     |   2 +-
 python/pyspark/sql/tests/test_dataframe.py         |   2 +-
 python/pyspark/sql/tests/test_types.py             |  15 ++-
 python/pyspark/sql/types.py                        | 127 ++++++++++++++++-----
 7 files changed, 137 insertions(+), 43 deletions(-)

diff --git a/python/pyspark/errors/error_classes.py 
b/python/pyspark/errors/error_classes.py
index c6f335bf8ac..27f047d0e36 100644
--- a/python/pyspark/errors/error_classes.py
+++ b/python/pyspark/errors/error_classes.py
@@ -111,7 +111,7 @@ ERROR_CLASSES_JSON = """
   },
   "CANNOT_PARSE_DATATYPE": {
     "message": [
-      "Unable to parse datatype from schema. <error>."
+      "Unable to parse datatype. <msg>."
     ]
   },
   "CANNOT_PROVIDE_METADATA": {
@@ -395,6 +395,11 @@ ERROR_CLASSES_JSON = """
       "<feature> is not implemented."
     ]
   },
+  "NOT_INSTANCE_OF" : {
+    "message" : [
+      "<value> is not an instance of type <data_type>."
+    ]
+  },
   "NOT_INT" : {
     "message" : [
       "Argument `<arg_name>` should be an int, got <arg_type>."
@@ -566,11 +571,21 @@ ERROR_CLASSES_JSON = """
       "pandas iterator UDF should exhaust the input iterator."
     ]
   },
+  "TOO_MANY_VALUES" : {
+    "message" : [
+      "Expected <expected> values for `<item>`, got <actual>."
+    ]
+  },
   "UNEXPECTED_RESPONSE_FROM_SERVER" : {
     "message" : [
       "Unexpected response from iterator server."
     ]
   },
+  "UNEXPECTED_TUPLE_WITH_STRUCT" : {
+    "message" : [
+      "Unexpected tuple <tuple> with StructType."
+    ]
+  },
   "UNKNOWN_EXPLAIN_MODE" : {
     "message" : [
       "Unknown explain mode: '<explain_mode>'. Accepted explain modes are 
'simple', 'extended', 'codegen', 'cost', 'formatted'."
@@ -646,6 +661,11 @@ ERROR_CLASSES_JSON = """
       "Value for `<arg_name>` must be 'any' or 'all', got '<arg_value>'."
     ]
   },
+  "VALUE_NOT_BETWEEN" : {
+    "message" : [
+      "Value for `<arg_name>` must be between <min> and <max>."
+    ]
+  },
   "VALUE_NOT_NON_EMPTY_STR" : {
     "message" : [
       "Value for `<arg_name>` must be a non empty string, got '<arg_value>'."
@@ -668,7 +688,7 @@ ERROR_CLASSES_JSON = """
   },
   "VALUE_OUT_OF_BOUND" : {
     "message" : [
-      "Value for `<arg_name>` must be between <min> and <max>."
+      "Value for `<arg_name>` must be greater than <lower_bound> or less than 
<upper_bound>, got <actual>"
     ]
   },
   "WRONG_NUM_ARGS_FOR_HIGHER_ORDER_FUNCTION" : {
diff --git a/python/pyspark/sql/connect/expressions.py 
b/python/pyspark/sql/connect/expressions.py
index 4fc0147d29b..e1b648c7bb8 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -282,7 +282,7 @@ class LiteralExpression(Expression):
                 return LongType()
             else:
                 raise PySparkValueError(
-                    error_class="VALUE_OUT_OF_BOUND",
+                    error_class="VALUE_NOT_BETWEEN",
                     message_parameters={
                         "arg_name": "value",
                         "min": str(JVM_LONG_MIN),
@@ -968,7 +968,7 @@ class WindowExpression(Expression):
                     expr.window.frame_spec.lower.value.literal.integer = start
                 else:
                     raise PySparkValueError(
-                        error_class="VALUE_OUT_OF_BOUND",
+                        error_class="VALUE_NOT_BETWEEN",
                         message_parameters={
                             "arg_name": "start",
                             "min": str(JVM_INT_MIN),
@@ -985,7 +985,7 @@ class WindowExpression(Expression):
                     expr.window.frame_spec.upper.value.literal.integer = end
                 else:
                     raise PySparkValueError(
-                        error_class="VALUE_OUT_OF_BOUND",
+                        error_class="VALUE_NOT_BETWEEN",
                         message_parameters={
                             "arg_name": "end",
                             "min": str(JVM_INT_MIN),
diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py 
b/python/pyspark/sql/tests/connect/test_connect_column.py
index a62f4dcfebf..d838260a26f 100644
--- a/python/pyspark/sql/tests/connect/test_connect_column.py
+++ b/python/pyspark/sql/tests/connect/test_connect_column.py
@@ -519,7 +519,7 @@ class SparkConnectColumnTests(SparkConnectSQLTestCase):
 
         self.check_error(
             exception=pe.exception,
-            error_class="VALUE_OUT_OF_BOUND",
+            error_class="VALUE_NOT_BETWEEN",
             message_parameters={"arg_name": "value", "min": 
"-9223372036854775808", "max": "32767"},
         )
 
@@ -528,7 +528,7 @@ class SparkConnectColumnTests(SparkConnectSQLTestCase):
 
         self.check_error(
             exception=pe.exception,
-            error_class="VALUE_OUT_OF_BOUND",
+            error_class="VALUE_NOT_BETWEEN",
             message_parameters={"arg_name": "value", "min": 
"-9223372036854775808", "max": "32767"},
         )
 
diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py 
b/python/pyspark/sql/tests/connect/test_connect_function.py
index 38a7ed4df62..e274635d3c6 100644
--- a/python/pyspark/sql/tests/connect/test_connect_function.py
+++ b/python/pyspark/sql/tests/connect/test_connect_function.py
@@ -867,7 +867,7 @@ class SparkConnectFunctionTests(ReusedConnectTestCase, 
PandasOnSparkTestUtils, S
 
         self.check_error(
             exception=pe.exception,
-            error_class="VALUE_OUT_OF_BOUND",
+            error_class="VALUE_NOT_BETWEEN",
             message_parameters={"arg_name": "end", "min": "-2147483648", 
"max": "2147483647"},
         )
 
diff --git a/python/pyspark/sql/tests/test_dataframe.py 
b/python/pyspark/sql/tests/test_dataframe.py
index df17e13e7f0..a9921d3063e 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -1004,7 +1004,7 @@ class DataFrameTestsMixin:
 
         # number of fields must match.
         self.assertRaisesRegex(
-            Exception, "Length of object", lambda: rdd.toDF("key: 
int").collect()
+            Exception, "LENGTH_SHOULD_BE_THE_SAME", lambda: rdd.toDF("key: 
int").collect()
         )
 
         # field types mismatch will cause exception at runtime.
diff --git a/python/pyspark/sql/tests/test_types.py 
b/python/pyspark/sql/tests/test_types.py
index 49952c2c135..083aa151d0d 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -26,7 +26,7 @@ import unittest
 
 from pyspark.sql import Row
 from pyspark.sql import functions as F
-from pyspark.errors import AnalysisException, PySparkTypeError
+from pyspark.errors import AnalysisException, PySparkTypeError, 
PySparkValueError
 from pyspark.sql.types import (
     ByteType,
     ShortType,
@@ -1338,10 +1338,15 @@ class DataTypeTests(unittest.TestCase):
 
 class DataTypeVerificationTests(unittest.TestCase, PySparkErrorTestUtils):
     def test_verify_type_exception_msg(self):
-        self.assertRaisesRegex(
-            ValueError,
-            "test_name",
-            lambda: _make_type_verifier(StringType(), nullable=False, 
name="test_name")(None),
+        with self.assertRaises(PySparkValueError) as pe:
+            _make_type_verifier(StringType(), nullable=False, 
name="test_name")(None)
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="CANNOT_BE_NONE",
+            message_parameters={
+                "arg_name": "obj",
+            },
         )
 
         schema = StructType([StructField("a", StructType([StructField("b", 
IntegerType())]))])
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 70d90a03c10..3600bc49ea7 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -50,7 +50,7 @@ from py4j.java_gateway import GatewayClient, JavaClass, 
JavaGateway, JavaObject,
 
 from pyspark.serializers import CloudPickleSerializer
 from pyspark.sql.utils import has_numpy, get_active_spark_context
-from pyspark.errors import PySparkNotImplementedError, PySparkTypeError
+from pyspark.errors import PySparkNotImplementedError, PySparkTypeError, 
PySparkValueError
 
 if has_numpy:
     import numpy as np
@@ -864,7 +864,13 @@ class StructType(DataType):
             self.names.append(field.name)
         else:
             if isinstance(field, str) and data_type is None:
-                raise ValueError("Must specify DataType if passing name of 
struct_field to create.")
+                raise PySparkValueError(
+                    error_class="ARGUMENT_REQUIRED",
+                    message_parameters={
+                        "arg_name": "data_type",
+                        "condition": "passing name of struct_field to create",
+                    },
+                )
 
             if isinstance(data_type, str):
                 data_type_f = _parse_datatype_json_value(data_type)
@@ -1049,7 +1055,10 @@ class StructType(DataType):
                     for n, f, c in zip(self.names, self.fields, 
self._needConversion)
                 )
             else:
-                raise ValueError("Unexpected tuple %r with StructType" % obj)
+                raise PySparkValueError(
+                    error_class="UNEXPECTED_TUPLE_WITH_STRUCT",
+                    message_parameters={"tuple": str(obj)},
+                )
         else:
             if isinstance(obj, dict):
                 return tuple(obj.get(n) for n in self.names)
@@ -1059,7 +1068,10 @@ class StructType(DataType):
                 d = obj.__dict__
                 return tuple(d.get(n) for n in self.names)
             else:
-                raise ValueError("Unexpected tuple %r with StructType" % obj)
+                raise PySparkValueError(
+                    error_class="UNEXPECTED_TUPLE_WITH_STRUCT",
+                    message_parameters={"tuple": str(obj)},
+                )
 
     def fromInternal(self, obj: Tuple) -> "Row":
         if obj is None:
@@ -1395,7 +1407,10 @@ def _parse_datatype_json_value(json_value: Union[dict, 
str]) -> DataType:
             m = _LENGTH_VARCHAR.match(json_value)
             return VarcharType(int(m.group(1)))  # type: ignore[union-attr]
         else:
-            raise ValueError("Could not parse datatype: %s" % json_value)
+            raise PySparkValueError(
+                error_class="CANNOT_PARSE_DATATYPE",
+                message_parameters={"error": str(json_value)},
+            )
     else:
         tpe = json_value["type"]
         if tpe in _all_complex_types:
@@ -1403,7 +1418,10 @@ def _parse_datatype_json_value(json_value: Union[dict, 
str]) -> DataType:
         elif tpe == "udt":
             return UserDefinedType.fromJson(json_value)
         else:
-            raise ValueError("not supported type: %s" % tpe)
+            raise PySparkValueError(
+                error_class="UNSUPPORTED_DATA_TYPE",
+                message_parameters={"data_type": str(tpe)},
+            )
 
 
 # Mapping Python types to Spark SQL DataType
@@ -1904,7 +1922,7 @@ def _make_type_verifier(
     >>> _make_type_verifier(LongType())(1 << 64) # doctest: 
+IGNORE_EXCEPTION_DETAIL
     Traceback (most recent call last):
         ...
-    ValueError:...
+    pyspark.errors.exceptions.base.PySparkValueError:...
     >>> _make_type_verifier(ArrayType(ShortType()))(list(range(3)))
     >>> _make_type_verifier(ArrayType(StringType()))(set()) # doctest: 
+IGNORE_EXCEPTION_DETAIL
     Traceback (most recent call last):
@@ -1916,33 +1934,33 @@ def _make_type_verifier(
     >>> _make_type_verifier(StructType([]))([1]) # doctest: 
+IGNORE_EXCEPTION_DETAIL
     Traceback (most recent call last):
         ...
-    ValueError:...
+    pyspark.errors.exceptions.base.PySparkValueError:...
     >>> # Check if numeric values are within the allowed range.
     >>> _make_type_verifier(ByteType())(12)
     >>> _make_type_verifier(ByteType())(1234) # doctest: 
+IGNORE_EXCEPTION_DETAIL
     Traceback (most recent call last):
         ...
-    ValueError:...
+    pyspark.errors.exceptions.base.PySparkValueError:...
     >>> _make_type_verifier(ByteType(), False)(None) # doctest: 
+IGNORE_EXCEPTION_DETAIL
     Traceback (most recent call last):
         ...
-    ValueError:...
+    pyspark.errors.exceptions.base.PySparkValueError:...
     >>> _make_type_verifier(
     ...     ArrayType(ShortType(), False))([1, None]) # doctest: 
+IGNORE_EXCEPTION_DETAIL
     Traceback (most recent call last):
         ...
-    ValueError:...
+    pyspark.errors.exceptions.base.PySparkValueError:...
     >>> _make_type_verifier(  # doctest: +IGNORE_EXCEPTION_DETAIL
     ...     MapType(StringType(), IntegerType())
     ...     )({None: 1})
     Traceback (most recent call last):
         ...
-    ValueError:...
+    pyspark.errors.exceptions.base.PySparkValueError:...
     >>> schema = StructType().add("a", IntegerType()).add("b", StringType(), 
False)
     >>> _make_type_verifier(schema)((1, None)) # doctest: 
+IGNORE_EXCEPTION_DETAIL
     Traceback (most recent call last):
         ...
-    ValueError:...
+    pyspark.errors.exceptions.base.PySparkValueError:...
     """
 
     if name is None:
@@ -1966,7 +1984,10 @@ def _make_type_verifier(
             if nullable:
                 return True
             else:
-                raise ValueError(new_msg("This field is not nullable, but got 
None"))
+                raise PySparkValueError(
+                    error_class="CANNOT_BE_NONE",
+                    message_parameters={"arg_name": "obj"},
+                )
         else:
             return False
 
@@ -1999,7 +2020,13 @@ def _make_type_verifier(
 
         def verify_udf(obj: Any) -> None:
             if not (hasattr(obj, "__UDT__") and obj.__UDT__ == dataType):
-                raise ValueError(new_msg("%r is not an instance of type %r" % 
(obj, dataType)))
+                raise PySparkValueError(
+                    error_class="NOT_INSTANCE_OF",
+                    message_parameters={
+                        "value": str(obj),
+                        "data_type": str(dataType),
+                    },
+                )
             verifier(dataType.toInternal(obj))
 
         verify_value = verify_udf
@@ -2010,7 +2037,15 @@ def _make_type_verifier(
             assert_acceptable_types(obj)
             verify_acceptable_types(obj)
             if obj < -128 or obj > 127:
-                raise ValueError(new_msg("object of ByteType out of range, 
got: %s" % obj))
+                raise PySparkValueError(
+                    error_class="VALUE_OUT_OF_BOUND",
+                    message_parameters={
+                        "arg_name": "obj",
+                        "lower_bound": "127",
+                        "upper_bound": "-127",
+                        "actual": str(obj),
+                    },
+                )
 
         verify_value = verify_byte
 
@@ -2020,7 +2055,15 @@ def _make_type_verifier(
             assert_acceptable_types(obj)
             verify_acceptable_types(obj)
             if obj < -32768 or obj > 32767:
-                raise ValueError(new_msg("object of ShortType out of range, 
got: %s" % obj))
+                raise PySparkValueError(
+                    error_class="VALUE_OUT_OF_BOUND",
+                    message_parameters={
+                        "arg_name": "obj",
+                        "lower_bound": "32767",
+                        "upper_bound": "-32768",
+                        "actual": str(obj),
+                    },
+                )
 
         verify_value = verify_short
 
@@ -2030,7 +2073,15 @@ def _make_type_verifier(
             assert_acceptable_types(obj)
             verify_acceptable_types(obj)
             if obj < -2147483648 or obj > 2147483647:
-                raise ValueError(new_msg("object of IntegerType out of range, 
got: %s" % obj))
+                raise PySparkValueError(
+                    error_class="VALUE_OUT_OF_BOUND",
+                    message_parameters={
+                        "arg_name": "obj",
+                        "lower_bound": "2147483647",
+                        "upper_bound": "-2147483648",
+                        "actual": str(obj),
+                    },
+                )
 
         verify_value = verify_integer
 
@@ -2040,7 +2091,15 @@ def _make_type_verifier(
             assert_acceptable_types(obj)
             verify_acceptable_types(obj)
             if obj < -9223372036854775808 or obj > 9223372036854775807:
-                raise ValueError(new_msg("object of LongType out of range, 
got: %s" % obj))
+                raise PySparkValueError(
+                    error_class="VALUE_OUT_OF_BOUND",
+                    message_parameters={
+                        "arg_name": "obj",
+                        "lower_bound": "9223372036854775807",
+                        "upper_bound": "-9223372036854775808",
+                        "actual": str(obj),
+                    },
+                )
 
         verify_value = verify_long
 
@@ -2086,11 +2145,14 @@ def _make_type_verifier(
                     verifier(obj.get(f))
             elif isinstance(obj, (tuple, list)):
                 if len(obj) != len(verifiers):
-                    raise ValueError(
-                        new_msg(
-                            "Length of object (%d) does not match with "
-                            "length of fields (%d)" % (len(obj), 
len(verifiers))
-                        )
+                    raise PySparkValueError(
+                        error_class="LENGTH_SHOULD_BE_THE_SAME",
+                        message_parameters={
+                            "arg1": "obj",
+                            "arg2": "fields",
+                            "arg1_length": str(len(obj)),
+                            "arg2_length": str(len(verifiers)),
+                        },
                     )
                 for v, (_, verifier) in zip(obj, verifiers):
                     verifier(v)
@@ -2205,7 +2267,10 @@ class Row(tuple):
 
     def __new__(cls, *args: Optional[str], **kwargs: Optional[Any]) -> "Row":
         if args and kwargs:
-            raise ValueError("Can not use both args " "and kwargs to create 
Row")
+            raise PySparkValueError(
+                error_class="CANNOT_SET_TOGETHER",
+                message_parameters={"arg_list": "args and kwargs"},
+            )
         if kwargs:
             # create row objects
             row = tuple.__new__(cls, list(kwargs.values()))
@@ -2278,9 +2343,13 @@ class Row(tuple):
     def __call__(self, *args: Any) -> "Row":
         """create new Row object"""
         if len(args) > len(self):
-            raise ValueError(
-                "Can not create Row with fields %s, expected %d values "
-                "but got %s" % (self, len(self), args)
+            raise PySparkValueError(
+                error_class="TOO_MANY_VALUES",
+                message_parameters={
+                    "expected": str(len(self)),
+                    "item": "fields",
+                    "actual": str(len(args)),
+                },
             )
         return _create_row(self, args)
 
@@ -2295,7 +2364,7 @@ class Row(tuple):
         except IndexError:
             raise KeyError(item)
         except ValueError:
-            raise ValueError(item)
+            raise PySparkValueError(item)
 
     def __getattr__(self, item: str) -> Any:
         if item.startswith("__"):


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to