Github user asolimando commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20632#discussion_r170139957
  
    --- Diff: 
mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala ---
    @@ -402,20 +405,40 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
           LabeledPoint(1.0, Vectors.dense(2.0)))
         val input = sc.parallelize(arr)
     
    +    val seed = 42
    +    val numTrees = 1
    +
         // Must set maxBins s.t. the feature will be treated as an ordered 
categorical feature.
         val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity 
= Gini, maxDepth = 1,
           numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3)
     
    -    val model = RandomForest.run(input, strategy, numTrees = 1, 
featureSubsetStrategy = "all",
    -      seed = 42, instr = None).head
    -    model.rootNode match {
    -      case n: InternalNode => n.split match {
    -        case s: CategoricalSplit =>
    -          assert(s.leftCategories === Array(1.0))
    -        case _ => throw new AssertionError("model.rootNode.split was not a 
CategoricalSplit")
    -      }
    -      case _ => throw new AssertionError("model.rootNode was not an 
InternalNode")
    -    }
    +    val metadata = DecisionTreeMetadata.buildMetadata(input, strategy, 
numTrees = numTrees,
    +      featureSubsetStrategy = "all")
    +    val splits = RandomForest.findSplits(input, metadata, seed = seed)
    +
    +    val treeInput = TreePoint.convertToTreeRDD(input, splits, metadata)
    +    val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput,
    +      strategy.subsamplingRate, numTrees, false, seed = seed)
    +
    +    val topNode = LearningNode.emptyNode(nodeIndex = 1)
    +    assert(topNode.isLeaf === false)
    +    assert(topNode.stats === null)
    +
    +    val nodesForGroup = Map(0 -> Array(topNode))
    +    val treeToNodeToIndexInfo = Map(0 -> Map(
    +      topNode.id -> new RandomForest.NodeIndexInfo(0, None)
    +    ))
    +    val nodeStack = new mutable.ArrayStack[(Int, LearningNode)]
    +    val bestSplit = RandomForest.findBestSplits(baggedInput, metadata, 
Map(0 -> topNode),
    +      nodesForGroup, treeToNodeToIndexInfo, splits, nodeStack)
    +
    +    assert(topNode.split.isDefined, "rootNode does not have a split")
    --- End diff --
    
    It is true but I need to call these internal methods to initialise the 
structure correctly, including _rootNode_.
    
    I have removed the only lines that did not look necessary to me:
    `  assert(topNode.isLeaf === false)`
    `   assert(topNode.stats === null)`
    
    What do you think?


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to