Repository: spark
Updated Branches:
  refs/heads/branch-1.6 39c8a995d -> bcc6813dd


[SPARK-6790][ML] Add spark.ml LinearRegression import/export

This replaces [https://github.com/apache/spark/pull/9656] with updates.

fayeshine should be the main author when this PR is committed.

CC: mengxr fayeshine

Author: Wenjian Huang <[email protected]>
Author: Joseph K. Bradley <[email protected]>

Closes #9814 from jkbradley/fayeshine-patch-6790.

(cherry picked from commit 045a4f045821dcf60442f0600c2df1b79bddb536)
Signed-off-by: Xiangrui Meng <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bcc6813d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bcc6813d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bcc6813d

Branch: refs/heads/branch-1.6
Commit: bcc6813dd8b050fd4bf9dbd2708e413b43b3e80d
Parents: 39c8a99
Author: Wenjian Huang <[email protected]>
Authored: Wed Nov 18 13:06:25 2015 -0800
Committer: Xiangrui Meng <[email protected]>
Committed: Wed Nov 18 13:06:32 2015 -0800

----------------------------------------------------------------------
 .../spark/ml/regression/LinearRegression.scala  | 77 +++++++++++++++++++-
 .../ml/regression/LinearRegressionSuite.scala   | 34 ++++++++-
 2 files changed, 106 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/bcc6813d/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala 
b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 913140e..ca55d59 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -22,6 +22,7 @@ import scala.collection.mutable
 import breeze.linalg.{DenseVector => BDV}
 import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => 
BreezeLBFGS, OWLQN => BreezeOWLQN}
 import breeze.stats.distributions.StudentsT
+import org.apache.hadoop.fs.Path
 
 import org.apache.spark.{Logging, SparkException}
 import org.apache.spark.ml.feature.Instance
@@ -30,7 +31,7 @@ import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.ml.PredictorParams
 import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util._
 import org.apache.spark.mllib.evaluation.RegressionMetrics
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.linalg.BLAS._
@@ -65,7 +66,7 @@ private[regression] trait LinearRegressionParams extends 
PredictorParams
 @Experimental
 class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: 
String)
   extends Regressor[Vector, LinearRegression, LinearRegressionModel]
-  with LinearRegressionParams with Logging {
+  with LinearRegressionParams with Writable with Logging {
 
   @Since("1.4.0")
   def this() = this(Identifiable.randomUID("linReg"))
@@ -341,6 +342,19 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") 
override val uid: String
 
   @Since("1.4.0")
   override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra)
+
+  @Since("1.6.0")
+  override def write: Writer = new DefaultParamsWriter(this)
+}
+
+@Since("1.6.0")
+object LinearRegression extends Readable[LinearRegression] {
+
+  @Since("1.6.0")
+  override def read: Reader[LinearRegression] = new 
DefaultParamsReader[LinearRegression]
+
+  @Since("1.6.0")
+  override def load(path: String): LinearRegression = read.load(path)
 }
 
 /**
@@ -354,7 +368,7 @@ class LinearRegressionModel private[ml] (
     val coefficients: Vector,
     val intercept: Double)
   extends RegressionModel[Vector, LinearRegressionModel]
-  with LinearRegressionParams {
+  with LinearRegressionParams with Writable {
 
   private var trainingSummary: Option[LinearRegressionTrainingSummary] = None
 
@@ -422,6 +436,63 @@ class LinearRegressionModel private[ml] (
     if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get)
     newModel.setParent(parent)
   }
+
+  /**
+   * Returns a [[Writer]] instance for this ML instance.
+   *
+   * For [[LinearRegressionModel]], this does NOT currently save the training 
[[summary]].
+   * An option to save [[summary]] may be added in the future.
+   *
+   * This also does not save the [[parent]] currently.
+   */
+  @Since("1.6.0")
+  override def write: Writer = new 
LinearRegressionModel.LinearRegressionModelWriter(this)
+}
+
+@Since("1.6.0")
+object LinearRegressionModel extends Readable[LinearRegressionModel] {
+
+  @Since("1.6.0")
+  override def read: Reader[LinearRegressionModel] = new 
LinearRegressionModelReader
+
+  @Since("1.6.0")
+  override def load(path: String): LinearRegressionModel = read.load(path)
+
+  /** [[Writer]] instance for [[LinearRegressionModel]] */
+  private[LinearRegressionModel] class LinearRegressionModelWriter(instance: 
LinearRegressionModel)
+    extends Writer with Logging {
+
+    private case class Data(intercept: Double, coefficients: Vector)
+
+    override protected def saveImpl(path: String): Unit = {
+      // Save metadata and Params
+      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      // Save model data: intercept, coefficients
+      val data = Data(instance.intercept, instance.coefficients)
+      val dataPath = new Path(path, "data").toString
+      
sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath)
+    }
+  }
+
+  private class LinearRegressionModelReader extends 
Reader[LinearRegressionModel] {
+
+    /** Checked against metadata when loading model */
+    private val className = 
"org.apache.spark.ml.regression.LinearRegressionModel"
+
+    override def load(path: String): LinearRegressionModel = {
+      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+
+      val dataPath = new Path(path, "data").toString
+      val data = sqlContext.read.format("parquet").load(dataPath)
+        .select("intercept", "coefficients").head()
+      val intercept = data.getDouble(0)
+      val coefficients = data.getAs[Vector](1)
+      val model = new LinearRegressionModel(metadata.uid, coefficients, 
intercept)
+
+      DefaultParamsReader.getAndSetParams(model, metadata)
+      model
+    }
+  }
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/bcc6813d/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index a1d86fe..2bdc0e1 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -22,14 +22,15 @@ import scala.util.Random
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.feature.Instance
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.linalg.{Vector, DenseVector, Vectors}
 import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
 import org.apache.spark.mllib.util.TestingUtils._
 import org.apache.spark.sql.{DataFrame, Row}
 
-class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
+class LinearRegressionSuite
+  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
 
   private val seed: Int = 42
   @transient var datasetWithDenseFeature: DataFrame = _
@@ -854,4 +855,33 @@ class LinearRegressionSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     model.summary.tValues.zip(tValsR).foreach{ x => assert(x._1 ~== x._2 
absTol 1E-3) }
     model.summary.pValues.zip(pValsR).foreach{ x => assert(x._1 ~== x._2 
absTol 1E-3) }
   }
+
+  test("read/write") {
+    def checkModelData(model: LinearRegressionModel, model2: 
LinearRegressionModel): Unit = {
+      assert(model.intercept === model2.intercept)
+      assert(model.coefficients === model2.coefficients)
+    }
+    val lr = new LinearRegression()
+    testEstimatorAndModelReadWrite(lr, datasetWithWeight, 
LinearRegressionSuite.allParamSettings,
+      checkModelData)
+  }
+}
+
+object LinearRegressionSuite {
+
+  /**
+   * Mapping from all Params to valid settings which differ from the defaults.
+   * This is useful for tests which need to exercise all Params, such as 
save/load.
+   * This excludes input columns to simplify some tests.
+   */
+  val allParamSettings: Map[String, Any] = Map(
+    "predictionCol" -> "myPrediction",
+    "regParam" -> 0.01,
+    "elasticNetParam" -> 0.1,
+    "maxIter" -> 2,  // intentionally small
+    "fitIntercept" -> true,
+    "tol" -> 0.8,
+    "standardization" -> false,
+    "solver" -> "l-bfgs"
+  )
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to