This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new b6b450927ec8 [SPARK-46317][PYTHON][CONNECT] Match minor behaviour 
matching in SparkSession with full test coverage
b6b450927ec8 is described below

commit b6b450927ec8139ab9b19442023178f308ada9cb
Author: Hyukjin Kwon <gurwls...@apache.org>
AuthorDate: Fri Dec 8 15:10:52 2023 +0900

    [SPARK-46317][PYTHON][CONNECT] Match minor behaviour matching in 
SparkSession with full test coverage
    
    ### What changes were proposed in this pull request?
    
    This PR matches the corner case behaviours in `SparkSession` between Spark 
Connect and non-Spark Connect with adding unittests with the full test coverage 
within `pyspark.sql.session`.
    
    ### Why are the changes needed?
    
    - For feature parity.
    - To improve the test coverage.
        See 
https://app.codecov.io/gh/apache/spark/blob/master/python%2Fpyspark%2Fsql%2Fsession.py
 - this is not being tested.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Manually ran the new unittest.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #44247 from HyukjinKwon/SPARK-46317.
    
    Lead-authored-by: Hyukjin Kwon <gurwls...@apache.org>
    Co-authored-by: Hyukjin Kwon <gurwls...@gmail.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/errors/error_classes.py     |  5 --
 python/pyspark/sql/connect/session.py      |  9 ++++
 python/pyspark/sql/session.py              |  4 +-
 python/pyspark/sql/tests/test_dataframe.py | 29 +++++++++++
 python/pyspark/sql/tests/test_session.py   | 84 +++++++++++++++++++++++++++++-
 python/pyspark/sql/tests/test_types.py     | 10 ++++
 6 files changed, 133 insertions(+), 8 deletions(-)

diff --git a/python/pyspark/errors/error_classes.py 
b/python/pyspark/errors/error_classes.py
index cc8400270967..d2d7f3148f4c 100644
--- a/python/pyspark/errors/error_classes.py
+++ b/python/pyspark/errors/error_classes.py
@@ -848,11 +848,6 @@ ERROR_CLASSES_JSON = """
       "SparkContext or SparkSession should be created first.."
     ]
   },
-  "SHOULD_NOT_DATAFRAME": {
-    "message": [
-      "Argument `<arg_name>` should not be a DataFrame."
-    ]
-  },
   "SLICE_WITH_STEP" : {
     "message" : [
       "Slice with step is not supported."
diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index 0fcd85c033cf..a27e6fa4b729 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -370,6 +370,15 @@ class SparkSession:
             _cols = [x.encode("utf-8") if not isinstance(x, str) else x for x 
in schema]
             _num_cols = len(_cols)
 
+        elif schema is not None:
+            raise PySparkTypeError(
+                error_class="NOT_LIST_OR_NONE_OR_STRUCT",
+                message_parameters={
+                    "arg_name": "schema",
+                    "arg_type": type(schema).__name__,
+                },
+            )
+
         if isinstance(data, np.ndarray) and data.ndim not in [1, 2]:
             raise PySparkValueError(
                 error_class="INVALID_NDARRAY_DIMENSION",
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 86aacfa54c6e..7615491a1778 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -1417,8 +1417,8 @@ class SparkSession(SparkConversionMixin):
         self._jvm.SparkSession.setActiveSession(self._jsparkSession)
         if isinstance(data, DataFrame):
             raise PySparkTypeError(
-                error_class="SHOULD_NOT_DATAFRAME",
-                message_parameters={"arg_name": "data"},
+                error_class="INVALID_TYPE",
+                message_parameters={"arg_name": "data", "data_type": 
"DataFrame"},
             )
 
         if isinstance(schema, str):
diff --git a/python/pyspark/sql/tests/test_dataframe.py 
b/python/pyspark/sql/tests/test_dataframe.py
index c25fe60ad174..e1df01116e18 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -1913,6 +1913,35 @@ class DataFrameTestsMixin:
         self.assertEqual(df.schema, schema)
         self.assertEqual(df.collect(), data)
 
+    def test_partial_inference_failure(self):
+        with self.assertRaises(PySparkValueError) as pe:
+            self.spark.createDataFrame([(None, 1)])
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="CANNOT_DETERMINE_TYPE",
+            message_parameters={},
+        )
+
+    def test_invalid_argument_create_dataframe(self):
+        with self.assertRaises(PySparkTypeError) as pe:
+            self.spark.createDataFrame([(1, 2)], schema=123)
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="NOT_LIST_OR_NONE_OR_STRUCT",
+            message_parameters={"arg_name": "schema", "arg_type": "int"},
+        )
+
+        with self.assertRaises(PySparkTypeError) as pe:
+            self.spark.createDataFrame(self.spark.range(1))
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="INVALID_TYPE",
+            message_parameters={"arg_name": "data", "data_type": "DataFrame"},
+        )
+
 
 class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils):
     # These tests are separate because it uses 
'spark.sql.queryExecutionListeners' which is
diff --git a/python/pyspark/sql/tests/test_session.py 
b/python/pyspark/sql/tests/test_session.py
index ba1d999ff7ba..f857e827895e 100644
--- a/python/pyspark/sql/tests/test_session.py
+++ b/python/pyspark/sql/tests/test_session.py
@@ -18,6 +18,8 @@
 import os
 import unittest
 import unittest.mock
+from io import StringIO
+from lxml import etree
 
 from pyspark import SparkConf, SparkContext
 from pyspark.errors import PySparkRuntimeError
@@ -74,8 +76,17 @@ class SparkSessionTests2(PySparkTestCase):
             spark.stop()
 
 
-class SparkSessionTests3(unittest.TestCase):
+class SparkSessionTests3(unittest.TestCase, PySparkErrorTestUtils):
     def test_active_session(self):
+        with self.assertRaises(PySparkRuntimeError) as pe1:
+            SparkSession.active()
+
+        self.check_error(
+            exception=pe1.exception,
+            error_class="NO_ACTIVE_OR_DEFAULT_SESSION",
+            message_parameters={},
+        )
+
         spark = SparkSession.builder.master("local").getOrCreate()
         try:
             activeSession = SparkSession.getActiveSession()
@@ -109,6 +120,11 @@ class SparkSessionTests3(unittest.TestCase):
             self.assertEqual(spark.table("table1").columns, ["name", "age"])
             self.assertEqual(spark.range(3).count(), 3)
 
+            try:
+                etree.parse(StringIO(spark._repr_html_()), 
etree.HTMLParser(recover=False))
+            except Exception as e:
+                self.fail(f"Generated HTML from `_repr_html_` was invalid: 
{e}")
+
             # SPARK-37516: Only plain column references work as variable in 
SQL.
             self.assertEqual(
                 spark.sql("select {c} from range(1)", c=col("id")).first(), 
spark.range(1).first()
@@ -163,6 +179,10 @@ class SparkSessionTests3(unittest.TestCase):
         finally:
             newSession.stop()
 
+    def test_create_new_session_with_statement(self):
+        with SparkSession.builder.master("local").getOrCreate() as session:
+            session.range(5).collect()
+
     def test_active_session_with_None_and_not_None_context(self):
         from pyspark.context import SparkContext
         from pyspark.conf import SparkConf
@@ -194,6 +214,30 @@ class SparkSessionTests3(unittest.TestCase):
             with self.assertRaisesRegex(RuntimeError, "Cannot create a Spark 
Connect session"):
                 SparkSession.builder.appName("test").getOrCreate()
 
+    def test_unsupported_api(self):
+        with SparkSession.builder.master("local").getOrCreate() as session:
+            unsupported = [
+                (lambda: session.client, "client"),
+                (session.addArtifacts, "addArtifact(s)"),
+                (lambda: session.copyFromLocalToFs("", ""), 
"copyFromLocalToFs"),
+                (lambda: session.interruptTag(""), "interruptTag"),
+                (lambda: session.interruptOperation(""), "interruptOperation"),
+                (lambda: session.addTag(""), "addTag"),
+                (lambda: session.removeTag(""), "removeTag"),
+                (session.getTags, "getTags"),
+                (session.clearTags, "clearTags"),
+            ]
+
+            for func, name in unsupported:
+                with self.assertRaises(PySparkRuntimeError) as pe1:
+                    func()
+
+                self.check_error(
+                    exception=pe1.exception,
+                    error_class="ONLY_SUPPORTED_WITH_SPARK_CONNECT",
+                    message_parameters={"feature": f"SparkSession.{name}"},
+                )
+
 
 class SparkSessionTests4(ReusedSQLTestCase):
     def test_get_active_session_after_create_dataframe(self):
@@ -378,6 +422,44 @@ class SparkSessionBuilderTests(unittest.TestCase, 
PySparkErrorTestUtils):
                 },
             )
 
+    def test_master_remote_conflicts(self):
+        with self.assertRaises(PySparkRuntimeError) as pe2:
+            SparkSession.builder.config("spark.master", 
"1").config("spark.remote", "2")
+
+        self.check_error(
+            exception=pe2.exception,
+            error_class="CANNOT_CONFIGURE_SPARK_CONNECT_MASTER",
+            message_parameters={"connect_url": "2", "master_url": "1"},
+        )
+
+        try:
+            os.environ["SPARK_REMOTE"] = "2"
+            os.environ["SPARK_LOCAL_REMOTE"] = "2"
+            with self.assertRaises(PySparkRuntimeError) as pe2:
+                SparkSession.builder.config("spark.remote", "1")
+
+            self.check_error(
+                exception=pe2.exception,
+                error_class="CANNOT_CONFIGURE_SPARK_CONNECT",
+                message_parameters={
+                    "new_url": "1",
+                    "existing_url": "2",
+                },
+            )
+        finally:
+            del os.environ["SPARK_REMOTE"]
+            del os.environ["SPARK_LOCAL_REMOTE"]
+
+    def test_invalid_create(self):
+        with self.assertRaises(PySparkRuntimeError) as pe2:
+            SparkSession.builder.config("spark.remote", "local").create()
+
+        self.check_error(
+            exception=pe2.exception,
+            error_class="UNSUPPORTED_LOCAL_CONNECTION_STRING",
+            message_parameters={},
+        )
+
 
 class SparkExtensionsTest(unittest.TestCase):
     # These tests are separate because it uses 'spark.sql.extensions' which is
diff --git a/python/pyspark/sql/tests/test_types.py 
b/python/pyspark/sql/tests/test_types.py
index 992abc8e82d9..4316e4962c9d 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -122,6 +122,16 @@ class TypesTestsMixin:
     def test_infer_schema(self):
         d = [Row(l=[], d={}, s=None), Row(l=[Row(a=1, b="s")], d={"key": 
Row(c=1.0, d="2")}, s="")]
         rdd = self.sc.parallelize(d)
+
+        with self.assertRaises(PySparkTypeError) as pe:
+            self.spark.createDataFrame(rdd, schema=123)
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="NOT_LIST_OR_NONE_OR_STRUCT",
+            message_parameters={"arg_name": "schema", "arg_type": "int"},
+        )
+
         df = self.spark.createDataFrame(rdd)
         self.assertEqual([], df.rdd.map(lambda r: r.l).first())
         self.assertEqual([None, ""], df.rdd.map(lambda r: r.s).collect())


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to