Github user WeichenXu123 commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20121#discussion_r171796232
  
    --- Diff: 
mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
 ---
    @@ -169,59 +171,28 @@ class GBTClassifierSuite extends SparkFunSuite with 
MLlibTestSparkContext
         val blas = BLAS.getInstance()
     
         val validationDataset = validationData.toDF(labelCol, featuresCol)
    -    val results = gbtModel.transform(validationDataset)
    -    // check that raw prediction is tree predictions dot tree weights
    -    results.select(rawPredictionCol, featuresCol).collect().foreach {
    -      case Row(raw: Vector, features: Vector) =>
    +    testTransformer[(Double, Vector)](validationDataset, gbtModel,
    +      "rawPrediction", "features", "probability", "prediction") {
    +      case Row(raw: Vector, features: Vector, prob: Vector, pred: Double) 
=>
             assert(raw.size === 2)
    +        // check that raw prediction is tree predictions dot tree weights
             val treePredictions = 
gbtModel.trees.map(_.rootNode.predictImpl(features).prediction)
             val prediction = blas.ddot(gbtModel.numTrees, treePredictions, 1, 
gbtModel.treeWeights, 1)
             assert(raw ~== Vectors.dense(-prediction, prediction) relTol eps)
    -    }
     
    -    // Compare rawPrediction with probability
    -    results.select(rawPredictionCol, probabilityCol).collect().foreach {
    -      case Row(raw: Vector, prob: Vector) =>
    -        assert(raw.size === 2)
    +        // Compare rawPrediction with probability
             assert(prob.size === 2)
             // Note: we should check other loss types for classification if 
they are added
             val predFromRaw = raw.toDense.values.map(value => 
LogLoss.computeProbability(value))
             assert(prob(0) ~== predFromRaw(0) relTol eps)
             assert(prob(1) ~== predFromRaw(1) relTol eps)
             assert(prob(0) + prob(1) ~== 1.0 absTol absEps)
    -    }
     
    -    // Compare prediction with probability
    -    results.select(predictionCol, probabilityCol).collect().foreach {
    -      case Row(pred: Double, prob: Vector) =>
    +        // Compare prediction with probability
             val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2
             assert(pred == predFromProb)
         }
     
    -    // force it to use raw2prediction
    -    gbtModel.setRawPredictionCol(rawPredictionCol).setProbabilityCol("")
    -    val resultsUsingRaw2Predict =
    -      
gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect()
    -    
resultsUsingRaw2Predict.zip(results.select(predictionCol).as[Double].collect()).foreach
 {
    -      case (pred1, pred2) => assert(pred1 === pred2)
    -    }
    -
    -    // force it to use probability2prediction
    -    gbtModel.setRawPredictionCol("").setProbabilityCol(probabilityCol)
    -    val resultsUsingProb2Predict =
    -      
gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect()
    -    
resultsUsingProb2Predict.zip(results.select(predictionCol).as[Double].collect()).foreach
 {
    -      case (pred1, pred2) => assert(pred1 === pred2)
    -    }
    -
    -    // force it to use predict
    --- End diff --
    
    These testing code path has been covered by 
`ProbabilisticClassifierSuite.testPredictMethods`. 


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to