Repository: spark Updated Branches: refs/heads/master cad29a40b -> 0dd06485c
[SPARK-13615][ML] GeneralizedLinearRegression supports save/load ## What changes were proposed in this pull request? ```GeneralizedLinearRegression``` supports ```save/load```. cc mengxr ## How was this patch tested? unit test. Author: Yanbo Liang <yblia...@gmail.com> Closes #11465 from yanboliang/spark-13615. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0dd06485 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0dd06485 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0dd06485 Branch: refs/heads/master Commit: 0dd06485c4222a896c0d1ee6a04d30043de3626c Parents: cad29a4 Author: Yanbo Liang <yblia...@gmail.com> Authored: Wed Mar 9 11:59:22 2016 -0800 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Wed Mar 9 11:59:22 2016 -0800 ---------------------------------------------------------------------- .../GeneralizedLinearRegression.scala | 74 +++++++++++++++++--- .../GeneralizedLinearRegressionSuite.scala | 32 ++++++++- 2 files changed, 96 insertions(+), 10 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/0dd06485/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index a850dfe..de1dff9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.regression import breeze.stats.distributions.{Gaussian => GD} +import org.apache.hadoop.fs.Path import org.apache.spark.{Logging, SparkException} import org.apache.spark.annotation.{Experimental, Since} @@ -26,7 +27,7 @@ import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.optim._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{BLAS, Vector} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} @@ -106,7 +107,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam @Since("2.0.0") class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val uid: String) extends Regressor[Vector, GeneralizedLinearRegression, GeneralizedLinearRegressionModel] - with GeneralizedLinearRegressionBase with Logging { + with GeneralizedLinearRegressionBase with DefaultParamsWritable with Logging { import GeneralizedLinearRegression._ @@ -236,10 +237,13 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val } @Since("2.0.0") -private[ml] object GeneralizedLinearRegression { +object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLinearRegression] { + + @Since("2.0.0") + override def load(path: String): GeneralizedLinearRegression = super.load(path) /** Set of family and link pairs that GeneralizedLinearRegression supports. */ - lazy val supportedFamilyAndLinkPairs = Set( + private[ml] lazy val supportedFamilyAndLinkPairs = Set( Gaussian -> Identity, Gaussian -> Log, Gaussian -> Inverse, Binomial -> Logit, Binomial -> Probit, Binomial -> CLogLog, Poisson -> Log, Poisson -> Identity, Poisson -> Sqrt, @@ -247,12 +251,12 @@ private[ml] object GeneralizedLinearRegression { ) /** Set of family names that GeneralizedLinearRegression supports. */ - lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name) + private[ml] lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name) /** Set of link names that GeneralizedLinearRegression supports. */ - lazy val supportedLinkNames = supportedFamilyAndLinkPairs.map(_._2.name) + private[ml] lazy val supportedLinkNames = supportedFamilyAndLinkPairs.map(_._2.name) - val epsilon: Double = 1E-16 + private[ml] val epsilon: Double = 1E-16 /** * Wrapper of family and link combination used in the model. @@ -552,7 +556,7 @@ class GeneralizedLinearRegressionModel private[ml] ( @Since("2.0.0") val coefficients: Vector, @Since("2.0.0") val intercept: Double) extends RegressionModel[Vector, GeneralizedLinearRegressionModel] - with GeneralizedLinearRegressionBase { + with GeneralizedLinearRegressionBase with MLWritable { import GeneralizedLinearRegression._ @@ -574,4 +578,58 @@ class GeneralizedLinearRegressionModel private[ml] ( copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept), extra) .setParent(parent) } + + @Since("2.0.0") + override def write: MLWriter = + new GeneralizedLinearRegressionModel.GeneralizedLinearRegressionModelWriter(this) +} + +@Since("2.0.0") +object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegressionModel] { + + @Since("2.0.0") + override def read: MLReader[GeneralizedLinearRegressionModel] = + new GeneralizedLinearRegressionModelReader + + @Since("2.0.0") + override def load(path: String): GeneralizedLinearRegressionModel = super.load(path) + + /** [[MLWriter]] instance for [[GeneralizedLinearRegressionModel]] */ + private[GeneralizedLinearRegressionModel] + class GeneralizedLinearRegressionModelWriter(instance: GeneralizedLinearRegressionModel) + extends MLWriter with Logging { + + private case class Data(intercept: Double, coefficients: Vector) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: intercept, coefficients + val data = Data(instance.intercept, instance.coefficients) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class GeneralizedLinearRegressionModelReader + extends MLReader[GeneralizedLinearRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[GeneralizedLinearRegressionModel].getName + + override def load(path: String): GeneralizedLinearRegressionModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("intercept", "coefficients").head() + val intercept = data.getDouble(0) + val coefficients = data.getAs[Vector](1) + + val model = new GeneralizedLinearRegressionModel(metadata.uid, coefficients, intercept) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } } http://git-wip-us.apache.org/repos/asf/spark/blob/0dd06485/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index 8bfa985..618304a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -21,7 +21,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.linalg.{BLAS, DenseVector, Vectors} import org.apache.spark.mllib.random._ @@ -30,7 +30,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} -class GeneralizedLinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { +class GeneralizedLinearRegressionSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { private val seed: Int = 42 @transient var datasetGaussianIdentity: DataFrame = _ @@ -464,10 +465,37 @@ class GeneralizedLinearRegressionSuite extends SparkFunSuite with MLlibTestSpark } } } + + test("read/write") { + def checkModelData( + model: GeneralizedLinearRegressionModel, + model2: GeneralizedLinearRegressionModel): Unit = { + assert(model.intercept === model2.intercept) + assert(model.coefficients.toArray === model2.coefficients.toArray) + } + + val glr = new GeneralizedLinearRegression() + testEstimatorAndModelReadWrite(glr, datasetPoissonLog, + GeneralizedLinearRegressionSuite.allParamSettings, checkModelData) + } } object GeneralizedLinearRegressionSuite { + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "family" -> "poisson", + "link" -> "log", + "fitIntercept" -> true, + "maxIter" -> 2, // intentionally small + "tol" -> 0.8, + "regParam" -> 0.01, + "predictionCol" -> "myPrediction") + def generateGeneralizedLinearRegressionInput( intercept: Double, coefficients: Array[Double], --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org