This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 7e937322773 [SPARK-43212][SS][PYTHON] Migrate Structured Streaming 
errors into error class
7e937322773 is described below

commit 7e9373227731a8c22a0be3ea47850e313d6f05e1
Author: itholic <haejoon....@databricks.com>
AuthorDate: Mon Apr 24 09:17:05 2023 +0900

    [SPARK-43212][SS][PYTHON] Migrate Structured Streaming errors into error 
class
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to migrate built-in `TypeError` and `ValueError` from 
Structured Streaming into PySpark error framework.
    
    ### Why are the changes needed?
    
    To improve the errors
    
    ### Does this PR introduce _any_ user-facing change?
    
    No API change. It's only error message improvements.
    
    ### How was this patch tested?
    
    The existing CI should pass.
    
    Closes #40880 from itholic/streaming_type_error.
    
    Authored-by: itholic <haejoon....@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/errors/error_classes.py             | 40 +++++++++
 python/pyspark/sql/streaming/query.py              | 12 ++-
 python/pyspark/sql/streaming/readwriter.py         | 96 ++++++++++++++++------
 python/pyspark/sql/streaming/state.py              | 44 ++++++++--
 .../sql/tests/streaming/test_streaming_foreach.py  | 12 +--
 5 files changed, 161 insertions(+), 43 deletions(-)

diff --git a/python/pyspark/errors/error_classes.py 
b/python/pyspark/errors/error_classes.py
index 3ad8e5d9703..5ee03430c89 100644
--- a/python/pyspark/errors/error_classes.py
+++ b/python/pyspark/errors/error_classes.py
@@ -24,6 +24,11 @@ ERROR_CLASSES_JSON = """
       "Argument `<arg_name>` is required when <condition>."
     ]
   },
+  "ATTRIBUTE_NOT_CALLABLE" : {
+    "message" : [
+      "Attribute `<attr_name>` in provided object `<obj_name>` is not 
callable."
+    ]
+  },
   "CANNOT_ACCESS_TO_DUNDER": {
     "message": [
       "Dunder(double underscore) attribute is for internal use only."
@@ -39,6 +44,11 @@ ERROR_CLASSES_JSON = """
       "At least one <item> must be specified."
     ]
   },
+  "CANNOT_BE_NONE": {
+    "message": [
+      "Argument `<arg_name>` can not be None."
+    ]
+  },
   "CANNOT_CONVERT_COLUMN_INTO_BOOL": {
     "message": [
       "Cannot convert column into bool: please use '&' for 'and', '|' for 
'or', '~' for 'not' when building DataFrame boolean expressions."
@@ -74,6 +84,11 @@ ERROR_CLASSES_JSON = """
       "All items in `<arg_name>` should be in <allowed_types>, got 
<item_type>."
     ]
   },
+  "INVALID_TIMEOUT_TIMESTAMP" : {
+    "message" : [
+      "Timeout timestamp (<timestamp>) cannot be earlier than the current 
watermark (<watermark>)."
+    ]
+  },
   "INVALID_WHEN_USAGE": {
     "message": [
       "when() can only be applied on a Column previously generated by when() 
function, and cannot be applied once otherwise() is applied."
@@ -139,6 +154,11 @@ ERROR_CLASSES_JSON = """
       "Argument `<arg_name>` should be a bool or str, got <arg_type>."
     ]
   },
+  "NOT_CALLABLE" : {
+    "message" : [
+      "Argument `<arg_name>` should be a callable, got <arg_type>."
+    ]
+  },
   "NOT_COLUMN" : {
     "message" : [
       "Argument `<arg_name>` should be a Column, got <arg_type>."
@@ -279,11 +299,21 @@ ERROR_CLASSES_JSON = """
       "Argument `<arg_name>` can only be provided for a single column."
     ]
   },
+  "ONLY_ALLOW_SINGLE_TRIGGER" : {
+    "message" : [
+      "Only a single trigger is allowed."
+    ]
+  },
   "SLICE_WITH_STEP" : {
     "message" : [
       "Slice with step is not supported."
     ]
   },
+  "STATE_NOT_EXISTS" : {
+    "message" : [
+      "State is either not defined or has already been removed."
+    ]
+  },
   "UNSUPPORTED_NUMPY_ARRAY_SCALAR" : {
     "message" : [
       "The type of array scalar '<dtype>' is not supported."
@@ -299,6 +329,11 @@ ERROR_CLASSES_JSON = """
       "Value for `<arg_name>` must be 'any' or 'all', got '<arg_value>'."
     ]
   },
+  "VALUE_NOT_NON_EMPTY_STR" : {
+    "message" : [
+      "Value for `<arg_name>` must be a non empty string, got '<arg_value>'."
+    ]
+  },
   "VALUE_NOT_PEARSON" : {
     "message" : [
       "Value for `<arg_name>` only supports the 'pearson', got '<arg_value>'."
@@ -309,6 +344,11 @@ ERROR_CLASSES_JSON = """
       "Value for `<arg_name>` must be positive, got '<arg_value>'."
     ]
   },
+  "VALUE_NOT_TRUE" : {
+    "message" : [
+      "Value for `<arg_name>` must be True, got '<arg_value>'."
+    ]
+  },
   "WRONG_NUM_ARGS_FOR_HIGHER_ORDER_FUNCTION" : {
     "message" : [
       "Function `<func_name>` should take between 1 and 3 arguments, but 
provided function takes <num_args>."
diff --git a/python/pyspark/sql/streaming/query.py 
b/python/pyspark/sql/streaming/query.py
index d2ce95da224..b6268dcdb18 100644
--- a/python/pyspark/sql/streaming/query.py
+++ b/python/pyspark/sql/streaming/query.py
@@ -20,7 +20,7 @@ from typing import Any, Dict, List, Optional
 
 from py4j.java_gateway import JavaObject, java_import
 
-from pyspark.errors import StreamingQueryException
+from pyspark.errors import StreamingQueryException, PySparkValueError
 from pyspark.errors.exceptions.captured import (
     StreamingQueryException as CapturedStreamingQueryException,
 )
@@ -197,7 +197,10 @@ class StreamingQuery:
         """
         if timeout is not None:
             if not isinstance(timeout, (int, float)) or timeout <= 0:
-                raise ValueError("timeout must be a positive integer or float. 
Got %s" % timeout)
+                raise PySparkValueError(
+                    error_class="VALUE_NOT_POSITIVE",
+                    message_parameters={"arg_name": "timeout", "arg_value": 
type(timeout).__name__},
+                )
             return self._jsq.awaitTermination(int(timeout * 1000))
         else:
             return self._jsq.awaitTermination()
@@ -532,7 +535,10 @@ class StreamingQueryManager:
         """
         if timeout is not None:
             if not isinstance(timeout, (int, float)) or timeout < 0:
-                raise ValueError("timeout must be a positive integer or float. 
Got %s" % timeout)
+                raise PySparkValueError(
+                    error_class="VALUE_NOT_POSITIVE",
+                    message_parameters={"arg_name": "timeout", "arg_value": 
type(timeout).__name__},
+                )
             return self._jsqm.awaitAnyTermination(int(timeout * 1000))
         else:
             return self._jsqm.awaitAnyTermination()
diff --git a/python/pyspark/sql/streaming/readwriter.py 
b/python/pyspark/sql/streaming/readwriter.py
index 529e3aeb60d..f805d0cc152 100644
--- a/python/pyspark/sql/streaming/readwriter.py
+++ b/python/pyspark/sql/streaming/readwriter.py
@@ -26,6 +26,7 @@ from pyspark.sql.readwriter import OptionUtils, to_str
 from pyspark.sql.streaming.query import StreamingQuery
 from pyspark.sql.types import Row, StructType
 from pyspark.sql.utils import ForeachBatchFunction
+from pyspark.errors import PySparkTypeError, PySparkValueError
 
 if TYPE_CHECKING:
     from pyspark.sql.session import SparkSession
@@ -160,7 +161,10 @@ class DataStreamReader(OptionUtils):
         elif isinstance(schema, str):
             self._jreader = self._jreader.schema(schema)
         else:
-            raise TypeError("schema should be StructType or string")
+            raise PySparkTypeError(
+                error_class="NOT_STR_OR_STRUCT",
+                message_parameters={"arg_name": "schema", "arg_type": 
type(schema).__name__},
+            )
         return self
 
     def option(self, key: str, value: "OptionalPrimitiveType") -> 
"DataStreamReader":
@@ -271,9 +275,9 @@ class DataStreamReader(OptionUtils):
         self.options(**options)
         if path is not None:
             if type(path) != str or len(path.strip()) == 0:
-                raise ValueError(
-                    "If the path is provided for stream, it needs to be a "
-                    + "non-empty string. List of paths are not supported."
+                raise PySparkValueError(
+                    error_class="VALUE_NOT_NON_EMPTY_STR",
+                    message_parameters={"arg_name": "path", "arg_value": 
str(path)},
                 )
             return self._df(self._jreader.load(path))
         else:
@@ -382,7 +386,10 @@ class DataStreamReader(OptionUtils):
         if isinstance(path, str):
             return self._df(self._jreader.json(path))
         else:
-            raise TypeError("path can be only a single string")
+            raise PySparkTypeError(
+                error_class="NOT_STR",
+                message_parameters={"arg_name": "path", "arg_type": 
type(path).__name__},
+            )
 
     def orc(
         self,
@@ -427,7 +434,10 @@ class DataStreamReader(OptionUtils):
         if isinstance(path, str):
             return self._df(self._jreader.orc(path))
         else:
-            raise TypeError("path can be only a single string")
+            raise PySparkTypeError(
+                error_class="NOT_STR",
+                message_parameters={"arg_name": "path", "arg_type": 
type(path).__name__},
+            )
 
     def parquet(
         self,
@@ -483,7 +493,10 @@ class DataStreamReader(OptionUtils):
         if isinstance(path, str):
             return self._df(self._jreader.parquet(path))
         else:
-            raise TypeError("path can be only a single string")
+            raise PySparkTypeError(
+                error_class="NOT_STR",
+                message_parameters={"arg_name": "path", "arg_type": 
type(path).__name__},
+            )
 
     def text(
         self,
@@ -546,7 +559,10 @@ class DataStreamReader(OptionUtils):
         if isinstance(path, str):
             return self._df(self._jreader.text(path))
         else:
-            raise TypeError("path can be only a single string")
+            raise PySparkTypeError(
+                error_class="NOT_STR",
+                message_parameters={"arg_name": "path", "arg_type": 
type(path).__name__},
+            )
 
     def csv(
         self,
@@ -663,7 +679,10 @@ class DataStreamReader(OptionUtils):
         if isinstance(path, str):
             return self._df(self._jreader.csv(path))
         else:
-            raise TypeError("path can be only a single string")
+            raise PySparkTypeError(
+                error_class="NOT_STR",
+                message_parameters={"arg_name": "path", "arg_type": 
type(path).__name__},
+            )
 
     def table(self, tableName: str) -> "DataFrame":
         """Define a Streaming DataFrame on a Table. The DataSource 
corresponding to the table should
@@ -706,7 +725,10 @@ class DataStreamReader(OptionUtils):
         if isinstance(tableName, str):
             return self._df(self._jreader.table(tableName))
         else:
-            raise TypeError("tableName can be only a single string")
+            raise PySparkTypeError(
+                error_class="NOT_STR",
+                message_parameters={"arg_name": "tableName", "arg_type": 
type(tableName).__name__},
+            )
 
 
 class DataStreamWriter:
@@ -779,7 +801,10 @@ class DataStreamWriter:
         >>> q.stop()
         """
         if not outputMode or type(outputMode) != str or 
len(outputMode.strip()) == 0:
-            raise ValueError("The output mode must be a non-empty string. Got: 
%s" % outputMode)
+            raise PySparkValueError(
+                error_class="VALUE_NOT_NON_EMPTY_STR",
+                message_parameters={"arg_name": "outputMode", "arg_value": 
str(outputMode)},
+            )
         self._jwrite = self._jwrite.outputMode(outputMode)
         return self
 
@@ -957,7 +982,10 @@ class DataStreamWriter:
         'streaming_query'
         """
         if not queryName or type(queryName) != str or len(queryName.strip()) 
== 0:
-            raise ValueError("The queryName must be a non-empty string. Got: 
%s" % queryName)
+            raise PySparkValueError(
+                error_class="VALUE_NOT_NON_EMPTY_STR",
+                message_parameters={"arg_name": "queryName", "arg_value": 
str(queryName)},
+            )
         self._jwrite = self._jwrite.queryName(queryName)
         return self
 
@@ -1033,16 +1061,26 @@ class DataStreamWriter:
         params = [processingTime, once, continuous, availableNow]
 
         if params.count(None) == 4:
-            raise ValueError("No trigger provided")
+            raise PySparkValueError(
+                error_class="ONLY_ALLOW_SINGLE_TRIGGER",
+                message_parameters={},
+            )
         elif params.count(None) < 3:
-            raise ValueError("Multiple triggers not allowed.")
+            raise PySparkValueError(
+                error_class="ONLY_ALLOW_SINGLE_TRIGGER",
+                message_parameters={},
+            )
 
         jTrigger = None
         assert self._spark._sc._jvm is not None
         if processingTime is not None:
             if type(processingTime) != str or len(processingTime.strip()) == 0:
-                raise ValueError(
-                    "Value for processingTime must be a non empty string. Got: 
%s" % processingTime
+                raise PySparkValueError(
+                    error_class="VALUE_NOT_NON_EMPTY_STR",
+                    message_parameters={
+                        "arg_name": "processingTime",
+                        "arg_value": str(processingTime),
+                    },
                 )
             interval = processingTime.strip()
             jTrigger = 
self._spark._sc._jvm.org.apache.spark.sql.streaming.Trigger.ProcessingTime(
@@ -1051,13 +1089,18 @@ class DataStreamWriter:
 
         elif once is not None:
             if once is not True:
-                raise ValueError("Value for once must be True. Got: %s" % once)
+                raise PySparkValueError(
+                    error_class="VALUE_NOT_TRUE",
+                    message_parameters={"arg_name": "once", "arg_value": 
str(once)},
+                )
+
             jTrigger = 
self._spark._sc._jvm.org.apache.spark.sql.streaming.Trigger.Once()
 
         elif continuous is not None:
             if type(continuous) != str or len(continuous.strip()) == 0:
-                raise ValueError(
-                    "Value for continuous must be a non empty string. Got: %s" 
% continuous
+                raise PySparkValueError(
+                    error_class="VALUE_NOT_NON_EMPTY_STR",
+                    message_parameters={"arg_name": "continuous", "arg_value": 
str(continuous)},
                 )
             interval = continuous.strip()
             jTrigger = 
self._spark._sc._jvm.org.apache.spark.sql.streaming.Trigger.Continuous(
@@ -1065,7 +1108,10 @@ class DataStreamWriter:
             )
         else:
             if availableNow is not True:
-                raise ValueError("Value for availableNow must be True. Got: 
%s" % availableNow)
+                raise PySparkValueError(
+                    error_class="VALUE_NOT_TRUE",
+                    message_parameters={"arg_name": "availableNow", 
"arg_value": str(availableNow)},
+                )
             jTrigger = 
self._spark._sc._jvm.org.apache.spark.sql.streaming.Trigger.AvailableNow()
 
         self._jwrite = self._jwrite.trigger(jTrigger)
@@ -1208,13 +1254,17 @@ class DataStreamWriter:
                 raise AttributeError("Provided object does not have a 
'process' method")
 
             if not callable(getattr(f, "process")):
-                raise TypeError("Attribute 'process' in provided object is not 
callable")
+                raise PySparkTypeError(
+                    error_class="ATTRIBUTE_NOT_CALLABLE",
+                    message_parameters={"attr_name": "process", "obj_name": 
"f"},
+                )
 
             def doesMethodExist(method_name: str) -> bool:
                 exists = hasattr(f, method_name)
                 if exists and not callable(getattr(f, method_name)):
-                    raise TypeError(
-                        "Attribute '%s' in provided object is not callable" % 
method_name
+                    raise PySparkTypeError(
+                        error_class="ATTRIBUTE_NOT_CALLABLE",
+                        message_parameters={"attr_name": method_name, 
"obj_name": "f"},
                     )
                 return exists
 
diff --git a/python/pyspark/sql/streaming/state.py 
b/python/pyspark/sql/streaming/state.py
index f0ac427cbea..8bf01b3ebd9 100644
--- a/python/pyspark/sql/streaming/state.py
+++ b/python/pyspark/sql/streaming/state.py
@@ -20,6 +20,7 @@ from typing import Tuple, Optional
 
 from pyspark.sql.types import DateType, Row, StructType
 from pyspark.sql.utils import has_numpy
+from pyspark.errors import PySparkTypeError, PySparkValueError
 
 __all__ = ["GroupState", "GroupStateTimeout"]
 
@@ -98,7 +99,10 @@ class GroupState:
         if self.exists:
             return tuple(self._value)
         else:
-            raise ValueError("State is either not defined or has already been 
removed")
+            raise PySparkValueError(
+                error_class="STATE_NOT_EXISTS",
+                message_parameters={},
+            )
 
     @property
     def getOption(self) -> Optional[Tuple]:
@@ -129,7 +133,10 @@ class GroupState:
         Update the value of the state. The value of the state cannot be null.
         """
         if newValue is None:
-            raise ValueError("'None' is not a valid state value")
+            raise PySparkTypeError(
+                error_class="CANNOT_BE_NONE",
+                message_parameters={"arg_name": "newValue"},
+            )
 
         converted = []
         if has_numpy:
@@ -169,7 +176,13 @@ class GroupState:
         """
         if isinstance(durationMs, str):
             # TODO(SPARK-40437): Support string representation of durationMs.
-            raise ValueError("durationMs should be int but get :%s" % 
type(durationMs))
+            raise PySparkTypeError(
+                error_class="NOT_INT",
+                message_parameters={
+                    "arg_name": "durationMs",
+                    "arg_type": type(durationMs).__name__,
+                },
+            )
 
         if self._timeout_conf != GroupStateTimeout.ProcessingTimeTimeout:
             raise RuntimeError(
@@ -178,7 +191,13 @@ class GroupState:
             )
 
         if durationMs <= 0:
-            raise ValueError("Timeout duration must be positive")
+            raise PySparkValueError(
+                error_class="VALUE_NOT_POSITIVE",
+                message_parameters={
+                    "arg_name": "durationMs",
+                    "arg_type": type(durationMs).__name__,
+                },
+            )
         self._timeout_timestamp = durationMs + self._batch_processing_time_ms
 
     # TODO(SPARK-40438): Implement additionalDuration parameter.
@@ -198,15 +217,24 @@ class GroupState:
             timestampMs = DateType().toInternal(timestampMs)
 
         if timestampMs <= 0:
-            raise ValueError("Timeout timestamp must be positive")
+            raise PySparkValueError(
+                error_class="VALUE_NOT_POSITIVE",
+                message_parameters={
+                    "arg_name": "timestampMs",
+                    "arg_type": type(timestampMs).__name__,
+                },
+            )
 
         if (
             self._event_time_watermark_ms != GroupState.NO_TIMESTAMP
             and timestampMs < self._event_time_watermark_ms
         ):
-            raise ValueError(
-                "Timeout timestamp (%s) cannot be earlier than the "
-                "current watermark (%s)" % (timestampMs, 
self._event_time_watermark_ms)
+            raise PySparkValueError(
+                error_class="INVALID_TIMEOUT_TIMESTAMP",
+                message_parameters={
+                    "timestamp": str(timestampMs),
+                    "watermark": str(self._event_time_watermark_ms),
+                },
             )
 
         self._timeout_timestamp = timestampMs
diff --git a/python/pyspark/sql/tests/streaming/test_streaming_foreach.py 
b/python/pyspark/sql/tests/streaming/test_streaming_foreach.py
index 8bd36020c9a..9ad3fee0972 100644
--- a/python/pyspark/sql/tests/streaming/test_streaming_foreach.py
+++ b/python/pyspark/sql/tests/streaming/test_streaming_foreach.py
@@ -248,9 +248,7 @@ class StreamingTestsForeach(ReusedSQLTestCase):
         class WriterWithNonCallableProcess:
             process = True
 
-        tester.assert_invalid_writer(
-            WriterWithNonCallableProcess(), "'process' in provided object is 
not callable"
-        )
+        tester.assert_invalid_writer(WriterWithNonCallableProcess(), 
"ATTRIBUTE_NOT_CALLABLE")
 
         class WriterWithNoParamProcess:
             def process(self):
@@ -266,9 +264,7 @@ class StreamingTestsForeach(ReusedSQLTestCase):
         class WriterWithNonCallableOpen(WithProcess):
             open = True
 
-        tester.assert_invalid_writer(
-            WriterWithNonCallableOpen(), "'open' in provided object is not 
callable"
-        )
+        tester.assert_invalid_writer(WriterWithNonCallableOpen(), 
"ATTRIBUTE_NOT_CALLABLE")
 
         class WriterWithNoParamOpen(WithProcess):
             def open(self):
@@ -279,9 +275,7 @@ class StreamingTestsForeach(ReusedSQLTestCase):
         class WriterWithNonCallableClose(WithProcess):
             close = True
 
-        tester.assert_invalid_writer(
-            WriterWithNonCallableClose(), "'close' in provided object is not 
callable"
-        )
+        tester.assert_invalid_writer(WriterWithNonCallableClose(), 
"ATTRIBUTE_NOT_CALLABLE")
 
 
 if __name__ == "__main__":


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to