This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.0 by this push:
new f91646871213 [SPARK-50954][ML][PYTHON][CONNECT] Support client side
model path overwrite for meta algorithms
f91646871213 is described below
commit f91646871213dd15c23aa36970f16ff7d0081ae2
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Feb 5 23:10:53 2025 +0800
[SPARK-50954][ML][PYTHON][CONNECT] Support client side model path overwrite
for meta algorithms
### What changes were proposed in this pull request?
Support client side model path overwrite for meta algorithms
### Why are the changes needed?
for feature parity
### Does this PR introduce _any_ user-facing change?
yes
### How was this patch tested?
updated tests
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #49810 from zhengruifeng/ml_connect_overwrite.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
(cherry picked from commit 6a1d38219428eb48e1f3bbcae1bfe2d487b7afc3)
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../org/apache/spark/ml/util/ConnectHelper.scala | 8 +++-
.../scala/org/apache/spark/ml/util/ReadWrite.scala | 2 +-
python/pyspark/ml/connect/readwrite.py | 46 ++++++++--------------
python/pyspark/ml/tests/test_ovr.py | 11 ++----
python/pyspark/ml/tests/test_pipeline.py | 21 ++++------
python/pyspark/ml/tests/test_tuning.py | 26 +++++-------
.../org/apache/spark/sql/connect/ml/MLUtils.scala | 1 +
7 files changed, 48 insertions(+), 67 deletions(-)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ConnectHelper.scala
b/mllib/src/main/scala/org/apache/spark/ml/util/ConnectHelper.scala
index cf1a75518857..d4a0a1301e15 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ConnectHelper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ConnectHelper.scala
@@ -19,12 +19,18 @@ package org.apache.spark.ml.util
import org.apache.spark.ml.Model
import org.apache.spark.ml.feature.{CountVectorizerModel, StringIndexerModel}
import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.types.StructType
private[spark] class ConnectHelper(override val uid: String) extends
Model[ConnectHelper] {
def this() = this(Identifiable.randomUID("ConnectHelper"))
+ def handleOverwrite(path: String, shouldOverwrite: Boolean): Boolean = {
+ val spark = SparkSession.builder().getOrCreate()
+ new FileSystemOverwrite().handleOverwrite(path, shouldOverwrite, spark)
+ true
+ }
+
def stringIndexerModelFromLabels(
uid: String, labels: Array[String]): StringIndexerModel = {
new StringIndexerModel(uid, labels)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index f9d9056c801e..dcb337218edc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -754,7 +754,7 @@ private[ml] object MetaAlgorithmReadWrite {
}
}
-private[ml] class FileSystemOverwrite extends Logging {
+private[spark] class FileSystemOverwrite extends Logging {
def handleOverwrite(path: String, shouldOverwrite: Boolean, session:
SparkSession): Unit = {
val hadoopConf = session.sessionState.newHadoopConf()
diff --git a/python/pyspark/ml/connect/readwrite.py
b/python/pyspark/ml/connect/readwrite.py
index de70a410dbc7..5c2b850d51b3 100644
--- a/python/pyspark/ml/connect/readwrite.py
+++ b/python/pyspark/ml/connect/readwrite.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-import warnings
+
from typing import cast, Type, TYPE_CHECKING, Union, Dict, Any, Optional
import pyspark.sql.connect.proto as pb2
@@ -142,75 +142,63 @@ class RemoteMLWriter(MLWriter):
elif isinstance(instance, Pipeline):
from pyspark.ml.pipeline import PipelineWriter
- if shouldOverwrite:
- # TODO(SPARK-50954): Support client side model path overwrite
- warnings.warn("Overwrite doesn't take effect for Pipeline")
-
+ RemoteMLWriter.handleOverwrite(path, shouldOverwrite)
pl_writer = PipelineWriter(instance)
pl_writer.session(session) # type: ignore[arg-type]
pl_writer.save(path)
elif isinstance(instance, PipelineModel):
from pyspark.ml.pipeline import PipelineModelWriter
- if shouldOverwrite:
- # TODO(SPARK-50954): Support client side model path overwrite
- warnings.warn("Overwrite doesn't take effect for
PipelineModel")
-
+ RemoteMLWriter.handleOverwrite(path, shouldOverwrite)
plm_writer = PipelineModelWriter(instance)
plm_writer.session(session) # type: ignore[arg-type]
plm_writer.save(path)
elif isinstance(instance, CrossValidator):
from pyspark.ml.tuning import CrossValidatorWriter
- if shouldOverwrite:
- # TODO(SPARK-50954): Support client side model path overwrite
- warnings.warn("Overwrite doesn't take effect for
CrossValidator")
+ RemoteMLWriter.handleOverwrite(path, shouldOverwrite)
cv_writer = CrossValidatorWriter(instance)
cv_writer.session(session) # type: ignore[arg-type]
cv_writer.save(path)
elif isinstance(instance, CrossValidatorModel):
- if shouldOverwrite:
- # TODO(SPARK-50954): Support client side model path overwrite
- warnings.warn("Overwrite doesn't take effect for
CrossValidatorModel")
+ RemoteMLWriter.handleOverwrite(path, shouldOverwrite)
cvm_writer = RemoteCrossValidatorModelWriter(instance, optionMap,
session)
cvm_writer.save(path)
elif isinstance(instance, TrainValidationSplit):
from pyspark.ml.tuning import TrainValidationSplitWriter
- if shouldOverwrite:
- # TODO(SPARK-50954): Support client side model path overwrite
- warnings.warn("Overwrite doesn't take effect for
TrainValidationSplit")
+ RemoteMLWriter.handleOverwrite(path, shouldOverwrite)
tvs_writer = TrainValidationSplitWriter(instance)
tvs_writer.save(path)
elif isinstance(instance, TrainValidationSplitModel):
- if shouldOverwrite:
- # TODO(SPARK-50954): Support client side model path overwrite
- warnings.warn("Overwrite doesn't take effect for
TrainValidationSplitModel")
+ RemoteMLWriter.handleOverwrite(path, shouldOverwrite)
tvsm_writer = RemoteTrainValidationSplitModelWriter(instance,
optionMap, session)
tvsm_writer.save(path)
elif isinstance(instance, OneVsRest):
from pyspark.ml.classification import OneVsRestWriter
- if shouldOverwrite:
- # TODO(SPARK-50954): Support client side model path overwrite
- warnings.warn("Overwrite doesn't take effect for OneVsRest")
-
+ RemoteMLWriter.handleOverwrite(path, shouldOverwrite)
ovr_writer = OneVsRestWriter(instance)
ovr_writer.session(session) # type: ignore[arg-type]
ovr_writer.save(path)
elif isinstance(instance, OneVsRestModel):
from pyspark.ml.classification import OneVsRestModelWriter
- if shouldOverwrite:
- # TODO(SPARK-50954): Support client side model path overwrite
- warnings.warn("Overwrite doesn't take effect for
OneVsRestModel")
-
+ RemoteMLWriter.handleOverwrite(path, shouldOverwrite)
ovrm_writer = OneVsRestModelWriter(instance)
ovrm_writer.session(session) # type: ignore[arg-type]
ovrm_writer.save(path)
else:
raise NotImplementedError(f"Unsupported write for
{instance.__class__}")
+ @staticmethod
+ def handleOverwrite(path: str, shouldOverwrite: bool) -> None:
+ from pyspark.ml.util import ML_CONNECT_HELPER_ID
+
+ if shouldOverwrite:
+ helper = JavaWrapper(java_obj=ML_CONNECT_HELPER_ID)
+ helper._call_java("handleOverwrite", path, shouldOverwrite)
+
class RemoteMLReader(MLReader[RL]):
def __init__(self, clazz: Type["JavaMLReadable[RL]"]) -> None:
diff --git a/python/pyspark/ml/tests/test_ovr.py
b/python/pyspark/ml/tests/test_ovr.py
index aa003be53f5d..42275c852ac1 100644
--- a/python/pyspark/ml/tests/test_ovr.py
+++ b/python/pyspark/ml/tests/test_ovr.py
@@ -15,7 +15,6 @@
# limitations under the License.
#
-import os
import tempfile
import unittest
@@ -108,14 +107,12 @@ class OneVsRestTestsMixin:
# Model save & load
with tempfile.TemporaryDirectory(prefix="linear_svc") as d:
- path1 = os.path.join(d, "ovr")
- ovr.write().overwrite().save(path1)
- ovr2 = OneVsRest.load(path1)
+ ovr.write().overwrite().save(d)
+ ovr2 = OneVsRest.load(d)
self.assertEqual(str(ovr), str(ovr2))
- path2 = os.path.join(d, "ovr_model")
- model.write().overwrite().save(path2)
- model2 = OneVsRestModel.load(path2)
+ model.write().overwrite().save(d)
+ model2 = OneVsRestModel.load(d)
self.assertEqual(str(model), str(model2))
diff --git a/python/pyspark/ml/tests/test_pipeline.py
b/python/pyspark/ml/tests/test_pipeline.py
index 9c6ea2126f8d..8318f3bb71c9 100644
--- a/python/pyspark/ml/tests/test_pipeline.py
+++ b/python/pyspark/ml/tests/test_pipeline.py
@@ -15,7 +15,6 @@
# limitations under the License.
#
-import os
import tempfile
import unittest
@@ -110,15 +109,13 @@ class PipelineTestsMixin:
# save & load
with tempfile.TemporaryDirectory(prefix="classification_pipeline") as
d:
- path1 = os.path.join(d, "pipeline")
- pipeline.write().save(path1)
- pipeline2 = Pipeline.load(path1)
+ pipeline.write().overwrite().save(d)
+ pipeline2 = Pipeline.load(d)
self.assertEqual(str(pipeline), str(pipeline2))
self.assertEqual(str(pipeline.getStages()),
str(pipeline2.getStages()))
- path2 = os.path.join(d, "pipeline_model")
- model.write().save(path2)
- model2 = PipelineModel.load(path2)
+ model.write().overwrite().save(d)
+ model2 = PipelineModel.load(d)
self.assertEqual(str(model), str(model2))
self.assertEqual(str(model.stages), str(model2.stages))
@@ -165,15 +162,13 @@ class PipelineTestsMixin:
# PipelineModel save & load
with tempfile.TemporaryDirectory(prefix="clustering_pipeline") as d:
- path1 = os.path.join(d, "pipeline")
- pipeline.write().save(path1)
- pipeline2 = Pipeline.load(path1)
+ pipeline.write().overwrite().save(d)
+ pipeline2 = Pipeline.load(d)
self.assertEqual(str(pipeline), str(pipeline2))
self.assertEqual(str(pipeline.getStages()),
str(pipeline2.getStages()))
- path2 = os.path.join(d, "pipeline_model")
- model.write().save(path2)
- model2 = PipelineModel.load(path2)
+ model.write().overwrite().save(d)
+ model2 = PipelineModel.load(d)
self.assertEqual(str(model), str(model2))
self.assertEqual(str(model.stages), str(model2.stages))
diff --git a/python/pyspark/ml/tests/test_tuning.py
b/python/pyspark/ml/tests/test_tuning.py
index 451c89db5b5a..94081a090bfe 100644
--- a/python/pyspark/ml/tests/test_tuning.py
+++ b/python/pyspark/ml/tests/test_tuning.py
@@ -15,7 +15,6 @@
# limitations under the License.
#
-import os
import tempfile
import unittest
@@ -73,16 +72,14 @@ class TuningTestsMixin:
# save & load
with tempfile.TemporaryDirectory(prefix="train_validation_split") as d:
- path1 = os.path.join(d, "tvs")
- tvs.write().save(path1)
- tvs2 = TrainValidationSplit.load(path1)
+ tvs.write().overwrite().save(d)
+ tvs2 = TrainValidationSplit.load(d)
self.assertEqual(str(tvs), str(tvs2))
self.assertEqual(str(tvs.getEstimator()), str(tvs2.getEstimator()))
self.assertEqual(str(tvs.getEvaluator()), str(tvs2.getEvaluator()))
- path2 = os.path.join(d, "tvsm")
- tvs_model.write().save(path2)
- model2 = TrainValidationSplitModel.load(path2)
+ tvs_model.write().overwrite().save(d)
+ model2 = TrainValidationSplitModel.load(d)
self.assertEqual(str(tvs_model), str(model2))
self.assertEqual(str(tvs_model.getEstimator()),
str(model2.getEstimator()))
self.assertEqual(str(tvs_model.getEvaluator()),
str(model2.getEvaluator()))
@@ -142,24 +139,21 @@ class TuningTestsMixin:
# save & load
with tempfile.TemporaryDirectory(prefix="cv") as d:
- path1 = os.path.join(d, "cv")
- cv.write().save(path1)
- cv2 = CrossValidator.load(path1)
+ cv.write().overwrite().save(d)
+ cv2 = CrossValidator.load(d)
self.assertEqual(str(cv), str(cv2))
self.assertEqual(str(cv.getEstimator()), str(cv2.getEstimator()))
self.assertEqual(str(cv.getEvaluator()), str(cv2.getEvaluator()))
- path2 = os.path.join(d, "cv_model")
- model.write().save(path2)
- model2 = CrossValidatorModel.load(path2)
+ model.write().overwrite().save(d)
+ model2 = CrossValidatorModel.load(d)
checkSubModels(model2.subModels)
self.assertEqual(str(model), str(model2))
self.assertEqual(str(model.getEstimator()),
str(model2.getEstimator()))
self.assertEqual(str(model.getEvaluator()),
str(model2.getEvaluator()))
- path2 = os.path.join(d, "cv_model2")
- model.write().option("persistSubModels", "false").save(path2)
- cvModel2 = CrossValidatorModel.load(path2)
+ model.write().overwrite().option("persistSubModels",
"false").save(d)
+ cvModel2 = CrossValidatorModel.load(d)
self.assertEqual(cvModel2.subModels, None)
model3 = model2.copy()
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
index e69f670226e4..2cbae0fa7d9d 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
@@ -656,6 +656,7 @@ private[ml] object MLUtils {
(
classOf[ConnectHelper],
Set(
+ "handleOverwrite",
"stringIndexerModelFromLabels",
"stringIndexerModelFromLabelsArray",
"countVectorizerModelFromVocabulary")))
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]