Github user jkbradley commented on a diff in the pull request:
https://github.com/apache/spark/pull/21097#discussion_r186865863
--- Diff:
mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
---
@@ -367,11 +367,31 @@ class GBTClassifierSuite extends MLTest with
DefaultReadWriteTest {
test("model evaluateEachIteration") {
val gbt = new GBTClassifier()
+ .setSeed(1L)
.setMaxDepth(2)
- .setMaxIter(2)
- val model = gbt.fit(trainData.toDF)
- val eval = model.evaluateEachIteration(validationData.toDF)
- assert(Vectors.dense(eval) ~== Vectors.dense(1.7641, 1.8209) relTol
1E-3)
+ .setMaxIter(3)
+ .setLossType("logistic")
+ val model3 = gbt.fit(trainData.toDF)
+ val model1 = new GBTClassificationModel("gbt-cls-model-test1",
+ model3.trees.take(1), model3.treeWeights.take(1),
model3.numFeatures, model3.numClasses)
+ val model2 = new GBTClassificationModel("gbt-cls-model-test2",
+ model3.trees.take(2), model3.treeWeights.take(2),
model3.numFeatures, model3.numClasses)
+
+ for (evalLossType <- GBTClassifier.supportedLossTypes) {
--- End diff --
evalLossType is not used, so I'd remove this loop.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]