Github user sethah commented on a diff in the pull request:
https://github.com/apache/spark/pull/16441#discussion_r95455906
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala ---
@@ -241,19 +261,42 @@ class GBTClassificationModel private[ml](
}
override protected def predict(features: Vector): Double = {
- // TODO: When we add a generic Boosting class, handle transform there?
SPARK-7129
- // Classifies by thresholding sum of weighted tree predictions
- val treePredictions =
_trees.map(_.rootNode.predictImpl(features).prediction)
- val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights,
1)
- if (prediction > 0.0) 1.0 else 0.0
+ // If thresholds defined, use predictRaw to get probabilities,
otherwise use optimization
+ if (isDefined(thresholds)) {
+ super.predict(features)
+ } else {
+ val prediction: Double = margin(features)
+ if (prediction > 0.0) 1.0 else 0.0
+ }
+ }
+
+ override protected def predictRaw(features: Vector): Vector = {
+ val prediction: Double = margin(features)
+ Vectors.dense(Array(-prediction, prediction))
+ }
+
+ override protected def raw2probabilityInPlace(rawPrediction: Vector):
Vector = {
+ rawPrediction match {
+ // The probability can be calculated for positive result:
+ // p+(x) = 1 / (1 + e^(-2 * F(x)))
+ // and negative result:
+ // p-(x) = 1 / (1 + e^(2 * F(x)))
+ case dv: DenseVector =>
+ dv.values(0) = getOldLossType.computeProbability(dv.values(0))
--- End diff --
Should we make a private class member `private val loss = getOldLossType`?
Otherwise we call getOldLossType, (which calls `getLossType`) for every single
instance.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]