This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 8597b78271fc [SPARK-48988][ML] Make `DefaultParamsReader/Writer`
handle metadata with spark session
8597b78271fc is described below
commit 8597b78271fcc29276d611186455695636b7b503
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Jul 24 19:13:32 2024 +0900
[SPARK-48988][ML] Make `DefaultParamsReader/Writer` handle metadata with
spark session
### What changes were proposed in this pull request?
`DefaultParamsReader/Writer` handle metadata with spark session
### Why are the changes needed?
In existing ml implementations, when loading/saving a model, it loads/saves
the metadata with `SparkContext` then loads/saves the coefficients with
`SparkSession`.
This PR aims to also load/save the metadata with `SparkSession`, by
introducing new helper functions.
- Note I: 3-rd libraries (e.g.
[xgboost](https://github.com/dmlc/xgboost/blob/master/jvm-packages/xgboost4j-spark/src/main/scala/org/apache/spark/ml/util/XGBoostReadWrite.scala#L38-L53)
) likely depends on existing implementation of saveMetadata/loadMetadata, so
we cannot simply remove them even though they are `private[ml]`.
- Note II: this PR only handles `loadMetadata` and `saveMetadata`, there
are similar cases for meta algorithms and param read/write, but I want to
ignore the remaining part first, to avoid touching too many files in single PR.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
CI
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #47467 from zhengruifeng/ml_load_with_spark.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../ml/classification/DecisionTreeClassifier.scala | 4 +-
.../spark/ml/classification/FMClassifier.scala | 4 +-
.../apache/spark/ml/classification/LinearSVC.scala | 4 +-
.../ml/classification/LogisticRegression.scala | 4 +-
.../MultilayerPerceptronClassifier.scala | 4 +-
.../spark/ml/classification/NaiveBayes.scala | 4 +-
.../spark/ml/clustering/BisectingKMeans.scala | 4 +-
.../spark/ml/clustering/GaussianMixture.scala | 4 +-
.../org/apache/spark/ml/clustering/KMeans.scala | 5 +-
.../scala/org/apache/spark/ml/clustering/LDA.scala | 10 ++--
.../ml/feature/BucketedRandomProjectionLSH.scala | 4 +-
.../apache/spark/ml/feature/ChiSqSelector.scala | 4 +-
.../apache/spark/ml/feature/CountVectorizer.scala | 4 +-
.../org/apache/spark/ml/feature/HashingTF.scala | 2 +-
.../scala/org/apache/spark/ml/feature/IDF.scala | 4 +-
.../org/apache/spark/ml/feature/Imputer.scala | 4 +-
.../org/apache/spark/ml/feature/MaxAbsScaler.scala | 4 +-
.../org/apache/spark/ml/feature/MinHashLSH.scala | 4 +-
.../org/apache/spark/ml/feature/MinMaxScaler.scala | 4 +-
.../apache/spark/ml/feature/OneHotEncoder.scala | 4 +-
.../scala/org/apache/spark/ml/feature/PCA.scala | 4 +-
.../org/apache/spark/ml/feature/RFormula.scala | 12 ++---
.../org/apache/spark/ml/feature/RobustScaler.scala | 4 +-
.../apache/spark/ml/feature/StandardScaler.scala | 4 +-
.../apache/spark/ml/feature/StringIndexer.scala | 4 +-
.../ml/feature/UnivariateFeatureSelector.scala | 4 +-
.../ml/feature/VarianceThresholdSelector.scala | 4 +-
.../apache/spark/ml/feature/VectorIndexer.scala | 4 +-
.../org/apache/spark/ml/feature/Word2Vec.scala | 4 +-
.../scala/org/apache/spark/ml/fpm/FPGrowth.scala | 5 +-
.../org/apache/spark/ml/recommendation/ALS.scala | 4 +-
.../ml/regression/AFTSurvivalRegression.scala | 4 +-
.../ml/regression/DecisionTreeRegressor.scala | 4 +-
.../apache/spark/ml/regression/FMRegressor.scala | 4 +-
.../regression/GeneralizedLinearRegression.scala | 4 +-
.../spark/ml/regression/IsotonicRegression.scala | 4 +-
.../spark/ml/regression/LinearRegression.scala | 4 +-
.../org/apache/spark/ml/tree/treeModels.scala | 4 +-
.../scala/org/apache/spark/ml/util/ReadWrite.scala | 61 +++++++++++++++++++---
39 files changed, 137 insertions(+), 90 deletions(-)
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 7deefda2eeaf..c5f1d7f39b6b 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -293,7 +293,7 @@ object DecisionTreeClassificationModel extends
MLReadable[DecisionTreeClassifica
val extraMetadata: JObject = Map(
"numFeatures" -> instance.numFeatures,
"numClasses" -> instance.numClasses)
- DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession,
Some(extraMetadata))
val (nodeData, _) = NodeData.build(instance.rootNode, 0)
val dataPath = new Path(path, "data").toString
val numDataParts = NodeData.inferNumPartitions(instance.numNodes)
@@ -309,7 +309,7 @@ object DecisionTreeClassificationModel extends
MLReadable[DecisionTreeClassifica
override def load(path: String): DecisionTreeClassificationModel = {
implicit val format = DefaultFormats
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val numClasses = (metadata.metadata \ "numClasses").extract[Int]
val root = loadTreeNodes(path, metadata, sparkSession)
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala
b/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala
index aec740a932ac..33e7c1fdd5e0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala
@@ -339,7 +339,7 @@ object FMClassificationModel extends
MLReadable[FMClassificationModel] {
factors: Matrix)
override protected def saveImpl(path: String): Unit = {
- DefaultParamsWriter.saveMetadata(instance, path, sc)
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.intercept, instance.linear, instance.factors)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
@@ -351,7 +351,7 @@ object FMClassificationModel extends
MLReadable[FMClassificationModel] {
private val className = classOf[FMClassificationModel].getName
override def load(path: String): FMClassificationModel = {
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.format("parquet").load(dataPath)
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
index 3e27f781d561..161e8f4cbd2c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
@@ -446,7 +446,7 @@ object LinearSVCModel extends MLReadable[LinearSVCModel] {
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
- DefaultParamsWriter.saveMetadata(instance, path, sc)
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.coefficients, instance.intercept)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
@@ -459,7 +459,7 @@ object LinearSVCModel extends MLReadable[LinearSVCModel] {
private val className = classOf[LinearSVCModel].getName
override def load(path: String): LinearSVCModel = {
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.format("parquet").load(dataPath)
val Row(coefficients: Vector, intercept: Double) =
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index ac0682f1df5b..745cb61bb7aa 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -1310,7 +1310,7 @@ object LogisticRegressionModel extends
MLReadable[LogisticRegressionModel] {
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
- DefaultParamsWriter.saveMetadata(instance, path, sc)
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
// Save model data: numClasses, numFeatures, intercept, coefficients
val data = Data(instance.numClasses, instance.numFeatures,
instance.interceptVector,
instance.coefficientMatrix, instance.isMultinomial)
@@ -1325,7 +1325,7 @@ object LogisticRegressionModel extends
MLReadable[LogisticRegressionModel] {
private val className = classOf[LogisticRegressionModel].getName
override def load(path: String): LogisticRegressionModel = {
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val (major, minor) =
VersionUtils.majorMinorVersion(metadata.sparkVersion)
val dataPath = new Path(path, "data").toString
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
index 16984bf9aed8..106282b9dc3a 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
@@ -365,7 +365,7 @@ object MultilayerPerceptronClassificationModel
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
- DefaultParamsWriter.saveMetadata(instance, path, sc)
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
// Save model data: weights
val data = Data(instance.weights)
val dataPath = new Path(path, "data").toString
@@ -380,7 +380,7 @@ object MultilayerPerceptronClassificationModel
private val className =
classOf[MultilayerPerceptronClassificationModel].getName
override def load(path: String): MultilayerPerceptronClassificationModel =
{
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val (majorVersion, _) = majorMinorVersion(metadata.sparkVersion)
val dataPath = new Path(path, "data").toString
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index 52486cb8aa24..4a511581d31a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -580,7 +580,7 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
- DefaultParamsWriter.saveMetadata(instance, path, sc)
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val dataPath = new Path(path, "data").toString
instance.getModelType match {
@@ -602,7 +602,7 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
override def load(path: String): NaiveBayesModel = {
implicit val format = DefaultFormats
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val (major, minor) =
VersionUtils.majorMinorVersion(metadata.sparkVersion)
val dataPath = new Path(path, "data").toString
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
index 2d809151384b..b4f1565362b0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
@@ -186,7 +186,7 @@ object BisectingKMeansModel extends
MLReadable[BisectingKMeansModel] {
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
- DefaultParamsWriter.saveMetadata(instance, path, sc)
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val dataPath = new Path(path, "data").toString
instance.parentModel.save(sc, dataPath)
}
@@ -198,7 +198,7 @@ object BisectingKMeansModel extends
MLReadable[BisectingKMeansModel] {
private val className = classOf[BisectingKMeansModel].getName
override def load(path: String): BisectingKMeansModel = {
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
val mllibModel = MLlibBisectingKMeansModel.load(sc, dataPath)
val model = new BisectingKMeansModel(metadata.uid, mllibModel)
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
index 0f6648bb4cda..d0db5dcba87b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -235,7 +235,7 @@ object GaussianMixtureModel extends
MLReadable[GaussianMixtureModel] {
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
- DefaultParamsWriter.saveMetadata(instance, path, sc)
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
// Save model data: weights and gaussians
val weights = instance.weights
val gaussians = instance.gaussians
@@ -253,7 +253,7 @@ object GaussianMixtureModel extends
MLReadable[GaussianMixtureModel] {
private val className = classOf[GaussianMixtureModel].getName
override def load(path: String): GaussianMixtureModel = {
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
val row = sparkSession.read.parquet(dataPath).select("weights", "mus",
"sigmas").head()
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 04f76660aee6..50fb18bb620a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -219,9 +219,8 @@ private class InternalKMeansModelWriter extends
MLWriterFormat with MLFormatRegi
override def write(path: String, sparkSession: SparkSession,
optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = {
val instance = stage.asInstanceOf[KMeansModel]
- val sc = sparkSession.sparkContext
// Save metadata and Params
- DefaultParamsWriter.saveMetadata(instance, path, sc)
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
// Save model data: cluster centers
val data: Array[ClusterData] = instance.clusterCenters.zipWithIndex.map {
case (center, idx) =>
@@ -272,7 +271,7 @@ object KMeansModel extends MLReadable[KMeansModel] {
val sparkSession = super.sparkSession
import sparkSession.implicits._
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
val clusterCenters = if (majorVersion(metadata.sparkVersion) >= 2) {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index 7cbfc732a19c..b3d3c84db051 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -654,7 +654,7 @@ object LocalLDAModel extends MLReadable[LocalLDAModel] {
gammaShape: Double)
override protected def saveImpl(path: String): Unit = {
- DefaultParamsWriter.saveMetadata(instance, path, sc)
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val oldModel = instance.oldLocalModel
val data = Data(instance.vocabSize, oldModel.topicsMatrix,
oldModel.docConcentration,
oldModel.topicConcentration, oldModel.gammaShape)
@@ -668,7 +668,7 @@ object LocalLDAModel extends MLReadable[LocalLDAModel] {
private val className = classOf[LocalLDAModel].getName
override def load(path: String): LocalLDAModel = {
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath)
val vectorConverted = MLUtils.convertVectorColumnsToML(data,
"docConcentration")
@@ -809,7 +809,7 @@ object DistributedLDAModel extends
MLReadable[DistributedLDAModel] {
class DistributedWriter(instance: DistributedLDAModel) extends MLWriter {
override protected def saveImpl(path: String): Unit = {
- DefaultParamsWriter.saveMetadata(instance, path, sc)
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val modelPath = new Path(path, "oldModel").toString
instance.oldDistributedModel.save(sc, modelPath)
}
@@ -820,7 +820,7 @@ object DistributedLDAModel extends
MLReadable[DistributedLDAModel] {
private val className = classOf[DistributedLDAModel].getName
override def load(path: String): DistributedLDAModel = {
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val modelPath = new Path(path, "oldModel").toString
val oldModel = OldDistributedLDAModel.load(sc, modelPath)
val model = new DistributedLDAModel(metadata.uid, oldModel.vocabSize,
@@ -1008,7 +1008,7 @@ object LDA extends MLReadable[LDA] {
private val className = classOf[LDA].getName
override def load(path: String): LDA = {
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val model = new LDA(metadata.uid)
LDAParams.getAndSetParams(model, metadata)
model
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
index d30962088cb8..537cb5020c88 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
@@ -227,7 +227,7 @@ object BucketedRandomProjectionLSHModel extends
MLReadable[BucketedRandomProject
private case class Data(randUnitVectors: Matrix)
override protected def saveImpl(path: String): Unit = {
- DefaultParamsWriter.saveMetadata(instance, path, sc)
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.randMatrix)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
@@ -241,7 +241,7 @@ object BucketedRandomProjectionLSHModel extends
MLReadable[BucketedRandomProject
private val className = classOf[BucketedRandomProjectionLSHModel].getName
override def load(path: String): BucketedRandomProjectionLSHModel = {
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath)
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
index 3062f643e950..eb2122b09b2f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
@@ -173,7 +173,7 @@ object ChiSqSelectorModel extends
MLReadable[ChiSqSelectorModel] {
private case class Data(selectedFeatures: Seq[Int])
override protected def saveImpl(path: String): Unit = {
- DefaultParamsWriter.saveMetadata(instance, path, sc)
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.selectedFeatures.toImmutableArraySeq)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
@@ -185,7 +185,7 @@ object ChiSqSelectorModel extends
MLReadable[ChiSqSelectorModel] {
private val className = classOf[ChiSqSelectorModel].getName
override def load(path: String): ChiSqSelectorModel = {
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
val data =
sparkSession.read.parquet(dataPath).select("selectedFeatures").head()
val selectedFeatures = data.getAs[Seq[Int]](0).toArray
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
index b81914f86fbb..611b5c710add 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
@@ -372,7 +372,7 @@ object CountVectorizerModel extends
MLReadable[CountVectorizerModel] {
private case class Data(vocabulary: Seq[String])
override protected def saveImpl(path: String): Unit = {
- DefaultParamsWriter.saveMetadata(instance, path, sc)
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.vocabulary.toImmutableArraySeq)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
@@ -384,7 +384,7 @@ object CountVectorizerModel extends
MLReadable[CountVectorizerModel] {
private val className = classOf[CountVectorizerModel].getName
override def load(path: String): CountVectorizerModel = {
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath)
.select("vocabulary")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
index f4223bc85943..3b42105958c7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
@@ -154,7 +154,7 @@ object HashingTF extends DefaultParamsReadable[HashingTF] {
private val className = classOf[HashingTF].getName
override def load(path: String): HashingTF = {
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
// We support loading old `HashingTF` saved by previous Spark versions.
// Previous `HashingTF` uses `mllib.feature.HashingTF.murmur3Hash`, but
new `HashingTF` uses
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
index 696e1516582d..3025a7b04af5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
@@ -198,7 +198,7 @@ object IDFModel extends MLReadable[IDFModel] {
private case class Data(idf: Vector, docFreq: Array[Long], numDocs: Long)
override protected def saveImpl(path: String): Unit = {
- DefaultParamsWriter.saveMetadata(instance, path, sc)
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.idf, instance.docFreq, instance.numDocs)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
@@ -210,7 +210,7 @@ object IDFModel extends MLReadable[IDFModel] {
private val className = classOf[IDFModel].getName
override def load(path: String): IDFModel = {
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
index ae65b17d7a81..38fb25903dca 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
@@ -308,7 +308,7 @@ object ImputerModel extends MLReadable[ImputerModel] {
private[ImputerModel] class ImputerModelWriter(instance: ImputerModel)
extends MLWriter {
override protected def saveImpl(path: String): Unit = {
- DefaultParamsWriter.saveMetadata(instance, path, sc)
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val dataPath = new Path(path, "data").toString
instance.surrogateDF.repartition(1).write.parquet(dataPath)
}
@@ -319,7 +319,7 @@ object ImputerModel extends MLReadable[ImputerModel] {
private val className = classOf[ImputerModel].getName
override def load(path: String): ImputerModel = {
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
val surrogateDF = sqlContext.read.parquet(dataPath)
val model = new ImputerModel(metadata.uid, surrogateDF)
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
index 05ee59d1627d..1a378cd85f3e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
@@ -162,7 +162,7 @@ object MaxAbsScalerModel extends
MLReadable[MaxAbsScalerModel] {
private case class Data(maxAbs: Vector)
override protected def saveImpl(path: String): Unit = {
- DefaultParamsWriter.saveMetadata(instance, path, sc)
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = new Data(instance.maxAbs)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
@@ -174,7 +174,7 @@ object MaxAbsScalerModel extends
MLReadable[MaxAbsScalerModel] {
private val className = classOf[MaxAbsScalerModel].getName
override def load(path: String): MaxAbsScalerModel = {
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
val Row(maxAbs: Vector) = sparkSession.read.parquet(dataPath)
.select("maxAbs")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
index d94aadd1ce1f..3f2a3327128a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
@@ -220,7 +220,7 @@ object MinHashLSHModel extends MLReadable[MinHashLSHModel] {
private case class Data(randCoefficients: Array[Int])
override protected def saveImpl(path: String): Unit = {
- DefaultParamsWriter.saveMetadata(instance, path, sc)
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.randCoefficients.flatMap(tuple =>
Array(tuple._1, tuple._2)))
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
@@ -233,7 +233,7 @@ object MinHashLSHModel extends MLReadable[MinHashLSHModel] {
private val className = classOf[MinHashLSHModel].getName
override def load(path: String): MinHashLSHModel = {
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
val data =
sparkSession.read.parquet(dataPath).select("randCoefficients").head()
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
index 4111e559a5c2..c311f4260424 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
@@ -247,7 +247,7 @@ object MinMaxScalerModel extends
MLReadable[MinMaxScalerModel] {
private case class Data(originalMin: Vector, originalMax: Vector)
override protected def saveImpl(path: String): Unit = {
- DefaultParamsWriter.saveMetadata(instance, path, sc)
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = new Data(instance.originalMin, instance.originalMax)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
@@ -259,7 +259,7 @@ object MinMaxScalerModel extends
MLReadable[MinMaxScalerModel] {
private val className = classOf[MinMaxScalerModel].getName
override def load(path: String): MinMaxScalerModel = {
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath)
val Row(originalMin: Vector, originalMax: Vector) =
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
index e7cf0105754a..823f767eebbe 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
@@ -403,7 +403,7 @@ object OneHotEncoderModel extends
MLReadable[OneHotEncoderModel] {
private case class Data(categorySizes: Array[Int])
override protected def saveImpl(path: String): Unit = {
- DefaultParamsWriter.saveMetadata(instance, path, sc)
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.categorySizes)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
@@ -415,7 +415,7 @@ object OneHotEncoderModel extends
MLReadable[OneHotEncoderModel] {
private val className = classOf[OneHotEncoderModel].getName
override def load(path: String): OneHotEncoderModel = {
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath)
.select("categorySizes")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
index f7ec18b38a0e..0bd9a3c38a1e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
@@ -184,7 +184,7 @@ object PCAModel extends MLReadable[PCAModel] {
private case class Data(pc: DenseMatrix, explainedVariance: DenseVector)
override protected def saveImpl(path: String): Unit = {
- DefaultParamsWriter.saveMetadata(instance, path, sc)
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.pc, instance.explainedVariance)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
@@ -205,7 +205,7 @@ object PCAModel extends MLReadable[PCAModel] {
* @return a [[PCAModel]]
*/
override def load(path: String): PCAModel = {
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
val model = if (majorVersion(metadata.sparkVersion) >= 2) {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index 7a47e73e5ef4..77bd18423ef1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -427,7 +427,7 @@ object RFormulaModel extends MLReadable[RFormulaModel] {
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
- DefaultParamsWriter.saveMetadata(instance, path, sc)
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
// Save model data: resolvedFormula
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(instance.resolvedFormula))
@@ -444,7 +444,7 @@ object RFormulaModel extends MLReadable[RFormulaModel] {
private val className = classOf[RFormulaModel].getName
override def load(path: String): RFormulaModel = {
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath).select("label", "terms",
"hasIntercept").head()
@@ -502,7 +502,7 @@ private object ColumnPruner extends
MLReadable[ColumnPruner] {
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
- DefaultParamsWriter.saveMetadata(instance, path, sc)
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
// Save model data: columnsToPrune
val data = Data(instance.columnsToPrune.toSeq)
val dataPath = new Path(path, "data").toString
@@ -516,7 +516,7 @@ private object ColumnPruner extends
MLReadable[ColumnPruner] {
private val className = classOf[ColumnPruner].getName
override def load(path: String): ColumnPruner = {
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
val data =
sparkSession.read.parquet(dataPath).select("columnsToPrune").head()
@@ -594,7 +594,7 @@ private object VectorAttributeRewriter extends
MLReadable[VectorAttributeRewrite
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
- DefaultParamsWriter.saveMetadata(instance, path, sc)
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
// Save model data: vectorCol, prefixesToRewrite
val data = Data(instance.vectorCol, instance.prefixesToRewrite)
val dataPath = new Path(path, "data").toString
@@ -608,7 +608,7 @@ private object VectorAttributeRewriter extends
MLReadable[VectorAttributeRewrite
private val className = classOf[VectorAttributeRewriter].getName
override def load(path: String): VectorAttributeRewriter = {
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath).select("vectorCol",
"prefixesToRewrite").head()
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala
index 0950dc55dccb..f3e068f04920 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala
@@ -284,7 +284,7 @@ object RobustScalerModel extends
MLReadable[RobustScalerModel] {
private case class Data(range: Vector, median: Vector)
override protected def saveImpl(path: String): Unit = {
- DefaultParamsWriter.saveMetadata(instance, path, sc)
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.range, instance.median)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
@@ -296,7 +296,7 @@ object RobustScalerModel extends
MLReadable[RobustScalerModel] {
private val className = classOf[RobustScalerModel].getName
override def load(path: String): RobustScalerModel = {
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath)
val Row(range: Vector, median: Vector) = MLUtils
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index c0a6392c29c3..f1e48b053d88 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -205,7 +205,7 @@ object StandardScalerModel extends
MLReadable[StandardScalerModel] {
private case class Data(std: Vector, mean: Vector)
override protected def saveImpl(path: String): Unit = {
- DefaultParamsWriter.saveMetadata(instance, path, sc)
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.std, instance.mean)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
@@ -217,7 +217,7 @@ object StandardScalerModel extends
MLReadable[StandardScalerModel] {
private val className = classOf[StandardScalerModel].getName
override def load(path: String): StandardScalerModel = {
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath)
val Row(std: Vector, mean: Vector) =
MLUtils.convertVectorColumnsToML(data, "std", "mean")
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 2ca640445b55..94d4fa6fe6f2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -509,7 +509,7 @@ object StringIndexerModel extends
MLReadable[StringIndexerModel] {
private case class Data(labelsArray: Array[Array[String]])
override protected def saveImpl(path: String): Unit = {
- DefaultParamsWriter.saveMetadata(instance, path, sc)
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.labelsArray)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
@@ -521,7 +521,7 @@ object StringIndexerModel extends
MLReadable[StringIndexerModel] {
private val className = classOf[StringIndexerModel].getName
override def load(path: String): StringIndexerModel = {
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
// We support loading old `StringIndexerModel` saved by previous Spark
versions.
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala
index 29a091012495..9c2033c28430 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala
@@ -349,7 +349,7 @@ object UnivariateFeatureSelectorModel extends
MLReadable[UnivariateFeatureSelect
private case class Data(selectedFeatures: Seq[Int])
override protected def saveImpl(path: String): Unit = {
- DefaultParamsWriter.saveMetadata(instance, path, sc)
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.selectedFeatures.toImmutableArraySeq)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
@@ -363,7 +363,7 @@ object UnivariateFeatureSelectorModel extends
MLReadable[UnivariateFeatureSelect
private val className = classOf[UnivariateFeatureSelectorModel].getName
override def load(path: String): UnivariateFeatureSelectorModel = {
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath)
.select("selectedFeatures").head()
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/VarianceThresholdSelector.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/VarianceThresholdSelector.scala
index df57e19f1a72..d767e113144c 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/feature/VarianceThresholdSelector.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/feature/VarianceThresholdSelector.scala
@@ -187,7 +187,7 @@ object VarianceThresholdSelectorModel extends
MLReadable[VarianceThresholdSelect
private case class Data(selectedFeatures: Seq[Int])
override protected def saveImpl(path: String): Unit = {
- DefaultParamsWriter.saveMetadata(instance, path, sc)
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.selectedFeatures.toImmutableArraySeq)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
@@ -201,7 +201,7 @@ object VarianceThresholdSelectorModel extends
MLReadable[VarianceThresholdSelect
private val className = classOf[VarianceThresholdSelectorModel].getName
override def load(path: String): VarianceThresholdSelectorModel = {
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath)
.select("selectedFeatures").head()
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index 4fed325e19e9..ff89dee68ea3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -519,7 +519,7 @@ object VectorIndexerModel extends
MLReadable[VectorIndexerModel] {
private case class Data(numFeatures: Int, categoryMaps: Map[Int,
Map[Double, Int]])
override protected def saveImpl(path: String): Unit = {
- DefaultParamsWriter.saveMetadata(instance, path, sc)
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.numFeatures, instance.categoryMaps)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
@@ -531,7 +531,7 @@ object VectorIndexerModel extends
MLReadable[VectorIndexerModel] {
private val className = classOf[VectorIndexerModel].getName
override def load(path: String): VectorIndexerModel = {
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath)
.select("numFeatures", "categoryMaps")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index 66b56f8b88ef..0329190a239e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -352,7 +352,7 @@ object Word2VecModel extends MLReadable[Word2VecModel] {
class Word2VecModelWriter(instance: Word2VecModel) extends MLWriter {
override protected def saveImpl(path: String): Unit = {
- DefaultParamsWriter.saveMetadata(instance, path, sc)
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val wordVectors = instance.wordVectors.getVectors
val dataPath = new Path(path, "data").toString
@@ -407,7 +407,7 @@ object Word2VecModel extends MLReadable[Word2VecModel] {
val spark = sparkSession
import spark.implicits._
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val (major, minor) =
VersionUtils.majorMinorVersion(metadata.sparkVersion)
val dataPath = new Path(path, "data").toString
diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
index 081a40bfbe80..d054ea8ebdb4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
@@ -336,7 +336,8 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
override protected def saveImpl(path: String): Unit = {
val extraMetadata: JObject = Map("numTrainingRecords" ->
instance.numTrainingRecords)
- DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata =
Some(extraMetadata))
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession,
+ extraMetadata = Some(extraMetadata))
val dataPath = new Path(path, "data").toString
instance.freqItemsets.write.parquet(dataPath)
}
@@ -349,7 +350,7 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
override def load(path: String): FPGrowthModel = {
implicit val format = DefaultFormats
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val (major, minor) =
VersionUtils.majorMinorVersion(metadata.sparkVersion)
val numTrainingRecords = if (major < 2 || (major == 2 && minor < 4)) {
// 2.3 and before don't store the count
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 50f94a579944..1a004f71749e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -556,7 +556,7 @@ object ALSModel extends MLReadable[ALSModel] {
override protected def saveImpl(path: String): Unit = {
val extraMetadata = "rank" -> instance.rank
- DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession,
Some(extraMetadata))
val userPath = new Path(path, "userFactors").toString
instance.userFactors.write.format("parquet").save(userPath)
val itemPath = new Path(path, "itemFactors").toString
@@ -570,7 +570,7 @@ object ALSModel extends MLReadable[ALSModel] {
private val className = classOf[ALSModel].getName
override def load(path: String): ALSModel = {
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
implicit val format = DefaultFormats
val rank = (metadata.metadata \ "rank").extract[Int]
val userPath = new Path(path, "userFactors").toString
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index d77d79dae4b8..6451cbf0329d 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -494,7 +494,7 @@ object AFTSurvivalRegressionModel extends
MLReadable[AFTSurvivalRegressionModel]
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
- DefaultParamsWriter.saveMetadata(instance, path, sc)
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
// Save model data: coefficients, intercept, scale
val data = Data(instance.coefficients, instance.intercept,
instance.scale)
val dataPath = new Path(path, "data").toString
@@ -508,7 +508,7 @@ object AFTSurvivalRegressionModel extends
MLReadable[AFTSurvivalRegressionModel]
private val className = classOf[AFTSurvivalRegressionModel].getName
override def load(path: String): AFTSurvivalRegressionModel = {
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath)
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 481e8c8357f1..dace99f214b1 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -302,7 +302,7 @@ object DecisionTreeRegressionModel extends
MLReadable[DecisionTreeRegressionMode
override protected def saveImpl(path: String): Unit = {
val extraMetadata: JObject = Map(
"numFeatures" -> instance.numFeatures)
- DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession,
Some(extraMetadata))
val (nodeData, _) = NodeData.build(instance.rootNode, 0)
val dataPath = new Path(path, "data").toString
val numDataParts = NodeData.inferNumPartitions(instance.numNodes)
@@ -318,7 +318,7 @@ object DecisionTreeRegressionModel extends
MLReadable[DecisionTreeRegressionMode
override def load(path: String): DecisionTreeRegressionModel = {
implicit val format = DefaultFormats
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val root = loadTreeNodes(path, metadata, sparkSession)
val model = new DecisionTreeRegressionModel(metadata.uid, root,
numFeatures)
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala
index 8c797295e671..182107a443c1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala
@@ -504,7 +504,7 @@ object FMRegressionModel extends
MLReadable[FMRegressionModel] {
factors: Matrix)
override protected def saveImpl(path: String): Unit = {
- DefaultParamsWriter.saveMetadata(instance, path, sc)
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
val data = Data(instance.intercept, instance.linear, instance.factors)
val dataPath = new Path(path, "data").toString
sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath)
@@ -516,7 +516,7 @@ object FMRegressionModel extends
MLReadable[FMRegressionModel] {
private val className = classOf[FMRegressionModel].getName
override def load(path: String): FMRegressionModel = {
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.format("parquet").load(dataPath)
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 181a1a03e6f3..dc0b553e2c91 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -1141,7 +1141,7 @@ object GeneralizedLinearRegressionModel extends
MLReadable[GeneralizedLinearRegr
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
- DefaultParamsWriter.saveMetadata(instance, path, sc)
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
// Save model data: intercept, coefficients
val data = Data(instance.intercept, instance.coefficients)
val dataPath = new Path(path, "data").toString
@@ -1156,7 +1156,7 @@ object GeneralizedLinearRegressionModel extends
MLReadable[GeneralizedLinearRegr
private val className = classOf[GeneralizedLinearRegressionModel].getName
override def load(path: String): GeneralizedLinearRegressionModel = {
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath)
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
index 29d8a00a4384..d624270af89d 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
@@ -301,7 +301,7 @@ object IsotonicRegressionModel extends
MLReadable[IsotonicRegressionModel] {
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
- DefaultParamsWriter.saveMetadata(instance, path, sc)
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
// Save model data: boundaries, predictions, isotonic
val data = Data(
instance.oldModel.boundaries, instance.oldModel.predictions,
instance.oldModel.isotonic)
@@ -316,7 +316,7 @@ object IsotonicRegressionModel extends
MLReadable[IsotonicRegressionModel] {
private val className = classOf[IsotonicRegressionModel].getName
override def load(path: String): IsotonicRegressionModel = {
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath)
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index d5dce782770b..abac9db8df02 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -780,7 +780,7 @@ private class InternalLinearRegressionModelWriter
val instance = stage.asInstanceOf[LinearRegressionModel]
val sc = sparkSession.sparkContext
// Save metadata and Params
- DefaultParamsWriter.saveMetadata(instance, path, sc)
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
// Save model data: intercept, coefficients, scale
val data = Data(instance.intercept, instance.coefficients, instance.scale)
val dataPath = new Path(path, "data").toString
@@ -824,7 +824,7 @@ object LinearRegressionModel extends
MLReadable[LinearRegressionModel] {
private val className = classOf[LinearRegressionModel].getName
override def load(path: String): LinearRegressionModel = {
- val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.format("parquet").load(dataPath)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index 6a7615fb149b..cdd40ae35503 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -471,7 +471,7 @@ private[ml] object EnsembleModelReadWrite {
path: String,
sparkSession: SparkSession,
extraMetadata: JObject): Unit = {
- DefaultParamsWriter.saveMetadata(instance, path,
sparkSession.sparkContext, Some(extraMetadata))
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession,
Some(extraMetadata))
val treesMetadataWeights = instance.trees.zipWithIndex.map { case (tree,
treeID) =>
(treeID,
DefaultParamsWriter.getMetadataToSave(tree.asInstanceOf[Params],
sparkSession.sparkContext),
@@ -510,7 +510,7 @@ private[ml] object EnsembleModelReadWrite {
treeClassName: String): (Metadata, Array[(Metadata, Node)],
Array[Double]) = {
import sparkSession.implicits._
implicit val format = DefaultFormats
- val metadata = DefaultParamsReader.loadMetadata(path,
sparkSession.sparkContext, className)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession,
className)
// Get impurity to construct ImpurityCalculator for each node
val impurityType: String = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index c127575e1470..d338c267d823 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -382,7 +382,7 @@ trait DefaultParamsReadable[T] extends MLReadable[T] {
private[ml] class DefaultParamsWriter(instance: Params) extends MLWriter {
override protected def saveImpl(path: String): Unit = {
- DefaultParamsWriter.saveMetadata(instance, path, sc)
+ DefaultParamsWriter.saveMetadata(instance, path, sparkSession)
}
}
@@ -403,20 +403,58 @@ private[ml] object DefaultParamsWriter {
* Otherwise, all [[org.apache.spark.ml.param.Param]]s are
encoded using
* [[org.apache.spark.ml.param.Param.jsonEncode()]].
*/
+ @deprecated("use saveMetadata with SparkSession", "4.0.0")
def saveMetadata(
instance: Params,
path: String,
sc: SparkContext,
extraMetadata: Option[JObject] = None,
- paramMap: Option[JValue] = None): Unit = {
+ paramMap: Option[JValue] = None): Unit =
+ saveMetadata(
+ instance,
+ path,
+ SparkSession.builder().sparkContext(sc).getOrCreate(),
+ extraMetadata,
+ paramMap)
+
+ /**
+ * Saves metadata + Params to: path + "/metadata"
+ * - class
+ * - timestamp
+ * - sparkVersion
+ * - uid
+ * - defaultParamMap
+ * - paramMap
+ * - (optionally, extra metadata)
+ *
+ * @param extraMetadata Extra metadata to be saved at same level as uid,
paramMap, etc.
+ * @param paramMap If given, this is saved in the "paramMap" field.
+ * Otherwise, all [[org.apache.spark.ml.param.Param]]s are
encoded using
+ * [[org.apache.spark.ml.param.Param.jsonEncode()]].
+ */
+ def saveMetadata(
+ instance: Params,
+ path: String,
+ spark: SparkSession,
+ extraMetadata: Option[JObject],
+ paramMap: Option[JValue]): Unit = {
val metadataPath = new Path(path, "metadata").toString
- val metadataJson = getMetadataToSave(instance, sc, extraMetadata, paramMap)
- val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
+ val metadataJson = getMetadataToSave(instance, spark.sparkContext,
extraMetadata, paramMap)
// Note that we should write single file. If there are more than one row
// it produces more partitions.
spark.createDataFrame(Seq(Tuple1(metadataJson))).write.text(metadataPath)
}
+ def saveMetadata(
+ instance: Params,
+ path: String,
+ spark: SparkSession,
+ extraMetadata: Option[JObject]): Unit =
+ saveMetadata(instance, path, spark, extraMetadata, None)
+
+ def saveMetadata(instance: Params, path: String, spark: SparkSession): Unit =
+ saveMetadata(instance, path, spark, None, None)
+
/**
* Helper for [[saveMetadata()]] which extracts the JSON to save.
* This is useful for ensemble models which need to save metadata for many
sub-models.
@@ -466,7 +504,7 @@ private[ml] object DefaultParamsWriter {
private[ml] class DefaultParamsReader[T] extends MLReader[T] {
override def load(path: String): T = {
- val metadata = DefaultParamsReader.loadMetadata(path, sc)
+ val metadata = DefaultParamsReader.loadMetadata(path, sparkSession)
val cls = Utils.classForName(metadata.className)
val instance =
cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params]
@@ -586,13 +624,22 @@ private[ml] object DefaultParamsReader {
* @param expectedClassName If non empty, this is checked against the
loaded metadata.
* @throws IllegalArgumentException if expectedClassName is specified and
does not match metadata
*/
- def loadMetadata(path: String, sc: SparkContext, expectedClassName: String =
""): Metadata = {
+ @deprecated("use loadMetadata with SparkSession", "4.0.0")
+ def loadMetadata(path: String, sc: SparkContext, expectedClassName: String =
""): Metadata =
+ loadMetadata(
+ path,
+ SparkSession.builder().sparkContext(sc).getOrCreate(),
+ expectedClassName)
+
+ def loadMetadata(path: String, spark: SparkSession, expectedClassName:
String): Metadata = {
val metadataPath = new Path(path, "metadata").toString
- val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
val metadataStr = spark.read.text(metadataPath).first().getString(0)
parseMetadata(metadataStr, expectedClassName)
}
+ def loadMetadata(path: String, spark: SparkSession): Metadata =
+ loadMetadata(path, spark, "")
+
/**
* Parse metadata JSON string produced by
[[DefaultParamsWriter.getMetadataToSave()]].
* This is a helper function for [[loadMetadata()]].
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]