This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new c303b042966b [SPARK-47808][PYTHON][ML][TESTS] Make pyspark.ml.connect tests running without optional dependencies c303b042966b is described below commit c303b042966bb3851da6649fc1d7f03de5db20be Author: Hyukjin Kwon <gurwls...@apache.org> AuthorDate: Thu Apr 11 16:42:23 2024 +0900 [SPARK-47808][PYTHON][ML][TESTS] Make pyspark.ml.connect tests running without optional dependencies ### What changes were proposed in this pull request? This PR makes `pyspark.ml.connect` tests running without optional dependencies. ### Why are the changes needed? Optional dependencies should not stop the tests. ### Does this PR introduce _any_ user-facing change? No, test-only. ### How was this patch tested? Will be tested together in https://github.com/apache/spark/pull/45941 ### Was this patch authored or co-authored using generative AI tooling? No. Closes #45996 from HyukjinKwon/SPARK-47808. Authored-by: Hyukjin Kwon <gurwls...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/ml/connect/classification.py | 16 ++++++++++++++-- .../ml/tests/connect/test_connect_classification.py | 4 +++- python/pyspark/ml/tests/connect/test_connect_feature.py | 13 ++++++++++++- python/pyspark/ml/tests/connect/test_connect_pipeline.py | 15 +++++++++++++-- 4 files changed, 42 insertions(+), 6 deletions(-) diff --git a/python/pyspark/ml/connect/classification.py b/python/pyspark/ml/connect/classification.py index 8b816f51ca27..8d8c6227eac3 100644 --- a/python/pyspark/ml/connect/classification.py +++ b/python/pyspark/ml/connect/classification.py @@ -17,8 +17,6 @@ from typing import Any, Dict, Union, List, Tuple, Callable, Optional import math -import torch -import torch.nn as torch_nn import numpy as np import pandas as pd @@ -87,6 +85,8 @@ def _train_logistic_regression_model_worker_fn( seed: int, ) -> Any: from pyspark.ml.torch.distributor import _get_spark_partition_data_loader + import torch + import torch.nn as torch_nn from torch.nn.parallel import DistributedDataParallel as DDP import torch.distributed import torch.optim as optim @@ -216,6 +216,9 @@ class LogisticRegression( self._set(**kwargs) def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "LogisticRegressionModel": + import torch + import torch.nn as torch_nn + if isinstance(dataset, pd.DataFrame): # TODO: support pandas dataframe fitting raise NotImplementedError("Fitting pandas dataframe is not supported yet.") @@ -316,6 +319,10 @@ class LogisticRegressionModel( return output_cols def _get_transform_fn(self) -> Callable[["pd.Series"], Any]: + import torch + + import torch.nn as torch_nn + model_state_dict = self.torch_model.state_dict() num_features = self.num_features num_classes = self.num_classes @@ -357,6 +364,9 @@ class LogisticRegressionModel( return self.__class__.__name__ + ".torch" def _save_core_model(self, path: str) -> None: + import torch + import torch.nn as torch_nn + lor_torch_model = torch_nn.Sequential( self.torch_model, torch_nn.Softmax(dim=1), @@ -364,6 +374,8 @@ class LogisticRegressionModel( torch.save(lor_torch_model, path) def _load_core_model(self, path: str) -> None: + import torch + lor_torch_model = torch.load(path) self.torch_model = lor_torch_model[0] diff --git a/python/pyspark/ml/tests/connect/test_connect_classification.py b/python/pyspark/ml/tests/connect/test_connect_classification.py index 1f811c774cbd..ebc1745874d9 100644 --- a/python/pyspark/ml/tests/connect/test_connect_classification.py +++ b/python/pyspark/ml/tests/connect/test_connect_classification.py @@ -21,6 +21,7 @@ import unittest from pyspark.sql import SparkSession from pyspark.testing.connectutils import should_test_connect, connect_requirement_message +torch_requirement_message = "torch is required" have_torch = True try: import torch # noqa: F401 @@ -32,7 +33,8 @@ if should_test_connect: @unittest.skipIf( - not should_test_connect or not have_torch, connect_requirement_message or "torch is required" + not should_test_connect or not have_torch, + connect_requirement_message or torch_requirement_message, ) class ClassificationTestsOnConnect(ClassificationTestsMixin, unittest.TestCase): def setUp(self) -> None: diff --git a/python/pyspark/ml/tests/connect/test_connect_feature.py b/python/pyspark/ml/tests/connect/test_connect_feature.py index cf450cc743ae..04b1744c4995 100644 --- a/python/pyspark/ml/tests/connect/test_connect_feature.py +++ b/python/pyspark/ml/tests/connect/test_connect_feature.py @@ -20,11 +20,22 @@ import unittest from pyspark.sql import SparkSession from pyspark.testing.connectutils import should_test_connect, connect_requirement_message +have_sklearn = True +sklearn_requirement_message = None +try: + from sklearn.datasets import load_breast_cancer # noqa: F401 +except ImportError: + have_sklearn = False + sklearn_requirement_message = "No sklearn found" + if should_test_connect: from pyspark.ml.tests.connect.test_legacy_mode_feature import FeatureTestsMixin -@unittest.skipIf(not should_test_connect, connect_requirement_message) +@unittest.skipIf( + not should_test_connect or not have_sklearn, + connect_requirement_message or sklearn_requirement_message, +) class FeatureTestsOnConnect(FeatureTestsMixin, unittest.TestCase): def setUp(self) -> None: self.spark = SparkSession.builder.remote("local[2]").getOrCreate() diff --git a/python/pyspark/ml/tests/connect/test_connect_pipeline.py b/python/pyspark/ml/tests/connect/test_connect_pipeline.py index 6a895e892397..45d19f2bcdde 100644 --- a/python/pyspark/ml/tests/connect/test_connect_pipeline.py +++ b/python/pyspark/ml/tests/connect/test_connect_pipeline.py @@ -24,8 +24,19 @@ from pyspark.testing.connectutils import should_test_connect, connect_requiremen if should_test_connect: from pyspark.ml.tests.connect.test_legacy_mode_pipeline import PipelineTestsMixin - -@unittest.skipIf(not should_test_connect, connect_requirement_message) +torch_requirement_message = None +have_torch = True +try: + import torch # noqa: F401 +except ImportError: + have_torch = False + torch_requirement_message = "torch is required" + + +@unittest.skipIf( + not should_test_connect or not have_torch, + connect_requirement_message or torch_requirement_message, +) class PipelineTestsOnConnect(PipelineTestsMixin, unittest.TestCase): def setUp(self) -> None: self.spark = ( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org