http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 5657bc5..6bfa0db 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp} import scala.collection.immutable.HashSet +import org.apache.spark.sql.catalyst.types.decimal.Decimal import org.scalatest.FunSuite import org.scalatest.Matchers._ import org.scalactic.TripleEqualsSupport.Spread @@ -138,7 +139,7 @@ class ExpressionEvaluationSuite extends FunSuite { val actual = try evaluate(expression, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expression", e) } - actual.asInstanceOf[Double] shouldBe expected + actual.asInstanceOf[Double] shouldBe expected } test("IN") { @@ -165,7 +166,7 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(InSet(three, nS, three +: nullS), false) checkEvaluation(InSet(one, hS, one +: s) && InSet(two, hS, two +: s), true) } - + test("MaxOf") { checkEvaluation(MaxOf(1, 2), 2) checkEvaluation(MaxOf(2, 1), 2) @@ -265,9 +266,9 @@ class ExpressionEvaluationSuite extends FunSuite { val ts = Timestamp.valueOf(nts) checkEvaluation("abdef" cast StringType, "abdef") - checkEvaluation("abdef" cast DecimalType, null) + checkEvaluation("abdef" cast DecimalType.Unlimited, null) checkEvaluation("abdef" cast TimestampType, null) - checkEvaluation("12.65" cast DecimalType, BigDecimal(12.65)) + checkEvaluation("12.65" cast DecimalType.Unlimited, Decimal(12.65)) checkEvaluation(Literal(1) cast LongType, 1) checkEvaluation(Cast(Literal(1000) cast TimestampType, LongType), 1.toLong) @@ -289,12 +290,12 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Cast(Cast(Cast(Cast( Cast("5" cast ByteType, ShortType), IntegerType), FloatType), DoubleType), LongType), 5) - checkEvaluation(Cast(Cast(Cast(Cast( - Cast("5" cast ByteType, TimestampType), DecimalType), LongType), StringType), ShortType), 0) - checkEvaluation(Cast(Cast(Cast(Cast( - Cast("5" cast TimestampType, ByteType), DecimalType), LongType), StringType), ShortType), null) - checkEvaluation(Cast(Cast(Cast(Cast( - Cast("5" cast DecimalType, ByteType), TimestampType), LongType), StringType), ShortType), 0) + checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast + ByteType, TimestampType), DecimalType.Unlimited), LongType), StringType), ShortType), 0) + checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast + TimestampType, ByteType), DecimalType.Unlimited), LongType), StringType), ShortType), null) + checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast + DecimalType.Unlimited, ByteType), TimestampType), LongType), StringType), ShortType), 0) checkEvaluation(Literal(true) cast IntegerType, 1) checkEvaluation(Literal(false) cast IntegerType, 0) checkEvaluation(Cast(Literal(1) cast BooleanType, IntegerType), 1) @@ -302,7 +303,7 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation("23" cast DoubleType, 23d) checkEvaluation("23" cast IntegerType, 23) checkEvaluation("23" cast FloatType, 23f) - checkEvaluation("23" cast DecimalType, 23: BigDecimal) + checkEvaluation("23" cast DecimalType.Unlimited, Decimal(23)) checkEvaluation("23" cast ByteType, 23.toByte) checkEvaluation("23" cast ShortType, 23.toShort) checkEvaluation("2012-12-11" cast DoubleType, null) @@ -311,7 +312,7 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Literal(23d) + Cast(true, DoubleType), 24d) checkEvaluation(Literal(23) + Cast(true, IntegerType), 24) checkEvaluation(Literal(23f) + Cast(true, FloatType), 24f) - checkEvaluation(Literal(BigDecimal(23)) + Cast(true, DecimalType), 24: BigDecimal) + checkEvaluation(Literal(Decimal(23)) + Cast(true, DecimalType.Unlimited), Decimal(24)) checkEvaluation(Literal(23.toByte) + Cast(true, ByteType), 24.toByte) checkEvaluation(Literal(23.toShort) + Cast(true, ShortType), 24.toShort) @@ -325,7 +326,8 @@ class ExpressionEvaluationSuite extends FunSuite { assert(("abcdef" cast IntegerType).nullable === true) assert(("abcdef" cast ShortType).nullable === true) assert(("abcdef" cast ByteType).nullable === true) - assert(("abcdef" cast DecimalType).nullable === true) + assert(("abcdef" cast DecimalType.Unlimited).nullable === true) + assert(("abcdef" cast DecimalType(4, 2)).nullable === true) assert(("abcdef" cast DoubleType).nullable === true) assert(("abcdef" cast FloatType).nullable === true) @@ -338,6 +340,64 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Literal(d1) < Literal(d2), true) } + test("casting to fixed-precision decimals") { + // Overflow and rounding for casting to fixed-precision decimals: + // - Values should round with HALF_UP mode by default when you lower scale + // - Values that would overflow the target precision should turn into null + // - Because of this, casts to fixed-precision decimals should be nullable + + assert(Cast(Literal(123), DecimalType.Unlimited).nullable === false) + assert(Cast(Literal(10.03f), DecimalType.Unlimited).nullable === false) + assert(Cast(Literal(10.03), DecimalType.Unlimited).nullable === false) + assert(Cast(Literal(Decimal(10.03)), DecimalType.Unlimited).nullable === false) + + assert(Cast(Literal(123), DecimalType(2, 1)).nullable === true) + assert(Cast(Literal(10.03f), DecimalType(2, 1)).nullable === true) + assert(Cast(Literal(10.03), DecimalType(2, 1)).nullable === true) + assert(Cast(Literal(Decimal(10.03)), DecimalType(2, 1)).nullable === true) + + checkEvaluation(Cast(Literal(123), DecimalType.Unlimited), Decimal(123)) + checkEvaluation(Cast(Literal(123), DecimalType(3, 0)), Decimal(123)) + checkEvaluation(Cast(Literal(123), DecimalType(3, 1)), null) + checkEvaluation(Cast(Literal(123), DecimalType(2, 0)), null) + + checkEvaluation(Cast(Literal(10.03), DecimalType.Unlimited), Decimal(10.03)) + checkEvaluation(Cast(Literal(10.03), DecimalType(4, 2)), Decimal(10.03)) + checkEvaluation(Cast(Literal(10.03), DecimalType(3, 1)), Decimal(10.0)) + checkEvaluation(Cast(Literal(10.03), DecimalType(2, 0)), Decimal(10)) + checkEvaluation(Cast(Literal(10.03), DecimalType(1, 0)), null) + checkEvaluation(Cast(Literal(10.03), DecimalType(2, 1)), null) + checkEvaluation(Cast(Literal(10.03), DecimalType(3, 2)), null) + checkEvaluation(Cast(Literal(Decimal(10.03)), DecimalType(3, 1)), Decimal(10.0)) + checkEvaluation(Cast(Literal(Decimal(10.03)), DecimalType(3, 2)), null) + + checkEvaluation(Cast(Literal(10.05), DecimalType.Unlimited), Decimal(10.05)) + checkEvaluation(Cast(Literal(10.05), DecimalType(4, 2)), Decimal(10.05)) + checkEvaluation(Cast(Literal(10.05), DecimalType(3, 1)), Decimal(10.1)) + checkEvaluation(Cast(Literal(10.05), DecimalType(2, 0)), Decimal(10)) + checkEvaluation(Cast(Literal(10.05), DecimalType(1, 0)), null) + checkEvaluation(Cast(Literal(10.05), DecimalType(2, 1)), null) + checkEvaluation(Cast(Literal(10.05), DecimalType(3, 2)), null) + checkEvaluation(Cast(Literal(Decimal(10.05)), DecimalType(3, 1)), Decimal(10.1)) + checkEvaluation(Cast(Literal(Decimal(10.05)), DecimalType(3, 2)), null) + + checkEvaluation(Cast(Literal(9.95), DecimalType(3, 2)), Decimal(9.95)) + checkEvaluation(Cast(Literal(9.95), DecimalType(3, 1)), Decimal(10.0)) + checkEvaluation(Cast(Literal(9.95), DecimalType(2, 0)), Decimal(10)) + checkEvaluation(Cast(Literal(9.95), DecimalType(2, 1)), null) + checkEvaluation(Cast(Literal(9.95), DecimalType(1, 0)), null) + checkEvaluation(Cast(Literal(Decimal(9.95)), DecimalType(3, 1)), Decimal(10.0)) + checkEvaluation(Cast(Literal(Decimal(9.95)), DecimalType(1, 0)), null) + + checkEvaluation(Cast(Literal(-9.95), DecimalType(3, 2)), Decimal(-9.95)) + checkEvaluation(Cast(Literal(-9.95), DecimalType(3, 1)), Decimal(-10.0)) + checkEvaluation(Cast(Literal(-9.95), DecimalType(2, 0)), Decimal(-10)) + checkEvaluation(Cast(Literal(-9.95), DecimalType(2, 1)), null) + checkEvaluation(Cast(Literal(-9.95), DecimalType(1, 0)), null) + checkEvaluation(Cast(Literal(Decimal(-9.95)), DecimalType(3, 1)), Decimal(-10.0)) + checkEvaluation(Cast(Literal(Decimal(-9.95)), DecimalType(1, 0)), null) + } + test("timestamp") { val ts1 = new Timestamp(12) val ts2 = new Timestamp(123) @@ -374,7 +434,7 @@ class ExpressionEvaluationSuite extends FunSuite { millis.toFloat / 1000) checkEvaluation(Cast(Cast(millis.toDouble / 1000, TimestampType), DoubleType), millis.toDouble / 1000) - checkEvaluation(Cast(Literal(BigDecimal(1)) cast TimestampType, DecimalType), 1) + checkEvaluation(Cast(Literal(Decimal(1)) cast TimestampType, DecimalType.Unlimited), Decimal(1)) // A test for higher precision than millis checkEvaluation(Cast(Cast(0.00000001, TimestampType), DoubleType), 0.00000001) @@ -673,7 +733,7 @@ class ExpressionEvaluationSuite extends FunSuite { val expectedResults = inputSequence.map(l => math.sqrt(l.toDouble)) val rowSequence = inputSequence.map(l => new GenericRow(Array[Any](l.toDouble))) val d = 'a.double.at(0) - + for ((row, expected) <- rowSequence zip expectedResults) { checkEvaluation(Sqrt(d), expected, row) }
http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/types/decimal/DecimalSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/types/decimal/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/types/decimal/DecimalSuite.scala new file mode 100644 index 0000000..5aa2634 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/types/decimal/DecimalSuite.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.types.decimal + +import org.scalatest.{PrivateMethodTester, FunSuite} + +import scala.language.postfixOps + +class DecimalSuite extends FunSuite with PrivateMethodTester { + test("creating decimals") { + /** Check that a Decimal has the given string representation, precision and scale */ + def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): Unit = { + assert(d.toString === string) + assert(d.precision === precision) + assert(d.scale === scale) + } + + checkDecimal(new Decimal(), "0", 1, 0) + checkDecimal(Decimal(BigDecimal("10.030")), "10.030", 5, 3) + checkDecimal(Decimal(BigDecimal("10.030"), 4, 1), "10.0", 4, 1) + checkDecimal(Decimal(BigDecimal("-9.95"), 4, 1), "-10.0", 4, 1) + checkDecimal(Decimal("10.030"), "10.030", 5, 3) + checkDecimal(Decimal(10.03), "10.03", 4, 2) + checkDecimal(Decimal(17L), "17", 20, 0) + checkDecimal(Decimal(17), "17", 10, 0) + checkDecimal(Decimal(17L, 2, 1), "1.7", 2, 1) + checkDecimal(Decimal(170L, 4, 2), "1.70", 4, 2) + checkDecimal(Decimal(17L, 24, 1), "1.7", 24, 1) + checkDecimal(Decimal(1e17.toLong, 18, 0), 1e17.toLong.toString, 18, 0) + checkDecimal(Decimal(Long.MaxValue), Long.MaxValue.toString, 20, 0) + checkDecimal(Decimal(Long.MinValue), Long.MinValue.toString, 20, 0) + intercept[IllegalArgumentException](Decimal(170L, 2, 1)) + intercept[IllegalArgumentException](Decimal(170L, 2, 0)) + intercept[IllegalArgumentException](Decimal(BigDecimal("10.030"), 2, 1)) + intercept[IllegalArgumentException](Decimal(BigDecimal("-9.95"), 2, 1)) + intercept[IllegalArgumentException](Decimal(1e17.toLong, 17, 0)) + } + + test("double and long values") { + /** Check that a Decimal converts to the given double and long values */ + def checkValues(d: Decimal, doubleValue: Double, longValue: Long): Unit = { + assert(d.toDouble === doubleValue) + assert(d.toLong === longValue) + } + + checkValues(new Decimal(), 0.0, 0L) + checkValues(Decimal(BigDecimal("10.030")), 10.03, 10L) + checkValues(Decimal(BigDecimal("10.030"), 4, 1), 10.0, 10L) + checkValues(Decimal(BigDecimal("-9.95"), 4, 1), -10.0, -10L) + checkValues(Decimal(10.03), 10.03, 10L) + checkValues(Decimal(17L), 17.0, 17L) + checkValues(Decimal(17), 17.0, 17L) + checkValues(Decimal(17L, 2, 1), 1.7, 1L) + checkValues(Decimal(170L, 4, 2), 1.7, 1L) + checkValues(Decimal(1e16.toLong), 1e16, 1e16.toLong) + checkValues(Decimal(1e17.toLong), 1e17, 1e17.toLong) + checkValues(Decimal(1e18.toLong), 1e18, 1e18.toLong) + checkValues(Decimal(2e18.toLong), 2e18, 2e18.toLong) + checkValues(Decimal(Long.MaxValue), Long.MaxValue.toDouble, Long.MaxValue) + checkValues(Decimal(Long.MinValue), Long.MinValue.toDouble, Long.MinValue) + checkValues(Decimal(Double.MaxValue), Double.MaxValue, 0L) + checkValues(Decimal(Double.MinValue), Double.MinValue, 0L) + } + + // Accessor for the BigDecimal value of a Decimal, which will be null if it's using Longs + private val decimalVal = PrivateMethod[BigDecimal]('decimalVal) + + /** Check whether a decimal is represented compactly (passing whether we expect it to be) */ + private def checkCompact(d: Decimal, expected: Boolean): Unit = { + val isCompact = d.invokePrivate(decimalVal()).eq(null) + assert(isCompact == expected, s"$d ${if (expected) "was not" else "was"} compact") + } + + test("small decimals represented as unscaled long") { + checkCompact(new Decimal(), true) + checkCompact(Decimal(BigDecimal(10.03)), false) + checkCompact(Decimal(BigDecimal(1e20)), false) + checkCompact(Decimal(17L), true) + checkCompact(Decimal(17), true) + checkCompact(Decimal(17L, 2, 1), true) + checkCompact(Decimal(170L, 4, 2), true) + checkCompact(Decimal(17L, 24, 1), true) + checkCompact(Decimal(1e16.toLong), true) + checkCompact(Decimal(1e17.toLong), true) + checkCompact(Decimal(1e18.toLong - 1), true) + checkCompact(Decimal(- 1e18.toLong + 1), true) + checkCompact(Decimal(1e18.toLong - 1, 30, 10), true) + checkCompact(Decimal(- 1e18.toLong + 1, 30, 10), true) + checkCompact(Decimal(1e18.toLong), false) + checkCompact(Decimal(-1e18.toLong), false) + checkCompact(Decimal(1e18.toLong, 30, 10), false) + checkCompact(Decimal(-1e18.toLong, 30, 10), false) + checkCompact(Decimal(Long.MaxValue), false) + checkCompact(Decimal(Long.MinValue), false) + } + + test("hash code") { + assert(Decimal(123).hashCode() === (123).##) + assert(Decimal(-123).hashCode() === (-123).##) + assert(Decimal(123.312).hashCode() === (123.312).##) + assert(Decimal(Int.MaxValue).hashCode() === Int.MaxValue.##) + assert(Decimal(Long.MaxValue).hashCode() === Long.MaxValue.##) + assert(Decimal(BigDecimal(123)).hashCode() === (123).##) + + val reallyBig = BigDecimal("123182312312313232112312312123.1231231231") + assert(Decimal(reallyBig).hashCode() === reallyBig.hashCode) + } + + test("equals") { + // The decimals on the left are stored compactly, while the ones on the right aren't + checkCompact(Decimal(123), true) + checkCompact(Decimal(BigDecimal(123)), false) + checkCompact(Decimal("123"), false) + assert(Decimal(123) === Decimal(BigDecimal(123))) + assert(Decimal(123) === Decimal(BigDecimal("123.00"))) + assert(Decimal(-123) === Decimal(BigDecimal(-123))) + assert(Decimal(-123) === Decimal(BigDecimal("-123.00"))) + } + + test("isZero") { + assert(Decimal(0).isZero) + assert(Decimal(0, 4, 2).isZero) + assert(Decimal("0").isZero) + assert(Decimal("0.000").isZero) + assert(!Decimal(1).isZero) + assert(!Decimal(1, 4, 2).isZero) + assert(!Decimal("1").isZero) + assert(!Decimal("0.001").isZero) + } + + test("arithmetic") { + assert(Decimal(100) + Decimal(-100) === Decimal(0)) + assert(Decimal(100) + Decimal(-100) === Decimal(0)) + assert(Decimal(100) * Decimal(-100) === Decimal(-10000)) + assert(Decimal(1e13) * Decimal(1e13) === Decimal(1e26)) + assert(Decimal(100) / Decimal(-100) === Decimal(-1)) + assert(Decimal(100) / Decimal(0) === null) + assert(Decimal(100) % Decimal(-100) === Decimal(0)) + assert(Decimal(100) % Decimal(3) === Decimal(1)) + assert(Decimal(-100) % Decimal(3) === Decimal(-1)) + assert(Decimal(100) % Decimal(0) === null) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java ---------------------------------------------------------------------- diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java index 0c85cdc..c383540 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java @@ -53,11 +53,6 @@ public abstract class DataType { public static final TimestampType TimestampType = new TimestampType(); /** - * Gets the DecimalType object. - */ - public static final DecimalType DecimalType = new DecimalType(); - - /** * Gets the DoubleType object. */ public static final DoubleType DoubleType = new DoubleType(); http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java ---------------------------------------------------------------------- diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java index bc54c07..6075245 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java @@ -19,9 +19,61 @@ package org.apache.spark.sql.api.java; /** * The data type representing java.math.BigDecimal values. - * - * {@code DecimalType} is represented by the singleton object {@link DataType#DecimalType}. */ public class DecimalType extends DataType { - protected DecimalType() {} + private boolean hasPrecisionInfo; + private int precision; + private int scale; + + public DecimalType(int precision, int scale) { + this.hasPrecisionInfo = true; + this.precision = precision; + this.scale = scale; + } + + public DecimalType() { + this.hasPrecisionInfo = false; + this.precision = -1; + this.scale = -1; + } + + public boolean isUnlimited() { + return !hasPrecisionInfo; + } + + public boolean isFixed() { + return hasPrecisionInfo; + } + + /** Return the precision, or -1 if no precision is set */ + public int getPrecision() { + return precision; + } + + /** Return the scale, or -1 if no precision is set */ + public int getScale() { + return scale; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + DecimalType that = (DecimalType) o; + + if (hasPrecisionInfo != that.hasPrecisionInfo) return false; + if (precision != that.precision) return false; + if (scale != that.scale) return false; + + return true; + } + + @Override + public int hashCode() { + int result = (hasPrecisionInfo ? 1 : 0); + result = 31 * result + precision; + result = 31 * result + scale; + return result; + } } http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 8b96df1..018a18c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import java.util.{Map => JMap, List => JList} +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.storage.StorageLevel import scala.collection.JavaConversions._ @@ -113,7 +114,7 @@ class SchemaRDD( // ========================================================================================= override def compute(split: Partition, context: TaskContext): Iterator[Row] = - firstParent[Row].compute(split, context).map(_.copy()) + firstParent[Row].compute(split, context).map(ScalaReflection.convertRowToScala) override def getPartitions: Array[Partition] = firstParent[Row].partitions http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index 082ae03..876b1c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -230,7 +230,7 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { case c: Class[_] if c == classOf[java.lang.Boolean] => (org.apache.spark.sql.BooleanType, true) case c: Class[_] if c == classOf[java.math.BigDecimal] => - (org.apache.spark.sql.DecimalType, true) + (org.apache.spark.sql.DecimalType(), true) case c: Class[_] if c == classOf[java.sql.Date] => (org.apache.spark.sql.DateType, true) case c: Class[_] if c == classOf[java.sql.Timestamp] => http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala index df01411..401798e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.api.java +import org.apache.spark.sql.catalyst.types.decimal.Decimal + import scala.annotation.varargs import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper} import scala.collection.JavaConversions @@ -106,6 +108,8 @@ class Row(private[spark] val row: ScalaRow) extends Serializable { } override def hashCode(): Int = row.hashCode() + + override def toString: String = row.toString } object Row { http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index b3edd50..087b0ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -70,16 +70,29 @@ case class GeneratedAggregate( val computeFunctions = aggregatesToCompute.map { case c @ Count(expr) => + // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its + // UnscaledValue will be null if and only if x is null; helps with Average on decimals + val toCount = expr match { + case UnscaledValue(e) => e + case _ => expr + } val currentCount = AttributeReference("currentCount", LongType, nullable = false)() val initialValue = Literal(0L) - val updateFunction = If(IsNotNull(expr), Add(currentCount, Literal(1L)), currentCount) + val updateFunction = If(IsNotNull(toCount), Add(currentCount, Literal(1L)), currentCount) val result = currentCount AggregateEvaluation(currentCount :: Nil, initialValue :: Nil, updateFunction :: Nil, result) case Sum(expr) => - val currentSum = AttributeReference("currentSum", expr.dataType, nullable = false)() - val initialValue = Cast(Literal(0L), expr.dataType) + val resultType = expr.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 10, scale) + case _ => + expr.dataType + } + + val currentSum = AttributeReference("currentSum", resultType, nullable = false)() + val initialValue = Cast(Literal(0L), resultType) // Coalasce avoids double calculation... // but really, common sub expression elimination would be better.... @@ -93,10 +106,26 @@ case class GeneratedAggregate( val currentSum = AttributeReference("currentSum", expr.dataType, nullable = false)() val initialCount = Literal(0L) val initialSum = Cast(Literal(0L), expr.dataType) - val updateCount = If(IsNotNull(expr), Add(currentCount, Literal(1L)), currentCount) + + // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its + // UnscaledValue will be null if and only if x is null; helps with Average on decimals + val toCount = expr match { + case UnscaledValue(e) => e + case _ => expr + } + + val updateCount = If(IsNotNull(toCount), Add(currentCount, Literal(1L)), currentCount) val updateSum = Coalesce(Add(expr, currentSum) :: currentSum :: Nil) - val result = Divide(Cast(currentSum, DoubleType), Cast(currentCount, DoubleType)) + val resultType = expr.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 4, scale + 4) + case DecimalType.Unlimited => + DecimalType.Unlimited + case _ => + DoubleType + } + val result = Divide(Cast(currentSum, resultType), Cast(currentCount, resultType)) AggregateEvaluation( currentCount :: currentSum :: Nil, @@ -142,7 +171,7 @@ case class GeneratedAggregate( val computationSchema = computeFunctions.flatMap(_.schema) - val resultMap: Map[TreeNodeRef, Expression] = + val resultMap: Map[TreeNodeRef, Expression] = aggregatesToCompute.zip(computeFunctions).map { case (agg, func) => new TreeNodeRef(agg) -> func.result }.toMap http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index b1a7948..aafcce0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -23,7 +23,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.catalyst.{ScalaReflection, trees} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -82,7 +82,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** * Runs this query returning the result as an array. */ - def executeCollect(): Array[Row] = execute().map(_.copy()).collect() + def executeCollect(): Array[Row] = execute().map(ScalaReflection.convertRowToScala).collect() protected def newProjection( expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = { http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index 077e6eb..84d96e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -29,6 +29,7 @@ import com.twitter.chill.{AllScalaRegistrar, ResourcePool} import org.apache.spark.{SparkEnv, SparkConf} import org.apache.spark.serializer.{SerializerInstance, KryoSerializer} import org.apache.spark.sql.catalyst.expressions.GenericRow +import org.apache.spark.sql.catalyst.types.decimal.Decimal import org.apache.spark.util.collection.OpenHashSet import org.apache.spark.util.MutablePair import org.apache.spark.util.Utils @@ -51,6 +52,7 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co kryo.register(classOf[LongHashSet], new LongHashSetSerializer) kryo.register(classOf[org.apache.spark.util.collection.OpenHashSet[_]], new OpenHashSetSerializer) + kryo.register(classOf[Decimal]) kryo.setReferences(false) kryo.setClassLoader(Utils.getSparkClassLoader) http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 977f3c9..e6cd1a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -143,7 +143,7 @@ case class Limit(limit: Int, child: SparkPlan) partsScanned += numPartsToTry } - buf.toArray + buf.toArray.map(ScalaReflection.convertRowToScala) } override def execute() = { @@ -176,10 +176,11 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) override def output = child.output override def outputPartitioning = SinglePartition - val ordering = new RowOrdering(sortOrder, child.output) + val ord = new RowOrdering(sortOrder, child.output) // TODO: Is this copying for no reason? - override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ordering) + override def executeCollect() = + child.execute().map(_.copy()).takeOrdered(limit)(ord).map(ScalaReflection.convertRowToScala) // TODO: Terminal split should be implemented differently from non-terminal split. // TODO: Pick num splits based on |limit|. http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 8fd3588..5cf2a78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -49,7 +49,8 @@ case class BroadcastHashJoin( @transient private val broadcastFuture = future { - val input: Array[Row] = buildPlan.executeCollect() + // Note that we use .execute().collect() because we don't want to convert data to Scala types + val input: Array[Row] = buildPlan.execute().map(_.copy()).collect() val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.length) sparkContext.broadcast(hashed) } http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index a1961bb..9976690 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution import java.util.{List => JList, Map => JMap} +import org.apache.spark.sql.catalyst.types.decimal.Decimal + import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ @@ -116,7 +118,7 @@ object EvaluatePython { def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { case (null, _) => null - case (row: Row, struct: StructType) => + case (row: Seq[Any], struct: StructType) => val fields = struct.fields.map(field => field.dataType) row.zip(fields).map { case (obj, dataType) => toJava(obj, dataType) @@ -133,6 +135,8 @@ object EvaluatePython { case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type }.asJava + case (dec: BigDecimal, dt: DecimalType) => dec.underlying() // Pyrolite can handle BigDecimal + // Pyrolite can handle Timestamp case (other, _) => other } http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index eabe312..5bb6f6c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.json +import org.apache.spark.sql.catalyst.types.decimal.Decimal + import scala.collection.Map import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper} import scala.math.BigDecimal @@ -175,9 +177,9 @@ private[sql] object JsonRDD extends Logging { ScalaReflection.typeOfObject orElse { // Since we do not have a data type backed by BigInteger, // when we see a Java BigInteger, we use DecimalType. - case value: java.math.BigInteger => DecimalType + case value: java.math.BigInteger => DecimalType.Unlimited // DecimalType's JVMType is scala BigDecimal. - case value: java.math.BigDecimal => DecimalType + case value: java.math.BigDecimal => DecimalType.Unlimited // Unexpected data type. case _ => StringType } @@ -319,13 +321,13 @@ private[sql] object JsonRDD extends Logging { } } - private def toDecimal(value: Any): BigDecimal = { + private def toDecimal(value: Any): Decimal = { value match { - case value: java.lang.Integer => BigDecimal(value) - case value: java.lang.Long => BigDecimal(value) - case value: java.math.BigInteger => BigDecimal(value) - case value: java.lang.Double => BigDecimal(value) - case value: java.math.BigDecimal => BigDecimal(value) + case value: java.lang.Integer => Decimal(value) + case value: java.lang.Long => Decimal(value) + case value: java.math.BigInteger => Decimal(BigDecimal(value)) + case value: java.lang.Double => Decimal(value) + case value: java.math.BigDecimal => Decimal(BigDecimal(value)) } } @@ -391,7 +393,7 @@ private[sql] object JsonRDD extends Logging { case IntegerType => value.asInstanceOf[IntegerType.JvmType] case LongType => toLong(value) case DoubleType => toDouble(value) - case DecimalType => toDecimal(value) + case DecimalType() => toDecimal(value) case BooleanType => value.asInstanceOf[BooleanType.JvmType] case NullType => null http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/main/scala/org/apache/spark/sql/package.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index f0e57e2..05926a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -183,6 +183,20 @@ package object sql { * * The data type representing `scala.math.BigDecimal` values. * + * TODO(matei): explain precision and scale + * + * @group dataType + */ + @DeveloperApi + type DecimalType = catalyst.types.DecimalType + + /** + * :: DeveloperApi :: + * + * The data type representing `scala.math.BigDecimal` values. + * + * TODO(matei): explain precision and scale + * * @group dataType */ @DeveloperApi http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 2fc7e1c..08feced 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.parquet +import org.apache.spark.sql.catalyst.types.decimal.Decimal + import scala.collection.mutable.{Buffer, ArrayBuffer, HashMap} import parquet.io.api.{PrimitiveConverter, GroupConverter, Binary, Converter} @@ -117,6 +119,12 @@ private[sql] object CatalystConverter { parent.updateByte(fieldIndex, value.asInstanceOf[ByteType.JvmType]) } } + case d: DecimalType => { + new CatalystPrimitiveConverter(parent, fieldIndex) { + override def addBinary(value: Binary): Unit = + parent.updateDecimal(fieldIndex, value, d) + } + } // All other primitive types use the default converter case ctype: PrimitiveType => { // note: need the type tag here! new CatalystPrimitiveConverter(parent, fieldIndex) @@ -191,6 +199,10 @@ private[parquet] abstract class CatalystConverter extends GroupConverter { protected[parquet] def updateString(fieldIndex: Int, value: Binary): Unit = updateField(fieldIndex, value.toStringUsingUTF8) + protected[parquet] def updateDecimal(fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = { + updateField(fieldIndex, readDecimal(new Decimal(), value, ctype)) + } + protected[parquet] def isRootConverter: Boolean = parent == null protected[parquet] def clearBuffer(): Unit @@ -201,6 +213,27 @@ private[parquet] abstract class CatalystConverter extends GroupConverter { * @return */ def getCurrentRecord: Row = throw new UnsupportedOperationException + + /** + * Read a decimal value from a Parquet Binary into "dest". Only supports decimals that fit in + * a long (i.e. precision <= 18) + */ + protected[parquet] def readDecimal(dest: Decimal, value: Binary, ctype: DecimalType): Unit = { + val precision = ctype.precisionInfo.get.precision + val scale = ctype.precisionInfo.get.scale + val bytes = value.getBytes + require(bytes.length <= 16, "Decimal field too large to read") + var unscaled = 0L + var i = 0 + while (i < bytes.length) { + unscaled = (unscaled << 8) | (bytes(i) & 0xFF) + i += 1 + } + // Make sure unscaled has the right sign, by sign-extending the first bit + val numBits = 8 * bytes.length + unscaled = (unscaled << (64 - numBits)) >> (64 - numBits) + dest.set(unscaled, precision, scale) + } } /** @@ -352,6 +385,16 @@ private[parquet] class CatalystPrimitiveRowConverter( override protected[parquet] def updateString(fieldIndex: Int, value: Binary): Unit = current.setString(fieldIndex, value.toStringUsingUTF8) + + override protected[parquet] def updateDecimal( + fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = { + var decimal = current(fieldIndex).asInstanceOf[Decimal] + if (decimal == null) { + decimal = new Decimal + current(fieldIndex) = decimal + } + readDecimal(decimal, value, ctype) + } } /** http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index bdf0240..2a5f23b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.parquet import java.util.{HashMap => JHashMap} import org.apache.hadoop.conf.Configuration +import org.apache.spark.sql.catalyst.types.decimal.Decimal import parquet.column.ParquetProperties import parquet.hadoop.ParquetOutputFormat import parquet.hadoop.api.ReadSupport.ReadContext @@ -204,6 +205,11 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { case DoubleType => writer.addDouble(value.asInstanceOf[Double]) case FloatType => writer.addFloat(value.asInstanceOf[Float]) case BooleanType => writer.addBoolean(value.asInstanceOf[Boolean]) + case d: DecimalType => + if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) { + sys.error(s"Unsupported datatype $d, cannot write to consumer") + } + writeDecimal(value.asInstanceOf[Decimal], d.precisionInfo.get.precision) case _ => sys.error(s"Do not know how to writer $schema to consumer") } } @@ -283,6 +289,23 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { } writer.endGroup() } + + // Scratch array used to write decimals as fixed-length binary + private val scratchBytes = new Array[Byte](8) + + private[parquet] def writeDecimal(decimal: Decimal, precision: Int): Unit = { + val numBytes = ParquetTypesConverter.BYTES_FOR_PRECISION(precision) + val unscaledLong = decimal.toUnscaledLong + var i = 0 + var shift = 8 * (numBytes - 1) + while (i < numBytes) { + scratchBytes(i) = (unscaledLong >> shift).toByte + i += 1 + shift -= 8 + } + writer.addBinary(Binary.fromByteArray(scratchBytes, 0, numBytes)) + } + } // Optimized for non-nested rows @@ -326,6 +349,11 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { case DoubleType => writer.addDouble(record.getDouble(index)) case FloatType => writer.addFloat(record.getFloat(index)) case BooleanType => writer.addBoolean(record.getBoolean(index)) + case d: DecimalType => + if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) { + sys.error(s"Unsupported datatype $d, cannot write to consumer") + } + writeDecimal(record(index).asInstanceOf[Decimal], d.precisionInfo.get.precision) case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer") } } http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index e6389cf..e5077de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -29,8 +29,8 @@ import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter import parquet.hadoop.{ParquetFileReader, Footer, ParquetFileWriter} import parquet.hadoop.metadata.{ParquetMetadata, FileMetaData} import parquet.hadoop.util.ContextUtil -import parquet.schema.{Type => ParquetType, PrimitiveType => ParquetPrimitiveType, MessageType} -import parquet.schema.{GroupType => ParquetGroupType, OriginalType => ParquetOriginalType, ConversionPatterns} +import parquet.schema.{Type => ParquetType, Types => ParquetTypes, PrimitiveType => ParquetPrimitiveType, MessageType} +import parquet.schema.{GroupType => ParquetGroupType, OriginalType => ParquetOriginalType, ConversionPatterns, DecimalMetadata} import parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeName} import parquet.schema.Type.Repetition @@ -41,17 +41,25 @@ import org.apache.spark.sql.catalyst.types._ // Implicits import scala.collection.JavaConversions._ +/** A class representing Parquet info fields we care about, for passing back to Parquet */ +private[parquet] case class ParquetTypeInfo( + primitiveType: ParquetPrimitiveTypeName, + originalType: Option[ParquetOriginalType] = None, + decimalMetadata: Option[DecimalMetadata] = None, + length: Option[Int] = None) + private[parquet] object ParquetTypesConverter extends Logging { def isPrimitiveType(ctype: DataType): Boolean = classOf[PrimitiveType] isAssignableFrom ctype.getClass def toPrimitiveDataType( parquetType: ParquetPrimitiveType, - binayAsString: Boolean): DataType = + binaryAsString: Boolean): DataType = { + val originalType = parquetType.getOriginalType + val decimalInfo = parquetType.getDecimalMetadata parquetType.getPrimitiveTypeName match { case ParquetPrimitiveTypeName.BINARY - if (parquetType.getOriginalType == ParquetOriginalType.UTF8 || - binayAsString) => StringType + if (originalType == ParquetOriginalType.UTF8 || binaryAsString) => StringType case ParquetPrimitiveTypeName.BINARY => BinaryType case ParquetPrimitiveTypeName.BOOLEAN => BooleanType case ParquetPrimitiveTypeName.DOUBLE => DoubleType @@ -61,9 +69,14 @@ private[parquet] object ParquetTypesConverter extends Logging { case ParquetPrimitiveTypeName.INT96 => // TODO: add BigInteger type? TODO(andre) use DecimalType instead???? sys.error("Potential loss of precision: cannot convert INT96") + case ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY + if (originalType == ParquetOriginalType.DECIMAL && decimalInfo.getPrecision <= 18) => + // TODO: for now, our reader only supports decimals that fit in a Long + DecimalType(decimalInfo.getPrecision, decimalInfo.getScale) case _ => sys.error( s"Unsupported parquet datatype $parquetType") } + } /** * Converts a given Parquet `Type` into the corresponding @@ -183,24 +196,41 @@ private[parquet] object ParquetTypesConverter extends Logging { * is not primitive. * * @param ctype The type to convert - * @return The name of the corresponding Parquet primitive type + * @return The name of the corresponding Parquet type properties */ - def fromPrimitiveDataType(ctype: DataType): - Option[(ParquetPrimitiveTypeName, Option[ParquetOriginalType])] = ctype match { - case StringType => Some(ParquetPrimitiveTypeName.BINARY, Some(ParquetOriginalType.UTF8)) - case BinaryType => Some(ParquetPrimitiveTypeName.BINARY, None) - case BooleanType => Some(ParquetPrimitiveTypeName.BOOLEAN, None) - case DoubleType => Some(ParquetPrimitiveTypeName.DOUBLE, None) - case FloatType => Some(ParquetPrimitiveTypeName.FLOAT, None) - case IntegerType => Some(ParquetPrimitiveTypeName.INT32, None) + def fromPrimitiveDataType(ctype: DataType): Option[ParquetTypeInfo] = ctype match { + case StringType => Some(ParquetTypeInfo( + ParquetPrimitiveTypeName.BINARY, Some(ParquetOriginalType.UTF8))) + case BinaryType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.BINARY)) + case BooleanType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.BOOLEAN)) + case DoubleType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.DOUBLE)) + case FloatType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.FLOAT)) + case IntegerType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32)) // There is no type for Byte or Short so we promote them to INT32. - case ShortType => Some(ParquetPrimitiveTypeName.INT32, None) - case ByteType => Some(ParquetPrimitiveTypeName.INT32, None) - case LongType => Some(ParquetPrimitiveTypeName.INT64, None) + case ShortType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32)) + case ByteType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32)) + case LongType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT64)) + case DecimalType.Fixed(precision, scale) if precision <= 18 => + // TODO: for now, our writer only supports decimals that fit in a Long + Some(ParquetTypeInfo(ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, + Some(ParquetOriginalType.DECIMAL), + Some(new DecimalMetadata(precision, scale)), + Some(BYTES_FOR_PRECISION(precision)))) case _ => None } /** + * Compute the FIXED_LEN_BYTE_ARRAY length needed to represent a given DECIMAL precision. + */ + private[parquet] val BYTES_FOR_PRECISION = Array.tabulate[Int](38) { precision => + var length = 1 + while (math.pow(2.0, 8 * length - 1) < math.pow(10.0, precision)) { + length += 1 + } + length + } + + /** * Converts a given Catalyst [[org.apache.spark.sql.catalyst.types.DataType]] into * the corresponding Parquet `Type`. * @@ -247,10 +277,17 @@ private[parquet] object ParquetTypesConverter extends Logging { } else { if (nullable) Repetition.OPTIONAL else Repetition.REQUIRED } - val primitiveType = fromPrimitiveDataType(ctype) - primitiveType.map { - case (primitiveType, originalType) => - new ParquetPrimitiveType(repetition, primitiveType, name, originalType.orNull) + val typeInfo = fromPrimitiveDataType(ctype) + typeInfo.map { + case ParquetTypeInfo(primitiveType, originalType, decimalMetadata, length) => + val builder = ParquetTypes.primitive(primitiveType, repetition).as(originalType.orNull) + for (len <- length) { + builder.length(len) + } + for (metadata <- decimalMetadata) { + builder.precision(metadata.getPrecision).scale(metadata.getScale) + } + builder.named(name) }.getOrElse { ctype match { case ArrayType(elementType, false) => { http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala index 142598c..7564bf3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.types.util import org.apache.spark.sql._ import org.apache.spark.sql.api.java.{DataType => JDataType, StructField => JStructField, MetadataBuilder => JMetaDataBuilder} +import org.apache.spark.sql.api.java.{DecimalType => JDecimalType} +import org.apache.spark.sql.catalyst.types.decimal.Decimal import scala.collection.JavaConverters._ @@ -44,7 +46,8 @@ protected[sql] object DataTypeConversions { case BooleanType => JDataType.BooleanType case DateType => JDataType.DateType case TimestampType => JDataType.TimestampType - case DecimalType => JDataType.DecimalType + case DecimalType.Fixed(precision, scale) => new JDecimalType(precision, scale) + case DecimalType.Unlimited => new JDecimalType() case DoubleType => JDataType.DoubleType case FloatType => JDataType.FloatType case ByteType => JDataType.ByteType @@ -88,7 +91,11 @@ protected[sql] object DataTypeConversions { case timestampType: org.apache.spark.sql.api.java.TimestampType => TimestampType case decimalType: org.apache.spark.sql.api.java.DecimalType => - DecimalType + if (decimalType.isFixed) { + DecimalType(decimalType.getPrecision, decimalType.getScale) + } else { + DecimalType.Unlimited + } case doubleType: org.apache.spark.sql.api.java.DoubleType => DoubleType case floatType: org.apache.spark.sql.api.java.FloatType => @@ -115,7 +122,7 @@ protected[sql] object DataTypeConversions { /** Converts Java objects to catalyst rows / types */ def convertJavaToCatalyst(a: Any): Any = a match { - case d: java.math.BigDecimal => BigDecimal(d) + case d: java.math.BigDecimal => Decimal(BigDecimal(d)) case other => other } http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java index 9435a88..a04b806 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java @@ -118,7 +118,7 @@ public class JavaApplySchemaSuite implements Serializable { "\"bigInteger\":92233720368547758069, \"double\":1.7976931348623157E305, " + "\"boolean\":false, \"null\":null}")); List<StructField> fields = new ArrayList<StructField>(7); - fields.add(DataType.createStructField("bigInteger", DataType.DecimalType, true)); + fields.add(DataType.createStructField("bigInteger", new DecimalType(), true)); fields.add(DataType.createStructField("boolean", DataType.BooleanType, true)); fields.add(DataType.createStructField("double", DataType.DoubleType, true)); fields.add(DataType.createStructField("integer", DataType.IntegerType, true)); http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java index d04396a..8396a29 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java @@ -41,7 +41,8 @@ public class JavaSideDataTypeConversionSuite { checkDataType(DataType.BooleanType); checkDataType(DataType.DateType); checkDataType(DataType.TimestampType); - checkDataType(DataType.DecimalType); + checkDataType(new DecimalType()); + checkDataType(new DecimalType(10, 4)); checkDataType(DataType.DoubleType); checkDataType(DataType.FloatType); checkDataType(DataType.ByteType); @@ -59,7 +60,7 @@ public class JavaSideDataTypeConversionSuite { // Simple StructType. List<StructField> simpleFields = new ArrayList<StructField>(); - simpleFields.add(DataType.createStructField("a", DataType.DecimalType, false)); + simpleFields.add(DataType.createStructField("a", new DecimalType(), false)); simpleFields.add(DataType.createStructField("b", DataType.BooleanType, true)); simpleFields.add(DataType.createStructField("c", DataType.LongType, true)); simpleFields.add(DataType.createStructField("d", DataType.BinaryType, false)); @@ -128,7 +129,7 @@ public class JavaSideDataTypeConversionSuite { // StructType try { List<StructField> simpleFields = new ArrayList<StructField>(); - simpleFields.add(DataType.createStructField("a", DataType.DecimalType, false)); + simpleFields.add(DataType.createStructField("a", new DecimalType(), false)); simpleFields.add(DataType.createStructField("b", DataType.BooleanType, true)); simpleFields.add(DataType.createStructField("c", DataType.LongType, true)); simpleFields.add(null); @@ -138,7 +139,7 @@ public class JavaSideDataTypeConversionSuite { } try { List<StructField> simpleFields = new ArrayList<StructField>(); - simpleFields.add(DataType.createStructField("a", DataType.DecimalType, false)); + simpleFields.add(DataType.createStructField("a", new DecimalType(), false)); simpleFields.add(DataType.createStructField("a", DataType.BooleanType, true)); simpleFields.add(DataType.createStructField("c", DataType.LongType, true)); DataType.createStructType(simpleFields); http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala index 6c9db63..e9740d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala @@ -69,7 +69,7 @@ class DataTypeSuite extends FunSuite { checkDataTypeJsonRepr(LongType) checkDataTypeJsonRepr(FloatType) checkDataTypeJsonRepr(DoubleType) - checkDataTypeJsonRepr(DecimalType) + checkDataTypeJsonRepr(DecimalType.Unlimited) checkDataTypeJsonRepr(TimestampType) checkDataTypeJsonRepr(StringType) checkDataTypeJsonRepr(BinaryType) http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index bfa9ea4..cf3a59e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} +import org.apache.spark.sql.catalyst.types.decimal.Decimal import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions._ @@ -81,7 +82,9 @@ class ScalaReflectionRelationSuite extends FunSuite { val rdd = sparkContext.parallelize(data :: Nil) rdd.registerTempTable("reflectData") - assert(sql("SELECT * FROM reflectData").collect().head === data.productIterator.toSeq) + assert(sql("SELECT * FROM reflectData").collect().head === + Seq("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, + BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3))) } test("query case class RDD with nulls") { http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala index d83f3e2..c9012c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.api.java +import org.apache.spark.sql.catalyst.types.decimal.Decimal + import scala.beans.BeanProperty import org.scalatest.FunSuite http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala index e0e0ff9..62fe59d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala @@ -38,7 +38,7 @@ class ScalaSideDataTypeConversionSuite extends FunSuite { checkDataType(org.apache.spark.sql.BooleanType) checkDataType(org.apache.spark.sql.DateType) checkDataType(org.apache.spark.sql.TimestampType) - checkDataType(org.apache.spark.sql.DecimalType) + checkDataType(org.apache.spark.sql.DecimalType.Unlimited) checkDataType(org.apache.spark.sql.DoubleType) checkDataType(org.apache.spark.sql.FloatType) checkDataType(org.apache.spark.sql.ByteType) @@ -58,7 +58,7 @@ class ScalaSideDataTypeConversionSuite extends FunSuite { // Simple StructType. val simpleScalaStructType = SStructType( - SStructField("a", org.apache.spark.sql.DecimalType, false) :: + SStructField("a", org.apache.spark.sql.DecimalType.Unlimited, false) :: SStructField("b", org.apache.spark.sql.BooleanType, true) :: SStructField("c", org.apache.spark.sql.LongType, true) :: SStructField("d", org.apache.spark.sql.BinaryType, false) :: Nil) http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index ce6184f..1cb6c23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.json import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.decimal.Decimal import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.json.JsonRDD.{enforceCorrectType, compatibleType} import org.apache.spark.sql.QueryTest @@ -44,19 +45,22 @@ class JsonSuite extends QueryTest { checkTypePromotion(intNumber, enforceCorrectType(intNumber, IntegerType)) checkTypePromotion(intNumber.toLong, enforceCorrectType(intNumber, LongType)) checkTypePromotion(intNumber.toDouble, enforceCorrectType(intNumber, DoubleType)) - checkTypePromotion(BigDecimal(intNumber), enforceCorrectType(intNumber, DecimalType)) + checkTypePromotion( + Decimal(intNumber), enforceCorrectType(intNumber, DecimalType.Unlimited)) val longNumber: Long = 9223372036854775807L checkTypePromotion(longNumber, enforceCorrectType(longNumber, LongType)) checkTypePromotion(longNumber.toDouble, enforceCorrectType(longNumber, DoubleType)) - checkTypePromotion(BigDecimal(longNumber), enforceCorrectType(longNumber, DecimalType)) + checkTypePromotion( + Decimal(longNumber), enforceCorrectType(longNumber, DecimalType.Unlimited)) val doubleNumber: Double = 1.7976931348623157E308d checkTypePromotion(doubleNumber.toDouble, enforceCorrectType(doubleNumber, DoubleType)) - checkTypePromotion(BigDecimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType)) - + checkTypePromotion( + Decimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType.Unlimited)) + checkTypePromotion(new Timestamp(intNumber), enforceCorrectType(intNumber, TimestampType)) - checkTypePromotion(new Timestamp(intNumber.toLong), + checkTypePromotion(new Timestamp(intNumber.toLong), enforceCorrectType(intNumber.toLong, TimestampType)) val strTime = "2014-09-30 12:34:56" checkTypePromotion(Timestamp.valueOf(strTime), enforceCorrectType(strTime, TimestampType)) @@ -80,7 +84,7 @@ class JsonSuite extends QueryTest { checkDataType(NullType, IntegerType, IntegerType) checkDataType(NullType, LongType, LongType) checkDataType(NullType, DoubleType, DoubleType) - checkDataType(NullType, DecimalType, DecimalType) + checkDataType(NullType, DecimalType.Unlimited, DecimalType.Unlimited) checkDataType(NullType, StringType, StringType) checkDataType(NullType, ArrayType(IntegerType), ArrayType(IntegerType)) checkDataType(NullType, StructType(Nil), StructType(Nil)) @@ -91,7 +95,7 @@ class JsonSuite extends QueryTest { checkDataType(BooleanType, IntegerType, StringType) checkDataType(BooleanType, LongType, StringType) checkDataType(BooleanType, DoubleType, StringType) - checkDataType(BooleanType, DecimalType, StringType) + checkDataType(BooleanType, DecimalType.Unlimited, StringType) checkDataType(BooleanType, StringType, StringType) checkDataType(BooleanType, ArrayType(IntegerType), StringType) checkDataType(BooleanType, StructType(Nil), StringType) @@ -100,7 +104,7 @@ class JsonSuite extends QueryTest { checkDataType(IntegerType, IntegerType, IntegerType) checkDataType(IntegerType, LongType, LongType) checkDataType(IntegerType, DoubleType, DoubleType) - checkDataType(IntegerType, DecimalType, DecimalType) + checkDataType(IntegerType, DecimalType.Unlimited, DecimalType.Unlimited) checkDataType(IntegerType, StringType, StringType) checkDataType(IntegerType, ArrayType(IntegerType), StringType) checkDataType(IntegerType, StructType(Nil), StringType) @@ -108,23 +112,23 @@ class JsonSuite extends QueryTest { // LongType checkDataType(LongType, LongType, LongType) checkDataType(LongType, DoubleType, DoubleType) - checkDataType(LongType, DecimalType, DecimalType) + checkDataType(LongType, DecimalType.Unlimited, DecimalType.Unlimited) checkDataType(LongType, StringType, StringType) checkDataType(LongType, ArrayType(IntegerType), StringType) checkDataType(LongType, StructType(Nil), StringType) // DoubleType checkDataType(DoubleType, DoubleType, DoubleType) - checkDataType(DoubleType, DecimalType, DecimalType) + checkDataType(DoubleType, DecimalType.Unlimited, DecimalType.Unlimited) checkDataType(DoubleType, StringType, StringType) checkDataType(DoubleType, ArrayType(IntegerType), StringType) checkDataType(DoubleType, StructType(Nil), StringType) // DoubleType - checkDataType(DecimalType, DecimalType, DecimalType) - checkDataType(DecimalType, StringType, StringType) - checkDataType(DecimalType, ArrayType(IntegerType), StringType) - checkDataType(DecimalType, StructType(Nil), StringType) + checkDataType(DecimalType.Unlimited, DecimalType.Unlimited, DecimalType.Unlimited) + checkDataType(DecimalType.Unlimited, StringType, StringType) + checkDataType(DecimalType.Unlimited, ArrayType(IntegerType), StringType) + checkDataType(DecimalType.Unlimited, StructType(Nil), StringType) // StringType checkDataType(StringType, StringType, StringType) @@ -178,7 +182,7 @@ class JsonSuite extends QueryTest { checkDataType( StructType( StructField("f1", IntegerType, true) :: Nil), - DecimalType, + DecimalType.Unlimited, StringType) } @@ -186,7 +190,7 @@ class JsonSuite extends QueryTest { val jsonSchemaRDD = jsonRDD(primitiveFieldAndType) val expectedSchema = StructType( - StructField("bigInteger", DecimalType, true) :: + StructField("bigInteger", DecimalType.Unlimited, true) :: StructField("boolean", BooleanType, true) :: StructField("double", DoubleType, true) :: StructField("integer", IntegerType, true) :: @@ -216,7 +220,7 @@ class JsonSuite extends QueryTest { val expectedSchema = StructType( StructField("arrayOfArray1", ArrayType(ArrayType(StringType, false), false), true) :: StructField("arrayOfArray2", ArrayType(ArrayType(DoubleType, false), false), true) :: - StructField("arrayOfBigInteger", ArrayType(DecimalType, false), true) :: + StructField("arrayOfBigInteger", ArrayType(DecimalType.Unlimited, false), true) :: StructField("arrayOfBoolean", ArrayType(BooleanType, false), true) :: StructField("arrayOfDouble", ArrayType(DoubleType, false), true) :: StructField("arrayOfInteger", ArrayType(IntegerType, false), true) :: @@ -230,7 +234,7 @@ class JsonSuite extends QueryTest { StructField("field3", StringType, true) :: Nil), false), true) :: StructField("struct", StructType( StructField("field1", BooleanType, true) :: - StructField("field2", DecimalType, true) :: Nil), true) :: + StructField("field2", DecimalType.Unlimited, true) :: Nil), true) :: StructField("structWithArrayFields", StructType( StructField("field1", ArrayType(IntegerType, false), true) :: StructField("field2", ArrayType(StringType, false), true) :: Nil), true) :: Nil) @@ -331,7 +335,7 @@ class JsonSuite extends QueryTest { val expectedSchema = StructType( StructField("num_bool", StringType, true) :: StructField("num_num_1", LongType, true) :: - StructField("num_num_2", DecimalType, true) :: + StructField("num_num_2", DecimalType.Unlimited, true) :: StructField("num_num_3", DoubleType, true) :: StructField("num_str", StringType, true) :: StructField("str_bool", StringType, true) :: Nil) @@ -521,7 +525,7 @@ class JsonSuite extends QueryTest { val jsonSchemaRDD = jsonFile(path) val expectedSchema = StructType( - StructField("bigInteger", DecimalType, true) :: + StructField("bigInteger", DecimalType.Unlimited, true) :: StructField("boolean", BooleanType, true) :: StructField("double", DoubleType, true) :: StructField("integer", IntegerType, true) :: @@ -551,7 +555,7 @@ class JsonSuite extends QueryTest { primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) val schema = StructType( - StructField("bigInteger", DecimalType, true) :: + StructField("bigInteger", DecimalType.Unlimited, true) :: StructField("boolean", BooleanType, true) :: StructField("double", DoubleType, true) :: StructField("integer", IntegerType, true) :: http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 9979ab4..08d9da2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -77,6 +77,8 @@ case class AllDataTypesWithNonPrimitiveType( case class BinaryData(binaryData: Array[Byte]) +case class NumericData(i: Int, d: Double) + class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { TestData // Load test data tables. @@ -560,7 +562,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(stringResult.size === 1) assert(stringResult(0).getString(2) == "100", "stringvalue incorrect") assert(stringResult(0).getInt(1) === 100) - + val query7 = sql(s"SELECT * FROM testfiltersource WHERE myoptint < 40") assert( query7.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], @@ -869,4 +871,35 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(a.dataType === b.dataType) } } + + test("read/write fixed-length decimals") { + for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) { + val tempDir = getTempFilePath("parquetTest").getCanonicalPath + val data = sparkContext.parallelize(0 to 1000) + .map(i => NumericData(i, i / 100.0)) + .select('i, 'd cast DecimalType(precision, scale)) + data.saveAsParquetFile(tempDir) + checkAnswer(parquetFile(tempDir), data.toSchemaRDD.collect().toSeq) + } + + // Decimals with precision above 18 are not yet supported + intercept[RuntimeException] { + val tempDir = getTempFilePath("parquetTest").getCanonicalPath + val data = sparkContext.parallelize(0 to 1000) + .map(i => NumericData(i, i / 100.0)) + .select('i, 'd cast DecimalType(19, 10)) + data.saveAsParquetFile(tempDir) + checkAnswer(parquetFile(tempDir), data.toSchemaRDD.collect().toSeq) + } + + // Unlimited-length decimals are not yet supported + intercept[RuntimeException] { + val tempDir = getTempFilePath("parquetTest").getCanonicalPath + val data = sparkContext.parallelize(0 to 1000) + .map(i => NumericData(i, i / 100.0)) + .select('i, 'd cast DecimalType.Unlimited) + data.saveAsParquetFile(tempDir) + checkAnswer(parquetFile(tempDir), data.toSchemaRDD.collect().toSeq) + } + } } http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala ---------------------------------------------------------------------- diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index 2a4f241..99c4f46 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -47,7 +47,7 @@ private[thriftserver] class SparkSQLOperationManager(hiveContext: HiveContext) val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay)( hiveContext, sessionToActivePool) - handleToOperation.put(operation.getHandle, operation) - operation + handleToOperation.put(operation.getHandle, operation) + operation } } http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala ---------------------------------------------------------------------- diff --git a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala index bbd727c..8077d0e 100644 --- a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala +++ b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala @@ -123,7 +123,7 @@ private[hive] class SparkExecuteStatementOperation( to.addColumnValue(ColumnValue.doubleValue(from.getDouble(ordinal))) case FloatType => to.addColumnValue(ColumnValue.floatValue(from.getFloat(ordinal))) - case DecimalType => + case DecimalType() => val hiveDecimal = from.get(ordinal).asInstanceOf[BigDecimal].bigDecimal to.addColumnValue(ColumnValue.stringValue(new HiveDecimal(hiveDecimal))) case LongType => @@ -156,7 +156,7 @@ private[hive] class SparkExecuteStatementOperation( to.addColumnValue(ColumnValue.doubleValue(null)) case FloatType => to.addColumnValue(ColumnValue.floatValue(null)) - case DecimalType => + case DecimalType() => to.addColumnValue(ColumnValue.stringValue(null: HiveDecimal)) case LongType => to.addColumnValue(ColumnValue.longValue(null)) http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala ---------------------------------------------------------------------- diff --git a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala index e59681b..2c1983d 100644 --- a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala +++ b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala @@ -123,7 +123,7 @@ private[hive] class SparkExecuteStatementOperation( to += from.getDouble(ordinal) case FloatType => to += from.getFloat(ordinal) - case DecimalType => + case DecimalType() => to += from.get(ordinal).asInstanceOf[BigDecimal].bigDecimal case LongType => to += from.getLong(ordinal) http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index ff8fa44..2e27817 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -21,6 +21,10 @@ import java.io.{BufferedReader, File, InputStreamReader, PrintStream} import java.sql.{Date, Timestamp} import java.util.{ArrayList => JArrayList} +import org.apache.hadoop.hive.common.`type`.HiveDecimal +import org.apache.spark.sql.catalyst.types.DecimalType +import org.apache.spark.sql.catalyst.types.decimal.Decimal + import scala.collection.JavaConversions._ import scala.language.implicitConversions import scala.reflect.runtime.universe.{TypeTag, typeTag} @@ -370,7 +374,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { protected val primitiveTypes = Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType, - ShortType, DecimalType, DateType, TimestampType, BinaryType) + ShortType, DateType, TimestampType, BinaryType) protected[sql] def toHiveString(a: (Any, DataType)): String = a match { case (struct: Row, StructType(fields)) => @@ -388,6 +392,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { case (d: Date, DateType) => new DateWritable(d).toString case (t: Timestamp, TimestampType) => new TimestampWritable(t).toString case (bin: Array[Byte], BinaryType) => new String(bin, "UTF-8") + case (decimal: Decimal, DecimalType()) => // Hive strips trailing zeros so use its toString + HiveShim.createDecimal(decimal.toBigDecimal.underlying()).toString case (other, tpe) if primitiveTypes contains tpe => other.toString } @@ -406,6 +412,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { }.toSeq.sorted.mkString("{", ",", "}") case (null, _) => "null" case (s: String, StringType) => "\"" + s + "\"" + case (decimal, DecimalType()) => decimal.toString case (other, tpe) if primitiveTypes contains tpe => other.toString } http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 0439ab9..1e2bf5c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -28,6 +28,7 @@ import org.apache.hadoop.{io => hadoopIo} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.decimal.Decimal /* Implicit conversions */ import scala.collection.JavaConversions._ @@ -38,7 +39,7 @@ private[hive] trait HiveInspectors { // writable case c: Class[_] if c == classOf[hadoopIo.DoubleWritable] => DoubleType case c: Class[_] if c == classOf[hiveIo.DoubleWritable] => DoubleType - case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType + case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType.Unlimited case c: Class[_] if c == classOf[hiveIo.ByteWritable] => ByteType case c: Class[_] if c == classOf[hiveIo.ShortWritable] => ShortType case c: Class[_] if c == classOf[hiveIo.DateWritable] => DateType @@ -54,8 +55,8 @@ private[hive] trait HiveInspectors { case c: Class[_] if c == classOf[java.lang.String] => StringType case c: Class[_] if c == classOf[java.sql.Date] => DateType case c: Class[_] if c == classOf[java.sql.Timestamp] => TimestampType - case c: Class[_] if c == classOf[HiveDecimal] => DecimalType - case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType + case c: Class[_] if c == classOf[HiveDecimal] => DecimalType.Unlimited + case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType.Unlimited case c: Class[_] if c == classOf[Array[Byte]] => BinaryType case c: Class[_] if c == classOf[java.lang.Short] => ShortType case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType @@ -90,7 +91,7 @@ private[hive] trait HiveInspectors { case hvoi: HiveVarcharObjectInspector => if (data == null) null else hvoi.getPrimitiveJavaObject(data).getValue case hdoi: HiveDecimalObjectInspector => - if (data == null) null else BigDecimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue()) + if (data == null) null else HiveShim.toCatalystDecimal(hdoi, data) // org.apache.hadoop.hive.serde2.io.TimestampWritable.set will reset current time object // if next timestamp is null, so Timestamp object is cloned case ti: TimestampObjectInspector => ti.getPrimitiveJavaObject(data).clone() @@ -137,8 +138,9 @@ private[hive] trait HiveInspectors { case l: Short => l: java.lang.Short case l: Byte => l: java.lang.Byte case b: BigDecimal => HiveShim.createDecimal(b.underlying()) + case d: Decimal => HiveShim.createDecimal(d.toBigDecimal.underlying()) case b: Array[Byte] => b - case d: java.sql.Date => d + case d: java.sql.Date => d case t: java.sql.Timestamp => t } case x: StructObjectInspector => @@ -200,7 +202,7 @@ private[hive] trait HiveInspectors { case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector case DateType => PrimitiveObjectInspectorFactory.javaDateObjectInspector case TimestampType => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector - case DecimalType => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector + case DecimalType() => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector case StructType(fields) => ObjectInspectorFactory.getStandardStructObjectInspector( fields.map(f => f.name), fields.map(f => toInspector(f.dataType))) @@ -229,8 +231,10 @@ private[hive] trait HiveInspectors { HiveShim.getPrimitiveWritableConstantObjectInspector(value) case Literal(value: java.sql.Timestamp, TimestampType) => HiveShim.getPrimitiveWritableConstantObjectInspector(value) - case Literal(value: BigDecimal, DecimalType) => + case Literal(value: BigDecimal, DecimalType()) => HiveShim.getPrimitiveWritableConstantObjectInspector(value) + case Literal(value: Decimal, DecimalType()) => + HiveShim.getPrimitiveWritableConstantObjectInspector(value.toBigDecimal) case Literal(_, NullType) => HiveShim.getPrimitiveNullWritableConstantObjectInspector case Literal(value: Seq[_], ArrayType(dt, _)) => @@ -277,8 +281,8 @@ private[hive] trait HiveInspectors { case _: JavaFloatObjectInspector => FloatType case _: WritableBinaryObjectInspector => BinaryType case _: JavaBinaryObjectInspector => BinaryType - case _: WritableHiveDecimalObjectInspector => DecimalType - case _: JavaHiveDecimalObjectInspector => DecimalType + case w: WritableHiveDecimalObjectInspector => HiveShim.decimalTypeInfoToCatalyst(w) + case j: JavaHiveDecimalObjectInspector => HiveShim.decimalTypeInfoToCatalyst(j) case _: WritableDateObjectInspector => DateType case _: JavaDateObjectInspector => DateType case _: WritableTimestampObjectInspector => TimestampType @@ -307,7 +311,7 @@ private[hive] trait HiveInspectors { case LongType => longTypeInfo case ShortType => shortTypeInfo case StringType => stringTypeInfo - case DecimalType => decimalTypeInfo + case d: DecimalType => HiveShim.decimalTypeInfo(d) case DateType => dateTypeInfo case TimestampType => timestampTypeInfo case NullType => voidTypeInfo http://git-wip-us.apache.org/repos/asf/spark/blob/23f966f4/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 2dd2c88..096b4a0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive import java.io.IOException import java.util.{List => JList} +import scala.util.matching.Regex import scala.util.parsing.combinator.RegexParsers import org.apache.hadoop.util.ReflectionUtils @@ -321,11 +322,18 @@ object HiveMetastoreTypes extends RegexParsers { "bigint" ^^^ LongType | "binary" ^^^ BinaryType | "boolean" ^^^ BooleanType | - HiveShim.metastoreDecimal ^^^ DecimalType | + fixedDecimalType | // Hive 0.13+ decimal with precision/scale + "decimal" ^^^ DecimalType.Unlimited | // Hive 0.12 decimal with no precision/scale "date" ^^^ DateType | "timestamp" ^^^ TimestampType | "varchar\\((\\d+)\\)".r ^^^ StringType + protected lazy val fixedDecimalType: Parser[DataType] = + ("decimal" ~> "(" ~> "\\d+".r) ~ ("," ~> "\\d+".r <~ ")") ^^ { + case precision ~ scale => + DecimalType(precision.toInt, scale.toInt) + } + protected lazy val arrayType: Parser[DataType] = "array" ~> "<" ~> dataType <~ ">" ^^ { case tpe => ArrayType(tpe) @@ -373,7 +381,7 @@ object HiveMetastoreTypes extends RegexParsers { case BinaryType => "binary" case BooleanType => "boolean" case DateType => "date" - case DecimalType => "decimal" + case d: DecimalType => HiveShim.decimalMetastoreString(d) case TimestampType => "timestamp" case NullType => "void" } @@ -441,7 +449,7 @@ private[hive] case class MetastoreRelation val partitionKeys = hiveQlTable.getPartitionKeys.map(_.toAttribute) /** Non-partitionKey attributes */ - val attributes = hiveQlTable.getCols.map(_.toAttribute) + val attributes = hiveQlTable.getCols.map(_.toAttribute) val output = attributes ++ partitionKeys --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
