Repository: spark Updated Branches: refs/heads/master 9d8666cac -> 890479123
[SPARK-2659][SQL] Fix division semantics for hive Author: Michael Armbrust <mich...@databricks.com> Closes #1557 from marmbrus/fixDivision and squashes the following commits: b85077f [Michael Armbrust] Fix unit tests. af98f29 [Michael Armbrust] Change DIV to long type 0c29ae8 [Michael Armbrust] Fix division semantics for hive Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/89047912 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/89047912 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/89047912 Branch: refs/heads/master Commit: 8904791230a0fae336db93e5a80f65c4d9d584dc Parents: 9d8666c Author: Michael Armbrust <mich...@databricks.com> Authored: Fri Jul 25 19:17:49 2014 -0700 Committer: Michael Armbrust <mich...@databricks.com> Committed: Fri Jul 25 19:17:49 2014 -0700 ---------------------------------------------------------------------- .../sql/catalyst/analysis/HiveTypeCoercion.scala | 18 ++++++++++++++++++ .../catalyst/optimizer/ConstantFoldingSuite.scala | 2 +- .../scala/org/apache/spark/sql/hive/HiveQl.scala | 3 ++- .../golden/div-0-3760f9b354ddacd7c7b01b28791d4585 | 1 + .../division-0-63b19f8a22471c8ba0415c1d3bc276f7 | 1 + .../sql/hive/execution/HiveComparisonTest.scala | 6 ------ .../spark/sql/hive/execution/HiveQuerySuite.scala | 5 ++++- 7 files changed, 27 insertions(+), 9 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/89047912/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 67a8ce9..47c7ad0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -50,6 +50,7 @@ trait HiveTypeCoercion { StringToIntegralCasts :: FunctionArgumentConversion :: CastNulls :: + Division :: Nil /** @@ -318,6 +319,23 @@ trait HiveTypeCoercion { } /** + * Hive only performs integral division with the DIV operator. The arguments to / are always + * converted to fractional types. + */ + object Division extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + // Skip nodes who's children have not been resolved yet. + case e if !e.childrenResolved => e + + // Decimal and Double remain the same + case d: Divide if d.dataType == DoubleType => d + case d: Divide if d.dataType == DecimalType => d + + case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType)) + } + } + + /** * Ensures that NullType gets casted to some other types under certain circumstances. */ object CastNulls extends Rule[LogicalPlan] { http://git-wip-us.apache.org/repos/asf/spark/blob/89047912/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index d607eed..0a27cce 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -83,7 +83,7 @@ class ConstantFoldingSuite extends PlanTest { Literal(10) as Symbol("2*3+4"), Literal(14) as Symbol("2*(3+4)")) .where(Literal(true)) - .groupBy(Literal(3))(Literal(3) as Symbol("9/3")) + .groupBy(Literal(3.0))(Literal(3.0) as Symbol("9/3")) .analyze comparePlans(optimized, correctAnswer) http://git-wip-us.apache.org/repos/asf/spark/blob/89047912/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 4395874..e6ab68b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -925,7 +925,8 @@ private[hive] object HiveQl { case Token("-", left :: right:: Nil) => Subtract(nodeToExpr(left), nodeToExpr(right)) case Token("*", left :: right:: Nil) => Multiply(nodeToExpr(left), nodeToExpr(right)) case Token("/", left :: right:: Nil) => Divide(nodeToExpr(left), nodeToExpr(right)) - case Token(DIV(), left :: right:: Nil) => Divide(nodeToExpr(left), nodeToExpr(right)) + case Token(DIV(), left :: right:: Nil) => + Cast(Divide(nodeToExpr(left), nodeToExpr(right)), LongType) case Token("%", left :: right:: Nil) => Remainder(nodeToExpr(left), nodeToExpr(right)) /* Comparisons */ http://git-wip-us.apache.org/repos/asf/spark/blob/89047912/sql/hive/src/test/resources/golden/div-0-3760f9b354ddacd7c7b01b28791d4585 ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/resources/golden/div-0-3760f9b354ddacd7c7b01b28791d4585 b/sql/hive/src/test/resources/golden/div-0-3760f9b354ddacd7c7b01b28791d4585 new file mode 100644 index 0000000..17ba0be --- /dev/null +++ b/sql/hive/src/test/resources/golden/div-0-3760f9b354ddacd7c7b01b28791d4585 @@ -0,0 +1 @@ +0 0 0 1 2 http://git-wip-us.apache.org/repos/asf/spark/blob/89047912/sql/hive/src/test/resources/golden/division-0-63b19f8a22471c8ba0415c1d3bc276f7 ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/resources/golden/division-0-63b19f8a22471c8ba0415c1d3bc276f7 b/sql/hive/src/test/resources/golden/division-0-63b19f8a22471c8ba0415c1d3bc276f7 new file mode 100644 index 0000000..7b7a917 --- /dev/null +++ b/sql/hive/src/test/resources/golden/division-0-63b19f8a22471c8ba0415c1d3bc276f7 @@ -0,0 +1 @@ +2.0 0.5 0.3333333333333333 0.002 http://git-wip-us.apache.org/repos/asf/spark/blob/89047912/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 08ef4d9..b4dbf2b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -350,12 +350,6 @@ abstract class HiveComparisonTest val resultComparison = sideBySide(hivePrintOut, catalystPrintOut).mkString("\n") - println("hive output") - hive.foreach(println) - - println("catalyst printout") - catalyst.foreach(println) - if (recomputeCache) { logger.warn(s"Clearing cache files for failed test $testCaseName") hiveCacheFiles.foreach(_.delete()) http://git-wip-us.apache.org/repos/asf/spark/blob/89047912/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 6f36a4f..a8623b6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -52,7 +52,10 @@ class HiveQuerySuite extends HiveComparisonTest { "SELECT * FROM src WHERE key Between 1 and 2") createQueryTest("div", - "SELECT 1 DIV 2, 1 div 2, 1 dIv 2 FROM src LIMIT 1") + "SELECT 1 DIV 2, 1 div 2, 1 dIv 2, 100 DIV 51, 100 DIV 49 FROM src LIMIT 1") + + createQueryTest("division", + "SELECT 2 / 1, 1 / 2, 1 / 3, 1 / COUNT(*) FROM src LIMIT 1") test("Query expressed in SQL") { assert(sql("SELECT 1").collect() === Array(Seq(1)))