zhengruifeng commented on code in PR #49810:
URL: https://github.com/apache/spark/pull/49810#discussion_r1942931538


##########
python/pyspark/ml/connect/readwrite.py:
##########
@@ -142,75 +142,62 @@ def saveInstance(
         elif isinstance(instance, Pipeline):
             from pyspark.ml.pipeline import PipelineWriter
 
-            if shouldOverwrite:
-                # TODO(SPARK-50954): Support client side model path overwrite
-                warnings.warn("Overwrite doesn't take effect for Pipeline")
-
+            RemoteMLWriter.handleOverwrite(path, shouldOverwrite)
             pl_writer = PipelineWriter(instance)
             pl_writer.session(session)  # type: ignore[arg-type]
             pl_writer.save(path)
         elif isinstance(instance, PipelineModel):
             from pyspark.ml.pipeline import PipelineModelWriter
 
-            if shouldOverwrite:
-                # TODO(SPARK-50954): Support client side model path overwrite
-                warnings.warn("Overwrite doesn't take effect for 
PipelineModel")
-
+            RemoteMLWriter.handleOverwrite(path, shouldOverwrite)
             plm_writer = PipelineModelWriter(instance)
             plm_writer.session(session)  # type: ignore[arg-type]
             plm_writer.save(path)
         elif isinstance(instance, CrossValidator):
             from pyspark.ml.tuning import CrossValidatorWriter
 
-            if shouldOverwrite:
-                # TODO(SPARK-50954): Support client side model path overwrite
-                warnings.warn("Overwrite doesn't take effect for 
CrossValidator")
+            RemoteMLWriter.handleOverwrite(path, shouldOverwrite)
             cv_writer = CrossValidatorWriter(instance)
             cv_writer.session(session)  # type: ignore[arg-type]
             cv_writer.save(path)
         elif isinstance(instance, CrossValidatorModel):
-            if shouldOverwrite:
-                # TODO(SPARK-50954): Support client side model path overwrite
-                warnings.warn("Overwrite doesn't take effect for 
CrossValidatorModel")
+            RemoteMLWriter.handleOverwrite(path, shouldOverwrite)
             cvm_writer = RemoteCrossValidatorModelWriter(instance, optionMap, 
session)
             cvm_writer.save(path)
         elif isinstance(instance, TrainValidationSplit):
             from pyspark.ml.tuning import TrainValidationSplitWriter
 
-            if shouldOverwrite:
-                # TODO(SPARK-50954): Support client side model path overwrite
-                warnings.warn("Overwrite doesn't take effect for 
TrainValidationSplit")
+            RemoteMLWriter.handleOverwrite(path, shouldOverwrite)
             tvs_writer = TrainValidationSplitWriter(instance)
             tvs_writer.save(path)
         elif isinstance(instance, TrainValidationSplitModel):
-            if shouldOverwrite:
-                # TODO(SPARK-50954): Support client side model path overwrite
-                warnings.warn("Overwrite doesn't take effect for 
TrainValidationSplitModel")
+            RemoteMLWriter.handleOverwrite(path, shouldOverwrite)
             tvsm_writer = RemoteTrainValidationSplitModelWriter(instance, 
optionMap, session)
             tvsm_writer.save(path)
         elif isinstance(instance, OneVsRest):
             from pyspark.ml.classification import OneVsRestWriter
 
-            if shouldOverwrite:
-                # TODO(SPARK-50954): Support client side model path overwrite
-                warnings.warn("Overwrite doesn't take effect for OneVsRest")
-
+            RemoteMLWriter.handleOverwrite(path, shouldOverwrite)
             ovr_writer = OneVsRestWriter(instance)
             ovr_writer.session(session)  # type: ignore[arg-type]
             ovr_writer.save(path)
         elif isinstance(instance, OneVsRestModel):
             from pyspark.ml.classification import OneVsRestModelWriter
 
-            if shouldOverwrite:
-                # TODO(SPARK-50954): Support client side model path overwrite
-                warnings.warn("Overwrite doesn't take effect for 
OneVsRestModel")
-
+            RemoteMLWriter.handleOverwrite(path, shouldOverwrite)
             ovrm_writer = OneVsRestModelWriter(instance)
             ovrm_writer.session(session)  # type: ignore[arg-type]
             ovrm_writer.save(path)
         else:
             raise NotImplementedError(f"Unsupported write for 
{instance.__class__}")
 
+    @staticmethod
+    def handleOverwrite(path: str, shouldOverwrite: bool) -> None:
+        from pyspark.ml.util import ML_CONNECT_HELPER_ID
+

Review Comment:
   good catch!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to