This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a40919912f5c [SPARK-50376][PYTHON][ML][TESTS] Centralize the 
dependency check in ML tests
a40919912f5c is described below

commit a40919912f5ce7f63fff2907b30e473dd4155227
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Nov 21 14:57:55 2024 +0900

    [SPARK-50376][PYTHON][ML][TESTS] Centralize the dependency check in ML tests
    
    ### What changes were proposed in this pull request?
    Centralize the dependency check in ML tests
    
    ### Why are the changes needed?
    deduplicate code
    
    ### Does this PR introduce _any_ user-facing change?
    no, test-only
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #48911 from zhengruifeng/py_centralize_ml_dep.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../deepspeed/tests/test_deepspeed_distributor.py  | 11 ++---
 .../tests/connect/test_connect_classification.py   |  8 +---
 .../ml/tests/connect/test_connect_evaluation.py    |  9 +---
 .../ml/tests/connect/test_connect_feature.py       |  9 +---
 .../ml/tests/connect/test_connect_pipeline.py      |  8 +---
 .../ml/tests/connect/test_connect_tuning.py        |  5 +--
 .../connect/test_legacy_mode_classification.py     | 11 ++---
 .../tests/connect/test_legacy_mode_evaluation.py   |  9 +---
 .../ml/tests/connect/test_legacy_mode_feature.py   |  5 +--
 .../ml/tests/connect/test_legacy_mode_pipeline.py  |  5 +--
 .../ml/tests/connect/test_legacy_mode_tuning.py    | 37 ++++++----------
 .../tests/connect/test_parity_torch_data_loader.py |  9 +---
 .../tests/connect/test_parity_torch_distributor.py |  9 +---
 python/pyspark/ml/torch/tests/test_data_loader.py  | 13 ++----
 python/pyspark/ml/torch/tests/test_distributor.py  | 18 +++-----
 python/pyspark/testing/utils.py                    | 49 ++++++++++++++--------
 16 files changed, 71 insertions(+), 144 deletions(-)

diff --git a/python/pyspark/ml/deepspeed/tests/test_deepspeed_distributor.py 
b/python/pyspark/ml/deepspeed/tests/test_deepspeed_distributor.py
index 590e541c3842..66a9b553cc75 100644
--- a/python/pyspark/ml/deepspeed/tests/test_deepspeed_distributor.py
+++ b/python/pyspark/ml/deepspeed/tests/test_deepspeed_distributor.py
@@ -30,12 +30,7 @@ from pyspark.ml.torch.tests.test_distributor import (
     set_up_test_dirs,
     get_distributed_mode_conf,
 )
-
-have_deepspeed = True
-try:
-    import deepspeed  # noqa: F401
-except ImportError:
-    have_deepspeed = False
+from pyspark.testing.utils import have_deepspeed, deepspeed_requirement_message
 
 
 class DeepspeedTorchDistributorUnitTests(unittest.TestCase):
@@ -219,7 +214,7 @@ def _create_pytorch_training_test_file():
 # and inference, the hope is to switch out the training
 # and file for the tests with more realistic testing
 # that use Deepspeed constructs.
[email protected](not have_deepspeed, "deepspeed is required for these tests")
[email protected](not have_deepspeed, deepspeed_requirement_message)
 class DeepspeedTorchDistributorDistributedEndToEnd(unittest.TestCase):
     @classmethod
     def setUpClass(cls) -> None:
@@ -259,7 +254,7 @@ class 
DeepspeedTorchDistributorDistributedEndToEnd(unittest.TestCase):
             dist.run(cp_path, 2, 5)
 
 
[email protected](not have_deepspeed, "deepspeed is required for these tests")
[email protected](not have_deepspeed, deepspeed_requirement_message)
 class DeepspeedDistributorLocalEndToEndTests(unittest.TestCase):
     @classmethod
     def setUpClass(cls) -> None:
diff --git a/python/pyspark/ml/tests/connect/test_connect_classification.py 
b/python/pyspark/ml/tests/connect/test_connect_classification.py
index 910d2d2ec42f..d3e86a3fb9df 100644
--- a/python/pyspark/ml/tests/connect/test_connect_classification.py
+++ b/python/pyspark/ml/tests/connect/test_connect_classification.py
@@ -23,13 +23,7 @@ from pyspark.util import is_remote_only
 from pyspark.sql import SparkSession
 from pyspark.ml.tests.connect.test_legacy_mode_classification import 
ClassificationTestsMixin
 from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
-
-torch_requirement_message = "torch is required"
-have_torch = True
-try:
-    import torch  # noqa: F401
-except ImportError:
-    have_torch = False
+from pyspark.testing.utils import have_torch, torch_requirement_message
 
 
 @unittest.skipIf(
diff --git a/python/pyspark/ml/tests/connect/test_connect_evaluation.py 
b/python/pyspark/ml/tests/connect/test_connect_evaluation.py
index 9acf5ae0ac44..cabd8b5b50df 100644
--- a/python/pyspark/ml/tests/connect/test_connect_evaluation.py
+++ b/python/pyspark/ml/tests/connect/test_connect_evaluation.py
@@ -20,19 +20,14 @@ import unittest
 
 from pyspark.sql import SparkSession
 from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
-
-have_torcheval = True
-try:
-    import torcheval  # noqa: F401
-except ImportError:
-    have_torcheval = False
+from pyspark.testing.utils import have_torcheval, torcheval_requirement_message
 
 if should_test_connect:
     from pyspark.ml.tests.connect.test_legacy_mode_evaluation import 
EvaluationTestsMixin
 
     @unittest.skipIf(
         not should_test_connect or not have_torcheval,
-        connect_requirement_message or "torcheval is required",
+        connect_requirement_message or torcheval_requirement_message,
     )
     class EvaluationTestsOnConnect(EvaluationTestsMixin, unittest.TestCase):
         def setUp(self) -> None:
diff --git a/python/pyspark/ml/tests/connect/test_connect_feature.py 
b/python/pyspark/ml/tests/connect/test_connect_feature.py
index c1d02050097b..879cbff6d0cc 100644
--- a/python/pyspark/ml/tests/connect/test_connect_feature.py
+++ b/python/pyspark/ml/tests/connect/test_connect_feature.py
@@ -20,14 +20,7 @@ import unittest
 
 from pyspark.sql import SparkSession
 from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
-
-have_sklearn = True
-sklearn_requirement_message = None
-try:
-    from sklearn.datasets import load_breast_cancer  # noqa: F401
-except ImportError:
-    have_sklearn = False
-    sklearn_requirement_message = "No sklearn found"
+from pyspark.testing.utils import have_sklearn, sklearn_requirement_message
 
 if should_test_connect:
     from pyspark.ml.tests.connect.test_legacy_mode_feature import 
FeatureTestsMixin
diff --git a/python/pyspark/ml/tests/connect/test_connect_pipeline.py 
b/python/pyspark/ml/tests/connect/test_connect_pipeline.py
index 7733af7631e9..f8576d0cb09d 100644
--- a/python/pyspark/ml/tests/connect/test_connect_pipeline.py
+++ b/python/pyspark/ml/tests/connect/test_connect_pipeline.py
@@ -21,14 +21,8 @@ import unittest
 from pyspark.util import is_remote_only
 from pyspark.sql import SparkSession
 from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
+from pyspark.testing.utils import have_torch, torch_requirement_message
 
-torch_requirement_message = None
-have_torch = True
-try:
-    import torch  # noqa: F401
-except ImportError:
-    have_torch = False
-    torch_requirement_message = "torch is required"
 
 if should_test_connect:
     from pyspark.ml.tests.connect.test_legacy_mode_pipeline import 
PipelineTestsMixin
diff --git a/python/pyspark/ml/tests/connect/test_connect_tuning.py 
b/python/pyspark/ml/tests/connect/test_connect_tuning.py
index fee7113e1ae5..d737dd5767db 100644
--- a/python/pyspark/ml/tests/connect/test_connect_tuning.py
+++ b/python/pyspark/ml/tests/connect/test_connect_tuning.py
@@ -19,13 +19,10 @@
 import unittest
 import os
 
-from pyspark.ml.tests.connect.test_connect_classification import (
-    have_torch,
-    torch_requirement_message,
-)
 from pyspark.util import is_remote_only
 from pyspark.sql import SparkSession
 from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
+from pyspark.testing.utils import have_torch, torch_requirement_message
 
 if should_test_connect:
     from pyspark.ml.tests.connect.test_legacy_mode_tuning import 
CrossValidatorTestsMixin
diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_classification.py 
b/python/pyspark/ml/tests/connect/test_legacy_mode_classification.py
index dc2642a42d66..fdae31077002 100644
--- a/python/pyspark/ml/tests/connect/test_legacy_mode_classification.py
+++ b/python/pyspark/ml/tests/connect/test_legacy_mode_classification.py
@@ -24,14 +24,7 @@ import numpy as np
 from pyspark.util import is_remote_only
 from pyspark.sql import SparkSession
 from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
-
-have_torch = True
-torch_requirement_message = None
-try:
-    import torch  # noqa: F401
-except ImportError:
-    have_torch = False
-    torch_requirement_message = "No torch found"
+from pyspark.testing.utils import have_torch, torch_requirement_message
 
 if should_test_connect:
     from pyspark.ml.connect.classification import (
@@ -135,6 +128,8 @@ class ClassificationTestsMixin:
         self._check_result(local_transform_result, expected_predictions, 
expected_probabilities)
 
     def test_save_load(self):
+        import torch
+
         with tempfile.TemporaryDirectory(prefix="test_save_load") as tmp_dir:
             estimator = LORV2(maxIter=2, numTrainWorkers=2, learningRate=0.001)
             local_path = os.path.join(tmp_dir, "estimator")
diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py 
b/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py
index 11c1f9aeee51..3a5417dadf50 100644
--- a/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py
+++ b/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py
@@ -23,14 +23,7 @@ import numpy as np
 from pyspark.util import is_remote_only
 from pyspark.sql import SparkSession
 from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
-
-have_torcheval = True
-torcheval_requirement_message = None
-try:
-    import torcheval  # noqa: F401
-except ImportError:
-    have_torcheval = False
-    torcheval_requirement_message = "torcheval is required"
+from pyspark.testing.utils import have_torcheval, torcheval_requirement_message
 
 if should_test_connect:
     from pyspark.ml.connect.evaluation import (
diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py 
b/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py
index d90e4a4315d5..6812db778450 100644
--- a/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py
+++ b/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py
@@ -26,10 +26,7 @@ import numpy as np
 from pyspark.util import is_remote_only
 from pyspark.sql import SparkSession
 from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
-from pyspark.ml.tests.connect.test_connect_classification import (
-    have_torch,
-    torch_requirement_message,
-)
+from pyspark.testing.utils import have_torch, torch_requirement_message
 
 if should_test_connect:
     from pyspark.ml.connect.feature import (
diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_pipeline.py 
b/python/pyspark/ml/tests/connect/test_legacy_mode_pipeline.py
index 9165034718d7..8b19f5931d20 100644
--- a/python/pyspark/ml/tests/connect/test_legacy_mode_pipeline.py
+++ b/python/pyspark/ml/tests/connect/test_legacy_mode_pipeline.py
@@ -24,10 +24,7 @@ import numpy as np
 from pyspark.util import is_remote_only
 from pyspark.sql import SparkSession
 from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
-from pyspark.ml.tests.connect.test_connect_classification import (
-    have_torch,
-    torch_requirement_message,
-)
+from pyspark.testing.utils import have_torch, torch_requirement_message
 
 if should_test_connect:
     from pyspark.ml.connect.feature import StandardScaler
diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py 
b/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py
index 33cc39d3319d..06c3ad93d92d 100644
--- a/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py
+++ b/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py
@@ -28,31 +28,14 @@ from pyspark.ml.tuning import ParamGridBuilder
 from pyspark.sql import SparkSession
 from pyspark.sql.functions import rand
 from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
-
-have_sklearn = True
-sklearn_requirement_message = None
-try:
-    from sklearn.datasets import load_breast_cancer  # noqa: F401
-except ImportError:
-    have_sklearn = False
-    sklearn_requirement_message = "No sklearn found"
-
-have_torch = True
-torch_requirement_message = None
-try:
-    import torch  # noqa: F401
-except ImportError:
-    have_torch = False
-    torch_requirement_message = "torch is required"
-
-have_torcheval = True
-torcheval_requirement_message = None
-try:
-    import torcheval  # noqa: F401
-except ImportError:
-    have_torcheval = False
-    torcheval_requirement_message = "torcheval is required"
-
+from pyspark.testing.utils import (
+    have_sklearn,
+    sklearn_requirement_message,
+    have_torch,
+    torch_requirement_message,
+    have_torcheval,
+    torcheval_requirement_message,
+)
 
 if should_test_connect:
     import pandas as pd
@@ -205,6 +188,8 @@ class CrossValidatorTestsMixin:
             )
 
     def test_crossvalidator_on_pipeline(self):
+        from sklearn.datasets import load_breast_cancer
+
         sk_dataset = load_breast_cancer()
 
         train_dataset = self.spark.createDataFrame(
@@ -270,6 +255,8 @@ class CrossValidatorTestsMixin:
         sys.version_info > (3, 12), "SPARK-46078: Fails with dev torch with 
Python 3.12"
     )
     def test_crossvalidator_with_fold_col(self):
+        from sklearn.datasets import load_breast_cancer
+
         sk_dataset = load_breast_cancer()
 
         train_dataset = self.spark.createDataFrame(
diff --git a/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py 
b/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py
index 462fe3822141..de05927138d4 100644
--- a/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py
+++ b/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py
@@ -19,14 +19,7 @@ import unittest
 
 from pyspark.util import is_remote_only
 from pyspark.sql import SparkSession
-
-torch_requirement_message = None
-have_torch = True
-try:
-    import torch  # noqa: F401
-except ImportError:
-    have_torch = False
-    torch_requirement_message = "torch is required"
+from pyspark.testing.utils import have_torch, torch_requirement_message
 
 if not is_remote_only():
     from pyspark.ml.torch.tests.test_data_loader import 
TorchDistributorDataLoaderUnitTests
diff --git a/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py 
b/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py
index e40303ae9ce2..3cd8abfc6e4e 100644
--- a/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py
+++ b/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py
@@ -19,16 +19,9 @@ import os
 import shutil
 import unittest
 
-torch_requirement_message = None
-have_torch = True
-try:
-    import torch  # noqa: F401
-except ImportError:
-    have_torch = False
-    torch_requirement_message = "torch is required"
-
 from pyspark.util import is_remote_only
 from pyspark.sql import SparkSession
+from pyspark.testing.utils import have_torch, torch_requirement_message
 
 if not is_remote_only():
     from pyspark.ml.torch.tests.test_distributor import (
diff --git a/python/pyspark/ml/torch/tests/test_data_loader.py 
b/python/pyspark/ml/torch/tests/test_data_loader.py
index 00f5f0a8c8d8..a47a5f163b68 100644
--- a/python/pyspark/ml/torch/tests/test_data_loader.py
+++ b/python/pyspark/ml/torch/tests/test_data_loader.py
@@ -17,23 +17,16 @@
 
 import unittest
 
-import numpy as np
-
-have_torch = True
-try:
-    import torch  # noqa: F401
-except ImportError:
-    have_torch = False
-
 from pyspark.ml.torch.distributor import (
     TorchDistributor,
     _get_spark_partition_data_loader,
 )
 from pyspark.sql import SparkSession
 from pyspark.ml.linalg import Vectors
+from pyspark.testing.utils import have_torch, torch_requirement_message
 
 
[email protected](not have_torch, "torch is required")
[email protected](not have_torch, torch_requirement_message)
 class TorchDistributorDataLoaderUnitTests(unittest.TestCase):
     def setUp(self) -> None:
         self.spark = (
@@ -46,6 +39,8 @@ class TorchDistributorDataLoaderUnitTests(unittest.TestCase):
         self.spark.stop()
 
     def _check_data_loader_result_correctness(self, result, expected):
+        import numpy as np
+
         assert len(result) == len(expected)
 
         for res_row, exp_row in zip(result, expected):
diff --git a/python/pyspark/ml/torch/tests/test_distributor.py 
b/python/pyspark/ml/torch/tests/test_distributor.py
index d16e60588482..7b6a93afbff7 100644
--- a/python/pyspark/ml/torch/tests/test_distributor.py
+++ b/python/pyspark/ml/torch/tests/test_distributor.py
@@ -29,17 +29,11 @@ from typing import Callable, Dict, Any
 import unittest
 from unittest.mock import patch
 
-have_torch = True
-try:
-    import torch  # noqa: F401
-except ImportError:
-    have_torch = False
-
 from pyspark import SparkConf, SparkContext
 from pyspark.ml.torch.distributor import TorchDistributor, _get_gpus_owned
 from pyspark.ml.torch.torch_run_process_wrapper import clean_and_terminate, 
check_parent_alive
 from pyspark.sql import SparkSession
-from pyspark.testing.utils import SPARK_HOME
+from pyspark.testing.utils import SPARK_HOME, have_torch, 
torch_requirement_message
 
 
 @contextlib.contextmanager
@@ -312,7 +306,7 @@ class TorchDistributorBaselineUnitTestsMixin:
         self.delete_env_vars(input_env_vars)
 
 
[email protected](not have_torch, "torch is required")
[email protected](not have_torch, torch_requirement_message)
 class 
TorchDistributorBaselineUnitTests(TorchDistributorBaselineUnitTestsMixin, 
unittest.TestCase):
     @classmethod
     def setUpClass(cls):
@@ -409,7 +403,7 @@ class TorchDistributorLocalUnitTestsMixin:
         self.assertEqual(output, "success" * 4096)
 
 
[email protected](not have_torch, "torch is required")
[email protected](not have_torch, torch_requirement_message)
 class TorchDistributorLocalUnitTests(TorchDistributorLocalUnitTestsMixin, 
unittest.TestCase):
     @classmethod
     def setUpClass(cls):
@@ -431,7 +425,7 @@ class 
TorchDistributorLocalUnitTests(TorchDistributorLocalUnitTestsMixin, unitte
         cls.spark.stop()
 
 
[email protected](not have_torch, "torch is required")
[email protected](not have_torch, torch_requirement_message)
 class TorchDistributorLocalUnitTestsII(TorchDistributorLocalUnitTestsMixin, 
unittest.TestCase):
     @classmethod
     def setUpClass(cls):
@@ -502,7 +496,7 @@ class TorchDistributorDistributedUnitTestsMixin:
         self.assertEqual(output, "success" * 4096)
 
 
[email protected](not have_torch, "torch is required")
[email protected](not have_torch, torch_requirement_message)
 class TorchDistributorDistributedUnitTests(
     TorchDistributorDistributedUnitTestsMixin, unittest.TestCase
 ):
@@ -549,7 +543,7 @@ class TorchWrapperUnitTestsMixin:
         self.assertEqual(mock_clean_and_terminate.call_count, 0)
 
 
[email protected](not have_torch, "torch is required")
[email protected](not have_torch, torch_requirement_message)
 class TorchWrapperUnitTests(TorchWrapperUnitTestsMixin, unittest.TestCase):
     pass
 
diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py
index 1773cdcf0a0a..ca16628fc56f 100644
--- a/python/pyspark/testing/utils.py
+++ b/python/pyspark/testing/utils.py
@@ -35,23 +35,6 @@ from typing import (
 )
 from itertools import zip_longest
 
-have_scipy = False
-have_numpy = False
-try:
-    import scipy  # noqa: F401
-
-    have_scipy = True
-except ImportError:
-    # No SciPy, but that's okay, we'll skip those tests
-    pass
-try:
-    import numpy as np  # noqa: F401
-
-    have_numpy = True
-except ImportError:
-    # No NumPy, but that's okay, we'll skip those tests
-    pass
-
 from pyspark import SparkConf
 from pyspark.errors import PySparkAssertionError, PySparkException, 
PySparkTypeError
 from pyspark.errors.exceptions.captured import CapturedException
@@ -68,6 +51,38 @@ __all__ = ["assertDataFrameEqual", "assertSchemaEqual"]
 SPARK_HOME = _find_spark_home()
 
 
+def have_package(name: str) -> bool:
+    try:
+        import importlib
+
+        importlib.import_module(name)
+        return True
+    except Exception:
+        return False
+
+
+have_numpy = have_package("numpy")
+numpy_requirement_message = None if have_numpy else "No module named 'numpy'"
+
+have_scipy = have_package("scipy")
+scipy_requirement_message = None if have_scipy else "No module named 'scipy'"
+
+have_sklearn = have_package("sklearn")
+sklearn_requirement_message = None if have_sklearn else "No module named 
'sklearn'"
+
+have_torch = have_package("torch")
+torch_requirement_message = None if have_torch else "No module named 'torch'"
+
+have_torcheval = have_package("torcheval")
+torcheval_requirement_message = None if have_torcheval else "No module named 
'torcheval'"
+
+have_deepspeed = have_package("deepspeed")
+deepspeed_requirement_message = None if have_deepspeed else "No module named 
'deepspeed'"
+
+have_plotly = have_package("plotly")
+plotly_requirement_message = None if have_plotly else "No module named 
'plotly'"
+
+
 def read_int(b):
     return struct.unpack("!i", b)[0]
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to