This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new b6b450927ec8 [SPARK-46317][PYTHON][CONNECT] Match minor behaviour
matching in SparkSession with full test coverage
b6b450927ec8 is described below
commit b6b450927ec8139ab9b19442023178f308ada9cb
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Fri Dec 8 15:10:52 2023 +0900
[SPARK-46317][PYTHON][CONNECT] Match minor behaviour matching in
SparkSession with full test coverage
### What changes were proposed in this pull request?
This PR matches the corner case behaviours in `SparkSession` between Spark
Connect and non-Spark Connect with adding unittests with the full test coverage
within `pyspark.sql.session`.
### Why are the changes needed?
- For feature parity.
- To improve the test coverage.
See
https://app.codecov.io/gh/apache/spark/blob/master/python%2Fpyspark%2Fsql%2Fsession.py
- this is not being tested.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Manually ran the new unittest.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #44247 from HyukjinKwon/SPARK-46317.
Lead-authored-by: Hyukjin Kwon <[email protected]>
Co-authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/errors/error_classes.py | 5 --
python/pyspark/sql/connect/session.py | 9 ++++
python/pyspark/sql/session.py | 4 +-
python/pyspark/sql/tests/test_dataframe.py | 29 +++++++++++
python/pyspark/sql/tests/test_session.py | 84 +++++++++++++++++++++++++++++-
python/pyspark/sql/tests/test_types.py | 10 ++++
6 files changed, 133 insertions(+), 8 deletions(-)
diff --git a/python/pyspark/errors/error_classes.py
b/python/pyspark/errors/error_classes.py
index cc8400270967..d2d7f3148f4c 100644
--- a/python/pyspark/errors/error_classes.py
+++ b/python/pyspark/errors/error_classes.py
@@ -848,11 +848,6 @@ ERROR_CLASSES_JSON = """
"SparkContext or SparkSession should be created first.."
]
},
- "SHOULD_NOT_DATAFRAME": {
- "message": [
- "Argument `<arg_name>` should not be a DataFrame."
- ]
- },
"SLICE_WITH_STEP" : {
"message" : [
"Slice with step is not supported."
diff --git a/python/pyspark/sql/connect/session.py
b/python/pyspark/sql/connect/session.py
index 0fcd85c033cf..a27e6fa4b729 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -370,6 +370,15 @@ class SparkSession:
_cols = [x.encode("utf-8") if not isinstance(x, str) else x for x
in schema]
_num_cols = len(_cols)
+ elif schema is not None:
+ raise PySparkTypeError(
+ error_class="NOT_LIST_OR_NONE_OR_STRUCT",
+ message_parameters={
+ "arg_name": "schema",
+ "arg_type": type(schema).__name__,
+ },
+ )
+
if isinstance(data, np.ndarray) and data.ndim not in [1, 2]:
raise PySparkValueError(
error_class="INVALID_NDARRAY_DIMENSION",
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 86aacfa54c6e..7615491a1778 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -1417,8 +1417,8 @@ class SparkSession(SparkConversionMixin):
self._jvm.SparkSession.setActiveSession(self._jsparkSession)
if isinstance(data, DataFrame):
raise PySparkTypeError(
- error_class="SHOULD_NOT_DATAFRAME",
- message_parameters={"arg_name": "data"},
+ error_class="INVALID_TYPE",
+ message_parameters={"arg_name": "data", "data_type":
"DataFrame"},
)
if isinstance(schema, str):
diff --git a/python/pyspark/sql/tests/test_dataframe.py
b/python/pyspark/sql/tests/test_dataframe.py
index c25fe60ad174..e1df01116e18 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -1913,6 +1913,35 @@ class DataFrameTestsMixin:
self.assertEqual(df.schema, schema)
self.assertEqual(df.collect(), data)
+ def test_partial_inference_failure(self):
+ with self.assertRaises(PySparkValueError) as pe:
+ self.spark.createDataFrame([(None, 1)])
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="CANNOT_DETERMINE_TYPE",
+ message_parameters={},
+ )
+
+ def test_invalid_argument_create_dataframe(self):
+ with self.assertRaises(PySparkTypeError) as pe:
+ self.spark.createDataFrame([(1, 2)], schema=123)
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="NOT_LIST_OR_NONE_OR_STRUCT",
+ message_parameters={"arg_name": "schema", "arg_type": "int"},
+ )
+
+ with self.assertRaises(PySparkTypeError) as pe:
+ self.spark.createDataFrame(self.spark.range(1))
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="INVALID_TYPE",
+ message_parameters={"arg_name": "data", "data_type": "DataFrame"},
+ )
+
class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils):
# These tests are separate because it uses
'spark.sql.queryExecutionListeners' which is
diff --git a/python/pyspark/sql/tests/test_session.py
b/python/pyspark/sql/tests/test_session.py
index ba1d999ff7ba..f857e827895e 100644
--- a/python/pyspark/sql/tests/test_session.py
+++ b/python/pyspark/sql/tests/test_session.py
@@ -18,6 +18,8 @@
import os
import unittest
import unittest.mock
+from io import StringIO
+from lxml import etree
from pyspark import SparkConf, SparkContext
from pyspark.errors import PySparkRuntimeError
@@ -74,8 +76,17 @@ class SparkSessionTests2(PySparkTestCase):
spark.stop()
-class SparkSessionTests3(unittest.TestCase):
+class SparkSessionTests3(unittest.TestCase, PySparkErrorTestUtils):
def test_active_session(self):
+ with self.assertRaises(PySparkRuntimeError) as pe1:
+ SparkSession.active()
+
+ self.check_error(
+ exception=pe1.exception,
+ error_class="NO_ACTIVE_OR_DEFAULT_SESSION",
+ message_parameters={},
+ )
+
spark = SparkSession.builder.master("local").getOrCreate()
try:
activeSession = SparkSession.getActiveSession()
@@ -109,6 +120,11 @@ class SparkSessionTests3(unittest.TestCase):
self.assertEqual(spark.table("table1").columns, ["name", "age"])
self.assertEqual(spark.range(3).count(), 3)
+ try:
+ etree.parse(StringIO(spark._repr_html_()),
etree.HTMLParser(recover=False))
+ except Exception as e:
+ self.fail(f"Generated HTML from `_repr_html_` was invalid:
{e}")
+
# SPARK-37516: Only plain column references work as variable in
SQL.
self.assertEqual(
spark.sql("select {c} from range(1)", c=col("id")).first(),
spark.range(1).first()
@@ -163,6 +179,10 @@ class SparkSessionTests3(unittest.TestCase):
finally:
newSession.stop()
+ def test_create_new_session_with_statement(self):
+ with SparkSession.builder.master("local").getOrCreate() as session:
+ session.range(5).collect()
+
def test_active_session_with_None_and_not_None_context(self):
from pyspark.context import SparkContext
from pyspark.conf import SparkConf
@@ -194,6 +214,30 @@ class SparkSessionTests3(unittest.TestCase):
with self.assertRaisesRegex(RuntimeError, "Cannot create a Spark
Connect session"):
SparkSession.builder.appName("test").getOrCreate()
+ def test_unsupported_api(self):
+ with SparkSession.builder.master("local").getOrCreate() as session:
+ unsupported = [
+ (lambda: session.client, "client"),
+ (session.addArtifacts, "addArtifact(s)"),
+ (lambda: session.copyFromLocalToFs("", ""),
"copyFromLocalToFs"),
+ (lambda: session.interruptTag(""), "interruptTag"),
+ (lambda: session.interruptOperation(""), "interruptOperation"),
+ (lambda: session.addTag(""), "addTag"),
+ (lambda: session.removeTag(""), "removeTag"),
+ (session.getTags, "getTags"),
+ (session.clearTags, "clearTags"),
+ ]
+
+ for func, name in unsupported:
+ with self.assertRaises(PySparkRuntimeError) as pe1:
+ func()
+
+ self.check_error(
+ exception=pe1.exception,
+ error_class="ONLY_SUPPORTED_WITH_SPARK_CONNECT",
+ message_parameters={"feature": f"SparkSession.{name}"},
+ )
+
class SparkSessionTests4(ReusedSQLTestCase):
def test_get_active_session_after_create_dataframe(self):
@@ -378,6 +422,44 @@ class SparkSessionBuilderTests(unittest.TestCase,
PySparkErrorTestUtils):
},
)
+ def test_master_remote_conflicts(self):
+ with self.assertRaises(PySparkRuntimeError) as pe2:
+ SparkSession.builder.config("spark.master",
"1").config("spark.remote", "2")
+
+ self.check_error(
+ exception=pe2.exception,
+ error_class="CANNOT_CONFIGURE_SPARK_CONNECT_MASTER",
+ message_parameters={"connect_url": "2", "master_url": "1"},
+ )
+
+ try:
+ os.environ["SPARK_REMOTE"] = "2"
+ os.environ["SPARK_LOCAL_REMOTE"] = "2"
+ with self.assertRaises(PySparkRuntimeError) as pe2:
+ SparkSession.builder.config("spark.remote", "1")
+
+ self.check_error(
+ exception=pe2.exception,
+ error_class="CANNOT_CONFIGURE_SPARK_CONNECT",
+ message_parameters={
+ "new_url": "1",
+ "existing_url": "2",
+ },
+ )
+ finally:
+ del os.environ["SPARK_REMOTE"]
+ del os.environ["SPARK_LOCAL_REMOTE"]
+
+ def test_invalid_create(self):
+ with self.assertRaises(PySparkRuntimeError) as pe2:
+ SparkSession.builder.config("spark.remote", "local").create()
+
+ self.check_error(
+ exception=pe2.exception,
+ error_class="UNSUPPORTED_LOCAL_CONNECTION_STRING",
+ message_parameters={},
+ )
+
class SparkExtensionsTest(unittest.TestCase):
# These tests are separate because it uses 'spark.sql.extensions' which is
diff --git a/python/pyspark/sql/tests/test_types.py
b/python/pyspark/sql/tests/test_types.py
index 992abc8e82d9..4316e4962c9d 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -122,6 +122,16 @@ class TypesTestsMixin:
def test_infer_schema(self):
d = [Row(l=[], d={}, s=None), Row(l=[Row(a=1, b="s")], d={"key":
Row(c=1.0, d="2")}, s="")]
rdd = self.sc.parallelize(d)
+
+ with self.assertRaises(PySparkTypeError) as pe:
+ self.spark.createDataFrame(rdd, schema=123)
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="NOT_LIST_OR_NONE_OR_STRUCT",
+ message_parameters={"arg_name": "schema", "arg_type": "int"},
+ )
+
df = self.spark.createDataFrame(rdd)
self.assertEqual([], df.rdd.map(lambda r: r.l).first())
self.assertEqual([None, ""], df.rdd.map(lambda r: r.s).collect())
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]