This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.0 by this push:
new 09876bb8fbad [SPARK-50918][ML][PYTHON][CONNECT] Refactor read/write
for Pipeline
09876bb8fbad is described below
commit 09876bb8fbad7d708739b29db6a7e2c573b1f685
Author: Bobby Wang <[email protected]>
AuthorDate: Tue Jan 28 18:55:56 2025 +0800
[SPARK-50918][ML][PYTHON][CONNECT] Refactor read/write for Pipeline
### What changes were proposed in this pull request?
We can use the built-in Pipeline/PipelineModel reader and writer to support
read/write on connect
### Why are the changes needed?
Reusing code
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
CI passes
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #49706 from wbo4958/pipeline-read-write.
Authored-by: Bobby Wang <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
(cherry picked from commit dd51f0e7592de569ce0e4db9c0eca3f05c160cba)
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/ml/connect/readwrite.py | 53 +++++++++++++++++-----------------
1 file changed, 27 insertions(+), 26 deletions(-)
diff --git a/python/pyspark/ml/connect/readwrite.py
b/python/pyspark/ml/connect/readwrite.py
index 9e5587ebc5d9..de70a410dbc7 100644
--- a/python/pyspark/ml/connect/readwrite.py
+++ b/python/pyspark/ml/connect/readwrite.py
@@ -15,7 +15,7 @@
# limitations under the License.
#
import warnings
-from typing import cast, Type, TYPE_CHECKING, Union, List, Dict, Any, Optional
+from typing import cast, Type, TYPE_CHECKING, Union, Dict, Any, Optional
import pyspark.sql.connect.proto as pb2
from pyspark.ml.connect.serialize import serialize_ml_params, deserialize,
deserialize_param
@@ -139,26 +139,26 @@ class RemoteMLWriter(MLWriter):
command.ml_command.write.CopyFrom(writer)
session.client.execute_command(command)
- elif isinstance(instance, (Pipeline, PipelineModel)):
- from pyspark.ml.pipeline import PipelineSharedReadWrite
+ elif isinstance(instance, Pipeline):
+ from pyspark.ml.pipeline import PipelineWriter
if shouldOverwrite:
# TODO(SPARK-50954): Support client side model path overwrite
- warnings.warn("Overwrite doesn't take effect for Pipeline and
PipelineModel")
+ warnings.warn("Overwrite doesn't take effect for Pipeline")
- if isinstance(instance, Pipeline):
- stages = instance.getStages() # type: ignore[attr-defined]
- else:
- stages = instance.stages
-
- PipelineSharedReadWrite.validateStages(stages)
- PipelineSharedReadWrite.saveImpl(
- instance, # type: ignore[arg-type]
- stages,
- session, # type: ignore[arg-type]
- path,
- )
+ pl_writer = PipelineWriter(instance)
+ pl_writer.session(session) # type: ignore[arg-type]
+ pl_writer.save(path)
+ elif isinstance(instance, PipelineModel):
+ from pyspark.ml.pipeline import PipelineModelWriter
+ if shouldOverwrite:
+ # TODO(SPARK-50954): Support client side model path overwrite
+ warnings.warn("Overwrite doesn't take effect for
PipelineModel")
+
+ plm_writer = PipelineModelWriter(instance)
+ plm_writer.session(session) # type: ignore[arg-type]
+ plm_writer.save(path)
elif isinstance(instance, CrossValidator):
from pyspark.ml.tuning import CrossValidatorWriter
@@ -231,7 +231,6 @@ class RemoteMLReader(MLReader[RL]):
path: str,
session: "SparkSession",
) -> RL:
- from pyspark.ml.base import Transformer
from pyspark.ml.wrapper import JavaModel, JavaEstimator,
JavaTransformer
from pyspark.ml.evaluation import JavaEvaluator
from pyspark.ml.pipeline import Pipeline, PipelineModel
@@ -289,17 +288,19 @@ class RemoteMLReader(MLReader[RL]):
else:
raise RuntimeError(f"Unsupported python type {py_type}")
- elif issubclass(clazz, Pipeline) or issubclass(clazz, PipelineModel):
- from pyspark.ml.pipeline import PipelineSharedReadWrite
- from pyspark.ml.util import DefaultParamsReader
+ elif issubclass(clazz, Pipeline):
+ from pyspark.ml.pipeline import PipelineReader
- metadata = DefaultParamsReader.loadMetadata(path, session)
- uid, stages = PipelineSharedReadWrite.load(metadata, session, path)
+ pl_reader = PipelineReader(Pipeline)
+ pl_reader.session(session)
+ return pl_reader.load(path)
- if issubclass(clazz, Pipeline):
- return Pipeline(stages=stages)._resetUid(uid)
- else:
- return PipelineModel(stages=cast(List[Transformer],
stages))._resetUid(uid)
+ elif issubclass(clazz, PipelineModel):
+ from pyspark.ml.pipeline import PipelineModelReader
+
+ plm_reader = PipelineModelReader(PipelineModel)
+ plm_reader.session(session)
+ return plm_reader.load(path)
elif issubclass(clazz, CrossValidator):
from pyspark.ml.tuning import CrossValidatorReader
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]