Repository: spark
Updated Branches:
  refs/heads/master 39f743a62 -> 8c11d1aab


[SPARK-11893] Model export/import for spark.ml: TrainValidationSplit

https://issues.apache.org/jira/browse/SPARK-11893

jkbradley In order to share read/write with `TrainValidationSplit`, I move the 
`SharedReadWrite` out of `CrossValidator` into a new trait `SharedReadWrite` in 
the tunning package.

To reduce the repeated tests, I move the complex tests from 
`CrossValidatorSuite` to `SharedReadWriteSuite`, and create a fake validator 
called `MyValidator` to test the shared code.

With `SharedReadWrite`, potential newly added `Validator` can share the 
read/write common part, and only need to implement their extra params save/load.

Author: Xusen Yin <[email protected]>
Author: Joseph K. Bradley <[email protected]>

Closes #9971 from yinxusen/SPARK-11893.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8c11d1aa
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8c11d1aa
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8c11d1aa

Branch: refs/heads/master
Commit: 8c11d1aab8522c75d78bc6b30402c64e8d9ff065
Parents: 39f743a
Author: Xusen Yin <[email protected]>
Authored: Mon Mar 28 15:40:06 2016 -0700
Committer: Joseph K. Bradley <[email protected]>
Committed: Mon Mar 28 15:40:06 2016 -0700

----------------------------------------------------------------------
 .../apache/spark/ml/tuning/CrossValidator.scala | 148 ++-----------------
 .../spark/ml/tuning/TrainValidationSplit.scala  | 100 ++++++++++++-
 .../spark/ml/tuning/ValidatorParams.scala       | 117 ++++++++++++++-
 .../org/apache/spark/ml/util/ReadWrite.scala    |  42 +++++-
 .../ml/tuning/TrainValidationSplitSuite.scala   |  45 +++++-
 5 files changed, 310 insertions(+), 142 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8c11d1aa/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 963f81c..040b009 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -19,25 +19,19 @@ package org.apache.spark.ml.tuning
 
 import com.github.fommil.netlib.F2jBLAS
 import org.apache.hadoop.fs.Path
-import org.json4s.{DefaultFormats, JObject}
-import org.json4s.jackson.JsonMethods._
+import org.json4s.DefaultFormats
 
-import org.apache.spark.SparkContext
 import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.internal.Logging
 import org.apache.spark.ml._
-import org.apache.spark.ml.classification.OneVsRestParams
 import org.apache.spark.ml.evaluation.Evaluator
-import org.apache.spark.ml.feature.RFormulaModel
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared.HasSeed
 import org.apache.spark.ml.util._
-import org.apache.spark.ml.util.DefaultParamsReader.Metadata
 import org.apache.spark.mllib.util.MLUtils
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.types.StructType
 
-
 /**
  * Params for [[CrossValidator]] and [[CrossValidatorModel]].
  */
@@ -45,6 +39,7 @@ private[ml] trait CrossValidatorParams extends 
ValidatorParams with HasSeed {
   /**
    * Param for number of folds for cross validation.  Must be >= 2.
    * Default: 3
+   *
    * @group param
    */
   val numFolds: IntParam = new IntParam(this, "numFolds",
@@ -163,10 +158,10 @@ object CrossValidator extends MLReadable[CrossValidator] {
 
   private[CrossValidator] class CrossValidatorWriter(instance: CrossValidator) 
extends MLWriter {
 
-    SharedReadWrite.validateParams(instance)
+    ValidatorParams.validateParams(instance)
 
     override protected def saveImpl(path: String): Unit =
-      SharedReadWrite.saveImpl(path, instance, sc)
+      ValidatorParams.saveImpl(path, instance, sc)
   }
 
   private class CrossValidatorReader extends MLReader[CrossValidator] {
@@ -175,8 +170,11 @@ object CrossValidator extends MLReadable[CrossValidator] {
     private val className = classOf[CrossValidator].getName
 
     override def load(path: String): CrossValidator = {
-      val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) =
-        SharedReadWrite.load(path, sc, className)
+      implicit val format = DefaultFormats
+
+      val (metadata, estimator, evaluator, estimatorParamMaps) =
+        ValidatorParams.loadImpl(path, sc, className)
+      val numFolds = (metadata.params \ "numFolds").extract[Int]
       new CrossValidator(metadata.uid)
         .setEstimator(estimator)
         .setEvaluator(evaluator)
@@ -184,123 +182,6 @@ object CrossValidator extends MLReadable[CrossValidator] {
         .setNumFolds(numFolds)
     }
   }
-
-  private object CrossValidatorReader {
-    /**
-     * Examine the given estimator (which may be a compound estimator) and 
extract a mapping
-     * from UIDs to corresponding [[Params]] instances.
-     */
-    def getUidMap(instance: Params): Map[String, Params] = {
-      val uidList = getUidMapImpl(instance)
-      val uidMap = uidList.toMap
-      if (uidList.size != uidMap.size) {
-        throw new RuntimeException("CrossValidator.load found a compound 
estimator with stages" +
-          s" with duplicate UIDs.  List of UIDs: 
${uidList.map(_._1).mkString(", ")}")
-      }
-      uidMap
-    }
-
-    def getUidMapImpl(instance: Params): List[(String, Params)] = {
-      val subStages: Array[Params] = instance match {
-        case p: Pipeline => p.getStages.asInstanceOf[Array[Params]]
-        case pm: PipelineModel => pm.stages.asInstanceOf[Array[Params]]
-        case v: ValidatorParams => Array(v.getEstimator, v.getEvaluator)
-        case ovr: OneVsRestParams =>
-          // TODO: SPARK-11892: This case may require special handling.
-          throw new UnsupportedOperationException("CrossValidator write will 
fail because it" +
-            " cannot yet handle an estimator containing type: 
${ovr.getClass.getName}")
-        case rformModel: RFormulaModel => Array(rformModel.pipelineModel)
-        case _: Params => Array()
-      }
-      val subStageMaps = 
subStages.map(getUidMapImpl).foldLeft(List.empty[(String, Params)])(_ ++ _)
-      List((instance.uid, instance)) ++ subStageMaps
-    }
-  }
-
-  private[tuning] object SharedReadWrite {
-
-    /**
-     * Check that [[CrossValidator.evaluator]] and 
[[CrossValidator.estimator]] are Writable.
-     * This does not check [[CrossValidator.estimatorParamMaps]].
-     */
-    def validateParams(instance: ValidatorParams): Unit = {
-      def checkElement(elem: Params, name: String): Unit = elem match {
-        case stage: MLWritable => // good
-        case other =>
-          throw new UnsupportedOperationException("CrossValidator write will 
fail " +
-            s" because it contains $name which does not implement Writable." +
-            s" Non-Writable $name: ${other.uid} of type ${other.getClass}")
-      }
-      checkElement(instance.getEvaluator, "evaluator")
-      checkElement(instance.getEstimator, "estimator")
-      // Check to make sure all Params apply to this estimator.  Throw an 
error if any do not.
-      // Extraneous Params would cause problems when loading the 
estimatorParamMaps.
-      val uidToInstance: Map[String, Params] = 
CrossValidatorReader.getUidMap(instance)
-      instance.getEstimatorParamMaps.foreach { case pMap: ParamMap =>
-        pMap.toSeq.foreach { case ParamPair(p, v) =>
-          require(uidToInstance.contains(p.parent), s"CrossValidator save 
requires all Params in" +
-            s" estimatorParamMaps to apply to this CrossValidator, its 
Estimator, or its" +
-            s" Evaluator.  An extraneous Param was found: $p")
-        }
-      }
-    }
-
-    private[tuning] def saveImpl(
-        path: String,
-        instance: CrossValidatorParams,
-        sc: SparkContext,
-        extraMetadata: Option[JObject] = None): Unit = {
-      import org.json4s.JsonDSL._
-
-      val estimatorParamMapsJson = compact(render(
-        instance.getEstimatorParamMaps.map { case paramMap =>
-          paramMap.toSeq.map { case ParamPair(p, v) =>
-            Map("parent" -> p.parent, "name" -> p.name, "value" -> 
p.jsonEncode(v))
-          }
-        }.toSeq
-      ))
-      val jsonParams = List(
-        "numFolds" -> 
parse(instance.numFolds.jsonEncode(instance.getNumFolds)),
-        "estimatorParamMaps" -> parse(estimatorParamMapsJson)
-      )
-      DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, 
Some(jsonParams))
-
-      val evaluatorPath = new Path(path, "evaluator").toString
-      instance.getEvaluator.asInstanceOf[MLWritable].save(evaluatorPath)
-      val estimatorPath = new Path(path, "estimator").toString
-      instance.getEstimator.asInstanceOf[MLWritable].save(estimatorPath)
-    }
-
-    private[tuning] def load[M <: Model[M]](
-        path: String,
-        sc: SparkContext,
-        expectedClassName: String): (Metadata, Estimator[M], Evaluator, 
Array[ParamMap], Int) = {
-
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, 
expectedClassName)
-
-      implicit val format = DefaultFormats
-      val evaluatorPath = new Path(path, "evaluator").toString
-      val evaluator = 
DefaultParamsReader.loadParamsInstance[Evaluator](evaluatorPath, sc)
-      val estimatorPath = new Path(path, "estimator").toString
-      val estimator = 
DefaultParamsReader.loadParamsInstance[Estimator[M]](estimatorPath, sc)
-
-      val uidToParams = Map(evaluator.uid -> evaluator) ++ 
CrossValidatorReader.getUidMap(estimator)
-
-      val numFolds = (metadata.params \ "numFolds").extract[Int]
-      val estimatorParamMaps: Array[ParamMap] =
-        (metadata.params \ "estimatorParamMaps").extract[Seq[Seq[Map[String, 
String]]]].map {
-          pMap =>
-            val paramPairs = pMap.map { case pInfo: Map[String, String] =>
-              val est = uidToParams(pInfo("parent"))
-              val param = est.getParam(pInfo("name"))
-              val value = param.jsonDecode(pInfo("value"))
-              param -> value
-            }
-            ParamMap(paramPairs: _*)
-        }.toArray
-      (metadata, estimator, evaluator, estimatorParamMaps, numFolds)
-    }
-  }
 }
 
 /**
@@ -346,8 +227,6 @@ class CrossValidatorModel private[ml] (
 @Since("1.6.0")
 object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
 
-  import CrossValidator.SharedReadWrite
-
   @Since("1.6.0")
   override def read: MLReader[CrossValidatorModel] = new 
CrossValidatorModelReader
 
@@ -357,12 +236,12 @@ object CrossValidatorModel extends 
MLReadable[CrossValidatorModel] {
   private[CrossValidatorModel]
   class CrossValidatorModelWriter(instance: CrossValidatorModel) extends 
MLWriter {
 
-    SharedReadWrite.validateParams(instance)
+    ValidatorParams.validateParams(instance)
 
     override protected def saveImpl(path: String): Unit = {
       import org.json4s.JsonDSL._
       val extraMetadata = "avgMetrics" -> instance.avgMetrics.toSeq
-      SharedReadWrite.saveImpl(path, instance, sc, Some(extraMetadata))
+      ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata))
       val bestModelPath = new Path(path, "bestModel").toString
       instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath)
     }
@@ -376,8 +255,9 @@ object CrossValidatorModel extends 
MLReadable[CrossValidatorModel] {
     override def load(path: String): CrossValidatorModel = {
       implicit val format = DefaultFormats
 
-      val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) =
-        SharedReadWrite.load(path, sc, className)
+      val (metadata, estimator, evaluator, estimatorParamMaps) =
+        ValidatorParams.loadImpl(path, sc, className)
+      val numFolds = (metadata.params \ "numFolds").extract[Int]
       val bestModelPath = new Path(path, "bestModel").toString
       val bestModel = 
DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
       val avgMetrics = (metadata.metadata \ 
"avgMetrics").extract[Seq[Double]].toArray

http://git-wip-us.apache.org/repos/asf/spark/blob/8c11d1aa/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index 70fa5f0..4d1d636 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -17,12 +17,15 @@
 
 package org.apache.spark.ml.tuning
 
+import org.apache.hadoop.fs.Path
+import org.json4s.DefaultFormats
+
 import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.internal.Logging
 import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.evaluation.Evaluator
 import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators}
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util._
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.types.StructType
 
@@ -33,6 +36,7 @@ private[ml] trait TrainValidationSplitParams extends 
ValidatorParams {
   /**
    * Param for ratio between train and validation data. Must be between 0 and 
1.
    * Default: 0.75
+   *
    * @group param
    */
   val trainRatio: DoubleParam = new DoubleParam(this, "trainRatio",
@@ -55,7 +59,7 @@ private[ml] trait TrainValidationSplitParams extends 
ValidatorParams {
 @Experimental
 class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: 
String)
   extends Estimator[TrainValidationSplitModel]
-  with TrainValidationSplitParams with Logging {
+  with TrainValidationSplitParams with MLWritable with Logging {
 
   @Since("1.5.0")
   def this() = this(Identifiable.randomUID("tvs"))
@@ -130,6 +134,47 @@ class TrainValidationSplit @Since("1.5.0") 
(@Since("1.5.0") override val uid: St
     }
     copied
   }
+
+  @Since("2.0.0")
+  override def write: MLWriter = new 
TrainValidationSplit.TrainValidationSplitWriter(this)
+}
+
+@Since("2.0.0")
+object TrainValidationSplit extends MLReadable[TrainValidationSplit] {
+
+  @Since("2.0.0")
+  override def read: MLReader[TrainValidationSplit] = new 
TrainValidationSplitReader
+
+  @Since("2.0.0")
+  override def load(path: String): TrainValidationSplit = super.load(path)
+
+  private[TrainValidationSplit] class TrainValidationSplitWriter(instance: 
TrainValidationSplit)
+    extends MLWriter {
+
+    ValidatorParams.validateParams(instance)
+
+    override protected def saveImpl(path: String): Unit =
+      ValidatorParams.saveImpl(path, instance, sc)
+  }
+
+  private class TrainValidationSplitReader extends 
MLReader[TrainValidationSplit] {
+
+    /** Checked against metadata when loading model */
+    private val className = classOf[TrainValidationSplit].getName
+
+    override def load(path: String): TrainValidationSplit = {
+      implicit val format = DefaultFormats
+
+      val (metadata, estimator, evaluator, estimatorParamMaps) =
+        ValidatorParams.loadImpl(path, sc, className)
+      val trainRatio = (metadata.params \ "trainRatio").extract[Double]
+      new TrainValidationSplit(metadata.uid)
+        .setEstimator(estimator)
+        .setEvaluator(evaluator)
+        .setEstimatorParamMaps(estimatorParamMaps)
+        .setTrainRatio(trainRatio)
+    }
+  }
 }
 
 /**
@@ -146,7 +191,7 @@ class TrainValidationSplitModel private[ml] (
     @Since("1.5.0") override val uid: String,
     @Since("1.5.0") val bestModel: Model[_],
     @Since("1.5.0") val validationMetrics: Array[Double])
-  extends Model[TrainValidationSplitModel] with TrainValidationSplitParams {
+  extends Model[TrainValidationSplitModel] with TrainValidationSplitParams 
with MLWritable {
 
   @Since("1.5.0")
   override def transform(dataset: DataFrame): DataFrame = {
@@ -167,4 +212,53 @@ class TrainValidationSplitModel private[ml] (
       validationMetrics.clone())
     copyValues(copied, extra)
   }
+
+  @Since("2.0.0")
+  override def write: MLWriter = new 
TrainValidationSplitModel.TrainValidationSplitModelWriter(this)
+}
+
+@Since("2.0.0")
+object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] 
{
+
+  @Since("2.0.0")
+  override def read: MLReader[TrainValidationSplitModel] = new 
TrainValidationSplitModelReader
+
+  @Since("2.0.0")
+  override def load(path: String): TrainValidationSplitModel = super.load(path)
+
+  private[TrainValidationSplitModel]
+  class TrainValidationSplitModelWriter(instance: TrainValidationSplitModel) 
extends MLWriter {
+
+    ValidatorParams.validateParams(instance)
+
+    override protected def saveImpl(path: String): Unit = {
+      import org.json4s.JsonDSL._
+      val extraMetadata = "validationMetrics" -> 
instance.validationMetrics.toSeq
+      ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata))
+      val bestModelPath = new Path(path, "bestModel").toString
+      instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath)
+    }
+  }
+
+  private class TrainValidationSplitModelReader extends 
MLReader[TrainValidationSplitModel] {
+
+    /** Checked against metadata when loading model */
+    private val className = classOf[TrainValidationSplitModel].getName
+
+    override def load(path: String): TrainValidationSplitModel = {
+      implicit val format = DefaultFormats
+
+      val (metadata, estimator, evaluator, estimatorParamMaps) =
+        ValidatorParams.loadImpl(path, sc, className)
+      val trainRatio = (metadata.params \ "trainRatio").extract[Double]
+      val bestModelPath = new Path(path, "bestModel").toString
+      val bestModel = 
DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
+      val validationMetrics = (metadata.metadata \ 
"validationMetrics").extract[Seq[Double]].toArray
+      val tvs = new TrainValidationSplitModel(metadata.uid, bestModel, 
validationMetrics)
+      tvs.set(tvs.estimator, estimator)
+        .set(tvs.evaluator, evaluator)
+        .set(tvs.estimatorParamMaps, estimatorParamMaps)
+        .set(tvs.trainRatio, trainRatio)
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/8c11d1aa/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
index 953456e..7a4e106 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
@@ -17,9 +17,17 @@
 
 package org.apache.spark.ml.tuning
 
-import org.apache.spark.ml.Estimator
+import org.apache.hadoop.fs.Path
+import org.json4s.{DefaultFormats, _}
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.SparkContext
+import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.evaluation.Evaluator
-import org.apache.spark.ml.param.{Param, ParamMap, Params}
+import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
+import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter, 
MetaAlgorithmReadWrite,
+  MLWritable}
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
 import org.apache.spark.sql.types.StructType
 
 /**
@@ -69,3 +77,108 @@ private[ml] trait ValidatorParams extends Params {
     est.copy(firstEstimatorParamMap).transformSchema(schema)
   }
 }
+
+private[ml] object ValidatorParams {
+  /**
+   * Check that [[ValidatorParams.evaluator]] and 
[[ValidatorParams.estimator]] are Writable.
+   * This does not check [[ValidatorParams.estimatorParamMaps]].
+   */
+  def validateParams(instance: ValidatorParams): Unit = {
+    def checkElement(elem: Params, name: String): Unit = elem match {
+      case stage: MLWritable => // good
+      case other =>
+        throw new UnsupportedOperationException(instance.getClass.getName + " 
write will fail " +
+          s" because it contains $name which does not implement Writable." +
+          s" Non-Writable $name: ${other.uid} of type ${other.getClass}")
+    }
+    checkElement(instance.getEvaluator, "evaluator")
+    checkElement(instance.getEstimator, "estimator")
+    // Check to make sure all Params apply to this estimator.  Throw an error 
if any do not.
+    // Extraneous Params would cause problems when loading the 
estimatorParamMaps.
+    val uidToInstance: Map[String, Params] = 
MetaAlgorithmReadWrite.getUidMap(instance)
+    instance.getEstimatorParamMaps.foreach { case pMap: ParamMap =>
+      pMap.toSeq.foreach { case ParamPair(p, v) =>
+        require(uidToInstance.contains(p.parent), s"ValidatorParams save 
requires all Params in" +
+          s" estimatorParamMaps to apply to this ValidatorParams, its 
Estimator, or its" +
+          s" Evaluator. An extraneous Param was found: $p")
+      }
+    }
+  }
+
+  /**
+   * Generic implementation of save for [[ValidatorParams]] types.
+   * This handles all [[ValidatorParams]] fields and saves [[Param]] values, 
but the implementing
+   * class needs to handle model data.
+   */
+  def saveImpl(
+      path: String,
+      instance: ValidatorParams,
+      sc: SparkContext,
+      extraMetadata: Option[JObject] = None): Unit = {
+    import org.json4s.JsonDSL._
+
+    val estimatorParamMapsJson = compact(render(
+      instance.getEstimatorParamMaps.map { case paramMap =>
+        paramMap.toSeq.map { case ParamPair(p, v) =>
+          Map("parent" -> p.parent, "name" -> p.name, "value" -> 
p.jsonEncode(v))
+        }
+      }.toSeq
+    ))
+
+    val validatorSpecificParams = instance match {
+      case cv: CrossValidatorParams =>
+        List("numFolds" -> parse(cv.numFolds.jsonEncode(cv.getNumFolds)))
+      case tvs: TrainValidationSplitParams =>
+        List("trainRatio" -> 
parse(tvs.trainRatio.jsonEncode(tvs.getTrainRatio)))
+      case _ =>
+        // This should not happen.
+        throw new NotImplementedError("ValidatorParams.saveImpl does not 
handle type: " +
+          instance.getClass.getCanonicalName)
+    }
+
+    val jsonParams = validatorSpecificParams ++ List(
+      "estimatorParamMaps" -> parse(estimatorParamMapsJson))
+
+    DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, 
Some(jsonParams))
+
+    val evaluatorPath = new Path(path, "evaluator").toString
+    instance.getEvaluator.asInstanceOf[MLWritable].save(evaluatorPath)
+    val estimatorPath = new Path(path, "estimator").toString
+    instance.getEstimator.asInstanceOf[MLWritable].save(estimatorPath)
+  }
+
+  /**
+   * Generic implementation of load for [[ValidatorParams]] types.
+   * This handles all [[ValidatorParams]] fields, but the implementing
+   * class needs to handle model data and special [[Param]] values.
+   */
+  def loadImpl[M <: Model[M]](
+      path: String,
+      sc: SparkContext,
+      expectedClassName: String): (Metadata, Estimator[M], Evaluator, 
Array[ParamMap]) = {
+
+    val metadata = DefaultParamsReader.loadMetadata(path, sc, 
expectedClassName)
+
+    implicit val format = DefaultFormats
+    val evaluatorPath = new Path(path, "evaluator").toString
+    val evaluator = 
DefaultParamsReader.loadParamsInstance[Evaluator](evaluatorPath, sc)
+    val estimatorPath = new Path(path, "estimator").toString
+    val estimator = 
DefaultParamsReader.loadParamsInstance[Estimator[M]](estimatorPath, sc)
+
+    val uidToParams = Map(evaluator.uid -> evaluator) ++ 
MetaAlgorithmReadWrite.getUidMap(estimator)
+
+    val estimatorParamMaps: Array[ParamMap] =
+      (metadata.params \ "estimatorParamMaps").extract[Seq[Seq[Map[String, 
String]]]].map {
+        pMap =>
+          val paramPairs = pMap.map { case pInfo: Map[String, String] =>
+            val est = uidToParams(pInfo("parent"))
+            val param = est.getParam(pInfo("name"))
+            val value = param.jsonDecode(pInfo("value"))
+            param -> value
+          }
+          ParamMap(paramPairs: _*)
+      }.toArray
+
+    (metadata, estimator, evaluator, estimatorParamMaps)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/8c11d1aa/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala 
b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index c95e536..5a596ca 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -21,13 +21,18 @@ import java.io.IOException
 
 import org.apache.hadoop.fs.Path
 import org.json4s._
-import org.json4s.jackson.JsonMethods._
+import org.json4s.{DefaultFormats, JObject}
 import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
 
 import org.apache.spark.SparkContext
 import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.internal.Logging
+import org.apache.spark.ml._
+import org.apache.spark.ml.classification.OneVsRestParams
+import org.apache.spark.ml.feature.RFormulaModel
 import org.apache.spark.ml.param.{ParamPair, Params}
+import org.apache.spark.ml.tuning.ValidatorParams
 import org.apache.spark.sql.SQLContext
 import org.apache.spark.util.Utils
 
@@ -352,3 +357,38 @@ private[ml] object DefaultParamsReader {
     cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path)
   }
 }
+
+/**
+ * Default Meta-Algorithm read and write implementation.
+ */
+private[ml] object MetaAlgorithmReadWrite {
+  /**
+   * Examine the given estimator (which may be a compound estimator) and 
extract a mapping
+   * from UIDs to corresponding [[Params]] instances.
+   */
+  def getUidMap(instance: Params): Map[String, Params] = {
+    val uidList = getUidMapImpl(instance)
+    val uidMap = uidList.toMap
+    if (uidList.size != uidMap.size) {
+      throw new RuntimeException(s"${instance.getClass.getName}.load found a 
compound estimator" +
+        s" with stages with duplicate UIDs. List of UIDs: 
${uidList.map(_._1).mkString(", ")}.")
+    }
+    uidMap
+  }
+
+  private def getUidMapImpl(instance: Params): List[(String, Params)] = {
+    val subStages: Array[Params] = instance match {
+      case p: Pipeline => p.getStages.asInstanceOf[Array[Params]]
+      case pm: PipelineModel => pm.stages.asInstanceOf[Array[Params]]
+      case v: ValidatorParams => Array(v.getEstimator, v.getEvaluator)
+      case ovr: OneVsRestParams =>
+        // TODO: SPARK-11892: This case may require special handling.
+        throw new UnsupportedOperationException(s"${instance.getClass.getName} 
write will fail" +
+          s" because it cannot yet handle an estimator containing type: 
${ovr.getClass.getName}.")
+      case rformModel: RFormulaModel => Array(rformModel.pipelineModel)
+      case _: Params => Array()
+    }
+    val subStageMaps = 
subStages.map(getUidMapImpl).foldLeft(List.empty[(String, Params)])(_ ++ _)
+    List((instance.uid, instance)) ++ subStageMaps
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/8c11d1aa/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
index cf8dcef..7cf7b3e 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
@@ -19,17 +19,20 @@ package org.apache.spark.ml.tuning
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.{Estimator, Model}
-import org.apache.spark.ml.classification.LogisticRegression
+import org.apache.spark.ml.classification.{LogisticRegression, 
LogisticRegressionModel}
 import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, 
Evaluator, RegressionEvaluator}
 import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.param.shared.HasInputCol
 import org.apache.spark.ml.regression.LinearRegression
+import org.apache.spark.ml.util.DefaultReadWriteTest
 import 
org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
+import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.types.StructType
 
-class TrainValidationSplitSuite extends SparkFunSuite with 
MLlibTestSparkContext {
+class TrainValidationSplitSuite
+  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
   test("train validation with logistic regression") {
     val dataset = sqlContext.createDataFrame(
       sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2))
@@ -105,6 +108,44 @@ class TrainValidationSplitSuite extends SparkFunSuite with 
MLlibTestSparkContext
       cv.transformSchema(new StructType())
     }
   }
+
+  test("read/write: TrainValidationSplit") {
+    val lr = new LogisticRegression().setMaxIter(3)
+    val evaluator = new BinaryClassificationEvaluator()
+    val paramMaps = new ParamGridBuilder()
+        .addGrid(lr.regParam, Array(0.1, 0.2))
+        .build()
+    val tvs = new TrainValidationSplit()
+      .setEstimator(lr)
+      .setEvaluator(evaluator)
+      .setTrainRatio(0.5)
+      .setEstimatorParamMaps(paramMaps)
+
+    val tvs2 = testDefaultReadWrite(tvs, testParams = false)
+
+    assert(tvs.getTrainRatio === tvs2.getTrainRatio)
+  }
+
+  test("read/write: TrainValidationSplitModel") {
+    val lr = new LogisticRegression()
+      .setThreshold(0.6)
+    val lrModel = new LogisticRegressionModel(lr.uid, Vectors.dense(1.0, 2.0), 
1.2)
+      .setThreshold(0.6)
+    val evaluator = new BinaryClassificationEvaluator()
+    val paramMaps = new ParamGridBuilder()
+        .addGrid(lr.regParam, Array(0.1, 0.2))
+        .build()
+    val tvs = new TrainValidationSplitModel("cvUid", lrModel, Array(0.3, 0.6))
+    tvs.set(tvs.estimator, lr)
+      .set(tvs.evaluator, evaluator)
+      .set(tvs.trainRatio, 0.5)
+      .set(tvs.estimatorParamMaps, paramMaps)
+
+    val tvs2 = testDefaultReadWrite(tvs, testParams = false)
+
+    assert(tvs.getTrainRatio === tvs2.getTrainRatio)
+    assert(tvs.validationMetrics === tvs2.validationMetrics)
+  }
 }
 
 object TrainValidationSplitSuite {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to