This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 2ce1a8769846 [SPARK-44219][SQL] Adds extra per-rule validations for
optimization rewrites
2ce1a8769846 is described below
commit 2ce1a8769846477cfc2c596885f8005d8fc972b5
Author: Yannis Sismanis <[email protected]>
AuthorDate: Fri Oct 6 11:20:35 2023 +0800
[SPARK-44219][SQL] Adds extra per-rule validations for optimization rewrites
### What changes were proposed in this pull request?
Adds per-rule validation checks for the following:
1. aggregate expressions in Aggregate plans are valid.
2. Grouping key types in Aggregate plans cannot by of type Map.
3. No dangling references have been generated.
This validation is by default enabled for all tests or selectively using
the spark.sql.planChangeValidation=true flag.
### Why are the changes needed?
Extra validation for optimizer rewrites.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Unit tests
Closes #41763 from YannisSismanis/SC-130139_followup.
Authored-by: Yannis Sismanis <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../sql/catalyst/analysis/CheckAnalysis.scala | 72 +-----
.../spark/sql/catalyst/expressions/ExprUtils.scala | 77 ++++++
.../sql/catalyst/plans/logical/LogicalPlan.scala | 68 +++++-
.../sql/catalyst/optimizer/OptimizerSuite.scala | 270 ++++++++++++++++++++-
4 files changed, 412 insertions(+), 75 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 81ca59c0976e..e140625f47ab 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -431,77 +431,7 @@ trait CheckAnalysis extends PredicateHelper with
LookupCatalog with QueryErrorsB
messageParameters = Map.empty)
}
- case Aggregate(groupingExprs, aggregateExprs, _) =>
- def checkValidAggregateExpression(expr: Expression): Unit = expr
match {
- case expr: AggregateExpression =>
- val aggFunction = expr.aggregateFunction
- aggFunction.children.foreach { child =>
- child.foreach {
- case expr: AggregateExpression =>
- expr.failAnalysis(
- errorClass = "NESTED_AGGREGATE_FUNCTION",
- messageParameters = Map.empty)
- case other => // OK
- }
-
- if (!child.deterministic) {
- child.failAnalysis(
- errorClass =
"AGGREGATE_FUNCTION_WITH_NONDETERMINISTIC_EXPRESSION",
- messageParameters = Map("sqlExpr" -> toSQLExpr(expr)))
- }
- }
- case _: Attribute if groupingExprs.isEmpty =>
- operator.failAnalysis(
- errorClass = "MISSING_GROUP_BY",
- messageParameters = Map.empty)
- case e: Attribute if !groupingExprs.exists(_.semanticEquals(e))
=>
- throw QueryCompilationErrors.columnNotInGroupByClauseError(e)
- case s: ScalarSubquery
- if s.children.nonEmpty &&
!groupingExprs.exists(_.semanticEquals(s)) =>
- s.failAnalysis(
- errorClass =
"SCALAR_SUBQUERY_IS_IN_GROUP_BY_OR_AGGREGATE_FUNCTION",
- messageParameters = Map("sqlExpr" -> toSQLExpr(s)))
- case e if groupingExprs.exists(_.semanticEquals(e)) => // OK
- // There should be no Window in Aggregate - this case will fail
later check anyway.
- // Perform this check for special case of lateral column alias,
when the window
- // expression is not eligible to propagate to upper plan because
it is not valid,
- // containing non-group-by or non-aggregate-expressions.
- case WindowExpression(function, spec) =>
- function.children.foreach(checkValidAggregateExpression)
- checkValidAggregateExpression(spec)
- case e => e.children.foreach(checkValidAggregateExpression)
- }
-
- def checkValidGroupingExprs(expr: Expression): Unit = {
- if (expr.exists(_.isInstanceOf[AggregateExpression])) {
- expr.failAnalysis(
- errorClass = "GROUP_BY_AGGREGATE",
- messageParameters = Map("sqlExpr" -> expr.sql))
- }
-
- // Check if the data type of expr is orderable.
- if (!RowOrdering.isOrderable(expr.dataType)) {
- expr.failAnalysis(
- errorClass = "GROUP_EXPRESSION_TYPE_IS_NOT_ORDERABLE",
- messageParameters = Map(
- "sqlExpr" -> toSQLExpr(expr),
- "dataType" -> toSQLType(expr.dataType)))
- }
-
- if (!expr.deterministic) {
- // This is just a sanity check, our analysis rule
PullOutNondeterministic should
- // already pull out those nondeterministic expressions and
evaluate them in
- // a Project node.
- throw SparkException.internalError(
- msg = s"Non-deterministic expression '${toSQLExpr(expr)}'
should not appear in " +
- "grouping expression.",
- context = expr.origin.getQueryContext,
- summary = expr.origin.context.summary)
- }
- }
-
- groupingExprs.foreach(checkValidGroupingExprs)
- aggregateExprs.foreach(checkValidAggregateExpression)
+ case a: Aggregate => ExprUtils.assertValidAggregation(a)
case CollectMetrics(name, metrics, _, _) =>
if (name == null || name.isEmpty) {
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala
index 5a093037e424..29c9605db51b 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala
@@ -20,8 +20,12 @@ package org.apache.spark.sql.catalyst.expressions
import java.text.{DecimalFormat, DecimalFormatSymbols, ParsePosition}
import java.util.Locale
+import org.apache.spark.SparkException
+import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import
org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch,
TypeCheckSuccess}
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
+import org.apache.spark.sql.catalyst.plans.logical.Aggregate
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, CharVarcharUtils}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase,
QueryExecutionErrors}
import org.apache.spark.sql.types.{DataType, MapType, StringType, StructType}
@@ -140,4 +144,77 @@ object ExprUtils extends QueryErrorsBase {
TypeCheckSuccess
}
}
+
+ def assertValidAggregation(a: Aggregate): Unit = {
+ def checkValidAggregateExpression(expr: Expression): Unit = expr match {
+ case expr: AggregateExpression =>
+ val aggFunction = expr.aggregateFunction
+ aggFunction.children.foreach { child =>
+ child.foreach {
+ case expr: AggregateExpression =>
+ expr.failAnalysis(
+ errorClass = "NESTED_AGGREGATE_FUNCTION",
+ messageParameters = Map.empty)
+ case other => // OK
+ }
+
+ if (!child.deterministic) {
+ child.failAnalysis(
+ errorClass =
"AGGREGATE_FUNCTION_WITH_NONDETERMINISTIC_EXPRESSION",
+ messageParameters = Map("sqlExpr" -> toSQLExpr(expr)))
+ }
+ }
+ case _: Attribute if a.groupingExpressions.isEmpty =>
+ a.failAnalysis(
+ errorClass = "MISSING_GROUP_BY",
+ messageParameters = Map.empty)
+ case e: Attribute if !a.groupingExpressions.exists(_.semanticEquals(e))
=>
+ throw QueryCompilationErrors.columnNotInGroupByClauseError(e)
+ case s: ScalarSubquery
+ if s.children.nonEmpty &&
!a.groupingExpressions.exists(_.semanticEquals(s)) =>
+ s.failAnalysis(
+ errorClass = "SCALAR_SUBQUERY_IS_IN_GROUP_BY_OR_AGGREGATE_FUNCTION",
+ messageParameters = Map("sqlExpr" -> toSQLExpr(s)))
+ case e if a.groupingExpressions.exists(_.semanticEquals(e)) => // OK
+ // There should be no Window in Aggregate - this case will fail later
check anyway.
+ // Perform this check for special case of lateral column alias, when the
window
+ // expression is not eligible to propagate to upper plan because it is
not valid,
+ // containing non-group-by or non-aggregate-expressions.
+ case WindowExpression(function, spec) =>
+ function.children.foreach(checkValidAggregateExpression)
+ checkValidAggregateExpression(spec)
+ case e => e.children.foreach(checkValidAggregateExpression)
+ }
+
+ def checkValidGroupingExprs(expr: Expression): Unit = {
+ if (expr.exists(_.isInstanceOf[AggregateExpression])) {
+ expr.failAnalysis(
+ errorClass = "GROUP_BY_AGGREGATE",
+ messageParameters = Map("sqlExpr" -> expr.sql))
+ }
+
+ // Check if the data type of expr is orderable.
+ if (!RowOrdering.isOrderable(expr.dataType)) {
+ expr.failAnalysis(
+ errorClass = "GROUP_EXPRESSION_TYPE_IS_NOT_ORDERABLE",
+ messageParameters = Map(
+ "sqlExpr" -> toSQLExpr(expr),
+ "dataType" -> toSQLType(expr.dataType)))
+ }
+
+ if (!expr.deterministic) {
+ // This is just a sanity check, our analysis rule
PullOutNondeterministic should
+ // already pull out those nondeterministic expressions and evaluate
them in
+ // a Project node.
+ throw SparkException.internalError(
+ msg = s"Non-deterministic expression '${toSQLExpr(expr)}' should not
appear in " +
+ "grouping expression.",
+ context = expr.origin.getQueryContext,
+ summary = expr.origin.context.summary)
+ }
+ }
+
+ a.groupingExpressions.foreach(checkValidGroupingExprs)
+ a.aggregateExpressions.foreach(checkValidAggregateExpression)
+ }
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index a5f4c4d5c51a..ae3029b279da 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.internal.Logging
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{AliasAwareQueryOutputOrdering,
QueryPlan}
@@ -26,7 +27,7 @@ import org.apache.spark.sql.catalyst.trees.{BinaryLike,
LeafLike, TreeNodeTag, U
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.MetadataColumnHelper
import org.apache.spark.sql.errors.{QueryCompilationErrors,
QueryExecutionErrors}
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{MapType, StructType}
abstract class LogicalPlan
@@ -325,6 +326,62 @@ object LogicalPlanIntegrity {
LogicalPlanIntegrity.hasUniqueExprIdsForOutput(plan))
}
+ /**
+ * This method validates there are no dangling attribute references.
+ * Returns an error message if the check does not pass, or None if it does
pass.
+ */
+ def validateNoDanglingReferences(plan: LogicalPlan): Option[String] = {
+ plan.collectFirst {
+ // DML commands and multi instance relations (like InMemoryRelation
caches)
+ // have different output semantics than typical queries.
+ case _: Command => None
+ case _: MultiInstanceRelation => None
+ case n if canGetOutputAttrs(n) =>
+ if (n.missingInput.nonEmpty) {
+ Some(s"Aliases ${ n.missingInput.mkString(", ")} are dangling " +
+ s"in the references for plan:\n ${n.treeString}")
+ } else {
+ None
+ }
+ }.flatten
+ }
+
+ /**
+ * Validate that the grouping key types in Aggregate plans are valid.
+ * Returns an error message if the check fails, or None if it succeeds.
+ */
+ def validateGroupByTypes(plan: LogicalPlan): Option[String] = {
+ plan.collectFirst {
+ case a @ Aggregate(groupingExprs, _, _) =>
+ val badExprs =
groupingExprs.filter(_.dataType.isInstanceOf[MapType]).map(_.toString)
+ if (badExprs.nonEmpty) {
+ Some(s"Grouping expressions ${badExprs.mkString(", ")} cannot be of
type Map " +
+ s"for plan:\n ${a.treeString}")
+ } else {
+ None
+ }
+ }.flatten
+ }
+
+ /**
+ * Validate that the aggregation expressions in Aggregate plans are valid.
+ * Returns an error message if the check fails, or None if it succeeds.
+ */
+ def validateAggregateExpressions(plan: LogicalPlan): Option[String] = {
+ plan.collectFirst {
+ case a: Aggregate =>
+ try {
+ ExprUtils.assertValidAggregation(a)
+ None
+ } catch {
+ case e: AnalysisException =>
+ Some(s"Aggregate: ${a.toString} is not a valid aggregate
expression: " +
+ s"${e.getSimpleMessage}")
+ }
+ }.flatten
+ }
+
+
/**
* Validate the structural integrity of an optimized plan.
* For example, we can check after the execution of each rule that each plan:
@@ -337,7 +394,7 @@ object LogicalPlanIntegrity {
def validateOptimizedPlan(
previousPlan: LogicalPlan,
currentPlan: LogicalPlan): Option[String] = {
- if (!currentPlan.resolved) {
+ var validation = if (!currentPlan.resolved) {
Some("The plan becomes unresolved: " + currentPlan.treeString + "\nThe
previous plan: " +
previousPlan.treeString)
} else if
(currentPlan.exists(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty))
{
@@ -353,6 +410,13 @@ object LogicalPlanIntegrity {
}
}
}
+ validation = validation
+ .orElse(LogicalPlanIntegrity.validateNoDanglingReferences(currentPlan))
+ .orElse(LogicalPlanIntegrity.validateGroupByTypes(currentPlan))
+ .orElse(LogicalPlanIntegrity.validateAggregateExpressions(currentPlan))
+ .map(err => s"${err}\nPrevious schema:${previousPlan.output.mkString(",
")}" +
+ s"\nPrevious plan: ${previousPlan.treeString}")
+ validation
}
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerSuite.scala
index 6b63f860b7da..e40fff22bc1c 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerSuite.scala
@@ -17,13 +17,17 @@
package org.apache.spark.sql.catalyst.optimizer
+import org.apache.spark.SparkException
import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.{Alias, IntegerLiteral,
Literal}
+import org.apache.spark.sql.catalyst.expressions.{Add, Alias,
AttributeReference, IntegerLiteral, Literal, Multiply, NamedExpression,
Remainder}
+import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
import org.apache.spark.sql.catalyst.plans.PlanTest
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan,
OneRowRelation, Project}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LocalRelation,
LogicalPlan, OneRowRelation, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.IntegerType
/**
* A dummy optimizer rule for testing that decrements integer literals until 0.
@@ -71,4 +75,266 @@ class OptimizerSuite extends PlanTest {
s"test, please set '${SQLConf.OPTIMIZER_MAX_ITERATIONS.key}' to a
larger value."))
}
}
+
+ test("Optimizer per rule validation catches dangling references") {
+ val analyzed = Project(Alias(Literal(10), "attr")() :: Nil,
+ OneRowRelation()).analyze
+
+ /**
+ * A dummy optimizer rule for testing that dangling references are not
allowed.
+ */
+ object DanglingReference extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ Project(Alias(
+ Add(AttributeReference("debug1", IntegerType, nullable = false)(),
+ AttributeReference("debug2", IntegerType, nullable = false)()),
"attr")() :: Nil,
+ plan)
+ }
+ }
+
+ val optimizer = new SimpleTestOptimizer() {
+ override def defaultBatches: Seq[Batch] =
+ Batch("test", FixedPoint(1),
+ DanglingReference) :: Nil
+ }
+ val message1 = intercept[SparkException] {
+ optimizer.execute(analyzed)
+ }.getMessage
+ assert(message1.contains("are dangling"))
+ }
+
+ test("Optimizer per rule validation catches invalid grouping types") {
+ val analyzed = LocalRelation('a.map(IntegerType, IntegerType))
+ .select('a).analyze
+
+ /**
+ * A dummy optimizer rule for testing that invalid grouping types are not
allowed.
+ */
+ object InvalidGroupingType extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ Aggregate(plan.output, plan.output, plan)
+ }
+ }
+
+ val optimizer = new SimpleTestOptimizer() {
+ override def defaultBatches: Seq[Batch] =
+ Batch("test", FixedPoint(1),
+ InvalidGroupingType) :: Nil
+ }
+ val message1 = intercept[SparkException] {
+ optimizer.execute(analyzed)
+ }.getMessage
+ assert(message1.contains("cannot be of type Map"))
+ }
+
+ test("Optimizer per rule validation catches invalid aggregation
expressions") {
+ val analyzed = LocalRelation('a.long, 'b.long)
+ .select('a, 'b).analyze
+
+ /**
+ * A dummy optimizer rule for testing that a non grouping key reference
+ * should be aggregated (under an AggregateFunction).
+ */
+ object InvalidAggregationReference extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ val outputExpressions = plan.output
+ val groupingExpressions = outputExpressions.head :: Nil
+ val aggregateExpressions = outputExpressions
+ // I.e INVALID: select a, b from T group by a
+ Aggregate(groupingExpressions, aggregateExpressions, plan)
+ }
+ }
+
+ /**
+ * A dummy optimizer rule for testing that a non grouping key reference
+ * should be aggregated (under an AggregateFunction).
+ */
+ object InvalidAggregationReference2 extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ val outputExpressions = plan.output
+ val groupingExpressions = outputExpressions.head :: Nil
+ val aggregateExpressions = Alias(Literal(1L), "a")() ::
outputExpressions.last :: Nil
+ // I.e INVALID: select 1 as a, b from T group by a
+ Aggregate(groupingExpressions, aggregateExpressions, plan)
+ }
+ }
+
+ /**
+ * A dummy optimizer rule for testing that a non grouping key expression
+ * should be aggregated (under an AggregateFunction).
+ */
+ object InvalidAggregationExpression extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ val outputExpressions = plan.output
+ val groupingExpressions = outputExpressions.head :: Nil
+ val aggregateExpressions = outputExpressions.head ::
+ Alias(Add(outputExpressions.last, Literal(1L)), "b")() :: Nil
+ // I.e INVALID: a, select b + 1 as b from T group by a
+ Aggregate(groupingExpressions, aggregateExpressions, plan)
+ }
+ }
+
+ /**
+ * A dummy optimizer rule for testing that a non grouping key expression
+ * should be aggregated (under an AggregateFunction).
+ */
+ object InvalidAggregationExpression2 extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ val outputExpressions = plan.output
+ val groupingExpressions = outputExpressions.head :: Nil
+ val aggregateExpressions = Alias(Literal(1L), "a")() ::
+ Alias(Remainder(outputExpressions.last, outputExpressions.head),
"b")() :: Nil
+ // I.e INVALID: select 1 as a, b % a as b from T group by a
+ Aggregate(groupingExpressions, aggregateExpressions, plan)
+ }
+ }
+
+ /**
+ * A dummy optimizer rule for testing that a non grouping key expression
+ * should be aggregated (under an AggregateFunction).
+ */
+ object InvalidAggregationExpression3 extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ val outputExpressions = plan.output
+ val groupingExpressions = outputExpressions.head :: Nil
+ val aggregateExpressions = Alias(Literal(1L), "a")() ::
+ Alias(Multiply(outputExpressions.head,
+ Sum(outputExpressions.head).toAggregateExpression()), "b")() :: Nil
+ // I.e VALID: select 1 as a, a*sum(a) as b from T group by a
+ // analyze() should not fail.
+ val goodAggregate =
+ Aggregate(groupingExpressions, aggregateExpressions, plan)
+ .analyze.asInstanceOf[Aggregate]
+ assert(goodAggregate.analyzed)
+ // I.e INVALID: select 1 as a, a*sum(a) as b from T group by b
+ // Rule-validation should catch this.
+ Aggregate(outputExpressions.last :: Nil,
goodAggregate.aggregateExpressions, plan)
+ }
+ }
+
+ /**
+ * A dummy optimizer rule for testing valid aggregate expression
+ */
+ object ValidAggregationExpression extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ val outputExpressions = plan.output
+ val groupingExpressions = outputExpressions.head :: Nil
+ val aggregateExpressions : Seq[NamedExpression] =
outputExpressions.head ::
+ Alias(Add(outputExpressions.head, Literal(1L)), "b")() :: Nil
+ // I.e VALID: select a, a + 1 as b from T group by a
+ Aggregate(groupingExpressions, aggregateExpressions, plan)
+ }
+ }
+
+ /**
+ * A dummy optimizer rule for testing another valid aggregate expression
+ */
+ object ValidAggregationExpression2 extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ val outputExpressions = plan.output
+ val groupingExpressions = Add(outputExpressions.head, Literal(1L)) ::
Nil
+ val aggregateExpressions : Seq[NamedExpression] = Alias(Literal(1L),
"a")() ::
+ Alias(Add(outputExpressions.head, Literal(1L)), "b")() :: Nil
+ // I.e VALID: select 1 as a, a + 1 as b from T group by a + 1
+ Aggregate(groupingExpressions, aggregateExpressions, plan)
+ }
+ }
+
+ /**
+ * A dummy optimizer rule for testing another valid aggregate expression
+ */
+ object ValidAggregationExpression3 extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ val outputExpressions = plan.output
+ val groupingExpressions = Add(outputExpressions.head, Literal(1L)) ::
Nil
+ val aggregateExpressions : Seq[NamedExpression] = Alias(Literal(1L),
"a")() ::
+ Alias(Add(Add(outputExpressions.head, Literal(1L)), Literal(1L)),
"b")() :: Nil
+ // I.e VALID: select 1 as a, a + 1 + 1 as b from T group by a + 1
+ Aggregate(groupingExpressions, aggregateExpressions, plan)
+ }
+ }
+
+ /**
+ * A dummy optimizer rule for testing another valid aggregate expression
+ */
+ object ValidAggregationExpression4 extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ val outputExpressions = plan.output
+ val groupingExpressions = Add(outputExpressions.head, Literal(1L)) ::
Nil
+ val aggregateExpressions : Seq[NamedExpression] = Alias(Literal(1L),
"a")() ::
+ Alias(Sum(outputExpressions.last).toAggregateExpression(), "b")() ::
Nil
+ // I.e VALID: select 1 as a, sum(b) as b from T group by a + 1
+ Aggregate(groupingExpressions, aggregateExpressions, plan)
+ }
+ }
+
+ /**
+ * A dummy optimizer rule for testing another valid aggregate expression
+ */
+ object ValidAggregationExpression5 extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ val outputExpressions = plan.output
+ val groupingExpressions = outputExpressions.head :: Nil
+ val aggregateExpressions : Seq[NamedExpression] = Alias(Literal(1L),
"a")() ::
+ Alias(Sum(outputExpressions.head).toAggregateExpression(), "b")() ::
Nil
+ // I.e VALID: select 1 as a, sum(a) as b from T group by a
+ Aggregate(groupingExpressions, aggregateExpressions, plan)
+ }
+ }
+
+ /**
+ * A dummy optimizer rule for testing another valid aggregate expression
+ */
+ object ValidAggregationExpression6 extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ val outputExpressions = plan.output
+ val groupingExpressions = Remainder(outputExpressions.head,
Literal(2L)) :: Nil
+ val aggregateExpressions : Seq[NamedExpression] = Alias(Literal(1L),
"a")() ::
+ Alias(Sum(outputExpressions.head).toAggregateExpression(), "b")() ::
Nil
+ // I.e VALID: select 1 as a, sum(a) as b from T group by a % 2
+ Aggregate(groupingExpressions, aggregateExpressions, plan)
+ }
+ }
+
+ /**
+ * A dummy optimizer rule for testing another valid aggregate expression
+ */
+ object ValidAggregationExpression7 extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ val outputExpressions = plan.output
+ val groupingExpressions = Remainder(outputExpressions.head,
Literal(2L)) :: Nil
+ val aggregateExpressions : Seq[NamedExpression] = Alias(Literal(1L),
"a")() ::
+ Alias(Add(Sum(outputExpressions.head).toAggregateExpression(),
+ groupingExpressions.head), "b")() :: Nil
+ // I.e VALID: 1 as a, select sum(a)*(a % 2) as b from T group by a % 2
+ Aggregate(groupingExpressions, aggregateExpressions, plan).analyze
+ }
+ }
+
+ // Valid rules do not trigger exceptions.
+ Seq(ValidAggregationExpression, ValidAggregationExpression2,
+ ValidAggregationExpression3, ValidAggregationExpression4,
+ ValidAggregationExpression5, ValidAggregationExpression6,
+ ValidAggregationExpression7).map { r =>
+ val optimizer = new SimpleTestOptimizer() {
+ override def defaultBatches: Seq[Batch] =
+ Batch("test", FixedPoint(1), r) :: Nil
+ }
+ assert(optimizer.execute(analyzed).resolved)
+ }
+
+ // Invalid rules trigger exceptions.
+ Seq(InvalidAggregationReference, InvalidAggregationReference2,
+ InvalidAggregationExpression, InvalidAggregationExpression2,
+ InvalidAggregationExpression3).map { r =>
+ val optimizer = new SimpleTestOptimizer() {
+ override def defaultBatches: Seq[Batch] =
+ Batch("test", FixedPoint(1), r) :: Nil
+ }
+ val message1 = intercept[SparkException] {
+ optimizer.execute(analyzed)
+ }.getMessage
+ assert(message1.contains("not a valid aggregate expression"))
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]