This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 3e9033a63c5f [SPARK-45913][PYTHON] Make the internal attributes
private from PySpark errors
3e9033a63c5f is described below
commit 3e9033a63c5f717296e212323b0d45f9d2afb059
Author: Haejoon Lee <[email protected]>
AuthorDate: Wed Nov 15 10:57:18 2023 +0900
[SPARK-45913][PYTHON] Make the internal attributes private from PySpark
errors
### What changes were proposed in this pull request?
This PR proposes to hide the internal attributes from errors.
### Why are the changes needed?
We should expose only user APIs to the surface, e.g. `getErrorClass`,
`getMessageParameters` and `getSqlState`, but currently there are some internal
attributes such as `error_reader`, `message`, `error_class` and
`message_parameters` are exposed as an user API as below:
<img width="680" alt="Screenshot 2023-11-14 at 9 09 54 AM"
src="https://github.com/apache/spark/assets/44108233/86f9a8c7-b11e-4935-ba62-930b39ffd2c0">
We should hide them by adding underscore.
### Does this PR introduce _any_ user-facing change?
No API changes, but the internal attributes are explicitly marked as
private from user space:
<img width="676" alt="Screenshot 2023-11-14 at 9 17 41 AM"
src="https://github.com/apache/spark/assets/44108233/0c342bcc-14cc-4836-93cf-5ed0ba2cc380">
### How was this patch tested?
The existing CI should pass
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #43790 from itholic/SPARK-45913.
Authored-by: Haejoon Lee <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/errors/exceptions/base.py | 18 +++++++++---------
python/pyspark/errors/exceptions/captured.py | 14 +++++++-------
python/pyspark/errors/exceptions/connect.py | 10 +++++-----
python/pyspark/sql/connect/udf.py | 2 +-
python/pyspark/sql/connect/udtf.py | 2 +-
python/pyspark/sql/tests/connect/test_connect_basic.py | 8 ++++----
python/pyspark/sql/tests/streaming/test_streaming.py | 10 +++++-----
python/pyspark/sql/tests/test_utils.py | 4 ++--
8 files changed, 34 insertions(+), 34 deletions(-)
diff --git a/python/pyspark/errors/exceptions/base.py
b/python/pyspark/errors/exceptions/base.py
index 5ab73b63d362..f37a621d3010 100644
--- a/python/pyspark/errors/exceptions/base.py
+++ b/python/pyspark/errors/exceptions/base.py
@@ -40,17 +40,17 @@ class PySparkException(Exception):
message is None and (error_class is not None and
message_parameters is not None)
)
- self.error_reader = ErrorClassesReader()
+ self._error_reader = ErrorClassesReader()
if message is None:
- self.message = self.error_reader.get_error_message(
+ self._message = self._error_reader.get_error_message(
cast(str, error_class), cast(Dict[str, str],
message_parameters)
)
else:
- self.message = message
+ self._message = message
- self.error_class = error_class
- self.message_parameters = message_parameters
+ self._error_class = error_class
+ self._message_parameters = message_parameters
def getErrorClass(self) -> Optional[str]:
"""
@@ -63,7 +63,7 @@ class PySparkException(Exception):
:meth:`PySparkException.getMessageParameters`
:meth:`PySparkException.getSqlState`
"""
- return self.error_class
+ return self._error_class
def getMessageParameters(self) -> Optional[Dict[str, str]]:
"""
@@ -76,7 +76,7 @@ class PySparkException(Exception):
:meth:`PySparkException.getErrorClass`
:meth:`PySparkException.getSqlState`
"""
- return self.message_parameters
+ return self._message_parameters
def getSqlState(self) -> Optional[str]:
"""
@@ -95,9 +95,9 @@ class PySparkException(Exception):
def __str__(self) -> str:
if self.getErrorClass() is not None:
- return f"[{self.getErrorClass()}] {self.message}"
+ return f"[{self.getErrorClass()}] {self._message}"
else:
- return self.message
+ return self._message
class AnalysisException(PySparkException):
diff --git a/python/pyspark/errors/exceptions/captured.py
b/python/pyspark/errors/exceptions/captured.py
index 55ed7ab3a6d5..8eba174eea3c 100644
--- a/python/pyspark/errors/exceptions/captured.py
+++ b/python/pyspark/errors/exceptions/captured.py
@@ -55,16 +55,16 @@ class CapturedException(PySparkException):
origin is None and desc is not None and stackTrace is not None
)
- self.desc = desc if desc is not None else cast(Py4JJavaError,
origin).getMessage()
+ self._desc = desc if desc is not None else cast(Py4JJavaError,
origin).getMessage()
assert SparkContext._jvm is not None
- self.stackTrace = (
+ self._stackTrace = (
stackTrace
if stackTrace is not None
else
(SparkContext._jvm.org.apache.spark.util.Utils.exceptionString(origin))
)
- self.cause = convert_exception(cause) if cause is not None else None
- if self.cause is None and origin is not None and origin.getCause() is
not None:
- self.cause = convert_exception(origin.getCause())
+ self._cause = convert_exception(cause) if cause is not None else None
+ if self._cause is None and origin is not None and origin.getCause() is
not None:
+ self._cause = convert_exception(origin.getCause())
self._origin = origin
def __str__(self) -> str:
@@ -80,9 +80,9 @@ class CapturedException(PySparkException):
except BaseException:
pass
- desc = self.desc
+ desc = self._desc
if debug_enabled:
- desc = desc + "\n\nJVM stacktrace:\n%s" % self.stackTrace
+ desc = desc + "\n\nJVM stacktrace:\n%s" % self._stackTrace
return str(desc)
def getErrorClass(self) -> Optional[str]:
diff --git a/python/pyspark/errors/exceptions/connect.py
b/python/pyspark/errors/exceptions/connect.py
index 2558c425469a..bb0faa7d344b 100644
--- a/python/pyspark/errors/exceptions/connect.py
+++ b/python/pyspark/errors/exceptions/connect.py
@@ -239,9 +239,9 @@ class SparkConnectGrpcException(SparkConnectException):
server_stacktrace: Optional[str] = None,
display_server_stacktrace: bool = False,
) -> None:
- self.message = message # type: ignore[assignment]
+ self._message = message # type: ignore[assignment]
if reason is not None:
- self.message = f"({reason}) {self.message}"
+ self._message = f"({reason}) {self._message}"
# PySparkException has the assumption that error_class and
message_parameters are
# only occurring together. If only one is set, we assume the message
to be fully
@@ -254,11 +254,11 @@ class SparkConnectGrpcException(SparkConnectException):
tmp_message_parameters = None
super().__init__(
- message=self.message,
+ message=self._message,
error_class=tmp_error_class,
message_parameters=tmp_message_parameters,
)
- self.error_class = error_class
+ self._error_class = error_class
self._sql_state: Optional[str] = sql_state
self._stacktrace: Optional[str] = server_stacktrace
self._display_stacktrace: bool = display_server_stacktrace
@@ -273,7 +273,7 @@ class SparkConnectGrpcException(SparkConnectException):
return self._stacktrace
def __str__(self) -> str:
- desc = self.message
+ desc = self._message
if self._display_stacktrace:
desc += "\n\nJVM stacktrace:\n%s" % self._stacktrace
return desc
diff --git a/python/pyspark/sql/connect/udf.py
b/python/pyspark/sql/connect/udf.py
index 90cea26e56f5..5386398bdca8 100644
--- a/python/pyspark/sql/connect/udf.py
+++ b/python/pyspark/sql/connect/udf.py
@@ -67,7 +67,7 @@ def _create_py_udf(
== "true"
)
except PySparkRuntimeError as e:
- if e.error_class == "NO_ACTIVE_OR_DEFAULT_SESSION":
+ if e.getErrorClass() == "NO_ACTIVE_OR_DEFAULT_SESSION":
pass # Just uses the default if no session found.
else:
raise e
diff --git a/python/pyspark/sql/connect/udtf.py
b/python/pyspark/sql/connect/udtf.py
index 16f9b990760d..f0facbd1a702 100644
--- a/python/pyspark/sql/connect/udtf.py
+++ b/python/pyspark/sql/connect/udtf.py
@@ -78,7 +78,7 @@ def _create_py_udtf(
== "true"
)
except PySparkRuntimeError as e:
- if e.error_class == "NO_ACTIVE_OR_DEFAULT_SESSION":
+ if e.getErrorClass() == "NO_ACTIVE_OR_DEFAULT_SESSION":
pass # Just uses the default if no session found.
else:
raise e
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index d2febcd6b089..d4fb2d92fbb4 100755
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -3368,8 +3368,8 @@ class SparkConnectSessionTests(ReusedConnectTestCase):
name = "test" * 10000
with self.assertRaises(AnalysisException) as e:
self.spark.sql("select " + name).collect()
- self.assertTrue(name in e.exception.message)
- self.assertFalse("JVM stacktrace" in e.exception.message)
+ self.assertTrue(name in e.exception._message)
+ self.assertFalse("JVM stacktrace" in e.exception._message)
def test_error_enrichment_jvm_stacktrace(self):
with self.sql_conf(
@@ -3384,7 +3384,7 @@ class SparkConnectSessionTests(ReusedConnectTestCase):
"""select from_json(
'{"d": "02-29"}', 'd date', map('dateFormat',
'MM-dd'))"""
).collect()
- self.assertFalse("JVM stacktrace" in e.exception.message)
+ self.assertFalse("JVM stacktrace" in e.exception._message)
with self.sql_conf({"spark.sql.connect.serverStacktrace.enabled":
True}):
with self.assertRaises(SparkUpgradeException) as e:
@@ -3655,7 +3655,7 @@ class ChannelBuilderTests(unittest.TestCase):
chan = ChannelBuilder(f"sc://host/;user_agent={user_agent}")
with self.assertRaises(SparkConnectException) as err:
chan.userAgent
- self.assertRegex(err.exception.message, "'user_agent' parameter should
not exceed")
+ self.assertRegex(err.exception._message, "'user_agent' parameter
should not exceed")
user_agent = "%C3%A4" * 341 # "%C3%A4" -> "ä"; (341 * 6 = 2046) < 2048
expected = "ä" * 341
diff --git a/python/pyspark/sql/tests/streaming/test_streaming.py
b/python/pyspark/sql/tests/streaming/test_streaming.py
index 0eea86dc7375..b51b058ca974 100644
--- a/python/pyspark/sql/tests/streaming/test_streaming.py
+++ b/python/pyspark/sql/tests/streaming/test_streaming.py
@@ -303,16 +303,16 @@ class StreamingTestsMixin:
def _assert_exception_tree_contains_msg_connect(self, exception, msg):
self.assertTrue(
- msg in exception.message,
+ msg in exception._message,
"Exception tree doesn't contain the expected message: %s" % msg,
)
def _assert_exception_tree_contains_msg_default(self, exception, msg):
e = exception
- contains = msg in e.desc
- while e.cause is not None and not contains:
- e = e.cause
- contains = msg in e.desc
+ contains = msg in e._desc
+ while e._cause is not None and not contains:
+ e = e._cause
+ contains = msg in e._desc
self.assertTrue(contains, "Exception tree doesn't contain the expected
message: %s" % msg)
def test_query_manager_get(self):
diff --git a/python/pyspark/sql/tests/test_utils.py
b/python/pyspark/sql/tests/test_utils.py
index ebdab31ec207..ccfadf0b8dbd 100644
--- a/python/pyspark/sql/tests/test_utils.py
+++ b/python/pyspark/sql/tests/test_utils.py
@@ -1739,8 +1739,8 @@ class UtilsTests(ReusedSQLTestCase, UtilsTestsMixin):
try:
df.select(sha2(df.a, 1024)).collect()
except IllegalArgumentException as e:
- self.assertRegex(e.desc, "1024 is not in the permitted values")
- self.assertRegex(e.stackTrace, "org.apache.spark.sql.functions")
+ self.assertRegex(e._desc, "1024 is not in the permitted values")
+ self.assertRegex(e._stackTrace, "org.apache.spark.sql.functions")
def test_get_error_class_state(self):
# SPARK-36953: test CapturedException.getErrorClass and getSqlState
(from SparkThrowable)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]