Repository: spark
Updated Branches:
  refs/heads/branch-2.0 403ba6513 -> 381a82589


[SPARK-15241] [SPARK-15242] [SQL] fix 2 decimal-related issues in RowEncoder

## What changes were proposed in this pull request?

SPARK-15241: We now support java decimal and catalyst decimal in external row, 
it makes sense to also support scala decimal.

SPARK-15242: This is a long-standing bug, and is exposed after 
https://github.com/apache/spark/pull/12364, which eliminate the `If` expression 
if the field is not nullable:
```
val fieldValue = serializerFor(
  GetExternalRowField(inputObject, i, externalDataTypeForInput(f.dataType)),
  f.dataType)
if (f.nullable) {
  If(
    Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil),
    Literal.create(null, f.dataType),
    fieldValue)
} else {
  fieldValue
}
```

Previously, we always use `DecimalType.SYSTEM_DEFAULT` as the output type of 
converted decimal field, which is wrong as it doesn't match the real decimal 
type. However, it works well because we always put converted field into `If` 
expression to do the null check, and `If` use its `trueValue`'s data type as 
its output type.
Now if we have a not nullable decimal field, then the converted field's output 
type will be `DecimalType.SYSTEM_DEFAULT`, and we will write wrong data into 
unsafe row.

The fix is simple, just use the given decimal type as the output type of 
converted decimal field.

These 2 issues was found at https://github.com/apache/spark/pull/13008

## How was this patch tested?

new tests in RowEncoderSuite

Author: Wenchen Fan <[email protected]>

Closes #13019 from cloud-fan/encoder-decimal.

(cherry picked from commit d8935db5ecb7c959585411da9bf1e9a9c4d5cb37)
Signed-off-by: Davies Liu <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/381a8258
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/381a8258
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/381a8258

Branch: refs/heads/branch-2.0
Commit: 381a825890b74afa0bd7325265aa4e15bbc8f10f
Parents: 403ba65
Author: Wenchen Fan <[email protected]>
Authored: Wed May 11 11:16:05 2016 -0700
Committer: Davies Liu <[email protected]>
Committed: Wed May 11 11:16:12 2016 -0700

----------------------------------------------------------------------
 .../sql/catalyst/encoders/RowEncoder.scala      |  6 ++--
 .../org/apache/spark/sql/types/Decimal.scala    |  1 +
 .../encoders/ExpressionEncoderSuite.scala       |  3 +-
 .../sql/catalyst/encoders/RowEncoderSuite.scala | 29 ++++++++++++++++----
 4 files changed, 29 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/381a8258/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index cfde3bf..33ac1fd 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -84,10 +84,10 @@ object RowEncoder {
         "fromJavaDate",
         inputObject :: Nil)
 
-    case _: DecimalType =>
+    case d: DecimalType =>
       StaticInvoke(
         Decimal.getClass,
-        DecimalType.SYSTEM_DEFAULT,
+        d,
         "fromDecimal",
         inputObject :: Nil)
 
@@ -162,7 +162,7 @@ object RowEncoder {
    * `org.apache.spark.sql.types.Decimal`.
    */
   private def externalDataTypeForInput(dt: DataType): DataType = dt match {
-    // In order to support both Decimal and java BigDecimal in external row, 
we make this
+    // In order to support both Decimal and java/scala BigDecimal in external 
row, we make this
     // as java.lang.Object.
     case _: DecimalType => ObjectType(classOf[java.lang.Object])
     case _ => externalDataTypeFor(dt)

http://git-wip-us.apache.org/repos/asf/spark/blob/381a8258/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index 6f4ec6b..2f7422b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -386,6 +386,7 @@ object Decimal {
   def fromDecimal(value: Any): Decimal = {
     value match {
       case j: java.math.BigDecimal => apply(j)
+      case d: BigDecimal => apply(d)
       case d: Decimal => d
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/381a8258/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index c3b20e2..177b139 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -108,7 +108,7 @@ class ExpressionEncoderSuite extends PlanTest with 
AnalysisTest {
   encodeDecodeTest(new java.lang.Double(-3.7), "boxed double")
 
   encodeDecodeTest(BigDecimal("32131413.211321313"), "scala decimal")
-  // encodeDecodeTest(new java.math.BigDecimal("231341.23123"), "java decimal")
+  encodeDecodeTest(new java.math.BigDecimal("231341.23123"), "java decimal")
 
   encodeDecodeTest(Decimal("32131413.211321313"), "catalyst decimal")
 
@@ -336,6 +336,7 @@ class ExpressionEncoderSuite extends PlanTest with 
AnalysisTest {
           Arrays.deepEquals(b1.asInstanceOf[Array[AnyRef]], 
b2.asInstanceOf[Array[AnyRef]])
         case (b1: Array[_], b2: Array[_]) =>
           Arrays.equals(b1.asInstanceOf[Array[AnyRef]], 
b2.asInstanceOf[Array[AnyRef]])
+        case (left: Comparable[Any], right: Comparable[Any]) => 
left.compareTo(right) == 0
         case _ => input == convertedBack
       }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/381a8258/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index 98be3b0..4800e2e 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -143,21 +143,38 @@ class RowEncoderSuite extends SparkFunSuite {
     assert(input.getStruct(0) == convertedBack.getStruct(0))
   }
 
-  test("encode/decode Decimal") {
+  test("encode/decode decimal type") {
     val schema = new StructType()
       .add("int", IntegerType)
       .add("string", StringType)
       .add("double", DoubleType)
-      .add("decimal", DecimalType.SYSTEM_DEFAULT)
+      .add("java_decimal", DecimalType.SYSTEM_DEFAULT)
+      .add("scala_decimal", DecimalType.SYSTEM_DEFAULT)
+      .add("catalyst_decimal", DecimalType.SYSTEM_DEFAULT)
 
     val encoder = RowEncoder(schema)
 
-    val input: Row = Row(100, "test", 0.123, Decimal(1234.5678))
+    val javaDecimal = new java.math.BigDecimal("1234.5678")
+    val scalaDecimal = BigDecimal("1234.5678")
+    val catalystDecimal = Decimal("1234.5678")
+
+    val input = Row(100, "test", 0.123, javaDecimal, scalaDecimal, 
catalystDecimal)
     val row = encoder.toRow(input)
     val convertedBack = encoder.fromRow(row)
-    // Decimal inside external row will be converted back to Java BigDecimal 
when decoding.
-    assert(input.get(3).asInstanceOf[Decimal].toJavaBigDecimal
-      .compareTo(convertedBack.getDecimal(3)) == 0)
+    // Decimal will be converted back to Java BigDecimal when decoding.
+    assert(convertedBack.getDecimal(3).compareTo(javaDecimal) == 0)
+    assert(convertedBack.getDecimal(4).compareTo(scalaDecimal.bigDecimal) == 0)
+    
assert(convertedBack.getDecimal(5).compareTo(catalystDecimal.toJavaBigDecimal) 
== 0)
+  }
+
+  test("RowEncoder should preserve decimal precision and scale") {
+    val schema = new StructType().add("decimal", DecimalType(10, 5), false)
+    val encoder = RowEncoder(schema)
+    val decimal = Decimal("67123.45")
+    val input = Row(decimal)
+    val row = encoder.toRow(input)
+
+    assert(row.toSeq(schema).head == decimal)
   }
 
   test("RowEncoder should preserve schema nullability") {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to