This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3d76e0bbc30f [SPARK-51275][PYTHON][ML][CONNECT] Session propagation in 
python readwrite
3d76e0bbc30f is described below

commit 3d76e0bbc30f735665cc4d84659c5737fb8af08a
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Fri Feb 21 15:56:48 2025 +0800

    [SPARK-51275][PYTHON][ML][CONNECT] Session propagation in python readwrite
    
    ### What changes were proposed in this pull request?
    Session propagation in python readwrite
    
    ### Why are the changes needed?
    to avoid session recreation
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    existing test should cover
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #50035 from zhengruifeng/py_ml_sc_session.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/ml/pipeline.py | 8 +++++---
 python/pyspark/ml/util.py     | 8 ++++----
 2 files changed, 9 insertions(+), 7 deletions(-)

diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index 18f537cf197a..b77392a50c7f 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -420,10 +420,11 @@ class PipelineSharedReadWrite:
         """
         stageUids = [stage.uid for stage in stages]
         jsonParams = {"stageUids": stageUids, "language": "Python"}
-        DefaultParamsWriter.saveMetadata(instance, path, sc, 
paramMap=jsonParams)
+        spark = cast(SparkSession, sc) if hasattr(sc, "createDataFrame") else 
SparkSession.active()
+        DefaultParamsWriter.saveMetadata(instance, path, spark, 
paramMap=jsonParams)
         stagesDir = os.path.join(path, "stages")
         for index, stage in enumerate(stages):
-            cast(MLWritable, stage).write().save(
+            cast(MLWritable, stage).write().session(spark).save(
                 PipelineSharedReadWrite.getStagePath(stage.uid, index, 
len(stages), stagesDir)
             )
 
@@ -443,12 +444,13 @@ class PipelineSharedReadWrite:
         """
         stagesDir = os.path.join(path, "stages")
         stageUids = metadata["paramMap"]["stageUids"]
+        spark = cast(SparkSession, sc) if hasattr(sc, "createDataFrame") else 
SparkSession.active()
         stages = []
         for index, stageUid in enumerate(stageUids):
             stagePath = PipelineSharedReadWrite.getStagePath(
                 stageUid, index, len(stageUids), stagesDir
             )
-            stage: "PipelineStage" = 
DefaultParamsReader.loadParamsInstance(stagePath, sc)
+            stage: "PipelineStage" = 
DefaultParamsReader.loadParamsInstance(stagePath, spark)
             stages.append(stage)
         return (metadata["uid"], stages)
 
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index 4919b828a35c..6b3d6101c249 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -462,7 +462,7 @@ class BaseReadWrite:
         Returns the user-specified Spark Session or the default.
         """
         if self._sparkSession is None:
-            self._sparkSession = SparkSession._getActiveSessionOrCreate()
+            self._sparkSession = SparkSession.active()
         assert self._sparkSession is not None
         return self._sparkSession
 
@@ -809,10 +809,10 @@ class DefaultParamsWriter(MLWriter):
             If given, this is saved in the "paramMap" field.
         """
         metadataPath = os.path.join(path, "metadata")
+        spark = cast(SparkSession, sc) if hasattr(sc, "createDataFrame") else 
SparkSession.active()
         metadataJson = DefaultParamsWriter._get_metadata_to_save(
-            instance, sc, extraMetadata, paramMap
+            instance, spark, extraMetadata, paramMap
         )
-        spark = sc if isinstance(sc, SparkSession) else 
SparkSession._getActiveSessionOrCreate()
         spark.createDataFrame([(metadataJson,)], 
schema=["value"]).coalesce(1).write.text(
             metadataPath
         )
@@ -932,7 +932,7 @@ class DefaultParamsReader(MLReader[RL]):
             If non empty, this is checked against the loaded metadata.
         """
         metadataPath = os.path.join(path, "metadata")
-        spark = sc if isinstance(sc, SparkSession) else 
SparkSession._getActiveSessionOrCreate()
+        spark = cast(SparkSession, sc) if hasattr(sc, "createDataFrame") else 
SparkSession.active()
         metadataStr = spark.read.text(metadataPath).first()[0]  # type: 
ignore[index]
         loadedVals = DefaultParamsReader._parseMetaData(metadataStr, 
expectedClassName)
         return loadedVals


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to