imatiach-msft commented on a change in pull request #25926: [SPARK-9612][ML] 
Add instance weight support for GBTs
URL: https://github.com/apache/spark/pull/25926#discussion_r335027506
 
 

 ##########
 File path: 
mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
 ##########
 @@ -166,29 +163,50 @@ private[spark] object GradientBoostedTrees extends 
Logging {
    * Method to calculate error of the base learner for the gradient boosting 
calculation.
    * Note: This method is not used by the gradient boosting algorithm but is 
useful for debugging
    * purposes.
-   * @param data Training dataset: RDD of `LabeledPoint`.
+   * @param data Training dataset: RDD of `Instance`.
    * @param trees Boosted Decision Tree models
    * @param treeWeights Learning rates at each boosting iteration.
    * @param loss evaluation metric.
    * @return Measure of model error on data
    */
   def computeError(
-      data: RDD[LabeledPoint],
+      data: RDD[Instance],
       trees: Array[DecisionTreeRegressionModel],
       treeWeights: Array[Double],
       loss: OldLoss): Double = {
-    data.map { lp =>
+    val (errSum, weightSum) = data.map { case Instance(label, weight, 
features) =>
       val predicted = trees.zip(treeWeights).foldLeft(0.0) { case (acc, 
(model, weight)) =>
-        updatePrediction(lp.features, acc, model, weight)
+        updatePrediction(features, acc, model, weight)
       }
-      loss.computeError(predicted, lp.label)
-    }.mean()
+      (loss.computeError(predicted, label) * weight, weight)
+    }.treeReduce{ case ((err1, weight1), (err2, weight2)) =>
+        (err1 + err2, weight1 + weight2)
+    }
+    errSum / weightSum
+  }
+
+  /**
+   * Method to calculate error of the base learner for the gradient boosting 
calculation.
+   * @param data Training dataset: RDD of `Instance`.
+   * @param predError Prediction and error.
+   * @return Measure of model error on data
+   */
+  def computeError(
 
 Review comment:
   maybe change the name to computeWeightedError to make that clear, since the 
above methods are also computing error but unweighted

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to