This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new c303b042966b [SPARK-47808][PYTHON][ML][TESTS] Make pyspark.ml.connect
tests running without optional dependencies
c303b042966b is described below
commit c303b042966bb3851da6649fc1d7f03de5db20be
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Thu Apr 11 16:42:23 2024 +0900
[SPARK-47808][PYTHON][ML][TESTS] Make pyspark.ml.connect tests running
without optional dependencies
### What changes were proposed in this pull request?
This PR makes `pyspark.ml.connect` tests running without optional
dependencies.
### Why are the changes needed?
Optional dependencies should not stop the tests.
### Does this PR introduce _any_ user-facing change?
No, test-only.
### How was this patch tested?
Will be tested together in https://github.com/apache/spark/pull/45941
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #45996 from HyukjinKwon/SPARK-47808.
Authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/ml/connect/classification.py | 16 ++++++++++++++--
.../ml/tests/connect/test_connect_classification.py | 4 +++-
python/pyspark/ml/tests/connect/test_connect_feature.py | 13 ++++++++++++-
python/pyspark/ml/tests/connect/test_connect_pipeline.py | 15 +++++++++++++--
4 files changed, 42 insertions(+), 6 deletions(-)
diff --git a/python/pyspark/ml/connect/classification.py
b/python/pyspark/ml/connect/classification.py
index 8b816f51ca27..8d8c6227eac3 100644
--- a/python/pyspark/ml/connect/classification.py
+++ b/python/pyspark/ml/connect/classification.py
@@ -17,8 +17,6 @@
from typing import Any, Dict, Union, List, Tuple, Callable, Optional
import math
-import torch
-import torch.nn as torch_nn
import numpy as np
import pandas as pd
@@ -87,6 +85,8 @@ def _train_logistic_regression_model_worker_fn(
seed: int,
) -> Any:
from pyspark.ml.torch.distributor import _get_spark_partition_data_loader
+ import torch
+ import torch.nn as torch_nn
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed
import torch.optim as optim
@@ -216,6 +216,9 @@ class LogisticRegression(
self._set(**kwargs)
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) ->
"LogisticRegressionModel":
+ import torch
+ import torch.nn as torch_nn
+
if isinstance(dataset, pd.DataFrame):
# TODO: support pandas dataframe fitting
raise NotImplementedError("Fitting pandas dataframe is not
supported yet.")
@@ -316,6 +319,10 @@ class LogisticRegressionModel(
return output_cols
def _get_transform_fn(self) -> Callable[["pd.Series"], Any]:
+ import torch
+
+ import torch.nn as torch_nn
+
model_state_dict = self.torch_model.state_dict()
num_features = self.num_features
num_classes = self.num_classes
@@ -357,6 +364,9 @@ class LogisticRegressionModel(
return self.__class__.__name__ + ".torch"
def _save_core_model(self, path: str) -> None:
+ import torch
+ import torch.nn as torch_nn
+
lor_torch_model = torch_nn.Sequential(
self.torch_model,
torch_nn.Softmax(dim=1),
@@ -364,6 +374,8 @@ class LogisticRegressionModel(
torch.save(lor_torch_model, path)
def _load_core_model(self, path: str) -> None:
+ import torch
+
lor_torch_model = torch.load(path)
self.torch_model = lor_torch_model[0]
diff --git a/python/pyspark/ml/tests/connect/test_connect_classification.py
b/python/pyspark/ml/tests/connect/test_connect_classification.py
index 1f811c774cbd..ebc1745874d9 100644
--- a/python/pyspark/ml/tests/connect/test_connect_classification.py
+++ b/python/pyspark/ml/tests/connect/test_connect_classification.py
@@ -21,6 +21,7 @@ import unittest
from pyspark.sql import SparkSession
from pyspark.testing.connectutils import should_test_connect,
connect_requirement_message
+torch_requirement_message = "torch is required"
have_torch = True
try:
import torch # noqa: F401
@@ -32,7 +33,8 @@ if should_test_connect:
@unittest.skipIf(
- not should_test_connect or not have_torch, connect_requirement_message or
"torch is required"
+ not should_test_connect or not have_torch,
+ connect_requirement_message or torch_requirement_message,
)
class ClassificationTestsOnConnect(ClassificationTestsMixin,
unittest.TestCase):
def setUp(self) -> None:
diff --git a/python/pyspark/ml/tests/connect/test_connect_feature.py
b/python/pyspark/ml/tests/connect/test_connect_feature.py
index cf450cc743ae..04b1744c4995 100644
--- a/python/pyspark/ml/tests/connect/test_connect_feature.py
+++ b/python/pyspark/ml/tests/connect/test_connect_feature.py
@@ -20,11 +20,22 @@ import unittest
from pyspark.sql import SparkSession
from pyspark.testing.connectutils import should_test_connect,
connect_requirement_message
+have_sklearn = True
+sklearn_requirement_message = None
+try:
+ from sklearn.datasets import load_breast_cancer # noqa: F401
+except ImportError:
+ have_sklearn = False
+ sklearn_requirement_message = "No sklearn found"
+
if should_test_connect:
from pyspark.ml.tests.connect.test_legacy_mode_feature import
FeatureTestsMixin
[email protected](not should_test_connect, connect_requirement_message)
[email protected](
+ not should_test_connect or not have_sklearn,
+ connect_requirement_message or sklearn_requirement_message,
+)
class FeatureTestsOnConnect(FeatureTestsMixin, unittest.TestCase):
def setUp(self) -> None:
self.spark = SparkSession.builder.remote("local[2]").getOrCreate()
diff --git a/python/pyspark/ml/tests/connect/test_connect_pipeline.py
b/python/pyspark/ml/tests/connect/test_connect_pipeline.py
index 6a895e892397..45d19f2bcdde 100644
--- a/python/pyspark/ml/tests/connect/test_connect_pipeline.py
+++ b/python/pyspark/ml/tests/connect/test_connect_pipeline.py
@@ -24,8 +24,19 @@ from pyspark.testing.connectutils import
should_test_connect, connect_requiremen
if should_test_connect:
from pyspark.ml.tests.connect.test_legacy_mode_pipeline import
PipelineTestsMixin
-
[email protected](not should_test_connect, connect_requirement_message)
+torch_requirement_message = None
+have_torch = True
+try:
+ import torch # noqa: F401
+except ImportError:
+ have_torch = False
+ torch_requirement_message = "torch is required"
+
+
[email protected](
+ not should_test_connect or not have_torch,
+ connect_requirement_message or torch_requirement_message,
+)
class PipelineTestsOnConnect(PipelineTestsMixin, unittest.TestCase):
def setUp(self) -> None:
self.spark = (
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]