This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new 09876bb8fbad [SPARK-50918][ML][PYTHON][CONNECT] Refactor read/write 
for Pipeline
09876bb8fbad is described below

commit 09876bb8fbad7d708739b29db6a7e2c573b1f685
Author: Bobby Wang <[email protected]>
AuthorDate: Tue Jan 28 18:55:56 2025 +0800

    [SPARK-50918][ML][PYTHON][CONNECT] Refactor read/write for Pipeline
    
    ### What changes were proposed in this pull request?
    
    We can use the built-in Pipeline/PipelineModel reader and writer to support 
read/write on connect
    
    ### Why are the changes needed?
    
    Reusing code
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    CI passes
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #49706 from wbo4958/pipeline-read-write.
    
    Authored-by: Bobby Wang <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
    (cherry picked from commit dd51f0e7592de569ce0e4db9c0eca3f05c160cba)
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/pyspark/ml/connect/readwrite.py | 53 +++++++++++++++++-----------------
 1 file changed, 27 insertions(+), 26 deletions(-)

diff --git a/python/pyspark/ml/connect/readwrite.py 
b/python/pyspark/ml/connect/readwrite.py
index 9e5587ebc5d9..de70a410dbc7 100644
--- a/python/pyspark/ml/connect/readwrite.py
+++ b/python/pyspark/ml/connect/readwrite.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 #
 import warnings
-from typing import cast, Type, TYPE_CHECKING, Union, List, Dict, Any, Optional
+from typing import cast, Type, TYPE_CHECKING, Union, Dict, Any, Optional
 
 import pyspark.sql.connect.proto as pb2
 from pyspark.ml.connect.serialize import serialize_ml_params, deserialize, 
deserialize_param
@@ -139,26 +139,26 @@ class RemoteMLWriter(MLWriter):
             command.ml_command.write.CopyFrom(writer)
             session.client.execute_command(command)
 
-        elif isinstance(instance, (Pipeline, PipelineModel)):
-            from pyspark.ml.pipeline import PipelineSharedReadWrite
+        elif isinstance(instance, Pipeline):
+            from pyspark.ml.pipeline import PipelineWriter
 
             if shouldOverwrite:
                 # TODO(SPARK-50954): Support client side model path overwrite
-                warnings.warn("Overwrite doesn't take effect for Pipeline and 
PipelineModel")
+                warnings.warn("Overwrite doesn't take effect for Pipeline")
 
-            if isinstance(instance, Pipeline):
-                stages = instance.getStages()  # type: ignore[attr-defined]
-            else:
-                stages = instance.stages
-
-            PipelineSharedReadWrite.validateStages(stages)
-            PipelineSharedReadWrite.saveImpl(
-                instance,  # type: ignore[arg-type]
-                stages,
-                session,  # type: ignore[arg-type]
-                path,
-            )
+            pl_writer = PipelineWriter(instance)
+            pl_writer.session(session)  # type: ignore[arg-type]
+            pl_writer.save(path)
+        elif isinstance(instance, PipelineModel):
+            from pyspark.ml.pipeline import PipelineModelWriter
 
+            if shouldOverwrite:
+                # TODO(SPARK-50954): Support client side model path overwrite
+                warnings.warn("Overwrite doesn't take effect for 
PipelineModel")
+
+            plm_writer = PipelineModelWriter(instance)
+            plm_writer.session(session)  # type: ignore[arg-type]
+            plm_writer.save(path)
         elif isinstance(instance, CrossValidator):
             from pyspark.ml.tuning import CrossValidatorWriter
 
@@ -231,7 +231,6 @@ class RemoteMLReader(MLReader[RL]):
         path: str,
         session: "SparkSession",
     ) -> RL:
-        from pyspark.ml.base import Transformer
         from pyspark.ml.wrapper import JavaModel, JavaEstimator, 
JavaTransformer
         from pyspark.ml.evaluation import JavaEvaluator
         from pyspark.ml.pipeline import Pipeline, PipelineModel
@@ -289,17 +288,19 @@ class RemoteMLReader(MLReader[RL]):
             else:
                 raise RuntimeError(f"Unsupported python type {py_type}")
 
-        elif issubclass(clazz, Pipeline) or issubclass(clazz, PipelineModel):
-            from pyspark.ml.pipeline import PipelineSharedReadWrite
-            from pyspark.ml.util import DefaultParamsReader
+        elif issubclass(clazz, Pipeline):
+            from pyspark.ml.pipeline import PipelineReader
 
-            metadata = DefaultParamsReader.loadMetadata(path, session)
-            uid, stages = PipelineSharedReadWrite.load(metadata, session, path)
+            pl_reader = PipelineReader(Pipeline)
+            pl_reader.session(session)
+            return pl_reader.load(path)
 
-            if issubclass(clazz, Pipeline):
-                return Pipeline(stages=stages)._resetUid(uid)
-            else:
-                return PipelineModel(stages=cast(List[Transformer], 
stages))._resetUid(uid)
+        elif issubclass(clazz, PipelineModel):
+            from pyspark.ml.pipeline import PipelineModelReader
+
+            plm_reader = PipelineModelReader(PipelineModel)
+            plm_reader.session(session)
+            return plm_reader.load(path)
 
         elif issubclass(clazz, CrossValidator):
             from pyspark.ml.tuning import CrossValidatorReader


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to