This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new c8b1e526b4a [SPARK-41897][CONNECT][TESTS] Enable tests with error
mismatch in connect/test_parity_functions.py
c8b1e526b4a is described below
commit c8b1e526b4a980550c4eeab541f7cd3d5aa6e0f2
Author: Sandeep Singh <[email protected]>
AuthorDate: Sat Jan 28 15:02:20 2023 +0900
[SPARK-41897][CONNECT][TESTS] Enable tests with error mismatch in
connect/test_parity_functions.py
### What changes were proposed in this pull request?
Fix tests with error mismatch in connect/test_parity_functions.py
### Why are the changes needed?
Tests coverage
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Enabling new tests
Closes #39450 from techaddict/SPARK-41897.
Authored-by: Sandeep Singh <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../sql/tests/connect/test_parity_functions.py | 10 -----
python/pyspark/sql/tests/test_functions.py | 44 ++++++++++------------
2 files changed, 19 insertions(+), 35 deletions(-)
diff --git a/python/pyspark/sql/tests/connect/test_parity_functions.py
b/python/pyspark/sql/tests/connect/test_parity_functions.py
index a7bb987e221..b151986cb24 100644
--- a/python/pyspark/sql/tests/connect/test_parity_functions.py
+++ b/python/pyspark/sql/tests/connect/test_parity_functions.py
@@ -22,11 +22,6 @@ from pyspark.testing.connectutils import
ReusedConnectTestCase
class FunctionsParityTests(FunctionsTestsMixin, ReusedConnectTestCase):
- # TODO(SPARK-41897): Parity in Error types between pyspark and connect
functions
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_assert_true(self):
- super().test_assert_true()
-
@unittest.skip("Spark Connect does not support Spark Context but the test
depends on that.")
def test_basic_functions(self):
super().test_basic_functions()
@@ -54,11 +49,6 @@ class FunctionsParityTests(FunctionsTestsMixin,
ReusedConnectTestCase):
def test_lit_np_scalar(self):
super().test_lit_np_scalar()
- # TODO(SPARK-41897): Parity in Error types between pyspark and connect
functions
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_raise_error(self):
- super().test_raise_error()
-
# Comparing column type of connect and pyspark
@unittest.skip("Fails in Spark Connect, should enable.")
def test_sorting_functions_with_column(self):
diff --git a/python/pyspark/sql/tests/test_functions.py
b/python/pyspark/sql/tests/test_functions.py
index 67d5e6bb739..f7f15d8e2c1 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -25,7 +25,7 @@ import math
import unittest
from py4j.protocol import Py4JJavaError
-from pyspark.errors import PySparkTypeError, PySparkValueError
+from pyspark.errors import PySparkTypeError, PySparkValueError,
SparkConnectException
from pyspark.sql import Row, Window, types
from pyspark.sql.functions import (
udf,
@@ -1015,48 +1015,42 @@ class FunctionsTestsMixin:
[Row(val=None), Row(val=None), Row(val=None)],
)
- with self.assertRaises(Py4JJavaError) as cm:
+ with self.assertRaisesRegex((Py4JJavaError, SparkConnectException),
"too big"):
df.select(assert_true(df.id < 2, "too big")).toDF("val").collect()
- self.assertIn("java.lang.RuntimeException", str(cm.exception))
- self.assertIn("too big", str(cm.exception))
- with self.assertRaises(Py4JJavaError) as cm:
+ with self.assertRaisesRegex((Py4JJavaError, SparkConnectException),
"2000000"):
df.select(assert_true(df.id < 2, df.id *
1e6)).toDF("val").collect()
- self.assertIn("java.lang.RuntimeException", str(cm.exception))
- self.assertIn("2000000", str(cm.exception))
- with self.assertRaises(PySparkTypeError) as pe:
+ with self.assertRaises((PySparkTypeError, TypeError)) as pe:
df.select(assert_true(df.id < 2, 5))
- self.check_error(
- exception=pe.exception,
- error_class="NOT_COLUMN_OR_STRING",
- message_parameters={"arg_name": "errMsg", "arg_type": "int"},
- )
+ if isinstance(pe, PySparkTypeError):
+ self.check_error(
+ exception=pe.exception,
+ error_class="NOT_COLUMN_OR_STRING",
+ message_parameters={"arg_name": "errMsg", "arg_type": "int"},
+ )
def test_raise_error(self):
from pyspark.sql.functions import raise_error
df = self.spark.createDataFrame([Row(id="foobar")])
- with self.assertRaises(Py4JJavaError) as cm:
+ with self.assertRaisesRegex((Py4JJavaError, SparkConnectException),
"foobar"):
df.select(raise_error(df.id)).collect()
- self.assertIn("java.lang.RuntimeException", str(cm.exception))
- self.assertIn("foobar", str(cm.exception))
- with self.assertRaises(Py4JJavaError) as cm:
+ with self.assertRaisesRegex((Py4JJavaError, SparkConnectException),
"barfoo"):
df.select(raise_error("barfoo")).collect()
- self.assertIn("java.lang.RuntimeException", str(cm.exception))
- self.assertIn("barfoo", str(cm.exception))
- with self.assertRaises(PySparkTypeError) as pe:
+ with self.assertRaises((PySparkTypeError, TypeError)) as pe:
df.select(raise_error(None))
- self.check_error(
- exception=pe.exception,
- error_class="NOT_COLUMN_OR_STRING",
- message_parameters={"arg_name": "errMsg", "arg_type": "NoneType"},
- )
+ if isinstance(pe, PySparkTypeError):
+ self.check_error(
+ exception=pe.exception,
+ error_class="NOT_COLUMN_OR_STRING",
+ message_parameters={"arg_name": "errMsg", "arg_type":
"NoneType"},
+ )
def test_sum_distinct(self):
self.spark.range(10).select(
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]