Repository: spark
Updated Branches:
  refs/heads/branch-1.3 e23c8f5c8 -> e26c14990


[SPARK-5757][MLLIB] replace SQL JSON usage in model import/export by json4s

This PR detaches MLlib model import/export code from SQL's JSON support, and 
hence unblocks #4544 . yhuai

Author: Xiangrui Meng <m...@databricks.com>

Closes #4555 from mengxr/SPARK-5757 and squashes the following commits:

b0415e8 [Xiangrui Meng] replace SQL JSON usage by json4s

(cherry picked from commit 99bd5006650bb15ec5465ffee1ebaca81354a3df)
Signed-off-by: Xiangrui Meng <m...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e26c1499
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e26c1499
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e26c1499

Branch: refs/heads/branch-1.3
Commit: e26c14990c477249241b429c1bb877c3d9339744
Parents: e23c8f5
Author: Xiangrui Meng <m...@databricks.com>
Authored: Thu Feb 12 10:48:13 2015 -0800
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Thu Feb 12 10:48:22 2015 -0800

----------------------------------------------------------------------
 .../classification/ClassificationModel.scala    | 16 ++----
 .../classification/LogisticRegression.scala     |  3 +-
 .../spark/mllib/classification/NaiveBayes.scala | 18 +++----
 .../apache/spark/mllib/classification/SVM.scala |  6 +--
 .../impl/GLMClassificationModel.scala           | 17 ++++---
 .../MatrixFactorizationModel.scala              | 14 ++++--
 .../apache/spark/mllib/regression/Lasso.scala   |  2 +-
 .../mllib/regression/LinearRegression.scala     |  2 +-
 .../mllib/regression/RegressionModel.scala      | 16 ++----
 .../mllib/regression/RidgeRegression.scala      |  2 +-
 .../regression/impl/GLMRegressionModel.scala    | 11 +++--
 .../apache/spark/mllib/tree/DecisionTree.scala  |  8 +--
 .../mllib/tree/model/DecisionTreeModel.scala    | 28 +++++------
 .../mllib/tree/model/treeEnsembleModels.scala   | 51 ++++++++------------
 .../apache/spark/mllib/util/modelSaveLoad.scala | 25 +++-------
 15 files changed, 92 insertions(+), 127 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e26c1499/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
 
b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
index 348c1e8..35a0db7 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
@@ -17,12 +17,12 @@
 
 package org.apache.spark.mllib.classification
 
+import org.json4s.{DefaultFormats, JValue}
+
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.mllib.linalg.Vector
-import org.apache.spark.mllib.util.Loader
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row}
 
 /**
  * :: Experimental ::
@@ -60,16 +60,10 @@ private[mllib] object ClassificationModel {
 
   /**
    * Helper method for loading GLM classification model metadata.
-   *
-   * @param modelClass  String name for model class (used for error messages)
    * @return (numFeatures, numClasses)
    */
-  def getNumFeaturesClasses(metadata: DataFrame, modelClass: String, path: 
String): (Int, Int) = {
-    metadata.select("numFeatures", "numClasses").take(1)(0) match {
-      case Row(nFeatures: Int, nClasses: Int) => (nFeatures, nClasses)
-      case _ => throw new Exception(s"$modelClass unable to load" +
-        s" numFeatures, numClasses from metadata: 
${Loader.metadataPath(path)}")
-    }
+  def getNumFeaturesClasses(metadata: JValue): (Int, Int) = {
+    implicit val formats = DefaultFormats
+    ((metadata \ "numFeatures").extract[Int], (metadata \ 
"numClasses").extract[Int])
   }
-
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e26c1499/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
 
b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index 9a391bf..420d6e2 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -173,8 +173,7 @@ object LogisticRegressionModel extends 
Loader[LogisticRegressionModel] {
     val classNameV1_0 = 
"org.apache.spark.mllib.classification.LogisticRegressionModel"
     (loadedClassName, version) match {
       case (className, "1.0") if className == classNameV1_0 =>
-        val (numFeatures, numClasses) =
-          ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, 
path)
+        val (numFeatures, numClasses) = 
ClassificationModel.getNumFeaturesClasses(metadata)
         val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, 
classNameV1_0)
         // numFeatures, numClasses, weights are checked in model initialization
         val model =

http://git-wip-us.apache.org/repos/asf/spark/blob/e26c1499/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index d9ce282..f9142bc 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -18,15 +18,16 @@
 package org.apache.spark.mllib.classification
 
 import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => 
brzArgmax, sum => brzSum}
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
 
-import org.apache.spark.{SparkContext, SparkException, Logging}
+import org.apache.spark.{Logging, SparkContext, SparkException}
 import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector}
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.util.{Loader, Saveable}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, SQLContext}
 
-
 /**
  * Model for Naive Bayes Classifiers.
  *
@@ -78,7 +79,7 @@ class NaiveBayesModel private[mllib] (
 
 object NaiveBayesModel extends Loader[NaiveBayesModel] {
 
-  import Loader._
+  import org.apache.spark.mllib.util.Loader._
 
   private object SaveLoadV1_0 {
 
@@ -95,10 +96,10 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
       import sqlContext.implicits._
 
       // Create JSON metadata.
-      val metadataRDD =
-        sc.parallelize(Seq((thisClassName, thisFormatVersion, 
data.theta(0).size, data.pi.size)), 1)
-          .toDataFrame("class", "version", "numFeatures", "numClasses")
-      metadataRDD.toJSON.saveAsTextFile(metadataPath(path))
+      val metadata = compact(render(
+        ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
+          ("numFeatures" -> data.theta(0).length) ~ ("numClasses" -> 
data.pi.length)))
+      sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))
 
       // Create Parquet data.
       val dataRDD: DataFrame = sc.parallelize(Seq(data), 1)
@@ -126,8 +127,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
     val classNameV1_0 = SaveLoadV1_0.thisClassName
     (loadedClassName, version) match {
       case (className, "1.0") if className == classNameV1_0 =>
-        val (numFeatures, numClasses) =
-          ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, 
path)
+        val (numFeatures, numClasses) = 
ClassificationModel.getNumFeaturesClasses(metadata)
         val model = SaveLoadV1_0.load(sc, path)
         assert(model.pi.size == numClasses,
           s"NaiveBayesModel.load expected $numClasses classes," +

http://git-wip-us.apache.org/repos/asf/spark/blob/e26c1499/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
index 24d31e6..cfc7f86 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
@@ -23,10 +23,9 @@ import 
org.apache.spark.mllib.classification.impl.GLMClassificationModel
 import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.mllib.optimization._
 import org.apache.spark.mllib.regression._
-import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader}
+import org.apache.spark.mllib.util.{DataValidators, Loader, Saveable}
 import org.apache.spark.rdd.RDD
 
-
 /**
  * Model for Support Vector Machines (SVMs).
  *
@@ -97,8 +96,7 @@ object SVMModel extends Loader[SVMModel] {
     val classNameV1_0 = "org.apache.spark.mllib.classification.SVMModel"
     (loadedClassName, version) match {
       case (className, "1.0") if className == classNameV1_0 =>
-        val (numFeatures, numClasses) =
-          ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, 
path)
+        val (numFeatures, numClasses) = 
ClassificationModel.getNumFeaturesClasses(metadata)
         val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, 
classNameV1_0)
         val model = new SVMModel(data.weights, data.intercept)
         assert(model.weights.size == numFeatures, s"SVMModel.load with 
numFeatures=$numFeatures" +

http://git-wip-us.apache.org/repos/asf/spark/blob/e26c1499/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
 
b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
index 8d60057..1d11896 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
@@ -17,10 +17,13 @@
 
 package org.apache.spark.mllib.classification.impl
 
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
 import org.apache.spark.SparkContext
 import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.mllib.util.Loader
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{Row, SQLContext}
 
 /**
  * Helper class for import/export of GLM classification models.
@@ -52,16 +55,14 @@ private[classification] object GLMClassificationModel {
       import sqlContext.implicits._
 
       // Create JSON metadata.
-      val metadataRDD =
-        sc.parallelize(Seq((modelClass, thisFormatVersion, numFeatures, 
numClasses)), 1)
-          .toDataFrame("class", "version", "numFeatures", "numClasses")
-      metadataRDD.toJSON.saveAsTextFile(Loader.metadataPath(path))
+      val metadata = compact(render(
+        ("class" -> modelClass) ~ ("version" -> thisFormatVersion) ~
+        ("numFeatures" -> numFeatures) ~ ("numClasses" -> numClasses)))
+      sc.parallelize(Seq(metadata), 
1).saveAsTextFile(Loader.metadataPath(path))
 
       // Create Parquet data.
       val data = Data(weights, intercept, threshold)
-      val dataRDD: DataFrame = sc.parallelize(Seq(data), 1)
-      // TODO: repartition with 1 partition after SPARK-5532 gets fixed
-      dataRDD.saveAsParquetFile(Loader.dataPath(path))
+      sc.parallelize(Seq(data), 1).saveAsParquetFile(Loader.dataPath(path))
     }
 
     /**

http://git-wip-us.apache.org/repos/asf/spark/blob/e26c1499/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
 
b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
index 16979c9..a3a3b5d 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
@@ -22,6 +22,9 @@ import java.lang.{Integer => JavaInteger}
 
 import org.apache.hadoop.fs.Path
 import org.jblas.DoubleMatrix
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
 
 import org.apache.spark.{Logging, SparkContext}
 import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
@@ -153,7 +156,7 @@ object MatrixFactorizationModel extends 
Loader[MatrixFactorizationModel] {
   import org.apache.spark.mllib.util.Loader._
 
   override def load(sc: SparkContext, path: String): MatrixFactorizationModel 
= {
-    val (loadedClassName, formatVersion, metadata) = loadMetadata(sc, path)
+    val (loadedClassName, formatVersion, _) = loadMetadata(sc, path)
     val classNameV1_0 = SaveLoadV1_0.thisClassName
     (loadedClassName, formatVersion) match {
       case (className, "1.0") if className == classNameV1_0 =>
@@ -181,19 +184,20 @@ object MatrixFactorizationModel extends 
Loader[MatrixFactorizationModel] {
       val sc = model.userFeatures.sparkContext
       val sqlContext = new SQLContext(sc)
       import sqlContext.implicits._
-      val metadata = (thisClassName, thisFormatVersion, model.rank)
-      val metadataRDD = sc.parallelize(Seq(metadata), 1).toDataFrame("class", 
"version", "rank")
-      metadataRDD.toJSON.saveAsTextFile(metadataPath(path))
+      val metadata = compact(render(
+        ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ 
("rank" -> model.rank)))
+      sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))
       model.userFeatures.toDataFrame("id", 
"features").saveAsParquetFile(userPath(path))
       model.productFeatures.toDataFrame("id", 
"features").saveAsParquetFile(productPath(path))
     }
 
     def load(sc: SparkContext, path: String): MatrixFactorizationModel = {
+      implicit val formats = DefaultFormats
       val sqlContext = new SQLContext(sc)
       val (className, formatVersion, metadata) = loadMetadata(sc, path)
       assert(className == thisClassName)
       assert(formatVersion == thisFormatVersion)
-      val rank = metadata.select("rank").first().getInt(0)
+      val rank = (metadata \ "rank").extract[Int]
       val userFeatures = sqlContext.parquetFile(userPath(path))
         .map { case Row(id: Int, features: Seq[Double]) =>
           (id, features.toArray)

http://git-wip-us.apache.org/repos/asf/spark/blob/e26c1499/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
index 1159e59..e8b0381 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
@@ -58,7 +58,7 @@ object LassoModel extends Loader[LassoModel] {
     val classNameV1_0 = "org.apache.spark.mllib.regression.LassoModel"
     (loadedClassName, version) match {
       case (className, "1.0") if className == classNameV1_0 =>
-        val numFeatures = RegressionModel.getNumFeatures(metadata, 
classNameV1_0, path)
+        val numFeatures = RegressionModel.getNumFeatures(metadata)
         val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, 
classNameV1_0, numFeatures)
         new LassoModel(data.weights, data.intercept)
       case _ => throw new Exception(

http://git-wip-us.apache.org/repos/asf/spark/blob/e26c1499/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
index 0136dcf..6fa7ad5 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
@@ -58,7 +58,7 @@ object LinearRegressionModel extends 
Loader[LinearRegressionModel] {
     val classNameV1_0 = 
"org.apache.spark.mllib.regression.LinearRegressionModel"
     (loadedClassName, version) match {
       case (className, "1.0") if className == classNameV1_0 =>
-        val numFeatures = RegressionModel.getNumFeatures(metadata, 
classNameV1_0, path)
+        val numFeatures = RegressionModel.getNumFeatures(metadata)
         val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, 
classNameV1_0, numFeatures)
         new LinearRegressionModel(data.weights, data.intercept)
       case _ => throw new Exception(

http://git-wip-us.apache.org/repos/asf/spark/blob/e26c1499/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala
index 843e59b..214ac4d 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala
@@ -17,12 +17,12 @@
 
 package org.apache.spark.mllib.regression
 
+import org.json4s.{DefaultFormats, JValue}
+
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.mllib.linalg.Vector
-import org.apache.spark.mllib.util.Loader
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row}
 
 @Experimental
 trait RegressionModel extends Serializable {
@@ -55,16 +55,10 @@ private[mllib] object RegressionModel {
 
   /**
    * Helper method for loading GLM regression model metadata.
-   *
-   * @param modelClass  String name for model class (used for error messages)
    * @return numFeatures
    */
-  def getNumFeatures(metadata: DataFrame, modelClass: String, path: String): 
Int = {
-    metadata.select("numFeatures").take(1)(0) match {
-      case Row(nFeatures: Int) => nFeatures
-      case _ => throw new Exception(s"$modelClass unable to load" +
-        s" numFeatures from metadata: ${Loader.metadataPath(path)}")
-    }
+  def getNumFeatures(metadata: JValue): Int = {
+    implicit val formats = DefaultFormats
+    (metadata \ "numFeatures").extract[Int]
   }
-
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e26c1499/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
index f2a5f1d..8838ca8 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
@@ -59,7 +59,7 @@ object RidgeRegressionModel extends 
Loader[RidgeRegressionModel] {
     val classNameV1_0 = 
"org.apache.spark.mllib.regression.RidgeRegressionModel"
     (loadedClassName, version) match {
       case (className, "1.0") if className == classNameV1_0 =>
-        val numFeatures = RegressionModel.getNumFeatures(metadata, 
classNameV1_0, path)
+        val numFeatures = RegressionModel.getNumFeatures(metadata)
         val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, 
classNameV1_0, numFeatures)
         new RidgeRegressionModel(data.weights, data.intercept)
       case _ => throw new Exception(

http://git-wip-us.apache.org/repos/asf/spark/blob/e26c1499/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
 
b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
index 838100e..f75de6f 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
@@ -17,6 +17,9 @@
 
 package org.apache.spark.mllib.regression.impl
 
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
 import org.apache.spark.SparkContext
 import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.mllib.util.Loader
@@ -48,10 +51,10 @@ private[regression] object GLMRegressionModel {
       import sqlContext.implicits._
 
       // Create JSON metadata.
-      val metadataRDD =
-        sc.parallelize(Seq((modelClass, thisFormatVersion, weights.size)), 1)
-          .toDataFrame("class", "version", "numFeatures")
-      metadataRDD.toJSON.saveAsTextFile(Loader.metadataPath(path))
+      val metadata = compact(render(
+        ("class" -> modelClass) ~ ("version" -> thisFormatVersion) ~
+          ("numFeatures" -> weights.size)))
+      sc.parallelize(Seq(metadata), 
1).saveAsTextFile(Loader.metadataPath(path))
 
       // Create Parquet data.
       val data = Data(weights, intercept)

http://git-wip-us.apache.org/repos/asf/spark/blob/e26c1499/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index b3e8ed9..9a586b9 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -17,14 +17,13 @@
 
 package org.apache.spark.mllib.tree
 
-import scala.collection.JavaConverters._
 import scala.collection.mutable
+import scala.collection.JavaConverters._
 import scala.collection.mutable.ArrayBuffer
 
-
+import org.apache.spark.Logging
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.Logging
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo
 import org.apache.spark.mllib.tree.configuration.Strategy
@@ -32,13 +31,10 @@ import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.mllib.tree.configuration.FeatureType._
 import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
 import org.apache.spark.mllib.tree.impl._
-import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity}
 import org.apache.spark.mllib.tree.impurity._
 import org.apache.spark.mllib.tree.model._
 import org.apache.spark.rdd.RDD
 import org.apache.spark.util.random.XORShiftRandom
-import org.apache.spark.SparkContext._
-
 
 /**
  * :: Experimental ::

http://git-wip-us.apache.org/repos/asf/spark/blob/e26c1499/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index 89ecf37..373192a 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -19,6 +19,10 @@ package org.apache.spark.mllib.tree.model
 
 import scala.collection.mutable
 
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
 import org.apache.spark.SparkContext
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.JavaRDD
@@ -184,10 +188,10 @@ object DecisionTreeModel extends 
Loader[DecisionTreeModel] {
       import sqlContext.implicits._
 
       // Create JSON metadata.
-      val metadataRDD = sc.parallelize(
-        Seq((thisClassName, thisFormatVersion, model.algo.toString, 
model.numNodes)), 1)
-        .toDataFrame("class", "version", "algo", "numNodes")
-      metadataRDD.toJSON.saveAsTextFile(Loader.metadataPath(path))
+      val metadata = compact(render(
+        ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
+          ("algo" -> model.algo.toString) ~ ("numNodes" -> model.numNodes)))
+      sc.parallelize(Seq(metadata), 
1).saveAsTextFile(Loader.metadataPath(path))
 
       // Create Parquet data.
       val nodes = model.topNode.subtreeIterator.toSeq
@@ -269,20 +273,10 @@ object DecisionTreeModel extends 
Loader[DecisionTreeModel] {
   }
 
   override def load(sc: SparkContext, path: String): DecisionTreeModel = {
+    implicit val formats = DefaultFormats
     val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
-    val (algo: String, numNodes: Int) = try {
-      val algo_numNodes = metadata.select("algo", "numNodes").collect()
-      assert(algo_numNodes.length == 1)
-      algo_numNodes(0) match {
-        case Row(a: String, n: Int) => (a, n)
-      }
-    } catch {
-      // Catch both Error and Exception since the checks above can throw 
either.
-      case e: Throwable =>
-        throw new Exception(
-          s"Unable to load DecisionTreeModel metadata from: 
${Loader.metadataPath(path)}."
-          + s"  Error message: ${e.getMessage}")
-    }
+    val algo = (metadata \ "algo").extract[String]
+    val numNodes = (metadata \ "numNodes").extract[Int]
     val classNameV1_0 = SaveLoadV1_0.thisClassName
     (loadedClassName, version) match {
       case (className, "1.0") if className == classNameV1_0 =>

http://git-wip-us.apache.org/repos/asf/spark/blob/e26c1499/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
index 23bd46b..dbd69dc 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
@@ -20,18 +20,20 @@ package org.apache.spark.mllib.tree.model
 import scala.collection.mutable
 
 import com.github.fommil.netlib.BLAS.{getInstance => blas}
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
 
 import org.apache.spark.SparkContext
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.mllib.linalg.Vector
-import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.mllib.tree.configuration.Algo
+import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._
-import org.apache.spark.mllib.util.{Saveable, Loader}
+import org.apache.spark.mllib.util.{Loader, Saveable}
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, SQLContext}
-
+import org.apache.spark.sql.SQLContext
 
 /**
  * :: Experimental ::
@@ -59,11 +61,11 @@ class RandomForestModel(override val algo: Algo, override 
val trees: Array[Decis
 object RandomForestModel extends Loader[RandomForestModel] {
 
   override def load(sc: SparkContext, path: String): RandomForestModel = {
-    val (loadedClassName, version, metadataRDD) = Loader.loadMetadata(sc, path)
+    val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, 
path)
     val classNameV1_0 = SaveLoadV1_0.thisClassName
     (loadedClassName, version) match {
       case (className, "1.0") if className == classNameV1_0 =>
-        val metadata = 
TreeEnsembleModel.SaveLoadV1_0.readMetadata(metadataRDD, path)
+        val metadata = 
TreeEnsembleModel.SaveLoadV1_0.readMetadata(jsonMetadata)
         assert(metadata.treeWeights.forall(_ == 1.0))
         val trees =
           TreeEnsembleModel.SaveLoadV1_0.loadTrees(sc, path, metadata.treeAlgo)
@@ -110,11 +112,11 @@ class GradientBoostedTreesModel(
 object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
 
   override def load(sc: SparkContext, path: String): GradientBoostedTreesModel 
= {
-    val (loadedClassName, version, metadataRDD) = Loader.loadMetadata(sc, path)
+    val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, 
path)
     val classNameV1_0 = SaveLoadV1_0.thisClassName
     (loadedClassName, version) match {
       case (className, "1.0") if className == classNameV1_0 =>
-        val metadata = 
TreeEnsembleModel.SaveLoadV1_0.readMetadata(metadataRDD, path)
+        val metadata = 
TreeEnsembleModel.SaveLoadV1_0.readMetadata(jsonMetadata)
         assert(metadata.combiningStrategy == Sum.toString)
         val trees =
           TreeEnsembleModel.SaveLoadV1_0.loadTrees(sc, path, metadata.treeAlgo)
@@ -252,7 +254,7 @@ private[tree] object TreeEnsembleModel {
 
   object SaveLoadV1_0 {
 
-    import DecisionTreeModel.SaveLoadV1_0.{NodeData, constructTrees}
+    import 
org.apache.spark.mllib.tree.model.DecisionTreeModel.SaveLoadV1_0.{NodeData, 
constructTrees}
 
     def thisFormatVersion = "1.0"
 
@@ -276,11 +278,13 @@ private[tree] object TreeEnsembleModel {
       import sqlContext.implicits._
 
       // Create JSON metadata.
-      val metadata = Metadata(model.algo.toString, 
model.trees(0).algo.toString,
+      implicit val format = DefaultFormats
+      val ensembleMetadata = Metadata(model.algo.toString, 
model.trees(0).algo.toString,
         model.combiningStrategy.toString, model.treeWeights)
-      val metadataRDD = sc.parallelize(Seq((className, thisFormatVersion, 
metadata)), 1)
-        .toDataFrame("class", "version", "metadata")
-      metadataRDD.toJSON.saveAsTextFile(Loader.metadataPath(path))
+      val metadata = compact(render(
+        ("class" -> className) ~ ("version" -> thisFormatVersion) ~
+          ("metadata" -> Extraction.decompose(ensembleMetadata))))
+      sc.parallelize(Seq(metadata), 
1).saveAsTextFile(Loader.metadataPath(path))
 
       // Create Parquet data.
       val dataRDD = sc.parallelize(model.trees.zipWithIndex).flatMap { case 
(tree, treeId) =>
@@ -290,24 +294,11 @@ private[tree] object TreeEnsembleModel {
     }
 
     /**
-     * Read metadata from the loaded metadata DataFrame.
-     * @param path  Path for loading data, used for debug messages.
+     * Read metadata from the loaded JSON metadata.
      */
-    def readMetadata(metadata: DataFrame, path: String): Metadata = {
-      try {
-        // We rely on the try-catch for schema checking rather than creating a 
schema just for this.
-        val metadataArray = metadata.select("metadata.algo", 
"metadata.treeAlgo",
-          "metadata.combiningStrategy", "metadata.treeWeights").collect()
-        assert(metadataArray.size == 1)
-        Metadata(metadataArray(0).getString(0), metadataArray(0).getString(1),
-          metadataArray(0).getString(2), 
metadataArray(0).getAs[Seq[Double]](3).toArray)
-      } catch {
-        // Catch both Error and Exception since the checks above can throw 
either.
-        case e: Throwable =>
-          throw new Exception(
-            s"Unable to load TreeEnsembleModel metadata from: 
${Loader.metadataPath(path)}."
-              + s"  Error message: ${e.getMessage}")
-      }
+    def readMetadata(metadata: JValue): Metadata = {
+      implicit val formats = DefaultFormats
+      (metadata \ "metadata").extract[Metadata]
     }
 
     /**

http://git-wip-us.apache.org/repos/asf/spark/blob/e26c1499/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala
index 56b77a7..4458340 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala
@@ -20,13 +20,13 @@ package org.apache.spark.mllib.util
 import scala.reflect.runtime.universe.TypeTag
 
 import org.apache.hadoop.fs.Path
+import org.json4s._
+import org.json4s.jackson.JsonMethods._
 
 import org.apache.spark.SparkContext
 import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
 import org.apache.spark.sql.catalyst.ScalaReflection
-import org.apache.spark.sql.types.{DataType, StructType, StructField}
-
+import org.apache.spark.sql.types.{DataType, StructField, StructType}
 
 /**
  * :: DeveloperApi ::
@@ -120,20 +120,11 @@ private[mllib] object Loader {
    * Load metadata from the given path.
    * @return (class name, version, metadata)
    */
-  def loadMetadata(sc: SparkContext, path: String): (String, String, 
DataFrame) = {
-    val sqlContext = new SQLContext(sc)
-    val metadata = sqlContext.jsonFile(metadataPath(path))
-    val (clazz, version) = try {
-      val metadataArray = metadata.select("class", "version").take(1)
-      assert(metadataArray.size == 1)
-      metadataArray(0) match {
-        case Row(clazz: String, version: String) => (clazz, version)
-      }
-    } catch {
-      case e: Exception =>
-        throw new Exception(s"Unable to load model metadata from: 
${metadataPath(path)}")
-    }
+  def loadMetadata(sc: SparkContext, path: String): (String, String, JValue) = 
{
+    implicit val formats = DefaultFormats
+    val metadata = parse(sc.textFile(metadataPath(path)).first())
+    val clazz = (metadata \ "class").extract[String]
+    val version = (metadata \ "version").extract[String]
     (clazz, version, metadata)
   }
-
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to