zhengruifeng commented on code in PR #49810:
URL: https://github.com/apache/spark/pull/49810#discussion_r1942931538
##########
python/pyspark/ml/connect/readwrite.py:
##########
@@ -142,75 +142,62 @@ def saveInstance(
elif isinstance(instance, Pipeline):
from pyspark.ml.pipeline import PipelineWriter
- if shouldOverwrite:
- # TODO(SPARK-50954): Support client side model path overwrite
- warnings.warn("Overwrite doesn't take effect for Pipeline")
-
+ RemoteMLWriter.handleOverwrite(path, shouldOverwrite)
pl_writer = PipelineWriter(instance)
pl_writer.session(session) # type: ignore[arg-type]
pl_writer.save(path)
elif isinstance(instance, PipelineModel):
from pyspark.ml.pipeline import PipelineModelWriter
- if shouldOverwrite:
- # TODO(SPARK-50954): Support client side model path overwrite
- warnings.warn("Overwrite doesn't take effect for
PipelineModel")
-
+ RemoteMLWriter.handleOverwrite(path, shouldOverwrite)
plm_writer = PipelineModelWriter(instance)
plm_writer.session(session) # type: ignore[arg-type]
plm_writer.save(path)
elif isinstance(instance, CrossValidator):
from pyspark.ml.tuning import CrossValidatorWriter
- if shouldOverwrite:
- # TODO(SPARK-50954): Support client side model path overwrite
- warnings.warn("Overwrite doesn't take effect for
CrossValidator")
+ RemoteMLWriter.handleOverwrite(path, shouldOverwrite)
cv_writer = CrossValidatorWriter(instance)
cv_writer.session(session) # type: ignore[arg-type]
cv_writer.save(path)
elif isinstance(instance, CrossValidatorModel):
- if shouldOverwrite:
- # TODO(SPARK-50954): Support client side model path overwrite
- warnings.warn("Overwrite doesn't take effect for
CrossValidatorModel")
+ RemoteMLWriter.handleOverwrite(path, shouldOverwrite)
cvm_writer = RemoteCrossValidatorModelWriter(instance, optionMap,
session)
cvm_writer.save(path)
elif isinstance(instance, TrainValidationSplit):
from pyspark.ml.tuning import TrainValidationSplitWriter
- if shouldOverwrite:
- # TODO(SPARK-50954): Support client side model path overwrite
- warnings.warn("Overwrite doesn't take effect for
TrainValidationSplit")
+ RemoteMLWriter.handleOverwrite(path, shouldOverwrite)
tvs_writer = TrainValidationSplitWriter(instance)
tvs_writer.save(path)
elif isinstance(instance, TrainValidationSplitModel):
- if shouldOverwrite:
- # TODO(SPARK-50954): Support client side model path overwrite
- warnings.warn("Overwrite doesn't take effect for
TrainValidationSplitModel")
+ RemoteMLWriter.handleOverwrite(path, shouldOverwrite)
tvsm_writer = RemoteTrainValidationSplitModelWriter(instance,
optionMap, session)
tvsm_writer.save(path)
elif isinstance(instance, OneVsRest):
from pyspark.ml.classification import OneVsRestWriter
- if shouldOverwrite:
- # TODO(SPARK-50954): Support client side model path overwrite
- warnings.warn("Overwrite doesn't take effect for OneVsRest")
-
+ RemoteMLWriter.handleOverwrite(path, shouldOverwrite)
ovr_writer = OneVsRestWriter(instance)
ovr_writer.session(session) # type: ignore[arg-type]
ovr_writer.save(path)
elif isinstance(instance, OneVsRestModel):
from pyspark.ml.classification import OneVsRestModelWriter
- if shouldOverwrite:
- # TODO(SPARK-50954): Support client side model path overwrite
- warnings.warn("Overwrite doesn't take effect for
OneVsRestModel")
-
+ RemoteMLWriter.handleOverwrite(path, shouldOverwrite)
ovrm_writer = OneVsRestModelWriter(instance)
ovrm_writer.session(session) # type: ignore[arg-type]
ovrm_writer.save(path)
else:
raise NotImplementedError(f"Unsupported write for
{instance.__class__}")
+ @staticmethod
+ def handleOverwrite(path: str, shouldOverwrite: bool) -> None:
+ from pyspark.ml.util import ML_CONNECT_HELPER_ID
+
Review Comment:
good catch!
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]