Repository: spark
Updated Branches:
  refs/heads/master 669ade3a8 -> ebd899b8a


[SPARK-25321][ML] Revert SPARK-14681 to avoid API breaking change

## What changes were proposed in this pull request?

This is the same as #22492 but for master branch. Revert SPARK-14681 to avoid 
API breaking changes.

cc: WeichenXu123

## How was this patch tested?

Existing unit tests.

Closes #22618 from mengxr/SPARK-25321.master.

Authored-by: WeichenXu <weichen...@databricks.com>
Signed-off-by: Dongjoon Hyun <dongj...@apache.org>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ebd899b8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ebd899b8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ebd899b8

Branch: refs/heads/master
Commit: ebd899b8a865395e6f1137163cb508086696879b
Parents: 669ade3
Author: WeichenXu <weichen...@databricks.com>
Authored: Sun Oct 7 10:06:44 2018 -0700
Committer: Dongjoon Hyun <dongj...@apache.org>
Committed: Sun Oct 7 10:06:44 2018 -0700

----------------------------------------------------------------------
 .../classification/DecisionTreeClassifier.scala |  14 +-
 .../spark/ml/classification/GBTClassifier.scala |   6 +-
 .../classification/RandomForestClassifier.scala |   6 +-
 .../ml/regression/DecisionTreeRegressor.scala   |  13 +-
 .../spark/ml/regression/GBTRegressor.scala      |   6 +-
 .../ml/regression/RandomForestRegressor.scala   |   6 +-
 .../scala/org/apache/spark/ml/tree/Node.scala   | 247 ++++---------------
 .../spark/ml/tree/impl/RandomForest.scala       |  10 +-
 .../org/apache/spark/ml/tree/treeModels.scala   |  36 +--
 .../DecisionTreeClassifierSuite.scala           |  31 +--
 .../ml/classification/GBTClassifierSuite.scala  |   4 +-
 .../RandomForestClassifierSuite.scala           |   5 +-
 .../regression/DecisionTreeRegressorSuite.scala |  14 --
 .../spark/ml/tree/impl/RandomForestSuite.scala  |  22 +-
 .../apache/spark/ml/tree/impl/TreeTests.scala   |  12 +-
 project/MimaExcludes.scala                      |   7 -
 16 files changed, 107 insertions(+), 332 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ebd899b8/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 8a57bfc..6648e78 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -168,7 +168,7 @@ object DecisionTreeClassifier extends 
DefaultParamsReadable[DecisionTreeClassifi
 @Since("1.4.0")
 class DecisionTreeClassificationModel private[ml] (
     @Since("1.4.0")override val uid: String,
-    @Since("1.4.0")override val rootNode: ClassificationNode,
+    @Since("1.4.0")override val rootNode: Node,
     @Since("1.6.0")override val numFeatures: Int,
     @Since("1.5.0")override val numClasses: Int)
   extends ProbabilisticClassificationModel[Vector, 
DecisionTreeClassificationModel]
@@ -181,7 +181,7 @@ class DecisionTreeClassificationModel private[ml] (
    * Construct a decision tree classification model.
    * @param rootNode  Root node of tree, with other nodes attached.
    */
-  private[ml] def this(rootNode: ClassificationNode, numFeatures: Int, 
numClasses: Int) =
+  private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) =
     this(Identifiable.randomUID("dtc"), rootNode, numFeatures, numClasses)
 
   override def predict(features: Vector): Double = {
@@ -279,9 +279,8 @@ object DecisionTreeClassificationModel extends 
MLReadable[DecisionTreeClassifica
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
       val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
       val numClasses = (metadata.metadata \ "numClasses").extract[Int]
-      val root = loadTreeNodes(path, metadata, sparkSession, isClassification 
= true)
-      val model = new DecisionTreeClassificationModel(metadata.uid,
-        root.asInstanceOf[ClassificationNode], numFeatures, numClasses)
+      val root = loadTreeNodes(path, metadata, sparkSession)
+      val model = new DecisionTreeClassificationModel(metadata.uid, root, 
numFeatures, numClasses)
       metadata.getAndSetParams(model)
       model
     }
@@ -296,10 +295,9 @@ object DecisionTreeClassificationModel extends 
MLReadable[DecisionTreeClassifica
     require(oldModel.algo == OldAlgo.Classification,
       s"Cannot convert non-classification DecisionTreeModel (old API) to" +
         s" DecisionTreeClassificationModel (new API).  Algo is: 
${oldModel.algo}")
-    val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures, 
isClassification = true)
+    val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
     val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc")
     // Can't infer number of features from old model, so default to -1
-    new DecisionTreeClassificationModel(uid,
-      rootNode.asInstanceOf[ClassificationNode], numFeatures, -1)
+    new DecisionTreeClassificationModel(uid, rootNode, numFeatures, -1)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/ebd899b8/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala 
b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index 33acd99..62cfa39 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -412,14 +412,14 @@ object GBTClassificationModel extends 
MLReadable[GBTClassificationModel] {
     override def load(path: String): GBTClassificationModel = {
       implicit val format = DefaultFormats
       val (metadata: Metadata, treesData: Array[(Metadata, Node)], 
treeWeights: Array[Double]) =
-        EnsembleModelReadWrite.loadImpl(path, sparkSession, className, 
treeClassName, false)
+        EnsembleModelReadWrite.loadImpl(path, sparkSession, className, 
treeClassName)
       val numFeatures = (metadata.metadata \ numFeaturesKey).extract[Int]
       val numTrees = (metadata.metadata \ numTreesKey).extract[Int]
 
       val trees: Array[DecisionTreeRegressionModel] = treesData.map {
         case (treeMetadata, root) =>
-          val tree = new DecisionTreeRegressionModel(treeMetadata.uid,
-            root.asInstanceOf[RegressionNode], numFeatures)
+          val tree =
+            new DecisionTreeRegressionModel(treeMetadata.uid, root, 
numFeatures)
           treeMetadata.getAndSetParams(tree)
           tree
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/ebd899b8/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 94887ac..5713238 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -313,15 +313,15 @@ object RandomForestClassificationModel extends 
MLReadable[RandomForestClassifica
     override def load(path: String): RandomForestClassificationModel = {
       implicit val format = DefaultFormats
       val (metadata: Metadata, treesData: Array[(Metadata, Node)], _) =
-        EnsembleModelReadWrite.loadImpl(path, sparkSession, className, 
treeClassName, true)
+        EnsembleModelReadWrite.loadImpl(path, sparkSession, className, 
treeClassName)
       val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
       val numClasses = (metadata.metadata \ "numClasses").extract[Int]
       val numTrees = (metadata.metadata \ "numTrees").extract[Int]
 
       val trees: Array[DecisionTreeClassificationModel] = treesData.map {
         case (treeMetadata, root) =>
-          val tree = new DecisionTreeClassificationModel(treeMetadata.uid,
-            root.asInstanceOf[ClassificationNode], numFeatures, numClasses)
+          val tree =
+            new DecisionTreeClassificationModel(treeMetadata.uid, root, 
numFeatures, numClasses)
           treeMetadata.getAndSetParams(tree)
           tree
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/ebd899b8/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 018290f..6fa6562 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -160,7 +160,7 @@ object DecisionTreeRegressor extends 
DefaultParamsReadable[DecisionTreeRegressor
 @Since("1.4.0")
 class DecisionTreeRegressionModel private[ml] (
     override val uid: String,
-    override val rootNode: RegressionNode,
+    override val rootNode: Node,
     override val numFeatures: Int)
   extends PredictionModel[Vector, DecisionTreeRegressionModel]
   with DecisionTreeModel with DecisionTreeRegressorParams with MLWritable with 
Serializable {
@@ -175,7 +175,7 @@ class DecisionTreeRegressionModel private[ml] (
    * Construct a decision tree regression model.
    * @param rootNode  Root node of tree, with other nodes attached.
    */
-  private[ml] def this(rootNode: RegressionNode, numFeatures: Int) =
+  private[ml] def this(rootNode: Node, numFeatures: Int) =
     this(Identifiable.randomUID("dtr"), rootNode, numFeatures)
 
   override def predict(features: Vector): Double = {
@@ -279,9 +279,8 @@ object DecisionTreeRegressionModel extends 
MLReadable[DecisionTreeRegressionMode
       implicit val format = DefaultFormats
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
       val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
-      val root = loadTreeNodes(path, metadata, sparkSession, isClassification 
= false)
-      val model = new DecisionTreeRegressionModel(metadata.uid,
-        root.asInstanceOf[RegressionNode], numFeatures)
+      val root = loadTreeNodes(path, metadata, sparkSession)
+      val model = new DecisionTreeRegressionModel(metadata.uid, root, 
numFeatures)
       metadata.getAndSetParams(model)
       model
     }
@@ -296,8 +295,8 @@ object DecisionTreeRegressionModel extends 
MLReadable[DecisionTreeRegressionMode
     require(oldModel.algo == OldAlgo.Regression,
       s"Cannot convert non-regression DecisionTreeModel (old API) to" +
         s" DecisionTreeRegressionModel (new API).  Algo is: ${oldModel.algo}")
-    val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures, 
isClassification = false)
+    val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
     val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtr")
-    new DecisionTreeRegressionModel(uid, 
rootNode.asInstanceOf[RegressionNode], numFeatures)
+    new DecisionTreeRegressionModel(uid, rootNode, numFeatures)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/ebd899b8/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala 
b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 3305881..07f88d8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -338,15 +338,15 @@ object GBTRegressionModel extends 
MLReadable[GBTRegressionModel] {
     override def load(path: String): GBTRegressionModel = {
       implicit val format = DefaultFormats
       val (metadata: Metadata, treesData: Array[(Metadata, Node)], 
treeWeights: Array[Double]) =
-        EnsembleModelReadWrite.loadImpl(path, sparkSession, className, 
treeClassName, false)
+        EnsembleModelReadWrite.loadImpl(path, sparkSession, className, 
treeClassName)
 
       val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
       val numTrees = (metadata.metadata \ "numTrees").extract[Int]
 
       val trees: Array[DecisionTreeRegressionModel] = treesData.map {
         case (treeMetadata, root) =>
-          val tree = new DecisionTreeRegressionModel(treeMetadata.uid,
-            root.asInstanceOf[RegressionNode], numFeatures)
+          val tree =
+            new DecisionTreeRegressionModel(treeMetadata.uid, root, 
numFeatures)
           treeMetadata.getAndSetParams(tree)
           tree
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/ebd899b8/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 3587572..82bf66f 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -271,13 +271,13 @@ object RandomForestRegressionModel extends 
MLReadable[RandomForestRegressionMode
     override def load(path: String): RandomForestRegressionModel = {
       implicit val format = DefaultFormats
       val (metadata: Metadata, treesData: Array[(Metadata, Node)], 
treeWeights: Array[Double]) =
-        EnsembleModelReadWrite.loadImpl(path, sparkSession, className, 
treeClassName, false)
+        EnsembleModelReadWrite.loadImpl(path, sparkSession, className, 
treeClassName)
       val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
       val numTrees = (metadata.metadata \ "numTrees").extract[Int]
 
       val trees: Array[DecisionTreeRegressionModel] = treesData.map { case 
(treeMetadata, root) =>
-        val tree = new DecisionTreeRegressionModel(treeMetadata.uid,
-          root.asInstanceOf[RegressionNode], numFeatures)
+        val tree =
+          new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
         treeMetadata.getAndSetParams(tree)
         tree
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/ebd899b8/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
index 0242bc76..d30be45 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
@@ -17,16 +17,14 @@
 
 package org.apache.spark.ml.tree
 
-import org.apache.spark.annotation.Since
 import org.apache.spark.ml.linalg.Vector
 import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
-import org.apache.spark.mllib.tree.model.{ImpurityStats, InformationGainStats 
=> OldInformationGainStats,
-  Node => OldNode, Predict => OldPredict}
+import org.apache.spark.mllib.tree.model.{ImpurityStats, InformationGainStats 
=> OldInformationGainStats, Node => OldNode, Predict => OldPredict}
 
 /**
  * Decision tree node interface.
  */
-sealed trait Node extends Serializable {
+sealed abstract class Node extends Serializable {
 
   // TODO: Add aggregate stats (once available).  This will happen after we 
move the DecisionTree
   //       code into the new API and deprecate the old API.  SPARK-3727
@@ -86,86 +84,35 @@ private[ml] object Node {
   /**
    * Create a new Node from the old Node format, recursively creating child 
nodes as needed.
    */
-  def fromOld(
-      oldNode: OldNode,
-      categoricalFeatures: Map[Int, Int],
-      isClassification: Boolean): Node = {
+  def fromOld(oldNode: OldNode, categoricalFeatures: Map[Int, Int]): Node = {
     if (oldNode.isLeaf) {
       // TODO: Once the implementation has been moved to this API, then 
include sufficient
       //       statistics here.
-      if (isClassification) {
-        new ClassificationLeafNode(prediction = oldNode.predict.predict,
-          impurity = oldNode.impurity, impurityStats = null)
-      } else {
-        new RegressionLeafNode(prediction = oldNode.predict.predict,
-          impurity = oldNode.impurity, impurityStats = null)
-      }
+      new LeafNode(prediction = oldNode.predict.predict,
+        impurity = oldNode.impurity, impurityStats = null)
     } else {
       val gain = if (oldNode.stats.nonEmpty) {
         oldNode.stats.get.gain
       } else {
         0.0
       }
-      if (isClassification) {
-        new ClassificationInternalNode(prediction = oldNode.predict.predict,
-          impurity = oldNode.impurity, gain = gain,
-          leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures, true)
-            .asInstanceOf[ClassificationNode],
-          rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures, 
true)
-            .asInstanceOf[ClassificationNode],
-          split = Split.fromOld(oldNode.split.get, categoricalFeatures), 
impurityStats = null)
-      } else {
-        new RegressionInternalNode(prediction = oldNode.predict.predict,
-          impurity = oldNode.impurity, gain = gain,
-          leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures, false)
-            .asInstanceOf[RegressionNode],
-          rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures, 
false)
-            .asInstanceOf[RegressionNode],
-          split = Split.fromOld(oldNode.split.get, categoricalFeatures), 
impurityStats = null)
-      }
+      new InternalNode(prediction = oldNode.predict.predict, impurity = 
oldNode.impurity,
+        gain = gain, leftChild = fromOld(oldNode.leftNode.get, 
categoricalFeatures),
+        rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures),
+        split = Split.fromOld(oldNode.split.get, categoricalFeatures), 
impurityStats = null)
     }
   }
 }
 
-@Since("2.4.0")
-sealed trait ClassificationNode extends Node {
-
-  /**
-   * Get count of training examples for specified label in this node
-   * @param label label number in the range [0, numClasses)
-   */
-  @Since("2.4.0")
-  def getLabelCount(label: Int): Double = {
-    require(label >= 0 && label < impurityStats.stats.length,
-      "label should be in the range between 0 (inclusive) " +
-      s"and ${impurityStats.stats.length} (exclusive).")
-    impurityStats.stats(label)
-  }
-}
-
-@Since("2.4.0")
-sealed trait RegressionNode extends Node {
-
-  /** Number of training data points in this node */
-  @Since("2.4.0")
-  def getCount: Double = impurityStats.stats(0)
-
-  /** Sum over training data points of the labels in this node */
-  @Since("2.4.0")
-  def getSum: Double = impurityStats.stats(1)
-
-  /** Sum over training data points of the square of the labels in this node */
-  @Since("2.4.0")
-  def getSumOfSquares: Double = impurityStats.stats(2)
-}
-
-@Since("2.4.0")
-sealed trait LeafNode extends Node {
-
-  /** Prediction this node makes. */
-  def prediction: Double
-
-  def impurity: Double
+/**
+ * Decision tree leaf node.
+ * @param prediction  Prediction this node makes
+ * @param impurity  Impurity measure at this node (for training data)
+ */
+class LeafNode private[ml] (
+    override val prediction: Double,
+    override val impurity: Double,
+    override private[ml] val impurityStats: ImpurityCalculator) extends Node {
 
   override def toString: String =
     s"LeafNode(prediction = $prediction, impurity = $impurity)"
@@ -188,58 +135,32 @@ sealed trait LeafNode extends Node {
 
   override private[ml] def maxSplitFeatureIndex(): Int = -1
 
-}
-
-/**
- * Decision tree leaf node for classification.
- */
-@Since("2.4.0")
-class ClassificationLeafNode private[ml] (
-    override val prediction: Double,
-    override val impurity: Double,
-    override private[ml] val impurityStats: ImpurityCalculator)
-  extends ClassificationNode with LeafNode {
-
   override private[tree] def deepCopy(): Node = {
-    new ClassificationLeafNode(prediction, impurity, impurityStats)
+    new LeafNode(prediction, impurity, impurityStats)
   }
 }
 
 /**
- * Decision tree leaf node for regression.
+ * Internal Decision Tree node.
+ * @param prediction  Prediction this node would make if it were a leaf node
+ * @param impurity  Impurity measure at this node (for training data)
+ * @param gain Information gain value. Values less than 0 indicate missing 
values;
+ *             this quirk will be removed with future updates.
+ * @param leftChild  Left-hand child node
+ * @param rightChild  Right-hand child node
+ * @param split  Information about the test used to split to the left or right 
child.
  */
-@Since("2.4.0")
-class RegressionLeafNode private[ml] (
+class InternalNode private[ml] (
     override val prediction: Double,
     override val impurity: Double,
-    override private[ml] val impurityStats: ImpurityCalculator)
-  extends RegressionNode with LeafNode {
-
-  override private[tree] def deepCopy(): Node = {
-    new RegressionLeafNode(prediction, impurity, impurityStats)
-  }
-}
-
-/**
- * Internal Decision Tree node.
- */
-@Since("2.4.0")
-sealed trait InternalNode extends Node {
-
-  /**
-   * Information gain value. Values less than 0 indicate missing values;
-   * this quirk will be removed with future updates.
-   */
-  def gain: Double
-
-  /** Left-hand child node */
-  def leftChild: Node
-
-  /** Right-hand child node */
-  def rightChild: Node
+    val gain: Double,
+    val leftChild: Node,
+    val rightChild: Node,
+    val split: Split,
+    override private[ml] val impurityStats: ImpurityCalculator) extends Node {
 
-  /** Information about the test used to split to the left or right child. */
-  def split: Split
+  // Note to developers: The constructor argument impurityStats should be 
reconsidered before we
+  //                     make the constructor public.  We may be able to 
improve the representation.
 
   override def toString: String = {
     s"InternalNode(prediction = $prediction, impurity = $impurity, split = 
$split)"
@@ -284,6 +205,11 @@ sealed trait InternalNode extends Node {
     math.max(split.featureIndex,
       math.max(leftChild.maxSplitFeatureIndex(), 
rightChild.maxSplitFeatureIndex()))
   }
+
+  override private[tree] def deepCopy(): Node = {
+    new InternalNode(prediction, impurity, gain, leftChild.deepCopy(), 
rightChild.deepCopy(),
+      split, impurityStats)
+  }
 }
 
 private object InternalNode {
@@ -315,57 +241,6 @@ private object InternalNode {
 }
 
 /**
- * Internal Decision Tree node for regression.
- */
-@Since("2.4.0")
-class ClassificationInternalNode private[ml] (
-    override val prediction: Double,
-    override val impurity: Double,
-    override val gain: Double,
-    override val leftChild: ClassificationNode,
-    override val rightChild: ClassificationNode,
-    override val split: Split,
-    override private[ml] val impurityStats: ImpurityCalculator)
-  extends ClassificationNode with InternalNode {
-
-  // Note to developers: The constructor argument impurityStats should be 
reconsidered before we
-  //                     make the constructor public.  We may be able to 
improve the representation.
-
-  override private[tree] def deepCopy(): Node = {
-    new ClassificationInternalNode(prediction, impurity, gain,
-      leftChild.deepCopy().asInstanceOf[ClassificationNode],
-      rightChild.deepCopy().asInstanceOf[ClassificationNode],
-      split, impurityStats)
-  }
-}
-
-/**
- * Internal Decision Tree node for regression.
- */
-@Since("2.4.0")
-class RegressionInternalNode private[ml] (
-    override val prediction: Double,
-    override val impurity: Double,
-    override val gain: Double,
-    override val leftChild: RegressionNode,
-    override val rightChild: RegressionNode,
-    override val split: Split,
-    override private[ml] val impurityStats: ImpurityCalculator)
-  extends RegressionNode with InternalNode {
-
-  // Note to developers: The constructor argument impurityStats should be 
reconsidered before we
-  //                     make the constructor public.  We may be able to 
improve the representation.
-
-  override private[tree] def deepCopy(): Node = {
-    new RegressionInternalNode(prediction, impurity, gain,
-      leftChild.deepCopy().asInstanceOf[RegressionNode],
-      rightChild.deepCopy().asInstanceOf[RegressionNode],
-      split, impurityStats)
-  }
-}
-
-
-/**
  * Version of a node used in learning.  This uses vars so that we can modify 
nodes as we split the
  * tree by adding children, etc.
  *
@@ -390,52 +265,30 @@ private[tree] class LearningNode(
     var isLeaf: Boolean,
     var stats: ImpurityStats) extends Serializable {
 
-  def toNode(isClassification: Boolean): Node = toNode(isClassification, prune 
= true)
-
-  def toClassificationNode(prune: Boolean = true): ClassificationNode = {
-    toNode(true, prune).asInstanceOf[ClassificationNode]
-  }
-
-  def toRegressionNode(prune: Boolean = true): RegressionNode = {
-    toNode(false, prune).asInstanceOf[RegressionNode]
-  }
+  def toNode: Node = toNode(prune = true)
 
   /**
    * Convert this [[LearningNode]] to a regular [[Node]], and recurse on any 
children.
    */
-  def toNode(isClassification: Boolean, prune: Boolean): Node = {
+  def toNode(prune: Boolean = true): Node = {
 
     if (!leftChild.isEmpty || !rightChild.isEmpty) {
       assert(leftChild.nonEmpty && rightChild.nonEmpty && split.nonEmpty && 
stats != null,
         "Unknown error during Decision Tree learning.  Could not convert 
LearningNode to Node.")
-      (leftChild.get.toNode(isClassification, prune),
-       rightChild.get.toNode(isClassification, prune)) match {
+      (leftChild.get.toNode(prune), rightChild.get.toNode(prune)) match {
         case (l: LeafNode, r: LeafNode) if prune && l.prediction == 
r.prediction =>
-          if (isClassification) {
-            new ClassificationLeafNode(l.prediction, stats.impurity, 
stats.impurityCalculator)
-          } else {
-            new RegressionLeafNode(l.prediction, stats.impurity, 
stats.impurityCalculator)
-          }
+          new LeafNode(l.prediction, stats.impurity, stats.impurityCalculator)
         case (l, r) =>
-          if (isClassification) {
-            new ClassificationInternalNode(stats.impurityCalculator.predict, 
stats.impurity,
-              stats.gain, l.asInstanceOf[ClassificationNode], 
r.asInstanceOf[ClassificationNode],
-              split.get, stats.impurityCalculator)
-          } else {
-            new RegressionInternalNode(stats.impurityCalculator.predict, 
stats.impurity, stats.gain,
-              l.asInstanceOf[RegressionNode], r.asInstanceOf[RegressionNode],
-              split.get, stats.impurityCalculator)
-          }
+          new InternalNode(stats.impurityCalculator.predict, stats.impurity, 
stats.gain,
+            l, r, split.get, stats.impurityCalculator)
       }
     } else {
-      // Here we want to keep same behavior with the old 
mllib.DecisionTreeModel
-      val impurity = if (stats.valid) stats.impurity else -1.0
-      if (isClassification) {
-        new ClassificationLeafNode(stats.impurityCalculator.predict, impurity,
+      if (stats.valid) {
+        new LeafNode(stats.impurityCalculator.predict, stats.impurity,
           stats.impurityCalculator)
       } else {
-        new RegressionLeafNode(stats.impurityCalculator.predict, impurity,
-          stats.impurityCalculator)
+        // Here we want to keep same behavior with the old 
mllib.DecisionTreeModel
+        new LeafNode(stats.impurityCalculator.predict, -1.0, 
stats.impurityCalculator)
       }
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/ebd899b8/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index 4cdd172..822abd2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -226,23 +226,23 @@ private[spark] object RandomForest extends Logging with 
Serializable {
       case Some(uid) =>
         if (strategy.algo == OldAlgo.Classification) {
           topNodes.map { rootNode =>
-            new DecisionTreeClassificationModel(uid, 
rootNode.toClassificationNode(prune),
-              numFeatures, strategy.getNumClasses)
+            new DecisionTreeClassificationModel(uid, rootNode.toNode(prune), 
numFeatures,
+              strategy.getNumClasses)
           }
         } else {
           topNodes.map { rootNode =>
-            new DecisionTreeRegressionModel(uid, 
rootNode.toRegressionNode(prune), numFeatures)
+            new DecisionTreeRegressionModel(uid, rootNode.toNode(prune), 
numFeatures)
           }
         }
       case None =>
         if (strategy.algo == OldAlgo.Classification) {
           topNodes.map { rootNode =>
-            new 
DecisionTreeClassificationModel(rootNode.toClassificationNode(prune), 
numFeatures,
+            new DecisionTreeClassificationModel(rootNode.toNode(prune), 
numFeatures,
               strategy.getNumClasses)
           }
         } else {
           topNodes.map(rootNode =>
-            new DecisionTreeRegressionModel(rootNode.toRegressionNode(prune), 
numFeatures))
+            new DecisionTreeRegressionModel(rootNode.toNode(prune), 
numFeatures))
         }
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/ebd899b8/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index f027b14..4aa4c36 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -219,10 +219,8 @@ private[ml] object TreeEnsembleModel {
         importances.changeValue(feature, scaledGain, _ + scaledGain)
         computeFeatureImportance(n.leftChild, importances)
         computeFeatureImportance(n.rightChild, importances)
-      case _: LeafNode =>
+      case n: LeafNode =>
       // do nothing
-      case _ =>
-        throw new IllegalArgumentException(s"Unknown node type: 
${node.getClass.toString}")
     }
   }
 
@@ -319,8 +317,6 @@ private[ml] object DecisionTreeModelReadWrite {
         (Seq(NodeData(id, node.prediction, node.impurity, 
node.impurityStats.stats,
           -1.0, -1, -1, SplitData(-1, Array.empty[Double], -1))),
           id)
-      case _ =>
-        throw new IllegalArgumentException(s"Unknown node type: 
${node.getClass.toString}")
     }
   }
 
@@ -331,7 +327,7 @@ private[ml] object DecisionTreeModelReadWrite {
   def loadTreeNodes(
       path: String,
       metadata: DefaultParamsReader.Metadata,
-      sparkSession: SparkSession, isClassification: Boolean): Node = {
+      sparkSession: SparkSession): Node = {
     import sparkSession.implicits._
     implicit val format = DefaultFormats
 
@@ -343,7 +339,7 @@ private[ml] object DecisionTreeModelReadWrite {
 
     val dataPath = new Path(path, "data").toString
     val data = sparkSession.read.parquet(dataPath).as[NodeData]
-    buildTreeFromNodes(data.collect(), impurityType, isClassification)
+    buildTreeFromNodes(data.collect(), impurityType)
   }
 
   /**
@@ -352,8 +348,7 @@ private[ml] object DecisionTreeModelReadWrite {
    * @param impurityType  Impurity type for this tree
    * @return Root node of reconstructed tree
    */
-  def buildTreeFromNodes(data: Array[NodeData], impurityType: String,
-      isClassification: Boolean): Node = {
+  def buildTreeFromNodes(data: Array[NodeData], impurityType: String): Node = {
     // Load all nodes, sorted by ID.
     val nodes = data.sortBy(_.id)
     // Sanity checks; could remove
@@ -369,21 +364,10 @@ private[ml] object DecisionTreeModelReadWrite {
       val node = if (n.leftChild != -1) {
         val leftChild = finalNodes(n.leftChild)
         val rightChild = finalNodes(n.rightChild)
-        if (isClassification) {
-          new ClassificationInternalNode(n.prediction, n.impurity, n.gain,
-            leftChild.asInstanceOf[ClassificationNode], 
rightChild.asInstanceOf[ClassificationNode],
-            n.split.getSplit, impurityStats)
-        } else {
-          new RegressionInternalNode(n.prediction, n.impurity, n.gain,
-            leftChild.asInstanceOf[RegressionNode], 
rightChild.asInstanceOf[RegressionNode],
-            n.split.getSplit, impurityStats)
-        }
+        new InternalNode(n.prediction, n.impurity, n.gain, leftChild, 
rightChild,
+          n.split.getSplit, impurityStats)
       } else {
-        if (isClassification) {
-          new ClassificationLeafNode(n.prediction, n.impurity, impurityStats)
-        } else {
-          new RegressionLeafNode(n.prediction, n.impurity, impurityStats)
-        }
+        new LeafNode(n.prediction, n.impurity, impurityStats)
       }
       finalNodes(n.id) = node
     }
@@ -437,8 +421,7 @@ private[ml] object EnsembleModelReadWrite {
       path: String,
       sql: SparkSession,
       className: String,
-      treeClassName: String,
-      isClassification: Boolean): (Metadata, Array[(Metadata, Node)], 
Array[Double]) = {
+      treeClassName: String): (Metadata, Array[(Metadata, Node)], 
Array[Double]) = {
     import sql.implicits._
     implicit val format = DefaultFormats
     val metadata = DefaultParamsReader.loadMetadata(path, sql.sparkContext, 
className)
@@ -466,8 +449,7 @@ private[ml] object EnsembleModelReadWrite {
     val rootNodesRDD: RDD[(Int, Node)] =
       nodeData.rdd.map(d => (d.treeID, d.nodeData)).groupByKey().map {
         case (treeID: Int, nodeData: Iterable[NodeData]) =>
-          treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(
-            nodeData.toArray, impurityType, isClassification)
+          treeID -> 
DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, impurityType)
       }
     val rootNodes: Array[Node] = rootNodesRDD.sortByKey().values.collect()
     (metadata, treesMetadata.zip(rootNodes), treesWeights)

http://git-wip-us.apache.org/repos/asf/spark/blob/ebd899b8/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index d3dbb4e..2930f49 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.feature.LabeledPoint
 import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.tree.ClassificationLeafNode
+import org.apache.spark.ml.tree.LeafNode
 import org.apache.spark.ml.tree.impl.TreeTests
 import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
 import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
@@ -61,8 +61,7 @@ class DecisionTreeClassifierSuite extends MLTest with 
DefaultReadWriteTest {
 
   test("params") {
     ParamsSuite.checkParams(new DecisionTreeClassifier)
-    val model = new DecisionTreeClassificationModel("dtc",
-      new ClassificationLeafNode(0.0, 0.0, null), 1, 2)
+    val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 
0.0, null), 1, 2)
     ParamsSuite.checkParams(model)
   }
 
@@ -376,32 +375,6 @@ class DecisionTreeClassifierSuite extends MLTest with 
DefaultReadWriteTest {
 
     testDefaultReadWrite(model)
   }
-
-  test("label/impurity stats") {
-    val arr = Array(
-      LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))),
-      LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))),
-      LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))))
-    val rdd = sc.parallelize(arr)
-    val df = TreeTests.setMetadata(rdd, Map.empty[Int, Int], 2)
-    val dt1 = new DecisionTreeClassifier()
-      .setImpurity("entropy")
-      .setMaxDepth(2)
-      .setMinInstancesPerNode(2)
-    val model1 = dt1.fit(df)
-
-    val rootNode1 = model1.rootNode
-    assert(Array(rootNode1.getLabelCount(0), rootNode1.getLabelCount(1)) === 
Array(2.0, 1.0))
-
-    val dt2 = new DecisionTreeClassifier()
-      .setImpurity("gini")
-      .setMaxDepth(2)
-      .setMinInstancesPerNode(2)
-    val model2 = dt2.fit(df)
-
-    val rootNode2 = model2.rootNode
-    assert(Array(rootNode2.getLabelCount(0), rootNode2.getLabelCount(1)) === 
Array(2.0, 1.0))
-  }
 }
 
 private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite {

http://git-wip-us.apache.org/repos/asf/spark/blob/ebd899b8/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index e6d2a8e..3049776 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -24,7 +24,7 @@ import org.apache.spark.ml.feature.LabeledPoint
 import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.param.ParamsSuite
 import org.apache.spark.ml.regression.DecisionTreeRegressionModel
-import org.apache.spark.ml.tree.RegressionLeafNode
+import org.apache.spark.ml.tree.LeafNode
 import org.apache.spark.ml.tree.impl.{GradientBoostedTrees, TreeTests}
 import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
 import org.apache.spark.ml.util.TestingUtils._
@@ -70,7 +70,7 @@ class GBTClassifierSuite extends MLTest with 
DefaultReadWriteTest {
   test("params") {
     ParamsSuite.checkParams(new GBTClassifier)
     val model = new GBTClassificationModel("gbtc",
-      Array(new DecisionTreeRegressionModel("dtr", new RegressionLeafNode(0.0, 
0.0, null), 1)),
+      Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, 
null), 1)),
       Array(1.0), 1, 2)
     ParamsSuite.checkParams(model)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/ebd899b8/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index 3062aa9..ba4a9cf 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.feature.LabeledPoint
 import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.tree.ClassificationLeafNode
+import org.apache.spark.ml.tree.LeafNode
 import org.apache.spark.ml.tree.impl.TreeTests
 import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
 import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
@@ -71,8 +71,7 @@ class RandomForestClassifierSuite extends MLTest with 
DefaultReadWriteTest {
   test("params") {
     ParamsSuite.checkParams(new RandomForestClassifier)
     val model = new RandomForestClassificationModel("rfc",
-      Array(new DecisionTreeClassificationModel("dtc",
-        new ClassificationLeafNode(0.0, 0.0, null), 1, 2)), 2, 2)
+      Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, 
null), 1, 2)), 2, 2)
     ParamsSuite.checkParams(model)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ebd899b8/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 9ae2733..29a4383 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -191,20 +191,6 @@ class DecisionTreeRegressorSuite extends MLTest with 
DefaultReadWriteTest {
       TreeTests.allParamSettings ++ Map("maxDepth" -> 0),
       TreeTests.allParamSettings ++ Map("maxDepth" -> 0), checkModelData)
   }
-
-  test("label/impurity stats") {
-    val categoricalFeatures = Map(0 -> 2, 1 -> 2)
-    val df = TreeTests.setMetadata(categoricalDataPointsRDD, 
categoricalFeatures, numClasses = 0)
-    val dtr = new DecisionTreeRegressor()
-      .setImpurity("variance")
-      .setMaxDepth(2)
-      .setMaxBins(8)
-    val model = dtr.fit(df)
-    val statInfo = model.rootNode
-
-    assert(statInfo.getCount == 1000.0 && statInfo.getSum == 600.0
-      && statInfo.getSumOfSquares == 600.0)
-  }
 }
 
 private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite {

http://git-wip-us.apache.org/repos/asf/spark/blob/ebd899b8/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
index 4dbbd75..743dacf 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -340,8 +340,8 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     assert(topNode.stats.impurity > 0.0)
 
     // set impurity and predict for child nodes
-    assert(topNode.leftChild.get.toNode(isClassification = true).prediction 
=== 0.0)
-    assert(topNode.rightChild.get.toNode(isClassification = true).prediction 
=== 1.0)
+    assert(topNode.leftChild.get.toNode.prediction === 0.0)
+    assert(topNode.rightChild.get.toNode.prediction === 1.0)
     assert(topNode.leftChild.get.stats.impurity === 0.0)
     assert(topNode.rightChild.get.stats.impurity === 0.0)
   }
@@ -382,8 +382,8 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     assert(topNode.stats.impurity > 0.0)
 
     // set impurity and predict for child nodes
-    assert(topNode.leftChild.get.toNode(isClassification = true).prediction 
=== 0.0)
-    assert(topNode.rightChild.get.toNode(isClassification = true).prediction 
=== 1.0)
+    assert(topNode.leftChild.get.toNode.prediction === 0.0)
+    assert(topNode.rightChild.get.toNode.prediction === 1.0)
     assert(topNode.leftChild.get.stats.impurity === 0.0)
     assert(topNode.rightChild.get.stats.impurity === 0.0)
   }
@@ -582,18 +582,18 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
                 left  right
      */
     val leftImp = new GiniCalculator(Array(3.0, 2.0, 1.0))
-    val left = new ClassificationLeafNode(0.0, leftImp.calculate(), leftImp)
+    val left = new LeafNode(0.0, leftImp.calculate(), leftImp)
 
     val rightImp = new GiniCalculator(Array(1.0, 2.0, 5.0))
-    val right = new ClassificationLeafNode(2.0, rightImp.calculate(), rightImp)
+    val right = new LeafNode(2.0, rightImp.calculate(), rightImp)
 
-    val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 
0.5), true)
+    val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 
0.5))
     val parentImp = parent.impurityStats
 
     val left2Imp = new GiniCalculator(Array(1.0, 6.0, 1.0))
-    val left2 = new ClassificationLeafNode(0.0, left2Imp.calculate(), left2Imp)
+    val left2 = new LeafNode(0.0, left2Imp.calculate(), left2Imp)
 
-    val grandParent = TreeTests.buildParentNode(left2, parent, new 
ContinuousSplit(1, 1.0), true)
+    val grandParent = TreeTests.buildParentNode(left2, parent, new 
ContinuousSplit(1, 1.0))
     val grandImp = grandParent.impurityStats
 
     // Test feature importance computed at different subtrees.
@@ -618,8 +618,8 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
 
     // Forest consisting of (full tree) + (internal node with 2 leafs)
     val trees = Array(parent, grandParent).map { root =>
-      new 
DecisionTreeClassificationModel(root.asInstanceOf[ClassificationNode],
-        numFeatures = 2, numClasses = 3).asInstanceOf[DecisionTreeModel]
+      new DecisionTreeClassificationModel(root, numFeatures = 2, numClasses = 
3)
+        .asInstanceOf[DecisionTreeModel]
     }
     val importances: Vector = TreeEnsembleModel.featureImportances(trees, 2)
     val tree2norm = feature0importance + feature1importance

http://git-wip-us.apache.org/repos/asf/spark/blob/ebd899b8/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala 
b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
index 3f03d90..b6894b3 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
@@ -159,7 +159,7 @@ private[ml] object TreeTests extends SparkFunSuite {
    * @param split  Split for parent node
    * @return  Parent node with children attached
    */
-  def buildParentNode(left: Node, right: Node, split: Split, isClassification: 
Boolean): Node = {
+  def buildParentNode(left: Node, right: Node, split: Split): Node = {
     val leftImp = left.impurityStats
     val rightImp = right.impurityStats
     val parentImp = leftImp.copy.add(rightImp)
@@ -168,15 +168,7 @@ private[ml] object TreeTests extends SparkFunSuite {
     val gain = parentImp.calculate() -
       (leftWeight * leftImp.calculate() + rightWeight * rightImp.calculate())
     val pred = parentImp.predict
-    if (isClassification) {
-      new ClassificationInternalNode(pred, parentImp.calculate(), gain,
-        left.asInstanceOf[ClassificationNode], 
right.asInstanceOf[ClassificationNode],
-        split, parentImp)
-    } else {
-      new RegressionInternalNode(pred, parentImp.calculate(), gain,
-        left.asInstanceOf[RegressionNode], right.asInstanceOf[RegressionNode],
-        split, parentImp)
-    }
+    new InternalNode(pred, parentImp.calculate(), gain, left, right, split, 
parentImp)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/ebd899b8/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index a931738..0b074fb 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -103,13 +103,6 @@ object MimaExcludes {
     
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.defaultParamMap"),
     
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.org$apache$spark$ml$param$Params$_setter_$defaultParamMap_="),
 
-    // [SPARK-14681][ML] Provide label/impurity stats for spark.ml decision 
tree nodes
-    
ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.LeafNode"),
-    
ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.InternalNode"),
-    
ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.Node"),
-    
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.this"),
-    
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.this"),
-
     // [SPARK-7132][ML] Add fit with validation set to spark.ml GBT
     
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.getValidationIndicatorCol"),
     
ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.org$apache$spark$ml$param$shared$HasValidationIndicatorCol$_setter_$validationIndicatorCol_="),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to