This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 62d4bab3f2b3 [SPARK-46386][PYTHON] Improve assertions of observation
(pyspark.sql.observation)
62d4bab3f2b3 is described below
commit 62d4bab3f2b30cbd5c87f0bb475f8b57e230e02e
Author: Xinrong Meng <[email protected]>
AuthorDate: Sat Dec 16 09:45:50 2023 -0800
[SPARK-46386][PYTHON] Improve assertions of observation
(pyspark.sql.observation)
### What changes were proposed in this pull request?
Improve and test assertions of observation (pyspark.sql.observation).
### Why are the changes needed?
Better error handling.
### Does this PR introduce _any_ user-facing change?
Yes, PySparkAssertionError is raised in the cases below:
```py
>>> observation = Observation()
>>> observation.get()
Traceback (most recent call last):
...
pyspark.errors.exceptions.base.PySparkAssertionError:
[NO_OBSERVE_BEFORE_GET] Should observe by calling `DataFrame.observe` before
`get`.
>>> df.observe(observation, count(lit(1)))
DataFrame[id: bigint, val: double, label: string]
>>> df.observe(observation, count(lit(1)))
Traceback (most recent call last):
...
raise PySparkAssertionError(error_class="REUSE_OBSERVATION",
message_parameters={})
pyspark.errors.exceptions.base.PySparkAssertionError: [REUSE_OBSERVATION]
An Observation can be used with a DataFrame only once.
```
### How was this patch tested?
Test change only.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #44324 from xinrong-meng/test_observe.
Authored-by: Xinrong Meng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/errors/error_classes.py | 10 ++++++++++
python/pyspark/sql/connect/observation.py | 8 ++++++--
python/pyspark/sql/observation.py | 9 ++++++---
python/pyspark/sql/tests/test_dataframe.py | 20 ++++++++++++++++++++
4 files changed, 42 insertions(+), 5 deletions(-)
diff --git a/python/pyspark/errors/error_classes.py
b/python/pyspark/errors/error_classes.py
index ffe5d692001c..bb2784812627 100644
--- a/python/pyspark/errors/error_classes.py
+++ b/python/pyspark/errors/error_classes.py
@@ -712,6 +712,11 @@ ERROR_CLASSES_JSON = """
"No active Spark session found. Please create a new Spark session before
running the code."
]
},
+ "NO_OBSERVE_BEFORE_GET" : {
+ "message" : [
+ "Should observe by calling `DataFrame.observe` before `get`."
+ ]
+ },
"NO_SCHEMA_AND_DRIVER_DEFAULT_SCHEME" : {
"message" : [
"Only allows <arg_name> to be a path without scheme, and Spark Driver
should use the default scheme to determine the destination file system."
@@ -828,6 +833,11 @@ ERROR_CLASSES_JSON = """
"The maximum number of retries has been exceeded."
]
},
+ "REUSE_OBSERVATION" : {
+ "message" : [
+ "An Observation can be used with a DataFrame only once."
+ ]
+ },
"SCHEMA_MISMATCH_FOR_PANDAS_UDF" : {
"message" : [
"Result vector from pandas_udf was not the required length: expected
<expected>, got <actual>."
diff --git a/python/pyspark/sql/connect/observation.py
b/python/pyspark/sql/connect/observation.py
index 174ab74c2506..d88a62009995 100644
--- a/python/pyspark/sql/connect/observation.py
+++ b/python/pyspark/sql/connect/observation.py
@@ -21,6 +21,7 @@ from pyspark.errors import (
PySparkTypeError,
PySparkValueError,
IllegalArgumentException,
+ PySparkAssertionError,
)
from pyspark.sql.connect.column import Column
from pyspark.sql.connect.dataframe import DataFrame
@@ -50,7 +51,8 @@ class Observation:
__init__.__doc__ = PySparkObservation.__init__.__doc__
def _on(self, df: DataFrame, *exprs: Column) -> DataFrame:
- assert self._result is None, "an Observation can be used with a
DataFrame only once"
+ if self._result is not None:
+ raise PySparkAssertionError(error_class="REUSE_OBSERVATION",
message_parameters={})
if self._name is None:
self._name = str(uuid.uuid4())
@@ -68,7 +70,9 @@ class Observation:
@property
def get(self) -> Dict[str, Any]:
- assert self._result is not None
+ if self._result is None:
+ raise PySparkAssertionError(error_class="NO_OBSERVE_BEFORE_GET",
message_parameters={})
+
return self._result
get.__doc__ = PySparkObservation.get.__doc__
diff --git a/python/pyspark/sql/observation.py
b/python/pyspark/sql/observation.py
index ecb21e8d9084..f12d1250cba2 100644
--- a/python/pyspark/sql/observation.py
+++ b/python/pyspark/sql/observation.py
@@ -19,7 +19,7 @@ from typing import Any, Dict, Optional
from py4j.java_gateway import JavaObject, JVMView
-from pyspark.errors import PySparkTypeError, PySparkValueError
+from pyspark.errors import PySparkTypeError, PySparkValueError,
PySparkAssertionError
from pyspark.sql import column
from pyspark.sql.column import Column
from pyspark.sql.dataframe import DataFrame
@@ -114,7 +114,8 @@ class Observation:
:class:`DataFrame`
the observed :class:`DataFrame`.
"""
- assert self._jo is None, "an Observation can be used with a DataFrame
only once"
+ if self._jo is not None:
+ raise PySparkAssertionError(error_class="REUSE_OBSERVATION",
message_parameters={})
self._jvm = df._sc._jvm
assert self._jvm is not None
@@ -137,7 +138,9 @@ class Observation:
dict
the observed metrics
"""
- assert self._jo is not None, "call DataFrame.observe"
+ if self._jo is None:
+ raise PySparkAssertionError(error_class="NO_OBSERVE_BEFORE_GET",
message_parameters={})
+
jmap = self._jo.getAsJava()
# return a pure Python dict, not jmap which is a py4j JavaMap
return {k: v for k, v in jmap.items()}
diff --git a/python/pyspark/sql/tests/test_dataframe.py
b/python/pyspark/sql/tests/test_dataframe.py
index e1df01116e18..692cf77d9afb 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -47,6 +47,7 @@ from pyspark.storagelevel import StorageLevel
from pyspark.errors import (
AnalysisException,
IllegalArgumentException,
+ PySparkAssertionError,
PySparkTypeError,
PySparkValueError,
)
@@ -977,6 +978,16 @@ class DataFrameTestsMixin:
unnamed_observation = Observation()
named_observation = Observation("metric")
+
+ with self.assertRaises(PySparkAssertionError) as pe:
+ unnamed_observation.get()
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="NO_OBSERVE_BEFORE_GET",
+ message_parameters={},
+ )
+
observed = (
df.orderBy("id")
.observe(
@@ -1003,6 +1014,15 @@ class DataFrameTestsMixin:
self.assertEqual(named_observation.get, dict(cnt=3, sum=6, mean=2.0))
self.assertEqual(unnamed_observation.get, dict(rows=3))
+ with self.assertRaises(PySparkAssertionError) as pe:
+ df.observe(named_observation, count(lit(1)).alias("count"))
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="REUSE_OBSERVATION",
+ message_parameters={},
+ )
+
# observation requires name (if given) to be non empty string
with self.assertRaisesRegex(TypeError, "`name` should be a str, got
int"):
Observation(123)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]