Github user jkbradley commented on a diff in the pull request:
https://github.com/apache/spark/pull/5669#discussion_r29004162
--- Diff:
mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala ---
@@ -177,102 +177,108 @@ object GradientBoostedTrees extends Logging {
treeStrategy.assertValid()
// Cache input
- if (input.getStorageLevel == StorageLevel.NONE) {
+ val persistedInput = if (input.getStorageLevel == StorageLevel.NONE) {
input.persist(StorageLevel.MEMORY_AND_DISK)
- }
+ true
+ } else false
- timer.stop("init")
+ try {
+ timer.stop("init")
- logDebug("##########")
- logDebug("Building tree 0")
- logDebug("##########")
- var data = input
+ logDebug("##########")
+ logDebug("Building tree 0")
+ logDebug("##########")
+ var data = input
- // Initialize tree
- timer.start("building tree 0")
- val firstTreeModel = new DecisionTree(treeStrategy).run(data)
- val firstTreeWeight = 1.0
- baseLearners(0) = firstTreeModel
- baseLearnerWeights(0) = firstTreeWeight
+ // Initialize tree
+ timer.start("building tree 0")
+ val firstTreeModel = new DecisionTree(treeStrategy).run(data)
+ val firstTreeWeight = 1.0
+ baseLearners(0) = firstTreeModel
+ baseLearnerWeights(0) = firstTreeWeight
- var predError: RDD[(Double, Double)] = GradientBoostedTreesModel.
- computeInitialPredictionAndError(input, firstTreeWeight,
firstTreeModel, loss)
- logDebug("error of gbt = " + predError.values.mean())
+ var predError: RDD[(Double, Double)] = GradientBoostedTreesModel.
+ computeInitialPredictionAndError(input, firstTreeWeight,
firstTreeModel, loss)
+ logDebug("error of gbt = " + predError.values.mean())
- // Note: A model of type regression is used since we require raw
prediction
- timer.stop("building tree 0")
+ // Note: A model of type regression is used since we require raw
prediction
+ timer.stop("building tree 0")
- var validatePredError: RDD[(Double, Double)] =
GradientBoostedTreesModel.
- computeInitialPredictionAndError(validationInput, firstTreeWeight,
firstTreeModel, loss)
- var bestValidateError = if (validate) validatePredError.values.mean()
else 0.0
- var bestM = 1
+ var validatePredError: RDD[(Double, Double)] =
GradientBoostedTreesModel.
+ computeInitialPredictionAndError(validationInput, firstTreeWeight,
firstTreeModel, loss)
+ var bestValidateError = if (validate)
validatePredError.values.mean() else 0.0
+ var bestM = 1
- // pseudo-residual for second iteration
- data = predError.zip(input).map { case ((pred, _), point) =>
- LabeledPoint(-loss.gradient(pred, point.label), point.features)
- }
+ // pseudo-residual for second iteration
+ data = predError.zip(input).map { case ((pred, _), point) =>
+ LabeledPoint(-loss.gradient(pred, point.label), point.features)
+ }
- var m = 1
- while (m < numIterations) {
- timer.start(s"building tree $m")
- logDebug("###################################################")
- logDebug("Gradient boosting tree iteration " + m)
- logDebug("###################################################")
- val model = new DecisionTree(treeStrategy).run(data)
- timer.stop(s"building tree $m")
- // Create partial model
- baseLearners(m) = model
- // Note: The setting of baseLearnerWeights is incorrect for losses
other than SquaredError.
- // Technically, the weight should be optimized for the
particular loss.
- // However, the behavior should be reasonable, though not
optimal.
- baseLearnerWeights(m) = learningRate
- // Note: A model of type regression is used since we require raw
prediction
- val partialModel = new GradientBoostedTreesModel(
- Regression, baseLearners.slice(0, m + 1),
- baseLearnerWeights.slice(0, m + 1))
+ var m = 1
+ while (m < numIterations) {
+ timer.start(s"building tree $m")
+ logDebug("###################################################")
+ logDebug("Gradient boosting tree iteration " + m)
+ logDebug("###################################################")
+ val model = new DecisionTree(treeStrategy).run(data)
+ timer.stop(s"building tree $m")
+ // Create partial model
+ baseLearners(m) = model
+ // Note: The setting of baseLearnerWeights is incorrect for losses
other than SquaredError.
+ // Technically, the weight should be optimized for the
particular loss.
+ // However, the behavior should be reasonable, though not
optimal.
+ baseLearnerWeights(m) = learningRate
+ // Note: A model of type regression is used since we require raw
prediction
+ val partialModel = new GradientBoostedTreesModel(
+ Regression, baseLearners.slice(0, m + 1),
+ baseLearnerWeights.slice(0, m + 1))
- predError = GradientBoostedTreesModel.updatePredictionError(
- input, predError, baseLearnerWeights(m), baseLearners(m), loss)
- logDebug("error of gbt = " + predError.values.mean())
+ predError = GradientBoostedTreesModel.updatePredictionError(
+ input, predError, baseLearnerWeights(m), baseLearners(m), loss)
+ logDebug("error of gbt = " + predError.values.mean())
- if (validate) {
- // Stop training early if
- // 1. Reduction in error is less than the validationTol or
- // 2. If the error increases, that is if the model is overfit.
- // We want the model returned corresponding to the best validation
error.
+ if (validate) {
+ // Stop training early if
+ // 1. Reduction in error is less than the validationTol or
+ // 2. If the error increases, that is if the model is overfit.
+ // We want the model returned corresponding to the best
validation error.
- validatePredError =
GradientBoostedTreesModel.updatePredictionError(
- validationInput, validatePredError, baseLearnerWeights(m),
baseLearners(m), loss)
- val currentValidateError = validatePredError.values.mean()
- if (bestValidateError - currentValidateError < validationTol) {
- return new GradientBoostedTreesModel(
- boostingStrategy.treeStrategy.algo,
- baseLearners.slice(0, bestM),
- baseLearnerWeights.slice(0, bestM))
- } else if (currentValidateError < bestValidateError) {
- bestValidateError = currentValidateError
- bestM = m + 1
+ validatePredError =
GradientBoostedTreesModel.updatePredictionError(
+ validationInput, validatePredError, baseLearnerWeights(m),
baseLearners(m), loss)
+ val currentValidateError = validatePredError.values.mean()
+ if (bestValidateError - currentValidateError < validationTol) {
+ return new GradientBoostedTreesModel(
+ boostingStrategy.treeStrategy.algo,
+ baseLearners.slice(0, bestM),
+ baseLearnerWeights.slice(0, bestM))
+ } else if (currentValidateError < bestValidateError) {
+ bestValidateError = currentValidateError
+ bestM = m + 1
+ }
}
+ // Update data with pseudo-residuals
+ data = predError.zip(input).map { case ((pred, _), point) =>
+ LabeledPoint(-loss.gradient(pred, point.label), point.features)
+ }
+ m += 1
}
- // Update data with pseudo-residuals
- data = predError.zip(input).map { case ((pred, _), point) =>
- LabeledPoint(-loss.gradient(pred, point.label), point.features)
- }
- m += 1
- }
- timer.stop("total")
+ timer.stop("total")
+
+ logInfo("Internal timing for DecisionTree:")
+ logInfo(s"$timer")
+ if (validate) {
+ new GradientBoostedTreesModel(
+ boostingStrategy.treeStrategy.algo,
+ baseLearners.slice(0, bestM),
+ baseLearnerWeights.slice(0, bestM))
+ } else {
+ new GradientBoostedTreesModel(
+ boostingStrategy.treeStrategy.algo, baseLearners,
baseLearnerWeights)
+ }
- logInfo("Internal timing for DecisionTree:")
- logInfo(s"$timer")
- if (validate) {
- new GradientBoostedTreesModel(
- boostingStrategy.treeStrategy.algo,
- baseLearners.slice(0, bestM),
- baseLearnerWeights.slice(0, bestM))
- } else {
- new GradientBoostedTreesModel(
- boostingStrategy.treeStrategy.algo, baseLearners,
baseLearnerWeights)
+ } finally {
+ if (persistedInput) input.unpersist()
--- End diff --
I agree it's cleaner this way, if a little more complex. Concerning
consistency, let me ping @mengxr since, if we do this here, we might start
talking about doing it everywhere.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]