This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new 1348482d8129 Revert "[SPARK-50922][ML][PYTHON][CONNECT] Support 
OneVsRest on Connect"
1348482d8129 is described below

commit 1348482d8129bfff7d7bc05ff9e3e227fdfa1a45
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Tue Jan 28 10:28:46 2025 +0900

    Revert "[SPARK-50922][ML][PYTHON][CONNECT] Support OneVsRest on Connect"
    
    This reverts commit 68615c4275c678f7cc72f5caa0fa1eb8e39e89fb.
---
 dev/sparktestsupport/modules.py                    |   2 -
 python/pyspark/ml/classification.py                |  59 +++------
 python/pyspark/ml/connect/readwrite.py             |  37 ------
 python/pyspark/ml/tests/connect/test_parity_ovr.py |  37 ------
 python/pyspark/ml/tests/test_ovr.py                | 135 ---------------------
 5 files changed, 19 insertions(+), 251 deletions(-)

diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index dd7c387cd5f1..8e518b0febe3 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -676,7 +676,6 @@ pyspark_ml = Module(
         "pyspark.ml.tests.test_persistence",
         "pyspark.ml.tests.test_pipeline",
         "pyspark.ml.tests.test_tuning",
-        "pyspark.ml.tests.test_ovr",
         "pyspark.ml.tests.test_stat",
         "pyspark.ml.tests.test_training_summary",
         "pyspark.ml.tests.tuning.test_tuning",
@@ -1130,7 +1129,6 @@ pyspark_ml_connect = Module(
         "pyspark.ml.tests.connect.test_parity_feature",
         "pyspark.ml.tests.connect.test_parity_pipeline",
         "pyspark.ml.tests.connect.test_parity_tuning",
-        "pyspark.ml.tests.connect.test_parity_ovr",
     ],
     excluded_python_implementations=[
         "PyPy"  # Skip these tests under PyPy since they require numpy, 
pandas, and pyarrow and
diff --git a/python/pyspark/ml/classification.py 
b/python/pyspark/ml/classification.py
index 1ed2f1fd4b8a..d8ed51a82abe 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -85,8 +85,6 @@ from pyspark.ml.util import (
     MLWriter,
     MLWritable,
     HasTrainingSummary,
-    try_remote_read,
-    try_remote_write,
     try_remote_attribute_relation,
 )
 from pyspark.ml.wrapper import JavaParams, JavaPredictor, JavaPredictionModel, 
JavaWrapper
@@ -96,7 +94,6 @@ from pyspark.sql import DataFrame, Row, SparkSession
 from pyspark.sql.functions import udf, when
 from pyspark.sql.types import ArrayType, DoubleType
 from pyspark.storagelevel import StorageLevel
-from pyspark.sql.utils import is_remote
 
 if TYPE_CHECKING:
     from pyspark.ml._typing import P, ParamMap
@@ -3575,45 +3572,31 @@ class OneVsRest(
         if handlePersistence:
             multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK)
 
-        def _oneClassFitTasks(numClasses: int):
-            indices = iter(range(numClasses))
-
-            def trainSingleClass() -> CM:
-                index = next(indices)
-
-                binaryLabelCol = "mc2b$" + str(index)
-                trainingDataset = multiclassLabeled.withColumn(
-                    binaryLabelCol,
-                    when(multiclassLabeled[labelCol] == float(index), 
1.0).otherwise(0.0),
-                )
-                paramMap = dict(
-                    [
-                        (classifier.labelCol, binaryLabelCol),
-                        (classifier.featuresCol, featuresCol),
-                        (classifier.predictionCol, predictionCol),
-                    ]
-                )
-                if weightCol:
-                    paramMap[cast(HasWeightCol, classifier).weightCol] = 
weightCol
-                return index, classifier.fit(trainingDataset, paramMap)
-
-            return [trainSingleClass] * numClasses
+        def trainSingleClass(index: int) -> CM:
+            binaryLabelCol = "mc2b$" + str(index)
+            trainingDataset = multiclassLabeled.withColumn(
+                binaryLabelCol,
+                when(multiclassLabeled[labelCol] == float(index), 
1.0).otherwise(0.0),
+            )
+            paramMap = dict(
+                [
+                    (classifier.labelCol, binaryLabelCol),
+                    (classifier.featuresCol, featuresCol),
+                    (classifier.predictionCol, predictionCol),
+                ]
+            )
+            if weightCol:
+                paramMap[cast(HasWeightCol, classifier).weightCol] = weightCol
+            return classifier.fit(trainingDataset, paramMap)
 
-        tasks = map(
-            inheritable_thread_target(dataset.sparkSession),
-            _oneClassFitTasks(numClasses),
-        )
         pool = ThreadPool(processes=min(self.getParallelism(), numClasses))
 
-        subModels = [None] * numClasses
-        for j, subModel in pool.imap_unordered(lambda f: f(), tasks):
-            assert subModels is not None
-            subModels[j] = subModel
+        models = pool.map(inheritable_thread_target(trainSingleClass), 
range(numClasses))
 
         if handlePersistence:
             multiclassLabeled.unpersist()
 
-        return self._copyValues(OneVsRestModel(models=subModels))
+        return self._copyValues(OneVsRestModel(models=models))
 
     def copy(self, extra: Optional["ParamMap"] = None) -> "OneVsRest":
         """
@@ -3688,11 +3671,9 @@ class OneVsRest(
         return _java_obj
 
     @classmethod
-    @try_remote_read
     def read(cls) -> "OneVsRestReader":
         return OneVsRestReader(cls)
 
-    @try_remote_write
     def write(self) -> MLWriter:
         if isinstance(self.getClassifier(), JavaMLWritable):
             return JavaMLWriter(self)  # type: ignore[arg-type]
@@ -3806,7 +3787,7 @@ class OneVsRestModel(
         from pyspark.core.context import SparkContext
 
         self.models = models
-        if is_remote() or not isinstance(models[0], JavaMLWritable):
+        if not isinstance(models[0], JavaMLWritable):
             return
         # set java instance
         java_models = [cast(_JavaClassificationModel, model)._to_java() for 
model in self.models]
@@ -3974,11 +3955,9 @@ class OneVsRestModel(
         return _java_obj
 
     @classmethod
-    @try_remote_read
     def read(cls) -> "OneVsRestModelReader":
         return OneVsRestModelReader(cls)
 
-    @try_remote_write
     def write(self) -> MLWriter:
         if all(
             map(
diff --git a/python/pyspark/ml/connect/readwrite.py 
b/python/pyspark/ml/connect/readwrite.py
index 6e364afb7dbf..6392e988c067 100644
--- a/python/pyspark/ml/connect/readwrite.py
+++ b/python/pyspark/ml/connect/readwrite.py
@@ -95,7 +95,6 @@ class RemoteMLWriter(MLWriter):
         from pyspark.ml.evaluation import JavaEvaluator
         from pyspark.ml.pipeline import Pipeline, PipelineModel
         from pyspark.ml.tuning import CrossValidator, TrainValidationSplit
-        from pyspark.ml.classification import OneVsRest, OneVsRestModel
 
         # Spark Connect ML is built on scala Spark.ML, that means we're only
         # supporting JavaModel or JavaEstimator or JavaEvaluator
@@ -188,27 +187,6 @@ class RemoteMLWriter(MLWriter):
                 warnings.warn("Overwrite doesn't take effect for 
TrainValidationSplitModel")
             tvsm_writer = RemoteTrainValidationSplitModelWriter(instance, 
optionMap, session)
             tvsm_writer.save(path)
-        elif isinstance(instance, OneVsRest):
-            from pyspark.ml.classification import OneVsRestWriter
-
-            if shouldOverwrite:
-                # TODO(SPARK-50954): Support client side model path overwrite
-                warnings.warn("Overwrite doesn't take effect for OneVsRest")
-
-            writer = OneVsRestWriter(instance)
-            writer.session(session)
-            writer.save(path)
-            # _OneVsRestSharedReadWrite.saveImpl(self.instance, 
self.sparkSession, path)
-        elif isinstance(instance, OneVsRestModel):
-            from pyspark.ml.classification import OneVsRestModelWriter
-
-            if shouldOverwrite:
-                # TODO(SPARK-50954): Support client side model path overwrite
-                warnings.warn("Overwrite doesn't take effect for 
OneVsRestModel")
-
-            writer = OneVsRestModelWriter(instance)
-            writer.session(session)
-            writer.save(path)
         else:
             raise NotImplementedError(f"Unsupported write for 
{instance.__class__}")
 
@@ -237,7 +215,6 @@ class RemoteMLReader(MLReader[RL]):
         from pyspark.ml.evaluation import JavaEvaluator
         from pyspark.ml.pipeline import Pipeline, PipelineModel
         from pyspark.ml.tuning import CrossValidator, TrainValidationSplit
-        from pyspark.ml.classification import OneVsRest, OneVsRestModel
 
         if (
             issubclass(clazz, JavaModel)
@@ -330,19 +307,5 @@ class RemoteMLReader(MLReader[RL]):
             tvs_reader.session(session)
             return tvs_reader.load(path)
 
-        elif issubclass(clazz, OneVsRest):
-            from pyspark.ml.classification import OneVsRestReader
-
-            ovr_reader = OneVsRestReader(OneVsRest)
-            ovr_reader.session(session)
-            return ovr_reader.load(path)
-
-        elif issubclass(clazz, OneVsRestModel):
-            from pyspark.ml.classification import OneVsRestModelReader
-
-            ovr_reader = OneVsRestModelReader(OneVsRestModel)
-            ovr_reader.session(session)
-            return ovr_reader.load(path)
-
         else:
             raise RuntimeError(f"Unsupported read for {clazz}")
diff --git a/python/pyspark/ml/tests/connect/test_parity_ovr.py 
b/python/pyspark/ml/tests/connect/test_parity_ovr.py
deleted file mode 100644
index 3ad3fec8cf10..000000000000
--- a/python/pyspark/ml/tests/connect/test_parity_ovr.py
+++ /dev/null
@@ -1,37 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements.  See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License.  You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import unittest
-
-from pyspark.ml.tests.test_ovr import OneVsRestTestsMixin
-from pyspark.testing.connectutils import ReusedConnectTestCase
-
-
-class OneVsRestParityTests(OneVsRestTestsMixin, ReusedConnectTestCase):
-    pass
-
-
-if __name__ == "__main__":
-    from pyspark.ml.tests.connect.test_parity_ovr import *  # noqa: F401
-
-    try:
-        import xmlrunner  # type: ignore[import]
-
-        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
-    except ImportError:
-        testRunner = None
-    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/ml/tests/test_ovr.py 
b/python/pyspark/ml/tests/test_ovr.py
deleted file mode 100644
index aa003be53f5d..000000000000
--- a/python/pyspark/ml/tests/test_ovr.py
+++ /dev/null
@@ -1,135 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements.  See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License.  You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import os
-import tempfile
-import unittest
-
-import numpy as np
-
-from pyspark.ml.linalg import Vectors
-from pyspark.ml.classification import (
-    LinearSVC,
-    LinearSVCModel,
-    OneVsRest,
-    OneVsRestModel,
-)
-from pyspark.testing.sqlutils import ReusedSQLTestCase
-
-
-class OneVsRestTestsMixin:
-    def test_one_vs_rest(self):
-        spark = self.spark
-        df = (
-            spark.createDataFrame(
-                [
-                    (0, 1.0, Vectors.dense(0.0, 5.0)),
-                    (1, 0.0, Vectors.dense(1.0, 2.0)),
-                    (2, 1.0, Vectors.dense(2.0, 1.0)),
-                    (3, 2.0, Vectors.dense(3.0, 3.0)),
-                ],
-                ["index", "label", "features"],
-            )
-            .coalesce(1)
-            .sortWithinPartitions("index")
-            .select("label", "features")
-        )
-
-        svc = LinearSVC(maxIter=1, regParam=1.0)
-        self.assertEqual(svc.getMaxIter(), 1)
-        self.assertEqual(svc.getRegParam(), 1.0)
-
-        ovr = OneVsRest(classifier=svc, parallelism=1)
-        self.assertEqual(ovr.getParallelism(), 1)
-
-        model = ovr.fit(df)
-        self.assertIsInstance(model, OneVsRestModel)
-        self.assertEqual(len(model.models), 3)
-        for submodel in model.models:
-            self.assertIsInstance(submodel, LinearSVCModel)
-
-        self.assertTrue(
-            np.allclose(model.models[0].intercept, 0.06279247869226989, 
atol=1e-4),
-            model.models[0].intercept,
-        )
-        self.assertTrue(
-            np.allclose(
-                model.models[0].coefficients.toArray(),
-                [-0.1198765502306968, -0.1027513287691687],
-                atol=1e-4,
-            ),
-            model.models[0].coefficients,
-        )
-
-        self.assertTrue(
-            np.allclose(model.models[1].intercept, 0.025877458475338313, 
atol=1e-4),
-            model.models[1].intercept,
-        )
-        self.assertTrue(
-            np.allclose(
-                model.models[1].coefficients.toArray(),
-                [-0.0362284418654736, 0.010350983390135305],
-                atol=1e-4,
-            ),
-            model.models[1].coefficients,
-        )
-
-        self.assertTrue(
-            np.allclose(model.models[2].intercept, -0.37024065419409624, 
atol=1e-4),
-            model.models[2].intercept,
-        )
-        self.assertTrue(
-            np.allclose(
-                model.models[2].coefficients.toArray(),
-                [0.12886829400126, 0.012273170857262873],
-                atol=1e-4,
-            ),
-            model.models[2].coefficients,
-        )
-
-        output = model.transform(df)
-        expected_cols = ["label", "features", "rawPrediction", "prediction"]
-        self.assertEqual(output.columns, expected_cols)
-        self.assertEqual(output.count(), 4)
-
-        # Model save & load
-        with tempfile.TemporaryDirectory(prefix="linear_svc") as d:
-            path1 = os.path.join(d, "ovr")
-            ovr.write().overwrite().save(path1)
-            ovr2 = OneVsRest.load(path1)
-            self.assertEqual(str(ovr), str(ovr2))
-
-            path2 = os.path.join(d, "ovr_model")
-            model.write().overwrite().save(path2)
-            model2 = OneVsRestModel.load(path2)
-            self.assertEqual(str(model), str(model2))
-
-
-class OneVsRestTests(OneVsRestTestsMixin, ReusedSQLTestCase):
-    pass
-
-
-if __name__ == "__main__":
-    from pyspark.ml.tests.test_ovr import *  # noqa: F401,F403
-
-    try:
-        import xmlrunner  # type: ignore[import]
-
-        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
-    except ImportError:
-        testRunner = None
-    unittest.main(testRunner=testRunner, verbosity=2)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to