Repository: spark
Updated Branches:
  refs/heads/master 328eb49e6 -> 6eb7008b7


[SPARK-11763][ML] Add save,load to LogisticRegression Estimator

Add save/load to LogisticRegression Estimator, and refactor tests a little to 
make it easier to add similar support to other Estimator, Model pairs.

Moved LogisticRegressionReader/Writer to within LogisticRegressionModel

CC: mengxr

Author: Joseph K. Bradley <jos...@databricks.com>

Closes #9749 from jkbradley/lr-io-2.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6eb7008b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6eb7008b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6eb7008b

Branch: refs/heads/master
Commit: 6eb7008b7f33a36b06d0615b68cc21ed90ad1d8a
Parents: 328eb49
Author: Joseph K. Bradley <jos...@databricks.com>
Authored: Tue Nov 17 14:03:49 2015 -0800
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Tue Nov 17 14:03:49 2015 -0800

----------------------------------------------------------------------
 .../ml/classification/LogisticRegression.scala  | 91 +++++++++++---------
 .../org/apache/spark/ml/util/ReadWrite.scala    |  1 +
 .../org/apache/spark/ml/PipelineSuite.scala     |  7 --
 .../ml/classification/ClassifierSuite.scala     | 32 +++++++
 .../LogisticRegressionSuite.scala               | 37 ++++++--
 .../ProbabilisticClassifierSuite.scala          | 14 +++
 .../spark/ml/util/DefaultReadWriteTest.scala    | 50 ++++++++++-
 7 files changed, 173 insertions(+), 59 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6eb7008b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index a88f526..71c2533 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -157,7 +157,7 @@ private[classification] trait LogisticRegressionParams 
extends ProbabilisticClas
 @Experimental
 class LogisticRegression(override val uid: String)
   extends ProbabilisticClassifier[Vector, LogisticRegression, 
LogisticRegressionModel]
-  with LogisticRegressionParams with Logging {
+  with LogisticRegressionParams with Writable with Logging {
 
   def this() = this(Identifiable.randomUID("logreg"))
 
@@ -385,6 +385,12 @@ class LogisticRegression(override val uid: String)
   }
 
   override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra)
+
+  override def write: Writer = new DefaultParamsWriter(this)
+}
+
+object LogisticRegression extends Readable[LogisticRegression] {
+  override def read: Reader[LogisticRegression] = new 
DefaultParamsReader[LogisticRegression]
 }
 
 /**
@@ -517,61 +523,62 @@ class LogisticRegressionModel private[ml] (
    *
    * For [[LogisticRegressionModel]], this does NOT currently save the 
training [[summary]].
    * An option to save [[summary]] may be added in the future.
+   *
+   * This also does not save the [[parent]] currently.
    */
-  override def write: Writer = new LogisticRegressionWriter(this)
-}
-
-
-/** [[Writer]] instance for [[LogisticRegressionModel]] */
-private[classification] class LogisticRegressionWriter(instance: 
LogisticRegressionModel)
-  extends Writer with Logging {
-
-  private case class Data(
-      numClasses: Int,
-      numFeatures: Int,
-      intercept: Double,
-      coefficients: Vector)
-
-  override protected def saveImpl(path: String): Unit = {
-    // Save metadata and Params
-    DefaultParamsWriter.saveMetadata(instance, path, sc)
-    // Save model data: numClasses, numFeatures, intercept, coefficients
-    val data = Data(instance.numClasses, instance.numFeatures, 
instance.intercept,
-      instance.coefficients)
-    val dataPath = new Path(path, "data").toString
-    
sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath)
-  }
+  override def write: Writer = new 
LogisticRegressionModel.LogisticRegressionModelWriter(this)
 }
 
 
 object LogisticRegressionModel extends Readable[LogisticRegressionModel] {
 
-  override def read: Reader[LogisticRegressionModel] = new 
LogisticRegressionReader
+  override def read: Reader[LogisticRegressionModel] = new 
LogisticRegressionModelReader
 
   override def load(path: String): LogisticRegressionModel = read.load(path)
-}
 
+  /** [[Writer]] instance for [[LogisticRegressionModel]] */
+  private[classification] class LogisticRegressionModelWriter(instance: 
LogisticRegressionModel)
+    extends Writer with Logging {
+
+    private case class Data(
+        numClasses: Int,
+        numFeatures: Int,
+        intercept: Double,
+        coefficients: Vector)
+
+    override protected def saveImpl(path: String): Unit = {
+      // Save metadata and Params
+      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      // Save model data: numClasses, numFeatures, intercept, coefficients
+      val data = Data(instance.numClasses, instance.numFeatures, 
instance.intercept,
+        instance.coefficients)
+      val dataPath = new Path(path, "data").toString
+      
sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath)
+    }
+  }
 
-private[classification] class LogisticRegressionReader extends 
Reader[LogisticRegressionModel] {
+  private[classification] class LogisticRegressionModelReader
+    extends Reader[LogisticRegressionModel] {
 
-  /** Checked against metadata when loading model */
-  private val className = 
"org.apache.spark.ml.classification.LogisticRegressionModel"
+    /** Checked against metadata when loading model */
+    private val className = 
"org.apache.spark.ml.classification.LogisticRegressionModel"
 
-  override def load(path: String): LogisticRegressionModel = {
-    val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+    override def load(path: String): LogisticRegressionModel = {
+      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
 
-    val dataPath = new Path(path, "data").toString
-    val data = sqlContext.read.format("parquet").load(dataPath)
-      .select("numClasses", "numFeatures", "intercept", "coefficients").head()
-    // We will need numClasses, numFeatures in the future for multinomial 
logreg support.
-    // val numClasses = data.getInt(0)
-    // val numFeatures = data.getInt(1)
-    val intercept = data.getDouble(2)
-    val coefficients = data.getAs[Vector](3)
-    val model = new LogisticRegressionModel(metadata.uid, coefficients, 
intercept)
+      val dataPath = new Path(path, "data").toString
+      val data = sqlContext.read.format("parquet").load(dataPath)
+        .select("numClasses", "numFeatures", "intercept", 
"coefficients").head()
+      // We will need numClasses, numFeatures in the future for multinomial 
logreg support.
+      // val numClasses = data.getInt(0)
+      // val numFeatures = data.getInt(1)
+      val intercept = data.getDouble(2)
+      val coefficients = data.getAs[Vector](3)
+      val model = new LogisticRegressionModel(metadata.uid, coefficients, 
intercept)
 
-    DefaultParamsReader.getAndSetParams(model, metadata)
-    model
+      DefaultParamsReader.getAndSetParams(model, metadata)
+      model
+    }
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/6eb7008b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala 
b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index 3169c9e..dddb72a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -217,6 +217,7 @@ private[ml] object DefaultParamsWriter {
  * (json4s-serializable) params and no data. This will not handle more complex 
params or types with
  * data (e.g., models with coefficients).
  * @tparam T ML instance type
+ * TODO: Consider adding check for correct class name.
  */
 private[ml] class DefaultParamsReader[T] extends Reader[T] {
 

http://git-wip-us.apache.org/repos/asf/spark/blob/6eb7008b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
index 484026b..7f5c389 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -149,13 +149,6 @@ class PipelineSuite extends SparkFunSuite with 
MLlibTestSparkContext with Defaul
     assert(pipeline2.stages(0).isInstanceOf[WritableStage])
     val writableStage2 = pipeline2.stages(0).asInstanceOf[WritableStage]
     assert(writableStage.getIntParam === writableStage2.getIntParam)
-
-    val path = new File(tempDir, pipeline.uid).getPath
-    val stagesDir = new Path(path, "stages").toString
-    val expectedStagePath = SharedReadWrite.getStagePath(writableStage.uid, 0, 
1, stagesDir)
-    assert(FileSystem.get(sc.hadoopConfiguration).exists(new 
Path(expectedStagePath)),
-      s"Expected stage 0 of 1 with uid ${writableStage.uid} in Pipeline with 
uid ${pipeline.uid}" +
-        s" to be saved to path: $expectedStagePath")
   }
 
   test("PipelineModel read/write: getStagePath") {

http://git-wip-us.apache.org/repos/asf/spark/blob/6eb7008b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala
new file mode 100644
index 0000000..d0e3fe7
--- /dev/null
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/ClassifierSuite.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+object ClassifierSuite {
+
+  /**
+   * Mapping from all Params to valid settings which differ from the defaults.
+   * This is useful for tests which need to exercise all Params, such as 
save/load.
+   * This excludes input columns to simplify some tests.
+   */
+  val allParamSettings: Map[String, Any] = Map(
+    "predictionCol" -> "myPrediction",
+    "rawPredictionCol" -> "myRawPrediction"
+  )
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/6eb7008b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 51b06b7..48ce1bb 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -873,15 +873,34 @@ class LogisticRegressionSuite
   }
 
   test("read/write") {
-    // Set some Params to make sure set Params are serialized.
+    def checkModelData(model: LogisticRegressionModel, model2: 
LogisticRegressionModel): Unit = {
+      assert(model.intercept === model2.intercept)
+      assert(model.coefficients.toArray === model2.coefficients.toArray)
+      assert(model.numClasses === model2.numClasses)
+      assert(model.numFeatures === model2.numFeatures)
+    }
     val lr = new LogisticRegression()
-      .setElasticNetParam(0.1)
-      .setMaxIter(2)
-      .fit(dataset)
-    val lr2 = testDefaultReadWrite(lr)
-    assert(lr.intercept === lr2.intercept)
-    assert(lr.coefficients.toArray === lr2.coefficients.toArray)
-    assert(lr.numClasses === lr2.numClasses)
-    assert(lr.numFeatures === lr2.numFeatures)
+    testEstimatorAndModelReadWrite(lr, dataset, 
LogisticRegressionSuite.allParamSettings,
+      checkModelData)
   }
 }
+
+object LogisticRegressionSuite {
+
+  /**
+   * Mapping from all Params to valid settings which differ from the defaults.
+   * This is useful for tests which need to exercise all Params, such as 
save/load.
+   * This excludes input columns to simplify some tests.
+   */
+  val allParamSettings: Map[String, Any] = 
ProbabilisticClassifierSuite.allParamSettings ++ Map(
+    "probabilityCol" -> "myProbability",
+    "thresholds" -> Array(0.4, 0.6),
+    "regParam" -> 0.01,
+    "elasticNetParam" -> 0.1,
+    "maxIter" -> 2,  // intentionally small
+    "fitIntercept" -> false,
+    "tol" -> 0.8,
+    "standardization" -> false,
+    "threshold" -> 0.6
+  )
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/6eb7008b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
index fb5f00e..cfa75ec 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
@@ -57,3 +57,17 @@ class ProbabilisticClassifierSuite extends SparkFunSuite {
     assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 2.0))) === 1.0)
   }
 }
+
+object ProbabilisticClassifierSuite {
+
+  /**
+   * Mapping from all Params to valid settings which differ from the defaults.
+   * This is useful for tests which need to exercise all Params, such as 
save/load.
+   * This excludes input columns to simplify some tests.
+   */
+  val allParamSettings: Map[String, Any] = ClassifierSuite.allParamSettings ++ 
Map(
+    "probabilityCol" -> "myProbability",
+    "thresholds" -> Array(0.4, 0.6)
+  )
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/6eb7008b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala 
b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
index c37f050..dd1e8ac 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
@@ -22,13 +22,17 @@ import java.io.{File, IOException}
 import org.scalatest.Suite
 
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.{Model, Estimator}
 import org.apache.spark.ml.param._
 import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.DataFrame
 
 trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
 
   /**
    * Checks "overwrite" option and params.
+   * This saves to and loads from [[tempDir]], but creates a subdirectory with 
a random name
+   * in order to avoid conflicts from multiple calls to this method.
    * @param instance ML instance to test saving/loading
    * @param testParams  If true, then test values of Params.  Otherwise, just 
test overwrite option.
    * @tparam T ML instance type
@@ -38,7 +42,10 @@ trait DefaultReadWriteTest extends TempDirectory { self: 
Suite =>
       instance: T,
       testParams: Boolean = true): T = {
     val uid = instance.uid
-    val path = new File(tempDir, uid).getPath
+    val subdirName = Identifiable.randomUID("test")
+
+    val subdir = new File(tempDir, subdirName)
+    val path = new File(subdir, uid).getPath
 
     instance.save(path)
     intercept[IOException] {
@@ -69,6 +76,47 @@ trait DefaultReadWriteTest extends TempDirectory { self: 
Suite =>
     assert(another.uid === instance.uid)
     another
   }
+
+  /**
+   * Default test for Estimator, Model pairs:
+   *  - Explicitly set Params, and train model
+   *  - Test save/load using [[testDefaultReadWrite()]] on Estimator and Model
+   *  - Check Params on Estimator and Model
+   *
+   * This requires that the [[Estimator]] and [[Model]] share the same set of 
[[Param]]s.
+   * @param estimator  Estimator to test
+   * @param dataset  Dataset to pass to [[Estimator.fit()]]
+   * @param testParams  Set of [[Param]] values to set in estimator
+   * @param checkModelData  Method which takes the original and loaded 
[[Model]] and compares their
+   *                        data.  This method does not need to check 
[[Param]] values.
+   * @tparam E  Type of [[Estimator]]
+   * @tparam M  Type of [[Model]] produced by estimator
+   */
+  def testEstimatorAndModelReadWrite[E <: Estimator[M] with Writable, M <: 
Model[M] with Writable](
+      estimator: E,
+      dataset: DataFrame,
+      testParams: Map[String, Any],
+      checkModelData: (M, M) => Unit): Unit = {
+    // Set some Params to make sure set Params are serialized.
+    testParams.foreach { case (p, v) =>
+      estimator.set(estimator.getParam(p), v)
+    }
+    val model = estimator.fit(dataset)
+
+    // Test Estimator save/load
+    val estimator2 = testDefaultReadWrite(estimator)
+    testParams.foreach { case (p, v) =>
+      val param = estimator.getParam(p)
+      assert(estimator.get(param).get === estimator2.get(param).get)
+    }
+
+    // Test Model save/load
+    val model2 = testDefaultReadWrite(model)
+    testParams.foreach { case (p, v) =>
+      val param = model.getParam(p)
+      assert(model.get(param).get === model2.get(param).get)
+    }
+  }
 }
 
 class MyParams(override val uid: String) extends Params with Writable {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to