This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 62d4bab3f2b3 [SPARK-46386][PYTHON] Improve assertions of observation 
(pyspark.sql.observation)
62d4bab3f2b3 is described below

commit 62d4bab3f2b30cbd5c87f0bb475f8b57e230e02e
Author: Xinrong Meng <[email protected]>
AuthorDate: Sat Dec 16 09:45:50 2023 -0800

    [SPARK-46386][PYTHON] Improve assertions of observation 
(pyspark.sql.observation)
    
    ### What changes were proposed in this pull request?
    Improve and test assertions of observation (pyspark.sql.observation).
    
    ### Why are the changes needed?
    Better error handling.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, PySparkAssertionError is raised in the cases below:
    
    ```py
    >>> observation = Observation()
    >>> observation.get()
    Traceback (most recent call last):
    ...
    pyspark.errors.exceptions.base.PySparkAssertionError: 
[NO_OBSERVE_BEFORE_GET] Should observe by calling `DataFrame.observe` before 
`get`.
    >>> df.observe(observation, count(lit(1)))
    DataFrame[id: bigint, val: double, label: string]
    >>> df.observe(observation, count(lit(1)))
    Traceback (most recent call last):
    ...
        raise PySparkAssertionError(error_class="REUSE_OBSERVATION", 
message_parameters={})
    pyspark.errors.exceptions.base.PySparkAssertionError: [REUSE_OBSERVATION] 
An Observation can be used with a DataFrame only once.
    ```
    
    ### How was this patch tested?
    Test change only.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #44324 from xinrong-meng/test_observe.
    
    Authored-by: Xinrong Meng <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/errors/error_classes.py     | 10 ++++++++++
 python/pyspark/sql/connect/observation.py  |  8 ++++++--
 python/pyspark/sql/observation.py          |  9 ++++++---
 python/pyspark/sql/tests/test_dataframe.py | 20 ++++++++++++++++++++
 4 files changed, 42 insertions(+), 5 deletions(-)

diff --git a/python/pyspark/errors/error_classes.py 
b/python/pyspark/errors/error_classes.py
index ffe5d692001c..bb2784812627 100644
--- a/python/pyspark/errors/error_classes.py
+++ b/python/pyspark/errors/error_classes.py
@@ -712,6 +712,11 @@ ERROR_CLASSES_JSON = """
       "No active Spark session found. Please create a new Spark session before 
running the code."
     ]
   },
+  "NO_OBSERVE_BEFORE_GET" : {
+    "message" : [
+      "Should observe by calling `DataFrame.observe` before `get`."
+    ]
+  },
   "NO_SCHEMA_AND_DRIVER_DEFAULT_SCHEME" : {
     "message" : [
       "Only allows <arg_name> to be a path without scheme, and Spark Driver 
should use the default scheme to determine the destination file system."
@@ -828,6 +833,11 @@ ERROR_CLASSES_JSON = """
       "The maximum number of retries has been exceeded."
     ]
   },
+  "REUSE_OBSERVATION" : {
+    "message" : [
+      "An Observation can be used with a DataFrame only once."
+    ]
+  },
   "SCHEMA_MISMATCH_FOR_PANDAS_UDF" : {
     "message" : [
       "Result vector from pandas_udf was not the required length: expected 
<expected>, got <actual>."
diff --git a/python/pyspark/sql/connect/observation.py 
b/python/pyspark/sql/connect/observation.py
index 174ab74c2506..d88a62009995 100644
--- a/python/pyspark/sql/connect/observation.py
+++ b/python/pyspark/sql/connect/observation.py
@@ -21,6 +21,7 @@ from pyspark.errors import (
     PySparkTypeError,
     PySparkValueError,
     IllegalArgumentException,
+    PySparkAssertionError,
 )
 from pyspark.sql.connect.column import Column
 from pyspark.sql.connect.dataframe import DataFrame
@@ -50,7 +51,8 @@ class Observation:
     __init__.__doc__ = PySparkObservation.__init__.__doc__
 
     def _on(self, df: DataFrame, *exprs: Column) -> DataFrame:
-        assert self._result is None, "an Observation can be used with a 
DataFrame only once"
+        if self._result is not None:
+            raise PySparkAssertionError(error_class="REUSE_OBSERVATION", 
message_parameters={})
 
         if self._name is None:
             self._name = str(uuid.uuid4())
@@ -68,7 +70,9 @@ class Observation:
 
     @property
     def get(self) -> Dict[str, Any]:
-        assert self._result is not None
+        if self._result is None:
+            raise PySparkAssertionError(error_class="NO_OBSERVE_BEFORE_GET", 
message_parameters={})
+
         return self._result
 
     get.__doc__ = PySparkObservation.get.__doc__
diff --git a/python/pyspark/sql/observation.py 
b/python/pyspark/sql/observation.py
index ecb21e8d9084..f12d1250cba2 100644
--- a/python/pyspark/sql/observation.py
+++ b/python/pyspark/sql/observation.py
@@ -19,7 +19,7 @@ from typing import Any, Dict, Optional
 
 from py4j.java_gateway import JavaObject, JVMView
 
-from pyspark.errors import PySparkTypeError, PySparkValueError
+from pyspark.errors import PySparkTypeError, PySparkValueError, 
PySparkAssertionError
 from pyspark.sql import column
 from pyspark.sql.column import Column
 from pyspark.sql.dataframe import DataFrame
@@ -114,7 +114,8 @@ class Observation:
         :class:`DataFrame`
             the observed :class:`DataFrame`.
         """
-        assert self._jo is None, "an Observation can be used with a DataFrame 
only once"
+        if self._jo is not None:
+            raise PySparkAssertionError(error_class="REUSE_OBSERVATION", 
message_parameters={})
 
         self._jvm = df._sc._jvm
         assert self._jvm is not None
@@ -137,7 +138,9 @@ class Observation:
         dict
             the observed metrics
         """
-        assert self._jo is not None, "call DataFrame.observe"
+        if self._jo is None:
+            raise PySparkAssertionError(error_class="NO_OBSERVE_BEFORE_GET", 
message_parameters={})
+
         jmap = self._jo.getAsJava()
         # return a pure Python dict, not jmap which is a py4j JavaMap
         return {k: v for k, v in jmap.items()}
diff --git a/python/pyspark/sql/tests/test_dataframe.py 
b/python/pyspark/sql/tests/test_dataframe.py
index e1df01116e18..692cf77d9afb 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -47,6 +47,7 @@ from pyspark.storagelevel import StorageLevel
 from pyspark.errors import (
     AnalysisException,
     IllegalArgumentException,
+    PySparkAssertionError,
     PySparkTypeError,
     PySparkValueError,
 )
@@ -977,6 +978,16 @@ class DataFrameTestsMixin:
 
         unnamed_observation = Observation()
         named_observation = Observation("metric")
+
+        with self.assertRaises(PySparkAssertionError) as pe:
+            unnamed_observation.get()
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="NO_OBSERVE_BEFORE_GET",
+            message_parameters={},
+        )
+
         observed = (
             df.orderBy("id")
             .observe(
@@ -1003,6 +1014,15 @@ class DataFrameTestsMixin:
         self.assertEqual(named_observation.get, dict(cnt=3, sum=6, mean=2.0))
         self.assertEqual(unnamed_observation.get, dict(rows=3))
 
+        with self.assertRaises(PySparkAssertionError) as pe:
+            df.observe(named_observation, count(lit(1)).alias("count"))
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="REUSE_OBSERVATION",
+            message_parameters={},
+        )
+
         # observation requires name (if given) to be non empty string
         with self.assertRaisesRegex(TypeError, "`name` should be a str, got 
int"):
             Observation(123)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to