KUDU-721: [Spark] Add DECIMAL type support

Adds DECIMAL support to the Kudu Spark source and RDD.

Change-Id: Ia5f7a801778ed81b68949bbf8d7c08d1a13ed840
Reviewed-on: http://gerrit.cloudera.org:8080/9213
Tested-by: Kudu Jenkins
Reviewed-by: Dan Burkert <d...@cloudera.com>


Project: http://git-wip-us.apache.org/repos/asf/kudu/repo
Commit: http://git-wip-us.apache.org/repos/asf/kudu/commit/cbd34fa8
Tree: http://git-wip-us.apache.org/repos/asf/kudu/tree/cbd34fa8
Diff: http://git-wip-us.apache.org/repos/asf/kudu/diff/cbd34fa8

Branch: refs/heads/master
Commit: cbd34fa850036b4e22e3ff5be92238677e99adb3
Parents: 4f34b69
Author: Grant Henke <granthe...@gmail.com>
Authored: Sun Feb 4 21:40:39 2018 -0600
Committer: Grant Henke <granthe...@gmail.com>
Committed: Tue Feb 13 22:20:22 2018 +0000

----------------------------------------------------------------------
 .../apache/kudu/spark/kudu/DefaultSource.scala  | 10 ++--
 .../apache/kudu/spark/kudu/KuduContext.scala    | 54 ++++++++++++++------
 .../org/apache/kudu/spark/kudu/KuduRDD.scala    |  1 +
 .../kudu/spark/kudu/DefaultSourceTest.scala     | 15 +++++-
 .../kudu/spark/kudu/KuduContextTest.scala       |  7 ++-
 .../apache/kudu/spark/kudu/TestContext.scala    | 22 +++++++-
 6 files changed, 84 insertions(+), 25 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/kudu/blob/cbd34fa8/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala
----------------------------------------------------------------------
diff --git 
a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala 
b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala
index c0b1b0a..7987bf8 100644
--- 
a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala
+++ 
b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala
@@ -17,6 +17,7 @@
 
 package org.apache.kudu.spark.kudu
 
+import java.math.BigDecimal
 import java.net.InetAddress
 import java.sql.Timestamp
 
@@ -31,7 +32,7 @@ import org.apache.yetus.audience.InterfaceStability
 
 import org.apache.kudu.client.KuduPredicate.ComparisonOp
 import org.apache.kudu.client._
-import org.apache.kudu.{ColumnSchema, Type}
+import org.apache.kudu.{ColumnSchema, ColumnTypeAttributes, Type}
 
 /**
   * Data source for integration with Spark's [[DataFrame]] API.
@@ -183,7 +184,7 @@ class KuduRelation(private val tableName: String,
 
   def kuduColumnToSparkField: (ColumnSchema) => StructField = {
     columnSchema =>
-      val sparkType = kuduTypeToSparkType(columnSchema.getType)
+      val sparkType = kuduTypeToSparkType(columnSchema.getType, 
columnSchema.getTypeAttributes)
       new StructField(columnSchema.getName, sparkType, columnSchema.isNullable)
   }
 
@@ -271,6 +272,7 @@ class KuduRelation(private val tableName: String,
       case value: Double => KuduPredicate.newComparisonPredicate(columnSchema, 
operator, value)
       case value: String => KuduPredicate.newComparisonPredicate(columnSchema, 
operator, value)
       case value: Array[Byte] => 
KuduPredicate.newComparisonPredicate(columnSchema, operator, value)
+      case value: BigDecimal => 
KuduPredicate.newComparisonPredicate(columnSchema, operator, value)
     }
   }
 
@@ -327,9 +329,10 @@ private[spark] object KuduRelation {
     * Converts a Kudu [[Type]] to a Spark SQL [[DataType]].
     *
     * @param t the Kudu type
+    * @param a the Kudu type attributes
     * @return the corresponding Spark SQL type
     */
-  private def kuduTypeToSparkType(t: Type): DataType = t match {
+  private def kuduTypeToSparkType(t: Type, a: ColumnTypeAttributes): DataType 
= t match {
     case Type.BOOL => BooleanType
     case Type.INT8 => ByteType
     case Type.INT16 => ShortType
@@ -340,6 +343,7 @@ private[spark] object KuduRelation {
     case Type.DOUBLE => DoubleType
     case Type.STRING => StringType
     case Type.BINARY => BinaryType
+    case Type.DECIMAL => DecimalType(a.getPrecision, a.getScale)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/kudu/blob/cbd34fa8/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduContext.scala
----------------------------------------------------------------------
diff --git 
a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduContext.scala 
b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduContext.scala
index a981f5e..ece4df0 100644
--- 
a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduContext.scala
+++ 
b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduContext.scala
@@ -26,9 +26,10 @@ import scala.collection.JavaConverters._
 import scala.collection.mutable
 
 import org.apache.hadoop.util.ShutdownHookManager
+import org.apache.kudu.ColumnTypeAttributes.ColumnTypeAttributesBuilder
 import org.apache.spark.SparkContext
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.types.{DataType, DataTypes, StructType}
+import org.apache.spark.sql.types.{DataType, DataTypes, DecimalType, 
StructField, StructType}
 import org.apache.spark.sql.{DataFrame, Row}
 import org.apache.spark.util.AccumulatorV2
 import org.apache.yetus.audience.InterfaceStability
@@ -161,17 +162,32 @@ class KuduContext(val kuduMaster: String,
     val kuduCols = new util.ArrayList[ColumnSchema]()
     // add the key columns first, in the order specified
     for (key <- keys) {
-      val f = schema.fields(schema.fieldIndex(key))
-      kuduCols.add(new ColumnSchema.ColumnSchemaBuilder(f.name, 
kuduType(f.dataType)).key(true).build())
+      val field = schema.fields(schema.fieldIndex(key))
+      val col = createColumn(field, isKey = true)
+      kuduCols.add(col)
     }
     // now add the non-key columns
-    for (f <- schema.fields.filter(field=> !keys.contains(field.name))) {
-      kuduCols.add(new ColumnSchema.ColumnSchemaBuilder(f.name, 
kuduType(f.dataType)).nullable(f.nullable).key(false).build())
+    for (field <- schema.fields.filter(field => !keys.contains(field.name))) {
+      val col = createColumn(field, isKey = false)
+      kuduCols.add(col)
     }
 
     syncClient.createTable(tableName, new Schema(kuduCols), options)
   }
 
+  private def createColumn(field: StructField, isKey: Boolean): ColumnSchema = 
{
+    val kt = kuduType(field.dataType)
+    val col = new ColumnSchema.ColumnSchemaBuilder(field.name, 
kt).key(isKey).nullable(field.nullable)
+    // Add ColumnTypeAttributesBuilder to DECIMAL columns
+    if (kt == Type.DECIMAL) {
+      val dt = field.dataType.asInstanceOf[DecimalType]
+      col.typeAttributes(
+        new 
ColumnTypeAttributesBuilder().precision(dt.precision).scale(dt.scale).build()
+      )
+    }
+    col.build()
+  }
+
   /** Map Spark SQL type to Kudu type */
   def kuduType(dt: DataType) : Type = dt match {
     case DataTypes.BinaryType => Type.BINARY
@@ -184,6 +200,7 @@ class KuduContext(val kuduMaster: String,
     case DataTypes.LongType => Type.INT64
     case DataTypes.FloatType => Type.FLOAT
     case DataTypes.DoubleType => Type.DOUBLE
+    case DecimalType() => Type.DECIMAL
     case _ => throw new IllegalArgumentException(s"No support for Spark SQL 
type $dt")
   }
 
@@ -276,18 +293,21 @@ class KuduContext(val kuduMaster: String,
         for ((sparkIdx, kuduIdx) <- indices) {
           if (row.isNullAt(sparkIdx)) {
             operation.getRow.setNull(kuduIdx)
-          } else schema.fields(sparkIdx).dataType match {
-            case DataTypes.StringType => operation.getRow.addString(kuduIdx, 
row.getString(sparkIdx))
-            case DataTypes.BinaryType => operation.getRow.addBinary(kuduIdx, 
row.getAs[Array[Byte]](sparkIdx))
-            case DataTypes.BooleanType => operation.getRow.addBoolean(kuduIdx, 
row.getBoolean(sparkIdx))
-            case DataTypes.ByteType => operation.getRow.addByte(kuduIdx, 
row.getByte(sparkIdx))
-            case DataTypes.ShortType => operation.getRow.addShort(kuduIdx, 
row.getShort(sparkIdx))
-            case DataTypes.IntegerType => operation.getRow.addInt(kuduIdx, 
row.getInt(sparkIdx))
-            case DataTypes.LongType => operation.getRow.addLong(kuduIdx, 
row.getLong(sparkIdx))
-            case DataTypes.FloatType => operation.getRow.addFloat(kuduIdx, 
row.getFloat(sparkIdx))
-            case DataTypes.DoubleType => operation.getRow.addDouble(kuduIdx, 
row.getDouble(sparkIdx))
-            case DataTypes.TimestampType => operation.getRow.addLong(kuduIdx, 
KuduRelation.timestampToMicros(row.getTimestamp(sparkIdx)))
-            case t => throw new IllegalArgumentException(s"No support for 
Spark SQL type $t")
+          } else {
+            schema.fields(sparkIdx).dataType match {
+              case DataTypes.StringType => operation.getRow.addString(kuduIdx, 
row.getString(sparkIdx))
+              case DataTypes.BinaryType => operation.getRow.addBinary(kuduIdx, 
row.getAs[Array[Byte]](sparkIdx))
+              case DataTypes.BooleanType => 
operation.getRow.addBoolean(kuduIdx, row.getBoolean(sparkIdx))
+              case DataTypes.ByteType => operation.getRow.addByte(kuduIdx, 
row.getByte(sparkIdx))
+              case DataTypes.ShortType => operation.getRow.addShort(kuduIdx, 
row.getShort(sparkIdx))
+              case DataTypes.IntegerType => operation.getRow.addInt(kuduIdx, 
row.getInt(sparkIdx))
+              case DataTypes.LongType => operation.getRow.addLong(kuduIdx, 
row.getLong(sparkIdx))
+              case DataTypes.FloatType => operation.getRow.addFloat(kuduIdx, 
row.getFloat(sparkIdx))
+              case DataTypes.DoubleType => operation.getRow.addDouble(kuduIdx, 
row.getDouble(sparkIdx))
+              case DataTypes.TimestampType => 
operation.getRow.addLong(kuduIdx, 
KuduRelation.timestampToMicros(row.getTimestamp(sparkIdx)))
+              case DecimalType() => operation.getRow.addDecimal(kuduIdx, 
row.getDecimal(sparkIdx))
+              case t => throw new IllegalArgumentException(s"No support for 
Spark SQL type $t")
+            }
           }
         }
         session.apply(operation)

http://git-wip-us.apache.org/repos/asf/kudu/blob/cbd34fa8/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduRDD.scala
----------------------------------------------------------------------
diff --git 
a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduRDD.scala 
b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduRDD.scala
index daed3f0..5a4ed8b 100644
--- a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduRDD.scala
+++ b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduRDD.scala
@@ -123,6 +123,7 @@ private class RowIterator(private val scanner: KuduScanner,
       case Type.DOUBLE => rowResult.getDouble(i)
       case Type.STRING => rowResult.getString(i)
       case Type.BINARY => rowResult.getBinaryCopy(i)
+      case Type.DECIMAL => rowResult.getDecimal(i)
     }
   }
 

http://git-wip-us.apache.org/repos/asf/kudu/blob/cbd34fa8/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala
----------------------------------------------------------------------
diff --git 
a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala
 
b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala
index 69d1f20..59997d2 100644
--- 
a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala
+++ 
b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala
@@ -249,7 +249,18 @@ class DefaultSourceTest extends FunSuite with TestContext 
with BeforeAndAfter wi
                  sqlContext.sql(s"""SELECT key, c7_float FROM $tableName where 
c7_float > 5""").count())
 
   }
-
+  test("table scan with projection and predicate decimal32") {
+    assertEquals(rows.count { case (key, i, s, ts) => i > 5},
+      sqlContext.sql(s"""SELECT key, c11_decimal32 FROM $tableName where 
c11_decimal32 > 5""").count())
+  }
+  test("table scan with projection and predicate decimal64") {
+    assertEquals(rows.count { case (key, i, s, ts) => i > 5},
+      sqlContext.sql(s"""SELECT key, c12_decimal64 FROM $tableName where 
c12_decimal64 > 5""").count())
+  }
+  test("table scan with projection and predicate decimal128") {
+    assertEquals(rows.count { case (key, i, s, ts) => i > 5},
+      sqlContext.sql(s"""SELECT key, c13_decimal128 FROM $tableName where 
c13_decimal128 > 5""").count())
+  }
   test("table scan with projection and predicate ") {
     assertEquals(rows.count { case (key, i, s, ts) => s != null && s > "5" },
       sqlContext.sql(s"""SELECT key FROM $tableName where c2_s > 
"5"""").count())
@@ -474,7 +485,7 @@ class DefaultSourceTest extends FunSuite with TestContext 
with BeforeAndAfter wi
     ))
 
     val dfDefaultSchema = sqlContext.read.options(kuduOptions).kudu
-    assertEquals(11, dfDefaultSchema.schema.fields.length)
+    assertEquals(14, dfDefaultSchema.schema.fields.length)
 
     val dfWithUserSchema = 
sqlContext.read.options(kuduOptions).schema(userSchema).kudu
     assertEquals(2, dfWithUserSchema.schema.fields.length)

http://git-wip-us.apache.org/repos/asf/kudu/blob/cbd34fa8/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduContextTest.scala
----------------------------------------------------------------------
diff --git 
a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduContextTest.scala
 
b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduContextTest.scala
index 8156365..8e12dcd 100644
--- 
a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduContextTest.scala
+++ 
b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduContextTest.scala
@@ -17,6 +17,7 @@
 package org.apache.kudu.spark.kudu
 
 import java.io.{ByteArrayInputStream, ByteArrayOutputStream, 
ObjectInputStream, ObjectOutputStream}
+import java.math.BigDecimal
 import java.sql.Timestamp
 
 import org.apache.spark.sql.functions.decode
@@ -60,7 +61,8 @@ class KuduContextTest extends FunSuite with TestContext with 
Matchers {
   test("Test basic kuduRDD") {
     val rows = insertRows(rowCount)
     val scanList = kuduContext.kuduRDD(ss.sparkContext, "test", Seq("key", 
"c1_i", "c2_s", "c3_double",
-        "c4_long", "c5_bool", "c6_short", "c7_float", "c8_binary", 
"c9_unixtime_micros", "c10_byte"))
+        "c4_long", "c5_bool", "c6_short", "c7_float", "c8_binary", 
"c9_unixtime_micros", "c10_byte",
+        "c11_decimal32", "c12_decimal64", "c13_decimal128"))
       .map(r => r.toSeq).collect()
     scanList.foreach(r => {
       val index = r.apply(0).asInstanceOf[Int]
@@ -77,6 +79,9 @@ class KuduContextTest extends FunSuite with TestContext with 
Matchers {
       assert(r.apply(9).asInstanceOf[Timestamp] ==
         KuduRelation.microsToTimestamp(rows.apply(index)._4))
       assert(r.apply(10).asInstanceOf[Byte] == rows.apply(index)._2.toByte)
+      assert(r.apply(11).asInstanceOf[BigDecimal] == 
BigDecimal.valueOf(rows.apply(index)._2))
+      assert(r.apply(12).asInstanceOf[BigDecimal] == 
BigDecimal.valueOf(rows.apply(index)._2))
+      assert(r.apply(13).asInstanceOf[BigDecimal] == 
BigDecimal.valueOf(rows.apply(index)._2))
     })
   }
 

http://git-wip-us.apache.org/repos/asf/kudu/blob/cbd34fa8/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/TestContext.scala
----------------------------------------------------------------------
diff --git 
a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/TestContext.scala 
b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/TestContext.scala
index 9ce0991..62b41cd 100644
--- 
a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/TestContext.scala
+++ 
b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/TestContext.scala
@@ -16,6 +16,7 @@
  */
 package org.apache.kudu.spark.kudu
 
+import java.math.BigDecimal
 import java.util.Date
 
 import scala.collection.JavaConverters._
@@ -26,10 +27,12 @@ import org.apache.spark.SparkConf
 import org.scalatest.{BeforeAndAfterAll, Suite}
 
 import org.apache.kudu.ColumnSchema.ColumnSchemaBuilder
+import org.apache.kudu.ColumnTypeAttributes.ColumnTypeAttributesBuilder
 import org.apache.kudu.client.KuduClient.KuduClientBuilder
 import org.apache.kudu.client.MiniKuduCluster.MiniKuduClusterBuilder
 import org.apache.kudu.client.{CreateTableOptions, KuduClient, KuduTable, 
MiniKuduCluster}
 import org.apache.kudu.{Schema, Type}
+import org.apache.kudu.util.DecimalUtil
 import org.apache.spark.sql.SparkSession
 
 trait TestContext extends BeforeAndAfterAll { self: Suite =>
@@ -54,8 +57,20 @@ trait TestContext extends BeforeAndAfterAll { self: Suite =>
       new ColumnSchemaBuilder("c7_float", Type.FLOAT).build(),
       new ColumnSchemaBuilder("c8_binary", Type.BINARY).build(),
       new ColumnSchemaBuilder("c9_unixtime_micros", 
Type.UNIXTIME_MICROS).build(),
-      new ColumnSchemaBuilder("c10_byte", Type.INT8).build())
-    new Schema(columns)
+      new ColumnSchemaBuilder("c10_byte", Type.INT8).build(),
+      new ColumnSchemaBuilder("c11_decimal32", Type.DECIMAL)
+        .typeAttributes(
+          new 
ColumnTypeAttributesBuilder().precision(DecimalUtil.MAX_DECIMAL32_PRECISION).build()
+        ).build(),
+      new ColumnSchemaBuilder("c12_decimal64", Type.DECIMAL)
+        .typeAttributes(
+          new 
ColumnTypeAttributesBuilder().precision(DecimalUtil.MAX_DECIMAL64_PRECISION).build()
+        ).build(),
+      new ColumnSchemaBuilder("c13_decimal128", Type.DECIMAL)
+        .typeAttributes(
+          new 
ColumnTypeAttributesBuilder().precision(DecimalUtil.MAX_DECIMAL128_PRECISION).build()
+        ).build())
+      new Schema(columns)
   }
 
   val appID: String = new Date().toString + math.floor(math.random * 
10E4).toLong.toString
@@ -113,6 +128,9 @@ trait TestContext extends BeforeAndAfterAll { self: Suite =>
       val ts = System.currentTimeMillis() * 1000
       row.addLong(9, ts)
       row.addByte(10, i.toByte)
+      row.addDecimal(11, BigDecimal.valueOf(i))
+      row.addDecimal(12, BigDecimal.valueOf(i))
+      row.addDecimal(13, BigDecimal.valueOf(i))
 
       // Sprinkling some nulls so that queries see them.
       val s = if (i % 2 == 0) {

Reply via email to