Github user asolimando commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20632#discussion_r168925279
  
    --- Diff: 
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala ---
    @@ -640,4 +689,96 @@ private object RandomForestSuite {
         val (indices, values) = map.toSeq.sortBy(_._1).unzip
         Vectors.sparse(size, indices.toArray, values.toArray)
       }
    +
    +  /** Generate a label. */
    +  private def generateLabel(rnd: Random, numClasses: Int): Double = {
    +    rnd.nextInt(numClasses)
    +  }
    +
    +  /** Generate a numeric value in the range [numericMin, numericMax]. */
    +  private def generateNumericValue(rnd: Random, numericMin: Double, 
numericMax: Double) : Double = {
    +    rnd.nextDouble() * (Math.abs(numericMax) + Math.abs(numericMin)) + 
numericMin
    +  }
    +
    +  /** Generate a binary value. */
    +  private def generateBinaryValue(rnd: Random) : Double = if 
(rnd.nextBoolean()) 1 else 0
    +
    +  /** Generate an array of binary values of length numBinary. */
    +  private def generateBinaryArray(rnd: Random, numBinary: Int): 
Array[Double] = {
    +    Range.apply(0, numBinary).map(_ => generateBinaryValue(rnd)).toArray
    +  }
    +
    +  /** Generate an array of binary values of length numNumeric in the range
    +    * [numericMin, numericMax]. */
    +  private def generateNumericArray(rnd: Random,
    +                                   numNumeric: Int,
    +                                   numericMin: Double,
    +                                   numericMax: Double) : Array[Double] = {
    +    Range.apply(0, numNumeric).map(_ => generateNumericValue(rnd, 
numericMin, numericMax)).toArray
    +  }
    +
    +  /** Generate a LabeledPoint with numNumeric numeric values followed by 
numBinary binary values. */
    +  private def generatePoint(rnd: Random,
    +                            numClasses: Int,
    +                            numNumeric: Int = 10,
    +                            numericMin: Double = -100,
    +                            numericMax: Double = 100,
    +                            numBinary: Int = 10): LabeledPoint = {
    +    val label = generateLabel(rnd, numClasses)
    +    val numericArray = generateNumericArray(rnd, numNumeric, numericMin, 
numericMax)
    +    val binaryArray = generateBinaryArray(rnd, numBinary)
    +    val vector = Vectors.dense(numericArray ++ binaryArray)
    +
    +    new LabeledPoint(label, vector)
    +  }
    +
    +  /** Data for tree redundancy tests which produces a non-trivial tree. */
    +  private def generateRedundantPoints(sc: SparkContext,
    +                                      numClasses: Int,
    +                                      numPoints: Int,
    +                                      rnd: Random): RDD[LabeledPoint] = {
    +    sc.parallelize(Range.apply(1, numPoints)
    +      .map(_ => generatePoint(rnd, numClasses = numClasses, numNumeric = 
0, numBinary = 3)))
    +  }
    +
    +  /**
    +    * Returns true iff the decision tree has at least one subtree that can 
be pruned
    +    * (i.e., all its leaves share the same prediction).
    +    * @param tree the tree to be tested
    +    * @return true iff the decision tree has at least one subtree that can 
be pruned.
    +    */
    +  private def isRedundant(tree: DecisionTreeModel): Boolean = 
_isRedundant(tree.rootNode)
    +
    +  /**
    +    * Returns true iff the node has at least one subtree that can be pruned
    +    * (i.e., all its leaves share the same prediction).
    +    * @param n the node to be tested
    +    * @return true iff the node has at least one subtree that can be 
pruned.
    +    */
    +  private def _isRedundant(n: Node): Boolean = n match {
    +    case n: InternalNode =>
    +      _isRedundant(n.leftChild) || _isRedundant(n.leftChild) || 
canBePruned(n)
    +    case _ => false
    +  }
    +
    +  /**
    +    * Returns true iff the subtree rooted at the given node can be pruned
    +    * (i.e., all its leaves share the same prediction).
    +    * @param n the node to be tested
    +    * @return returns true iff the subtree rooted at the given node can be 
pruned.
    +    */
    +  private def canBePruned(n: Node): Boolean = n match {
    +    case n: InternalNode =>
    +      (leafPredictions(n.leftChild) ++ leafPredictions(n.rightChild)).size 
== 1
    +    case _ => false
    +  }
    +
    +  /**
    +    * Given a node, the method returns the set of predictions appearing in 
the subtree rooted at it.
    +    * @return the set of predictions appearing in the subtree rooted at 
the given node.
    +    */
    +  private def leafPredictions(n: Node): Set[Double] = n match {
    --- End diff --
    
    They are, but unfortunately here we work with (subclasses of) Node class 
(composing the decision tree), while the other methods performing the 
optimization are defined at LearningNode level, objects that are lost when the 
decision tree is given as output.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to