Repository: spark
Updated Branches:
  refs/heads/master 9d8666cac -> 890479123


[SPARK-2659][SQL] Fix division semantics for hive

Author: Michael Armbrust <mich...@databricks.com>

Closes #1557 from marmbrus/fixDivision and squashes the following commits:

b85077f [Michael Armbrust] Fix unit tests.
af98f29 [Michael Armbrust] Change DIV to long type
0c29ae8 [Michael Armbrust] Fix division semantics for hive


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/89047912
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/89047912
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/89047912

Branch: refs/heads/master
Commit: 8904791230a0fae336db93e5a80f65c4d9d584dc
Parents: 9d8666c
Author: Michael Armbrust <mich...@databricks.com>
Authored: Fri Jul 25 19:17:49 2014 -0700
Committer: Michael Armbrust <mich...@databricks.com>
Committed: Fri Jul 25 19:17:49 2014 -0700

----------------------------------------------------------------------
 .../sql/catalyst/analysis/HiveTypeCoercion.scala  | 18 ++++++++++++++++++
 .../catalyst/optimizer/ConstantFoldingSuite.scala |  2 +-
 .../scala/org/apache/spark/sql/hive/HiveQl.scala  |  3 ++-
 .../golden/div-0-3760f9b354ddacd7c7b01b28791d4585 |  1 +
 .../division-0-63b19f8a22471c8ba0415c1d3bc276f7   |  1 +
 .../sql/hive/execution/HiveComparisonTest.scala   |  6 ------
 .../spark/sql/hive/execution/HiveQuerySuite.scala |  5 ++++-
 7 files changed, 27 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/89047912/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 67a8ce9..47c7ad0 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -50,6 +50,7 @@ trait HiveTypeCoercion {
     StringToIntegralCasts ::
     FunctionArgumentConversion ::
     CastNulls ::
+    Division ::
     Nil
 
   /**
@@ -318,6 +319,23 @@ trait HiveTypeCoercion {
   }
 
   /**
+   * Hive only performs integral division with the DIV operator. The arguments 
to / are always
+   * converted to fractional types.
+   */
+  object Division extends Rule[LogicalPlan] {
+    def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+      // Skip nodes who's children have not been resolved yet.
+      case e if !e.childrenResolved => e
+
+      // Decimal and Double remain the same
+      case d: Divide if d.dataType == DoubleType => d
+      case d: Divide if d.dataType == DecimalType => d
+
+      case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType))
+    }
+  }
+
+  /**
    * Ensures that NullType gets casted to some other types under certain 
circumstances.
    */
   object CastNulls extends Rule[LogicalPlan] {

http://git-wip-us.apache.org/repos/asf/spark/blob/89047912/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
index d607eed..0a27cce 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
@@ -83,7 +83,7 @@ class ConstantFoldingSuite extends PlanTest {
           Literal(10) as Symbol("2*3+4"),
           Literal(14) as Symbol("2*(3+4)"))
         .where(Literal(true))
-        .groupBy(Literal(3))(Literal(3) as Symbol("9/3"))
+        .groupBy(Literal(3.0))(Literal(3.0) as Symbol("9/3"))
         .analyze
 
     comparePlans(optimized, correctAnswer)

http://git-wip-us.apache.org/repos/asf/spark/blob/89047912/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 4395874..e6ab68b 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -925,7 +925,8 @@ private[hive] object HiveQl {
     case Token("-", left :: right:: Nil) => Subtract(nodeToExpr(left), 
nodeToExpr(right))
     case Token("*", left :: right:: Nil) => Multiply(nodeToExpr(left), 
nodeToExpr(right))
     case Token("/", left :: right:: Nil) => Divide(nodeToExpr(left), 
nodeToExpr(right))
-    case Token(DIV(), left :: right:: Nil) => Divide(nodeToExpr(left), 
nodeToExpr(right))
+    case Token(DIV(), left :: right:: Nil) =>
+      Cast(Divide(nodeToExpr(left), nodeToExpr(right)), LongType)
     case Token("%", left :: right:: Nil) => Remainder(nodeToExpr(left), 
nodeToExpr(right))
 
     /* Comparisons */

http://git-wip-us.apache.org/repos/asf/spark/blob/89047912/sql/hive/src/test/resources/golden/div-0-3760f9b354ddacd7c7b01b28791d4585
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/resources/golden/div-0-3760f9b354ddacd7c7b01b28791d4585 
b/sql/hive/src/test/resources/golden/div-0-3760f9b354ddacd7c7b01b28791d4585
new file mode 100644
index 0000000..17ba0be
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/div-0-3760f9b354ddacd7c7b01b28791d4585
@@ -0,0 +1 @@
+0      0       0       1       2

http://git-wip-us.apache.org/repos/asf/spark/blob/89047912/sql/hive/src/test/resources/golden/division-0-63b19f8a22471c8ba0415c1d3bc276f7
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/resources/golden/division-0-63b19f8a22471c8ba0415c1d3bc276f7
 
b/sql/hive/src/test/resources/golden/division-0-63b19f8a22471c8ba0415c1d3bc276f7
new file mode 100644
index 0000000..7b7a917
--- /dev/null
+++ 
b/sql/hive/src/test/resources/golden/division-0-63b19f8a22471c8ba0415c1d3bc276f7
@@ -0,0 +1 @@
+2.0    0.5     0.3333333333333333      0.002

http://git-wip-us.apache.org/repos/asf/spark/blob/89047912/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
index 08ef4d9..b4dbf2b 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
@@ -350,12 +350,6 @@ abstract class HiveComparisonTest
 
               val resultComparison = sideBySide(hivePrintOut, 
catalystPrintOut).mkString("\n")
 
-              println("hive output")
-              hive.foreach(println)
-
-              println("catalyst printout")
-              catalyst.foreach(println)
-
               if (recomputeCache) {
                 logger.warn(s"Clearing cache files for failed test 
$testCaseName")
                 hiveCacheFiles.foreach(_.delete())

http://git-wip-us.apache.org/repos/asf/spark/blob/89047912/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index 6f36a4f..a8623b6 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -52,7 +52,10 @@ class HiveQuerySuite extends HiveComparisonTest {
     "SELECT * FROM src WHERE key Between 1 and 2")
 
   createQueryTest("div",
-    "SELECT 1 DIV 2, 1 div 2, 1 dIv 2 FROM src LIMIT 1")
+    "SELECT 1 DIV 2, 1 div 2, 1 dIv 2, 100 DIV 51, 100 DIV 49 FROM src LIMIT 
1")
+
+  createQueryTest("division",
+    "SELECT 2 / 1, 1 / 2, 1 / 3, 1 / COUNT(*) FROM src LIMIT 1")
 
   test("Query expressed in SQL") {
     assert(sql("SELECT 1").collect() === Array(Seq(1)))

Reply via email to