Repository: spark
Updated Branches:
  refs/heads/master eda2800d4 -> 51841d77d


[SPARK-13866] [SQL] Handle decimal type in CSV inference at CSV data source.

## What changes were proposed in this pull request?

https://issues.apache.org/jira/browse/SPARK-13866

This PR adds the support to infer `DecimalType`.
Here are the rules between `IntegerType`, `LongType` and `DecimalType`.

#### Infering Types

1. `IntegerType` and then `LongType`are tried first.

  ```scala
  Int.MaxValue => IntegerType
  Long.MaxValue => LongType
  ```

2. If it fails, try `DecimalType`.

  ```scala
  (Long.MaxValue + 1) => DecimalType(20, 0)
  ```
  This does not try to infer this as `DecimalType` when scale is less than 0.

3. if it fails, try `DoubleType`
  ```scala
  0.1 => DoubleType // This is failed to be inferred as `DecimalType` because 
it has the scale, 1.
  ```

#### Compatible Types (Merging Types)

For merging types, this is the same with JSON data source. If `DecimalType` is 
not capable, then it becomes `DoubleType`

## How was this patch tested?

Unit tests were used and `./dev/run_tests` for code style test.

Author: hyukjinkwon <gurwls...@gmail.com>
Author: Hyukjin Kwon <gurwls...@gmail.com>

Closes #11724 from HyukjinKwon/SPARK-13866.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/51841d77
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/51841d77
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/51841d77

Branch: refs/heads/master
Commit: 51841d77d99a858f8fa1256e923b0364b9b28fa0
Parents: eda2800
Author: hyukjinkwon <gurwls...@gmail.com>
Authored: Thu May 12 22:31:14 2016 -0700
Committer: Davies Liu <davies....@gmail.com>
Committed: Thu May 12 22:31:14 2016 -0700

----------------------------------------------------------------------
 .../datasources/csv/CSVInferSchema.scala        | 50 +++++++++++++++++++-
 sql/core/src/test/resources/decimal.csv         |  7 +++
 .../datasources/csv/CSVInferSchemaSuite.scala   | 13 ++++-
 .../execution/datasources/csv/CSVSuite.scala    | 15 ++++++
 4 files changed, 81 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/51841d77/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
index cfd66af..05c8d8e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.execution.datasources.csv
 
 import java.math.BigDecimal
-import java.text.{NumberFormat, SimpleDateFormat}
+import java.text.NumberFormat
 import java.util.Locale
 
 import scala.util.control.Exception._
@@ -85,6 +85,7 @@ private[csv] object CSVInferSchema {
         case NullType => tryParseInteger(field, options)
         case IntegerType => tryParseInteger(field, options)
         case LongType => tryParseLong(field, options)
+        case _: DecimalType => tryParseDecimal(field, options)
         case DoubleType => tryParseDouble(field, options)
         case TimestampType => tryParseTimestamp(field, options)
         case BooleanType => tryParseBoolean(field, options)
@@ -107,10 +108,28 @@ private[csv] object CSVInferSchema {
     if ((allCatch opt field.toLong).isDefined) {
       LongType
     } else {
-      tryParseDouble(field, options)
+      tryParseDecimal(field, options)
     }
   }
 
+  private def tryParseDecimal(field: String, options: CSVOptions): DataType = {
+    val decimalTry = allCatch opt {
+      // `BigDecimal` conversion can fail when the `field` is not a form of 
number.
+      val bigDecimal = new BigDecimal(field)
+      // Because many other formats do not support decimal, it reduces the 
cases for
+      // decimals by disallowing values having scale (eg. `1.1`).
+      if (bigDecimal.scale <= 0) {
+        // `DecimalType` conversion can fail when
+        //   1. The precision is bigger than 38.
+        //   2. scale is bigger than precision.
+        DecimalType(bigDecimal.precision, bigDecimal.scale)
+      } else {
+        tryParseDouble(field, options)
+      }
+    }
+    decimalTry.getOrElse(tryParseDouble(field, options))
+  }
+
   private def tryParseDouble(field: String, options: CSVOptions): DataType = {
     if ((allCatch opt field.toDouble).isDefined) {
       DoubleType
@@ -170,6 +189,33 @@ private[csv] object CSVInferSchema {
       val index = numericPrecedence.lastIndexWhere(t => t == t1 || t == t2)
       Some(numericPrecedence(index))
 
+    // These two cases below deal with when `DecimalType` is larger than 
`IntegralType`.
+    case (t1: IntegralType, t2: DecimalType) if t2.isWiderThan(t1) =>
+      Some(t2)
+    case (t1: DecimalType, t2: IntegralType) if t1.isWiderThan(t2) =>
+      Some(t1)
+
+    // These two cases below deal with when `IntegralType` is larger than 
`DecimalType`.
+    case (t1: IntegralType, t2: DecimalType) =>
+      findTightestCommonType(DecimalType.forType(t1), t2)
+    case (t1: DecimalType, t2: IntegralType) =>
+      findTightestCommonType(t1, DecimalType.forType(t2))
+
+    // Double support larger range than fixed decimal, DecimalType.Maximum 
should be enough
+    // in most case, also have better precision.
+    case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) =>
+      Some(DoubleType)
+
+    case (t1: DecimalType, t2: DecimalType) =>
+      val scale = math.max(t1.scale, t2.scale)
+      val range = math.max(t1.precision - t1.scale, t2.precision - t2.scale)
+      if (range + scale > 38) {
+        // DecimalType can't support precision > 38
+        Some(DoubleType)
+      } else {
+        Some(DecimalType(range + scale, scale))
+      }
+
     case _ => None
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/51841d77/sql/core/src/test/resources/decimal.csv
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/decimal.csv 
b/sql/core/src/test/resources/decimal.csv
new file mode 100644
index 0000000..870f6aa
--- /dev/null
+++ b/sql/core/src/test/resources/decimal.csv
@@ -0,0 +1,7 @@
+~ decimal field has integer, integer and decimal values. The last value cannot 
fit to a long
+~ long field has integer, long and integer values.
+~ double field has double, double and decimal values.
+decimal,long,double
+1,1,0.1
+1,9223372036854775807,1.0
+92233720368547758070,1,92233720368547758070

http://git-wip-us.apache.org/repos/asf/spark/blob/51841d77/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala
index daf85be..dbe3af4 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala
@@ -17,8 +17,6 @@
 
 package org.apache.spark.sql.execution.datasources.csv
 
-import java.text.SimpleDateFormat
-
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.types._
 
@@ -35,6 +33,11 @@ class CSVInferSchemaSuite extends SparkFunSuite {
     assert(CSVInferSchema.inferField(NullType, "2015-08-20 15:57:00", options) 
== TimestampType)
     assert(CSVInferSchema.inferField(NullType, "True", options) == BooleanType)
     assert(CSVInferSchema.inferField(NullType, "FAlSE", options) == 
BooleanType)
+
+    val textValueOne = Long.MaxValue.toString + "0"
+    val decimalValueOne = new java.math.BigDecimal(textValueOne)
+    val expectedTypeOne = DecimalType(decimalValueOne.precision, 
decimalValueOne.scale)
+    assert(CSVInferSchema.inferField(NullType, textValueOne, options) == 
expectedTypeOne)
   }
 
   test("String fields types are inferred correctly from other types") {
@@ -49,6 +52,11 @@ class CSVInferSchemaSuite extends SparkFunSuite {
     assert(CSVInferSchema.inferField(LongType, "True", options) == BooleanType)
     assert(CSVInferSchema.inferField(IntegerType, "FALSE", options) == 
BooleanType)
     assert(CSVInferSchema.inferField(TimestampType, "FALSE", options) == 
BooleanType)
+
+    val textValueOne = Long.MaxValue.toString + "0"
+    val decimalValueOne = new java.math.BigDecimal(textValueOne)
+    val expectedTypeOne = DecimalType(decimalValueOne.precision, 
decimalValueOne.scale)
+    assert(CSVInferSchema.inferField(IntegerType, textValueOne, options) == 
expectedTypeOne)
   }
 
   test("Timestamp field types are inferred correctly via custom data format") {
@@ -94,6 +102,7 @@ class CSVInferSchemaSuite extends SparkFunSuite {
     assert(CSVInferSchema.inferField(DoubleType, "\\N", options) == DoubleType)
     assert(CSVInferSchema.inferField(TimestampType, "\\N", options) == 
TimestampType)
     assert(CSVInferSchema.inferField(BooleanType, "\\N", options) == 
BooleanType)
+    assert(CSVInferSchema.inferField(DecimalType(1, 1), "\\N", options) == 
DecimalType(1, 1))
   }
 
   test("Merging Nulltypes should yield Nulltype.") {

http://git-wip-us.apache.org/repos/asf/spark/blob/51841d77/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index ae91e0f..27d6dc9 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -43,6 +43,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with 
SQLTestUtils {
   private val commentsFile = "comments.csv"
   private val disableCommentsFile = "disable_comments.csv"
   private val boolFile = "bool.csv"
+  private val decimalFile = "decimal.csv"
   private val simpleSparseFile = "simple_sparse.csv"
   private val numbersFile = "numbers.csv"
   private val datesFile = "dates.csv"
@@ -133,6 +134,20 @@ class CSVSuite extends QueryTest with SharedSQLContext 
with SQLTestUtils {
     assert(result.schema === expectedSchema)
   }
 
+  test("test inferring decimals") {
+    val result = sqlContext.read
+      .format("csv")
+      .option("comment", "~")
+      .option("header", "true")
+      .option("inferSchema", "true")
+      .load(testFile(decimalFile))
+    val expectedSchema = StructType(List(
+      StructField("decimal", DecimalType(20, 0), nullable = true),
+      StructField("long", LongType, nullable = true),
+      StructField("double", DoubleType, nullable = true)))
+    assert(result.schema === expectedSchema)
+  }
+
   test("test with alternative delimiter and quote") {
     val cars = spark.read
       .format("csv")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to