This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8597b78271fc [SPARK-48988][ML] Make `DefaultParamsReader/Writer` 
handle metadata with spark session
8597b78271fc is described below

commit 8597b78271fcc29276d611186455695636b7b503
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Jul 24 19:13:32 2024 +0900

    [SPARK-48988][ML] Make `DefaultParamsReader/Writer` handle metadata with 
spark session
    
    ### What changes were proposed in this pull request?
    `DefaultParamsReader/Writer` handle metadata with spark session
    
    ### Why are the changes needed?
    In existing ml implementations, when loading/saving a model, it loads/saves 
the metadata with `SparkContext` then loads/saves the coefficients with 
`SparkSession`.
    
    This PR aims to also load/save the metadata with `SparkSession`, by 
introducing new helper functions.
    
    - Note I: 3-rd libraries (e.g. 
[xgboost](https://github.com/dmlc/xgboost/blob/master/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/ml/util/XGBoostReadWrite.scala#L38-L53)
 ) likely depends on existing implementation of saveMetadata/loadMetadata, so 
we cannot simply remove them even though they are `private[ml]`.
    
    - Note II: this PR only handles `loadMetadata` and `saveMetadata`, there 
are similar cases for meta algorithms and param read/write, but I want to 
ignore the remaining part first, to avoid touching too many files in single PR.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    CI
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #47467 from zhengruifeng/ml_load_with_spark.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../ml/classification/DecisionTreeClassifier.scala |  4 +-
 .../spark/ml/classification/FMClassifier.scala     |  4 +-
 .../apache/spark/ml/classification/LinearSVC.scala |  4 +-
 .../ml/classification/LogisticRegression.scala     |  4 +-
 .../MultilayerPerceptronClassifier.scala           |  4 +-
 .../spark/ml/classification/NaiveBayes.scala       |  4 +-
 .../spark/ml/clustering/BisectingKMeans.scala      |  4 +-
 .../spark/ml/clustering/GaussianMixture.scala      |  4 +-
 .../org/apache/spark/ml/clustering/KMeans.scala    |  5 +-
 .../scala/org/apache/spark/ml/clustering/LDA.scala | 10 ++--
 .../ml/feature/BucketedRandomProjectionLSH.scala   |  4 +-
 .../apache/spark/ml/feature/ChiSqSelector.scala    |  4 +-
 .../apache/spark/ml/feature/CountVectorizer.scala  |  4 +-
 .../org/apache/spark/ml/feature/HashingTF.scala    |  2 +-
 .../scala/org/apache/spark/ml/feature/IDF.scala    |  4 +-
 .../org/apache/spark/ml/feature/Imputer.scala      |  4 +-
 .../org/apache/spark/ml/feature/MaxAbsScaler.scala |  4 +-
 .../org/apache/spark/ml/feature/MinHashLSH.scala   |  4 +-
 .../org/apache/spark/ml/feature/MinMaxScaler.scala |  4 +-
 .../apache/spark/ml/feature/OneHotEncoder.scala    |  4 +-
 .../scala/org/apache/spark/ml/feature/PCA.scala    |  4 +-
 .../org/apache/spark/ml/feature/RFormula.scala     | 12 ++---
 .../org/apache/spark/ml/feature/RobustScaler.scala |  4 +-
 .../apache/spark/ml/feature/StandardScaler.scala   |  4 +-
 .../apache/spark/ml/feature/StringIndexer.scala    |  4 +-
 .../ml/feature/UnivariateFeatureSelector.scala     |  4 +-
 .../ml/feature/VarianceThresholdSelector.scala     |  4 +-
 .../apache/spark/ml/feature/VectorIndexer.scala    |  4 +-
 .../org/apache/spark/ml/feature/Word2Vec.scala     |  4 +-
 .../scala/org/apache/spark/ml/fpm/FPGrowth.scala   |  5 +-
 .../org/apache/spark/ml/recommendation/ALS.scala   |  4 +-
 .../ml/regression/AFTSurvivalRegression.scala      |  4 +-
 .../ml/regression/DecisionTreeRegressor.scala      |  4 +-
 .../apache/spark/ml/regression/FMRegressor.scala   |  4 +-
 .../regression/GeneralizedLinearRegression.scala   |  4 +-
 .../spark/ml/regression/IsotonicRegression.scala   |  4 +-
 .../spark/ml/regression/LinearRegression.scala     |  4 +-
 .../org/apache/spark/ml/tree/treeModels.scala      |  4 +-
 .../scala/org/apache/spark/ml/util/ReadWrite.scala | 61 +++++++++++++++++++---
 39 files changed, 137 insertions(+), 90 deletions(-)

diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 7deefda2eeaf..c5f1d7f39b6b 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -293,7 +293,7 @@ object DecisionTreeClassificationModel extends 
MLReadable[DecisionTreeClassifica
       val extraMetadata: JObject = Map(
         "numFeatures" -> instance.numFeatures,
         "numClasses" -> instance.numClasses)
-      DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
+      DefaultParamsWriter.saveMetadata(instance, path, sparkSession, 
Some(extraMetadata))
       val (nodeData, _) = NodeData.build(instance.rootNode, 0)
       val dataPath = new Path(path, "data").toString
       val numDataParts = NodeData.inferNumPartitions(instance.numNodes)
@@ -309,7 +309,7 @@ object DecisionTreeClassificationModel extends 
MLReadable[DecisionTreeClassifica
 
     override def load(path: String): DecisionTreeClassificationModel = {
       implicit val format = DefaultFormats
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
       val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
       val numClasses = (metadata.metadata \ "numClasses").extract[Int]
       val root = loadTreeNodes(path, metadata, sparkSession)
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala 
b/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala
index aec740a932ac..33e7c1fdd5e0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala
@@ -339,7 +339,7 @@ object FMClassificationModel extends 
MLReadable[FMClassificationModel] {
       factors: Matrix)
 
     override protected def saveImpl(path: String): Unit = {
-      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
       val data = Data(instance.intercept, instance.linear, instance.factors)
       val dataPath = new Path(path, "data").toString
       sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
@@ -351,7 +351,7 @@ object FMClassificationModel extends 
MLReadable[FMClassificationModel] {
     private val className = classOf[FMClassificationModel].getName
 
     override def load(path: String): FMClassificationModel = {
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
       val dataPath = new Path(path, "data").toString
       val data = sparkSession.read.format("parquet").load(dataPath)
 
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala 
b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
index 3e27f781d561..161e8f4cbd2c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
@@ -446,7 +446,7 @@ object LinearSVCModel extends MLReadable[LinearSVCModel] {
 
     override protected def saveImpl(path: String): Unit = {
       // Save metadata and Params
-      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
       val data = Data(instance.coefficients, instance.intercept)
       val dataPath = new Path(path, "data").toString
       sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
@@ -459,7 +459,7 @@ object LinearSVCModel extends MLReadable[LinearSVCModel] {
     private val className = classOf[LinearSVCModel].getName
 
     override def load(path: String): LinearSVCModel = {
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
       val dataPath = new Path(path, "data").toString
       val data = sparkSession.read.format("parquet").load(dataPath)
       val Row(coefficients: Vector, intercept: Double) =
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index ac0682f1df5b..745cb61bb7aa 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -1310,7 +1310,7 @@ object LogisticRegressionModel extends 
MLReadable[LogisticRegressionModel] {
 
     override protected def saveImpl(path: String): Unit = {
       // Save metadata and Params
-      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
       // Save model data: numClasses, numFeatures, intercept, coefficients
       val data = Data(instance.numClasses, instance.numFeatures, 
instance.interceptVector,
         instance.coefficientMatrix, instance.isMultinomial)
@@ -1325,7 +1325,7 @@ object LogisticRegressionModel extends 
MLReadable[LogisticRegressionModel] {
     private val className = classOf[LogisticRegressionModel].getName
 
     override def load(path: String): LogisticRegressionModel = {
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
       val (major, minor) = 
VersionUtils.majorMinorVersion(metadata.sparkVersion)
 
       val dataPath = new Path(path, "data").toString
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
index 16984bf9aed8..106282b9dc3a 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
@@ -365,7 +365,7 @@ object MultilayerPerceptronClassificationModel
 
     override protected def saveImpl(path: String): Unit = {
       // Save metadata and Params
-      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
       // Save model data: weights
       val data = Data(instance.weights)
       val dataPath = new Path(path, "data").toString
@@ -380,7 +380,7 @@ object MultilayerPerceptronClassificationModel
     private val className = 
classOf[MultilayerPerceptronClassificationModel].getName
 
     override def load(path: String): MultilayerPerceptronClassificationModel = 
{
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
       val (majorVersion, _) = majorMinorVersion(metadata.sparkVersion)
 
       val dataPath = new Path(path, "data").toString
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala 
b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index 52486cb8aa24..4a511581d31a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -580,7 +580,7 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
 
     override protected def saveImpl(path: String): Unit = {
       // Save metadata and Params
-      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
       val dataPath = new Path(path, "data").toString
 
       instance.getModelType match {
@@ -602,7 +602,7 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
 
     override def load(path: String): NaiveBayesModel = {
       implicit val format = DefaultFormats
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
       val (major, minor) = 
VersionUtils.majorMinorVersion(metadata.sparkVersion)
 
       val dataPath = new Path(path, "data").toString
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
index 2d809151384b..b4f1565362b0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
@@ -186,7 +186,7 @@ object BisectingKMeansModel extends 
MLReadable[BisectingKMeansModel] {
 
     override protected def saveImpl(path: String): Unit = {
       // Save metadata and Params
-      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
       val dataPath = new Path(path, "data").toString
       instance.parentModel.save(sc, dataPath)
     }
@@ -198,7 +198,7 @@ object BisectingKMeansModel extends 
MLReadable[BisectingKMeansModel] {
     private val className = classOf[BisectingKMeansModel].getName
 
     override def load(path: String): BisectingKMeansModel = {
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
       val dataPath = new Path(path, "data").toString
       val mllibModel = MLlibBisectingKMeansModel.load(sc, dataPath)
       val model = new BisectingKMeansModel(metadata.uid, mllibModel)
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
index 0f6648bb4cda..d0db5dcba87b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -235,7 +235,7 @@ object GaussianMixtureModel extends 
MLReadable[GaussianMixtureModel] {
 
     override protected def saveImpl(path: String): Unit = {
       // Save metadata and Params
-      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
       // Save model data: weights and gaussians
       val weights = instance.weights
       val gaussians = instance.gaussians
@@ -253,7 +253,7 @@ object GaussianMixtureModel extends 
MLReadable[GaussianMixtureModel] {
     private val className = classOf[GaussianMixtureModel].getName
 
     override def load(path: String): GaussianMixtureModel = {
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
 
       val dataPath = new Path(path, "data").toString
       val row = sparkSession.read.parquet(dataPath).select("weights", "mus", 
"sigmas").head()
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 04f76660aee6..50fb18bb620a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -219,9 +219,8 @@ private class InternalKMeansModelWriter extends 
MLWriterFormat with MLFormatRegi
   override def write(path: String, sparkSession: SparkSession,
     optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = {
     val instance = stage.asInstanceOf[KMeansModel]
-    val sc = sparkSession.sparkContext
     // Save metadata and Params
-    DefaultParamsWriter.saveMetadata(instance, path, sc)
+    DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
     // Save model data: cluster centers
     val data: Array[ClusterData] = instance.clusterCenters.zipWithIndex.map {
       case (center, idx) =>
@@ -272,7 +271,7 @@ object KMeansModel extends MLReadable[KMeansModel] {
       val sparkSession = super.sparkSession
       import sparkSession.implicits._
 
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
       val dataPath = new Path(path, "data").toString
 
       val clusterCenters = if (majorVersion(metadata.sparkVersion) >= 2) {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index 7cbfc732a19c..b3d3c84db051 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -654,7 +654,7 @@ object LocalLDAModel extends MLReadable[LocalLDAModel] {
         gammaShape: Double)
 
     override protected def saveImpl(path: String): Unit = {
-      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
       val oldModel = instance.oldLocalModel
       val data = Data(instance.vocabSize, oldModel.topicsMatrix, 
oldModel.docConcentration,
         oldModel.topicConcentration, oldModel.gammaShape)
@@ -668,7 +668,7 @@ object LocalLDAModel extends MLReadable[LocalLDAModel] {
     private val className = classOf[LocalLDAModel].getName
 
     override def load(path: String): LocalLDAModel = {
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
       val dataPath = new Path(path, "data").toString
       val data = sparkSession.read.parquet(dataPath)
       val vectorConverted = MLUtils.convertVectorColumnsToML(data, 
"docConcentration")
@@ -809,7 +809,7 @@ object DistributedLDAModel extends 
MLReadable[DistributedLDAModel] {
   class DistributedWriter(instance: DistributedLDAModel) extends MLWriter {
 
     override protected def saveImpl(path: String): Unit = {
-      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
       val modelPath = new Path(path, "oldModel").toString
       instance.oldDistributedModel.save(sc, modelPath)
     }
@@ -820,7 +820,7 @@ object DistributedLDAModel extends 
MLReadable[DistributedLDAModel] {
     private val className = classOf[DistributedLDAModel].getName
 
     override def load(path: String): DistributedLDAModel = {
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
       val modelPath = new Path(path, "oldModel").toString
       val oldModel = OldDistributedLDAModel.load(sc, modelPath)
       val model = new DistributedLDAModel(metadata.uid, oldModel.vocabSize,
@@ -1008,7 +1008,7 @@ object LDA extends MLReadable[LDA] {
     private val className = classOf[LDA].getName
 
     override def load(path: String): LDA = {
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
       val model = new LDA(metadata.uid)
       LDAParams.getAndSetParams(model, metadata)
       model
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
index d30962088cb8..537cb5020c88 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
@@ -227,7 +227,7 @@ object BucketedRandomProjectionLSHModel extends 
MLReadable[BucketedRandomProject
     private case class Data(randUnitVectors: Matrix)
 
     override protected def saveImpl(path: String): Unit = {
-      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
       val data = Data(instance.randMatrix)
       val dataPath = new Path(path, "data").toString
       sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
@@ -241,7 +241,7 @@ object BucketedRandomProjectionLSHModel extends 
MLReadable[BucketedRandomProject
     private val className = classOf[BucketedRandomProjectionLSHModel].getName
 
     override def load(path: String): BucketedRandomProjectionLSHModel = {
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
 
       val dataPath = new Path(path, "data").toString
       val data = sparkSession.read.parquet(dataPath)
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
index 3062f643e950..eb2122b09b2f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
@@ -173,7 +173,7 @@ object ChiSqSelectorModel extends 
MLReadable[ChiSqSelectorModel] {
     private case class Data(selectedFeatures: Seq[Int])
 
     override protected def saveImpl(path: String): Unit = {
-      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
       val data = Data(instance.selectedFeatures.toImmutableArraySeq)
       val dataPath = new Path(path, "data").toString
       sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
@@ -185,7 +185,7 @@ object ChiSqSelectorModel extends 
MLReadable[ChiSqSelectorModel] {
     private val className = classOf[ChiSqSelectorModel].getName
 
     override def load(path: String): ChiSqSelectorModel = {
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
       val dataPath = new Path(path, "data").toString
       val data = 
sparkSession.read.parquet(dataPath).select("selectedFeatures").head()
       val selectedFeatures = data.getAs[Seq[Int]](0).toArray
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
index b81914f86fbb..611b5c710add 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
@@ -372,7 +372,7 @@ object CountVectorizerModel extends 
MLReadable[CountVectorizerModel] {
     private case class Data(vocabulary: Seq[String])
 
     override protected def saveImpl(path: String): Unit = {
-      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
       val data = Data(instance.vocabulary.toImmutableArraySeq)
       val dataPath = new Path(path, "data").toString
       sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
@@ -384,7 +384,7 @@ object CountVectorizerModel extends 
MLReadable[CountVectorizerModel] {
     private val className = classOf[CountVectorizerModel].getName
 
     override def load(path: String): CountVectorizerModel = {
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
       val dataPath = new Path(path, "data").toString
       val data = sparkSession.read.parquet(dataPath)
         .select("vocabulary")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
index f4223bc85943..3b42105958c7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
@@ -154,7 +154,7 @@ object HashingTF extends DefaultParamsReadable[HashingTF] {
     private val className = classOf[HashingTF].getName
 
     override def load(path: String): HashingTF = {
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
 
       // We support loading old `HashingTF` saved by previous Spark versions.
       // Previous `HashingTF` uses `mllib.feature.HashingTF.murmur3Hash`, but 
new `HashingTF` uses
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
index 696e1516582d..3025a7b04af5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
@@ -198,7 +198,7 @@ object IDFModel extends MLReadable[IDFModel] {
     private case class Data(idf: Vector, docFreq: Array[Long], numDocs: Long)
 
     override protected def saveImpl(path: String): Unit = {
-      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
       val data = Data(instance.idf, instance.docFreq, instance.numDocs)
       val dataPath = new Path(path, "data").toString
       sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
@@ -210,7 +210,7 @@ object IDFModel extends MLReadable[IDFModel] {
     private val className = classOf[IDFModel].getName
 
     override def load(path: String): IDFModel = {
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
       val dataPath = new Path(path, "data").toString
       val data = sparkSession.read.parquet(dataPath)
 
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
index ae65b17d7a81..38fb25903dca 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
@@ -308,7 +308,7 @@ object ImputerModel extends MLReadable[ImputerModel] {
   private[ImputerModel] class ImputerModelWriter(instance: ImputerModel) 
extends MLWriter {
 
     override protected def saveImpl(path: String): Unit = {
-      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
       val dataPath = new Path(path, "data").toString
       instance.surrogateDF.repartition(1).write.parquet(dataPath)
     }
@@ -319,7 +319,7 @@ object ImputerModel extends MLReadable[ImputerModel] {
     private val className = classOf[ImputerModel].getName
 
     override def load(path: String): ImputerModel = {
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
       val dataPath = new Path(path, "data").toString
       val surrogateDF = sqlContext.read.parquet(dataPath)
       val model = new ImputerModel(metadata.uid, surrogateDF)
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
index 05ee59d1627d..1a378cd85f3e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
@@ -162,7 +162,7 @@ object MaxAbsScalerModel extends 
MLReadable[MaxAbsScalerModel] {
     private case class Data(maxAbs: Vector)
 
     override protected def saveImpl(path: String): Unit = {
-      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
       val data = new Data(instance.maxAbs)
       val dataPath = new Path(path, "data").toString
       sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
@@ -174,7 +174,7 @@ object MaxAbsScalerModel extends 
MLReadable[MaxAbsScalerModel] {
     private val className = classOf[MaxAbsScalerModel].getName
 
     override def load(path: String): MaxAbsScalerModel = {
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
       val dataPath = new Path(path, "data").toString
       val Row(maxAbs: Vector) = sparkSession.read.parquet(dataPath)
         .select("maxAbs")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
index d94aadd1ce1f..3f2a3327128a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
@@ -220,7 +220,7 @@ object MinHashLSHModel extends MLReadable[MinHashLSHModel] {
     private case class Data(randCoefficients: Array[Int])
 
     override protected def saveImpl(path: String): Unit = {
-      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
       val data = Data(instance.randCoefficients.flatMap(tuple => 
Array(tuple._1, tuple._2)))
       val dataPath = new Path(path, "data").toString
       sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
@@ -233,7 +233,7 @@ object MinHashLSHModel extends MLReadable[MinHashLSHModel] {
     private val className = classOf[MinHashLSHModel].getName
 
     override def load(path: String): MinHashLSHModel = {
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
 
       val dataPath = new Path(path, "data").toString
       val data = 
sparkSession.read.parquet(dataPath).select("randCoefficients").head()
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
index 4111e559a5c2..c311f4260424 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
@@ -247,7 +247,7 @@ object MinMaxScalerModel extends 
MLReadable[MinMaxScalerModel] {
     private case class Data(originalMin: Vector, originalMax: Vector)
 
     override protected def saveImpl(path: String): Unit = {
-      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
       val data = new Data(instance.originalMin, instance.originalMax)
       val dataPath = new Path(path, "data").toString
       sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
@@ -259,7 +259,7 @@ object MinMaxScalerModel extends 
MLReadable[MinMaxScalerModel] {
     private val className = classOf[MinMaxScalerModel].getName
 
     override def load(path: String): MinMaxScalerModel = {
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
       val dataPath = new Path(path, "data").toString
       val data = sparkSession.read.parquet(dataPath)
       val Row(originalMin: Vector, originalMax: Vector) =
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
index e7cf0105754a..823f767eebbe 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
@@ -403,7 +403,7 @@ object OneHotEncoderModel extends 
MLReadable[OneHotEncoderModel] {
     private case class Data(categorySizes: Array[Int])
 
     override protected def saveImpl(path: String): Unit = {
-      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
       val data = Data(instance.categorySizes)
       val dataPath = new Path(path, "data").toString
       sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
@@ -415,7 +415,7 @@ object OneHotEncoderModel extends 
MLReadable[OneHotEncoderModel] {
     private val className = classOf[OneHotEncoderModel].getName
 
     override def load(path: String): OneHotEncoderModel = {
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
       val dataPath = new Path(path, "data").toString
       val data = sparkSession.read.parquet(dataPath)
         .select("categorySizes")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
index f7ec18b38a0e..0bd9a3c38a1e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
@@ -184,7 +184,7 @@ object PCAModel extends MLReadable[PCAModel] {
     private case class Data(pc: DenseMatrix, explainedVariance: DenseVector)
 
     override protected def saveImpl(path: String): Unit = {
-      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
       val data = Data(instance.pc, instance.explainedVariance)
       val dataPath = new Path(path, "data").toString
       sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
@@ -205,7 +205,7 @@ object PCAModel extends MLReadable[PCAModel] {
      * @return a [[PCAModel]]
      */
     override def load(path: String): PCAModel = {
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
 
       val dataPath = new Path(path, "data").toString
       val model = if (majorVersion(metadata.sparkVersion) >= 2) {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index 7a47e73e5ef4..77bd18423ef1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -427,7 +427,7 @@ object RFormulaModel extends MLReadable[RFormulaModel] {
 
     override protected def saveImpl(path: String): Unit = {
       // Save metadata and Params
-      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
       // Save model data: resolvedFormula
       val dataPath = new Path(path, "data").toString
       sparkSession.createDataFrame(Seq(instance.resolvedFormula))
@@ -444,7 +444,7 @@ object RFormulaModel extends MLReadable[RFormulaModel] {
     private val className = classOf[RFormulaModel].getName
 
     override def load(path: String): RFormulaModel = {
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
 
       val dataPath = new Path(path, "data").toString
       val data = sparkSession.read.parquet(dataPath).select("label", "terms", 
"hasIntercept").head()
@@ -502,7 +502,7 @@ private object ColumnPruner extends 
MLReadable[ColumnPruner] {
 
     override protected def saveImpl(path: String): Unit = {
       // Save metadata and Params
-      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
       // Save model data: columnsToPrune
       val data = Data(instance.columnsToPrune.toSeq)
       val dataPath = new Path(path, "data").toString
@@ -516,7 +516,7 @@ private object ColumnPruner extends 
MLReadable[ColumnPruner] {
     private val className = classOf[ColumnPruner].getName
 
     override def load(path: String): ColumnPruner = {
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
 
       val dataPath = new Path(path, "data").toString
       val data = 
sparkSession.read.parquet(dataPath).select("columnsToPrune").head()
@@ -594,7 +594,7 @@ private object VectorAttributeRewriter extends 
MLReadable[VectorAttributeRewrite
 
     override protected def saveImpl(path: String): Unit = {
       // Save metadata and Params
-      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
       // Save model data: vectorCol, prefixesToRewrite
       val data = Data(instance.vectorCol, instance.prefixesToRewrite)
       val dataPath = new Path(path, "data").toString
@@ -608,7 +608,7 @@ private object VectorAttributeRewriter extends 
MLReadable[VectorAttributeRewrite
     private val className = classOf[VectorAttributeRewriter].getName
 
     override def load(path: String): VectorAttributeRewriter = {
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
 
       val dataPath = new Path(path, "data").toString
       val data = sparkSession.read.parquet(dataPath).select("vectorCol", 
"prefixesToRewrite").head()
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala
index 0950dc55dccb..f3e068f04920 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala
@@ -284,7 +284,7 @@ object RobustScalerModel extends 
MLReadable[RobustScalerModel] {
     private case class Data(range: Vector, median: Vector)
 
     override protected def saveImpl(path: String): Unit = {
-      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
       val data = Data(instance.range, instance.median)
       val dataPath = new Path(path, "data").toString
       sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
@@ -296,7 +296,7 @@ object RobustScalerModel extends 
MLReadable[RobustScalerModel] {
     private val className = classOf[RobustScalerModel].getName
 
     override def load(path: String): RobustScalerModel = {
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
       val dataPath = new Path(path, "data").toString
       val data = sparkSession.read.parquet(dataPath)
       val Row(range: Vector, median: Vector) = MLUtils
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index c0a6392c29c3..f1e48b053d88 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -205,7 +205,7 @@ object StandardScalerModel extends 
MLReadable[StandardScalerModel] {
     private case class Data(std: Vector, mean: Vector)
 
     override protected def saveImpl(path: String): Unit = {
-      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
       val data = Data(instance.std, instance.mean)
       val dataPath = new Path(path, "data").toString
       sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
@@ -217,7 +217,7 @@ object StandardScalerModel extends 
MLReadable[StandardScalerModel] {
     private val className = classOf[StandardScalerModel].getName
 
     override def load(path: String): StandardScalerModel = {
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
       val dataPath = new Path(path, "data").toString
       val data = sparkSession.read.parquet(dataPath)
       val Row(std: Vector, mean: Vector) = 
MLUtils.convertVectorColumnsToML(data, "std", "mean")
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 2ca640445b55..94d4fa6fe6f2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -509,7 +509,7 @@ object StringIndexerModel extends 
MLReadable[StringIndexerModel] {
     private case class Data(labelsArray: Array[Array[String]])
 
     override protected def saveImpl(path: String): Unit = {
-      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
       val data = Data(instance.labelsArray)
       val dataPath = new Path(path, "data").toString
       sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
@@ -521,7 +521,7 @@ object StringIndexerModel extends 
MLReadable[StringIndexerModel] {
     private val className = classOf[StringIndexerModel].getName
 
     override def load(path: String): StringIndexerModel = {
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
       val dataPath = new Path(path, "data").toString
 
       // We support loading old `StringIndexerModel` saved by previous Spark 
versions.
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala
index 29a091012495..9c2033c28430 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala
@@ -349,7 +349,7 @@ object UnivariateFeatureSelectorModel extends 
MLReadable[UnivariateFeatureSelect
     private case class Data(selectedFeatures: Seq[Int])
 
     override protected def saveImpl(path: String): Unit = {
-      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
       val data = Data(instance.selectedFeatures.toImmutableArraySeq)
       val dataPath = new Path(path, "data").toString
       sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
@@ -363,7 +363,7 @@ object UnivariateFeatureSelectorModel extends 
MLReadable[UnivariateFeatureSelect
     private val className = classOf[UnivariateFeatureSelectorModel].getName
 
     override def load(path: String): UnivariateFeatureSelectorModel = {
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
       val dataPath = new Path(path, "data").toString
       val data = sparkSession.read.parquet(dataPath)
         .select("selectedFeatures").head()
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/VarianceThresholdSelector.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/feature/VarianceThresholdSelector.scala
index df57e19f1a72..d767e113144c 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/feature/VarianceThresholdSelector.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/feature/VarianceThresholdSelector.scala
@@ -187,7 +187,7 @@ object VarianceThresholdSelectorModel extends 
MLReadable[VarianceThresholdSelect
     private case class Data(selectedFeatures: Seq[Int])
 
     override protected def saveImpl(path: String): Unit = {
-      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
       val data = Data(instance.selectedFeatures.toImmutableArraySeq)
       val dataPath = new Path(path, "data").toString
       sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
@@ -201,7 +201,7 @@ object VarianceThresholdSelectorModel extends 
MLReadable[VarianceThresholdSelect
     private val className = classOf[VarianceThresholdSelectorModel].getName
 
     override def load(path: String): VarianceThresholdSelectorModel = {
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
       val dataPath = new Path(path, "data").toString
       val data = sparkSession.read.parquet(dataPath)
         .select("selectedFeatures").head()
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index 4fed325e19e9..ff89dee68ea3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -519,7 +519,7 @@ object VectorIndexerModel extends 
MLReadable[VectorIndexerModel] {
     private case class Data(numFeatures: Int, categoryMaps: Map[Int, 
Map[Double, Int]])
 
     override protected def saveImpl(path: String): Unit = {
-      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
       val data = Data(instance.numFeatures, instance.categoryMaps)
       val dataPath = new Path(path, "data").toString
       sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
@@ -531,7 +531,7 @@ object VectorIndexerModel extends 
MLReadable[VectorIndexerModel] {
     private val className = classOf[VectorIndexerModel].getName
 
     override def load(path: String): VectorIndexerModel = {
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
       val dataPath = new Path(path, "data").toString
       val data = sparkSession.read.parquet(dataPath)
         .select("numFeatures", "categoryMaps")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index 66b56f8b88ef..0329190a239e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -352,7 +352,7 @@ object Word2VecModel extends MLReadable[Word2VecModel] {
   class Word2VecModelWriter(instance: Word2VecModel) extends MLWriter {
 
     override protected def saveImpl(path: String): Unit = {
-      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
 
       val wordVectors = instance.wordVectors.getVectors
       val dataPath = new Path(path, "data").toString
@@ -407,7 +407,7 @@ object Word2VecModel extends MLReadable[Word2VecModel] {
       val spark = sparkSession
       import spark.implicits._
 
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
       val (major, minor) = 
VersionUtils.majorMinorVersion(metadata.sparkVersion)
 
       val dataPath = new Path(path, "data").toString
diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala 
b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
index 081a40bfbe80..d054ea8ebdb4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
@@ -336,7 +336,8 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
 
     override protected def saveImpl(path: String): Unit = {
       val extraMetadata: JObject = Map("numTrainingRecords" -> 
instance.numTrainingRecords)
-      DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata = 
Some(extraMetadata))
+      DefaultParamsWriter.saveMetadata(instance, path, sparkSession,
+        extraMetadata = Some(extraMetadata))
       val dataPath = new Path(path, "data").toString
       instance.freqItemsets.write.parquet(dataPath)
     }
@@ -349,7 +350,7 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
 
     override def load(path: String): FPGrowthModel = {
       implicit val format = DefaultFormats
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
       val (major, minor) = 
VersionUtils.majorMinorVersion(metadata.sparkVersion)
       val numTrainingRecords = if (major < 2 || (major == 2 && minor < 4)) {
         // 2.3 and before don't store the count
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala 
b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 50f94a579944..1a004f71749e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -556,7 +556,7 @@ object ALSModel extends MLReadable[ALSModel] {
 
     override protected def saveImpl(path: String): Unit = {
       val extraMetadata = "rank" -> instance.rank
-      DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
+      DefaultParamsWriter.saveMetadata(instance, path, sparkSession, 
Some(extraMetadata))
       val userPath = new Path(path, "userFactors").toString
       instance.userFactors.write.format("parquet").save(userPath)
       val itemPath = new Path(path, "itemFactors").toString
@@ -570,7 +570,7 @@ object ALSModel extends MLReadable[ALSModel] {
     private val className = classOf[ALSModel].getName
 
     override def load(path: String): ALSModel = {
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
       implicit val format = DefaultFormats
       val rank = (metadata.metadata \ "rank").extract[Int]
       val userPath = new Path(path, "userFactors").toString
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index d77d79dae4b8..6451cbf0329d 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -494,7 +494,7 @@ object AFTSurvivalRegressionModel extends 
MLReadable[AFTSurvivalRegressionModel]
 
     override protected def saveImpl(path: String): Unit = {
       // Save metadata and Params
-      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
       // Save model data: coefficients, intercept, scale
       val data = Data(instance.coefficients, instance.intercept, 
instance.scale)
       val dataPath = new Path(path, "data").toString
@@ -508,7 +508,7 @@ object AFTSurvivalRegressionModel extends 
MLReadable[AFTSurvivalRegressionModel]
     private val className = classOf[AFTSurvivalRegressionModel].getName
 
     override def load(path: String): AFTSurvivalRegressionModel = {
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
 
       val dataPath = new Path(path, "data").toString
       val data = sparkSession.read.parquet(dataPath)
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 481e8c8357f1..dace99f214b1 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -302,7 +302,7 @@ object DecisionTreeRegressionModel extends 
MLReadable[DecisionTreeRegressionMode
     override protected def saveImpl(path: String): Unit = {
       val extraMetadata: JObject = Map(
         "numFeatures" -> instance.numFeatures)
-      DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
+      DefaultParamsWriter.saveMetadata(instance, path, sparkSession, 
Some(extraMetadata))
       val (nodeData, _) = NodeData.build(instance.rootNode, 0)
       val dataPath = new Path(path, "data").toString
       val numDataParts = NodeData.inferNumPartitions(instance.numNodes)
@@ -318,7 +318,7 @@ object DecisionTreeRegressionModel extends 
MLReadable[DecisionTreeRegressionMode
 
     override def load(path: String): DecisionTreeRegressionModel = {
       implicit val format = DefaultFormats
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
       val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
       val root = loadTreeNodes(path, metadata, sparkSession)
       val model = new DecisionTreeRegressionModel(metadata.uid, root, 
numFeatures)
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala 
b/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala
index 8c797295e671..182107a443c1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala
@@ -504,7 +504,7 @@ object FMRegressionModel extends 
MLReadable[FMRegressionModel] {
         factors: Matrix)
 
     override protected def saveImpl(path: String): Unit = {
-      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
       val data = Data(instance.intercept, instance.linear, instance.factors)
       val dataPath = new Path(path, "data").toString
       sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
@@ -516,7 +516,7 @@ object FMRegressionModel extends 
MLReadable[FMRegressionModel] {
     private val className = classOf[FMRegressionModel].getName
 
     override def load(path: String): FMRegressionModel = {
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
       val dataPath = new Path(path, "data").toString
       val data = sparkSession.read.format("parquet").load(dataPath)
 
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 181a1a03e6f3..dc0b553e2c91 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -1141,7 +1141,7 @@ object GeneralizedLinearRegressionModel extends 
MLReadable[GeneralizedLinearRegr
 
     override protected def saveImpl(path: String): Unit = {
       // Save metadata and Params
-      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
       // Save model data: intercept, coefficients
       val data = Data(instance.intercept, instance.coefficients)
       val dataPath = new Path(path, "data").toString
@@ -1156,7 +1156,7 @@ object GeneralizedLinearRegressionModel extends 
MLReadable[GeneralizedLinearRegr
     private val className = classOf[GeneralizedLinearRegressionModel].getName
 
     override def load(path: String): GeneralizedLinearRegressionModel = {
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
 
       val dataPath = new Path(path, "data").toString
       val data = sparkSession.read.parquet(dataPath)
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala 
b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
index 29d8a00a4384..d624270af89d 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
@@ -301,7 +301,7 @@ object IsotonicRegressionModel extends 
MLReadable[IsotonicRegressionModel] {
 
     override protected def saveImpl(path: String): Unit = {
       // Save metadata and Params
-      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
       // Save model data: boundaries, predictions, isotonic
       val data = Data(
         instance.oldModel.boundaries, instance.oldModel.predictions, 
instance.oldModel.isotonic)
@@ -316,7 +316,7 @@ object IsotonicRegressionModel extends 
MLReadable[IsotonicRegressionModel] {
     private val className = classOf[IsotonicRegressionModel].getName
 
     override def load(path: String): IsotonicRegressionModel = {
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
 
       val dataPath = new Path(path, "data").toString
       val data = sparkSession.read.parquet(dataPath)
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala 
b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index d5dce782770b..abac9db8df02 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -780,7 +780,7 @@ private class InternalLinearRegressionModelWriter
     val instance = stage.asInstanceOf[LinearRegressionModel]
     val sc = sparkSession.sparkContext
     // Save metadata and Params
-    DefaultParamsWriter.saveMetadata(instance, path, sc)
+    DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
     // Save model data: intercept, coefficients, scale
     val data = Data(instance.intercept, instance.coefficients, instance.scale)
     val dataPath = new Path(path, "data").toString
@@ -824,7 +824,7 @@ object LinearRegressionModel extends 
MLReadable[LinearRegressionModel] {
     private val className = classOf[LinearRegressionModel].getName
 
     override def load(path: String): LinearRegressionModel = {
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
 
       val dataPath = new Path(path, "data").toString
       val data = sparkSession.read.format("parquet").load(dataPath)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index 6a7615fb149b..cdd40ae35503 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -471,7 +471,7 @@ private[ml] object EnsembleModelReadWrite {
       path: String,
       sparkSession: SparkSession,
       extraMetadata: JObject): Unit = {
-    DefaultParamsWriter.saveMetadata(instance, path, 
sparkSession.sparkContext, Some(extraMetadata))
+    DefaultParamsWriter.saveMetadata(instance, path, sparkSession, 
Some(extraMetadata))
     val treesMetadataWeights = instance.trees.zipWithIndex.map { case (tree, 
treeID) =>
       (treeID,
         DefaultParamsWriter.getMetadataToSave(tree.asInstanceOf[Params], 
sparkSession.sparkContext),
@@ -510,7 +510,7 @@ private[ml] object EnsembleModelReadWrite {
       treeClassName: String): (Metadata, Array[(Metadata, Node)], 
Array[Double]) = {
     import sparkSession.implicits._
     implicit val format = DefaultFormats
-    val metadata = DefaultParamsReader.loadMetadata(path, 
sparkSession.sparkContext, className)
+    val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, 
className)
 
     // Get impurity to construct ImpurityCalculator for each node
     val impurityType: String = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala 
b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index c127575e1470..d338c267d823 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -382,7 +382,7 @@ trait DefaultParamsReadable[T] extends MLReadable[T] {
 private[ml] class DefaultParamsWriter(instance: Params) extends MLWriter {
 
   override protected def saveImpl(path: String): Unit = {
-    DefaultParamsWriter.saveMetadata(instance, path, sc)
+    DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
   }
 }
 
@@ -403,20 +403,58 @@ private[ml] object DefaultParamsWriter {
    *                  Otherwise, all [[org.apache.spark.ml.param.Param]]s are 
encoded using
    *                  [[org.apache.spark.ml.param.Param.jsonEncode()]].
    */
+  @deprecated("use saveMetadata with SparkSession", "4.0.0")
   def saveMetadata(
       instance: Params,
       path: String,
       sc: SparkContext,
       extraMetadata: Option[JObject] = None,
-      paramMap: Option[JValue] = None): Unit = {
+      paramMap: Option[JValue] = None): Unit =
+    saveMetadata(
+      instance,
+      path,
+      SparkSession.builder().sparkContext(sc).getOrCreate(),
+      extraMetadata,
+      paramMap)
+
+  /**
+   * Saves metadata + Params to: path + "/metadata"
+   *  - class
+   *  - timestamp
+   *  - sparkVersion
+   *  - uid
+   *  - defaultParamMap
+   *  - paramMap
+   *  - (optionally, extra metadata)
+   *
+   * @param extraMetadata  Extra metadata to be saved at same level as uid, 
paramMap, etc.
+   * @param paramMap  If given, this is saved in the "paramMap" field.
+   *                  Otherwise, all [[org.apache.spark.ml.param.Param]]s are 
encoded using
+   *                  [[org.apache.spark.ml.param.Param.jsonEncode()]].
+   */
+  def saveMetadata(
+      instance: Params,
+      path: String,
+      spark: SparkSession,
+      extraMetadata: Option[JObject],
+      paramMap: Option[JValue]): Unit = {
     val metadataPath = new Path(path, "metadata").toString
-    val metadataJson = getMetadataToSave(instance, sc, extraMetadata, paramMap)
-    val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
+    val metadataJson = getMetadataToSave(instance, spark.sparkContext, 
extraMetadata, paramMap)
     // Note that we should write single file. If there are more than one row
     // it produces more partitions.
     spark.createDataFrame(Seq(Tuple1(metadataJson))).write.text(metadataPath)
   }
 
+  def saveMetadata(
+      instance: Params,
+      path: String,
+      spark: SparkSession,
+      extraMetadata: Option[JObject]): Unit =
+    saveMetadata(instance, path, spark, extraMetadata, None)
+
+  def saveMetadata(instance: Params, path: String, spark: SparkSession): Unit =
+    saveMetadata(instance, path, spark, None, None)
+
   /**
    * Helper for [[saveMetadata()]] which extracts the JSON to save.
    * This is useful for ensemble models which need to save metadata for many 
sub-models.
@@ -466,7 +504,7 @@ private[ml] object DefaultParamsWriter {
 private[ml] class DefaultParamsReader[T] extends MLReader[T] {
 
   override def load(path: String): T = {
-    val metadata = DefaultParamsReader.loadMetadata(path, sc)
+    val metadata = DefaultParamsReader.loadMetadata(path, sparkSession)
     val cls = Utils.classForName(metadata.className)
     val instance =
       
cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params]
@@ -586,13 +624,22 @@ private[ml] object DefaultParamsReader {
    * @param expectedClassName  If non empty, this is checked against the 
loaded metadata.
    * @throws IllegalArgumentException if expectedClassName is specified and 
does not match metadata
    */
-  def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = 
""): Metadata = {
+  @deprecated("use loadMetadata with SparkSession", "4.0.0")
+  def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = 
""): Metadata =
+    loadMetadata(
+      path,
+      SparkSession.builder().sparkContext(sc).getOrCreate(),
+      expectedClassName)
+
+  def loadMetadata(path: String, spark: SparkSession, expectedClassName: 
String): Metadata = {
     val metadataPath = new Path(path, "metadata").toString
-    val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
     val metadataStr = spark.read.text(metadataPath).first().getString(0)
     parseMetadata(metadataStr, expectedClassName)
   }
 
+  def loadMetadata(path: String, spark: SparkSession): Metadata =
+    loadMetadata(path, spark, "")
+
   /**
    * Parse metadata JSON string produced by 
[[DefaultParamsWriter.getMetadataToSave()]].
    * This is a helper function for [[loadMetadata()]].


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to