Github user jkbradley commented on a diff in the pull request:
https://github.com/apache/spark/pull/20786#discussion_r175957326
--- Diff: mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala ---
@@ -84,35 +86,73 @@ private[ml] object Node {
/**
* Create a new Node from the old Node format, recursively creating
child nodes as needed.
*/
- def fromOld(oldNode: OldNode, categoricalFeatures: Map[Int, Int]): Node
= {
+ def fromOld(oldNode: OldNode, categoricalFeatures: Map[Int, Int],
+ isClassification: Boolean): Node = {
if (oldNode.isLeaf) {
// TODO: Once the implementation has been moved to this API, then
include sufficient
// statistics here.
- new LeafNode(prediction = oldNode.predict.predict,
- impurity = oldNode.impurity, impurityStats = null)
+ if (isClassification) {
+ new ClassificationLeafNode(prediction = oldNode.predict.predict,
+ impurity = oldNode.impurity, impurityStats = null)
+ } else {
+ new RegressionLeafNode(prediction = oldNode.predict.predict,
+ impurity = oldNode.impurity, impurityStats = null)
+ }
} else {
val gain = if (oldNode.stats.nonEmpty) {
oldNode.stats.get.gain
} else {
0.0
}
- new InternalNode(prediction = oldNode.predict.predict, impurity =
oldNode.impurity,
- gain = gain, leftChild = fromOld(oldNode.leftNode.get,
categoricalFeatures),
- rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures),
- split = Split.fromOld(oldNode.split.get, categoricalFeatures),
impurityStats = null)
+ if (isClassification) {
+ new ClassificationInternalNode(prediction =
oldNode.predict.predict,
+ impurity = oldNode.impurity, gain = gain,
+ leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures,
true)
+ .asInstanceOf[ClassificationNode],
+ rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures,
true)
+ .asInstanceOf[ClassificationNode],
+ split = Split.fromOld(oldNode.split.get, categoricalFeatures),
impurityStats = null)
+ } else {
+ new RegressionInternalNode(prediction = oldNode.predict.predict,
+ impurity = oldNode.impurity, gain = gain,
+ leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures,
false)
+ .asInstanceOf[RegressionNode],
+ rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures,
false)
+ .asInstanceOf[RegressionNode],
+ split = Split.fromOld(oldNode.split.get, categoricalFeatures),
impurityStats = null)
+ }
}
}
}
-/**
- * Decision tree leaf node.
- * @param prediction Prediction this node makes
- * @param impurity Impurity measure at this node (for training data)
- */
-class LeafNode private[ml] (
- override val prediction: Double,
- override val impurity: Double,
- override private[ml] val impurityStats: ImpurityCalculator) extends
Node {
+trait ClassificationNode extends Node {
+
+ @Since("2.4.0")
+ def getLabelCount(label: Int): Double = {
--- End diff --
Add doc string to this (same for regression node methods)
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]