This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 2ce1a8769846 [SPARK-44219][SQL] Adds extra per-rule validations for 
optimization rewrites
2ce1a8769846 is described below

commit 2ce1a8769846477cfc2c596885f8005d8fc972b5
Author: Yannis Sismanis <[email protected]>
AuthorDate: Fri Oct 6 11:20:35 2023 +0800

    [SPARK-44219][SQL] Adds extra per-rule validations for optimization rewrites
    
    ### What changes were proposed in this pull request?
    
    Adds per-rule validation checks for the following:
    
    1.  aggregate expressions in Aggregate plans are valid.
    2. Grouping key types in Aggregate plans cannot by of type Map.
    3. No dangling references have been generated.
    
    This validation is by default enabled for all tests or selectively using 
the spark.sql.planChangeValidation=true flag.
    
    ### Why are the changes needed?
    Extra validation for optimizer rewrites.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Unit tests
    
    Closes #41763 from YannisSismanis/SC-130139_followup.
    
    Authored-by: Yannis Sismanis <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../sql/catalyst/analysis/CheckAnalysis.scala      |  72 +-----
 .../spark/sql/catalyst/expressions/ExprUtils.scala |  77 ++++++
 .../sql/catalyst/plans/logical/LogicalPlan.scala   |  68 +++++-
 .../sql/catalyst/optimizer/OptimizerSuite.scala    | 270 ++++++++++++++++++++-
 4 files changed, 412 insertions(+), 75 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 81ca59c0976e..e140625f47ab 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -431,77 +431,7 @@ trait CheckAnalysis extends PredicateHelper with 
LookupCatalog with QueryErrorsB
                 messageParameters = Map.empty)
             }
 
-          case Aggregate(groupingExprs, aggregateExprs, _) =>
-            def checkValidAggregateExpression(expr: Expression): Unit = expr 
match {
-              case expr: AggregateExpression =>
-                val aggFunction = expr.aggregateFunction
-                aggFunction.children.foreach { child =>
-                  child.foreach {
-                    case expr: AggregateExpression =>
-                      expr.failAnalysis(
-                        errorClass = "NESTED_AGGREGATE_FUNCTION",
-                        messageParameters = Map.empty)
-                    case other => // OK
-                  }
-
-                  if (!child.deterministic) {
-                    child.failAnalysis(
-                      errorClass = 
"AGGREGATE_FUNCTION_WITH_NONDETERMINISTIC_EXPRESSION",
-                      messageParameters = Map("sqlExpr" -> toSQLExpr(expr)))
-                  }
-                }
-              case _: Attribute if groupingExprs.isEmpty =>
-                operator.failAnalysis(
-                  errorClass = "MISSING_GROUP_BY",
-                  messageParameters = Map.empty)
-              case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) 
=>
-                throw QueryCompilationErrors.columnNotInGroupByClauseError(e)
-              case s: ScalarSubquery
-                  if s.children.nonEmpty && 
!groupingExprs.exists(_.semanticEquals(s)) =>
-                s.failAnalysis(
-                  errorClass = 
"SCALAR_SUBQUERY_IS_IN_GROUP_BY_OR_AGGREGATE_FUNCTION",
-                  messageParameters = Map("sqlExpr" -> toSQLExpr(s)))
-              case e if groupingExprs.exists(_.semanticEquals(e)) => // OK
-              // There should be no Window in Aggregate - this case will fail 
later check anyway.
-              // Perform this check for special case of lateral column alias, 
when the window
-              // expression is not eligible to propagate to upper plan because 
it is not valid,
-              // containing non-group-by or non-aggregate-expressions.
-              case WindowExpression(function, spec) =>
-                function.children.foreach(checkValidAggregateExpression)
-                checkValidAggregateExpression(spec)
-              case e => e.children.foreach(checkValidAggregateExpression)
-            }
-
-            def checkValidGroupingExprs(expr: Expression): Unit = {
-              if (expr.exists(_.isInstanceOf[AggregateExpression])) {
-                expr.failAnalysis(
-                  errorClass = "GROUP_BY_AGGREGATE",
-                  messageParameters = Map("sqlExpr" -> expr.sql))
-              }
-
-              // Check if the data type of expr is orderable.
-              if (!RowOrdering.isOrderable(expr.dataType)) {
-                expr.failAnalysis(
-                  errorClass = "GROUP_EXPRESSION_TYPE_IS_NOT_ORDERABLE",
-                  messageParameters = Map(
-                    "sqlExpr" -> toSQLExpr(expr),
-                    "dataType" -> toSQLType(expr.dataType)))
-              }
-
-              if (!expr.deterministic) {
-                // This is just a sanity check, our analysis rule 
PullOutNondeterministic should
-                // already pull out those nondeterministic expressions and 
evaluate them in
-                // a Project node.
-                throw SparkException.internalError(
-                  msg = s"Non-deterministic expression '${toSQLExpr(expr)}' 
should not appear in " +
-                    "grouping expression.",
-                  context = expr.origin.getQueryContext,
-                  summary = expr.origin.context.summary)
-              }
-            }
-
-            groupingExprs.foreach(checkValidGroupingExprs)
-            aggregateExprs.foreach(checkValidAggregateExpression)
+          case a: Aggregate => ExprUtils.assertValidAggregation(a)
 
           case CollectMetrics(name, metrics, _, _) =>
             if (name == null || name.isEmpty) {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala
index 5a093037e424..29c9605db51b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala
@@ -20,8 +20,12 @@ package org.apache.spark.sql.catalyst.expressions
 import java.text.{DecimalFormat, DecimalFormatSymbols, ParsePosition}
 import java.util.Locale
 
+import org.apache.spark.SparkException
+import org.apache.spark.sql.catalyst.analysis._
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import 
org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, 
TypeCheckSuccess}
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
+import org.apache.spark.sql.catalyst.plans.logical.Aggregate
 import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, CharVarcharUtils}
 import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase, 
QueryExecutionErrors}
 import org.apache.spark.sql.types.{DataType, MapType, StringType, StructType}
@@ -140,4 +144,77 @@ object ExprUtils extends QueryErrorsBase {
       TypeCheckSuccess
     }
   }
+
+  def assertValidAggregation(a: Aggregate): Unit = {
+    def checkValidAggregateExpression(expr: Expression): Unit = expr match {
+      case expr: AggregateExpression =>
+        val aggFunction = expr.aggregateFunction
+        aggFunction.children.foreach { child =>
+          child.foreach {
+            case expr: AggregateExpression =>
+              expr.failAnalysis(
+                errorClass = "NESTED_AGGREGATE_FUNCTION",
+                messageParameters = Map.empty)
+            case other => // OK
+          }
+
+          if (!child.deterministic) {
+            child.failAnalysis(
+              errorClass = 
"AGGREGATE_FUNCTION_WITH_NONDETERMINISTIC_EXPRESSION",
+              messageParameters = Map("sqlExpr" -> toSQLExpr(expr)))
+          }
+        }
+      case _: Attribute if a.groupingExpressions.isEmpty =>
+        a.failAnalysis(
+          errorClass = "MISSING_GROUP_BY",
+          messageParameters = Map.empty)
+      case e: Attribute if !a.groupingExpressions.exists(_.semanticEquals(e)) 
=>
+        throw QueryCompilationErrors.columnNotInGroupByClauseError(e)
+      case s: ScalarSubquery
+        if s.children.nonEmpty && 
!a.groupingExpressions.exists(_.semanticEquals(s)) =>
+        s.failAnalysis(
+          errorClass = "SCALAR_SUBQUERY_IS_IN_GROUP_BY_OR_AGGREGATE_FUNCTION",
+          messageParameters = Map("sqlExpr" -> toSQLExpr(s)))
+      case e if a.groupingExpressions.exists(_.semanticEquals(e)) => // OK
+      // There should be no Window in Aggregate - this case will fail later 
check anyway.
+      // Perform this check for special case of lateral column alias, when the 
window
+      // expression is not eligible to propagate to upper plan because it is 
not valid,
+      // containing non-group-by or non-aggregate-expressions.
+      case WindowExpression(function, spec) =>
+        function.children.foreach(checkValidAggregateExpression)
+        checkValidAggregateExpression(spec)
+      case e => e.children.foreach(checkValidAggregateExpression)
+    }
+
+    def checkValidGroupingExprs(expr: Expression): Unit = {
+      if (expr.exists(_.isInstanceOf[AggregateExpression])) {
+        expr.failAnalysis(
+          errorClass = "GROUP_BY_AGGREGATE",
+          messageParameters = Map("sqlExpr" -> expr.sql))
+      }
+
+      // Check if the data type of expr is orderable.
+      if (!RowOrdering.isOrderable(expr.dataType)) {
+        expr.failAnalysis(
+          errorClass = "GROUP_EXPRESSION_TYPE_IS_NOT_ORDERABLE",
+          messageParameters = Map(
+            "sqlExpr" -> toSQLExpr(expr),
+            "dataType" -> toSQLType(expr.dataType)))
+      }
+
+      if (!expr.deterministic) {
+        // This is just a sanity check, our analysis rule 
PullOutNondeterministic should
+        // already pull out those nondeterministic expressions and evaluate 
them in
+        // a Project node.
+        throw SparkException.internalError(
+          msg = s"Non-deterministic expression '${toSQLExpr(expr)}' should not 
appear in " +
+            "grouping expression.",
+          context = expr.origin.getQueryContext,
+          summary = expr.origin.context.summary)
+      }
+    }
+
+    a.groupingExpressions.foreach(checkValidGroupingExprs)
+    a.aggregateExpressions.foreach(checkValidAggregateExpression)
+  }
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index a5f4c4d5c51a..ae3029b279da 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.plans.logical
 
 import org.apache.spark.internal.Logging
+import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.analysis._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.{AliasAwareQueryOutputOrdering, 
QueryPlan}
@@ -26,7 +27,7 @@ import org.apache.spark.sql.catalyst.trees.{BinaryLike, 
LeafLike, TreeNodeTag, U
 import org.apache.spark.sql.catalyst.types.DataTypeUtils
 import org.apache.spark.sql.catalyst.util.MetadataColumnHelper
 import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{MapType, StructType}
 
 
 abstract class LogicalPlan
@@ -325,6 +326,62 @@ object LogicalPlanIntegrity {
       LogicalPlanIntegrity.hasUniqueExprIdsForOutput(plan))
   }
 
+  /**
+   * This method validates there are no dangling attribute references.
+   * Returns an error message if the check does not pass, or None if it does 
pass.
+   */
+  def validateNoDanglingReferences(plan: LogicalPlan): Option[String] = {
+    plan.collectFirst {
+      // DML commands and multi instance relations (like InMemoryRelation 
caches)
+      // have different output semantics than typical queries.
+      case _: Command => None
+      case _: MultiInstanceRelation => None
+      case n if canGetOutputAttrs(n) =>
+        if (n.missingInput.nonEmpty) {
+          Some(s"Aliases ${ n.missingInput.mkString(", ")} are dangling " +
+            s"in the references for plan:\n ${n.treeString}")
+        } else {
+          None
+        }
+    }.flatten
+  }
+
+  /**
+   * Validate that the grouping key types in Aggregate plans are valid.
+   * Returns an error message if the check fails, or None if it succeeds.
+   */
+  def validateGroupByTypes(plan: LogicalPlan): Option[String] = {
+    plan.collectFirst {
+      case a @ Aggregate(groupingExprs, _, _) =>
+        val badExprs = 
groupingExprs.filter(_.dataType.isInstanceOf[MapType]).map(_.toString)
+        if (badExprs.nonEmpty) {
+          Some(s"Grouping expressions ${badExprs.mkString(", ")} cannot be of 
type Map " +
+            s"for plan:\n ${a.treeString}")
+        } else {
+          None
+        }
+    }.flatten
+  }
+
+  /**
+   * Validate that the aggregation expressions in Aggregate plans are valid.
+   * Returns an error message if the check fails, or None if it succeeds.
+   */
+  def validateAggregateExpressions(plan: LogicalPlan): Option[String] = {
+    plan.collectFirst {
+      case a: Aggregate =>
+        try {
+          ExprUtils.assertValidAggregation(a)
+          None
+        } catch {
+          case e: AnalysisException =>
+            Some(s"Aggregate: ${a.toString} is not a valid aggregate 
expression: " +
+            s"${e.getSimpleMessage}")
+        }
+    }.flatten
+  }
+
+
   /**
    * Validate the structural integrity of an optimized plan.
    * For example, we can check after the execution of each rule that each plan:
@@ -337,7 +394,7 @@ object LogicalPlanIntegrity {
   def validateOptimizedPlan(
       previousPlan: LogicalPlan,
       currentPlan: LogicalPlan): Option[String] = {
-    if (!currentPlan.resolved) {
+    var validation = if (!currentPlan.resolved) {
       Some("The plan becomes unresolved: " + currentPlan.treeString + "\nThe 
previous plan: " +
         previousPlan.treeString)
     } else if 
(currentPlan.exists(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty))
 {
@@ -353,6 +410,13 @@ object LogicalPlanIntegrity {
         }
       }
     }
+    validation = validation
+      .orElse(LogicalPlanIntegrity.validateNoDanglingReferences(currentPlan))
+      .orElse(LogicalPlanIntegrity.validateGroupByTypes(currentPlan))
+      .orElse(LogicalPlanIntegrity.validateAggregateExpressions(currentPlan))
+      .map(err => s"${err}\nPrevious schema:${previousPlan.output.mkString(", 
")}" +
+        s"\nPrevious plan: ${previousPlan.treeString}")
+    validation
   }
 }
 
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerSuite.scala
index 6b63f860b7da..e40fff22bc1c 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerSuite.scala
@@ -17,13 +17,17 @@
 
 package org.apache.spark.sql.catalyst.optimizer
 
+import org.apache.spark.SparkException
 import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.{Alias, IntegerLiteral, 
Literal}
+import org.apache.spark.sql.catalyst.expressions.{Add, Alias, 
AttributeReference, IntegerLiteral, Literal, Multiply, NamedExpression, 
Remainder}
+import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
 import org.apache.spark.sql.catalyst.plans.PlanTest
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, 
OneRowRelation, Project}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LocalRelation, 
LogicalPlan, OneRowRelation, Project}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.IntegerType
 
 /**
  * A dummy optimizer rule for testing that decrements integer literals until 0.
@@ -71,4 +75,266 @@ class OptimizerSuite extends PlanTest {
         s"test, please set '${SQLConf.OPTIMIZER_MAX_ITERATIONS.key}' to a 
larger value."))
     }
   }
+
+  test("Optimizer per rule validation catches dangling references") {
+    val analyzed = Project(Alias(Literal(10), "attr")() :: Nil,
+      OneRowRelation()).analyze
+
+    /**
+     * A dummy optimizer rule for testing that dangling references are not 
allowed.
+     */
+    object DanglingReference extends Rule[LogicalPlan] {
+      def apply(plan: LogicalPlan): LogicalPlan = {
+          Project(Alias(
+            Add(AttributeReference("debug1", IntegerType, nullable = false)(),
+            AttributeReference("debug2", IntegerType, nullable = false)()), 
"attr")() :: Nil,
+            plan)
+      }
+    }
+
+    val optimizer = new SimpleTestOptimizer() {
+        override def defaultBatches: Seq[Batch] =
+          Batch("test", FixedPoint(1),
+            DanglingReference) :: Nil
+    }
+    val message1 = intercept[SparkException] {
+        optimizer.execute(analyzed)
+    }.getMessage
+    assert(message1.contains("are dangling"))
+  }
+
+  test("Optimizer per rule validation catches invalid grouping types") {
+    val analyzed = LocalRelation('a.map(IntegerType, IntegerType))
+      .select('a).analyze
+
+    /**
+     * A dummy optimizer rule for testing that invalid grouping types are not 
allowed.
+     */
+    object InvalidGroupingType extends Rule[LogicalPlan] {
+      def apply(plan: LogicalPlan): LogicalPlan = {
+        Aggregate(plan.output, plan.output, plan)
+      }
+    }
+
+    val optimizer = new SimpleTestOptimizer() {
+      override def defaultBatches: Seq[Batch] =
+        Batch("test", FixedPoint(1),
+          InvalidGroupingType) :: Nil
+    }
+    val message1 = intercept[SparkException] {
+      optimizer.execute(analyzed)
+    }.getMessage
+    assert(message1.contains("cannot be of type Map"))
+  }
+
+  test("Optimizer per rule validation catches invalid aggregation 
expressions") {
+    val analyzed = LocalRelation('a.long, 'b.long)
+      .select('a, 'b).analyze
+
+    /**
+     * A dummy optimizer rule for testing that a non grouping key reference
+     * should be aggregated (under an AggregateFunction).
+     */
+    object InvalidAggregationReference extends Rule[LogicalPlan] {
+      def apply(plan: LogicalPlan): LogicalPlan = {
+        val outputExpressions = plan.output
+        val groupingExpressions = outputExpressions.head :: Nil
+        val aggregateExpressions = outputExpressions
+        // I.e INVALID: select a, b from T group by a
+        Aggregate(groupingExpressions, aggregateExpressions, plan)
+      }
+    }
+
+    /**
+     * A dummy optimizer rule for testing that a non grouping key reference
+     * should be aggregated (under an AggregateFunction).
+     */
+    object InvalidAggregationReference2 extends Rule[LogicalPlan] {
+      def apply(plan: LogicalPlan): LogicalPlan = {
+        val outputExpressions = plan.output
+        val groupingExpressions = outputExpressions.head :: Nil
+        val aggregateExpressions = Alias(Literal(1L), "a")() :: 
outputExpressions.last :: Nil
+        // I.e INVALID: select 1 as a, b from T group by a
+        Aggregate(groupingExpressions, aggregateExpressions, plan)
+      }
+    }
+
+    /**
+     * A dummy optimizer rule for testing that a non grouping key expression
+     * should be aggregated (under an AggregateFunction).
+     */
+    object InvalidAggregationExpression extends Rule[LogicalPlan] {
+      def apply(plan: LogicalPlan): LogicalPlan = {
+        val outputExpressions = plan.output
+        val groupingExpressions = outputExpressions.head :: Nil
+        val aggregateExpressions = outputExpressions.head ::
+          Alias(Add(outputExpressions.last, Literal(1L)), "b")() :: Nil
+        // I.e INVALID: a, select b + 1 as b from T group by a
+        Aggregate(groupingExpressions, aggregateExpressions, plan)
+      }
+    }
+
+    /**
+     * A dummy optimizer rule for testing that a non grouping key expression
+     * should be aggregated (under an AggregateFunction).
+     */
+    object InvalidAggregationExpression2 extends Rule[LogicalPlan] {
+      def apply(plan: LogicalPlan): LogicalPlan = {
+        val outputExpressions = plan.output
+        val groupingExpressions = outputExpressions.head :: Nil
+        val aggregateExpressions = Alias(Literal(1L), "a")() ::
+          Alias(Remainder(outputExpressions.last, outputExpressions.head), 
"b")() :: Nil
+        // I.e INVALID: select 1 as a, b % a as b from T group by a
+        Aggregate(groupingExpressions, aggregateExpressions, plan)
+      }
+    }
+
+    /**
+     * A dummy optimizer rule for testing that a non grouping key expression
+     * should be aggregated (under an AggregateFunction).
+     */
+    object InvalidAggregationExpression3 extends Rule[LogicalPlan] {
+      def apply(plan: LogicalPlan): LogicalPlan = {
+        val outputExpressions = plan.output
+        val groupingExpressions = outputExpressions.head :: Nil
+        val aggregateExpressions = Alias(Literal(1L), "a")() ::
+          Alias(Multiply(outputExpressions.head,
+            Sum(outputExpressions.head).toAggregateExpression()), "b")() :: Nil
+        // I.e VALID: select 1 as a, a*sum(a) as b from T group by a
+        // analyze() should not fail.
+        val goodAggregate =
+          Aggregate(groupingExpressions, aggregateExpressions, plan)
+            .analyze.asInstanceOf[Aggregate]
+        assert(goodAggregate.analyzed)
+        // I.e INVALID: select 1 as a, a*sum(a) as b from T group by b
+        // Rule-validation should catch this.
+       Aggregate(outputExpressions.last :: Nil, 
goodAggregate.aggregateExpressions, plan)
+      }
+    }
+
+    /**
+     * A dummy optimizer rule for testing valid aggregate expression
+     */
+    object ValidAggregationExpression extends Rule[LogicalPlan] {
+      def apply(plan: LogicalPlan): LogicalPlan = {
+        val outputExpressions = plan.output
+        val groupingExpressions = outputExpressions.head :: Nil
+        val aggregateExpressions : Seq[NamedExpression] = 
outputExpressions.head ::
+          Alias(Add(outputExpressions.head, Literal(1L)), "b")() :: Nil
+        // I.e VALID: select a, a + 1 as b from T group by a
+        Aggregate(groupingExpressions, aggregateExpressions, plan)
+      }
+    }
+
+    /**
+     * A dummy optimizer rule for testing another valid aggregate expression
+     */
+    object ValidAggregationExpression2 extends Rule[LogicalPlan] {
+      def apply(plan: LogicalPlan): LogicalPlan = {
+        val outputExpressions = plan.output
+        val groupingExpressions = Add(outputExpressions.head, Literal(1L)) :: 
Nil
+        val aggregateExpressions : Seq[NamedExpression] = Alias(Literal(1L), 
"a")() ::
+          Alias(Add(outputExpressions.head, Literal(1L)), "b")() :: Nil
+        // I.e VALID: select 1 as a, a + 1 as b from T group by a + 1
+        Aggregate(groupingExpressions, aggregateExpressions, plan)
+      }
+    }
+
+    /**
+     * A dummy optimizer rule for testing another valid aggregate expression
+     */
+    object ValidAggregationExpression3 extends Rule[LogicalPlan] {
+      def apply(plan: LogicalPlan): LogicalPlan = {
+        val outputExpressions = plan.output
+        val groupingExpressions = Add(outputExpressions.head, Literal(1L)) :: 
Nil
+        val aggregateExpressions : Seq[NamedExpression] = Alias(Literal(1L), 
"a")() ::
+          Alias(Add(Add(outputExpressions.head, Literal(1L)), Literal(1L)), 
"b")() :: Nil
+        // I.e VALID: select 1 as a, a + 1 + 1 as b from T group by a + 1
+        Aggregate(groupingExpressions, aggregateExpressions, plan)
+      }
+    }
+
+    /**
+     * A dummy optimizer rule for testing another valid aggregate expression
+     */
+    object ValidAggregationExpression4 extends Rule[LogicalPlan] {
+      def apply(plan: LogicalPlan): LogicalPlan = {
+        val outputExpressions = plan.output
+        val groupingExpressions = Add(outputExpressions.head, Literal(1L)) :: 
Nil
+        val aggregateExpressions : Seq[NamedExpression] = Alias(Literal(1L), 
"a")() ::
+          Alias(Sum(outputExpressions.last).toAggregateExpression(), "b")() :: 
Nil
+        // I.e VALID: select 1 as a, sum(b) as b from T group by a + 1
+        Aggregate(groupingExpressions, aggregateExpressions, plan)
+      }
+    }
+
+    /**
+     * A dummy optimizer rule for testing another valid aggregate expression
+     */
+    object ValidAggregationExpression5 extends Rule[LogicalPlan] {
+      def apply(plan: LogicalPlan): LogicalPlan = {
+        val outputExpressions = plan.output
+        val groupingExpressions = outputExpressions.head :: Nil
+        val aggregateExpressions : Seq[NamedExpression] = Alias(Literal(1L), 
"a")() ::
+          Alias(Sum(outputExpressions.head).toAggregateExpression(), "b")() :: 
Nil
+        // I.e VALID: select 1 as a, sum(a) as b from T group by a
+        Aggregate(groupingExpressions, aggregateExpressions, plan)
+      }
+    }
+
+    /**
+     * A dummy optimizer rule for testing another valid aggregate expression
+     */
+    object ValidAggregationExpression6 extends Rule[LogicalPlan] {
+      def apply(plan: LogicalPlan): LogicalPlan = {
+        val outputExpressions = plan.output
+        val groupingExpressions = Remainder(outputExpressions.head, 
Literal(2L)) :: Nil
+        val aggregateExpressions : Seq[NamedExpression] = Alias(Literal(1L), 
"a")() ::
+          Alias(Sum(outputExpressions.head).toAggregateExpression(), "b")() :: 
Nil
+        // I.e VALID: select 1 as a, sum(a) as b from T group by a % 2
+        Aggregate(groupingExpressions, aggregateExpressions, plan)
+      }
+    }
+
+    /**
+     * A dummy optimizer rule for testing another valid aggregate expression
+     */
+    object ValidAggregationExpression7 extends Rule[LogicalPlan] {
+      def apply(plan: LogicalPlan): LogicalPlan = {
+        val outputExpressions = plan.output
+        val groupingExpressions = Remainder(outputExpressions.head, 
Literal(2L)) :: Nil
+        val aggregateExpressions : Seq[NamedExpression] = Alias(Literal(1L), 
"a")() ::
+          Alias(Add(Sum(outputExpressions.head).toAggregateExpression(),
+            groupingExpressions.head), "b")() :: Nil
+        // I.e VALID: 1 as a, select sum(a)*(a % 2) as b from T group by a % 2
+        Aggregate(groupingExpressions, aggregateExpressions, plan).analyze
+      }
+    }
+
+    // Valid rules do not trigger exceptions.
+    Seq(ValidAggregationExpression, ValidAggregationExpression2,
+      ValidAggregationExpression3, ValidAggregationExpression4,
+      ValidAggregationExpression5, ValidAggregationExpression6,
+      ValidAggregationExpression7).map { r =>
+      val optimizer = new SimpleTestOptimizer() {
+        override def defaultBatches: Seq[Batch] =
+          Batch("test", FixedPoint(1), r) :: Nil
+      }
+      assert(optimizer.execute(analyzed).resolved)
+    }
+
+    // Invalid rules trigger exceptions.
+    Seq(InvalidAggregationReference, InvalidAggregationReference2,
+      InvalidAggregationExpression, InvalidAggregationExpression2,
+      InvalidAggregationExpression3).map { r =>
+      val optimizer = new SimpleTestOptimizer() {
+        override def defaultBatches: Seq[Batch] =
+          Batch("test", FixedPoint(1), r) :: Nil
+      }
+      val message1 = intercept[SparkException] {
+        optimizer.execute(analyzed)
+      }.getMessage
+      assert(message1.contains("not a valid aggregate expression"))
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to