Github user jkbradley commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20786#discussion_r178202685
  
    --- Diff: mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala ---
    @@ -84,35 +86,85 @@ private[ml] object Node {
       /**
        * Create a new Node from the old Node format, recursively creating 
child nodes as needed.
        */
    -  def fromOld(oldNode: OldNode, categoricalFeatures: Map[Int, Int]): Node 
= {
    +  def fromOld(oldNode: OldNode, categoricalFeatures: Map[Int, Int],
    +    isClassification: Boolean): Node = {
         if (oldNode.isLeaf) {
           // TODO: Once the implementation has been moved to this API, then 
include sufficient
           //       statistics here.
    -      new LeafNode(prediction = oldNode.predict.predict,
    -        impurity = oldNode.impurity, impurityStats = null)
    +      if (isClassification) {
    +        new ClassificationLeafNode(prediction = oldNode.predict.predict,
    +          impurity = oldNode.impurity, impurityStats = null)
    +      } else {
    +        new RegressionLeafNode(prediction = oldNode.predict.predict,
    +          impurity = oldNode.impurity, impurityStats = null)
    +      }
         } else {
           val gain = if (oldNode.stats.nonEmpty) {
             oldNode.stats.get.gain
           } else {
             0.0
           }
    -      new InternalNode(prediction = oldNode.predict.predict, impurity = 
oldNode.impurity,
    -        gain = gain, leftChild = fromOld(oldNode.leftNode.get, 
categoricalFeatures),
    -        rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures),
    -        split = Split.fromOld(oldNode.split.get, categoricalFeatures), 
impurityStats = null)
    +      if (isClassification) {
    +        new ClassificationInternalNode(prediction = 
oldNode.predict.predict,
    +          impurity = oldNode.impurity, gain = gain,
    +          leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures, 
true)
    +            .asInstanceOf[ClassificationNode],
    +          rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures, 
true)
    +            .asInstanceOf[ClassificationNode],
    +          split = Split.fromOld(oldNode.split.get, categoricalFeatures), 
impurityStats = null)
    +      } else {
    +        new RegressionInternalNode(prediction = oldNode.predict.predict,
    +          impurity = oldNode.impurity, gain = gain,
    +          leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures, 
false)
    +            .asInstanceOf[RegressionNode],
    +          rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures, 
false)
    +            .asInstanceOf[RegressionNode],
    +          split = Split.fromOld(oldNode.split.get, categoricalFeatures), 
impurityStats = null)
    +      }
         }
       }
     }
     
    -/**
    - * Decision tree leaf node.
    - * @param prediction  Prediction this node makes
    - * @param impurity  Impurity measure at this node (for training data)
    - */
    -class LeafNode private[ml] (
    -    override val prediction: Double,
    -    override val impurity: Double,
    -    override private[ml] val impurityStats: ImpurityCalculator) extends 
Node {
    +@Since("2.4.0")
    +trait ClassificationNode extends Node {
    +
    +  /**
    +   * Get count for specified label in this node
    +   * @param label label number in the range [0, numClasses)
    +   */
    +  @Since("2.4.0")
    +  def getLabelCount(label: Int): Double = {
    +    require(label >= 0 && label < impurityStats.stats.length,
    +      "label should be in the rangle between 0 (inclusive) " +
    +      s"and ${impurityStats.stats.length} (exclusive).")
    +    impurityStats.stats(label)
    +  }
    +}
    +
    +@Since("2.4.0")
    +trait RegressionNode extends Node {
    +
    +  /** Number of data points in this node */
    +  @Since("2.4.0")
    +  def getCount: Double = impurityStats.stats(0)
    +
    +  /** Sum of data points labels in this node */
    +  @Since("2.4.0")
    +  def getSum: Double = impurityStats.stats(1)
    +
    +  /** Sum of data points label squares in this node */
    --- End diff --
    
    "Sum of data points label squares" -> "Sum over training data points of the 
square of the labels"


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to