This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c8b1e526b4a [SPARK-41897][CONNECT][TESTS] Enable tests with error 
mismatch in connect/test_parity_functions.py
c8b1e526b4a is described below

commit c8b1e526b4a980550c4eeab541f7cd3d5aa6e0f2
Author: Sandeep Singh <[email protected]>
AuthorDate: Sat Jan 28 15:02:20 2023 +0900

    [SPARK-41897][CONNECT][TESTS] Enable tests with error mismatch in 
connect/test_parity_functions.py
    
    ### What changes were proposed in this pull request?
    Fix tests with error mismatch in connect/test_parity_functions.py
    
    ### Why are the changes needed?
    Tests coverage
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Enabling new tests
    
    Closes #39450 from techaddict/SPARK-41897.
    
    Authored-by: Sandeep Singh <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../sql/tests/connect/test_parity_functions.py     | 10 -----
 python/pyspark/sql/tests/test_functions.py         | 44 ++++++++++------------
 2 files changed, 19 insertions(+), 35 deletions(-)

diff --git a/python/pyspark/sql/tests/connect/test_parity_functions.py 
b/python/pyspark/sql/tests/connect/test_parity_functions.py
index a7bb987e221..b151986cb24 100644
--- a/python/pyspark/sql/tests/connect/test_parity_functions.py
+++ b/python/pyspark/sql/tests/connect/test_parity_functions.py
@@ -22,11 +22,6 @@ from pyspark.testing.connectutils import 
ReusedConnectTestCase
 
 
 class FunctionsParityTests(FunctionsTestsMixin, ReusedConnectTestCase):
-    # TODO(SPARK-41897): Parity in Error types between pyspark and connect 
functions
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_assert_true(self):
-        super().test_assert_true()
-
     @unittest.skip("Spark Connect does not support Spark Context but the test 
depends on that.")
     def test_basic_functions(self):
         super().test_basic_functions()
@@ -54,11 +49,6 @@ class FunctionsParityTests(FunctionsTestsMixin, 
ReusedConnectTestCase):
     def test_lit_np_scalar(self):
         super().test_lit_np_scalar()
 
-    # TODO(SPARK-41897): Parity in Error types between pyspark and connect 
functions
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_raise_error(self):
-        super().test_raise_error()
-
     # Comparing column type of connect and pyspark
     @unittest.skip("Fails in Spark Connect, should enable.")
     def test_sorting_functions_with_column(self):
diff --git a/python/pyspark/sql/tests/test_functions.py 
b/python/pyspark/sql/tests/test_functions.py
index 67d5e6bb739..f7f15d8e2c1 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -25,7 +25,7 @@ import math
 import unittest
 
 from py4j.protocol import Py4JJavaError
-from pyspark.errors import PySparkTypeError, PySparkValueError
+from pyspark.errors import PySparkTypeError, PySparkValueError, 
SparkConnectException
 from pyspark.sql import Row, Window, types
 from pyspark.sql.functions import (
     udf,
@@ -1015,48 +1015,42 @@ class FunctionsTestsMixin:
             [Row(val=None), Row(val=None), Row(val=None)],
         )
 
-        with self.assertRaises(Py4JJavaError) as cm:
+        with self.assertRaisesRegex((Py4JJavaError, SparkConnectException), 
"too big"):
             df.select(assert_true(df.id < 2, "too big")).toDF("val").collect()
-        self.assertIn("java.lang.RuntimeException", str(cm.exception))
-        self.assertIn("too big", str(cm.exception))
 
-        with self.assertRaises(Py4JJavaError) as cm:
+        with self.assertRaisesRegex((Py4JJavaError, SparkConnectException), 
"2000000"):
             df.select(assert_true(df.id < 2, df.id * 
1e6)).toDF("val").collect()
-        self.assertIn("java.lang.RuntimeException", str(cm.exception))
-        self.assertIn("2000000", str(cm.exception))
 
-        with self.assertRaises(PySparkTypeError) as pe:
+        with self.assertRaises((PySparkTypeError, TypeError)) as pe:
             df.select(assert_true(df.id < 2, 5))
 
-        self.check_error(
-            exception=pe.exception,
-            error_class="NOT_COLUMN_OR_STRING",
-            message_parameters={"arg_name": "errMsg", "arg_type": "int"},
-        )
+        if isinstance(pe, PySparkTypeError):
+            self.check_error(
+                exception=pe.exception,
+                error_class="NOT_COLUMN_OR_STRING",
+                message_parameters={"arg_name": "errMsg", "arg_type": "int"},
+            )
 
     def test_raise_error(self):
         from pyspark.sql.functions import raise_error
 
         df = self.spark.createDataFrame([Row(id="foobar")])
 
-        with self.assertRaises(Py4JJavaError) as cm:
+        with self.assertRaisesRegex((Py4JJavaError, SparkConnectException), 
"foobar"):
             df.select(raise_error(df.id)).collect()
-        self.assertIn("java.lang.RuntimeException", str(cm.exception))
-        self.assertIn("foobar", str(cm.exception))
 
-        with self.assertRaises(Py4JJavaError) as cm:
+        with self.assertRaisesRegex((Py4JJavaError, SparkConnectException), 
"barfoo"):
             df.select(raise_error("barfoo")).collect()
-        self.assertIn("java.lang.RuntimeException", str(cm.exception))
-        self.assertIn("barfoo", str(cm.exception))
 
-        with self.assertRaises(PySparkTypeError) as pe:
+        with self.assertRaises((PySparkTypeError, TypeError)) as pe:
             df.select(raise_error(None))
 
-        self.check_error(
-            exception=pe.exception,
-            error_class="NOT_COLUMN_OR_STRING",
-            message_parameters={"arg_name": "errMsg", "arg_type": "NoneType"},
-        )
+        if isinstance(pe, PySparkTypeError):
+            self.check_error(
+                exception=pe.exception,
+                error_class="NOT_COLUMN_OR_STRING",
+                message_parameters={"arg_name": "errMsg", "arg_type": 
"NoneType"},
+            )
 
     def test_sum_distinct(self):
         self.spark.range(10).select(


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to