This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.0 by this push:
new 68615c4275c6 [SPARK-50922][ML][PYTHON][CONNECT] Support OneVsRest on
Connect
68615c4275c6 is described below
commit 68615c4275c678f7cc72f5caa0fa1eb8e39e89fb
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Tue Jan 28 09:40:06 2025 +0900
[SPARK-50922][ML][PYTHON][CONNECT] Support OneVsRest on Connect
### What changes were proposed in this pull request?
Support OneVsRest on Connect
### Why are the changes needed?
feature parity
### Does this PR introduce _any_ user-facing change?
yes
### How was this patch tested?
new tests
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #49693 from zhengruifeng/ml_connect_ovr.
Lead-authored-by: Ruifeng Zheng <[email protected]>
Co-authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
(cherry picked from commit 22bac2eeb5901d48eb10ba36b52e5ccfdaac9d35)
Signed-off-by: Hyukjin Kwon <[email protected]>
---
dev/sparktestsupport/modules.py | 2 +
python/pyspark/ml/classification.py | 59 ++++++---
python/pyspark/ml/connect/readwrite.py | 37 ++++++
python/pyspark/ml/tests/connect/test_parity_ovr.py | 37 ++++++
python/pyspark/ml/tests/test_ovr.py | 135 +++++++++++++++++++++
5 files changed, 251 insertions(+), 19 deletions(-)
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 8e518b0febe3..dd7c387cd5f1 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -676,6 +676,7 @@ pyspark_ml = Module(
"pyspark.ml.tests.test_persistence",
"pyspark.ml.tests.test_pipeline",
"pyspark.ml.tests.test_tuning",
+ "pyspark.ml.tests.test_ovr",
"pyspark.ml.tests.test_stat",
"pyspark.ml.tests.test_training_summary",
"pyspark.ml.tests.tuning.test_tuning",
@@ -1129,6 +1130,7 @@ pyspark_ml_connect = Module(
"pyspark.ml.tests.connect.test_parity_feature",
"pyspark.ml.tests.connect.test_parity_pipeline",
"pyspark.ml.tests.connect.test_parity_tuning",
+ "pyspark.ml.tests.connect.test_parity_ovr",
],
excluded_python_implementations=[
"PyPy" # Skip these tests under PyPy since they require numpy,
pandas, and pyarrow and
diff --git a/python/pyspark/ml/classification.py
b/python/pyspark/ml/classification.py
index d8ed51a82abe..1ed2f1fd4b8a 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -85,6 +85,8 @@ from pyspark.ml.util import (
MLWriter,
MLWritable,
HasTrainingSummary,
+ try_remote_read,
+ try_remote_write,
try_remote_attribute_relation,
)
from pyspark.ml.wrapper import JavaParams, JavaPredictor, JavaPredictionModel,
JavaWrapper
@@ -94,6 +96,7 @@ from pyspark.sql import DataFrame, Row, SparkSession
from pyspark.sql.functions import udf, when
from pyspark.sql.types import ArrayType, DoubleType
from pyspark.storagelevel import StorageLevel
+from pyspark.sql.utils import is_remote
if TYPE_CHECKING:
from pyspark.ml._typing import P, ParamMap
@@ -3572,31 +3575,45 @@ class OneVsRest(
if handlePersistence:
multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK)
- def trainSingleClass(index: int) -> CM:
- binaryLabelCol = "mc2b$" + str(index)
- trainingDataset = multiclassLabeled.withColumn(
- binaryLabelCol,
- when(multiclassLabeled[labelCol] == float(index),
1.0).otherwise(0.0),
- )
- paramMap = dict(
- [
- (classifier.labelCol, binaryLabelCol),
- (classifier.featuresCol, featuresCol),
- (classifier.predictionCol, predictionCol),
- ]
- )
- if weightCol:
- paramMap[cast(HasWeightCol, classifier).weightCol] = weightCol
- return classifier.fit(trainingDataset, paramMap)
+ def _oneClassFitTasks(numClasses: int):
+ indices = iter(range(numClasses))
+
+ def trainSingleClass() -> CM:
+ index = next(indices)
+
+ binaryLabelCol = "mc2b$" + str(index)
+ trainingDataset = multiclassLabeled.withColumn(
+ binaryLabelCol,
+ when(multiclassLabeled[labelCol] == float(index),
1.0).otherwise(0.0),
+ )
+ paramMap = dict(
+ [
+ (classifier.labelCol, binaryLabelCol),
+ (classifier.featuresCol, featuresCol),
+ (classifier.predictionCol, predictionCol),
+ ]
+ )
+ if weightCol:
+ paramMap[cast(HasWeightCol, classifier).weightCol] =
weightCol
+ return index, classifier.fit(trainingDataset, paramMap)
+ return [trainSingleClass] * numClasses
+
+ tasks = map(
+ inheritable_thread_target(dataset.sparkSession),
+ _oneClassFitTasks(numClasses),
+ )
pool = ThreadPool(processes=min(self.getParallelism(), numClasses))
- models = pool.map(inheritable_thread_target(trainSingleClass),
range(numClasses))
+ subModels = [None] * numClasses
+ for j, subModel in pool.imap_unordered(lambda f: f(), tasks):
+ assert subModels is not None
+ subModels[j] = subModel
if handlePersistence:
multiclassLabeled.unpersist()
- return self._copyValues(OneVsRestModel(models=models))
+ return self._copyValues(OneVsRestModel(models=subModels))
def copy(self, extra: Optional["ParamMap"] = None) -> "OneVsRest":
"""
@@ -3671,9 +3688,11 @@ class OneVsRest(
return _java_obj
@classmethod
+ @try_remote_read
def read(cls) -> "OneVsRestReader":
return OneVsRestReader(cls)
+ @try_remote_write
def write(self) -> MLWriter:
if isinstance(self.getClassifier(), JavaMLWritable):
return JavaMLWriter(self) # type: ignore[arg-type]
@@ -3787,7 +3806,7 @@ class OneVsRestModel(
from pyspark.core.context import SparkContext
self.models = models
- if not isinstance(models[0], JavaMLWritable):
+ if is_remote() or not isinstance(models[0], JavaMLWritable):
return
# set java instance
java_models = [cast(_JavaClassificationModel, model)._to_java() for
model in self.models]
@@ -3955,9 +3974,11 @@ class OneVsRestModel(
return _java_obj
@classmethod
+ @try_remote_read
def read(cls) -> "OneVsRestModelReader":
return OneVsRestModelReader(cls)
+ @try_remote_write
def write(self) -> MLWriter:
if all(
map(
diff --git a/python/pyspark/ml/connect/readwrite.py
b/python/pyspark/ml/connect/readwrite.py
index 6392e988c067..6e364afb7dbf 100644
--- a/python/pyspark/ml/connect/readwrite.py
+++ b/python/pyspark/ml/connect/readwrite.py
@@ -95,6 +95,7 @@ class RemoteMLWriter(MLWriter):
from pyspark.ml.evaluation import JavaEvaluator
from pyspark.ml.pipeline import Pipeline, PipelineModel
from pyspark.ml.tuning import CrossValidator, TrainValidationSplit
+ from pyspark.ml.classification import OneVsRest, OneVsRestModel
# Spark Connect ML is built on scala Spark.ML, that means we're only
# supporting JavaModel or JavaEstimator or JavaEvaluator
@@ -187,6 +188,27 @@ class RemoteMLWriter(MLWriter):
warnings.warn("Overwrite doesn't take effect for
TrainValidationSplitModel")
tvsm_writer = RemoteTrainValidationSplitModelWriter(instance,
optionMap, session)
tvsm_writer.save(path)
+ elif isinstance(instance, OneVsRest):
+ from pyspark.ml.classification import OneVsRestWriter
+
+ if shouldOverwrite:
+ # TODO(SPARK-50954): Support client side model path overwrite
+ warnings.warn("Overwrite doesn't take effect for OneVsRest")
+
+ writer = OneVsRestWriter(instance)
+ writer.session(session)
+ writer.save(path)
+ # _OneVsRestSharedReadWrite.saveImpl(self.instance,
self.sparkSession, path)
+ elif isinstance(instance, OneVsRestModel):
+ from pyspark.ml.classification import OneVsRestModelWriter
+
+ if shouldOverwrite:
+ # TODO(SPARK-50954): Support client side model path overwrite
+ warnings.warn("Overwrite doesn't take effect for
OneVsRestModel")
+
+ writer = OneVsRestModelWriter(instance)
+ writer.session(session)
+ writer.save(path)
else:
raise NotImplementedError(f"Unsupported write for
{instance.__class__}")
@@ -215,6 +237,7 @@ class RemoteMLReader(MLReader[RL]):
from pyspark.ml.evaluation import JavaEvaluator
from pyspark.ml.pipeline import Pipeline, PipelineModel
from pyspark.ml.tuning import CrossValidator, TrainValidationSplit
+ from pyspark.ml.classification import OneVsRest, OneVsRestModel
if (
issubclass(clazz, JavaModel)
@@ -307,5 +330,19 @@ class RemoteMLReader(MLReader[RL]):
tvs_reader.session(session)
return tvs_reader.load(path)
+ elif issubclass(clazz, OneVsRest):
+ from pyspark.ml.classification import OneVsRestReader
+
+ ovr_reader = OneVsRestReader(OneVsRest)
+ ovr_reader.session(session)
+ return ovr_reader.load(path)
+
+ elif issubclass(clazz, OneVsRestModel):
+ from pyspark.ml.classification import OneVsRestModelReader
+
+ ovr_reader = OneVsRestModelReader(OneVsRestModel)
+ ovr_reader.session(session)
+ return ovr_reader.load(path)
+
else:
raise RuntimeError(f"Unsupported read for {clazz}")
diff --git a/python/pyspark/ml/tests/connect/test_parity_ovr.py
b/python/pyspark/ml/tests/connect/test_parity_ovr.py
new file mode 100644
index 000000000000..3ad3fec8cf10
--- /dev/null
+++ b/python/pyspark/ml/tests/connect/test_parity_ovr.py
@@ -0,0 +1,37 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from pyspark.ml.tests.test_ovr import OneVsRestTestsMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+
+
+class OneVsRestParityTests(OneVsRestTestsMixin, ReusedConnectTestCase):
+ pass
+
+
+if __name__ == "__main__":
+ from pyspark.ml.tests.connect.test_parity_ovr import * # noqa: F401
+
+ try:
+ import xmlrunner # type: ignore[import]
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/ml/tests/test_ovr.py
b/python/pyspark/ml/tests/test_ovr.py
new file mode 100644
index 000000000000..aa003be53f5d
--- /dev/null
+++ b/python/pyspark/ml/tests/test_ovr.py
@@ -0,0 +1,135 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import tempfile
+import unittest
+
+import numpy as np
+
+from pyspark.ml.linalg import Vectors
+from pyspark.ml.classification import (
+ LinearSVC,
+ LinearSVCModel,
+ OneVsRest,
+ OneVsRestModel,
+)
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+
+
+class OneVsRestTestsMixin:
+ def test_one_vs_rest(self):
+ spark = self.spark
+ df = (
+ spark.createDataFrame(
+ [
+ (0, 1.0, Vectors.dense(0.0, 5.0)),
+ (1, 0.0, Vectors.dense(1.0, 2.0)),
+ (2, 1.0, Vectors.dense(2.0, 1.0)),
+ (3, 2.0, Vectors.dense(3.0, 3.0)),
+ ],
+ ["index", "label", "features"],
+ )
+ .coalesce(1)
+ .sortWithinPartitions("index")
+ .select("label", "features")
+ )
+
+ svc = LinearSVC(maxIter=1, regParam=1.0)
+ self.assertEqual(svc.getMaxIter(), 1)
+ self.assertEqual(svc.getRegParam(), 1.0)
+
+ ovr = OneVsRest(classifier=svc, parallelism=1)
+ self.assertEqual(ovr.getParallelism(), 1)
+
+ model = ovr.fit(df)
+ self.assertIsInstance(model, OneVsRestModel)
+ self.assertEqual(len(model.models), 3)
+ for submodel in model.models:
+ self.assertIsInstance(submodel, LinearSVCModel)
+
+ self.assertTrue(
+ np.allclose(model.models[0].intercept, 0.06279247869226989,
atol=1e-4),
+ model.models[0].intercept,
+ )
+ self.assertTrue(
+ np.allclose(
+ model.models[0].coefficients.toArray(),
+ [-0.1198765502306968, -0.1027513287691687],
+ atol=1e-4,
+ ),
+ model.models[0].coefficients,
+ )
+
+ self.assertTrue(
+ np.allclose(model.models[1].intercept, 0.025877458475338313,
atol=1e-4),
+ model.models[1].intercept,
+ )
+ self.assertTrue(
+ np.allclose(
+ model.models[1].coefficients.toArray(),
+ [-0.0362284418654736, 0.010350983390135305],
+ atol=1e-4,
+ ),
+ model.models[1].coefficients,
+ )
+
+ self.assertTrue(
+ np.allclose(model.models[2].intercept, -0.37024065419409624,
atol=1e-4),
+ model.models[2].intercept,
+ )
+ self.assertTrue(
+ np.allclose(
+ model.models[2].coefficients.toArray(),
+ [0.12886829400126, 0.012273170857262873],
+ atol=1e-4,
+ ),
+ model.models[2].coefficients,
+ )
+
+ output = model.transform(df)
+ expected_cols = ["label", "features", "rawPrediction", "prediction"]
+ self.assertEqual(output.columns, expected_cols)
+ self.assertEqual(output.count(), 4)
+
+ # Model save & load
+ with tempfile.TemporaryDirectory(prefix="linear_svc") as d:
+ path1 = os.path.join(d, "ovr")
+ ovr.write().overwrite().save(path1)
+ ovr2 = OneVsRest.load(path1)
+ self.assertEqual(str(ovr), str(ovr2))
+
+ path2 = os.path.join(d, "ovr_model")
+ model.write().overwrite().save(path2)
+ model2 = OneVsRestModel.load(path2)
+ self.assertEqual(str(model), str(model2))
+
+
+class OneVsRestTests(OneVsRestTestsMixin, ReusedSQLTestCase):
+ pass
+
+
+if __name__ == "__main__":
+ from pyspark.ml.tests.test_ovr import * # noqa: F401,F403
+
+ try:
+ import xmlrunner # type: ignore[import]
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]