This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new f91646871213 [SPARK-50954][ML][PYTHON][CONNECT] Support client side 
model path overwrite for meta algorithms
f91646871213 is described below

commit f91646871213dd15c23aa36970f16ff7d0081ae2
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Feb 5 23:10:53 2025 +0800

    [SPARK-50954][ML][PYTHON][CONNECT] Support client side model path overwrite 
for meta algorithms
    
    ### What changes were proposed in this pull request?
    Support client side model path overwrite for meta algorithms
    
    ### Why are the changes needed?
    for feature parity
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    
    ### How was this patch tested?
    updated tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #49810 from zhengruifeng/ml_connect_overwrite.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
    (cherry picked from commit 6a1d38219428eb48e1f3bbcae1bfe2d487b7afc3)
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../org/apache/spark/ml/util/ConnectHelper.scala   |  8 +++-
 .../scala/org/apache/spark/ml/util/ReadWrite.scala |  2 +-
 python/pyspark/ml/connect/readwrite.py             | 46 ++++++++--------------
 python/pyspark/ml/tests/test_ovr.py                | 11 ++----
 python/pyspark/ml/tests/test_pipeline.py           | 21 ++++------
 python/pyspark/ml/tests/test_tuning.py             | 26 +++++-------
 .../org/apache/spark/sql/connect/ml/MLUtils.scala  |  1 +
 7 files changed, 48 insertions(+), 67 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ConnectHelper.scala 
b/mllib/src/main/scala/org/apache/spark/ml/util/ConnectHelper.scala
index cf1a75518857..d4a0a1301e15 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ConnectHelper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ConnectHelper.scala
@@ -19,12 +19,18 @@ package org.apache.spark.ml.util
 import org.apache.spark.ml.Model
 import org.apache.spark.ml.feature.{CountVectorizerModel, StringIndexerModel}
 import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
 import org.apache.spark.sql.types.StructType
 
 private[spark] class ConnectHelper(override val uid: String) extends 
Model[ConnectHelper] {
   def this() = this(Identifiable.randomUID("ConnectHelper"))
 
+  def handleOverwrite(path: String, shouldOverwrite: Boolean): Boolean = {
+    val spark = SparkSession.builder().getOrCreate()
+    new FileSystemOverwrite().handleOverwrite(path, shouldOverwrite, spark)
+    true
+  }
+
   def stringIndexerModelFromLabels(
       uid: String, labels: Array[String]): StringIndexerModel = {
     new StringIndexerModel(uid, labels)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala 
b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index f9d9056c801e..dcb337218edc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -754,7 +754,7 @@ private[ml] object MetaAlgorithmReadWrite {
   }
 }
 
-private[ml] class FileSystemOverwrite extends Logging {
+private[spark] class FileSystemOverwrite extends Logging {
 
   def handleOverwrite(path: String, shouldOverwrite: Boolean, session: 
SparkSession): Unit = {
     val hadoopConf = session.sessionState.newHadoopConf()
diff --git a/python/pyspark/ml/connect/readwrite.py 
b/python/pyspark/ml/connect/readwrite.py
index de70a410dbc7..5c2b850d51b3 100644
--- a/python/pyspark/ml/connect/readwrite.py
+++ b/python/pyspark/ml/connect/readwrite.py
@@ -14,7 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-import warnings
+
 from typing import cast, Type, TYPE_CHECKING, Union, Dict, Any, Optional
 
 import pyspark.sql.connect.proto as pb2
@@ -142,75 +142,63 @@ class RemoteMLWriter(MLWriter):
         elif isinstance(instance, Pipeline):
             from pyspark.ml.pipeline import PipelineWriter
 
-            if shouldOverwrite:
-                # TODO(SPARK-50954): Support client side model path overwrite
-                warnings.warn("Overwrite doesn't take effect for Pipeline")
-
+            RemoteMLWriter.handleOverwrite(path, shouldOverwrite)
             pl_writer = PipelineWriter(instance)
             pl_writer.session(session)  # type: ignore[arg-type]
             pl_writer.save(path)
         elif isinstance(instance, PipelineModel):
             from pyspark.ml.pipeline import PipelineModelWriter
 
-            if shouldOverwrite:
-                # TODO(SPARK-50954): Support client side model path overwrite
-                warnings.warn("Overwrite doesn't take effect for 
PipelineModel")
-
+            RemoteMLWriter.handleOverwrite(path, shouldOverwrite)
             plm_writer = PipelineModelWriter(instance)
             plm_writer.session(session)  # type: ignore[arg-type]
             plm_writer.save(path)
         elif isinstance(instance, CrossValidator):
             from pyspark.ml.tuning import CrossValidatorWriter
 
-            if shouldOverwrite:
-                # TODO(SPARK-50954): Support client side model path overwrite
-                warnings.warn("Overwrite doesn't take effect for 
CrossValidator")
+            RemoteMLWriter.handleOverwrite(path, shouldOverwrite)
             cv_writer = CrossValidatorWriter(instance)
             cv_writer.session(session)  # type: ignore[arg-type]
             cv_writer.save(path)
         elif isinstance(instance, CrossValidatorModel):
-            if shouldOverwrite:
-                # TODO(SPARK-50954): Support client side model path overwrite
-                warnings.warn("Overwrite doesn't take effect for 
CrossValidatorModel")
+            RemoteMLWriter.handleOverwrite(path, shouldOverwrite)
             cvm_writer = RemoteCrossValidatorModelWriter(instance, optionMap, 
session)
             cvm_writer.save(path)
         elif isinstance(instance, TrainValidationSplit):
             from pyspark.ml.tuning import TrainValidationSplitWriter
 
-            if shouldOverwrite:
-                # TODO(SPARK-50954): Support client side model path overwrite
-                warnings.warn("Overwrite doesn't take effect for 
TrainValidationSplit")
+            RemoteMLWriter.handleOverwrite(path, shouldOverwrite)
             tvs_writer = TrainValidationSplitWriter(instance)
             tvs_writer.save(path)
         elif isinstance(instance, TrainValidationSplitModel):
-            if shouldOverwrite:
-                # TODO(SPARK-50954): Support client side model path overwrite
-                warnings.warn("Overwrite doesn't take effect for 
TrainValidationSplitModel")
+            RemoteMLWriter.handleOverwrite(path, shouldOverwrite)
             tvsm_writer = RemoteTrainValidationSplitModelWriter(instance, 
optionMap, session)
             tvsm_writer.save(path)
         elif isinstance(instance, OneVsRest):
             from pyspark.ml.classification import OneVsRestWriter
 
-            if shouldOverwrite:
-                # TODO(SPARK-50954): Support client side model path overwrite
-                warnings.warn("Overwrite doesn't take effect for OneVsRest")
-
+            RemoteMLWriter.handleOverwrite(path, shouldOverwrite)
             ovr_writer = OneVsRestWriter(instance)
             ovr_writer.session(session)  # type: ignore[arg-type]
             ovr_writer.save(path)
         elif isinstance(instance, OneVsRestModel):
             from pyspark.ml.classification import OneVsRestModelWriter
 
-            if shouldOverwrite:
-                # TODO(SPARK-50954): Support client side model path overwrite
-                warnings.warn("Overwrite doesn't take effect for 
OneVsRestModel")
-
+            RemoteMLWriter.handleOverwrite(path, shouldOverwrite)
             ovrm_writer = OneVsRestModelWriter(instance)
             ovrm_writer.session(session)  # type: ignore[arg-type]
             ovrm_writer.save(path)
         else:
             raise NotImplementedError(f"Unsupported write for 
{instance.__class__}")
 
+    @staticmethod
+    def handleOverwrite(path: str, shouldOverwrite: bool) -> None:
+        from pyspark.ml.util import ML_CONNECT_HELPER_ID
+
+        if shouldOverwrite:
+            helper = JavaWrapper(java_obj=ML_CONNECT_HELPER_ID)
+            helper._call_java("handleOverwrite", path, shouldOverwrite)
+
 
 class RemoteMLReader(MLReader[RL]):
     def __init__(self, clazz: Type["JavaMLReadable[RL]"]) -> None:
diff --git a/python/pyspark/ml/tests/test_ovr.py 
b/python/pyspark/ml/tests/test_ovr.py
index aa003be53f5d..42275c852ac1 100644
--- a/python/pyspark/ml/tests/test_ovr.py
+++ b/python/pyspark/ml/tests/test_ovr.py
@@ -15,7 +15,6 @@
 # limitations under the License.
 #
 
-import os
 import tempfile
 import unittest
 
@@ -108,14 +107,12 @@ class OneVsRestTestsMixin:
 
         # Model save & load
         with tempfile.TemporaryDirectory(prefix="linear_svc") as d:
-            path1 = os.path.join(d, "ovr")
-            ovr.write().overwrite().save(path1)
-            ovr2 = OneVsRest.load(path1)
+            ovr.write().overwrite().save(d)
+            ovr2 = OneVsRest.load(d)
             self.assertEqual(str(ovr), str(ovr2))
 
-            path2 = os.path.join(d, "ovr_model")
-            model.write().overwrite().save(path2)
-            model2 = OneVsRestModel.load(path2)
+            model.write().overwrite().save(d)
+            model2 = OneVsRestModel.load(d)
             self.assertEqual(str(model), str(model2))
 
 
diff --git a/python/pyspark/ml/tests/test_pipeline.py 
b/python/pyspark/ml/tests/test_pipeline.py
index 9c6ea2126f8d..8318f3bb71c9 100644
--- a/python/pyspark/ml/tests/test_pipeline.py
+++ b/python/pyspark/ml/tests/test_pipeline.py
@@ -15,7 +15,6 @@
 # limitations under the License.
 #
 
-import os
 import tempfile
 import unittest
 
@@ -110,15 +109,13 @@ class PipelineTestsMixin:
 
         # save & load
         with tempfile.TemporaryDirectory(prefix="classification_pipeline") as 
d:
-            path1 = os.path.join(d, "pipeline")
-            pipeline.write().save(path1)
-            pipeline2 = Pipeline.load(path1)
+            pipeline.write().overwrite().save(d)
+            pipeline2 = Pipeline.load(d)
             self.assertEqual(str(pipeline), str(pipeline2))
             self.assertEqual(str(pipeline.getStages()), 
str(pipeline2.getStages()))
 
-            path2 = os.path.join(d, "pipeline_model")
-            model.write().save(path2)
-            model2 = PipelineModel.load(path2)
+            model.write().overwrite().save(d)
+            model2 = PipelineModel.load(d)
             self.assertEqual(str(model), str(model2))
             self.assertEqual(str(model.stages), str(model2.stages))
 
@@ -165,15 +162,13 @@ class PipelineTestsMixin:
 
         # PipelineModel save & load
         with tempfile.TemporaryDirectory(prefix="clustering_pipeline") as d:
-            path1 = os.path.join(d, "pipeline")
-            pipeline.write().save(path1)
-            pipeline2 = Pipeline.load(path1)
+            pipeline.write().overwrite().save(d)
+            pipeline2 = Pipeline.load(d)
             self.assertEqual(str(pipeline), str(pipeline2))
             self.assertEqual(str(pipeline.getStages()), 
str(pipeline2.getStages()))
 
-            path2 = os.path.join(d, "pipeline_model")
-            model.write().save(path2)
-            model2 = PipelineModel.load(path2)
+            model.write().overwrite().save(d)
+            model2 = PipelineModel.load(d)
             self.assertEqual(str(model), str(model2))
             self.assertEqual(str(model.stages), str(model2.stages))
 
diff --git a/python/pyspark/ml/tests/test_tuning.py 
b/python/pyspark/ml/tests/test_tuning.py
index 451c89db5b5a..94081a090bfe 100644
--- a/python/pyspark/ml/tests/test_tuning.py
+++ b/python/pyspark/ml/tests/test_tuning.py
@@ -15,7 +15,6 @@
 # limitations under the License.
 #
 
-import os
 import tempfile
 import unittest
 
@@ -73,16 +72,14 @@ class TuningTestsMixin:
 
         # save & load
         with tempfile.TemporaryDirectory(prefix="train_validation_split") as d:
-            path1 = os.path.join(d, "tvs")
-            tvs.write().save(path1)
-            tvs2 = TrainValidationSplit.load(path1)
+            tvs.write().overwrite().save(d)
+            tvs2 = TrainValidationSplit.load(d)
             self.assertEqual(str(tvs), str(tvs2))
             self.assertEqual(str(tvs.getEstimator()), str(tvs2.getEstimator()))
             self.assertEqual(str(tvs.getEvaluator()), str(tvs2.getEvaluator()))
 
-            path2 = os.path.join(d, "tvsm")
-            tvs_model.write().save(path2)
-            model2 = TrainValidationSplitModel.load(path2)
+            tvs_model.write().overwrite().save(d)
+            model2 = TrainValidationSplitModel.load(d)
             self.assertEqual(str(tvs_model), str(model2))
             self.assertEqual(str(tvs_model.getEstimator()), 
str(model2.getEstimator()))
             self.assertEqual(str(tvs_model.getEvaluator()), 
str(model2.getEvaluator()))
@@ -142,24 +139,21 @@ class TuningTestsMixin:
 
         # save & load
         with tempfile.TemporaryDirectory(prefix="cv") as d:
-            path1 = os.path.join(d, "cv")
-            cv.write().save(path1)
-            cv2 = CrossValidator.load(path1)
+            cv.write().overwrite().save(d)
+            cv2 = CrossValidator.load(d)
             self.assertEqual(str(cv), str(cv2))
             self.assertEqual(str(cv.getEstimator()), str(cv2.getEstimator()))
             self.assertEqual(str(cv.getEvaluator()), str(cv2.getEvaluator()))
 
-            path2 = os.path.join(d, "cv_model")
-            model.write().save(path2)
-            model2 = CrossValidatorModel.load(path2)
+            model.write().overwrite().save(d)
+            model2 = CrossValidatorModel.load(d)
             checkSubModels(model2.subModels)
             self.assertEqual(str(model), str(model2))
             self.assertEqual(str(model.getEstimator()), 
str(model2.getEstimator()))
             self.assertEqual(str(model.getEvaluator()), 
str(model2.getEvaluator()))
 
-            path2 = os.path.join(d, "cv_model2")
-            model.write().option("persistSubModels", "false").save(path2)
-            cvModel2 = CrossValidatorModel.load(path2)
+            model.write().overwrite().option("persistSubModels", 
"false").save(d)
+            cvModel2 = CrossValidatorModel.load(d)
             self.assertEqual(cvModel2.subModels, None)
 
             model3 = model2.copy()
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
index e69f670226e4..2cbae0fa7d9d 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
@@ -656,6 +656,7 @@ private[ml] object MLUtils {
     (
       classOf[ConnectHelper],
       Set(
+        "handleOverwrite",
         "stringIndexerModelFromLabels",
         "stringIndexerModelFromLabelsArray",
         "countVectorizerModelFromVocabulary")))


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to