This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new b6b450927ec8 [SPARK-46317][PYTHON][CONNECT] Match minor behaviour matching in SparkSession with full test coverage b6b450927ec8 is described below commit b6b450927ec8139ab9b19442023178f308ada9cb Author: Hyukjin Kwon <gurwls...@apache.org> AuthorDate: Fri Dec 8 15:10:52 2023 +0900 [SPARK-46317][PYTHON][CONNECT] Match minor behaviour matching in SparkSession with full test coverage ### What changes were proposed in this pull request? This PR matches the corner case behaviours in `SparkSession` between Spark Connect and non-Spark Connect with adding unittests with the full test coverage within `pyspark.sql.session`. ### Why are the changes needed? - For feature parity. - To improve the test coverage. See https://app.codecov.io/gh/apache/spark/blob/master/python%2Fpyspark%2Fsql%2Fsession.py - this is not being tested. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Manually ran the new unittest. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #44247 from HyukjinKwon/SPARK-46317. Lead-authored-by: Hyukjin Kwon <gurwls...@apache.org> Co-authored-by: Hyukjin Kwon <gurwls...@gmail.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/errors/error_classes.py | 5 -- python/pyspark/sql/connect/session.py | 9 ++++ python/pyspark/sql/session.py | 4 +- python/pyspark/sql/tests/test_dataframe.py | 29 +++++++++++ python/pyspark/sql/tests/test_session.py | 84 +++++++++++++++++++++++++++++- python/pyspark/sql/tests/test_types.py | 10 ++++ 6 files changed, 133 insertions(+), 8 deletions(-) diff --git a/python/pyspark/errors/error_classes.py b/python/pyspark/errors/error_classes.py index cc8400270967..d2d7f3148f4c 100644 --- a/python/pyspark/errors/error_classes.py +++ b/python/pyspark/errors/error_classes.py @@ -848,11 +848,6 @@ ERROR_CLASSES_JSON = """ "SparkContext or SparkSession should be created first.." ] }, - "SHOULD_NOT_DATAFRAME": { - "message": [ - "Argument `<arg_name>` should not be a DataFrame." - ] - }, "SLICE_WITH_STEP" : { "message" : [ "Slice with step is not supported." diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 0fcd85c033cf..a27e6fa4b729 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -370,6 +370,15 @@ class SparkSession: _cols = [x.encode("utf-8") if not isinstance(x, str) else x for x in schema] _num_cols = len(_cols) + elif schema is not None: + raise PySparkTypeError( + error_class="NOT_LIST_OR_NONE_OR_STRUCT", + message_parameters={ + "arg_name": "schema", + "arg_type": type(schema).__name__, + }, + ) + if isinstance(data, np.ndarray) and data.ndim not in [1, 2]: raise PySparkValueError( error_class="INVALID_NDARRAY_DIMENSION", diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 86aacfa54c6e..7615491a1778 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -1417,8 +1417,8 @@ class SparkSession(SparkConversionMixin): self._jvm.SparkSession.setActiveSession(self._jsparkSession) if isinstance(data, DataFrame): raise PySparkTypeError( - error_class="SHOULD_NOT_DATAFRAME", - message_parameters={"arg_name": "data"}, + error_class="INVALID_TYPE", + message_parameters={"arg_name": "data", "data_type": "DataFrame"}, ) if isinstance(schema, str): diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index c25fe60ad174..e1df01116e18 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -1913,6 +1913,35 @@ class DataFrameTestsMixin: self.assertEqual(df.schema, schema) self.assertEqual(df.collect(), data) + def test_partial_inference_failure(self): + with self.assertRaises(PySparkValueError) as pe: + self.spark.createDataFrame([(None, 1)]) + + self.check_error( + exception=pe.exception, + error_class="CANNOT_DETERMINE_TYPE", + message_parameters={}, + ) + + def test_invalid_argument_create_dataframe(self): + with self.assertRaises(PySparkTypeError) as pe: + self.spark.createDataFrame([(1, 2)], schema=123) + + self.check_error( + exception=pe.exception, + error_class="NOT_LIST_OR_NONE_OR_STRUCT", + message_parameters={"arg_name": "schema", "arg_type": "int"}, + ) + + with self.assertRaises(PySparkTypeError) as pe: + self.spark.createDataFrame(self.spark.range(1)) + + self.check_error( + exception=pe.exception, + error_class="INVALID_TYPE", + message_parameters={"arg_name": "data", "data_type": "DataFrame"}, + ) + class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils): # These tests are separate because it uses 'spark.sql.queryExecutionListeners' which is diff --git a/python/pyspark/sql/tests/test_session.py b/python/pyspark/sql/tests/test_session.py index ba1d999ff7ba..f857e827895e 100644 --- a/python/pyspark/sql/tests/test_session.py +++ b/python/pyspark/sql/tests/test_session.py @@ -18,6 +18,8 @@ import os import unittest import unittest.mock +from io import StringIO +from lxml import etree from pyspark import SparkConf, SparkContext from pyspark.errors import PySparkRuntimeError @@ -74,8 +76,17 @@ class SparkSessionTests2(PySparkTestCase): spark.stop() -class SparkSessionTests3(unittest.TestCase): +class SparkSessionTests3(unittest.TestCase, PySparkErrorTestUtils): def test_active_session(self): + with self.assertRaises(PySparkRuntimeError) as pe1: + SparkSession.active() + + self.check_error( + exception=pe1.exception, + error_class="NO_ACTIVE_OR_DEFAULT_SESSION", + message_parameters={}, + ) + spark = SparkSession.builder.master("local").getOrCreate() try: activeSession = SparkSession.getActiveSession() @@ -109,6 +120,11 @@ class SparkSessionTests3(unittest.TestCase): self.assertEqual(spark.table("table1").columns, ["name", "age"]) self.assertEqual(spark.range(3).count(), 3) + try: + etree.parse(StringIO(spark._repr_html_()), etree.HTMLParser(recover=False)) + except Exception as e: + self.fail(f"Generated HTML from `_repr_html_` was invalid: {e}") + # SPARK-37516: Only plain column references work as variable in SQL. self.assertEqual( spark.sql("select {c} from range(1)", c=col("id")).first(), spark.range(1).first() @@ -163,6 +179,10 @@ class SparkSessionTests3(unittest.TestCase): finally: newSession.stop() + def test_create_new_session_with_statement(self): + with SparkSession.builder.master("local").getOrCreate() as session: + session.range(5).collect() + def test_active_session_with_None_and_not_None_context(self): from pyspark.context import SparkContext from pyspark.conf import SparkConf @@ -194,6 +214,30 @@ class SparkSessionTests3(unittest.TestCase): with self.assertRaisesRegex(RuntimeError, "Cannot create a Spark Connect session"): SparkSession.builder.appName("test").getOrCreate() + def test_unsupported_api(self): + with SparkSession.builder.master("local").getOrCreate() as session: + unsupported = [ + (lambda: session.client, "client"), + (session.addArtifacts, "addArtifact(s)"), + (lambda: session.copyFromLocalToFs("", ""), "copyFromLocalToFs"), + (lambda: session.interruptTag(""), "interruptTag"), + (lambda: session.interruptOperation(""), "interruptOperation"), + (lambda: session.addTag(""), "addTag"), + (lambda: session.removeTag(""), "removeTag"), + (session.getTags, "getTags"), + (session.clearTags, "clearTags"), + ] + + for func, name in unsupported: + with self.assertRaises(PySparkRuntimeError) as pe1: + func() + + self.check_error( + exception=pe1.exception, + error_class="ONLY_SUPPORTED_WITH_SPARK_CONNECT", + message_parameters={"feature": f"SparkSession.{name}"}, + ) + class SparkSessionTests4(ReusedSQLTestCase): def test_get_active_session_after_create_dataframe(self): @@ -378,6 +422,44 @@ class SparkSessionBuilderTests(unittest.TestCase, PySparkErrorTestUtils): }, ) + def test_master_remote_conflicts(self): + with self.assertRaises(PySparkRuntimeError) as pe2: + SparkSession.builder.config("spark.master", "1").config("spark.remote", "2") + + self.check_error( + exception=pe2.exception, + error_class="CANNOT_CONFIGURE_SPARK_CONNECT_MASTER", + message_parameters={"connect_url": "2", "master_url": "1"}, + ) + + try: + os.environ["SPARK_REMOTE"] = "2" + os.environ["SPARK_LOCAL_REMOTE"] = "2" + with self.assertRaises(PySparkRuntimeError) as pe2: + SparkSession.builder.config("spark.remote", "1") + + self.check_error( + exception=pe2.exception, + error_class="CANNOT_CONFIGURE_SPARK_CONNECT", + message_parameters={ + "new_url": "1", + "existing_url": "2", + }, + ) + finally: + del os.environ["SPARK_REMOTE"] + del os.environ["SPARK_LOCAL_REMOTE"] + + def test_invalid_create(self): + with self.assertRaises(PySparkRuntimeError) as pe2: + SparkSession.builder.config("spark.remote", "local").create() + + self.check_error( + exception=pe2.exception, + error_class="UNSUPPORTED_LOCAL_CONNECTION_STRING", + message_parameters={}, + ) + class SparkExtensionsTest(unittest.TestCase): # These tests are separate because it uses 'spark.sql.extensions' which is diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 992abc8e82d9..4316e4962c9d 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -122,6 +122,16 @@ class TypesTestsMixin: def test_infer_schema(self): d = [Row(l=[], d={}, s=None), Row(l=[Row(a=1, b="s")], d={"key": Row(c=1.0, d="2")}, s="")] rdd = self.sc.parallelize(d) + + with self.assertRaises(PySparkTypeError) as pe: + self.spark.createDataFrame(rdd, schema=123) + + self.check_error( + exception=pe.exception, + error_class="NOT_LIST_OR_NONE_OR_STRUCT", + message_parameters={"arg_name": "schema", "arg_type": "int"}, + ) + df = self.spark.createDataFrame(rdd) self.assertEqual([], df.rdd.map(lambda r: r.l).first()) self.assertEqual([None, ""], df.rdd.map(lambda r: r.s).collect()) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org