Repository: spark Updated Branches: refs/heads/master e88476c8c -> 6c5a837c5
[SPARK-12301][ML] Made all tree and ensemble classes not final ## What changes were proposed in this pull request? There have been continuing requests (e.g., SPARK-7131) for allowing users to extend and modify MLlib models and algorithms. This PR makes tree and ensemble classes, Node types, and Split types in spark.ml no longer final. This matches most other spark.ml algorithms. Constructors for models are still private since we may need to refactor how stats are maintained in tree nodes. ## How was this patch tested? Existing unit tests Author: Joseph K. Bradley <[email protected]> Closes #12711 from jkbradley/final-trees. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6c5a837c Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6c5a837c Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6c5a837c Branch: refs/heads/master Commit: 6c5a837c509233d4008cffeaede111f17fea5289 Parents: e88476c Author: Joseph K. Bradley <[email protected]> Authored: Tue Apr 26 14:44:39 2016 -0700 Committer: Joseph K. Bradley <[email protected]> Committed: Tue Apr 26 14:44:39 2016 -0700 ---------------------------------------------------------------------- .../apache/spark/ml/classification/DecisionTreeClassifier.scala | 4 ++-- .../scala/org/apache/spark/ml/classification/GBTClassifier.scala | 4 ++-- .../apache/spark/ml/classification/RandomForestClassifier.scala | 4 ++-- .../org/apache/spark/ml/regression/DecisionTreeRegressor.scala | 4 ++-- .../main/scala/org/apache/spark/ml/regression/GBTRegressor.scala | 4 ++-- .../org/apache/spark/ml/regression/RandomForestRegressor.scala | 4 ++-- mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala | 4 ++-- mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala | 4 ++-- 8 files changed, 16 insertions(+), 16 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/6c5a837c/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 1943a4a..ecb218e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -44,7 +44,7 @@ import org.apache.spark.sql.Dataset */ @Since("1.4.0") @Experimental -final class DecisionTreeClassifier @Since("1.4.0") ( +class DecisionTreeClassifier @Since("1.4.0") ( @Since("1.4.0") override val uid: String) extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel] with DecisionTreeClassifierParams with DefaultParamsWritable { @@ -138,7 +138,7 @@ object DecisionTreeClassifier extends DefaultParamsReadable[DecisionTreeClassifi */ @Since("1.4.0") @Experimental -final class DecisionTreeClassificationModel private[ml] ( +class DecisionTreeClassificationModel private[ml] ( @Since("1.4.0")override val uid: String, @Since("1.4.0")override val rootNode: Node, @Since("1.6.0")override val numFeatures: Int, http://git-wip-us.apache.org/repos/asf/spark/blob/6c5a837c/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 1bd6dae..e736f01 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -57,7 +57,7 @@ import org.apache.spark.sql.functions._ */ @Since("1.4.0") @Experimental -final class GBTClassifier @Since("1.4.0") ( +class GBTClassifier @Since("1.4.0") ( @Since("1.4.0") override val uid: String) extends Predictor[Vector, GBTClassifier, GBTClassificationModel] with GBTClassifierParams with DefaultParamsWritable with Logging { @@ -170,7 +170,7 @@ object GBTClassifier extends DefaultParamsReadable[GBTClassifier] { */ @Since("1.6.0") @Experimental -final class GBTClassificationModel private[ml]( +class GBTClassificationModel private[ml]( @Since("1.6.0") override val uid: String, private val _trees: Array[DecisionTreeRegressionModel], private val _treeWeights: Array[Double], http://git-wip-us.apache.org/repos/asf/spark/blob/6c5a837c/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index c04ecc8..28364c2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -44,7 +44,7 @@ import org.apache.spark.sql.functions._ */ @Since("1.4.0") @Experimental -final class RandomForestClassifier @Since("1.4.0") ( +class RandomForestClassifier @Since("1.4.0") ( @Since("1.4.0") override val uid: String) extends ProbabilisticClassifier[Vector, RandomForestClassifier, RandomForestClassificationModel] with RandomForestClassifierParams with DefaultParamsWritable { @@ -149,7 +149,7 @@ object RandomForestClassifier extends DefaultParamsReadable[RandomForestClassifi */ @Since("1.4.0") @Experimental -final class RandomForestClassificationModel private[ml] ( +class RandomForestClassificationModel private[ml] ( @Since("1.5.0") override val uid: String, private val _trees: Array[DecisionTreeClassificationModel], @Since("1.6.0") override val numFeatures: Int, http://git-wip-us.apache.org/repos/asf/spark/blob/6c5a837c/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index c04c416..339a8cf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -45,7 +45,7 @@ import org.apache.spark.sql.functions._ */ @Since("1.4.0") @Experimental -final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) +class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel] with DecisionTreeRegressorParams with DefaultParamsWritable { @@ -129,7 +129,7 @@ object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor */ @Since("1.4.0") @Experimental -final class DecisionTreeRegressionModel private[ml] ( +class DecisionTreeRegressionModel private[ml] ( override val uid: String, override val rootNode: Node, override val numFeatures: Int) http://git-wip-us.apache.org/repos/asf/spark/blob/6c5a837c/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index da51cb7..c41fb4b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -57,7 +57,7 @@ import org.apache.spark.sql.functions._ */ @Since("1.4.0") @Experimental -final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) +class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Predictor[Vector, GBTRegressor, GBTRegressionModel] with GBTRegressorParams with DefaultParamsWritable with Logging { @@ -157,7 +157,7 @@ object GBTRegressor extends DefaultParamsReadable[GBTRegressor] { */ @Since("1.4.0") @Experimental -final class GBTRegressionModel private[ml]( +class GBTRegressionModel private[ml]( override val uid: String, private val _trees: Array[DecisionTreeRegressionModel], private val _treeWeights: Array[Double], http://git-wip-us.apache.org/repos/asf/spark/blob/6c5a837c/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 8eaed8b..b6ab2fd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -43,7 +43,7 @@ import org.apache.spark.sql.functions._ */ @Since("1.4.0") @Experimental -final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) +class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Predictor[Vector, RandomForestRegressor, RandomForestRegressionModel] with RandomForestRegressorParams with DefaultParamsWritable { @@ -137,7 +137,7 @@ object RandomForestRegressor extends DefaultParamsReadable[RandomForestRegressor */ @Since("1.4.0") @Experimental -final class RandomForestRegressionModel private[ml] ( +class RandomForestRegressionModel private[ml] ( override val uid: String, private val _trees: Array[DecisionTreeRegressionModel], override val numFeatures: Int) http://git-wip-us.apache.org/repos/asf/spark/blob/6c5a837c/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index b5cb378..f71d28c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -115,7 +115,7 @@ private[ml] object Node { * @param impurity Impurity measure at this node (for training data) */ @DeveloperApi -final class LeafNode private[ml] ( +class LeafNode private[ml] ( override val prediction: Double, override val impurity: Double, override private[ml] val impurityStats: ImpurityCalculator) extends Node { @@ -158,7 +158,7 @@ final class LeafNode private[ml] ( * @param split Information about the test used to split to the left or right child. */ @DeveloperApi -final class InternalNode private[ml] ( +class InternalNode private[ml] ( override val prediction: Double, override val impurity: Double, val gain: Double, http://git-wip-us.apache.org/repos/asf/spark/blob/6c5a837c/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala index 5d11ed0..a428748 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala @@ -75,7 +75,7 @@ private[tree] object Split { * @param numCategories Number of categories for this feature. */ @DeveloperApi -final class CategoricalSplit private[ml] ( +class CategoricalSplit private[ml] ( override val featureIndex: Int, _leftCategories: Array[Double], @Since("2.0.0") val numCategories: Int) @@ -160,7 +160,7 @@ final class CategoricalSplit private[ml] ( * Otherwise, it goes right. */ @DeveloperApi -final class ContinuousSplit private[ml] (override val featureIndex: Int, val threshold: Double) +class ContinuousSplit private[ml] (override val featureIndex: Int, val threshold: Double) extends Split { override private[ml] def shouldGoLeft(features: Vector): Boolean = { --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
