Repository: spark
Updated Branches:
  refs/heads/branch-2.0 87be9652b -> 735e2039a


[SPARK-18555][SQL] DataFrameNaFunctions.fill miss up original values in long 
integers

## What changes were proposed in this pull request?

   DataSet.na.fill(0) used on a DataSet which has a long value column, it will 
change the original long value.

   The reason is that the type of the function fill's param is Double, and the 
numeric columns are always cast to double(`fillCol[Double](f, value)`) .
```
  def fill(value: Double, cols: Seq[String]): DataFrame = {
    val columnEquals = df.sparkSession.sessionState.analyzer.resolver
    val projections = df.schema.fields.map { f =>
      // Only fill if the column is part of the cols list.
      if (f.dataType.isInstanceOf[NumericType] && cols.exists(col => 
columnEquals(f.name, col))) {
        fillCol[Double](f, value)
      } else {
        df.col(f.name)
      }
    }
    df.select(projections : _*)
  }
```

 For example:
```
scala> val df = Seq[(Long, Long)]((1, 2), (-1, -2), (9123146099426677101L, 
9123146560113991650L)).toDF("a", "b")
df: org.apache.spark.sql.DataFrame = [a: bigint, b: bigint]

scala> df.show
+-------------------+-------------------+
|                  a|                  b|
+-------------------+-------------------+
|                  1|                  2|
|                 -1|                 -2|
|9123146099426677101|9123146560113991650|
+-------------------+-------------------+

scala> df.na.fill(0).show
+-------------------+-------------------+
|                  a|                  b|
+-------------------+-------------------+
|                  1|                  2|
|                 -1|                 -2|
|9123146099426676736|9123146560113991680|
+-------------------+-------------------+
 ```

the original values changed [which is not we expected result]:
```
 9123146099426677101 -> 9123146099426676736
 9123146560113991650 -> 9123146560113991680
```

## How was this patch tested?

unit test added.

Author: root <root@iZbp1gsnrlfzjxh82cz80vZ.(none)>

Closes #15994 from windpiger/nafillMissupOriginalValue.

(cherry picked from commit 508de38c9928d160cf70e8e7d69ddb1dca5c1a64)
Signed-off-by: DB Tsai <dbt...@dbtsai.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/735e2039
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/735e2039
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/735e2039

Branch: refs/heads/branch-2.0
Commit: 735e2039a5095ecaeac6e987ac3738136e4e4849
Parents: 87be965
Author: root <root@iZbp1gsnrlfzjxh82cz80vZ.(none)>
Authored: Mon Dec 5 18:39:56 2016 -0800
Committer: DB Tsai <dbt...@dbtsai.com>
Committed: Tue Apr 11 00:11:29 2017 +0000

----------------------------------------------------------------------
 .../apache/spark/sql/DataFrameNaFunctions.scala | 89 ++++++++++++++------
 .../spark/sql/DataFrameNaFunctionsSuite.scala   | 18 ++++
 2 files changed, 80 insertions(+), 27 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/735e2039/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
index ad00966..3df8bb9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
@@ -129,6 +129,12 @@ final class DataFrameNaFunctions private[sql](df: 
DataFrame) {
   /**
    * Returns a new [[DataFrame]] that replaces null or NaN values in numeric 
columns with `value`.
    *
+   * @since 2.2.0
+   */
+  def fill(value: Long): DataFrame = fill(value, df.columns)
+
+  /**
+   * Returns a new `DataFrame` that replaces null or NaN values in numeric 
columns with `value`.
    * @since 1.3.1
    */
   def fill(value: Double): DataFrame = fill(value, df.columns)
@@ -144,6 +150,14 @@ final class DataFrameNaFunctions private[sql](df: 
DataFrame) {
    * Returns a new [[DataFrame]] that replaces null or NaN values in specified 
numeric columns.
    * If a specified column is not a numeric column, it is ignored.
    *
+   * @since 2.2.0
+   */
+  def fill(value: Long, cols: Array[String]): DataFrame = fill(value, 
cols.toSeq)
+
+  /**
+   * Returns a new `DataFrame` that replaces null or NaN values in specified 
numeric columns.
+   * If a specified column is not a numeric column, it is ignored.
+   *
    * @since 1.3.1
    */
   def fill(value: Double, cols: Array[String]): DataFrame = fill(value, 
cols.toSeq)
@@ -152,20 +166,18 @@ final class DataFrameNaFunctions private[sql](df: 
DataFrame) {
    * (Scala-specific) Returns a new [[DataFrame]] that replaces null or NaN 
values in specified
    * numeric columns. If a specified column is not a numeric column, it is 
ignored.
    *
+   * @since 2.2.0
+   */
+  def fill(value: Long, cols: Seq[String]): DataFrame = fillValue(value, cols)
+
+  /**
+   * (Scala-specific) Returns a new `DataFrame` that replaces null or NaN 
values in specified
+   * numeric columns. If a specified column is not a numeric column, it is 
ignored.
+   *
    * @since 1.3.1
    */
-  def fill(value: Double, cols: Seq[String]): DataFrame = {
-    val columnEquals = df.sparkSession.sessionState.analyzer.resolver
-    val projections = df.schema.fields.map { f =>
-      // Only fill if the column is part of the cols list.
-      if (f.dataType.isInstanceOf[NumericType] && cols.exists(col => 
columnEquals(f.name, col))) {
-        fillCol[Double](f, value)
-      } else {
-        df.col(f.name)
-      }
-    }
-    df.select(projections : _*)
-  }
+  def fill(value: Double, cols: Seq[String]): DataFrame = fillValue(value, 
cols)
+
 
   /**
    * Returns a new [[DataFrame]] that replaces null values in specified string 
columns.
@@ -181,18 +193,7 @@ final class DataFrameNaFunctions private[sql](df: 
DataFrame) {
    *
    * @since 1.3.1
    */
-  def fill(value: String, cols: Seq[String]): DataFrame = {
-    val columnEquals = df.sparkSession.sessionState.analyzer.resolver
-    val projections = df.schema.fields.map { f =>
-      // Only fill if the column is part of the cols list.
-      if (f.dataType.isInstanceOf[StringType] && cols.exists(col => 
columnEquals(f.name, col))) {
-        fillCol[String](f, value)
-      } else {
-        df.col(f.name)
-      }
-    }
-    df.select(projections : _*)
-  }
+  def fill(value: String, cols: Seq[String]): DataFrame = fillValue(value, 
cols)
 
   /**
    * Returns a new [[DataFrame]] that replaces null values.
@@ -211,7 +212,7 @@ final class DataFrameNaFunctions private[sql](df: 
DataFrame) {
    *
    * @since 1.3.1
    */
-  def fill(valueMap: java.util.Map[String, Any]): DataFrame = 
fill0(valueMap.asScala.toSeq)
+  def fill(valueMap: java.util.Map[String, Any]): DataFrame = 
fillMap(valueMap.asScala.toSeq)
 
   /**
    * (Scala-specific) Returns a new [[DataFrame]] that replaces null values.
@@ -231,7 +232,7 @@ final class DataFrameNaFunctions private[sql](df: 
DataFrame) {
    *
    * @since 1.3.1
    */
-  def fill(valueMap: Map[String, Any]): DataFrame = fill0(valueMap.toSeq)
+  def fill(valueMap: Map[String, Any]): DataFrame = fillMap(valueMap.toSeq)
 
   /**
    * Replaces values matching keys in `replacement` map with the corresponding 
values.
@@ -369,7 +370,7 @@ final class DataFrameNaFunctions private[sql](df: 
DataFrame) {
     df.select(projections : _*)
   }
 
-  private def fill0(values: Seq[(String, Any)]): DataFrame = {
+  private def fillMap(values: Seq[(String, Any)]): DataFrame = {
     // Error handling
     values.foreach { case (colName, replaceValue) =>
       // Check column name exists
@@ -436,4 +437,38 @@ final class DataFrameNaFunctions private[sql](df: 
DataFrame) {
     case v => throw new IllegalArgumentException(
       s"Unsupported value type ${v.getClass.getName} ($v).")
   }
+
+  /**
+   * Returns a new `DataFrame` that replaces null or NaN values in specified
+   * numeric, string columns. If a specified column is not a numeric, string 
column,
+   * it is ignored.
+   */
+  private def fillValue[T](value: T, cols: Seq[String]): DataFrame = {
+    // the fill[T] which T is  Long/Double,
+    // should apply on all the NumericType Column, for example:
+    // val input = Seq[(java.lang.Integer, java.lang.Double)]((null, 
164.3)).toDF("a","b")
+    // input.na.fill(3.1)
+    // the result is (3,164.3), not (null, 164.3)
+    val targetType = value match {
+      case _: Double | _: Long => NumericType
+      case _: String => StringType
+      case _ => throw new IllegalArgumentException(
+        s"Unsupported value type ${value.getClass.getName} ($value).")
+    }
+
+    val columnEquals = df.sparkSession.sessionState.analyzer.resolver
+    val projections = df.schema.fields.map { f =>
+      val typeMatches = (targetType, f.dataType) match {
+        case (NumericType, dt) => dt.isInstanceOf[NumericType]
+        case (StringType, dt) => dt == StringType
+      }
+      // Only fill if the column is part of the cols list.
+      if (typeMatches && cols.exists(col => columnEquals(f.name, col))) {
+        fillCol[T](f, value)
+      } else {
+        df.col(f.name)
+      }
+    }
+    df.select(projections : _*)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/735e2039/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
index 47b55e2..fd82984 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
@@ -138,6 +138,24 @@ class DataFrameNaFunctionsSuite extends QueryTest with 
SharedSQLContext {
     checkAnswer(
       Seq[(String, String)]((null, null)).toDF("col1", "col2").na.fill("test", 
"col1" :: Nil),
       Row("test", null))
+
+    checkAnswer(
+      Seq[(Long, Long)]((1, 2), (-1, -2), (9123146099426677101L, 
9123146560113991650L))
+        .toDF("a", "b").na.fill(0),
+      Row(1, 2) :: Row(-1, -2) :: Row(9123146099426677101L, 
9123146560113991650L) :: Nil
+    )
+
+    checkAnswer(
+      Seq[(java.lang.Long, java.lang.Double)]((null, 1.23), (3L, null), (4L, 
3.45))
+        .toDF("a", "b").na.fill(2.34),
+      Row(2, 1.23) :: Row(3, 2.34) :: Row(4, 3.45) :: Nil
+    )
+
+    checkAnswer(
+      Seq[(java.lang.Long, java.lang.Double)]((null, 1.23), (3L, null), (4L, 
3.45))
+        .toDF("a", "b").na.fill(5),
+      Row(5, 1.23) :: Row(3, 5.0) :: Row(4, 3.45) :: Nil
+    )
   }
 
   test("fill with map") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to