This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push:
new 6a3e77e32e2 [SPARK-42187][CONNECT][TESTS] Avoid using
RemoteSparkSession.builder.getOrCreate in tests
6a3e77e32e2 is described below
commit 6a3e77e32e23de81c5ebdde9b1e1576af534d249
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Thu Jan 26 13:33:07 2023 +0900
[SPARK-42187][CONNECT][TESTS] Avoid using
RemoteSparkSession.builder.getOrCreate in tests
### What changes were proposed in this pull request?
This PR proposes to use `pyspark.sql.SparkSession.getOrCreate` instead of
`pyspark.sql.connect.Sparksession.builder.getOrCreate`.
### Why are the changes needed?
Because `pyspark.sql.connect.Sparksession.builder.getOrCreate` is supposed
to be internal, and it does not have the unified handling of the Spark sessions
for both PySpark session and Spark Connect sessions.
### Does this PR introduce _any_ user-facing change?
No, test-only.
### How was this patch tested?
Unittests fixed.
Closes #39743 from HyukjinKwon/cleanup-test.
Lead-authored-by: Hyukjin Kwon <[email protected]>
Co-authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
(cherry picked from commit 4f60ebcf0f83b767e17199a4ffe1edc24862fcfa)
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../sql/tests/connect/test_connect_basic.py | 41 +++++++++++----------
.../sql/tests/connect/test_connect_function.py | 42 +++++++++-------------
python/pyspark/sql/utils.py | 8 ++---
python/pyspark/testing/connectutils.py | 20 ++++-------
python/pyspark/testing/pandasutils.py | 14 ++++----
5 files changed, 58 insertions(+), 67 deletions(-)
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 3f7494a6385..94aed9fcc30 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -17,12 +17,13 @@
import array
import datetime
+import os
import unittest
import shutil
import tempfile
from pyspark.testing.sqlutils import SQLTestUtils
-from pyspark.sql import SparkSession, Row
+from pyspark.sql import SparkSession as PySparkSession, Row
from pyspark.sql.types import (
StructType,
StructField,
@@ -33,9 +34,11 @@ from pyspark.sql.types import (
ArrayType,
Row,
)
-from pyspark.testing.utils import ReusedPySparkTestCase
-from pyspark.testing.connectutils import should_test_connect,
connect_requirement_message
-from pyspark.testing.pandasutils import PandasOnSparkTestCase
+from pyspark.testing.connectutils import (
+ should_test_connect,
+ ReusedConnectTestCase,
+)
+from pyspark.testing.pandasutils import PandasOnSparkTestUtils
from pyspark.errors import (
SparkConnectException,
SparkConnectAnalysisException,
@@ -57,22 +60,25 @@ if should_test_connect:
from pyspark.sql.connect import functions as CF
[email protected](not should_test_connect, connect_requirement_message)
-class SparkConnectSQLTestCase(PandasOnSparkTestCase, ReusedPySparkTestCase,
SQLTestUtils):
+class SparkConnectSQLTestCase(ReusedConnectTestCase, SQLTestUtils,
PandasOnSparkTestUtils):
"""Parent test fixture class for all Spark Connect related
test cases."""
@classmethod
def setUpClass(cls):
- ReusedPySparkTestCase.setUpClass()
- cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
- cls.hive_available = True
- # Create the new Spark Session
- cls.spark = SparkSession(cls.sc)
+ super(SparkConnectSQLTestCase, cls).setUpClass()
+ # Disable the shared namespace so pyspark.sql.functions, etc point the
regular
+ # PySpark libraries.
+ os.environ["PYSPARK_NO_NAMESPACE_SHARE"] = "1"
+
+ cls.connect = cls.spark # Switch Spark Connect session and regular
PySpark sesion.
+ cls.spark = PySparkSession._instantiatedSession
+ assert cls.spark is not None
+
cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
cls.testDataStr = [Row(key=str(i)) for i in range(100)]
- cls.df = cls.sc.parallelize(cls.testData).toDF()
- cls.df_text = cls.sc.parallelize(cls.testDataStr).toDF()
+ cls.df = cls.spark.sparkContext.parallelize(cls.testData).toDF()
+ cls.df_text =
cls.spark.sparkContext.parallelize(cls.testDataStr).toDF()
cls.tbl_name = "test_connect_basic_table_1"
cls.tbl_name2 = "test_connect_basic_table_2"
@@ -88,12 +94,12 @@ class SparkConnectSQLTestCase(PandasOnSparkTestCase,
ReusedPySparkTestCase, SQLT
@classmethod
def tearDownClass(cls):
cls.spark_connect_clean_up_test_data()
- ReusedPySparkTestCase.tearDownClass()
+ cls.spark = cls.connect # Stopping Spark Connect closes the session
in JVM at the server.
+ super(SparkConnectSQLTestCase, cls).setUpClass()
+ del os.environ["PYSPARK_NO_NAMESPACE_SHARE"]
@classmethod
def spark_connect_load_test_data(cls):
- # Setup Remote Spark Session
- cls.connect = RemoteSparkSession.builder.remote().getOrCreate()
df = cls.spark.createDataFrame([(x, f"{x}") for x in range(100)],
["id", "name"])
# Since we might create multiple Spark sessions, we need to create
global temporary view
# that is specifically maintained in the "global_temp" schema.
@@ -2596,8 +2602,7 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
getattr(df.write, f)()
[email protected](not should_test_connect, connect_requirement_message)
-class ChannelBuilderTests(ReusedPySparkTestCase):
+class ChannelBuilderTests(unittest.TestCase):
def test_invalid_connection_strings(self):
invalid = [
"scc://host:12",
diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py
b/python/pyspark/sql/tests/connect/test_connect_function.py
index b74b1a9ee69..7042a7e8e6f 100644
--- a/python/pyspark/sql/tests/connect/test_connect_function.py
+++ b/python/pyspark/sql/tests/connect/test_connect_function.py
@@ -14,45 +14,37 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+import os
import unittest
-import tempfile
from pyspark.errors import PySparkTypeError
-from pyspark.sql import SparkSession
+from pyspark.sql import SparkSession as PySparkSession
from pyspark.sql.types import StringType, StructType, StructField, ArrayType,
IntegerType
-from pyspark.testing.pandasutils import PandasOnSparkTestCase
-from pyspark.testing.connectutils import should_test_connect,
connect_requirement_message
-from pyspark.testing.utils import ReusedPySparkTestCase
+from pyspark.testing.pandasutils import PandasOnSparkTestUtils
+from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.testing.sqlutils import SQLTestUtils
from pyspark.errors import SparkConnectAnalysisException, SparkConnectException
-if should_test_connect:
- from pyspark.sql.connect.session import SparkSession as RemoteSparkSession
-
[email protected](not should_test_connect, connect_requirement_message)
-class SparkConnectFuncTestCase(PandasOnSparkTestCase, ReusedPySparkTestCase,
SQLTestUtils):
- """Parent test fixture class for all Spark Connect related
- test cases."""
+class SparkConnectFunctionTests(ReusedConnectTestCase, PandasOnSparkTestUtils,
SQLTestUtils):
+ """These test cases exercise the interface to the proto plan
+ generation but do not call Spark."""
@classmethod
def setUpClass(cls):
- ReusedPySparkTestCase.setUpClass()
- cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
- cls.hive_available = True
- # Create the new Spark Session
- cls.spark = SparkSession(cls.sc)
- # Setup Remote Spark Session
- cls.connect = RemoteSparkSession.builder.remote().getOrCreate()
+ super(SparkConnectFunctionTests, cls).setUpClass()
+ # Disable the shared namespace so pyspark.sql.functions, etc point the
regular
+ # PySpark libraries.
+ os.environ["PYSPARK_NO_NAMESPACE_SHARE"] = "1"
+ cls.connect = cls.spark # Switch Spark Connect session and regular
PySpark sesion.
+ cls.spark = PySparkSession._instantiatedSession
+ assert cls.spark is not None
@classmethod
def tearDownClass(cls):
- ReusedPySparkTestCase.tearDownClass()
-
-
-class SparkConnectFunctionTests(SparkConnectFuncTestCase):
- """These test cases exercise the interface to the proto plan
- generation but do not call Spark."""
+ cls.spark = cls.connect # Stopping Spark Connect closes the session
in JVM at the server.
+ super(SparkConnectFunctionTests, cls).setUpClass()
+ del os.environ["PYSPARK_NO_NAMESPACE_SHARE"]
def compare_by_show(self, df1, df2, n: int = 20, truncate: int = 20):
from pyspark.sql.dataframe import DataFrame as SDF
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index 4f99a23b82d..b9b045541a6 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -151,7 +151,7 @@ def try_remote_functions(f: FuncT) -> FuncT:
@functools.wraps(f)
def wrapped(*args: Any, **kwargs: Any) -> Any:
- if is_remote():
+ if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
from pyspark.sql.connect import functions
return getattr(functions, f.__name__)(*args, **kwargs)
@@ -167,7 +167,7 @@ def try_remote_window(f: FuncT) -> FuncT:
@functools.wraps(f)
def wrapped(*args: Any, **kwargs: Any) -> Any:
- if is_remote():
+ if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
from pyspark.sql.connect.window import Window
return getattr(Window, f.__name__)(*args, **kwargs)
@@ -183,7 +183,7 @@ def try_remote_windowspec(f: FuncT) -> FuncT:
@functools.wraps(f)
def wrapped(*args: Any, **kwargs: Any) -> Any:
- if is_remote():
+ if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
from pyspark.sql.connect.window import WindowSpec
return getattr(WindowSpec, f.__name__)(*args, **kwargs)
@@ -199,7 +199,7 @@ def try_remote_observation(f: FuncT) -> FuncT:
@functools.wraps(f)
def wrapped(*args: Any, **kwargs: Any) -> Any:
# TODO(SPARK-41527): Add the support of Observation.
- if is_remote():
+ if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
raise NotImplementedError()
return f(*args, **kwargs)
diff --git a/python/pyspark/testing/connectutils.py
b/python/pyspark/testing/connectutils.py
index 64934c763c3..210d525ade7 100644
--- a/python/pyspark/testing/connectutils.py
+++ b/python/pyspark/testing/connectutils.py
@@ -55,7 +55,6 @@ except ImportError as e:
googleapis_common_protos_requirement_message = str(e)
have_googleapis_common_protos = googleapis_common_protos_requirement_message
is None
-connect_not_compiled_message = None
if (
have_pandas
and have_pyarrow
@@ -63,19 +62,7 @@ if (
and have_grpc_status
and have_googleapis_common_protos
):
- from pyspark.sql.connect import DataFrame
- from pyspark.sql.connect.plan import Read, Range, SQL
- from pyspark.testing.utils import search_jar
- from pyspark.sql.connect.session import SparkSession
-
- connect_jar = search_jar("connector/connect/server",
"spark-connect-assembly-", "spark-connect")
- existing_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell")
- connect_url = "--remote sc://localhost"
- jars_args = "--jars %s" % connect_jar
- plugin_args = "--conf
spark.plugins=org.apache.spark.sql.connect.SparkConnectPlugin"
- os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join(
- [connect_url, jars_args, plugin_args, existing_args]
- )
+ connect_not_compiled_message = None
else:
connect_not_compiled_message = (
"Skipping all Spark Connect Python tests as the optional Spark Connect
project was "
@@ -94,6 +81,11 @@ connect_requirement_message = (
)
should_test_connect: str = typing.cast(str, connect_requirement_message is
None)
+if should_test_connect:
+ from pyspark.sql.connect import DataFrame
+ from pyspark.sql.connect.plan import Read, Range, SQL
+ from pyspark.sql.connect.session import SparkSession
+
class MockRemoteSession:
def __init__(self):
diff --git a/python/pyspark/testing/pandasutils.py
b/python/pyspark/testing/pandasutils.py
index 6a828f10026..202603ca5c0 100644
--- a/python/pyspark/testing/pandasutils.py
+++ b/python/pyspark/testing/pandasutils.py
@@ -54,12 +54,7 @@ except ImportError as e:
have_plotly = plotly_requirement_message is None
-class PandasOnSparkTestCase(ReusedSQLTestCase):
- @classmethod
- def setUpClass(cls):
- super(PandasOnSparkTestCase, cls).setUpClass()
- cls.spark.conf.set(SPARK_CONF_ARROW_ENABLED, True)
-
+class PandasOnSparkTestUtils:
def convert_str_to_lambda(self, func):
"""
This function coverts `func` str to lambda call
@@ -248,6 +243,13 @@ class PandasOnSparkTestCase(ReusedSQLTestCase):
return obj
+class PandasOnSparkTestCase(ReusedSQLTestCase, PandasOnSparkTestUtils):
+ @classmethod
+ def setUpClass(cls):
+ super(PandasOnSparkTestCase, cls).setUpClass()
+ cls.spark.conf.set(SPARK_CONF_ARROW_ENABLED, True)
+
+
class TestUtils:
@contextmanager
def temp_dir(self):
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]