This is an automated email from the ASF dual-hosted git repository. gengliang pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new c34fee4 [SPARK-38548][SQL] New SQL function: try_sum c34fee4 is described below commit c34fee4d20da9ab5b4f1f26185fc1a9a83b99d05 Author: Gengliang Wang <gengli...@apache.org> AuthorDate: Mon Mar 21 14:26:46 2022 +0800 [SPARK-38548][SQL] New SQL function: try_sum ### What changes were proposed in this pull request? Add a new SQL function: try_sum. It is identical to the function `sum`, except that it returns `NULL` result instead of throwing an exception on integral/decimal value overflow. Note it is different from sum when ANSI mode is off: | Function | Sum | TrySum | |------------------|------------------------------------|-------------| | Decimal overflow | Return NULL | Return NULL | | Integer overflow | Return lower 64 bits of the result | Return NULL | ### Why are the changes needed? * Users can manage to finish queries without interruptions in ANSI mode. * Users can get NULLs instead of unreasonable results if overflow occurs when ANSI mode is off. For example ``` > SELECT sum(col) FROM VALUES (9223372036854775807L), (1L) AS tab(col); -9223372036854775808 > SELECT try_sum(col) FROM VALUES (9223372036854775807L), (1L) AS tab(col); NULL ``` ### Does this PR introduce _any_ user-facing change? Yes, a new SQL function: try_sum which is identical to the function `sum`, except that it returns `NULL` result instead of throwing an exception on integral/decimal value overflow. ### How was this patch tested? UT Closes #35848 from gengliangwang/trySum2. Authored-by: Gengliang Wang <gengli...@apache.org> Signed-off-by: Gengliang Wang <gengli...@apache.org> --- docs/sql-ref-ansi-compliance.md | 1 + .../sql/catalyst/analysis/FunctionRegistry.scala | 1 + .../sql/catalyst/expressions/aggregate/Sum.scala | 129 ++++++++++++++++----- .../sql-functions/sql-expression-schema.md | 3 +- .../sql-tests/inputs/ansi/try_aggregates.sql | 1 + .../resources/sql-tests/inputs/try_aggregates.sql | 13 +++ .../sql-tests/results/ansi/try_aggregates.sql.out | 82 +++++++++++++ .../sql-tests/results/try_aggregates.sql.out | 82 +++++++++++++ 8 files changed, 282 insertions(+), 30 deletions(-) diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index ccfc601..0f7dd5d 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -316,6 +316,7 @@ When ANSI mode is on, it throws exceptions for invalid operations. You can use t - `try_subtract`: identical to the add operator `-`, except that it returns `NULL` result instead of throwing an exception on integral value overflow. - `try_multiply`: identical to the add operator `*`, except that it returns `NULL` result instead of throwing an exception on integral value overflow. - `try_divide`: identical to the division operator `/`, except that it returns `NULL` result instead of throwing an exception on dividing 0. + - `try_sum`: identical to the function `sum`, except that it returns `NULL` result instead of throwing an exception on integral/decimal value overflow. - `try_element_at`: identical to the function `element_at`, except that it returns `NULL` result instead of throwing an exception on array's index out of bound or map's key not found. ### SQL Keywords (optional, disabled by default) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index e5954c8..a37d4b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -452,6 +452,7 @@ object FunctionRegistry { expression[TrySubtract]("try_subtract"), expression[TryMultiply]("try_multiply"), expression[TryElementAt]("try_element_at"), + expression[TrySum]("try_sum"), // aggregate functions expression[HyperLogLogPlusPlus]("approx_count_distinct"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index ec7479a..5d8fd70 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -26,27 +26,11 @@ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -@ExpressionDescription( - usage = "_FUNC_(expr) - Returns the sum calculated from values of a group.", - examples = """ - Examples: - > SELECT _FUNC_(col) FROM VALUES (5), (10), (15) AS tab(col); - 30 - > SELECT _FUNC_(col) FROM VALUES (NULL), (10), (15) AS tab(col); - 25 - > SELECT _FUNC_(col) FROM VALUES (NULL), (NULL) AS tab(col); - NULL - """, - group = "agg_funcs", - since = "1.0.0") -case class Sum( - child: Expression, - failOnError: Boolean = SQLConf.get.ansiEnabled) - extends DeclarativeAggregate +abstract class SumBase(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes with UnaryLike[Expression] { - def this(child: Expression) = this(child, failOnError = SQLConf.get.ansiEnabled) + def failOnError: Boolean override def nullable: Boolean = true @@ -57,7 +41,7 @@ case class Sum( Seq(TypeCollection(NumericType, YearMonthIntervalType, DayTimeIntervalType)) override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForAnsiIntervalOrNumericType(child.dataType, "sum") + TypeUtils.checkForAnsiIntervalOrNumericType(child.dataType, prettyName) final override val nodePatterns: Seq[TreePattern] = Seq(SUM) @@ -86,16 +70,17 @@ case class Sum( case _ => Seq(Literal(null, resultType)) } - override lazy val updateExpressions: Seq[Expression] = { + protected def getUpdateExpressions: Seq[Expression] = { resultType match { case _: DecimalType => // For decimal type, the initial value of `sum` is 0. We need to keep `sum` unchanged if // the input is null, as SUM function ignores null input. The `sum` can only be null if // overflow happens under non-ansi mode. val sumExpr = if (child.nullable) { - If(child.isNull, sum, sum + KnownNotNull(child).cast(resultType)) + If(child.isNull, sum, + Add(sum, KnownNotNull(child).cast(resultType), failOnError = failOnError)) } else { - sum + child.cast(resultType) + Add(sum, child.cast(resultType), failOnError = failOnError) } // The buffer becomes non-empty after seeing the first not-null input. val isEmptyExpr = if (child.nullable) { @@ -110,9 +95,10 @@ case class Sum( // in case the input is nullable. The `sum` can only be null if there is no value, as // non-decimal type can produce overflowed value under non-ansi mode. if (child.nullable) { - Seq(coalesce(coalesce(sum, zero) + child.cast(resultType), sum)) + Seq(coalesce(Add(coalesce(sum, zero), child.cast(resultType), failOnError = failOnError), + sum)) } else { - Seq(coalesce(sum, zero) + child.cast(resultType)) + Seq(Add(coalesce(sum, zero), child.cast(resultType), failOnError = failOnError)) } } } @@ -129,7 +115,7 @@ case class Sum( * isEmpty: Set to false if either one of the left or right is set to false. This * means we have seen atleast a value that was not null. */ - override lazy val mergeExpressions: Seq[Expression] = { + protected def getMergeExpressions: Seq[Expression] = { resultType match { case _: DecimalType => val bufferOverflow = !isEmpty.left && sum.left.isNull @@ -143,7 +129,9 @@ case class Sum( // overflow happens. KnownNotNull(sum.left) + KnownNotNull(sum.right)), isEmpty.left && isEmpty.right) - case _ => Seq(coalesce(coalesce(sum.left, zero) + sum.right, sum.left)) + case _ => Seq(coalesce( + Add(coalesce(sum.left, zero), sum.right, failOnError = failOnError), + sum.left)) } } @@ -154,15 +142,98 @@ case class Sum( * So now, if ansi is enabled, then throw exception, if not then return null. * If sum is not null, then return the sum. */ - override lazy val evaluateExpression: Expression = resultType match { + protected def getEvaluateExpression: Expression = resultType match { case d: DecimalType => If(isEmpty, Literal.create(null, resultType), CheckOverflowInSum(sum, d, !failOnError)) case _ => sum } - override protected def withNewChildInternal(newChild: Expression): Sum = copy(child = newChild) - // The flag `failOnError` won't be shown in the `toString` or `toAggString` methods override def flatArguments: Iterator[Any] = Iterator(child) } + +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the sum calculated from values of a group.", + examples = """ + Examples: + > SELECT _FUNC_(col) FROM VALUES (5), (10), (15) AS tab(col); + 30 + > SELECT _FUNC_(col) FROM VALUES (NULL), (10), (15) AS tab(col); + 25 + > SELECT _FUNC_(col) FROM VALUES (NULL), (NULL) AS tab(col); + NULL + """, + group = "agg_funcs", + since = "1.0.0") +case class Sum( + child: Expression, + failOnError: Boolean = SQLConf.get.ansiEnabled) + extends SumBase(child) { + def this(child: Expression) = this(child, failOnError = SQLConf.get.ansiEnabled) + + override protected def withNewChildInternal(newChild: Expression): Sum = copy(child = newChild) + + override lazy val updateExpressions: Seq[Expression] = getUpdateExpressions + + override lazy val mergeExpressions: Seq[Expression] = getMergeExpressions + + override lazy val evaluateExpression: Expression = getEvaluateExpression +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the sum calculated from values of a group and the result is null on overflow.", + examples = """ + Examples: + > SELECT _FUNC_(col) FROM VALUES (5), (10), (15) AS tab(col); + 30 + > SELECT _FUNC_(col) FROM VALUES (NULL), (10), (15) AS tab(col); + 25 + > SELECT _FUNC_(col) FROM VALUES (NULL), (NULL) AS tab(col); + NULL + > SELECT _FUNC_(col) FROM VALUES (9223372036854775807L), (1L) AS tab(col); + NULL + """, + since = "3.3.0", + group = "agg_funcs") +// scalastyle:on line.size.limit +case class TrySum(child: Expression) extends SumBase(child) { + + override def failOnError: Boolean = dataType match { + // Double type won't fail, thus the failOnError is always false + // For decimal type, it returns NULL on overflow. It behaves the same as TrySum when + // `failOnError` is false. + case _: DoubleType | _: DecimalType => false + case _ => true + } + + override lazy val updateExpressions: Seq[Expression] = + if (failOnError) { + val expressions = getUpdateExpressions + // If the length of updateExpressions is larger than 1, the tail expressions are for + // tracking whether the input is empty, which doesn't need `TryEval` execution. + Seq(TryEval(expressions.head)) ++ expressions.tail + } else { + getUpdateExpressions + } + + override lazy val mergeExpressions: Seq[Expression] = + if (failOnError) { + getMergeExpressions.map(TryEval) + } else { + getMergeExpressions + } + + override lazy val evaluateExpression: Expression = + if (failOnError) { + TryEval(getEvaluateExpression) + } else { + getEvaluateExpression + } + + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) + + override def prettyName: String = "try_sum" +} diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 386dd1f..1afba46 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -1,6 +1,6 @@ <!-- Automatically generated by ExpressionsSchemaSuite --> ## Summary - - Number of queries: 382 + - Number of queries: 383 - Number of expressions that missing example: 12 - Expressions missing examples: bigint,binary,boolean,date,decimal,double,float,int,smallint,string,timestamp,tinyint ## Schema of Built-in Functions @@ -376,6 +376,7 @@ | org.apache.spark.sql.catalyst.expressions.aggregate.StddevSamp | stddev | SELECT stddev(col) FROM VALUES (1), (2), (3) AS tab(col) | struct<stddev(col):double> | | org.apache.spark.sql.catalyst.expressions.aggregate.StddevSamp | stddev_samp | SELECT stddev_samp(col) FROM VALUES (1), (2), (3) AS tab(col) | struct<stddev_samp(col):double> | | org.apache.spark.sql.catalyst.expressions.aggregate.Sum | sum | SELECT sum(col) FROM VALUES (5), (10), (15) AS tab(col) | struct<sum(col):bigint> | +| org.apache.spark.sql.catalyst.expressions.aggregate.TrySum | try_sum | SELECT try_sum(col) FROM VALUES (5), (10), (15) AS tab(col) | struct<try_sum(col):bigint> | | org.apache.spark.sql.catalyst.expressions.aggregate.VariancePop | var_pop | SELECT var_pop(col) FROM VALUES (1), (2), (3) AS tab(col) | struct<var_pop(col):double> | | org.apache.spark.sql.catalyst.expressions.aggregate.VarianceSamp | var_samp | SELECT var_samp(col) FROM VALUES (1), (2), (3) AS tab(col) | struct<var_samp(col):double> | | org.apache.spark.sql.catalyst.expressions.aggregate.VarianceSamp | variance | SELECT variance(col) FROM VALUES (1), (2), (3) AS tab(col) | struct<variance(col):double> | diff --git a/sql/core/src/test/resources/sql-tests/inputs/ansi/try_aggregates.sql b/sql/core/src/test/resources/sql-tests/inputs/ansi/try_aggregates.sql new file mode 100644 index 0000000..f5b44d2 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/ansi/try_aggregates.sql @@ -0,0 +1 @@ +--IMPORT try_aggregates.sql diff --git a/sql/core/src/test/resources/sql-tests/inputs/try_aggregates.sql b/sql/core/src/test/resources/sql-tests/inputs/try_aggregates.sql new file mode 100644 index 0000000..ffa8eef --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/try_aggregates.sql @@ -0,0 +1,13 @@ +-- try_sum +SELECT try_sum(col) FROM VALUES (5), (10), (15) AS tab(col); +SELECT try_sum(col) FROM VALUES (5.0), (10.0), (15.0) AS tab(col); +SELECT try_sum(col) FROM VALUES (NULL), (10), (15) AS tab(col); +SELECT try_sum(col) FROM VALUES (NULL), (NULL) AS tab(col); +SELECT try_sum(col) FROM VALUES (9223372036854775807L), (1L) AS tab(col); +-- test overflow in Decimal(38, 0) +SELECT try_sum(col) FROM VALUES (98765432109876543210987654321098765432BD), (98765432109876543210987654321098765432BD) AS tab(col); + +SELECT try_sum(col) FROM VALUES (interval '1 months'), (interval '1 months') AS tab(col); +SELECT try_sum(col) FROM VALUES (interval '2147483647 months'), (interval '1 months') AS tab(col); +SELECT try_sum(col) FROM VALUES (interval '1 seconds'), (interval '1 seconds') AS tab(col); +SELECT try_sum(col) FROM VALUES (interval '106751991 DAYS'), (interval '1 DAYS') AS tab(col); diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/try_aggregates.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/try_aggregates.sql.out new file mode 100644 index 0000000..7ae217a --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/ansi/try_aggregates.sql.out @@ -0,0 +1,82 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 10 + + +-- !query +SELECT try_sum(col) FROM VALUES (5), (10), (15) AS tab(col) +-- !query schema +struct<try_sum(col):bigint> +-- !query output +30 + + +-- !query +SELECT try_sum(col) FROM VALUES (5.0), (10.0), (15.0) AS tab(col) +-- !query schema +struct<try_sum(col):decimal(13,1)> +-- !query output +30.0 + + +-- !query +SELECT try_sum(col) FROM VALUES (NULL), (10), (15) AS tab(col) +-- !query schema +struct<try_sum(col):bigint> +-- !query output +25 + + +-- !query +SELECT try_sum(col) FROM VALUES (NULL), (NULL) AS tab(col) +-- !query schema +struct<try_sum(col):double> +-- !query output +NULL + + +-- !query +SELECT try_sum(col) FROM VALUES (9223372036854775807L), (1L) AS tab(col) +-- !query schema +struct<try_sum(col):bigint> +-- !query output +NULL + + +-- !query +SELECT try_sum(col) FROM VALUES (98765432109876543210987654321098765432BD), (98765432109876543210987654321098765432BD) AS tab(col) +-- !query schema +struct<try_sum(col):decimal(38,0)> +-- !query output +NULL + + +-- !query +SELECT try_sum(col) FROM VALUES (interval '1 months'), (interval '1 months') AS tab(col) +-- !query schema +struct<try_sum(col):interval month> +-- !query output +0-2 + + +-- !query +SELECT try_sum(col) FROM VALUES (interval '2147483647 months'), (interval '1 months') AS tab(col) +-- !query schema +struct<try_sum(col):interval month> +-- !query output +NULL + + +-- !query +SELECT try_sum(col) FROM VALUES (interval '1 seconds'), (interval '1 seconds') AS tab(col) +-- !query schema +struct<try_sum(col):interval second> +-- !query output +0 00:00:02.000000000 + + +-- !query +SELECT try_sum(col) FROM VALUES (interval '106751991 DAYS'), (interval '1 DAYS') AS tab(col) +-- !query schema +struct<try_sum(col):interval day> +-- !query output +NULL diff --git a/sql/core/src/test/resources/sql-tests/results/try_aggregates.sql.out b/sql/core/src/test/resources/sql-tests/results/try_aggregates.sql.out new file mode 100644 index 0000000..7ae217a --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/try_aggregates.sql.out @@ -0,0 +1,82 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 10 + + +-- !query +SELECT try_sum(col) FROM VALUES (5), (10), (15) AS tab(col) +-- !query schema +struct<try_sum(col):bigint> +-- !query output +30 + + +-- !query +SELECT try_sum(col) FROM VALUES (5.0), (10.0), (15.0) AS tab(col) +-- !query schema +struct<try_sum(col):decimal(13,1)> +-- !query output +30.0 + + +-- !query +SELECT try_sum(col) FROM VALUES (NULL), (10), (15) AS tab(col) +-- !query schema +struct<try_sum(col):bigint> +-- !query output +25 + + +-- !query +SELECT try_sum(col) FROM VALUES (NULL), (NULL) AS tab(col) +-- !query schema +struct<try_sum(col):double> +-- !query output +NULL + + +-- !query +SELECT try_sum(col) FROM VALUES (9223372036854775807L), (1L) AS tab(col) +-- !query schema +struct<try_sum(col):bigint> +-- !query output +NULL + + +-- !query +SELECT try_sum(col) FROM VALUES (98765432109876543210987654321098765432BD), (98765432109876543210987654321098765432BD) AS tab(col) +-- !query schema +struct<try_sum(col):decimal(38,0)> +-- !query output +NULL + + +-- !query +SELECT try_sum(col) FROM VALUES (interval '1 months'), (interval '1 months') AS tab(col) +-- !query schema +struct<try_sum(col):interval month> +-- !query output +0-2 + + +-- !query +SELECT try_sum(col) FROM VALUES (interval '2147483647 months'), (interval '1 months') AS tab(col) +-- !query schema +struct<try_sum(col):interval month> +-- !query output +NULL + + +-- !query +SELECT try_sum(col) FROM VALUES (interval '1 seconds'), (interval '1 seconds') AS tab(col) +-- !query schema +struct<try_sum(col):interval second> +-- !query output +0 00:00:02.000000000 + + +-- !query +SELECT try_sum(col) FROM VALUES (interval '106751991 DAYS'), (interval '1 DAYS') AS tab(col) +-- !query schema +struct<try_sum(col):interval day> +-- !query output +NULL --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org