Github user sethah commented on a diff in the pull request:
https://github.com/apache/spark/pull/20632#discussion_r171071499
--- Diff: mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala ---
@@ -266,15 +265,24 @@ private[tree] class LearningNode(
var isLeaf: Boolean,
var stats: ImpurityStats) extends Serializable {
+ def toNode: Node = toNode(prune = true)
+
/**
* Convert this [[LearningNode]] to a regular [[Node]], and recurse on
any children.
*/
- def toNode: Node = {
- if (leftChild.nonEmpty) {
- assert(rightChild.nonEmpty && split.nonEmpty && stats != null,
+ def toNode(prune: Boolean = true): Node = {
+
+ if (!leftChild.isEmpty || !rightChild.isEmpty) {
+ assert(leftChild.nonEmpty && rightChild.nonEmpty && split.nonEmpty
&& stats != null,
"Unknown error during Decision Tree learning. Could not convert
LearningNode to Node.")
- new InternalNode(stats.impurityCalculator.predict, stats.impurity,
stats.gain,
- leftChild.get.toNode, rightChild.get.toNode, split.get,
stats.impurityCalculator)
+ (leftChild.get.toNode(prune), rightChild.get.toNode(prune)) match {
+ // when both children make the same prediction, collapse into a
single leaf
--- End diff --
On second thought, I'm not sure the comment is useful since it just
explains what the code does. I vote either no comment or we explain why this
happens, i.e. you can improve impurity without changing the prediction.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]