Github user srowen commented on a diff in the pull request:
https://github.com/apache/spark/pull/20632#discussion_r168922781
--- Diff:
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala ---
@@ -640,4 +689,96 @@ private object RandomForestSuite {
val (indices, values) = map.toSeq.sortBy(_._1).unzip
Vectors.sparse(size, indices.toArray, values.toArray)
}
+
+ /** Generate a label. */
+ private def generateLabel(rnd: Random, numClasses: Int): Double = {
+ rnd.nextInt(numClasses)
+ }
+
+ /** Generate a numeric value in the range [numericMin, numericMax]. */
+ private def generateNumericValue(rnd: Random, numericMin: Double,
numericMax: Double) : Double = {
+ rnd.nextDouble() * (Math.abs(numericMax) + Math.abs(numericMin)) +
numericMin
+ }
+
+ /** Generate a binary value. */
+ private def generateBinaryValue(rnd: Random) : Double = if
(rnd.nextBoolean()) 1 else 0
+
+ /** Generate an array of binary values of length numBinary. */
+ private def generateBinaryArray(rnd: Random, numBinary: Int):
Array[Double] = {
+ Range.apply(0, numBinary).map(_ => generateBinaryValue(rnd)).toArray
+ }
+
+ /** Generate an array of binary values of length numNumeric in the range
+ * [numericMin, numericMax]. */
+ private def generateNumericArray(rnd: Random,
+ numNumeric: Int,
+ numericMin: Double,
+ numericMax: Double) : Array[Double] = {
+ Range.apply(0, numNumeric).map(_ => generateNumericValue(rnd,
numericMin, numericMax)).toArray
+ }
+
+ /** Generate a LabeledPoint with numNumeric numeric values followed by
numBinary binary values. */
+ private def generatePoint(rnd: Random,
+ numClasses: Int,
+ numNumeric: Int = 10,
+ numericMin: Double = -100,
+ numericMax: Double = 100,
+ numBinary: Int = 10): LabeledPoint = {
+ val label = generateLabel(rnd, numClasses)
+ val numericArray = generateNumericArray(rnd, numNumeric, numericMin,
numericMax)
+ val binaryArray = generateBinaryArray(rnd, numBinary)
+ val vector = Vectors.dense(numericArray ++ binaryArray)
+
+ new LabeledPoint(label, vector)
+ }
+
+ /** Data for tree redundancy tests which produces a non-trivial tree. */
+ private def generateRedundantPoints(sc: SparkContext,
+ numClasses: Int,
+ numPoints: Int,
+ rnd: Random): RDD[LabeledPoint] = {
+ sc.parallelize(Range.apply(1, numPoints)
+ .map(_ => generatePoint(rnd, numClasses = numClasses, numNumeric =
0, numBinary = 3)))
+ }
+
+ /**
+ * Returns true iff the decision tree has at least one subtree that can
be pruned
+ * (i.e., all its leaves share the same prediction).
+ * @param tree the tree to be tested
+ * @return true iff the decision tree has at least one subtree that can
be pruned.
+ */
+ private def isRedundant(tree: DecisionTreeModel): Boolean =
_isRedundant(tree.rootNode)
--- End diff --
This method doesn't seem to add enough to be a whole new method with 4
lines of docs
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]