This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new a40919912f5c [SPARK-50376][PYTHON][ML][TESTS] Centralize the
dependency check in ML tests
a40919912f5c is described below
commit a40919912f5ce7f63fff2907b30e473dd4155227
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Nov 21 14:57:55 2024 +0900
[SPARK-50376][PYTHON][ML][TESTS] Centralize the dependency check in ML tests
### What changes were proposed in this pull request?
Centralize the dependency check in ML tests
### Why are the changes needed?
deduplicate code
### Does this PR introduce _any_ user-facing change?
no, test-only
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #48911 from zhengruifeng/py_centralize_ml_dep.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../deepspeed/tests/test_deepspeed_distributor.py | 11 ++---
.../tests/connect/test_connect_classification.py | 8 +---
.../ml/tests/connect/test_connect_evaluation.py | 9 +---
.../ml/tests/connect/test_connect_feature.py | 9 +---
.../ml/tests/connect/test_connect_pipeline.py | 8 +---
.../ml/tests/connect/test_connect_tuning.py | 5 +--
.../connect/test_legacy_mode_classification.py | 11 ++---
.../tests/connect/test_legacy_mode_evaluation.py | 9 +---
.../ml/tests/connect/test_legacy_mode_feature.py | 5 +--
.../ml/tests/connect/test_legacy_mode_pipeline.py | 5 +--
.../ml/tests/connect/test_legacy_mode_tuning.py | 37 ++++++----------
.../tests/connect/test_parity_torch_data_loader.py | 9 +---
.../tests/connect/test_parity_torch_distributor.py | 9 +---
python/pyspark/ml/torch/tests/test_data_loader.py | 13 ++----
python/pyspark/ml/torch/tests/test_distributor.py | 18 +++-----
python/pyspark/testing/utils.py | 49 ++++++++++++++--------
16 files changed, 71 insertions(+), 144 deletions(-)
diff --git a/python/pyspark/ml/deepspeed/tests/test_deepspeed_distributor.py
b/python/pyspark/ml/deepspeed/tests/test_deepspeed_distributor.py
index 590e541c3842..66a9b553cc75 100644
--- a/python/pyspark/ml/deepspeed/tests/test_deepspeed_distributor.py
+++ b/python/pyspark/ml/deepspeed/tests/test_deepspeed_distributor.py
@@ -30,12 +30,7 @@ from pyspark.ml.torch.tests.test_distributor import (
set_up_test_dirs,
get_distributed_mode_conf,
)
-
-have_deepspeed = True
-try:
- import deepspeed # noqa: F401
-except ImportError:
- have_deepspeed = False
+from pyspark.testing.utils import have_deepspeed, deepspeed_requirement_message
class DeepspeedTorchDistributorUnitTests(unittest.TestCase):
@@ -219,7 +214,7 @@ def _create_pytorch_training_test_file():
# and inference, the hope is to switch out the training
# and file for the tests with more realistic testing
# that use Deepspeed constructs.
[email protected](not have_deepspeed, "deepspeed is required for these tests")
[email protected](not have_deepspeed, deepspeed_requirement_message)
class DeepspeedTorchDistributorDistributedEndToEnd(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
@@ -259,7 +254,7 @@ class
DeepspeedTorchDistributorDistributedEndToEnd(unittest.TestCase):
dist.run(cp_path, 2, 5)
[email protected](not have_deepspeed, "deepspeed is required for these tests")
[email protected](not have_deepspeed, deepspeed_requirement_message)
class DeepspeedDistributorLocalEndToEndTests(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
diff --git a/python/pyspark/ml/tests/connect/test_connect_classification.py
b/python/pyspark/ml/tests/connect/test_connect_classification.py
index 910d2d2ec42f..d3e86a3fb9df 100644
--- a/python/pyspark/ml/tests/connect/test_connect_classification.py
+++ b/python/pyspark/ml/tests/connect/test_connect_classification.py
@@ -23,13 +23,7 @@ from pyspark.util import is_remote_only
from pyspark.sql import SparkSession
from pyspark.ml.tests.connect.test_legacy_mode_classification import
ClassificationTestsMixin
from pyspark.testing.connectutils import should_test_connect,
connect_requirement_message
-
-torch_requirement_message = "torch is required"
-have_torch = True
-try:
- import torch # noqa: F401
-except ImportError:
- have_torch = False
+from pyspark.testing.utils import have_torch, torch_requirement_message
@unittest.skipIf(
diff --git a/python/pyspark/ml/tests/connect/test_connect_evaluation.py
b/python/pyspark/ml/tests/connect/test_connect_evaluation.py
index 9acf5ae0ac44..cabd8b5b50df 100644
--- a/python/pyspark/ml/tests/connect/test_connect_evaluation.py
+++ b/python/pyspark/ml/tests/connect/test_connect_evaluation.py
@@ -20,19 +20,14 @@ import unittest
from pyspark.sql import SparkSession
from pyspark.testing.connectutils import should_test_connect,
connect_requirement_message
-
-have_torcheval = True
-try:
- import torcheval # noqa: F401
-except ImportError:
- have_torcheval = False
+from pyspark.testing.utils import have_torcheval, torcheval_requirement_message
if should_test_connect:
from pyspark.ml.tests.connect.test_legacy_mode_evaluation import
EvaluationTestsMixin
@unittest.skipIf(
not should_test_connect or not have_torcheval,
- connect_requirement_message or "torcheval is required",
+ connect_requirement_message or torcheval_requirement_message,
)
class EvaluationTestsOnConnect(EvaluationTestsMixin, unittest.TestCase):
def setUp(self) -> None:
diff --git a/python/pyspark/ml/tests/connect/test_connect_feature.py
b/python/pyspark/ml/tests/connect/test_connect_feature.py
index c1d02050097b..879cbff6d0cc 100644
--- a/python/pyspark/ml/tests/connect/test_connect_feature.py
+++ b/python/pyspark/ml/tests/connect/test_connect_feature.py
@@ -20,14 +20,7 @@ import unittest
from pyspark.sql import SparkSession
from pyspark.testing.connectutils import should_test_connect,
connect_requirement_message
-
-have_sklearn = True
-sklearn_requirement_message = None
-try:
- from sklearn.datasets import load_breast_cancer # noqa: F401
-except ImportError:
- have_sklearn = False
- sklearn_requirement_message = "No sklearn found"
+from pyspark.testing.utils import have_sklearn, sklearn_requirement_message
if should_test_connect:
from pyspark.ml.tests.connect.test_legacy_mode_feature import
FeatureTestsMixin
diff --git a/python/pyspark/ml/tests/connect/test_connect_pipeline.py
b/python/pyspark/ml/tests/connect/test_connect_pipeline.py
index 7733af7631e9..f8576d0cb09d 100644
--- a/python/pyspark/ml/tests/connect/test_connect_pipeline.py
+++ b/python/pyspark/ml/tests/connect/test_connect_pipeline.py
@@ -21,14 +21,8 @@ import unittest
from pyspark.util import is_remote_only
from pyspark.sql import SparkSession
from pyspark.testing.connectutils import should_test_connect,
connect_requirement_message
+from pyspark.testing.utils import have_torch, torch_requirement_message
-torch_requirement_message = None
-have_torch = True
-try:
- import torch # noqa: F401
-except ImportError:
- have_torch = False
- torch_requirement_message = "torch is required"
if should_test_connect:
from pyspark.ml.tests.connect.test_legacy_mode_pipeline import
PipelineTestsMixin
diff --git a/python/pyspark/ml/tests/connect/test_connect_tuning.py
b/python/pyspark/ml/tests/connect/test_connect_tuning.py
index fee7113e1ae5..d737dd5767db 100644
--- a/python/pyspark/ml/tests/connect/test_connect_tuning.py
+++ b/python/pyspark/ml/tests/connect/test_connect_tuning.py
@@ -19,13 +19,10 @@
import unittest
import os
-from pyspark.ml.tests.connect.test_connect_classification import (
- have_torch,
- torch_requirement_message,
-)
from pyspark.util import is_remote_only
from pyspark.sql import SparkSession
from pyspark.testing.connectutils import should_test_connect,
connect_requirement_message
+from pyspark.testing.utils import have_torch, torch_requirement_message
if should_test_connect:
from pyspark.ml.tests.connect.test_legacy_mode_tuning import
CrossValidatorTestsMixin
diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_classification.py
b/python/pyspark/ml/tests/connect/test_legacy_mode_classification.py
index dc2642a42d66..fdae31077002 100644
--- a/python/pyspark/ml/tests/connect/test_legacy_mode_classification.py
+++ b/python/pyspark/ml/tests/connect/test_legacy_mode_classification.py
@@ -24,14 +24,7 @@ import numpy as np
from pyspark.util import is_remote_only
from pyspark.sql import SparkSession
from pyspark.testing.connectutils import should_test_connect,
connect_requirement_message
-
-have_torch = True
-torch_requirement_message = None
-try:
- import torch # noqa: F401
-except ImportError:
- have_torch = False
- torch_requirement_message = "No torch found"
+from pyspark.testing.utils import have_torch, torch_requirement_message
if should_test_connect:
from pyspark.ml.connect.classification import (
@@ -135,6 +128,8 @@ class ClassificationTestsMixin:
self._check_result(local_transform_result, expected_predictions,
expected_probabilities)
def test_save_load(self):
+ import torch
+
with tempfile.TemporaryDirectory(prefix="test_save_load") as tmp_dir:
estimator = LORV2(maxIter=2, numTrainWorkers=2, learningRate=0.001)
local_path = os.path.join(tmp_dir, "estimator")
diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py
b/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py
index 11c1f9aeee51..3a5417dadf50 100644
--- a/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py
+++ b/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py
@@ -23,14 +23,7 @@ import numpy as np
from pyspark.util import is_remote_only
from pyspark.sql import SparkSession
from pyspark.testing.connectutils import should_test_connect,
connect_requirement_message
-
-have_torcheval = True
-torcheval_requirement_message = None
-try:
- import torcheval # noqa: F401
-except ImportError:
- have_torcheval = False
- torcheval_requirement_message = "torcheval is required"
+from pyspark.testing.utils import have_torcheval, torcheval_requirement_message
if should_test_connect:
from pyspark.ml.connect.evaluation import (
diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py
b/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py
index d90e4a4315d5..6812db778450 100644
--- a/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py
+++ b/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py
@@ -26,10 +26,7 @@ import numpy as np
from pyspark.util import is_remote_only
from pyspark.sql import SparkSession
from pyspark.testing.connectutils import should_test_connect,
connect_requirement_message
-from pyspark.ml.tests.connect.test_connect_classification import (
- have_torch,
- torch_requirement_message,
-)
+from pyspark.testing.utils import have_torch, torch_requirement_message
if should_test_connect:
from pyspark.ml.connect.feature import (
diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_pipeline.py
b/python/pyspark/ml/tests/connect/test_legacy_mode_pipeline.py
index 9165034718d7..8b19f5931d20 100644
--- a/python/pyspark/ml/tests/connect/test_legacy_mode_pipeline.py
+++ b/python/pyspark/ml/tests/connect/test_legacy_mode_pipeline.py
@@ -24,10 +24,7 @@ import numpy as np
from pyspark.util import is_remote_only
from pyspark.sql import SparkSession
from pyspark.testing.connectutils import should_test_connect,
connect_requirement_message
-from pyspark.ml.tests.connect.test_connect_classification import (
- have_torch,
- torch_requirement_message,
-)
+from pyspark.testing.utils import have_torch, torch_requirement_message
if should_test_connect:
from pyspark.ml.connect.feature import StandardScaler
diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py
b/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py
index 33cc39d3319d..06c3ad93d92d 100644
--- a/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py
+++ b/python/pyspark/ml/tests/connect/test_legacy_mode_tuning.py
@@ -28,31 +28,14 @@ from pyspark.ml.tuning import ParamGridBuilder
from pyspark.sql import SparkSession
from pyspark.sql.functions import rand
from pyspark.testing.connectutils import should_test_connect,
connect_requirement_message
-
-have_sklearn = True
-sklearn_requirement_message = None
-try:
- from sklearn.datasets import load_breast_cancer # noqa: F401
-except ImportError:
- have_sklearn = False
- sklearn_requirement_message = "No sklearn found"
-
-have_torch = True
-torch_requirement_message = None
-try:
- import torch # noqa: F401
-except ImportError:
- have_torch = False
- torch_requirement_message = "torch is required"
-
-have_torcheval = True
-torcheval_requirement_message = None
-try:
- import torcheval # noqa: F401
-except ImportError:
- have_torcheval = False
- torcheval_requirement_message = "torcheval is required"
-
+from pyspark.testing.utils import (
+ have_sklearn,
+ sklearn_requirement_message,
+ have_torch,
+ torch_requirement_message,
+ have_torcheval,
+ torcheval_requirement_message,
+)
if should_test_connect:
import pandas as pd
@@ -205,6 +188,8 @@ class CrossValidatorTestsMixin:
)
def test_crossvalidator_on_pipeline(self):
+ from sklearn.datasets import load_breast_cancer
+
sk_dataset = load_breast_cancer()
train_dataset = self.spark.createDataFrame(
@@ -270,6 +255,8 @@ class CrossValidatorTestsMixin:
sys.version_info > (3, 12), "SPARK-46078: Fails with dev torch with
Python 3.12"
)
def test_crossvalidator_with_fold_col(self):
+ from sklearn.datasets import load_breast_cancer
+
sk_dataset = load_breast_cancer()
train_dataset = self.spark.createDataFrame(
diff --git a/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py
b/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py
index 462fe3822141..de05927138d4 100644
--- a/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py
+++ b/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py
@@ -19,14 +19,7 @@ import unittest
from pyspark.util import is_remote_only
from pyspark.sql import SparkSession
-
-torch_requirement_message = None
-have_torch = True
-try:
- import torch # noqa: F401
-except ImportError:
- have_torch = False
- torch_requirement_message = "torch is required"
+from pyspark.testing.utils import have_torch, torch_requirement_message
if not is_remote_only():
from pyspark.ml.torch.tests.test_data_loader import
TorchDistributorDataLoaderUnitTests
diff --git a/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py
b/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py
index e40303ae9ce2..3cd8abfc6e4e 100644
--- a/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py
+++ b/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py
@@ -19,16 +19,9 @@ import os
import shutil
import unittest
-torch_requirement_message = None
-have_torch = True
-try:
- import torch # noqa: F401
-except ImportError:
- have_torch = False
- torch_requirement_message = "torch is required"
-
from pyspark.util import is_remote_only
from pyspark.sql import SparkSession
+from pyspark.testing.utils import have_torch, torch_requirement_message
if not is_remote_only():
from pyspark.ml.torch.tests.test_distributor import (
diff --git a/python/pyspark/ml/torch/tests/test_data_loader.py
b/python/pyspark/ml/torch/tests/test_data_loader.py
index 00f5f0a8c8d8..a47a5f163b68 100644
--- a/python/pyspark/ml/torch/tests/test_data_loader.py
+++ b/python/pyspark/ml/torch/tests/test_data_loader.py
@@ -17,23 +17,16 @@
import unittest
-import numpy as np
-
-have_torch = True
-try:
- import torch # noqa: F401
-except ImportError:
- have_torch = False
-
from pyspark.ml.torch.distributor import (
TorchDistributor,
_get_spark_partition_data_loader,
)
from pyspark.sql import SparkSession
from pyspark.ml.linalg import Vectors
+from pyspark.testing.utils import have_torch, torch_requirement_message
[email protected](not have_torch, "torch is required")
[email protected](not have_torch, torch_requirement_message)
class TorchDistributorDataLoaderUnitTests(unittest.TestCase):
def setUp(self) -> None:
self.spark = (
@@ -46,6 +39,8 @@ class TorchDistributorDataLoaderUnitTests(unittest.TestCase):
self.spark.stop()
def _check_data_loader_result_correctness(self, result, expected):
+ import numpy as np
+
assert len(result) == len(expected)
for res_row, exp_row in zip(result, expected):
diff --git a/python/pyspark/ml/torch/tests/test_distributor.py
b/python/pyspark/ml/torch/tests/test_distributor.py
index d16e60588482..7b6a93afbff7 100644
--- a/python/pyspark/ml/torch/tests/test_distributor.py
+++ b/python/pyspark/ml/torch/tests/test_distributor.py
@@ -29,17 +29,11 @@ from typing import Callable, Dict, Any
import unittest
from unittest.mock import patch
-have_torch = True
-try:
- import torch # noqa: F401
-except ImportError:
- have_torch = False
-
from pyspark import SparkConf, SparkContext
from pyspark.ml.torch.distributor import TorchDistributor, _get_gpus_owned
from pyspark.ml.torch.torch_run_process_wrapper import clean_and_terminate,
check_parent_alive
from pyspark.sql import SparkSession
-from pyspark.testing.utils import SPARK_HOME
+from pyspark.testing.utils import SPARK_HOME, have_torch,
torch_requirement_message
@contextlib.contextmanager
@@ -312,7 +306,7 @@ class TorchDistributorBaselineUnitTestsMixin:
self.delete_env_vars(input_env_vars)
[email protected](not have_torch, "torch is required")
[email protected](not have_torch, torch_requirement_message)
class
TorchDistributorBaselineUnitTests(TorchDistributorBaselineUnitTestsMixin,
unittest.TestCase):
@classmethod
def setUpClass(cls):
@@ -409,7 +403,7 @@ class TorchDistributorLocalUnitTestsMixin:
self.assertEqual(output, "success" * 4096)
[email protected](not have_torch, "torch is required")
[email protected](not have_torch, torch_requirement_message)
class TorchDistributorLocalUnitTests(TorchDistributorLocalUnitTestsMixin,
unittest.TestCase):
@classmethod
def setUpClass(cls):
@@ -431,7 +425,7 @@ class
TorchDistributorLocalUnitTests(TorchDistributorLocalUnitTestsMixin, unitte
cls.spark.stop()
[email protected](not have_torch, "torch is required")
[email protected](not have_torch, torch_requirement_message)
class TorchDistributorLocalUnitTestsII(TorchDistributorLocalUnitTestsMixin,
unittest.TestCase):
@classmethod
def setUpClass(cls):
@@ -502,7 +496,7 @@ class TorchDistributorDistributedUnitTestsMixin:
self.assertEqual(output, "success" * 4096)
[email protected](not have_torch, "torch is required")
[email protected](not have_torch, torch_requirement_message)
class TorchDistributorDistributedUnitTests(
TorchDistributorDistributedUnitTestsMixin, unittest.TestCase
):
@@ -549,7 +543,7 @@ class TorchWrapperUnitTestsMixin:
self.assertEqual(mock_clean_and_terminate.call_count, 0)
[email protected](not have_torch, "torch is required")
[email protected](not have_torch, torch_requirement_message)
class TorchWrapperUnitTests(TorchWrapperUnitTestsMixin, unittest.TestCase):
pass
diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py
index 1773cdcf0a0a..ca16628fc56f 100644
--- a/python/pyspark/testing/utils.py
+++ b/python/pyspark/testing/utils.py
@@ -35,23 +35,6 @@ from typing import (
)
from itertools import zip_longest
-have_scipy = False
-have_numpy = False
-try:
- import scipy # noqa: F401
-
- have_scipy = True
-except ImportError:
- # No SciPy, but that's okay, we'll skip those tests
- pass
-try:
- import numpy as np # noqa: F401
-
- have_numpy = True
-except ImportError:
- # No NumPy, but that's okay, we'll skip those tests
- pass
-
from pyspark import SparkConf
from pyspark.errors import PySparkAssertionError, PySparkException,
PySparkTypeError
from pyspark.errors.exceptions.captured import CapturedException
@@ -68,6 +51,38 @@ __all__ = ["assertDataFrameEqual", "assertSchemaEqual"]
SPARK_HOME = _find_spark_home()
+def have_package(name: str) -> bool:
+ try:
+ import importlib
+
+ importlib.import_module(name)
+ return True
+ except Exception:
+ return False
+
+
+have_numpy = have_package("numpy")
+numpy_requirement_message = None if have_numpy else "No module named 'numpy'"
+
+have_scipy = have_package("scipy")
+scipy_requirement_message = None if have_scipy else "No module named 'scipy'"
+
+have_sklearn = have_package("sklearn")
+sklearn_requirement_message = None if have_sklearn else "No module named
'sklearn'"
+
+have_torch = have_package("torch")
+torch_requirement_message = None if have_torch else "No module named 'torch'"
+
+have_torcheval = have_package("torcheval")
+torcheval_requirement_message = None if have_torcheval else "No module named
'torcheval'"
+
+have_deepspeed = have_package("deepspeed")
+deepspeed_requirement_message = None if have_deepspeed else "No module named
'deepspeed'"
+
+have_plotly = have_package("plotly")
+plotly_requirement_message = None if have_plotly else "No module named
'plotly'"
+
+
def read_int(b):
return struct.unpack("!i", b)[0]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]