Repository: spark
Updated Branches:
  refs/heads/master bdb5e55c2 -> 874350905


[SPARK-22700][ML] Bucketizer.transform incorrectly drops row containing NaN

## What changes were proposed in this pull request?
only drops the rows containing NaN in the input columns

## How was this patch tested?
existing tests and added tests

Author: Ruifeng Zheng <[email protected]>
Author: Zheng RuiFeng <[email protected]>

Closes #19894 from zhengruifeng/bucketizer_nan.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/87435090
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/87435090
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/87435090

Branch: refs/heads/master
Commit: 874350905ff41ffdd8dc07548aa3c4f5782a3b35
Parents: bdb5e55
Author: Ruifeng Zheng <[email protected]>
Authored: Wed Dec 13 09:10:03 2017 +0200
Committer: Nick Pentreath <[email protected]>
Committed: Wed Dec 13 09:10:03 2017 +0200

----------------------------------------------------------------------
 .../org/apache/spark/ml/feature/Bucketizer.scala      | 14 ++++++++------
 .../org/apache/spark/ml/feature/BucketizerSuite.scala |  9 +++++++++
 2 files changed, 17 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/87435090/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
index e07f2a1..8299a3e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -155,10 +155,16 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") 
override val uid: String
   override def transform(dataset: Dataset[_]): DataFrame = {
     val transformedSchema = transformSchema(dataset.schema)
 
+    val (inputColumns, outputColumns) = if (isBucketizeMultipleColumns()) {
+      ($(inputCols).toSeq, $(outputCols).toSeq)
+    } else {
+      (Seq($(inputCol)), Seq($(outputCol)))
+    }
+
     val (filteredDataset, keepInvalid) = {
       if (getHandleInvalid == Bucketizer.SKIP_INVALID) {
         // "skip" NaN option is set, will filter out NaN values in the dataset
-        (dataset.na.drop().toDF(), false)
+        (dataset.na.drop(inputColumns).toDF(), false)
       } else {
         (dataset.toDF(), getHandleInvalid == Bucketizer.KEEP_INVALID)
       }
@@ -176,11 +182,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") 
override val uid: String
       }.withName(s"bucketizer_$idx")
     }
 
-    val (inputColumns, outputColumns) = if (isBucketizeMultipleColumns()) {
-      ($(inputCols).toSeq, $(outputCols).toSeq)
-    } else {
-      (Seq($(inputCol)), Seq($(outputCol)))
-    }
+
     val newCols = inputColumns.zipWithIndex.map { case (inputCol, idx) =>
       bucketizers(idx)(filteredDataset(inputCol).cast(DoubleType))
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/87435090/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
index 748dbd1..d9c97ae 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
@@ -123,6 +123,15 @@ class BucketizerSuite extends SparkFunSuite with 
MLlibTestSparkContext with Defa
     }
   }
 
+  test("Bucketizer should only drop NaN in input columns, with 
handleInvalid=skip") {
+    val df = spark.createDataFrame(Seq((2.3, 3.0), (Double.NaN, 3.0), (6.7, 
Double.NaN)))
+      .toDF("a", "b")
+    val splits = Array(Double.NegativeInfinity, 3.0, Double.PositiveInfinity)
+    val bucketizer = new 
Bucketizer().setInputCol("a").setOutputCol("x").setSplits(splits)
+    bucketizer.setHandleInvalid("skip")
+    assert(bucketizer.transform(df).count() == 2)
+  }
+
   test("Bucket continuous features, with NaN splits") {
     val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, 
Double.PositiveInfinity, Double.NaN)
     withClue("Invalid NaN split was not caught during Bucketizer 
initialization") {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to