This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.0 by this push:
new 21b6693b8ea3 [SPARK-50940][ML][PYTHON][CONNECT] Adds support
CrossValidator/CrossValidatorModel on connect
21b6693b8ea3 is described below
commit 21b6693b8ea3c6732bbc71b63c629849a20fed84
Author: Bobby Wang <[email protected]>
AuthorDate: Sun Jan 26 14:41:15 2025 +0800
[SPARK-50940][ML][PYTHON][CONNECT] Adds support
CrossValidator/CrossValidatorModel on connect
### What changes were proposed in this pull request?
Support CrossValidator/CrossValidatorModel on connect
### Why are the changes needed?
for parity feature
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
The newly added test pass
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #49644 from wbo4958/cv.
Authored-by: Bobby Wang <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
(cherry picked from commit 762599cc64df36ebff2b794dd6c77f48c428749f)
Signed-off-by: Ruifeng Zheng <[email protected]>
---
dev/sparktestsupport/modules.py | 2 +
python/pyspark/ml/connect/readwrite.py | 48 ++++++++++-
.../pyspark/ml/tests/connect/test_parity_tuning.py | 56 +++++++++++++
python/pyspark/ml/tests/test_tuning.py | 98 ++++++++++++++++++++++
python/pyspark/ml/tuning.py | 32 ++++---
python/pyspark/ml/util.py | 3 +-
6 files changed, 224 insertions(+), 15 deletions(-)
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index f76427f79761..8e518b0febe3 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -675,6 +675,7 @@ pyspark_ml = Module(
"pyspark.ml.tests.test_param",
"pyspark.ml.tests.test_persistence",
"pyspark.ml.tests.test_pipeline",
+ "pyspark.ml.tests.test_tuning",
"pyspark.ml.tests.test_stat",
"pyspark.ml.tests.test_training_summary",
"pyspark.ml.tests.tuning.test_tuning",
@@ -1127,6 +1128,7 @@ pyspark_ml_connect = Module(
"pyspark.ml.tests.connect.test_parity_evaluation",
"pyspark.ml.tests.connect.test_parity_feature",
"pyspark.ml.tests.connect.test_parity_pipeline",
+ "pyspark.ml.tests.connect.test_parity_tuning",
],
excluded_python_implementations=[
"PyPy" # Skip these tests under PyPy since they require numpy,
pandas, and pyarrow and
diff --git a/python/pyspark/ml/connect/readwrite.py
b/python/pyspark/ml/connect/readwrite.py
index eea966270871..3bf2031538d9 100644
--- a/python/pyspark/ml/connect/readwrite.py
+++ b/python/pyspark/ml/connect/readwrite.py
@@ -14,12 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-
import warnings
-from typing import cast, Type, TYPE_CHECKING, Union, List, Dict, Any
+from typing import cast, Type, TYPE_CHECKING, Union, List, Dict, Any, Optional
import pyspark.sql.connect.proto as pb2
from pyspark.ml.connect.serialize import serialize_ml_params, deserialize,
deserialize_param
+from pyspark.ml.tuning import CrossValidatorModelWriter, CrossValidatorModel
from pyspark.ml.util import MLWriter, MLReader, RL
from pyspark.ml.wrapper import JavaWrapper
@@ -29,6 +29,19 @@ if TYPE_CHECKING:
from pyspark.ml.util import JavaMLReadable, JavaMLWritable
+class RemoteCrossValidatorModelWriter(CrossValidatorModelWriter):
+ def __init__(
+ self,
+ instance: "CrossValidatorModel",
+ optionMap: Dict[str, Any] = {},
+ session: Optional["SparkSession"] = None,
+ ):
+ super(RemoteCrossValidatorModelWriter, self).__init__(instance)
+ self.instance = instance
+ self.optionMap = optionMap
+ self.session(session) # type: ignore[arg-type]
+
+
class RemoteMLWriter(MLWriter):
def __init__(self, instance: "JavaMLWritable") -> None:
super().__init__()
@@ -63,6 +76,7 @@ class RemoteMLWriter(MLWriter):
from pyspark.ml.wrapper import JavaModel, JavaEstimator,
JavaTransformer
from pyspark.ml.evaluation import JavaEvaluator
from pyspark.ml.pipeline import Pipeline, PipelineModel
+ from pyspark.ml.tuning import CrossValidator
# Spark Connect ML is built on scala Spark.ML, that means we're only
# supporting JavaModel or JavaEstimator or JavaEvaluator
@@ -126,6 +140,21 @@ class RemoteMLWriter(MLWriter):
path,
)
+ elif isinstance(instance, CrossValidator):
+ from pyspark.ml.tuning import CrossValidatorWriter
+
+ if shouldOverwrite:
+ # TODO(SPARK-50954): Support client side model path overwrite
+ warnings.warn("Overwrite doesn't take effect for
CrossValidator")
+ cv_writer = CrossValidatorWriter(instance)
+ cv_writer.session(session) # type: ignore[arg-type]
+ cv_writer.save(path)
+ elif isinstance(instance, CrossValidatorModel):
+ if shouldOverwrite:
+ # TODO(SPARK-50954): Support client side model path overwrite
+ warnings.warn("Overwrite doesn't take effect for
CrossValidatorModel")
+ cvm_writer = RemoteCrossValidatorModelWriter(instance, optionMap,
session)
+ cvm_writer.save(path)
else:
raise NotImplementedError(f"Unsupported write for
{instance.__class__}")
@@ -153,6 +182,7 @@ class RemoteMLReader(MLReader[RL]):
from pyspark.ml.wrapper import JavaModel, JavaEstimator,
JavaTransformer
from pyspark.ml.evaluation import JavaEvaluator
from pyspark.ml.pipeline import Pipeline, PipelineModel
+ from pyspark.ml.tuning import CrossValidator
if (
issubclass(clazz, JavaModel)
@@ -217,5 +247,19 @@ class RemoteMLReader(MLReader[RL]):
else:
return PipelineModel(stages=cast(List[Transformer],
stages))._resetUid(uid)
+ elif issubclass(clazz, CrossValidator):
+ from pyspark.ml.tuning import CrossValidatorReader
+
+ cv_reader = CrossValidatorReader(CrossValidator)
+ cv_reader.session(session)
+ return cv_reader.load(path)
+
+ elif issubclass(clazz, CrossValidatorModel):
+ from pyspark.ml.tuning import CrossValidatorModelReader
+
+ cvm_reader = CrossValidatorModelReader(CrossValidator)
+ cvm_reader.session(session)
+ return cvm_reader.load(path)
+
else:
raise RuntimeError(f"Unsupported read for {clazz}")
diff --git a/python/pyspark/ml/tests/connect/test_parity_tuning.py
b/python/pyspark/ml/tests/connect/test_parity_tuning.py
new file mode 100644
index 000000000000..473661d6c288
--- /dev/null
+++ b/python/pyspark/ml/tests/connect/test_parity_tuning.py
@@ -0,0 +1,56 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from pyspark.ml.classification import LogisticRegression
+from pyspark.ml.connect.readwrite import RemoteCrossValidatorModelWriter
+from pyspark.ml.linalg import Vectors
+from pyspark.ml.tests.test_tuning import TuningTestsMixin
+from pyspark.ml.tuning import CrossValidatorModel
+from pyspark.testing.connectutils import ReusedConnectTestCase
+
+
+class TuningParityTests(TuningTestsMixin, ReusedConnectTestCase):
+ def test_remote_cross_validator_model_writer(self):
+ df = self.spark.createDataFrame(
+ [
+ (1.0, 1.0, Vectors.dense(0.0, 5.0)),
+ (0.0, 2.0, Vectors.dense(1.0, 2.0)),
+ (1.0, 3.0, Vectors.dense(2.0, 1.0)),
+ (0.0, 4.0, Vectors.dense(3.0, 3.0)),
+ ],
+ ["label", "weight", "features"],
+ )
+
+ lor = LogisticRegression()
+ lor_model = lor.fit(df)
+ cv_model = CrossValidatorModel(lor_model)
+ writer = RemoteCrossValidatorModelWriter(cv_model, {"a": "b"},
self.spark)
+ self.assertEqual(writer.optionMap["a"], "b")
+
+
+if __name__ == "__main__":
+ from pyspark.ml.tests.connect.test_parity_tuning import * # noqa: F401
+
+ try:
+ import xmlrunner # type: ignore[import]
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/ml/tests/test_tuning.py
b/python/pyspark/ml/tests/test_tuning.py
new file mode 100644
index 000000000000..2bc0e22c1209
--- /dev/null
+++ b/python/pyspark/ml/tests/test_tuning.py
@@ -0,0 +1,98 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import tempfile
+import unittest
+
+import numpy as np
+
+from pyspark.ml.evaluation import BinaryClassificationEvaluator
+from pyspark.ml.linalg import Vectors
+from pyspark.ml.classification import LogisticRegression
+from pyspark.ml.tuning import ParamGridBuilder, CrossValidator,
CrossValidatorModel
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+
+
+class TuningTestsMixin:
+ def test_cross_validator(self):
+ dataset = self.spark.createDataFrame(
+ [
+ (Vectors.dense([0.0]), 0.0),
+ (Vectors.dense([0.4]), 1.0),
+ (Vectors.dense([0.5]), 0.0),
+ (Vectors.dense([0.6]), 1.0),
+ (Vectors.dense([1.0]), 1.0),
+ ]
+ * 10,
+ ["features", "label"],
+ )
+ lr = LogisticRegression()
+ grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
+ evaluator = BinaryClassificationEvaluator()
+ cv = CrossValidator(
+ estimator=lr, estimatorParamMaps=grid, evaluator=evaluator,
parallelism=1
+ )
+
+ self.assertEqual(cv.getEstimator(), lr)
+ self.assertEqual(cv.getEvaluator(), evaluator)
+ self.assertEqual(cv.getParallelism(), 1)
+ self.assertEqual(cv.getEstimatorParamMaps(), grid)
+
+ model = cv.fit(dataset)
+ self.assertEqual(model.getEstimator(), lr)
+ self.assertEqual(model.getEvaluator(), evaluator)
+ self.assertEqual(model.getEstimatorParamMaps(), grid)
+ self.assertTrue(np.isclose(model.avgMetrics[0], 0.5, atol=1e-4))
+
+ output = model.transform(dataset)
+ self.assertEqual(
+ output.columns, ["features", "label", "rawPrediction",
"probability", "prediction"]
+ )
+ self.assertEqual(output.count(), 50)
+
+ # save & load
+ with tempfile.TemporaryDirectory(prefix="cv_lr") as d:
+ path1 = os.path.join(d, "cv")
+ cv.write().save(path1)
+ cv2 = CrossValidator.load(path1)
+ self.assertEqual(str(cv), str(cv2))
+ self.assertEqual(str(cv.getEstimator()), str(cv2.getEstimator()))
+ self.assertEqual(str(cv.getEvaluator()), str(cv2.getEvaluator()))
+
+ path2 = os.path.join(d, "cv_model")
+ model.write().save(path2)
+ model2 = CrossValidatorModel.load(path2)
+ self.assertEqual(str(model), str(model2))
+ self.assertEqual(str(model.getEstimator()),
str(model2.getEstimator()))
+ self.assertEqual(str(model.getEvaluator()),
str(model2.getEvaluator()))
+
+
+class TuningTests(TuningTestsMixin, ReusedSQLTestCase):
+ pass
+
+
+if __name__ == "__main__":
+ from pyspark.ml.tests.test_tuning import * # noqa: F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index 695bbf98517c..06d3837e1ae4 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -53,6 +53,8 @@ from pyspark.ml.util import (
MLWriter,
JavaMLReader,
JavaMLWriter,
+ try_remote_write,
+ try_remote_read,
)
from pyspark.ml.wrapper import JavaParams, JavaEstimator, JavaWrapper
from pyspark.sql.functions import col, lit, rand, UserDefinedFunction
@@ -386,7 +388,7 @@ class _ValidatorSharedReadWrite:
def saveImpl(
path: str,
instance: _ValidatorParams,
- sc: "SparkContext",
+ sc: Union["SparkContext", "SparkSession"],
extraMetadata: Optional[Dict[str, Any]] = None,
) -> None:
numParamsNotJson = 0
@@ -430,7 +432,7 @@ class _ValidatorSharedReadWrite:
@staticmethod
def load(
- path: str, sc: "SparkContext", metadata: Dict[str, Any]
+ path: str, sc: Union["SparkContext", "SparkSession"], metadata:
Dict[str, Any]
) -> Tuple[Dict[str, Any], Estimator, Evaluator, List["ParamMap"]]:
evaluatorPath = os.path.join(path, "evaluator")
evaluator: Evaluator =
DefaultParamsReader.loadParamsInstance(evaluatorPath, sc)
@@ -513,12 +515,12 @@ class CrossValidatorReader(MLReader["CrossValidator"]):
self.cls = cls
def load(self, path: str) -> "CrossValidator":
- metadata = DefaultParamsReader.loadMetadata(path, self.sc)
+ metadata = DefaultParamsReader.loadMetadata(path, self.sparkSession)
if not DefaultParamsReader.isPythonParamsInstance(metadata):
return JavaMLReader(self.cls).load(path) # type: ignore[arg-type]
else:
metadata, estimator, evaluator, estimatorParamMaps =
_ValidatorSharedReadWrite.load(
- path, self.sc, metadata
+ path, self.sparkSession, metadata
)
cv = CrossValidator(
estimator=estimator, estimatorParamMaps=estimatorParamMaps,
evaluator=evaluator
@@ -536,7 +538,7 @@ class CrossValidatorWriter(MLWriter):
def saveImpl(self, path: str) -> None:
_ValidatorSharedReadWrite.validateParams(self.instance)
- _ValidatorSharedReadWrite.saveImpl(path, self.instance, self.sc)
+ _ValidatorSharedReadWrite.saveImpl(path, self.instance,
self.sparkSession)
@inherit_doc
@@ -546,16 +548,18 @@ class
CrossValidatorModelReader(MLReader["CrossValidatorModel"]):
self.cls = cls
def load(self, path: str) -> "CrossValidatorModel":
- metadata = DefaultParamsReader.loadMetadata(path, self.sc)
+ metadata = DefaultParamsReader.loadMetadata(path, self.sparkSession)
if not DefaultParamsReader.isPythonParamsInstance(metadata):
return JavaMLReader(self.cls).load(path) # type: ignore[arg-type]
else:
metadata, estimator, evaluator, estimatorParamMaps =
_ValidatorSharedReadWrite.load(
- path, self.sc, metadata
+ path, self.sparkSession, metadata
)
numFolds = metadata["paramMap"]["numFolds"]
bestModelPath = os.path.join(path, "bestModel")
- bestModel: Model =
DefaultParamsReader.loadParamsInstance(bestModelPath, self.sc)
+ bestModel: Model = DefaultParamsReader.loadParamsInstance(
+ bestModelPath, self.sparkSession
+ )
avgMetrics = metadata["avgMetrics"]
if "stdMetrics" in metadata:
stdMetrics = metadata["stdMetrics"]
@@ -571,7 +575,7 @@ class
CrossValidatorModelReader(MLReader["CrossValidatorModel"]):
path, "subModels", f"fold{splitIndex}",
f"{paramIndex}"
)
subModels[splitIndex][paramIndex] =
DefaultParamsReader.loadParamsInstance(
- modelPath, self.sc
+ modelPath, self.sparkSession
)
else:
subModels = None
@@ -608,7 +612,9 @@ class CrossValidatorModelWriter(MLWriter):
if instance.stdMetrics:
extraMetadata["stdMetrics"] = instance.stdMetrics
- _ValidatorSharedReadWrite.saveImpl(path, instance, self.sc,
extraMetadata=extraMetadata)
+ _ValidatorSharedReadWrite.saveImpl(
+ path, instance, self.sparkSession, extraMetadata=extraMetadata
+ )
bestModelPath = os.path.join(path, "bestModel")
cast(MLWritable, instance.bestModel).save(bestModelPath)
if persistSubModels:
@@ -845,7 +851,7 @@ class CrossValidator(
train = datasets[i][0].cache()
tasks = map(
- inheritable_thread_target,
+ inheritable_thread_target(dataset.sparkSession),
_parallelFitTasks(est, train, eva, validation, epm,
collectSubModelsParam),
)
for j, metric, subModel in pool.imap_unordered(lambda f: f(),
tasks):
@@ -939,6 +945,7 @@ class CrossValidator(
return newCV
@since("2.3.0")
+ @try_remote_write
def write(self) -> MLWriter:
"""Returns an MLWriter instance for this ML instance."""
if _ValidatorSharedReadWrite.is_java_convertible(self):
@@ -947,6 +954,7 @@ class CrossValidator(
@classmethod
@since("2.3.0")
+ @try_remote_read
def read(cls) -> CrossValidatorReader:
"""Returns an MLReader instance for this class."""
return CrossValidatorReader(cls)
@@ -1077,6 +1085,7 @@ class CrossValidatorModel(
)
@since("2.3.0")
+ @try_remote_write
def write(self) -> MLWriter:
"""Returns an MLWriter instance for this ML instance."""
if _ValidatorSharedReadWrite.is_java_convertible(self):
@@ -1085,6 +1094,7 @@ class CrossValidatorModel(
@classmethod
@since("2.3.0")
+ @try_remote_read
def read(cls) -> CrossValidatorModelReader:
"""Returns an MLReader instance for this class."""
return CrossValidatorModelReader(cls)
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index 6006d131b5c0..309f8452ac79 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -149,11 +149,10 @@ def try_remote_transform_relation(f: FuncT) -> FuncT:
def wrapped(self: "JavaWrapper", dataset: "ConnectDataFrame") -> Any:
if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
from pyspark.ml import Model, Transformer
- from pyspark.sql.connect.session import SparkSession
from pyspark.sql.connect.dataframe import DataFrame as
ConnectDataFrame
from pyspark.ml.connect.serialize import serialize_ml_params
- session = SparkSession.getActiveSession()
+ session = dataset.sparkSession
assert session is not None
# Model is also a Transformer, so we much match Model first
if isinstance(self, Model):
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]