This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new aa568354725c [SPARK-47811][PYTHON][CONNECT][TESTS] Run ML tests for 
pyspark-connect package
aa568354725c is described below

commit aa568354725ce44fc0261973b97597ab0986edb1
Author: Hyukjin Kwon <gurwls...@apache.org>
AuthorDate: Fri Apr 12 09:02:47 2024 +0900

    [SPARK-47811][PYTHON][CONNECT][TESTS] Run ML tests for pyspark-connect 
package
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to extends `pyspark-connect` scheduled job to run ML tests 
as well.
    
    ### Why are the changes needed?
    
    In order to make sure pure Python library works with ML.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, test-only.
    
    ### How was this patch tested?
    
    Tested in my fork: 
https://github.com/HyukjinKwon/spark/actions/runs/8643632135/job/23697401430
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #45941 from HyukjinKwon/test-ps-ci.
    
    Authored-by: Hyukjin Kwon <gurwls...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .github/workflows/build_python_connect.yml         |   3 +-
 python/packaging/connect/setup.py                  |   1 +
 python/pyspark/ml/connect/classification.py        |   1 -
 python/pyspark/ml/param/__init__.py                |   7 +-
 .../tests/connect/test_connect_classification.py   |  10 +-
 .../ml/tests/connect/test_connect_evaluation.py    |   5 +-
 .../ml/tests/connect/test_connect_feature.py       |   5 +-
 .../ml/tests/connect/test_connect_function.py      |   2 +
 .../ml/tests/connect/test_connect_pipeline.py      |  11 +-
 .../ml/tests/connect/test_connect_summarizer.py    |   5 +-
 .../ml/tests/connect/test_connect_tuning.py        |   9 +-
 .../connect/test_legacy_mode_classification.py     |   8 +-
 .../tests/connect/test_legacy_mode_evaluation.py   |   9 +-
 .../ml/tests/connect/test_legacy_mode_feature.py   |   6 +-
 .../ml/tests/connect/test_legacy_mode_pipeline.py  |   6 +-
 .../tests/connect/test_legacy_mode_summarizer.py   |   6 +-
 .../ml/tests/connect/test_legacy_mode_tuning.py    |   9 +-
 .../tests/connect/test_parity_torch_data_loader.py |  28 ++-
 .../tests/connect/test_parity_torch_distributor.py | 232 +++++++++++----------
 19 files changed, 218 insertions(+), 145 deletions(-)

diff --git a/.github/workflows/build_python_connect.yml 
b/.github/workflows/build_python_connect.yml
index ec7103e5dbeb..8deee026131e 100644
--- a/.github/workflows/build_python_connect.yml
+++ b/.github/workflows/build_python_connect.yml
@@ -72,6 +72,7 @@ jobs:
           python packaging/connect/setup.py sdist
           cd dist
           pip install pyspark-connect-*.tar.gz
+          pip install scikit-learn torch torchvision torcheval
       - name: Run tests
         env:
           SPARK_CONNECT_TESTING_REMOTE: sc://localhost
@@ -82,7 +83,7 @@ jobs:
           # Remove Py4J and PySpark zipped library to make sure there is no 
JVM connection
           rm python/lib/*
           rm -r python/pyspark
-          ./python/run-tests --parallelism=1 --python-executables=python3 
--modules pyspark-connect
+          ./python/run-tests --parallelism=1 --python-executables=python3 
--modules pyspark-connect,pyspark-ml-connect
       - name: Upload test results to report
         if: always()
         uses: actions/upload-artifact@v4
diff --git a/python/packaging/connect/setup.py 
b/python/packaging/connect/setup.py
index 3514e5cdc422..419ed36b4236 100755
--- a/python/packaging/connect/setup.py
+++ b/python/packaging/connect/setup.py
@@ -77,6 +77,7 @@ if "SPARK_TESTING" in os.environ:
         "pyspark.sql.tests.connect.shell",
         "pyspark.sql.tests.pandas",
         "pyspark.sql.tests.streaming",
+        "pyspark.ml.tests.connect",
     ]
 
 try:
diff --git a/python/pyspark/ml/connect/classification.py 
b/python/pyspark/ml/connect/classification.py
index 8d8c6227eac3..fc7b5cda88a2 100644
--- a/python/pyspark/ml/connect/classification.py
+++ b/python/pyspark/ml/connect/classification.py
@@ -320,7 +320,6 @@ class LogisticRegressionModel(
 
     def _get_transform_fn(self) -> Callable[["pd.Series"], Any]:
         import torch
-
         import torch.nn as torch_nn
 
         model_state_dict = self.torch_model.state_dict()
diff --git a/python/pyspark/ml/param/__init__.py 
b/python/pyspark/ml/param/__init__.py
index 345b7f7a5964..f32ead2a580c 100644
--- a/python/pyspark/ml/param/__init__.py
+++ b/python/pyspark/ml/param/__init__.py
@@ -30,8 +30,8 @@ from typing import (
 )
 
 import numpy as np
-from py4j.java_gateway import JavaObject
 
+from pyspark.util import is_remote_only
 from pyspark.ml.linalg import DenseVector, Vector, Matrix
 from pyspark.ml.util import Identifiable
 
@@ -516,9 +516,12 @@ class Params(Identifiable, metaclass=ABCMeta):
         """
         Sets default params.
         """
+        if not is_remote_only():
+            from py4j.java_gateway import JavaObject
+
         for param, value in kwargs.items():
             p = getattr(self, param)
-            if value is not None and not isinstance(value, JavaObject):
+            if value is not None and (is_remote_only() or not 
isinstance(value, JavaObject)):
                 try:
                     value = p.typeConverter(value)
                 except TypeError as e:
diff --git a/python/pyspark/ml/tests/connect/test_connect_classification.py 
b/python/pyspark/ml/tests/connect/test_connect_classification.py
index ebc1745874d9..8083090523a0 100644
--- a/python/pyspark/ml/tests/connect/test_connect_classification.py
+++ b/python/pyspark/ml/tests/connect/test_connect_classification.py
@@ -17,7 +17,9 @@
 #
 
 import unittest
+import os
 
+from pyspark.util import is_remote_only
 from pyspark.sql import SparkSession
 from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
 
@@ -33,13 +35,15 @@ if should_test_connect:
 
 
 @unittest.skipIf(
-    not should_test_connect or not have_torch,
-    connect_requirement_message or torch_requirement_message,
+    not should_test_connect or not have_torch or is_remote_only(),
+    connect_requirement_message
+    or torch_requirement_message
+    or "Requires PySpark core library in Spark Connect server",
 )
 class ClassificationTestsOnConnect(ClassificationTestsMixin, 
unittest.TestCase):
     def setUp(self) -> None:
         self.spark = (
-            SparkSession.builder.remote("local[2]")
+            
SparkSession.builder.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", 
"local[2]"))
             .config("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", 
"true")
             .getOrCreate()
         )
diff --git a/python/pyspark/ml/tests/connect/test_connect_evaluation.py 
b/python/pyspark/ml/tests/connect/test_connect_evaluation.py
index 7f3b6bd0198c..359a77bbcb20 100644
--- a/python/pyspark/ml/tests/connect/test_connect_evaluation.py
+++ b/python/pyspark/ml/tests/connect/test_connect_evaluation.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 #
 
+import os
 import unittest
 
 from pyspark.sql import SparkSession
@@ -36,7 +37,9 @@ if should_test_connect:
 )
 class EvaluationTestsOnConnect(EvaluationTestsMixin, unittest.TestCase):
     def setUp(self) -> None:
-        self.spark = SparkSession.builder.remote("local[2]").getOrCreate()
+        self.spark = SparkSession.builder.remote(
+            os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
+        ).getOrCreate()
 
     def tearDown(self) -> None:
         self.spark.stop()
diff --git a/python/pyspark/ml/tests/connect/test_connect_feature.py 
b/python/pyspark/ml/tests/connect/test_connect_feature.py
index 04b1744c4995..c786ce2f87d0 100644
--- a/python/pyspark/ml/tests/connect/test_connect_feature.py
+++ b/python/pyspark/ml/tests/connect/test_connect_feature.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 #
 
+import os
 import unittest
 
 from pyspark.sql import SparkSession
@@ -38,7 +39,9 @@ if should_test_connect:
 )
 class FeatureTestsOnConnect(FeatureTestsMixin, unittest.TestCase):
     def setUp(self) -> None:
-        self.spark = SparkSession.builder.remote("local[2]").getOrCreate()
+        self.spark = SparkSession.builder.remote(
+            os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
+        ).getOrCreate()
 
     def tearDown(self) -> None:
         self.spark.stop()
diff --git a/python/pyspark/ml/tests/connect/test_connect_function.py 
b/python/pyspark/ml/tests/connect/test_connect_function.py
index b38d415e2bb2..f50376110660 100644
--- a/python/pyspark/ml/tests/connect/test_connect_function.py
+++ b/python/pyspark/ml/tests/connect/test_connect_function.py
@@ -17,6 +17,7 @@
 import os
 import unittest
 
+from pyspark.util import is_remote_only
 from pyspark.sql import SparkSession as PySparkSession
 from pyspark.sql.dataframe import DataFrame as SDF
 from pyspark.ml import functions as SF
@@ -32,6 +33,7 @@ if should_test_connect:
     from pyspark.ml.connect import functions as CF
 
 
+@unittest.skipIf(is_remote_only(), "Requires JVM access")
 class SparkConnectMLFunctionTests(ReusedConnectTestCase, 
PandasOnSparkTestUtils, SQLTestUtils):
     """These test cases exercise the interface to the proto plan
     generation but do not call Spark."""
diff --git a/python/pyspark/ml/tests/connect/test_connect_pipeline.py 
b/python/pyspark/ml/tests/connect/test_connect_pipeline.py
index 45d19f2bcdde..4105f593f170 100644
--- a/python/pyspark/ml/tests/connect/test_connect_pipeline.py
+++ b/python/pyspark/ml/tests/connect/test_connect_pipeline.py
@@ -15,9 +15,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-
+import os
 import unittest
 
+from pyspark.util import is_remote_only
 from pyspark.sql import SparkSession
 from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
 
@@ -34,13 +35,15 @@ except ImportError:
 
 
 @unittest.skipIf(
-    not should_test_connect or not have_torch,
-    connect_requirement_message or torch_requirement_message,
+    not should_test_connect or not have_torch or is_remote_only(),
+    connect_requirement_message
+    or torch_requirement_message
+    or "Requires PySpark core library in Spark Connect server",
 )
 class PipelineTestsOnConnect(PipelineTestsMixin, unittest.TestCase):
     def setUp(self) -> None:
         self.spark = (
-            SparkSession.builder.remote("local[2]")
+            
SparkSession.builder.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", 
"local[2]"))
             .config("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", 
"true")
             .getOrCreate()
         )
diff --git a/python/pyspark/ml/tests/connect/test_connect_summarizer.py 
b/python/pyspark/ml/tests/connect/test_connect_summarizer.py
index 866a3468388d..1cfd2ed229e5 100644
--- a/python/pyspark/ml/tests/connect/test_connect_summarizer.py
+++ b/python/pyspark/ml/tests/connect/test_connect_summarizer.py
@@ -16,6 +16,7 @@
 #
 
 import unittest
+import os
 
 from pyspark.sql import SparkSession
 from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
@@ -27,7 +28,9 @@ if should_test_connect:
 @unittest.skipIf(not should_test_connect, connect_requirement_message)
 class SummarizerTestsOnConnect(SummarizerTestsMixin, unittest.TestCase):
     def setUp(self) -> None:
-        self.spark = SparkSession.builder.remote("local[2]").getOrCreate()
+        self.spark = SparkSession.builder.remote(
+            os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
+        ).getOrCreate()
 
     def tearDown(self) -> None:
         self.spark.stop()
diff --git a/python/pyspark/ml/tests/connect/test_connect_tuning.py 
b/python/pyspark/ml/tests/connect/test_connect_tuning.py
index 7b10d91da064..d5fcb93099b6 100644
--- a/python/pyspark/ml/tests/connect/test_connect_tuning.py
+++ b/python/pyspark/ml/tests/connect/test_connect_tuning.py
@@ -17,7 +17,9 @@
 #
 
 import unittest
+import os
 
+from pyspark.util import is_remote_only
 from pyspark.sql import SparkSession
 from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
 
@@ -25,11 +27,14 @@ if should_test_connect:
     from pyspark.ml.tests.connect.test_legacy_mode_tuning import 
CrossValidatorTestsMixin
 
 
-@unittest.skipIf(not should_test_connect, connect_requirement_message)
+@unittest.skipIf(
+    not should_test_connect or is_remote_only(),
+    connect_requirement_message or "Requires PySpark core library in Spark 
Connect server",
+)
 class CrossValidatorTestsOnConnect(CrossValidatorTestsMixin, 
unittest.TestCase):
     def setUp(self) -> None:
         self.spark = (
-            SparkSession.builder.remote("local[2]")
+            
SparkSession.builder.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", 
"local[2]"))
             .config("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", 
"true")
             .getOrCreate()
         )
diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_classification.py 
b/python/pyspark/ml/tests/connect/test_legacy_mode_classification.py
index db9a29804808..dc2642a42d66 100644
--- a/python/pyspark/ml/tests/connect/test_legacy_mode_classification.py
+++ b/python/pyspark/ml/tests/connect/test_legacy_mode_classification.py
@@ -21,14 +21,17 @@ import unittest
 
 import numpy as np
 
+from pyspark.util import is_remote_only
 from pyspark.sql import SparkSession
 from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
 
 have_torch = True
+torch_requirement_message = None
 try:
     import torch  # noqa: F401
 except ImportError:
     have_torch = False
+    torch_requirement_message = "No torch found"
 
 if should_test_connect:
     from pyspark.ml.connect.classification import (
@@ -228,7 +231,10 @@ class ClassificationTestsMixin:
 
 
 @unittest.skipIf(
-    not should_test_connect or not have_torch, connect_requirement_message or 
"No torch found"
+    not should_test_connect or not have_torch or is_remote_only(),
+    connect_requirement_message
+    or torch_requirement_message
+    or "pyspark-connect cannot test classic Spark",
 )
 class ClassificationTests(ClassificationTestsMixin, unittest.TestCase):
     def setUp(self) -> None:
diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py 
b/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py
index ae01031ff462..11c1f9aeee51 100644
--- a/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py
+++ b/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py
@@ -20,14 +20,17 @@ import tempfile
 
 import numpy as np
 
+from pyspark.util import is_remote_only
 from pyspark.sql import SparkSession
 from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
 
 have_torcheval = True
+torcheval_requirement_message = None
 try:
     import torcheval  # noqa: F401
 except ImportError:
     have_torcheval = False
+    torcheval_requirement_message = "torcheval is required"
 
 if should_test_connect:
     from pyspark.ml.connect.evaluation import (
@@ -177,8 +180,10 @@ class EvaluationTestsMixin:
 
 
 @unittest.skipIf(
-    not should_test_connect or not have_torcheval,
-    connect_requirement_message or "torcheval is required",
+    not should_test_connect or not have_torcheval or is_remote_only(),
+    connect_requirement_message
+    or torcheval_requirement_message
+    or "pyspark-connect cannot test classic Spark",
 )
 class EvaluationTests(EvaluationTestsMixin, unittest.TestCase):
     def setUp(self) -> None:
diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py 
b/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py
index 9565b3a09a5b..4915d4706b87 100644
--- a/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py
+++ b/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py
@@ -23,6 +23,7 @@ import unittest
 
 import numpy as np
 
+from pyspark.util import is_remote_only
 from pyspark.sql import SparkSession
 from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
 
@@ -194,7 +195,10 @@ class FeatureTestsMixin:
             assembler2.transform(pandas_df)["out"].tolist()
 
 
-@unittest.skipIf(not should_test_connect, connect_requirement_message)
+@unittest.skipIf(
+    not should_test_connect or is_remote_only(),
+    connect_requirement_message or "pyspark-connect cannot test classic Spark",
+)
 class FeatureTests(FeatureTestsMixin, unittest.TestCase):
     def setUp(self) -> None:
         self.spark = SparkSession.builder.master("local[2]").getOrCreate()
diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_pipeline.py 
b/python/pyspark/ml/tests/connect/test_legacy_mode_pipeline.py
index 104aff17e0b2..692144148af0 100644
--- a/python/pyspark/ml/tests/connect/test_legacy_mode_pipeline.py
+++ b/python/pyspark/ml/tests/connect/test_legacy_mode_pipeline.py
@@ -21,6 +21,7 @@ import unittest
 
 import numpy as np
 
+from pyspark.util import is_remote_only
 from pyspark.sql import SparkSession
 from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
 
@@ -167,7 +168,10 @@ class PipelineTestsMixin:
         assert lorv2.getOrDefault(lorv2.maxIter) == 200
 
 
-@unittest.skipIf(not should_test_connect, connect_requirement_message)
+@unittest.skipIf(
+    not should_test_connect or is_remote_only(),
+    connect_requirement_message or "pyspark-connect cannot test classic Spark",
+)
 class PipelineTests(PipelineTestsMixin, unittest.TestCase):
     def setUp(self) -> None:
         self.spark = SparkSession.builder.master("local[2]").getOrCreate()
diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_summarizer.py 
b/python/pyspark/ml/tests/connect/test_legacy_mode_summarizer.py
index 7f09eb9f0742..253632a74c97 100644
--- a/python/pyspark/ml/tests/connect/test_legacy_mode_summarizer.py
+++ b/python/pyspark/ml/tests/connect/test_legacy_mode_summarizer.py
@@ -20,6 +20,7 @@ import unittest
 
 import numpy as np
 
+from pyspark.util import is_remote_only
 from pyspark.sql import SparkSession
 from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
 
@@ -62,7 +63,10 @@ class SummarizerTestsMixin:
         assert_dict_allclose(result_local, expected_result)
 
 
-@unittest.skipIf(not should_test_connect, connect_requirement_message)
+@unittest.skipIf(
+    not should_test_connect or is_remote_only(),
+    connect_requirement_message or "pyspark-connect cannot test classic Spark",
+)
 class SummarizerTests(SummarizerTestsMixin, unittest.TestCase):
     def setUp(self) -> None:
         self.spark = SparkSession.builder.master("local[2]").getOrCreate()
diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py 
b/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py
index 7f26788c137f..14f52d75e6d6 100644
--- a/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py
+++ b/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py
@@ -22,6 +22,7 @@ import sys
 
 import numpy as np
 
+from pyspark.util import is_remote_only
 from pyspark.ml.param import Param, Params
 from pyspark.ml.tuning import ParamGridBuilder
 from pyspark.sql import SparkSession
@@ -29,10 +30,13 @@ from pyspark.sql.functions import rand
 from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
 
 have_sklearn = True
+sklearn_requirement_message = None
 try:
     from sklearn.datasets import load_breast_cancer  # noqa: F401
 except ImportError:
     have_sklearn = False
+    sklearn_requirement_message = "No sklearn found"
+
 
 if should_test_connect:
     import pandas as pd
@@ -279,7 +283,10 @@ class CrossValidatorTestsMixin:
 
 
 @unittest.skipIf(
-    not should_test_connect or not have_sklearn, connect_requirement_message 
or "No sklearn found"
+    not should_test_connect or not have_sklearn or is_remote_only(),
+    connect_requirement_message
+    or sklearn_requirement_message
+    or "pyspark-connect cannot test classic Spark",
 )
 class CrossValidatorTests(CrossValidatorTestsMixin, unittest.TestCase):
     def setUp(self) -> None:
diff --git a/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py 
b/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py
index 1984efdc6c6e..462fe3822141 100644
--- a/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py
+++ b/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py
@@ -17,24 +17,30 @@
 
 import unittest
 
+from pyspark.util import is_remote_only
 from pyspark.sql import SparkSession
-from pyspark.ml.torch.tests.test_data_loader import 
TorchDistributorDataLoaderUnitTests
 
+torch_requirement_message = None
 have_torch = True
 try:
     import torch  # noqa: F401
 except ImportError:
     have_torch = False
-
-
-@unittest.skipIf(not have_torch, "torch is required")
-class 
TorchDistributorBaselineUnitTestsOnConnect(TorchDistributorDataLoaderUnitTests):
-    def setUp(self) -> None:
-        self.spark = (
-            SparkSession.builder.remote("local[1]")
-            .config("spark.default.parallelism", "1")
-            .getOrCreate()
-        )
+    torch_requirement_message = "torch is required"
+
+if not is_remote_only():
+    from pyspark.ml.torch.tests.test_data_loader import 
TorchDistributorDataLoaderUnitTests
+
+    @unittest.skipIf(
+        not have_torch or is_remote_only(), torch_requirement_message or 
"Requires JVM access"
+    )
+    class 
TorchDistributorBaselineUnitTestsOnConnect(TorchDistributorDataLoaderUnitTests):
+        def setUp(self) -> None:
+            self.spark = (
+                SparkSession.builder.remote("local[1]")
+                .config("spark.default.parallelism", "1")
+                .getOrCreate()
+            )
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py 
b/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py
index 70aa80ba6d11..e40303ae9ce2 100644
--- a/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py
+++ b/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py
@@ -19,124 +19,134 @@ import os
 import shutil
 import unittest
 
+torch_requirement_message = None
 have_torch = True
 try:
     import torch  # noqa: F401
 except ImportError:
     have_torch = False
+    torch_requirement_message = "torch is required"
 
+from pyspark.util import is_remote_only
 from pyspark.sql import SparkSession
-from pyspark.ml.torch.tests.test_distributor import (
-    TorchDistributorBaselineUnitTestsMixin,
-    TorchDistributorLocalUnitTestsMixin,
-    TorchDistributorDistributedUnitTestsMixin,
-    TorchWrapperUnitTestsMixin,
-    set_up_test_dirs,
-    get_local_mode_conf,
-    get_distributed_mode_conf,
-)
-
-
-@unittest.skipIf(not have_torch, "torch is required")
-class TorchDistributorBaselineUnitTestsOnConnect(
-    TorchDistributorBaselineUnitTestsMixin, unittest.TestCase
-):
-    @classmethod
-    def setUpClass(cls):
-        cls.spark = SparkSession.builder.remote("local[4]").getOrCreate()
-
-    @classmethod
-    def tearDownClass(cls):
-        cls.spark.stop()
-
-
-@unittest.skipIf(not have_torch, "torch is required")
-class TorchDistributorLocalUnitTestsOnConnect(
-    TorchDistributorLocalUnitTestsMixin, unittest.TestCase
-):
-    @classmethod
-    def setUpClass(cls):
-        (cls.gpu_discovery_script_file_name, cls.mnist_dir_path) = 
set_up_test_dirs()
-        builder = SparkSession.builder.appName(cls.__name__)
-        for k, v in get_local_mode_conf().items():
-            builder = builder.config(k, v)
-        builder = builder.config(
-            "spark.driver.resource.gpu.discoveryScript", 
cls.gpu_discovery_script_file_name
-        )
-        cls.spark = builder.remote("local-cluster[2,2,512]").getOrCreate()
-
-    @classmethod
-    def tearDownClass(cls):
-        shutil.rmtree(cls.mnist_dir_path)
-        os.unlink(cls.gpu_discovery_script_file_name)
-        cls.spark.stop()
-
-    def _get_inputs_for_test_local_training_succeeds(self):
-        return [
-            ("0,1,2", 1, True, "0,1,2"),
-            ("0,1,2", 3, True, "0,1,2"),
-            ("0,1,2", 2, False, "0,1,2"),
-            (None, 3, False, "NONE"),
-        ]
-
-
-@unittest.skipIf(not have_torch, "torch is required")
-class TorchDistributorLocalUnitTestsIIOnConnect(
-    TorchDistributorLocalUnitTestsMixin, unittest.TestCase
-):
-    @classmethod
-    def setUpClass(cls):
-        (cls.gpu_discovery_script_file_name, cls.mnist_dir_path) = 
set_up_test_dirs()
-        builder = SparkSession.builder.appName(cls.__name__)
-        for k, v in get_local_mode_conf().items():
-            builder = builder.config(k, v)
-
-        builder = builder.config(
-            "spark.driver.resource.gpu.discoveryScript", 
cls.gpu_discovery_script_file_name
-        )
-        cls.spark = builder.remote("local[4]").getOrCreate()
-
-    @classmethod
-    def tearDownClass(cls):
-        shutil.rmtree(cls.mnist_dir_path)
-        os.unlink(cls.gpu_discovery_script_file_name)
-        cls.spark.stop()
-
-    def _get_inputs_for_test_local_training_succeeds(self):
-        return [
-            ("0,1,2", 1, True, "0,1,2"),
-            ("0,1,2", 3, True, "0,1,2"),
-            ("0,1,2", 2, False, "0,1,2"),
-            (None, 3, False, "NONE"),
-        ]
-
-
-@unittest.skipIf(not have_torch, "torch is required")
-class TorchDistributorDistributedUnitTestsOnConnect(
-    TorchDistributorDistributedUnitTestsMixin, unittest.TestCase
-):
-    @classmethod
-    def setUpClass(cls):
-        (cls.gpu_discovery_script_file_name, cls.mnist_dir_path) = 
set_up_test_dirs()
-        builder = SparkSession.builder.appName(cls.__name__)
-        for k, v in get_distributed_mode_conf().items():
-            builder = builder.config(k, v)
-
-        builder = builder.config(
-            "spark.worker.resource.gpu.discoveryScript", 
cls.gpu_discovery_script_file_name
-        )
-        cls.spark = builder.remote("local-cluster[2,2,512]").getOrCreate()
-
-    @classmethod
-    def tearDownClass(cls):
-        shutil.rmtree(cls.mnist_dir_path)
-        os.unlink(cls.gpu_discovery_script_file_name)
-        cls.spark.stop()
-
-
-@unittest.skipIf(not have_torch, "torch is required")
-class TorchWrapperUnitTestsOnConnect(TorchWrapperUnitTestsMixin, 
unittest.TestCase):
-    pass
+
+if not is_remote_only():
+    from pyspark.ml.torch.tests.test_distributor import (
+        TorchDistributorBaselineUnitTestsMixin,
+        TorchDistributorLocalUnitTestsMixin,
+        TorchDistributorDistributedUnitTestsMixin,
+        TorchWrapperUnitTestsMixin,
+        set_up_test_dirs,
+        get_local_mode_conf,
+        get_distributed_mode_conf,
+    )
+
+    @unittest.skipIf(
+        not have_torch or is_remote_only(), torch_requirement_message or 
"Requires JVM access"
+    )
+    class TorchDistributorBaselineUnitTestsOnConnect(
+        TorchDistributorBaselineUnitTestsMixin, unittest.TestCase
+    ):
+        @classmethod
+        def setUpClass(cls):
+            cls.spark = SparkSession.builder.remote("local[4]").getOrCreate()
+
+        @classmethod
+        def tearDownClass(cls):
+            cls.spark.stop()
+
+    @unittest.skipIf(
+        not have_torch or is_remote_only(), torch_requirement_message or 
"Requires JVM access"
+    )
+    class TorchDistributorLocalUnitTestsOnConnect(
+        TorchDistributorLocalUnitTestsMixin, unittest.TestCase
+    ):
+        @classmethod
+        def setUpClass(cls):
+            (cls.gpu_discovery_script_file_name, cls.mnist_dir_path) = 
set_up_test_dirs()
+            builder = SparkSession.builder.appName(cls.__name__)
+            for k, v in get_local_mode_conf().items():
+                builder = builder.config(k, v)
+            builder = builder.config(
+                "spark.driver.resource.gpu.discoveryScript", 
cls.gpu_discovery_script_file_name
+            )
+            cls.spark = builder.remote("local-cluster[2,2,512]").getOrCreate()
+
+        @classmethod
+        def tearDownClass(cls):
+            shutil.rmtree(cls.mnist_dir_path)
+            os.unlink(cls.gpu_discovery_script_file_name)
+            cls.spark.stop()
+
+        def _get_inputs_for_test_local_training_succeeds(self):
+            return [
+                ("0,1,2", 1, True, "0,1,2"),
+                ("0,1,2", 3, True, "0,1,2"),
+                ("0,1,2", 2, False, "0,1,2"),
+                (None, 3, False, "NONE"),
+            ]
+
+    @unittest.skipIf(
+        not have_torch or is_remote_only(), torch_requirement_message or 
"Requires JVM access"
+    )
+    class TorchDistributorLocalUnitTestsIIOnConnect(
+        TorchDistributorLocalUnitTestsMixin, unittest.TestCase
+    ):
+        @classmethod
+        def setUpClass(cls):
+            (cls.gpu_discovery_script_file_name, cls.mnist_dir_path) = 
set_up_test_dirs()
+            builder = SparkSession.builder.appName(cls.__name__)
+            for k, v in get_local_mode_conf().items():
+                builder = builder.config(k, v)
+
+            builder = builder.config(
+                "spark.driver.resource.gpu.discoveryScript", 
cls.gpu_discovery_script_file_name
+            )
+            cls.spark = builder.remote("local[4]").getOrCreate()
+
+        @classmethod
+        def tearDownClass(cls):
+            shutil.rmtree(cls.mnist_dir_path)
+            os.unlink(cls.gpu_discovery_script_file_name)
+            cls.spark.stop()
+
+        def _get_inputs_for_test_local_training_succeeds(self):
+            return [
+                ("0,1,2", 1, True, "0,1,2"),
+                ("0,1,2", 3, True, "0,1,2"),
+                ("0,1,2", 2, False, "0,1,2"),
+                (None, 3, False, "NONE"),
+            ]
+
+    @unittest.skipIf(
+        not have_torch or is_remote_only(), torch_requirement_message or 
"Requires JVM access"
+    )
+    class TorchDistributorDistributedUnitTestsOnConnect(
+        TorchDistributorDistributedUnitTestsMixin, unittest.TestCase
+    ):
+        @classmethod
+        def setUpClass(cls):
+            (cls.gpu_discovery_script_file_name, cls.mnist_dir_path) = 
set_up_test_dirs()
+            builder = SparkSession.builder.appName(cls.__name__)
+            for k, v in get_distributed_mode_conf().items():
+                builder = builder.config(k, v)
+
+            builder = builder.config(
+                "spark.worker.resource.gpu.discoveryScript", 
cls.gpu_discovery_script_file_name
+            )
+            cls.spark = builder.remote("local-cluster[2,2,512]").getOrCreate()
+
+        @classmethod
+        def tearDownClass(cls):
+            shutil.rmtree(cls.mnist_dir_path)
+            os.unlink(cls.gpu_discovery_script_file_name)
+            cls.spark.stop()
+
+    @unittest.skipIf(
+        not have_torch or is_remote_only(), torch_requirement_message or 
"Requires JVM access"
+    )
+    class TorchWrapperUnitTestsOnConnect(TorchWrapperUnitTestsMixin, 
unittest.TestCase):
+        pass
 
 
 if __name__ == "__main__":


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to