Repository: spark Updated Branches: refs/heads/master 75d666b95 -> 28f9f3f22
[SPARK-22271][SQL] mean overflows and returns null for some decimal variables ## What changes were proposed in this pull request? In Average.scala, it has ``` override lazy val evaluateExpression = child.dataType match { case DecimalType.Fixed(p, s) => // increase the precision and scale to prevent precision loss val dt = DecimalType.bounded(p + 14, s + 4) Cast(Cast(sum, dt) / Cast(count, dt), resultType) case _ => Cast(sum, resultType) / Cast(count, resultType) } def setChild (newchild: Expression) = { child = newchild } ``` It is possible that Cast(count, dt), resultType) will make the precision of the decimal number bigger than 38, and this causes over flow. Since count is an integer and doesn't need a scale, I will cast it using DecimalType.bounded(38,0) ## How was this patch tested? In DataFrameSuite, I will add a test case. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Huaxin Gao <huax...@us.ibm.com> Closes #19496 from huaxingao/spark-22271. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/28f9f3f2 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/28f9f3f2 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/28f9f3f2 Branch: refs/heads/master Commit: 28f9f3f22511e9f2f900764d9bd5b90d2eeee773 Parents: 75d666b Author: Huaxin Gao <huax...@us.ibm.com> Authored: Tue Oct 17 12:50:41 2017 -0700 Committer: gatorsmile <gatorsm...@gmail.com> Committed: Tue Oct 17 12:50:41 2017 -0700 ---------------------------------------------------------------------- .../spark/sql/catalyst/expressions/aggregate/Average.scala | 3 ++- .../test/scala/org/apache/spark/sql/DataFrameSuite.scala | 9 +++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/28f9f3f2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index c423e17..708bdbf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -80,7 +80,8 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit case DecimalType.Fixed(p, s) => // increase the precision and scale to prevent precision loss val dt = DecimalType.bounded(p + 14, s + 4) - Cast(Cast(sum, dt) / Cast(count, dt), resultType) + Cast(Cast(sum, dt) / Cast(count, DecimalType.bounded(DecimalType.MAX_PRECISION, 0)), + resultType) case _ => Cast(sum, resultType) / Cast(count, resultType) } http://git-wip-us.apache.org/repos/asf/spark/blob/28f9f3f2/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 50de2fd..473c355 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2105,4 +2105,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { testData2.select(lit(7), 'a, 'b).orderBy(lit(1), lit(2), lit(3)), Seq(Row(7, 1, 1), Row(7, 1, 2), Row(7, 2, 1), Row(7, 2, 2), Row(7, 3, 1), Row(7, 3, 2))) } + + test("SPARK-22271: mean overflows and returns null for some decimal variables") { + val d = 0.034567890 + val df = Seq(d, d, d, d, d, d, d, d, d, d).toDF("DecimalCol") + val result = df.select('DecimalCol cast DecimalType(38, 33)) + .select(col("DecimalCol")).describe() + val mean = result.select("DecimalCol").where($"summary" === "mean") + assert(mean.collect().toSet === Set(Row("0.0345678900000000000000000000000000000"))) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org