Repository: spark Updated Branches: refs/heads/branch-2.2 a95c3e29d -> 1cc34f3e5
[SPARK-22700][ML] Bucketizer.transform incorrectly drops row containing NaN - for branch-2.2 ## What changes were proposed in this pull request? for branch-2.2 only drops the rows containing NaN in the input columns ## How was this patch tested? existing tests and added tests Author: Zheng RuiFeng <ruife...@foxmail.com> Closes #20539 from zhengruifeng/bucketizer_nan_2.2. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1cc34f3e Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1cc34f3e Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1cc34f3e Branch: refs/heads/branch-2.2 Commit: 1cc34f3e58c92dd06545727e9d931008a1082bbf Parents: a95c3e2 Author: Zheng RuiFeng <ruife...@foxmail.com> Authored: Wed Feb 21 17:26:33 2018 -0800 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Wed Feb 21 17:26:33 2018 -0800 ---------------------------------------------------------------------- .../main/scala/org/apache/spark/ml/feature/Bucketizer.scala | 2 +- .../scala/org/apache/spark/ml/feature/BucketizerSuite.scala | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/1cc34f3e/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index bb8f2a3..f585ff0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -106,7 +106,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String val (filteredDataset, keepInvalid) = { if (getHandleInvalid == Bucketizer.SKIP_INVALID) { // "skip" NaN option is set, will filter out NaN values in the dataset - (dataset.na.drop().toDF(), false) + (dataset.na.drop(Seq($(inputCol))).toDF(), false) } else { (dataset.toDF(), getHandleInvalid == Bucketizer.KEEP_INVALID) } http://git-wip-us.apache.org/repos/asf/spark/blob/1cc34f3e/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 420fb17..32e50a9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -187,6 +187,15 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa } } } + + test("Bucketizer should only drop NaN in input columns, with handleInvalid=skip") { + val df = spark.createDataFrame(Seq((2.3, 3.0), (Double.NaN, 3.0), (6.7, Double.NaN))) + .toDF("a", "b") + val splits = Array(Double.NegativeInfinity, 3.0, Double.PositiveInfinity) + val bucketizer = new Bucketizer().setInputCol("a").setOutputCol("x").setSplits(splits) + bucketizer.setHandleInvalid("skip") + assert(bucketizer.transform(df).count() == 2) + } } private object BucketizerSuite extends SparkFunSuite { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org