Repository: spark
Updated Branches:
  refs/heads/branch-1.6 4774897f9 -> d7b3d5785


[SPARK-11829][ML] Add read/write to estimators under ml.feature (II)

Add read/write support to the following estimators under spark.ml:
* ChiSqSelector
* PCA
* VectorIndexer
* Word2Vec

Author: Yanbo Liang <yblia...@gmail.com>

Closes #9838 from yanboliang/spark-11829.

(cherry picked from commit 3b7f056da87a23f3a96f0311b3a947a9b698f38b)
Signed-off-by: Xiangrui Meng <m...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d7b3d578
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d7b3d578
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d7b3d578

Branch: refs/heads/branch-1.6
Commit: d7b3d578555d6fabfacd80da97b88aae56f81f1b
Parents: 4774897
Author: Yanbo Liang <yblia...@gmail.com>
Authored: Thu Nov 19 22:02:17 2015 -0800
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Thu Nov 19 22:02:25 2015 -0800

----------------------------------------------------------------------
 .../apache/spark/ml/feature/ChiSqSelector.scala | 65 +++++++++++++++++--
 .../scala/org/apache/spark/ml/feature/PCA.scala | 67 ++++++++++++++++++--
 .../apache/spark/ml/feature/VectorIndexer.scala | 66 +++++++++++++++++--
 .../org/apache/spark/ml/feature/Word2Vec.scala  | 67 ++++++++++++++++++--
 .../apache/spark/mllib/feature/Word2Vec.scala   |  6 +-
 .../spark/ml/feature/ChiSqSelectorSuite.scala   | 22 ++++++-
 .../org/apache/spark/ml/feature/PCASuite.scala  | 26 +++++++-
 .../spark/ml/feature/VectorIndexerSuite.scala   | 22 ++++++-
 .../apache/spark/ml/feature/Word2VecSuite.scala | 30 ++++++++-
 9 files changed, 338 insertions(+), 33 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d7b3d578/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
index 5e4061f..dfec038 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
@@ -17,13 +17,14 @@
 
 package org.apache.spark.ml.feature
 
-import org.apache.spark.annotation.Experimental
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.ml._
 import org.apache.spark.ml.attribute.{AttributeGroup, _}
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.Identifiable
-import org.apache.spark.ml.util.SchemaUtils
+import org.apache.spark.ml.util._
 import org.apache.spark.mllib.feature
 import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
 import org.apache.spark.mllib.regression.LabeledPoint
@@ -60,7 +61,7 @@ private[feature] trait ChiSqSelectorParams extends Params
  */
 @Experimental
 final class ChiSqSelector(override val uid: String)
-  extends Estimator[ChiSqSelectorModel] with ChiSqSelectorParams {
+  extends Estimator[ChiSqSelectorModel] with ChiSqSelectorParams with 
DefaultParamsWritable {
 
   def this() = this(Identifiable.randomUID("chiSqSelector"))
 
@@ -95,6 +96,13 @@ final class ChiSqSelector(override val uid: String)
   override def copy(extra: ParamMap): ChiSqSelector = defaultCopy(extra)
 }
 
+@Since("1.6.0")
+object ChiSqSelector extends DefaultParamsReadable[ChiSqSelector] {
+
+  @Since("1.6.0")
+  override def load(path: String): ChiSqSelector = super.load(path)
+}
+
 /**
  * :: Experimental ::
  * Model fitted by [[ChiSqSelector]].
@@ -103,7 +111,12 @@ final class ChiSqSelector(override val uid: String)
 final class ChiSqSelectorModel private[ml] (
     override val uid: String,
     private val chiSqSelector: feature.ChiSqSelectorModel)
-  extends Model[ChiSqSelectorModel] with ChiSqSelectorParams {
+  extends Model[ChiSqSelectorModel] with ChiSqSelectorParams with MLWritable {
+
+  import ChiSqSelectorModel._
+
+  /** list of indices to select (filter). Must be ordered asc */
+  val selectedFeatures: Array[Int] = chiSqSelector.selectedFeatures
 
   /** @group setParam */
   def setFeaturesCol(value: String): this.type = set(featuresCol, value)
@@ -147,4 +160,46 @@ final class ChiSqSelectorModel private[ml] (
     val copied = new ChiSqSelectorModel(uid, chiSqSelector)
     copyValues(copied, extra).setParent(parent)
   }
+
+  @Since("1.6.0")
+  override def write: MLWriter = new ChiSqSelectorModelWriter(this)
+}
+
+@Since("1.6.0")
+object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] {
+
+  private[ChiSqSelectorModel]
+  class ChiSqSelectorModelWriter(instance: ChiSqSelectorModel) extends 
MLWriter {
+
+    private case class Data(selectedFeatures: Seq[Int])
+
+    override protected def saveImpl(path: String): Unit = {
+      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      val data = Data(instance.selectedFeatures.toSeq)
+      val dataPath = new Path(path, "data").toString
+      
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+    }
+  }
+
+  private class ChiSqSelectorModelReader extends MLReader[ChiSqSelectorModel] {
+
+    private val className = classOf[ChiSqSelectorModel].getName
+
+    override def load(path: String): ChiSqSelectorModel = {
+      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val dataPath = new Path(path, "data").toString
+      val data = 
sqlContext.read.parquet(dataPath).select("selectedFeatures").head()
+      val selectedFeatures = data.getAs[Seq[Int]](0).toArray
+      val oldModel = new feature.ChiSqSelectorModel(selectedFeatures)
+      val model = new ChiSqSelectorModel(metadata.uid, oldModel)
+      DefaultParamsReader.getAndSetParams(model, metadata)
+      model
+    }
+  }
+
+  @Since("1.6.0")
+  override def read: MLReader[ChiSqSelectorModel] = new 
ChiSqSelectorModelReader
+
+  @Since("1.6.0")
+  override def load(path: String): ChiSqSelectorModel = super.load(path)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/d7b3d578/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
index 5390847..32d7afe 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
@@ -17,13 +17,15 @@
 
 package org.apache.spark.ml.feature
 
-import org.apache.spark.annotation.Experimental
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.ml._
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util._
 import org.apache.spark.mllib.feature
-import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
+import org.apache.spark.mllib.linalg._
 import org.apache.spark.sql._
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.{StructField, StructType}
@@ -49,7 +51,8 @@ private[feature] trait PCAParams extends Params with 
HasInputCol with HasOutputC
  * PCA trains a model to project vectors to a low-dimensional space using PCA.
  */
 @Experimental
-class PCA (override val uid: String) extends Estimator[PCAModel] with 
PCAParams {
+class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams
+  with DefaultParamsWritable {
 
   def this() = this(Identifiable.randomUID("pca"))
 
@@ -86,6 +89,13 @@ class PCA (override val uid: String) extends 
Estimator[PCAModel] with PCAParams
   override def copy(extra: ParamMap): PCA = defaultCopy(extra)
 }
 
+@Since("1.6.0")
+object PCA extends DefaultParamsReadable[PCA] {
+
+  @Since("1.6.0")
+  override def load(path: String): PCA = super.load(path)
+}
+
 /**
  * :: Experimental ::
  * Model fitted by [[PCA]].
@@ -94,7 +104,12 @@ class PCA (override val uid: String) extends 
Estimator[PCAModel] with PCAParams
 class PCAModel private[ml] (
     override val uid: String,
     pcaModel: feature.PCAModel)
-  extends Model[PCAModel] with PCAParams {
+  extends Model[PCAModel] with PCAParams with MLWritable {
+
+  import PCAModel._
+
+  /** a principal components Matrix. Each column is one principal component. */
+  val pc: DenseMatrix = pcaModel.pc
 
   /** @group setParam */
   def setInputCol(value: String): this.type = set(inputCol, value)
@@ -127,4 +142,46 @@ class PCAModel private[ml] (
     val copied = new PCAModel(uid, pcaModel)
     copyValues(copied, extra).setParent(parent)
   }
+
+  @Since("1.6.0")
+  override def write: MLWriter = new PCAModelWriter(this)
+}
+
+@Since("1.6.0")
+object PCAModel extends MLReadable[PCAModel] {
+
+  private[PCAModel] class PCAModelWriter(instance: PCAModel) extends MLWriter {
+
+    private case class Data(k: Int, pc: DenseMatrix)
+
+    override protected def saveImpl(path: String): Unit = {
+      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      val data = Data(instance.getK, instance.pc)
+      val dataPath = new Path(path, "data").toString
+      
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+    }
+  }
+
+  private class PCAModelReader extends MLReader[PCAModel] {
+
+    private val className = classOf[PCAModel].getName
+
+    override def load(path: String): PCAModel = {
+      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val dataPath = new Path(path, "data").toString
+      val Row(k: Int, pc: DenseMatrix) = sqlContext.read.parquet(dataPath)
+        .select("k", "pc")
+        .head()
+      val oldModel = new feature.PCAModel(k, pc)
+      val model = new PCAModel(metadata.uid, oldModel)
+      DefaultParamsReader.getAndSetParams(model, metadata)
+      model
+    }
+  }
+
+  @Since("1.6.0")
+  override def read: MLReader[PCAModel] = new PCAModelReader
+
+  @Since("1.6.0")
+  override def load(path: String): PCAModel = super.load(path)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/d7b3d578/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index 52e0599..a637a6f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -22,12 +22,14 @@ import java.util.{Map => JMap}
 
 import scala.collection.JavaConverters._
 
-import org.apache.spark.annotation.Experimental
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.attribute._
-import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators, Params}
+import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
+import org.apache.spark.ml.util._
 import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, 
VectorUDT}
 import org.apache.spark.sql.{DataFrame, Row}
 import org.apache.spark.sql.functions.udf
@@ -93,7 +95,7 @@ private[ml] trait VectorIndexerParams extends Params with 
HasInputCol with HasOu
  */
 @Experimental
 class VectorIndexer(override val uid: String) extends 
Estimator[VectorIndexerModel]
-  with VectorIndexerParams {
+  with VectorIndexerParams with DefaultParamsWritable {
 
   def this() = this(Identifiable.randomUID("vecIdx"))
 
@@ -136,7 +138,11 @@ class VectorIndexer(override val uid: String) extends 
Estimator[VectorIndexerMod
   override def copy(extra: ParamMap): VectorIndexer = defaultCopy(extra)
 }
 
-private object VectorIndexer {
+@Since("1.6.0")
+object VectorIndexer extends DefaultParamsReadable[VectorIndexer] {
+
+  @Since("1.6.0")
+  override def load(path: String): VectorIndexer = super.load(path)
 
   /**
    * Helper class for tracking unique values for each feature.
@@ -146,7 +152,7 @@ private object VectorIndexer {
    * @param numFeatures  This class fails if it encounters a Vector whose 
length is not numFeatures.
    * @param maxCategories  This class caps the number of unique values 
collected at maxCategories.
    */
-  class CategoryStats(private val numFeatures: Int, private val maxCategories: 
Int)
+  private class CategoryStats(private val numFeatures: Int, private val 
maxCategories: Int)
     extends Serializable {
 
     /** featureValueSets[feature index] = set of unique values */
@@ -252,7 +258,9 @@ class VectorIndexerModel private[ml] (
     override val uid: String,
     val numFeatures: Int,
     val categoryMaps: Map[Int, Map[Double, Int]])
-  extends Model[VectorIndexerModel] with VectorIndexerParams {
+  extends Model[VectorIndexerModel] with VectorIndexerParams with MLWritable {
+
+  import VectorIndexerModel._
 
   /** Java-friendly version of [[categoryMaps]] */
   def javaCategoryMaps: JMap[JInt, JMap[JDouble, JInt]] = {
@@ -408,4 +416,48 @@ class VectorIndexerModel private[ml] (
     val copied = new VectorIndexerModel(uid, numFeatures, categoryMaps)
     copyValues(copied, extra).setParent(parent)
   }
+
+  @Since("1.6.0")
+  override def write: MLWriter = new VectorIndexerModelWriter(this)
+}
+
+@Since("1.6.0")
+object VectorIndexerModel extends MLReadable[VectorIndexerModel] {
+
+  private[VectorIndexerModel]
+  class VectorIndexerModelWriter(instance: VectorIndexerModel) extends 
MLWriter {
+
+    private case class Data(numFeatures: Int, categoryMaps: Map[Int, 
Map[Double, Int]])
+
+    override protected def saveImpl(path: String): Unit = {
+      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      val data = Data(instance.numFeatures, instance.categoryMaps)
+      val dataPath = new Path(path, "data").toString
+      
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+    }
+  }
+
+  private class VectorIndexerModelReader extends MLReader[VectorIndexerModel] {
+
+    private val className = classOf[VectorIndexerModel].getName
+
+    override def load(path: String): VectorIndexerModel = {
+      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val dataPath = new Path(path, "data").toString
+      val data = sqlContext.read.parquet(dataPath)
+        .select("numFeatures", "categoryMaps")
+        .head()
+      val numFeatures = data.getAs[Int](0)
+      val categoryMaps = data.getAs[Map[Int, Map[Double, Int]]](1)
+      val model = new VectorIndexerModel(metadata.uid, numFeatures, 
categoryMaps)
+      DefaultParamsReader.getAndSetParams(model, metadata)
+      model
+    }
+  }
+
+  @Since("1.6.0")
+  override def read: MLReader[VectorIndexerModel] = new 
VectorIndexerModelReader
+
+  @Since("1.6.0")
+  override def load(path: String): VectorIndexerModel = super.load(path)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/d7b3d578/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index 708dbee..a8d61b6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -17,15 +17,17 @@
 
 package org.apache.spark.ml.feature
 
+import org.apache.hadoop.fs.Path
+
 import org.apache.spark.SparkContext
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
+import org.apache.spark.ml.util._
 import org.apache.spark.mllib.feature
 import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors}
-import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types._
 
@@ -90,7 +92,8 @@ private[feature] trait Word2VecBase extends Params
  * natural language processing or machine learning process.
  */
 @Experimental
-final class Word2Vec(override val uid: String) extends 
Estimator[Word2VecModel] with Word2VecBase {
+final class Word2Vec(override val uid: String) extends 
Estimator[Word2VecModel] with Word2VecBase
+  with DefaultParamsWritable {
 
   def this() = this(Identifiable.randomUID("w2v"))
 
@@ -139,6 +142,13 @@ final class Word2Vec(override val uid: String) extends 
Estimator[Word2VecModel]
   override def copy(extra: ParamMap): Word2Vec = defaultCopy(extra)
 }
 
+@Since("1.6.0")
+object Word2Vec extends DefaultParamsReadable[Word2Vec] {
+
+  @Since("1.6.0")
+  override def load(path: String): Word2Vec = super.load(path)
+}
+
 /**
  * :: Experimental ::
  * Model fitted by [[Word2Vec]].
@@ -147,7 +157,9 @@ final class Word2Vec(override val uid: String) extends 
Estimator[Word2VecModel]
 class Word2VecModel private[ml] (
     override val uid: String,
     @transient private val wordVectors: feature.Word2VecModel)
-  extends Model[Word2VecModel] with Word2VecBase {
+  extends Model[Word2VecModel] with Word2VecBase with MLWritable {
+
+  import Word2VecModel._
 
   /**
    * Returns a dataframe with two fields, "word" and "vector", with "word" 
being a String and
@@ -224,4 +236,49 @@ class Word2VecModel private[ml] (
     val copied = new Word2VecModel(uid, wordVectors)
     copyValues(copied, extra).setParent(parent)
   }
+
+  @Since("1.6.0")
+  override def write: MLWriter = new Word2VecModelWriter(this)
+}
+
+@Since("1.6.0")
+object Word2VecModel extends MLReadable[Word2VecModel] {
+
+  private[Word2VecModel]
+  class Word2VecModelWriter(instance: Word2VecModel) extends MLWriter {
+
+    private case class Data(wordIndex: Map[String, Int], wordVectors: 
Seq[Float])
+
+    override protected def saveImpl(path: String): Unit = {
+      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      val data = Data(instance.wordVectors.wordIndex, 
instance.wordVectors.wordVectors.toSeq)
+      val dataPath = new Path(path, "data").toString
+      
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+    }
+  }
+
+  private class Word2VecModelReader extends MLReader[Word2VecModel] {
+
+    private val className = classOf[Word2VecModel].getName
+
+    override def load(path: String): Word2VecModel = {
+      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val dataPath = new Path(path, "data").toString
+      val data = sqlContext.read.parquet(dataPath)
+        .select("wordIndex", "wordVectors")
+        .head()
+      val wordIndex = data.getAs[Map[String, Int]](0)
+      val wordVectors = data.getAs[Seq[Float]](1).toArray
+      val oldModel = new feature.Word2VecModel(wordIndex, wordVectors)
+      val model = new Word2VecModel(metadata.uid, oldModel)
+      DefaultParamsReader.getAndSetParams(model, metadata)
+      model
+    }
+  }
+
+  @Since("1.6.0")
+  override def read: MLReader[Word2VecModel] = new Word2VecModelReader
+
+  @Since("1.6.0")
+  override def load(path: String): Word2VecModel = super.load(path)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/d7b3d578/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index 7ab0d89..a47f27b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -432,9 +432,9 @@ class Word2Vec extends Serializable with Logging {
  *                    (i * vectorSize, i * vectorSize + vectorSize)
  */
 @Since("1.1.0")
-class Word2VecModel private[mllib] (
-    private val wordIndex: Map[String, Int],
-    private val wordVectors: Array[Float]) extends Serializable with Saveable {
+class Word2VecModel private[spark] (
+    private[spark] val wordIndex: Map[String, Int],
+    private[spark] val wordVectors: Array[Float]) extends Serializable with 
Saveable {
 
   private val numWords = wordIndex.size
   // vectorSize: Dimension of each word's vector.

http://git-wip-us.apache.org/repos/asf/spark/blob/d7b3d578/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
index e5a4296..7827db2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
@@ -18,13 +18,17 @@
 package org.apache.spark.ml.feature
 
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.mllib.feature
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
 import org.apache.spark.sql.{Row, SQLContext}
 
-class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext {
+class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
+  with DefaultReadWriteTest {
+
   test("Test Chi-Square selector") {
     val sqlContext = SQLContext.getOrCreate(sc)
     import sqlContext.implicits._
@@ -58,4 +62,20 @@ class ChiSqSelectorSuite extends SparkFunSuite with 
MLlibTestSparkContext {
         assert(vec1 ~== vec2 absTol 1e-1)
     }
   }
+
+  test("ChiSqSelector read/write") {
+    val t = new ChiSqSelector()
+      .setFeaturesCol("myFeaturesCol")
+      .setLabelCol("myLabelCol")
+      .setOutputCol("myOutputCol")
+      .setNumTopFeatures(2)
+    testDefaultReadWrite(t)
+  }
+
+  test("ChiSqSelectorModel read/write") {
+    val oldModel = new feature.ChiSqSelectorModel(Array(1, 3))
+    val instance = new ChiSqSelectorModel("myChiSqSelectorModel", oldModel)
+    val newInstance = testDefaultReadWrite(instance)
+    assert(newInstance.selectedFeatures === instance.selectedFeatures)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/d7b3d578/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
index 30c500f..5a21cd2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
@@ -19,15 +19,15 @@ package org.apache.spark.ml.feature
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.mllib.linalg.distributed.RowMatrix
-import org.apache.spark.mllib.linalg.{Vector, Vectors, DenseMatrix, Matrices}
+import org.apache.spark.mllib.linalg._
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
 import org.apache.spark.mllib.feature.{PCAModel => OldPCAModel}
 import org.apache.spark.sql.Row
 
-class PCASuite extends SparkFunSuite with MLlibTestSparkContext {
+class PCASuite extends SparkFunSuite with MLlibTestSparkContext with 
DefaultReadWriteTest {
 
   test("params") {
     ParamsSuite.checkParams(new PCA)
@@ -65,4 +65,24 @@ class PCASuite extends SparkFunSuite with 
MLlibTestSparkContext {
         assert(x ~== y absTol 1e-5, "Transformed vector is different with 
expected vector.")
     }
   }
+
+  test("read/write") {
+
+    def checkModelData(model1: PCAModel, model2: PCAModel): Unit = {
+      assert(model1.pc === model2.pc)
+    }
+    val allParams: Map[String, Any] = Map(
+      "k" -> 3,
+      "inputCol" -> "features",
+      "outputCol" -> "pca_features"
+    )
+    val data = Seq(
+      (0.0, Vectors.sparse(5, Seq((1, 1.0), (3, 7.0)))),
+      (1.0, Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)),
+      (2.0, Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0))
+    )
+    val df = sqlContext.createDataFrame(data).toDF("id", "features")
+    val pca = new PCA().setK(3)
+    testEstimatorAndModelReadWrite(pca, df, allParams, checkModelData)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/d7b3d578/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
index 8cb0a2c..67817fa 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
@@ -22,13 +22,14 @@ import scala.beans.{BeanInfo, BeanProperty}
 import org.apache.spark.{Logging, SparkException, SparkFunSuite}
 import org.apache.spark.ml.attribute._
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.DataFrame
 
-class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext with 
Logging {
+class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext
+  with DefaultReadWriteTest with Logging {
 
   import VectorIndexerSuite.FeatureData
 
@@ -251,6 +252,23 @@ class VectorIndexerSuite extends SparkFunSuite with 
MLlibTestSparkContext with L
       }
     }
   }
+
+  test("VectorIndexer read/write") {
+    val t = new VectorIndexer()
+      .setInputCol("myInputCol")
+      .setOutputCol("myOutputCol")
+      .setMaxCategories(30)
+    testDefaultReadWrite(t)
+  }
+
+  test("VectorIndexerModel read/write") {
+    val categoryMaps = Map(0 -> Map(0.0 -> 0, 1.0 -> 1), 1 -> Map(0.0 -> 0, 
1.0 -> 1,
+      2.0 -> 2, 3.0 -> 3), 2 -> Map(0.0 -> 0, -1.0 -> 1, 2.0 -> 2))
+    val instance = new VectorIndexerModel("myVectorIndexerModel", 3, 
categoryMaps)
+    val newInstance = testDefaultReadWrite(instance)
+    assert(newInstance.numFeatures === instance.numFeatures)
+    assert(newInstance.categoryMaps === instance.categoryMaps)
+  }
 }
 
 private[feature] object VectorIndexerSuite {

http://git-wip-us.apache.org/repos/asf/spark/blob/d7b3d578/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
index 23dfdaa..a773244 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
@@ -19,14 +19,14 @@ package org.apache.spark.ml.feature
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
 import org.apache.spark.sql.{Row, SQLContext}
 import org.apache.spark.mllib.feature.{Word2VecModel => OldWord2VecModel}
 
-class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext {
+class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with 
DefaultReadWriteTest {
 
   test("params") {
     ParamsSuite.checkParams(new Word2Vec)
@@ -143,5 +143,31 @@ class Word2VecSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     }
 
   }
+
+  test("Word2Vec read/write") {
+    val t = new Word2Vec()
+      .setInputCol("myInputCol")
+      .setOutputCol("myOutputCol")
+      .setMaxIter(2)
+      .setMinCount(8)
+      .setNumPartitions(1)
+      .setSeed(42L)
+      .setStepSize(0.01)
+      .setVectorSize(100)
+    testDefaultReadWrite(t)
+  }
+
+  test("Word2VecModel read/write") {
+    val word2VecMap = Map(
+      ("china", Array(0.50f, 0.50f, 0.50f, 0.50f)),
+      ("japan", Array(0.40f, 0.50f, 0.50f, 0.50f)),
+      ("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)),
+      ("korea", Array(0.45f, 0.60f, 0.60f, 0.60f))
+    )
+    val oldModel = new OldWord2VecModel(word2VecMap)
+    val instance = new Word2VecModel("myWord2VecModel", oldModel)
+    val newInstance = testDefaultReadWrite(instance)
+    assert(newInstance.getVectors.collect() === instance.getVectors.collect())
+  }
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to