Repository: spark
Updated Branches:
  refs/heads/branch-1.6 59eaec2d4 -> 6731dd668


[SPARK-11842][ML] Small cleanups to existing Readers and Writers

Updates:
* Add repartition(1) to save() methods' saving of data for 
LogisticRegressionModel, LinearRegressionModel.
* Strengthen privacy to class and companion object for Writers and Readers
* Change LogisticRegressionSuite read/write test to fit intercept
* Add Since versions for read/write methods in Pipeline, LogisticRegression
* Switch from hand-written class names in Readers to using getClass

CC: mengxr

CC: yanboliang Would you mind taking a look at this PR?  mengxr might not be 
able to soon.  Thank you!

Author: Joseph K. Bradley <[email protected]>

Closes #9829 from jkbradley/ml-io-cleanups.

(cherry picked from commit d02d5b9295b169c3ebb0967453b2835edb8a121f)
Signed-off-by: Xiangrui Meng <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6731dd66
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6731dd66
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6731dd66

Branch: refs/heads/branch-1.6
Commit: 6731dd66896bb4cf4797a12a9f48cf7fa2ce7112
Parents: 59eaec2
Author: Joseph K. Bradley <[email protected]>
Authored: Wed Nov 18 21:44:01 2015 -0800
Committer: Xiangrui Meng <[email protected]>
Committed: Wed Nov 18 21:44:09 2015 -0800

----------------------------------------------------------------------
 .../scala/org/apache/spark/ml/Pipeline.scala    | 22 +++++++++++++-------
 .../ml/classification/LogisticRegression.scala  | 19 ++++++++++-------
 .../spark/ml/feature/CountVectorizer.scala      |  2 +-
 .../scala/org/apache/spark/ml/feature/IDF.scala |  2 +-
 .../apache/spark/ml/feature/MinMaxScaler.scala  |  2 +-
 .../spark/ml/feature/StandardScaler.scala       |  2 +-
 .../apache/spark/ml/feature/StringIndexer.scala |  2 +-
 .../apache/spark/ml/recommendation/ALS.scala    |  6 +++---
 .../spark/ml/regression/LinearRegression.scala  |  4 ++--
 .../LogisticRegressionSuite.scala               |  2 +-
 10 files changed, 38 insertions(+), 25 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6731dd66/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala 
b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index b0f22e0..6f15b37 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -27,7 +27,7 @@ import org.json4s._
 import org.json4s.jackson.JsonMethods._
 
 import org.apache.spark.{SparkContext, Logging}
-import org.apache.spark.annotation.{DeveloperApi, Experimental}
+import org.apache.spark.annotation.{Since, DeveloperApi, Experimental}
 import org.apache.spark.ml.param.{Param, ParamMap, Params}
 import org.apache.spark.ml.util.MLReader
 import org.apache.spark.ml.util.MLWriter
@@ -174,16 +174,20 @@ class Pipeline(override val uid: String) extends 
Estimator[PipelineModel] with M
     theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur))
   }
 
+  @Since("1.6.0")
   override def write: MLWriter = new Pipeline.PipelineWriter(this)
 }
 
+@Since("1.6.0")
 object Pipeline extends MLReadable[Pipeline] {
 
+  @Since("1.6.0")
   override def read: MLReader[Pipeline] = new PipelineReader
 
+  @Since("1.6.0")
   override def load(path: String): Pipeline = super.load(path)
 
-  private[ml] class PipelineWriter(instance: Pipeline) extends MLWriter {
+  private[Pipeline] class PipelineWriter(instance: Pipeline) extends MLWriter {
 
     SharedReadWrite.validateStages(instance.getStages)
 
@@ -191,10 +195,10 @@ object Pipeline extends MLReadable[Pipeline] {
       SharedReadWrite.saveImpl(instance, instance.getStages, sc, path)
   }
 
-  private[ml] class PipelineReader extends MLReader[Pipeline] {
+  private class PipelineReader extends MLReader[Pipeline] {
 
     /** Checked against metadata when loading model */
-    private val className = "org.apache.spark.ml.Pipeline"
+    private val className = classOf[Pipeline].getName
 
     override def load(path: String): Pipeline = {
       val (uid: String, stages: Array[PipelineStage]) = 
SharedReadWrite.load(className, sc, path)
@@ -333,18 +337,22 @@ class PipelineModel private[ml] (
     new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent)
   }
 
+  @Since("1.6.0")
   override def write: MLWriter = new PipelineModel.PipelineModelWriter(this)
 }
 
+@Since("1.6.0")
 object PipelineModel extends MLReadable[PipelineModel] {
 
   import Pipeline.SharedReadWrite
 
+  @Since("1.6.0")
   override def read: MLReader[PipelineModel] = new PipelineModelReader
 
+  @Since("1.6.0")
   override def load(path: String): PipelineModel = super.load(path)
 
-  private[ml] class PipelineModelWriter(instance: PipelineModel) extends 
MLWriter {
+  private[PipelineModel] class PipelineModelWriter(instance: PipelineModel) 
extends MLWriter {
 
     
SharedReadWrite.validateStages(instance.stages.asInstanceOf[Array[PipelineStage]])
 
@@ -352,10 +360,10 @@ object PipelineModel extends MLReadable[PipelineModel] {
       instance.stages.asInstanceOf[Array[PipelineStage]], sc, path)
   }
 
-  private[ml] class PipelineModelReader extends MLReader[PipelineModel] {
+  private class PipelineModelReader extends MLReader[PipelineModel] {
 
     /** Checked against metadata when loading model */
-    private val className = "org.apache.spark.ml.PipelineModel"
+    private val className = classOf[PipelineModel].getName
 
     override def load(path: String): PipelineModel = {
       val (uid: String, stages: Array[PipelineStage]) = 
SharedReadWrite.load(className, sc, path)

http://git-wip-us.apache.org/repos/asf/spark/blob/6731dd66/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index a3cc49f..418bbdc 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -24,7 +24,7 @@ import breeze.optimize.{CachedDiffFunction, DiffFunction, 
LBFGS => BreezeLBFGS,
 import org.apache.hadoop.fs.Path
 
 import org.apache.spark.{Logging, SparkException}
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.annotation.{Since, Experimental}
 import org.apache.spark.ml.feature.Instance
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
@@ -525,18 +525,23 @@ class LogisticRegressionModel private[ml] (
    *
    * This also does not save the [[parent]] currently.
    */
+  @Since("1.6.0")
   override def write: MLWriter = new 
LogisticRegressionModel.LogisticRegressionModelWriter(this)
 }
 
 
+@Since("1.6.0")
 object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] {
 
+  @Since("1.6.0")
   override def read: MLReader[LogisticRegressionModel] = new 
LogisticRegressionModelReader
 
+  @Since("1.6.0")
   override def load(path: String): LogisticRegressionModel = super.load(path)
 
   /** [[MLWriter]] instance for [[LogisticRegressionModel]] */
-  private[classification] class LogisticRegressionModelWriter(instance: 
LogisticRegressionModel)
+  private[LogisticRegressionModel]
+  class LogisticRegressionModelWriter(instance: LogisticRegressionModel)
     extends MLWriter with Logging {
 
     private case class Data(
@@ -552,15 +557,15 @@ object LogisticRegressionModel extends 
MLReadable[LogisticRegressionModel] {
       val data = Data(instance.numClasses, instance.numFeatures, 
instance.intercept,
         instance.coefficients)
       val dataPath = new Path(path, "data").toString
-      
sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath)
+      
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
     }
   }
 
-  private[classification] class LogisticRegressionModelReader
+  private class LogisticRegressionModelReader
     extends MLReader[LogisticRegressionModel] {
 
     /** Checked against metadata when loading model */
-    private val className = 
"org.apache.spark.ml.classification.LogisticRegressionModel"
+    private val className = classOf[LogisticRegressionModel].getName
 
     override def load(path: String): LogisticRegressionModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
@@ -603,7 +608,7 @@ private[classification] class MultiClassSummarizer extends 
Serializable {
    * @return This MultilabelSummarizer
    */
   def add(label: Double, weight: Double = 1.0): this.type = {
-    require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0")
+    require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0")
 
     if (weight == 0.0) return this
 
@@ -839,7 +844,7 @@ private class LogisticAggregator(
     instance match { case Instance(label, weight, features) =>
       require(dim == features.size, s"Dimensions mismatch when adding new 
instance." +
         s" Expecting $dim but got ${features.size}.")
-      require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0")
+      require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0")
 
       if (weight == 0.0) return this
 

http://git-wip-us.apache.org/repos/asf/spark/blob/6731dd66/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
index 4969cf4..b9e2144 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
@@ -266,7 +266,7 @@ object CountVectorizerModel extends 
MLReadable[CountVectorizerModel] {
 
   private class CountVectorizerModelReader extends 
MLReader[CountVectorizerModel] {
 
-    private val className = "org.apache.spark.ml.feature.CountVectorizerModel"
+    private val className = classOf[CountVectorizerModel].getName
 
     override def load(path: String): CountVectorizerModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)

http://git-wip-us.apache.org/repos/asf/spark/blob/6731dd66/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
index 0e00ef6..f7b0f29 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
@@ -155,7 +155,7 @@ object IDFModel extends MLReadable[IDFModel] {
 
   private class IDFModelReader extends MLReader[IDFModel] {
 
-    private val className = "org.apache.spark.ml.feature.IDFModel"
+    private val className = classOf[IDFModel].getName
 
     override def load(path: String): IDFModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)

http://git-wip-us.apache.org/repos/asf/spark/blob/6731dd66/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
index ed24eab..c2866f5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
@@ -210,7 +210,7 @@ object MinMaxScalerModel extends 
MLReadable[MinMaxScalerModel] {
 
   private class MinMaxScalerModelReader extends MLReader[MinMaxScalerModel] {
 
-    private val className = "org.apache.spark.ml.feature.MinMaxScalerModel"
+    private val className = classOf[MinMaxScalerModel].getName
 
     override def load(path: String): MinMaxScalerModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)

http://git-wip-us.apache.org/repos/asf/spark/blob/6731dd66/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index 1f689c1..6d54521 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -180,7 +180,7 @@ object StandardScalerModel extends 
MLReadable[StandardScalerModel] {
 
   private class StandardScalerModelReader extends 
MLReader[StandardScalerModel] {
 
-    private val className = "org.apache.spark.ml.feature.StandardScalerModel"
+    private val className = classOf[StandardScalerModel].getName
 
     override def load(path: String): StandardScalerModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)

http://git-wip-us.apache.org/repos/asf/spark/blob/6731dd66/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 97a2e4f..5c40c35 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -210,7 +210,7 @@ object StringIndexerModel extends 
MLReadable[StringIndexerModel] {
 
   private class StringIndexerModelReader extends MLReader[StringIndexerModel] {
 
-    private val className = "org.apache.spark.ml.feature.StringIndexerModel"
+    private val className = classOf[StringIndexerModel].getName
 
     override def load(path: String): StringIndexerModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)

http://git-wip-us.apache.org/repos/asf/spark/blob/6731dd66/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala 
b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 795b73c..4d35177 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -237,7 +237,7 @@ object ALSModel extends MLReadable[ALSModel] {
   @Since("1.6.0")
   override def load(path: String): ALSModel = super.load(path)
 
-  private[recommendation] class ALSModelWriter(instance: ALSModel) extends 
MLWriter {
+  private[ALSModel] class ALSModelWriter(instance: ALSModel) extends MLWriter {
 
     override protected def saveImpl(path: String): Unit = {
       val extraMetadata = render("rank" -> instance.rank)
@@ -249,10 +249,10 @@ object ALSModel extends MLReadable[ALSModel] {
     }
   }
 
-  private[recommendation] class ALSModelReader extends MLReader[ALSModel] {
+  private class ALSModelReader extends MLReader[ALSModel] {
 
     /** Checked against metadata when loading model */
-    private val className = "org.apache.spark.ml.recommendation.ALSModel"
+    private val className = classOf[ALSModel].getName
 
     override def load(path: String): ALSModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)

http://git-wip-us.apache.org/repos/asf/spark/blob/6731dd66/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala 
b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 7ba1a60..70ccec7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -467,14 +467,14 @@ object LinearRegressionModel extends 
MLReadable[LinearRegressionModel] {
       // Save model data: intercept, coefficients
       val data = Data(instance.intercept, instance.coefficients)
       val dataPath = new Path(path, "data").toString
-      
sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath)
+      
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
     }
   }
 
   private class LinearRegressionModelReader extends 
MLReader[LinearRegressionModel] {
 
     /** Checked against metadata when loading model */
-    private val className = 
"org.apache.spark.ml.regression.LinearRegressionModel"
+    private val className = classOf[LinearRegressionModel].getName
 
     override def load(path: String): LinearRegressionModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)

http://git-wip-us.apache.org/repos/asf/spark/blob/6731dd66/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 48ce1bb..a9a6ff8 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -898,7 +898,7 @@ object LogisticRegressionSuite {
     "regParam" -> 0.01,
     "elasticNetParam" -> 0.1,
     "maxIter" -> 2,  // intentionally small
-    "fitIntercept" -> false,
+    "fitIntercept" -> true,
     "tol" -> 0.8,
     "standardization" -> false,
     "threshold" -> 0.6


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to