This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new b1998bd4145d [SPARK-47725][INFRA] Set up the CI for pyspark-connect package b1998bd4145d is described below commit b1998bd4145d60d9ea3b569b64604a0881335b17 Author: Hyukjin Kwon <gurwls...@apache.org> AuthorDate: Mon Apr 8 09:39:06 2024 -0700 [SPARK-47725][INFRA] Set up the CI for pyspark-connect package ### What changes were proposed in this pull request? This PR proposes to set up a scheduled job for `pyspark-connect` package. The CI: 1. Build Spark 2. Package `pyspark-connect` with test cases 3. Remove `python/lib/pyspark.zip` and `python/lib/py4j.zip` to make sure we don't use JVM 4. Run the test cases packaged together within `pyspark-connect`. ### Why are the changes needed? In order to make sure on the feature coverage in `pyspark-connect`. ### Does this PR introduce _any_ user-facing change? No, test-only. ### How was this patch tested? Manually tested in my fork, https://github.com/HyukjinKwon/spark/actions/runs/8598881063 ### Was this patch authored or co-authored using generative AI tooling? No. Closes #45870 from HyukjinKwon/do-not-merge-ci. Authored-by: Hyukjin Kwon <gurwls...@apache.org> Signed-off-by: Dongjoon Hyun <dh...@apple.com> --- .github/workflows/build_python_connect.yml | 96 ++++++++++++++++++++++ python/packaging/connect/setup.py | 16 +++- .../pyspark/pandas/data_type_ops/datetime_ops.py | 3 +- python/pyspark/sql/connect/avro/functions.py | 8 +- python/pyspark/sql/connect/catalog.py | 5 +- python/pyspark/sql/connect/column.py | 5 +- python/pyspark/sql/connect/conf.py | 5 +- python/pyspark/sql/connect/dataframe.py | 2 +- python/pyspark/sql/connect/functions/builtin.py | 3 +- .../pyspark/sql/connect/functions/partitioning.py | 3 +- python/pyspark/sql/connect/group.py | 5 +- python/pyspark/sql/connect/observation.py | 3 +- python/pyspark/sql/connect/protobuf/functions.py | 8 +- python/pyspark/sql/connect/readwriter.py | 3 +- python/pyspark/sql/connect/session.py | 5 +- python/pyspark/sql/connect/streaming/query.py | 2 +- python/pyspark/sql/connect/streaming/readwriter.py | 3 +- python/pyspark/sql/connect/window.py | 5 +- python/pyspark/sql/session.py | 2 +- .../sql/tests/connect/client/test_artifact.py | 6 +- .../sql/tests/connect/client/test_reattach.py | 2 + .../sql/tests/connect/test_connect_basic.py | 2 + .../sql/tests/connect/test_connect_function.py | 2 + .../sql/tests/connect/test_connect_session.py | 6 +- .../sql/tests/connect/test_parity_udf_profiler.py | 3 + .../pyspark/sql/tests/connect/test_parity_udtf.py | 11 ++- python/pyspark/sql/tests/connect/test_resources.py | 14 ++-- python/pyspark/sql/tests/test_arrow.py | 4 +- python/pyspark/sql/tests/test_udf.py | 11 ++- python/pyspark/sql/tests/test_udtf.py | 9 +- python/pyspark/tests/test_memory_profiler.py | 4 +- 31 files changed, 223 insertions(+), 33 deletions(-) diff --git a/.github/workflows/build_python_connect.yml b/.github/workflows/build_python_connect.yml new file mode 100644 index 000000000000..2f80eac9624f --- /dev/null +++ b/.github/workflows/build_python_connect.yml @@ -0,0 +1,96 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +name: Build / Spark Connect Python-only (master, Python 3.11) + +on: + schedule: + - cron: '0 19 * * *' + +jobs: + # Build: build Spark and run the tests for specified modules using SBT + build: + name: "Build modules: pyspark-connect" + runs-on: ubuntu-latest + timeout-minutes: 300 + steps: + - name: Checkout Spark repository + uses: actions/checkout@v4 + - name: Cache Scala, SBT and Maven + uses: actions/cache@v4 + with: + path: | + build/apache-maven-* + build/scala-* + build/*.jar + ~/.sbt + key: build-spark-connect-python-only-${{ hashFiles('**/pom.xml', 'project/build.properties', 'build/mvn', 'build/sbt', 'build/sbt-launch-lib.bash', 'build/spark-build-info') }} + restore-keys: | + build-spark-connect-python-only- + - name: Cache Coursier local repository + uses: actions/cache@v4 + with: + path: ~/.cache/coursier + key: coursier-build-spark-connect-python-only-${{ hashFiles('**/pom.xml') }} + restore-keys: | + coursier-build-spark-connect-python-only- + - name: Install Java 17 + uses: actions/setup-java@v4 + with: + distribution: zulu + java-version: 17 + - name: Install Python 3.11 + uses: actions/setup-python@v5 + with: + python-version: '3.11' + architecture: x64 + - name: Build Spark + run: | + ./build/sbt -Phive test:package + - name: Install pure Python package (pyspark-connect) + env: + SPARK_TESTING: 1 + run: | + cd python + python packaging/connect/setup.py sdist + cd dist + pip install pyspark-connect-*.tar.gz + - name: Run tests + env: + SPARK_CONNECT_TESTING_REMOTE: sc://localhost + SPARK_TESTING: 1 + run: | + # Start a Spark Connect server + ./sbin/start-connect-server.sh --jars `find connector/connect/server/target -name spark-connect*SNAPSHOT.jar` + # Remove Py4J and PySpark zipped library to make sure there is no JVM connection + rm python/lib/* + rm -r python/pyspark + ./python/run-tests --parallelism=1 --python-executables=python3 --modules pyspark-connect + - name: Upload test results to report + if: always() + uses: actions/upload-artifact@v4 + with: + name: test-results-spark-connect-python-only + path: "**/target/test-reports/*.xml" + - name: Upload unit tests log files + if: failure() + uses: actions/upload-artifact@v4 + with: + name: unit-tests-log-spark-connect-python-only + path: "**/target/unit-tests.log" diff --git a/python/packaging/connect/setup.py b/python/packaging/connect/setup.py index 01c5518d4451..782c55fff241 100755 --- a/python/packaging/connect/setup.py +++ b/python/packaging/connect/setup.py @@ -65,6 +65,20 @@ in_spark = os.path.isfile("../core/src/main/scala/org/apache/spark/SparkContext. os.path.isfile("../RELEASE") and len(glob.glob("../jars/spark*core*.jar")) == 1 ) +test_packages = [] +if "SPARK_TESTING" in os.environ: + test_packages = [ + "pyspark.tests", # for Memory profiler parity tests + "pyspark.testing", + "pyspark.sql.tests", + "pyspark.sql.tests.connect", + "pyspark.sql.tests.connect.streaming", + "pyspark.sql.tests.connect.client", + "pyspark.sql.tests.connect.shell", + "pyspark.sql.tests.pandas", + "pyspark.sql.tests.streaming", + ] + try: if in_spark: copyfile("packaging/connect/setup.py", "setup.py") @@ -136,7 +150,7 @@ try: author="Spark Developers", author_email="d...@spark.apache.org", url="https://github.com/apache/spark/tree/master/python", - packages=connect_packages, + packages=connect_packages + test_packages, license="http://www.apache.org/licenses/LICENSE-2.0", # Don't forget to update python/docs/source/getting_started/install.rst # if you're updating the versions or dependencies. diff --git a/python/pyspark/pandas/data_type_ops/datetime_ops.py b/python/pyspark/pandas/data_type_ops/datetime_ops.py index 8d5853b68246..9b4cc72fa2e4 100644 --- a/python/pyspark/pandas/data_type_ops/datetime_ops.py +++ b/python/pyspark/pandas/data_type_ops/datetime_ops.py @@ -23,7 +23,6 @@ import numpy as np import pandas as pd from pandas.api.types import CategoricalDtype -from pyspark import SparkContext from pyspark.sql import Column, functions as F from pyspark.sql.types import ( BooleanType, @@ -151,6 +150,8 @@ class DatetimeNTZOps(DatetimeOps): """ def _cast_spark_column_timestamp_to_long(self, scol: Column) -> Column: + from pyspark import SparkContext + jvm = SparkContext._active_spark_context._jvm return Column(jvm.PythonSQLUtils.castTimestampNTZToLong(scol._jc)) diff --git a/python/pyspark/sql/connect/avro/functions.py b/python/pyspark/sql/connect/avro/functions.py index 1d28fd077b18..43088333b108 100644 --- a/python/pyspark/sql/connect/avro/functions.py +++ b/python/pyspark/sql/connect/avro/functions.py @@ -80,12 +80,18 @@ def _test() -> None: import doctest from pyspark.sql import SparkSession as PySparkSession import pyspark.sql.connect.avro.functions + from pyspark.util import is_remote_only globs = pyspark.sql.connect.avro.functions.__dict__.copy() + # TODO(SPARK-47760): Reeanble Avro function doctests + if is_remote_only(): + del pyspark.sql.connect.avro.functions.from_avro + del pyspark.sql.connect.avro.functions.to_avro + globs["spark"] = ( PySparkSession.builder.appName("sql.connect.avro.functions tests") - .remote("local[4]") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) .getOrCreate() ) diff --git a/python/pyspark/sql/connect/catalog.py b/python/pyspark/sql/connect/catalog.py index ef1bff9d28c6..f9e31bdc7724 100644 --- a/python/pyspark/sql/connect/catalog.py +++ b/python/pyspark/sql/connect/catalog.py @@ -316,6 +316,7 @@ Catalog.__doc__ = PySparkCatalog.__doc__ def _test() -> None: + import os import sys import doctest from pyspark.sql import SparkSession as PySparkSession @@ -323,7 +324,9 @@ def _test() -> None: globs = pyspark.sql.connect.catalog.__dict__.copy() globs["spark"] = ( - PySparkSession.builder.appName("sql.connect.catalog tests").remote("local[4]").getOrCreate() + PySparkSession.builder.appName("sql.connect.catalog tests") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) + .getOrCreate() ) (failure_count, test_count) = doctest.testmod( diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index 719d592924ad..4436b36907a9 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -503,6 +503,7 @@ Column.__doc__ = PySparkColumn.__doc__ def _test() -> None: + import os import sys import doctest from pyspark.sql import SparkSession as PySparkSession @@ -510,7 +511,9 @@ def _test() -> None: globs = pyspark.sql.connect.column.__dict__.copy() globs["spark"] = ( - PySparkSession.builder.appName("sql.connect.column tests").remote("local[4]").getOrCreate() + PySparkSession.builder.appName("sql.connect.column tests") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) + .getOrCreate() ) (failure_count, test_count) = doctest.testmod( diff --git a/python/pyspark/sql/connect/conf.py b/python/pyspark/sql/connect/conf.py index 57a669aca889..2dc382da8143 100644 --- a/python/pyspark/sql/connect/conf.py +++ b/python/pyspark/sql/connect/conf.py @@ -122,6 +122,7 @@ RuntimeConf.__doc__ = PySparkRuntimeConfig.__doc__ def _test() -> None: + import os import sys import doctest from pyspark.sql import SparkSession as PySparkSession @@ -129,7 +130,9 @@ def _test() -> None: globs = pyspark.sql.connect.conf.__dict__.copy() globs["spark"] = ( - PySparkSession.builder.appName("sql.connect.conf tests").remote("local[4]").getOrCreate() + PySparkSession.builder.appName("sql.connect.conf tests") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) + .getOrCreate() ) (failure_count, test_count) = doctest.testmod( diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 576c196dbd2b..1dddcc078810 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -2279,7 +2279,7 @@ def _test() -> None: globs["spark"] = ( PySparkSession.builder.appName("sql.connect.dataframe tests") - .remote("local[4]") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) .getOrCreate() ) diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py index c2bf02023282..0a4733aac32d 100644 --- a/python/pyspark/sql/connect/functions/builtin.py +++ b/python/pyspark/sql/connect/functions/builtin.py @@ -4065,6 +4065,7 @@ call_function.__doc__ = pysparkfuncs.call_function.__doc__ def _test() -> None: import sys + import os import doctest from pyspark.sql import SparkSession as PySparkSession import pyspark.sql.connect.functions.builtin @@ -4073,7 +4074,7 @@ def _test() -> None: globs["spark"] = ( PySparkSession.builder.appName("sql.connect.functions tests") - .remote("local[4]") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) .getOrCreate() ) diff --git a/python/pyspark/sql/connect/functions/partitioning.py b/python/pyspark/sql/connect/functions/partitioning.py index ef319cad2e72..bfeddad7d568 100644 --- a/python/pyspark/sql/connect/functions/partitioning.py +++ b/python/pyspark/sql/connect/functions/partitioning.py @@ -81,6 +81,7 @@ hours.__doc__ = pysparkfuncs.partitioning.hours.__doc__ def _test() -> None: import sys + import os import doctest from pyspark.sql import SparkSession as PySparkSession import pyspark.sql.connect.functions.partitioning @@ -89,7 +90,7 @@ def _test() -> None: globs["spark"] = ( PySparkSession.builder.appName("sql.connect.functions tests") - .remote("local[4]") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) .getOrCreate() ) diff --git a/python/pyspark/sql/connect/group.py b/python/pyspark/sql/connect/group.py index 088bb000a344..b866f61efe4a 100644 --- a/python/pyspark/sql/connect/group.py +++ b/python/pyspark/sql/connect/group.py @@ -448,6 +448,7 @@ PandasCogroupedOps.__doc__ = PySparkPandasCogroupedOps.__doc__ def _test() -> None: + import os import sys import doctest from pyspark.sql import SparkSession as PySparkSession @@ -456,7 +457,9 @@ def _test() -> None: globs = pyspark.sql.connect.group.__dict__.copy() globs["spark"] = ( - PySparkSession.builder.appName("sql.connect.group tests").remote("local[4]").getOrCreate() + PySparkSession.builder.appName("sql.connect.group tests") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) + .getOrCreate() ) (failure_count, test_count) = doctest.testmod( diff --git a/python/pyspark/sql/connect/observation.py b/python/pyspark/sql/connect/observation.py index d88a62009995..4fefb8aac41f 100644 --- a/python/pyspark/sql/connect/observation.py +++ b/python/pyspark/sql/connect/observation.py @@ -82,6 +82,7 @@ Observation.__doc__ = PySparkObservation.__doc__ def _test() -> None: + import os import sys import doctest from pyspark.sql import SparkSession as PySparkSession @@ -90,7 +91,7 @@ def _test() -> None: globs = pyspark.sql.connect.observation.__dict__.copy() globs["spark"] = ( PySparkSession.builder.appName("sql.connect.observation tests") - .remote("local[4]") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) .getOrCreate() ) diff --git a/python/pyspark/sql/connect/protobuf/functions.py b/python/pyspark/sql/connect/protobuf/functions.py index 8bcc2218f06c..fcf1ed1ee02e 100644 --- a/python/pyspark/sql/connect/protobuf/functions.py +++ b/python/pyspark/sql/connect/protobuf/functions.py @@ -120,6 +120,7 @@ def _read_descriptor_set_file(filePath: str) -> bytes: def _test() -> None: import os import sys + from pyspark.util import is_remote_only from pyspark.testing.utils import search_jar protobuf_jar = search_jar("connector/protobuf", "spark-protobuf-assembly-", "spark-protobuf") @@ -142,9 +143,14 @@ def _test() -> None: globs = pyspark.sql.connect.protobuf.functions.__dict__.copy() + # TODO(SPARK-47763): Reeanble Protobuf function doctests + if is_remote_only(): + del pyspark.sql.connect.protobuf.functions.from_protobuf + del pyspark.sql.connect.protobuf.functions.to_protobuf + globs["spark"] = ( PySparkSession.builder.appName("sql.protobuf.functions tests") - .remote("local[2]") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")) .getOrCreate() ) diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py index 0e9c9128bdbf..bf7dc4d36905 100644 --- a/python/pyspark/sql/connect/readwriter.py +++ b/python/pyspark/sql/connect/readwriter.py @@ -934,6 +934,7 @@ class DataFrameWriterV2(OptionUtils): def _test() -> None: import sys + import os import doctest from pyspark.sql import SparkSession as PySparkSession import pyspark.sql.connect.readwriter @@ -942,7 +943,7 @@ def _test() -> None: globs["spark"] = ( PySparkSession.builder.appName("sql.connect.readwriter tests") - .remote("local[4]") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) .getOrCreate() ) diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index b19c420c3833..40a8076698bf 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -1023,6 +1023,7 @@ SparkSession.__doc__ = PySparkSession.__doc__ def _test() -> None: + import os import sys import doctest from pyspark.sql import SparkSession as PySparkSession @@ -1030,7 +1031,9 @@ def _test() -> None: globs = pyspark.sql.connect.session.__dict__.copy() globs["spark"] = ( - PySparkSession.builder.appName("sql.connect.session tests").remote("local[4]").getOrCreate() + PySparkSession.builder.appName("sql.connect.session tests") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) + .getOrCreate() ) # Uses PySpark session to test builder. diff --git a/python/pyspark/sql/connect/streaming/query.py b/python/pyspark/sql/connect/streaming/query.py index 65b480993636..c1940921c631 100644 --- a/python/pyspark/sql/connect/streaming/query.py +++ b/python/pyspark/sql/connect/streaming/query.py @@ -285,7 +285,7 @@ def _test() -> None: globs["spark"] = ( PySparkSession.builder.appName("sql.connect.streaming.query tests") - .remote("local[4]") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) .getOrCreate() ) diff --git a/python/pyspark/sql/connect/streaming/readwriter.py b/python/pyspark/sql/connect/streaming/readwriter.py index 11f230473fcf..ac0aca6d4b19 100644 --- a/python/pyspark/sql/connect/streaming/readwriter.py +++ b/python/pyspark/sql/connect/streaming/readwriter.py @@ -650,6 +650,7 @@ class DataStreamWriter: def _test() -> None: + import os import sys import doctest from pyspark.sql import SparkSession as PySparkSession @@ -659,7 +660,7 @@ def _test() -> None: globs["spark"] = ( PySparkSession.builder.appName("sql.connect.streaming.readwriter tests") - .remote("local[4]") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) .getOrCreate() ) diff --git a/python/pyspark/sql/connect/window.py b/python/pyspark/sql/connect/window.py index bab476e4782d..e30a5b7d7a9e 100644 --- a/python/pyspark/sql/connect/window.py +++ b/python/pyspark/sql/connect/window.py @@ -234,6 +234,7 @@ Window.__doc__ = PySparkWindow.__doc__ def _test() -> None: + import os import sys import doctest from pyspark.sql import SparkSession as PySparkSession @@ -241,7 +242,9 @@ def _test() -> None: globs = pyspark.sql.connect.window.__dict__.copy() globs["spark"] = ( - PySparkSession.builder.appName("sql.connect.window tests").remote("local[4]").getOrCreate() + PySparkSession.builder.appName("sql.connect.window tests") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) + .getOrCreate() ) (failure_count, test_count) = doctest.testmod( diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 11e0ef43b59f..c187122cdb40 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -861,7 +861,7 @@ class SparkSession(SparkConversionMixin): Create a temp view, show the list, and drop it. >>> spark.range(1).createTempView("test_view") - >>> spark.catalog.listTables() + >>> spark.catalog.listTables() # doctest: +SKIP [Table(name='test_view', catalog=None, namespace=[], description=None, ... >>> _ = spark.catalog.dropTempView("test_view") """ diff --git a/python/pyspark/sql/tests/connect/client/test_artifact.py b/python/pyspark/sql/tests/connect/client/test_artifact.py index f1cbf637b92a..f4f49ab25126 100644 --- a/python/pyspark/sql/tests/connect/client/test_artifact.py +++ b/python/pyspark/sql/tests/connect/client/test_artifact.py @@ -20,11 +20,11 @@ import tempfile import unittest import os +from pyspark.util import is_remote_only from pyspark.errors.exceptions.connect import SparkConnectGrpcException from pyspark.sql import SparkSession from pyspark.testing.connectutils import ReusedConnectTestCase, should_test_connect from pyspark.testing.utils import SPARK_HOME -from pyspark import SparkFiles from pyspark.sql.functions import udf if should_test_connect: @@ -174,9 +174,12 @@ class ArtifactTestsMixin: ) +@unittest.skipIf(is_remote_only(), "Requires JVM access") class ArtifactTests(ReusedConnectTestCase, ArtifactTestsMixin): @classmethod def root(cls): + from pyspark.core.files import SparkFiles + # In local mode, the file location is the same as Driver # The executors are running in a thread. jvm = SparkSession._instantiatedSession._jvm @@ -424,6 +427,7 @@ class ArtifactTests(ReusedConnectTestCase, ArtifactTestsMixin): ) +@unittest.skipIf(is_remote_only(), "Requires local cluster to run") class LocalClusterArtifactTests(ReusedConnectTestCase, ArtifactTestsMixin): @classmethod def conf(cls): diff --git a/python/pyspark/sql/tests/connect/client/test_reattach.py b/python/pyspark/sql/tests/connect/client/test_reattach.py index cea0be7008cc..64c81529ec14 100644 --- a/python/pyspark/sql/tests/connect/client/test_reattach.py +++ b/python/pyspark/sql/tests/connect/client/test_reattach.py @@ -18,6 +18,7 @@ import os import unittest +from pyspark.util import is_remote_only from pyspark.sql import SparkSession as PySparkSession from pyspark.testing.connectutils import ReusedConnectTestCase from pyspark.testing.pandasutils import PandasOnSparkTestUtils @@ -25,6 +26,7 @@ from pyspark.testing.sqlutils import SQLTestUtils from pyspark.testing.utils import eventually +@unittest.skipIf(is_remote_only(), "Requires JVM access") class SparkConnectReattachTestCase(ReusedConnectTestCase, SQLTestUtils, PandasOnSparkTestUtils): @classmethod def setUpClass(cls): diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 16e9a577451f..d6a498c1bfff 100755 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -20,6 +20,7 @@ import unittest import shutil import tempfile +from pyspark.util import is_remote_only from pyspark.errors import PySparkTypeError, PySparkValueError from pyspark.sql import SparkSession as PySparkSession, Row from pyspark.sql.types import ( @@ -52,6 +53,7 @@ if should_test_connect: from pyspark.sql.connect import functions as CF +@unittest.skipIf(is_remote_only(), "Requires JVM access") class SparkConnectSQLTestCase(ReusedConnectTestCase, SQLTestUtils, PandasOnSparkTestUtils): """Parent test fixture class for all Spark Connect related test cases.""" diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py b/python/pyspark/sql/tests/connect/test_connect_function.py index cb0d1bab7ffa..581fde3e6293 100644 --- a/python/pyspark/sql/tests/connect/test_connect_function.py +++ b/python/pyspark/sql/tests/connect/test_connect_function.py @@ -18,6 +18,7 @@ import os import unittest from inspect import getmembers, isfunction +from pyspark.util import is_remote_only from pyspark.errors import PySparkTypeError, PySparkValueError from pyspark.sql import SparkSession as PySparkSession from pyspark.sql.types import StringType, StructType, StructField, ArrayType, IntegerType @@ -37,6 +38,7 @@ if should_test_connect: from pyspark.sql.connect.dataframe import DataFrame as CDF +@unittest.skipIf(is_remote_only(), "Requires JVM access") class SparkConnectFunctionTests(ReusedConnectTestCase, PandasOnSparkTestUtils, SQLTestUtils): """These test cases exercise the interface to the proto plan generation but do not call Spark.""" diff --git a/python/pyspark/sql/tests/connect/test_connect_session.py b/python/pyspark/sql/tests/connect/test_connect_session.py index 186580046ef0..4d6127b5be8b 100644 --- a/python/pyspark/sql/tests/connect/test_connect_session.py +++ b/python/pyspark/sql/tests/connect/test_connect_session.py @@ -15,11 +15,12 @@ # limitations under the License. # +import os import unittest import uuid from collections import defaultdict - +from pyspark.util import is_remote_only from pyspark.errors import ( PySparkException, PySparkValueError, @@ -46,6 +47,7 @@ if should_test_connect: from pyspark.sql.connect.client.core import Retrying, SparkConnectClient +@unittest.skipIf(is_remote_only(), "Session creation different from local mode") class SparkConnectSessionTests(ReusedConnectTestCase): def setUp(self) -> None: self.spark = ( @@ -248,7 +250,7 @@ class SparkConnectSessionWithOptionsTest(unittest.TestCase): .config("integer", 1) .config("boolean", False) .appName(self.__class__.__name__) - .remote("local[4]") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) .getOrCreate() ) diff --git a/python/pyspark/sql/tests/connect/test_parity_udf_profiler.py b/python/pyspark/sql/tests/connect/test_parity_udf_profiler.py index dfa56ff0bb88..e682e46ca185 100644 --- a/python/pyspark/sql/tests/connect/test_parity_udf_profiler.py +++ b/python/pyspark/sql/tests/connect/test_parity_udf_profiler.py @@ -18,10 +18,13 @@ import inspect import os import unittest +from pyspark.util import is_remote_only from pyspark.sql.tests.test_udf_profiler import UDFProfiler2TestsMixin, _do_computation from pyspark.testing.connectutils import ReusedConnectTestCase +# TODO(SPARK-47756): Reeanble UDFProfilerParityTests for pyspark-connect +@unittest.skipIf(is_remote_only(), "Skipped for now") class UDFProfilerParityTests(UDFProfiler2TestsMixin, ReusedConnectTestCase): def setUp(self) -> None: super().setUp() diff --git a/python/pyspark/sql/tests/connect/test_parity_udtf.py b/python/pyspark/sql/tests/connect/test_parity_udtf.py index e12e697e582d..02570ac9efa7 100644 --- a/python/pyspark/sql/tests/connect/test_parity_udtf.py +++ b/python/pyspark/sql/tests/connect/test_parity_udtf.py @@ -23,8 +23,9 @@ if should_test_connect: from pyspark.sql.connect.udtf import UserDefinedTableFunction sql.udtf.UserDefinedTableFunction = UserDefinedTableFunction + from pyspark.sql.connect.functions import lit, udtf -from pyspark.sql.connect.functions import lit, udtf +from pyspark.util import is_remote_only from pyspark.sql.tests.test_udtf import BaseUDTFTestsMixin, UDTFArrowTestsMixin from pyspark.testing.connectutils import ReusedConnectTestCase from pyspark.errors.exceptions.connect import SparkConnectGrpcException @@ -67,6 +68,14 @@ class UDTFParityTests(BaseUDTFTestsMixin, ReusedConnectTestCase): def test_udtf_with_analyze_using_accumulator(self): super().test_udtf_with_analyze_using_accumulator() + @unittest.skipIf(is_remote_only(), "pyspark-connect does not have SparkFiles") + def test_udtf_with_analyze_using_archive(self): + super().test_udtf_with_analyze_using_archive() + + @unittest.skipIf(is_remote_only(), "pyspark-connect does not have SparkFiles") + def test_udtf_with_analyze_using_file(self): + super().test_udtf_with_analyze_using_file() + def _add_pyfile(self, path): self.spark.addArtifacts(path, pyfile=True) diff --git a/python/pyspark/sql/tests/connect/test_resources.py b/python/pyspark/sql/tests/connect/test_resources.py index b4cc138c4df8..931acd929804 100644 --- a/python/pyspark/sql/tests/connect/test_resources.py +++ b/python/pyspark/sql/tests/connect/test_resources.py @@ -16,14 +16,18 @@ # import unittest +from pyspark.util import is_remote_only from pyspark.testing.connectutils import ReusedConnectTestCase -from pyspark.sql.tests.test_resources import ResourceProfileTestsMixin -class ResourceProfileTests(ResourceProfileTestsMixin, ReusedConnectTestCase): - @classmethod - def master(cls): - return "local-cluster[1, 4, 1024]" +# TODO(SPARK-47757): Reeanble ResourceProfileTests for pyspark-connect +if not is_remote_only(): + from pyspark.sql.tests.test_resources import ResourceProfileTestsMixin + + class ResourceProfileTests(ResourceProfileTestsMixin, ReusedConnectTestCase): + @classmethod + def master(cls): + return "local-cluster[1, 4, 1024]" if __name__ == "__main__": diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index d1462b7f3987..5235e021bae9 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -25,7 +25,7 @@ from typing import cast from collections import namedtuple import sys -from pyspark import SparkContext, SparkConf +from pyspark import SparkConf from pyspark.sql import Row, SparkSession from pyspark.sql.functions import rand, udf, assert_true, lit from pyspark.sql.types import ( @@ -1202,6 +1202,8 @@ class MaxResultArrowTests(unittest.TestCase): @classmethod def setUpClass(cls): + from pyspark import SparkContext + cls.spark = SparkSession( SparkContext( "local[4]", cls.__name__, conf=SparkConf().set("spark.driver.maxResultSize", "10k") diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py index 84c6089bab36..d76572531b73 100644 --- a/python/pyspark/sql/tests/test_udf.py +++ b/python/pyspark/sql/tests/test_udf.py @@ -23,7 +23,6 @@ import tempfile import unittest import datetime -from pyspark import SparkContext, SQLContext from pyspark.sql import SparkSession, Column, Row from pyspark.sql.functions import col, udf, assert_true, lit, rand from pyspark.sql.udf import UserDefinedFunction @@ -80,6 +79,8 @@ class BaseUDFTestsMixin(object): self.assertEqual(row[0], 5) def test_udf_on_sql_context(self): + from pyspark import SQLContext + # This is to check if a deprecated 'SQLContext.registerFunction' can call its alias. sqlContext = SQLContext.getOrCreate(self.spark.sparkContext) sqlContext.registerFunction("oneArg", lambda x: len(x), IntegerType()) @@ -394,6 +395,8 @@ class BaseUDFTestsMixin(object): ) def test_udf_registration_returns_udf_on_sql_context(self): + from pyspark import SQLContext + df = self.spark.range(10) # This is to check if a 'SQLContext.udf' can call its alias. @@ -458,6 +461,8 @@ class BaseUDFTestsMixin(object): ) def test_non_existed_udf_with_sql_context(self): + from pyspark import SQLContext + # This is to check if a deprecated 'SQLContext.registerJavaFunction' can call its alias. sqlContext = SQLContext.getOrCreate(self.spark.sparkContext) self.assertRaisesRegex( @@ -1096,6 +1101,8 @@ class UDFTests(BaseUDFTestsMixin, ReusedSQLTestCase): class UDFInitializationTests(unittest.TestCase): def tearDown(self): + from pyspark import SparkContext + if SparkSession._instantiatedSession is not None: SparkSession._instantiatedSession.stop() @@ -1103,6 +1110,8 @@ class UDFInitializationTests(unittest.TestCase): SparkContext._active_spark_context.stop() def test_udf_init_should_not_initialize_context(self): + from pyspark import SparkContext + UserDefinedFunction(lambda x: x, StringType()) self.assertIsNone( diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index 0d2582b51fe1..923fe4a2a8e8 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -21,8 +21,6 @@ import unittest from dataclasses import dataclass from typing import Iterator, Optional -from py4j.protocol import Py4JJavaError - from pyspark.errors import ( PySparkAttributeError, PythonException, @@ -30,7 +28,6 @@ from pyspark.errors import ( AnalysisException, PySparkPicklingError, ) -from pyspark.core.files import SparkFiles from pyspark.util import PythonEvalType from pyspark.sql.functions import ( array, @@ -858,6 +855,8 @@ class BaseUDTFTestsMixin: self._check_result_or_exception(TestUDTF, ret_type, expected) def test_struct_output_type_casting_row(self): + from py4j.protocol import Py4JJavaError + self.check_struct_output_type_casting_row(Py4JJavaError) def check_struct_output_type_casting_row(self, error_type): @@ -1800,6 +1799,8 @@ class BaseUDTFTestsMixin: self.sc.addArchive(path) def test_udtf_with_analyze_using_archive(self): + from pyspark.core.files import SparkFiles + with tempfile.TemporaryDirectory(prefix="test_udtf_with_analyze_using_archive") as d: archive_path = os.path.join(d, "my_archive") os.mkdir(archive_path) @@ -1847,6 +1848,8 @@ class BaseUDTFTestsMixin: self.sc.addFile(path) def test_udtf_with_analyze_using_file(self): + from pyspark.core.files import SparkFiles + with tempfile.TemporaryDirectory(prefix="test_udtf_with_analyze_using_file") as d: file_path = os.path.join(d, "my_file.txt") with open(file_path, "w") as f: diff --git a/python/pyspark/tests/test_memory_profiler.py b/python/pyspark/tests/test_memory_profiler.py index 046dd3621c42..0e921b48afc4 100644 --- a/python/pyspark/tests/test_memory_profiler.py +++ b/python/pyspark/tests/test_memory_profiler.py @@ -26,7 +26,7 @@ from io import StringIO from typing import cast, Iterator from unittest import mock -from pyspark import SparkConf, SparkContext +from pyspark import SparkConf from pyspark.profiler import has_memory_profiler from pyspark.sql import SparkSession from pyspark.sql.functions import col, pandas_udf, udf @@ -61,6 +61,8 @@ def _do_computation(spark, *, action=lambda df: df.collect(), use_arrow=False): @unittest.skipIf(not have_pandas, pandas_requirement_message) class MemoryProfilerTests(PySparkTestCase): def setUp(self): + from pyspark import SparkContext + self._old_sys_path = list(sys.path) class_name = self.__class__.__name__ conf = SparkConf().set("spark.python.profile.memory", "true") --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org