Repository: spark
Updated Branches:
  refs/heads/branch-2.1 a5ec2a7b2 -> 8cd466e83


[SPARK-18622][SQL] Fix the datatype of the Sum aggregate function

## What changes were proposed in this pull request?
The result of a `sum` aggregate function is typically a Decimal, Double or a 
Long. Currently the output dataType is based on input's dataType.

The `FunctionArgumentConversion` rule will make sure that the input is promoted 
to the largest type, and that also ensures that the output uses a (hopefully) 
sufficiently large output dataType. The issue is that sum is in a resolved 
state when we cast the input type, this means that rules assuming that the 
dataType of the expression does not change anymore could have been applied in 
the mean time. This is what happens if we apply `WidenSetOperationTypes` before 
applying the casts, and this breaks analysis.

The most straight forward and future proof solution is to make `sum` always 
output the widest dataType in its class (Long for IntegralTypes, Decimal for 
DecimalTypes & Double for FloatType and DoubleType). This PR implements that 
solution.

We should move expression specific type casting rules into the given Expression 
at some point.

## How was this patch tested?
Added (regression) tests to SQLQueryTestSuite's `union.sql`.

Author: Herman van Hovell <[email protected]>

Closes #16063 from hvanhovell/SPARK-18622.

(cherry picked from commit 879ba71110b6c85a4e47133620fbae7580650a6f)
Signed-off-by: Wenchen Fan <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8cd466e8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8cd466e8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8cd466e8

Branch: refs/heads/branch-2.1
Commit: 8cd466e831a7987a6fb04833c31b9b442da092db
Parents: a5ec2a7
Author: Herman van Hovell <[email protected]>
Authored: Wed Nov 30 15:25:33 2016 +0800
Committer: Wenchen Fan <[email protected]>
Committed: Wed Nov 30 15:25:52 2016 +0800

----------------------------------------------------------------------
 .../catalyst/expressions/aggregate/Sum.scala    |  6 +-
 .../test/resources/sql-tests/inputs/union.sql   | 27 +++++++
 .../resources/sql-tests/results/union.sql.out   | 80 ++++++++++++++++++++
 3 files changed, 110 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8cd466e8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
index f3731d4..3c77b11 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
@@ -33,8 +33,7 @@ case class Sum(child: Expression) extends 
DeclarativeAggregate {
   // Return data type.
   override def dataType: DataType = resultType
 
-  override def inputTypes: Seq[AbstractDataType] =
-    Seq(TypeCollection(LongType, DoubleType, DecimalType))
+  override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
 
   override def checkInputDataTypes(): TypeCheckResult =
     TypeUtils.checkForNumericExpr(child.dataType, "function sum")
@@ -42,7 +41,8 @@ case class Sum(child: Expression) extends 
DeclarativeAggregate {
   private lazy val resultType = child.dataType match {
     case DecimalType.Fixed(precision, scale) =>
       DecimalType.bounded(precision + 10, scale)
-    case _ => child.dataType
+    case _: IntegralType => LongType
+    case _ => DoubleType
   }
 
   private lazy val sumDataType = resultType

http://git-wip-us.apache.org/repos/asf/spark/blob/8cd466e8/sql/core/src/test/resources/sql-tests/inputs/union.sql
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/inputs/union.sql 
b/sql/core/src/test/resources/sql-tests/inputs/union.sql
new file mode 100644
index 0000000..1f4780a
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/union.sql
@@ -0,0 +1,27 @@
+CREATE OR REPLACE TEMPORARY VIEW t1 AS VALUES (1, 'a'), (2, 'b') tbl(c1, c2);
+CREATE OR REPLACE TEMPORARY VIEW t2 AS VALUES (1.0, 1), (2.0, 4) tbl(c1, c2);
+
+-- Simple Union
+SELECT *
+FROM   (SELECT * FROM t1
+        UNION ALL
+        SELECT * FROM t1);
+
+-- Type Coerced Union
+SELECT *
+FROM   (SELECT * FROM t1
+        UNION ALL
+        SELECT * FROM t2
+        UNION ALL
+        SELECT * FROM t2);
+
+-- Regression test for SPARK-18622
+SELECT a
+FROM (SELECT 0 a, 0 b
+      UNION ALL
+      SELECT SUM(1) a, CAST(0 AS BIGINT) b
+      UNION ALL SELECT 0 a, 0 b) T;
+
+-- Clean-up
+DROP VIEW IF EXISTS t1;
+DROP VIEW IF EXISTS t2;

http://git-wip-us.apache.org/repos/asf/spark/blob/8cd466e8/sql/core/src/test/resources/sql-tests/results/union.sql.out
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/results/union.sql.out 
b/sql/core/src/test/resources/sql-tests/results/union.sql.out
new file mode 100644
index 0000000..c57028c
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/union.sql.out
@@ -0,0 +1,80 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 7
+
+
+-- !query 0
+CREATE OR REPLACE TEMPORARY VIEW t1 AS VALUES (1, 'a'), (2, 'b') tbl(c1, c2)
+-- !query 0 schema
+struct<>
+-- !query 0 output
+
+
+
+-- !query 1
+CREATE OR REPLACE TEMPORARY VIEW t2 AS VALUES (1.0, 1), (2.0, 4) tbl(c1, c2)
+-- !query 1 schema
+struct<>
+-- !query 1 output
+
+
+
+-- !query 2
+SELECT *
+FROM   (SELECT * FROM t1
+        UNION ALL
+        SELECT * FROM t1)
+-- !query 2 schema
+struct<c1:int,c2:string>
+-- !query 2 output
+1      a
+1      a
+2      b
+2      b
+
+
+-- !query 3
+SELECT *
+FROM   (SELECT * FROM t1
+        UNION ALL
+        SELECT * FROM t2
+        UNION ALL
+        SELECT * FROM t2)
+-- !query 3 schema
+struct<c1:decimal(11,1),c2:string>
+-- !query 3 output
+1      1
+1      1
+1      a
+2      4
+2      4
+2      b
+
+
+-- !query 4
+SELECT a
+FROM (SELECT 0 a, 0 b
+      UNION ALL
+      SELECT SUM(1) a, CAST(0 AS BIGINT) b
+      UNION ALL SELECT 0 a, 0 b) T
+-- !query 4 schema
+struct<a:bigint>
+-- !query 4 output
+0
+0
+1
+
+
+-- !query 5
+DROP VIEW IF EXISTS t1
+-- !query 5 schema
+struct<>
+-- !query 5 output
+
+
+
+-- !query 6
+DROP VIEW IF EXISTS t2
+-- !query 6 schema
+struct<>
+-- !query 6 output
+


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to