Repository: spark Updated Branches: refs/heads/branch-2.1 a5ec2a7b2 -> 8cd466e83
[SPARK-18622][SQL] Fix the datatype of the Sum aggregate function ## What changes were proposed in this pull request? The result of a `sum` aggregate function is typically a Decimal, Double or a Long. Currently the output dataType is based on input's dataType. The `FunctionArgumentConversion` rule will make sure that the input is promoted to the largest type, and that also ensures that the output uses a (hopefully) sufficiently large output dataType. The issue is that sum is in a resolved state when we cast the input type, this means that rules assuming that the dataType of the expression does not change anymore could have been applied in the mean time. This is what happens if we apply `WidenSetOperationTypes` before applying the casts, and this breaks analysis. The most straight forward and future proof solution is to make `sum` always output the widest dataType in its class (Long for IntegralTypes, Decimal for DecimalTypes & Double for FloatType and DoubleType). This PR implements that solution. We should move expression specific type casting rules into the given Expression at some point. ## How was this patch tested? Added (regression) tests to SQLQueryTestSuite's `union.sql`. Author: Herman van Hovell <[email protected]> Closes #16063 from hvanhovell/SPARK-18622. (cherry picked from commit 879ba71110b6c85a4e47133620fbae7580650a6f) Signed-off-by: Wenchen Fan <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8cd466e8 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8cd466e8 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8cd466e8 Branch: refs/heads/branch-2.1 Commit: 8cd466e831a7987a6fb04833c31b9b442da092db Parents: a5ec2a7 Author: Herman van Hovell <[email protected]> Authored: Wed Nov 30 15:25:33 2016 +0800 Committer: Wenchen Fan <[email protected]> Committed: Wed Nov 30 15:25:52 2016 +0800 ---------------------------------------------------------------------- .../catalyst/expressions/aggregate/Sum.scala | 6 +- .../test/resources/sql-tests/inputs/union.sql | 27 +++++++ .../resources/sql-tests/results/union.sql.out | 80 ++++++++++++++++++++ 3 files changed, 110 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/8cd466e8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index f3731d4..3c77b11 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -33,8 +33,7 @@ case class Sum(child: Expression) extends DeclarativeAggregate { // Return data type. override def dataType: DataType = resultType - override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(LongType, DoubleType, DecimalType)) + override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) override def checkInputDataTypes(): TypeCheckResult = TypeUtils.checkForNumericExpr(child.dataType, "function sum") @@ -42,7 +41,8 @@ case class Sum(child: Expression) extends DeclarativeAggregate { private lazy val resultType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType.bounded(precision + 10, scale) - case _ => child.dataType + case _: IntegralType => LongType + case _ => DoubleType } private lazy val sumDataType = resultType http://git-wip-us.apache.org/repos/asf/spark/blob/8cd466e8/sql/core/src/test/resources/sql-tests/inputs/union.sql ---------------------------------------------------------------------- diff --git a/sql/core/src/test/resources/sql-tests/inputs/union.sql b/sql/core/src/test/resources/sql-tests/inputs/union.sql new file mode 100644 index 0000000..1f4780a --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/union.sql @@ -0,0 +1,27 @@ +CREATE OR REPLACE TEMPORARY VIEW t1 AS VALUES (1, 'a'), (2, 'b') tbl(c1, c2); +CREATE OR REPLACE TEMPORARY VIEW t2 AS VALUES (1.0, 1), (2.0, 4) tbl(c1, c2); + +-- Simple Union +SELECT * +FROM (SELECT * FROM t1 + UNION ALL + SELECT * FROM t1); + +-- Type Coerced Union +SELECT * +FROM (SELECT * FROM t1 + UNION ALL + SELECT * FROM t2 + UNION ALL + SELECT * FROM t2); + +-- Regression test for SPARK-18622 +SELECT a +FROM (SELECT 0 a, 0 b + UNION ALL + SELECT SUM(1) a, CAST(0 AS BIGINT) b + UNION ALL SELECT 0 a, 0 b) T; + +-- Clean-up +DROP VIEW IF EXISTS t1; +DROP VIEW IF EXISTS t2; http://git-wip-us.apache.org/repos/asf/spark/blob/8cd466e8/sql/core/src/test/resources/sql-tests/results/union.sql.out ---------------------------------------------------------------------- diff --git a/sql/core/src/test/resources/sql-tests/results/union.sql.out b/sql/core/src/test/resources/sql-tests/results/union.sql.out new file mode 100644 index 0000000..c57028c --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/union.sql.out @@ -0,0 +1,80 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 7 + + +-- !query 0 +CREATE OR REPLACE TEMPORARY VIEW t1 AS VALUES (1, 'a'), (2, 'b') tbl(c1, c2) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE OR REPLACE TEMPORARY VIEW t2 AS VALUES (1.0, 1), (2.0, 4) tbl(c1, c2) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +SELECT * +FROM (SELECT * FROM t1 + UNION ALL + SELECT * FROM t1) +-- !query 2 schema +struct<c1:int,c2:string> +-- !query 2 output +1 a +1 a +2 b +2 b + + +-- !query 3 +SELECT * +FROM (SELECT * FROM t1 + UNION ALL + SELECT * FROM t2 + UNION ALL + SELECT * FROM t2) +-- !query 3 schema +struct<c1:decimal(11,1),c2:string> +-- !query 3 output +1 1 +1 1 +1 a +2 4 +2 4 +2 b + + +-- !query 4 +SELECT a +FROM (SELECT 0 a, 0 b + UNION ALL + SELECT SUM(1) a, CAST(0 AS BIGINT) b + UNION ALL SELECT 0 a, 0 b) T +-- !query 4 schema +struct<a:bigint> +-- !query 4 output +0 +0 +1 + + +-- !query 5 +DROP VIEW IF EXISTS t1 +-- !query 5 schema +struct<> +-- !query 5 output + + + +-- !query 6 +DROP VIEW IF EXISTS t2 +-- !query 6 schema +struct<> +-- !query 6 output + --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
