Github user smurching commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19381#discussion_r146986798
  
    --- Diff: 
mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
 ---
    @@ -267,6 +268,24 @@ class DecisionTreeClassifierSuite
           Vector, DecisionTreeClassificationModel](newTree, newData)
       }
     
    +  test("prediction on single instance") {
    +    val rdd = continuousDataPointsForMulticlassRDD
    +    val dt = new DecisionTreeClassifier()
    +      .setImpurity("Gini")
    +      .setMaxDepth(4)
    +      .setMaxBins(100)
    +    val categoricalFeatures = Map(0 -> 3)
    +    val numClasses = 3
    +
    +    val newData: DataFrame = TreeTests.setMetadata(rdd, 
categoricalFeatures, numClasses)
    +    val newTree = dt.fit(newData)
    +
    +    newTree.transform(newData).select(dt.getFeaturesCol, 
dt.getPredictionCol).collect().foreach {
    +      case Row(features: Vector, prediction: Double) =>
    +        assert(prediction ~== newTree.predict(features) relTol 1E-5)
    --- End diff --
    
    Can we test exact equality (e.g. `prediction === 
newTree.predict(features)`) here and in other unit tests?


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to