This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new 68615c4275c6 [SPARK-50922][ML][PYTHON][CONNECT] Support OneVsRest on 
Connect
68615c4275c6 is described below

commit 68615c4275c678f7cc72f5caa0fa1eb8e39e89fb
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Tue Jan 28 09:40:06 2025 +0900

    [SPARK-50922][ML][PYTHON][CONNECT] Support OneVsRest on Connect
    
    ### What changes were proposed in this pull request?
    Support OneVsRest on Connect
    
    ### Why are the changes needed?
    feature parity
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    
    ### How was this patch tested?
    new tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #49693 from zhengruifeng/ml_connect_ovr.
    
    Lead-authored-by: Ruifeng Zheng <[email protected]>
    Co-authored-by: Hyukjin Kwon <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
    (cherry picked from commit 22bac2eeb5901d48eb10ba36b52e5ccfdaac9d35)
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 dev/sparktestsupport/modules.py                    |   2 +
 python/pyspark/ml/classification.py                |  59 ++++++---
 python/pyspark/ml/connect/readwrite.py             |  37 ++++++
 python/pyspark/ml/tests/connect/test_parity_ovr.py |  37 ++++++
 python/pyspark/ml/tests/test_ovr.py                | 135 +++++++++++++++++++++
 5 files changed, 251 insertions(+), 19 deletions(-)

diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 8e518b0febe3..dd7c387cd5f1 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -676,6 +676,7 @@ pyspark_ml = Module(
         "pyspark.ml.tests.test_persistence",
         "pyspark.ml.tests.test_pipeline",
         "pyspark.ml.tests.test_tuning",
+        "pyspark.ml.tests.test_ovr",
         "pyspark.ml.tests.test_stat",
         "pyspark.ml.tests.test_training_summary",
         "pyspark.ml.tests.tuning.test_tuning",
@@ -1129,6 +1130,7 @@ pyspark_ml_connect = Module(
         "pyspark.ml.tests.connect.test_parity_feature",
         "pyspark.ml.tests.connect.test_parity_pipeline",
         "pyspark.ml.tests.connect.test_parity_tuning",
+        "pyspark.ml.tests.connect.test_parity_ovr",
     ],
     excluded_python_implementations=[
         "PyPy"  # Skip these tests under PyPy since they require numpy, 
pandas, and pyarrow and
diff --git a/python/pyspark/ml/classification.py 
b/python/pyspark/ml/classification.py
index d8ed51a82abe..1ed2f1fd4b8a 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -85,6 +85,8 @@ from pyspark.ml.util import (
     MLWriter,
     MLWritable,
     HasTrainingSummary,
+    try_remote_read,
+    try_remote_write,
     try_remote_attribute_relation,
 )
 from pyspark.ml.wrapper import JavaParams, JavaPredictor, JavaPredictionModel, 
JavaWrapper
@@ -94,6 +96,7 @@ from pyspark.sql import DataFrame, Row, SparkSession
 from pyspark.sql.functions import udf, when
 from pyspark.sql.types import ArrayType, DoubleType
 from pyspark.storagelevel import StorageLevel
+from pyspark.sql.utils import is_remote
 
 if TYPE_CHECKING:
     from pyspark.ml._typing import P, ParamMap
@@ -3572,31 +3575,45 @@ class OneVsRest(
         if handlePersistence:
             multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK)
 
-        def trainSingleClass(index: int) -> CM:
-            binaryLabelCol = "mc2b$" + str(index)
-            trainingDataset = multiclassLabeled.withColumn(
-                binaryLabelCol,
-                when(multiclassLabeled[labelCol] == float(index), 
1.0).otherwise(0.0),
-            )
-            paramMap = dict(
-                [
-                    (classifier.labelCol, binaryLabelCol),
-                    (classifier.featuresCol, featuresCol),
-                    (classifier.predictionCol, predictionCol),
-                ]
-            )
-            if weightCol:
-                paramMap[cast(HasWeightCol, classifier).weightCol] = weightCol
-            return classifier.fit(trainingDataset, paramMap)
+        def _oneClassFitTasks(numClasses: int):
+            indices = iter(range(numClasses))
+
+            def trainSingleClass() -> CM:
+                index = next(indices)
+
+                binaryLabelCol = "mc2b$" + str(index)
+                trainingDataset = multiclassLabeled.withColumn(
+                    binaryLabelCol,
+                    when(multiclassLabeled[labelCol] == float(index), 
1.0).otherwise(0.0),
+                )
+                paramMap = dict(
+                    [
+                        (classifier.labelCol, binaryLabelCol),
+                        (classifier.featuresCol, featuresCol),
+                        (classifier.predictionCol, predictionCol),
+                    ]
+                )
+                if weightCol:
+                    paramMap[cast(HasWeightCol, classifier).weightCol] = 
weightCol
+                return index, classifier.fit(trainingDataset, paramMap)
 
+            return [trainSingleClass] * numClasses
+
+        tasks = map(
+            inheritable_thread_target(dataset.sparkSession),
+            _oneClassFitTasks(numClasses),
+        )
         pool = ThreadPool(processes=min(self.getParallelism(), numClasses))
 
-        models = pool.map(inheritable_thread_target(trainSingleClass), 
range(numClasses))
+        subModels = [None] * numClasses
+        for j, subModel in pool.imap_unordered(lambda f: f(), tasks):
+            assert subModels is not None
+            subModels[j] = subModel
 
         if handlePersistence:
             multiclassLabeled.unpersist()
 
-        return self._copyValues(OneVsRestModel(models=models))
+        return self._copyValues(OneVsRestModel(models=subModels))
 
     def copy(self, extra: Optional["ParamMap"] = None) -> "OneVsRest":
         """
@@ -3671,9 +3688,11 @@ class OneVsRest(
         return _java_obj
 
     @classmethod
+    @try_remote_read
     def read(cls) -> "OneVsRestReader":
         return OneVsRestReader(cls)
 
+    @try_remote_write
     def write(self) -> MLWriter:
         if isinstance(self.getClassifier(), JavaMLWritable):
             return JavaMLWriter(self)  # type: ignore[arg-type]
@@ -3787,7 +3806,7 @@ class OneVsRestModel(
         from pyspark.core.context import SparkContext
 
         self.models = models
-        if not isinstance(models[0], JavaMLWritable):
+        if is_remote() or not isinstance(models[0], JavaMLWritable):
             return
         # set java instance
         java_models = [cast(_JavaClassificationModel, model)._to_java() for 
model in self.models]
@@ -3955,9 +3974,11 @@ class OneVsRestModel(
         return _java_obj
 
     @classmethod
+    @try_remote_read
     def read(cls) -> "OneVsRestModelReader":
         return OneVsRestModelReader(cls)
 
+    @try_remote_write
     def write(self) -> MLWriter:
         if all(
             map(
diff --git a/python/pyspark/ml/connect/readwrite.py 
b/python/pyspark/ml/connect/readwrite.py
index 6392e988c067..6e364afb7dbf 100644
--- a/python/pyspark/ml/connect/readwrite.py
+++ b/python/pyspark/ml/connect/readwrite.py
@@ -95,6 +95,7 @@ class RemoteMLWriter(MLWriter):
         from pyspark.ml.evaluation import JavaEvaluator
         from pyspark.ml.pipeline import Pipeline, PipelineModel
         from pyspark.ml.tuning import CrossValidator, TrainValidationSplit
+        from pyspark.ml.classification import OneVsRest, OneVsRestModel
 
         # Spark Connect ML is built on scala Spark.ML, that means we're only
         # supporting JavaModel or JavaEstimator or JavaEvaluator
@@ -187,6 +188,27 @@ class RemoteMLWriter(MLWriter):
                 warnings.warn("Overwrite doesn't take effect for 
TrainValidationSplitModel")
             tvsm_writer = RemoteTrainValidationSplitModelWriter(instance, 
optionMap, session)
             tvsm_writer.save(path)
+        elif isinstance(instance, OneVsRest):
+            from pyspark.ml.classification import OneVsRestWriter
+
+            if shouldOverwrite:
+                # TODO(SPARK-50954): Support client side model path overwrite
+                warnings.warn("Overwrite doesn't take effect for OneVsRest")
+
+            writer = OneVsRestWriter(instance)
+            writer.session(session)
+            writer.save(path)
+            # _OneVsRestSharedReadWrite.saveImpl(self.instance, 
self.sparkSession, path)
+        elif isinstance(instance, OneVsRestModel):
+            from pyspark.ml.classification import OneVsRestModelWriter
+
+            if shouldOverwrite:
+                # TODO(SPARK-50954): Support client side model path overwrite
+                warnings.warn("Overwrite doesn't take effect for 
OneVsRestModel")
+
+            writer = OneVsRestModelWriter(instance)
+            writer.session(session)
+            writer.save(path)
         else:
             raise NotImplementedError(f"Unsupported write for 
{instance.__class__}")
 
@@ -215,6 +237,7 @@ class RemoteMLReader(MLReader[RL]):
         from pyspark.ml.evaluation import JavaEvaluator
         from pyspark.ml.pipeline import Pipeline, PipelineModel
         from pyspark.ml.tuning import CrossValidator, TrainValidationSplit
+        from pyspark.ml.classification import OneVsRest, OneVsRestModel
 
         if (
             issubclass(clazz, JavaModel)
@@ -307,5 +330,19 @@ class RemoteMLReader(MLReader[RL]):
             tvs_reader.session(session)
             return tvs_reader.load(path)
 
+        elif issubclass(clazz, OneVsRest):
+            from pyspark.ml.classification import OneVsRestReader
+
+            ovr_reader = OneVsRestReader(OneVsRest)
+            ovr_reader.session(session)
+            return ovr_reader.load(path)
+
+        elif issubclass(clazz, OneVsRestModel):
+            from pyspark.ml.classification import OneVsRestModelReader
+
+            ovr_reader = OneVsRestModelReader(OneVsRestModel)
+            ovr_reader.session(session)
+            return ovr_reader.load(path)
+
         else:
             raise RuntimeError(f"Unsupported read for {clazz}")
diff --git a/python/pyspark/ml/tests/connect/test_parity_ovr.py 
b/python/pyspark/ml/tests/connect/test_parity_ovr.py
new file mode 100644
index 000000000000..3ad3fec8cf10
--- /dev/null
+++ b/python/pyspark/ml/tests/connect/test_parity_ovr.py
@@ -0,0 +1,37 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from pyspark.ml.tests.test_ovr import OneVsRestTestsMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+
+
+class OneVsRestParityTests(OneVsRestTestsMixin, ReusedConnectTestCase):
+    pass
+
+
+if __name__ == "__main__":
+    from pyspark.ml.tests.connect.test_parity_ovr import *  # noqa: F401
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/ml/tests/test_ovr.py 
b/python/pyspark/ml/tests/test_ovr.py
new file mode 100644
index 000000000000..aa003be53f5d
--- /dev/null
+++ b/python/pyspark/ml/tests/test_ovr.py
@@ -0,0 +1,135 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import tempfile
+import unittest
+
+import numpy as np
+
+from pyspark.ml.linalg import Vectors
+from pyspark.ml.classification import (
+    LinearSVC,
+    LinearSVCModel,
+    OneVsRest,
+    OneVsRestModel,
+)
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+
+
+class OneVsRestTestsMixin:
+    def test_one_vs_rest(self):
+        spark = self.spark
+        df = (
+            spark.createDataFrame(
+                [
+                    (0, 1.0, Vectors.dense(0.0, 5.0)),
+                    (1, 0.0, Vectors.dense(1.0, 2.0)),
+                    (2, 1.0, Vectors.dense(2.0, 1.0)),
+                    (3, 2.0, Vectors.dense(3.0, 3.0)),
+                ],
+                ["index", "label", "features"],
+            )
+            .coalesce(1)
+            .sortWithinPartitions("index")
+            .select("label", "features")
+        )
+
+        svc = LinearSVC(maxIter=1, regParam=1.0)
+        self.assertEqual(svc.getMaxIter(), 1)
+        self.assertEqual(svc.getRegParam(), 1.0)
+
+        ovr = OneVsRest(classifier=svc, parallelism=1)
+        self.assertEqual(ovr.getParallelism(), 1)
+
+        model = ovr.fit(df)
+        self.assertIsInstance(model, OneVsRestModel)
+        self.assertEqual(len(model.models), 3)
+        for submodel in model.models:
+            self.assertIsInstance(submodel, LinearSVCModel)
+
+        self.assertTrue(
+            np.allclose(model.models[0].intercept, 0.06279247869226989, 
atol=1e-4),
+            model.models[0].intercept,
+        )
+        self.assertTrue(
+            np.allclose(
+                model.models[0].coefficients.toArray(),
+                [-0.1198765502306968, -0.1027513287691687],
+                atol=1e-4,
+            ),
+            model.models[0].coefficients,
+        )
+
+        self.assertTrue(
+            np.allclose(model.models[1].intercept, 0.025877458475338313, 
atol=1e-4),
+            model.models[1].intercept,
+        )
+        self.assertTrue(
+            np.allclose(
+                model.models[1].coefficients.toArray(),
+                [-0.0362284418654736, 0.010350983390135305],
+                atol=1e-4,
+            ),
+            model.models[1].coefficients,
+        )
+
+        self.assertTrue(
+            np.allclose(model.models[2].intercept, -0.37024065419409624, 
atol=1e-4),
+            model.models[2].intercept,
+        )
+        self.assertTrue(
+            np.allclose(
+                model.models[2].coefficients.toArray(),
+                [0.12886829400126, 0.012273170857262873],
+                atol=1e-4,
+            ),
+            model.models[2].coefficients,
+        )
+
+        output = model.transform(df)
+        expected_cols = ["label", "features", "rawPrediction", "prediction"]
+        self.assertEqual(output.columns, expected_cols)
+        self.assertEqual(output.count(), 4)
+
+        # Model save & load
+        with tempfile.TemporaryDirectory(prefix="linear_svc") as d:
+            path1 = os.path.join(d, "ovr")
+            ovr.write().overwrite().save(path1)
+            ovr2 = OneVsRest.load(path1)
+            self.assertEqual(str(ovr), str(ovr2))
+
+            path2 = os.path.join(d, "ovr_model")
+            model.write().overwrite().save(path2)
+            model2 = OneVsRestModel.load(path2)
+            self.assertEqual(str(model), str(model2))
+
+
+class OneVsRestTests(OneVsRestTestsMixin, ReusedSQLTestCase):
+    pass
+
+
+if __name__ == "__main__":
+    from pyspark.ml.tests.test_ovr import *  # noqa: F401,F403
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to