This is an automated email from the ASF dual-hosted git repository.

srowen pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 00e6acc  [SPARK-31676][ML] QuantileDiscretizer raise error parameter 
splits given invalid value (splits array includes -0.0 and 0.0)
00e6acc is described below

commit 00e6acc9c6d45c5dd3b3f70c87909743a8073dba
Author: Weichen Xu <[email protected]>
AuthorDate: Thu May 14 09:24:40 2020 -0500

    [SPARK-31676][ML] QuantileDiscretizer raise error parameter splits given 
invalid value (splits array includes -0.0 and 0.0)
    
    ### What changes were proposed in this pull request?
    
    In QuantileDiscretizer.getDistinctSplits, before invoking distinct, 
normalize all -0.0 and 0.0 to be 0.0
    ```
        for (i <- 0 until splits.length) {
          if (splits(i) == -0.0) {
            splits(i) = 0.0
          }
        }
    ```
    ### Why are the changes needed?
    Fix bug.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Unit test.
    
    #### Manually test:
    
    ~~~scala
    import scala.util.Random
    val rng = new Random(3)
    
    val a1 = Array.tabulate(200)(_=>rng.nextDouble * 2.0 - 1.0) ++ 
Array.fill(20)(0.0) ++ Array.fill(20)(-0.0)
    
    import spark.implicits._
    val df1 = sc.parallelize(a1, 2).toDF("id")
    
    import org.apache.spark.ml.feature.QuantileDiscretizer
    val qd = new 
QuantileDiscretizer().setInputCol("id").setOutputCol("out").setNumBuckets(200).setRelativeError(0.0)
    
    val model = qd.fit(df1) // will raise error in spark master.
    ~~~
    
    ### Explain
    scala `0.0 == -0.0` is True but `0.0.hashCode == -0.0.hashCode()` is False. 
This break the contract between equals() and hashCode() If two objects are 
equal, then they must have the same hash code.
    
    And array.distinct will rely on elem.hashCode so it leads to this error.
    
    Test code on distinct
    ```
    import scala.util.Random
    val rng = new Random(3)
    
    val a1 = Array.tabulate(200)(_=>rng.nextDouble * 2.0 - 1.0) ++ 
Array.fill(20)(0.0) ++ Array.fill(20)(-0.0)
    a1.distinct.sorted.foreach(x => print(x.toString + "\n"))
    ```
    
    Then you will see output like:
    ```
    ...
    -0.009292684662246975
    -0.0033280686465135823
    -0.0
    0.0
    0.0022219556032221366
    0.02217419561977274
    ...
    ```
    
    Closes #28498 from WeichenXu123/SPARK-31676.
    
    Authored-by: Weichen Xu <[email protected]>
    Signed-off-by: Sean Owen <[email protected]>
    (cherry picked from commit b2300fca1e1a22d74c6eeda37942920a6c6299ff)
    Signed-off-by: Sean Owen <[email protected]>
---
 .../apache/spark/ml/feature/QuantileDiscretizer.scala  | 12 ++++++++++++
 .../spark/ml/feature/QuantileDiscretizerSuite.scala    | 18 ++++++++++++++++++
 2 files changed, 30 insertions(+)

diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index 216d99d..4eedfc4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -236,6 +236,18 @@ final class QuantileDiscretizer @Since("1.6.0") 
(@Since("1.6.0") override val ui
   private def getDistinctSplits(splits: Array[Double]): Array[Double] = {
     splits(0) = Double.NegativeInfinity
     splits(splits.length - 1) = Double.PositiveInfinity
+
+    // 0.0 and -0.0 are distinct values, array.distinct will preserve both of 
them.
+    // but 0.0 > -0.0 is False which will break the parameter validation 
checking.
+    // and in scala <= 2.12, there's bug which will cause array.distinct 
generate
+    // non-deterministic results when array contains both 0.0 and -0.0
+    // So that here we should first normalize all 0.0 and -0.0 to be 0.0
+    // See https://github.com/scala/bug/issues/11995
+    for (i <- 0 until splits.length) {
+      if (splits(i) == -0.0) {
+        splits(i) = 0.0
+      }
+    }
     val distinctSplits = splits.distinct
     if (splits.length != distinctSplits.length) {
       log.warn(s"Some quantiles were identical. Bucketing to 
${distinctSplits.length - 1}" +
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
index 6f6ab26..682b87a 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
@@ -512,4 +512,22 @@ class QuantileDiscretizerSuite extends MLTest with 
DefaultReadWriteTest {
     assert(observedNumBuckets === numBuckets,
       "Observed number of buckets does not equal expected number of buckets.")
   }
+
+  test("[SPARK-31676] QuantileDiscretizer raise error parameter splits given 
invalid value") {
+    import scala.util.Random
+    val rng = new Random(3)
+
+    val a1 = Array.tabulate(200)(_ => rng.nextDouble * 2.0 - 1.0) ++
+      Array.fill(20)(0.0) ++ Array.fill(20)(-0.0)
+
+    val df1 = sc.parallelize(a1, 2).toDF("id")
+
+    val qd = new QuantileDiscretizer()
+      .setInputCol("id")
+      .setOutputCol("out")
+      .setNumBuckets(200)
+      .setRelativeError(0.0)
+
+    qd.fit(df1) // assert no exception raised here.
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to