Repository: spark
Updated Branches:
  refs/heads/master 7151011b3 -> 1dbb725db


[SPARK-16462][SPARK-16460][SPARK-15144][SQL] Make CSV cast null values properly

## Problem

CSV in Spark 2.0.0:
-  does not read null values back correctly for certain data types such as 
`Boolean`, `TimestampType`, `DateType` -- this is a regression comparing to 1.6;
- does not read empty values (specified by `options.nullValue`) as `null`s for 
`StringType` -- this is compatible with 1.6 but leads to problems like 
SPARK-16903.

## What changes were proposed in this pull request?

This patch makes changes to read all empty values back as `null`s.

## How was this patch tested?

New test cases.

Author: Liwei Lin <lwl...@gmail.com>

Closes #14118 from lw-lin/csv-cast-null.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1dbb725d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1dbb725d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1dbb725d

Branch: refs/heads/master
Commit: 1dbb725dbef30bf7633584ce8efdb573f2d92bca
Parents: 7151011
Author: Liwei Lin <lwl...@gmail.com>
Authored: Sun Sep 18 19:25:58 2016 +0100
Committer: Sean Owen <so...@cloudera.com>
Committed: Sun Sep 18 19:25:58 2016 +0100

----------------------------------------------------------------------
 python/pyspark/sql/readwriter.py                |   3 +-
 python/pyspark/sql/streaming.py                 |   3 +-
 .../org/apache/spark/sql/DataFrameReader.scala  |   3 +-
 .../datasources/csv/CSVInferSchema.scala        | 108 +++++++++----------
 .../spark/sql/streaming/DataStreamReader.scala  |   3 +-
 .../execution/datasources/csv/CSVSuite.scala    |   2 +-
 .../datasources/csv/CSVTypeCastSuite.scala      |  54 ++++++----
 7 files changed, 93 insertions(+), 83 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1dbb725d/python/pyspark/sql/readwriter.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index 3d79e0c..a6860ef 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -329,7 +329,8 @@ class DataFrameReader(OptionUtils):
                                          being read should be skipped. If None 
is set, it uses
                                          the default value, ``false``.
         :param nullValue: sets the string representation of a null value. If 
None is set, it uses
-                          the default value, empty string.
+                          the default value, empty string. Since 2.0.1, this 
``nullValue`` param
+                          applies to all supported types including the string 
type.
         :param nanValue: sets the string representation of a non-number value. 
If None is set, it
                          uses the default value, ``NaN``.
         :param positiveInf: sets the string representation of a positive 
infinity value. If None

http://git-wip-us.apache.org/repos/asf/spark/blob/1dbb725d/python/pyspark/sql/streaming.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py
index 67375f6..0136451 100644
--- a/python/pyspark/sql/streaming.py
+++ b/python/pyspark/sql/streaming.py
@@ -497,7 +497,8 @@ class DataStreamReader(OptionUtils):
                                          being read should be skipped. If None 
is set, it uses
                                          the default value, ``false``.
         :param nullValue: sets the string representation of a null value. If 
None is set, it uses
-                          the default value, empty string.
+                          the default value, empty string. Since 2.0.1, this 
``nullValue`` param
+                          applies to all supported types including the string 
type.
         :param nanValue: sets the string representation of a non-number value. 
If None is set, it
                          uses the default value, ``NaN``.
         :param positiveInf: sets the string representation of a positive 
infinity value. If None

http://git-wip-us.apache.org/repos/asf/spark/blob/1dbb725d/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index d29d90c..30f39c7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -376,7 +376,8 @@ class DataFrameReader private[sql](sparkSession: 
SparkSession) extends Logging {
    * from values being read should be skipped.</li>
    * <li>`ignoreTrailingWhiteSpace` (default `false`): defines whether or not 
trailing
    * whitespaces from values being read should be skipped.</li>
-   * <li>`nullValue` (default empty string): sets the string representation of 
a null value.</li>
+   * <li>`nullValue` (default empty string): sets the string representation of 
a null value. Since
+   * 2.0.1, this applies to all supported types including the string type.</li>
    * <li>`nanValue` (default `NaN`): sets the string representation of a 
non-number" value.</li>
    * <li>`positiveInf` (default `Inf`): sets the string representation of a 
positive infinity
    * value.</li>

http://git-wip-us.apache.org/repos/asf/spark/blob/1dbb725d/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
index 1ca6eff..3ab775c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
@@ -232,66 +232,58 @@ private[csv] object CSVTypeCast {
       nullable: Boolean = true,
       options: CSVOptions = CSVOptions()): Any = {
 
-    castType match {
-      case _: ByteType => if (datum == options.nullValue && nullable) null 
else datum.toByte
-      case _: ShortType => if (datum == options.nullValue && nullable) null 
else datum.toShort
-      case _: IntegerType => if (datum == options.nullValue && nullable) null 
else datum.toInt
-      case _: LongType => if (datum == options.nullValue && nullable) null 
else datum.toLong
-      case _: FloatType =>
-        if (datum == options.nullValue && nullable) {
-          null
-        } else if (datum == options.nanValue) {
-          Float.NaN
-        } else if (datum == options.negativeInf) {
-          Float.NegativeInfinity
-        } else if (datum == options.positiveInf) {
-          Float.PositiveInfinity
-        } else {
-          Try(datum.toFloat)
-            
.getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).floatValue())
-        }
-      case _: DoubleType =>
-        if (datum == options.nullValue && nullable) {
-          null
-        } else if (datum == options.nanValue) {
-          Double.NaN
-        } else if (datum == options.negativeInf) {
-          Double.NegativeInfinity
-        } else if (datum == options.positiveInf) {
-          Double.PositiveInfinity
-        } else {
-          Try(datum.toDouble)
-            
.getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).doubleValue())
-        }
-      case _: BooleanType => datum.toBoolean
-      case dt: DecimalType =>
-        if (datum == options.nullValue && nullable) {
-          null
-        } else {
-          val value = new BigDecimal(datum.replaceAll(",", ""))
-          Decimal(value, dt.precision, dt.scale)
-        }
-      case _: TimestampType =>
-        // This one will lose microseconds parts.
-        // See https://issues.apache.org/jira/browse/SPARK-10681.
-        Try(options.timestampFormat.parse(datum).getTime * 1000L)
-          .getOrElse {
-            // If it fails to parse, then tries the way used in 2.0 and 1.x 
for backwards
-            // compatibility.
-            DateTimeUtils.stringToTime(datum).getTime  * 1000L
+    if (nullable && datum == options.nullValue) {
+      null
+    } else {
+      castType match {
+        case _: ByteType => datum.toByte
+        case _: ShortType => datum.toShort
+        case _: IntegerType => datum.toInt
+        case _: LongType => datum.toLong
+        case _: FloatType =>
+          datum match {
+            case options.nanValue => Float.NaN
+            case options.negativeInf => Float.NegativeInfinity
+            case options.positiveInf => Float.PositiveInfinity
+            case _ =>
+              Try(datum.toFloat)
+                
.getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).floatValue())
           }
-      case _: DateType =>
-        // This one will lose microseconds parts.
-        // See https://issues.apache.org/jira/browse/SPARK-10681.x
-        
Try(DateTimeUtils.millisToDays(options.dateFormat.parse(datum).getTime))
-          .getOrElse {
-            // If it fails to parse, then tries the way used in 2.0 and 1.x 
for backwards
-            // compatibility.
-            
DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(datum).getTime)
+        case _: DoubleType =>
+          datum match {
+            case options.nanValue => Double.NaN
+            case options.negativeInf => Double.NegativeInfinity
+            case options.positiveInf => Double.PositiveInfinity
+            case _ =>
+              Try(datum.toDouble)
+                
.getOrElse(NumberFormat.getInstance(Locale.getDefault).parse(datum).doubleValue())
           }
-      case _: StringType => UTF8String.fromString(datum)
-      case udt: UserDefinedType[_] => castTo(datum, udt.sqlType, nullable, 
options)
-      case _ => throw new RuntimeException(s"Unsupported type: 
${castType.typeName}")
+        case _: BooleanType => datum.toBoolean
+        case dt: DecimalType =>
+          val value = new BigDecimal(datum.replaceAll(",", ""))
+          Decimal(value, dt.precision, dt.scale)
+        case _: TimestampType =>
+          // This one will lose microseconds parts.
+          // See https://issues.apache.org/jira/browse/SPARK-10681.
+          Try(options.timestampFormat.parse(datum).getTime * 1000L)
+            .getOrElse {
+              // If it fails to parse, then tries the way used in 2.0 and 1.x 
for backwards
+              // compatibility.
+              DateTimeUtils.stringToTime(datum).getTime * 1000L
+            }
+        case _: DateType =>
+          // This one will lose microseconds parts.
+          // See https://issues.apache.org/jira/browse/SPARK-10681.x
+          
Try(DateTimeUtils.millisToDays(options.dateFormat.parse(datum).getTime))
+            .getOrElse {
+              // If it fails to parse, then tries the way used in 2.0 and 1.x 
for backwards
+              // compatibility.
+              
DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(datum).getTime)
+            }
+        case _: StringType => UTF8String.fromString(datum)
+        case udt: UserDefinedType[_] => castTo(datum, udt.sqlType, nullable, 
options)
+        case _ => throw new RuntimeException(s"Unsupported type: 
${castType.typeName}")
+      }
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1dbb725d/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
index c25f71a..9d17405 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
@@ -232,7 +232,8 @@ final class DataStreamReader private[sql](sparkSession: 
SparkSession) extends Lo
    * from values being read should be skipped.</li>
    * <li>`ignoreTrailingWhiteSpace` (default `false`): defines whether or not 
trailing
    * whitespaces from values being read should be skipped.</li>
-   * <li>`nullValue` (default empty string): sets the string representation of 
a null value.</li>
+   * <li>`nullValue` (default empty string): sets the string representation of 
a null value. Since
+   * 2.0.1, this applies to all supported types including the string type.</li>
    * <li>`nanValue` (default `NaN`): sets the string representation of a 
non-number" value.</li>
    * <li>`positiveInf` (default `Inf`): sets the string representation of a 
positive infinity
    * value.</li>

http://git-wip-us.apache.org/repos/asf/spark/blob/1dbb725d/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index 1930862..29aac9d 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -554,7 +554,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with 
SQLTestUtils {
 
     verifyCars(cars, withHeader = true, checkValues = false)
     val results = cars.collect()
-    assert(results(0).toSeq === Array(2012, "Tesla", "S", "null", "null"))
+    assert(results(0).toSeq === Array(2012, "Tesla", "S", null, null))
     assert(results(2).toSeq === Array(null, "Chevy", "Volt", null, null))
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1dbb725d/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala
index 3ce643e..dae92f6 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVTypeCastSuite.scala
@@ -68,16 +68,46 @@ class CSVTypeCastSuite extends SparkFunSuite {
   }
 
   test("Nullable types are handled") {
-    assert(CSVTypeCast.castTo("", IntegerType, nullable = true, CSVOptions()) 
== null)
+    assertNull(
+      CSVTypeCast.castTo("-", ByteType, nullable = true, 
CSVOptions("nullValue", "-")))
+    assertNull(
+      CSVTypeCast.castTo("-", ShortType, nullable = true, 
CSVOptions("nullValue", "-")))
+    assertNull(
+      CSVTypeCast.castTo("-", IntegerType, nullable = true, 
CSVOptions("nullValue", "-")))
+    assertNull(
+      CSVTypeCast.castTo("-", LongType, nullable = true, 
CSVOptions("nullValue", "-")))
+    assertNull(
+      CSVTypeCast.castTo("-", FloatType, nullable = true, 
CSVOptions("nullValue", "-")))
+    assertNull(
+      CSVTypeCast.castTo("-", DoubleType, nullable = true, 
CSVOptions("nullValue", "-")))
+    assertNull(
+      CSVTypeCast.castTo("-", BooleanType, nullable = true, 
CSVOptions("nullValue", "-")))
+    assertNull(
+      CSVTypeCast.castTo("-", DecimalType.DoubleDecimal, true, 
CSVOptions("nullValue", "-")))
+    assertNull(
+      CSVTypeCast.castTo("-", TimestampType, nullable = true, 
CSVOptions("nullValue", "-")))
+    assertNull(
+      CSVTypeCast.castTo("-", DateType, nullable = true, 
CSVOptions("nullValue", "-")))
+    assertNull(
+      CSVTypeCast.castTo("-", StringType, nullable = true, 
CSVOptions("nullValue", "-")))
   }
 
-  test("String type should always return the same as the input") {
+  test("String type should also respect `nullValue`") {
+    assertNull(
+      CSVTypeCast.castTo("", StringType, nullable = true, CSVOptions()))
     assert(
-      CSVTypeCast.castTo("", StringType, nullable = true, CSVOptions()) ==
+      CSVTypeCast.castTo("", StringType, nullable = false, CSVOptions()) ==
         UTF8String.fromString(""))
+
     assert(
-      CSVTypeCast.castTo("", StringType, nullable = false, CSVOptions()) ==
+      CSVTypeCast.castTo("", StringType, nullable = true, 
CSVOptions("nullValue", "null")) ==
+        UTF8String.fromString(""))
+    assert(
+      CSVTypeCast.castTo("", StringType, nullable = false, 
CSVOptions("nullValue", "null")) ==
         UTF8String.fromString(""))
+
+    assertNull(
+      CSVTypeCast.castTo(null, StringType, nullable = true, 
CSVOptions("nullValue", "null")))
   }
 
   test("Throws exception for empty string with non null type") {
@@ -170,20 +200,4 @@ class CSVTypeCastSuite extends SparkFunSuite {
     assert(doubleVal2 == Double.PositiveInfinity)
   }
 
-  test("Type-specific null values are used for casting") {
-    assertNull(
-      CSVTypeCast.castTo("-", ByteType, nullable = true, 
CSVOptions("nullValue", "-")))
-    assertNull(
-      CSVTypeCast.castTo("-", ShortType, nullable = true, 
CSVOptions("nullValue", "-")))
-    assertNull(
-      CSVTypeCast.castTo("-", IntegerType, nullable = true, 
CSVOptions("nullValue", "-")))
-    assertNull(
-      CSVTypeCast.castTo("-", LongType, nullable = true, 
CSVOptions("nullValue", "-")))
-    assertNull(
-      CSVTypeCast.castTo("-", FloatType, nullable = true, 
CSVOptions("nullValue", "-")))
-    assertNull(
-      CSVTypeCast.castTo("-", DoubleType, nullable = true, 
CSVOptions("nullValue", "-")))
-    assertNull(
-      CSVTypeCast.castTo("-", DecimalType.DoubleDecimal, true, 
CSVOptions("nullValue", "-")))
-  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to