This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c34fee4  [SPARK-38548][SQL] New SQL function: try_sum
c34fee4 is described below

commit c34fee4d20da9ab5b4f1f26185fc1a9a83b99d05
Author: Gengliang Wang <gengli...@apache.org>
AuthorDate: Mon Mar 21 14:26:46 2022 +0800

    [SPARK-38548][SQL] New SQL function: try_sum
    
    ### What changes were proposed in this pull request?
    
    Add a new SQL function: try_sum. It is identical to the function `sum`, 
except that it returns `NULL` result instead of throwing an exception on 
integral/decimal value overflow.
    Note it is different from sum when ANSI mode is off:
    | Function         | Sum                                | TrySum      |
    |------------------|------------------------------------|-------------|
    | Decimal overflow | Return NULL                        | Return NULL |
    | Integer overflow | Return lower 64 bits of the result | Return NULL |
    
    ### Why are the changes needed?
    
    * Users can manage to finish queries without interruptions in ANSI mode.
    * Users can get NULLs instead of unreasonable results if overflow occurs 
when ANSI mode is off. For example
    ```
    > SELECT sum(col) FROM VALUES (9223372036854775807L), (1L) AS tab(col);
    -9223372036854775808
    
    > SELECT try_sum(col) FROM VALUES (9223372036854775807L), (1L) AS tab(col);
    NULL
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, a new SQL function: try_sum which is identical to the function `sum`, 
except that it returns `NULL` result instead of throwing an exception on 
integral/decimal value overflow.
    
    ### How was this patch tested?
    
    UT
    
    Closes #35848 from gengliangwang/trySum2.
    
    Authored-by: Gengliang Wang <gengli...@apache.org>
    Signed-off-by: Gengliang Wang <gengli...@apache.org>
---
 docs/sql-ref-ansi-compliance.md                    |   1 +
 .../sql/catalyst/analysis/FunctionRegistry.scala   |   1 +
 .../sql/catalyst/expressions/aggregate/Sum.scala   | 129 ++++++++++++++++-----
 .../sql-functions/sql-expression-schema.md         |   3 +-
 .../sql-tests/inputs/ansi/try_aggregates.sql       |   1 +
 .../resources/sql-tests/inputs/try_aggregates.sql  |  13 +++
 .../sql-tests/results/ansi/try_aggregates.sql.out  |  82 +++++++++++++
 .../sql-tests/results/try_aggregates.sql.out       |  82 +++++++++++++
 8 files changed, 282 insertions(+), 30 deletions(-)

diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md
index ccfc601..0f7dd5d 100644
--- a/docs/sql-ref-ansi-compliance.md
+++ b/docs/sql-ref-ansi-compliance.md
@@ -316,6 +316,7 @@ When ANSI mode is on, it throws exceptions for invalid 
operations. You can use t
   - `try_subtract`: identical to the add operator `-`, except that it returns 
`NULL` result instead of throwing an exception on integral value overflow.
   - `try_multiply`: identical to the add operator `*`, except that it returns 
`NULL` result instead of throwing an exception on integral value overflow.
   - `try_divide`: identical to the division operator `/`, except that it 
returns `NULL` result instead of throwing an exception on dividing 0.
+  - `try_sum`: identical to the function `sum`, except that it returns `NULL` 
result instead of throwing an exception on integral/decimal value overflow.
   - `try_element_at`: identical to the function `element_at`, except that it 
returns `NULL` result instead of throwing an exception on array's index out of 
bound or map's key not found.
 
 ### SQL Keywords (optional, disabled by default)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index e5954c8..a37d4b2 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -452,6 +452,7 @@ object FunctionRegistry {
     expression[TrySubtract]("try_subtract"),
     expression[TryMultiply]("try_multiply"),
     expression[TryElementAt]("try_element_at"),
+    expression[TrySum]("try_sum"),
 
     // aggregate functions
     expression[HyperLogLogPlusPlus]("approx_count_distinct"),
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
index ec7479a..5d8fd70 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
@@ -26,27 +26,11 @@ import org.apache.spark.sql.catalyst.util.TypeUtils
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 
-@ExpressionDescription(
-  usage = "_FUNC_(expr) - Returns the sum calculated from values of a group.",
-  examples = """
-    Examples:
-      > SELECT _FUNC_(col) FROM VALUES (5), (10), (15) AS tab(col);
-       30
-      > SELECT _FUNC_(col) FROM VALUES (NULL), (10), (15) AS tab(col);
-       25
-      > SELECT _FUNC_(col) FROM VALUES (NULL), (NULL) AS tab(col);
-       NULL
-  """,
-  group = "agg_funcs",
-  since = "1.0.0")
-case class Sum(
-    child: Expression,
-    failOnError: Boolean = SQLConf.get.ansiEnabled)
-  extends DeclarativeAggregate
+abstract class SumBase(child: Expression) extends DeclarativeAggregate
   with ImplicitCastInputTypes
   with UnaryLike[Expression] {
 
-  def this(child: Expression) = this(child, failOnError = 
SQLConf.get.ansiEnabled)
+  def failOnError: Boolean
 
   override def nullable: Boolean = true
 
@@ -57,7 +41,7 @@ case class Sum(
     Seq(TypeCollection(NumericType, YearMonthIntervalType, 
DayTimeIntervalType))
 
   override def checkInputDataTypes(): TypeCheckResult =
-    TypeUtils.checkForAnsiIntervalOrNumericType(child.dataType, "sum")
+    TypeUtils.checkForAnsiIntervalOrNumericType(child.dataType, prettyName)
 
   final override val nodePatterns: Seq[TreePattern] = Seq(SUM)
 
@@ -86,16 +70,17 @@ case class Sum(
     case _ => Seq(Literal(null, resultType))
   }
 
-  override lazy val updateExpressions: Seq[Expression] = {
+  protected def getUpdateExpressions: Seq[Expression] = {
     resultType match {
       case _: DecimalType =>
         // For decimal type, the initial value of `sum` is 0. We need to keep 
`sum` unchanged if
         // the input is null, as SUM function ignores null input. The `sum` 
can only be null if
         // overflow happens under non-ansi mode.
         val sumExpr = if (child.nullable) {
-          If(child.isNull, sum, sum + KnownNotNull(child).cast(resultType))
+          If(child.isNull, sum,
+            Add(sum, KnownNotNull(child).cast(resultType), failOnError = 
failOnError))
         } else {
-          sum + child.cast(resultType)
+          Add(sum, child.cast(resultType), failOnError = failOnError)
         }
         // The buffer becomes non-empty after seeing the first not-null input.
         val isEmptyExpr = if (child.nullable) {
@@ -110,9 +95,10 @@ case class Sum(
         // in case the input is nullable. The `sum` can only be null if there 
is no value, as
         // non-decimal type can produce overflowed value under non-ansi mode.
         if (child.nullable) {
-          Seq(coalesce(coalesce(sum, zero) + child.cast(resultType), sum))
+          Seq(coalesce(Add(coalesce(sum, zero), child.cast(resultType), 
failOnError = failOnError),
+            sum))
         } else {
-          Seq(coalesce(sum, zero) + child.cast(resultType))
+          Seq(Add(coalesce(sum, zero), child.cast(resultType), failOnError = 
failOnError))
         }
     }
   }
@@ -129,7 +115,7 @@ case class Sum(
    * isEmpty:  Set to false if either one of the left or right is set to 
false. This
    * means we have seen atleast a value that was not null.
    */
-  override lazy val mergeExpressions: Seq[Expression] = {
+  protected def getMergeExpressions: Seq[Expression] = {
     resultType match {
       case _: DecimalType =>
         val bufferOverflow = !isEmpty.left && sum.left.isNull
@@ -143,7 +129,9 @@ case class Sum(
             // overflow happens.
             KnownNotNull(sum.left) + KnownNotNull(sum.right)),
           isEmpty.left && isEmpty.right)
-      case _ => Seq(coalesce(coalesce(sum.left, zero) + sum.right, sum.left))
+      case _ => Seq(coalesce(
+        Add(coalesce(sum.left, zero), sum.right, failOnError = failOnError),
+        sum.left))
     }
   }
 
@@ -154,15 +142,98 @@ case class Sum(
    * So now, if ansi is enabled, then throw exception, if not then return null.
    * If sum is not null, then return the sum.
    */
-  override lazy val evaluateExpression: Expression = resultType match {
+  protected def getEvaluateExpression: Expression = resultType match {
     case d: DecimalType =>
       If(isEmpty, Literal.create(null, resultType),
         CheckOverflowInSum(sum, d, !failOnError))
     case _ => sum
   }
 
-  override protected def withNewChildInternal(newChild: Expression): Sum = 
copy(child = newChild)
-
   // The flag `failOnError` won't be shown in the `toString` or `toAggString` 
methods
   override def flatArguments: Iterator[Any] = Iterator(child)
 }
+
+@ExpressionDescription(
+  usage = "_FUNC_(expr) - Returns the sum calculated from values of a group.",
+  examples = """
+    Examples:
+      > SELECT _FUNC_(col) FROM VALUES (5), (10), (15) AS tab(col);
+       30
+      > SELECT _FUNC_(col) FROM VALUES (NULL), (10), (15) AS tab(col);
+       25
+      > SELECT _FUNC_(col) FROM VALUES (NULL), (NULL) AS tab(col);
+       NULL
+  """,
+  group = "agg_funcs",
+  since = "1.0.0")
+case class Sum(
+    child: Expression,
+    failOnError: Boolean = SQLConf.get.ansiEnabled)
+  extends SumBase(child) {
+  def this(child: Expression) = this(child, failOnError = 
SQLConf.get.ansiEnabled)
+
+  override protected def withNewChildInternal(newChild: Expression): Sum = 
copy(child = newChild)
+
+  override lazy val updateExpressions: Seq[Expression] = getUpdateExpressions
+
+  override lazy val mergeExpressions: Seq[Expression] = getMergeExpressions
+
+  override lazy val evaluateExpression: Expression = getEvaluateExpression
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = "_FUNC_(expr) - Returns the sum calculated from values of a group 
and the result is null on overflow.",
+  examples = """
+    Examples:
+      > SELECT _FUNC_(col) FROM VALUES (5), (10), (15) AS tab(col);
+       30
+      > SELECT _FUNC_(col) FROM VALUES (NULL), (10), (15) AS tab(col);
+       25
+      > SELECT _FUNC_(col) FROM VALUES (NULL), (NULL) AS tab(col);
+       NULL
+      > SELECT _FUNC_(col) FROM VALUES (9223372036854775807L), (1L) AS 
tab(col);
+       NULL
+  """,
+  since = "3.3.0",
+  group = "agg_funcs")
+// scalastyle:on line.size.limit
+case class TrySum(child: Expression) extends SumBase(child) {
+
+  override def failOnError: Boolean = dataType match {
+    // Double type won't fail, thus the failOnError is always false
+    // For decimal type, it returns NULL on overflow. It behaves the same as 
TrySum when
+    // `failOnError` is false.
+    case _: DoubleType | _: DecimalType => false
+    case _ => true
+  }
+
+  override lazy val updateExpressions: Seq[Expression] =
+    if (failOnError) {
+      val expressions = getUpdateExpressions
+      // If the length of updateExpressions is larger than 1, the tail 
expressions are for
+      // tracking whether the input is empty, which doesn't need `TryEval` 
execution.
+      Seq(TryEval(expressions.head)) ++ expressions.tail
+    } else {
+      getUpdateExpressions
+    }
+
+  override lazy val mergeExpressions: Seq[Expression] =
+    if (failOnError) {
+      getMergeExpressions.map(TryEval)
+    } else {
+      getMergeExpressions
+    }
+
+  override lazy val evaluateExpression: Expression =
+    if (failOnError) {
+      TryEval(getEvaluateExpression)
+    } else {
+      getEvaluateExpression
+    }
+
+  override protected def withNewChildInternal(newChild: Expression): 
Expression =
+    copy(child = newChild)
+
+  override def prettyName: String = "try_sum"
+}
diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md 
b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
index 386dd1f..1afba46 100644
--- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
+++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
@@ -1,6 +1,6 @@
 <!-- Automatically generated by ExpressionsSchemaSuite -->
 ## Summary
-  - Number of queries: 382
+  - Number of queries: 383
   - Number of expressions that missing example: 12
   - Expressions missing examples: 
bigint,binary,boolean,date,decimal,double,float,int,smallint,string,timestamp,tinyint
 ## Schema of Built-in Functions
@@ -376,6 +376,7 @@
 | org.apache.spark.sql.catalyst.expressions.aggregate.StddevSamp | stddev | 
SELECT stddev(col) FROM VALUES (1), (2), (3) AS tab(col) | 
struct<stddev(col):double> |
 | org.apache.spark.sql.catalyst.expressions.aggregate.StddevSamp | stddev_samp 
| SELECT stddev_samp(col) FROM VALUES (1), (2), (3) AS tab(col) | 
struct<stddev_samp(col):double> |
 | org.apache.spark.sql.catalyst.expressions.aggregate.Sum | sum | SELECT 
sum(col) FROM VALUES (5), (10), (15) AS tab(col) | struct<sum(col):bigint> |
+| org.apache.spark.sql.catalyst.expressions.aggregate.TrySum | try_sum | 
SELECT try_sum(col) FROM VALUES (5), (10), (15) AS tab(col) | 
struct<try_sum(col):bigint> |
 | org.apache.spark.sql.catalyst.expressions.aggregate.VariancePop | var_pop | 
SELECT var_pop(col) FROM VALUES (1), (2), (3) AS tab(col) | 
struct<var_pop(col):double> |
 | org.apache.spark.sql.catalyst.expressions.aggregate.VarianceSamp | var_samp 
| SELECT var_samp(col) FROM VALUES (1), (2), (3) AS tab(col) | 
struct<var_samp(col):double> |
 | org.apache.spark.sql.catalyst.expressions.aggregate.VarianceSamp | variance 
| SELECT variance(col) FROM VALUES (1), (2), (3) AS tab(col) | 
struct<variance(col):double> |
diff --git 
a/sql/core/src/test/resources/sql-tests/inputs/ansi/try_aggregates.sql 
b/sql/core/src/test/resources/sql-tests/inputs/ansi/try_aggregates.sql
new file mode 100644
index 0000000..f5b44d2
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/ansi/try_aggregates.sql
@@ -0,0 +1 @@
+--IMPORT try_aggregates.sql
diff --git a/sql/core/src/test/resources/sql-tests/inputs/try_aggregates.sql 
b/sql/core/src/test/resources/sql-tests/inputs/try_aggregates.sql
new file mode 100644
index 0000000..ffa8eef
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/try_aggregates.sql
@@ -0,0 +1,13 @@
+-- try_sum
+SELECT try_sum(col) FROM VALUES (5), (10), (15) AS tab(col);
+SELECT try_sum(col) FROM VALUES (5.0), (10.0), (15.0) AS tab(col);
+SELECT try_sum(col) FROM VALUES (NULL), (10), (15) AS tab(col);
+SELECT try_sum(col) FROM VALUES (NULL), (NULL) AS tab(col);
+SELECT try_sum(col) FROM VALUES (9223372036854775807L), (1L) AS tab(col);
+-- test overflow in Decimal(38, 0)
+SELECT try_sum(col) FROM VALUES (98765432109876543210987654321098765432BD), 
(98765432109876543210987654321098765432BD) AS tab(col);
+
+SELECT try_sum(col) FROM VALUES (interval '1 months'), (interval '1 months') 
AS tab(col);
+SELECT try_sum(col) FROM VALUES (interval '2147483647 months'), (interval '1 
months') AS tab(col);
+SELECT try_sum(col) FROM VALUES (interval '1 seconds'), (interval '1 seconds') 
AS tab(col);
+SELECT try_sum(col) FROM VALUES (interval '106751991 DAYS'), (interval '1 
DAYS') AS tab(col);
diff --git 
a/sql/core/src/test/resources/sql-tests/results/ansi/try_aggregates.sql.out 
b/sql/core/src/test/resources/sql-tests/results/ansi/try_aggregates.sql.out
new file mode 100644
index 0000000..7ae217a
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/try_aggregates.sql.out
@@ -0,0 +1,82 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 10
+
+
+-- !query
+SELECT try_sum(col) FROM VALUES (5), (10), (15) AS tab(col)
+-- !query schema
+struct<try_sum(col):bigint>
+-- !query output
+30
+
+
+-- !query
+SELECT try_sum(col) FROM VALUES (5.0), (10.0), (15.0) AS tab(col)
+-- !query schema
+struct<try_sum(col):decimal(13,1)>
+-- !query output
+30.0
+
+
+-- !query
+SELECT try_sum(col) FROM VALUES (NULL), (10), (15) AS tab(col)
+-- !query schema
+struct<try_sum(col):bigint>
+-- !query output
+25
+
+
+-- !query
+SELECT try_sum(col) FROM VALUES (NULL), (NULL) AS tab(col)
+-- !query schema
+struct<try_sum(col):double>
+-- !query output
+NULL
+
+
+-- !query
+SELECT try_sum(col) FROM VALUES (9223372036854775807L), (1L) AS tab(col)
+-- !query schema
+struct<try_sum(col):bigint>
+-- !query output
+NULL
+
+
+-- !query
+SELECT try_sum(col) FROM VALUES (98765432109876543210987654321098765432BD), 
(98765432109876543210987654321098765432BD) AS tab(col)
+-- !query schema
+struct<try_sum(col):decimal(38,0)>
+-- !query output
+NULL
+
+
+-- !query
+SELECT try_sum(col) FROM VALUES (interval '1 months'), (interval '1 months') 
AS tab(col)
+-- !query schema
+struct<try_sum(col):interval month>
+-- !query output
+0-2
+
+
+-- !query
+SELECT try_sum(col) FROM VALUES (interval '2147483647 months'), (interval '1 
months') AS tab(col)
+-- !query schema
+struct<try_sum(col):interval month>
+-- !query output
+NULL
+
+
+-- !query
+SELECT try_sum(col) FROM VALUES (interval '1 seconds'), (interval '1 seconds') 
AS tab(col)
+-- !query schema
+struct<try_sum(col):interval second>
+-- !query output
+0 00:00:02.000000000
+
+
+-- !query
+SELECT try_sum(col) FROM VALUES (interval '106751991 DAYS'), (interval '1 
DAYS') AS tab(col)
+-- !query schema
+struct<try_sum(col):interval day>
+-- !query output
+NULL
diff --git 
a/sql/core/src/test/resources/sql-tests/results/try_aggregates.sql.out 
b/sql/core/src/test/resources/sql-tests/results/try_aggregates.sql.out
new file mode 100644
index 0000000..7ae217a
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/try_aggregates.sql.out
@@ -0,0 +1,82 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 10
+
+
+-- !query
+SELECT try_sum(col) FROM VALUES (5), (10), (15) AS tab(col)
+-- !query schema
+struct<try_sum(col):bigint>
+-- !query output
+30
+
+
+-- !query
+SELECT try_sum(col) FROM VALUES (5.0), (10.0), (15.0) AS tab(col)
+-- !query schema
+struct<try_sum(col):decimal(13,1)>
+-- !query output
+30.0
+
+
+-- !query
+SELECT try_sum(col) FROM VALUES (NULL), (10), (15) AS tab(col)
+-- !query schema
+struct<try_sum(col):bigint>
+-- !query output
+25
+
+
+-- !query
+SELECT try_sum(col) FROM VALUES (NULL), (NULL) AS tab(col)
+-- !query schema
+struct<try_sum(col):double>
+-- !query output
+NULL
+
+
+-- !query
+SELECT try_sum(col) FROM VALUES (9223372036854775807L), (1L) AS tab(col)
+-- !query schema
+struct<try_sum(col):bigint>
+-- !query output
+NULL
+
+
+-- !query
+SELECT try_sum(col) FROM VALUES (98765432109876543210987654321098765432BD), 
(98765432109876543210987654321098765432BD) AS tab(col)
+-- !query schema
+struct<try_sum(col):decimal(38,0)>
+-- !query output
+NULL
+
+
+-- !query
+SELECT try_sum(col) FROM VALUES (interval '1 months'), (interval '1 months') 
AS tab(col)
+-- !query schema
+struct<try_sum(col):interval month>
+-- !query output
+0-2
+
+
+-- !query
+SELECT try_sum(col) FROM VALUES (interval '2147483647 months'), (interval '1 
months') AS tab(col)
+-- !query schema
+struct<try_sum(col):interval month>
+-- !query output
+NULL
+
+
+-- !query
+SELECT try_sum(col) FROM VALUES (interval '1 seconds'), (interval '1 seconds') 
AS tab(col)
+-- !query schema
+struct<try_sum(col):interval second>
+-- !query output
+0 00:00:02.000000000
+
+
+-- !query
+SELECT try_sum(col) FROM VALUES (interval '106751991 DAYS'), (interval '1 
DAYS') AS tab(col)
+-- !query schema
+struct<try_sum(col):interval day>
+-- !query output
+NULL

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to