This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3e9033a63c5f [SPARK-45913][PYTHON] Make the internal attributes 
private from PySpark errors
3e9033a63c5f is described below

commit 3e9033a63c5f717296e212323b0d45f9d2afb059
Author: Haejoon Lee <[email protected]>
AuthorDate: Wed Nov 15 10:57:18 2023 +0900

    [SPARK-45913][PYTHON] Make the internal attributes private from PySpark 
errors
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to hide the internal attributes from errors.
    
    ### Why are the changes needed?
    
    We should expose only user APIs to the surface, e.g. `getErrorClass`, 
`getMessageParameters` and `getSqlState`, but currently there are some internal 
attributes such as `error_reader`, `message`, `error_class` and 
`message_parameters` are exposed as an user API as below:
    <img width="680" alt="Screenshot 2023-11-14 at 9 09 54 AM" 
src="https://github.com/apache/spark/assets/44108233/86f9a8c7-b11e-4935-ba62-930b39ffd2c0";>
    
    We should hide them by adding underscore.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No API changes, but the internal attributes are explicitly marked as 
private from user space:
    
    <img width="676" alt="Screenshot 2023-11-14 at 9 17 41 AM" 
src="https://github.com/apache/spark/assets/44108233/0c342bcc-14cc-4836-93cf-5ed0ba2cc380";>
    
    ### How was this patch tested?
    
    The existing CI should pass
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #43790 from itholic/SPARK-45913.
    
    Authored-by: Haejoon Lee <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/errors/exceptions/base.py               | 18 +++++++++---------
 python/pyspark/errors/exceptions/captured.py           | 14 +++++++-------
 python/pyspark/errors/exceptions/connect.py            | 10 +++++-----
 python/pyspark/sql/connect/udf.py                      |  2 +-
 python/pyspark/sql/connect/udtf.py                     |  2 +-
 python/pyspark/sql/tests/connect/test_connect_basic.py |  8 ++++----
 python/pyspark/sql/tests/streaming/test_streaming.py   | 10 +++++-----
 python/pyspark/sql/tests/test_utils.py                 |  4 ++--
 8 files changed, 34 insertions(+), 34 deletions(-)

diff --git a/python/pyspark/errors/exceptions/base.py 
b/python/pyspark/errors/exceptions/base.py
index 5ab73b63d362..f37a621d3010 100644
--- a/python/pyspark/errors/exceptions/base.py
+++ b/python/pyspark/errors/exceptions/base.py
@@ -40,17 +40,17 @@ class PySparkException(Exception):
             message is None and (error_class is not None and 
message_parameters is not None)
         )
 
-        self.error_reader = ErrorClassesReader()
+        self._error_reader = ErrorClassesReader()
 
         if message is None:
-            self.message = self.error_reader.get_error_message(
+            self._message = self._error_reader.get_error_message(
                 cast(str, error_class), cast(Dict[str, str], 
message_parameters)
             )
         else:
-            self.message = message
+            self._message = message
 
-        self.error_class = error_class
-        self.message_parameters = message_parameters
+        self._error_class = error_class
+        self._message_parameters = message_parameters
 
     def getErrorClass(self) -> Optional[str]:
         """
@@ -63,7 +63,7 @@ class PySparkException(Exception):
         :meth:`PySparkException.getMessageParameters`
         :meth:`PySparkException.getSqlState`
         """
-        return self.error_class
+        return self._error_class
 
     def getMessageParameters(self) -> Optional[Dict[str, str]]:
         """
@@ -76,7 +76,7 @@ class PySparkException(Exception):
         :meth:`PySparkException.getErrorClass`
         :meth:`PySparkException.getSqlState`
         """
-        return self.message_parameters
+        return self._message_parameters
 
     def getSqlState(self) -> Optional[str]:
         """
@@ -95,9 +95,9 @@ class PySparkException(Exception):
 
     def __str__(self) -> str:
         if self.getErrorClass() is not None:
-            return f"[{self.getErrorClass()}] {self.message}"
+            return f"[{self.getErrorClass()}] {self._message}"
         else:
-            return self.message
+            return self._message
 
 
 class AnalysisException(PySparkException):
diff --git a/python/pyspark/errors/exceptions/captured.py 
b/python/pyspark/errors/exceptions/captured.py
index 55ed7ab3a6d5..8eba174eea3c 100644
--- a/python/pyspark/errors/exceptions/captured.py
+++ b/python/pyspark/errors/exceptions/captured.py
@@ -55,16 +55,16 @@ class CapturedException(PySparkException):
             origin is None and desc is not None and stackTrace is not None
         )
 
-        self.desc = desc if desc is not None else cast(Py4JJavaError, 
origin).getMessage()
+        self._desc = desc if desc is not None else cast(Py4JJavaError, 
origin).getMessage()
         assert SparkContext._jvm is not None
-        self.stackTrace = (
+        self._stackTrace = (
             stackTrace
             if stackTrace is not None
             else 
(SparkContext._jvm.org.apache.spark.util.Utils.exceptionString(origin))
         )
-        self.cause = convert_exception(cause) if cause is not None else None
-        if self.cause is None and origin is not None and origin.getCause() is 
not None:
-            self.cause = convert_exception(origin.getCause())
+        self._cause = convert_exception(cause) if cause is not None else None
+        if self._cause is None and origin is not None and origin.getCause() is 
not None:
+            self._cause = convert_exception(origin.getCause())
         self._origin = origin
 
     def __str__(self) -> str:
@@ -80,9 +80,9 @@ class CapturedException(PySparkException):
         except BaseException:
             pass
 
-        desc = self.desc
+        desc = self._desc
         if debug_enabled:
-            desc = desc + "\n\nJVM stacktrace:\n%s" % self.stackTrace
+            desc = desc + "\n\nJVM stacktrace:\n%s" % self._stackTrace
         return str(desc)
 
     def getErrorClass(self) -> Optional[str]:
diff --git a/python/pyspark/errors/exceptions/connect.py 
b/python/pyspark/errors/exceptions/connect.py
index 2558c425469a..bb0faa7d344b 100644
--- a/python/pyspark/errors/exceptions/connect.py
+++ b/python/pyspark/errors/exceptions/connect.py
@@ -239,9 +239,9 @@ class SparkConnectGrpcException(SparkConnectException):
         server_stacktrace: Optional[str] = None,
         display_server_stacktrace: bool = False,
     ) -> None:
-        self.message = message  # type: ignore[assignment]
+        self._message = message  # type: ignore[assignment]
         if reason is not None:
-            self.message = f"({reason}) {self.message}"
+            self._message = f"({reason}) {self._message}"
 
         # PySparkException has the assumption that error_class and 
message_parameters are
         # only occurring together. If only one is set, we assume the message 
to be fully
@@ -254,11 +254,11 @@ class SparkConnectGrpcException(SparkConnectException):
             tmp_message_parameters = None
 
         super().__init__(
-            message=self.message,
+            message=self._message,
             error_class=tmp_error_class,
             message_parameters=tmp_message_parameters,
         )
-        self.error_class = error_class
+        self._error_class = error_class
         self._sql_state: Optional[str] = sql_state
         self._stacktrace: Optional[str] = server_stacktrace
         self._display_stacktrace: bool = display_server_stacktrace
@@ -273,7 +273,7 @@ class SparkConnectGrpcException(SparkConnectException):
         return self._stacktrace
 
     def __str__(self) -> str:
-        desc = self.message
+        desc = self._message
         if self._display_stacktrace:
             desc += "\n\nJVM stacktrace:\n%s" % self._stacktrace
         return desc
diff --git a/python/pyspark/sql/connect/udf.py 
b/python/pyspark/sql/connect/udf.py
index 90cea26e56f5..5386398bdca8 100644
--- a/python/pyspark/sql/connect/udf.py
+++ b/python/pyspark/sql/connect/udf.py
@@ -67,7 +67,7 @@ def _create_py_udf(
                 == "true"
             )
         except PySparkRuntimeError as e:
-            if e.error_class == "NO_ACTIVE_OR_DEFAULT_SESSION":
+            if e.getErrorClass() == "NO_ACTIVE_OR_DEFAULT_SESSION":
                 pass  # Just uses the default if no session found.
             else:
                 raise e
diff --git a/python/pyspark/sql/connect/udtf.py 
b/python/pyspark/sql/connect/udtf.py
index 16f9b990760d..f0facbd1a702 100644
--- a/python/pyspark/sql/connect/udtf.py
+++ b/python/pyspark/sql/connect/udtf.py
@@ -78,7 +78,7 @@ def _create_py_udtf(
                 == "true"
             )
         except PySparkRuntimeError as e:
-            if e.error_class == "NO_ACTIVE_OR_DEFAULT_SESSION":
+            if e.getErrorClass() == "NO_ACTIVE_OR_DEFAULT_SESSION":
                 pass  # Just uses the default if no session found.
             else:
                 raise e
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index d2febcd6b089..d4fb2d92fbb4 100755
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -3368,8 +3368,8 @@ class SparkConnectSessionTests(ReusedConnectTestCase):
             name = "test" * 10000
             with self.assertRaises(AnalysisException) as e:
                 self.spark.sql("select " + name).collect()
-            self.assertTrue(name in e.exception.message)
-            self.assertFalse("JVM stacktrace" in e.exception.message)
+            self.assertTrue(name in e.exception._message)
+            self.assertFalse("JVM stacktrace" in e.exception._message)
 
     def test_error_enrichment_jvm_stacktrace(self):
         with self.sql_conf(
@@ -3384,7 +3384,7 @@ class SparkConnectSessionTests(ReusedConnectTestCase):
                         """select from_json(
                             '{"d": "02-29"}', 'd date', map('dateFormat', 
'MM-dd'))"""
                     ).collect()
-                self.assertFalse("JVM stacktrace" in e.exception.message)
+                self.assertFalse("JVM stacktrace" in e.exception._message)
 
             with self.sql_conf({"spark.sql.connect.serverStacktrace.enabled": 
True}):
                 with self.assertRaises(SparkUpgradeException) as e:
@@ -3655,7 +3655,7 @@ class ChannelBuilderTests(unittest.TestCase):
         chan = ChannelBuilder(f"sc://host/;user_agent={user_agent}")
         with self.assertRaises(SparkConnectException) as err:
             chan.userAgent
-        self.assertRegex(err.exception.message, "'user_agent' parameter should 
not exceed")
+        self.assertRegex(err.exception._message, "'user_agent' parameter 
should not exceed")
 
         user_agent = "%C3%A4" * 341  # "%C3%A4" -> "ä"; (341 * 6 = 2046) < 2048
         expected = "ä" * 341
diff --git a/python/pyspark/sql/tests/streaming/test_streaming.py 
b/python/pyspark/sql/tests/streaming/test_streaming.py
index 0eea86dc7375..b51b058ca974 100644
--- a/python/pyspark/sql/tests/streaming/test_streaming.py
+++ b/python/pyspark/sql/tests/streaming/test_streaming.py
@@ -303,16 +303,16 @@ class StreamingTestsMixin:
 
     def _assert_exception_tree_contains_msg_connect(self, exception, msg):
         self.assertTrue(
-            msg in exception.message,
+            msg in exception._message,
             "Exception tree doesn't contain the expected message: %s" % msg,
         )
 
     def _assert_exception_tree_contains_msg_default(self, exception, msg):
         e = exception
-        contains = msg in e.desc
-        while e.cause is not None and not contains:
-            e = e.cause
-            contains = msg in e.desc
+        contains = msg in e._desc
+        while e._cause is not None and not contains:
+            e = e._cause
+            contains = msg in e._desc
         self.assertTrue(contains, "Exception tree doesn't contain the expected 
message: %s" % msg)
 
     def test_query_manager_get(self):
diff --git a/python/pyspark/sql/tests/test_utils.py 
b/python/pyspark/sql/tests/test_utils.py
index ebdab31ec207..ccfadf0b8dbd 100644
--- a/python/pyspark/sql/tests/test_utils.py
+++ b/python/pyspark/sql/tests/test_utils.py
@@ -1739,8 +1739,8 @@ class UtilsTests(ReusedSQLTestCase, UtilsTestsMixin):
         try:
             df.select(sha2(df.a, 1024)).collect()
         except IllegalArgumentException as e:
-            self.assertRegex(e.desc, "1024 is not in the permitted values")
-            self.assertRegex(e.stackTrace, "org.apache.spark.sql.functions")
+            self.assertRegex(e._desc, "1024 is not in the permitted values")
+            self.assertRegex(e._stackTrace, "org.apache.spark.sql.functions")
 
     def test_get_error_class_state(self):
         # SPARK-36953: test CapturedException.getErrorClass and getSqlState 
(from SparkThrowable)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to