This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.0 by this push:
new 1348482d8129 Revert "[SPARK-50922][ML][PYTHON][CONNECT] Support
OneVsRest on Connect"
1348482d8129 is described below
commit 1348482d8129bfff7d7bc05ff9e3e227fdfa1a45
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Tue Jan 28 10:28:46 2025 +0900
Revert "[SPARK-50922][ML][PYTHON][CONNECT] Support OneVsRest on Connect"
This reverts commit 68615c4275c678f7cc72f5caa0fa1eb8e39e89fb.
---
dev/sparktestsupport/modules.py | 2 -
python/pyspark/ml/classification.py | 59 +++------
python/pyspark/ml/connect/readwrite.py | 37 ------
python/pyspark/ml/tests/connect/test_parity_ovr.py | 37 ------
python/pyspark/ml/tests/test_ovr.py | 135 ---------------------
5 files changed, 19 insertions(+), 251 deletions(-)
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index dd7c387cd5f1..8e518b0febe3 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -676,7 +676,6 @@ pyspark_ml = Module(
"pyspark.ml.tests.test_persistence",
"pyspark.ml.tests.test_pipeline",
"pyspark.ml.tests.test_tuning",
- "pyspark.ml.tests.test_ovr",
"pyspark.ml.tests.test_stat",
"pyspark.ml.tests.test_training_summary",
"pyspark.ml.tests.tuning.test_tuning",
@@ -1130,7 +1129,6 @@ pyspark_ml_connect = Module(
"pyspark.ml.tests.connect.test_parity_feature",
"pyspark.ml.tests.connect.test_parity_pipeline",
"pyspark.ml.tests.connect.test_parity_tuning",
- "pyspark.ml.tests.connect.test_parity_ovr",
],
excluded_python_implementations=[
"PyPy" # Skip these tests under PyPy since they require numpy,
pandas, and pyarrow and
diff --git a/python/pyspark/ml/classification.py
b/python/pyspark/ml/classification.py
index 1ed2f1fd4b8a..d8ed51a82abe 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -85,8 +85,6 @@ from pyspark.ml.util import (
MLWriter,
MLWritable,
HasTrainingSummary,
- try_remote_read,
- try_remote_write,
try_remote_attribute_relation,
)
from pyspark.ml.wrapper import JavaParams, JavaPredictor, JavaPredictionModel,
JavaWrapper
@@ -96,7 +94,6 @@ from pyspark.sql import DataFrame, Row, SparkSession
from pyspark.sql.functions import udf, when
from pyspark.sql.types import ArrayType, DoubleType
from pyspark.storagelevel import StorageLevel
-from pyspark.sql.utils import is_remote
if TYPE_CHECKING:
from pyspark.ml._typing import P, ParamMap
@@ -3575,45 +3572,31 @@ class OneVsRest(
if handlePersistence:
multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK)
- def _oneClassFitTasks(numClasses: int):
- indices = iter(range(numClasses))
-
- def trainSingleClass() -> CM:
- index = next(indices)
-
- binaryLabelCol = "mc2b$" + str(index)
- trainingDataset = multiclassLabeled.withColumn(
- binaryLabelCol,
- when(multiclassLabeled[labelCol] == float(index),
1.0).otherwise(0.0),
- )
- paramMap = dict(
- [
- (classifier.labelCol, binaryLabelCol),
- (classifier.featuresCol, featuresCol),
- (classifier.predictionCol, predictionCol),
- ]
- )
- if weightCol:
- paramMap[cast(HasWeightCol, classifier).weightCol] =
weightCol
- return index, classifier.fit(trainingDataset, paramMap)
-
- return [trainSingleClass] * numClasses
+ def trainSingleClass(index: int) -> CM:
+ binaryLabelCol = "mc2b$" + str(index)
+ trainingDataset = multiclassLabeled.withColumn(
+ binaryLabelCol,
+ when(multiclassLabeled[labelCol] == float(index),
1.0).otherwise(0.0),
+ )
+ paramMap = dict(
+ [
+ (classifier.labelCol, binaryLabelCol),
+ (classifier.featuresCol, featuresCol),
+ (classifier.predictionCol, predictionCol),
+ ]
+ )
+ if weightCol:
+ paramMap[cast(HasWeightCol, classifier).weightCol] = weightCol
+ return classifier.fit(trainingDataset, paramMap)
- tasks = map(
- inheritable_thread_target(dataset.sparkSession),
- _oneClassFitTasks(numClasses),
- )
pool = ThreadPool(processes=min(self.getParallelism(), numClasses))
- subModels = [None] * numClasses
- for j, subModel in pool.imap_unordered(lambda f: f(), tasks):
- assert subModels is not None
- subModels[j] = subModel
+ models = pool.map(inheritable_thread_target(trainSingleClass),
range(numClasses))
if handlePersistence:
multiclassLabeled.unpersist()
- return self._copyValues(OneVsRestModel(models=subModels))
+ return self._copyValues(OneVsRestModel(models=models))
def copy(self, extra: Optional["ParamMap"] = None) -> "OneVsRest":
"""
@@ -3688,11 +3671,9 @@ class OneVsRest(
return _java_obj
@classmethod
- @try_remote_read
def read(cls) -> "OneVsRestReader":
return OneVsRestReader(cls)
- @try_remote_write
def write(self) -> MLWriter:
if isinstance(self.getClassifier(), JavaMLWritable):
return JavaMLWriter(self) # type: ignore[arg-type]
@@ -3806,7 +3787,7 @@ class OneVsRestModel(
from pyspark.core.context import SparkContext
self.models = models
- if is_remote() or not isinstance(models[0], JavaMLWritable):
+ if not isinstance(models[0], JavaMLWritable):
return
# set java instance
java_models = [cast(_JavaClassificationModel, model)._to_java() for
model in self.models]
@@ -3974,11 +3955,9 @@ class OneVsRestModel(
return _java_obj
@classmethod
- @try_remote_read
def read(cls) -> "OneVsRestModelReader":
return OneVsRestModelReader(cls)
- @try_remote_write
def write(self) -> MLWriter:
if all(
map(
diff --git a/python/pyspark/ml/connect/readwrite.py
b/python/pyspark/ml/connect/readwrite.py
index 6e364afb7dbf..6392e988c067 100644
--- a/python/pyspark/ml/connect/readwrite.py
+++ b/python/pyspark/ml/connect/readwrite.py
@@ -95,7 +95,6 @@ class RemoteMLWriter(MLWriter):
from pyspark.ml.evaluation import JavaEvaluator
from pyspark.ml.pipeline import Pipeline, PipelineModel
from pyspark.ml.tuning import CrossValidator, TrainValidationSplit
- from pyspark.ml.classification import OneVsRest, OneVsRestModel
# Spark Connect ML is built on scala Spark.ML, that means we're only
# supporting JavaModel or JavaEstimator or JavaEvaluator
@@ -188,27 +187,6 @@ class RemoteMLWriter(MLWriter):
warnings.warn("Overwrite doesn't take effect for
TrainValidationSplitModel")
tvsm_writer = RemoteTrainValidationSplitModelWriter(instance,
optionMap, session)
tvsm_writer.save(path)
- elif isinstance(instance, OneVsRest):
- from pyspark.ml.classification import OneVsRestWriter
-
- if shouldOverwrite:
- # TODO(SPARK-50954): Support client side model path overwrite
- warnings.warn("Overwrite doesn't take effect for OneVsRest")
-
- writer = OneVsRestWriter(instance)
- writer.session(session)
- writer.save(path)
- # _OneVsRestSharedReadWrite.saveImpl(self.instance,
self.sparkSession, path)
- elif isinstance(instance, OneVsRestModel):
- from pyspark.ml.classification import OneVsRestModelWriter
-
- if shouldOverwrite:
- # TODO(SPARK-50954): Support client side model path overwrite
- warnings.warn("Overwrite doesn't take effect for
OneVsRestModel")
-
- writer = OneVsRestModelWriter(instance)
- writer.session(session)
- writer.save(path)
else:
raise NotImplementedError(f"Unsupported write for
{instance.__class__}")
@@ -237,7 +215,6 @@ class RemoteMLReader(MLReader[RL]):
from pyspark.ml.evaluation import JavaEvaluator
from pyspark.ml.pipeline import Pipeline, PipelineModel
from pyspark.ml.tuning import CrossValidator, TrainValidationSplit
- from pyspark.ml.classification import OneVsRest, OneVsRestModel
if (
issubclass(clazz, JavaModel)
@@ -330,19 +307,5 @@ class RemoteMLReader(MLReader[RL]):
tvs_reader.session(session)
return tvs_reader.load(path)
- elif issubclass(clazz, OneVsRest):
- from pyspark.ml.classification import OneVsRestReader
-
- ovr_reader = OneVsRestReader(OneVsRest)
- ovr_reader.session(session)
- return ovr_reader.load(path)
-
- elif issubclass(clazz, OneVsRestModel):
- from pyspark.ml.classification import OneVsRestModelReader
-
- ovr_reader = OneVsRestModelReader(OneVsRestModel)
- ovr_reader.session(session)
- return ovr_reader.load(path)
-
else:
raise RuntimeError(f"Unsupported read for {clazz}")
diff --git a/python/pyspark/ml/tests/connect/test_parity_ovr.py
b/python/pyspark/ml/tests/connect/test_parity_ovr.py
deleted file mode 100644
index 3ad3fec8cf10..000000000000
--- a/python/pyspark/ml/tests/connect/test_parity_ovr.py
+++ /dev/null
@@ -1,37 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import unittest
-
-from pyspark.ml.tests.test_ovr import OneVsRestTestsMixin
-from pyspark.testing.connectutils import ReusedConnectTestCase
-
-
-class OneVsRestParityTests(OneVsRestTestsMixin, ReusedConnectTestCase):
- pass
-
-
-if __name__ == "__main__":
- from pyspark.ml.tests.connect.test_parity_ovr import * # noqa: F401
-
- try:
- import xmlrunner # type: ignore[import]
-
- testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
- except ImportError:
- testRunner = None
- unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/ml/tests/test_ovr.py
b/python/pyspark/ml/tests/test_ovr.py
deleted file mode 100644
index aa003be53f5d..000000000000
--- a/python/pyspark/ml/tests/test_ovr.py
+++ /dev/null
@@ -1,135 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import os
-import tempfile
-import unittest
-
-import numpy as np
-
-from pyspark.ml.linalg import Vectors
-from pyspark.ml.classification import (
- LinearSVC,
- LinearSVCModel,
- OneVsRest,
- OneVsRestModel,
-)
-from pyspark.testing.sqlutils import ReusedSQLTestCase
-
-
-class OneVsRestTestsMixin:
- def test_one_vs_rest(self):
- spark = self.spark
- df = (
- spark.createDataFrame(
- [
- (0, 1.0, Vectors.dense(0.0, 5.0)),
- (1, 0.0, Vectors.dense(1.0, 2.0)),
- (2, 1.0, Vectors.dense(2.0, 1.0)),
- (3, 2.0, Vectors.dense(3.0, 3.0)),
- ],
- ["index", "label", "features"],
- )
- .coalesce(1)
- .sortWithinPartitions("index")
- .select("label", "features")
- )
-
- svc = LinearSVC(maxIter=1, regParam=1.0)
- self.assertEqual(svc.getMaxIter(), 1)
- self.assertEqual(svc.getRegParam(), 1.0)
-
- ovr = OneVsRest(classifier=svc, parallelism=1)
- self.assertEqual(ovr.getParallelism(), 1)
-
- model = ovr.fit(df)
- self.assertIsInstance(model, OneVsRestModel)
- self.assertEqual(len(model.models), 3)
- for submodel in model.models:
- self.assertIsInstance(submodel, LinearSVCModel)
-
- self.assertTrue(
- np.allclose(model.models[0].intercept, 0.06279247869226989,
atol=1e-4),
- model.models[0].intercept,
- )
- self.assertTrue(
- np.allclose(
- model.models[0].coefficients.toArray(),
- [-0.1198765502306968, -0.1027513287691687],
- atol=1e-4,
- ),
- model.models[0].coefficients,
- )
-
- self.assertTrue(
- np.allclose(model.models[1].intercept, 0.025877458475338313,
atol=1e-4),
- model.models[1].intercept,
- )
- self.assertTrue(
- np.allclose(
- model.models[1].coefficients.toArray(),
- [-0.0362284418654736, 0.010350983390135305],
- atol=1e-4,
- ),
- model.models[1].coefficients,
- )
-
- self.assertTrue(
- np.allclose(model.models[2].intercept, -0.37024065419409624,
atol=1e-4),
- model.models[2].intercept,
- )
- self.assertTrue(
- np.allclose(
- model.models[2].coefficients.toArray(),
- [0.12886829400126, 0.012273170857262873],
- atol=1e-4,
- ),
- model.models[2].coefficients,
- )
-
- output = model.transform(df)
- expected_cols = ["label", "features", "rawPrediction", "prediction"]
- self.assertEqual(output.columns, expected_cols)
- self.assertEqual(output.count(), 4)
-
- # Model save & load
- with tempfile.TemporaryDirectory(prefix="linear_svc") as d:
- path1 = os.path.join(d, "ovr")
- ovr.write().overwrite().save(path1)
- ovr2 = OneVsRest.load(path1)
- self.assertEqual(str(ovr), str(ovr2))
-
- path2 = os.path.join(d, "ovr_model")
- model.write().overwrite().save(path2)
- model2 = OneVsRestModel.load(path2)
- self.assertEqual(str(model), str(model2))
-
-
-class OneVsRestTests(OneVsRestTestsMixin, ReusedSQLTestCase):
- pass
-
-
-if __name__ == "__main__":
- from pyspark.ml.tests.test_ovr import * # noqa: F401,F403
-
- try:
- import xmlrunner # type: ignore[import]
-
- testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
- except ImportError:
- testRunner = None
- unittest.main(testRunner=testRunner, verbosity=2)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]