Github user jkbradley commented on a diff in the pull request:
https://github.com/apache/spark/pull/20786#discussion_r178202596
--- Diff: mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala ---
@@ -84,35 +86,85 @@ private[ml] object Node {
/**
* Create a new Node from the old Node format, recursively creating
child nodes as needed.
*/
- def fromOld(oldNode: OldNode, categoricalFeatures: Map[Int, Int]): Node
= {
+ def fromOld(oldNode: OldNode, categoricalFeatures: Map[Int, Int],
+ isClassification: Boolean): Node = {
if (oldNode.isLeaf) {
// TODO: Once the implementation has been moved to this API, then
include sufficient
// statistics here.
- new LeafNode(prediction = oldNode.predict.predict,
- impurity = oldNode.impurity, impurityStats = null)
+ if (isClassification) {
+ new ClassificationLeafNode(prediction = oldNode.predict.predict,
+ impurity = oldNode.impurity, impurityStats = null)
+ } else {
+ new RegressionLeafNode(prediction = oldNode.predict.predict,
+ impurity = oldNode.impurity, impurityStats = null)
+ }
} else {
val gain = if (oldNode.stats.nonEmpty) {
oldNode.stats.get.gain
} else {
0.0
}
- new InternalNode(prediction = oldNode.predict.predict, impurity =
oldNode.impurity,
- gain = gain, leftChild = fromOld(oldNode.leftNode.get,
categoricalFeatures),
- rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures),
- split = Split.fromOld(oldNode.split.get, categoricalFeatures),
impurityStats = null)
+ if (isClassification) {
+ new ClassificationInternalNode(prediction =
oldNode.predict.predict,
+ impurity = oldNode.impurity, gain = gain,
+ leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures,
true)
+ .asInstanceOf[ClassificationNode],
+ rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures,
true)
+ .asInstanceOf[ClassificationNode],
+ split = Split.fromOld(oldNode.split.get, categoricalFeatures),
impurityStats = null)
+ } else {
+ new RegressionInternalNode(prediction = oldNode.predict.predict,
+ impurity = oldNode.impurity, gain = gain,
+ leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures,
false)
+ .asInstanceOf[RegressionNode],
+ rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures,
false)
+ .asInstanceOf[RegressionNode],
+ split = Split.fromOld(oldNode.split.get, categoricalFeatures),
impurityStats = null)
+ }
}
}
}
-/**
- * Decision tree leaf node.
- * @param prediction Prediction this node makes
- * @param impurity Impurity measure at this node (for training data)
- */
-class LeafNode private[ml] (
- override val prediction: Double,
- override val impurity: Double,
- override private[ml] val impurityStats: ImpurityCalculator) extends
Node {
+@Since("2.4.0")
+trait ClassificationNode extends Node {
+
+ /**
+ * Get count for specified label in this node
+ * @param label label number in the range [0, numClasses)
+ */
+ @Since("2.4.0")
+ def getLabelCount(label: Int): Double = {
+ require(label >= 0 && label < impurityStats.stats.length,
+ "label should be in the rangle between 0 (inclusive) " +
+ s"and ${impurityStats.stats.length} (exclusive).")
+ impurityStats.stats(label)
+ }
+}
+
+@Since("2.4.0")
+trait RegressionNode extends Node {
+
+ /** Number of data points in this node */
--- End diff --
"data points" -> "training data points" (for other methods too)
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]