Repository: spark
Updated Branches:
  refs/heads/master 59741887e -> f97326bcd


[SPARK-25977][SQL] Parsing decimals from CSV using locale

## What changes were proposed in this pull request?

In the PR, I propose using of the locale option to parse decimals from CSV 
input. After the changes, `UnivocityParser` converts input string to 
`BigDecimal` and to Spark's Decimal by using `java.text.DecimalFormat`.

## How was this patch tested?

Added a test for the `en-US`, `ko-KR`, `ru-RU`, `de-DE` locales.

Closes #22979 from MaxGekk/decimal-parsing-locale.

Lead-authored-by: Maxim Gekk <maxim.g...@databricks.com>
Co-authored-by: Maxim Gekk <max.g...@gmail.com>
Signed-off-by: hyukjinkwon <gurwls...@apache.org>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f97326bc
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f97326bc
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f97326bc

Branch: refs/heads/master
Commit: f97326bcdba532eabf25d4899b13709e9af2bfea
Parents: 5974188
Author: Maxim Gekk <maxim.g...@databricks.com>
Authored: Fri Nov 30 08:27:55 2018 +0800
Committer: hyukjinkwon <gurwls...@apache.org>
Committed: Fri Nov 30 08:27:55 2018 +0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/csv/CSVExprUtils.scala   |   4 +
 .../spark/sql/catalyst/csv/CSVInferSchema.scala |  72 ++++-----
 .../sql/catalyst/csv/UnivocityParser.scala      |   8 +-
 .../catalyst/expressions/csvExpressions.scala   |   5 +-
 .../sql/catalyst/csv/CSVInferSchemaSuite.scala  | 147 ++++++++++++-------
 .../sql/catalyst/csv/UnivocityParserSuite.scala |  22 ++-
 .../datasources/csv/CSVDataSource.scala         |   4 +-
 7 files changed, 168 insertions(+), 94 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f97326bc/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtils.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtils.scala
index bbe2783..6c982a1 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtils.scala
@@ -17,6 +17,10 @@
 
 package org.apache.spark.sql.catalyst.csv
 
+import java.math.BigDecimal
+import java.text.{DecimalFormat, DecimalFormatSymbols, ParsePosition}
+import java.util.Locale
+
 object CSVExprUtils {
   /**
    * Filter ignorable rows for CSV iterator (lines empty and starting with 
`comment`).

http://git-wip-us.apache.org/repos/asf/spark/blob/f97326bc/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala
index 799e999..94cb4b1 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala
@@ -17,16 +17,19 @@
 
 package org.apache.spark.sql.catalyst.csv
 
-import java.math.BigDecimal
-
 import scala.util.control.Exception.allCatch
 
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.analysis.TypeCoercion
+import org.apache.spark.sql.catalyst.expressions.ExprUtils
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.types._
 
-object CSVInferSchema {
+class CSVInferSchema(options: CSVOptions) extends Serializable {
+
+  private val decimalParser = {
+    ExprUtils.getDecimalParser(options.locale)
+  }
 
   /**
    * Similar to the JSON schema inference
@@ -36,14 +39,13 @@ object CSVInferSchema {
    */
   def infer(
       tokenRDD: RDD[Array[String]],
-      header: Array[String],
-      options: CSVOptions): StructType = {
+      header: Array[String]): StructType = {
     val fields = if (options.inferSchemaFlag) {
       val startType: Array[DataType] = 
Array.fill[DataType](header.length)(NullType)
       val rootTypes: Array[DataType] =
-        tokenRDD.aggregate(startType)(inferRowType(options), mergeRowTypes)
+        tokenRDD.aggregate(startType)(inferRowType, mergeRowTypes)
 
-      toStructFields(rootTypes, header, options)
+      toStructFields(rootTypes, header)
     } else {
       // By default fields are assumed to be StringType
       header.map(fieldName => StructField(fieldName, StringType, nullable = 
true))
@@ -54,8 +56,7 @@ object CSVInferSchema {
 
   def toStructFields(
       fieldTypes: Array[DataType],
-      header: Array[String],
-      options: CSVOptions): Array[StructField] = {
+      header: Array[String]): Array[StructField] = {
     header.zip(fieldTypes).map { case (thisHeader, rootType) =>
       val dType = rootType match {
         case _: NullType => StringType
@@ -65,11 +66,10 @@ object CSVInferSchema {
     }
   }
 
-  def inferRowType(options: CSVOptions)
-      (rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = {
+  def inferRowType(rowSoFar: Array[DataType], next: Array[String]): 
Array[DataType] = {
     var i = 0
     while (i < math.min(rowSoFar.length, next.length)) {  // May have columns 
on right missing.
-      rowSoFar(i) = inferField(rowSoFar(i), next(i), options)
+      rowSoFar(i) = inferField(rowSoFar(i), next(i))
       i+=1
     }
     rowSoFar
@@ -85,20 +85,20 @@ object CSVInferSchema {
    * Infer type of string field. Given known type Double, and a string "1", 
there is no
    * point checking if it is an Int, as the final type must be Double or 
higher.
    */
-  def inferField(typeSoFar: DataType, field: String, options: CSVOptions): 
DataType = {
+  def inferField(typeSoFar: DataType, field: String): DataType = {
     if (field == null || field.isEmpty || field == options.nullValue) {
       typeSoFar
     } else {
       typeSoFar match {
-        case NullType => tryParseInteger(field, options)
-        case IntegerType => tryParseInteger(field, options)
-        case LongType => tryParseLong(field, options)
+        case NullType => tryParseInteger(field)
+        case IntegerType => tryParseInteger(field)
+        case LongType => tryParseLong(field)
         case _: DecimalType =>
           // DecimalTypes have different precisions and scales, so we try to 
find the common type.
-          compatibleType(typeSoFar, tryParseDecimal(field, 
options)).getOrElse(StringType)
-        case DoubleType => tryParseDouble(field, options)
-        case TimestampType => tryParseTimestamp(field, options)
-        case BooleanType => tryParseBoolean(field, options)
+          compatibleType(typeSoFar, 
tryParseDecimal(field)).getOrElse(StringType)
+        case DoubleType => tryParseDouble(field)
+        case TimestampType => tryParseTimestamp(field)
+        case BooleanType => tryParseBoolean(field)
         case StringType => StringType
         case other: DataType =>
           throw new UnsupportedOperationException(s"Unexpected data type 
$other")
@@ -106,30 +106,30 @@ object CSVInferSchema {
     }
   }
 
-  private def isInfOrNan(field: String, options: CSVOptions): Boolean = {
+  private def isInfOrNan(field: String): Boolean = {
     field == options.nanValue || field == options.negativeInf || field == 
options.positiveInf
   }
 
-  private def tryParseInteger(field: String, options: CSVOptions): DataType = {
+  private def tryParseInteger(field: String): DataType = {
     if ((allCatch opt field.toInt).isDefined) {
       IntegerType
     } else {
-      tryParseLong(field, options)
+      tryParseLong(field)
     }
   }
 
-  private def tryParseLong(field: String, options: CSVOptions): DataType = {
+  private def tryParseLong(field: String): DataType = {
     if ((allCatch opt field.toLong).isDefined) {
       LongType
     } else {
-      tryParseDecimal(field, options)
+      tryParseDecimal(field)
     }
   }
 
-  private def tryParseDecimal(field: String, options: CSVOptions): DataType = {
+  private def tryParseDecimal(field: String): DataType = {
     val decimalTry = allCatch opt {
-      // `BigDecimal` conversion can fail when the `field` is not a form of 
number.
-      val bigDecimal = new BigDecimal(field)
+      // The conversion can fail when the `field` is not a form of number.
+      val bigDecimal = decimalParser(field)
       // Because many other formats do not support decimal, it reduces the 
cases for
       // decimals by disallowing values having scale (eg. `1.1`).
       if (bigDecimal.scale <= 0) {
@@ -138,21 +138,21 @@ object CSVInferSchema {
         //   2. scale is bigger than precision.
         DecimalType(bigDecimal.precision, bigDecimal.scale)
       } else {
-        tryParseDouble(field, options)
+        tryParseDouble(field)
       }
     }
-    decimalTry.getOrElse(tryParseDouble(field, options))
+    decimalTry.getOrElse(tryParseDouble(field))
   }
 
-  private def tryParseDouble(field: String, options: CSVOptions): DataType = {
-    if ((allCatch opt field.toDouble).isDefined || isInfOrNan(field, options)) 
{
+  private def tryParseDouble(field: String): DataType = {
+    if ((allCatch opt field.toDouble).isDefined || isInfOrNan(field)) {
       DoubleType
     } else {
-      tryParseTimestamp(field, options)
+      tryParseTimestamp(field)
     }
   }
 
-  private def tryParseTimestamp(field: String, options: CSVOptions): DataType 
= {
+  private def tryParseTimestamp(field: String): DataType = {
     // This case infers a custom `dataFormat` is set.
     if ((allCatch opt options.timestampFormat.parse(field)).isDefined) {
       TimestampType
@@ -160,11 +160,11 @@ object CSVInferSchema {
       // We keep this for backwards compatibility.
       TimestampType
     } else {
-      tryParseBoolean(field, options)
+      tryParseBoolean(field)
     }
   }
 
-  private def tryParseBoolean(field: String, options: CSVOptions): DataType = {
+  private def tryParseBoolean(field: String): DataType = {
     if ((allCatch opt field.toBoolean).isDefined) {
       BooleanType
     } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/f97326bc/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala
index ed19693..85e1292 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala
@@ -18,7 +18,6 @@
 package org.apache.spark.sql.catalyst.csv
 
 import java.io.InputStream
-import java.math.BigDecimal
 
 import scala.util.Try
 import scala.util.control.NonFatal
@@ -27,7 +26,7 @@ import com.univocity.parsers.csv.CsvParser
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
+import org.apache.spark.sql.catalyst.expressions.{ExprUtils, 
GenericInternalRow}
 import org.apache.spark.sql.catalyst.util.{BadRecordException, DateTimeUtils, 
FailureSafeParser}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
@@ -104,6 +103,8 @@ class UnivocityParser(
     requiredSchema.map(f => makeConverter(f.name, f.dataType, f.nullable, 
options)).toArray
   }
 
+  private val decimalParser = ExprUtils.getDecimalParser(options.locale)
+
   /**
    * Create a converter which converts the string value to a value according 
to a desired type.
    * Currently, we do not support complex types (`ArrayType`, `MapType`, 
`StructType`).
@@ -149,8 +150,7 @@ class UnivocityParser(
 
     case dt: DecimalType => (d: String) =>
       nullSafeDatum(d, name, nullable, options) { datum =>
-        val value = new BigDecimal(datum.replaceAll(",", ""))
-        Decimal(value, dt.precision, dt.scale)
+        Decimal(decimalParser(datum), dt.precision, dt.scale)
       }
 
     case _: TimestampType => (d: String) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/f97326bc/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
index 1e4e1c6..83b0299 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
@@ -180,8 +180,9 @@ case class SchemaOfCsv(
 
     val header = row.zipWithIndex.map { case (_, index) => s"_c$index" }
     val startType: Array[DataType] = 
Array.fill[DataType](header.length)(NullType)
-    val fieldTypes = CSVInferSchema.inferRowType(parsedOptions)(startType, row)
-    val st = StructType(CSVInferSchema.toStructFields(fieldTypes, header, 
parsedOptions))
+    val inferSchema = new CSVInferSchema(parsedOptions)
+    val fieldTypes = inferSchema.inferRowType(startType, row)
+    val st = StructType(inferSchema.toStructFields(fieldTypes, header))
     UTF8String.fromString(st.catalogString)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f97326bc/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala
index 651846d..1a020e6 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchemaSuite.scala
@@ -17,126 +17,175 @@
 
 package org.apache.spark.sql.catalyst.csv
 
+import java.text.{DecimalFormat, DecimalFormatSymbols}
+import java.util.Locale
+
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.plans.SQLHelper
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 
-class CSVInferSchemaSuite extends SparkFunSuite {
+class CSVInferSchemaSuite extends SparkFunSuite  with SQLHelper {
 
   test("String fields types are inferred correctly from null types") {
     val options = new CSVOptions(Map.empty[String, String], false, "GMT")
-    assert(CSVInferSchema.inferField(NullType, "", options) == NullType)
-    assert(CSVInferSchema.inferField(NullType, null, options) == NullType)
-    assert(CSVInferSchema.inferField(NullType, "100000000000", options) == 
LongType)
-    assert(CSVInferSchema.inferField(NullType, "60", options) == IntegerType)
-    assert(CSVInferSchema.inferField(NullType, "3.5", options) == DoubleType)
-    assert(CSVInferSchema.inferField(NullType, "test", options) == StringType)
-    assert(CSVInferSchema.inferField(NullType, "2015-08-20 15:57:00", options) 
== TimestampType)
-    assert(CSVInferSchema.inferField(NullType, "True", options) == BooleanType)
-    assert(CSVInferSchema.inferField(NullType, "FAlSE", options) == 
BooleanType)
+    val inferSchema = new CSVInferSchema(options)
+
+    assert(inferSchema.inferField(NullType, "") == NullType)
+    assert(inferSchema.inferField(NullType, null) == NullType)
+    assert(inferSchema.inferField(NullType, "100000000000") == LongType)
+    assert(inferSchema.inferField(NullType, "60") == IntegerType)
+    assert(inferSchema.inferField(NullType, "3.5") == DoubleType)
+    assert(inferSchema.inferField(NullType, "test") == StringType)
+    assert(inferSchema.inferField(NullType, "2015-08-20 15:57:00") == 
TimestampType)
+    assert(inferSchema.inferField(NullType, "True") == BooleanType)
+    assert(inferSchema.inferField(NullType, "FAlSE") == BooleanType)
 
     val textValueOne = Long.MaxValue.toString + "0"
     val decimalValueOne = new java.math.BigDecimal(textValueOne)
     val expectedTypeOne = DecimalType(decimalValueOne.precision, 
decimalValueOne.scale)
-    assert(CSVInferSchema.inferField(NullType, textValueOne, options) == 
expectedTypeOne)
+    assert(inferSchema.inferField(NullType, textValueOne) == expectedTypeOne)
   }
 
   test("String fields types are inferred correctly from other types") {
     val options = new CSVOptions(Map.empty[String, String], false, "GMT")
-    assert(CSVInferSchema.inferField(LongType, "1.0", options) == DoubleType)
-    assert(CSVInferSchema.inferField(LongType, "test", options) == StringType)
-    assert(CSVInferSchema.inferField(IntegerType, "1.0", options) == 
DoubleType)
-    assert(CSVInferSchema.inferField(DoubleType, null, options) == DoubleType)
-    assert(CSVInferSchema.inferField(DoubleType, "test", options) == 
StringType)
-    assert(CSVInferSchema.inferField(LongType, "2015-08-20 14:57:00", options) 
== TimestampType)
-    assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 15:57:00", 
options) == TimestampType)
-    assert(CSVInferSchema.inferField(LongType, "True", options) == BooleanType)
-    assert(CSVInferSchema.inferField(IntegerType, "FALSE", options) == 
BooleanType)
-    assert(CSVInferSchema.inferField(TimestampType, "FALSE", options) == 
BooleanType)
+    val inferSchema = new CSVInferSchema(options)
+
+    assert(inferSchema.inferField(LongType, "1.0") == DoubleType)
+    assert(inferSchema.inferField(LongType, "test") == StringType)
+    assert(inferSchema.inferField(IntegerType, "1.0") == DoubleType)
+    assert(inferSchema.inferField(DoubleType, null) == DoubleType)
+    assert(inferSchema.inferField(DoubleType, "test") == StringType)
+    assert(inferSchema.inferField(LongType, "2015-08-20 14:57:00") == 
TimestampType)
+    assert(inferSchema.inferField(DoubleType, "2015-08-20 15:57:00") == 
TimestampType)
+    assert(inferSchema.inferField(LongType, "True") == BooleanType)
+    assert(inferSchema.inferField(IntegerType, "FALSE") == BooleanType)
+    assert(inferSchema.inferField(TimestampType, "FALSE") == BooleanType)
 
     val textValueOne = Long.MaxValue.toString + "0"
     val decimalValueOne = new java.math.BigDecimal(textValueOne)
     val expectedTypeOne = DecimalType(decimalValueOne.precision, 
decimalValueOne.scale)
-    assert(CSVInferSchema.inferField(IntegerType, textValueOne, options) == 
expectedTypeOne)
+    assert(inferSchema.inferField(IntegerType, textValueOne) == 
expectedTypeOne)
   }
 
   test("Timestamp field types are inferred correctly via custom data format") {
     var options = new CSVOptions(Map("timestampFormat" -> "yyyy-mm"), false, 
"GMT")
-    assert(CSVInferSchema.inferField(TimestampType, "2015-08", options) == 
TimestampType)
+    var inferSchema = new CSVInferSchema(options)
+
+    assert(inferSchema.inferField(TimestampType, "2015-08") == TimestampType)
+
     options = new CSVOptions(Map("timestampFormat" -> "yyyy"), false, "GMT")
-    assert(CSVInferSchema.inferField(TimestampType, "2015", options) == 
TimestampType)
+    inferSchema = new CSVInferSchema(options)
+    assert(inferSchema.inferField(TimestampType, "2015") == TimestampType)
   }
 
   test("Timestamp field types are inferred correctly from other types") {
     val options = new CSVOptions(Map.empty[String, String], false, "GMT")
-    assert(CSVInferSchema.inferField(IntegerType, "2015-08-20 14", options) == 
StringType)
-    assert(CSVInferSchema.inferField(DoubleType, "2015-08-20 14:10", options) 
== StringType)
-    assert(CSVInferSchema.inferField(LongType, "2015-08 14:49:00", options) == 
StringType)
+    val inferSchema = new CSVInferSchema(options)
+
+    assert(inferSchema.inferField(IntegerType, "2015-08-20 14") == StringType)
+    assert(inferSchema.inferField(DoubleType, "2015-08-20 14:10") == 
StringType)
+    assert(inferSchema.inferField(LongType, "2015-08 14:49:00") == StringType)
   }
 
   test("Boolean fields types are inferred correctly from other types") {
     val options = new CSVOptions(Map.empty[String, String], false, "GMT")
-    assert(CSVInferSchema.inferField(LongType, "Fale", options) == StringType)
-    assert(CSVInferSchema.inferField(DoubleType, "TRUEe", options) == 
StringType)
+    val inferSchema = new CSVInferSchema(options)
+
+    assert(inferSchema.inferField(LongType, "Fale") == StringType)
+    assert(inferSchema.inferField(DoubleType, "TRUEe") == StringType)
   }
 
   test("Type arrays are merged to highest common type") {
+    val options = new CSVOptions(Map.empty[String, String], false, "GMT")
+    val inferSchema = new CSVInferSchema(options)
+
     assert(
-      CSVInferSchema.mergeRowTypes(Array(StringType),
+      inferSchema.mergeRowTypes(Array(StringType),
         Array(DoubleType)).deep == Array(StringType).deep)
     assert(
-      CSVInferSchema.mergeRowTypes(Array(IntegerType),
+      inferSchema.mergeRowTypes(Array(IntegerType),
         Array(LongType)).deep == Array(LongType).deep)
     assert(
-      CSVInferSchema.mergeRowTypes(Array(DoubleType),
+      inferSchema.mergeRowTypes(Array(DoubleType),
         Array(LongType)).deep == Array(DoubleType).deep)
   }
 
   test("Null fields are handled properly when a nullValue is specified") {
     var options = new CSVOptions(Map("nullValue" -> "null"), false, "GMT")
-    assert(CSVInferSchema.inferField(NullType, "null", options) == NullType)
-    assert(CSVInferSchema.inferField(StringType, "null", options) == 
StringType)
-    assert(CSVInferSchema.inferField(LongType, "null", options) == LongType)
+    var inferSchema = new CSVInferSchema(options)
+
+    assert(inferSchema.inferField(NullType, "null") == NullType)
+    assert(inferSchema.inferField(StringType, "null") == StringType)
+    assert(inferSchema.inferField(LongType, "null") == LongType)
 
     options = new CSVOptions(Map("nullValue" -> "\\N"), false, "GMT")
-    assert(CSVInferSchema.inferField(IntegerType, "\\N", options) == 
IntegerType)
-    assert(CSVInferSchema.inferField(DoubleType, "\\N", options) == DoubleType)
-    assert(CSVInferSchema.inferField(TimestampType, "\\N", options) == 
TimestampType)
-    assert(CSVInferSchema.inferField(BooleanType, "\\N", options) == 
BooleanType)
-    assert(CSVInferSchema.inferField(DecimalType(1, 1), "\\N", options) == 
DecimalType(1, 1))
+    inferSchema = new CSVInferSchema(options)
+
+    assert(inferSchema.inferField(IntegerType, "\\N") == IntegerType)
+    assert(inferSchema.inferField(DoubleType, "\\N") == DoubleType)
+    assert(inferSchema.inferField(TimestampType, "\\N") == TimestampType)
+    assert(inferSchema.inferField(BooleanType, "\\N") == BooleanType)
+    assert(inferSchema.inferField(DecimalType(1, 1), "\\N") == DecimalType(1, 
1))
   }
 
   test("Merging Nulltypes should yield Nulltype.") {
-    val mergedNullTypes = CSVInferSchema.mergeRowTypes(Array(NullType), 
Array(NullType))
+    val options = new CSVOptions(Map.empty[String, String], false, "GMT")
+    val inferSchema = new CSVInferSchema(options)
+
+    val mergedNullTypes = inferSchema.mergeRowTypes(Array(NullType), 
Array(NullType))
     assert(mergedNullTypes.deep == Array(NullType).deep)
   }
 
   test("SPARK-18433: Improve DataSource option keys to be more 
case-insensitive") {
     val options = new CSVOptions(Map("TiMeStampFormat" -> "yyyy-mm"), false, 
"GMT")
-    assert(CSVInferSchema.inferField(TimestampType, "2015-08", options) == 
TimestampType)
+    val inferSchema = new CSVInferSchema(options)
+
+    assert(inferSchema.inferField(TimestampType, "2015-08") == TimestampType)
   }
 
   test("SPARK-18877: `inferField` on DecimalType should find a common type 
with `typeSoFar`") {
     val options = new CSVOptions(Map.empty[String, String], false, "GMT")
+    val inferSchema = new CSVInferSchema(options)
 
     // 9.03E+12 is Decimal(3, -10) and 1.19E+11 is Decimal(3, -9).
-    assert(CSVInferSchema.inferField(DecimalType(3, -10), "1.19E+11", options) 
==
+    assert(inferSchema.inferField(DecimalType(3, -10), "1.19E11") ==
       DecimalType(4, -9))
 
     // BigDecimal("12345678901234567890.01234567890123456789") is precision 40 
and scale 20.
     val value = "12345678901234567890.01234567890123456789"
-    assert(CSVInferSchema.inferField(DecimalType(3, -10), value, options) == 
DoubleType)
+    assert(inferSchema.inferField(DecimalType(3, -10), value) == DoubleType)
 
     // Seq(s"${Long.MaxValue}1", "2015-12-01 00:00:00") should be StringType
-    assert(CSVInferSchema.inferField(NullType, s"${Long.MaxValue}1", options) 
== DecimalType(20, 0))
-    assert(CSVInferSchema.inferField(DecimalType(20, 0), "2015-12-01 
00:00:00", options)
+    assert(inferSchema.inferField(NullType, s"${Long.MaxValue}1") == 
DecimalType(20, 0))
+    assert(inferSchema.inferField(DecimalType(20, 0), "2015-12-01 00:00:00")
       == StringType)
   }
 
   test("DoubleType should be inferred when user defined nan/inf are provided") 
{
     val options = new CSVOptions(Map("nanValue" -> "nan", "negativeInf" -> 
"-inf",
       "positiveInf" -> "inf"), false, "GMT")
-    assert(CSVInferSchema.inferField(NullType, "nan", options) == DoubleType)
-    assert(CSVInferSchema.inferField(NullType, "inf", options) == DoubleType)
-    assert(CSVInferSchema.inferField(NullType, "-inf", options) == DoubleType)
+    val inferSchema = new CSVInferSchema(options)
+
+    assert(inferSchema.inferField(NullType, "nan") == DoubleType)
+    assert(inferSchema.inferField(NullType, "inf") == DoubleType)
+    assert(inferSchema.inferField(NullType, "-inf") == DoubleType)
+  }
+
+  test("inferring the decimal type using locale") {
+    def checkDecimalInfer(langTag: String, expectedType: DataType): Unit = {
+      val options = new CSVOptions(
+        parameters = Map("locale" -> langTag, "inferSchema" -> "true", "sep" 
-> "|"),
+        columnPruning = false,
+        defaultTimeZoneId = "GMT")
+      val inferSchema = new CSVInferSchema(options)
+
+      val df = new DecimalFormat("", new 
DecimalFormatSymbols(Locale.forLanguageTag(langTag)))
+      val input = df.format(Decimal(1000001).toBigDecimal)
+
+      assert(inferSchema.inferField(NullType, input) == expectedType)
+    }
+
+    Seq("en-US", "ko-KR", "ru-RU", "de-DE").foreach(checkDecimalInfer(_, 
DecimalType(7, 0)))
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f97326bc/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala
index e4e7dc2..7212402 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala
@@ -18,13 +18,17 @@
 package org.apache.spark.sql.catalyst.csv
 
 import java.math.BigDecimal
+import java.text.{DecimalFormat, DecimalFormatSymbols}
+import java.util.Locale
 
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.plans.SQLHelper
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 
-class UnivocityParserSuite extends SparkFunSuite {
+class UnivocityParserSuite extends SparkFunSuite with SQLHelper {
   private val parser = new UnivocityParser(
     StructType(Seq.empty),
     new CSVOptions(Map.empty[String, String], false, "GMT"))
@@ -196,4 +200,20 @@ class UnivocityParserSuite extends SparkFunSuite {
     assert(doubleVal2 == Double.PositiveInfinity)
   }
 
+  test("parse decimals using locale") {
+    def checkDecimalParsing(langTag: String): Unit = {
+      val decimalVal = new BigDecimal("1000.001")
+      val decimalType = new DecimalType(10, 5)
+      val expected = Decimal(decimalVal, decimalType.precision, 
decimalType.scale)
+      val df = new DecimalFormat("", new 
DecimalFormatSymbols(Locale.forLanguageTag(langTag)))
+      val input = df.format(expected.toBigDecimal)
+
+      val options = new CSVOptions(Map("locale" -> langTag), false, "GMT")
+      val parser = new UnivocityParser(new StructType().add("d", decimalType), 
options)
+
+      assert(parser.makeConverter("_1", decimalType, options = 
options).apply(input) === expected)
+    }
+
+    Seq("en-US", "ko-KR", "ru-RU", "de-DE").foreach(checkDecimalParsing)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f97326bc/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
index b35b885..b46dfb9 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
@@ -135,7 +135,7 @@ object TextInputCSVDataSource extends CSVDataSource {
           val parser = new CsvParser(parsedOptions.asParserSettings)
           linesWithoutHeader.map(parser.parseLine)
         }
-        CSVInferSchema.infer(tokenRDD, header, parsedOptions)
+        new CSVInferSchema(parsedOptions).infer(tokenRDD, header)
       case _ =>
         // If the first line could not be read, just return the empty schema.
         StructType(Nil)
@@ -208,7 +208,7 @@ object MultiLineCSVDataSource extends CSVDataSource {
             encoding = parsedOptions.charset)
         }
         val sampled = CSVUtils.sample(tokenRDD, parsedOptions)
-        CSVInferSchema.infer(sampled, header, parsedOptions)
+        new CSVInferSchema(parsedOptions).infer(sampled, header)
       case None =>
         // If the first row could not be read, just return the empty schema.
         StructType(Nil)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to