Repository: spark Updated Branches: refs/heads/master 56f2c61cd -> 23f966f47
http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index a3573e6..74f68d0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.decimal.Decimal /* Implicit conversions */ import scala.collection.JavaConversions._ @@ -325,7 +326,11 @@ private[hive] object HiveQl { } protected def nodeToDataType(node: Node): DataType = node match { - case Token("TOK_DECIMAL", Nil) => DecimalType + case Token("TOK_DECIMAL", precision :: scale :: Nil) => + DecimalType(precision.getText.toInt, scale.getText.toInt) + case Token("TOK_DECIMAL", precision :: Nil) => + DecimalType(precision.getText.toInt, 0) + case Token("TOK_DECIMAL", Nil) => DecimalType.Unlimited case Token("TOK_BIGINT", Nil) => LongType case Token("TOK_INT", Nil) => IntegerType case Token("TOK_TINYINT", Nil) => ByteType @@ -942,8 +947,12 @@ private[hive] object HiveQl { Cast(nodeToExpr(arg), BinaryType) case Token("TOK_FUNCTION", Token("TOK_BOOLEAN", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), BooleanType) + case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: scale :: nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), DecimalType(precision.getText.toInt, scale.getText.toInt)) + case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), DecimalType(precision.getText.toInt, 0)) case Token("TOK_FUNCTION", Token("TOK_DECIMAL", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DecimalType) + Cast(nodeToExpr(arg), DecimalType.Unlimited) case Token("TOK_FUNCTION", Token("TOK_TIMESTAMP", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), TimestampType) case Token("TOK_FUNCTION", Token("TOK_DATE", Nil) :: arg :: Nil) => @@ -1063,7 +1072,7 @@ private[hive] object HiveQl { } else if (ast.getText.endsWith("BD") || ast.getText.endsWith("D")) { // Literal decimal val strVal = ast.getText.stripSuffix("D").stripSuffix("B") - v = Literal(BigDecimal(strVal)) + v = Literal(Decimal(strVal)) } else { v = Literal(ast.getText.toDouble, DoubleType) v = Literal(ast.getText.toLong, LongType) http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 79234f8..92bc1c6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -35,6 +35,7 @@ import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.catalyst.types.decimal.Decimal import org.apache.spark.sql.execution.{Command, SparkPlan, UnaryNode} import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.{ ShimFileSinkDesc => FileSinkDesc} @@ -76,7 +77,7 @@ case class InsertIntoHiveTable( (o: Any) => new HiveVarchar(o.asInstanceOf[String], o.asInstanceOf[String].size) case _: JavaHiveDecimalObjectInspector => - (o: Any) => HiveShim.createDecimal(o.asInstanceOf[BigDecimal].underlying()) + (o: Any) => HiveShim.createDecimal(o.asInstanceOf[Decimal].toBigDecimal.underlying()) case soi: StandardStructObjectInspector => val wrappers = soi.getAllStructFieldRefs.map(ref => wrapperFor(ref.getFieldObjectInspector)) http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala ---------------------------------------------------------------------- diff --git a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala index afc252a..8e946b7 100644 --- a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala +++ b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala @@ -30,21 +30,24 @@ import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc} import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.stats.StatsSetupConst import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory -import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector +import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory} +import org.apache.hadoop.hive.serde2.objectinspector.{PrimitiveObjectInspector, ObjectInspector} +import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, TypeInfoFactory} import org.apache.hadoop.hive.serde2.{Deserializer, ColumnProjectionUtils} import org.apache.hadoop.hive.serde2.{io => hiveIo} import org.apache.hadoop.{io => hadoopIo} import org.apache.hadoop.mapred.InputFormat +import org.apache.spark.sql.catalyst.types.decimal.Decimal import scala.collection.JavaConversions._ import scala.language.implicitConversions +import org.apache.spark.sql.catalyst.types.DecimalType + /** * A compatibility layer for interacting with Hive version 0.12.0. */ private[hive] object HiveShim { val version = "0.12.0" - val metastoreDecimal = "decimal" def getTableDesc( serdeClass: Class[_ <: Deserializer], @@ -149,6 +152,19 @@ private[hive] object HiveShim { def setLocation(tbl: Table, crtTbl: CreateTableDesc): Unit = { tbl.setDataLocation(new Path(crtTbl.getLocation()).toUri()) } + + def decimalMetastoreString(decimalType: DecimalType): String = "decimal" + + def decimalTypeInfo(decimalType: DecimalType): TypeInfo = + TypeInfoFactory.decimalTypeInfo + + def decimalTypeInfoToCatalyst(inspector: PrimitiveObjectInspector): DecimalType = { + DecimalType.Unlimited + } + + def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = { + Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue()) + } } class ShimFileSinkDesc(var dir: String, var tableInfo: TableDesc, var compressed: Boolean) http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala ---------------------------------------------------------------------- diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala index 42cd65b..0bc330c 100644 --- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala +++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala @@ -29,15 +29,15 @@ import org.apache.hadoop.hive.ql.Context import org.apache.hadoop.hive.ql.metadata.{Table, Hive, Partition} import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc} import org.apache.hadoop.hive.ql.processors.CommandProcessorFactory -import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory -import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Deserializer} -import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory -import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector +import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, DecimalTypeInfo, TypeInfoFactory} +import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory} +import org.apache.hadoop.hive.serde2.objectinspector.{PrimitiveObjectInspector, ObjectInspector} import org.apache.hadoop.hive.serde2.{Deserializer, ColumnProjectionUtils} import org.apache.hadoop.hive.serde2.{io => hiveIo} import org.apache.hadoop.{io => hadoopIo} import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.types.DecimalType +import org.apache.spark.sql.catalyst.types.decimal.Decimal import scala.collection.JavaConversions._ import scala.language.implicitConversions @@ -47,11 +47,6 @@ import scala.language.implicitConversions */ private[hive] object HiveShim { val version = "0.13.1" - /* - * TODO: hive-0.13 support DECIMAL(precision, scale), DECIMAL in hive-0.12 is actually DECIMAL(38,unbounded) - * Full support of new decimal feature need to be fixed in seperate PR. - */ - val metastoreDecimal = "decimal\\((\\d+),(\\d+)\\)".r def getTableDesc( serdeClass: Class[_ <: Deserializer], @@ -197,6 +192,30 @@ private[hive] object HiveShim { f.setDestTableId(w.destTableId) f } + + // Precision and scale to pass for unlimited decimals; these are the same as the precision and + // scale Hive 0.13 infers for BigDecimals from sources that don't specify them (e.g. UDFs) + private val UNLIMITED_DECIMAL_PRECISION = 38 + private val UNLIMITED_DECIMAL_SCALE = 18 + + def decimalMetastoreString(decimalType: DecimalType): String = decimalType match { + case DecimalType.Fixed(precision, scale) => s"decimal($precision,$scale)" + case _ => s"decimal($UNLIMITED_DECIMAL_PRECISION,$UNLIMITED_DECIMAL_SCALE)" + } + + def decimalTypeInfo(decimalType: DecimalType): TypeInfo = decimalType match { + case DecimalType.Fixed(precision, scale) => new DecimalTypeInfo(precision, scale) + case _ => new DecimalTypeInfo(UNLIMITED_DECIMAL_PRECISION, UNLIMITED_DECIMAL_SCALE) + } + + def decimalTypeInfoToCatalyst(inspector: PrimitiveObjectInspector): DecimalType = { + val info = inspector.getTypeInfo.asInstanceOf[DecimalTypeInfo] + DecimalType(info.precision(), info.scale()) + } + + def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = { + Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue(), hdoi.precision(), hdoi.scale()) + } } /* --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
