This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new aa568354725c [SPARK-47811][PYTHON][CONNECT][TESTS] Run ML tests for pyspark-connect package aa568354725c is described below commit aa568354725ce44fc0261973b97597ab0986edb1 Author: Hyukjin Kwon <gurwls...@apache.org> AuthorDate: Fri Apr 12 09:02:47 2024 +0900 [SPARK-47811][PYTHON][CONNECT][TESTS] Run ML tests for pyspark-connect package ### What changes were proposed in this pull request? This PR proposes to extends `pyspark-connect` scheduled job to run ML tests as well. ### Why are the changes needed? In order to make sure pure Python library works with ML. ### Does this PR introduce _any_ user-facing change? No, test-only. ### How was this patch tested? Tested in my fork: https://github.com/HyukjinKwon/spark/actions/runs/8643632135/job/23697401430 ### Was this patch authored or co-authored using generative AI tooling? No Closes #45941 from HyukjinKwon/test-ps-ci. Authored-by: Hyukjin Kwon <gurwls...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .github/workflows/build_python_connect.yml | 3 +- python/packaging/connect/setup.py | 1 + python/pyspark/ml/connect/classification.py | 1 - python/pyspark/ml/param/__init__.py | 7 +- .../tests/connect/test_connect_classification.py | 10 +- .../ml/tests/connect/test_connect_evaluation.py | 5 +- .../ml/tests/connect/test_connect_feature.py | 5 +- .../ml/tests/connect/test_connect_function.py | 2 + .../ml/tests/connect/test_connect_pipeline.py | 11 +- .../ml/tests/connect/test_connect_summarizer.py | 5 +- .../ml/tests/connect/test_connect_tuning.py | 9 +- .../connect/test_legacy_mode_classification.py | 8 +- .../tests/connect/test_legacy_mode_evaluation.py | 9 +- .../ml/tests/connect/test_legacy_mode_feature.py | 6 +- .../ml/tests/connect/test_legacy_mode_pipeline.py | 6 +- .../tests/connect/test_legacy_mode_summarizer.py | 6 +- .../ml/tests/connect/test_legacy_mode_tuning.py | 9 +- .../tests/connect/test_parity_torch_data_loader.py | 28 ++- .../tests/connect/test_parity_torch_distributor.py | 232 +++++++++++---------- 19 files changed, 218 insertions(+), 145 deletions(-) diff --git a/.github/workflows/build_python_connect.yml b/.github/workflows/build_python_connect.yml index ec7103e5dbeb..8deee026131e 100644 --- a/.github/workflows/build_python_connect.yml +++ b/.github/workflows/build_python_connect.yml @@ -72,6 +72,7 @@ jobs: python packaging/connect/setup.py sdist cd dist pip install pyspark-connect-*.tar.gz + pip install scikit-learn torch torchvision torcheval - name: Run tests env: SPARK_CONNECT_TESTING_REMOTE: sc://localhost @@ -82,7 +83,7 @@ jobs: # Remove Py4J and PySpark zipped library to make sure there is no JVM connection rm python/lib/* rm -r python/pyspark - ./python/run-tests --parallelism=1 --python-executables=python3 --modules pyspark-connect + ./python/run-tests --parallelism=1 --python-executables=python3 --modules pyspark-connect,pyspark-ml-connect - name: Upload test results to report if: always() uses: actions/upload-artifact@v4 diff --git a/python/packaging/connect/setup.py b/python/packaging/connect/setup.py index 3514e5cdc422..419ed36b4236 100755 --- a/python/packaging/connect/setup.py +++ b/python/packaging/connect/setup.py @@ -77,6 +77,7 @@ if "SPARK_TESTING" in os.environ: "pyspark.sql.tests.connect.shell", "pyspark.sql.tests.pandas", "pyspark.sql.tests.streaming", + "pyspark.ml.tests.connect", ] try: diff --git a/python/pyspark/ml/connect/classification.py b/python/pyspark/ml/connect/classification.py index 8d8c6227eac3..fc7b5cda88a2 100644 --- a/python/pyspark/ml/connect/classification.py +++ b/python/pyspark/ml/connect/classification.py @@ -320,7 +320,6 @@ class LogisticRegressionModel( def _get_transform_fn(self) -> Callable[["pd.Series"], Any]: import torch - import torch.nn as torch_nn model_state_dict = self.torch_model.state_dict() diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 345b7f7a5964..f32ead2a580c 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -30,8 +30,8 @@ from typing import ( ) import numpy as np -from py4j.java_gateway import JavaObject +from pyspark.util import is_remote_only from pyspark.ml.linalg import DenseVector, Vector, Matrix from pyspark.ml.util import Identifiable @@ -516,9 +516,12 @@ class Params(Identifiable, metaclass=ABCMeta): """ Sets default params. """ + if not is_remote_only(): + from py4j.java_gateway import JavaObject + for param, value in kwargs.items(): p = getattr(self, param) - if value is not None and not isinstance(value, JavaObject): + if value is not None and (is_remote_only() or not isinstance(value, JavaObject)): try: value = p.typeConverter(value) except TypeError as e: diff --git a/python/pyspark/ml/tests/connect/test_connect_classification.py b/python/pyspark/ml/tests/connect/test_connect_classification.py index ebc1745874d9..8083090523a0 100644 --- a/python/pyspark/ml/tests/connect/test_connect_classification.py +++ b/python/pyspark/ml/tests/connect/test_connect_classification.py @@ -17,7 +17,9 @@ # import unittest +import os +from pyspark.util import is_remote_only from pyspark.sql import SparkSession from pyspark.testing.connectutils import should_test_connect, connect_requirement_message @@ -33,13 +35,15 @@ if should_test_connect: @unittest.skipIf( - not should_test_connect or not have_torch, - connect_requirement_message or torch_requirement_message, + not should_test_connect or not have_torch or is_remote_only(), + connect_requirement_message + or torch_requirement_message + or "Requires PySpark core library in Spark Connect server", ) class ClassificationTestsOnConnect(ClassificationTestsMixin, unittest.TestCase): def setUp(self) -> None: self.spark = ( - SparkSession.builder.remote("local[2]") + SparkSession.builder.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")) .config("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", "true") .getOrCreate() ) diff --git a/python/pyspark/ml/tests/connect/test_connect_evaluation.py b/python/pyspark/ml/tests/connect/test_connect_evaluation.py index 7f3b6bd0198c..359a77bbcb20 100644 --- a/python/pyspark/ml/tests/connect/test_connect_evaluation.py +++ b/python/pyspark/ml/tests/connect/test_connect_evaluation.py @@ -15,6 +15,7 @@ # limitations under the License. # +import os import unittest from pyspark.sql import SparkSession @@ -36,7 +37,9 @@ if should_test_connect: ) class EvaluationTestsOnConnect(EvaluationTestsMixin, unittest.TestCase): def setUp(self) -> None: - self.spark = SparkSession.builder.remote("local[2]").getOrCreate() + self.spark = SparkSession.builder.remote( + os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]") + ).getOrCreate() def tearDown(self) -> None: self.spark.stop() diff --git a/python/pyspark/ml/tests/connect/test_connect_feature.py b/python/pyspark/ml/tests/connect/test_connect_feature.py index 04b1744c4995..c786ce2f87d0 100644 --- a/python/pyspark/ml/tests/connect/test_connect_feature.py +++ b/python/pyspark/ml/tests/connect/test_connect_feature.py @@ -15,6 +15,7 @@ # limitations under the License. # +import os import unittest from pyspark.sql import SparkSession @@ -38,7 +39,9 @@ if should_test_connect: ) class FeatureTestsOnConnect(FeatureTestsMixin, unittest.TestCase): def setUp(self) -> None: - self.spark = SparkSession.builder.remote("local[2]").getOrCreate() + self.spark = SparkSession.builder.remote( + os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]") + ).getOrCreate() def tearDown(self) -> None: self.spark.stop() diff --git a/python/pyspark/ml/tests/connect/test_connect_function.py b/python/pyspark/ml/tests/connect/test_connect_function.py index b38d415e2bb2..f50376110660 100644 --- a/python/pyspark/ml/tests/connect/test_connect_function.py +++ b/python/pyspark/ml/tests/connect/test_connect_function.py @@ -17,6 +17,7 @@ import os import unittest +from pyspark.util import is_remote_only from pyspark.sql import SparkSession as PySparkSession from pyspark.sql.dataframe import DataFrame as SDF from pyspark.ml import functions as SF @@ -32,6 +33,7 @@ if should_test_connect: from pyspark.ml.connect import functions as CF +@unittest.skipIf(is_remote_only(), "Requires JVM access") class SparkConnectMLFunctionTests(ReusedConnectTestCase, PandasOnSparkTestUtils, SQLTestUtils): """These test cases exercise the interface to the proto plan generation but do not call Spark.""" diff --git a/python/pyspark/ml/tests/connect/test_connect_pipeline.py b/python/pyspark/ml/tests/connect/test_connect_pipeline.py index 45d19f2bcdde..4105f593f170 100644 --- a/python/pyspark/ml/tests/connect/test_connect_pipeline.py +++ b/python/pyspark/ml/tests/connect/test_connect_pipeline.py @@ -15,9 +15,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import os import unittest +from pyspark.util import is_remote_only from pyspark.sql import SparkSession from pyspark.testing.connectutils import should_test_connect, connect_requirement_message @@ -34,13 +35,15 @@ except ImportError: @unittest.skipIf( - not should_test_connect or not have_torch, - connect_requirement_message or torch_requirement_message, + not should_test_connect or not have_torch or is_remote_only(), + connect_requirement_message + or torch_requirement_message + or "Requires PySpark core library in Spark Connect server", ) class PipelineTestsOnConnect(PipelineTestsMixin, unittest.TestCase): def setUp(self) -> None: self.spark = ( - SparkSession.builder.remote("local[2]") + SparkSession.builder.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")) .config("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", "true") .getOrCreate() ) diff --git a/python/pyspark/ml/tests/connect/test_connect_summarizer.py b/python/pyspark/ml/tests/connect/test_connect_summarizer.py index 866a3468388d..1cfd2ed229e5 100644 --- a/python/pyspark/ml/tests/connect/test_connect_summarizer.py +++ b/python/pyspark/ml/tests/connect/test_connect_summarizer.py @@ -16,6 +16,7 @@ # import unittest +import os from pyspark.sql import SparkSession from pyspark.testing.connectutils import should_test_connect, connect_requirement_message @@ -27,7 +28,9 @@ if should_test_connect: @unittest.skipIf(not should_test_connect, connect_requirement_message) class SummarizerTestsOnConnect(SummarizerTestsMixin, unittest.TestCase): def setUp(self) -> None: - self.spark = SparkSession.builder.remote("local[2]").getOrCreate() + self.spark = SparkSession.builder.remote( + os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]") + ).getOrCreate() def tearDown(self) -> None: self.spark.stop() diff --git a/python/pyspark/ml/tests/connect/test_connect_tuning.py b/python/pyspark/ml/tests/connect/test_connect_tuning.py index 7b10d91da064..d5fcb93099b6 100644 --- a/python/pyspark/ml/tests/connect/test_connect_tuning.py +++ b/python/pyspark/ml/tests/connect/test_connect_tuning.py @@ -17,7 +17,9 @@ # import unittest +import os +from pyspark.util import is_remote_only from pyspark.sql import SparkSession from pyspark.testing.connectutils import should_test_connect, connect_requirement_message @@ -25,11 +27,14 @@ if should_test_connect: from pyspark.ml.tests.connect.test_legacy_mode_tuning import CrossValidatorTestsMixin -@unittest.skipIf(not should_test_connect, connect_requirement_message) +@unittest.skipIf( + not should_test_connect or is_remote_only(), + connect_requirement_message or "Requires PySpark core library in Spark Connect server", +) class CrossValidatorTestsOnConnect(CrossValidatorTestsMixin, unittest.TestCase): def setUp(self) -> None: self.spark = ( - SparkSession.builder.remote("local[2]") + SparkSession.builder.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")) .config("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", "true") .getOrCreate() ) diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_classification.py b/python/pyspark/ml/tests/connect/test_legacy_mode_classification.py index db9a29804808..dc2642a42d66 100644 --- a/python/pyspark/ml/tests/connect/test_legacy_mode_classification.py +++ b/python/pyspark/ml/tests/connect/test_legacy_mode_classification.py @@ -21,14 +21,17 @@ import unittest import numpy as np +from pyspark.util import is_remote_only from pyspark.sql import SparkSession from pyspark.testing.connectutils import should_test_connect, connect_requirement_message have_torch = True +torch_requirement_message = None try: import torch # noqa: F401 except ImportError: have_torch = False + torch_requirement_message = "No torch found" if should_test_connect: from pyspark.ml.connect.classification import ( @@ -228,7 +231,10 @@ class ClassificationTestsMixin: @unittest.skipIf( - not should_test_connect or not have_torch, connect_requirement_message or "No torch found" + not should_test_connect or not have_torch or is_remote_only(), + connect_requirement_message + or torch_requirement_message + or "pyspark-connect cannot test classic Spark", ) class ClassificationTests(ClassificationTestsMixin, unittest.TestCase): def setUp(self) -> None: diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py b/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py index ae01031ff462..11c1f9aeee51 100644 --- a/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py +++ b/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py @@ -20,14 +20,17 @@ import tempfile import numpy as np +from pyspark.util import is_remote_only from pyspark.sql import SparkSession from pyspark.testing.connectutils import should_test_connect, connect_requirement_message have_torcheval = True +torcheval_requirement_message = None try: import torcheval # noqa: F401 except ImportError: have_torcheval = False + torcheval_requirement_message = "torcheval is required" if should_test_connect: from pyspark.ml.connect.evaluation import ( @@ -177,8 +180,10 @@ class EvaluationTestsMixin: @unittest.skipIf( - not should_test_connect or not have_torcheval, - connect_requirement_message or "torcheval is required", + not should_test_connect or not have_torcheval or is_remote_only(), + connect_requirement_message + or torcheval_requirement_message + or "pyspark-connect cannot test classic Spark", ) class EvaluationTests(EvaluationTestsMixin, unittest.TestCase): def setUp(self) -> None: diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py b/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py index 9565b3a09a5b..4915d4706b87 100644 --- a/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py +++ b/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py @@ -23,6 +23,7 @@ import unittest import numpy as np +from pyspark.util import is_remote_only from pyspark.sql import SparkSession from pyspark.testing.connectutils import should_test_connect, connect_requirement_message @@ -194,7 +195,10 @@ class FeatureTestsMixin: assembler2.transform(pandas_df)["out"].tolist() -@unittest.skipIf(not should_test_connect, connect_requirement_message) +@unittest.skipIf( + not should_test_connect or is_remote_only(), + connect_requirement_message or "pyspark-connect cannot test classic Spark", +) class FeatureTests(FeatureTestsMixin, unittest.TestCase): def setUp(self) -> None: self.spark = SparkSession.builder.master("local[2]").getOrCreate() diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_pipeline.py b/python/pyspark/ml/tests/connect/test_legacy_mode_pipeline.py index 104aff17e0b2..692144148af0 100644 --- a/python/pyspark/ml/tests/connect/test_legacy_mode_pipeline.py +++ b/python/pyspark/ml/tests/connect/test_legacy_mode_pipeline.py @@ -21,6 +21,7 @@ import unittest import numpy as np +from pyspark.util import is_remote_only from pyspark.sql import SparkSession from pyspark.testing.connectutils import should_test_connect, connect_requirement_message @@ -167,7 +168,10 @@ class PipelineTestsMixin: assert lorv2.getOrDefault(lorv2.maxIter) == 200 -@unittest.skipIf(not should_test_connect, connect_requirement_message) +@unittest.skipIf( + not should_test_connect or is_remote_only(), + connect_requirement_message or "pyspark-connect cannot test classic Spark", +) class PipelineTests(PipelineTestsMixin, unittest.TestCase): def setUp(self) -> None: self.spark = SparkSession.builder.master("local[2]").getOrCreate() diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_summarizer.py b/python/pyspark/ml/tests/connect/test_legacy_mode_summarizer.py index 7f09eb9f0742..253632a74c97 100644 --- a/python/pyspark/ml/tests/connect/test_legacy_mode_summarizer.py +++ b/python/pyspark/ml/tests/connect/test_legacy_mode_summarizer.py @@ -20,6 +20,7 @@ import unittest import numpy as np +from pyspark.util import is_remote_only from pyspark.sql import SparkSession from pyspark.testing.connectutils import should_test_connect, connect_requirement_message @@ -62,7 +63,10 @@ class SummarizerTestsMixin: assert_dict_allclose(result_local, expected_result) -@unittest.skipIf(not should_test_connect, connect_requirement_message) +@unittest.skipIf( + not should_test_connect or is_remote_only(), + connect_requirement_message or "pyspark-connect cannot test classic Spark", +) class SummarizerTests(SummarizerTestsMixin, unittest.TestCase): def setUp(self) -> None: self.spark = SparkSession.builder.master("local[2]").getOrCreate() diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py b/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py index 7f26788c137f..14f52d75e6d6 100644 --- a/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py +++ b/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py @@ -22,6 +22,7 @@ import sys import numpy as np +from pyspark.util import is_remote_only from pyspark.ml.param import Param, Params from pyspark.ml.tuning import ParamGridBuilder from pyspark.sql import SparkSession @@ -29,10 +30,13 @@ from pyspark.sql.functions import rand from pyspark.testing.connectutils import should_test_connect, connect_requirement_message have_sklearn = True +sklearn_requirement_message = None try: from sklearn.datasets import load_breast_cancer # noqa: F401 except ImportError: have_sklearn = False + sklearn_requirement_message = "No sklearn found" + if should_test_connect: import pandas as pd @@ -279,7 +283,10 @@ class CrossValidatorTestsMixin: @unittest.skipIf( - not should_test_connect or not have_sklearn, connect_requirement_message or "No sklearn found" + not should_test_connect or not have_sklearn or is_remote_only(), + connect_requirement_message + or sklearn_requirement_message + or "pyspark-connect cannot test classic Spark", ) class CrossValidatorTests(CrossValidatorTestsMixin, unittest.TestCase): def setUp(self) -> None: diff --git a/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py b/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py index 1984efdc6c6e..462fe3822141 100644 --- a/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py +++ b/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py @@ -17,24 +17,30 @@ import unittest +from pyspark.util import is_remote_only from pyspark.sql import SparkSession -from pyspark.ml.torch.tests.test_data_loader import TorchDistributorDataLoaderUnitTests +torch_requirement_message = None have_torch = True try: import torch # noqa: F401 except ImportError: have_torch = False - - -@unittest.skipIf(not have_torch, "torch is required") -class TorchDistributorBaselineUnitTestsOnConnect(TorchDistributorDataLoaderUnitTests): - def setUp(self) -> None: - self.spark = ( - SparkSession.builder.remote("local[1]") - .config("spark.default.parallelism", "1") - .getOrCreate() - ) + torch_requirement_message = "torch is required" + +if not is_remote_only(): + from pyspark.ml.torch.tests.test_data_loader import TorchDistributorDataLoaderUnitTests + + @unittest.skipIf( + not have_torch or is_remote_only(), torch_requirement_message or "Requires JVM access" + ) + class TorchDistributorBaselineUnitTestsOnConnect(TorchDistributorDataLoaderUnitTests): + def setUp(self) -> None: + self.spark = ( + SparkSession.builder.remote("local[1]") + .config("spark.default.parallelism", "1") + .getOrCreate() + ) if __name__ == "__main__": diff --git a/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py b/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py index 70aa80ba6d11..e40303ae9ce2 100644 --- a/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py +++ b/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py @@ -19,124 +19,134 @@ import os import shutil import unittest +torch_requirement_message = None have_torch = True try: import torch # noqa: F401 except ImportError: have_torch = False + torch_requirement_message = "torch is required" +from pyspark.util import is_remote_only from pyspark.sql import SparkSession -from pyspark.ml.torch.tests.test_distributor import ( - TorchDistributorBaselineUnitTestsMixin, - TorchDistributorLocalUnitTestsMixin, - TorchDistributorDistributedUnitTestsMixin, - TorchWrapperUnitTestsMixin, - set_up_test_dirs, - get_local_mode_conf, - get_distributed_mode_conf, -) - - -@unittest.skipIf(not have_torch, "torch is required") -class TorchDistributorBaselineUnitTestsOnConnect( - TorchDistributorBaselineUnitTestsMixin, unittest.TestCase -): - @classmethod - def setUpClass(cls): - cls.spark = SparkSession.builder.remote("local[4]").getOrCreate() - - @classmethod - def tearDownClass(cls): - cls.spark.stop() - - -@unittest.skipIf(not have_torch, "torch is required") -class TorchDistributorLocalUnitTestsOnConnect( - TorchDistributorLocalUnitTestsMixin, unittest.TestCase -): - @classmethod - def setUpClass(cls): - (cls.gpu_discovery_script_file_name, cls.mnist_dir_path) = set_up_test_dirs() - builder = SparkSession.builder.appName(cls.__name__) - for k, v in get_local_mode_conf().items(): - builder = builder.config(k, v) - builder = builder.config( - "spark.driver.resource.gpu.discoveryScript", cls.gpu_discovery_script_file_name - ) - cls.spark = builder.remote("local-cluster[2,2,512]").getOrCreate() - - @classmethod - def tearDownClass(cls): - shutil.rmtree(cls.mnist_dir_path) - os.unlink(cls.gpu_discovery_script_file_name) - cls.spark.stop() - - def _get_inputs_for_test_local_training_succeeds(self): - return [ - ("0,1,2", 1, True, "0,1,2"), - ("0,1,2", 3, True, "0,1,2"), - ("0,1,2", 2, False, "0,1,2"), - (None, 3, False, "NONE"), - ] - - -@unittest.skipIf(not have_torch, "torch is required") -class TorchDistributorLocalUnitTestsIIOnConnect( - TorchDistributorLocalUnitTestsMixin, unittest.TestCase -): - @classmethod - def setUpClass(cls): - (cls.gpu_discovery_script_file_name, cls.mnist_dir_path) = set_up_test_dirs() - builder = SparkSession.builder.appName(cls.__name__) - for k, v in get_local_mode_conf().items(): - builder = builder.config(k, v) - - builder = builder.config( - "spark.driver.resource.gpu.discoveryScript", cls.gpu_discovery_script_file_name - ) - cls.spark = builder.remote("local[4]").getOrCreate() - - @classmethod - def tearDownClass(cls): - shutil.rmtree(cls.mnist_dir_path) - os.unlink(cls.gpu_discovery_script_file_name) - cls.spark.stop() - - def _get_inputs_for_test_local_training_succeeds(self): - return [ - ("0,1,2", 1, True, "0,1,2"), - ("0,1,2", 3, True, "0,1,2"), - ("0,1,2", 2, False, "0,1,2"), - (None, 3, False, "NONE"), - ] - - -@unittest.skipIf(not have_torch, "torch is required") -class TorchDistributorDistributedUnitTestsOnConnect( - TorchDistributorDistributedUnitTestsMixin, unittest.TestCase -): - @classmethod - def setUpClass(cls): - (cls.gpu_discovery_script_file_name, cls.mnist_dir_path) = set_up_test_dirs() - builder = SparkSession.builder.appName(cls.__name__) - for k, v in get_distributed_mode_conf().items(): - builder = builder.config(k, v) - - builder = builder.config( - "spark.worker.resource.gpu.discoveryScript", cls.gpu_discovery_script_file_name - ) - cls.spark = builder.remote("local-cluster[2,2,512]").getOrCreate() - - @classmethod - def tearDownClass(cls): - shutil.rmtree(cls.mnist_dir_path) - os.unlink(cls.gpu_discovery_script_file_name) - cls.spark.stop() - - -@unittest.skipIf(not have_torch, "torch is required") -class TorchWrapperUnitTestsOnConnect(TorchWrapperUnitTestsMixin, unittest.TestCase): - pass + +if not is_remote_only(): + from pyspark.ml.torch.tests.test_distributor import ( + TorchDistributorBaselineUnitTestsMixin, + TorchDistributorLocalUnitTestsMixin, + TorchDistributorDistributedUnitTestsMixin, + TorchWrapperUnitTestsMixin, + set_up_test_dirs, + get_local_mode_conf, + get_distributed_mode_conf, + ) + + @unittest.skipIf( + not have_torch or is_remote_only(), torch_requirement_message or "Requires JVM access" + ) + class TorchDistributorBaselineUnitTestsOnConnect( + TorchDistributorBaselineUnitTestsMixin, unittest.TestCase + ): + @classmethod + def setUpClass(cls): + cls.spark = SparkSession.builder.remote("local[4]").getOrCreate() + + @classmethod + def tearDownClass(cls): + cls.spark.stop() + + @unittest.skipIf( + not have_torch or is_remote_only(), torch_requirement_message or "Requires JVM access" + ) + class TorchDistributorLocalUnitTestsOnConnect( + TorchDistributorLocalUnitTestsMixin, unittest.TestCase + ): + @classmethod + def setUpClass(cls): + (cls.gpu_discovery_script_file_name, cls.mnist_dir_path) = set_up_test_dirs() + builder = SparkSession.builder.appName(cls.__name__) + for k, v in get_local_mode_conf().items(): + builder = builder.config(k, v) + builder = builder.config( + "spark.driver.resource.gpu.discoveryScript", cls.gpu_discovery_script_file_name + ) + cls.spark = builder.remote("local-cluster[2,2,512]").getOrCreate() + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.mnist_dir_path) + os.unlink(cls.gpu_discovery_script_file_name) + cls.spark.stop() + + def _get_inputs_for_test_local_training_succeeds(self): + return [ + ("0,1,2", 1, True, "0,1,2"), + ("0,1,2", 3, True, "0,1,2"), + ("0,1,2", 2, False, "0,1,2"), + (None, 3, False, "NONE"), + ] + + @unittest.skipIf( + not have_torch or is_remote_only(), torch_requirement_message or "Requires JVM access" + ) + class TorchDistributorLocalUnitTestsIIOnConnect( + TorchDistributorLocalUnitTestsMixin, unittest.TestCase + ): + @classmethod + def setUpClass(cls): + (cls.gpu_discovery_script_file_name, cls.mnist_dir_path) = set_up_test_dirs() + builder = SparkSession.builder.appName(cls.__name__) + for k, v in get_local_mode_conf().items(): + builder = builder.config(k, v) + + builder = builder.config( + "spark.driver.resource.gpu.discoveryScript", cls.gpu_discovery_script_file_name + ) + cls.spark = builder.remote("local[4]").getOrCreate() + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.mnist_dir_path) + os.unlink(cls.gpu_discovery_script_file_name) + cls.spark.stop() + + def _get_inputs_for_test_local_training_succeeds(self): + return [ + ("0,1,2", 1, True, "0,1,2"), + ("0,1,2", 3, True, "0,1,2"), + ("0,1,2", 2, False, "0,1,2"), + (None, 3, False, "NONE"), + ] + + @unittest.skipIf( + not have_torch or is_remote_only(), torch_requirement_message or "Requires JVM access" + ) + class TorchDistributorDistributedUnitTestsOnConnect( + TorchDistributorDistributedUnitTestsMixin, unittest.TestCase + ): + @classmethod + def setUpClass(cls): + (cls.gpu_discovery_script_file_name, cls.mnist_dir_path) = set_up_test_dirs() + builder = SparkSession.builder.appName(cls.__name__) + for k, v in get_distributed_mode_conf().items(): + builder = builder.config(k, v) + + builder = builder.config( + "spark.worker.resource.gpu.discoveryScript", cls.gpu_discovery_script_file_name + ) + cls.spark = builder.remote("local-cluster[2,2,512]").getOrCreate() + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.mnist_dir_path) + os.unlink(cls.gpu_discovery_script_file_name) + cls.spark.stop() + + @unittest.skipIf( + not have_torch or is_remote_only(), torch_requirement_message or "Requires JVM access" + ) + class TorchWrapperUnitTestsOnConnect(TorchWrapperUnitTestsMixin, unittest.TestCase): + pass if __name__ == "__main__": --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org