Github user jkbradley commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5330#discussion_r28195992
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala 
---
    @@ -166,6 +150,68 @@ class GradientBoostedTreesModel(
     
     object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] 
{
     
    +  /**
    +   * Compute the initial predictions and errors for a dataset for the first
    +   * iteration of gradient boosting.
    +   * @param Training data.
    +   * @param initTreeWeight: learning rate assigned to the first tree.
    +   * @param initTree: first DecisionTreeModel.
    +   * @param loss: evaluation metric.
    +   * @return a RDD with each element being a zip of the prediction and 
error
    +   *         corresponding to every sample.
    +   */
    +  def computeInitialPredictionAndError(
    +      data: RDD[LabeledPoint],
    +      initTreeWeight: Double,
    +      initTree: DecisionTreeModel,
    +      loss: Loss): RDD[(Double, Double)] = {
    +    data.map { lp =>
    +      val pred = initTreeWeight * initTree.predict(lp.features)
    +      val error = loss.computeError(pred, lp.label)
    +      (pred, error)
    +    }
    +  }
    +
    +  /**
    +   * Update a zipped predictionError RDD
    +   * (as obtained with computeInitialPredictionAndError)
    +   * @param training data.
    +   * @param predictionAndError: predictionError RDD
    +   * @param nTree: tree index.
    +   * @param treeWeight: Learning rate.
    +   * @param tree: Tree using which the prediction and error should be 
updated.
    +   * @param loss: evaluation metric.
    +   * @return a RDD with each element being a zip of the prediction and 
error
    +   *         corresponding to each sample.
    +   */
    +  def updatePredictionError(
    +    data: RDD[LabeledPoint],
    +    predictionAndError: RDD[(Double, Double)],
    +    treeWeight: Double,
    +    tree: DecisionTreeModel,
    +    loss: Loss): RDD[(Double, Double)] = {
    +
    +    val sc = data.sparkContext
    +    val broadcastedTreeWeight = sc.broadcast(treeWeight)
    +    val broadcastedTree = sc.broadcast(tree)
    +
    +    val newPredError = data.zip(predictionAndError).mapPartitions { iter =>
    +      val currentTreeWeight = broadcastedTreeWeight.value
    +      val currentTree = broadcastedTree.value
    +      iter.map {
    +        case (lp, (pred, error)) => {
    +          val newPred = pred + currentTree.predict(lp.features) * 
currentTreeWeight
    +          val newError = loss.computeError(newPred, lp.label)
    +          (newPred, newError)
    +        }
    +      }
    +    }
    +
    +    broadcastedTreeWeight.unpersist()
    --- End diff --
    
    I realized that the broadcast variables can't be unpersisted here since 
they were used in a map for an RDD which has yet to be materialized.  (No 
action has been called on newPredError yet.)  We have 2 choices:
    * remove the unpersist, and let GC handle it once all of these RDDs go out 
of scope (at the end of training)
      * Let's do this one for now since it's simpler and since we aren't 
broadcasting (persisting) too much data.
    * return the broadcast variables, and keep them around until the caller can 
unpersist them safely
      * We can do this in the future if users ever encounter problems with too 
many broadcast variables.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to