cloud-fan commented on code in PR #45392:
URL: https://github.com/apache/spark/pull/45392#discussion_r1514386052
##########
sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala:
##########
@@ -2200,6 +2200,115 @@ class DataFrameAggregateSuite extends QueryTest
checkAnswer(df, Row(1, 2, 2) :: Row(3, 1, 1) :: Nil)
}
}
+
+ private def assertDecimalSumOverflow(
+ df: DataFrame, ansiEnabled: Boolean, expectedAnswer: Row): Unit = {
+ if (!ansiEnabled) {
+ checkAnswer(df, expectedAnswer)
+ } else {
+ val e = intercept[ArithmeticException] {
+ df.collect()
+ }
+ assert(e.getMessage.contains("cannot be represented as Decimal") ||
+ e.getMessage.contains("Overflow in sum of decimals"))
+ }
+ }
+
+ def checkAggResultsForDecimalOverflow(aggFn: Column => Column): Unit = {
+ Seq("true", "false").foreach { wholeStageEnabled =>
+ withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStageEnabled))
{
+ Seq(true, false).foreach { ansiEnabled =>
+ withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) {
+ val df0 = Seq(
+ (BigDecimal("10000000000000000000"), 1),
+ (BigDecimal("10000000000000000000"), 1),
+ (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")
+ val df1 = Seq(
+ (BigDecimal("10000000000000000000"), 2),
+ (BigDecimal("10000000000000000000"), 2),
+ (BigDecimal("10000000000000000000"), 2),
+ (BigDecimal("10000000000000000000"), 2),
+ (BigDecimal("10000000000000000000"), 2),
+ (BigDecimal("10000000000000000000"), 2),
+ (BigDecimal("10000000000000000000"), 2),
+ (BigDecimal("10000000000000000000"), 2),
+ (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")
+ val df = df0.union(df1)
+ val df2 = df.withColumnRenamed("decNum", "decNum2").
+ join(df, "intNum").agg(aggFn($"decNum"))
+
+ val expectedAnswer = Row(null)
+ assertDecimalSumOverflow(df2, ansiEnabled, expectedAnswer)
+
+ val decStr = "1" + "0" * 19
+ val d1 = spark.range(0, 12, 1, 1)
+ val d2 = d1.select(expr(s"cast('$decStr' as decimal (38, 18)) as
d")).agg(aggFn($"d"))
+ assertDecimalSumOverflow(d2, ansiEnabled, expectedAnswer)
+
+ val d3 = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1))
+ val d4 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as
d")).agg(aggFn($"d"))
+ assertDecimalSumOverflow(d4, ansiEnabled, expectedAnswer)
+
+ val d5 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as
d"),
+
lit(1).as("key")).groupBy("key").agg(aggFn($"d").alias("aggd")).select($"aggd")
+ assertDecimalSumOverflow(d5, ansiEnabled, expectedAnswer)
+
+ val nullsDf = spark.range(1, 4, 1).select(expr(s"cast(null as
decimal(38,18)) as d"))
+
+ val largeDecimals = Seq(BigDecimal("1"* 20 + ".123"),
BigDecimal("9"* 20 + ".123")).
+ toDF("d")
+ assertDecimalSumOverflow(
+ nullsDf.union(largeDecimals).agg(aggFn($"d")), ansiEnabled,
expectedAnswer)
+
+ val df3 = Seq(
+ (BigDecimal("10000000000000000000"), 1),
+ (BigDecimal("50000000000000000000"), 1),
+ (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")
+
+ val df4 = Seq(
+ (BigDecimal("10000000000000000000"), 1),
+ (BigDecimal("10000000000000000000"), 1),
+ (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")
+
+ val df5 = Seq(
+ (BigDecimal("10000000000000000000"), 1),
+ (BigDecimal("10000000000000000000"), 1),
+ (BigDecimal("20000000000000000000"), 2)).toDF("decNum", "intNum")
+
+ val df6 = df3.union(df4).union(df5)
+ val df7 = df6.groupBy("intNum").agg(sum("decNum"),
countDistinct("decNum")).
+ filter("intNum == 1")
+ assertDecimalSumOverflow(df7, ansiEnabled, Row(1, null, 2))
+ }
+ }
+ }
+ }
+ }
+
+ test("SPARK-28067: Aggregate sum should not return wrong results for decimal
overflow") {
+ checkAggResultsForDecimalOverflow(c => sum(c))
+ }
+
+ test("SPARK-35955: Aggregate avg should not return wrong results for decimal
overflow") {
+ checkAggResultsForDecimalOverflow(c => avg(c))
+ }
+
+ test("SPARK-28224: Aggregate sum big decimal overflow") {
+ val largeDecimals = spark.sparkContext.parallelize(
+ DecimalData(BigDecimal("1"* 20 + ".123"), BigDecimal("1"* 20 + ".123"))
::
+ DecimalData(BigDecimal("9"* 20 + ".123"), BigDecimal("9"* 20 +
".123")) :: Nil).toDF()
+
+ Seq(true, false).foreach { ansiEnabled =>
+ withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) {
+ val structDf = largeDecimals.select("a").agg(sum("a"))
+ assertDecimalSumOverflow(structDf, ansiEnabled, Row(null))
+ }
+ }
+ }
+
+ test("SPARK-32761: aggregating multiple distinct CONSTANT columns") {
+ checkAnswer(sql("select count(distinct 2), count(distinct 2,3)"), Row(1,
1))
+ }
}
Review Comment:
just for curiosity, what's the runtime of `DataFrameAggregateSuite`? It has
more than 2000 LOC...
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]