This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-2.4 by this push:
new 4be3390 [SPARK-31519][SQL][2.4] Datetime functions in having
aggregate expressions returns the wrong result
4be3390 is described below
commit 4be3390200617edbeec9f86bfba7b85c34f9600e
Author: Yuanjian Li <[email protected]>
AuthorDate: Wed Apr 29 09:18:40 2020 -0700
[SPARK-31519][SQL][2.4] Datetime functions in having aggregate expressions
returns the wrong result
### What changes were proposed in this pull request?
Add a new logical node AggregateWithHaving, and the parser should create
this plan for HAVING. The analyzer resolves it to Filter(..., Aggregate(...)).
### Why are the changes needed?
The SQL parser in Spark creates Filter(..., Aggregate(...)) for the HAVING
query, and Spark has a special analyzer rule ResolveAggregateFunctions to
resolve the aggregate functions and grouping columns in the Filter operator.
It works for simple cases in a very tricky way as it relies on rule
execution order:
1. Rule ResolveReferences hits the Aggregate operator and resolves
attributes inside aggregate functions, but the function itself is still
unresolved as it's an UnresolvedFunction. This stops resolving the Filter
operator as the child Aggrege operator is still unresolved.
2. Rule ResolveFunctions resolves UnresolvedFunction. This makes the
Aggrege operator resolved.
3. Rule ResolveAggregateFunctions resolves the Filter operator if its child
is a resolved Aggregate. This rule can correctly resolve the grouping columns.
In the example query, I put a datetime function `hour`, which needs to be
resolved by rule ResolveTimeZone, which runs after ResolveAggregateFunctions.
This breaks step 3 as the Aggregate operator is unresolved at that time. Then
the analyzer starts next round and the Filter operator is resolved by
ResolveReferences, which wrongly resolves the grouping columns.
See the demo below:
```
SELECT SUM(a) AS b, '2020-01-01 12:12:12' AS fake FROM VALUES (1, 10), (2,
20) AS T(a, b) GROUP BY b HAVING b > 10
```
The query's result is
```
+---+-------------------+
| b| fake|
+---+-------------------+
| 2|2020-01-01 12:12:12|
+---+-------------------+
```
But if we use `hour` function, it will return an empty result.
```
SELECT SUM(a) AS b, hour('2020-01-01 12:12:12') AS fake FROM VALUES (1,
10), (2, 20) AS T(a, b) GROUP BY b HAVING b > 10
```
### Does this PR introduce any user-facing change?
Yes, bug fix for cast in having aggregate expressions.
### How was this patch tested?
New UT added.
Closes #28397 from xuanyuanking/SPARK-31519-backport.
Authored-by: Yuanjian Li <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
.../spark/sql/catalyst/analysis/Analyzer.scala | 129 ++++++++++++---------
.../spark/sql/catalyst/analysis/unresolved.scala | 13 ++-
.../apache/spark/sql/catalyst/dsl/package.scala | 8 ++
.../spark/sql/catalyst/parser/AstBuilder.scala | 7 +-
.../sql/catalyst/parser/PlanParserSuite.scala | 5 +-
.../src/test/resources/sql-tests/inputs/having.sql | 3 +
.../resources/sql-tests/results/having.sql.out | 10 +-
.../spark/sql/DataFrameWindowFunctionsSuite.scala | 44 ++++---
8 files changed, 137 insertions(+), 82 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 61f77be2..f10276d 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -963,6 +963,9 @@ class Analyzer(
// rule: ResolveDeserializer.
case plan if containsDeserializer(plan.expressions) => plan
+ // Skip the having clause here, this will be handled in
ResolveAggregateFunctions.
+ case h: AggregateWithHaving => h
+
case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString}")
q.mapExpressions(resolve(_, q))
@@ -1536,62 +1539,14 @@ class Analyzer(
*/
object ResolveAggregateFunctions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
- case f @ Filter(cond, agg @ Aggregate(grouping, originalAggExprs,
child)) if agg.resolved =>
+ // Resolve aggregate with having clause to Filter(..., Aggregate()).
Note, to avoid wrongly
+ // resolve the having condition expression, here we skip resolving it in
ResolveReferences
+ // and transform it to Filter after aggregate is resolved. See more
details in SPARK-31519.
+ case AggregateWithHaving(cond, agg: Aggregate) if agg.resolved =>
+ resolveHaving(Filter(cond, agg), agg)
- // Try resolving the condition of the filter as though it is in the
aggregate clause
- try {
- val aggregatedCondition =
- Aggregate(
- grouping,
- Alias(cond, "havingCondition")() :: Nil,
- child)
- val resolvedOperator = executeSameContext(aggregatedCondition)
- def resolvedAggregateFilter =
- resolvedOperator
- .asInstanceOf[Aggregate]
- .aggregateExpressions.head
-
- // If resolution was successful and we see the filter has an
aggregate in it, add it to
- // the original aggregate operator.
- if (resolvedOperator.resolved) {
- // Try to replace all aggregate expressions in the filter by an
alias.
- val aggregateExpressions = ArrayBuffer.empty[NamedExpression]
- val transformedAggregateFilter = resolvedAggregateFilter.transform
{
- case ae: AggregateExpression =>
- val alias = Alias(ae, ae.toString)()
- aggregateExpressions += alias
- alias.toAttribute
- // Grouping functions are handled in the rule
[[ResolveGroupingAnalytics]].
- case e: Expression if grouping.exists(_.semanticEquals(e)) &&
- !ResolveGroupingAnalytics.hasGroupingFunction(e) &&
- !agg.output.exists(_.semanticEquals(e)) =>
- e match {
- case ne: NamedExpression =>
- aggregateExpressions += ne
- ne.toAttribute
- case _ =>
- val alias = Alias(e, e.toString)()
- aggregateExpressions += alias
- alias.toAttribute
- }
- }
-
- // Push the aggregate expressions into the aggregate (if any).
- if (aggregateExpressions.nonEmpty) {
- Project(agg.output,
- Filter(transformedAggregateFilter,
- agg.copy(aggregateExpressions = originalAggExprs ++
aggregateExpressions)))
- } else {
- f
- }
- } else {
- f
- }
- } catch {
- // Attempting to resolve in the aggregate can result in ambiguity.
When this happens,
- // just return the original plan.
- case ae: AnalysisException => f
- }
+ case f @ Filter(_, agg: Aggregate) if agg.resolved =>
+ resolveHaving(f, agg)
case sort @ Sort(sortOrder, global, aggregate: Aggregate) if
aggregate.resolved =>
@@ -1662,6 +1617,63 @@ class Analyzer(
def containsAggregate(condition: Expression): Boolean = {
condition.find(_.isInstanceOf[AggregateExpression]).isDefined
}
+
+ def resolveHaving(filter: Filter, agg: Aggregate): LogicalPlan = {
+ // Try resolving the condition of the filter as though it is in the
aggregate clause
+ try {
+ val aggregatedCondition =
+ Aggregate(
+ agg.groupingExpressions,
+ Alias(filter.condition, "havingCondition")() :: Nil,
+ agg.child)
+ val resolvedOperator = executeSameContext(aggregatedCondition)
+ def resolvedAggregateFilter =
+ resolvedOperator
+ .asInstanceOf[Aggregate]
+ .aggregateExpressions.head
+
+ // If resolution was successful and we see the filter has an aggregate
in it, add it to
+ // the original aggregate operator.
+ if (resolvedOperator.resolved) {
+ // Try to replace all aggregate expressions in the filter by an
alias.
+ val aggregateExpressions = ArrayBuffer.empty[NamedExpression]
+ val transformedAggregateFilter = resolvedAggregateFilter.transform {
+ case ae: AggregateExpression =>
+ val alias = Alias(ae, ae.toString)()
+ aggregateExpressions += alias
+ alias.toAttribute
+ // Grouping functions are handled in the rule
[[ResolveGroupingAnalytics]].
+ case e: Expression if
agg.groupingExpressions.exists(_.semanticEquals(e)) &&
+ !ResolveGroupingAnalytics.hasGroupingFunction(e) &&
+ !agg.output.exists(_.semanticEquals(e)) =>
+ e match {
+ case ne: NamedExpression =>
+ aggregateExpressions += ne
+ ne.toAttribute
+ case _ =>
+ val alias = Alias(e, e.toString)()
+ aggregateExpressions += alias
+ alias.toAttribute
+ }
+ }
+
+ // Push the aggregate expressions into the aggregate (if any).
+ if (aggregateExpressions.nonEmpty) {
+ Project(agg.output,
+ Filter(transformedAggregateFilter,
+ agg.copy(aggregateExpressions = agg.aggregateExpressions ++
aggregateExpressions)))
+ } else {
+ filter
+ }
+ } else {
+ filter
+ }
+ } catch {
+ // Attempting to resolve in the aggregate can result in ambiguity.
When this happens,
+ // just return the original plan.
+ case ae: AnalysisException => filter
+ }
+ }
}
/**
@@ -2050,11 +2062,14 @@ class Analyzer(
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown {
case Filter(condition, _) if hasWindowFunction(condition) =>
- failAnalysis("It is not allowed to use window functions inside WHERE
and HAVING clauses")
+ failAnalysis("It is not allowed to use window functions inside WHERE
clause")
+
+ case AggregateWithHaving(condition, _) if hasWindowFunction(condition) =>
+ failAnalysis("It is not allowed to use window functions inside HAVING
clause")
// Aggregate with Having clause. This rule works with an unresolved
Aggregate because
// a resolved Aggregate will not have Window Functions.
- case f @ Filter(condition, a @ Aggregate(groupingExprs, aggregateExprs,
child))
+ case f @ AggregateWithHaving(condition, a @ Aggregate(groupingExprs,
aggregateExprs, child))
if child.resolved &&
hasWindowFunction(aggregateExprs) &&
a.expressions.forall(_.resolved) =>
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 36cad3c..bcd2ff7 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext,
ExprCode}
import org.apache.spark.sql.catalyst.parser.ParserUtils
-import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan,
UnaryNode}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LeafNode,
LogicalPlan, UnaryNode}
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.util.quoteIdentifier
import org.apache.spark.sql.types.{DataType, Metadata, StructType}
@@ -513,3 +513,14 @@ case class UnresolvedOrdinal(ordinal: Int)
override def nullable: Boolean = throw new UnresolvedException(this,
"nullable")
override lazy val resolved = false
}
+
+/**
+ * Represents unresolved aggregate with having clause, it is turned by the
analyzer into a Filter.
+ */
+case class AggregateWithHaving(
+ havingCondition: Expression,
+ child: Aggregate)
+ extends UnaryNode {
+ override lazy val resolved: Boolean = false
+ override def output: Seq[Attribute] = child.output
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index d3ccd18..f7b1638 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -356,6 +356,14 @@ package object dsl {
Aggregate(groupingExprs, aliasedExprs, logicalPlan)
}
+ def having(
+ groupingExprs: Expression*)(
+ aggregateExprs: Expression*)(
+ havingCondition: Expression): LogicalPlan = {
+ AggregateWithHaving(havingCondition,
+ groupBy(groupingExprs: _*)(aggregateExprs:
_*).asInstanceOf[Aggregate])
+ }
+
def window(
windowExpressions: Seq[NamedExpression],
partitionSpec: Seq[Expression],
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 80a4d18..22d5f1d 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -401,7 +401,12 @@ class AstBuilder(conf: SQLConf) extends
SqlBaseBaseVisitor[AnyRef] with Logging
case p: Predicate => p
case e => Cast(e, BooleanType)
}
- Filter(predicate, plan)
+ plan match {
+ case aggregate: Aggregate =>
+ AggregateWithHaving(predicate, aggregate)
+ case _ =>
+ Filter(predicate, plan)
+ }
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index f5da90f..da69c12 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -108,7 +108,7 @@ class PlanParserSuite extends AnalysisTest {
assertEqual("select a, b from db.c where x < 1", table("db", "c").where('x
< 1).select('a, 'b))
assertEqual(
"select a, b from db.c having x < 1",
- table("db", "c").groupBy()('a, 'b).where('x < 1))
+ table("db", "c").having()('a, 'b)('x < 1))
assertEqual("select distinct a, b from db.c", Distinct(table("db",
"c").select('a, 'b)))
assertEqual("select all a, b from db.c", table("db", "c").select('a, 'b))
assertEqual("select from tbl", OneRowRelation().select('from.as("tbl")))
@@ -473,8 +473,7 @@ class PlanParserSuite extends AnalysisTest {
assertEqual(
"select g from t group by g having a > (select b from s)",
table("t")
- .groupBy('g)('g)
- .where('a > ScalarSubquery(table("s").select('b))))
+ .having('g)('g)('a > ScalarSubquery(table("s").select('b))))
}
test("table reference") {
diff --git a/sql/core/src/test/resources/sql-tests/inputs/having.sql
b/sql/core/src/test/resources/sql-tests/inputs/having.sql
index 868a911..179686e 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/having.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/having.sql
@@ -16,3 +16,6 @@ SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t
HAVING(COUNT(1) > 0);
-- SPARK-20329: make sure we handle timezones correctly
SELECT a + b FROM VALUES (1L, 2), (3L, 4) AS T(a, b) GROUP BY a + b HAVING a +
b > 1;
+
+-- SPARK-31519: Datetime functions in having aggregate expressions returns the
wrong result
+SELECT SUM(a) AS b, hour('2020-01-01 12:12:12') AS fake FROM VALUES (1, 10),
(2, 20) AS T(a, b) GROUP BY b HAVING b > 10
diff --git a/sql/core/src/test/resources/sql-tests/results/having.sql.out
b/sql/core/src/test/resources/sql-tests/results/having.sql.out
index d87ee52..fd594f5 100644
--- a/sql/core/src/test/resources/sql-tests/results/having.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/having.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 5
+-- Number of queries: 6
-- !query 0
@@ -47,3 +47,11 @@ struct<(a + CAST(b AS BIGINT)):bigint>
-- !query 4 output
3
7
+
+
+-- !query 5
+SELECT SUM(a) AS b, hour('2020-01-01 12:12:12') AS fake FROM VALUES (1, 10),
(2, 20) AS T(a, b) GROUP BY b HAVING b > 10
+-- !query 5 schema
+struct<b:bigint,fake:int>
+-- !query 5 output
+2 12
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
index 97a8439..f8db6fa 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
@@ -623,39 +623,45 @@ class DataFrameWindowFunctionsSuite extends QueryTest
with SharedSQLContext {
}
test("SPARK-24575: Window functions inside WHERE and HAVING clauses") {
- def checkAnalysisError(df: => DataFrame): Unit = {
- val thrownException = the [AnalysisException] thrownBy {
+ def checkAnalysisError(df: => DataFrame, clause: String): Unit = {
+ val thrownException = the[AnalysisException] thrownBy {
df.queryExecution.analyzed
}
- assert(thrownException.message.contains("window functions inside WHERE
and HAVING clauses"))
+ assert(thrownException.message.contains(s"window functions inside
$clause clause"))
}
-
checkAnalysisError(testData2.select('a).where(rank().over(Window.orderBy('b))
=== 1))
- checkAnalysisError(testData2.where('b === 2 &&
rank().over(Window.orderBy('b)) === 1))
checkAnalysisError(
- testData2.groupBy('a)
- .agg(avg('b).as("avgb"))
- .where('a > 'avgb && rank().over(Window.orderBy('a)) === 1))
+ testData2.select("a").where(rank().over(Window.orderBy($"b")) === 1),
"WHERE")
checkAnalysisError(
- testData2.groupBy('a)
- .agg(max('b).as("maxb"), sum('b).as("sumb"))
- .where(rank().over(Window.orderBy('a)) === 1))
+ testData2.where($"b" === 2 && rank().over(Window.orderBy($"b")) === 1),
"WHERE")
checkAnalysisError(
- testData2.groupBy('a)
- .agg(max('b).as("maxb"), sum('b).as("sumb"))
- .where('sumb === 5 && rank().over(Window.orderBy('a)) === 1))
+ testData2.groupBy($"a")
+ .agg(avg($"b").as("avgb"))
+ .where($"a" > $"avgb" && rank().over(Window.orderBy($"a")) === 1),
"WHERE")
+ checkAnalysisError(
+ testData2.groupBy($"a")
+ .agg(max($"b").as("maxb"), sum($"b").as("sumb"))
+ .where(rank().over(Window.orderBy($"a")) === 1), "WHERE")
+ checkAnalysisError(
+ testData2.groupBy($"a")
+ .agg(max($"b").as("maxb"), sum($"b").as("sumb"))
+ .where($"sumb" === 5 && rank().over(Window.orderBy($"a")) === 1),
"WHERE")
- checkAnalysisError(sql("SELECT a FROM testData2 WHERE RANK() OVER(ORDER BY
b) = 1"))
- checkAnalysisError(sql("SELECT * FROM testData2 WHERE b = 2 AND RANK()
OVER(ORDER BY b) = 1"))
+ checkAnalysisError(sql("SELECT a FROM testData2 WHERE RANK() OVER(ORDER BY
b) = 1"), "WHERE")
+ checkAnalysisError(
+ sql("SELECT * FROM testData2 WHERE b = 2 AND RANK() OVER(ORDER BY b) =
1"), "WHERE")
checkAnalysisError(
- sql("SELECT * FROM testData2 GROUP BY a HAVING a > AVG(b) AND RANK()
OVER(ORDER BY a) = 1"))
+ sql("SELECT * FROM testData2 GROUP BY a HAVING a > AVG(b) AND RANK()
OVER(ORDER BY a) = 1"),
+ "HAVING")
checkAnalysisError(
- sql("SELECT a, MAX(b), SUM(b) FROM testData2 GROUP BY a HAVING RANK()
OVER(ORDER BY a) = 1"))
+ sql("SELECT a, MAX(b), SUM(b) FROM testData2 GROUP BY a HAVING RANK()
OVER(ORDER BY a) = 1"),
+ "HAVING")
checkAnalysisError(
sql(
s"""SELECT a, MAX(b)
|FROM testData2
|GROUP BY a
- |HAVING SUM(b) = 5 AND RANK() OVER(ORDER BY a) = 1""".stripMargin))
+ |HAVING SUM(b) = 5 AND RANK() OVER(ORDER BY a) = 1""".stripMargin),
+ "HAVING")
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]