KUDU-721: [Spark] Add DECIMAL type support Adds DECIMAL support to the Kudu Spark source and RDD.
Change-Id: Ia5f7a801778ed81b68949bbf8d7c08d1a13ed840 Reviewed-on: http://gerrit.cloudera.org:8080/9213 Tested-by: Kudu Jenkins Reviewed-by: Dan Burkert <d...@cloudera.com> Project: http://git-wip-us.apache.org/repos/asf/kudu/repo Commit: http://git-wip-us.apache.org/repos/asf/kudu/commit/cbd34fa8 Tree: http://git-wip-us.apache.org/repos/asf/kudu/tree/cbd34fa8 Diff: http://git-wip-us.apache.org/repos/asf/kudu/diff/cbd34fa8 Branch: refs/heads/master Commit: cbd34fa850036b4e22e3ff5be92238677e99adb3 Parents: 4f34b69 Author: Grant Henke <granthe...@gmail.com> Authored: Sun Feb 4 21:40:39 2018 -0600 Committer: Grant Henke <granthe...@gmail.com> Committed: Tue Feb 13 22:20:22 2018 +0000 ---------------------------------------------------------------------- .../apache/kudu/spark/kudu/DefaultSource.scala | 10 ++-- .../apache/kudu/spark/kudu/KuduContext.scala | 54 ++++++++++++++------ .../org/apache/kudu/spark/kudu/KuduRDD.scala | 1 + .../kudu/spark/kudu/DefaultSourceTest.scala | 15 +++++- .../kudu/spark/kudu/KuduContextTest.scala | 7 ++- .../apache/kudu/spark/kudu/TestContext.scala | 22 +++++++- 6 files changed, 84 insertions(+), 25 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/kudu/blob/cbd34fa8/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala ---------------------------------------------------------------------- diff --git a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala index c0b1b0a..7987bf8 100644 --- a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala +++ b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/DefaultSource.scala @@ -17,6 +17,7 @@ package org.apache.kudu.spark.kudu +import java.math.BigDecimal import java.net.InetAddress import java.sql.Timestamp @@ -31,7 +32,7 @@ import org.apache.yetus.audience.InterfaceStability import org.apache.kudu.client.KuduPredicate.ComparisonOp import org.apache.kudu.client._ -import org.apache.kudu.{ColumnSchema, Type} +import org.apache.kudu.{ColumnSchema, ColumnTypeAttributes, Type} /** * Data source for integration with Spark's [[DataFrame]] API. @@ -183,7 +184,7 @@ class KuduRelation(private val tableName: String, def kuduColumnToSparkField: (ColumnSchema) => StructField = { columnSchema => - val sparkType = kuduTypeToSparkType(columnSchema.getType) + val sparkType = kuduTypeToSparkType(columnSchema.getType, columnSchema.getTypeAttributes) new StructField(columnSchema.getName, sparkType, columnSchema.isNullable) } @@ -271,6 +272,7 @@ class KuduRelation(private val tableName: String, case value: Double => KuduPredicate.newComparisonPredicate(columnSchema, operator, value) case value: String => KuduPredicate.newComparisonPredicate(columnSchema, operator, value) case value: Array[Byte] => KuduPredicate.newComparisonPredicate(columnSchema, operator, value) + case value: BigDecimal => KuduPredicate.newComparisonPredicate(columnSchema, operator, value) } } @@ -327,9 +329,10 @@ private[spark] object KuduRelation { * Converts a Kudu [[Type]] to a Spark SQL [[DataType]]. * * @param t the Kudu type + * @param a the Kudu type attributes * @return the corresponding Spark SQL type */ - private def kuduTypeToSparkType(t: Type): DataType = t match { + private def kuduTypeToSparkType(t: Type, a: ColumnTypeAttributes): DataType = t match { case Type.BOOL => BooleanType case Type.INT8 => ByteType case Type.INT16 => ShortType @@ -340,6 +343,7 @@ private[spark] object KuduRelation { case Type.DOUBLE => DoubleType case Type.STRING => StringType case Type.BINARY => BinaryType + case Type.DECIMAL => DecimalType(a.getPrecision, a.getScale) } /** http://git-wip-us.apache.org/repos/asf/kudu/blob/cbd34fa8/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduContext.scala ---------------------------------------------------------------------- diff --git a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduContext.scala b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduContext.scala index a981f5e..ece4df0 100644 --- a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduContext.scala +++ b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduContext.scala @@ -26,9 +26,10 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import org.apache.hadoop.util.ShutdownHookManager +import org.apache.kudu.ColumnTypeAttributes.ColumnTypeAttributesBuilder import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.types.{DataType, DataTypes, StructType} +import org.apache.spark.sql.types.{DataType, DataTypes, DecimalType, StructField, StructType} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.util.AccumulatorV2 import org.apache.yetus.audience.InterfaceStability @@ -161,17 +162,32 @@ class KuduContext(val kuduMaster: String, val kuduCols = new util.ArrayList[ColumnSchema]() // add the key columns first, in the order specified for (key <- keys) { - val f = schema.fields(schema.fieldIndex(key)) - kuduCols.add(new ColumnSchema.ColumnSchemaBuilder(f.name, kuduType(f.dataType)).key(true).build()) + val field = schema.fields(schema.fieldIndex(key)) + val col = createColumn(field, isKey = true) + kuduCols.add(col) } // now add the non-key columns - for (f <- schema.fields.filter(field=> !keys.contains(field.name))) { - kuduCols.add(new ColumnSchema.ColumnSchemaBuilder(f.name, kuduType(f.dataType)).nullable(f.nullable).key(false).build()) + for (field <- schema.fields.filter(field => !keys.contains(field.name))) { + val col = createColumn(field, isKey = false) + kuduCols.add(col) } syncClient.createTable(tableName, new Schema(kuduCols), options) } + private def createColumn(field: StructField, isKey: Boolean): ColumnSchema = { + val kt = kuduType(field.dataType) + val col = new ColumnSchema.ColumnSchemaBuilder(field.name, kt).key(isKey).nullable(field.nullable) + // Add ColumnTypeAttributesBuilder to DECIMAL columns + if (kt == Type.DECIMAL) { + val dt = field.dataType.asInstanceOf[DecimalType] + col.typeAttributes( + new ColumnTypeAttributesBuilder().precision(dt.precision).scale(dt.scale).build() + ) + } + col.build() + } + /** Map Spark SQL type to Kudu type */ def kuduType(dt: DataType) : Type = dt match { case DataTypes.BinaryType => Type.BINARY @@ -184,6 +200,7 @@ class KuduContext(val kuduMaster: String, case DataTypes.LongType => Type.INT64 case DataTypes.FloatType => Type.FLOAT case DataTypes.DoubleType => Type.DOUBLE + case DecimalType() => Type.DECIMAL case _ => throw new IllegalArgumentException(s"No support for Spark SQL type $dt") } @@ -276,18 +293,21 @@ class KuduContext(val kuduMaster: String, for ((sparkIdx, kuduIdx) <- indices) { if (row.isNullAt(sparkIdx)) { operation.getRow.setNull(kuduIdx) - } else schema.fields(sparkIdx).dataType match { - case DataTypes.StringType => operation.getRow.addString(kuduIdx, row.getString(sparkIdx)) - case DataTypes.BinaryType => operation.getRow.addBinary(kuduIdx, row.getAs[Array[Byte]](sparkIdx)) - case DataTypes.BooleanType => operation.getRow.addBoolean(kuduIdx, row.getBoolean(sparkIdx)) - case DataTypes.ByteType => operation.getRow.addByte(kuduIdx, row.getByte(sparkIdx)) - case DataTypes.ShortType => operation.getRow.addShort(kuduIdx, row.getShort(sparkIdx)) - case DataTypes.IntegerType => operation.getRow.addInt(kuduIdx, row.getInt(sparkIdx)) - case DataTypes.LongType => operation.getRow.addLong(kuduIdx, row.getLong(sparkIdx)) - case DataTypes.FloatType => operation.getRow.addFloat(kuduIdx, row.getFloat(sparkIdx)) - case DataTypes.DoubleType => operation.getRow.addDouble(kuduIdx, row.getDouble(sparkIdx)) - case DataTypes.TimestampType => operation.getRow.addLong(kuduIdx, KuduRelation.timestampToMicros(row.getTimestamp(sparkIdx))) - case t => throw new IllegalArgumentException(s"No support for Spark SQL type $t") + } else { + schema.fields(sparkIdx).dataType match { + case DataTypes.StringType => operation.getRow.addString(kuduIdx, row.getString(sparkIdx)) + case DataTypes.BinaryType => operation.getRow.addBinary(kuduIdx, row.getAs[Array[Byte]](sparkIdx)) + case DataTypes.BooleanType => operation.getRow.addBoolean(kuduIdx, row.getBoolean(sparkIdx)) + case DataTypes.ByteType => operation.getRow.addByte(kuduIdx, row.getByte(sparkIdx)) + case DataTypes.ShortType => operation.getRow.addShort(kuduIdx, row.getShort(sparkIdx)) + case DataTypes.IntegerType => operation.getRow.addInt(kuduIdx, row.getInt(sparkIdx)) + case DataTypes.LongType => operation.getRow.addLong(kuduIdx, row.getLong(sparkIdx)) + case DataTypes.FloatType => operation.getRow.addFloat(kuduIdx, row.getFloat(sparkIdx)) + case DataTypes.DoubleType => operation.getRow.addDouble(kuduIdx, row.getDouble(sparkIdx)) + case DataTypes.TimestampType => operation.getRow.addLong(kuduIdx, KuduRelation.timestampToMicros(row.getTimestamp(sparkIdx))) + case DecimalType() => operation.getRow.addDecimal(kuduIdx, row.getDecimal(sparkIdx)) + case t => throw new IllegalArgumentException(s"No support for Spark SQL type $t") + } } } session.apply(operation) http://git-wip-us.apache.org/repos/asf/kudu/blob/cbd34fa8/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduRDD.scala ---------------------------------------------------------------------- diff --git a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduRDD.scala b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduRDD.scala index daed3f0..5a4ed8b 100644 --- a/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduRDD.scala +++ b/java/kudu-spark/src/main/scala/org/apache/kudu/spark/kudu/KuduRDD.scala @@ -123,6 +123,7 @@ private class RowIterator(private val scanner: KuduScanner, case Type.DOUBLE => rowResult.getDouble(i) case Type.STRING => rowResult.getString(i) case Type.BINARY => rowResult.getBinaryCopy(i) + case Type.DECIMAL => rowResult.getDecimal(i) } } http://git-wip-us.apache.org/repos/asf/kudu/blob/cbd34fa8/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala ---------------------------------------------------------------------- diff --git a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala index 69d1f20..59997d2 100644 --- a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala +++ b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/DefaultSourceTest.scala @@ -249,7 +249,18 @@ class DefaultSourceTest extends FunSuite with TestContext with BeforeAndAfter wi sqlContext.sql(s"""SELECT key, c7_float FROM $tableName where c7_float > 5""").count()) } - + test("table scan with projection and predicate decimal32") { + assertEquals(rows.count { case (key, i, s, ts) => i > 5}, + sqlContext.sql(s"""SELECT key, c11_decimal32 FROM $tableName where c11_decimal32 > 5""").count()) + } + test("table scan with projection and predicate decimal64") { + assertEquals(rows.count { case (key, i, s, ts) => i > 5}, + sqlContext.sql(s"""SELECT key, c12_decimal64 FROM $tableName where c12_decimal64 > 5""").count()) + } + test("table scan with projection and predicate decimal128") { + assertEquals(rows.count { case (key, i, s, ts) => i > 5}, + sqlContext.sql(s"""SELECT key, c13_decimal128 FROM $tableName where c13_decimal128 > 5""").count()) + } test("table scan with projection and predicate ") { assertEquals(rows.count { case (key, i, s, ts) => s != null && s > "5" }, sqlContext.sql(s"""SELECT key FROM $tableName where c2_s > "5"""").count()) @@ -474,7 +485,7 @@ class DefaultSourceTest extends FunSuite with TestContext with BeforeAndAfter wi )) val dfDefaultSchema = sqlContext.read.options(kuduOptions).kudu - assertEquals(11, dfDefaultSchema.schema.fields.length) + assertEquals(14, dfDefaultSchema.schema.fields.length) val dfWithUserSchema = sqlContext.read.options(kuduOptions).schema(userSchema).kudu assertEquals(2, dfWithUserSchema.schema.fields.length) http://git-wip-us.apache.org/repos/asf/kudu/blob/cbd34fa8/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduContextTest.scala ---------------------------------------------------------------------- diff --git a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduContextTest.scala b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduContextTest.scala index 8156365..8e12dcd 100644 --- a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduContextTest.scala +++ b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduContextTest.scala @@ -17,6 +17,7 @@ package org.apache.kudu.spark.kudu import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} +import java.math.BigDecimal import java.sql.Timestamp import org.apache.spark.sql.functions.decode @@ -60,7 +61,8 @@ class KuduContextTest extends FunSuite with TestContext with Matchers { test("Test basic kuduRDD") { val rows = insertRows(rowCount) val scanList = kuduContext.kuduRDD(ss.sparkContext, "test", Seq("key", "c1_i", "c2_s", "c3_double", - "c4_long", "c5_bool", "c6_short", "c7_float", "c8_binary", "c9_unixtime_micros", "c10_byte")) + "c4_long", "c5_bool", "c6_short", "c7_float", "c8_binary", "c9_unixtime_micros", "c10_byte", + "c11_decimal32", "c12_decimal64", "c13_decimal128")) .map(r => r.toSeq).collect() scanList.foreach(r => { val index = r.apply(0).asInstanceOf[Int] @@ -77,6 +79,9 @@ class KuduContextTest extends FunSuite with TestContext with Matchers { assert(r.apply(9).asInstanceOf[Timestamp] == KuduRelation.microsToTimestamp(rows.apply(index)._4)) assert(r.apply(10).asInstanceOf[Byte] == rows.apply(index)._2.toByte) + assert(r.apply(11).asInstanceOf[BigDecimal] == BigDecimal.valueOf(rows.apply(index)._2)) + assert(r.apply(12).asInstanceOf[BigDecimal] == BigDecimal.valueOf(rows.apply(index)._2)) + assert(r.apply(13).asInstanceOf[BigDecimal] == BigDecimal.valueOf(rows.apply(index)._2)) }) } http://git-wip-us.apache.org/repos/asf/kudu/blob/cbd34fa8/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/TestContext.scala ---------------------------------------------------------------------- diff --git a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/TestContext.scala b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/TestContext.scala index 9ce0991..62b41cd 100644 --- a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/TestContext.scala +++ b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/TestContext.scala @@ -16,6 +16,7 @@ */ package org.apache.kudu.spark.kudu +import java.math.BigDecimal import java.util.Date import scala.collection.JavaConverters._ @@ -26,10 +27,12 @@ import org.apache.spark.SparkConf import org.scalatest.{BeforeAndAfterAll, Suite} import org.apache.kudu.ColumnSchema.ColumnSchemaBuilder +import org.apache.kudu.ColumnTypeAttributes.ColumnTypeAttributesBuilder import org.apache.kudu.client.KuduClient.KuduClientBuilder import org.apache.kudu.client.MiniKuduCluster.MiniKuduClusterBuilder import org.apache.kudu.client.{CreateTableOptions, KuduClient, KuduTable, MiniKuduCluster} import org.apache.kudu.{Schema, Type} +import org.apache.kudu.util.DecimalUtil import org.apache.spark.sql.SparkSession trait TestContext extends BeforeAndAfterAll { self: Suite => @@ -54,8 +57,20 @@ trait TestContext extends BeforeAndAfterAll { self: Suite => new ColumnSchemaBuilder("c7_float", Type.FLOAT).build(), new ColumnSchemaBuilder("c8_binary", Type.BINARY).build(), new ColumnSchemaBuilder("c9_unixtime_micros", Type.UNIXTIME_MICROS).build(), - new ColumnSchemaBuilder("c10_byte", Type.INT8).build()) - new Schema(columns) + new ColumnSchemaBuilder("c10_byte", Type.INT8).build(), + new ColumnSchemaBuilder("c11_decimal32", Type.DECIMAL) + .typeAttributes( + new ColumnTypeAttributesBuilder().precision(DecimalUtil.MAX_DECIMAL32_PRECISION).build() + ).build(), + new ColumnSchemaBuilder("c12_decimal64", Type.DECIMAL) + .typeAttributes( + new ColumnTypeAttributesBuilder().precision(DecimalUtil.MAX_DECIMAL64_PRECISION).build() + ).build(), + new ColumnSchemaBuilder("c13_decimal128", Type.DECIMAL) + .typeAttributes( + new ColumnTypeAttributesBuilder().precision(DecimalUtil.MAX_DECIMAL128_PRECISION).build() + ).build()) + new Schema(columns) } val appID: String = new Date().toString + math.floor(math.random * 10E4).toLong.toString @@ -113,6 +128,9 @@ trait TestContext extends BeforeAndAfterAll { self: Suite => val ts = System.currentTimeMillis() * 1000 row.addLong(9, ts) row.addByte(10, i.toByte) + row.addDecimal(11, BigDecimal.valueOf(i)) + row.addDecimal(12, BigDecimal.valueOf(i)) + row.addDecimal(13, BigDecimal.valueOf(i)) // Sprinkling some nulls so that queries see them. val s = if (i % 2 == 0) {