Repository: spark Updated Branches: refs/heads/master 32fa611b1 -> fb9027321
[SPARK-7047] [ML] ml.Model optional parent support Made Model.parent transient. Added Model.hasParent to test for null parent CC: mengxr Author: Joseph K. Bradley <[email protected]> Closes #5914 from jkbradley/parent-optional and squashes the following commits: d501774 [Joseph K. Bradley] Made Model.parent transient. Added Model.hasParent to test for null parent Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fb902732 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fb902732 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fb902732 Branch: refs/heads/master Commit: fb90273212dc7241c9a0c3446e25e0e0b9377750 Parents: 32fa611 Author: Joseph K. Bradley <[email protected]> Authored: Tue May 19 10:55:21 2015 -0700 Committer: Xiangrui Meng <[email protected]> Committed: Tue May 19 10:55:21 2015 -0700 ---------------------------------------------------------------------- mllib/src/main/scala/org/apache/spark/ml/Model.scala | 5 ++++- .../spark/ml/classification/LogisticRegressionSuite.scala | 1 + .../spark/ml/classification/RandomForestClassifierSuite.scala | 2 ++ 3 files changed, 7 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/fb902732/mllib/src/main/scala/org/apache/spark/ml/Model.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/Model.scala b/mllib/src/main/scala/org/apache/spark/ml/Model.scala index 7fd5153..70e7495 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Model.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Model.scala @@ -32,7 +32,7 @@ abstract class Model[M <: Model[M]] extends Transformer { * The parent estimator that produced this model. * Note: For ensembles' component Models, this value can be null. */ - var parent: Estimator[M] = _ + @transient var parent: Estimator[M] = _ /** * Sets the parent of this model (Java API). @@ -42,6 +42,9 @@ abstract class Model[M <: Model[M]] extends Transformer { this.asInstanceOf[M] } + /** Indicates whether this [[Model]] has a corresponding parent. */ + def hasParent: Boolean = parent != null + override def copy(extra: ParamMap): M = { // The default implementation of Params.copy doesn't work for models. throw new NotImplementedError(s"${this.getClass} doesn't implement copy(extra: ParamMap)") http://git-wip-us.apache.org/repos/asf/spark/blob/fb902732/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 4376524..97f9749 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -83,6 +83,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { assert(model.getRawPredictionCol === "rawPrediction") assert(model.getProbabilityCol === "probability") assert(model.intercept !== 0.0) + assert(model.hasParent) } test("logistic regression doesn't fit intercept when fitIntercept is off") { http://git-wip-us.apache.org/repos/asf/spark/blob/fb902732/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 08f86fa..cdbbaca 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -162,5 +162,7 @@ private object RandomForestClassifierSuite { val oldModelAsNew = RandomForestClassificationModel.fromOld( oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) + assert(newModel.hasParent) + assert(!newModel.trees.head.asInstanceOf[DecisionTreeClassificationModel].hasParent) } } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
