Repository: spark
Updated Branches:
  refs/heads/master f3fe55439 -> 248916f55


[SPARK-17057][ML] ProbabilisticClassifierModels' thresholds should have at most 
one 0

## What changes were proposed in this pull request?

Match ProbabilisticClassifer.thresholds requirements to R randomForest cutoff, 
requiring all > 0

## How was this patch tested?

Jenkins tests plus new test cases

Author: Sean Owen <[email protected]>

Closes #15149 from srowen/SPARK-17057.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/248916f5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/248916f5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/248916f5

Branch: refs/heads/master
Commit: 248916f5589155c0c3e93c3874781f17b08d598d
Parents: f3fe554
Author: Sean Owen <[email protected]>
Authored: Sat Sep 24 08:15:55 2016 +0100
Committer: Sean Owen <[email protected]>
Committed: Sat Sep 24 08:15:55 2016 +0100

----------------------------------------------------------------------
 .../ml/classification/LogisticRegression.scala  |  5 +--
 .../ProbabilisticClassifier.scala               | 20 +++++------
 .../ml/param/shared/SharedParamsCodeGen.scala   |  8 +++--
 .../spark/ml/param/shared/sharedParams.scala    |  4 +--
 .../ProbabilisticClassifierSuite.scala          | 35 ++++++++++++++++----
 .../pyspark/ml/param/_shared_params_code_gen.py |  5 +--
 python/pyspark/ml/param/shared.py               |  4 +--
 7 files changed, 52 insertions(+), 29 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/248916f5/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 343d50c..5ab63d1 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -123,9 +123,10 @@ private[classification] trait LogisticRegressionParams 
extends ProbabilisticClas
 
   /**
    * Set thresholds in multiclass (or binary) classification to adjust the 
probability of
-   * predicting each class. Array must have length equal to the number of 
classes, with values >= 0.
+   * predicting each class. Array must have length equal to the number of 
classes, with values > 0,
+   * excepting that at most one value may be 0.
    * The class with largest value p/t is predicted, where p is the original 
probability of that
-   * class and t is the class' threshold.
+   * class and t is the class's threshold.
    *
    * Note: When [[setThresholds()]] is called, any user-set value for 
[[threshold]] will be cleared.
    *       If both [[threshold]] and [[thresholds]] are set in a ParamMap, 
then they must be

http://git-wip-us.apache.org/repos/asf/spark/blob/248916f5/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
index 1b6e775..e89da6f 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.ml.classification
 
 import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors, VectorUDT}
+import org.apache.spark.ml.linalg.{DenseVector, Vector, VectorUDT}
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util.SchemaUtils
 import org.apache.spark.sql.{DataFrame, Dataset}
@@ -200,22 +200,20 @@ abstract class ProbabilisticClassificationModel[
     if (!isDefined(thresholds)) {
       probability.argmax
     } else {
-      val thresholds: Array[Double] = getThresholds
-      val probabilities = probability.toArray
+      val thresholds = getThresholds
       var argMax = 0
       var max = Double.NegativeInfinity
       var i = 0
       val probabilitySize = probability.size
       while (i < probabilitySize) {
-        if (thresholds(i) == 0.0) {
-          max = Double.PositiveInfinity
+        // Thresholds are all > 0, excepting that at most one may be 0.
+        // The single class whose threshold is 0, if any, will always be 
predicted
+        // ('scaled' = +Infinity). However in the case that this class also has
+        // 0 probability, the class will not be selected ('scaled' is NaN).
+        val scaled = probability(i) / thresholds(i)
+        if (scaled > max) {
+          max = scaled
           argMax = i
-        } else {
-          val scaled = probabilities(i) / thresholds(i)
-          if (scaled > max) {
-            max = scaled
-            argMax = i
-          }
         }
         i += 1
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/248916f5/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index 480b03d..c94b8b4 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -50,10 +50,12 @@ private[shared] object SharedParamsCodeGen {
         isValid = "ParamValidators.inRange(0, 1)", finalMethods = false),
       ParamDesc[Array[Double]]("thresholds", "Thresholds in multi-class 
classification" +
         " to adjust the probability of predicting each class." +
-        " Array must have length equal to the number of classes, with values 
>= 0." +
+        " Array must have length equal to the number of classes, with values > 
0" +
+        " excepting that at most one value may be 0." +
         " The class with largest value p/t is predicted, where p is the 
original probability" +
-        " of that class and t is the class' threshold",
-        isValid = "(t: Array[Double]) => t.forall(_ >= 0)", finalMethods = 
false),
+        " of that class and t is the class's threshold",
+        isValid = "(t: Array[Double]) => t.forall(_ >= 0) && t.count(_ == 0) 
<= 1",
+        finalMethods = false),
       ParamDesc[String]("inputCol", "input column name"),
       ParamDesc[Array[String]]("inputCols", "input column names"),
       ParamDesc[String]("outputCol", "output column name", Some("uid + 
\"__output\"")),

http://git-wip-us.apache.org/repos/asf/spark/blob/248916f5/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala 
b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index 9125d9e..fa45309 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -176,10 +176,10 @@ private[ml] trait HasThreshold extends Params {
 private[ml] trait HasThresholds extends Params {
 
   /**
-   * Param for Thresholds in multi-class classification to adjust the 
probability of predicting each class. Array must have length equal to the 
number of classes, with values >= 0. The class with largest value p/t is 
predicted, where p is the original probability of that class and t is the 
class' threshold.
+   * Param for Thresholds in multi-class classification to adjust the 
probability of predicting each class. Array must have length equal to the 
number of classes, with values > 0 excepting that at most one value may be 0. 
The class with largest value p/t is predicted, where p is the original 
probability of that class and t is the class's threshold.
    * @group param
    */
-  final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, 
"thresholds", "Thresholds in multi-class classification to adjust the 
probability of predicting each class. Array must have length equal to the 
number of classes, with values >= 0. The class with largest value p/t is 
predicted, where p is the original probability of that class and t is the 
class' threshold", (t: Array[Double]) => t.forall(_ >= 0))
+  final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, 
"thresholds", "Thresholds in multi-class classification to adjust the 
probability of predicting each class. Array must have length equal to the 
number of classes, with values > 0 excepting that at most one value may be 0. 
The class with largest value p/t is predicted, where p is the original 
probability of that class and t is the class's threshold", (t: Array[Double]) 
=> t.forall(_ >= 0) && t.count(_ == 0) <= 1)
 
   /** @group getParam */
   def getThresholds: Array[Double] = $(thresholds)

http://git-wip-us.apache.org/repos/asf/spark/blob/248916f5/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
index b3bd2b3..172c64a 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
@@ -36,8 +36,8 @@ final class TestProbabilisticClassificationModel(
     rawPrediction
   }
 
-  def friendlyPredict(input: Vector): Double = {
-    predict(input)
+  def friendlyPredict(values: Double*): Double = {
+    predict(Vectors.dense(values.toArray))
   }
 }
 
@@ -45,16 +45,37 @@ final class TestProbabilisticClassificationModel(
 class ProbabilisticClassifierSuite extends SparkFunSuite {
 
   test("test thresholding") {
-    val thresholds = Array(0.5, 0.2)
     val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2)
-      .setThresholds(thresholds)
-    assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 1.0))) === 1.0)
-    assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 0.2))) === 0.0)
+      .setThresholds(Array(0.5, 0.2))
+    assert(testModel.friendlyPredict(1.0, 1.0) === 1.0)
+    assert(testModel.friendlyPredict(1.0, 0.2) === 0.0)
   }
 
   test("test thresholding not required") {
     val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2)
-    assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 2.0))) === 1.0)
+    assert(testModel.friendlyPredict(1.0, 2.0) === 1.0)
+  }
+
+  test("test tiebreak") {
+    val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2)
+      .setThresholds(Array(0.4, 0.4))
+    assert(testModel.friendlyPredict(0.6, 0.6) === 0.0)
+  }
+
+  test("test one zero threshold") {
+    val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2)
+      .setThresholds(Array(0.0, 0.1))
+    assert(testModel.friendlyPredict(1.0, 10.0) === 0.0)
+    assert(testModel.friendlyPredict(0.0, 10.0) === 1.0)
+  }
+
+  test("bad thresholds") {
+    intercept[IllegalArgumentException] {
+      new TestProbabilisticClassificationModel("myuid", 2, 
2).setThresholds(Array(0.0, 0.0))
+    }
+    intercept[IllegalArgumentException] {
+      new TestProbabilisticClassificationModel("myuid", 2, 
2).setThresholds(Array(-0.1, 0.1))
+    }
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/248916f5/python/pyspark/ml/param/_shared_params_code_gen.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py 
b/python/pyspark/ml/param/_shared_params_code_gen.py
index 4f4328b..9295912 100644
--- a/python/pyspark/ml/param/_shared_params_code_gen.py
+++ b/python/pyspark/ml/param/_shared_params_code_gen.py
@@ -139,8 +139,9 @@ if __name__ == "__main__":
          "model.", "True", "TypeConverters.toBoolean"),
         ("thresholds", "Thresholds in multi-class classification to adjust the 
probability of " +
          "predicting each class. Array must have length equal to the number of 
classes, with " +
-         "values >= 0. The class with largest value p/t is predicted, where p 
is the original " +
-         "probability of that class and t is the class' threshold.", None,
+         "values > 0, excepting that at most one value may be 0. " +
+         "The class with largest value p/t is predicted, where p is the 
original " +
+         "probability of that class and t is the class's threshold.", None,
          "TypeConverters.toListFloat"),
         ("weightCol", "weight column name. If this is not set or empty, we 
treat " +
          "all instance weights as 1.0.", None, "TypeConverters.toString"),

http://git-wip-us.apache.org/repos/asf/spark/blob/248916f5/python/pyspark/ml/param/shared.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/param/shared.py 
b/python/pyspark/ml/param/shared.py
index 24af07a..cc59693 100644
--- a/python/pyspark/ml/param/shared.py
+++ b/python/pyspark/ml/param/shared.py
@@ -469,10 +469,10 @@ class HasStandardization(Params):
 
 class HasThresholds(Params):
     """
-    Mixin for param thresholds: Thresholds in multi-class classification to 
adjust the probability of predicting each class. Array must have length equal 
to the number of classes, with values >= 0. The class with largest value p/t is 
predicted, where p is the original probability of that class and t is the 
class' threshold.
+    Mixin for param thresholds: Thresholds in multi-class classification to 
adjust the probability of predicting each class. Array must have length equal 
to the number of classes, with values > 0, excepting that at most one value may 
be 0. The class with largest value p/t is predicted, where p is the original 
probability of that class and t is the class's threshold.
     """
 
-    thresholds = Param(Params._dummy(), "thresholds", "Thresholds in 
multi-class classification to adjust the probability of predicting each class. 
Array must have length equal to the number of classes, with values >= 0. The 
class with largest value p/t is predicted, where p is the original probability 
of that class and t is the class' threshold.", 
typeConverter=TypeConverters.toListFloat)
+    thresholds = Param(Params._dummy(), "thresholds", "Thresholds in 
multi-class classification to adjust the probability of predicting each class. 
Array must have length equal to the number of classes, with values > 0, 
excepting that at most one value may be 0. The class with largest value p/t is 
predicted, where p is the original probability of that class and t is the 
class's threshold.", typeConverter=TypeConverters.toListFloat)
 
     def __init__(self):
         super(HasThresholds, self).__init__()


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to