Github user jkbradley commented on a diff in the pull request:
https://github.com/apache/spark/pull/21129#discussion_r186572374
--- Diff:
mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
---
@@ -365,6 +366,50 @@ class GBTClassifierSuite extends MLTest with
DefaultReadWriteTest {
assert(mostImportantFeature !== mostIF)
}
+ test("runWithValidation stops early and performs better on a validation
dataset") {
+ val validationIndicatorCol = "validationIndicator"
+ val trainDF = trainData.toDF().withColumn(validationIndicatorCol,
lit(false))
+ val validationDF =
validationData.toDF().withColumn(validationIndicatorCol, lit(true))
+
+ val numIter = 20
+ for (lossType <- GBTClassifier.supportedLossTypes) {
+ val gbt = new GBTClassifier()
+ .setSeed(123)
+ .setMaxDepth(2)
+ .setLossType(lossType)
+ .setMaxIter(numIter)
+ val modelWithoutValidation = gbt.fit(trainDF)
+
+ gbt.setValidationIndicatorCol(validationIndicatorCol)
+ val modelWithValidation = gbt.fit(trainDF.union(validationDF))
+
+ // early stop
+ assert(modelWithValidation.numTrees < numIter)
+
+ val (errorWithoutValidation, errorWithValidation) = {
+ val remappedRdd = validationData.map(x => new LabeledPoint(2 *
x.label - 1, x.features))
+ (GradientBoostedTrees.computeError(remappedRdd,
modelWithoutValidation.trees,
+ modelWithoutValidation.treeWeights,
modelWithoutValidation.getOldLossType),
+ GradientBoostedTrees.computeError(remappedRdd,
modelWithValidation.trees,
+ modelWithValidation.treeWeights,
modelWithValidation.getOldLossType))
+ }
+ assert(errorWithValidation <= errorWithoutValidation)
--- End diff --
It'd be nice to have this be strictly true. Is it not?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]