This is an automated email from the ASF dual-hosted git repository. yamamuro pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new b79cf0d [SPARK-28224][SQL] Check overflow in decimal Sum aggregate b79cf0d is described below commit b79cf0d14351c741efe4f27523919a0e24b8b2ed Author: Mick Jermsurawong <mickjermsuraw...@stripe.com> AuthorDate: Tue Aug 20 09:47:04 2019 +0900 [SPARK-28224][SQL] Check overflow in decimal Sum aggregate ## What changes were proposed in this pull request? - Currently `sum` in aggregates for decimal type can overflow and return null. - `Sum` expression codegens arithmetic on `sql.Decimal` and the output which preserves scale and precision goes into `UnsafeRowWriter`. Here overflowing will be converted to null when writing out. - It also does not go through this branch in `DecimalAggregates` because it's expecting precision of the sum (not the elements to be summed) to be less than 5. https://github.com/apache/spark/blob/4ebff5b6d68f26cc1ff9265a5489e0d7c2e05449/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala#L1400-L1403 - This PR adds the check at the final result of the sum operator itself. https://github.com/apache/spark/blob/4ebff5b6d68f26cc1ff9265a5489e0d7c2e05449/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala#L372-L376 https://issues.apache.org/jira/browse/SPARK-28224 ## How was this patch tested? - Added an integration test on dataframe suite cc mgaido91 JoshRosen Closes #25033 from mickjermsurawong-stripe/SPARK-28224. Authored-by: Mick Jermsurawong <mickjermsuraw...@stripe.com> Signed-off-by: Takeshi Yamamuro <yamam...@apache.org> --- .../sql/catalyst/expressions/aggregate/Sum.scala | 7 ++++++- .../org/apache/spark/sql/DataFrameSuite.scala | 23 +++++++++++++++++++++- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index ef204ec..d04fe92 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @ExpressionDescription( @@ -89,5 +90,9 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast ) } - override lazy val evaluateExpression: Expression = sum + override lazy val evaluateExpression: Expression = resultType match { + case d: DecimalType => CheckOverflow(sum, d, SQLConf.get.decimalOperationsNullOnOverflow) + case _ => sum + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index ba8fced..c6daff1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExc import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSparkSession} -import org.apache.spark.sql.test.SQLTestData.{NullStrings, TestData2} +import org.apache.spark.sql.test.SQLTestData.{DecimalData, NullStrings, TestData2} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -156,6 +156,27 @@ class DataFrameSuite extends QueryTest with SharedSparkSession { structDf.select(xxhash64($"a", $"record.*"))) } + test("SPARK-28224: Aggregate sum big decimal overflow") { + val largeDecimals = spark.sparkContext.parallelize( + DecimalData(BigDecimal("1"* 20 + ".123"), BigDecimal("1"* 20 + ".123")) :: + DecimalData(BigDecimal("9"* 20 + ".123"), BigDecimal("9"* 20 + ".123")) :: Nil).toDF() + + Seq(true, false).foreach { nullOnOverflow => + withSQLConf((SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key, nullOnOverflow.toString)) { + val structDf = largeDecimals.select("a").agg(sum("a")) + if (nullOnOverflow) { + checkAnswer(structDf, Row(null)) + } else { + val e = intercept[SparkException] { + structDf.collect + } + assert(e.getCause.getClass.equals(classOf[ArithmeticException])) + assert(e.getCause.getMessage.contains("cannot be represented as Decimal")) + } + } + } + } + test("Star Expansion - explode should fail with a meaningful message if it takes a star") { val df = Seq(("1,2"), ("4"), ("7,8,9")).toDF("csv") val e = intercept[AnalysisException] { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org