This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new afa7f3d1bb8 [SPARK-43323][SQL][PYTHON] Fix DataFrame.toPandas with 
Arrow enabled to handle exceptions properly
afa7f3d1bb8 is described below

commit afa7f3d1bb865e319b0ca7e295a9c8bf4106ae0a
Author: Takuya UESHIN <[email protected]>
AuthorDate: Tue May 2 08:16:51 2023 +0900

    [SPARK-43323][SQL][PYTHON] Fix DataFrame.toPandas with Arrow enabled to 
handle exceptions properly
    
    ### What changes were proposed in this pull request?
    
    Fixes `DataFrame.toPandas` with Arrow enabled to handle exceptions properly.
    
    ```py
    >>> spark.conf.set("spark.sql.ansi.enabled", True)
    >>> spark.conf.set('spark.sql.execution.arrow.pyspark.enabled', True)
    >>> spark.sql("select 1/0").toPandas()
    ...
    Traceback (most recent call last):
    ...
    pyspark.errors.exceptions.captured.ArithmeticException: [DIVIDE_BY_ZERO] 
Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL 
instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this 
error.
    == SQL(line 1, position 8) ==
    select 1/0
           ^^^
    
    ```
    
    ### Why are the changes needed?
    
    Currently `DataFrame.toPandas` doesn't capture exceptions happened in Spark 
properly.
    
    ```py
    >>> spark.conf.set("spark.sql.ansi.enabled", True)
    >>> spark.conf.set('spark.sql.execution.arrow.pyspark.enabled', True)
    >>> spark.sql("select 1/0").toPandas()
    ...
      An error occurred while calling o53.getResult.
    : org.apache.spark.SparkException: Exception thrown in awaitResult:
            at 
org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:322)
    ...
    ```
    
    because `jsocket_auth_server.getResult()` always wraps the thrown 
exceptions with `SparkException` that won't be captured.
    
    Whereas without Arrow:
    
    ```py
    >>> spark.conf.set("spark.sql.ansi.enabled", True)
    >>> spark.conf.set('spark.sql.execution.arrow.pyspark.enabled', False)
    >>> spark.sql("select 1/0").toPandas()
    Traceback (most recent call last):
    ...
    pyspark.errors.exceptions.captured.ArithmeticException: [DIVIDE_BY_ZERO] 
Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL 
instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this 
error.
    == SQL(line 1, position 8) ==
    select 1/0
           ^^^
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    `DataFrame.toPandas` with Arrow enabled will show a proper exception.
    
    ### How was this patch tested?
    
    Added the related test.
    
    Closes #40998 from ueshin/issues/SPARK-43323/getResult.
    
    Authored-by: Takuya UESHIN <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/errors/exceptions/captured.py         | 20 ++++++++++++++++++--
 python/pyspark/sql/pandas/conversion.py              |  6 ++++--
 .../pyspark/sql/tests/connect/test_parity_arrow.py   |  3 +++
 python/pyspark/sql/tests/test_arrow.py               | 17 ++++++++++++++++-
 4 files changed, 41 insertions(+), 5 deletions(-)

diff --git a/python/pyspark/errors/exceptions/captured.py 
b/python/pyspark/errors/exceptions/captured.py
index d1b57997f99..5b008f4ab00 100644
--- a/python/pyspark/errors/exceptions/captured.py
+++ b/python/pyspark/errors/exceptions/captured.py
@@ -14,8 +14,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-
-from typing import Any, Callable, Dict, Optional, cast
+from contextlib import contextmanager
+from typing import Any, Callable, Dict, Iterator, Optional, cast
 
 import py4j
 from py4j.protocol import Py4JJavaError
@@ -186,6 +186,22 @@ def capture_sql_exception(f: Callable[..., Any]) -> 
Callable[..., Any]:
     return deco
 
 
+@contextmanager
+def unwrap_spark_exception() -> Iterator[Any]:
+    assert SparkContext._gateway is not None
+
+    gw = SparkContext._gateway
+    try:
+        yield
+    except Py4JJavaError as e:
+        je: Py4JJavaError = e.java_exception
+        if je is not None and is_instance_of(gw, je, 
"org.apache.spark.SparkException"):
+            converted = convert_exception(je.getCause())
+            if not isinstance(converted, UnknownException):
+                raise converted from None
+        raise
+
+
 def install_exception_handler() -> None:
     """
     Hook an exception handler into Py4j, which could capture some SQL 
exceptions in Java.
diff --git a/python/pyspark/sql/pandas/conversion.py 
b/python/pyspark/sql/pandas/conversion.py
index a5f0664ed75..ce0143d1851 100644
--- a/python/pyspark/sql/pandas/conversion.py
+++ b/python/pyspark/sql/pandas/conversion.py
@@ -19,6 +19,7 @@ from collections import Counter
 from typing import List, Optional, Type, Union, no_type_check, overload, 
TYPE_CHECKING
 from warnings import catch_warnings, simplefilter, warn
 
+from pyspark.errors.exceptions.captured import unwrap_spark_exception
 from pyspark.rdd import _load_from_socket
 from pyspark.sql.pandas.serializers import ArrowCollectSerializer
 from pyspark.sql.types import (
@@ -357,8 +358,9 @@ class PandasConversionMixin:
             else:
                 results = list(batch_stream)
         finally:
-            # Join serving thread and raise any exceptions from 
collectAsArrowToPython
-            jsocket_auth_server.getResult()
+            with unwrap_spark_exception():
+                # Join serving thread and raise any exceptions from 
collectAsArrowToPython
+                jsocket_auth_server.getResult()
 
         # Separate RecordBatches from batch order indices in results
         batches = results[:-1]
diff --git a/python/pyspark/sql/tests/connect/test_parity_arrow.py 
b/python/pyspark/sql/tests/connect/test_parity_arrow.py
index fd05821f052..f2fa9ece4df 100644
--- a/python/pyspark/sql/tests/connect/test_parity_arrow.py
+++ b/python/pyspark/sql/tests/connect/test_parity_arrow.py
@@ -103,6 +103,9 @@ class ArrowParityTests(ArrowTestsMixin, 
ReusedConnectTestCase):
     def test_timestamp_nat(self):
         self.check_timestamp_nat(True)
 
+    def test_toPandas_error(self):
+        self.check_toPandas_error(True)
+
 
 if __name__ == "__main__":
     from pyspark.sql.tests.connect.test_parity_arrow import *  # noqa: F401
diff --git a/python/pyspark/sql/tests/test_arrow.py 
b/python/pyspark/sql/tests/test_arrow.py
index 518e17d57b6..84c782e8d95 100644
--- a/python/pyspark/sql/tests/test_arrow.py
+++ b/python/pyspark/sql/tests/test_arrow.py
@@ -55,7 +55,7 @@ from pyspark.testing.sqlutils import (
     pyarrow_requirement_message,
 )
 from pyspark.testing.utils import QuietTest
-from pyspark.errors import PySparkTypeError
+from pyspark.errors import ArithmeticException, PySparkTypeError
 
 if have_pandas:
     import pandas as pd
@@ -873,6 +873,21 @@ class ArrowTestsMixin:
         self.assertEqual([Row(c1=1, c2="string")], df.collect())
         self.assertGreater(self.spark.sparkContext.defaultParallelism, 
len(pdf))
 
+    def test_toPandas_error(self):
+        for arrow_enabled in [True, False]:
+            with self.subTest(arrow_enabled=arrow_enabled):
+                self.check_toPandas_error(arrow_enabled)
+
+    def check_toPandas_error(self, arrow_enabled):
+        with self.sql_conf(
+            {
+                "spark.sql.ansi.enabled": True,
+                "spark.sql.execution.arrow.pyspark.enabled": arrow_enabled,
+            }
+        ):
+            with self.assertRaises(ArithmeticException):
+                self.spark.sql("select 1/0").toPandas()
+
 
 @unittest.skipIf(
     not have_pandas or not have_pyarrow,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to