Repository: spark
Updated Branches:
  refs/heads/master d283223a5 -> 2cf46d5a9


[SPARK-11871] Add save/load for MLPC

## What changes were proposed in this pull request?

https://issues.apache.org/jira/browse/SPARK-11871

Add save/load for MLPC

## How was this patch tested?

Test with Scala unit test

Author: Xusen Yin <[email protected]>

Closes #9854 from yinxusen/SPARK-11871.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2cf46d5a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2cf46d5a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2cf46d5a

Branch: refs/heads/master
Commit: 2cf46d5a96897d5f97b364db357d30566183c6e7
Parents: d283223
Author: Xusen Yin <[email protected]>
Authored: Thu Mar 24 15:29:17 2016 -0700
Committer: Xiangrui Meng <[email protected]>
Committed: Thu Mar 24 15:29:17 2016 -0700

----------------------------------------------------------------------
 .../MultilayerPerceptronClassifier.scala        | 69 +++++++++++++++++++-
 .../MultilayerPerceptronClassifierSuite.scala   | 43 ++++++++++--
 2 files changed, 103 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2cf46d5a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
index 719d107..f6de5f2 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
@@ -19,12 +19,14 @@ package org.apache.spark.ml.classification
 
 import scala.collection.JavaConverters._
 
+import org.apache.hadoop.fs.Path
+
 import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams}
 import org.apache.spark.ml.ann.{FeedForwardTopology, FeedForwardTrainer}
 import org.apache.spark.ml.param.{IntArrayParam, IntParam, ParamMap, 
ParamValidators}
 import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasTol}
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util._
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.sql.DataFrame
@@ -110,7 +112,7 @@ private object LabelConverter {
 class MultilayerPerceptronClassifier @Since("1.5.0") (
     @Since("1.5.0") override val uid: String)
   extends Predictor[Vector, MultilayerPerceptronClassifier, 
MultilayerPerceptronClassificationModel]
-  with MultilayerPerceptronParams {
+  with MultilayerPerceptronParams with DefaultParamsWritable {
 
   @Since("1.5.0")
   def this() = this(Identifiable.randomUID("mlpc"))
@@ -172,6 +174,14 @@ class MultilayerPerceptronClassifier @Since("1.5.0") (
   }
 }
 
+@Since("2.0.0")
+object MultilayerPerceptronClassifier
+  extends DefaultParamsReadable[MultilayerPerceptronClassifier] {
+
+  @Since("2.0.0")
+  override def load(path: String): MultilayerPerceptronClassifier = 
super.load(path)
+}
+
 /**
  * :: Experimental ::
  * Classification model based on the Multilayer Perceptron.
@@ -188,7 +198,7 @@ class MultilayerPerceptronClassificationModel private[ml] (
     @Since("1.5.0") val layers: Array[Int],
     @Since("1.5.0") val weights: Vector)
   extends PredictionModel[Vector, MultilayerPerceptronClassificationModel]
-  with Serializable {
+  with Serializable with MLWritable {
 
   @Since("1.6.0")
   override val numFeatures: Int = layers.head
@@ -214,4 +224,57 @@ class MultilayerPerceptronClassificationModel private[ml] (
   override def copy(extra: ParamMap): MultilayerPerceptronClassificationModel 
= {
     copyValues(new MultilayerPerceptronClassificationModel(uid, layers, 
weights), extra)
   }
+
+  @Since("2.0.0")
+  override def write: MLWriter =
+    new 
MultilayerPerceptronClassificationModel.MultilayerPerceptronClassificationModelWriter(this)
+}
+
+@Since("2.0.0")
+object MultilayerPerceptronClassificationModel
+  extends MLReadable[MultilayerPerceptronClassificationModel] {
+
+  @Since("2.0.0")
+  override def read: MLReader[MultilayerPerceptronClassificationModel] =
+    new MultilayerPerceptronClassificationModelReader
+
+  @Since("2.0.0")
+  override def load(path: String): MultilayerPerceptronClassificationModel = 
super.load(path)
+
+  /** [[MLWriter]] instance for [[MultilayerPerceptronClassificationModel]] */
+  private[MultilayerPerceptronClassificationModel]
+  class MultilayerPerceptronClassificationModelWriter(
+      instance: MultilayerPerceptronClassificationModel) extends MLWriter {
+
+    private case class Data(layers: Array[Int], weights: Vector)
+
+    override protected def saveImpl(path: String): Unit = {
+      // Save metadata and Params
+      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      // Save model data: layers, weights
+      val data = Data(instance.layers, instance.weights)
+      val dataPath = new Path(path, "data").toString
+      
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+    }
+  }
+
+  private class MultilayerPerceptronClassificationModelReader
+    extends MLReader[MultilayerPerceptronClassificationModel] {
+
+    /** Checked against metadata when loading model */
+    private val className = 
classOf[MultilayerPerceptronClassificationModel].getName
+
+    override def load(path: String): MultilayerPerceptronClassificationModel = 
{
+      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+
+      val dataPath = new Path(path, "data").toString
+      val data = sqlContext.read.parquet(dataPath).select("layers", 
"weights").head()
+      val layers = data.getAs[Seq[Int]](0).toArray
+      val weights = data.getAs[Vector](1)
+      val model = new MultilayerPerceptronClassificationModel(metadata.uid, 
layers, weights)
+
+      DefaultParamsReader.getAndSetParams(model, metadata)
+      model
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/2cf46d5a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
index 602b5a8..5df8e6a 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
@@ -18,31 +18,40 @@
 package org.apache.spark.ml.classification
 
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
 import org.apache.spark.mllib.classification.LogisticRegressionSuite._
 import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
 import org.apache.spark.mllib.evaluation.MulticlassMetrics
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.{DataFrame, Row}
 
-class MultilayerPerceptronClassifierSuite extends SparkFunSuite with 
MLlibTestSparkContext {
+class MultilayerPerceptronClassifierSuite
+  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
 
-  test("XOR function learning as binary classification problem with two 
outputs.") {
-    val dataFrame = sqlContext.createDataFrame(Seq(
+  @transient var dataset: DataFrame = _
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+
+    dataset = sqlContext.createDataFrame(Seq(
         (Vectors.dense(0.0, 0.0), 0.0),
         (Vectors.dense(0.0, 1.0), 1.0),
         (Vectors.dense(1.0, 0.0), 1.0),
         (Vectors.dense(1.0, 1.0), 0.0))
     ).toDF("features", "label")
+  }
+
+  test("XOR function learning as binary classification problem with two 
outputs.") {
     val layers = Array[Int](2, 5, 2)
     val trainer = new MultilayerPerceptronClassifier()
       .setLayers(layers)
       .setBlockSize(1)
       .setSeed(11L)
       .setMaxIter(100)
-    val model = trainer.fit(dataFrame)
-    val result = model.transform(dataFrame)
+    val model = trainer.fit(dataset)
+    val result = model.transform(dataset)
     val predictionAndLabels = result.select("prediction", "label").collect()
     predictionAndLabels.foreach { case Row(p: Double, l: Double) =>
       assert(p == l)
@@ -92,4 +101,26 @@ class MultilayerPerceptronClassifierSuite extends 
SparkFunSuite with MLlibTestSp
     val mlpMetrics = new MulticlassMetrics(mlpPredictionAndLabels)
     assert(mlpMetrics.confusionMatrix ~== lrMetrics.confusionMatrix absTol 100)
   }
+
+  test("read/write: MultilayerPerceptronClassifier") {
+    val mlp = new MultilayerPerceptronClassifier()
+      .setLayers(Array(2, 3, 2))
+      .setMaxIter(5)
+      .setBlockSize(2)
+      .setSeed(42)
+      .setTol(0.1)
+      .setFeaturesCol("myFeatures")
+      .setLabelCol("myLabel")
+      .setPredictionCol("myPrediction")
+
+    testDefaultReadWrite(mlp, testParams = true)
+  }
+
+  test("read/write: MultilayerPerceptronClassificationModel") {
+    val mlp = new MultilayerPerceptronClassifier().setLayers(Array(2, 3, 
2)).setMaxIter(5)
+    val mlpModel = mlp.fit(dataset)
+    val newMlpModel = testDefaultReadWrite(mlpModel, testParams = true)
+    assert(newMlpModel.layers === mlpModel.layers)
+    assert(newMlpModel.weights === mlpModel.weights)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to