Repository: spark Updated Branches: refs/heads/master bdb5e55c2 -> 874350905
[SPARK-22700][ML] Bucketizer.transform incorrectly drops row containing NaN ## What changes were proposed in this pull request? only drops the rows containing NaN in the input columns ## How was this patch tested? existing tests and added tests Author: Ruifeng Zheng <[email protected]> Author: Zheng RuiFeng <[email protected]> Closes #19894 from zhengruifeng/bucketizer_nan. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/87435090 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/87435090 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/87435090 Branch: refs/heads/master Commit: 874350905ff41ffdd8dc07548aa3c4f5782a3b35 Parents: bdb5e55 Author: Ruifeng Zheng <[email protected]> Authored: Wed Dec 13 09:10:03 2017 +0200 Committer: Nick Pentreath <[email protected]> Committed: Wed Dec 13 09:10:03 2017 +0200 ---------------------------------------------------------------------- .../org/apache/spark/ml/feature/Bucketizer.scala | 14 ++++++++------ .../org/apache/spark/ml/feature/BucketizerSuite.scala | 9 +++++++++ 2 files changed, 17 insertions(+), 6 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/87435090/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index e07f2a1..8299a3e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -155,10 +155,16 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String override def transform(dataset: Dataset[_]): DataFrame = { val transformedSchema = transformSchema(dataset.schema) + val (inputColumns, outputColumns) = if (isBucketizeMultipleColumns()) { + ($(inputCols).toSeq, $(outputCols).toSeq) + } else { + (Seq($(inputCol)), Seq($(outputCol))) + } + val (filteredDataset, keepInvalid) = { if (getHandleInvalid == Bucketizer.SKIP_INVALID) { // "skip" NaN option is set, will filter out NaN values in the dataset - (dataset.na.drop().toDF(), false) + (dataset.na.drop(inputColumns).toDF(), false) } else { (dataset.toDF(), getHandleInvalid == Bucketizer.KEEP_INVALID) } @@ -176,11 +182,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String }.withName(s"bucketizer_$idx") } - val (inputColumns, outputColumns) = if (isBucketizeMultipleColumns()) { - ($(inputCols).toSeq, $(outputCols).toSeq) - } else { - (Seq($(inputCol)), Seq($(outputCol))) - } + val newCols = inputColumns.zipWithIndex.map { case (inputCol, idx) => bucketizers(idx)(filteredDataset(inputCol).cast(DoubleType)) } http://git-wip-us.apache.org/repos/asf/spark/blob/87435090/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 748dbd1..d9c97ae 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -123,6 +123,15 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa } } + test("Bucketizer should only drop NaN in input columns, with handleInvalid=skip") { + val df = spark.createDataFrame(Seq((2.3, 3.0), (Double.NaN, 3.0), (6.7, Double.NaN))) + .toDF("a", "b") + val splits = Array(Double.NegativeInfinity, 3.0, Double.PositiveInfinity) + val bucketizer = new Bucketizer().setInputCol("a").setOutputCol("x").setSplits(splits) + bucketizer.setHandleInvalid("skip") + assert(bucketizer.transform(df).count() == 2) + } + test("Bucket continuous features, with NaN splits") { val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity, Double.NaN) withClue("Invalid NaN split was not caught during Bucketizer initialization") { --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
