Github user jkbradley commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20786#discussion_r175957487
  
    --- Diff: mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala ---
    @@ -84,35 +86,73 @@ private[ml] object Node {
       /**
        * Create a new Node from the old Node format, recursively creating 
child nodes as needed.
        */
    -  def fromOld(oldNode: OldNode, categoricalFeatures: Map[Int, Int]): Node 
= {
    +  def fromOld(oldNode: OldNode, categoricalFeatures: Map[Int, Int],
    +    isClassification: Boolean): Node = {
         if (oldNode.isLeaf) {
           // TODO: Once the implementation has been moved to this API, then 
include sufficient
           //       statistics here.
    -      new LeafNode(prediction = oldNode.predict.predict,
    -        impurity = oldNode.impurity, impurityStats = null)
    +      if (isClassification) {
    +        new ClassificationLeafNode(prediction = oldNode.predict.predict,
    +          impurity = oldNode.impurity, impurityStats = null)
    +      } else {
    +        new RegressionLeafNode(prediction = oldNode.predict.predict,
    +          impurity = oldNode.impurity, impurityStats = null)
    +      }
         } else {
           val gain = if (oldNode.stats.nonEmpty) {
             oldNode.stats.get.gain
           } else {
             0.0
           }
    -      new InternalNode(prediction = oldNode.predict.predict, impurity = 
oldNode.impurity,
    -        gain = gain, leftChild = fromOld(oldNode.leftNode.get, 
categoricalFeatures),
    -        rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures),
    -        split = Split.fromOld(oldNode.split.get, categoricalFeatures), 
impurityStats = null)
    +      if (isClassification) {
    +        new ClassificationInternalNode(prediction = 
oldNode.predict.predict,
    +          impurity = oldNode.impurity, gain = gain,
    +          leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures, 
true)
    +            .asInstanceOf[ClassificationNode],
    +          rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures, 
true)
    +            .asInstanceOf[ClassificationNode],
    +          split = Split.fromOld(oldNode.split.get, categoricalFeatures), 
impurityStats = null)
    +      } else {
    +        new RegressionInternalNode(prediction = oldNode.predict.predict,
    +          impurity = oldNode.impurity, gain = gain,
    +          leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures, 
false)
    +            .asInstanceOf[RegressionNode],
    +          rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures, 
false)
    +            .asInstanceOf[RegressionNode],
    +          split = Split.fromOld(oldNode.split.get, categoricalFeatures), 
impurityStats = null)
    +      }
         }
       }
     }
     
    -/**
    - * Decision tree leaf node.
    - * @param prediction  Prediction this node makes
    - * @param impurity  Impurity measure at this node (for training data)
    - */
    -class LeafNode private[ml] (
    -    override val prediction: Double,
    -    override val impurity: Double,
    -    override private[ml] val impurityStats: ImpurityCalculator) extends 
Node {
    +trait ClassificationNode extends Node {
    +
    +  @Since("2.4.0")
    +  def getLabelCount(label: Int): Double = {
    +    require(label >= 0 && label < impurityStats.stats.length)
    +    impurityStats.stats(label)
    +  }
    +}
    +
    +trait RegressionNode extends Node {
    +
    +  @Since("2.4.0")
    +  def getCount(): Double = impurityStats.stats(0)
    +
    +  @Since("2.4.0")
    +  def getSum(): Double = impurityStats.stats(1)
    +
    +  @Since("2.4.0")
    +  def getSquareSum(): Double = impurityStats.stats(2)
    --- End diff --
    
    rename: getSumOfSquares


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to