This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new e6839ad [SPARK-37839][SQL][FOLLOWUP] Check overflow when DS V2
partial aggregate push-down `AVG`
e6839ad is described below
commit e6839ad7340bc9eb5df03df2a62110bdda805e6b
Author: Jiaan Geng <[email protected]>
AuthorDate: Thu Mar 31 19:18:58 2022 +0800
[SPARK-37839][SQL][FOLLOWUP] Check overflow when DS V2 partial aggregate
push-down `AVG`
### What changes were proposed in this pull request?
https://github.com/apache/spark/pull/35130 supports partial aggregate
push-down `AVG` for DS V2.
The behavior doesn't consistent with `Average` if occurs overflow in ansi
mode.
This PR closely follows the implement of `Average` to respect overflow in
ansi mode.
### Why are the changes needed?
Make the behavior consistent with `Average` if occurs overflow in ansi mode.
### Does this PR introduce _any_ user-facing change?
'Yes'.
Users could see the exception about overflow throws in ansi mode.
### How was this patch tested?
New tests.
Closes #35320 from beliefer/SPARK-37839_followup.
Authored-by: Jiaan Geng <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../catalyst/expressions/aggregate/Average.scala | 4 +-
.../datasources/v2/V2ScanRelationPushDown.scala | 21 ++++------
.../org/apache/spark/sql/jdbc/JDBCV2Suite.scala | 45 +++++++++++++++++++++-
3 files changed, 52 insertions(+), 18 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
index 05f7eda..533f7f2 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
@@ -76,8 +76,8 @@ case class Average(
case _ => DoubleType
}
- private lazy val sum = AttributeReference("sum", sumDataType)()
- private lazy val count = AttributeReference("count", LongType)()
+ lazy val sum = AttributeReference("sum", sumDataType)()
+ lazy val count = AttributeReference("count", LongType)()
override lazy val aggBufferAttributes = sum :: count :: Nil
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
index c8ef8b0..5371829 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2
import scala.collection.mutable
-import org.apache.spark.sql.catalyst.expressions.{Alias, AliasHelper, And,
Attribute, AttributeReference, Cast, Divide, DivideDTInterval,
DivideYMInterval, EqualTo, Expression, If, IntegerLiteral, Literal,
NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder,
SubqueryExpression}
+import org.apache.spark.sql.catalyst.expressions.{Alias, AliasHelper, And,
Attribute, AttributeReference, Cast, Expression, IntegerLiteral,
NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder,
SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.optimizer.CollapseProject
@@ -32,7 +32,7 @@ import
org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder,
SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.sources
-import org.apache.spark.sql.types.{DataType, DayTimeIntervalType, LongType,
StructType, YearMonthIntervalType}
+import org.apache.spark.sql.types.{DataType, LongType, StructType}
import org.apache.spark.sql.util.SchemaUtils._
object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper
with AliasHelper {
@@ -138,18 +138,11 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan]
with PredicateHelper wit
case AggregateExpression(avg: aggregate.Average, _,
isDistinct, _, _) =>
val sum =
aggregate.Sum(avg.child).toAggregateExpression(isDistinct)
val count =
aggregate.Count(avg.child).toAggregateExpression(isDistinct)
- // Closely follow `Average.evaluateExpression`
- avg.dataType match {
- case _: YearMonthIntervalType =>
- If(EqualTo(count, Literal(0L)),
- Literal(null, YearMonthIntervalType()),
DivideYMInterval(sum, count))
- case _: DayTimeIntervalType =>
- If(EqualTo(count, Literal(0L)),
- Literal(null, DayTimeIntervalType()),
DivideDTInterval(sum, count))
- case _ =>
- // TODO deal with the overflow issue
- Divide(addCastIfNeeded(sum, avg.dataType),
- addCastIfNeeded(count, avg.dataType), false)
+ avg.evaluateExpression transform {
+ case a: Attribute if a.semanticEquals(avg.sum) =>
+ addCastIfNeeded(sum, avg.sum.dataType)
+ case a: Attribute if a.semanticEquals(avg.count) =>
+ addCastIfNeeded(count, avg.count.dataType)
}
}
}.asInstanceOf[Seq[NamedExpression]]
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index a5e3a71..67a0290 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -95,6 +95,14 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession
with ExplainSuiteHel
"""CREATE TABLE "test"."view1" ("|col1" INTEGER, "|col2"
INTEGER)""").executeUpdate()
conn.prepareStatement(
"""CREATE TABLE "test"."view2" ("|col1" INTEGER, "|col3"
INTEGER)""").executeUpdate()
+
+ conn.prepareStatement(
+ "CREATE TABLE \"test\".\"item\" (id INTEGER, name TEXT(32), price
NUMERIC(23, 3))")
+ .executeUpdate()
+ conn.prepareStatement("INSERT INTO \"test\".\"item\" VALUES " +
+ "(1, 'bottle', 11111111111111111111.123)").executeUpdate()
+ conn.prepareStatement("INSERT INTO \"test\".\"item\" VALUES " +
+ "(1, 'bottle', 99999999999999999999.123)").executeUpdate()
}
}
@@ -484,8 +492,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession
with ExplainSuiteHel
test("show tables") {
checkAnswer(sql("SHOW TABLES IN h2.test"),
Seq(Row("test", "people", false), Row("test", "empty_table", false),
- Row("test", "employee", false), Row("test", "dept", false),
Row("test", "person", false),
- Row("test", "view1", false), Row("test", "view2", false)))
+ Row("test", "employee", false), Row("test", "item", false),
Row("test", "dept", false),
+ Row("test", "person", false), Row("test", "view1", false), Row("test",
"view2", false)))
}
test("SQL API: create table as select") {
@@ -1106,4 +1114,37 @@ class JDBCV2Suite extends QueryTest with
SharedSparkSession with ExplainSuiteHel
checkAnswer(df2, Seq(Row("alex", 12000.00), Row("amy", 10000.00),
Row("cathy", 9000.00), Row("david", 10000.00), Row("jen", 12000.00)))
}
+
+ test("scan with aggregate push-down: partial push-down AVG with overflow") {
+ def createDataFrame: DataFrame = spark.read
+ .option("partitionColumn", "id")
+ .option("lowerBound", "0")
+ .option("upperBound", "2")
+ .option("numPartitions", "2")
+ .table("h2.test.item")
+ .agg(avg($"PRICE").as("avg"))
+
+ Seq(true, false).foreach { ansiEnabled =>
+ withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) {
+ val df = createDataFrame
+ checkAggregateRemoved(df, false)
+ df.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregates: [SUM(PRICE), COUNT(PRICE)]"
+ checkKeywordsExistsInExplain(df, expected_plan_fragment)
+ }
+ if (ansiEnabled) {
+ val e = intercept[SparkException] {
+ df.collect()
+ }
+ assert(e.getCause.isInstanceOf[ArithmeticException])
+ assert(e.getCause.getMessage.contains("cannot be represented as
Decimal") ||
+ e.getCause.getMessage.contains("Overflow in sum of decimals"))
+ } else {
+ checkAnswer(df, Seq(Row(null)))
+ }
+ }
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]