Repository: spark
Updated Branches:
  refs/heads/master a1d9138ab -> 879ba7111


[SPARK-18622][SQL] Fix the datatype of the Sum aggregate function

## What changes were proposed in this pull request?
The result of a `sum` aggregate function is typically a Decimal, Double or a 
Long. Currently the output dataType is based on input's dataType.

The `FunctionArgumentConversion` rule will make sure that the input is promoted 
to the largest type, and that also ensures that the output uses a (hopefully) 
sufficiently large output dataType. The issue is that sum is in a resolved 
state when we cast the input type, this means that rules assuming that the 
dataType of the expression does not change anymore could have been applied in 
the mean time. This is what happens if we apply `WidenSetOperationTypes` before 
applying the casts, and this breaks analysis.

The most straight forward and future proof solution is to make `sum` always 
output the widest dataType in its class (Long for IntegralTypes, Decimal for 
DecimalTypes & Double for FloatType and DoubleType). This PR implements that 
solution.

We should move expression specific type casting rules into the given Expression 
at some point.

## How was this patch tested?
Added (regression) tests to SQLQueryTestSuite's `union.sql`.

Author: Herman van Hovell <[email protected]>

Closes #16063 from hvanhovell/SPARK-18622.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/879ba711
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/879ba711
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/879ba711

Branch: refs/heads/master
Commit: 879ba71110b6c85a4e47133620fbae7580650a6f
Parents: a1d9138
Author: Herman van Hovell <[email protected]>
Authored: Wed Nov 30 15:25:33 2016 +0800
Committer: Wenchen Fan <[email protected]>
Committed: Wed Nov 30 15:25:33 2016 +0800

----------------------------------------------------------------------
 .../catalyst/expressions/aggregate/Sum.scala    |  6 +-
 .../test/resources/sql-tests/inputs/union.sql   | 27 +++++++
 .../resources/sql-tests/results/union.sql.out   | 80 ++++++++++++++++++++
 3 files changed, 110 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/879ba711/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
index 96e8cee..86e40a9 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
@@ -33,8 +33,7 @@ case class Sum(child: Expression) extends 
DeclarativeAggregate with ImplicitCast
   // Return data type.
   override def dataType: DataType = resultType
 
-  override def inputTypes: Seq[AbstractDataType] =
-    Seq(TypeCollection(LongType, DoubleType, DecimalType))
+  override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
 
   override def checkInputDataTypes(): TypeCheckResult =
     TypeUtils.checkForNumericExpr(child.dataType, "function sum")
@@ -42,7 +41,8 @@ case class Sum(child: Expression) extends 
DeclarativeAggregate with ImplicitCast
   private lazy val resultType = child.dataType match {
     case DecimalType.Fixed(precision, scale) =>
       DecimalType.bounded(precision + 10, scale)
-    case _ => child.dataType
+    case _: IntegralType => LongType
+    case _ => DoubleType
   }
 
   private lazy val sumDataType = resultType

http://git-wip-us.apache.org/repos/asf/spark/blob/879ba711/sql/core/src/test/resources/sql-tests/inputs/union.sql
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/inputs/union.sql 
b/sql/core/src/test/resources/sql-tests/inputs/union.sql
new file mode 100644
index 0000000..1f4780a
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/union.sql
@@ -0,0 +1,27 @@
+CREATE OR REPLACE TEMPORARY VIEW t1 AS VALUES (1, 'a'), (2, 'b') tbl(c1, c2);
+CREATE OR REPLACE TEMPORARY VIEW t2 AS VALUES (1.0, 1), (2.0, 4) tbl(c1, c2);
+
+-- Simple Union
+SELECT *
+FROM   (SELECT * FROM t1
+        UNION ALL
+        SELECT * FROM t1);
+
+-- Type Coerced Union
+SELECT *
+FROM   (SELECT * FROM t1
+        UNION ALL
+        SELECT * FROM t2
+        UNION ALL
+        SELECT * FROM t2);
+
+-- Regression test for SPARK-18622
+SELECT a
+FROM (SELECT 0 a, 0 b
+      UNION ALL
+      SELECT SUM(1) a, CAST(0 AS BIGINT) b
+      UNION ALL SELECT 0 a, 0 b) T;
+
+-- Clean-up
+DROP VIEW IF EXISTS t1;
+DROP VIEW IF EXISTS t2;

http://git-wip-us.apache.org/repos/asf/spark/blob/879ba711/sql/core/src/test/resources/sql-tests/results/union.sql.out
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/results/union.sql.out 
b/sql/core/src/test/resources/sql-tests/results/union.sql.out
new file mode 100644
index 0000000..c57028c
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/union.sql.out
@@ -0,0 +1,80 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 7
+
+
+-- !query 0
+CREATE OR REPLACE TEMPORARY VIEW t1 AS VALUES (1, 'a'), (2, 'b') tbl(c1, c2)
+-- !query 0 schema
+struct<>
+-- !query 0 output
+
+
+
+-- !query 1
+CREATE OR REPLACE TEMPORARY VIEW t2 AS VALUES (1.0, 1), (2.0, 4) tbl(c1, c2)
+-- !query 1 schema
+struct<>
+-- !query 1 output
+
+
+
+-- !query 2
+SELECT *
+FROM   (SELECT * FROM t1
+        UNION ALL
+        SELECT * FROM t1)
+-- !query 2 schema
+struct<c1:int,c2:string>
+-- !query 2 output
+1      a
+1      a
+2      b
+2      b
+
+
+-- !query 3
+SELECT *
+FROM   (SELECT * FROM t1
+        UNION ALL
+        SELECT * FROM t2
+        UNION ALL
+        SELECT * FROM t2)
+-- !query 3 schema
+struct<c1:decimal(11,1),c2:string>
+-- !query 3 output
+1      1
+1      1
+1      a
+2      4
+2      4
+2      b
+
+
+-- !query 4
+SELECT a
+FROM (SELECT 0 a, 0 b
+      UNION ALL
+      SELECT SUM(1) a, CAST(0 AS BIGINT) b
+      UNION ALL SELECT 0 a, 0 b) T
+-- !query 4 schema
+struct<a:bigint>
+-- !query 4 output
+0
+0
+1
+
+
+-- !query 5
+DROP VIEW IF EXISTS t1
+-- !query 5 schema
+struct<>
+-- !query 5 output
+
+
+
+-- !query 6
+DROP VIEW IF EXISTS t2
+-- !query 6 schema
+struct<>
+-- !query 6 output
+


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to